[HAL/AMDGPU] Initial host-side AMDGPU HAL implementation (#24298)

This PR lands IREE's native AMDGPU HAL driver: a direct HSA/ROCR backend
that owns queue submission, packet construction, memory placement,
command-buffer recording/replay, profiling, counters, device-library
selection, and future scheduling policy inside IREE instead of routing
normal execution through HIP. The cost is ~70kLoC but that gives IREE
direct ownership of AMD GPU execution instead of routing through HIP
streams and HIP graphs. The critical unlocks happen because IREE already
knows the real program structure that HIP tries to guess at: explicit
semaphore frontiers, queue affinity, memory types, binding tables,
reusable command-buffer blocks, executable metadata, profiling scopes,
and replay captures. The native driver turns that structure directly
into AQL packets and queue-local completion state, which lets us do
things HIP cannot naturally express: low-overhead dynamic command
buffers, heterogeneous HAL device groups, future remote execution,
device-side fixup/scheduling, and profiling/replay from the same command
model. The early numbers show the shape of the win: ~12.5x lower submit
overhead for cross-queue dependency edges, ~22x lower dynamic graph
construction tax versus HIP graphs on a 512-dispatch chain, and ~20x
lower steady-state host CPU time on queue-heavy submission paths. This
is v0, but it is already the architecture we want to optimize: fewer
compatibility layers, more explicit contracts, and a path where AMD GPUs
participate in the full HAL ecosystem instead of living behind a
HIP-shaped abstraction boundary.

This is intentionally a large PR. The driver is not a thin shim around
one runtime call; it is the runtime boundary for AMD GPUs. The branch
contains the native driver plus the AMDGPU-specific hardening that made
the final shape reviewable: command-buffer replay cleanup, queue/pool
integration, profiling producers, target-library selection, device
capability handling, tests, and developer documentation.

The headlines:

- IREE now has a native AMDGPU execution path based on HSA queues and
AQL packets.
- The driver can run normal HAL dispatches and reusable HAL command
buffers without HIP streams or HIP graphs.
- The command-buffer representation is designed as a durable block
program that can be replayed by host processors now and device-side
processors later.
- The profiling path can expose queue, dispatch, executable, counter,
device metric, and ATT/SQTT trace data through the HAL profile tooling.
- The hot paths are structured so static production replay does not pay
for optional profiling, trace, upload, or future device-fixup machinery.

## Why

HIP is a useful compatibility layer and comparison point, but it is not
the right abstraction boundary for the runtime work IREE wants to do.

IREE needs to be able to control:

- how HAL queue operations become AQL/PM4 packets;
- where kernargs, command-buffer templates, transient buffers, and
staging records live;
- how semaphore dependencies map to queue frontiers and completion
epochs;
- how reusable command buffers are recorded, validated, replayed, and
profiled;
- where host work ends and queue-ordered device work begins;
- how to capture profiling data without turning the production queue
into a debug path; and
- how to evolve toward device-side command-buffer scheduling and fixup.

HIP graphs are especially awkward for IREE's dynamic command-buffer use
case. They can be expensive to construct, hard to introspect, and
difficult to shape around IREE's own async allocation and replay
contracts. The native driver gives IREE a graph-like reusable command
stream while keeping the command stream in IREE's own ABI.

## Design Principles

The implementation follows a few constraints that are worth making
explicit for review.

**Own the production hot path.** Queue submission, command-buffer
replay, kernarg formation, packet publication, and completion are
explicit IREE code. Optional features are allowed only when they do not
tax the default path. For example, profiling, ATT/SQTT capture,
queue-control upload rings, and future device-side fixup all have opt-in
storage and control flow.

**Record facts once.** Command buffers are allowed to do work while
recording and finalizing so replay can be simple. Binding counts, patch
counts, packet counts, barrier requirements, prepublication eligibility,
rodata references, and block terminators are recorded in the
command-buffer program instead of rediscovered by
scanning command records during submission.

**Keep host and device processors pointed at the same ABI.** The AMDGPU
command buffer is a block program, not a host-only replay script. The
current host AQL block processor consumes that program; future
device-side processors should consume the same block format for
command-buffer continuations, scheduling, and kernarg fixup.

**Separate invariant clusters.** The driver is split by subsystem rather
than growing one giant queue file. There are distinct files for queue
submission, queue waits, command-buffer block processing, command-buffer
replay, profiling augmentation, staging/file paths, memory operations,
executable handling, topology, device capabilities, and utility rings.

**Fail loud on unsupported strategies.** Unsupported memory paths,
command forms, profiling modes, and device capabilities should fail with
a concrete status instead of silently falling back through the wrong
mechanism.

**Make platform/device variation explicit.** The code names the places
where HSA memory-pool access, HDP publication, topology links, target
IDs, device-library coverage, Linux KFD metrics, and optional ROCm
profiling libraries affect behavior.

## Architecture Overview

### Driver And Device Model

The driver dynamically loads HSA/ROCR, discovers CPU and GPU agents, and
creates logical HAL devices over one or more physical AMDGPU agents.

The main object split is:

- driver: HSA discovery, option parsing, and logical-device creation;
- logical device: HAL-facing device object and shared runtime state;
- physical device: one HSA GPU agent with queues, memory pools,
executable cache, device-library selection, profiling state, device
metrics, and topology facts;
- host queue: HSA queue plus IREE's AQL, kernarg, notification,
completion, and reclaim state; and
- virtual queue: the internal interface used so command-buffer, direct
dispatch, memory, file, and profiling paths route through one queue
contract.

Device selection supports all visible AMDGPU agents by default,
single-device selection, UUID-based selection, ordinal selection, and
multi-device logical devices. The topology code records HSA memory-pool
access, link class, NUMA distance, coherency, atomics, and interop
capability facts so future placement and transfer strategies can reason
about PCIe, xGMI, and other link types without hard-coded assumptions.

### Executables And Device Libraries

AMDGPU executables are loaded from HSACO/code-object data and matched
against the selected physical device. The runtime also embeds AMDGPU
device libraries used for builtin operations such as fill/copy helpers,
timestamp helpers, and dispatch-side utilities.

The device-library target map is single-sourced from generated target
metadata. Builds can select exact targets, LLVM generic targets,
TheRock-style generic families, or product bundles. This keeps package
size and device coverage under explicit build-system control while
letting the runtime fail clearly when a required target was not
embedded.

### Memory, Pools, And Publication

The driver integrates with the HAL pool substrate and AMDGPU HSA memory
pools instead of treating all buffers as generic allocations.

The implementation distinguishes:

- device-local memory;
- CPU-visible fine-grained host memory;
- CPU-visible coarse-grained device memory;
- queue-owned kernarg memory;
- optional queue-control upload memory;
- transient allocation pools;
- file/staging storage; and
- host-side block/slab pools used by queue and profiling data
structures.

HDP publication is represented as a selected capability of the memory
path, not as an ad hoc flush sprinkled through dispatch code. If CPU
writes to memory that the GPU will consume require publication on a
device, the queue-owned memory path knows how to publish those writes
before the relevant packet headers become
visible.

The default queue-control upload ring is disabled until a production
consumer opts in. That keeps the future device-side fixup path available
without charging every queue an unused HSA allocation.

### Queue Submission And Completion

Host queues own an HSA AQL queue and maintain:

- an AQL ring view for packet reservation/publicat
> <img width="403" height="222" alt="image"
src="https://github.com/user-attachments/assets/11a9ef26-dc32-427c-a01e-7969fd24ec2d"
/> (kind of, consider this the reference implementation)

ion;
- a kernarg ring for queue-owned dispatch arguments;
- an epoch/notification ring mapping GPU completions to HAL semaphore
signals;
- a queue frontier snapshot for dependency tracking;
- one completion thread that drains queue epochs and publishes
user-visible semaphore completions;
- optional PM4 IB slots indexed by AQL packet id on hardware that
supports AQL PM4 packets; and
- optional profiling/counter/trace state.

Submission is serialized per queue, but independent queues do not
synchronize with each other. The queue submission path reserves AQL
packets, kernargs, and notification entries before publishing headers.
If admission fails, reclaim is routed through the same
notification/reclaim machinery instead of inventing a
parallel cleanup path.

HAL ordering is represented by semaphore/frontier dependencies, not by
assuming FIFO execution. The queue frontier machinery lets the driver
elide redundant waits when the dependency is already known to be
satisfied, while preserving correctness when the frontier overflows or
cannot prove elision.

### Direct Dispatch And Builtin Operations

Direct `queue_dispatch` resolves executable metadata, validates dispatch
shape, forms kernargs, retains the executable/buffer resources required
by the submission, and emits AQL packets through the common queue
submission path.

Queue buffer operations are implemented through explicit strategies.
Builtin device kernels cover fill/copy/update paths and are selected
based on alignment, size, and available device-library kernels. The code
leaves room for SDMA, PM4, P2P, and future direct-storage strategies
without conflating those with the current kernel-dispatch path.

### Command Buffers

The AMDGPU command-buffer ABI is the center of the rewrite.

Recorded command buffers are stored as a program of blocks. Each block
has a fixed header with command counts, binding-source counts,
packet/kernarg worst case, rodata extent, dispatch/profile-marker
counts, barrier metadata, and a terminator. Commands include barriers,
dispatches, fills, copies, updates,
profile markers, branches, conditional branches, and returns.

The important split is:

- the command buffer owns the durable block program and rodata;
- the AQL block processor consumes one block and writes reserved
packet/kernarg storage;
- host queue replay is the container/orchestration layer that
initializes a processor, invokes blocks, handles continuations, and
integrates with semaphores/reclaim; and
- profiling processors are separate variants that augment replay only
when profiling was explicitly requested.

This shape is deliberate. A block processor is close to a small
interpreter over the block ABI. It is suitable for dedicated tests today
and for device-side processor variants later. Host queue code should not
need to know how every command body becomes AQL packets.

Replay hot paths are specialized:

- static reusable dispatches can use prepublished kernargs;
- all-dynamic dispatches use a direct binding-pointer scatter path;
- mixed static/dynamic reusable dispatches use immutable templates plus
recorded dynamic patch sources;
- indirect dispatch parameters stay on the generic path where required;
and
- profile-disabled replay bypasses profile sidecars and trace/counter
logic.

Dynamic binding sources retain the original `queue_execute` binding
table slot for the entire command-buffer lifetime. There is no per-block
binding remap sidecar, and no finalization scan that rewrites binding
slots. Future
device-side fixup should consume recorded patch records directly: `patch
offset + binding table slot + binding offset`.

### Profiling, Counters, Traces, And Replay

The driver is a first-class producer for the HAL-native profiling and
replay stack.

Supported profiling/data modes include:

- host-side memory and queue events;
- device-side queue timestamps;
- per-dispatch timestamps;
- executable/export metadata;
- hardware/software counters;
- queue-range PMC sampling;
- device metrics from platform-specific sources;
- filtered ATT/SQTT executable traces through dynamically loaded ROCm
profiling libraries; and
- replay captures that can be run, benchmarked, dumped, and profiled
outside the original application.

Normal execution does not require ROCm profiling libraries. The
aqlprofile path is dynamically loaded only for modes that need counters
or executable traces. Linux-specific device-metric support is isolated
behind a platform source so the core driver remains structured for
future Windows and macOS HSA support.

## Performance Evidence

The main apples-to-apples GPU comparison uses the SDXL CLIP prompt
encoder: a real sharktank workload with 792 dispatches, 28 executables,
and enough queue traffic to exercise command-buffer replay and
host/runtime overhead.

Post-cleanup optimized non-Tracy medians:

| Shape | AMDGPU wall | HIP stream wall | AMDGPU vs stream | HIP graph
wall | AMDGPU vs graph | AMDGPU host CPU |
| --- | ---: | ---: | ---: | ---: | ---: | ---: |
| c1/d1 | 10.9508 ms | 11.5456 ms | 5.15% faster | 11.6199 ms | 5.76%
faster | 0.618 ms |
| c1/d16 | 0.7035 ms/item | 0.7311 ms/item | 3.78% faster | 0.7335
ms/item | 4.09% faster | 0.036 ms/item |
| c2/d16 | 0.7073 ms/item | 0.7298 ms/item | 3.08% faster | 0.7330
ms/item | 3.50% faster | 0.037 ms/item |
| c4/d16 | 0.7066 ms/item | 0.7278 ms/item | 2.92% faster | 0.7288
ms/item | 3.05% faster | 0.037 ms/item |
| c8/d16 | 0.7058 ms/item | 0.7322 ms/item | 3.60% faster | 0.7333
ms/item | 3.75% faster | 0.038 ms/item |

The broader model spread is consistent with the same story: native
AMDGPU is usually ahead of HIP stream, usually ahead of HIP graph when
HIP graph can import the workload, and uses much less host CPU on
queue-heavy paths.

Representative additional rows:

| Workload | Shape | AMDGPU | HIP stream | HIP graph | Notes |
| --- | --- | ---: | ---: | ---: | --- |
| MNIST-12 | c1/d1 | 0.0978 ms | 0.1423 ms | 0.1425 ms | Small
classifier, high runtime-overhead sensitivity. |
| SqueezeNet 1.0 | c1/d1 | 1.1428 ms | 1.2043 ms | 1.1988 ms | Compact
CNN. |
| toy CLIP bf16 | c1/d1 | 0.2227 ms | 0.2578 ms | 0.2597 ms |
Transformer-ish toy encoder. |
| MobileNetV2-12 | c1/d1 | 1.8462 ms | 1.9316 ms | crash |
Depthwise/mobile CNN; HIP graph crashes locally. |
| TinyYOLOv2-8 | c1/d1 | 7.6516 ms | 8.0490 ms | 8.5600 ms | Object
detection graph. |
| ResNet50-v1-12 | c1/d1 | 9.5364 ms | 9.6900 ms | import fails | HIP
graph node limit. |
| SDXL scheduled UNet | c1/d1 body | 204.36 ms | 215.19 ms | 216.43 ms |
Direct `run_forward` body. |
| SDXL CLIP prompt encoder | c8/d16 | 0.692 ms | 0.721 ms | 0.725 ms |
Byte-identical HSACO/no-prefetch row. |

We also compared raw C HAL command-buffer construction/replay against
raw C HIP graph construction/launch for a 512 dispatch/barrier chain,
avoiding VM overhead on both sides:

| Path | Prebuilt wall | Dynamic wall | Extra wall | Extra wall /
dispatch | Extra CPU / dispatch |
| --- | ---: | ---: | ---: | ---: | ---: |
| HAL command buffer, validated | 2096.4 us | 2177.0 us | 80.5 us |
0.157 us | 0.582 us |
| HAL command buffer, unvalidated | 2096.4 us | 2143.3 us | 46.9 us |
0.092 us | 0.526 us |
| HIP graph | 2983.7 us | 4022.9 us | 1039.3 us | 2.030 us | 2.308 us |

That is the key dynamic-command-buffer result: unvalidated HAL
command-buffer recording/replay adds tens of microseconds for the
512-pair chain, while HIP graph construction adds about a millisecond in
the same harness.

Queue-stress microbenchmarks isolate the pathological submission streams
that large distributed and graph-style applications care about. The
current-head HAL rows below use the checked-in AMDGPU `queue_benchmark`
built optimized with release ThinLTO/O3/native flags, pinned to one CPU
and one local RDNA3 GPU. HIP rows use the matching HIP event ping-pong
harness on the same CPU/GPU pin. The end-to-end rows measure 512
cross-queue dependency edges plus one public host-visible completion:

| Shape | AMDGPU end-to-end / edge | HIP end-to-end / edge | Read |
| --- | ---: | ---: | --- |
| Cross-queue dependency edge | 4.58 us | 11.20 us | AMDGPU is 2.4x
faster. |
| Edge + 4-byte device copy | 11.65 us | 14.62 us | AMDGPU is 1.25x
faster. |
| Edge + 4-byte device fill | 10.98 us | 15.20 us | AMDGPU is 1.38x
faster. |
| Edge + tiny dispatch | 10.55 us | 14.59 us | AMDGPU is 1.38x faster. |
| Edge + no-op dispatch packet | 4.56 us | n/a | AMDGPU stays near the
pure dependency floor when payload work is empty. |

The pure submit-only dependency row is the sharpest host-path
comparison: AMDGPU submits a cross-queue dependency edge for about 0.42
us/edge, while HIP events cost about 5.23 us/edge in the same pinned
harness. That is about 12.5x less host-side submission overhead for the
synchronization pattern used by tensor-parallel and pipeline-parallel
programs.

This is not just an implementation-speed comparison. HIP stream events
and HIP graphs sit above a compatibility runtime that has to rediscover
intent from streams, events, graph nodes, kernel parameters, and raw
pointer arguments. IREE already has that intent in structured HAL
commands: explicit semaphore frontiers, queue affinity, binding tables,
memory types, command-buffer blocks, and executable metadata. The AMDGPU
HAL can turn those contracts directly into AQL packets and queue-local
completion state without routing every operation
through HIP's public stream/event/graph abstraction.

That structural difference is why the CPU-time story is as important as
the wall-time story. On the SDXL CLIP prompt encoder, AMDGPU runs the
steady-state batched path with roughly 0.036-0.038 ms/item of host CPU
time while HIP stream and HIP graph paths are around 0.74-0.76 ms/item.
That is a roughly 20x host CPU reduction on the queue-heavy path. On
systems with many accelerators, expensive prefill/decode scheduling, or
small CPU budgets, that difference is the difference between the CPU
being orchestration glue and the CPU becoming the
bottleneck.

The same abstraction boundary is also what lets HAL scale beyond HIP's
world model. HAL command buffers, semaphores, queue affinity, memory
files, and device groups can describe local GPUs, CPU devices, remote
devices, and heterogeneous execution without changing the program's
synchronization model. The upcoming remote HAL work can use the same
command/dependency concepts across process or machine boundaries; HIP
cannot represent that kind of heterogeneous or remote execution graph
without collapsing it back into host-side framework logic. This rewrite
puts AMDGPU on the same HAL substrate as local-task, local-sync,
profiling, replay, and future remote execution instead of treating AMD
GPUs as a HIP-shaped island.

Tracy and Perfetto captures were used as structural evidence for queue
shape, host/runtime gaps, worker behavior, dispatch timing, counter
ranges, and device metric sampling. Non-Tracy optimized runs are the
source of the wall-time numbers above.

## Portability And Hardware Coverage

The current implementation has been exercised primarily on local
RDNA3/gfx1100 Linux hardware, but the code is structured for broader
AMDGPU support.

Cross-device preparation in this PR includes:

- target ID parsing and generated target maps for exact, generic,
family, and product-bundle device-library selection;
- explicit HSA memory-pool access and link-topology modeling;
- CPU-visible device-coarse memory capability selection with HDP
publication;
- queue-owned kernarg publication policy;
- PM4 capability detection and AQL PM4 IB infrastructure where
supported;
- generic device-library target selection instead of hard-coding
gfx1100; and
- tests around target IDs, code-object target selection, topology,
memory access, device-library lookup, and PM4/AQL emitters.

Cross-platform preparation includes:

- dynamic HSA loading instead of a direct link dependency;
- platform-isolated Linux KFD/device-metric support;
- optional dynamic loading of ROCm profiling libraries;
- public HAL abstractions for profiling/replay rather than AMDGPU-only
tool hooks; and
- explicit failure for unsupported platform features.

This PR does not claim every modern RDNA/CDNA target is fully proven. It
gives us the driver architecture, target map, and capability seams
required to harden that matrix as more hardware and platform HSA stacks
become available.

## Forward-Looking Work Enabled By This Shape

Several important features are intentionally not completed in this PR,
but the landed architecture is designed around them.

**Device-side dynamic kernarg fixup.** Dynamic command buffers currently
patch queue-owned kernargs on the host. The planned production path is
to upload a small per-submission binding table/control record and
dispatch a device-side fixup kernel that copies template kernargs and
patches dynamic qwords before
payload dispatches execute. The recorded command-buffer patch records
already carry the essential facts: target patch location, original
binding-table slot, and binding offset.

**Device-side command-buffer scheduling.** The block-program ABI gives
us a clean path to device-side processors. A device queue can invoke
block processors, advance command-buffer continuations, and schedule
independent blocks without forcing host queue code to understand every
command body.

**Command-buffer control flow.** The ABI already reserves branch,
conditional branch, and return terminators. Host replay currently
supports the subset needed by the landed workloads; the representation
is intentionally shaped so richer control flow can become an execution
feature rather than a new command-buffer
format.

**Binding-table-indirect dispatch ABI.** A future dispatch ABI may avoid
dynamic kernarg pointer fixup by passing an invocation-local binding
table base and loading buffer pointers indirectly in kernels. That needs
compiler/runtime experiments to measure the cost of an extra scalar load
versus raw pointer kernargs, but the current direct binding-table slot
invariant is compatible with that direction.

**PM4-backed queues and operations.** The driver now has PM4 emitters,
PM4 program utilities, capability detection, and AQL PM4 IB slots on
supported hardware. That creates room for PM4-backed waits, transfers,
profiling snippets, and potentially lower-level queue strategies where
HSA/AQL alone is not the best mechanism.

**Transfer strategy expansion.** Current transfer paths use explicit
builtin device kernels and staging strategies. The queue/file/memory
split leaves room for SDMA, P2P, direct storage, and topology-aware copy
selection without rewriting the core queue completion path.

**Broader profiling.** CDNA devices should expose richer counter options
than the initial local setup. The queue-range PMC and profile-bundle
infrastructure are meant to scale into that environment without changing
the normal execution path.

## Review Guide

Good entry points for review:

- `runtime/src/iree/hal/drivers/amdgpu/README.md`: user-facing driver
overview, build flags, runtime selection, profiling, and target-library
notes.
- `runtime/src/iree/hal/drivers/amdgpu/api.h`: public driver/device
options.
- `runtime/src/iree/hal/drivers/amdgpu/driver.c`: driver registration,
HSA loading, and device creation.
- `runtime/src/iree/hal/drivers/amdgpu/logical_device.c`: HAL device
methods, profiling/replay integration, and physical-device
orchestration.
- `runtime/src/iree/hal/drivers/amdgpu/physical_device.c`: HSA agent
setup, queue creation, memory pools, executable caches, device
libraries, profiling, and topology state.
- `runtime/src/iree/hal/drivers/amdgpu/host_queue.c`: queue ownership,
completion thread, submission state, and reclaim lifetime.
- `runtime/src/iree/hal/drivers/amdgpu/host_queue_submission.c`: common
submission admission, publication, and failure/reclaim path.
- `runtime/src/iree/hal/drivers/amdgpu/aql_command_buffer.c`:
command-buffer recording, layout, prepublication, dynamic binding
strategy, and block construction.
- `runtime/src/iree/hal/drivers/amdgpu/abi/command_buffer.h`: durable
command-buffer block ABI.
- `runtime/src/iree/hal/drivers/amdgpu/aql_block_processor.c`:
unprofiled AQL block processor.
- `runtime/src/iree/hal/drivers/amdgpu/aql_block_processor_profile.c`:
profiling-augmented block processor.
- `runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer*.c`:
host replay orchestration, block submission, packet policy, scratch
storage, and profiling integration.
- `runtime/src/iree/hal/drivers/amdgpu/profile_*.c`: profile producers
for events, metadata, counters, device metrics, and traces.
- `runtime/src/iree/hal/drivers/amdgpu/device/*.c`: embedded device-side
helper kernels and host-side packet/kernarg formation helpers.
- `runtime/src/iree/hal/drivers/amdgpu/util/*.c`: HSA loading, target
IDs, code-object metadata, rings, signals, PM4/AQL emitters, topology,
and KFD utilities.

## Validation

Validation covered both source-level unit tests and workload-level
evidence:

- focused AMDGPU unit tests for HSA loading, target IDs, code-object
metadata, device libraries, topology, capabilities, pools, signals,
rings, emitters, executables, semaphores, allocators, command buffers,
block processors, host queue submission, staging, profiling
metadata/events, and CTS backends;
- AMDGPU HAL CTS dispatch/executable coverage;
- focused Linux Bazel ASAN builds/tests for the AMDGPU runtime targets;
- focused CMake configure/build/test coverage for AMDGPU runtime
libraries and generated CTS artifacts;
- Windows and macOS CMake validation of the shared
HAL/async/profile/replay substrate that this driver depends on;
- SDXL CLIP correctness on both visible local AMDGPU devices with the
same weights, inputs, and expected outputs used for CPU validation;
- SDXL CLIP, SDXL UNet, model-spread, command-buffer-vs-HIP-graph,
Tracy, Perfetto, device-metrics, PMC, and ATT/SQTT profiling runs; and
- pre-commit formatting/check generation hooks for the final branch.

The performance numbers in this PR are from optimized non-Tracy runs on
my machine, YMMV. Tracy, Perfetto, counters, and device metrics were
used to explain structure and validate behavior, not as the source of
wall-clock claims.
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 963d4f2..9424c62 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -107,6 +107,14 @@
         entry: ./build_tools/bazel_to_cmake/bazel_to_cmake.py
         files: "^(compiler|runtime|samples|tests|tools)/(.*/)?(CMakeLists.txt)$"
 
+      - id: amdgpu_target_map
+        name: Check AMDGPU target map generated files
+        language: python
+        entry: ./build_tools/scripts/amdgpu_target_map.py --check
+        pass_filenames: false
+        files: >-
+          ^build_tools/scripts/amdgpu_target_map\.py$|^runtime/src/iree/hal/drivers/amdgpu/device/binaries/target_map\.(bzl|cmake|inl)$
+
       - id: check_path_lengths
         name: Check for excessively long path lengths
         language: fail
diff --git a/build_tools/bazel/iree_amdgpu_binary.bzl b/build_tools/bazel/iree_amdgpu_binary.bzl
index ab7a034..b4aa95d 100644
--- a/build_tools/bazel/iree_amdgpu_binary.bzl
+++ b/build_tools/bazel/iree_amdgpu_binary.bzl
@@ -21,13 +21,10 @@
         name: Name of the target.
         target: LLVM `-target` flag.
         arch: LLVM `-march` flag.
-        srcs: source files to pass to clang.
-        internal_hdrs: all headers transitively included by the source files.
-                       Unlike typical Bazel `hdrs`, these are not exposed as
-                       interface headers. This would normally be part of `srcs`,
-                       but separating it was easier for `bazel_to_cmake`, as
-                       CMake does not need this, and making this explicitly
-                       Bazel-only allows using `filegroup` on the Bazel side.
+        srcs: source files or filegroups to pass to clang.
+        internal_hdrs: headers that should invalidate device compilation but
+                       are not compiled as translation units or exposed as
+                       interface headers.
         copts: additional flags to pass to clang.
         linkopts: additional flags to pass to lld.
         **kwargs: any additional attributes to pass to the underlying rules.
@@ -71,43 +68,39 @@
         "-emit-llvm",
     ]
 
-    bitcode_files = []
-
-    for src in srcs:
-        bitcode_out = "%s_%s.bc" % (name, src)
-        bitcode_files.append(bitcode_out)
-        native.genrule(
-            name = "gen_%s" % (bitcode_out),
-            srcs = [src, builtin_headers_dep] + internal_hdrs,
-            outs = [bitcode_out],
-            cmd = " && ".join([
+    archive_out = "%s.a" % (name)
+    source_locations = " ".join(["$(locations %s)" % (src,) for src in srcs])
+    object_dir = "$(@D)/%s.objects" % (name,)
+    native.genrule(
+        name = "archive_%s" % (name),
+        srcs = srcs + [builtin_headers_dep] + internal_hdrs,
+        outs = [archive_out],
+        cmd = " && ".join([
+            "set -e",
+            "object_dir=\"%s\"" % (object_dir,),
+            "rm -rf \"$${object_dir}\"",
+            "mkdir -p \"$${object_dir}\"",
+            "object_index=0",
+            "for src in %s; do %s; object_index=$$((object_index + 1)); done" % (
+                source_locations,
                 " ".join([
                     "$(location %s)" % (clang_tool),
                     " ".join(base_copts + copts),
-                    "-o $(location %s)" % (bitcode_out),
-                    "$(location %s)" % (src),
+                    "-o \"$${object_dir}/$${object_index}.bc\"",
+                    "\"$${src}\"",
                 ]),
-            ]),
-            tools = [clang_tool],
-            message = "Compiling %s to %s..." % (src, bitcode_out),
-            output_to_bindir = 1,
-            **kwargs
-        )
-
-    archive_out = "%s.a" % (name)
-    native.genrule(
-        name = "archive_%s" % (name),
-        srcs = bitcode_files,
-        outs = [archive_out],
-        cmd = " && ".join([
+            ),
             " ".join([
                 "$(location %s)" % (link_tool),
-                " ".join(["$(locations %s)" % (src) for src in bitcode_files]),
+                "\"$${object_dir}\"/*.bc",
                 "-o $(location %s)" % (archive_out),
             ]),
         ]),
-        tools = [link_tool],
-        message = "Archiving bitcode libraries %s to %s..." % (bitcode_files, archive_out),
+        tools = [
+            clang_tool,
+            link_tool,
+        ],
+        message = "Compiling bitcode library %s to %s..." % (srcs, archive_out),
         output_to_bindir = 1,
         **kwargs
     )
diff --git a/build_tools/cmake/iree_amdgpu_binary.cmake b/build_tools/cmake/iree_amdgpu_binary.cmake
index aed8747..1124f04 100644
--- a/build_tools/cmake/iree_amdgpu_binary.cmake
+++ b/build_tools/cmake/iree_amdgpu_binary.cmake
@@ -14,12 +14,8 @@
 # TARGET: LLVM `-target` flag.
 # ARCH: LLVM `-march` flag.
 # SRCS: source files to pass to clang.
-# INTERNAL_HDRS: all headers transitively included by the source files.
-#                Unlike typical Bazel `hdrs`, these are not exposed as
-#                interface headers. This would normally be part of `srcs`,
-#                but separating it was easier for `bazel_to_cmake`, as
-#                CMake does not need this, and making this explicitly
-#                Bazel-only allows using `filegroup` on the Bazel side.
+# INTERNAL_HDRS: headers that should invalidate device compilation but are not
+#                compiled as translation units or exposed as interface headers.
 # COPTS: additional flags to pass to clang.
 # LINKOPTS: additional flags to pass to lld.
 function(iree_amdgpu_binary)
@@ -74,7 +70,12 @@
   set(_BITCODE_FILES)
   foreach(_SRC ${_RULE_SRCS})
     get_filename_component(_BITCODE_SRC_PATH "${_SRC}" REALPATH)
-    string(REGEX REPLACE "[.]c$" "--${_RULE_ARCH}.bc" _BITCODE_FILE ${_SRC})
+    set(_BITCODE_SRC_FRAGMENT "${_SRC}")
+    string(REPLACE "\\" "_" _BITCODE_SRC_FRAGMENT "${_BITCODE_SRC_FRAGMENT}")
+    string(REPLACE "/" "_" _BITCODE_SRC_FRAGMENT "${_BITCODE_SRC_FRAGMENT}")
+    string(REPLACE ":" "_" _BITCODE_SRC_FRAGMENT "${_BITCODE_SRC_FRAGMENT}")
+    string(REPLACE "." "_" _BITCODE_SRC_FRAGMENT "${_BITCODE_SRC_FRAGMENT}")
+    set(_BITCODE_FILE "${_RULE_NAME}_${_BITCODE_SRC_FRAGMENT}.bc")
     list(APPEND _BITCODE_FILES ${_BITCODE_FILE})
     add_custom_command(
       OUTPUT
@@ -89,8 +90,6 @@
         "${IREE_CLANG_BINARY}"
         "${_BITCODE_SRC_PATH}"
         "${_RULE_INTERNAL_HDRS}"
-      MAIN_DEPENDENCY
-        "${_BITCODE_SRC_PATH}"
       COMMENT
         "Compiling ${_SRC} to ${_BITCODE_FILE}"
       VERBATIM
diff --git a/build_tools/scripts/amdgpu_target_map.py b/build_tools/scripts/amdgpu_target_map.py
new file mode 100755
index 0000000..c637a31
--- /dev/null
+++ b/build_tools/scripts/amdgpu_target_map.py
@@ -0,0 +1,468 @@
+#!/usr/bin/env python3
+# Copyright 2026 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+"""Generates shared AMDGPU device library target map fragments.
+
+The map in this file is the source of truth for the small generated tables used
+by Bazel, CMake, and the runtime device-library loader. Keep build logic in
+Starlark/CMake; keep target facts here.
+"""
+
+import argparse
+import difflib
+import re
+import sys
+from pathlib import Path
+
+DEFAULT_TARGET_SELECTIONS = ("all",)
+
+# Each exact target must match an HSA ISA architecture suffix. Each code object
+# target must be accepted by LLVM clang/lld as an AMDGPU -march value. Generic
+# code object coverage follows LLVM generic processor documentation; TheRock
+# family membership follows ROCm/TheRock's cmake/therock_amdgpu_targets.cmake.
+EXACT_TARGET_CODE_OBJECTS = (
+    ("gfx900", "gfx9-generic"),
+    ("gfx902", "gfx9-generic"),
+    ("gfx904", "gfx9-generic"),
+    ("gfx90c", "gfx9-generic"),
+    ("gfx906", "gfx9-generic"),
+    ("gfx908", "gfx908"),
+    ("gfx909", "gfx9-generic"),
+    ("gfx90a", "gfx90a"),
+    ("gfx940", "gfx9-4-generic"),
+    ("gfx941", "gfx9-4-generic"),
+    ("gfx942", "gfx9-4-generic"),
+    ("gfx950", "gfx9-4-generic"),
+    ("gfx1010", "gfx10-1-generic"),
+    ("gfx1011", "gfx10-1-generic"),
+    ("gfx1012", "gfx10-1-generic"),
+    ("gfx1013", "gfx10-1-generic"),
+    ("gfx1030", "gfx10-3-generic"),
+    ("gfx1031", "gfx10-3-generic"),
+    ("gfx1032", "gfx10-3-generic"),
+    ("gfx1033", "gfx10-3-generic"),
+    ("gfx1034", "gfx10-3-generic"),
+    ("gfx1035", "gfx10-3-generic"),
+    ("gfx1036", "gfx10-3-generic"),
+    ("gfx1100", "gfx11-generic"),
+    ("gfx1101", "gfx11-generic"),
+    ("gfx1102", "gfx11-generic"),
+    ("gfx1103", "gfx11-generic"),
+    ("gfx1150", "gfx11-generic"),
+    ("gfx1151", "gfx11-generic"),
+    ("gfx1152", "gfx11-generic"),
+    ("gfx1153", "gfx11-generic"),
+    ("gfx1170", "gfx11-generic"),
+    ("gfx1171", "gfx11-generic"),
+    ("gfx1172", "gfx11-generic"),
+    ("gfx1200", "gfx12-generic"),
+    ("gfx1201", "gfx12-generic"),
+    ("gfx1250", "gfx12-5-generic"),
+    ("gfx1251", "gfx12-5-generic"),
+)
+
+FEATURE_SRAMECC = "sramecc"
+FEATURE_XNACK = "xnack"
+
+# Feature support follows ROCr's ISA registry. A target absent from a feature
+# set does not support that feature; supported targets may still select an
+# explicit on/off mode at runtime.
+TARGET_FEATURE_SUPPORT = {
+    "gfx900": (FEATURE_XNACK,),
+    "gfx902": (FEATURE_XNACK,),
+    "gfx904": (FEATURE_XNACK,),
+    "gfx906": (FEATURE_SRAMECC, FEATURE_XNACK),
+    "gfx908": (FEATURE_SRAMECC, FEATURE_XNACK),
+    "gfx909": (FEATURE_XNACK,),
+    "gfx90c": (FEATURE_XNACK,),
+    "gfx90a": (FEATURE_SRAMECC, FEATURE_XNACK),
+    "gfx940": (FEATURE_SRAMECC, FEATURE_XNACK),
+    "gfx941": (FEATURE_SRAMECC, FEATURE_XNACK),
+    "gfx942": (FEATURE_SRAMECC, FEATURE_XNACK),
+    "gfx950": (FEATURE_SRAMECC, FEATURE_XNACK),
+    "gfx1010": (FEATURE_XNACK,),
+    "gfx1011": (FEATURE_XNACK,),
+    "gfx1012": (FEATURE_XNACK,),
+    "gfx1013": (FEATURE_XNACK,),
+}
+
+ALL_EXACT_TARGETS = object()
+
+TARGET_FAMILIES = (
+    ("all", ALL_EXACT_TARGETS),
+    (
+        "dcgpu-all",
+        ("gfx908", "gfx90a", "gfx940", "gfx941", "gfx942", "gfx950"),
+    ),
+    (
+        "dgpu-all",
+        (
+            "gfx900",
+            "gfx902",
+            "gfx904",
+            "gfx906",
+            "gfx909",
+            "gfx1010",
+            "gfx1011",
+            "gfx1012",
+            "gfx1013",
+            "gfx1030",
+            "gfx1031",
+            "gfx1032",
+            "gfx1034",
+            "gfx1100",
+            "gfx1101",
+            "gfx1102",
+            "gfx1200",
+            "gfx1201",
+        ),
+    ),
+    ("gfx900-dgpu", ("gfx900",)),
+    ("gfx906-dgpu", ("gfx906",)),
+    ("gfx908-dcgpu", ("gfx908",)),
+    ("gfx90a-dcgpu", ("gfx90a",)),
+    ("gfx90c-igpu", ("gfx90c",)),
+    ("gfx94X-all", ("gfx940", "gfx941", "gfx942")),
+    ("gfx94X-dcgpu", ("gfx940", "gfx941", "gfx942")),
+    ("gfx950-all", ("gfx950",)),
+    ("gfx950-dcgpu", ("gfx950",)),
+    ("gfx101X-all", ("gfx1010", "gfx1011", "gfx1012", "gfx1013")),
+    ("gfx101X-dgpu", ("gfx1010", "gfx1011", "gfx1012", "gfx1013")),
+    (
+        "gfx103X-all",
+        (
+            "gfx1030",
+            "gfx1031",
+            "gfx1032",
+            "gfx1033",
+            "gfx1034",
+            "gfx1035",
+            "gfx1036",
+        ),
+    ),
+    ("gfx103X-dgpu", ("gfx1030", "gfx1031", "gfx1032", "gfx1034")),
+    ("gfx103X-igpu", ("gfx1033", "gfx1035", "gfx1036")),
+    ("gfx110X-all", ("gfx1100", "gfx1101", "gfx1102", "gfx1103")),
+    ("gfx110X-dgpu", ("gfx1100", "gfx1101", "gfx1102")),
+    ("gfx110X-igpu", ("gfx1103",)),
+    ("gfx115X-all", ("gfx1150", "gfx1151", "gfx1152", "gfx1153")),
+    ("gfx115X-igpu", ("gfx1150", "gfx1151", "gfx1152", "gfx1153")),
+    ("gfx117X-all", ("gfx1170", "gfx1171", "gfx1172")),
+    ("gfx120X-all", ("gfx1200", "gfx1201")),
+    ("gfx125X-all", ("gfx1250", "gfx1251")),
+    (
+        "igpu-all",
+        (
+            "gfx90c",
+            "gfx1033",
+            "gfx1035",
+            "gfx1036",
+            "gfx1103",
+            "gfx1150",
+            "gfx1151",
+            "gfx1152",
+            "gfx1153",
+        ),
+    ),
+)
+
+
+def find_repo_root():
+    current = Path(__file__).resolve()
+    while current != current.parent:
+        if (current / "runtime" / "src" / "iree").exists():
+            return current
+        current = current.parent
+    print("error: could not find IREE repository root", file=sys.stderr)
+    sys.exit(1)
+
+
+def append_unique(values, new_values):
+    for value in new_values:
+        if value not in values:
+            values.append(value)
+
+
+def exact_targets():
+    return [exact_target for exact_target, _ in EXACT_TARGET_CODE_OBJECTS]
+
+
+def code_object_targets():
+    values = []
+    for _, code_object_target in EXACT_TARGET_CODE_OBJECTS:
+        append_unique(values, [code_object_target])
+    return values
+
+
+def family_targets(targets):
+    if targets is ALL_EXACT_TARGETS:
+        return exact_targets()
+    return list(targets)
+
+
+def target_family_names():
+    return [family for family, _ in TARGET_FAMILIES]
+
+
+def validate_target_map():
+    exact = exact_targets()
+    if len(set(exact)) != len(exact):
+        raise ValueError("duplicate exact AMDGPU targets in target map")
+
+    families = target_family_names()
+    if len(set(families)) != len(families):
+        raise ValueError("duplicate AMDGPU target families in target map")
+
+    exact_set = set(exact)
+    feature_targets = set(TARGET_FEATURE_SUPPORT)
+    unknown_feature_targets = sorted(feature_targets - exact_set)
+    if unknown_feature_targets:
+        raise ValueError(
+            "target feature support references unknown exact targets: {}".format(
+                ", ".join(unknown_feature_targets)
+            )
+        )
+
+    for family, targets in TARGET_FAMILIES:
+        unknown_targets = sorted(set(family_targets(targets)) - exact_set)
+        if unknown_targets:
+            raise ValueError(
+                "target family {} references unknown exact targets: {}".format(
+                    family, ", ".join(unknown_targets)
+                )
+            )
+
+
+def generated_header(comment_prefix, output_path):
+    return "\n".join(
+        [
+            "{} Generated by build_tools/scripts/amdgpu_target_map.py.".format(
+                comment_prefix
+            ),
+            "{} Do not edit directly; edit the map in that script and regenerate.".format(
+                comment_prefix
+            ),
+            "{} Output: {}".format(comment_prefix, output_path),
+        ]
+    )
+
+
+def bzl_list(name, values):
+    lines = ["{} = [".format(name)]
+    lines.extend(['    "{}",'.format(value) for value in values])
+    lines.append("]")
+    return "\n".join(lines)
+
+
+def bzl_string_dict(name, values):
+    lines = ["{} = {{".format(name)]
+    for key, value in values:
+        lines.append('    "{}": "{}",'.format(key, value))
+    lines.append("}")
+    return "\n".join(lines)
+
+
+def bzl_family_dict(name):
+    lines = ["{} = {{".format(name)]
+    for family, targets in TARGET_FAMILIES:
+        values = family_targets(targets)
+        if targets is ALL_EXACT_TARGETS:
+            lines.append(
+                '    "{}": IREE_HAL_AMDGPU_DEVICE_LIBRARY_EXACT_TARGETS,'.format(family)
+            )
+        elif len(values) == 1:
+            lines.append('    "{}": ["{}"],'.format(family, values[0]))
+        else:
+            lines.append('    "{}": ['.format(family))
+            lines.extend(['        "{}",'.format(value) for value in values])
+            lines.append("    ],")
+    lines.append("}")
+    return "\n".join(lines)
+
+
+def render_bzl():
+    output_path = "runtime/src/iree/hal/drivers/amdgpu/device/binaries/target_map.bzl"
+    return (
+        "\n\n".join(
+            [
+                generated_header("#", output_path),
+                bzl_list(
+                    "IREE_HAL_AMDGPU_DEVICE_LIBRARY_DEFAULT_TARGETS",
+                    DEFAULT_TARGET_SELECTIONS,
+                ),
+                bzl_list(
+                    "IREE_HAL_AMDGPU_DEVICE_LIBRARY_EXACT_TARGETS",
+                    exact_targets(),
+                ),
+                bzl_list(
+                    "IREE_HAL_AMDGPU_DEVICE_LIBRARY_CODE_OBJECT_TARGETS",
+                    code_object_targets(),
+                ),
+                bzl_string_dict(
+                    "IREE_HAL_AMDGPU_DEVICE_LIBRARY_EXACT_TARGET_CODE_OBJECTS",
+                    EXACT_TARGET_CODE_OBJECTS,
+                ),
+                bzl_list(
+                    "IREE_HAL_AMDGPU_DEVICE_LIBRARY_TARGET_FAMILY_NAMES",
+                    target_family_names(),
+                ),
+                bzl_family_dict("IREE_HAL_AMDGPU_DEVICE_LIBRARY_TARGET_FAMILIES"),
+            ]
+        )
+        + "\n"
+    )
+
+
+def cmake_list(name, values):
+    lines = ["set({}".format(name)]
+    lines.extend(['  "{}"'.format(value) for value in values])
+    lines.append(")")
+    return "\n".join(lines)
+
+
+def cmake_identifier(value):
+    return re.sub(r"[^A-Za-z0-9_]", "_", value)
+
+
+def render_cmake():
+    output_path = "runtime/src/iree/hal/drivers/amdgpu/device/binaries/target_map.cmake"
+    lines = [
+        generated_header("#", output_path),
+        "",
+        cmake_list("_IREE_HAL_AMDGPU_DEVICE_TARGETS", exact_targets()),
+        "",
+        cmake_list(
+            "_IREE_HAL_AMDGPU_DEVICE_CODE_OBJECT_TARGETS", code_object_targets()
+        ),
+        "",
+    ]
+    for exact_target, code_object_target in EXACT_TARGET_CODE_OBJECTS:
+        lines.append(
+            'set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_{} "{}")'.format(
+                exact_target, code_object_target
+            )
+        )
+    lines.extend(
+        [
+            "",
+            cmake_list(
+                "_IREE_HAL_AMDGPU_DEVICE_TARGET_FAMILIES", target_family_names()
+            ),
+            "",
+        ]
+    )
+    for family, targets in TARGET_FAMILIES:
+        var_name = "_IREE_HAL_AMDGPU_DEVICE_TARGET_FAMILY_{}".format(
+            cmake_identifier(family)
+        )
+        if targets is ALL_EXACT_TARGETS:
+            lines.append("set({}".format(var_name))
+            lines.append("  ${_IREE_HAL_AMDGPU_DEVICE_TARGETS}")
+            lines.append(")")
+        else:
+            lines.append(cmake_list(var_name, family_targets(targets)))
+    lines.append("")
+    return "\n".join(lines)
+
+
+def render_target_id_inl():
+    output_path = "runtime/src/iree/hal/drivers/amdgpu/util/target_id_map.inl"
+    lines = [
+        generated_header("//", output_path),
+        "//",
+        "// Included inside iree_hal_amdgpu_target_id_mappings.",
+        "",
+        "// clang-format off",
+    ]
+    feature_flag_names = {
+        FEATURE_SRAMECC: "IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_SRAMECC",
+        FEATURE_XNACK: "IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_XNACK",
+    }
+    for exact_target, code_object_target in EXACT_TARGET_CODE_OBJECTS:
+        features = TARGET_FEATURE_SUPPORT.get(exact_target, ())
+        feature_flags = " | ".join(feature_flag_names[feature] for feature in features)
+        if not feature_flags:
+            feature_flags = "IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_NONE"
+        lines.append(
+            '{{IREE_SVL("{}"), IREE_SVL("{}"), {}}},'.format(
+                exact_target, code_object_target, feature_flags
+            )
+        )
+    lines.append("")
+    return "\n".join(lines)
+
+
+def generated_outputs(repo_root):
+    binary_output_dir = (
+        repo_root / "runtime/src/iree/hal/drivers/amdgpu/device/binaries"
+    )
+    util_output_dir = repo_root / "runtime/src/iree/hal/drivers/amdgpu/util"
+    return {
+        binary_output_dir / "target_map.bzl": render_bzl(),
+        binary_output_dir / "target_map.cmake": render_cmake(),
+        util_output_dir / "target_id_map.inl": render_target_id_inl(),
+    }
+
+
+def check_outputs(repo_root, outputs):
+    failed = False
+    for path, content in outputs.items():
+        if not path.exists():
+            print("error: {} does not exist".format(path), file=sys.stderr)
+            failed = True
+            continue
+        existing = path.read_text()
+        if existing == content:
+            continue
+        rel_path = path.relative_to(repo_root)
+        print("error: {} is out of date".format(rel_path), file=sys.stderr)
+        diff = difflib.unified_diff(
+            existing.splitlines(keepends=True),
+            content.splitlines(keepends=True),
+            fromfile=str(rel_path),
+            tofile=str(rel_path) + " (generated)",
+        )
+        sys.stderr.writelines(diff)
+        failed = True
+    if failed:
+        print(
+            "Run 'python build_tools/scripts/amdgpu_target_map.py' to regenerate.",
+            file=sys.stderr,
+        )
+        return 1
+    print("AMDGPU target map generated files are up to date.")
+    return 0
+
+
+def write_outputs(outputs):
+    for path, content in outputs.items():
+        path.parent.mkdir(parents=True, exist_ok=True)
+        path.write_text(content)
+        print("Wrote {}".format(path))
+    return 0
+
+
+def main():
+    parser = argparse.ArgumentParser(
+        description="Generate AMDGPU device library target map fragments."
+    )
+    parser.add_argument(
+        "--check",
+        action="store_true",
+        help="Check that generated files are up to date without modifying them.",
+    )
+    args = parser.parse_args()
+
+    validate_target_map()
+    repo_root = find_repo_root()
+    outputs = generated_outputs(repo_root)
+    if args.check:
+        return check_outputs(repo_root, outputs)
+    return write_outputs(outputs)
+
+
+if __name__ == "__main__":
+    sys.exit(main())
diff --git a/build_tools/third_party/hsa-runtime-headers/BUILD.overlay b/build_tools/third_party/hsa-runtime-headers/BUILD.overlay
index b3e0b85..ce568bb 100644
--- a/build_tools/third_party/hsa-runtime-headers/BUILD.overlay
+++ b/build_tools/third_party/hsa-runtime-headers/BUILD.overlay
@@ -9,6 +9,7 @@
 cc_library(
     name = "hsa_runtime_headers",
     hdrs = glob([
+        "include/aqlprofile-sdk/*.h",
         "include/hsa/*.h",
     ]),
     include_prefix = "third_party/hsa-runtime-headers/",
diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp
index cd65e7b..022a640 100644
--- a/compiler/plugins/target/ROCM/ROCMTarget.cpp
+++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp
@@ -75,6 +75,108 @@
   HSACO,
 };
 
+enum class AMDGPUTargetFeatureMode {
+  // Feature mode is not specified by the target ID.
+  Any,
+  // Feature is explicitly disabled.
+  Off,
+  // Feature is explicitly enabled.
+  On,
+};
+
+struct AMDGPUTargetFeatureModes {
+  // SRAM ECC target feature mode.
+  AMDGPUTargetFeatureMode sramecc = AMDGPUTargetFeatureMode::Any;
+  // XNACK target feature mode.
+  AMDGPUTargetFeatureMode xnack = AMDGPUTargetFeatureMode::Any;
+};
+
+static LogicalResult
+setAMDGPUTargetFeatureMode(Location loc, StringRef featureName,
+                           AMDGPUTargetFeatureMode featureMode,
+                           AMDGPUTargetFeatureMode &targetFeatureMode) {
+  if (targetFeatureMode != AMDGPUTargetFeatureMode::Any) {
+    return emitError(loc, "duplicate ROCM target feature '")
+           << featureName << "'";
+  }
+  targetFeatureMode = featureMode;
+  return success();
+}
+
+static FailureOr<AMDGPUTargetFeatureModes>
+parseAMDGPUTargetFeatureModes(Location loc, StringRef targetFeatures) {
+  AMDGPUTargetFeatureModes modes;
+  SmallVector<StringRef> features;
+  llvm::SplitString(targetFeatures, features, ",");
+  for (StringRef rawFeature : features) {
+    AMDGPUTargetFeatureMode featureMode = AMDGPUTargetFeatureMode::Any;
+    StringRef feature = rawFeature;
+    if (feature.consume_front("+")) {
+      featureMode = AMDGPUTargetFeatureMode::On;
+    } else if (feature.consume_front("-")) {
+      featureMode = AMDGPUTargetFeatureMode::Off;
+    } else {
+      emitError(loc, "ROCM target feature must be prefixed with '+' or '-'; "
+                     "but seen '")
+          << rawFeature << "'";
+      return failure();
+    }
+    if (feature == "sramecc") {
+      if (failed(setAMDGPUTargetFeatureMode(loc, feature, featureMode,
+                                            modes.sramecc))) {
+        return failure();
+      }
+    } else if (feature == "xnack") {
+      if (failed(setAMDGPUTargetFeatureMode(loc, feature, featureMode,
+                                            modes.xnack))) {
+        return failure();
+      }
+    } else {
+      // We only support these two features to be set explicitly. Features like
+      // wavefrontsize are controlled and tuned by the compiler.
+      emitError(loc,
+                "ROCM target feature can only be 'sramecc' or 'xnack'; but "
+                "seen '")
+          << feature << "'";
+      return failure();
+    }
+  }
+  return modes;
+}
+
+static void appendAMDGPUTargetFeatureSuffix(std::string &targetID,
+                                            StringRef featureName,
+                                            AMDGPUTargetFeatureMode mode) {
+  switch (mode) {
+  case AMDGPUTargetFeatureMode::Any:
+    return;
+  case AMDGPUTargetFeatureMode::Off:
+    targetID += ":";
+    targetID += featureName;
+    targetID += "-";
+    return;
+  case AMDGPUTargetFeatureMode::On:
+    targetID += ":";
+    targetID += featureName;
+    targetID += "+";
+    return;
+  }
+}
+
+static FailureOr<std::string> buildAMDGPUTargetID(Location loc,
+                                                  StringRef targetArch,
+                                                  StringRef targetFeatures) {
+  FailureOr<AMDGPUTargetFeatureModes> modes =
+      parseAMDGPUTargetFeatureModes(loc, targetFeatures);
+  if (failed(modes)) {
+    return failure();
+  }
+  std::string targetID = targetArch.str();
+  appendAMDGPUTargetFeatureSuffix(targetID, "sramecc", modes->sramecc);
+  appendAMDGPUTargetFeatureSuffix(targetID, "xnack", modes->xnack);
+  return targetID;
+}
+
 struct ROCMOptions {
   std::string target = "";
   std::string targetFeatures = "";
@@ -264,24 +366,9 @@
       return emitError(builder.getUnknownLoc(), "Unknown ROCM target '")
              << target << "'";
     }
-    SmallVector<StringRef> features;
-    llvm::SplitString(targetFeatures, features, ",");
-    for (StringRef f : features) {
-      if (!(f.starts_with("+") || f.starts_with("-"))) {
-        return emitError(builder.getUnknownLoc(),
-                         "ROCM target feature must be prefixed with '+' or "
-                         "'-'; but seen '")
-               << f << "'";
-      }
-      StringRef feature = f.substr(1);
-      if (feature != "sramecc" && feature != "xnack") {
-        // We only support these two features to be set explicitly. Features
-        // like wavefrontsize is controlled and tuned by the compiler.
-        return emitError(builder.getUnknownLoc(),
-                         "ROCM target feature can only be 'sramecc' or "
-                         "'xnack'; but seen '")
-               << feature << "'";
-      }
+    if (failed(parseAMDGPUTargetFeatureModes(builder.getUnknownLoc(),
+                                             targetFeatures))) {
+      return failure();
     }
     return success();
   }
@@ -392,7 +479,13 @@
     addConfig("abi", b.getStringAttr(deviceID));
     std::string format;
     if (deviceID == "amdgpu") {
-      format = targetOptions.target;
+      FailureOr<std::string> targetID =
+          buildAMDGPUTargetID(b.getUnknownLoc(), targetOptions.target,
+                              targetOptions.targetFeatures);
+      if (failed(targetID)) {
+        return nullptr;
+      }
+      format = *targetID;
     } else {
       format = "rocm-hsaco-fb"; // legacy HIP
     }
@@ -922,6 +1015,7 @@
     }
 
     // Wrap the HSACO ELF binary in the requested container type (if any).
+    StringAttr executableBinaryFormat = variantOp.getTarget().getFormat();
     FailureOr<DenseIntElementsAttr> binaryContainer;
     switch (containerType) {
     case ContainerType::Auto: {
@@ -930,8 +1024,14 @@
       break;
     }
     case ContainerType::AMDGPU: {
+      FailureOr<std::string> targetID =
+          buildAMDGPUTargetID(variantOp.getLoc(), targetArch, targetFeatures);
+      if (failed(targetID)) {
+        return failure();
+      }
+      executableBinaryFormat = executableBuilder.getStringAttr(*targetID);
       binaryContainer = serializeAMDGPUBinaryContainer(
-          serializationOptions, variantOp, exportOps, targetHSACO);
+          serializationOptions, variantOp, exportOps, *targetID, targetHSACO);
       break;
     }
     case ContainerType::HIP: {
@@ -957,7 +1057,7 @@
     // Add the binary data to the target executable.
     auto binaryOp = IREE::HAL::ExecutableBinaryOp::create(
         executableBuilder, variantOp.getLoc(), variantOp.getSymName(),
-        variantOp.getTarget().getFormat(), binaryContainer.value());
+        executableBinaryFormat, binaryContainer.value());
     binaryOp.setMimeTypeAttr(
         executableBuilder.getStringAttr("application/x-flatbuffers"));
 
@@ -968,7 +1068,7 @@
   FailureOr<DenseIntElementsAttr> serializeAMDGPUBinaryContainer(
       const SerializationOptions &serializationOptions,
       IREE::HAL::ExecutableVariantOp variantOp,
-      ArrayRef<IREE::HAL::ExecutableExportOp> exportOps,
+      ArrayRef<IREE::HAL::ExecutableExportOp> exportOps, StringRef targetID,
       StringRef hsacoModule) {
     iree_compiler::FlatbufferBuilder builder;
     iree_hal_amdgpu_ExecutableDef_start_as_root(builder);
@@ -1043,7 +1143,7 @@
     }
     auto exportsRef = builder.createOffsetVecDestructive(exportRefs);
 
-    auto isaRef = builder.createString(variantOp.getTarget().getFormat());
+    auto isaRef = builder.createString(targetID);
     iree_hal_amdgpu_ExecutableDef_isa_add(builder, isaRef);
     iree_hal_amdgpu_ExecutableDef_exports_add(builder, exportsRef);
     iree_hal_amdgpu_ExecutableDef_modules_add(builder, modulesRef);
diff --git a/runtime/src/iree/hal/command_buffer.c b/runtime/src/iree/hal/command_buffer.c
index a2dd06a..1139486 100644
--- a/runtime/src/iree/hal/command_buffer.c
+++ b/runtime/src/iree/hal/command_buffer.c
@@ -123,6 +123,8 @@
       {IREE_HAL_COMMAND_BUFFER_MODE_UNRETAINED, IREE_SVL("UNRETAINED")},
       {IREE_HAL_COMMAND_BUFFER_MODE_RETAIN_PROFILE_METADATA,
        IREE_SVL("RETAIN_PROFILE_METADATA")},
+      {IREE_HAL_COMMAND_BUFFER_MODE_RETAIN_DISPATCH_METADATA,
+       IREE_SVL("RETAIN_DISPATCH_METADATA")},
   };
   return iree_bitfield_format_inline(value, IREE_ARRAYSIZE(mappings), mappings,
                                      out_temp);
diff --git a/runtime/src/iree/hal/command_buffer.h b/runtime/src/iree/hal/command_buffer.h
index d33ddb1..e594b04 100644
--- a/runtime/src/iree/hal/command_buffer.h
+++ b/runtime/src/iree/hal/command_buffer.h
@@ -77,8 +77,16 @@
   // This makes profiling possible for the command buffer but does not enable
   // profiling by itself. Implementations may spend additional recording-time
   // CPU and memory to retain command operation metadata and compact sidecars
-  // used by profiling sessions.
+  // used by profiling sessions. This is intended for rich host profiling that
+  // needs source/correlation records, not minimal production timestamp capture.
   IREE_HAL_COMMAND_BUFFER_MODE_RETAIN_PROFILE_METADATA = 1u << 7,
+
+  // Retains compact dispatch metadata required for command-buffer timestamping.
+  // This makes dispatch timestamp capture possible for the command buffer but
+  // does not enable timestamp capture by itself. Implementations may spend
+  // additional recording-time CPU and memory to retain compact per-dispatch
+  // packet/correlation sidecars without requiring full profile metadata.
+  IREE_HAL_COMMAND_BUFFER_MODE_RETAIN_DISPATCH_METADATA = 1u << 8,
 };
 typedef uint32_t iree_hal_command_buffer_mode_t;
 
diff --git a/runtime/src/iree/hal/device.c b/runtime/src/iree/hal/device.c
index b2806f4..123c07f 100644
--- a/runtime/src/iree/hal/device.c
+++ b/runtime/src/iree/hal/device.c
@@ -521,11 +521,11 @@
           IREE_STATUS_INVALID_ARGUMENT,
           "hardware counter set selections require a counter_sets array");
     }
-    if (!iree_hal_device_profiling_options_requests_counter_samples(options)) {
+    if (!iree_hal_device_profiling_options_requests_counters(options)) {
       return iree_make_status(
           IREE_STATUS_INVALID_ARGUMENT,
-          "hardware counter set selections require the counter-samples "
-          "profiling data family");
+          "hardware counter set selections require a counter profiling data "
+          "family");
     }
     for (iree_host_size_t i = 0; i < options->counter_set_count; ++i) {
       const iree_hal_profile_counter_set_selection_t* counter_set =
@@ -548,12 +548,11 @@
       }
     }
   }
-  if (iree_hal_device_profiling_options_requests_counter_samples(options) &&
+  if (iree_hal_device_profiling_options_requests_counters(options) &&
       options->counter_set_count == 0) {
     return iree_make_status(
         IREE_STATUS_INVALID_ARGUMENT,
-        "counter-samples profiling requires at least one counter set "
-        "selection");
+        "counter profiling requires at least one counter set selection");
   }
 
   const bool data_requested =
diff --git a/runtime/src/iree/hal/device.h b/runtime/src/iree/hal/device.h
index 3a0692d..a6f7996 100644
--- a/runtime/src/iree/hal/device.h
+++ b/runtime/src/iree/hal/device.h
@@ -314,10 +314,15 @@
   // Native timeline semaphore support (not binary emulation).
   IREE_HAL_DEVICE_CAPABILITY_TIMELINE_SEMAPHORES = 1ull << 0,
 
-  // Memory model capabilities.
+  // Device memory is transparently accessible through one address space without
+  // driver-specific per-range access programming.
   IREE_HAL_DEVICE_CAPABILITY_UNIFIED_MEMORY = 1ull << 1,
+  // Host accesses to device-visible memory are coherent without explicit
+  // application-managed cache maintenance.
   IREE_HAL_DEVICE_CAPABILITY_HOST_COHERENT = 1ull << 2,
-  IREE_HAL_DEVICE_CAPABILITY_PEER_COHERENT = 1ull << 3,  // Same-driver.
+  // Same-driver peer memory accesses are coherent without explicit
+  // application-managed cache maintenance.
+  IREE_HAL_DEVICE_CAPABILITY_PEER_COHERENT = 1ull << 3,
 
   // Transfer capabilities.
   // P2P DMA engine can copy directly between this device and peers.
@@ -332,6 +337,11 @@
   // Example: NVLink with large BAR → PEER_ADDRESSABLE + P2P_COPY.
   // Example: PCIe P2P without BAR mapping → P2P_COPY only.
   IREE_HAL_DEVICE_CAPABILITY_PEER_ADDRESSABLE = 1ull << 10,
+  // Shared virtual addressing is available across devices. This means a
+  // driver can make matching virtual addresses meaningful, but it does not
+  // imply those addresses are accessible by default; some runtimes require
+  // per-allocation or per-range access programming before device use.
+  IREE_HAL_DEVICE_CAPABILITY_SHARED_VIRTUAL_ADDRESS = 1ull << 11,
 
   // Concurrency and atomics.
   IREE_HAL_DEVICE_CAPABILITY_CONCURRENT_SAFE = 1ull << 6,
@@ -341,7 +351,7 @@
   // Isolation (MIG, SR-IOV, etc.).
   IREE_HAL_DEVICE_CAPABILITY_ISOLATED = 1ull << 9,
 
-  // Reserved for future use (bits 10-63).
+  // Reserved for future use (bits 12-63).
 };
 
 // Device capabilities for topology edge construction.
diff --git a/runtime/src/iree/hal/device_test.cc b/runtime/src/iree/hal/device_test.cc
index f8da939..05c13f7 100644
--- a/runtime/src/iree/hal/device_test.cc
+++ b/runtime/src/iree/hal/device_test.cc
@@ -189,6 +189,14 @@
                         Begin(&profiling_options));
 }
 
+TEST_F(DeviceProfilingTest, BeginRejectsCounterRangesWithoutCounterSets) {
+  iree_hal_device_profiling_options_t profiling_options = {0};
+  profiling_options.data_families =
+      IREE_HAL_DEVICE_PROFILING_DATA_COUNTER_RANGES;
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_INVALID_ARGUMENT,
+                        Begin(&profiling_options));
+}
+
 TEST_F(DeviceProfilingTest, BeginRejectsCounterSetWithoutCounterNames) {
   iree_hal_profile_counter_set_selection_t counter_set = {0};
 
diff --git a/runtime/src/iree/hal/drivers/amdgpu/BUILD.bazel b/runtime/src/iree/hal/drivers/amdgpu/BUILD.bazel
index 4b01289..5249475 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/BUILD.bazel
+++ b/runtime/src/iree/hal/drivers/amdgpu/BUILD.bazel
@@ -12,70 +12,182 @@
     licenses = ["notice"],  # Apache 2.0
 )
 
-# TODO(benvanik): implement omitted files.
+iree_runtime_cc_library(
+    name = "queue_affinity",
+    srcs = ["queue_affinity.c"],
+    hdrs = ["queue_affinity.h"],
+    deps = [
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/hal",
+    ],
+)
+
+iree_runtime_cc_library(
+    name = "access_policy",
+    srcs = ["access_policy.c"],
+    hdrs = ["access_policy.h"],
+    deps = [
+        ":queue_affinity",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/hal",
+        "//runtime/src/iree/hal/drivers/amdgpu/util:libhsa",
+        "//runtime/src/iree/hal/drivers/amdgpu/util:topology",
+    ],
+)
+
+iree_runtime_cc_library(
+    name = "profile_events",
+    srcs = ["profile_events.c"],
+    hdrs = ["profile_events.h"],
+    deps = [
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/base/threading",
+        "//runtime/src/iree/hal",
+        "//runtime/src/iree/hal/utils:profile_event_ring",
+    ],
+)
+
 iree_runtime_cc_library(
     name = "amdgpu",
     srcs = [
         "allocator.c",
         "allocator.h",
+        "aql_block_processor.c",
+        "aql_block_processor.h",
+        "aql_block_processor_profile.c",
+        "aql_block_processor_profile.h",
+        "aql_block_processor_timestamp.c",
+        "aql_block_processor_timestamp.h",
+        "aql_command_buffer.c",
+        "aql_command_buffer.h",
+        "aql_command_buffer_profile.c",
+        "aql_command_buffer_profile.h",
+        "aql_prepublished_kernarg_storage.h",
+        "aql_program_builder.c",
+        "aql_program_builder.h",
+        "aql_program_validation.c",
+        "aql_program_validation.h",
         "buffer.c",
         "buffer.h",
-        "buffer_pool.c",
-        "buffer_pool.h",
-        "channel.c",
-        "channel.h",
-        "command_buffer.c",
-        "command_buffer.h",
-        "device_queue.c",
-        "device_queue.h",
         "driver.c",
         "driver.h",
-        "event.c",
-        "event.h",
         "executable.c",
         "executable.h",
         "executable_cache.c",
         "executable_cache.h",
         "host_queue.c",
         "host_queue.h",
-        "host_service.c",
-        "host_service.h",
+        "host_queue_blit.c",
+        "host_queue_blit.h",
+        "host_queue_command_buffer.c",
+        "host_queue_command_buffer.h",
+        "host_queue_command_buffer_block.c",
+        "host_queue_command_buffer_block.h",
+        "host_queue_command_buffer_packet.c",
+        "host_queue_command_buffer_packet.h",
+        "host_queue_command_buffer_profile.c",
+        "host_queue_command_buffer_profile.h",
+        "host_queue_command_buffer_replay.c",
+        "host_queue_command_buffer_replay.h",
+        "host_queue_command_buffer_scratch.h",
+        "host_queue_dispatch.c",
+        "host_queue_dispatch.h",
+        "host_queue_file.c",
+        "host_queue_file.h",
+        "host_queue_host_call.c",
+        "host_queue_host_call.h",
+        "host_queue_memory.c",
+        "host_queue_memory.h",
+        "host_queue_pending.c",
+        "host_queue_pending.h",
+        "host_queue_pending_operation.h",
+        "host_queue_pending_payload.c",
+        "host_queue_policy.c",
+        "host_queue_policy.h",
+        "host_queue_profile.c",
+        "host_queue_profile.h",
+        "host_queue_profile_events.c",
+        "host_queue_profile_events.h",
+        "host_queue_staging.c",
+        "host_queue_staging.h",
+        "host_queue_submission.c",
+        "host_queue_submission.h",
+        "host_queue_timestamp.c",
+        "host_queue_timestamp.h",
+        "host_queue_waits.c",
+        "host_queue_waits.h",
         "logical_device.c",
         "logical_device.h",
         "physical_device.c",
         "physical_device.h",
+        "physical_device_capabilities.c",
+        "physical_device_capabilities.h",
+        "profile_aqlprofile.c",
+        "profile_aqlprofile.h",
+        "profile_counters.c",
+        "profile_counters.h",
+        "profile_device_metrics.c",
+        "profile_device_metrics.h",
+        "profile_device_metrics_linux.c",
+        "profile_device_metrics_source.h",
+        "profile_metadata.c",
+        "profile_metadata.h",
+        "profile_traces.c",
+        "profile_traces.h",
         "semaphore.c",
         "semaphore.h",
-        "semaphore_pool.c",
-        "semaphore_pool.h",
+        "slab_provider.c",
+        "slab_provider.h",
         "system.c",
         "system.h",
-        # "trace_buffer.c",
-        # "trace_buffer.h",
-        "virtual_queue.c",
+        "transient_buffer.c",
+        "transient_buffer.h",
         "virtual_queue.h",
     ],
     hdrs = [
         "api.h",
     ],
     deps = [
+        ":access_policy",
+        ":profile_events",
+        ":queue_affinity",
         "//runtime/src/iree/async",
         "//runtime/src/iree/async/util:proactor_pool",
         "//runtime/src/iree/base",
+        "//runtime/src/iree/base:core_headers",
         "//runtime/src/iree/base/internal",
         "//runtime/src/iree/base/internal:arena",
         "//runtime/src/iree/base/internal/flatcc:parsing",
         "//runtime/src/iree/base/threading",
         "//runtime/src/iree/base/threading:thread",
         "//runtime/src/iree/hal",
-        "//runtime/src/iree/hal/drivers/amdgpu/device:binaries",
+        "//runtime/src/iree/hal/drivers/amdgpu/abi",
+        "//runtime/src/iree/hal/drivers/amdgpu/device:blit",
+        "//runtime/src/iree/hal/drivers/amdgpu/device:dispatch",
         "//runtime/src/iree/hal/drivers/amdgpu/device:headers",
-        "//runtime/src/iree/hal/drivers/amdgpu/util",
+        "//runtime/src/iree/hal/drivers/amdgpu/device:timestamp",
+        "//runtime/src/iree/hal/drivers/amdgpu/device/binaries:toc",
+        "//runtime/src/iree/hal/drivers/amdgpu/util:block_pool",
+        "//runtime/src/iree/hal/drivers/amdgpu/util:code_object_target",
+        "//runtime/src/iree/hal/drivers/amdgpu/util:device_clock",
+        "//runtime/src/iree/hal/drivers/amdgpu/util:device_library",
+        "//runtime/src/iree/hal/drivers/amdgpu/util:hsaco_metadata",
+        "//runtime/src/iree/hal/drivers/amdgpu/util:info",
+        "//runtime/src/iree/hal/drivers/amdgpu/util:kfd",
+        "//runtime/src/iree/hal/drivers/amdgpu/util:libaqlprofile",
         "//runtime/src/iree/hal/drivers/amdgpu/util:libhsa",
+        "//runtime/src/iree/hal/drivers/amdgpu/util:packet",
+        "//runtime/src/iree/hal/drivers/amdgpu/util:queue_primitives",
+        "//runtime/src/iree/hal/drivers/amdgpu/util:target_id",
         "//runtime/src/iree/hal/drivers/amdgpu/util:topology",
+        "//runtime/src/iree/hal/drivers/amdgpu/util:vmem",
+        "//runtime/src/iree/hal/memory:passthrough_pool",
+        "//runtime/src/iree/hal/memory:slab_provider",
+        "//runtime/src/iree/hal/memory:tlsf_pool",
+        "//runtime/src/iree/hal/memory:tracing",
+        "//runtime/src/iree/hal/utils:elf_format",
         "//runtime/src/iree/hal/utils:executable_debug_info",
         "//runtime/src/iree/hal/utils:executable_header",
-        "//runtime/src/iree/hal/utils:file_transfer",
         "//runtime/src/iree/hal/utils:files",
         "//runtime/src/iree/hal/utils:resource_set",
         "//runtime/src/iree/schemas:amdgpu_executable_def_c_fbs",
@@ -88,63 +200,105 @@
 iree_runtime_cc_library(
     name = "headers",
     hdrs = [
-        "allocator.h",
+        "aql_block_processor.h",
+        "aql_block_processor_timestamp.h",
+        "aql_command_buffer.h",
+        "aql_command_buffer_profile.h",
+        "aql_prepublished_kernarg_storage.h",
+        "aql_program_builder.h",
+        "aql_program_validation.h",
         "buffer.h",
-        "buffer_pool.h",
-        "channel.h",
-        "command_buffer.h",
-        "device_queue.h",
         "driver.h",
-        "event.h",
         "executable.h",
         "executable_cache.h",
         "host_queue.h",
-        "host_service.h",
+        "host_queue_blit.h",
+        "host_queue_command_buffer.h",
+        "host_queue_command_buffer_block.h",
+        "host_queue_command_buffer_profile.h",
+        "host_queue_command_buffer_replay.h",
+        "host_queue_dispatch.h",
+        "host_queue_file.h",
+        "host_queue_host_call.h",
+        "host_queue_memory.h",
+        "host_queue_pending.h",
+        "host_queue_policy.h",
+        "host_queue_profile.h",
+        "host_queue_profile_events.h",
+        "host_queue_staging.h",
+        "host_queue_submission.h",
+        "host_queue_timestamp.h",
+        "host_queue_waits.h",
         "logical_device.h",
         "physical_device.h",
-        # "queue.h",
+        "physical_device_capabilities.h",
+        "profile_aqlprofile.h",
+        "profile_counters.h",
+        "profile_device_metrics.h",
+        "profile_events.h",
+        "profile_metadata.h",
+        "profile_traces.h",
         "semaphore.h",
-        "semaphore_pool.h",
+        "slab_provider.h",
         "system.h",
-        # "trace_buffer.h",
+        "transient_buffer.h",
+        "virtual_queue.h",
     ],
     deps = [
+        ":profile_events",
+        ":queue_affinity",
         "//runtime/src/iree/base",
         "//runtime/src/iree/base/internal",
         "//runtime/src/iree/base/internal:arena",
+        "//runtime/src/iree/base/threading",
         "//runtime/src/iree/hal",
-        "//runtime/src/iree/hal/drivers/amdgpu/device:binaries",
+        "//runtime/src/iree/hal/drivers/amdgpu/abi",
+        "//runtime/src/iree/hal/drivers/amdgpu/device:dispatch",
         "//runtime/src/iree/hal/drivers/amdgpu/device:headers",
-        "//runtime/src/iree/hal/drivers/amdgpu/util",
+        "//runtime/src/iree/hal/drivers/amdgpu/device/binaries:toc",
+        "//runtime/src/iree/hal/drivers/amdgpu/util:block_pool",
+        "//runtime/src/iree/hal/drivers/amdgpu/util:device_library",
+        "//runtime/src/iree/hal/drivers/amdgpu/util:info",
         "//runtime/src/iree/hal/drivers/amdgpu/util:libhsa",
+        "//runtime/src/iree/hal/drivers/amdgpu/util:packet",
+        "//runtime/src/iree/hal/drivers/amdgpu/util:queue_primitives",
+        "//runtime/src/iree/hal/drivers/amdgpu/util:target_id",
         "//runtime/src/iree/hal/drivers/amdgpu/util:topology",
+        "//runtime/src/iree/hal/drivers/amdgpu/util:vmem",
+        "//runtime/src/iree/hal/memory:slab_provider",
+        "//runtime/src/iree/hal/memory:tracing",
     ],
 )
 
 iree_runtime_cc_test(
-    name = "buffer_pool_test",
-    srcs = ["buffer_pool_test.cc"],
+    name = "access_policy_test",
+    srcs = ["access_policy_test.cc"],
     group = "iree-hal-drivers-amdgpu-tests",
-    tags = [
-        "driver=amdgpu",
-        "nodocker",
-    ],
     deps = [
-        ":amdgpu",
-        ":headers",
+        ":access_policy",
         "//runtime/src/iree/base",
-        "//runtime/src/iree/hal/drivers/amdgpu/device:headers",
-        "//runtime/src/iree/hal/drivers/amdgpu/util",
-        "//runtime/src/iree/hal/drivers/amdgpu/util:libhsa",
-        "//runtime/src/iree/hal/drivers/amdgpu/util:topology",
+        "//runtime/src/iree/hal",
         "//runtime/src/iree/testing:gtest",
         "//runtime/src/iree/testing:gtest_main",
     ],
 )
 
 iree_runtime_cc_test(
-    name = "host_service_test",
-    srcs = ["host_service_test.cc"],
+    name = "queue_affinity_test",
+    srcs = ["queue_affinity_test.cc"],
+    group = "iree-hal-drivers-amdgpu-tests",
+    deps = [
+        ":queue_affinity",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/hal",
+        "//runtime/src/iree/testing:gtest",
+        "//runtime/src/iree/testing:gtest_main",
+    ],
+)
+
+iree_runtime_cc_test(
+    name = "allocator_test",
+    srcs = ["allocator_test.cc"],
     group = "iree-hal-drivers-amdgpu-tests",
     tags = [
         "driver=amdgpu",
@@ -155,9 +309,7 @@
         ":headers",
         "//runtime/src/iree/base",
         "//runtime/src/iree/hal",
-        "//runtime/src/iree/hal/drivers/amdgpu/device:headers",
-        "//runtime/src/iree/hal/drivers/amdgpu/util",
-        "//runtime/src/iree/hal/drivers/amdgpu/util:libhsa",
+        "//runtime/src/iree/hal/cts/util:test_base",
         "//runtime/src/iree/hal/drivers/amdgpu/util:topology",
         "//runtime/src/iree/testing:gtest",
         "//runtime/src/iree/testing:gtest_main",
@@ -165,8 +317,8 @@
 )
 
 iree_runtime_cc_test(
-    name = "semaphore_pool_test",
-    srcs = ["semaphore_pool_test.cc"],
+    name = "aql_block_processor_test",
+    srcs = ["aql_block_processor_test.cc"],
     group = "iree-hal-drivers-amdgpu-tests",
     tags = [
         "driver=amdgpu",
@@ -176,10 +328,273 @@
         ":amdgpu",
         ":headers",
         "//runtime/src/iree/base",
+        "//runtime/src/iree/base/internal:arena",
+        "//runtime/src/iree/hal",
+        "//runtime/src/iree/hal/drivers/amdgpu/abi",
         "//runtime/src/iree/hal/drivers/amdgpu/device:headers",
-        "//runtime/src/iree/hal/drivers/amdgpu/util",
-        "//runtime/src/iree/hal/drivers/amdgpu/util:libhsa",
-        "//runtime/src/iree/hal/drivers/amdgpu/util:topology",
+        "//runtime/src/iree/hal/drivers/amdgpu/util:packet",
+        "//runtime/src/iree/testing:gtest",
+        "//runtime/src/iree/testing:gtest_main",
+    ],
+)
+
+iree_runtime_cc_test(
+    name = "aql_block_processor_timestamp_test",
+    srcs = ["aql_block_processor_timestamp_test.cc"],
+    group = "iree-hal-drivers-amdgpu-tests",
+    tags = [
+        "driver=amdgpu",
+        "nodocker",
+    ],
+    deps = [
+        ":amdgpu",
+        ":headers",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/hal/drivers/amdgpu/abi",
+        "//runtime/src/iree/hal/drivers/amdgpu/device:headers",
+        "//runtime/src/iree/hal/drivers/amdgpu/util:packet",
+        "//runtime/src/iree/testing:gtest",
+        "//runtime/src/iree/testing:gtest_main",
+    ],
+)
+
+iree_runtime_cc_test(
+    name = "aql_command_buffer_test",
+    srcs = ["aql_command_buffer_test.cc"],
+    group = "iree-hal-drivers-amdgpu-tests",
+    tags = [
+        "driver=amdgpu",
+        "nodocker",
+    ],
+    deps = [
+        ":amdgpu",
+        ":headers",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/base/internal:arena",
+        "//runtime/src/iree/hal",
+        "//runtime/src/iree/hal/drivers/amdgpu/abi",
+        "//runtime/src/iree/testing:gtest",
+        "//runtime/src/iree/testing:gtest_main",
+    ],
+)
+
+iree_runtime_cc_test(
+    name = "aql_program_builder_test",
+    srcs = ["aql_program_builder_test.cc"],
+    group = "iree-hal-drivers-amdgpu-tests",
+    tags = [
+        "driver=amdgpu",
+        "nodocker",
+    ],
+    deps = [
+        ":amdgpu",
+        ":headers",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/base/internal:arena",
+        "//runtime/src/iree/hal/drivers/amdgpu/abi",
+        "//runtime/src/iree/testing:gtest",
+        "//runtime/src/iree/testing:gtest_main",
+    ],
+)
+
+iree_runtime_cc_test(
+    name = "driver_options_test",
+    srcs = ["driver_options_test.cc"],
+    group = "iree-hal-drivers-amdgpu-tests",
+    deps = [
+        ":amdgpu",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/hal",
+        "//runtime/src/iree/testing:gtest",
+        "//runtime/src/iree/testing:gtest_main",
+    ],
+)
+
+iree_runtime_cc_test(
+    name = "executable_test",
+    srcs = ["executable_test.cc"],
+    group = "iree-hal-drivers-amdgpu-tests",
+    deps = [
+        ":amdgpu",
+        ":headers",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/base:core_headers",
+        "//runtime/src/iree/base/internal/flatcc:building",
+        "//runtime/src/iree/hal",
+        "//runtime/src/iree/schemas:amdgpu_executable_def_c_fbs",
+        "//runtime/src/iree/testing:gtest",
+        "//runtime/src/iree/testing:gtest_main",
+    ],
+)
+
+iree_runtime_cc_test(
+    name = "host_queue_command_buffer_test",
+    srcs = [
+        "host_queue_command_buffer_packet.h",
+        "host_queue_command_buffer_test.cc",
+    ],
+    group = "iree-hal-drivers-amdgpu-tests",
+    tags = [
+        "driver=amdgpu",
+        "nodocker",
+    ],
+    deps = [
+        ":amdgpu",
+        ":headers",
+        ":queue_affinity",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/hal",
+        "//runtime/src/iree/hal/cts/util:test_base",
+        "//runtime/src/iree/hal/drivers/amdgpu/cts:testdata_amdgpu",
+        "//runtime/src/iree/hal/drivers/amdgpu/util:packet",
+        "//runtime/src/iree/testing:gtest",
+        "//runtime/src/iree/testing:gtest_main",
+    ],
+)
+
+iree_runtime_cc_test(
+    name = "host_queue_pending_test",
+    srcs = ["host_queue_pending_test.cc"],
+    group = "iree-hal-drivers-amdgpu-tests",
+    tags = [
+        "driver=amdgpu",
+        "nodocker",
+    ],
+    deps = [
+        ":amdgpu",
+        ":headers",
+        "//runtime/src/iree/async",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/base/internal",
+        "//runtime/src/iree/base/threading",
+        "//runtime/src/iree/hal",
+        "//runtime/src/iree/hal/cts/util:test_base",
+        "//runtime/src/iree/hal/drivers/amdgpu/util:packet",
+        "//runtime/src/iree/hal/memory:fixed_block_pool",
+        "//runtime/src/iree/hal/memory:tlsf_pool",
+        "//runtime/src/iree/testing:gtest",
+        "//runtime/src/iree/testing:gtest_main",
+    ],
+)
+
+iree_runtime_cc_test(
+    name = "host_queue_submission_test",
+    srcs = ["host_queue_submission_test.cc"],
+    group = "iree-hal-drivers-amdgpu-tests",
+    tags = [
+        "driver=amdgpu",
+        "nodocker",
+    ],
+    deps = [
+        ":amdgpu",
+        ":headers",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/hal",
+        "//runtime/src/iree/hal/cts/util:test_base",
+        "//runtime/src/iree/hal/drivers/amdgpu/util:packet",
+        "//runtime/src/iree/testing:gtest",
+        "//runtime/src/iree/testing:gtest_main",
+    ],
+)
+
+iree_runtime_cc_test(
+    name = "host_queue_staging_test",
+    srcs = ["host_queue_staging_test.cc"],
+    group = "iree-hal-drivers-amdgpu-tests",
+    tags = [
+        "driver=amdgpu",
+        "nodocker",
+    ],
+    deps = [
+        ":amdgpu",
+        ":headers",
+        ":queue_affinity",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/hal",
+        "//runtime/src/iree/hal/cts/util:test_base",
+        "//runtime/src/iree/hal/drivers/amdgpu/util:packet",
+        "//runtime/src/iree/io:file_handle",
+        "//runtime/src/iree/testing:gtest",
+        "//runtime/src/iree/testing:gtest_main",
+    ],
+)
+
+iree_runtime_cc_test(
+    name = "physical_device_capabilities_test",
+    srcs = ["physical_device_capabilities_test.cc"],
+    group = "iree-hal-drivers-amdgpu-tests",
+    deps = [
+        ":amdgpu",
+        ":headers",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/hal",
+        "//runtime/src/iree/testing:gtest",
+        "//runtime/src/iree/testing:gtest_main",
+    ],
+)
+
+iree_runtime_cc_test(
+    name = "profile_metadata_test",
+    srcs = ["profile_metadata_test.cc"],
+    group = "iree-hal-drivers-amdgpu-tests",
+    deps = [
+        ":amdgpu",
+        ":headers",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/hal",
+        "//runtime/src/iree/testing:gtest",
+        "//runtime/src/iree/testing:gtest_main",
+    ],
+)
+
+iree_runtime_cc_test(
+    name = "profile_events_test",
+    srcs = ["profile_events_test.cc"],
+    group = "iree-hal-drivers-amdgpu-tests",
+    deps = [
+        ":profile_events",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/hal",
+        "//runtime/src/iree/testing:gtest",
+        "//runtime/src/iree/testing:gtest_main",
+    ],
+)
+
+iree_runtime_cc_test(
+    name = "semaphore_test",
+    srcs = ["semaphore_test.cc"],
+    group = "iree-hal-drivers-amdgpu-tests",
+    tags = [
+        "driver=amdgpu",
+        "nodocker",
+    ],
+    deps = [
+        ":amdgpu",
+        ":headers",
+        "//runtime/src/iree/async",
+        "//runtime/src/iree/async:platform",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/hal",
+        "//runtime/src/iree/testing:gtest",
+        "//runtime/src/iree/testing:gtest_main",
+    ],
+)
+
+iree_runtime_cc_test(
+    name = "slab_provider_test",
+    srcs = ["slab_provider_test.cc"],
+    group = "iree-hal-drivers-amdgpu-tests",
+    tags = [
+        "driver=amdgpu",
+        "nodocker",
+    ],
+    deps = [
+        ":amdgpu",
+        ":headers",
+        "//runtime/src/iree/async",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/hal",
+        "//runtime/src/iree/hal/cts/util:test_base",
         "//runtime/src/iree/testing:gtest",
         "//runtime/src/iree/testing:gtest_main",
     ],
diff --git a/runtime/src/iree/hal/drivers/amdgpu/CMakeLists.txt b/runtime/src/iree/hal/drivers/amdgpu/CMakeLists.txt
index 332f2c9..18f6ba9 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/CMakeLists.txt
+++ b/runtime/src/iree/hal/drivers/amdgpu/CMakeLists.txt
@@ -12,64 +12,188 @@
 
 iree_cc_library(
   NAME
+    queue_affinity
+  HDRS
+    "queue_affinity.h"
+  SRCS
+    "queue_affinity.c"
+  DEPS
+    iree::base
+    iree::hal
+  PUBLIC
+)
+
+iree_cc_library(
+  NAME
+    access_policy
+  HDRS
+    "access_policy.h"
+  SRCS
+    "access_policy.c"
+  DEPS
+    ::queue_affinity
+    iree::base
+    iree::hal
+    iree::hal::drivers::amdgpu::util::libhsa
+    iree::hal::drivers::amdgpu::util::topology
+  PUBLIC
+)
+
+iree_cc_library(
+  NAME
+    profile_events
+  HDRS
+    "profile_events.h"
+  SRCS
+    "profile_events.c"
+  DEPS
+    iree::base
+    iree::base::threading
+    iree::hal
+    iree::hal::utils::profile_event_ring
+  PUBLIC
+)
+
+iree_cc_library(
+  NAME
     amdgpu
   HDRS
     "api.h"
   SRCS
     "allocator.c"
     "allocator.h"
+    "aql_block_processor.c"
+    "aql_block_processor.h"
+    "aql_block_processor_profile.c"
+    "aql_block_processor_profile.h"
+    "aql_block_processor_timestamp.c"
+    "aql_block_processor_timestamp.h"
+    "aql_command_buffer.c"
+    "aql_command_buffer.h"
+    "aql_command_buffer_profile.c"
+    "aql_command_buffer_profile.h"
+    "aql_prepublished_kernarg_storage.h"
+    "aql_program_builder.c"
+    "aql_program_builder.h"
+    "aql_program_validation.c"
+    "aql_program_validation.h"
     "buffer.c"
     "buffer.h"
-    "buffer_pool.c"
-    "buffer_pool.h"
-    "channel.c"
-    "channel.h"
-    "command_buffer.c"
-    "command_buffer.h"
-    "device_queue.c"
-    "device_queue.h"
     "driver.c"
     "driver.h"
-    "event.c"
-    "event.h"
     "executable.c"
     "executable.h"
     "executable_cache.c"
     "executable_cache.h"
     "host_queue.c"
     "host_queue.h"
-    "host_service.c"
-    "host_service.h"
+    "host_queue_blit.c"
+    "host_queue_blit.h"
+    "host_queue_command_buffer.c"
+    "host_queue_command_buffer.h"
+    "host_queue_command_buffer_block.c"
+    "host_queue_command_buffer_block.h"
+    "host_queue_command_buffer_packet.c"
+    "host_queue_command_buffer_packet.h"
+    "host_queue_command_buffer_profile.c"
+    "host_queue_command_buffer_profile.h"
+    "host_queue_command_buffer_replay.c"
+    "host_queue_command_buffer_replay.h"
+    "host_queue_command_buffer_scratch.h"
+    "host_queue_dispatch.c"
+    "host_queue_dispatch.h"
+    "host_queue_file.c"
+    "host_queue_file.h"
+    "host_queue_host_call.c"
+    "host_queue_host_call.h"
+    "host_queue_memory.c"
+    "host_queue_memory.h"
+    "host_queue_pending.c"
+    "host_queue_pending.h"
+    "host_queue_pending_operation.h"
+    "host_queue_pending_payload.c"
+    "host_queue_policy.c"
+    "host_queue_policy.h"
+    "host_queue_profile.c"
+    "host_queue_profile.h"
+    "host_queue_profile_events.c"
+    "host_queue_profile_events.h"
+    "host_queue_staging.c"
+    "host_queue_staging.h"
+    "host_queue_submission.c"
+    "host_queue_submission.h"
+    "host_queue_timestamp.c"
+    "host_queue_timestamp.h"
+    "host_queue_waits.c"
+    "host_queue_waits.h"
     "logical_device.c"
     "logical_device.h"
     "physical_device.c"
     "physical_device.h"
+    "physical_device_capabilities.c"
+    "physical_device_capabilities.h"
+    "profile_aqlprofile.c"
+    "profile_aqlprofile.h"
+    "profile_counters.c"
+    "profile_counters.h"
+    "profile_device_metrics.c"
+    "profile_device_metrics.h"
+    "profile_device_metrics_linux.c"
+    "profile_device_metrics_source.h"
+    "profile_metadata.c"
+    "profile_metadata.h"
+    "profile_traces.c"
+    "profile_traces.h"
     "semaphore.c"
     "semaphore.h"
-    "semaphore_pool.c"
-    "semaphore_pool.h"
+    "slab_provider.c"
+    "slab_provider.h"
     "system.c"
     "system.h"
-    "virtual_queue.c"
+    "transient_buffer.c"
+    "transient_buffer.h"
     "virtual_queue.h"
   DEPS
+    ::access_policy
+    ::profile_events
+    ::queue_affinity
     iree::async
     iree::async::util::proactor_pool
     iree::base
+    iree::base::core_headers
     iree::base::internal
     iree::base::internal::arena
     iree::base::internal::flatcc::parsing
     iree::base::threading
     iree::base::threading::thread
     iree::hal
-    iree::hal::drivers::amdgpu::device::binaries
+    iree::hal::drivers::amdgpu::abi
+    iree::hal::drivers::amdgpu::device::binaries::toc
+    iree::hal::drivers::amdgpu::device::blit
+    iree::hal::drivers::amdgpu::device::dispatch
     iree::hal::drivers::amdgpu::device::headers
-    iree::hal::drivers::amdgpu::util
+    iree::hal::drivers::amdgpu::device::timestamp
+    iree::hal::drivers::amdgpu::util::block_pool
+    iree::hal::drivers::amdgpu::util::code_object_target
+    iree::hal::drivers::amdgpu::util::device_clock
+    iree::hal::drivers::amdgpu::util::device_library
+    iree::hal::drivers::amdgpu::util::hsaco_metadata
+    iree::hal::drivers::amdgpu::util::info
+    iree::hal::drivers::amdgpu::util::kfd
+    iree::hal::drivers::amdgpu::util::libaqlprofile
     iree::hal::drivers::amdgpu::util::libhsa
+    iree::hal::drivers::amdgpu::util::packet
+    iree::hal::drivers::amdgpu::util::queue_primitives
+    iree::hal::drivers::amdgpu::util::target_id
     iree::hal::drivers::amdgpu::util::topology
+    iree::hal::drivers::amdgpu::util::vmem
+    iree::hal::memory::passthrough_pool
+    iree::hal::memory::slab_provider
+    iree::hal::memory::tlsf_pool
+    iree::hal::memory::tracing
+    iree::hal::utils::elf_format
     iree::hal::utils::executable_debug_info
     iree::hal::utils::executable_header
-    iree::hal::utils::file_transfer
     iree::hal::utils::files
     iree::hal::utils::resource_set
     iree::schemas::amdgpu_executable_def_c_fbs
@@ -81,71 +205,116 @@
   NAME
     headers
   HDRS
-    "allocator.h"
+    "aql_block_processor.h"
+    "aql_block_processor_timestamp.h"
+    "aql_command_buffer.h"
+    "aql_command_buffer_profile.h"
+    "aql_prepublished_kernarg_storage.h"
+    "aql_program_builder.h"
+    "aql_program_validation.h"
     "buffer.h"
-    "buffer_pool.h"
-    "channel.h"
-    "command_buffer.h"
-    "device_queue.h"
     "driver.h"
-    "event.h"
     "executable.h"
     "executable_cache.h"
     "host_queue.h"
-    "host_service.h"
+    "host_queue_blit.h"
+    "host_queue_command_buffer.h"
+    "host_queue_command_buffer_block.h"
+    "host_queue_command_buffer_profile.h"
+    "host_queue_command_buffer_replay.h"
+    "host_queue_dispatch.h"
+    "host_queue_file.h"
+    "host_queue_host_call.h"
+    "host_queue_memory.h"
+    "host_queue_pending.h"
+    "host_queue_policy.h"
+    "host_queue_profile.h"
+    "host_queue_profile_events.h"
+    "host_queue_staging.h"
+    "host_queue_submission.h"
+    "host_queue_timestamp.h"
+    "host_queue_waits.h"
     "logical_device.h"
     "physical_device.h"
+    "physical_device_capabilities.h"
+    "profile_aqlprofile.h"
+    "profile_counters.h"
+    "profile_device_metrics.h"
+    "profile_events.h"
+    "profile_metadata.h"
+    "profile_traces.h"
     "semaphore.h"
-    "semaphore_pool.h"
+    "slab_provider.h"
     "system.h"
+    "transient_buffer.h"
+    "virtual_queue.h"
   DEPS
+    ::profile_events
+    ::queue_affinity
     iree::base
     iree::base::internal
     iree::base::internal::arena
+    iree::base::threading
     iree::hal
-    iree::hal::drivers::amdgpu::device::binaries
+    iree::hal::drivers::amdgpu::abi
+    iree::hal::drivers::amdgpu::device::binaries::toc
+    iree::hal::drivers::amdgpu::device::dispatch
     iree::hal::drivers::amdgpu::device::headers
-    iree::hal::drivers::amdgpu::util
+    iree::hal::drivers::amdgpu::util::block_pool
+    iree::hal::drivers::amdgpu::util::device_library
+    iree::hal::drivers::amdgpu::util::info
     iree::hal::drivers::amdgpu::util::libhsa
+    iree::hal::drivers::amdgpu::util::packet
+    iree::hal::drivers::amdgpu::util::queue_primitives
+    iree::hal::drivers::amdgpu::util::target_id
     iree::hal::drivers::amdgpu::util::topology
+    iree::hal::drivers::amdgpu::util::vmem
+    iree::hal::memory::slab_provider
+    iree::hal::memory::tracing
   PUBLIC
 )
 
 iree_cc_test(
   NAME
-    buffer_pool_test
+    access_policy_test
   SRCS
-    "buffer_pool_test.cc"
+    "access_policy_test.cc"
   DEPS
-    ::amdgpu
-    ::headers
+    ::access_policy
     iree::base
-    iree::hal::drivers::amdgpu::device::headers
-    iree::hal::drivers::amdgpu::util
-    iree::hal::drivers::amdgpu::util::libhsa
-    iree::hal::drivers::amdgpu::util::topology
+    iree::hal
     iree::testing::gtest
     iree::testing::gtest_main
-  LABELS
-    "driver=amdgpu"
-    "nodocker"
   GROUP
     "iree-hal-drivers-amdgpu-tests"
 )
 
 iree_cc_test(
   NAME
-    host_service_test
+    queue_affinity_test
   SRCS
-    "host_service_test.cc"
+    "queue_affinity_test.cc"
+  DEPS
+    ::queue_affinity
+    iree::base
+    iree::hal
+    iree::testing::gtest
+    iree::testing::gtest_main
+  GROUP
+    "iree-hal-drivers-amdgpu-tests"
+)
+
+iree_cc_test(
+  NAME
+    allocator_test
+  SRCS
+    "allocator_test.cc"
   DEPS
     ::amdgpu
     ::headers
     iree::base
     iree::hal
-    iree::hal::drivers::amdgpu::device::headers
-    iree::hal::drivers::amdgpu::util
-    iree::hal::drivers::amdgpu::util::libhsa
+    iree::hal::cts::util::test_base
     iree::hal::drivers::amdgpu::util::topology
     iree::testing::gtest
     iree::testing::gtest_main
@@ -158,17 +327,297 @@
 
 iree_cc_test(
   NAME
-    semaphore_pool_test
+    aql_block_processor_test
   SRCS
-    "semaphore_pool_test.cc"
+    "aql_block_processor_test.cc"
   DEPS
     ::amdgpu
     ::headers
     iree::base
+    iree::base::internal::arena
+    iree::hal
+    iree::hal::drivers::amdgpu::abi
     iree::hal::drivers::amdgpu::device::headers
-    iree::hal::drivers::amdgpu::util
-    iree::hal::drivers::amdgpu::util::libhsa
-    iree::hal::drivers::amdgpu::util::topology
+    iree::hal::drivers::amdgpu::util::packet
+    iree::testing::gtest
+    iree::testing::gtest_main
+  LABELS
+    "driver=amdgpu"
+    "nodocker"
+  GROUP
+    "iree-hal-drivers-amdgpu-tests"
+)
+
+iree_cc_test(
+  NAME
+    aql_block_processor_timestamp_test
+  SRCS
+    "aql_block_processor_timestamp_test.cc"
+  DEPS
+    ::amdgpu
+    ::headers
+    iree::base
+    iree::hal::drivers::amdgpu::abi
+    iree::hal::drivers::amdgpu::device::headers
+    iree::hal::drivers::amdgpu::util::packet
+    iree::testing::gtest
+    iree::testing::gtest_main
+  LABELS
+    "driver=amdgpu"
+    "nodocker"
+  GROUP
+    "iree-hal-drivers-amdgpu-tests"
+)
+
+iree_cc_test(
+  NAME
+    aql_command_buffer_test
+  SRCS
+    "aql_command_buffer_test.cc"
+  DEPS
+    ::amdgpu
+    ::headers
+    iree::base
+    iree::base::internal::arena
+    iree::hal
+    iree::hal::drivers::amdgpu::abi
+    iree::testing::gtest
+    iree::testing::gtest_main
+  LABELS
+    "driver=amdgpu"
+    "nodocker"
+  GROUP
+    "iree-hal-drivers-amdgpu-tests"
+)
+
+iree_cc_test(
+  NAME
+    aql_program_builder_test
+  SRCS
+    "aql_program_builder_test.cc"
+  DEPS
+    ::amdgpu
+    ::headers
+    iree::base
+    iree::base::internal::arena
+    iree::hal::drivers::amdgpu::abi
+    iree::testing::gtest
+    iree::testing::gtest_main
+  LABELS
+    "driver=amdgpu"
+    "nodocker"
+  GROUP
+    "iree-hal-drivers-amdgpu-tests"
+)
+
+iree_cc_test(
+  NAME
+    driver_options_test
+  SRCS
+    "driver_options_test.cc"
+  DEPS
+    ::amdgpu
+    iree::base
+    iree::hal
+    iree::testing::gtest
+    iree::testing::gtest_main
+  GROUP
+    "iree-hal-drivers-amdgpu-tests"
+)
+
+iree_cc_test(
+  NAME
+    executable_test
+  SRCS
+    "executable_test.cc"
+  DEPS
+    ::amdgpu
+    ::headers
+    iree::base
+    iree::base::core_headers
+    iree::base::internal::flatcc::building
+    iree::hal
+    iree::schemas::amdgpu_executable_def_c_fbs
+    iree::testing::gtest
+    iree::testing::gtest_main
+  GROUP
+    "iree-hal-drivers-amdgpu-tests"
+)
+
+iree_cc_test(
+  NAME
+    host_queue_command_buffer_test
+  SRCS
+    "host_queue_command_buffer_packet.h"
+    "host_queue_command_buffer_test.cc"
+  DEPS
+    ::amdgpu
+    ::headers
+    ::queue_affinity
+    iree::base
+    iree::hal
+    iree::hal::cts::util::test_base
+    iree::hal::drivers::amdgpu::cts::testdata_amdgpu
+    iree::hal::drivers::amdgpu::util::packet
+    iree::testing::gtest
+    iree::testing::gtest_main
+  LABELS
+    "driver=amdgpu"
+    "nodocker"
+  GROUP
+    "iree-hal-drivers-amdgpu-tests"
+)
+
+iree_cc_test(
+  NAME
+    host_queue_pending_test
+  SRCS
+    "host_queue_pending_test.cc"
+  DEPS
+    ::amdgpu
+    ::headers
+    iree::async
+    iree::base
+    iree::base::internal
+    iree::base::threading
+    iree::hal
+    iree::hal::cts::util::test_base
+    iree::hal::drivers::amdgpu::util::packet
+    iree::hal::memory::fixed_block_pool
+    iree::hal::memory::tlsf_pool
+    iree::testing::gtest
+    iree::testing::gtest_main
+  LABELS
+    "driver=amdgpu"
+    "nodocker"
+  GROUP
+    "iree-hal-drivers-amdgpu-tests"
+)
+
+iree_cc_test(
+  NAME
+    host_queue_submission_test
+  SRCS
+    "host_queue_submission_test.cc"
+  DEPS
+    ::amdgpu
+    ::headers
+    iree::base
+    iree::hal
+    iree::hal::cts::util::test_base
+    iree::hal::drivers::amdgpu::util::packet
+    iree::testing::gtest
+    iree::testing::gtest_main
+  LABELS
+    "driver=amdgpu"
+    "nodocker"
+  GROUP
+    "iree-hal-drivers-amdgpu-tests"
+)
+
+iree_cc_test(
+  NAME
+    host_queue_staging_test
+  SRCS
+    "host_queue_staging_test.cc"
+  DEPS
+    ::amdgpu
+    ::headers
+    ::queue_affinity
+    iree::base
+    iree::hal
+    iree::hal::cts::util::test_base
+    iree::hal::drivers::amdgpu::util::packet
+    iree::io::file_handle
+    iree::testing::gtest
+    iree::testing::gtest_main
+  LABELS
+    "driver=amdgpu"
+    "nodocker"
+  GROUP
+    "iree-hal-drivers-amdgpu-tests"
+)
+
+iree_cc_test(
+  NAME
+    physical_device_capabilities_test
+  SRCS
+    "physical_device_capabilities_test.cc"
+  DEPS
+    ::amdgpu
+    ::headers
+    iree::base
+    iree::hal
+    iree::testing::gtest
+    iree::testing::gtest_main
+  GROUP
+    "iree-hal-drivers-amdgpu-tests"
+)
+
+iree_cc_test(
+  NAME
+    profile_metadata_test
+  SRCS
+    "profile_metadata_test.cc"
+  DEPS
+    ::amdgpu
+    ::headers
+    iree::base
+    iree::hal
+    iree::testing::gtest
+    iree::testing::gtest_main
+  GROUP
+    "iree-hal-drivers-amdgpu-tests"
+)
+
+iree_cc_test(
+  NAME
+    profile_events_test
+  SRCS
+    "profile_events_test.cc"
+  DEPS
+    ::profile_events
+    iree::base
+    iree::hal
+    iree::testing::gtest
+    iree::testing::gtest_main
+  GROUP
+    "iree-hal-drivers-amdgpu-tests"
+)
+
+iree_cc_test(
+  NAME
+    semaphore_test
+  SRCS
+    "semaphore_test.cc"
+  DEPS
+    ::amdgpu
+    ::headers
+    iree::async
+    iree::async::platform
+    iree::base
+    iree::hal
+    iree::testing::gtest
+    iree::testing::gtest_main
+  LABELS
+    "driver=amdgpu"
+    "nodocker"
+  GROUP
+    "iree-hal-drivers-amdgpu-tests"
+)
+
+iree_cc_test(
+  NAME
+    slab_provider_test
+  SRCS
+    "slab_provider_test.cc"
+  DEPS
+    ::amdgpu
+    ::headers
+    iree::async
+    iree::base
+    iree::hal
+    iree::hal::cts::util::test_base
     iree::testing::gtest
     iree::testing::gtest_main
   LABELS
diff --git a/runtime/src/iree/hal/drivers/amdgpu/README.md b/runtime/src/iree/hal/drivers/amdgpu/README.md
index be129c1..05d1861 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/README.md
+++ b/runtime/src/iree/hal/drivers/amdgpu/README.md
@@ -12,19 +12,28 @@
 -DIREE_BUILD_COMPILER=ON
 -DIREE_TARGET_BACKEND_ROCM=ON
 -DIREE_HAL_DRIVER_AMDGPU=ON
--DIREE_HAL_AMDGPU_DEVICE_LIBRARY_TARGETS=gfx1100
+-DIREE_HAL_AMDGPU_DEVICE_LIBRARY_TARGETS=all
 -DIREE_ROCM_TEST_TARGET_CHIP=gfx1100
 ```
 
 ### Bazel
 
+Build tools with the AMDGPU runtime driver registered and with device artifacts
+compiled for your local GPU architecture:
+
+```sh
+iree-bazel-build //tools:iree-compile //tools:iree-run-module \
+  --iree_drivers=amdgpu,cuda,hip,local-sync,local-task,vulkan \
+  --//build_tools/bazel:rocm_test_target=gfx1100
+```
+
 The ROCM chip target defaults to `gfx1100`. Override for your hardware:
 
 ```sh
 iree-bazel-test --//build_tools/bazel:rocm_test_target=gfx942 //runtime/src/iree/hal/drivers/amdgpu/cts/...
 ```
 
-Substitute the architecture with your own. See [therock_amdgpu_targets.cmake](https://github.com/ROCm/TheRock/blob/main/cmake/therock_amdgpu_targets.cmake#L44) for a list of common targets. Future changes will include family support matching that file.
+Substitute the architecture with your own. See [therock_amdgpu_targets.cmake](https://github.com/ROCm/TheRock/blob/main/cmake/therock_amdgpu_targets.cmake) for the target and generic family vocabulary mirrored by the embedded device library build.
 
 Use `amdgpu` to specify devices at runtime:
 
@@ -49,6 +58,104 @@
 iree-compile --iree-hal-target-device=amdgpu ...
 ```
 
+For a direct Bazel-built smoke test, compile with the AMDGPU target device and
+run the resulting VMFB with the AMDGPU HAL driver:
+
+```sh
+bazel-bin/tools/iree-compile \
+  --iree-input-type=stablehlo \
+  --iree-hal-target-device=amdgpu \
+  --iree-rocm-target=gfx1100 \
+  --iree-rocm-bc-dir=bazel-bin/external/_main~iree_extension~amdgpu_device_libs/bitcode \
+  tests/e2e/stablehlo_models/mnist_fake_weights.mlir \
+  -o=/tmp/mnist_fake_amdgpu.vmfb
+
+bazel-bin/tools/iree-run-module \
+  --device=amdgpu \
+  --module=/tmp/mnist_fake_amdgpu.vmfb \
+  --function=predict \
+  --input=1x28x28x1xf32 \
+  --expected_output='1x10xf32=0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1'
+```
+
+Prefer this explicit two-step flow when debugging AMDGPU target-device behavior.
+`iree-run-mlir --device=amdgpu` currently relies on generic device-to-compiler
+flag inference and may select the legacy ROCm/HIP target path instead of the
+AMDGPU HAL target.
+
+To capture a Tracy trace of the same runtime path, use the Bazel trace wrapper.
+The wrapper requires a `tracy-capture` binary in `PATH`, or one supplied through
+`IREE_TRACY_CAPTURE`:
+
+```sh
+IREE_TRACY_CAPTURE=/path/to/tracy-capture \
+  build_tools/bin/iree-bazel-run \
+  --trace \
+  --trace_name=fake_mnist \
+  //tools:iree-run-module \
+  --iree_drivers=amdgpu,cuda,hip,local-sync,local-task,vulkan \
+  --//build_tools/bazel:rocm_test_target=gfx1100 \
+  -- \
+  --device=amdgpu \
+  --module=/tmp/mnist_fake_amdgpu.vmfb \
+  --function=predict \
+  --input=1x28x28x1xf32 \
+  --expected_output='1x10xf32=0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1'
+```
+
+## Driver Shape
+
+The AMDGPU HAL driver is a native HSA/ROCR backend. It does not route through
+HIP or the legacy ROCm HAL driver. The major runtime objects are:
+
+* a driver that discovers HSA agents and creates logical devices;
+* one logical device spanning one or more physical GPU devices;
+* one physical-device object per HSA GPU agent, including queues, memory pools,
+  executable caches, profiling state, and device metrics;
+* host queues that translate HAL queue operations into AQL packet streams;
+* replayable command buffers that store backend command records and emit AQL
+  packets at submission time; and
+* device libraries embedded in the runtime and selected for the target GPU ISA
+  at device creation.
+
+Normal execution only depends on the HSA runtime and the embedded device
+library. Optional profiling modes dynamically load ROCm profiling libraries only
+when the selected mode needs them.
+
+## Profiling and Replay
+
+The AMDGPU driver is one of the primary producers for IREE's HAL-native
+profiling and replay tools:
+
+* `--device_profiling_mode=queue-events,device-queue-events,dispatch-events`
+  captures queue submissions, device-side queue timing, and per-dispatch timing
+  into a `.ireeprof` bundle.
+* `--device_profiling_mode=executable-metadata,counters` adds executable export
+  metadata and selected hardware/software counters.
+* `--device_profiling_mode=executable-traces` captures heavy ATT/SQTT artifacts
+  for filtered dispatches.
+* `--device_replay_output=/tmp/model.ireereplay` records a HAL-level replay
+  stream that can be run, benchmarked, and profiled independently of the
+  original application.
+
+Useful inspection commands:
+
+```sh
+iree-profile summary /tmp/model.ireeprof
+iree-profile dispatch --format=jsonl /tmp/model.ireeprof
+iree-profile export --format=ireeperf-jsonl \
+  --output=/tmp/model.ireeperf.jsonl /tmp/model.ireeprof
+uvx --with perfetto --with protobuf python "$(command -v iree-profile-render)" \
+  --format=perfetto \
+  /tmp/model.ireeperf.jsonl -o /tmp/model.pftrace
+iree-profile att --rocm_library_path=/opt/rocm/lib /tmp/model-att.ireeprof
+```
+
+See the website documentation for the full workflows:
+
+* [Device profiling](../../../../../../docs/website/docs/developers/performance/device-profiling.md)
+* [Device replay](../../../../../../docs/website/docs/developers/performance/device-replay.md)
+
 ## Build Notes
 
 ### HSA/ROCR Dependency
@@ -61,6 +168,23 @@
 
 See [HSA/ROCR Library](#hsarocr-library) for more information on our usage.
 
+### ROCm Profiling Dependencies
+
+Counter and ATT/SQTT capture use ROCm's aqlprofile library through a small
+dynamic-loader shim. Normal execution, queue timing, dispatch timing, replay,
+statistics, and Perfetto export do not require this library.
+
+When `counters` or `executable-traces` profiling is requested, the driver looks
+for an aqlprofile-compatible library in this order:
+
+* `IREE_HAL_AMDGPU_LIBAQLPROFILE_PATH`;
+* a library adjacent to the loaded HSA runtime; and
+* the platform dynamic-library search path.
+
+The `iree-profile att` decoder also needs ROCm decode libraries. Pass
+`--rocm_library_path=/opt/rocm/lib`, set `IREE_HAL_AMDGPU_LIBAQLPROFILE_PATH`,
+or rely on the platform search path.
+
 ### Device Library Compilation
 
 **Required CMake Options**: `-DIREE_BUILD_COMPILER=ON -DIREE_TARGET_BACKEND_ROCM=ON`
@@ -71,4 +195,14 @@
 
 The device library should be compiled automatically when building the AMDGPU HAL driver and gets embedded inside the runtime binary so that no additional files are required at runtime.
 
-The `IREE_HAL_AMDGPU_DEVICE_LIBRARY_TARGETS` CMake variable can be set to a list of target architectures to build the library for and bundle into the AMDGPU HAL library. Architectures not built into the library will fail to instantiate the driver at runtime.
+The `IREE_HAL_AMDGPU_DEVICE_LIBRARY_TARGETS` CMake variable defaults to `all`, which embeds LLVM generic ISA code objects covering every currently known AMDGPU device library target. Packagers can set it to a smaller list of exact target architectures, LLVM generic ISA targets, TheRock-style generic target families, or TheRock-style product bundles. Exact targets use the HSA ISA spelling, such as `gfx1100`. LLVM generic ISA targets use spellings such as `gfx11-generic`. Generic families use the TheRock family spelling, such as `gfx110X-all`, and product bundles use spellings such as `dgpu-all` or `igpu-all`. These selectors expand to the smallest known compatible code object set instead of one code object per exact GPU. Architectures not built into the library will fail to instantiate the driver at runtime.
+
+The Bazel build exposes the same selector vocabulary through `//runtime/src/iree/hal/drivers/amdgpu/device/binaries:targets`:
+
+```sh
+iree-bazel-build --//runtime/src/iree/hal/drivers/amdgpu/device/binaries:targets=igpu-all //runtime/src/iree/hal/drivers/amdgpu:amdgpu
+```
+
+See [`device/binaries/README.md`](device/binaries/README.md) for the target map
+update flow, the generated Bazel/CMake/runtime fragments, and the TheRock/LLVM
+sources that should be checked when adding support for a new architecture.
diff --git a/runtime/src/iree/hal/drivers/amdgpu/abi/BUILD.bazel b/runtime/src/iree/hal/drivers/amdgpu/abi/BUILD.bazel
new file mode 100644
index 0000000..cbf7618
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/abi/BUILD.bazel
@@ -0,0 +1,29 @@
+# Copyright 2025 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:build_defs.oss.bzl", "iree_runtime_cc_library")
+
+package(
+    default_visibility = ["//visibility:public"],
+    features = ["layering_check"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+iree_runtime_cc_library(
+    name = "abi",
+    hdrs = [
+        "command_buffer.h",
+        "common.h",
+        "kernel_args.h",
+        "profile.h",
+        "queue.h",
+        "signal.h",
+        "timestamp.h",
+    ],
+    deps = [
+        "@hsa_runtime_headers",
+    ],
+)
diff --git a/runtime/src/iree/hal/drivers/amdgpu/abi/CMakeLists.txt b/runtime/src/iree/hal/drivers/amdgpu/abi/CMakeLists.txt
new file mode 100644
index 0000000..0d014f9
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/abi/CMakeLists.txt
@@ -0,0 +1,29 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from           #
+# runtime/src/iree/hal/drivers/amdgpu/abi/BUILD.bazel                          #
+#                                                                              #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary   #
+# CMake-only content.                                                          #
+#                                                                              #
+# To disable autogeneration for this file entirely, delete this header.        #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+  NAME
+    abi
+  HDRS
+    "command_buffer.h"
+    "common.h"
+    "kernel_args.h"
+    "profile.h"
+    "queue.h"
+    "signal.h"
+    "timestamp.h"
+  DEPS
+    hsa_runtime::headers
+  PUBLIC
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/runtime/src/iree/hal/drivers/amdgpu/abi/command_buffer.h b/runtime/src/iree/hal/drivers/amdgpu/abi/command_buffer.h
new file mode 100644
index 0000000..021c50a
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/abi/command_buffer.h
@@ -0,0 +1,492 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_ABI_COMMAND_BUFFER_H_
+#define IREE_HAL_DRIVERS_AMDGPU_ABI_COMMAND_BUFFER_H_
+
+#include "iree/hal/drivers/amdgpu/abi/common.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+//===----------------------------------------------------------------------===//
+// Command Buffer Program ABI
+//===----------------------------------------------------------------------===//
+
+enum {
+  // Magic value stored in every command-buffer block header.
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_BLOCK_MAGIC = 0x444D4342u,
+  // Version of the block ABI defined in this header.
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_BLOCK_VERSION_0 = 0,
+  // Required alignment for all command records and binding source records.
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_RECORD_ALIGNMENT = 8,
+};
+
+// Opcodes in the AMDGPU command-buffer program.
+typedef enum iree_hal_amdgpu_command_buffer_opcode_e {
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_INVALID = 0,
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BARRIER = 1,
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_DISPATCH = 2,
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_FILL = 3,
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_COPY = 4,
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_UPDATE = 5,
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_PROFILE_MARKER = 6,
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BRANCH = 7,
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_COND_BRANCH = 8,
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_RETURN = 9,
+} iree_hal_amdgpu_command_buffer_opcode_t;
+
+// Command flags shared by all command records.
+typedef enum iree_hal_amdgpu_command_buffer_command_flag_bits_e {
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_NONE = 0u,
+  // The command writes queue-owned kernarg memory at replay time.
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_USES_QUEUE_KERNARGS = 1u << 0,
+  // The command's first payload packet must participate in the command-buffer
+  // execution dependency chain.
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_HAS_BARRIER = 1u << 1,
+} iree_hal_amdgpu_command_buffer_command_flag_bits_t;
+
+enum {
+  // First bit of the two-bit acquire fence scope field in command flags.
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_ACQUIRE_SCOPE_SHIFT = 2,
+  // Bit mask of the two-bit acquire fence scope field in command flags.
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_ACQUIRE_SCOPE_MASK = 0x0Cu,
+  // First bit of the two-bit release fence scope field in command flags.
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_RELEASE_SCOPE_SHIFT = 4,
+  // Bit mask of the two-bit release fence scope field in command flags.
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_RELEASE_SCOPE_MASK = 0x30u,
+};
+
+// Binding source flags used to form HAL dispatch kernarg pointer prefixes.
+typedef enum iree_hal_amdgpu_command_buffer_binding_source_flag_bits_e {
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_NONE = 0u,
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_DYNAMIC = 1u << 0,
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_INDIRECT_PARAMETERS = 1u
+                                                                           << 1,
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_STATIC_BUFFER = 1u << 2,
+} iree_hal_amdgpu_command_buffer_binding_source_flag_bits_t;
+
+// Dispatch command flags.
+typedef enum iree_hal_amdgpu_command_buffer_dispatch_flag_bits_e {
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_DISPATCH_FLAG_NONE = 0u,
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_DISPATCH_FLAG_INDIRECT_PARAMETERS = 1u << 0,
+} iree_hal_amdgpu_command_buffer_dispatch_flag_bits_t;
+
+// Kernarg formation strategy for a dispatch command.
+typedef enum iree_hal_amdgpu_command_buffer_kernarg_strategy_e {
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_HAL = 0,
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_CUSTOM_DIRECT = 1,
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_INDIRECT = 2,
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_PREPUBLISHED = 3,
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_PATCHED_TEMPLATE = 4,
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_DYNAMIC_BINDINGS = 5,
+} iree_hal_amdgpu_command_buffer_kernarg_strategy_t;
+
+// Binding reference kind constants embedded in command records.
+enum iree_hal_amdgpu_command_buffer_binding_kind_e {
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_KIND_INVALID = 0,
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_KIND_STATIC = 1,
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_KIND_DYNAMIC = 2,
+};
+// Compact binding reference kind storage.
+typedef uint8_t iree_hal_amdgpu_command_buffer_binding_kind_t;
+
+// Block flags reserved for optional block metadata.
+typedef enum iree_hal_amdgpu_command_buffer_block_flag_bits_e {
+  IREE_HAL_AMDGPU_COMMAND_BUFFER_BLOCK_FLAG_NONE = 0u,
+} iree_hal_amdgpu_command_buffer_block_flag_bits_t;
+// Compact block flag storage.
+typedef uint8_t iree_hal_amdgpu_command_buffer_block_flags_t;
+
+// Header stored at byte 0 of every command-buffer block.
+typedef struct IREE_AMDGPU_ALIGNAS(8)
+    iree_hal_amdgpu_command_buffer_block_header_t {
+  // Magic value identifying this as an AMDGPU command-buffer block.
+  uint32_t magic;
+  // ABI version used to interpret this block.
+  uint16_t version;
+  // Byte length of this header.
+  uint16_t header_length;
+  // Ordinal of this block within the command-buffer program.
+  uint32_t block_ordinal;
+  // Branch target block ordinal when |terminator_opcode| is BRANCH, or zero.
+  uint32_t terminator_target_block_ordinal;
+  // Total byte capacity of this block, including this header.
+  uint32_t block_length;
+  // Byte offset from this header to the first command record.
+  uint32_t command_offset;
+  // Total bytes occupied by command records.
+  uint32_t command_length;
+  // Byte offset from this header to the first binding source record.
+  uint32_t binding_source_offset;
+  // Number of command records in this block, including terminators.
+  uint16_t command_count;
+  // Number of binding source records in this block.
+  uint16_t binding_source_count;
+  // Worst-case AQL packets emitted when replaying this block.
+  uint32_t aql_packet_count;
+  // Worst-case kernarg bytes emitted when replaying this block.
+  uint32_t kernarg_length;
+  // Number of leading AQL payload packets in the initial unordered span,
+  // including the first packet with a barrier edge. Zero when the block emits
+  // no AQL packets.
+  uint32_t initial_barrier_packet_count;
+  // Byte offset from this header to block-local read-only payload data.
+  uint32_t rodata_offset;
+  // Total bytes occupied by block-local read-only payload data.
+  uint32_t rodata_length;
+  // Number of dispatch command records in this block.
+  uint16_t dispatch_count;
+  // Number of dispatch command records using indirect parameters.
+  uint16_t indirect_dispatch_count;
+  // Number of profile marker command records in this block.
+  uint16_t profile_marker_count;
+  // Terminator opcode from iree_hal_amdgpu_command_buffer_opcode_t.
+  uint8_t terminator_opcode;
+  // Block flags from iree_hal_amdgpu_command_buffer_block_flag_bits_t.
+  iree_hal_amdgpu_command_buffer_block_flags_t flags;
+} iree_hal_amdgpu_command_buffer_block_header_t;
+IREE_AMDGPU_STATIC_ASSERT(
+    sizeof(iree_hal_amdgpu_command_buffer_block_header_t) == 64,
+    "command-buffer block header must stay cache-line sized");
+
+// Header common to every command record.
+typedef struct IREE_AMDGPU_ALIGNAS(8)
+    iree_hal_amdgpu_command_buffer_command_header_t {
+  // Opcode from iree_hal_amdgpu_command_buffer_opcode_t.
+  uint8_t opcode;
+  // Command flags from iree_hal_amdgpu_command_buffer_command_flag_bits_t.
+  uint8_t flags;
+  // Command length in 8-byte qwords, including this header.
+  uint16_t length_qwords;
+  // Program-global command index used for profiling/source attribution.
+  uint32_t command_index;
+} iree_hal_amdgpu_command_buffer_command_header_t;
+IREE_AMDGPU_STATIC_ASSERT(
+    sizeof(iree_hal_amdgpu_command_buffer_command_header_t) == 8,
+    "command header size is part of the command-buffer ABI");
+
+// Returns |flags| with its encoded fence scope fields replaced.
+static inline uint8_t
+iree_hal_amdgpu_command_buffer_command_flags_set_fence_scopes(
+    uint8_t flags, uint8_t acquire_scope, uint8_t release_scope) {
+  flags &=
+      (uint8_t)~(IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_ACQUIRE_SCOPE_MASK |
+                 IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_RELEASE_SCOPE_MASK);
+  flags |=
+      (uint8_t)((acquire_scope & 0x3u)
+                << IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_ACQUIRE_SCOPE_SHIFT);
+  flags |=
+      (uint8_t)((release_scope & 0x3u)
+                << IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_RELEASE_SCOPE_SHIFT);
+  return flags;
+}
+
+// Returns one encoded fence scope field from a command's flag byte.
+static inline uint8_t iree_hal_amdgpu_command_buffer_command_flags_fence_scope(
+    uint8_t flags, uint8_t mask, uint8_t shift) {
+  return (uint8_t)((flags & mask) >> shift);
+}
+
+// Returns the acquire fence scope encoded in a command's flag byte.
+static inline uint8_t
+iree_hal_amdgpu_command_buffer_command_flags_acquire_scope(uint8_t flags) {
+  return iree_hal_amdgpu_command_buffer_command_flags_fence_scope(
+      flags, IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_ACQUIRE_SCOPE_MASK,
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_ACQUIRE_SCOPE_SHIFT);
+}
+
+// Returns the release fence scope encoded in a command's flag byte.
+static inline uint8_t
+iree_hal_amdgpu_command_buffer_command_flags_release_scope(uint8_t flags) {
+  return iree_hal_amdgpu_command_buffer_command_flags_fence_scope(
+      flags, IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_RELEASE_SCOPE_MASK,
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_RELEASE_SCOPE_SHIFT);
+}
+
+// Source record used to emit one HAL ABI dispatch binding pointer.
+typedef struct IREE_AMDGPU_ALIGNAS(8)
+    iree_hal_amdgpu_command_buffer_binding_source_t {
+  // Static raw source: final raw device pointer.
+  //
+  // Dynamic source: byte offset added to the queue_execute binding table entry
+  // in |slot|.
+  //
+  // Static-buffer source: byte offset added to the command-buffer static buffer
+  // ordinal in |slot|.
+  uint64_t offset_or_pointer;
+  // Dynamic source queue_execute binding table slot or static buffer ordinal.
+  // Must be zero for raw static sources.
+  uint32_t slot;
+  // Destination HAL ABI binding pointer ordinal for compact patch lists.
+  uint16_t target_binding_ordinal;
+  // Source flags from
+  // iree_hal_amdgpu_command_buffer_binding_source_flag_bits_t.
+  uint8_t flags;
+  // Reserved byte that must be zero in version 0.
+  uint8_t reserved0;
+} iree_hal_amdgpu_command_buffer_binding_source_t;
+IREE_AMDGPU_STATIC_ASSERT(
+    sizeof(iree_hal_amdgpu_command_buffer_binding_source_t) == 16,
+    "binding source size is part of the command-buffer ABI");
+
+// Barrier metadata command. Replay normally folds this into the next
+// packet-bearing command instead of emitting a standalone packet.
+typedef struct IREE_AMDGPU_ALIGNAS(8)
+    iree_hal_amdgpu_command_buffer_barrier_command_t {
+  // Common command record header.
+  iree_hal_amdgpu_command_buffer_command_header_t header;
+  // Acquire fence scope requested by the barrier.
+  uint8_t acquire_scope;
+  // Release fence scope requested by the barrier.
+  uint8_t release_scope;
+  // Barrier flags reserved for visibility-debt lowering.
+  uint16_t barrier_flags;
+  // Reserved bytes that must be zero in version 0.
+  uint32_t reserved0;
+} iree_hal_amdgpu_command_buffer_barrier_command_t;
+IREE_AMDGPU_STATIC_ASSERT(
+    sizeof(iree_hal_amdgpu_command_buffer_barrier_command_t) == 16,
+    "barrier command size must remain qword aligned");
+
+// Dispatch command record.
+typedef struct IREE_AMDGPU_ALIGNAS(8)
+    iree_hal_amdgpu_command_buffer_dispatch_command_t {
+  // Common command record header.
+  iree_hal_amdgpu_command_buffer_command_header_t header;
+  // HSA kernel object for the command buffer's selected physical device.
+  uint64_t kernel_object;
+  // Byte offset from the block header to this dispatch's first binding source.
+  uint32_t binding_source_offset;
+  // Strategy-specific payload reference.
+  // HAL/CUSTOM_DIRECT/INDIRECT: byte offset from this command record to
+  // constants/implicit tail bytes.
+  // PATCHED_TEMPLATE: command-buffer rodata ordinal for the immutable kernarg
+  // template copied into queue-owned kernargs before dynamic binding patches.
+  // PREPUBLISHED: byte offset from the command-buffer prepublished kernarg
+  // storage to the final kernargs.
+  uint32_t payload_reference;
+  // Number of HAL ABI binding pointer slots emitted before the tail payload.
+  uint16_t binding_count;
+  // Total kernarg reservation size in 8-byte qwords.
+  uint16_t kernarg_length_qwords;
+  // Strategy-specific payload count.
+  union {
+    // Inline tail payload size in 8-byte qwords.
+    uint16_t tail_length_qwords;
+    // Number of dynamic binding patch records following this command.
+    uint16_t patch_source_count;
+  } payload;
+  // Kernarg strategy from iree_hal_amdgpu_command_buffer_kernarg_strategy_t.
+  uint8_t kernarg_strategy;
+  // Dispatch flags from iree_hal_amdgpu_command_buffer_dispatch_flag_bits_t.
+  uint8_t dispatch_flags;
+  // AQL dispatch packet setup field.
+  uint16_t setup;
+  // Executable export ordinal used for profiling and diagnostics.
+  uint32_t export_ordinal;
+  // AQL dispatch packet workgroup size fields.
+  uint16_t workgroup_size[3];
+  // Kernarg qword offset of implicit args, or UINT16_MAX when absent.
+  uint16_t implicit_args_offset_qwords;
+  // AQL dispatch packet grid size fields.
+  uint32_t grid_size[3];
+  // AQL dispatch packet private segment size field.
+  uint32_t private_segment_size;
+  // AQL dispatch packet group segment size field.
+  uint32_t group_segment_size;
+  // Session-local profile executable id used for event attribution.
+  uint64_t executable_id;
+} iree_hal_amdgpu_command_buffer_dispatch_command_t;
+IREE_AMDGPU_STATIC_ASSERT(
+    sizeof(iree_hal_amdgpu_command_buffer_dispatch_command_t) == 80,
+    "dispatch command size must remain qword aligned");
+
+// Fill command record.
+typedef struct IREE_AMDGPU_ALIGNAS(8)
+    iree_hal_amdgpu_command_buffer_fill_command_t {
+  // Common command record header.
+  iree_hal_amdgpu_command_buffer_command_header_t header;
+  // Byte offset into the target buffer reference.
+  uint64_t target_offset;
+  // Byte length of the target range.
+  uint64_t length;
+  // Repeated fill pattern stored in the low bytes.
+  uint64_t pattern;
+  // Static buffer ordinal or dynamic binding-table slot.
+  uint32_t target_ordinal;
+  // Binding reference kind from iree_hal_amdgpu_command_buffer_binding_kind_t.
+  iree_hal_amdgpu_command_buffer_binding_kind_t target_kind;
+  // Byte length of the fill pattern.
+  uint8_t pattern_length;
+  // Reserved bytes that must be zero in version 0.
+  uint8_t reserved0[2];
+} iree_hal_amdgpu_command_buffer_fill_command_t;
+IREE_AMDGPU_STATIC_ASSERT(
+    sizeof(iree_hal_amdgpu_command_buffer_fill_command_t) == 40,
+    "fill command size must remain qword aligned");
+
+// Copy command record.
+typedef struct IREE_AMDGPU_ALIGNAS(8)
+    iree_hal_amdgpu_command_buffer_copy_command_t {
+  // Common command record header.
+  iree_hal_amdgpu_command_buffer_command_header_t header;
+  // Byte length of the copied range.
+  uint64_t length;
+  // Byte offset into the source buffer reference.
+  uint64_t source_offset;
+  // Byte offset into the target buffer reference.
+  uint64_t target_offset;
+  // Static buffer ordinal or dynamic binding-table slot for the source.
+  uint32_t source_ordinal;
+  // Static buffer ordinal or dynamic binding-table slot for the target.
+  uint32_t target_ordinal;
+  // Source reference kind from iree_hal_amdgpu_command_buffer_binding_kind_t.
+  iree_hal_amdgpu_command_buffer_binding_kind_t source_kind;
+  // Target reference kind from iree_hal_amdgpu_command_buffer_binding_kind_t.
+  iree_hal_amdgpu_command_buffer_binding_kind_t target_kind;
+  // Reserved bytes that must be zero in version 0.
+  uint8_t reserved0[6];
+} iree_hal_amdgpu_command_buffer_copy_command_t;
+IREE_AMDGPU_STATIC_ASSERT(
+    sizeof(iree_hal_amdgpu_command_buffer_copy_command_t) == 48,
+    "copy command size must remain qword aligned");
+
+// Update command record.
+typedef struct IREE_AMDGPU_ALIGNAS(8)
+    iree_hal_amdgpu_command_buffer_update_command_t {
+  // Common command record header.
+  iree_hal_amdgpu_command_buffer_command_header_t header;
+  // Command-buffer rodata segment ordinal containing the update payload.
+  uint64_t rodata_ordinal;
+  // Byte offset into the target buffer reference.
+  uint64_t target_offset;
+  // Byte length of the update payload and target range.
+  uint32_t length;
+  // Static buffer ordinal or dynamic binding-table slot for the target.
+  uint32_t target_ordinal;
+  // Target reference kind from iree_hal_amdgpu_command_buffer_binding_kind_t.
+  iree_hal_amdgpu_command_buffer_binding_kind_t target_kind;
+  // Reserved bytes that must be zero in version 0.
+  uint8_t reserved0[7];
+} iree_hal_amdgpu_command_buffer_update_command_t;
+IREE_AMDGPU_STATIC_ASSERT(
+    sizeof(iree_hal_amdgpu_command_buffer_update_command_t) == 40,
+    "update command size must remain qword aligned");
+
+// Unconditional branch terminator.
+typedef struct IREE_AMDGPU_ALIGNAS(8)
+    iree_hal_amdgpu_command_buffer_branch_command_t {
+  // Common command record header.
+  iree_hal_amdgpu_command_buffer_command_header_t header;
+  // Target block ordinal.
+  uint32_t target_block_ordinal;
+  // Reserved bytes that must be zero in version 0.
+  uint32_t reserved0;
+} iree_hal_amdgpu_command_buffer_branch_command_t;
+IREE_AMDGPU_STATIC_ASSERT(
+    sizeof(iree_hal_amdgpu_command_buffer_branch_command_t) == 16,
+    "branch command size must remain qword aligned");
+
+// Conditional branch terminator.
+typedef struct IREE_AMDGPU_ALIGNAS(8)
+    iree_hal_amdgpu_command_buffer_cond_branch_command_t {
+  // Common command record header.
+  iree_hal_amdgpu_command_buffer_command_header_t header;
+  // Target block ordinal when the loaded condition value is non-zero.
+  uint32_t true_block_ordinal;
+  // Target block ordinal when the loaded condition value is zero.
+  uint32_t false_block_ordinal;
+  // Condition load width in bytes.
+  uint8_t condition_width;
+  // Reserved bytes that must be zero in version 0.
+  uint8_t reserved0[7];
+} iree_hal_amdgpu_command_buffer_cond_branch_command_t;
+IREE_AMDGPU_STATIC_ASSERT(
+    sizeof(iree_hal_amdgpu_command_buffer_cond_branch_command_t) == 24,
+    "conditional branch command size must remain qword aligned");
+
+// Return terminator.
+typedef struct IREE_AMDGPU_ALIGNAS(8)
+    iree_hal_amdgpu_command_buffer_return_command_t {
+  // Common command record header.
+  iree_hal_amdgpu_command_buffer_command_header_t header;
+} iree_hal_amdgpu_command_buffer_return_command_t;
+IREE_AMDGPU_STATIC_ASSERT(
+    sizeof(iree_hal_amdgpu_command_buffer_return_command_t) == 8,
+    "return command size must remain qword aligned");
+
+// Returns the byte length encoded in a command header.
+static inline size_t iree_hal_amdgpu_command_buffer_command_length(
+    const iree_hal_amdgpu_command_buffer_command_header_t* command) {
+  return (size_t)command->length_qwords *
+         IREE_HAL_AMDGPU_COMMAND_BUFFER_RECORD_ALIGNMENT;
+}
+
+// Returns the first command in |block|.
+static inline iree_hal_amdgpu_command_buffer_command_header_t*
+iree_hal_amdgpu_command_buffer_block_commands(
+    iree_hal_amdgpu_command_buffer_block_header_t* block) {
+  uint8_t* block_base = (uint8_t*)block;
+  uint8_t* command_data = block_base + block->command_offset;
+  return (iree_hal_amdgpu_command_buffer_command_header_t*)command_data;
+}
+
+// Returns the first command in |block|.
+static inline const iree_hal_amdgpu_command_buffer_command_header_t*
+iree_hal_amdgpu_command_buffer_block_commands_const(
+    const iree_hal_amdgpu_command_buffer_block_header_t* block) {
+  const uint8_t* block_base = (const uint8_t*)block;
+  const uint8_t* command_data = block_base + block->command_offset;
+  return (const iree_hal_amdgpu_command_buffer_command_header_t*)command_data;
+}
+
+// Returns the command following |command|.
+static inline iree_hal_amdgpu_command_buffer_command_header_t*
+iree_hal_amdgpu_command_buffer_command_next(
+    iree_hal_amdgpu_command_buffer_command_header_t* command) {
+  uint8_t* command_base = (uint8_t*)command;
+  uint8_t* next_command =
+      command_base + iree_hal_amdgpu_command_buffer_command_length(command);
+  return (iree_hal_amdgpu_command_buffer_command_header_t*)next_command;
+}
+
+// Returns the command following |command|.
+static inline const iree_hal_amdgpu_command_buffer_command_header_t*
+iree_hal_amdgpu_command_buffer_command_next_const(
+    const iree_hal_amdgpu_command_buffer_command_header_t* command) {
+  const uint8_t* command_base = (const uint8_t*)command;
+  const uint8_t* next_command =
+      command_base + iree_hal_amdgpu_command_buffer_command_length(command);
+  return (const iree_hal_amdgpu_command_buffer_command_header_t*)next_command;
+}
+
+// Returns the first binding source record in |block|.
+static inline iree_hal_amdgpu_command_buffer_binding_source_t*
+iree_hal_amdgpu_command_buffer_block_binding_sources(
+    iree_hal_amdgpu_command_buffer_block_header_t* block) {
+  uint8_t* block_base = (uint8_t*)block;
+  uint8_t* binding_source_data = block_base + block->binding_source_offset;
+  return (iree_hal_amdgpu_command_buffer_binding_source_t*)binding_source_data;
+}
+
+// Returns the first binding source record in |block|.
+static inline const iree_hal_amdgpu_command_buffer_binding_source_t*
+iree_hal_amdgpu_command_buffer_block_binding_sources_const(
+    const iree_hal_amdgpu_command_buffer_block_header_t* block) {
+  const uint8_t* block_base = (const uint8_t*)block;
+  const uint8_t* binding_source_data =
+      block_base + block->binding_source_offset;
+  return (const iree_hal_amdgpu_command_buffer_binding_source_t*)
+      binding_source_data;
+}
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_ABI_COMMAND_BUFFER_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/abi/common.h b/runtime/src/iree/hal/drivers/amdgpu/abi/common.h
new file mode 100644
index 0000000..bca7f1f
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/abi/common.h
@@ -0,0 +1,93 @@
+// Copyright 2025 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// Minimal type definitions and attributes shared between AMDGPU device code
+// (bare-metal C compiled to GPU bitcode) and host code (standard C compiled
+// for the CPU). This header is the root of the abi/ dependency tree and must
+// have zero dependencies beyond system headers and the HSA type definitions.
+//
+// Device code gets bare-metal typedefs for fixed-width integers and the
+// compiler attribute forms it needs. Host code gets the same logical macros
+// backed by standard C11 and HSA headers.
+//
+// The abi/ headers define only struct layouts, enums, and constants that match
+// what the hardware expects. No operations, no atomics, no device builtins.
+// Those live in device/support/ which re-exports abi/ and adds the
+// implementation machinery.
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_ABI_COMMON_H_
+#define IREE_HAL_DRIVERS_AMDGPU_ABI_COMMON_H_
+
+//===----------------------------------------------------------------------===//
+// Compiler Configuration
+//===----------------------------------------------------------------------===//
+
+#if defined(__AMDGPU__)
+#define IREE_AMDGPU_TARGET_DEVICE 1
+#else
+#define IREE_AMDGPU_TARGET_HOST 1
+#endif  // __AMDGPU__
+
+#if defined(IREE_AMDGPU_TARGET_DEVICE)
+
+typedef char int8_t;
+typedef unsigned char uint8_t;
+typedef short int16_t;
+typedef unsigned short uint16_t;
+typedef int int32_t;
+typedef unsigned int uint32_t;
+typedef long int64_t;
+typedef unsigned long uint64_t;
+
+typedef int64_t ssize_t;
+typedef uint64_t size_t;
+typedef int64_t intptr_t;
+typedef uint64_t uintptr_t;
+
+#define UINT32_MAX 0xFFFFFFFFu
+#define UINT64_MAX 0xFFFFFFFFFFFFFFFFull
+
+#define NULL ((void*)0)
+
+#else
+
+#include <stdbool.h>
+#include <stddef.h>
+#include <stdint.h>
+
+// HSA system type definitions. On the host side several abi/ types are
+// typedef'd directly to their HSA equivalents (e.g. iree_hsa_signal_t) so
+// that they can be used interchangeably with HSA API calls.
+#include "third_party/hsa-runtime-headers/include/hsa/hsa.h"  // IWYU pragma: export
+
+#endif  // IREE_AMDGPU_TARGET_DEVICE
+
+// Both device and host targets use GCC/Clang so the compiler attribute syntax
+// is identical. The AMDGPU driver is Linux-only and always compiled with Clang.
+#define IREE_AMDGPU_RESTRICT __restrict__
+#define IREE_AMDGPU_ALIGNAS(x) __attribute__((aligned(x)))
+#define IREE_AMDGPU_ALIGNOF(x) __alignof__(x)
+#define IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE __attribute__((always_inline))
+#define IREE_AMDGPU_ATTRIBUTE_PACKED __attribute__((__packed__))
+
+#if defined(__cplusplus)
+#define IREE_AMDGPU_STATIC_ASSERT(expr, message) static_assert((expr), message)
+#else
+#define IREE_AMDGPU_STATIC_ASSERT(expr, message) _Static_assert((expr), message)
+#endif  // __cplusplus
+
+#if defined(IREE_AMDGPU_TARGET_DEVICE)
+#define IREE_AMDGPU_OFFSETOF(type, field) __builtin_offsetof(type, field)
+#else
+#define IREE_AMDGPU_OFFSETOF(type, field) offsetof(type, field)
+#endif  // IREE_AMDGPU_TARGET_DEVICE
+
+// Tick in the agent domain.
+// This can be converted to the system domain for correlation across agents and
+// the host with hsa_amd_profiling_convert_tick_to_system_domain.
+typedef uint64_t iree_amdgpu_device_tick_t;
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_ABI_COMMON_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/abi/kernel_args.h b/runtime/src/iree/hal/drivers/amdgpu/abi/kernel_args.h
new file mode 100644
index 0000000..fe64f33
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/abi/kernel_args.h
@@ -0,0 +1,333 @@
+// Copyright 2025 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// Kernel argument struct layouts for AMDGPU dispatch. Covers both the explicit
+// kernel arguments (iree_hal_amdgpu_device_kernel_args_t) used by IREE's
+// compiled kernel dispatch and the implicit kernel arguments
+// (iree_amdgpu_kernel_implicit_args_t) defined by the LLVM AMDGPU backend for
+// OpenCL/HIP compatibility.
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_ABI_KERNEL_ARGS_H_
+#define IREE_HAL_DRIVERS_AMDGPU_ABI_KERNEL_ARGS_H_
+
+#include "iree/hal/drivers/amdgpu/abi/common.h"
+
+typedef struct iree_hsa_queue_t iree_hsa_queue_t;
+
+//===----------------------------------------------------------------------===//
+// Kernel Arguments
+//===----------------------------------------------------------------------===//
+
+// Explicit kernel arguments for IREE dispatches where the kernel function
+// signature is known at executable load time. These are populated from the
+// code object metadata and cached for the lifetime of the executable so that
+// dispatches can be issued without re-querying the symbol properties.
+// This must match what the kernel was compiled to support.
+typedef struct iree_hal_amdgpu_device_kernel_args_t {
+  // Opaque handle to the kernel object to execute.
+  uint64_t kernel_object;
+  // Dispatch setup parameters. Used to configure kernel dispatch parameters
+  // such as the number of dimensions in the grid. The parameters are
+  // described by hsa_kernel_dispatch_packet_setup_t.
+  uint16_t setup;
+  // XYZ dimensions of work-group, in work-items. Must be greater than 0.
+  // If the grid has fewer than 3 dimensions the unused must be 1.
+  uint16_t workgroup_size[3];
+  // Size in bytes of private memory allocation request (per work-item).
+  uint32_t private_segment_size;
+  // Size in bytes of group memory allocation request (per work-group). Must
+  // not be less than the sum of the group memory used by the kernel (and the
+  // functions it calls directly or indirectly) and the dynamically allocated
+  // group segment variables.
+  uint32_t group_segment_size;
+  // Size of kernarg segment memory that is required to hold the values of the
+  // kernel arguments, in bytes. Must be a multiple of 16.
+  uint16_t kernarg_size;
+  // Alignment (in bytes) of the buffer used to pass arguments to the kernel,
+  // which is the maximum of 16 and the maximum alignment of any of the kernel
+  // arguments.
+  uint16_t kernarg_alignment;
+  // Total number of 4-byte constants used by the dispatch (if a HAL dispatch).
+  uint16_t constant_count;
+  // Total number of bindings used by the dispatch (if a HAL dispatch).
+  uint16_t binding_count;
+  // Reserved for future hot kernel metadata. Must be zero.
+  uint32_t reserved;
+} iree_hal_amdgpu_device_kernel_args_t;
+IREE_AMDGPU_STATIC_ASSERT(
+    sizeof(iree_hal_amdgpu_device_kernel_args_t) <= 64,
+    "keep hot kernel arg structure in as few cache lines as possible; every "
+    "dispatch issued must access this information and it is likely uncached");
+
+// Implicit kernel arguments passed to OpenCL/HIP kernels that use them.
+// Not all kernels require this and the metadata needs to be checked to detect
+// its use (or if the total kernargs size is > what we think it should be).
+// Layout-wise explicit args always start at offset 0 and implicit args follow
+// those with 8-byte alignment.
+//
+// The metadata will contain exact fields and offsets and most driver code will
+// carefully walk to detect, align, pad, and write each field:
+// OpenCL/HIP: (`amd::KernelParameterDescriptor`...)
+// https://github.com/ROCm/clr/blob/5da72f9d524420c43fe3eee44b11ac875d884e0f/rocclr/device/rocm/rocvirtual.cpp#L3197
+//
+// This complex construction was required once upon a time. The LLVM code
+// producing the kernargs layout and metadata handles these cases much more
+// simply by only ever truncating the implicit args at the last used field:
+// https://github.com/llvm/llvm-project/blob/7f1b465c6ae476e59dc90652d58fc648932d23b1/llvm/lib/Target/AMDGPU/AMDGPUHSAMetadataStreamer.cpp#L389
+//
+// Then at some point in time someone was like "meh, who cares about optimizing"
+// and decided to include all of them always 🤦:
+// https://github.com/llvm/llvm-project/blob/7f1b465c6ae476e59dc90652d58fc648932d23b1/llvm/lib/Target/AMDGPU/AMDGPUSubtarget.cpp#L299
+//
+// What this means in practice is that if any implicit arg is used then all will
+// be included and declared in the metadata even if only one is actually read by
+// the kernel -- there's no way for us to know. In the ideal case none of them
+// are read and the kernel function gets the `amdgpu-no-implicitarg-ptr` attr
+// so that all of them can be skipped. Otherwise we reserve the 256 bytes and
+// just splat them all in. This at least keeps our code simple relative to all
+// the implementations that enumerate the metadata and write args one at a time.
+// We really should try to force `amdgpu-no-implicitarg-ptr` when we generate
+// code, though.
+//
+// For our bare-metal C runtime device code we have total freedom and don't use
+// any OpenCL/HIP-related things that would emit the implicit args.
+typedef struct IREE_AMDGPU_ALIGNAS(8) iree_amdgpu_kernel_implicit_args_t {
+  // Grid dispatch workgroup count.
+  // Some languages, such as OpenCL, support a last workgroup in each
+  // dimension being partial. This count only includes the non-partial
+  // workgroup count. This is not the same as the value in the AQL dispatch
+  // packet, which has the grid size in workitems.
+  //
+  // Represented in metadata as:
+  //   hidden_block_count_x
+  //   hidden_block_count_y
+  //   hidden_block_count_z
+  uint32_t block_count[3];  // + 0/4/8
+
+  // Grid dispatch workgroup size.
+  // This size only applies to the non-partial workgroups. This is the same
+  // value as the AQL dispatch packet workgroup size.
+  //
+  // Represented in metadata as:
+  //   hidden_group_size_x
+  //   hidden_group_size_y
+  //   hidden_group_size_z
+  uint16_t group_size[3];  // + 12/14/16
+
+  // Grid dispatch work group size of the partial work group, if it exists.
+  // Any dimension that does not exist must be 0. Only used in OpenCL and can
+  // be 0.
+  //
+  // Represented in metadata as:
+  //   hidden_remainder_x
+  //   hidden_remainder_y
+  //   hidden_remainder_z
+  uint16_t remainder[3];  // + 18/20/22
+
+  uint64_t reserved0;  // + 24 hidden_tool_correlation_id
+  uint64_t reserved1;  // + 32
+
+  // OpenCL grid dispatch global offset.
+  // Always 0 in HIP but still required as the device library functions for
+  // grid locations is shared with OpenCL and unconditionally factors it in.
+  //
+  // Hardcoded to 0 in HIP:
+  // https://github.com/ROCm/clr/blob/a2550e0a9ecaa8f371cb14d08904c51874c37cbe/hipamd/src/hip_module.cpp#L348
+  //
+  // Represented in metadata as:
+  //   hidden_global_offset_x
+  //   hidden_global_offset_y
+  //   hidden_global_offset_z
+  uint64_t global_offset[3];  // + 40/48/56
+
+  // Grid dispatch dimensionality. This is the same value as the AQL
+  // dispatch packet dimensionality. Must be a value between 1 and 3.
+  //
+  // Hardcoded to 3 in HIP:
+  // https://github.com/ROCm/clr/blob/a2550e0a9ecaa8f371cb14d08904c51874c37cbe/hipamd/src/hip_module.cpp#L349
+  //
+  // Represented in metadata as:
+  //   hidden_grid_dims
+  uint16_t grid_dims;  // + 64
+
+  // Fixed-size buffer for `-mprintf-kind=buffered` support.
+  // By default LLVM uses `hostcall` but that's a mess and we avoid it.
+  // `__printf_alloc` in the device library is used to grab this pointer, the
+  // header DWORDs are manipulated, and the contents are written to the buffer.
+  //
+  // struct {
+  //   atomic_uint32_t offset;
+  //   uint32_t size;
+  //   uint8_t data[size];
+  // } printf_buffer_t;
+  //
+  // One of many disappointing parts of this scheme is that constant string
+  // values are interned, MD5 hashed, and stored *externally* in the amdhsa data
+  // blob. In order to print with any constant format string this data blob
+  // needs to be parsed, retained, and referenced every time a printf packet is
+  // processed. It would have been significantly better to embed the table in
+  // the ELF as a global constant instead as then we could reference it on both
+  // host and device and not need to parse the amdhsa blob.
+  //
+  // The contents of the data buffer are best defined by the janky parser code:
+  // https://github.com/ROCm/clr/blob/a2550e0a9ecaa8f371cb14d08904c51874c37cbe/rocclr/device/rocm/rocprintf.cpp#L454
+  // Each printf consists of a control DWORD followed by 8-byte aligned
+  // contents. Effectively:
+  // struct {
+  //   uint32_t is_stderr : 1;       // else stdout
+  //   uint32_t constant : 1;        // constant format string code path
+  //   uint32_t size_in_bytes : 30;  // (including this header)
+  //   uint64_t data[size_in_bytes / 8];
+  // } printf_packet_t;
+  //
+  // To construct the full format data buffer if constant == 1:
+  //  data[0] contains the lower 64-bits of the MD5 hash of the string followed
+  //  by size_in_bytes-12 arguments. The data buffer needs to be expanded into
+  //  an 8-byte aligned NUL-terminated string with the corresponding hash
+  //  followed by the arguments verbatim. Once reconstituted the subsequent
+  //  logic is the same.
+  //
+  // The data buffer is an 8-byte aligned NUL-terminated string followed by
+  // the argument data. E.g. `hi! %s` would be encoded as `hi! %s` 0x00 0x??
+  // (with the last byte being padding to an 8-byte boundary). The reference
+  // code for formatting the string lives in the CLR:
+  // https://github.com/ROCm/clr/blob/a2550e0a9ecaa8f371cb14d08904c51874c37cbe/rocclr/device/devhcprintf.cpp#L168
+  // Note that the documentation is incorrect about there being a version prefix
+  // and it expects the first uint64_t to contain the format string bytes.
+  //
+  // Note that in another disappointing display of rube-goldbergian development
+  // this implementation for some reason uses uint64_t for its data elements
+  // but never aligns it - meaning that consumer code must use unaligned loads
+  // in order to read the data. The CLR just copies it out each time. One could
+  // think that was for streaming (release the buffer contents early back to
+  // dispatches) but since they fully halt the world and synchronize after every
+  // dispatch containing a print none of that matters and it's just poor
+  // engineering.
+  //
+  // The compiler emits strings in the delimited form of
+  // `"0:0:<format_string_hash>,<actual_format_string>"`. Note that the first
+  // two values should always be 0 and are delimited by `:` while the MD5 hash
+  // is delimited from the format string itself by `,`. There's some special
+  // handling in the CLR for `:` being in the format string because whoever
+  // wrote it did a find from the end instead of a prefix consume - there's
+  // special handling of \72 (`:`) and other weird things that I'm not sure is
+  // needed. Example from LLVM: `"0:0:8addc4c0362218ac,Hello World!:\n"`.
+  //
+  // The hash is the lower 64 bits of the MD5 hash in hex but we don't care as
+  // it's just a semi-unique value we use to lookup the string formats. On load
+  // we sort and do a binary search instead of creating an std::map for every
+  // single print invocation like the CLR does. Just... wow.
+  //
+  // Handling the contents is also overtly complicated and poorly documented:
+  // https://github.com/ROCm/clr/blob/a2550e0a9ecaa8f371cb14d08904c51874c37cbe/rocclr/device/devhcprintf.cpp#L168
+  //
+  // See:
+  // https://github.com/ROCm/llvm-project/commit/631c965483e03355cdc1dba578e787b259c4d79d
+  // https://github.com/ROCm/llvm-project/blob/997363823fcc5ccc7b0cc572aad05ba08714bf5f/amd/device-libs/ockl/src/cprintf.cl#L17
+  // https://github.com/ROCm/clr/blob/a2550e0a9ecaa8f371cb14d08904c51874c37cbe/rocclr/device/rocm/rocprintf.cpp#L393
+  //
+  // Note that having a printf in a kernel causes the kernel to dispatch
+  // synchronously :facepalm:. We can't do the same and would need to emit
+  // flush packets (or something) into the control queue. What a mess.
+  // https://github.com/ROCm/clr/blob/a2550e0a9ecaa8f371cb14d08904c51874c37cbe/rocclr/device/rocm/rocvirtual.cpp#L3644
+  // https://github.com/ROCm/clr/blob/a2550e0a9ecaa8f371cb14d08904c51874c37cbe/rocclr/device/rocm/rocprintf.cpp#L428-L429
+  //
+  // Represented in metadata as:
+  //   hidden_printf_buffer
+  void* printf_buffer;  // + 72
+
+  // Used for ASAN, printf, and more modern device memory allocations.
+  // It's bizarre and only "documented" in code and I really hope we don't have
+  // to touch it. Note that due to some LLVM bug sometimes this will be included
+  // in the offset table for a kernel even if it is not used (the
+  // `amdgpu-no-hostcall-ptr` attribute is set). At this point I'm quite sure no
+  // one has ever actually inspected the files produced by the LLVM backend.
+  //
+  // Represented in metadata as:
+  //   hidden_hostcall_buffer
+  void* hostcall_buffer;  // + 80
+
+  // Multi-grid support was deprecated in ROCM 5.x and should never appear in
+  // any program we generate ourselves or care about running.
+  //
+  // Represented in metadata as:
+  //   hidden_multigrid_sync_arg
+  uint64_t deprecated_multigrid_sync_arg;
+
+  // Device memory heap pointer for device malloc/free.
+  // We don't support kernels using this as it requires too much goo for little
+  // payoff. The kernels we run shouldn't be malloc/freeing internally. If they
+  // do we will need to implement the heap API via hostcalls and other silly
+  // things that add a tremendous amount of complexity.
+  //
+  // See:
+  // https://github.com/ROCm/llvm-project/blob/97753eeaa4c79c2db2dcd9f37b7989596a8d4f15/amd/device-libs/ockl/src/dm.cl#L192
+  //
+  // Represented in metadata as:
+  //   hidden_heap_v1
+  uint64_t unused_heap_v1;
+
+  // AQL queue handles are only used by OpenCL device-side enqueue and we do not
+  // support that. We could, probably, by passing in our execution queue but
+  // since HIP has never supported it the use case doesn't exist. If we wanted
+  // to support device-enqueue we'd do it in a structured fashion instead of
+  // letting kernels splat right into the AQL queue.
+  //
+  // See:
+  // https://github.com/ROCm/llvm-project/blob/97753eeaa4c79c2db2dcd9f37b7989596a8d4f15/amd/device-libs/opencl/src/devenq/enqueue.cl#L310
+  //
+  // Represented in metadata as:
+  //   hidden_default_queue
+  uint64_t unused_default_queue;
+
+  // Completion actions were (I believe) an attempt at dynamic parallelism and
+  // HIP has never supported them. Device-side enqueue in OpenCL uses this but
+  // we don't support those kernels.
+  //
+  // See:
+  // https://github.com/ROCm/llvm-project/blob/97753eeaa4c79c2db2dcd9f37b7989596a8d4f15/amd/device-libs/opencl/src/devenq/enqueue.cl#L311
+  //
+  // Represented in metadata as:
+  //   hidden_completion_action
+  uint64_t unused_completion_action;
+
+  // The value of the sharedMemBytes parameter to the dispatch indicating how
+  // much dynamic shared memory was reserved for the kernel. This may be larger
+  // than the requested amount. The total group_segment_size for a dispatch is
+  // the static LDS requirement of the kernel plus this value.
+  //
+  // Represented in metadata as:
+  //   hidden_dynamic_lds_size
+  uint32_t dynamic_lds_size;
+
+  uint8_t reserved[68];
+
+  // Only used by GFX8, which we don't support.
+  //
+  // Represented in metadata as:
+  //   hidden_private_base
+  uint32_t deprecated_private_base;
+
+  // Only used by GFX8, which we don't support.
+  //
+  // Represented in metadata as:
+  //   hidden_shared_base
+  uint32_t deprecated_shared_base;
+
+  // AQL queue the dispatch is running on.
+  // Only used by pre-GFX9 devices, which we don't support.
+  //
+  // Represented in metadata as:
+  //   hidden_queue_ptr;
+  iree_hsa_queue_t* deprecated_queue_ptr;
+} iree_amdgpu_kernel_implicit_args_t;
+
+#define IREE_AMDGPU_KERNEL_IMPLICIT_ARGS_SIZE               \
+  (IREE_AMDGPU_OFFSETOF(iree_amdgpu_kernel_implicit_args_t, \
+                        dynamic_lds_size) +                 \
+   sizeof(((iree_amdgpu_kernel_implicit_args_t*)NULL)->dynamic_lds_size))
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_ABI_KERNEL_ARGS_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/abi/profile.h b/runtime/src/iree/hal/drivers/amdgpu/abi/profile.h
new file mode 100644
index 0000000..f5bedc0
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/abi/profile.h
@@ -0,0 +1,135 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_ABI_PROFILE_H_
+#define IREE_HAL_DRIVERS_AMDGPU_ABI_PROFILE_H_
+
+#include "iree/hal/drivers/amdgpu/abi/timestamp.h"
+
+//===----------------------------------------------------------------------===//
+// Dispatch event records
+//===----------------------------------------------------------------------===//
+
+// Bitfield specifying properties of one AMDGPU dispatch event record.
+typedef uint32_t iree_hal_amdgpu_profile_dispatch_event_flags_t;
+enum iree_hal_amdgpu_profile_dispatch_event_flag_bits_t {
+  IREE_HAL_AMDGPU_PROFILE_DISPATCH_EVENT_FLAG_NONE = 0u,
+  // Dispatch was enqueued through a reusable command buffer.
+  IREE_HAL_AMDGPU_PROFILE_DISPATCH_EVENT_FLAG_COMMAND_BUFFER = 1u << 0,
+  // Workgroup counts were loaded from device memory before dispatch.
+  IREE_HAL_AMDGPU_PROFILE_DISPATCH_EVENT_FLAG_INDIRECT_PARAMETERS = 1u << 1,
+};
+
+// Device-written dispatch execution event.
+//
+// Host submission writes all static metadata before publishing the harvest
+// dispatch. The device-side harvest kernel writes only start_tick/end_tick
+// after queue ordering proves the profiled dispatch completed.
+typedef struct iree_hal_amdgpu_profile_dispatch_event_t {
+  // Size of this record in bytes for forward-compatible parsing.
+  uint32_t record_length;
+  // Flags describing how the dispatch was produced.
+  iree_hal_amdgpu_profile_dispatch_event_flags_t flags;
+  // Producer-defined event identifier unique within the dispatch event stream.
+  uint64_t event_id;
+  // Queue submission epoch containing this dispatch.
+  uint64_t submission_id;
+  // Session-local command-buffer identifier, or 0 for direct queue dispatch.
+  uint64_t command_buffer_id;
+  // Session-local executable identifier, or 0 when unavailable.
+  uint64_t executable_id;
+  // Command ordinal within a command buffer, or UINT32_MAX for direct dispatch.
+  uint32_t command_index;
+  // Executable export ordinal dispatched.
+  uint32_t export_ordinal;
+  // Workgroup counts submitted for each dimension.
+  uint32_t workgroup_count[3];
+  // Workgroup sizes submitted for each dimension.
+  uint32_t workgroup_size[3];
+  // Device timestamp captured when dispatch execution started.
+  uint64_t start_tick;
+  // Device timestamp captured when dispatch execution completed.
+  uint64_t end_tick;
+} iree_hal_amdgpu_profile_dispatch_event_t;
+IREE_AMDGPU_STATIC_ASSERT(
+    sizeof(iree_hal_amdgpu_profile_dispatch_event_t) == 88,
+    "dispatch event record size is part of the profiling ABI");
+IREE_AMDGPU_STATIC_ASSERT(
+    IREE_AMDGPU_OFFSETOF(iree_hal_amdgpu_profile_dispatch_event_t, start_tick) +
+            sizeof(iree_hal_amdgpu_timestamp_range_t) ==
+        sizeof(iree_hal_amdgpu_profile_dispatch_event_t),
+    "dispatch event timestamps must be a trailing timestamp range");
+
+// Returns the timestamp range embedded in |event|.
+static inline iree_hal_amdgpu_timestamp_range_t*
+iree_hal_amdgpu_profile_dispatch_event_ticks(
+    iree_hal_amdgpu_profile_dispatch_event_t* event) {
+  return (iree_hal_amdgpu_timestamp_range_t*)&event->start_tick;
+}
+
+// Fixed timestamp harvest source used to populate profile dispatch events.
+typedef iree_hal_amdgpu_dispatch_timestamp_harvest_source_t
+    iree_hal_amdgpu_profile_dispatch_harvest_source_t;
+IREE_AMDGPU_STATIC_ASSERT(
+    sizeof(iree_hal_amdgpu_profile_dispatch_harvest_source_t) == 16,
+    "dispatch harvest source size is part of the profiling ABI");
+
+// Fixed timestamp harvest arguments used to populate profile dispatch events.
+typedef iree_hal_amdgpu_dispatch_timestamp_harvest_args_t
+    iree_hal_amdgpu_profile_dispatch_harvest_args_t;
+IREE_AMDGPU_STATIC_ASSERT(
+    sizeof(iree_hal_amdgpu_profile_dispatch_harvest_args_t) == 16,
+    "dispatch harvest args must match the kernel ABI");
+
+//===----------------------------------------------------------------------===//
+// Queue device event records
+//===----------------------------------------------------------------------===//
+
+// Device-written queue operation event.
+//
+// Host submission writes all static metadata before publishing the timestamp
+// packet. PM4 timestamp packets write only start_tick/end_tick while the
+// notification ring epoch continues to own readiness and reclaim.
+typedef struct iree_hal_amdgpu_profile_queue_device_event_t {
+  // Size of this record in bytes for forward-compatible parsing.
+  uint32_t record_length;
+  // Kind of queue operation represented by this device event.
+  uint32_t type;
+  // Flags describing queue operation properties.
+  uint32_t flags;
+  // Reserved for future queue device event fields; must be zero.
+  uint32_t reserved0;
+  // Producer-defined event identifier unique within the queue device event
+  // stream.
+  uint64_t event_id;
+  // Queue submission epoch containing this device event.
+  uint64_t submission_id;
+  // Session-local command-buffer identifier, or 0 when not applicable.
+  uint64_t command_buffer_id;
+  // Producer-defined allocation identifier, or 0 when not applicable.
+  uint64_t allocation_id;
+  // Producer-defined stream identifier matching the queue metadata record.
+  uint64_t stream_id;
+  // Type-specific payload byte length, or 0 when not applicable.
+  uint64_t payload_length;
+  // Session-local physical device ordinal associated with this operation.
+  uint32_t physical_device_ordinal;
+  // Session-local queue ordinal associated with this operation.
+  uint32_t queue_ordinal;
+  // Number of encoded payload operations represented by this queue operation.
+  uint32_t operation_count;
+  // Reserved for future queue device event fields; must be zero.
+  uint32_t reserved1;
+  // Device timestamp captured when queue-visible work started.
+  uint64_t start_tick;
+  // Device timestamp captured when queue-visible work completed.
+  uint64_t end_tick;
+} iree_hal_amdgpu_profile_queue_device_event_t;
+IREE_AMDGPU_STATIC_ASSERT(
+    sizeof(iree_hal_amdgpu_profile_queue_device_event_t) == 96,
+    "queue device event record size is part of the profiling ABI");
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_ABI_PROFILE_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/abi/queue.h b/runtime/src/iree/hal/drivers/amdgpu/abi/queue.h
new file mode 100644
index 0000000..30d778f
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/abi/queue.h
@@ -0,0 +1,445 @@
+// Copyright 2025 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// HSA/AMDGPU AQL queue and packet struct layouts as defined by the HSA spec
+// and AMD extensions. These are the types that appear in device-visible memory
+// and must have identical layout on both host and device.
+//
+// This header contains only type definitions, enums, and constants.
+// Device-side kernel dispatch/work-item helpers live in
+// device/support/kernel.h. Device-side queue index manipulation functions and
+// the cached queue optimization live in device/support/queue.h.
+//
+// Sources:
+// https://hsafoundation.com/wp-content/uploads/2021/02/HSA-SysArch-1.2.pdf
+// https://github.com/ROCm/ROCR-Runtime
+// https://github.com/ROCm/rocMLIR/blob/develop/external/llvm-project/amd/device-libs/README.md
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_ABI_QUEUE_H_
+#define IREE_HAL_DRIVERS_AMDGPU_ABI_QUEUE_H_
+
+#include "iree/hal/drivers/amdgpu/abi/signal.h"
+
+//===----------------------------------------------------------------------===//
+// HSA/AMDGPU AQL Queue
+//===----------------------------------------------------------------------===//
+
+typedef enum {
+  // Queue supports multiple producers.
+  IREE_HSA_QUEUE_TYPE_MULTI = 0,
+  // Queue only supports a single producer.
+  IREE_HSA_QUEUE_TYPE_SINGLE = 1,
+} iree_hsa_queue_type_t;
+
+// NOTE: this is not our struct and we cannot change it.
+typedef struct iree_hsa_queue_t {
+  // Queue type.
+  iree_hsa_queue_type_t type;
+
+  // Queue features mask. This is a bit-field of iree_hsa_queue_feature_t
+  // values. Applications should ignore any unknown set bits.
+  uint32_t features;
+
+  // Packet storage. Must be accessible on any agents that may operate on it and
+  // aligned to at least 64 (the size of an AQL packet).
+  void* base_address;
+
+  // Signal object used by the application to indicate the ID of a packet that
+  // is ready to be processed. The HSA runtime or hardware packet processor
+  // manages the doorbell signal. If the application tries to replace or destroy
+  // this signal the behavior is undefined.
+  //
+  // If type is HSA_QUEUE_TYPE_SINGLE the doorbell signal value must be
+  // updated in a monotonically increasing fashion. If type is
+  // HSA_QUEUE_TYPE_MULTI the doorbell signal value can be updated with any
+  // value and the act of writing a differing value is enough to wake the
+  // processor. On AMD GPUs today it is reportedly not any more efficient to
+  // use SINGLE queues as the packet processor handles both the same way.
+  iree_hsa_signal_t doorbell_signal;
+
+  // Maximum number of packets the queue can hold. Must be a power of 2.
+  uint32_t size;
+
+  uint32_t reserved1;  // must be 0
+
+  // Queue identifier, which is unique over the lifetime of the application even
+  // if the queue is reallocated.
+  uint64_t id;
+} iree_hsa_queue_t;
+
+#define IREE_AMD_HSA_BITS_CREATE_ENUM_ENTRIES(name, shift, width) \
+  name##_SHIFT = (shift), name##_WIDTH = (width),                 \
+  name = (((1 << (width)) - 1) << (shift))
+#define IREE_AMD_HSA_BITS_SET(dst, mask, val) \
+  dst &= (~(1 << mask##_SHIFT) & ~mask);      \
+  dst |= (((val) << mask##_SHIFT) & mask)
+
+enum iree_amd_queue_properties_t {
+  IREE_AMD_HSA_BITS_CREATE_ENUM_ENTRIES(
+      IREE_AMD_QUEUE_PROPERTIES_ENABLE_TRAP_HANDLER, 0, 1),
+  // All devices we care about are 64-bit.
+  IREE_AMD_HSA_BITS_CREATE_ENUM_ENTRIES(IREE_AMD_QUEUE_PROPERTIES_IS_PTR64, 1,
+                                        1),
+  IREE_AMD_HSA_BITS_CREATE_ENUM_ENTRIES(
+      IREE_AMD_QUEUE_PROPERTIES_ENABLE_TRAP_HANDLER_DEBUG_SGPRS, 2, 1),
+  // Timestamps will be stored on signals (start_ts/end_ts).
+  IREE_AMD_HSA_BITS_CREATE_ENUM_ENTRIES(
+      IREE_AMD_QUEUE_PROPERTIES_ENABLE_PROFILING, 3, 1),
+  IREE_AMD_HSA_BITS_CREATE_ENUM_ENTRIES(
+      IREE_AMD_QUEUE_PROPERTIES_USE_SCRATCH_ONCE, 4, 1),
+  IREE_AMD_HSA_BITS_CREATE_ENUM_ENTRIES(IREE_AMD_QUEUE_PROPERTIES_RESERVED1, 5,
+                                        27)
+};
+typedef uint32_t iree_amd_queue_properties32_t;
+
+// An AQL packet queue.
+// We generally treat these as opaque except for if we need to read queue
+// properties to check modes - otherwise we just treat any queue handle as
+// an iree_hsa_queue_t.
+typedef struct IREE_AMDGPU_ALIGNAS(64) iree_amd_queue_t {
+  iree_hsa_queue_t hsa_queue;
+  uint32_t caps;
+  uint32_t reserved1[3];
+  volatile uint64_t write_dispatch_id;
+  uint32_t group_segment_aperture_base_hi;
+  uint32_t private_segment_aperture_base_hi;
+  uint32_t max_cu_id;
+  uint32_t max_wave_id;
+  volatile uint64_t max_legacy_doorbell_dispatch_id_plus_1;
+  volatile uint32_t legacy_doorbell_lock;
+  uint32_t reserved2[9];
+  volatile uint64_t read_dispatch_id;
+  uint32_t read_dispatch_id_field_base_byte_offset;
+  uint32_t compute_tmpring_size;
+  uint32_t scratch_resource_descriptor[4];
+  uint64_t scratch_backing_memory_location;
+  uint64_t scratch_backing_memory_byte_size;
+  uint32_t scratch_workitem_byte_size;
+  iree_amd_queue_properties32_t queue_properties;
+  volatile uint64_t scratch_last_used_index; /* async-reclaim */
+  iree_hsa_signal_t queue_inactive_signal;
+  uint32_t reserved4[2];
+  volatile uint64_t alt_scratch_last_used_index; /* async-reclaim */
+  uint64_t alt_scratch_backing_memory_location;  /* async-reclaim */
+  uint64_t alt_scratch_backing_memory_byte_size; /* async-reclaim */
+  uint32_t alt_scratch_dispatch_limit_x;         /* async-reclaim */
+  uint32_t alt_scratch_dispatch_limit_y;         /* async-reclaim */
+  uint32_t alt_scratch_dispatch_limit_z;         /* async-reclaim */
+  uint32_t alt_scratch_wave64_lane_byte_size;    /* async-reclaim */
+  uint32_t alt_compute_tmpring_size;             /* async-reclaim */
+  uint32_t reserved5;
+} iree_amd_queue_t;
+
+//===----------------------------------------------------------------------===//
+// HSA/AMDGPU AQL Packets
+//===----------------------------------------------------------------------===//
+
+typedef enum {
+  // Handled entirely by the packet processor and will vary agent to agent.
+  IREE_HSA_PACKET_TYPE_VENDOR_SPECIFIC = 0,
+  // Invalid packet (not yet populated) that will stall the packet processor.
+  IREE_HSA_PACKET_TYPE_INVALID = 1,
+  // iree_hsa_kernel_dispatch_packet_t
+  IREE_HSA_PACKET_TYPE_KERNEL_DISPATCH = 2,
+  // iree_hsa_barrier_and_packet_t
+  IREE_HSA_PACKET_TYPE_BARRIER_AND = 3,
+  // iree_hsa_agent_dispatch_packet_t
+  IREE_HSA_PACKET_TYPE_AGENT_DISPATCH = 4,
+  // iree_hsa_barrier_or_packet_t
+  IREE_HSA_PACKET_TYPE_BARRIER_OR = 5,
+} iree_hsa_packet_type_t;
+
+// Bit offsets within the header word of various values.
+// We have to perform the bit manipulation ourselves because OpenCL has no
+// bitfields. Crazy.
+//
+// If we did have bitfields the struct would look like:
+// typedef struct {
+//   uint16_t type : 8;
+//   uint16_t barrier : 1;
+//   uint16_t scacquire_fence_scope : 2;
+//   uint16_t screlease_fence_scope : 2;
+//   uint16_t reserved : 3;  // must be 0
+// } iree_hsa_packet_header_t;
+//
+// Since the smallest atomic width is 32-bits and this header is 16-bits any
+// operations updating the header must include the subsequent 16-bits of the
+// packet (e.g. setup for kernel dispatches).
+//
+// See spec 2.9.1 and child entries for the full details.
+typedef enum {
+  // Determines the packet type as processed by the packet processor.
+  // The header is the same for all packets but all other following contents may
+  // change.
+  IREE_HSA_PACKET_HEADER_TYPE = 0,
+  // If set then processing of the packet will only begin when all preceding
+  // packets are complete. There is no implicit fence defined as part of the
+  // barrier and an acquire fence scope must still be specified if any is
+  // required.
+  IREE_HSA_PACKET_HEADER_BARRIER = 8,
+  // A packet memory acquire fence ensures any subsequent global segment or
+  // image loads by any unit of execution that belongs to a dispatch that has
+  // not yet entered the active phase on any queue of the same agent, sees any
+  // data previously released at the scopes specified by the packet acquire
+  // fence.
+  //
+  // Behavior:
+  //   IREE_HSA_FENCE_SCOPE_NONE:
+  //     No fence is applied and the packet relies on an earlier acquire fence
+  //     performed on the agent or acquire fences within the operation (e.g. by
+  //     the kernel).
+  //   IREE_HSA_FENCE_SCOPE_AGENT:
+  //     The acquire fence is applied with agent scope for the global segment.
+  //   IREE_HSA_FENCE_SCOPE_SYSTEM:
+  //     The acquire fence is applied across both agent and system scope for the
+  //     global segment.
+  IREE_HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE = 9,
+  // A packet memory release fence makes any global segment or image data that
+  // was stored by any unit of execution that belonged to a dispatch that has
+  // completed the active phase on any queue of the same agent visible in all
+  // the scopes specified by the packet release fence.
+  //
+  // Behavior:
+  //   IREE_HSA_FENCE_SCOPE_NONE:
+  //     No fence is applied and the packet relies on a later release fence
+  //     performed on the agent or release fences within the operation (e.g. by
+  //     the kernel).
+  //   IREE_HSA_FENCE_SCOPE_AGENT:
+  //     The release fence is applied with agent scope for the global segment.
+  //   IREE_HSA_FENCE_SCOPE_SYSTEM:
+  //     The release fence is applied across both agent and system scope for the
+  //     global segment.
+  IREE_HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE = 11,
+} iree_hsa_packet_header_t;
+
+// Width in bits of the sub-fields in iree_hsa_packet_header_t.
+typedef enum {
+  IREE_HSA_PACKET_HEADER_WIDTH_TYPE = 8,
+  IREE_HSA_PACKET_HEADER_WIDTH_BARRIER = 1,
+  IREE_HSA_PACKET_HEADER_WIDTH_SCACQUIRE_FENCE_SCOPE = 2,
+  IREE_HSA_PACKET_HEADER_WIDTH_SCRELEASE_FENCE_SCOPE = 2,
+} iree_hsa_packet_header_width_t;
+
+// Forms a packet 16-bit AQL packet header.
+#define iree_hsa_make_packet_header(type, is_barrier, scacquire_fence_scope,   \
+                                    screlease_fence_scope)                     \
+  (((type) << IREE_HSA_PACKET_HEADER_TYPE) |                                   \
+   ((is_barrier ? 1 : 0) << IREE_HSA_PACKET_HEADER_BARRIER) |                  \
+   ((scacquire_fence_scope) << IREE_HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE) | \
+   ((screlease_fence_scope) << IREE_HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE))
+
+typedef enum {
+  // No cache management occurs.
+  IREE_HSA_FENCE_SCOPE_NONE = 0,
+  // Invalidates I, K and L1 caches. Changes will be available to any queue on
+  // the same agent but may not be available on any other agent.
+  IREE_HSA_FENCE_SCOPE_AGENT = 1,
+  // Invalidates L1, L2 and flushes L2 caches. Changes will be available on all
+  // agents in the system after the fence completes.
+  IREE_HSA_FENCE_SCOPE_SYSTEM = 2,
+} iree_hsa_fence_scope_t;
+
+// Kernel dispatch (2.9.6 in the spec).
+//
+// Pseudo-code:
+//   for (uint32_t z = 0; z < grid_size[2] / workgroup_size[2]; ++z) {
+//     for (uint32_t y = 0; y < grid_size[1] / workgroup_size[1]; ++y) {
+//       for (uint32_t x = 0; x < grid_size[0] / workgroup_size[0]; ++x) {
+//         kernel_object(*kernarg_address);
+//       }
+//     }
+//   }
+//   iree_hsa_signal_subtract(completion_signal, 1);
+//
+// The acquire fence is applied at the end of the launch phase just before the
+// packet enters the active phase. The release fence is applied at the start of
+// the completion phase of the packet.
+typedef struct iree_hsa_kernel_dispatch_packet_t {
+  // AQL packet header. See iree_hsa_packet_header_t for details.
+  uint16_t header;
+  // Number of grid dimensions (1, 2, or 3 - we always use 3).
+  uint16_t setup;
+  // Work-group size in work-items.
+  uint16_t workgroup_size[3];
+  uint16_t reserved0;  // must be 0
+  // Grid size in work-items.
+  uint32_t grid_size[3];
+  // Total size in bytes of the per-work-item memory.
+  uint32_t private_segment_size;
+  // Total size in bytes of the per-work-group memory.
+  uint32_t group_segment_size;
+  // Kernel object (function) handle as returned from a query on the symbol
+  // of HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_OBJECT.
+  uint64_t kernel_object;
+  // Kernel arguments as required by the function.
+  // Must be 16-byte aligned and live until the dispatch has completed.
+  void* kernarg_address;
+  uint64_t reserved2;  // must be 0
+  iree_hsa_signal_t completion_signal;
+  // Optional signal indicating completion of all work-groups.
+} iree_hsa_kernel_dispatch_packet_t;
+
+// Agent dispatch (2.9.7 in the spec).
+//
+// Pseudo-code:
+//   *return_address = fns[type](arg[0], arg[1], arg[2], arg[3]);
+//   iree_hsa_signal_subtract(completion_signal, 1);
+//
+// The acquire fence is applied at the end of the launch phase just before the
+// packet enters the active phase. The release fence is applied at the start of
+// the completion phase of the packet.
+typedef struct iree_hsa_agent_dispatch_packet_t {
+  // AQL packet header. See iree_hsa_packet_header_t for details.
+  uint16_t header;
+  // Agent-defined type (discriminator).
+  uint16_t type;
+  uint32_t reserved0;  // must be 0
+  // Pointer to store the return value(s) in with the contents and layout
+  // defined by the type.
+  void* return_address;
+  // Arguments to the dispatch as defined by the type.
+  uint64_t arg[4];
+  uint64_t reserved2;  // must be 0
+  // Optional signal indicating completion of the dispatch.
+  iree_hsa_signal_t completion_signal;
+} iree_hsa_agent_dispatch_packet_t;
+
+// Barrier-AND (2.9.8 in the spec).
+// Waits until all dep_signals reach the value 0 at the same time and then
+// decrements the completion_signal. Ignores any 0 (null) signals.
+//
+// Pseudo-code:
+//   do {
+//     bool any_unsatisfied = false;
+//     for (int i = 0; i < 5; ++i) {
+//       if (iree_hsa_signal_load(dep_signal[i]) != 0) any_unsatisfied = true;
+//     }
+//     if (!any_unsatisfied) break;
+//     iree_amdgpu_yield();
+//   } while(true);
+//   iree_hsa_signal_subtract(completion_signal, 1);
+//
+// The acquire fence is processed first in the completion phase of the packet
+// after the barrier condition has been met. The release fence is processed
+// after the acquire fence in the completion phase.
+typedef struct iree_hsa_barrier_and_packet_t {
+  // AQL packet header. See iree_hsa_packet_header_t for details.
+  uint16_t header;
+  uint16_t reserved0;  // must be 0
+  uint32_t reserved1;  // must be 0
+  // Handles for dependent signaling objects to be evaluated by the packet
+  // processor. Any 0 (null) handles are ignored.
+  iree_hsa_signal_t dep_signal[5];
+  uint64_t reserved2;
+  // Signal to decrement when all dep_signals are satisfied.
+  iree_hsa_signal_t completion_signal;
+} iree_hsa_barrier_and_packet_t;
+
+// Barrier-OR (2.9.9 in the spec).
+// Waits until any one dep_signal reaches the value 0 and then decrements the
+// completion_signal. Ignores any 0 (null) signals.
+//
+// Pseudo-code:
+//   do {
+//     for (int i = 0; i < 5; ++i) {
+//       if (iree_hsa_signal_load(dep_signal[i]) == 0) break;
+//     }
+//     iree_amdgpu_yield();
+//   } while(true);
+//   iree_hsa_signal_subtract(completion_signal, 1);
+//
+// The acquire fence is processed first in the completion phase of the packet
+// after the barrier condition has been met. The release fence is processed
+// after the acquire fence in the completion phase.
+typedef struct iree_hsa_barrier_or_packet_t {
+  // AQL packet header. See iree_hsa_packet_header_t for details.
+  uint16_t header;
+  uint16_t reserved0;  // must be 0
+  uint32_t reserved1;  // must be 0
+  // Handles for dependent signaling objects to be evaluated by the packet
+  // processor. Any 0 (null) handles are ignored.
+  iree_hsa_signal_t dep_signal[5];
+  uint64_t reserved2;  // must be 0
+  // Signal to decrement when any dep_signal is satisfied.
+  iree_hsa_signal_t completion_signal;
+} iree_hsa_barrier_or_packet_t;
+
+typedef enum {
+  // iree_hsa_amd_aql_pm4_ib_packet_t
+  IREE_HSA_AMD_AQL_FORMAT_PM4_IB = 1,
+  // iree_hsa_amd_barrier_value_packet_t
+  IREE_HSA_AMD_AQL_FORMAT_BARRIER_VALUE = 2,
+} iree_hsa_amd_aql_format_t;
+typedef uint8_t iree_hsa_amd_aql_format8_t;
+
+// Prefix of AMD-specific vendor packets.
+typedef struct iree_hsa_amd_vendor_packet_header_t {
+  // AQL packet header. See iree_hsa_packet_header_t for details.
+  uint16_t header;
+  // Secondary type indicating which AMD-specific packet this is.
+  iree_hsa_amd_aql_format8_t AmdFormat;
+  uint8_t reserved;  // must be 0
+} iree_hsa_amd_vendor_packet_header_t;
+
+// PM4 indirect-buffer extension.
+// Executes the PM4 indirect buffer referenced by ib_jump_cmd.
+typedef struct iree_hsa_amd_aql_pm4_ib_packet_t {
+  // AMD vendor-specific packet header.
+  iree_hsa_amd_vendor_packet_header_t header;
+  // PM4 INDIRECT_BUFFER packet words.
+  uint32_t ib_jump_cmd[4];
+  // Remaining dword count after the CP consumes the inline PM4 jump.
+  uint32_t dw_cnt_remain;
+  uint32_t reserved[8];  // must be 0
+  // Signal to decrement when the IB completes.
+  iree_hsa_signal_t completion_signal;
+} iree_hsa_amd_aql_pm4_ib_packet_t;
+IREE_AMDGPU_STATIC_ASSERT(sizeof(iree_hsa_amd_aql_pm4_ib_packet_t) == 64,
+                          "PM4-IB packet must be exactly one AQL slot");
+IREE_AMDGPU_STATIC_ASSERT(
+    IREE_AMDGPU_OFFSETOF(iree_hsa_amd_aql_pm4_ib_packet_t, completion_signal) ==
+        56,
+    "PM4-IB completion signal must match ROCR AQL layout");
+
+// Barrier value extension.
+// Halts packet processing and waits for `(signal_value & mask) cond value` to
+// be satisfied before decrementing the completion_signal.
+//
+// Pseudo-code:
+//   do {
+//     if (iree_hsa_evaluate_signal_condition(
+//         /*condition=*/cond,
+//         /*current_value=*/(iree_hsa_signal_load(signal) & mask),
+//         /*desired_value=*/value)) {
+//       break;
+//     }
+//     iree_amdgpu_yield();
+//   } while(true);
+//   iree_hsa_signal_subtract(completion_signal, 1);
+//
+// The acquire fence is processed first in the completion phase of the packet
+// after the barrier condition has been met. The release fence is processed
+// after the acquire fence in the completion phase.
+typedef struct iree_hsa_amd_barrier_value_packet_t {
+  // AMD vendor-specific packet header.
+  iree_hsa_amd_vendor_packet_header_t header;
+  uint32_t reserved0;  // must be 0
+  // Dependent signal object. A 0 (null) signal will be treated as satisfied.
+  iree_hsa_signal_t signal;
+  // Value to compare the signal against (no mask applied).
+  iree_hsa_signal_value_t value;
+  // Bitmask applied to the current signal value.
+  iree_hsa_signal_value_t mask;
+  // Comparison operation.
+  iree_hsa_signal_condition32_t cond;
+  uint32_t reserved1;  // must be 0
+  uint64_t reserved2;  // must be 0
+  uint64_t reserved3;  // must be 0
+  // Signal to decrement when any dep_signal is satisfied.
+  iree_hsa_signal_t completion_signal;
+} iree_hsa_amd_barrier_value_packet_t;
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_ABI_QUEUE_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/abi/signal.h b/runtime/src/iree/hal/drivers/amdgpu/abi/signal.h
new file mode 100644
index 0000000..e5ccd22
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/abi/signal.h
@@ -0,0 +1,180 @@
+// Copyright 2025 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// HSA/AMDGPU signal struct layout as defined by the HSA spec and the AMD
+// hardware implementation. These are the types that appear in device-visible
+// memory and must have identical layout on both host and device.
+//
+// This header contains only type definitions, enums, and pure-logic helpers.
+// Device-side signal manipulation functions (atomic stores, mailbox pokes,
+// etc.) live in device/support/signal.h.
+//
+// Sources:
+// https://hsafoundation.com/wp-content/uploads/2021/02/HSA-SysArch-1.2.pdf
+// https://github.com/ROCm/ROCR-Runtime
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_ABI_SIGNAL_H_
+#define IREE_HAL_DRIVERS_AMDGPU_ABI_SIGNAL_H_
+
+#include "iree/hal/drivers/amdgpu/abi/common.h"
+
+typedef struct iree_amd_queue_t iree_amd_queue_t;
+
+//===----------------------------------------------------------------------===//
+// HSA/AMDGPU Signal
+//===----------------------------------------------------------------------===//
+
+#if defined(IREE_AMDGPU_TARGET_DEVICE)
+
+// "Opaque" reference to an iree_amd_signal_t*.
+// A value of 0 indicates a no-op signal (waits will succeed immediately and
+// completions will no-op).
+typedef struct iree_hsa_signal_t {
+  uint64_t handle;
+} iree_hsa_signal_t;
+
+#else
+
+typedef hsa_signal_t iree_hsa_signal_t;
+
+#endif  // IREE_AMDGPU_TARGET_DEVICE
+
+// No-op signal that will immediately succeed when waited on and be ignored when
+// signaling.
+#define iree_hsa_signal_null() (iree_hsa_signal_t){0}
+
+// Returns true if the given signal is null.
+#define iree_hsa_signal_is_null(signal) ((signal).handle == 0)
+
+// Value of a signal.
+// The interpretation of this is dependent on the operation consuming it.
+// With barrier value packets it's user-defined and can be any value.
+// With barrier-and/barrier-or and dispatch packets it acts as a semaphore where
+// a 0 value indicates set and a non-zero value indicates unset. For example,
+// if 3 operations are required to complete before another can proceed it should
+// be set to 3, included as the completion_signal for the 3 operations, and
+// used as the dependent signal in a barrier. As each operation completes it
+// will decrement the value and when it reaches 0 the barrier will succeed and
+// allow the dependent operation to execute.
+typedef int64_t iree_hsa_signal_value_t;
+
+// AMD signal kind.
+enum iree_amd_signal_kind_t {
+  // Unassigned (not seen).
+  IREE_AMD_SIGNAL_KIND_INVALID = 0,
+  // User-defined signal that supports all signal operations.
+  IREE_AMD_SIGNAL_KIND_USER = 1,
+  // Agent-defined doorbell (usually the queue's doorbell_signal field).
+  // Only writes are permitted from any agent other than the origin and for our
+  // purposes that means no writes ever. Soft queues created by the user must
+  // use IREE_AMD_SIGNAL_KIND_USER as this is reserved for hardware.
+  IREE_AMD_SIGNAL_KIND_DOORBELL = -1,
+};
+// Storage-width alias for iree_amd_signal_kind_t.
+typedef int64_t iree_amd_signal_kind64_t;
+
+// AMDGPU signal implementation.
+// This is an implementation detail from the perspective of the HSA spec but a
+// stable interface to the current generations of hardware implementing HSA.
+// Signals are just locations in memory and have no special behavior other than
+// how they are initialized. For our purposes there are two types: USER and
+// DOORBELL.
+//
+// Signal values depend on the producer/consumer operations. See
+// `iree_hsa_signal_value_t` for more information.
+//
+// Doorbell signals are firmware/hardware-specific and must only be written to
+// by the host and other agents (that means no waiting either, as that's a
+// read). Only the hardware queues as allocated by the HSA implementation should
+// set these.
+//
+// User signals as presented to the hardware via `iree_amd_signal_t` are like
+// futices: allocating memory accessible to a set of agents and populating it
+// is enough to create and use the signal and (so long as it's not used
+// afterward) deleting it is just freeing the memory. Special behavior only
+// comes with host interaction: using any host HSA API (`hsa_signal_store_*`,
+// `hsa_signal_wait_*`, etc) is only possible with signals allocated via either
+// `hsa_signal_create` or `hsa_amd_signal_create` as those functions cast to an
+// internal ROCR `Signal` interface. If the signal will only ever be used by our
+// device code, the hardware queues, or our own host code not using the HSA APIs
+// then we don't need to use signals created by HSA. When we do need to interact
+// with the APIs the signals are implemented by two types: busy-wait and
+// interrupt (as implemented in ROCR by `BusyWaitSignal` and `InterruptSignal`).
+// Busy-wait are like a futex and _mostly_ exist entirely in user-mode.
+// Interrupt are the same but with an additional platform event handle so that
+// `hsaKmtWaitOnEvent` and other kernel-level waits can be performed. For such
+// signals the platform event as returned by `hsaKmtCreateEvent` is stored in
+// the `event_mailbox_ptr` and the value to post is `event_id`. I suspect in
+// modern implementations that could be removed as they could be implemented
+// with a futex when in-process and then the full platform handles would be
+// reserved for IPC.
+//
+// Timestamps on the signal are set by the agent processing the operation.
+// `start_ts` is set when the packet enters the active phase and `end_ts` is set
+// when it completes. These timestamps are in agent-specific ticks and need to
+// be translated into system-scope by scaling by relative frequencies of the
+// system and the particular agent by
+// `hsa_amd_profiling_convert_tick_to_system_domain` that handles the scaling.
+// At its core that method occasionally queries the base timestamps and
+// frequencies of the agents (as they may change over time) and the
+// resynchronization accounts for drift. In order to resolve timestamps fully
+// on-device we do the same thing by polling `AMDKFD_IOC_GET_CLOCK_COUNTERS`
+// and providing it to the device runtime. Every time the clocks are resynced
+// there's the potential for a discontinuity/backwards rolling timestamps so
+// we try to only do it per-submission to at least keep all of the times within
+// relatively aligned even if the entire submission may have drifted from the
+// system by the end. Note that because work can happen out-of-order the
+// timestamps on a set of signals may be out-of-order with respect to the system
+// time once resolved and anything using the timestamps needs to handle that or
+// unset the CONCURRENT execution flag on the queue.
+typedef struct IREE_AMDGPU_ALIGNAS(64) iree_amd_signal_t {
+  iree_amd_signal_kind64_t kind;
+  union {
+    volatile iree_hsa_signal_value_t value;
+    volatile uint64_t* hardware_doorbell_ptr;
+  };
+  uint64_t event_mailbox_ptr;
+  uint32_t event_id;
+  uint32_t reserved1;
+  iree_amdgpu_device_tick_t start_ts;
+  iree_amdgpu_device_tick_t end_ts;
+  iree_amd_queue_t* queue_ptr;
+  uint32_t reserved3[2];
+} iree_amd_signal_t;
+
+// Wait condition operation.
+typedef uint32_t iree_hsa_signal_condition32_t;
+typedef enum {
+  // The two operands are equal.
+  IREE_HSA_SIGNAL_CONDITION_EQ = 0,
+  // The two operands are not equal.
+  IREE_HSA_SIGNAL_CONDITION_NE = 1,
+  // The first operand is less than the second operand.
+  IREE_HSA_SIGNAL_CONDITION_LT = 2,
+  // The first operand is greater than or equal to the second operand.
+  IREE_HSA_SIGNAL_CONDITION_GTE = 3
+} iree_hsa_signal_condition_t;
+
+// Returns true if the given |current_signal| value matches the expected
+// |desired_value| as defined by |condition|.
+static IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE inline bool
+iree_hsa_evaluate_signal_condition(iree_hsa_signal_condition32_t condition,
+                                   iree_hsa_signal_value_t current_value,
+                                   iree_hsa_signal_value_t desired_value) {
+  switch (condition) {
+    default:
+    case IREE_HSA_SIGNAL_CONDITION_EQ:
+      return current_value == desired_value;
+    case IREE_HSA_SIGNAL_CONDITION_NE:
+      return current_value != desired_value;
+    case IREE_HSA_SIGNAL_CONDITION_LT:
+      return current_value < desired_value;
+    case IREE_HSA_SIGNAL_CONDITION_GTE:
+      return current_value >= desired_value;
+  }
+}
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_ABI_SIGNAL_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/abi/timestamp.h b/runtime/src/iree/hal/drivers/amdgpu/abi/timestamp.h
new file mode 100644
index 0000000..9c22e26
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/abi/timestamp.h
@@ -0,0 +1,146 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_ABI_TIMESTAMP_H_
+#define IREE_HAL_DRIVERS_AMDGPU_ABI_TIMESTAMP_H_
+
+#include "iree/hal/drivers/amdgpu/abi/signal.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+//===----------------------------------------------------------------------===//
+// Timestamp Records
+//===----------------------------------------------------------------------===//
+
+enum {
+  // Version of the timestamp record ABI defined in this header.
+  IREE_HAL_AMDGPU_TIMESTAMP_RECORD_VERSION_0 = 0,
+};
+
+// Timestamp record types.
+typedef enum iree_hal_amdgpu_timestamp_record_type_e {
+  IREE_HAL_AMDGPU_TIMESTAMP_RECORD_TYPE_NONE = 0,
+  IREE_HAL_AMDGPU_TIMESTAMP_RECORD_TYPE_COMMAND_BUFFER = 1,
+  IREE_HAL_AMDGPU_TIMESTAMP_RECORD_TYPE_DISPATCH = 2,
+} iree_hal_amdgpu_timestamp_record_type_t;
+
+// Timestamp tick range written by PM4 packets or device harvest kernels.
+typedef struct iree_hal_amdgpu_timestamp_range_t {
+  // Agent-specific tick captured when the range started.
+  iree_amdgpu_device_tick_t start_tick;
+  // Agent-specific tick captured when the range completed.
+  iree_amdgpu_device_tick_t end_tick;
+} iree_hal_amdgpu_timestamp_range_t;
+IREE_AMDGPU_STATIC_ASSERT(sizeof(iree_hal_amdgpu_timestamp_range_t) == 16,
+                          "timestamp range size is part of the ABI");
+
+// Header common to every fixed binary timestamp record.
+typedef struct IREE_AMDGPU_ALIGNAS(8)
+    iree_hal_amdgpu_timestamp_record_header_t {
+  // Size of the enclosing record in bytes.
+  uint32_t record_length;
+  // ABI version from IREE_HAL_AMDGPU_TIMESTAMP_RECORD_VERSION_*.
+  uint16_t version;
+  // Record type from iree_hal_amdgpu_timestamp_record_type_t.
+  uint16_t type;
+  // Producer-defined ordinal within the record stream for this type.
+  uint32_t record_ordinal;
+  // Reserved bits that must be zero.
+  uint32_t reserved0;
+} iree_hal_amdgpu_timestamp_record_header_t;
+IREE_AMDGPU_STATIC_ASSERT(sizeof(iree_hal_amdgpu_timestamp_record_header_t) ==
+                              16,
+                          "timestamp record header size is part of the ABI");
+
+// Device-written command-buffer execution timestamp record.
+typedef struct IREE_AMDGPU_ALIGNAS(8)
+    iree_hal_amdgpu_command_buffer_timestamp_record_t {
+  // Common timestamp record header. Its record ordinal is the command-buffer
+  // timestamp record ordinal.
+  iree_hal_amdgpu_timestamp_record_header_t header;
+  // Producer-defined command-buffer identifier used for correlation.
+  uint64_t command_buffer_id;
+  // Command-buffer block ordinal when this record describes a block, or
+  // UINT32_MAX when this record describes the whole queue execute.
+  uint32_t block_ordinal;
+  // Reserved bits that must be zero.
+  uint32_t reserved0;
+  // Device tick range captured for the command-buffer execution.
+  iree_hal_amdgpu_timestamp_range_t ticks;
+} iree_hal_amdgpu_command_buffer_timestamp_record_t;
+IREE_AMDGPU_STATIC_ASSERT(
+    sizeof(iree_hal_amdgpu_command_buffer_timestamp_record_t) == 48,
+    "command-buffer timestamp record size is part of the ABI");
+
+// Dispatch timestamp record flags.
+typedef uint32_t iree_hal_amdgpu_dispatch_timestamp_record_flags_t;
+enum iree_hal_amdgpu_dispatch_timestamp_record_flag_bits_t {
+  IREE_HAL_AMDGPU_DISPATCH_TIMESTAMP_RECORD_FLAG_NONE = 0u,
+  // Workgroup counts were read from device memory before dispatch.
+  IREE_HAL_AMDGPU_DISPATCH_TIMESTAMP_RECORD_FLAG_INDIRECT_PARAMETERS = 1u << 0,
+};
+
+// Device-written dispatch execution timestamp record.
+typedef struct IREE_AMDGPU_ALIGNAS(8)
+    iree_hal_amdgpu_dispatch_timestamp_record_t {
+  // Common timestamp record header. Its record ordinal is the dispatch
+  // timestamp record ordinal.
+  iree_hal_amdgpu_timestamp_record_header_t header;
+  // Producer-defined command-buffer identifier, or 0 for direct dispatch.
+  uint64_t command_buffer_id;
+  // Producer-defined executable identifier, or 0 when unavailable.
+  uint64_t executable_id;
+  // Command-buffer block ordinal containing the dispatch, or UINT32_MAX for a
+  // direct dispatch.
+  uint32_t block_ordinal;
+  // Program-global command index, or UINT32_MAX for a direct dispatch.
+  uint32_t command_index;
+  // Executable export ordinal dispatched.
+  uint32_t export_ordinal;
+  // Flags from iree_hal_amdgpu_dispatch_timestamp_record_flag_bits_t.
+  iree_hal_amdgpu_dispatch_timestamp_record_flags_t flags;
+  // Device tick range captured for the dispatch execution.
+  iree_hal_amdgpu_timestamp_range_t ticks;
+} iree_hal_amdgpu_dispatch_timestamp_record_t;
+IREE_AMDGPU_STATIC_ASSERT(sizeof(iree_hal_amdgpu_dispatch_timestamp_record_t) ==
+                              64,
+                          "dispatch timestamp record size is part of the ABI");
+
+//===----------------------------------------------------------------------===//
+// Dispatch Timestamp Harvest ABI
+//===----------------------------------------------------------------------===//
+
+// One device-side dispatch timestamp harvest source.
+typedef struct iree_hal_amdgpu_dispatch_timestamp_harvest_source_t {
+  // Raw AMD completion signal populated by the CP for the timestamped dispatch.
+  const iree_amd_signal_t* completion_signal;
+  // Timestamp range receiving copied completion-signal ticks.
+  iree_hal_amdgpu_timestamp_range_t* ticks;
+} iree_hal_amdgpu_dispatch_timestamp_harvest_source_t;
+IREE_AMDGPU_STATIC_ASSERT(
+    sizeof(iree_hal_amdgpu_dispatch_timestamp_harvest_source_t) == 16,
+    "dispatch timestamp harvest source size is part of the ABI");
+
+// Kernel arguments for the dispatch timestamp harvest builtin.
+typedef struct iree_hal_amdgpu_dispatch_timestamp_harvest_args_t {
+  // Source table with one entry per timestamped dispatch.
+  const iree_hal_amdgpu_dispatch_timestamp_harvest_source_t* sources;
+  // Number of entries in |sources|.
+  uint32_t source_count;
+  // Reserved padding that must be zero.
+  uint32_t reserved0;
+} iree_hal_amdgpu_dispatch_timestamp_harvest_args_t;
+IREE_AMDGPU_STATIC_ASSERT(
+    sizeof(iree_hal_amdgpu_dispatch_timestamp_harvest_args_t) == 16,
+    "dispatch timestamp harvest args size is part of the ABI");
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_ABI_TIMESTAMP_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/access_policy.c b/runtime/src/iree/hal/drivers/amdgpu/access_policy.c
new file mode 100644
index 0000000..48b80aa
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/access_policy.c
@@ -0,0 +1,104 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/access_policy.h"
+
+static bool iree_hal_amdgpu_access_agent_list_contains(
+    const iree_hal_amdgpu_access_agent_list_t* agent_list, hsa_agent_t agent) {
+  for (uint32_t i = 0; i < agent_list->count; ++i) {
+    if (agent_list->values[i].handle == agent.handle) return true;
+  }
+  return false;
+}
+
+static iree_status_t iree_hal_amdgpu_access_agent_list_append_unique(
+    iree_hal_amdgpu_access_agent_list_t* agent_list, hsa_agent_t agent) {
+  if (iree_hal_amdgpu_access_agent_list_contains(agent_list, agent)) {
+    return iree_ok_status();
+  }
+  if (IREE_UNLIKELY(agent_list->count >= IREE_ARRAYSIZE(agent_list->values))) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "AMDGPU access agent list capacity exceeded");
+  }
+  agent_list->values[agent_list->count++] = agent;
+  return iree_ok_status();
+}
+
+iree_status_t iree_hal_amdgpu_access_agent_list_resolve(
+    const iree_hal_amdgpu_topology_t* topology,
+    iree_hal_amdgpu_queue_affinity_domain_t queue_affinity_domain,
+    iree_hal_queue_affinity_t queue_affinity,
+    iree_hal_amdgpu_access_agent_list_t* out_agent_list) {
+  IREE_ASSERT_ARGUMENT(topology);
+  IREE_ASSERT_ARGUMENT(out_agent_list);
+  memset(out_agent_list, 0, sizeof(*out_agent_list));
+
+  if (IREE_UNLIKELY(queue_affinity_domain.physical_device_count >
+                    topology->gpu_agent_count)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "AMDGPU queue affinity domain device count %" PRIhsz
+                            " exceeds topology GPU agent count %" PRIhsz,
+                            queue_affinity_domain.physical_device_count,
+                            topology->gpu_agent_count);
+  }
+
+  iree_hal_amdgpu_queue_affinity_physical_device_set_t physical_device_set;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_queue_affinity_select_physical_devices(
+      queue_affinity_domain, queue_affinity, &physical_device_set));
+
+  iree_status_t status = iree_ok_status();
+  for (iree_host_size_t physical_device_ordinal = 0;
+       physical_device_ordinal < topology->gpu_agent_count &&
+       iree_status_is_ok(status);
+       ++physical_device_ordinal) {
+    if (!iree_all_bits_set(physical_device_set.physical_device_mask,
+                           ((uint64_t)1) << physical_device_ordinal)) {
+      continue;
+    }
+    const iree_host_size_t cpu_agent_ordinal =
+        topology->gpu_cpu_map[physical_device_ordinal];
+    if (IREE_UNLIKELY(cpu_agent_ordinal >= topology->cpu_agent_count)) {
+      status =
+          iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                           "AMDGPU topology maps GPU agent ordinal %" PRIhsz
+                           " to invalid CPU agent ordinal %" PRIhsz,
+                           physical_device_ordinal, cpu_agent_ordinal);
+      break;
+    }
+
+    status = iree_hal_amdgpu_access_agent_list_append_unique(
+        out_agent_list, topology->gpu_agents[physical_device_ordinal]);
+    if (iree_status_is_ok(status)) {
+      status = iree_hal_amdgpu_access_agent_list_append_unique(
+          out_agent_list, topology->cpu_agents[cpu_agent_ordinal]);
+    }
+  }
+  return status;
+}
+
+iree_status_t iree_hal_amdgpu_access_allow_agent_list(
+    const iree_hal_amdgpu_libhsa_t* libhsa,
+    const iree_hal_amdgpu_access_agent_list_t* agent_list, const void* ptr) {
+  IREE_ASSERT_ARGUMENT(libhsa);
+  IREE_ASSERT_ARGUMENT(agent_list);
+  IREE_ASSERT_ARGUMENT(ptr);
+  return iree_hsa_amd_agents_allow_access(IREE_LIBHSA(libhsa),
+                                          agent_list->count, agent_list->values,
+                                          /*flags=*/NULL, ptr);
+}
+
+iree_status_t iree_hal_amdgpu_access_lock_host_allocation(
+    const iree_hal_amdgpu_libhsa_t* libhsa,
+    const iree_hal_amdgpu_access_agent_list_t* agent_list, void* host_ptr,
+    iree_device_size_t length, void** out_agent_ptr) {
+  IREE_ASSERT_ARGUMENT(libhsa);
+  IREE_ASSERT_ARGUMENT(agent_list);
+  IREE_ASSERT_ARGUMENT(host_ptr);
+  IREE_ASSERT_ARGUMENT(out_agent_ptr);
+  return iree_hsa_amd_memory_lock(IREE_LIBHSA(libhsa), host_ptr, (size_t)length,
+                                  (hsa_agent_t*)agent_list->values,
+                                  (int)agent_list->count, out_agent_ptr);
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/access_policy.h b/runtime/src/iree/hal/drivers/amdgpu/access_policy.h
new file mode 100644
index 0000000..6608d81
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/access_policy.h
@@ -0,0 +1,67 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_ACCESS_POLICY_H_
+#define IREE_HAL_DRIVERS_AMDGPU_ACCESS_POLICY_H_
+
+#include "iree/base/api.h"
+#include "iree/hal/api.h"
+#include "iree/hal/drivers/amdgpu/queue_affinity.h"
+#include "iree/hal/drivers/amdgpu/util/libhsa.h"
+#include "iree/hal/drivers/amdgpu/util/topology.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+//===----------------------------------------------------------------------===//
+// iree_hal_amdgpu_access_agent_list_t
+//===----------------------------------------------------------------------===//
+
+// Fixed-capacity list of HSA agents that should be granted access to memory.
+//
+// HSA access calls are cold allocation/pinning paths. Keeping this list in
+// caller storage avoids heap allocations while still making access policy an
+// explicit, inspectable decision instead of ad-hoc all-agent grants.
+typedef struct iree_hal_amdgpu_access_agent_list_t {
+  // Number of initialized entries in |values|.
+  uint32_t count;
+
+  // HSA agents to pass to hsa_amd_agents_allow_access or hsa_amd_memory_lock.
+  hsa_agent_t
+      values[IREE_HAL_AMDGPU_MAX_CPU_AGENT + IREE_HAL_AMDGPU_MAX_GPU_AGENT];
+} iree_hal_amdgpu_access_agent_list_t;
+
+// Resolves the HSA agents that may access memory placed for |queue_affinity|.
+//
+// The resulting list contains each selected GPU agent and its nearest CPU agent
+// from |topology|. IREE_HAL_QUEUE_AFFINITY_ANY therefore grants the entire
+// logical device topology, while physical-device-local affinities stay scoped
+// to that physical device. Sharing usage bits define how queues may share the
+// buffer within the requested placement; they do not expand the placement past
+// |queue_affinity|.
+iree_status_t iree_hal_amdgpu_access_agent_list_resolve(
+    const iree_hal_amdgpu_topology_t* topology,
+    iree_hal_amdgpu_queue_affinity_domain_t queue_affinity_domain,
+    iree_hal_queue_affinity_t queue_affinity,
+    iree_hal_amdgpu_access_agent_list_t* out_agent_list);
+
+// Grants |agent_list| access to an HSA memory pool allocation.
+iree_status_t iree_hal_amdgpu_access_allow_agent_list(
+    const iree_hal_amdgpu_libhsa_t* libhsa,
+    const iree_hal_amdgpu_access_agent_list_t* agent_list, const void* ptr);
+
+// Pins |host_ptr| for the HSA agents in |agent_list|.
+iree_status_t iree_hal_amdgpu_access_lock_host_allocation(
+    const iree_hal_amdgpu_libhsa_t* libhsa,
+    const iree_hal_amdgpu_access_agent_list_t* agent_list, void* host_ptr,
+    iree_device_size_t length, void** out_agent_ptr);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_ACCESS_POLICY_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/access_policy_test.cc b/runtime/src/iree/hal/drivers/amdgpu/access_policy_test.cc
new file mode 100644
index 0000000..c675a36
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/access_policy_test.cc
@@ -0,0 +1,107 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/access_policy.h"
+
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace iree::hal::amdgpu {
+namespace {
+
+static hsa_agent_t MakeAgent(uint64_t handle) { return hsa_agent_t{handle}; }
+
+static iree_hal_amdgpu_topology_t MakeThreeGpuTopology() {
+  iree_hal_amdgpu_topology_t topology;
+  iree_hal_amdgpu_topology_initialize(&topology);
+  topology.cpu_agent_count = 2;
+  topology.cpu_agents[0] = MakeAgent(100);
+  topology.cpu_agents[1] = MakeAgent(101);
+  topology.gpu_agent_count = 3;
+  topology.gpu_agents[0] = MakeAgent(200);
+  topology.gpu_agents[1] = MakeAgent(201);
+  topology.gpu_agents[2] = MakeAgent(202);
+  topology.gpu_agent_queue_count = 2;
+  topology.gpu_cpu_map[0] = 0;
+  topology.gpu_cpu_map[1] = 0;
+  topology.gpu_cpu_map[2] = 1;
+  topology.all_agent_count = 5;
+  topology.all_agents[0] = topology.cpu_agents[0];
+  topology.all_agents[1] = topology.cpu_agents[1];
+  topology.all_agents[2] = topology.gpu_agents[0];
+  topology.all_agents[3] = topology.gpu_agents[1];
+  topology.all_agents[4] = topology.gpu_agents[2];
+  return topology;
+}
+
+static iree_hal_amdgpu_queue_affinity_domain_t ThreeGpuDomain() {
+  return (iree_hal_amdgpu_queue_affinity_domain_t){
+      .supported_affinity = 0x3Full,
+      .physical_device_count = 3,
+      .queue_count_per_physical_device = 2,
+  };
+}
+
+static bool AgentListContains(
+    const iree_hal_amdgpu_access_agent_list_t& agent_list, hsa_agent_t agent) {
+  for (uint32_t i = 0; i < agent_list.count; ++i) {
+    if (agent_list.values[i].handle == agent.handle) return true;
+  }
+  return false;
+}
+
+TEST(AccessPolicyTest, AnySelectsLogicalTopologyAgents) {
+  iree_hal_amdgpu_topology_t topology = MakeThreeGpuTopology();
+
+  iree_hal_amdgpu_access_agent_list_t agent_list;
+  IREE_ASSERT_OK(iree_hal_amdgpu_access_agent_list_resolve(
+      &topology, ThreeGpuDomain(), IREE_HAL_QUEUE_AFFINITY_ANY, &agent_list));
+
+  EXPECT_EQ(agent_list.count, 5u);
+  EXPECT_TRUE(AgentListContains(agent_list, topology.cpu_agents[0]));
+  EXPECT_TRUE(AgentListContains(agent_list, topology.cpu_agents[1]));
+  EXPECT_TRUE(AgentListContains(agent_list, topology.gpu_agents[0]));
+  EXPECT_TRUE(AgentListContains(agent_list, topology.gpu_agents[1]));
+  EXPECT_TRUE(AgentListContains(agent_list, topology.gpu_agents[2]));
+}
+
+TEST(AccessPolicyTest, PhysicalDeviceAffinitySelectsGpuAndNearestCpu) {
+  iree_hal_amdgpu_topology_t topology = MakeThreeGpuTopology();
+
+  iree_hal_amdgpu_access_agent_list_t agent_list;
+  IREE_ASSERT_OK(iree_hal_amdgpu_access_agent_list_resolve(
+      &topology, ThreeGpuDomain(), 0xCull, &agent_list));
+
+  EXPECT_EQ(agent_list.count, 2u);
+  EXPECT_TRUE(AgentListContains(agent_list, topology.cpu_agents[0]));
+  EXPECT_TRUE(AgentListContains(agent_list, topology.gpu_agents[1]));
+}
+
+TEST(AccessPolicyTest, CrossDeviceAffinityDeduplicatesCpuAgents) {
+  iree_hal_amdgpu_topology_t topology = MakeThreeGpuTopology();
+
+  iree_hal_amdgpu_access_agent_list_t agent_list;
+  IREE_ASSERT_OK(iree_hal_amdgpu_access_agent_list_resolve(
+      &topology, ThreeGpuDomain(), 0x5ull, &agent_list));
+
+  EXPECT_EQ(agent_list.count, 3u);
+  EXPECT_TRUE(AgentListContains(agent_list, topology.cpu_agents[0]));
+  EXPECT_TRUE(AgentListContains(agent_list, topology.gpu_agents[0]));
+  EXPECT_TRUE(AgentListContains(agent_list, topology.gpu_agents[1]));
+}
+
+TEST(AccessPolicyTest, RejectsInvalidGpuCpuMap) {
+  iree_hal_amdgpu_topology_t topology = MakeThreeGpuTopology();
+  topology.gpu_cpu_map[1] = 2;
+
+  iree_hal_amdgpu_access_agent_list_t agent_list;
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_OUT_OF_RANGE,
+                        iree_hal_amdgpu_access_agent_list_resolve(
+                            &topology, ThreeGpuDomain(), 0x4ull, &agent_list));
+}
+
+}  // namespace
+}  // namespace iree::hal::amdgpu
diff --git a/runtime/src/iree/hal/drivers/amdgpu/allocator.c b/runtime/src/iree/hal/drivers/amdgpu/allocator.c
index 60a5a31..74f0ed3 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/allocator.c
+++ b/runtime/src/iree/hal/drivers/amdgpu/allocator.c
@@ -1,4 +1,4 @@
-// Copyright 2025 The IREE Authors
+// Copyright 2026 The IREE Authors
 //
 // Licensed under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,31 +6,120 @@
 
 #include "iree/hal/drivers/amdgpu/allocator.h"
 
+#include "iree/hal/drivers/amdgpu/access_policy.h"
 #include "iree/hal/drivers/amdgpu/buffer.h"
+#include "iree/hal/drivers/amdgpu/logical_device.h"
+#include "iree/hal/drivers/amdgpu/queue_affinity.h"
+#include "iree/hal/drivers/amdgpu/system.h"
 #include "iree/hal/drivers/amdgpu/util/topology.h"
+#include "iree/hal/drivers/amdgpu/util/vmem.h"
 
 //===----------------------------------------------------------------------===//
 // iree_hal_amdgpu_allocator_t
 //===----------------------------------------------------------------------===//
 
-// TODO(benvanik): use one ID per address space or pool - each shows as a
-// different track in tracing tools.
 #if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_ALLOCATION_TRACKING
-static const char* IREE_HAL_AMDGPU_ALLOCATOR_ID = "AMDGPU unpooled";
+static const char* IREE_HAL_AMDGPU_ALLOCATOR_ID = "iree-hal-amdgpu-unpooled";
 #endif  // IREE_TRACING_FEATURE_ALLOCATION_TRACKING
 
+typedef struct iree_hal_amdgpu_allocator_memory_pool_t {
+  // HSA memory pool used for allocations.
+  hsa_amd_memory_pool_t memory_pool;
+
+  // Allocation sizes submitted to HSA are rounded up to this granule.
+  iree_device_size_t allocation_granule;
+
+  // Base-pointer alignment guaranteed by HSA allocations from |memory_pool|.
+  iree_device_size_t allocation_alignment;
+
+  // Maximum single HSA allocation size supported by |memory_pool|.
+  iree_device_size_t max_allocation_size;
+} iree_hal_amdgpu_allocator_memory_pool_t;
+
+typedef struct iree_hal_amdgpu_allocator_memory_pools_t {
+  // Coarse-grained GPU-local pools used for default device-local allocations.
+  iree_hal_amdgpu_allocator_memory_pool_t
+      device_coarse[IREE_HAL_AMDGPU_MAX_GPU_AGENT];
+
+  // Fine-grained GPU-local pools used only for explicit host-visible requests.
+  iree_hal_amdgpu_allocator_memory_pool_t
+      device_fine[IREE_HAL_AMDGPU_MAX_GPU_AGENT];
+
+  // Fine-grained host-local pools nearest to each GPU.
+  iree_hal_amdgpu_allocator_memory_pool_t
+      host_fine[IREE_HAL_AMDGPU_MAX_GPU_AGENT];
+} iree_hal_amdgpu_allocator_memory_pools_t;
+
 typedef struct iree_hal_amdgpu_allocator_t {
+  // HAL resource header for allocator lifetime management.
   iree_hal_resource_t resource;
+
+  // Host allocator used for allocator-owned bookkeeping.
   iree_allocator_t host_allocator;
 
+  // Unowned logical device. Must outlive the allocator.
+  iree_hal_amdgpu_logical_device_t* logical_device;
+
   // Unowned libhsa handle. Must be retained by the owner.
   const iree_hal_amdgpu_libhsa_t* libhsa;
-  // Topology with all CPU and GPU agents.
+
+  // Unowned topology used to resolve queue affinity to a physical device.
   const iree_hal_amdgpu_topology_t* topology;
 
-  IREE_STATISTICS(iree_hal_allocator_statistics_t statistics;)
+  // Cached HSA memory pool properties used for placement and heap queries.
+  iree_hal_amdgpu_allocator_memory_pools_t memory_pools;
+
+  IREE_STATISTICS(
+      // Aggregate allocation statistics reported through the HAL allocator API.
+      iree_hal_allocator_statistics_t statistics;)
 } iree_hal_amdgpu_allocator_t;
 
+typedef struct iree_hal_amdgpu_allocator_placement_t {
+  // HSA memory pool selected for the allocation.
+  const iree_hal_amdgpu_allocator_memory_pool_t* memory_pool;
+
+  // Physical device ordinal owning |memory_pool|.
+  uint32_t physical_device_ordinal;
+
+  // Resolved HAL memory type exposed by the created buffer.
+  iree_hal_memory_type_t memory_type;
+} iree_hal_amdgpu_allocator_placement_t;
+
+typedef struct iree_hal_amdgpu_imported_host_release_data_t {
+  // Unowned libhsa handle used to unlock the imported host allocation.
+  const iree_hal_amdgpu_libhsa_t* libhsa;
+
+  // Unowned logical device used to record the paired unimport event.
+  iree_hal_device_t* profile_device;
+
+  // Original host allocation pointer passed to hsa_amd_memory_lock.
+  void* host_ptr;
+
+  // Length of the imported host allocation in bytes.
+  iree_device_size_t length;
+
+  // HAL memory type bits used to expose the imported buffer.
+  iree_hal_memory_type_t memory_type;
+
+  // HAL buffer usage bits used to expose the imported buffer.
+  iree_hal_buffer_usage_t buffer_usage;
+
+  // Profiling session id owning |profile_allocation_id|.
+  uint64_t profile_session_id;
+
+  // Session-local allocation id for the import/unimport lifecycle.
+  uint64_t profile_allocation_id;
+
+  // Session-local physical device ordinal attributed to this import.
+  uint32_t profile_physical_device_ordinal;
+
+  // Host allocator used to release this thunk after buffer destruction.
+  iree_allocator_t host_allocator;
+
+  // Optional caller callback invoked after HSA has unlocked the host memory.
+  iree_hal_buffer_release_callback_t caller_release_callback;
+} iree_hal_amdgpu_imported_host_release_data_t;
+
 static const iree_hal_allocator_vtable_t iree_hal_amdgpu_allocator_vtable;
 
 static iree_hal_amdgpu_allocator_t* iree_hal_amdgpu_allocator_cast(
@@ -39,31 +128,249 @@
   return (iree_hal_amdgpu_allocator_t*)base_value;
 }
 
+static iree_status_t iree_hal_amdgpu_allocator_query_pool_properties(
+    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_amd_memory_pool_t memory_pool,
+    iree_hal_amdgpu_allocator_memory_pool_t* out_pool) {
+  memset(out_pool, 0, sizeof(*out_pool));
+
+  bool allocation_allowed = false;
+  IREE_RETURN_IF_ERROR(iree_hsa_amd_memory_pool_get_info(
+      IREE_LIBHSA(libhsa), memory_pool,
+      HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_ALLOWED, &allocation_allowed));
+  if (!allocation_allowed) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "AMDGPU allocator memory pool does not support runtime allocations");
+  }
+
+  size_t allocation_granule = 0;
+  IREE_RETURN_IF_ERROR(iree_hsa_amd_memory_pool_get_info(
+      IREE_LIBHSA(libhsa), memory_pool,
+      HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_GRANULE, &allocation_granule));
+  if (allocation_granule == 0 ||
+      !iree_device_size_is_power_of_two(allocation_granule)) {
+    return iree_make_status(
+        IREE_STATUS_INTERNAL,
+        "invalid HSA runtime allocation granule for an AMDGPU memory pool: "
+        "%" PRIhsz,
+        (iree_host_size_t)allocation_granule);
+  }
+
+  size_t allocation_alignment = 0;
+  IREE_RETURN_IF_ERROR(iree_hsa_amd_memory_pool_get_info(
+      IREE_LIBHSA(libhsa), memory_pool,
+      HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_ALIGNMENT, &allocation_alignment));
+  if (allocation_alignment == 0 ||
+      !iree_device_size_is_power_of_two(allocation_alignment)) {
+    return iree_make_status(
+        IREE_STATUS_INTERNAL,
+        "invalid HSA runtime allocation alignment for an AMDGPU memory pool: "
+        "%" PRIhsz,
+        (iree_host_size_t)allocation_alignment);
+  }
+
+  size_t max_allocation_size = 0;
+  IREE_RETURN_IF_ERROR(iree_hsa_amd_memory_pool_get_info(
+      IREE_LIBHSA(libhsa), memory_pool, HSA_AMD_MEMORY_POOL_INFO_ALLOC_MAX_SIZE,
+      &max_allocation_size));
+  if (max_allocation_size == 0) {
+    return iree_make_status(
+        IREE_STATUS_INTERNAL,
+        "invalid HSA max allocation size for an AMDGPU memory pool");
+  }
+
+  out_pool->memory_pool = memory_pool;
+  out_pool->allocation_granule = (iree_device_size_t)allocation_granule;
+  out_pool->allocation_alignment = (iree_device_size_t)allocation_alignment;
+  out_pool->max_allocation_size = (iree_device_size_t)max_allocation_size;
+  return iree_ok_status();
+}
+
+static iree_hal_amdgpu_queue_affinity_domain_t
+iree_hal_amdgpu_allocator_queue_affinity_domain(
+    const iree_hal_amdgpu_allocator_t* allocator) {
+  return (iree_hal_amdgpu_queue_affinity_domain_t){
+      .supported_affinity = allocator->logical_device->queue_affinity_mask,
+      .physical_device_count = allocator->topology->gpu_agent_count,
+      .queue_count_per_physical_device =
+          allocator->topology->gpu_agent_queue_count,
+  };
+}
+
+static bool iree_hal_amdgpu_allocator_select_device_ordinal(
+    const iree_hal_amdgpu_allocator_t* allocator,
+    iree_hal_queue_affinity_t queue_affinity,
+    iree_hal_queue_affinity_t* out_queue_affinity,
+    iree_host_size_t* out_device_ordinal) {
+  const iree_hal_amdgpu_queue_affinity_domain_t domain =
+      iree_hal_amdgpu_allocator_queue_affinity_domain(allocator);
+  iree_hal_amdgpu_queue_affinity_resolved_t resolved;
+  if (!iree_hal_amdgpu_queue_affinity_try_resolve(domain, queue_affinity,
+                                                  &resolved)) {
+    return false;
+  }
+  if (out_queue_affinity) {
+    *out_queue_affinity = resolved.queue_affinity;
+  }
+  *out_device_ordinal = resolved.physical_device_ordinal;
+  return true;
+}
+
+static iree_status_t iree_hal_amdgpu_allocator_resolve_access_agents(
+    const iree_hal_amdgpu_allocator_t* allocator,
+    iree_hal_queue_affinity_t queue_affinity,
+    iree_hal_amdgpu_access_agent_list_t* out_agent_list) {
+  return iree_hal_amdgpu_access_agent_list_resolve(
+      allocator->topology,
+      iree_hal_amdgpu_allocator_queue_affinity_domain(allocator),
+      queue_affinity, out_agent_list);
+}
+
+static bool iree_hal_amdgpu_allocator_resolve_placement(
+    iree_hal_amdgpu_allocator_t* allocator, iree_hal_buffer_params_t* params,
+    iree_hal_amdgpu_allocator_placement_t* out_placement) {
+  memset(out_placement, 0, sizeof(*out_placement));
+
+  iree_host_size_t device_ordinal = 0;
+  iree_hal_queue_affinity_t queue_affinity = 0;
+  if (!iree_hal_amdgpu_allocator_select_device_ordinal(
+          allocator, params->queue_affinity, &queue_affinity,
+          &device_ordinal)) {
+    return false;
+  }
+  params->queue_affinity = queue_affinity;
+
+  const iree_hal_memory_type_t requested_type = params->type;
+  const iree_hal_memory_type_t required_type =
+      requested_type & ~IREE_HAL_MEMORY_TYPE_OPTIMAL;
+  const bool requires_host_local =
+      iree_all_bits_set(required_type, IREE_HAL_MEMORY_TYPE_HOST_LOCAL);
+  const bool requires_host_visible =
+      iree_all_bits_set(required_type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE);
+  const bool requires_device_local =
+      iree_all_bits_set(required_type, IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL);
+
+  const iree_hal_amdgpu_allocator_memory_pool_t* memory_pool = NULL;
+  iree_hal_memory_type_t memory_type = 0;
+  // Sharing hints do not affect HSA pool selection. Export is omitted because
+  // it requires dedicated platform export support.
+  const iree_hal_buffer_usage_t sharing_usage =
+      IREE_HAL_BUFFER_USAGE_SHARING_REPLICATE |
+      IREE_HAL_BUFFER_USAGE_SHARING_CONCURRENT |
+      IREE_HAL_BUFFER_USAGE_SHARING_IMMUTABLE;
+  iree_hal_buffer_usage_t supported_usage = IREE_HAL_BUFFER_USAGE_TRANSFER |
+                                            IREE_HAL_BUFFER_USAGE_DISPATCH |
+                                            sharing_usage;
+  if (requires_host_local) {
+    if (requires_device_local) return false;
+    memory_pool = &allocator->memory_pools.host_fine[device_ordinal];
+    memory_type =
+        IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE;
+  } else if (requires_host_visible) {
+    memory_pool = &allocator->memory_pools.device_fine[device_ordinal];
+    memory_type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL |
+                  IREE_HAL_MEMORY_TYPE_HOST_VISIBLE |
+                  IREE_HAL_MEMORY_TYPE_HOST_COHERENT;
+  } else {
+    memory_pool = &allocator->memory_pools.device_coarse[device_ordinal];
+    memory_type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL;
+  }
+  if (!memory_pool->memory_pool.handle) return false;
+  if (!iree_all_bits_set(memory_type, required_type)) return false;
+
+  if (iree_any_bit_set(memory_type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) {
+    supported_usage |= IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED |
+                       IREE_HAL_BUFFER_USAGE_MAPPING_PERSISTENT |
+                       IREE_HAL_BUFFER_USAGE_MAPPING_OPTIONAL |
+                       IREE_HAL_BUFFER_USAGE_MAPPING_ACCESS_RANDOM |
+                       IREE_HAL_BUFFER_USAGE_MAPPING_ACCESS_SEQUENTIAL_WRITE;
+  } else {
+    const iree_hal_buffer_usage_t mapping_usage =
+        IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED |
+        IREE_HAL_BUFFER_USAGE_MAPPING_PERSISTENT |
+        IREE_HAL_BUFFER_USAGE_MAPPING_ACCESS_RANDOM |
+        IREE_HAL_BUFFER_USAGE_MAPPING_ACCESS_SEQUENTIAL_WRITE;
+    if (iree_any_bit_set(params->usage, mapping_usage)) {
+      if (!iree_all_bits_set(params->usage,
+                             IREE_HAL_BUFFER_USAGE_MAPPING_OPTIONAL)) {
+        return false;
+      }
+      params->usage &=
+          ~(mapping_usage | IREE_HAL_BUFFER_USAGE_MAPPING_OPTIONAL);
+    }
+  }
+  if (!iree_all_bits_set(supported_usage, params->usage)) return false;
+
+  params->type = memory_type;
+  params->usage &= supported_usage;
+  out_placement->memory_pool = memory_pool;
+  out_placement->physical_device_ordinal = (uint32_t)device_ordinal;
+  out_placement->memory_type = memory_type;
+  return true;
+}
+
 iree_status_t iree_hal_amdgpu_allocator_create(
+    iree_hal_amdgpu_logical_device_t* logical_device,
     const iree_hal_amdgpu_libhsa_t* libhsa,
     const iree_hal_amdgpu_topology_t* topology, iree_allocator_t host_allocator,
     iree_hal_allocator_t** out_allocator) {
+  IREE_ASSERT_ARGUMENT(logical_device);
   IREE_ASSERT_ARGUMENT(libhsa);
   IREE_ASSERT_ARGUMENT(topology);
   IREE_ASSERT_ARGUMENT(out_allocator);
   IREE_TRACE_ZONE_BEGIN(z0);
+  *out_allocator = NULL;
 
   iree_hal_amdgpu_allocator_t* allocator = NULL;
   IREE_RETURN_AND_END_ZONE_IF_ERROR(
       z0, iree_allocator_malloc(host_allocator, sizeof(*allocator),
                                 (void**)&allocator));
+  memset(allocator, 0, sizeof(*allocator));
   iree_hal_resource_initialize(&iree_hal_amdgpu_allocator_vtable,
                                &allocator->resource);
   allocator->host_allocator = host_allocator;
+  allocator->logical_device = logical_device;
   allocator->libhsa = libhsa;
   allocator->topology = topology;
 
-  // TODO(benvanik): query device heaps, supported features (concurrent
-  // access/etc), and prepare any pools that will be used during allocation.
-  // It's expected that most failures that occur after creation are allocation
-  // request-specific so preparing here will help keep the errors more
-  // localized.
   iree_status_t status = iree_ok_status();
+  for (iree_host_size_t i = 0;
+       i < topology->gpu_agent_count && iree_status_is_ok(status); ++i) {
+    hsa_amd_memory_pool_t device_coarse_pool = {0};
+    status = iree_hal_amdgpu_find_coarse_global_memory_pool(
+        libhsa, topology->gpu_agents[i], &device_coarse_pool);
+    if (iree_status_is_ok(status)) {
+      status = iree_hal_amdgpu_allocator_query_pool_properties(
+          libhsa, device_coarse_pool,
+          &allocator->memory_pools.device_coarse[i]);
+    }
+
+    hsa_amd_memory_pool_t device_fine_pool = {0};
+    if (iree_status_is_ok(status)) {
+      status = iree_hal_amdgpu_find_fine_global_memory_pool(
+          libhsa, topology->gpu_agents[i], &device_fine_pool);
+      if (!iree_status_is_ok(status)) {
+        status = iree_status_annotate_f(
+            status,
+            "AMDGPU allocator requires fine-grained device-local memory for "
+            "host-coherent DEVICE_LOCAL|HOST_VISIBLE allocations on physical "
+            "device %" PRIhsz,
+            i);
+      }
+    }
+    if (iree_status_is_ok(status)) {
+      status = iree_hal_amdgpu_allocator_query_pool_properties(
+          libhsa, device_fine_pool, &allocator->memory_pools.device_fine[i]);
+    }
+
+    if (iree_status_is_ok(status)) {
+      const iree_host_size_t host_ordinal = topology->gpu_cpu_map[i];
+      status = iree_hal_amdgpu_allocator_query_pool_properties(
+          libhsa,
+          logical_device->system->host_memory_pools[host_ordinal].fine_pool,
+          &allocator->memory_pools.host_fine[i]);
+    }
+  }
 
   if (iree_status_is_ok(status)) {
     *out_allocator = (iree_hal_allocator_t*)allocator;
@@ -76,7 +383,6 @@
 
 static void iree_hal_amdgpu_allocator_destroy(
     iree_hal_allocator_t* IREE_RESTRICT base_allocator) {
-  IREE_ASSERT_ARGUMENT(base_allocator);
   iree_hal_amdgpu_allocator_t* allocator =
       iree_hal_amdgpu_allocator_cast(base_allocator);
   IREE_TRACE_ZONE_BEGIN(z0);
@@ -95,16 +401,6 @@
 
 static iree_status_t iree_hal_amdgpu_allocator_trim(
     iree_hal_allocator_t* IREE_RESTRICT base_allocator) {
-  iree_hal_amdgpu_allocator_t* allocator =
-      (iree_hal_amdgpu_allocator_t*)base_allocator;
-
-  // TODO(benvanik): if the allocator is retaining any unused resources they
-  // should be dropped here. If the underlying implementation has pools or
-  // caches it should be notified that a trim is requested. This is called in
-  // low-memory situations or when IREE is not going to be used for awhile (low
-  // power modes or suspension).
-  (void)allocator;
-
   return iree_ok_status();
 }
 
@@ -115,10 +411,30 @@
     iree_hal_amdgpu_allocator_t* allocator =
         iree_hal_amdgpu_allocator_cast(base_allocator);
     memcpy(out_statistics, &allocator->statistics, sizeof(*out_statistics));
-    // TODO(benvanik): update statistics (merge).
   });
 }
 
+static iree_device_size_t iree_hal_amdgpu_allocator_min_pool_limit(
+    iree_device_size_t lhs, iree_device_size_t rhs) {
+  return lhs < rhs ? lhs : rhs;
+}
+
+static void iree_hal_amdgpu_allocator_query_pool_family_limits(
+    const iree_hal_amdgpu_allocator_memory_pool_t* pools,
+    iree_host_size_t pool_count, iree_device_size_t* out_max_allocation_size,
+    iree_device_size_t* out_min_alignment) {
+  iree_device_size_t max_allocation_size = pools[0].max_allocation_size;
+  iree_device_size_t min_alignment = pools[0].allocation_alignment;
+  for (iree_host_size_t i = 1; i < pool_count; ++i) {
+    max_allocation_size = iree_hal_amdgpu_allocator_min_pool_limit(
+        max_allocation_size, pools[i].max_allocation_size);
+    min_alignment = iree_hal_amdgpu_allocator_min_pool_limit(
+        min_alignment, pools[i].allocation_alignment);
+  }
+  *out_max_allocation_size = max_allocation_size;
+  *out_min_alignment = min_alignment;
+}
+
 static iree_status_t iree_hal_amdgpu_allocator_query_memory_heaps(
     iree_hal_allocator_t* IREE_RESTRICT base_allocator,
     iree_host_size_t capacity,
@@ -126,16 +442,71 @@
     iree_host_size_t* IREE_RESTRICT out_count) {
   iree_hal_amdgpu_allocator_t* allocator =
       iree_hal_amdgpu_allocator_cast(base_allocator);
+  const iree_host_size_t heap_count = 3;
+  *out_count = heap_count;
+  if (capacity < heap_count) {
+    // NOTE: lightweight as this is hit in normal pre-sizing usage.
+    return iree_status_from_code(IREE_STATUS_OUT_OF_RANGE);
+  }
 
-  // TODO(benvanik): return heap information. This is called at least once with
-  // a capacity that may be 0 (indicating a query for the total count) and the
-  // heaps should only be populated if capacity is sufficient to store all of
-  // them.
-  (void)allocator;
-  iree_status_t status =
-      iree_make_status(IREE_STATUS_UNIMPLEMENTED, "heap query not implemented");
+  memset(heaps, 0, heap_count * sizeof(*heaps));
 
-  return status;
+  iree_device_size_t device_coarse_max_allocation_size = 0;
+  iree_device_size_t device_coarse_min_alignment = 0;
+  iree_hal_amdgpu_allocator_query_pool_family_limits(
+      allocator->memory_pools.device_coarse,
+      allocator->topology->gpu_agent_count, &device_coarse_max_allocation_size,
+      &device_coarse_min_alignment);
+
+  iree_device_size_t device_fine_max_allocation_size = 0;
+  iree_device_size_t device_fine_min_alignment = 0;
+  iree_hal_amdgpu_allocator_query_pool_family_limits(
+      allocator->memory_pools.device_fine, allocator->topology->gpu_agent_count,
+      &device_fine_max_allocation_size, &device_fine_min_alignment);
+
+  iree_device_size_t host_fine_max_allocation_size = 0;
+  iree_device_size_t host_fine_min_alignment = 0;
+  iree_hal_amdgpu_allocator_query_pool_family_limits(
+      allocator->memory_pools.host_fine, allocator->topology->gpu_agent_count,
+      &host_fine_max_allocation_size, &host_fine_min_alignment);
+
+  // Sharing hints do not affect HSA pool selection. Export is omitted because
+  // it requires dedicated platform export support.
+  const iree_hal_buffer_usage_t sharing_usage =
+      IREE_HAL_BUFFER_USAGE_SHARING_REPLICATE |
+      IREE_HAL_BUFFER_USAGE_SHARING_CONCURRENT |
+      IREE_HAL_BUFFER_USAGE_SHARING_IMMUTABLE;
+  const iree_hal_buffer_usage_t mappable_usage =
+      IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_DISPATCH |
+      sharing_usage | IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED |
+      IREE_HAL_BUFFER_USAGE_MAPPING_PERSISTENT |
+      IREE_HAL_BUFFER_USAGE_MAPPING_OPTIONAL |
+      IREE_HAL_BUFFER_USAGE_MAPPING_ACCESS_RANDOM |
+      IREE_HAL_BUFFER_USAGE_MAPPING_ACCESS_SEQUENTIAL_WRITE;
+
+  // Heap 0: coarse-grained device-local memory.
+  heaps[0].type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL;
+  heaps[0].allowed_usage = IREE_HAL_BUFFER_USAGE_TRANSFER |
+                           IREE_HAL_BUFFER_USAGE_DISPATCH | sharing_usage;
+  heaps[0].max_allocation_size = device_coarse_max_allocation_size;
+  heaps[0].min_alignment = device_coarse_min_alignment;
+
+  // Heap 1: fine-grained device-local memory for explicit host visibility.
+  heaps[1].type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL |
+                  IREE_HAL_MEMORY_TYPE_HOST_VISIBLE |
+                  IREE_HAL_MEMORY_TYPE_HOST_COHERENT;
+  heaps[1].allowed_usage = mappable_usage;
+  heaps[1].max_allocation_size = device_fine_max_allocation_size;
+  heaps[1].min_alignment = device_fine_min_alignment;
+
+  // Heap 2: fine-grained host-local memory visible to the device.
+  heaps[2].type =
+      IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE;
+  heaps[2].allowed_usage = mappable_usage;
+  heaps[2].max_allocation_size = host_fine_max_allocation_size;
+  heaps[2].min_alignment = host_fine_min_alignment;
+
+  return iree_ok_status();
 }
 
 static iree_hal_buffer_compatibility_t
@@ -146,28 +517,150 @@
   iree_hal_amdgpu_allocator_t* allocator =
       iree_hal_amdgpu_allocator_cast(base_allocator);
 
-  // TODO(benvanik): set compatibility rules based on the implementation.
-  // Note that the user may have requested that the allocator place the
-  // allocation based on whatever is optimal for the indicated usage by
-  // including the IREE_HAL_MEMORY_TYPE_OPTIMAL flag. It's still required that
-  // the implementation meet all the requirements but it is free to place it in
-  // either host or device memory so long as the appropriate bits are updated to
-  // indicate where it landed.
-  (void)allocator;
-  iree_hal_buffer_compatibility_t compatibility =
-      IREE_HAL_BUFFER_COMPATIBILITY_NONE;
+  iree_hal_amdgpu_allocator_placement_t placement;
+  if (!iree_hal_amdgpu_allocator_resolve_placement(allocator, params,
+                                                   &placement)) {
+    return IREE_HAL_BUFFER_COMPATIBILITY_NONE;
+  }
+  if (!iree_device_size_is_valid_alignment(params->min_alignment)) {
+    return IREE_HAL_BUFFER_COMPATIBILITY_NONE;
+  }
 
-  // We are now optimal.
-  params->type &= ~IREE_HAL_MEMORY_TYPE_OPTIMAL;
-
-  // Guard against the corner case where the requested buffer size is 0. The
-  // application is unlikely to do anything when requesting a 0-byte buffer; but
-  // it can happen in real world use cases. So we should at least not crash.
+  // Guard against 0-byte allocations.
   if (*allocation_size == 0) *allocation_size = 4;
 
+  iree_device_size_t aligned_allocation_size = 0;
+  if (!iree_device_size_checked_align(*allocation_size,
+                                      placement.memory_pool->allocation_granule,
+                                      &aligned_allocation_size)) {
+    return IREE_HAL_BUFFER_COMPATIBILITY_NONE;
+  }
+  *allocation_size = aligned_allocation_size;
+
+  const bool allocation_size_valid =
+      aligned_allocation_size <= placement.memory_pool->max_allocation_size;
+  const bool allocation_alignment_valid =
+      params->min_alignment == 0 ||
+      params->min_alignment <= placement.memory_pool->allocation_alignment;
+
+  const bool allocation_compatible =
+      allocation_size_valid && allocation_alignment_valid;
+  const bool import_compatible =
+      allocation_size_valid &&
+      iree_all_bits_set(params->type,
+                        IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
+                            IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE) &&
+      !iree_all_bits_set(params->type, IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL);
+  if (!allocation_compatible && !import_compatible) {
+    return IREE_HAL_BUFFER_COMPATIBILITY_NONE;
+  }
+
+  iree_hal_buffer_compatibility_t compatibility = 0;
+  if (allocation_compatible) {
+    compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_ALLOCATABLE;
+  }
+
+  if (iree_any_bit_set(params->usage, IREE_HAL_BUFFER_USAGE_TRANSFER)) {
+    compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER;
+  }
+  if (iree_any_bit_set(params->usage, IREE_HAL_BUFFER_USAGE_DISPATCH)) {
+    compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH;
+  }
+
+  if (import_compatible) {
+    compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_IMPORTABLE;
+  }
+  if (iree_all_bits_set(placement.memory_type,
+                        IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL |
+                            IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) {
+    // Fine-grained GPU-local memory exists to support explicit coherent host
+    // access, but dispatches should prefer coarse-grained device-local memory.
+    // Generic generation helpers use this hint to stage host-produced data
+    // through a transfer instead of generating directly into dispatch inputs.
+    compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_LOW_PERFORMANCE;
+  }
+
   return compatibility;
 }
 
+static void iree_hal_amdgpu_allocator_record_buffer_allocate(
+    iree_hal_amdgpu_allocator_t* allocator,
+    const iree_hal_amdgpu_allocator_placement_t* memory_placement,
+    iree_hal_buffer_params_t params, iree_device_size_t allocation_size,
+    void* host_ptr, iree_hal_buffer_t* buffer) {
+  uint64_t session_id = 0;
+  const uint64_t allocation_id =
+      iree_hal_amdgpu_logical_device_allocate_profile_memory_allocation_id(
+          (iree_hal_device_t*)allocator->logical_device, &session_id);
+  if (allocation_id == 0) {
+    return;
+  }
+
+  iree_hal_profile_memory_event_t event =
+      iree_hal_profile_memory_event_default();
+  event.type = IREE_HAL_PROFILE_MEMORY_EVENT_TYPE_BUFFER_ALLOCATE;
+  event.allocation_id = allocation_id;
+  event.pool_id = memory_placement->memory_pool->memory_pool.handle;
+  event.backing_id = (uint64_t)(uintptr_t)host_ptr;
+  event.physical_device_ordinal = memory_placement->physical_device_ordinal;
+  event.memory_type = memory_placement->memory_type;
+  event.buffer_usage = params.usage;
+  event.length = allocation_size;
+  event.alignment = memory_placement->memory_pool->allocation_alignment;
+  if (iree_hal_amdgpu_logical_device_record_profile_memory_event(
+          (iree_hal_device_t*)allocator->logical_device, &event)) {
+    iree_hal_amdgpu_buffer_set_profile_allocation(
+        buffer, session_id, allocation_id, event.pool_id,
+        event.physical_device_ordinal, event.alignment);
+  }
+}
+
+static void iree_hal_amdgpu_allocator_record_buffer_import(
+    iree_hal_amdgpu_allocator_t* allocator, iree_hal_buffer_params_t params,
+    const iree_hal_external_buffer_t* external_buffer,
+    iree_hal_amdgpu_imported_host_release_data_t* release_data) {
+  uint64_t session_id = 0;
+  const uint64_t allocation_id =
+      iree_hal_amdgpu_logical_device_allocate_profile_memory_allocation_id(
+          (iree_hal_device_t*)allocator->logical_device, &session_id);
+  if (allocation_id == 0) {
+    return;
+  }
+
+  uint32_t physical_device_ordinal = UINT32_MAX;
+  iree_host_size_t selected_device_ordinal = 0;
+  if (iree_hal_amdgpu_allocator_select_device_ordinal(
+          allocator, params.queue_affinity, /*out_queue_affinity=*/NULL,
+          &selected_device_ordinal)) {
+    physical_device_ordinal = (uint32_t)selected_device_ordinal;
+  }
+
+  iree_hal_profile_memory_event_t event =
+      iree_hal_profile_memory_event_default();
+  event.type = IREE_HAL_PROFILE_MEMORY_EVENT_TYPE_BUFFER_IMPORT;
+  event.flags = IREE_HAL_PROFILE_MEMORY_EVENT_FLAG_EXTERNALLY_OWNED;
+  event.allocation_id = allocation_id;
+  event.backing_id =
+      (uint64_t)(uintptr_t)external_buffer->handle.host_allocation.ptr;
+  event.physical_device_ordinal = physical_device_ordinal;
+  event.memory_type = params.type;
+  event.buffer_usage = params.usage;
+  event.length = external_buffer->size;
+  event.alignment = 1;
+  if (!iree_hal_amdgpu_logical_device_record_profile_memory_event(
+          (iree_hal_device_t*)allocator->logical_device, &event)) {
+    return;
+  }
+
+  release_data->profile_device = (iree_hal_device_t*)allocator->logical_device;
+  release_data->length = external_buffer->size;
+  release_data->memory_type = params.type;
+  release_data->buffer_usage = params.usage;
+  release_data->profile_session_id = session_id;
+  release_data->profile_allocation_id = allocation_id;
+  release_data->profile_physical_device_ordinal = physical_device_ordinal;
+}
+
 static iree_status_t iree_hal_amdgpu_allocator_allocate_buffer(
     iree_hal_allocator_t* IREE_RESTRICT base_allocator,
     const iree_hal_buffer_params_t* IREE_RESTRICT params,
@@ -175,61 +668,133 @@
     iree_hal_buffer_t** IREE_RESTRICT out_buffer) {
   iree_hal_amdgpu_allocator_t* allocator =
       iree_hal_amdgpu_allocator_cast(base_allocator);
+  const iree_device_size_t byte_length = allocation_size;
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)byte_length);
 
   // Coerce options into those required by the current device.
   iree_hal_buffer_params_t compat_params = *params;
-  iree_hal_buffer_compatibility_t compatibility =
-      iree_hal_amdgpu_allocator_query_buffer_compatibility(
-          base_allocator, &compat_params, &allocation_size);
-  if (!iree_all_bits_set(compatibility,
-                         IREE_HAL_BUFFER_COMPATIBILITY_ALLOCATABLE)) {
-    // TODO(benvanik): make a helper for this.
+  iree_hal_amdgpu_allocator_placement_t memory_placement;
+  if (!iree_hal_amdgpu_allocator_resolve_placement(allocator, &compat_params,
+                                                   &memory_placement)) {
 #if IREE_STATUS_MODE
-    iree_bitfield_string_temp_t temp0, temp1, temp2;
+    iree_bitfield_string_temp_t temp0, temp1;
     iree_string_view_t memory_type_str =
         iree_hal_memory_type_format(params->type, &temp0);
     iree_string_view_t usage_str =
         iree_hal_buffer_usage_format(params->usage, &temp1);
-    iree_string_view_t compatibility_str =
-        iree_hal_buffer_compatibility_format(compatibility, &temp2);
+    IREE_TRACE_ZONE_END(z0);
     return iree_make_status(
         IREE_STATUS_INVALID_ARGUMENT,
         "allocator cannot allocate a buffer with the given parameters; "
-        "memory_type=%.*s, usage=%.*s, compatibility=%.*s",
+        "memory_type=%.*s, usage=%.*s",
         (int)memory_type_str.size, memory_type_str.data, (int)usage_str.size,
-        usage_str.data, (int)compatibility_str.size, compatibility_str.data);
+        usage_str.data);
 #else
+    IREE_TRACE_ZONE_END(z0);
     return iree_make_status(
         IREE_STATUS_INVALID_ARGUMENT,
         "allocator cannot allocate a buffer with the given parameters");
 #endif  // IREE_STATUS_MODE
   }
 
-  // TODO(benvanik): allocate the underlying device memory. The impl_ptr is just
-  // used for accounting and can be an opaque value (handle/etc) so long as it
-  // is consistent between the alloc and free and unique to the buffer while it
-  // is live. An example iree_hal_amdgpu_external_buffer_wrap is provided that
-  // can be used for implementations that are managing memory using underlying
-  // allocators and just wrapping those device pointers in the HAL buffer type.
-  // Other implementations that require more tracking can provide their own
-  // buffer types that do such tracking for them.
-  (void)allocator;
-  void* impl_ptr = NULL;
-  (void)impl_ptr;
+  if (IREE_UNLIKELY(
+          !iree_device_size_is_valid_alignment(params->min_alignment))) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "requested AMDGPU allocation alignment %" PRIu64
+                            " is not a power-of-two",
+                            (uint64_t)params->min_alignment);
+  }
+  if (IREE_UNLIKELY(params->min_alignment != 0 &&
+                    params->min_alignment >
+                        memory_placement.memory_pool->allocation_alignment)) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "requested AMDGPU allocation alignment %" PRIu64
+        " exceeds HSA memory pool alignment %" PRIu64,
+        (uint64_t)params->min_alignment,
+        (uint64_t)memory_placement.memory_pool->allocation_alignment);
+  }
+
+  // Guard against 0-byte allocations and align to the HSA allocation granule.
+  if (allocation_size == 0) allocation_size = 4;
+  if (!iree_device_size_checked_align(
+          allocation_size, memory_placement.memory_pool->allocation_granule,
+          &allocation_size)) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "allocation size %" PRIdsz
+                            " overflows HSA memory pool allocation granule",
+                            allocation_size);
+  }
+  if (IREE_UNLIKELY(allocation_size >
+                    memory_placement.memory_pool->max_allocation_size)) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "AMDGPU allocation size %" PRIu64
+        " exceeds HSA memory pool max allocation size %" PRIu64,
+        (uint64_t)allocation_size,
+        (uint64_t)memory_placement.memory_pool->max_allocation_size);
+  }
+
+  // Allocate from the resolved HSA memory pool.
+  void* host_ptr = NULL;
+  iree_status_t status = iree_hsa_amd_memory_pool_allocate(
+      IREE_LIBHSA(allocator->libhsa), memory_placement.memory_pool->memory_pool,
+      (size_t)allocation_size, HSA_AMD_MEMORY_POOL_STANDARD_FLAG, &host_ptr);
+
+  // Grant the physical devices selected by the buffer placement access. A
+  // placement of ANY remains intentionally broad within the logical topology,
+  // but never expands to all ROCR-visible platform agents.
+  iree_hal_amdgpu_access_agent_list_t access_agents;
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_allocator_resolve_access_agents(
+        allocator, compat_params.queue_affinity, &access_agents);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_access_allow_agent_list(allocator->libhsa,
+                                                     &access_agents, host_ptr);
+  }
+
+  // Wrap in a HAL buffer.
   iree_hal_buffer_t* buffer = NULL;
-  iree_status_t status = iree_make_status(IREE_STATUS_UNIMPLEMENTED,
-                                          "buffer allocation not implemented");
+  if (iree_status_is_ok(status)) {
+    const iree_hal_buffer_placement_t buffer_placement = {
+        .device = (iree_hal_device_t*)allocator->logical_device,
+        .queue_affinity = compat_params.queue_affinity
+                              ? compat_params.queue_affinity
+                              : IREE_HAL_QUEUE_AFFINITY_ANY,
+        .flags = IREE_HAL_BUFFER_PLACEMENT_FLAG_NONE,
+    };
+    status = iree_hal_amdgpu_buffer_create(
+        allocator->libhsa, buffer_placement, memory_placement.memory_type,
+        compat_params.access, compat_params.usage, allocation_size, byte_length,
+        host_ptr, iree_hal_buffer_release_callback_null(),
+        allocator->host_allocator, &buffer);
+  }
 
   if (iree_status_is_ok(status)) {
-    // TODO(benvanik): ensure this accounting is balanced in deallocate_buffer.
-    IREE_TRACE_ALLOC_NAMED(IREE_HAL_AMDGPU_ALLOCATOR_ID, impl_ptr,
-                           allocation_size);
+    IREE_TRACE_ALLOC_NAMED(IREE_HAL_AMDGPU_ALLOCATOR_ID, host_ptr,
+                           (iree_host_size_t)allocation_size);
     IREE_STATISTICS(iree_hal_allocator_statistics_record_alloc(
         &allocator->statistics, compat_params.type, allocation_size));
+    iree_hal_amdgpu_allocator_record_buffer_allocate(
+        allocator, &memory_placement, compat_params, allocation_size, host_ptr,
+        buffer);
     *out_buffer = buffer;
   } else {
+    if (host_ptr) {
+      status = iree_status_join(
+          status, iree_hsa_amd_memory_pool_free(IREE_LIBHSA(allocator->libhsa),
+                                                host_ptr));
+    }
     iree_hal_buffer_release(buffer);
   }
+
+  IREE_TRACE_ZONE_END(z0);
   return status;
 }
 
@@ -238,27 +803,48 @@
     iree_hal_buffer_t* IREE_RESTRICT base_buffer) {
   iree_hal_amdgpu_allocator_t* allocator =
       iree_hal_amdgpu_allocator_cast(base_allocator);
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(
+      z0, iree_hal_buffer_allocation_size(base_buffer));
 
-  // TODO(benvanik): free the underlying device memory here. Buffers allocated
-  // from this allocator will call this method to handle cleanup. Note that
-  // because this method is responsible for doing the base
-  // iree_hal_buffer_destroy and the caller assumes the memory has been freed an
-  // implementation could pool the buffer handle and return it in the future.
-  (void)allocator;
-  void* impl_ptr = NULL;
-  (void)impl_ptr;
+  IREE_STATISTICS(iree_hal_allocator_statistics_record_free(
+      &allocator->statistics, iree_hal_buffer_memory_type(base_buffer),
+      iree_hal_buffer_allocation_size(base_buffer)));
 
-  // TODO(benvanik): if the buffer was imported then this accounting may need to
-  // be conditional depending on the implementation.
-  bool was_imported = false;
-  if (!was_imported) {
-    IREE_TRACE_FREE_NAMED(IREE_HAL_AMDGPU_ALLOCATOR_ID, impl_ptr);
-    IREE_STATISTICS(iree_hal_allocator_statistics_record_free(
-        &allocator->statistics, iree_hal_buffer_memory_type(base_buffer),
-        iree_hal_buffer_allocation_size(base_buffer)));
-  }
-
+  // The buffer's destroy method handles freeing the HSA allocation.
   iree_hal_buffer_destroy(base_buffer);
+  IREE_TRACE_ZONE_END(z0);
+}
+
+static void iree_hal_amdgpu_allocator_release_imported_host(
+    void* user_data, iree_hal_buffer_t* buffer) {
+  iree_hal_amdgpu_imported_host_release_data_t* data =
+      (iree_hal_amdgpu_imported_host_release_data_t*)user_data;
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  iree_hal_amdgpu_hsa_cleanup_assert_success(
+      iree_hsa_amd_memory_unlock_raw(data->libhsa, data->host_ptr));
+  if (data->profile_allocation_id != 0) {
+    iree_hal_profile_memory_event_t event =
+        iree_hal_profile_memory_event_default();
+    event.type = IREE_HAL_PROFILE_MEMORY_EVENT_TYPE_BUFFER_UNIMPORT;
+    event.flags = IREE_HAL_PROFILE_MEMORY_EVENT_FLAG_EXTERNALLY_OWNED;
+    event.allocation_id = data->profile_allocation_id;
+    event.backing_id = (uint64_t)(uintptr_t)data->host_ptr;
+    event.physical_device_ordinal = data->profile_physical_device_ordinal;
+    event.memory_type = data->memory_type;
+    event.buffer_usage = data->buffer_usage;
+    event.length = data->length;
+    event.alignment = 1;
+    iree_hal_amdgpu_logical_device_record_profile_memory_event_for_session(
+        data->profile_device, data->profile_session_id, &event);
+  }
+  if (data->caller_release_callback.fn) {
+    data->caller_release_callback.fn(data->caller_release_callback.user_data,
+                                     buffer);
+  }
+  iree_allocator_free(data->host_allocator, data);
+  IREE_TRACE_ZONE_END(z0);
 }
 
 static iree_status_t iree_hal_amdgpu_allocator_import_buffer(
@@ -267,18 +853,83 @@
     iree_hal_external_buffer_t* IREE_RESTRICT external_buffer,
     iree_hal_buffer_release_callback_t release_callback,
     iree_hal_buffer_t** IREE_RESTRICT out_buffer) {
+  IREE_ASSERT_ARGUMENT(out_buffer);
+  *out_buffer = NULL;
   iree_hal_amdgpu_allocator_t* allocator =
       iree_hal_amdgpu_allocator_cast(base_allocator);
 
-  // Coerce options into those required by the current device.
+  if (IREE_UNLIKELY(external_buffer->flags !=
+                    IREE_HAL_EXTERNAL_BUFFER_FLAG_NONE)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "unsupported AMDGPU external buffer flags: 0x%x",
+                            external_buffer->flags);
+  }
+
+  switch (external_buffer->type) {
+    case IREE_HAL_EXTERNAL_BUFFER_TYPE_HOST_ALLOCATION:
+      break;
+    case IREE_HAL_EXTERNAL_BUFFER_TYPE_DEVICE_ALLOCATION:
+    case IREE_HAL_EXTERNAL_BUFFER_TYPE_OPAQUE_FD:
+    case IREE_HAL_EXTERNAL_BUFFER_TYPE_OPAQUE_WIN32:
+      return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                              "AMDGPU external buffer type not supported");
+    default:
+      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                              "invalid AMDGPU external buffer type");
+  }
+
+  if (IREE_UNLIKELY(!external_buffer->handle.host_allocation.ptr)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "host allocation import requires a non-null ptr");
+  }
+  if (IREE_UNLIKELY(
+          !iree_device_size_is_valid_alignment(params->min_alignment))) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "requested AMDGPU import alignment %" PRIu64
+                            " is not a power-of-two",
+                            (uint64_t)params->min_alignment);
+  }
+  if (IREE_UNLIKELY(params->min_alignment != 0 &&
+                    (params->min_alignment > IREE_HOST_SIZE_MAX ||
+                     !iree_host_ptr_has_alignment(
+                         external_buffer->handle.host_allocation.ptr,
+                         (iree_host_size_t)params->min_alignment)))) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "host allocation import pointer does not satisfy requested AMDGPU "
+        "alignment %" PRIu64,
+        (uint64_t)params->min_alignment);
+  }
+  if (IREE_UNLIKELY(external_buffer->size == 0)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "host allocation import requires a non-zero size");
+  }
+  if (IREE_UNLIKELY(
+          iree_all_bits_set(params->type, IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL))) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "unable to import host allocations as device-local memory");
+  }
+
   iree_hal_buffer_params_t compat_params = *params;
+  if (iree_any_bit_set(compat_params.type, IREE_HAL_MEMORY_TYPE_OPTIMAL)) {
+    compat_params.type &= ~IREE_HAL_MEMORY_TYPE_OPTIMAL;
+    compat_params.type |=
+        IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE;
+  }
+  if (!iree_all_bits_set(compat_params.type,
+                         IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "host allocation import requires device-visible memory");
+  }
+
   iree_device_size_t allocation_size = external_buffer->size;
   iree_hal_buffer_compatibility_t compatibility =
       iree_hal_amdgpu_allocator_query_buffer_compatibility(
           base_allocator, &compat_params, &allocation_size);
   if (!iree_all_bits_set(compatibility,
                          IREE_HAL_BUFFER_COMPATIBILITY_IMPORTABLE)) {
-    // TODO(benvanik): make a helper for this.
 #if IREE_STATUS_MODE
     iree_bitfield_string_temp_t temp0, temp1, temp2;
     iree_string_view_t memory_type_str =
@@ -300,17 +951,73 @@
 #endif  // IREE_STATUS_MODE
   }
 
-  // TODO(benvanik): switch on external_buffer->type and import the buffer. See
-  // the headers for more information on semantics. Most implementations can
-  // service IREE_HAL_EXTERNAL_BUFFER_TYPE_DEVICE_ALLOCATION by just wrapping
-  // the underlying device pointer. Those that can service
-  // IREE_HAL_EXTERNAL_BUFFER_TYPE_HOST_ALLOCATION may be able to avoid a lot of
-  // additional copies when moving data around between host and device or across
-  // devices from different drivers.
-  (void)allocator;
-  iree_status_t status = iree_make_status(IREE_STATUS_UNIMPLEMENTED,
-                                          "external buffer type not supported");
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, external_buffer->size);
+  void* host_ptr = external_buffer->handle.host_allocation.ptr;
+  void* agent_ptr = NULL;
+  iree_hal_amdgpu_access_agent_list_t access_agents;
+  iree_status_t status = iree_hal_amdgpu_allocator_resolve_access_agents(
+      allocator, compat_params.queue_affinity, &access_agents);
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_access_lock_host_allocation(
+        allocator->libhsa, &access_agents, host_ptr, external_buffer->size,
+        &agent_ptr);
+  }
 
+  iree_hal_amdgpu_imported_host_release_data_t* release_data = NULL;
+  if (iree_status_is_ok(status)) {
+    status =
+        iree_allocator_malloc(allocator->host_allocator, sizeof(*release_data),
+                              (void**)&release_data);
+    if (iree_status_is_ok(status)) {
+      memset(release_data, 0, sizeof(*release_data));
+    }
+  }
+
+  iree_hal_buffer_t* buffer = NULL;
+  if (iree_status_is_ok(status)) {
+    release_data->libhsa = allocator->libhsa;
+    release_data->host_ptr = host_ptr;
+    release_data->length = external_buffer->size;
+    release_data->memory_type = compat_params.type;
+    release_data->buffer_usage = compat_params.usage;
+    release_data->profile_physical_device_ordinal = UINT32_MAX;
+    release_data->host_allocator = allocator->host_allocator;
+    release_data->caller_release_callback = release_callback;
+    iree_hal_buffer_release_callback_t imported_release_callback = {
+        .fn = iree_hal_amdgpu_allocator_release_imported_host,
+        .user_data = release_data,
+    };
+    const iree_hal_buffer_placement_t placement = {
+        .device = (iree_hal_device_t*)allocator->logical_device,
+        .queue_affinity = compat_params.queue_affinity
+                              ? compat_params.queue_affinity
+                              : IREE_HAL_QUEUE_AFFINITY_ANY,
+        .flags = IREE_HAL_BUFFER_PLACEMENT_FLAG_NONE,
+    };
+    status = iree_hal_amdgpu_buffer_create(
+        allocator->libhsa, placement, compat_params.type, compat_params.access,
+        compat_params.usage, external_buffer->size, external_buffer->size,
+        agent_ptr, imported_release_callback, allocator->host_allocator,
+        &buffer);
+  }
+
+  if (iree_status_is_ok(status)) {
+    iree_hal_amdgpu_allocator_record_buffer_import(
+        allocator, compat_params, external_buffer, release_data);
+    *out_buffer = buffer;
+  } else {
+    if (release_data) {
+      iree_allocator_free(allocator->host_allocator, release_data);
+    }
+    if (agent_ptr) {
+      status = iree_status_join(
+          status,
+          iree_hsa_amd_memory_unlock(IREE_LIBHSA(allocator->libhsa), host_ptr));
+    }
+    iree_hal_buffer_release(buffer);
+  }
+  IREE_TRACE_ZONE_END(z0);
   return status;
 }
 
@@ -320,18 +1027,8 @@
     iree_hal_external_buffer_type_t requested_type,
     iree_hal_external_buffer_flags_t requested_flags,
     iree_hal_external_buffer_t* IREE_RESTRICT out_external_buffer) {
-  iree_hal_amdgpu_allocator_t* allocator =
-      iree_hal_amdgpu_allocator_cast(base_allocator);
-
-  // TODO(benvanik): switch on requested_type and export as appropriate. Most
-  // implementations can service IREE_HAL_EXTERNAL_BUFFER_TYPE_DEVICE_ALLOCATION
-  // by just exposing the underlying device pointer. Those that can service
-  // IREE_HAL_EXTERNAL_BUFFER_TYPE_HOST_ALLOCATION may be able to avoid a lot of
-  // additional copies when moving data around between host and device or across
-  // devices from different drivers.
-  (void)allocator;
   return iree_make_status(IREE_STATUS_UNAVAILABLE,
-                          "external buffer type not supported");
+                          "AMDGPU buffer export not yet implemented");
 }
 
 static bool iree_hal_amdgpu_allocator_supports_virtual_memory(
diff --git a/runtime/src/iree/hal/drivers/amdgpu/allocator.h b/runtime/src/iree/hal/drivers/amdgpu/allocator.h
index 49f1321..a797bf0 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/allocator.h
+++ b/runtime/src/iree/hal/drivers/amdgpu/allocator.h
@@ -1,4 +1,4 @@
-// Copyright 2025 The IREE Authors
+// Copyright 2026 The IREE Authors
 //
 // Licensed under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -11,17 +11,25 @@
 #include "iree/hal/api.h"
 #include "iree/hal/drivers/amdgpu/util/libhsa.h"
 
+typedef struct iree_hal_amdgpu_logical_device_t
+    iree_hal_amdgpu_logical_device_t;
 typedef struct iree_hal_amdgpu_topology_t iree_hal_amdgpu_topology_t;
 
 //===----------------------------------------------------------------------===//
 // iree_hal_amdgpu_allocator_t
 //===----------------------------------------------------------------------===//
 
-// TODO(benvanik): implement allocator and expose ringbuffers
-// (iree_hal_amdgpu_allocator_allocate_ringbuffer, etc).
-
-// Creates a buffer allocator used for persistent allocations and import/export.
+// Creates a buffer allocator that allocates from HSA memory pools.
+//
+// This is a simple direct allocator — each allocate_buffer call maps to an
+// hsa_amd_memory_pool_allocate and each deallocation to an
+// hsa_amd_memory_pool_free. No pooling, no suballocation. Suitable for
+// bootstrapping and testing; the async suballocator (hal/utils/) will replace
+// this on the transient allocation path.
+//
+// |logical_device| is unretained and must outlive the allocator.
 iree_status_t iree_hal_amdgpu_allocator_create(
+    iree_hal_amdgpu_logical_device_t* logical_device,
     const iree_hal_amdgpu_libhsa_t* libhsa,
     const iree_hal_amdgpu_topology_t* topology, iree_allocator_t host_allocator,
     iree_hal_allocator_t** out_allocator);
diff --git a/runtime/src/iree/hal/drivers/amdgpu/allocator_test.cc b/runtime/src/iree/hal/drivers/amdgpu/allocator_test.cc
new file mode 100644
index 0000000..f6a5ea4
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/allocator_test.cc
@@ -0,0 +1,261 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <array>
+
+#include "iree/hal/api.h"
+#include "iree/hal/cts/util/test_base.h"
+#include "iree/hal/drivers/amdgpu/logical_device.h"
+#include "iree/hal/drivers/amdgpu/util/topology.h"
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace iree::hal::amdgpu {
+namespace {
+
+class AllocatorTest : public ::testing::Test {
+ protected:
+  static void SetUpTestSuite() {
+    host_allocator_ = iree_allocator_system();
+    iree_status_t status = iree_hal_amdgpu_libhsa_initialize(
+        IREE_HAL_AMDGPU_LIBHSA_FLAG_NONE, iree_string_view_list_empty(),
+        host_allocator_, &libhsa_);
+    if (!iree_status_is_ok(status)) {
+      iree_status_fprint(stderr, status);
+      iree_status_free(status);
+      GTEST_SKIP() << "HSA not available, skipping tests";
+    }
+    IREE_ASSERT_OK(iree_hal_amdgpu_topology_initialize_with_defaults(
+        &libhsa_, &topology_));
+    if (topology_.gpu_agent_count == 0) {
+      GTEST_SKIP() << "no GPU devices available, skipping tests";
+    }
+  }
+
+  static void TearDownTestSuite() {
+    iree_hal_amdgpu_topology_deinitialize(&topology_);
+    iree_hal_amdgpu_libhsa_deinitialize(&libhsa_);
+  }
+
+  class TestLogicalDevice {
+   public:
+    ~TestLogicalDevice() {
+      iree_hal_device_release(base_device_);
+      iree_hal_device_group_release(device_group_);
+    }
+
+    iree_status_t Initialize(const iree_hal_amdgpu_libhsa_t* libhsa,
+                             const iree_hal_amdgpu_topology_t* topology,
+                             iree_allocator_t host_allocator) {
+      iree_hal_amdgpu_logical_device_options_t options;
+      iree_hal_amdgpu_logical_device_options_initialize(&options);
+      IREE_RETURN_IF_ERROR(create_context_.Initialize(host_allocator));
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_logical_device_create(
+          IREE_SV("amdgpu"), &options, libhsa, topology,
+          create_context_.params(), host_allocator, &base_device_));
+      return iree_hal_device_group_create_from_device(
+          base_device_, create_context_.frontier_tracker(), host_allocator,
+          &device_group_);
+    }
+
+    iree_hal_allocator_t* allocator() const {
+      return iree_hal_device_allocator(base_device_);
+    }
+
+   private:
+    // Creation context supplying the proactor pool and frontier tracker.
+    iree::hal::cts::DeviceCreateContext create_context_;
+
+    // Test-owned device reference released before the topology-owning group.
+    iree_hal_device_t* base_device_ = NULL;
+
+    // Device group that owns the topology assigned to |base_device_|.
+    iree_hal_device_group_t* device_group_ = NULL;
+  };
+
+  static iree_allocator_t host_allocator_;
+  static iree_hal_amdgpu_libhsa_t libhsa_;
+  static iree_hal_amdgpu_topology_t topology_;
+};
+
+iree_allocator_t AllocatorTest::host_allocator_;
+iree_hal_amdgpu_libhsa_t AllocatorTest::libhsa_;
+iree_hal_amdgpu_topology_t AllocatorTest::topology_;
+
+TEST_F(AllocatorTest, QueryMemoryHeapsReportsHsaLimits) {
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(test_device.Initialize(&libhsa_, &topology_, host_allocator_));
+
+  iree_host_size_t heap_count = 0;
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_OUT_OF_RANGE,
+                        iree_hal_allocator_query_memory_heaps(
+                            test_device.allocator(),
+                            /*capacity=*/0, /*heaps=*/NULL, &heap_count));
+  ASSERT_EQ(heap_count, 3u);
+
+  std::array<iree_hal_allocator_memory_heap_t, 3> heaps;
+  IREE_ASSERT_OK(iree_hal_allocator_query_memory_heaps(
+      test_device.allocator(), heaps.size(), heaps.data(), &heap_count));
+  ASSERT_EQ(heap_count, heaps.size());
+
+  for (const auto& heap : heaps) {
+    EXPECT_NE(heap.max_allocation_size, 0u);
+    EXPECT_NE(heap.max_allocation_size, ~(iree_device_size_t)0);
+    EXPECT_NE(heap.min_alignment, 0u);
+    EXPECT_TRUE(iree_device_size_is_power_of_two(heap.min_alignment));
+  }
+}
+
+TEST_F(AllocatorTest, OversizedAllocationIsRejectedByCompatibility) {
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(test_device.Initialize(&libhsa_, &topology_, host_allocator_));
+
+  std::array<iree_hal_allocator_memory_heap_t, 3> heaps;
+  iree_host_size_t heap_count = 0;
+  IREE_ASSERT_OK(iree_hal_allocator_query_memory_heaps(
+      test_device.allocator(), heaps.size(), heaps.data(), &heap_count));
+  ASSERT_EQ(heap_count, heaps.size());
+  ASSERT_LT(heaps[0].max_allocation_size, ~(iree_device_size_t)0);
+
+  iree_device_size_t oversized_allocation_size = 0;
+  ASSERT_TRUE(iree_device_size_checked_add(heaps[0].max_allocation_size, 1,
+                                           &oversized_allocation_size));
+
+  iree_hal_buffer_params_t params = {0};
+  params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL;
+  params.usage = IREE_HAL_BUFFER_USAGE_TRANSFER;
+
+  iree_hal_buffer_params_t resolved_params = {0};
+  iree_device_size_t resolved_allocation_size = 0;
+  const iree_hal_buffer_compatibility_t compatibility =
+      iree_hal_allocator_query_buffer_compatibility(
+          test_device.allocator(), params, oversized_allocation_size,
+          &resolved_params, &resolved_allocation_size);
+  EXPECT_EQ(compatibility, IREE_HAL_BUFFER_COMPATIBILITY_NONE);
+}
+
+TEST_F(AllocatorTest, DeviceLocalHostVisibleMemoryIsLowPerformance) {
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(test_device.Initialize(&libhsa_, &topology_, host_allocator_));
+
+  iree_hal_buffer_params_t params = {0};
+  params.type =
+      IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL | IREE_HAL_MEMORY_TYPE_HOST_VISIBLE;
+  params.usage = IREE_HAL_BUFFER_USAGE_TRANSFER |
+                 IREE_HAL_BUFFER_USAGE_DISPATCH |
+                 IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED;
+
+  iree_hal_buffer_params_t resolved_params = {0};
+  iree_device_size_t resolved_allocation_size = 0;
+  const iree_hal_buffer_compatibility_t compatibility =
+      iree_hal_allocator_query_buffer_compatibility(
+          test_device.allocator(), params, /*allocation_size=*/4096,
+          &resolved_params, &resolved_allocation_size);
+  EXPECT_TRUE(iree_all_bits_set(
+      compatibility, IREE_HAL_BUFFER_COMPATIBILITY_ALLOCATABLE |
+                         IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER |
+                         IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH |
+                         IREE_HAL_BUFFER_COMPATIBILITY_LOW_PERFORMANCE));
+  EXPECT_TRUE(iree_all_bits_set(resolved_params.type,
+                                IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL |
+                                    IREE_HAL_MEMORY_TYPE_HOST_VISIBLE |
+                                    IREE_HAL_MEMORY_TYPE_HOST_COHERENT));
+}
+
+TEST_F(AllocatorTest, OverAlignedAllocationIsRejected) {
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(test_device.Initialize(&libhsa_, &topology_, host_allocator_));
+
+  std::array<iree_hal_allocator_memory_heap_t, 3> heaps;
+  iree_host_size_t heap_count = 0;
+  IREE_ASSERT_OK(iree_hal_allocator_query_memory_heaps(
+      test_device.allocator(), heaps.size(), heaps.data(), &heap_count));
+  ASSERT_EQ(heap_count, heaps.size());
+
+  const iree_device_size_t over_alignment =
+      ~(iree_device_size_t)0 ^ (~(iree_device_size_t)0 >> 1);
+  ASSERT_TRUE(iree_device_size_is_power_of_two(over_alignment));
+  ASSERT_GT(over_alignment, heaps[0].min_alignment);
+
+  iree_hal_buffer_params_t params = {0};
+  params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL;
+  params.usage = IREE_HAL_BUFFER_USAGE_TRANSFER;
+  params.min_alignment = over_alignment;
+
+  iree_hal_buffer_params_t resolved_params = {0};
+  iree_device_size_t resolved_allocation_size = 0;
+  const iree_hal_buffer_compatibility_t compatibility =
+      iree_hal_allocator_query_buffer_compatibility(
+          test_device.allocator(), params, /*allocation_size=*/1,
+          &resolved_params, &resolved_allocation_size);
+  EXPECT_EQ(compatibility, IREE_HAL_BUFFER_COMPATIBILITY_NONE);
+
+  iree_hal_buffer_t* buffer = NULL;
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_INVALID_ARGUMENT,
+      iree_hal_allocator_allocate_buffer(test_device.allocator(), params,
+                                         /*allocation_size=*/1, &buffer));
+  EXPECT_EQ(buffer, nullptr);
+}
+
+TEST_F(AllocatorTest, UnsupportedExternalBufferImportsFailLoud) {
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(test_device.Initialize(&libhsa_, &topology_, host_allocator_));
+
+  iree_hal_buffer_params_t params = {0};
+  params.type =
+      IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE;
+  params.usage = IREE_HAL_BUFFER_USAGE_TRANSFER;
+
+  std::array<iree_hal_external_buffer_type_t, 3> unsupported_types = {
+      IREE_HAL_EXTERNAL_BUFFER_TYPE_DEVICE_ALLOCATION,
+      IREE_HAL_EXTERNAL_BUFFER_TYPE_OPAQUE_FD,
+      IREE_HAL_EXTERNAL_BUFFER_TYPE_OPAQUE_WIN32,
+  };
+  for (iree_hal_external_buffer_type_t unsupported_type : unsupported_types) {
+    iree_hal_external_buffer_t external_buffer = {};
+    external_buffer.type = unsupported_type;
+    external_buffer.size = 4096;
+
+    iree_hal_buffer_t* buffer = NULL;
+    IREE_EXPECT_STATUS_IS(
+        IREE_STATUS_UNIMPLEMENTED,
+        iree_hal_allocator_import_buffer(
+            test_device.allocator(), params, &external_buffer,
+            iree_hal_buffer_release_callback_null(), &buffer));
+    EXPECT_EQ(buffer, nullptr);
+  }
+}
+
+TEST_F(AllocatorTest, ExternalBufferExportFailsLoud) {
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(test_device.Initialize(&libhsa_, &topology_, host_allocator_));
+
+  iree_hal_buffer_params_t params = {0};
+  params.type =
+      IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE;
+  params.usage = IREE_HAL_BUFFER_USAGE_TRANSFER;
+
+  iree_hal_buffer_t* buffer = NULL;
+  IREE_ASSERT_OK(iree_hal_allocator_allocate_buffer(
+      test_device.allocator(), params, /*allocation_size=*/4096, &buffer));
+
+  iree_hal_external_buffer_t external_buffer = {};
+  external_buffer.type = IREE_HAL_EXTERNAL_BUFFER_TYPE_OPAQUE_WIN32;
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_UNAVAILABLE,
+      iree_hal_allocator_export_buffer(
+          test_device.allocator(), buffer,
+          IREE_HAL_EXTERNAL_BUFFER_TYPE_OPAQUE_WIN32,
+          IREE_HAL_EXTERNAL_BUFFER_FLAG_NONE, &external_buffer));
+  EXPECT_EQ(external_buffer.type, IREE_HAL_EXTERNAL_BUFFER_TYPE_NONE);
+  EXPECT_EQ(external_buffer.size, 0u);
+
+  iree_hal_buffer_release(buffer);
+}
+
+}  // namespace
+}  // namespace iree::hal::amdgpu
diff --git a/runtime/src/iree/hal/drivers/amdgpu/api.h b/runtime/src/iree/hal/drivers/amdgpu/api.h
index 8f1cf5b..996cda0 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/api.h
+++ b/runtime/src/iree/hal/drivers/amdgpu/api.h
@@ -24,17 +24,16 @@
 
 // Controls where the queue operates.
 typedef enum iree_hal_amdgpu_queue_placement_e {
-  // Automatically select where to place the queue based on whether the target
-  // CPU/GPU agent pair supports device placement requirements.
+  // Automatically select the best supported placement. Today this selects the
+  // host queue path because device-side queue scheduling is not implemented.
   IREE_HAL_AMDGPU_QUEUE_PLACEMENT_ANY = 0,
   // Queue executes entirely on the host via iree_hal_amdgpu_host_queue_t.
   // This introduces additional latency on all queue operations but can operate
   // on systems without host/device atomics (PCIe atomics, xGMI, etc). It is
   // also useful for debugging.
   IREE_HAL_AMDGPU_QUEUE_PLACEMENT_HOST,
-  // Queue executes entirely on the device via iree_hal_amdgpu_device_queue_t.
-  // A scheduler kernel handles all queue entry processing without host
-  // involvement.
+  // Queue executes entirely on the device. Not implemented; requests for this
+  // explicit placement fail during device option verification.
   IREE_HAL_AMDGPU_QUEUE_PLACEMENT_DEVICE,
 } iree_hal_amdgpu_queue_placement_t;
 
@@ -44,14 +43,22 @@
 typedef struct iree_hal_amdgpu_logical_device_options_t {
   // Size of a block in each host block pool.
   struct {
+    // Small host block pool options.
     struct {
       // Size in bytes of a small host block. Must be a power of two.
       iree_host_size_t block_size;
     } small;
+    // Large host block pool options.
     struct {
       // Size in bytes of a large host block. Must be a power of two.
       iree_host_size_t block_size;
     } large;
+    // Command-buffer host block pool options.
+    struct {
+      // Usable byte capacity of a command-buffer recording block. Must be a
+      // power of two.
+      iree_host_size_t usable_block_size;
+    } command_buffer;
   } host_block_pools;
 
   // Size of a block in each device block pool.
@@ -70,29 +77,60 @@
     } large;
   } device_block_pools;
 
-  // Controls where queues are placed.
-  // Defaults to IREE_HAL_AMDGPU_QUEUE_PLACEMENT_ANY and selects the optimal
-  // placement based on queried agent properties. If a placement is explicitly
-  // specified all physical devices will use that placement and if any do not
-  // support it initialization will fail (useful for forcing host placement
-  // during debugging/testing).
+  // Default queue-allocation pool policy.
+  struct {
+    // Logical byte length of the default TLSF pool range per physical device.
+    iree_device_size_t range_length;
+
+    // Minimum byte alignment for every default-pool reservation.
+    iree_device_size_t alignment;
+
+    // Maximum death-frontier entry count stored per free TLSF block.
+    uint8_t frontier_capacity;
+  } default_pool;
+
+  // Controls where queues are placed. ANY and HOST currently select host
+  // queues. DEVICE is reserved for future device-side scheduling and fails
+  // loudly until that path is implemented.
   iree_hal_amdgpu_queue_placement_t queue_placement;
 
+  // Per-physical-device host queue policy.
+  struct {
+    // HSA AQL ring capacity in packets for each host queue. Must be a power of
+    // two. Larger rings allow more in-flight packet work before submitters see
+    // AQL backpressure.
+    uint32_t aql_capacity;
+    // Completion/reclaim ring capacity for each host queue. Must be a power of
+    // two. This bounds in-flight host-visible completion epochs before replay
+    // must park and resume after drain.
+    uint32_t notification_capacity;
+    // Kernarg ring capacity in 64-byte blocks for each host queue. Must be a
+    // power of two and at least 2x |aql_capacity| to cover one tail-padding
+    // gap at wrap. Submission admission checks kernarg and AQL capacity
+    // together before publishing packets.
+    uint32_t kernarg_capacity;
+    // Device-visible control upload ring capacity in bytes for each host queue.
+    // Zero disables the optional upload ring; non-zero values must be powers of
+    // two. This carries queue-ordered metadata such as device-side
+    // command-buffer fixup inputs without using the file staging pool.
+    uint32_t upload_capacity;
+  } host_queues;
+
   // Preallocates a reasonable number of resources in pools to reduce initial
   // execution latency.
   uint64_t preallocate_pools : 1;
 
-  // Enables dispatch-level tracing (if device instrumentation is compiled in).
-  uint64_t trace_execution : 1;
-
-  // Forces queues to run one entry at a time instead of overlapping or
-  // aggressively scheduling queue entries out-of-order.
+  // Reserved for a future exclusive queue scheduling mode. Unsupported today;
+  // enabling it fails option verification.
   uint64_t exclusive_execution : 1;
 
-  // Uses HSA_WAIT_STATE_ACTIVE for up to the given duration before switching to
-  // HSA_WAIT_STATE_BLOCKED. Above zero this will increase CPU usage in cases
-  // where the waits are long and decrease latency in cases where the waits are
-  // short. When IREE_DURATION_INFINITE waits will use HSA_WAIT_STATE_ACTIVE.
+  // Forces cross-queue wait barriers to use software deferral instead of the
+  // device-side strategy selected from the GPU ISA. Useful for testing the
+  // conservative host-only fallback path.
+  uint64_t force_wait_barrier_defer : 1;
+
+  // Reserved for future HSA active-wait tuning. Must be zero today because no
+  // wait path consumes it yet.
   iree_duration_t wait_active_for_ns;
 } iree_hal_amdgpu_logical_device_options_t;
 
@@ -100,9 +138,9 @@
 IREE_API_EXPORT void iree_hal_amdgpu_logical_device_options_initialize(
     iree_hal_amdgpu_logical_device_options_t* out_options);
 
-// Parses |params| and updates |options|.
-// String views may be set to reference strings in the original parameters and
-// the caller must ensure the options does not outlive the storage.
+// Parses |params| and updates |options|. No AMDGPU logical-device string
+// parameters are currently supported; nonempty lists fail loudly instead of
+// being ignored.
 IREE_API_EXPORT iree_status_t iree_hal_amdgpu_logical_device_options_parse(
     iree_hal_amdgpu_logical_device_options_t* options,
     iree_string_pair_list_t params);
@@ -134,7 +172,8 @@
 // use.
 typedef struct iree_hal_amdgpu_driver_options_t {
   // Search paths (directories or files) for finding the HSA runtime shared
-  // library.
+  // library. Driver creation clones these strings; callers only need to keep
+  // them live until iree_hal_amdgpu_driver_create returns.
   iree_string_view_list_t libhsa_search_paths;
 
   // Default device options when none are provided during device creation.
@@ -145,9 +184,8 @@
 IREE_API_EXPORT void iree_hal_amdgpu_driver_options_initialize(
     iree_hal_amdgpu_driver_options_t* out_options);
 
-// Parses |params| and updates |options|.
-// String views may be set to reference strings in the original parameters and
-// the caller must ensure the options does not outlive the storage.
+// Parses |params| and updates |options|. No AMDGPU driver string parameters are
+// currently supported; nonempty lists fail loudly instead of being ignored.
 IREE_API_EXPORT iree_status_t iree_hal_amdgpu_driver_options_parse(
     iree_hal_amdgpu_driver_options_t* options, iree_string_pair_list_t params);
 
diff --git a/runtime/src/iree/hal/drivers/amdgpu/aql_block_processor.c b/runtime/src/iree/hal/drivers/amdgpu/aql_block_processor.c
new file mode 100644
index 0000000..120c97d
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/aql_block_processor.c
@@ -0,0 +1,1184 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/aql_block_processor.h"
+
+#include <string.h>
+
+#include "iree/hal/drivers/amdgpu/buffer.h"
+#include "iree/hal/drivers/amdgpu/device/dispatch.h"
+#include "iree/hal/drivers/amdgpu/util/aql_emitter.h"
+
+typedef uint32_t iree_hal_amdgpu_aql_block_processor_packet_flags_t;
+enum iree_hal_amdgpu_aql_block_processor_packet_flag_bits_t {
+  IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PACKET_FLAG_NONE = 0u,
+  // Packet must participate in the command-buffer execution dependency chain.
+  IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PACKET_FLAG_EXECUTION_BARRIER = 1u << 0,
+  // Packet carries terminal signal release scope for the block submission.
+  IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PACKET_FLAG_FINAL = 1u << 1,
+  // First bit of the two-bit acquire fence scope field in packet flags.
+  IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PACKET_ACQUIRE_SCOPE_SHIFT = 2,
+  // Bit mask of the two-bit acquire fence scope field in packet flags.
+  IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PACKET_ACQUIRE_SCOPE_MASK = 0x0Cu,
+  // First bit of the two-bit release fence scope field in packet flags.
+  IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PACKET_RELEASE_SCOPE_SHIFT = 4,
+  // Bit mask of the two-bit release fence scope field in packet flags.
+  IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PACKET_RELEASE_SCOPE_MASK = 0x30u,
+};
+
+typedef struct iree_hal_amdgpu_aql_block_processor_state_t {
+  // Packet cursors advanced while invoking the block.
+  struct {
+    // Number of recorded block AQL packets consumed from the block.
+    uint32_t recorded;
+    // Number of payload AQL packets emitted into the reserved packet span.
+    uint32_t emitted;
+  } packets;
+  // Queue-owned kernarg cursors advanced while invoking the block.
+  struct {
+    // Number of queue-owned kernarg blocks consumed from the reserved span.
+    uint32_t block;
+  } kernargs;
+} iree_hal_amdgpu_aql_block_processor_state_t;
+
+static iree_hsa_fence_scope_t iree_hal_amdgpu_aql_block_processor_max_scope(
+    iree_hsa_fence_scope_t lhs, iree_hsa_fence_scope_t rhs) {
+  return lhs > rhs ? lhs : rhs;
+}
+
+static iree_hal_amdgpu_aql_block_processor_packet_flags_t
+iree_hal_amdgpu_aql_block_processor_packet_flags_set_fence_scopes(
+    iree_hal_amdgpu_aql_block_processor_packet_flags_t flags,
+    iree_hsa_fence_scope_t acquire_scope,
+    iree_hsa_fence_scope_t release_scope) {
+  flags &= ~(IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PACKET_ACQUIRE_SCOPE_MASK |
+             IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PACKET_RELEASE_SCOPE_MASK);
+  flags |= ((uint32_t)acquire_scope & 0x3u)
+           << IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PACKET_ACQUIRE_SCOPE_SHIFT;
+  flags |= ((uint32_t)release_scope & 0x3u)
+           << IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PACKET_RELEASE_SCOPE_SHIFT;
+  return flags;
+}
+
+static iree_hsa_fence_scope_t
+iree_hal_amdgpu_aql_block_processor_packet_flags_fence_scope(
+    iree_hal_amdgpu_aql_block_processor_packet_flags_t flags, uint32_t mask,
+    uint32_t shift) {
+  return (iree_hsa_fence_scope_t)((flags & mask) >> shift);
+}
+
+static iree_hsa_fence_scope_t
+iree_hal_amdgpu_aql_block_processor_packet_flags_acquire_scope(
+    iree_hal_amdgpu_aql_block_processor_packet_flags_t flags) {
+  return iree_hal_amdgpu_aql_block_processor_packet_flags_fence_scope(
+      flags, IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PACKET_ACQUIRE_SCOPE_MASK,
+      IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PACKET_ACQUIRE_SCOPE_SHIFT);
+}
+
+static iree_hsa_fence_scope_t
+iree_hal_amdgpu_aql_block_processor_packet_flags_release_scope(
+    iree_hal_amdgpu_aql_block_processor_packet_flags_t flags) {
+  return iree_hal_amdgpu_aql_block_processor_packet_flags_fence_scope(
+      flags, IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PACKET_RELEASE_SCOPE_MASK,
+      IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PACKET_RELEASE_SCOPE_SHIFT);
+}
+
+static bool iree_hal_amdgpu_aql_block_processor_packet_has_barrier(
+    const iree_hal_amdgpu_aql_block_processor_t* processor,
+    uint32_t packet_index,
+    iree_hal_amdgpu_aql_block_processor_packet_flags_t packet_flags) {
+  return iree_any_bit_set(
+             packet_flags,
+             IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PACKET_FLAG_EXECUTION_BARRIER) ||
+         iree_any_bit_set(
+             packet_flags,
+             IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PACKET_FLAG_FINAL) ||
+         (packet_index == 0 && processor->submission.wait_barrier_count > 0) ||
+         (packet_index == 0 && processor->submission.inline_acquire_scope !=
+                                   IREE_HSA_FENCE_SCOPE_NONE);
+}
+
+static iree_hal_amdgpu_aql_packet_control_t
+iree_hal_amdgpu_aql_block_processor_packet_control(
+    const iree_hal_amdgpu_aql_block_processor_t* processor,
+    uint32_t packet_index, iree_hsa_fence_scope_t minimum_acquire_scope,
+    iree_hal_amdgpu_aql_block_processor_packet_flags_t packet_flags) {
+  // The replay hot path is intentionally additive: command recording has
+  // already encoded execution-barrier scopes in |packet_flags|, and submission
+  // policy only overlays wait, queue-kernarg, and terminal-signal visibility.
+  // Do not infer memory hazards here by walking command operands.
+  const uint32_t logical_packet_index =
+      processor->packets.index_base + packet_index;
+  const bool has_barrier =
+      iree_hal_amdgpu_aql_block_processor_packet_has_barrier(
+          processor, logical_packet_index, packet_flags);
+  const iree_hsa_fence_scope_t execution_acquire_scope =
+      iree_hal_amdgpu_aql_block_processor_packet_flags_acquire_scope(
+          packet_flags);
+  const iree_hsa_fence_scope_t execution_release_scope =
+      iree_hal_amdgpu_aql_block_processor_packet_flags_release_scope(
+          packet_flags);
+  const iree_hsa_fence_scope_t acquire_scope =
+      logical_packet_index == 0
+          ? iree_hal_amdgpu_aql_block_processor_max_scope(
+                execution_acquire_scope,
+                processor->submission.inline_acquire_scope)
+          : execution_acquire_scope;
+  const iree_hsa_fence_scope_t effective_acquire_scope =
+      iree_hal_amdgpu_aql_block_processor_max_scope(acquire_scope,
+                                                    minimum_acquire_scope);
+  iree_hsa_fence_scope_t release_scope = execution_release_scope;
+  if (iree_any_bit_set(packet_flags,
+                       IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PACKET_FLAG_FINAL)) {
+    release_scope = iree_hal_amdgpu_aql_block_processor_max_scope(
+        release_scope, processor->submission.signal_release_scope);
+  }
+  return iree_hal_amdgpu_aql_packet_control(
+      has_barrier, effective_acquire_scope, release_scope);
+}
+
+static iree_hal_amdgpu_aql_block_processor_packet_flags_t
+iree_hal_amdgpu_aql_block_processor_command_packet_flags(
+    const iree_hal_amdgpu_command_buffer_command_header_t* command) {
+  iree_hal_amdgpu_aql_block_processor_packet_flags_t packet_flags =
+      IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PACKET_FLAG_NONE;
+  if (iree_any_bit_set(
+          command->flags,
+          IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_HAS_BARRIER)) {
+    packet_flags |=
+        IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PACKET_FLAG_EXECUTION_BARRIER;
+  }
+  return iree_hal_amdgpu_aql_block_processor_packet_flags_set_fence_scopes(
+      packet_flags,
+      (iree_hsa_fence_scope_t)
+          iree_hal_amdgpu_command_buffer_command_flags_acquire_scope(
+              command->flags),
+      (iree_hsa_fence_scope_t)
+          iree_hal_amdgpu_command_buffer_command_flags_release_scope(
+              command->flags));
+}
+
+static iree_hal_amdgpu_aql_block_processor_packet_flags_t
+iree_hal_amdgpu_aql_block_processor_packet_flags_merge(
+    iree_hal_amdgpu_aql_block_processor_packet_flags_t lhs,
+    iree_hal_amdgpu_aql_block_processor_packet_flags_t rhs) {
+  const iree_hsa_fence_scope_t acquire_scope =
+      iree_hal_amdgpu_aql_block_processor_max_scope(
+          iree_hal_amdgpu_aql_block_processor_packet_flags_acquire_scope(lhs),
+          iree_hal_amdgpu_aql_block_processor_packet_flags_acquire_scope(rhs));
+  const iree_hsa_fence_scope_t release_scope =
+      iree_hal_amdgpu_aql_block_processor_max_scope(
+          iree_hal_amdgpu_aql_block_processor_packet_flags_release_scope(lhs),
+          iree_hal_amdgpu_aql_block_processor_packet_flags_release_scope(rhs));
+  return iree_hal_amdgpu_aql_block_processor_packet_flags_set_fence_scopes(
+      lhs | rhs, acquire_scope, release_scope);
+}
+
+static iree_hal_amdgpu_aql_block_processor_packet_flags_t
+iree_hal_amdgpu_aql_block_processor_agent_barrier_packet_flags(void) {
+  return iree_hal_amdgpu_aql_block_processor_packet_flags_set_fence_scopes(
+      IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PACKET_FLAG_EXECUTION_BARRIER,
+      IREE_HSA_FENCE_SCOPE_AGENT, IREE_HSA_FENCE_SCOPE_AGENT);
+}
+
+static bool iree_hal_amdgpu_aql_block_processor_command_uses_queue_kernargs(
+    const iree_hal_amdgpu_command_buffer_command_header_t* command) {
+  return iree_any_bit_set(
+      command->flags,
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_USES_QUEUE_KERNARGS);
+}
+
+static bool iree_hal_amdgpu_aql_block_processor_dispatch_uses_indirect(
+    const iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command) {
+  return iree_any_bit_set(
+      dispatch_command->dispatch_flags,
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_DISPATCH_FLAG_INDIRECT_PARAMETERS);
+}
+
+static bool iree_hal_amdgpu_aql_block_processor_dispatch_uses_prepublished(
+    const iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command) {
+  return dispatch_command->kernarg_strategy ==
+         IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_PREPUBLISHED;
+}
+
+static bool iree_hal_amdgpu_aql_block_processor_dispatch_uses_queue_kernargs(
+    const iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command) {
+  return !iree_hal_amdgpu_aql_block_processor_dispatch_uses_prepublished(
+      dispatch_command);
+}
+
+static uint32_t
+iree_hal_amdgpu_aql_block_processor_dispatch_target_kernarg_block_count(
+    const iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command) {
+  if (iree_hal_amdgpu_aql_block_processor_dispatch_uses_prepublished(
+          dispatch_command)) {
+    return 0;
+  }
+  const uint32_t kernarg_length =
+      (uint32_t)dispatch_command->kernarg_length_qwords * 8u;
+  return iree_max(1u,
+                  (uint32_t)iree_host_size_ceil_div(
+                      kernarg_length, sizeof(iree_hal_amdgpu_kernarg_block_t)));
+}
+
+static uint32_t
+iree_hal_amdgpu_aql_block_processor_dispatch_kernarg_block_count(
+    const iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command) {
+  return iree_hal_amdgpu_aql_block_processor_dispatch_target_kernarg_block_count(
+             dispatch_command) +
+         (iree_hal_amdgpu_aql_block_processor_dispatch_uses_indirect(
+              dispatch_command)
+              ? 1u
+              : 0u);
+}
+
+static iree_status_t iree_hal_amdgpu_aql_block_processor_resolve_buffer_ref_ptr(
+    iree_hal_buffer_ref_t buffer_ref, iree_hal_buffer_usage_t required_usage,
+    iree_hal_memory_access_t required_access, uint8_t** out_device_ptr) {
+  *out_device_ptr = NULL;
+  if (IREE_UNLIKELY(!buffer_ref.buffer)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "AQL command-buffer dynamic binding resolved to a NULL buffer");
+  }
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_usage(
+      iree_hal_buffer_allowed_usage(buffer_ref.buffer), required_usage));
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_access(
+      iree_hal_buffer_allowed_access(buffer_ref.buffer), required_access));
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_range(
+      buffer_ref.buffer, buffer_ref.offset, buffer_ref.length));
+  iree_hal_buffer_t* allocated_buffer =
+      iree_hal_buffer_allocated_buffer(buffer_ref.buffer);
+  uint8_t* device_ptr =
+      (uint8_t*)iree_hal_amdgpu_buffer_device_pointer(allocated_buffer);
+  if (IREE_UNLIKELY(!device_ptr)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "AQL command-buffer buffer must be backed by an AMDGPU allocation");
+  }
+  iree_device_size_t device_offset = 0;
+  if (IREE_UNLIKELY(!iree_device_size_checked_add(
+          iree_hal_buffer_byte_offset(buffer_ref.buffer), buffer_ref.offset,
+          &device_offset))) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "AQL command-buffer buffer device pointer offset overflows device "
+        "size");
+  }
+  if (IREE_UNLIKELY(device_offset > UINTPTR_MAX)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "AQL command-buffer buffer device pointer offset exceeds host pointer "
+        "size");
+  }
+  *out_device_ptr = device_ptr + (uintptr_t)device_offset;
+  return iree_ok_status();
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_block_processor_resolve_command_buffer_ref(
+    iree_hal_command_buffer_t* command_buffer,
+    iree_hal_buffer_binding_table_t binding_table,
+    iree_hal_amdgpu_command_buffer_binding_kind_t kind, uint32_t ordinal,
+    uint64_t offset, uint64_t length, iree_hal_buffer_usage_t required_usage,
+    iree_hal_memory_access_t required_access,
+    iree_hal_buffer_ref_t* out_buffer_ref, uint8_t** out_device_ptr) {
+  memset(out_buffer_ref, 0, sizeof(*out_buffer_ref));
+  *out_device_ptr = NULL;
+  if (kind == IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_KIND_STATIC) {
+    iree_hal_buffer_t* buffer =
+        iree_hal_amdgpu_aql_command_buffer_static_buffer(command_buffer,
+                                                         ordinal);
+    if (IREE_UNLIKELY(!buffer)) {
+      return iree_make_status(
+          IREE_STATUS_INVALID_ARGUMENT,
+          "AQL command-buffer static buffer ordinal %" PRIu32 " is invalid",
+          ordinal);
+    }
+    *out_buffer_ref = iree_hal_make_buffer_ref(buffer, offset, length);
+  } else if (kind == IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_KIND_DYNAMIC) {
+    iree_hal_buffer_ref_t dynamic_ref =
+        iree_hal_make_indirect_buffer_ref(ordinal, offset, length);
+    IREE_RETURN_IF_ERROR(iree_hal_buffer_binding_table_resolve_ref(
+        binding_table, dynamic_ref, out_buffer_ref));
+  } else {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AQL command-buffer binding kind %u is invalid",
+                            kind);
+  }
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_block_processor_resolve_buffer_ref_ptr(
+          *out_buffer_ref, required_usage, required_access, out_device_ptr));
+  return iree_ok_status();
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_block_processor_resolve_static_binding_source_ptr(
+    iree_hal_command_buffer_t* command_buffer,
+    const iree_hal_amdgpu_command_buffer_binding_source_t* binding_source,
+    uint64_t* out_binding_ptr) {
+  *out_binding_ptr = 0;
+  iree_hal_buffer_t* buffer = iree_hal_amdgpu_aql_command_buffer_static_buffer(
+      command_buffer, binding_source->slot);
+  if (IREE_UNLIKELY(!buffer)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "AQL command-buffer static dispatch binding ordinal %" PRIu32
+        " is invalid",
+        binding_source->slot);
+  }
+  iree_hal_buffer_t* allocated_buffer =
+      iree_hal_buffer_allocated_buffer(buffer);
+  void* device_ptr = iree_hal_amdgpu_buffer_device_pointer(allocated_buffer);
+  if (IREE_UNLIKELY(!device_ptr)) {
+    return iree_make_status(
+        IREE_STATUS_FAILED_PRECONDITION,
+        "AQL command-buffer static dispatch binding has no staged AMDGPU "
+        "backing after queue waits completed");
+  }
+  iree_device_size_t device_offset = 0;
+  if (IREE_UNLIKELY(!iree_device_size_checked_add(
+          iree_hal_buffer_byte_offset(buffer),
+          binding_source->offset_or_pointer, &device_offset))) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "AQL command-buffer static dispatch binding pointer offset overflows "
+        "device size");
+  }
+  if (IREE_UNLIKELY(device_offset > UINTPTR_MAX)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "AQL command-buffer static dispatch binding pointer offset exceeds "
+        "host pointer size");
+  }
+  *out_binding_ptr =
+      (uint64_t)((uintptr_t)device_ptr + (uintptr_t)device_offset);
+  return iree_ok_status();
+}
+
+static void iree_hal_amdgpu_aql_block_processor_write_dispatch_packet_body(
+    const iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command,
+    iree_hal_amdgpu_aql_packet_t* packet, uint8_t* kernarg_data,
+    iree_hsa_signal_t completion_signal, uint16_t* out_setup) {
+  packet->dispatch.setup = dispatch_command->setup;
+  packet->dispatch.workgroup_size[0] = dispatch_command->workgroup_size[0];
+  packet->dispatch.workgroup_size[1] = dispatch_command->workgroup_size[1];
+  packet->dispatch.workgroup_size[2] = dispatch_command->workgroup_size[2];
+  packet->dispatch.reserved0 = 0;
+  packet->dispatch.grid_size[0] = dispatch_command->grid_size[0];
+  packet->dispatch.grid_size[1] = dispatch_command->grid_size[1];
+  packet->dispatch.grid_size[2] = dispatch_command->grid_size[2];
+  packet->dispatch.private_segment_size =
+      dispatch_command->private_segment_size;
+  packet->dispatch.group_segment_size = dispatch_command->group_segment_size;
+  packet->dispatch.kernel_object = dispatch_command->kernel_object;
+  packet->dispatch.kernarg_address = kernarg_data;
+  packet->dispatch.reserved2 = 0;
+  packet->dispatch.completion_signal = completion_signal;
+  *out_setup = packet->dispatch.setup;
+}
+
+static inline void iree_hal_amdgpu_aql_block_processor_copy_dispatch_tail(
+    const iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command,
+    uint8_t* kernarg_data, iree_host_size_t tail_offset) {
+  const iree_host_size_t tail_length =
+      (iree_host_size_t)dispatch_command->payload.tail_length_qwords * 8u;
+  if (tail_length > 0) {
+    const uint8_t* tail_payload =
+        (const uint8_t*)dispatch_command + dispatch_command->payload_reference;
+    memcpy(kernarg_data + tail_offset, tail_payload, tail_length);
+  }
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_block_processor_replay_dispatch_kernargs(
+    const iree_hal_amdgpu_command_buffer_block_header_t* block,
+    iree_hal_command_buffer_t* command_buffer, const uint64_t* binding_ptrs,
+    const iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command,
+    uint8_t* kernarg_data) {
+  switch (dispatch_command->kernarg_strategy) {
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_DYNAMIC_BINDINGS: {
+      if (IREE_UNLIKELY(!binding_ptrs)) {
+        return iree_make_status(
+            IREE_STATUS_INVALID_ARGUMENT,
+            "AQL command-buffer dispatch has dynamic bindings but no binding "
+            "table was provided");
+      }
+      uint64_t* binding_dst = (uint64_t*)kernarg_data;
+      const iree_hal_amdgpu_command_buffer_binding_source_t* binding_sources =
+          (const iree_hal_amdgpu_command_buffer_binding_source_t*)((const uint8_t*)
+                                                                       block +
+                                                                   dispatch_command
+                                                                       ->binding_source_offset);
+      for (uint16_t i = 0; i < dispatch_command->binding_count; ++i) {
+        const iree_hal_amdgpu_command_buffer_binding_source_t* binding_source =
+            &binding_sources[i];
+        binding_dst[i] = binding_ptrs[binding_source->slot] +
+                         binding_source->offset_or_pointer;
+      }
+      iree_hal_amdgpu_aql_block_processor_copy_dispatch_tail(
+          dispatch_command, kernarg_data,
+          (iree_host_size_t)dispatch_command->binding_count * sizeof(uint64_t));
+      return iree_ok_status();
+    }
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_HAL: {
+      uint64_t* binding_dst = (uint64_t*)kernarg_data;
+      const iree_hal_amdgpu_command_buffer_binding_source_t* binding_sources =
+          (const iree_hal_amdgpu_command_buffer_binding_source_t*)((const uint8_t*)
+                                                                       block +
+                                                                   dispatch_command
+                                                                       ->binding_source_offset);
+      for (uint16_t i = 0; i < dispatch_command->binding_count; ++i) {
+        const iree_hal_amdgpu_command_buffer_binding_source_t* binding_source =
+            &binding_sources[i];
+        const uint32_t flags = binding_source->flags;
+        if (IREE_LIKELY(
+                flags ==
+                IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_NONE)) {
+          binding_dst[i] = binding_source->offset_or_pointer;
+        } else if (flags ==
+                   IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_DYNAMIC) {
+          if (IREE_UNLIKELY(!binding_ptrs)) {
+            return iree_make_status(
+                IREE_STATUS_INVALID_ARGUMENT,
+                "AQL command-buffer dispatch has dynamic bindings but no "
+                "binding table was provided");
+          }
+          binding_dst[i] = binding_ptrs[binding_source->slot] +
+                           binding_source->offset_or_pointer;
+        } else if (
+            flags ==
+            IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_STATIC_BUFFER) {
+          IREE_RETURN_IF_ERROR(
+              iree_hal_amdgpu_aql_block_processor_resolve_static_binding_source_ptr(
+                  command_buffer, binding_source, &binding_dst[i]));
+        } else {
+          return iree_make_status(
+              IREE_STATUS_INVALID_ARGUMENT,
+              "malformed AQL command-buffer dispatch binding source flags %u",
+              binding_source->flags);
+        }
+      }
+      iree_hal_amdgpu_aql_block_processor_copy_dispatch_tail(
+          dispatch_command, kernarg_data,
+          (iree_host_size_t)dispatch_command->binding_count * sizeof(uint64_t));
+      return iree_ok_status();
+    }
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_CUSTOM_DIRECT: {
+      iree_hal_amdgpu_aql_block_processor_copy_dispatch_tail(
+          dispatch_command, kernarg_data, /*tail_offset=*/0);
+      return iree_ok_status();
+    }
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_INDIRECT:
+      return iree_make_status(
+          IREE_STATUS_UNIMPLEMENTED,
+          "indirect dispatch arguments are not supported by AMDGPU command "
+          "buffers yet");
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_PATCHED_TEMPLATE: {
+      const uint32_t kernarg_length =
+          (uint32_t)dispatch_command->kernarg_length_qwords * 8u;
+      const uint8_t* kernarg_template =
+          iree_hal_amdgpu_aql_command_buffer_rodata(
+              command_buffer, dispatch_command->payload_reference,
+              kernarg_length);
+      if (IREE_UNLIKELY(!kernarg_template)) {
+        return iree_make_status(
+            IREE_STATUS_INVALID_ARGUMENT,
+            "AQL command-buffer patched kernarg template range is invalid");
+      }
+      memcpy(kernarg_data, kernarg_template, kernarg_length);
+      uint64_t* binding_dst = (uint64_t*)kernarg_data;
+      const uint16_t patch_source_count =
+          dispatch_command->payload.patch_source_count;
+      if (IREE_UNLIKELY(!binding_ptrs)) {
+        return iree_make_status(
+            IREE_STATUS_INVALID_ARGUMENT,
+            "AQL command-buffer dispatch has dynamic bindings but no binding "
+            "table was provided");
+      }
+      const iree_hal_amdgpu_command_buffer_binding_source_t* binding_sources =
+          (const iree_hal_amdgpu_command_buffer_binding_source_t*)((const uint8_t*)
+                                                                       block +
+                                                                   dispatch_command
+                                                                       ->binding_source_offset);
+      for (uint16_t i = 0; i < patch_source_count; ++i) {
+        const iree_hal_amdgpu_command_buffer_binding_source_t* binding_source =
+            &binding_sources[i];
+        binding_dst[binding_source->target_binding_ordinal] =
+            binding_ptrs[binding_source->slot] +
+            binding_source->offset_or_pointer;
+      }
+      return iree_ok_status();
+    }
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_PREPUBLISHED:
+      return iree_make_status(
+          IREE_STATUS_INVALID_ARGUMENT,
+          "prepublished command-buffer dispatch should not rewrite kernargs");
+    default:
+      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                              "malformed AQL command-buffer kernarg strategy "
+                              "%u",
+                              dispatch_command->kernarg_strategy);
+  }
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_block_processor_replay_dispatch_indirect_params_ptr(
+    const iree_hal_amdgpu_aql_block_processor_t* processor,
+    const iree_hal_amdgpu_command_buffer_binding_source_t* binding_source,
+    const uint32_t** out_workgroup_count_ptr) {
+  *out_workgroup_count_ptr = NULL;
+  switch (binding_source->flags) {
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_INDIRECT_PARAMETERS:
+      *out_workgroup_count_ptr =
+          (const uint32_t*)(uintptr_t)binding_source->offset_or_pointer;
+      return iree_ok_status();
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_DYNAMIC |
+        IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_INDIRECT_PARAMETERS: {
+      iree_hal_buffer_ref_t resolved_ref = {0};
+      IREE_RETURN_IF_ERROR(iree_hal_buffer_binding_table_resolve_ref(
+          processor->bindings.table,
+          iree_hal_make_indirect_buffer_ref(binding_source->slot,
+                                            binding_source->offset_or_pointer,
+                                            sizeof(uint32_t[3])),
+          &resolved_ref));
+      uint8_t* device_ptr = NULL;
+      IREE_RETURN_IF_ERROR(
+          iree_hal_amdgpu_aql_block_processor_resolve_buffer_ref_ptr(
+              resolved_ref, IREE_HAL_BUFFER_USAGE_DISPATCH_INDIRECT_PARAMETERS,
+              IREE_HAL_MEMORY_ACCESS_READ, &device_ptr));
+      *out_workgroup_count_ptr = (const uint32_t*)device_ptr;
+      return iree_ok_status();
+    }
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_STATIC_BUFFER |
+        IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_INDIRECT_PARAMETERS: {
+      uint64_t workgroup_count_ptr = 0;
+      IREE_RETURN_IF_ERROR(
+          iree_hal_amdgpu_aql_block_processor_resolve_static_binding_source_ptr(
+              processor->command_buffer, binding_source, &workgroup_count_ptr));
+      *out_workgroup_count_ptr =
+          (const uint32_t*)(uintptr_t)workgroup_count_ptr;
+      return iree_ok_status();
+    }
+    default:
+      return iree_make_status(
+          IREE_STATUS_INVALID_ARGUMENT,
+          "malformed AQL command-buffer indirect parameter source flags %u",
+          binding_source->flags);
+  }
+}
+
+static iree_amdgpu_kernel_implicit_args_t*
+iree_hal_amdgpu_aql_block_processor_dispatch_implicit_args_ptr(
+    const iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command,
+    uint8_t* kernarg_data) {
+  if (dispatch_command->implicit_args_offset_qwords == UINT16_MAX) {
+    return NULL;
+  }
+  return (
+      iree_amdgpu_kernel_implicit_args_t*)(kernarg_data +
+                                           (iree_host_size_t)dispatch_command
+                                                   ->implicit_args_offset_qwords *
+                                               8u);
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_block_processor_replay_dispatch_packet_body(
+    const iree_hal_amdgpu_aql_block_processor_t* processor,
+    const iree_hal_amdgpu_command_buffer_block_header_t* block,
+    const iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command,
+    iree_hal_amdgpu_aql_packet_t* packet, uint8_t* kernarg_data,
+    iree_hsa_signal_t completion_signal, uint16_t* out_setup) {
+  if (iree_hal_amdgpu_aql_block_processor_dispatch_uses_prepublished(
+          dispatch_command)) {
+    const uint32_t kernarg_length =
+        (uint32_t)dispatch_command->kernarg_length_qwords * 8u;
+    kernarg_data = iree_hal_amdgpu_aql_command_buffer_prepublished_kernarg(
+        processor->command_buffer, dispatch_command->payload_reference,
+        kernarg_length);
+    if (IREE_UNLIKELY(!kernarg_data)) {
+      return iree_make_status(
+          IREE_STATUS_INVALID_ARGUMENT,
+          "AQL command-buffer prepublished kernarg range is invalid");
+    }
+  } else {
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_aql_block_processor_replay_dispatch_kernargs(
+            block, processor->command_buffer, processor->bindings.ptrs,
+            dispatch_command, kernarg_data));
+  }
+  iree_hal_amdgpu_aql_block_processor_write_dispatch_packet_body(
+      dispatch_command, packet, kernarg_data, completion_signal, out_setup);
+  return iree_ok_status();
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_block_processor_replay_indirect_dispatch_packet_bodies(
+    const iree_hal_amdgpu_aql_block_processor_t* processor,
+    const iree_hal_amdgpu_command_buffer_block_header_t* block,
+    const iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command,
+    iree_hal_amdgpu_aql_packet_t* patch_packet,
+    iree_hal_amdgpu_aql_packet_t* dispatch_packet, uint8_t* patch_kernarg_data,
+    uint8_t* dispatch_kernarg_data, iree_hsa_signal_t completion_signal,
+    uint16_t dispatch_header, uint16_t* out_patch_setup,
+    uint16_t* out_dispatch_setup) {
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_block_processor_replay_dispatch_packet_body(
+          processor, block, dispatch_command, dispatch_packet,
+          dispatch_kernarg_data, completion_signal, out_dispatch_setup));
+
+  const iree_hal_amdgpu_command_buffer_binding_source_t* binding_sources =
+      (const iree_hal_amdgpu_command_buffer_binding_source_t*)((const uint8_t*)
+                                                                   block +
+                                                               dispatch_command
+                                                                   ->binding_source_offset);
+  const iree_hal_amdgpu_command_buffer_binding_source_t*
+      indirect_params_source =
+          &binding_sources[dispatch_command->binding_count];
+  const uint32_t* workgroup_count_ptr = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_block_processor_replay_dispatch_indirect_params_ptr(
+          processor, indirect_params_source, &workgroup_count_ptr));
+
+  iree_amdgpu_kernel_implicit_args_t* implicit_args =
+      iree_hal_amdgpu_aql_block_processor_dispatch_implicit_args_ptr(
+          dispatch_command, dispatch_kernarg_data);
+  iree_hal_amdgpu_device_dispatch_emplace_indirect_params_patch(
+      &processor->transfer_context->kernels
+           ->iree_hal_amdgpu_device_dispatch_patch_indirect_params,
+      workgroup_count_ptr, &dispatch_packet->dispatch, dispatch_header,
+      *out_dispatch_setup, implicit_args, &patch_packet->dispatch,
+      patch_kernarg_data);
+  *out_patch_setup = patch_packet->dispatch.setup;
+  return iree_ok_status();
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_block_processor_replay_fill_packet_body(
+    const iree_hal_amdgpu_aql_block_processor_t* processor,
+    const iree_hal_amdgpu_command_buffer_fill_command_t* fill_command,
+    iree_hal_amdgpu_aql_packet_t* packet,
+    iree_hal_amdgpu_kernarg_block_t* kernarg_block,
+    iree_hsa_signal_t completion_signal, uint16_t* out_setup) {
+  iree_hal_buffer_ref_t target_ref = {0};
+  uint8_t* target_ptr = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_block_processor_resolve_command_buffer_ref(
+          processor->command_buffer, processor->bindings.table,
+          fill_command->target_kind, fill_command->target_ordinal,
+          fill_command->target_offset, fill_command->length,
+          IREE_HAL_BUFFER_USAGE_TRANSFER_TARGET, IREE_HAL_MEMORY_ACCESS_WRITE,
+          &target_ref, &target_ptr));
+  (void)target_ref;
+  if (IREE_UNLIKELY(!iree_hal_amdgpu_device_buffer_fill_emplace(
+          processor->transfer_context, &packet->dispatch, target_ptr,
+          fill_command->length, fill_command->pattern,
+          fill_command->pattern_length, kernarg_block->data))) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "unsupported command-buffer fill dispatch shape");
+  }
+  packet->dispatch.completion_signal = completion_signal;
+  *out_setup = packet->dispatch.setup;
+  return iree_ok_status();
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_block_processor_replay_copy_packet_body(
+    const iree_hal_amdgpu_aql_block_processor_t* processor,
+    const iree_hal_amdgpu_command_buffer_copy_command_t* copy_command,
+    iree_hal_amdgpu_aql_packet_t* packet,
+    iree_hal_amdgpu_kernarg_block_t* kernarg_block,
+    iree_hsa_signal_t completion_signal, uint16_t* out_setup) {
+  iree_hal_buffer_ref_t source_ref = {0};
+  uint8_t* source_ptr = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_block_processor_resolve_command_buffer_ref(
+          processor->command_buffer, processor->bindings.table,
+          copy_command->source_kind, copy_command->source_ordinal,
+          copy_command->source_offset, copy_command->length,
+          IREE_HAL_BUFFER_USAGE_TRANSFER_SOURCE, IREE_HAL_MEMORY_ACCESS_READ,
+          &source_ref, &source_ptr));
+  iree_hal_buffer_ref_t target_ref = {0};
+  uint8_t* target_ptr = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_block_processor_resolve_command_buffer_ref(
+          processor->command_buffer, processor->bindings.table,
+          copy_command->target_kind, copy_command->target_ordinal,
+          copy_command->target_offset, copy_command->length,
+          IREE_HAL_BUFFER_USAGE_TRANSFER_TARGET, IREE_HAL_MEMORY_ACCESS_WRITE,
+          &target_ref, &target_ptr));
+
+  if (IREE_UNLIKELY(
+          iree_hal_buffer_test_overlap(source_ref.buffer, source_ref.offset,
+                                       source_ref.length, target_ref.buffer,
+                                       target_ref.offset, target_ref.length) !=
+          IREE_HAL_BUFFER_OVERLAP_DISJOINT)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "source and target ranges must not overlap within the same buffer");
+  }
+  if (IREE_UNLIKELY(!iree_hal_amdgpu_device_buffer_copy_emplace(
+          processor->transfer_context, &packet->dispatch, source_ptr,
+          target_ptr, copy_command->length, kernarg_block->data))) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "unsupported command-buffer copy dispatch shape");
+  }
+  packet->dispatch.completion_signal = completion_signal;
+  *out_setup = packet->dispatch.setup;
+  return iree_ok_status();
+}
+
+static iree_host_size_t
+iree_hal_amdgpu_aql_block_processor_update_kernarg_length(
+    uint32_t source_length) {
+  const iree_host_size_t source_payload_offset =
+      IREE_HAL_AMDGPU_DEVICE_BUFFER_COPY_STAGED_SOURCE_OFFSET;
+  return source_payload_offset + (iree_host_size_t)source_length;
+}
+
+static uint32_t iree_hal_amdgpu_aql_block_processor_update_kernarg_block_count(
+    uint32_t source_length) {
+  return (uint32_t)iree_host_size_ceil_div(
+      iree_hal_amdgpu_aql_block_processor_update_kernarg_length(source_length),
+      sizeof(iree_hal_amdgpu_kernarg_block_t));
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_block_processor_resolve_update_packet_operands(
+    const iree_hal_amdgpu_aql_block_processor_t* processor,
+    const iree_hal_amdgpu_command_buffer_update_command_t* update_command,
+    const uint8_t** out_source_bytes, uint8_t** out_target_ptr) {
+  *out_source_bytes = NULL;
+  *out_target_ptr = NULL;
+  iree_hal_buffer_ref_t target_ref = {0};
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_block_processor_resolve_command_buffer_ref(
+          processor->command_buffer, processor->bindings.table,
+          update_command->target_kind, update_command->target_ordinal,
+          update_command->target_offset, update_command->length,
+          IREE_HAL_BUFFER_USAGE_TRANSFER_TARGET, IREE_HAL_MEMORY_ACCESS_WRITE,
+          &target_ref, out_target_ptr));
+  (void)target_ref;
+  *out_source_bytes = iree_hal_amdgpu_aql_command_buffer_rodata(
+      processor->command_buffer, update_command->rodata_ordinal,
+      update_command->length);
+  if (IREE_UNLIKELY(!*out_source_bytes)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "AQL command-buffer update rodata range is invalid");
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_block_processor_replay_update_packet_body(
+    const iree_hal_amdgpu_aql_block_processor_t* processor,
+    const iree_hal_amdgpu_command_buffer_update_command_t* update_command,
+    iree_hal_amdgpu_aql_packet_t* packet, uint8_t* kernarg_data,
+    iree_host_size_t kernarg_length, iree_hsa_signal_t completion_signal,
+    uint16_t* out_setup) {
+  const uint8_t* source_bytes = NULL;
+  uint8_t* target_ptr = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_block_processor_resolve_update_packet_operands(
+          processor, update_command, &source_bytes, &target_ptr));
+
+  const iree_host_size_t source_payload_offset =
+      IREE_HAL_AMDGPU_DEVICE_BUFFER_COPY_STAGED_SOURCE_OFFSET;
+  const iree_host_size_t required_kernarg_length =
+      source_payload_offset + (iree_host_size_t)update_command->length;
+  if (IREE_UNLIKELY(required_kernarg_length > kernarg_length)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "AQL command-buffer update kernarg range is too small");
+  }
+
+  iree_hal_amdgpu_device_buffer_copy_kernargs_t kernargs;
+  memset(&kernargs, 0, sizeof(kernargs));
+  if (IREE_UNLIKELY(!iree_hal_amdgpu_device_buffer_copy_emplace(
+          processor->transfer_context, &packet->dispatch,
+          (const void*)(uintptr_t)
+              IREE_HAL_AMDGPU_DEVICE_BUFFER_COPY_STAGED_SOURCE_ALIGNMENT,
+          target_ptr, update_command->length, &kernargs))) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "unsupported command-buffer update dispatch shape");
+  }
+
+  uint8_t* staged_source_bytes = kernarg_data + source_payload_offset;
+  memcpy(kernarg_data, &kernargs, sizeof(kernargs));
+  ((iree_hal_amdgpu_device_buffer_copy_kernargs_t*)kernarg_data)->source_ptr =
+      staged_source_bytes;
+  memcpy(staged_source_bytes, source_bytes, update_command->length);
+  packet->dispatch.kernarg_address = kernarg_data;
+  packet->dispatch.completion_signal = completion_signal;
+  *out_setup = packet->dispatch.setup;
+  return iree_ok_status();
+}
+
+static uint32_t
+iree_hal_amdgpu_aql_block_processor_payload_acquire_packet_count(
+    const iree_hal_amdgpu_aql_block_processor_t* processor,
+    const iree_hal_amdgpu_command_buffer_block_header_t* block) {
+  if (processor->payload.acquire_scope == IREE_HSA_FENCE_SCOPE_NONE) return 0;
+  if (block->aql_packet_count == 0) return 0;
+  iree_hal_amdgpu_aql_block_processor_packet_flags_t packet_flags =
+      IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PACKET_FLAG_NONE;
+  if (block->aql_packet_count == 1 &&
+      iree_all_bits_set(
+          processor->flags,
+          IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_FLAG_FINAL_PAYLOAD_PACKET)) {
+    packet_flags |= IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PACKET_FLAG_FINAL;
+  }
+  if (iree_hal_amdgpu_aql_block_processor_packet_has_barrier(
+          processor, processor->packets.index_base, packet_flags)) {
+    return 1;
+  }
+  return block->initial_barrier_packet_count;
+}
+
+static iree_hsa_fence_scope_t
+iree_hal_amdgpu_aql_block_processor_payload_acquire_scope(
+    const iree_hal_amdgpu_aql_block_processor_t* processor,
+    const iree_hal_amdgpu_aql_block_processor_state_t* state,
+    uint32_t payload_acquire_packet_count, uint32_t packet_index,
+    const iree_hal_amdgpu_command_buffer_command_header_t* command,
+    iree_hal_amdgpu_aql_block_processor_packet_flags_t packet_flags) {
+  if (processor->payload.acquire_scope == IREE_HSA_FENCE_SCOPE_NONE) {
+    return IREE_HSA_FENCE_SCOPE_NONE;
+  }
+  if (state->packets.recorded >= payload_acquire_packet_count) {
+    return IREE_HSA_FENCE_SCOPE_NONE;
+  }
+  if (iree_hal_amdgpu_aql_block_processor_command_uses_queue_kernargs(
+          command)) {
+    return processor->payload.acquire_scope;
+  }
+  if (iree_hal_amdgpu_aql_block_processor_packet_has_barrier(
+          processor, processor->packets.index_base + packet_index,
+          packet_flags)) {
+    return processor->payload.acquire_scope;
+  }
+  return IREE_HSA_FENCE_SCOPE_NONE;
+}
+
+static iree_hal_amdgpu_aql_packet_t* iree_hal_amdgpu_aql_block_processor_packet(
+    const iree_hal_amdgpu_aql_block_processor_t* processor,
+    uint32_t packet_index) {
+  return iree_hal_amdgpu_aql_ring_packet(
+      processor->packets.ring, processor->packets.first_id + packet_index);
+}
+
+static bool iree_hal_amdgpu_aql_block_processor_is_block_final(
+    const iree_hal_amdgpu_command_buffer_block_header_t* block,
+    const iree_hal_amdgpu_aql_block_processor_state_t* state,
+    uint32_t recorded_packet_count) {
+  return state->packets.recorded + recorded_packet_count ==
+         block->aql_packet_count;
+}
+
+static iree_status_t iree_hal_amdgpu_aql_block_processor_emit_direct_dispatch(
+    const iree_hal_amdgpu_aql_block_processor_t* processor,
+    const iree_hal_amdgpu_command_buffer_block_header_t* block,
+    uint32_t payload_acquire_packet_count,
+    const iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command,
+    iree_hal_amdgpu_aql_block_processor_state_t* state) {
+  const uint32_t dispatch_packet_index = state->packets.emitted;
+  iree_hal_amdgpu_aql_packet_t* packet =
+      iree_hal_amdgpu_aql_block_processor_packet(processor,
+                                                 dispatch_packet_index);
+  uint8_t* kernarg_data = NULL;
+  if (iree_hal_amdgpu_aql_block_processor_dispatch_uses_queue_kernargs(
+          dispatch_command)) {
+    kernarg_data = processor->kernargs.blocks[state->kernargs.block].data;
+  }
+  iree_status_t status =
+      iree_hal_amdgpu_aql_block_processor_replay_dispatch_packet_body(
+          processor, block, dispatch_command, packet, kernarg_data,
+          iree_hsa_signal_null(),
+          &processor->packets.setups[dispatch_packet_index]);
+  if (iree_status_is_ok(status)) {
+    iree_hal_amdgpu_aql_block_processor_packet_flags_t packet_flags =
+        iree_hal_amdgpu_aql_block_processor_command_packet_flags(
+            &dispatch_command->header);
+    if (iree_hal_amdgpu_aql_block_processor_is_block_final(
+            block, state, /*recorded_packet_count=*/1) &&
+        iree_all_bits_set(
+            processor->flags,
+            IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_FLAG_FINAL_PAYLOAD_PACKET)) {
+      packet_flags |= IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PACKET_FLAG_FINAL;
+    }
+    const iree_hsa_fence_scope_t payload_acquire_scope =
+        iree_hal_amdgpu_aql_block_processor_payload_acquire_scope(
+            processor, state, payload_acquire_packet_count,
+            dispatch_packet_index, &dispatch_command->header, packet_flags);
+    processor->packets.headers[dispatch_packet_index] =
+        iree_hal_amdgpu_aql_make_header(
+            IREE_HSA_PACKET_TYPE_KERNEL_DISPATCH,
+            iree_hal_amdgpu_aql_block_processor_packet_control(
+                processor, dispatch_packet_index, payload_acquire_scope,
+                packet_flags));
+    ++state->packets.emitted;
+    ++state->packets.recorded;
+    state->kernargs.block +=
+        iree_hal_amdgpu_aql_block_processor_dispatch_kernarg_block_count(
+            dispatch_command);
+  }
+  return status;
+}
+
+static iree_status_t iree_hal_amdgpu_aql_block_processor_emit_indirect_dispatch(
+    const iree_hal_amdgpu_aql_block_processor_t* processor,
+    const iree_hal_amdgpu_command_buffer_block_header_t* block,
+    uint32_t payload_acquire_packet_count,
+    const iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command,
+    iree_hal_amdgpu_aql_block_processor_state_t* state) {
+  const uint32_t patch_packet_index = state->packets.emitted++;
+  const uint32_t dispatch_packet_index = state->packets.emitted;
+  iree_hal_amdgpu_aql_packet_t* patch_packet =
+      iree_hal_amdgpu_aql_block_processor_packet(processor, patch_packet_index);
+  iree_hal_amdgpu_aql_packet_t* dispatch_packet =
+      iree_hal_amdgpu_aql_block_processor_packet(processor,
+                                                 dispatch_packet_index);
+  iree_hal_amdgpu_aql_block_processor_packet_flags_t dispatch_packet_flags =
+      IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PACKET_FLAG_NONE;
+  if (iree_hal_amdgpu_aql_block_processor_is_block_final(
+          block, state, /*recorded_packet_count=*/2) &&
+      iree_all_bits_set(
+          processor->flags,
+          IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_FLAG_FINAL_PAYLOAD_PACKET)) {
+    dispatch_packet_flags |=
+        IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PACKET_FLAG_FINAL;
+  }
+  const uint16_t dispatch_header = iree_hal_amdgpu_aql_make_header(
+      IREE_HSA_PACKET_TYPE_KERNEL_DISPATCH,
+      iree_hal_amdgpu_aql_block_processor_packet_control(
+          processor, dispatch_packet_index, IREE_HSA_FENCE_SCOPE_NONE,
+          dispatch_packet_flags));
+  iree_status_t status =
+      iree_hal_amdgpu_aql_block_processor_replay_indirect_dispatch_packet_bodies(
+          processor, block, dispatch_command, patch_packet, dispatch_packet,
+          processor->kernargs.blocks[state->kernargs.block].data,
+          processor->kernargs.blocks[state->kernargs.block + 1].data,
+          iree_hsa_signal_null(), dispatch_header,
+          &processor->packets.setups[patch_packet_index],
+          &processor->packets.setups[dispatch_packet_index]);
+  if (iree_status_is_ok(status)) {
+    const iree_hal_amdgpu_aql_block_processor_packet_flags_t patch_flags =
+        iree_hal_amdgpu_aql_block_processor_agent_barrier_packet_flags();
+    const iree_hsa_fence_scope_t patch_acquire_scope =
+        iree_hal_amdgpu_aql_block_processor_payload_acquire_scope(
+            processor, state, payload_acquire_packet_count, patch_packet_index,
+            &dispatch_command->header, patch_flags);
+    // The patch dispatch publishes the following dispatch packet header, so it
+    // must retire before the CP observes that slot.
+    processor->packets.headers[patch_packet_index] =
+        iree_hal_amdgpu_aql_make_header(
+            IREE_HSA_PACKET_TYPE_KERNEL_DISPATCH,
+            iree_hal_amdgpu_aql_block_processor_packet_control(
+                processor, patch_packet_index, patch_acquire_scope,
+                patch_flags));
+    // The patch dispatch publishes the target dispatch header after it has
+    // updated dynamic workgroup counts on device.
+    processor->packets.headers[dispatch_packet_index] =
+        IREE_HSA_PACKET_TYPE_INVALID;
+    state->packets.emitted = dispatch_packet_index + /*dispatch packet=*/1;
+    state->packets.recorded += 2;
+    state->kernargs.block +=
+        iree_hal_amdgpu_aql_block_processor_dispatch_kernarg_block_count(
+            dispatch_command);
+  }
+  return status;
+}
+
+static iree_status_t iree_hal_amdgpu_aql_block_processor_emit_dispatch(
+    const iree_hal_amdgpu_aql_block_processor_t* processor,
+    const iree_hal_amdgpu_command_buffer_block_header_t* block,
+    uint32_t payload_acquire_packet_count,
+    const iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command,
+    iree_hal_amdgpu_aql_block_processor_state_t* state) {
+  if (iree_hal_amdgpu_aql_block_processor_dispatch_uses_indirect(
+          dispatch_command)) {
+    return iree_hal_amdgpu_aql_block_processor_emit_indirect_dispatch(
+        processor, block, payload_acquire_packet_count, dispatch_command,
+        state);
+  }
+  return iree_hal_amdgpu_aql_block_processor_emit_direct_dispatch(
+      processor, block, payload_acquire_packet_count, dispatch_command, state);
+}
+
+static iree_status_t iree_hal_amdgpu_aql_block_processor_emit_transfer(
+    const iree_hal_amdgpu_aql_block_processor_t* processor,
+    const iree_hal_amdgpu_command_buffer_block_header_t* block,
+    uint32_t payload_acquire_packet_count,
+    const iree_hal_amdgpu_command_buffer_command_header_t* command,
+    iree_hal_amdgpu_aql_block_processor_state_t* state) {
+  const uint32_t packet_index = state->packets.emitted;
+  iree_hal_amdgpu_aql_block_processor_packet_flags_t packet_flags =
+      iree_hal_amdgpu_aql_block_processor_command_packet_flags(command);
+  if (iree_hal_amdgpu_aql_block_processor_is_block_final(
+          block, state, /*recorded_packet_count=*/1) &&
+      iree_all_bits_set(
+          processor->flags,
+          IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_FLAG_FINAL_PAYLOAD_PACKET)) {
+    packet_flags |= IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PACKET_FLAG_FINAL;
+  }
+  iree_hal_amdgpu_aql_packet_t* packet =
+      iree_hal_amdgpu_aql_block_processor_packet(processor, packet_index);
+
+  iree_status_t status = iree_ok_status();
+  if (command->opcode == IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_FILL) {
+    status = iree_hal_amdgpu_aql_block_processor_replay_fill_packet_body(
+        processor,
+        (const iree_hal_amdgpu_command_buffer_fill_command_t*)command, packet,
+        &processor->kernargs.blocks[state->kernargs.block],
+        iree_hsa_signal_null(), &processor->packets.setups[packet_index]);
+  } else if (command->opcode == IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_COPY) {
+    status = iree_hal_amdgpu_aql_block_processor_replay_copy_packet_body(
+        processor,
+        (const iree_hal_amdgpu_command_buffer_copy_command_t*)command, packet,
+        &processor->kernargs.blocks[state->kernargs.block],
+        iree_hsa_signal_null(), &processor->packets.setups[packet_index]);
+  } else {
+    const iree_hal_amdgpu_command_buffer_update_command_t* update_command =
+        (const iree_hal_amdgpu_command_buffer_update_command_t*)command;
+    const iree_host_size_t kernarg_length =
+        iree_hal_amdgpu_aql_block_processor_update_kernarg_length(
+            update_command->length);
+    const iree_host_size_t kernarg_block_count = iree_host_size_ceil_div(
+        kernarg_length, sizeof(iree_hal_amdgpu_kernarg_block_t));
+    status = iree_hal_amdgpu_aql_block_processor_replay_update_packet_body(
+        processor, update_command, packet,
+        processor->kernargs.blocks[state->kernargs.block].data,
+        kernarg_block_count * sizeof(iree_hal_amdgpu_kernarg_block_t),
+        iree_hsa_signal_null(), &processor->packets.setups[packet_index]);
+    if (iree_status_is_ok(status)) {
+      state->kernargs.block += (uint32_t)kernarg_block_count - 1u;
+    }
+  }
+  if (iree_status_is_ok(status)) {
+    processor->packets.headers[packet_index] = iree_hal_amdgpu_aql_make_header(
+        IREE_HSA_PACKET_TYPE_KERNEL_DISPATCH,
+        iree_hal_amdgpu_aql_block_processor_packet_control(
+            processor, packet_index,
+            iree_hal_amdgpu_aql_block_processor_payload_acquire_scope(
+                processor, state, payload_acquire_packet_count, packet_index,
+                command, packet_flags),
+            packet_flags));
+    ++state->packets.emitted;
+    ++state->packets.recorded;
+    ++state->kernargs.block;
+  }
+  return status;
+}
+
+void iree_hal_amdgpu_aql_block_processor_initialize(
+    const iree_hal_amdgpu_aql_block_processor_t* params,
+    iree_hal_amdgpu_aql_block_processor_t* out_processor) {
+  *out_processor = *params;
+}
+
+void iree_hal_amdgpu_aql_block_processor_deinitialize(
+    iree_hal_amdgpu_aql_block_processor_t* processor) {
+  memset(processor, 0, sizeof(*processor));
+}
+
+iree_status_t iree_hal_amdgpu_aql_block_processor_invoke(
+    const iree_hal_amdgpu_aql_block_processor_t* processor,
+    const iree_hal_amdgpu_command_buffer_block_header_t* block,
+    iree_hal_amdgpu_aql_block_processor_result_t* out_result) {
+  memset(out_result, 0, sizeof(*out_result));
+  const uint32_t payload_acquire_packet_count =
+      iree_hal_amdgpu_aql_block_processor_payload_acquire_packet_count(
+          processor, block);
+  const iree_hal_amdgpu_command_buffer_command_header_t* command =
+      iree_hal_amdgpu_command_buffer_block_commands_const(block);
+  iree_hal_amdgpu_aql_block_processor_state_t state = {0};
+  bool reached_terminator = false;
+  iree_status_t status = iree_ok_status();
+  for (uint16_t i = 0; i < block->command_count && iree_status_is_ok(status) &&
+                       !reached_terminator;
+       ++i) {
+    switch (command->opcode) {
+      case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BARRIER:
+        break;
+      case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_DISPATCH:
+        status = iree_hal_amdgpu_aql_block_processor_emit_dispatch(
+            processor, block, payload_acquire_packet_count,
+            (const iree_hal_amdgpu_command_buffer_dispatch_command_t*)command,
+            &state);
+        break;
+      case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_FILL:
+      case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_COPY:
+      case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_UPDATE:
+        status = iree_hal_amdgpu_aql_block_processor_emit_transfer(
+            processor, block, payload_acquire_packet_count, command, &state);
+        break;
+      case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_RETURN:
+        out_result->terminator =
+            IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_TERMINATOR_RETURN;
+        reached_terminator = true;
+        break;
+      case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BRANCH: {
+        const iree_hal_amdgpu_command_buffer_branch_command_t* branch_command =
+            (const iree_hal_amdgpu_command_buffer_branch_command_t*)command;
+        out_result->terminator =
+            IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_TERMINATOR_BRANCH;
+        out_result->target_block_ordinal = branch_command->target_block_ordinal;
+        reached_terminator = true;
+        break;
+      }
+      case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_PROFILE_MARKER:
+      case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_COND_BRANCH:
+        status = iree_make_status(
+            IREE_STATUS_UNIMPLEMENTED,
+            "AQL command-buffer opcode %u replay not yet wired",
+            command->opcode);
+        break;
+      default:
+        status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                                  "malformed AQL command-buffer opcode %u",
+                                  command->opcode);
+        break;
+    }
+    if (iree_status_is_ok(status) && !reached_terminator) {
+      command = iree_hal_amdgpu_command_buffer_command_next_const(command);
+    }
+  }
+  if (iree_status_is_ok(status) && !reached_terminator) {
+    status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                              "AQL command-buffer block %" PRIu32
+                              " has no terminator",
+                              block->block_ordinal);
+  }
+  if (iree_status_is_ok(status) &&
+      state.packets.recorded != block->aql_packet_count) {
+    status = iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "AQL command-buffer block %" PRIu32 " consumed %" PRIu32
+        " packets but declares %" PRIu32,
+        block->block_ordinal, state.packets.recorded, block->aql_packet_count);
+  }
+  if (iree_status_is_ok(status) &&
+      state.packets.emitted != processor->packets.count) {
+    status = iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "AQL command-buffer block %" PRIu32 " emitted %" PRIu32
+        " payload packets but reserved %u",
+        block->block_ordinal, state.packets.emitted, processor->packets.count);
+  }
+  if (iree_status_is_ok(status) &&
+      state.kernargs.block != processor->kernargs.count) {
+    status = iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "AQL command-buffer block %" PRIu32 " emitted %" PRIu32
+        " kernarg blocks but reserved %" PRIu32,
+        block->block_ordinal, state.kernargs.block, processor->kernargs.count);
+  }
+  out_result->packets.recorded = state.packets.recorded;
+  out_result->packets.emitted = state.packets.emitted;
+  out_result->kernargs.consumed = state.kernargs.block;
+  return status;
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/aql_block_processor.h b/runtime/src/iree/hal/drivers/amdgpu/aql_block_processor.h
new file mode 100644
index 0000000..4965e05
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/aql_block_processor.h
@@ -0,0 +1,128 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_AQL_BLOCK_PROCESSOR_H_
+#define IREE_HAL_DRIVERS_AMDGPU_AQL_BLOCK_PROCESSOR_H_
+
+#include "iree/base/api.h"
+#include "iree/hal/api.h"
+#include "iree/hal/drivers/amdgpu/aql_command_buffer.h"
+#include "iree/hal/drivers/amdgpu/device/blit.h"
+#include "iree/hal/drivers/amdgpu/util/aql_ring.h"
+#include "iree/hal/drivers/amdgpu/util/kernarg_ring.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+typedef uint32_t iree_hal_amdgpu_aql_block_processor_flags_t;
+enum iree_hal_amdgpu_aql_block_processor_flag_bits_t {
+  IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_FLAG_NONE = 0u,
+  // The final recorded payload packet carries terminal signal release scope.
+  IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_FLAG_FINAL_PAYLOAD_PACKET = 1u << 0,
+};
+
+typedef uint8_t iree_hal_amdgpu_aql_block_processor_terminator_t;
+enum iree_hal_amdgpu_aql_block_processor_terminator_e {
+  IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_TERMINATOR_NONE = 0u,
+  IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_TERMINATOR_RETURN = 1u,
+  IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_TERMINATOR_BRANCH = 2u,
+};
+
+// Host-side base processor for one AQL command-buffer block.
+typedef struct iree_hal_amdgpu_aql_block_processor_t {
+  // Device helper kernel context used for transfer and patch dispatches.
+  const iree_hal_amdgpu_device_buffer_transfer_context_t* transfer_context;
+  // Command buffer that owns static buffers and retained rodata.
+  iree_hal_command_buffer_t* command_buffer;
+  // Queue-execute binding state consumed by dynamic block operands.
+  struct {
+    // Binding table supplied to queue_execute.
+    iree_hal_buffer_binding_table_t table;
+    // Pre-resolved binding pointers indexed by queue_execute binding table
+    // slot.
+    const uint64_t* ptrs;
+  } bindings;
+  // Reserved packet span populated by the processor.
+  struct {
+    // AQL ring containing the reserved packet span.
+    iree_hal_amdgpu_aql_ring_t* ring;
+    // First reserved packet id in |ring|.
+    uint64_t first_id;
+    // Logical packet index of |first_id| within the full submission.
+    uint32_t index_base;
+    // Number of reserved packet slots available to the processor.
+    uint32_t count;
+    // Header words produced by the processor and published by the caller.
+    uint16_t* headers;
+    // Setup words produced with |headers|.
+    uint16_t* setups;
+  } packets;
+  // Reserved queue-owned kernarg storage consumed by the processor.
+  struct {
+    // First reserved kernarg block.
+    iree_hal_amdgpu_kernarg_block_t* blocks;
+    // Number of reserved kernarg blocks.
+    uint32_t count;
+  } kernargs;
+  // Submission overlay precomputed by host queue policy.
+  struct {
+    // Number of wait-barrier prefix packets before the payload span.
+    uint32_t wait_barrier_count;
+    // Acquire scope required by inline wait payloads on logical packet zero.
+    iree_hsa_fence_scope_t inline_acquire_scope;
+    // Release scope required when the payload span carries terminal signals.
+    iree_hsa_fence_scope_t signal_release_scope;
+  } submission;
+  // Queue-owned payload visibility requirements.
+  struct {
+    // Minimum acquire scope required for queue-owned kernarg visibility.
+    iree_hsa_fence_scope_t acquire_scope;
+  } payload;
+  // Flags from iree_hal_amdgpu_aql_block_processor_flag_bits_t.
+  iree_hal_amdgpu_aql_block_processor_flags_t flags;
+} iree_hal_amdgpu_aql_block_processor_t;
+
+// Result of invoking the processor on one block.
+typedef struct iree_hal_amdgpu_aql_block_processor_result_t {
+  // Packet accounting reported by the processor.
+  struct {
+    // Number of recorded block AQL packets consumed.
+    uint32_t recorded;
+    // Number of reserved AQL packets populated.
+    uint32_t emitted;
+  } packets;
+  // Kernarg accounting reported by the processor.
+  struct {
+    // Number of reserved kernarg blocks consumed.
+    uint32_t consumed;
+  } kernargs;
+  // Terminator kind reached by this invocation.
+  iree_hal_amdgpu_aql_block_processor_terminator_t terminator;
+  // Branch target block ordinal when |terminator| is BRANCH.
+  uint32_t target_block_ordinal;
+} iree_hal_amdgpu_aql_block_processor_result_t;
+
+// Initializes |out_processor| with borrowed submission storage.
+void iree_hal_amdgpu_aql_block_processor_initialize(
+    const iree_hal_amdgpu_aql_block_processor_t* params,
+    iree_hal_amdgpu_aql_block_processor_t* out_processor);
+
+// Deinitializes |processor|. This currently releases no resources.
+void iree_hal_amdgpu_aql_block_processor_deinitialize(
+    iree_hal_amdgpu_aql_block_processor_t* processor);
+
+// Invokes |processor| on |block| and populates reserved packet/kernarg storage.
+iree_status_t iree_hal_amdgpu_aql_block_processor_invoke(
+    const iree_hal_amdgpu_aql_block_processor_t* processor,
+    const iree_hal_amdgpu_command_buffer_block_header_t* block,
+    iree_hal_amdgpu_aql_block_processor_result_t* out_result);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_AQL_BLOCK_PROCESSOR_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/aql_block_processor_profile.c b/runtime/src/iree/hal/drivers/amdgpu/aql_block_processor_profile.c
new file mode 100644
index 0000000..de29004
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/aql_block_processor_profile.c
@@ -0,0 +1,1443 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/aql_block_processor_profile.h"
+
+#include <string.h>
+
+#include "iree/hal/drivers/amdgpu/buffer.h"
+#include "iree/hal/drivers/amdgpu/device/blit.h"
+#include "iree/hal/drivers/amdgpu/device/dispatch.h"
+#include "iree/hal/drivers/amdgpu/host_queue_command_buffer_packet.h"
+#include "iree/hal/drivers/amdgpu/host_queue_command_buffer_profile.h"
+#include "iree/hal/drivers/amdgpu/host_queue_policy.h"
+#include "iree/hal/drivers/amdgpu/profile_counters.h"
+#include "iree/hal/drivers/amdgpu/profile_traces.h"
+#include "iree/hal/drivers/amdgpu/util/aql_emitter.h"
+
+typedef struct iree_hal_amdgpu_aql_block_processor_profile_state_t {
+  // Packet cursors advanced while invoking profiled replay.
+  struct {
+    // Number of recorded block AQL packets consumed from the block.
+    uint32_t recorded;
+    // Number of payload AQL packets emitted into the reserved packet span.
+    uint32_t emitted;
+  } packets;
+  // Queue-owned kernarg cursors advanced while invoking profiled replay.
+  struct {
+    // Number of queue-owned kernarg blocks consumed from the reserved span.
+    uint32_t block;
+  } kernargs;
+  // Profile sidecar cursors advanced while invoking profiled replay.
+  struct {
+    // Number of dispatch profile events consumed from the reservation.
+    uint32_t event;
+  } profile;
+} iree_hal_amdgpu_aql_block_processor_profile_state_t;
+
+typedef uint32_t iree_hal_amdgpu_aql_block_processor_dispatch_profile_flags_t;
+enum iree_hal_amdgpu_aql_block_processor_dispatch_profile_flag_bits_t {
+  IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_DISPATCH_PROFILE_FLAG_NONE = 0u,
+  // Dispatch has a reserved profiling event and completion signal.
+  IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_DISPATCH_PROFILE_FLAG_DISPATCH_PACKET =
+      1u << 0,
+  // Counter PM4 packets wrap this dispatch event.
+  IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_DISPATCH_PROFILE_FLAG_COUNTER_PACKETS =
+      1u << 1,
+  // ATT trace PM4 packets wrap this dispatch event.
+  IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_DISPATCH_PROFILE_FLAG_TRACE_PACKETS =
+      1u << 2,
+  // Dispatch consumes the final recorded packet in the block.
+  IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_DISPATCH_PROFILE_FLAG_BLOCK_FINAL = 1u
+                                                                          << 3,
+  // Dispatch is the final replayed payload packet in the block.
+  IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_DISPATCH_PROFILE_FLAG_FINAL = 1u << 4,
+};
+
+typedef struct iree_hal_amdgpu_aql_block_processor_dispatch_profile_t {
+  // Flags from
+  // iree_hal_amdgpu_aql_block_processor_dispatch_profile_flag_bits_t.
+  iree_hal_amdgpu_aql_block_processor_dispatch_profile_flags_t flags;
+  // Event-ring position for this dispatch when DISPATCH_PACKET is set.
+  uint64_t event_position;
+  // Selected dispatch profile sidecar when DISPATCH_PACKET is set.
+  const iree_hal_amdgpu_aql_block_processor_profile_dispatch_t* dispatch;
+} iree_hal_amdgpu_aql_block_processor_dispatch_profile_t;
+
+static iree_status_t
+iree_hal_amdgpu_aql_block_processor_profile_resolve_buffer_ref_ptr(
+    iree_hal_buffer_ref_t buffer_ref, iree_hal_buffer_usage_t required_usage,
+    iree_hal_memory_access_t required_access, uint8_t** out_device_ptr) {
+  *out_device_ptr = NULL;
+  if (IREE_UNLIKELY(!buffer_ref.buffer)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "AQL command-buffer dynamic binding resolved to a NULL buffer");
+  }
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_usage(
+      iree_hal_buffer_allowed_usage(buffer_ref.buffer), required_usage));
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_access(
+      iree_hal_buffer_allowed_access(buffer_ref.buffer), required_access));
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_range(
+      buffer_ref.buffer, buffer_ref.offset, buffer_ref.length));
+  iree_hal_buffer_t* allocated_buffer =
+      iree_hal_buffer_allocated_buffer(buffer_ref.buffer);
+  uint8_t* device_ptr =
+      (uint8_t*)iree_hal_amdgpu_buffer_device_pointer(allocated_buffer);
+  if (IREE_UNLIKELY(!device_ptr)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "AQL command-buffer buffer must be backed by an AMDGPU allocation");
+  }
+  iree_device_size_t device_offset = 0;
+  if (IREE_UNLIKELY(!iree_device_size_checked_add(
+          iree_hal_buffer_byte_offset(buffer_ref.buffer), buffer_ref.offset,
+          &device_offset))) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "AQL command-buffer buffer device pointer offset overflows device "
+        "size");
+  }
+  if (IREE_UNLIKELY(device_offset > UINTPTR_MAX)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "AQL command-buffer buffer device pointer offset exceeds host pointer "
+        "size");
+  }
+  *out_device_ptr = device_ptr + (uintptr_t)device_offset;
+  return iree_ok_status();
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_block_processor_profile_resolve_command_buffer_ref(
+    iree_hal_command_buffer_t* command_buffer,
+    iree_hal_buffer_binding_table_t binding_table,
+    iree_hal_amdgpu_command_buffer_binding_kind_t kind, uint32_t ordinal,
+    uint64_t offset, uint64_t length, iree_hal_buffer_usage_t required_usage,
+    iree_hal_memory_access_t required_access,
+    iree_hal_buffer_ref_t* out_buffer_ref, uint8_t** out_device_ptr) {
+  memset(out_buffer_ref, 0, sizeof(*out_buffer_ref));
+  *out_device_ptr = NULL;
+  if (kind == IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_KIND_STATIC) {
+    iree_hal_buffer_t* buffer =
+        iree_hal_amdgpu_aql_command_buffer_static_buffer(command_buffer,
+                                                         ordinal);
+    if (IREE_UNLIKELY(!buffer)) {
+      return iree_make_status(
+          IREE_STATUS_INVALID_ARGUMENT,
+          "AQL command-buffer static buffer ordinal %" PRIu32 " is invalid",
+          ordinal);
+    }
+    *out_buffer_ref = iree_hal_make_buffer_ref(buffer, offset, length);
+  } else if (kind == IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_KIND_DYNAMIC) {
+    iree_hal_buffer_ref_t dynamic_ref =
+        iree_hal_make_indirect_buffer_ref(ordinal, offset, length);
+    IREE_RETURN_IF_ERROR(iree_hal_buffer_binding_table_resolve_ref(
+        binding_table, dynamic_ref, out_buffer_ref));
+  } else {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AQL command-buffer binding kind %u is invalid",
+                            kind);
+  }
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_block_processor_profile_resolve_buffer_ref_ptr(
+          *out_buffer_ref, required_usage, required_access, out_device_ptr));
+  return iree_ok_status();
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_block_processor_profile_resolve_static_binding_source_ptr(
+    iree_hal_command_buffer_t* command_buffer,
+    const iree_hal_amdgpu_command_buffer_binding_source_t* binding_source,
+    uint64_t* out_binding_ptr) {
+  *out_binding_ptr = 0;
+  iree_hal_buffer_t* buffer = iree_hal_amdgpu_aql_command_buffer_static_buffer(
+      command_buffer, binding_source->slot);
+  if (IREE_UNLIKELY(!buffer)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "AQL command-buffer static dispatch binding ordinal %" PRIu32
+        " is invalid",
+        binding_source->slot);
+  }
+  iree_hal_buffer_t* allocated_buffer =
+      iree_hal_buffer_allocated_buffer(buffer);
+  void* device_ptr = iree_hal_amdgpu_buffer_device_pointer(allocated_buffer);
+  if (IREE_UNLIKELY(!device_ptr)) {
+    return iree_make_status(
+        IREE_STATUS_FAILED_PRECONDITION,
+        "AQL command-buffer static dispatch binding has no staged AMDGPU "
+        "backing after queue waits completed");
+  }
+  iree_device_size_t device_offset = 0;
+  if (IREE_UNLIKELY(!iree_device_size_checked_add(
+          iree_hal_buffer_byte_offset(buffer),
+          binding_source->offset_or_pointer, &device_offset))) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "AQL command-buffer static dispatch binding pointer offset overflows "
+        "device size");
+  }
+  if (IREE_UNLIKELY(device_offset > UINTPTR_MAX)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "AQL command-buffer static dispatch binding pointer offset exceeds "
+        "host pointer size");
+  }
+  *out_binding_ptr =
+      (uint64_t)((uintptr_t)device_ptr + (uintptr_t)device_offset);
+  return iree_ok_status();
+}
+
+static bool iree_hal_amdgpu_aql_block_processor_profile_packet_has_barrier(
+    const iree_hal_amdgpu_aql_block_processor_profile_t* processor,
+    uint32_t packet_index,
+    iree_hal_amdgpu_host_queue_command_buffer_packet_flags_t packet_flags) {
+  return iree_any_bit_set(
+             packet_flags,
+             IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_FLAG_EXECUTION_BARRIER) ||
+         iree_any_bit_set(
+             packet_flags,
+             IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_FLAG_FINAL) ||
+         (packet_index == 0 &&
+          processor->submission.resolution->barrier_count > 0) ||
+         (packet_index == 0 &&
+          processor->submission.resolution->inline_acquire_scope !=
+              IREE_HSA_FENCE_SCOPE_NONE);
+}
+
+static iree_hal_amdgpu_host_queue_command_buffer_packet_flags_t
+iree_hal_amdgpu_aql_block_processor_profile_command_packet_flags(
+    const iree_hal_amdgpu_command_buffer_command_header_t* command) {
+  iree_hal_amdgpu_host_queue_command_buffer_packet_flags_t packet_flags =
+      IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_FLAG_NONE;
+  if (iree_any_bit_set(
+          command->flags,
+          IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_HAS_BARRIER)) {
+    packet_flags |=
+        IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_FLAG_EXECUTION_BARRIER;
+  }
+  return iree_hal_amdgpu_host_queue_command_buffer_packet_flags_set_fence_scopes(
+      packet_flags,
+      (iree_hsa_fence_scope_t)
+          iree_hal_amdgpu_command_buffer_command_flags_acquire_scope(
+              command->flags),
+      (iree_hsa_fence_scope_t)
+          iree_hal_amdgpu_command_buffer_command_flags_release_scope(
+              command->flags));
+}
+
+static iree_hal_amdgpu_host_queue_command_buffer_packet_flags_t
+iree_hal_amdgpu_aql_block_processor_profile_packet_flags_merge(
+    iree_hal_amdgpu_host_queue_command_buffer_packet_flags_t lhs,
+    iree_hal_amdgpu_host_queue_command_buffer_packet_flags_t rhs) {
+  const iree_hsa_fence_scope_t acquire_scope =
+      iree_hal_amdgpu_host_queue_max_fence_scope(
+          iree_hal_amdgpu_host_queue_command_buffer_packet_flags_acquire_scope(
+              lhs),
+          iree_hal_amdgpu_host_queue_command_buffer_packet_flags_acquire_scope(
+              rhs));
+  const iree_hsa_fence_scope_t release_scope =
+      iree_hal_amdgpu_host_queue_max_fence_scope(
+          iree_hal_amdgpu_host_queue_command_buffer_packet_flags_release_scope(
+              lhs),
+          iree_hal_amdgpu_host_queue_command_buffer_packet_flags_release_scope(
+              rhs));
+  return iree_hal_amdgpu_host_queue_command_buffer_packet_flags_set_fence_scopes(
+      lhs | rhs, acquire_scope, release_scope);
+}
+
+static iree_hal_amdgpu_host_queue_command_buffer_packet_flags_t
+iree_hal_amdgpu_aql_block_processor_profile_agent_barrier_packet_flags(void) {
+  return iree_hal_amdgpu_host_queue_command_buffer_packet_flags_set_fence_scopes(
+      IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_FLAG_EXECUTION_BARRIER,
+      IREE_HSA_FENCE_SCOPE_AGENT, IREE_HSA_FENCE_SCOPE_AGENT);
+}
+
+static bool
+iree_hal_amdgpu_aql_block_processor_profile_command_uses_queue_kernargs(
+    const iree_hal_amdgpu_command_buffer_command_header_t* command) {
+  return iree_any_bit_set(
+      command->flags,
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_USES_QUEUE_KERNARGS);
+}
+
+static iree_hsa_fence_scope_t
+iree_hal_amdgpu_aql_block_processor_profile_payload_acquire_scope(
+    const iree_hal_amdgpu_aql_block_processor_profile_t* processor,
+    const iree_hal_amdgpu_aql_block_processor_profile_state_t* state,
+    uint32_t packet_index,
+    const iree_hal_amdgpu_command_buffer_command_header_t* command,
+    iree_hal_amdgpu_host_queue_command_buffer_packet_flags_t packet_flags) {
+  if (processor->payload.acquire_scope == IREE_HSA_FENCE_SCOPE_NONE) {
+    return IREE_HSA_FENCE_SCOPE_NONE;
+  }
+  if (state->packets.recorded >= processor->payload.acquire_packet_count) {
+    return IREE_HSA_FENCE_SCOPE_NONE;
+  }
+  if (iree_hal_amdgpu_aql_block_processor_profile_command_uses_queue_kernargs(
+          command)) {
+    return processor->payload.acquire_scope;
+  }
+  if (iree_hal_amdgpu_aql_block_processor_profile_packet_has_barrier(
+          processor, processor->packets.index_base + packet_index,
+          packet_flags)) {
+    return processor->payload.acquire_scope;
+  }
+  return IREE_HSA_FENCE_SCOPE_NONE;
+}
+
+static bool iree_hal_amdgpu_aql_block_processor_profile_dispatch_uses_indirect(
+    const iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command) {
+  return iree_any_bit_set(
+      dispatch_command->dispatch_flags,
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_DISPATCH_FLAG_INDIRECT_PARAMETERS);
+}
+
+static bool
+iree_hal_amdgpu_aql_block_processor_profile_dispatch_uses_prepublished(
+    const iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command) {
+  return dispatch_command->kernarg_strategy ==
+         IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_PREPUBLISHED;
+}
+
+static bool
+iree_hal_amdgpu_aql_block_processor_profile_dispatch_uses_queue_kernargs(
+    const iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command) {
+  return !iree_hal_amdgpu_aql_block_processor_profile_dispatch_uses_prepublished(
+      dispatch_command);
+}
+
+static uint32_t
+iree_hal_amdgpu_aql_block_processor_profile_dispatch_target_kernarg_block_count(
+    const iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command) {
+  if (iree_hal_amdgpu_aql_block_processor_profile_dispatch_uses_prepublished(
+          dispatch_command)) {
+    return 0;
+  }
+  const uint32_t kernarg_length =
+      (uint32_t)dispatch_command->kernarg_length_qwords * 8u;
+  return iree_max(1u,
+                  (uint32_t)iree_host_size_ceil_div(
+                      kernarg_length, sizeof(iree_hal_amdgpu_kernarg_block_t)));
+}
+
+static uint32_t
+iree_hal_amdgpu_aql_block_processor_profile_dispatch_kernarg_block_count(
+    const iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command) {
+  return iree_hal_amdgpu_aql_block_processor_profile_dispatch_target_kernarg_block_count(
+             dispatch_command) +
+         (iree_hal_amdgpu_aql_block_processor_profile_dispatch_uses_indirect(
+              dispatch_command)
+              ? 1u
+              : 0u);
+}
+
+static void
+iree_hal_amdgpu_aql_block_processor_profile_write_dispatch_packet_body(
+    const iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command,
+    iree_hal_amdgpu_aql_packet_t* packet, uint8_t* kernarg_data,
+    iree_hsa_signal_t completion_signal, uint16_t* out_setup) {
+  packet->dispatch.setup = dispatch_command->setup;
+  packet->dispatch.workgroup_size[0] = dispatch_command->workgroup_size[0];
+  packet->dispatch.workgroup_size[1] = dispatch_command->workgroup_size[1];
+  packet->dispatch.workgroup_size[2] = dispatch_command->workgroup_size[2];
+  packet->dispatch.reserved0 = 0;
+  packet->dispatch.grid_size[0] = dispatch_command->grid_size[0];
+  packet->dispatch.grid_size[1] = dispatch_command->grid_size[1];
+  packet->dispatch.grid_size[2] = dispatch_command->grid_size[2];
+  packet->dispatch.private_segment_size =
+      dispatch_command->private_segment_size;
+  packet->dispatch.group_segment_size = dispatch_command->group_segment_size;
+  packet->dispatch.kernel_object = dispatch_command->kernel_object;
+  packet->dispatch.kernarg_address = kernarg_data;
+  packet->dispatch.reserved2 = 0;
+  packet->dispatch.completion_signal = completion_signal;
+  *out_setup = packet->dispatch.setup;
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_block_processor_profile_replay_dispatch_kernargs(
+    const iree_hal_amdgpu_aql_block_processor_profile_t* processor,
+    const iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command,
+    uint8_t* kernarg_data) {
+  switch (dispatch_command->kernarg_strategy) {
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_DYNAMIC_BINDINGS: {
+      if (IREE_UNLIKELY(!processor->bindings.ptrs)) {
+        return iree_make_status(
+            IREE_STATUS_INVALID_ARGUMENT,
+            "AQL command-buffer dispatch has dynamic bindings but no binding "
+            "table was provided");
+      }
+      uint64_t* binding_dst = (uint64_t*)kernarg_data;
+      const iree_hal_amdgpu_command_buffer_binding_source_t* binding_sources =
+          (const iree_hal_amdgpu_command_buffer_binding_source_t*)((const uint8_t*)
+                                                                       processor
+                                                                           ->block +
+                                                                   dispatch_command
+                                                                       ->binding_source_offset);
+      for (uint16_t i = 0; i < dispatch_command->binding_count; ++i) {
+        const iree_hal_amdgpu_command_buffer_binding_source_t* binding_source =
+            &binding_sources[i];
+        binding_dst[i] = processor->bindings.ptrs[binding_source->slot] +
+                         binding_source->offset_or_pointer;
+      }
+      const iree_host_size_t tail_length =
+          (iree_host_size_t)dispatch_command->payload.tail_length_qwords * 8u;
+      if (tail_length > 0) {
+        const uint8_t* tail_payload = (const uint8_t*)dispatch_command +
+                                      dispatch_command->payload_reference;
+        memcpy(
+            kernarg_data + (iree_host_size_t)dispatch_command->binding_count *
+                               sizeof(uint64_t),
+            tail_payload, tail_length);
+      }
+      return iree_ok_status();
+    }
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_HAL: {
+      uint64_t* binding_dst = (uint64_t*)kernarg_data;
+      const iree_hal_amdgpu_command_buffer_binding_source_t* binding_sources =
+          (const iree_hal_amdgpu_command_buffer_binding_source_t*)((const uint8_t*)
+                                                                       processor
+                                                                           ->block +
+                                                                   dispatch_command
+                                                                       ->binding_source_offset);
+      for (uint16_t i = 0; i < dispatch_command->binding_count; ++i) {
+        const iree_hal_amdgpu_command_buffer_binding_source_t* binding_source =
+            &binding_sources[i];
+        const uint32_t flags = binding_source->flags;
+        if (IREE_LIKELY(
+                flags ==
+                IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_NONE)) {
+          binding_dst[i] = binding_source->offset_or_pointer;
+        } else if (flags ==
+                   IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_DYNAMIC) {
+          if (IREE_UNLIKELY(!processor->bindings.ptrs)) {
+            return iree_make_status(
+                IREE_STATUS_INVALID_ARGUMENT,
+                "AQL command-buffer dispatch has dynamic bindings but no "
+                "binding table was provided");
+          }
+          binding_dst[i] = processor->bindings.ptrs[binding_source->slot] +
+                           binding_source->offset_or_pointer;
+        } else if (
+            flags ==
+            IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_STATIC_BUFFER) {
+          IREE_RETURN_IF_ERROR(
+              iree_hal_amdgpu_aql_block_processor_profile_resolve_static_binding_source_ptr(
+                  processor->command_buffer, binding_source, &binding_dst[i]));
+        } else {
+          return iree_make_status(
+              IREE_STATUS_INVALID_ARGUMENT,
+              "malformed AQL command-buffer dispatch binding source flags %u",
+              binding_source->flags);
+        }
+      }
+      const iree_host_size_t tail_length =
+          (iree_host_size_t)dispatch_command->payload.tail_length_qwords * 8u;
+      if (tail_length > 0) {
+        const uint8_t* tail_payload = (const uint8_t*)dispatch_command +
+                                      dispatch_command->payload_reference;
+        memcpy(
+            kernarg_data + (iree_host_size_t)dispatch_command->binding_count *
+                               sizeof(uint64_t),
+            tail_payload, tail_length);
+      }
+      return iree_ok_status();
+    }
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_CUSTOM_DIRECT: {
+      const iree_host_size_t tail_length =
+          (iree_host_size_t)dispatch_command->payload.tail_length_qwords * 8u;
+      if (tail_length > 0) {
+        const uint8_t* tail_payload = (const uint8_t*)dispatch_command +
+                                      dispatch_command->payload_reference;
+        memcpy(kernarg_data, tail_payload, tail_length);
+      }
+      return iree_ok_status();
+    }
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_INDIRECT:
+      return iree_make_status(
+          IREE_STATUS_UNIMPLEMENTED,
+          "indirect dispatch arguments are not supported by AMDGPU command "
+          "buffers yet");
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_PATCHED_TEMPLATE: {
+      const uint32_t kernarg_length =
+          (uint32_t)dispatch_command->kernarg_length_qwords * 8u;
+      const uint8_t* kernarg_template =
+          iree_hal_amdgpu_aql_command_buffer_rodata(
+              processor->command_buffer, dispatch_command->payload_reference,
+              kernarg_length);
+      if (IREE_UNLIKELY(!kernarg_template)) {
+        return iree_make_status(
+            IREE_STATUS_INVALID_ARGUMENT,
+            "AQL command-buffer patched kernarg template range is invalid");
+      }
+      memcpy(kernarg_data, kernarg_template, kernarg_length);
+      uint64_t* binding_dst = (uint64_t*)kernarg_data;
+      const uint16_t patch_source_count =
+          dispatch_command->payload.patch_source_count;
+      const uint64_t* binding_ptrs = processor->bindings.ptrs;
+      if (IREE_UNLIKELY(!binding_ptrs)) {
+        return iree_make_status(
+            IREE_STATUS_INVALID_ARGUMENT,
+            "AQL command-buffer dispatch has dynamic bindings but no binding "
+            "table was provided");
+      }
+      const iree_hal_amdgpu_command_buffer_binding_source_t* binding_sources =
+          (const iree_hal_amdgpu_command_buffer_binding_source_t*)((const uint8_t*)
+                                                                       processor
+                                                                           ->block +
+                                                                   dispatch_command
+                                                                       ->binding_source_offset);
+      for (uint16_t i = 0; i < patch_source_count; ++i) {
+        const iree_hal_amdgpu_command_buffer_binding_source_t* binding_source =
+            &binding_sources[i];
+        binding_dst[binding_source->target_binding_ordinal] =
+            binding_ptrs[binding_source->slot] +
+            binding_source->offset_or_pointer;
+      }
+      return iree_ok_status();
+    }
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_PREPUBLISHED:
+      return iree_make_status(
+          IREE_STATUS_INVALID_ARGUMENT,
+          "prepublished command-buffer dispatch should not rewrite kernargs");
+    default:
+      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                              "malformed AQL command-buffer kernarg strategy "
+                              "%u",
+                              dispatch_command->kernarg_strategy);
+  }
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_block_processor_profile_replay_dispatch_indirect_params_ptr(
+    const iree_hal_amdgpu_aql_block_processor_profile_t* processor,
+    const iree_hal_amdgpu_command_buffer_binding_source_t* binding_source,
+    const uint32_t** out_workgroup_count_ptr) {
+  *out_workgroup_count_ptr = NULL;
+  switch (binding_source->flags) {
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_INDIRECT_PARAMETERS:
+      *out_workgroup_count_ptr =
+          (const uint32_t*)(uintptr_t)binding_source->offset_or_pointer;
+      return iree_ok_status();
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_DYNAMIC |
+        IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_INDIRECT_PARAMETERS: {
+      iree_hal_buffer_ref_t resolved_ref = {0};
+      IREE_RETURN_IF_ERROR(iree_hal_buffer_binding_table_resolve_ref(
+          processor->bindings.table,
+          iree_hal_make_indirect_buffer_ref(binding_source->slot,
+                                            binding_source->offset_or_pointer,
+                                            sizeof(uint32_t[3])),
+          &resolved_ref));
+      uint8_t* device_ptr = NULL;
+      IREE_RETURN_IF_ERROR(
+          iree_hal_amdgpu_aql_block_processor_profile_resolve_buffer_ref_ptr(
+              resolved_ref, IREE_HAL_BUFFER_USAGE_DISPATCH_INDIRECT_PARAMETERS,
+              IREE_HAL_MEMORY_ACCESS_READ, &device_ptr));
+      *out_workgroup_count_ptr = (const uint32_t*)device_ptr;
+      return iree_ok_status();
+    }
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_STATIC_BUFFER |
+        IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_INDIRECT_PARAMETERS: {
+      uint64_t workgroup_count_ptr = 0;
+      IREE_RETURN_IF_ERROR(
+          iree_hal_amdgpu_aql_block_processor_profile_resolve_static_binding_source_ptr(
+              processor->command_buffer, binding_source, &workgroup_count_ptr));
+      *out_workgroup_count_ptr =
+          (const uint32_t*)(uintptr_t)workgroup_count_ptr;
+      return iree_ok_status();
+    }
+    default:
+      return iree_make_status(
+          IREE_STATUS_INVALID_ARGUMENT,
+          "malformed AQL command-buffer indirect parameter source flags %u",
+          binding_source->flags);
+  }
+}
+
+static iree_amdgpu_kernel_implicit_args_t*
+iree_hal_amdgpu_aql_block_processor_profile_dispatch_implicit_args_ptr(
+    const iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command,
+    uint8_t* kernarg_data) {
+  if (dispatch_command->implicit_args_offset_qwords == UINT16_MAX) {
+    return NULL;
+  }
+  return (
+      iree_amdgpu_kernel_implicit_args_t*)(kernarg_data +
+                                           (iree_host_size_t)dispatch_command
+                                                   ->implicit_args_offset_qwords *
+                                               8u);
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_block_processor_profile_replay_dispatch_packet_body(
+    const iree_hal_amdgpu_aql_block_processor_profile_t* processor,
+    const iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command,
+    iree_hal_amdgpu_aql_packet_t* packet, uint8_t* kernarg_data,
+    iree_hsa_signal_t completion_signal, uint16_t* out_setup) {
+  if (iree_hal_amdgpu_aql_block_processor_profile_dispatch_uses_prepublished(
+          dispatch_command)) {
+    const uint32_t kernarg_length =
+        (uint32_t)dispatch_command->kernarg_length_qwords * 8u;
+    kernarg_data = iree_hal_amdgpu_aql_command_buffer_prepublished_kernarg(
+        processor->command_buffer, dispatch_command->payload_reference,
+        kernarg_length);
+    if (IREE_UNLIKELY(!kernarg_data)) {
+      return iree_make_status(
+          IREE_STATUS_INVALID_ARGUMENT,
+          "AQL command-buffer prepublished kernarg range is invalid");
+    }
+  } else {
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_aql_block_processor_profile_replay_dispatch_kernargs(
+            processor, dispatch_command, kernarg_data));
+  }
+  iree_hal_amdgpu_aql_block_processor_profile_write_dispatch_packet_body(
+      dispatch_command, packet, kernarg_data, completion_signal, out_setup);
+  return iree_ok_status();
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_block_processor_profile_replay_indirect_dispatch_packet_bodies(
+    const iree_hal_amdgpu_aql_block_processor_profile_t* processor,
+    const iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command,
+    iree_hal_amdgpu_aql_packet_t* patch_packet,
+    iree_hal_amdgpu_aql_packet_t* dispatch_packet, uint8_t* patch_kernarg_data,
+    uint8_t* dispatch_kernarg_data, iree_hsa_signal_t completion_signal,
+    uint16_t dispatch_header, uint16_t* out_patch_setup,
+    uint16_t* out_dispatch_setup) {
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_block_processor_profile_replay_dispatch_packet_body(
+          processor, dispatch_command, dispatch_packet, dispatch_kernarg_data,
+          completion_signal, out_dispatch_setup));
+
+  const iree_hal_amdgpu_command_buffer_binding_source_t* binding_sources =
+      (const iree_hal_amdgpu_command_buffer_binding_source_t*)((const uint8_t*)
+                                                                   processor
+                                                                       ->block +
+                                                               dispatch_command
+                                                                   ->binding_source_offset);
+  const iree_hal_amdgpu_command_buffer_binding_source_t*
+      indirect_params_source =
+          &binding_sources[dispatch_command->binding_count];
+  const uint32_t* workgroup_count_ptr = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_block_processor_profile_replay_dispatch_indirect_params_ptr(
+          processor, indirect_params_source, &workgroup_count_ptr));
+
+  iree_amdgpu_kernel_implicit_args_t* implicit_args =
+      iree_hal_amdgpu_aql_block_processor_profile_dispatch_implicit_args_ptr(
+          dispatch_command, dispatch_kernarg_data);
+  iree_hal_amdgpu_device_dispatch_emplace_indirect_params_patch(
+      &processor->queue->transfer_context->kernels
+           ->iree_hal_amdgpu_device_dispatch_patch_indirect_params,
+      workgroup_count_ptr, &dispatch_packet->dispatch, dispatch_header,
+      *out_dispatch_setup, implicit_args, &patch_packet->dispatch,
+      patch_kernarg_data);
+  *out_patch_setup = patch_packet->dispatch.setup;
+  return iree_ok_status();
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_block_processor_profile_replay_fill_packet_body(
+    const iree_hal_amdgpu_aql_block_processor_profile_t* processor,
+    const iree_hal_amdgpu_command_buffer_fill_command_t* fill_command,
+    iree_hal_amdgpu_aql_packet_t* packet,
+    iree_hal_amdgpu_kernarg_block_t* kernarg_block,
+    iree_hsa_signal_t completion_signal, uint16_t* out_setup) {
+  iree_hal_buffer_ref_t target_ref = {0};
+  uint8_t* target_ptr = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_block_processor_profile_resolve_command_buffer_ref(
+          processor->command_buffer, processor->bindings.table,
+          fill_command->target_kind, fill_command->target_ordinal,
+          fill_command->target_offset, fill_command->length,
+          IREE_HAL_BUFFER_USAGE_TRANSFER_TARGET, IREE_HAL_MEMORY_ACCESS_WRITE,
+          &target_ref, &target_ptr));
+  (void)target_ref;
+  if (IREE_UNLIKELY(!iree_hal_amdgpu_device_buffer_fill_emplace(
+          processor->queue->transfer_context, &packet->dispatch, target_ptr,
+          fill_command->length, fill_command->pattern,
+          fill_command->pattern_length, kernarg_block->data))) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "unsupported command-buffer fill dispatch shape");
+  }
+  packet->dispatch.completion_signal = completion_signal;
+  *out_setup = packet->dispatch.setup;
+  return iree_ok_status();
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_block_processor_profile_replay_copy_packet_body(
+    const iree_hal_amdgpu_aql_block_processor_profile_t* processor,
+    const iree_hal_amdgpu_command_buffer_copy_command_t* copy_command,
+    iree_hal_amdgpu_aql_packet_t* packet,
+    iree_hal_amdgpu_kernarg_block_t* kernarg_block,
+    iree_hsa_signal_t completion_signal, uint16_t* out_setup) {
+  iree_hal_buffer_ref_t source_ref = {0};
+  uint8_t* source_ptr = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_block_processor_profile_resolve_command_buffer_ref(
+          processor->command_buffer, processor->bindings.table,
+          copy_command->source_kind, copy_command->source_ordinal,
+          copy_command->source_offset, copy_command->length,
+          IREE_HAL_BUFFER_USAGE_TRANSFER_SOURCE, IREE_HAL_MEMORY_ACCESS_READ,
+          &source_ref, &source_ptr));
+  iree_hal_buffer_ref_t target_ref = {0};
+  uint8_t* target_ptr = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_block_processor_profile_resolve_command_buffer_ref(
+          processor->command_buffer, processor->bindings.table,
+          copy_command->target_kind, copy_command->target_ordinal,
+          copy_command->target_offset, copy_command->length,
+          IREE_HAL_BUFFER_USAGE_TRANSFER_TARGET, IREE_HAL_MEMORY_ACCESS_WRITE,
+          &target_ref, &target_ptr));
+
+  if (IREE_UNLIKELY(
+          iree_hal_buffer_test_overlap(source_ref.buffer, source_ref.offset,
+                                       source_ref.length, target_ref.buffer,
+                                       target_ref.offset, target_ref.length) !=
+          IREE_HAL_BUFFER_OVERLAP_DISJOINT)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "source and target ranges must not overlap within the same buffer");
+  }
+  if (IREE_UNLIKELY(!iree_hal_amdgpu_device_buffer_copy_emplace(
+          processor->queue->transfer_context, &packet->dispatch, source_ptr,
+          target_ptr, copy_command->length, kernarg_block->data))) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "unsupported command-buffer copy dispatch shape");
+  }
+  packet->dispatch.completion_signal = completion_signal;
+  *out_setup = packet->dispatch.setup;
+  return iree_ok_status();
+}
+
+static iree_host_size_t
+iree_hal_amdgpu_aql_block_processor_profile_update_kernarg_length(
+    uint32_t source_length) {
+  const iree_host_size_t source_payload_offset =
+      IREE_HAL_AMDGPU_DEVICE_BUFFER_COPY_STAGED_SOURCE_OFFSET;
+  return source_payload_offset + (iree_host_size_t)source_length;
+}
+
+static uint32_t
+iree_hal_amdgpu_aql_block_processor_profile_update_kernarg_block_count(
+    uint32_t source_length) {
+  return (uint32_t)iree_host_size_ceil_div(
+      iree_hal_amdgpu_aql_block_processor_profile_update_kernarg_length(
+          source_length),
+      sizeof(iree_hal_amdgpu_kernarg_block_t));
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_block_processor_profile_resolve_update_packet_operands(
+    const iree_hal_amdgpu_aql_block_processor_profile_t* processor,
+    const iree_hal_amdgpu_command_buffer_update_command_t* update_command,
+    const uint8_t** out_source_bytes, uint8_t** out_target_ptr) {
+  *out_source_bytes = NULL;
+  *out_target_ptr = NULL;
+  iree_hal_buffer_ref_t target_ref = {0};
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_block_processor_profile_resolve_command_buffer_ref(
+          processor->command_buffer, processor->bindings.table,
+          update_command->target_kind, update_command->target_ordinal,
+          update_command->target_offset, update_command->length,
+          IREE_HAL_BUFFER_USAGE_TRANSFER_TARGET, IREE_HAL_MEMORY_ACCESS_WRITE,
+          &target_ref, out_target_ptr));
+  (void)target_ref;
+  *out_source_bytes = iree_hal_amdgpu_aql_command_buffer_rodata(
+      processor->command_buffer, update_command->rodata_ordinal,
+      update_command->length);
+  if (IREE_UNLIKELY(!*out_source_bytes)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "AQL command-buffer update rodata range is invalid");
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_block_processor_profile_replay_update_packet_body(
+    const iree_hal_amdgpu_aql_block_processor_profile_t* processor,
+    const iree_hal_amdgpu_command_buffer_update_command_t* update_command,
+    iree_hal_amdgpu_aql_packet_t* packet, uint8_t* kernarg_data,
+    iree_host_size_t kernarg_length, iree_hsa_signal_t completion_signal,
+    uint16_t* out_setup) {
+  const uint8_t* source_bytes = NULL;
+  uint8_t* target_ptr = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_block_processor_profile_resolve_update_packet_operands(
+          processor, update_command, &source_bytes, &target_ptr));
+
+  const iree_host_size_t source_payload_offset =
+      IREE_HAL_AMDGPU_DEVICE_BUFFER_COPY_STAGED_SOURCE_OFFSET;
+  const iree_host_size_t required_kernarg_length =
+      source_payload_offset + (iree_host_size_t)update_command->length;
+  if (IREE_UNLIKELY(required_kernarg_length > kernarg_length)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "AQL command-buffer update kernarg range is too small");
+  }
+
+  iree_hal_amdgpu_device_buffer_copy_kernargs_t kernargs;
+  memset(&kernargs, 0, sizeof(kernargs));
+  if (IREE_UNLIKELY(!iree_hal_amdgpu_device_buffer_copy_emplace(
+          processor->queue->transfer_context, &packet->dispatch,
+          (const void*)(uintptr_t)
+              IREE_HAL_AMDGPU_DEVICE_BUFFER_COPY_STAGED_SOURCE_ALIGNMENT,
+          target_ptr, update_command->length, &kernargs))) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "unsupported command-buffer update dispatch shape");
+  }
+
+  uint8_t* staged_source_bytes = kernarg_data + source_payload_offset;
+  memcpy(kernarg_data, &kernargs, sizeof(kernargs));
+  ((iree_hal_amdgpu_device_buffer_copy_kernargs_t*)kernarg_data)->source_ptr =
+      staged_source_bytes;
+  memcpy(staged_source_bytes, source_bytes, update_command->length);
+  packet->dispatch.kernarg_address = kernarg_data;
+  packet->dispatch.completion_signal = completion_signal;
+  *out_setup = packet->dispatch.setup;
+  return iree_ok_status();
+}
+
+static iree_hal_amdgpu_aql_packet_t*
+iree_hal_amdgpu_aql_block_processor_profile_packet(
+    const iree_hal_amdgpu_aql_block_processor_profile_t* processor,
+    uint32_t packet_index) {
+  return iree_hal_amdgpu_aql_ring_packet(
+      &processor->queue->aql_ring,
+      processor->packets.first_payload_id + packet_index);
+}
+
+static bool iree_hal_amdgpu_aql_block_processor_profile_is_block_final(
+    const iree_hal_amdgpu_aql_block_processor_profile_t* processor,
+    const iree_hal_amdgpu_aql_block_processor_profile_state_t* state,
+    uint32_t recorded_packet_count) {
+  return state->packets.recorded + recorded_packet_count ==
+         processor->block->aql_packet_count;
+}
+
+static iree_hal_amdgpu_aql_packet_control_t
+iree_hal_amdgpu_aql_block_processor_profile_packet_control(
+    const iree_hal_amdgpu_aql_block_processor_profile_t* processor,
+    uint32_t packet_index, iree_hsa_fence_scope_t minimum_acquire_scope,
+    iree_hal_amdgpu_host_queue_command_buffer_packet_flags_t packet_flags) {
+  return iree_hal_amdgpu_host_queue_command_buffer_packet_control(
+      processor->queue, processor->submission.resolution,
+      processor->submission.signal_semaphore_list,
+      processor->packets.index_base + packet_index, minimum_acquire_scope,
+      packet_flags);
+}
+
+static iree_hsa_signal_t
+iree_hal_amdgpu_aql_block_processor_profile_dispatch_completion_signal(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_dispatch_event_reservation_t profile_events,
+    uint32_t profile_event_index) {
+  const uint64_t profile_event_position =
+      profile_events.first_event_position + profile_event_index;
+  return iree_hal_amdgpu_host_queue_profiling_completion_signal(
+      queue, profile_event_position);
+}
+
+static bool iree_hal_amdgpu_aql_block_processor_dispatch_profile_has(
+    iree_hal_amdgpu_aql_block_processor_dispatch_profile_t profile,
+    iree_hal_amdgpu_aql_block_processor_dispatch_profile_flags_t flags) {
+  return iree_any_bit_set(profile.flags, flags);
+}
+
+static iree_hal_amdgpu_host_queue_command_buffer_packet_flags_t
+iree_hal_amdgpu_aql_block_processor_profile_packet_flags(
+    iree_hal_amdgpu_aql_block_processor_dispatch_profile_t profile) {
+  iree_hal_amdgpu_host_queue_command_buffer_packet_flags_t flags =
+      IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_FLAG_NONE;
+  if (iree_hal_amdgpu_aql_block_processor_dispatch_profile_has(
+          profile,
+          IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_DISPATCH_PROFILE_FLAG_COUNTER_PACKETS |
+              IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_DISPATCH_PROFILE_FLAG_TRACE_PACKETS)) {
+    flags = iree_hal_amdgpu_aql_block_processor_profile_packet_flags_merge(
+        flags,
+        iree_hal_amdgpu_aql_block_processor_profile_agent_barrier_packet_flags());
+  }
+  if (iree_hal_amdgpu_aql_block_processor_dispatch_profile_has(
+          profile,
+          IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_DISPATCH_PROFILE_FLAG_FINAL)) {
+    flags |= IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_FLAG_FINAL;
+  }
+  return flags;
+}
+
+static iree_hsa_signal_t
+iree_hal_amdgpu_aql_block_processor_profile_completion_signal(
+    const iree_hal_amdgpu_aql_block_processor_profile_t* processor,
+    const iree_hal_amdgpu_aql_block_processor_profile_state_t* state,
+    iree_hal_amdgpu_aql_block_processor_dispatch_profile_t profile) {
+  if (!iree_hal_amdgpu_aql_block_processor_dispatch_profile_has(
+          profile,
+          IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_DISPATCH_PROFILE_FLAG_DISPATCH_PACKET)) {
+    return iree_hsa_signal_null();
+  }
+  return iree_hal_amdgpu_aql_block_processor_profile_dispatch_completion_signal(
+      processor->queue, processor->profile.dispatch_events,
+      state->profile.event);
+}
+
+static void iree_hal_amdgpu_aql_block_processor_profile_emit_source(
+    const iree_hal_amdgpu_aql_block_processor_profile_t* processor,
+    iree_hal_amdgpu_aql_block_processor_dispatch_profile_t profile,
+    iree_hal_amdgpu_aql_block_processor_profile_state_t* state) {
+  if (!iree_hal_amdgpu_aql_block_processor_dispatch_profile_has(
+          profile,
+          IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_DISPATCH_PROFILE_FLAG_DISPATCH_PACKET)) {
+    return;
+  }
+  iree_hal_amdgpu_host_queue_record_command_buffer_profile_dispatch_source(
+      processor->queue, processor->profile.command_buffer_id, profile.dispatch,
+      processor->profile.dispatch_events, processor->profile.harvest_sources,
+      &state->profile.event);
+}
+
+static const iree_hal_amdgpu_aql_block_processor_profile_dispatch_t*
+iree_hal_amdgpu_aql_block_processor_profile_current_dispatch(
+    const iree_hal_amdgpu_aql_block_processor_profile_t* processor,
+    const iree_hal_amdgpu_aql_block_processor_profile_state_t* state,
+    uint32_t dispatch_packet_ordinal) {
+  if (state->profile.event >= processor->profile.dispatches.count) {
+    return NULL;
+  }
+  const iree_hal_amdgpu_aql_block_processor_profile_dispatch_t* dispatch =
+      &processor->profile.dispatches.values[state->profile.event];
+  if (IREE_UNLIKELY(!dispatch->summary)) return NULL;
+  if (dispatch->summary->packets.dispatch_ordinal != dispatch_packet_ordinal) {
+    return NULL;
+  }
+  return dispatch;
+}
+
+static iree_hal_amdgpu_aql_block_processor_dispatch_profile_t
+iree_hal_amdgpu_aql_block_processor_profile_dispatch_profile(
+    const iree_hal_amdgpu_aql_block_processor_profile_t* processor,
+    const iree_hal_amdgpu_aql_block_processor_profile_state_t* state,
+    uint32_t recorded_packet_count, uint32_t dispatch_packet_ordinal) {
+  iree_hal_amdgpu_aql_block_processor_dispatch_profile_flags_t flags =
+      IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_DISPATCH_PROFILE_FLAG_NONE;
+  const iree_hal_amdgpu_aql_block_processor_profile_dispatch_t*
+      profile_dispatch =
+          iree_hal_amdgpu_aql_block_processor_profile_current_dispatch(
+              processor, state, dispatch_packet_ordinal);
+  if (iree_any_bit_set(
+          processor->flags,
+          IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PROFILE_FLAG_DISPATCH_PACKETS) &&
+      profile_dispatch) {
+    flags |=
+        IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_DISPATCH_PROFILE_FLAG_DISPATCH_PACKET;
+  }
+  if (iree_any_bit_set(
+          flags,
+          IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_DISPATCH_PROFILE_FLAG_DISPATCH_PACKET) &&
+      processor->profile.counter_set_count != 0) {
+    flags |=
+        IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_DISPATCH_PROFILE_FLAG_COUNTER_PACKETS;
+  }
+  if (iree_any_bit_set(
+          flags,
+          IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_DISPATCH_PROFILE_FLAG_DISPATCH_PACKET) &&
+      processor->profile.trace_packet_count != 0) {
+    flags |=
+        IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_DISPATCH_PROFILE_FLAG_TRACE_PACKETS;
+  }
+  if (iree_hal_amdgpu_aql_block_processor_profile_is_block_final(
+          processor, state, recorded_packet_count)) {
+    flags |=
+        IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_DISPATCH_PROFILE_FLAG_BLOCK_FINAL;
+  }
+  if (iree_any_bit_set(
+          flags,
+          IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_DISPATCH_PROFILE_FLAG_BLOCK_FINAL) &&
+      !iree_any_bit_set(
+          processor->flags,
+          IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PROFILE_FLAG_DISPATCH_PACKETS |
+              IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PROFILE_FLAG_QUEUE_DEVICE_EVENT)) {
+    flags |= IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_DISPATCH_PROFILE_FLAG_FINAL;
+  }
+  return (iree_hal_amdgpu_aql_block_processor_dispatch_profile_t){
+      .flags = flags,
+      .event_position =
+          iree_any_bit_set(
+              flags,
+              IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_DISPATCH_PROFILE_FLAG_DISPATCH_PACKET)
+              ? processor->profile.dispatch_events.first_event_position +
+                    state->profile.event
+              : 0,
+      .dispatch = profile_dispatch,
+  };
+}
+
+static void iree_hal_amdgpu_aql_block_processor_profile_emit_counter_starts(
+    const iree_hal_amdgpu_aql_block_processor_profile_t* processor,
+    uint64_t event_position,
+    iree_hal_amdgpu_aql_packet_control_t packet_control,
+    iree_hal_amdgpu_aql_block_processor_profile_state_t* state) {
+  if (processor->profile.counter_set_count == 0) return;
+  iree_hal_amdgpu_host_queue_emplace_profile_counter_start_packets(
+      processor->queue, event_position, processor->profile.counter_set_count,
+      processor->packets.first_payload_id, state->packets.emitted,
+      packet_control, processor->packets.headers, processor->packets.setups);
+  state->packets.emitted += processor->profile.counter_set_count;
+}
+
+static void iree_hal_amdgpu_aql_block_processor_profile_emit_counter_read_stops(
+    const iree_hal_amdgpu_aql_block_processor_profile_t* processor,
+    uint64_t event_position,
+    iree_hal_amdgpu_aql_block_processor_profile_state_t* state) {
+  if (processor->profile.counter_set_count == 0) return;
+  iree_hal_amdgpu_host_queue_emplace_profile_counter_read_stop_packets(
+      processor->queue, event_position, processor->profile.counter_set_count,
+      processor->packets.first_payload_id, state->packets.emitted,
+      iree_hal_amdgpu_aql_packet_control_barrier(IREE_HSA_FENCE_SCOPE_AGENT,
+                                                 IREE_HSA_FENCE_SCOPE_AGENT),
+      processor->packets.headers, processor->packets.setups);
+  state->packets.emitted += processor->profile.counter_set_count * 2u;
+}
+
+static void iree_hal_amdgpu_aql_block_processor_profile_emit_trace_start(
+    const iree_hal_amdgpu_aql_block_processor_profile_t* processor,
+    uint64_t event_position,
+    iree_hal_amdgpu_host_queue_command_buffer_packet_flags_t packet_flags,
+    iree_hal_amdgpu_aql_block_processor_profile_state_t* state) {
+  iree_hal_amdgpu_host_queue_emplace_profile_trace_start_packet(
+      processor->queue, event_position, processor->packets.first_payload_id,
+      state->packets.emitted,
+      iree_hal_amdgpu_aql_block_processor_profile_packet_control(
+          processor, state->packets.emitted, IREE_HSA_FENCE_SCOPE_NONE,
+          packet_flags),
+      processor->packets.headers, processor->packets.setups);
+  ++state->packets.emitted;
+  iree_hal_amdgpu_host_queue_emplace_profile_trace_code_object_packet(
+      processor->queue, event_position, processor->packets.first_payload_id,
+      state->packets.emitted,
+      iree_hal_amdgpu_aql_packet_control_barrier(IREE_HSA_FENCE_SCOPE_AGENT,
+                                                 IREE_HSA_FENCE_SCOPE_AGENT),
+      processor->packets.headers, processor->packets.setups);
+  ++state->packets.emitted;
+}
+
+static void iree_hal_amdgpu_aql_block_processor_profile_emit_trace_stop(
+    const iree_hal_amdgpu_aql_block_processor_profile_t* processor,
+    uint64_t event_position,
+    iree_hal_amdgpu_aql_block_processor_profile_state_t* state) {
+  iree_hal_amdgpu_host_queue_emplace_profile_trace_stop_packet(
+      processor->queue, event_position, processor->packets.first_payload_id,
+      state->packets.emitted,
+      iree_hal_amdgpu_aql_packet_control_barrier(IREE_HSA_FENCE_SCOPE_AGENT,
+                                                 IREE_HSA_FENCE_SCOPE_AGENT),
+      processor->packets.headers, processor->packets.setups);
+  ++state->packets.emitted;
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_block_processor_profile_emit_direct_dispatch(
+    const iree_hal_amdgpu_aql_block_processor_profile_t* processor,
+    const iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command,
+    iree_hal_amdgpu_aql_block_processor_profile_state_t* state) {
+  const iree_hal_amdgpu_aql_block_processor_dispatch_profile_t profile =
+      iree_hal_amdgpu_aql_block_processor_profile_dispatch_profile(
+          processor, state, /*recorded_packet_count=*/1,
+          /*dispatch_packet_ordinal=*/state->packets.recorded);
+  const iree_hal_amdgpu_host_queue_command_buffer_packet_flags_t
+      profile_packet_flags =
+          iree_hal_amdgpu_aql_block_processor_profile_packet_flags(profile);
+  const iree_hal_amdgpu_aql_block_processor_dispatch_profile_flags_t
+      profile_flags = profile.flags;
+
+  if (iree_any_bit_set(
+          profile_flags,
+          IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_DISPATCH_PROFILE_FLAG_COUNTER_PACKETS)) {
+    iree_hal_amdgpu_aql_block_processor_profile_emit_counter_starts(
+        processor, profile.event_position,
+        iree_hal_amdgpu_aql_block_processor_profile_packet_control(
+            processor, state->packets.emitted, IREE_HSA_FENCE_SCOPE_NONE,
+            iree_hal_amdgpu_aql_block_processor_profile_agent_barrier_packet_flags()),
+        state);
+  }
+  if (iree_any_bit_set(
+          profile_flags,
+          IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_DISPATCH_PROFILE_FLAG_TRACE_PACKETS)) {
+    iree_hal_amdgpu_aql_block_processor_profile_emit_trace_start(
+        processor, profile.event_position,
+        iree_hal_amdgpu_aql_block_processor_profile_agent_barrier_packet_flags(),
+        state);
+  }
+
+  const uint32_t dispatch_packet_index = state->packets.emitted;
+  iree_hal_amdgpu_aql_packet_t* packet =
+      iree_hal_amdgpu_aql_block_processor_profile_packet(processor,
+                                                         dispatch_packet_index);
+  const iree_hsa_signal_t completion_signal =
+      iree_hal_amdgpu_aql_block_processor_profile_completion_signal(
+          processor, state, profile);
+  uint8_t* kernarg_data = NULL;
+  if (iree_hal_amdgpu_aql_block_processor_profile_command_uses_queue_kernargs(
+          &dispatch_command->header)) {
+    kernarg_data = processor->kernargs.blocks[state->kernargs.block].data;
+  }
+  iree_status_t status =
+      iree_hal_amdgpu_aql_block_processor_profile_replay_dispatch_packet_body(
+          processor, dispatch_command, packet, kernarg_data, completion_signal,
+          &processor->packets.setups[dispatch_packet_index]);
+  if (iree_status_is_ok(status)) {
+    const iree_hal_amdgpu_host_queue_command_buffer_packet_flags_t packet_flags =
+        iree_hal_amdgpu_aql_block_processor_profile_packet_flags_merge(
+            iree_hal_amdgpu_aql_block_processor_profile_command_packet_flags(
+                &dispatch_command->header),
+            profile_packet_flags);
+    const iree_hsa_fence_scope_t payload_acquire_scope =
+        iree_hal_amdgpu_aql_block_processor_profile_payload_acquire_scope(
+            processor, state, dispatch_packet_index, &dispatch_command->header,
+            packet_flags);
+    iree_hal_amdgpu_aql_block_processor_profile_emit_source(processor, profile,
+                                                            state);
+    processor->packets.headers[dispatch_packet_index] =
+        iree_hal_amdgpu_aql_make_header(
+            IREE_HSA_PACKET_TYPE_KERNEL_DISPATCH,
+            iree_hal_amdgpu_aql_block_processor_profile_packet_control(
+                processor, dispatch_packet_index, payload_acquire_scope,
+                packet_flags));
+    ++state->packets.emitted;
+    ++state->packets.recorded;
+    state->kernargs.block +=
+        iree_hal_amdgpu_aql_block_processor_profile_dispatch_kernarg_block_count(
+            dispatch_command);
+    if (iree_any_bit_set(
+            profile_flags,
+            IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_DISPATCH_PROFILE_FLAG_TRACE_PACKETS)) {
+      iree_hal_amdgpu_aql_block_processor_profile_emit_trace_stop(
+          processor, profile.event_position, state);
+    }
+    if (iree_any_bit_set(
+            profile_flags,
+            IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_DISPATCH_PROFILE_FLAG_COUNTER_PACKETS)) {
+      iree_hal_amdgpu_aql_block_processor_profile_emit_counter_read_stops(
+          processor, profile.event_position, state);
+    }
+  }
+  return status;
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_block_processor_profile_emit_indirect_dispatch(
+    const iree_hal_amdgpu_aql_block_processor_profile_t* processor,
+    const iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command,
+    iree_hal_amdgpu_aql_block_processor_profile_state_t* state) {
+  const iree_hal_amdgpu_aql_block_processor_dispatch_profile_t profile =
+      iree_hal_amdgpu_aql_block_processor_profile_dispatch_profile(
+          processor, state, /*recorded_packet_count=*/2,
+          /*dispatch_packet_ordinal=*/state->packets.recorded + 1u);
+  const iree_hal_amdgpu_host_queue_command_buffer_packet_flags_t
+      profile_packet_flags =
+          iree_hal_amdgpu_aql_block_processor_profile_packet_flags(profile);
+  const iree_hal_amdgpu_aql_block_processor_dispatch_profile_flags_t
+      profile_flags = profile.flags;
+
+  const uint32_t patch_packet_index = state->packets.emitted++;
+  if (iree_any_bit_set(
+          profile_flags,
+          IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_DISPATCH_PROFILE_FLAG_COUNTER_PACKETS)) {
+    iree_hal_amdgpu_aql_block_processor_profile_emit_counter_starts(
+        processor, profile.event_position,
+        iree_hal_amdgpu_aql_block_processor_profile_packet_control(
+            processor, state->packets.emitted, IREE_HSA_FENCE_SCOPE_NONE,
+            iree_hal_amdgpu_aql_block_processor_profile_agent_barrier_packet_flags()),
+        state);
+  }
+  if (iree_any_bit_set(
+          profile_flags,
+          IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_DISPATCH_PROFILE_FLAG_TRACE_PACKETS)) {
+    iree_hal_amdgpu_aql_block_processor_profile_emit_trace_start(
+        processor, profile.event_position,
+        iree_hal_amdgpu_aql_block_processor_profile_agent_barrier_packet_flags(),
+        state);
+  }
+  const uint32_t dispatch_packet_index = state->packets.emitted;
+  iree_hal_amdgpu_aql_packet_t* patch_packet =
+      iree_hal_amdgpu_aql_block_processor_profile_packet(processor,
+                                                         patch_packet_index);
+  iree_hal_amdgpu_aql_packet_t* dispatch_packet =
+      iree_hal_amdgpu_aql_block_processor_profile_packet(processor,
+                                                         dispatch_packet_index);
+  const iree_hsa_signal_t completion_signal =
+      iree_hal_amdgpu_aql_block_processor_profile_completion_signal(
+          processor, state, profile);
+  const uint16_t dispatch_header = iree_hal_amdgpu_aql_make_header(
+      IREE_HSA_PACKET_TYPE_KERNEL_DISPATCH,
+      iree_hal_amdgpu_aql_block_processor_profile_packet_control(
+          processor, dispatch_packet_index, IREE_HSA_FENCE_SCOPE_NONE,
+          profile_packet_flags));
+
+  iree_status_t status =
+      iree_hal_amdgpu_aql_block_processor_profile_replay_indirect_dispatch_packet_bodies(
+          processor, dispatch_command, patch_packet, dispatch_packet,
+          processor->kernargs.blocks[state->kernargs.block].data,
+          processor->kernargs.blocks[state->kernargs.block + 1].data,
+          completion_signal, dispatch_header,
+          &processor->packets.setups[patch_packet_index],
+          &processor->packets.setups[dispatch_packet_index]);
+  if (iree_status_is_ok(status)) {
+    const iree_hal_amdgpu_host_queue_command_buffer_packet_flags_t patch_flags =
+        iree_hal_amdgpu_aql_block_processor_profile_agent_barrier_packet_flags();
+    const iree_hsa_fence_scope_t patch_acquire_scope =
+        iree_hal_amdgpu_aql_block_processor_profile_payload_acquire_scope(
+            processor, state, patch_packet_index, &dispatch_command->header,
+            patch_flags);
+    iree_hal_amdgpu_aql_block_processor_profile_emit_source(processor, profile,
+                                                            state);
+    // The patch dispatch publishes the following dispatch packet header, so it
+    // must retire before the CP observes that slot.
+    processor->packets.headers[patch_packet_index] =
+        iree_hal_amdgpu_aql_make_header(
+            IREE_HSA_PACKET_TYPE_KERNEL_DISPATCH,
+            iree_hal_amdgpu_aql_block_processor_profile_packet_control(
+                processor, patch_packet_index, patch_acquire_scope,
+                patch_flags));
+    // The patch dispatch publishes the target dispatch header after it has
+    // updated dynamic workgroup counts on device.
+    processor->packets.headers[dispatch_packet_index] =
+        IREE_HSA_PACKET_TYPE_INVALID;
+    state->packets.emitted = dispatch_packet_index + /*dispatch packet=*/1;
+    state->packets.recorded += 2;
+    state->kernargs.block +=
+        iree_hal_amdgpu_aql_block_processor_profile_dispatch_kernarg_block_count(
+            dispatch_command);
+    if (iree_any_bit_set(
+            profile_flags,
+            IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_DISPATCH_PROFILE_FLAG_TRACE_PACKETS)) {
+      iree_hal_amdgpu_aql_block_processor_profile_emit_trace_stop(
+          processor, profile.event_position, state);
+    }
+    if (iree_any_bit_set(
+            profile_flags,
+            IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_DISPATCH_PROFILE_FLAG_COUNTER_PACKETS)) {
+      iree_hal_amdgpu_aql_block_processor_profile_emit_counter_read_stops(
+          processor, profile.event_position, state);
+    }
+  }
+  return status;
+}
+
+static iree_status_t iree_hal_amdgpu_aql_block_processor_profile_emit_dispatch(
+    const iree_hal_amdgpu_aql_block_processor_profile_t* processor,
+    const iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command,
+    iree_hal_amdgpu_aql_block_processor_profile_state_t* state) {
+  if (iree_hal_amdgpu_aql_block_processor_profile_dispatch_uses_indirect(
+          dispatch_command)) {
+    return iree_hal_amdgpu_aql_block_processor_profile_emit_indirect_dispatch(
+        processor, dispatch_command, state);
+  }
+  return iree_hal_amdgpu_aql_block_processor_profile_emit_direct_dispatch(
+      processor, dispatch_command, state);
+}
+
+static iree_status_t iree_hal_amdgpu_aql_block_processor_profile_emit_transfer(
+    const iree_hal_amdgpu_aql_block_processor_profile_t* processor,
+    const iree_hal_amdgpu_command_buffer_command_header_t* command,
+    iree_hal_amdgpu_aql_block_processor_profile_state_t* state) {
+  const uint32_t packet_index = state->packets.emitted;
+  iree_hal_amdgpu_host_queue_command_buffer_packet_flags_t packet_flags =
+      iree_hal_amdgpu_aql_block_processor_profile_command_packet_flags(command);
+  if (iree_hal_amdgpu_aql_block_processor_profile_is_block_final(
+          processor, state, /*recorded_packet_count=*/1) &&
+      !iree_any_bit_set(
+          processor->flags,
+          IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PROFILE_FLAG_DISPATCH_PACKETS |
+              IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PROFILE_FLAG_QUEUE_DEVICE_EVENT)) {
+    packet_flags |= IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_FLAG_FINAL;
+  }
+  iree_hal_amdgpu_aql_packet_t* packet =
+      iree_hal_amdgpu_aql_block_processor_profile_packet(processor,
+                                                         packet_index);
+  const iree_hsa_signal_t completion_signal = iree_hsa_signal_null();
+
+  iree_status_t status = iree_ok_status();
+  if (command->opcode == IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_FILL) {
+    status =
+        iree_hal_amdgpu_aql_block_processor_profile_replay_fill_packet_body(
+            processor,
+            (const iree_hal_amdgpu_command_buffer_fill_command_t*)command,
+            packet, &processor->kernargs.blocks[state->kernargs.block],
+            completion_signal, &processor->packets.setups[packet_index]);
+  } else if (command->opcode == IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_COPY) {
+    status =
+        iree_hal_amdgpu_aql_block_processor_profile_replay_copy_packet_body(
+            processor,
+            (const iree_hal_amdgpu_command_buffer_copy_command_t*)command,
+            packet, &processor->kernargs.blocks[state->kernargs.block],
+            completion_signal, &processor->packets.setups[packet_index]);
+  } else {
+    const iree_hal_amdgpu_command_buffer_update_command_t* update_command =
+        (const iree_hal_amdgpu_command_buffer_update_command_t*)command;
+    const iree_host_size_t kernarg_length =
+        iree_hal_amdgpu_aql_block_processor_profile_update_kernarg_length(
+            update_command->length);
+    const iree_host_size_t kernarg_block_count = iree_host_size_ceil_div(
+        kernarg_length, sizeof(iree_hal_amdgpu_kernarg_block_t));
+    status =
+        iree_hal_amdgpu_aql_block_processor_profile_replay_update_packet_body(
+            processor, update_command, packet,
+            processor->kernargs.blocks[state->kernargs.block].data,
+            kernarg_block_count * sizeof(iree_hal_amdgpu_kernarg_block_t),
+            completion_signal, &processor->packets.setups[packet_index]);
+    if (iree_status_is_ok(status)) {
+      state->kernargs.block += (uint32_t)kernarg_block_count - 1u;
+    }
+  }
+  if (iree_status_is_ok(status)) {
+    processor->packets.headers[packet_index] = iree_hal_amdgpu_aql_make_header(
+        IREE_HSA_PACKET_TYPE_KERNEL_DISPATCH,
+        iree_hal_amdgpu_aql_block_processor_profile_packet_control(
+            processor, packet_index,
+            iree_hal_amdgpu_aql_block_processor_profile_payload_acquire_scope(
+                processor, state, packet_index, command, packet_flags),
+            packet_flags));
+    ++state->packets.emitted;
+    ++state->packets.recorded;
+    ++state->kernargs.block;
+  }
+  return status;
+}
+
+void iree_hal_amdgpu_aql_block_processor_profile_initialize(
+    const iree_hal_amdgpu_aql_block_processor_profile_t* params,
+    iree_hal_amdgpu_aql_block_processor_profile_t* out_processor) {
+  *out_processor = *params;
+}
+
+void iree_hal_amdgpu_aql_block_processor_profile_deinitialize(
+    iree_hal_amdgpu_aql_block_processor_profile_t* processor) {
+  memset(processor, 0, sizeof(*processor));
+}
+
+iree_status_t iree_hal_amdgpu_aql_block_processor_profile_invoke(
+    const iree_hal_amdgpu_aql_block_processor_profile_t* processor,
+    iree_hal_amdgpu_aql_block_processor_profile_result_t* out_result) {
+  memset(out_result, 0, sizeof(*out_result));
+  if (IREE_UNLIKELY(processor->profile.dispatches.count !=
+                    processor->profile.dispatch_events.event_count)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AQL command-buffer block %" PRIu32
+                            " has %u profile dispatch sidecars "
+                            "but reserved %u profile dispatch events",
+                            processor->block->block_ordinal,
+                            processor->profile.dispatches.count,
+                            processor->profile.dispatch_events.event_count);
+  }
+  const iree_hal_amdgpu_command_buffer_command_header_t* command =
+      iree_hal_amdgpu_command_buffer_block_commands_const(processor->block);
+  iree_hal_amdgpu_aql_block_processor_profile_state_t state = {0};
+  bool reached_terminator = false;
+  iree_status_t status = iree_ok_status();
+  for (uint16_t i = 0; i < processor->block->command_count &&
+                       iree_status_is_ok(status) && !reached_terminator;
+       ++i) {
+    switch (command->opcode) {
+      case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BARRIER:
+        break;
+      case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_DISPATCH: {
+        const iree_hal_amdgpu_command_buffer_dispatch_command_t*
+            dispatch_command =
+                (const iree_hal_amdgpu_command_buffer_dispatch_command_t*)
+                    command;
+        status = iree_hal_amdgpu_aql_block_processor_profile_emit_dispatch(
+            processor, dispatch_command, &state);
+        break;
+      }
+      case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_FILL:
+      case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_COPY:
+      case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_UPDATE:
+        status = iree_hal_amdgpu_aql_block_processor_profile_emit_transfer(
+            processor, command, &state);
+        break;
+      case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_RETURN:
+        out_result->terminator =
+            IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PROFILE_TERMINATOR_RETURN;
+        reached_terminator = true;
+        break;
+      case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BRANCH: {
+        const iree_hal_amdgpu_command_buffer_branch_command_t* branch_command =
+            (const iree_hal_amdgpu_command_buffer_branch_command_t*)command;
+        out_result->terminator =
+            IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PROFILE_TERMINATOR_BRANCH;
+        out_result->target_block_ordinal = branch_command->target_block_ordinal;
+        reached_terminator = true;
+        break;
+      }
+      case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_PROFILE_MARKER:
+      case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_COND_BRANCH:
+        status = iree_make_status(
+            IREE_STATUS_UNIMPLEMENTED,
+            "AQL command-buffer opcode %u replay not yet wired",
+            command->opcode);
+        break;
+      default:
+        status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                                  "malformed AQL command-buffer opcode %u",
+                                  command->opcode);
+        break;
+    }
+    if (iree_status_is_ok(status) && !reached_terminator) {
+      command = iree_hal_amdgpu_command_buffer_command_next_const(command);
+    }
+  }
+  if (iree_status_is_ok(status) && !reached_terminator) {
+    status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                              "AQL command-buffer block %" PRIu32
+                              " has no terminator",
+                              processor->block->block_ordinal);
+  }
+  if (iree_status_is_ok(status) &&
+      state.packets.recorded != processor->block->aql_packet_count) {
+    status = iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "AQL command-buffer block %" PRIu32 " consumed %" PRIu32
+        " packets but declares %" PRIu32,
+        processor->block->block_ordinal, state.packets.recorded,
+        processor->block->aql_packet_count);
+  }
+  if (iree_status_is_ok(status) &&
+      state.packets.emitted != processor->packets.count) {
+    status =
+        iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                         "AQL command-buffer block %" PRIu32 " emitted %" PRIu32
+                         " payload packets but reserved %u",
+                         processor->block->block_ordinal, state.packets.emitted,
+                         processor->packets.count);
+  }
+  if (iree_status_is_ok(status) &&
+      state.kernargs.block != processor->kernargs.count) {
+    status =
+        iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                         "AQL command-buffer block %" PRIu32 " emitted %" PRIu32
+                         " kernarg blocks but reserved %" PRIu32,
+                         processor->block->block_ordinal, state.kernargs.block,
+                         processor->kernargs.count);
+  }
+  if (iree_status_is_ok(status) &&
+      state.profile.event != processor->profile.dispatch_events.event_count) {
+    status =
+        iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                         "AQL command-buffer block %" PRIu32 " emitted %" PRIu32
+                         " profile dispatch events but reserved %u",
+                         processor->block->block_ordinal, state.profile.event,
+                         processor->profile.dispatch_events.event_count);
+  }
+  out_result->packets.recorded = state.packets.recorded;
+  out_result->packets.emitted = state.packets.emitted;
+  out_result->kernargs.consumed = state.kernargs.block;
+  out_result->profile.events = state.profile.event;
+  return status;
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/aql_block_processor_profile.h b/runtime/src/iree/hal/drivers/amdgpu/aql_block_processor_profile.h
new file mode 100644
index 0000000..0ba18d6
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/aql_block_processor_profile.h
@@ -0,0 +1,164 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_AQL_BLOCK_PROCESSOR_PROFILE_H_
+#define IREE_HAL_DRIVERS_AMDGPU_AQL_BLOCK_PROCESSOR_PROFILE_H_
+
+#include "iree/base/api.h"
+#include "iree/hal/api.h"
+#include "iree/hal/drivers/amdgpu/aql_command_buffer.h"
+#include "iree/hal/drivers/amdgpu/host_queue.h"
+#include "iree/hal/drivers/amdgpu/host_queue_profile_events.h"
+#include "iree/hal/drivers/amdgpu/host_queue_submission.h"
+#include "iree/hal/drivers/amdgpu/util/kernarg_ring.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+typedef uint32_t iree_hal_amdgpu_aql_block_processor_profile_flags_t;
+enum iree_hal_amdgpu_aql_block_processor_profile_flag_bits_t {
+  IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PROFILE_FLAG_NONE = 0u,
+  // This block reserves dispatch timestamp events.
+  IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PROFILE_FLAG_DISPATCH_PACKETS = 1u << 0,
+  // This block reserves a whole-block queue-device timestamp event.
+  IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PROFILE_FLAG_QUEUE_DEVICE_EVENT = 1u << 1,
+};
+
+typedef uint8_t iree_hal_amdgpu_aql_block_processor_profile_terminator_t;
+enum iree_hal_amdgpu_aql_block_processor_profile_terminator_e {
+  IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PROFILE_TERMINATOR_NONE = 0u,
+  IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PROFILE_TERMINATOR_RETURN = 1u,
+  IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PROFILE_TERMINATOR_BRANCH = 2u,
+};
+
+// Selected command-buffer dispatch for host profile packet augmentation.
+typedef struct iree_hal_amdgpu_aql_block_processor_profile_dispatch_t {
+  // Retained dispatch summary selected by the active capture filter.
+  const iree_hal_amdgpu_aql_command_buffer_dispatch_summary_t* summary;
+} iree_hal_amdgpu_aql_block_processor_profile_dispatch_t;
+
+// Selected command-buffer dispatches in command order.
+typedef struct iree_hal_amdgpu_aql_block_processor_profile_dispatch_list_t {
+  // Selected dispatch entries allocated from caller-owned scratch storage.
+  const iree_hal_amdgpu_aql_block_processor_profile_dispatch_t* values;
+  // Number of selected dispatch entries.
+  uint32_t count;
+} iree_hal_amdgpu_aql_block_processor_profile_dispatch_list_t;
+
+// Host-side processor for one profiled AQL command-buffer block.
+typedef struct iree_hal_amdgpu_aql_block_processor_profile_t {
+  // Host queue owning the AQL ring and profiling slot storage.
+  iree_hal_amdgpu_host_queue_t* queue;
+  // Command buffer being replayed.
+  iree_hal_command_buffer_t* command_buffer;
+  // Recorded command-buffer block being emitted.
+  const iree_hal_amdgpu_command_buffer_block_header_t* block;
+  // Submission-level packet control inputs.
+  struct {
+    // Wait resolution prefixing this block submission.
+    const iree_hal_amdgpu_wait_resolution_t* resolution;
+    // Final signal list for this block when it is terminal.
+    iree_hal_semaphore_list_t signal_semaphore_list;
+  } submission;
+  // Queue-execute binding state consumed by dynamic block operands.
+  struct {
+    // Binding table supplied to queue_execute.
+    iree_hal_buffer_binding_table_t table;
+    // Pre-resolved binding pointers indexed by queue_execute binding table
+    // slot.
+    const uint64_t* ptrs;
+  } bindings;
+  // Reserved packet span populated by profiled replay.
+  struct {
+    // First reserved AQL payload packet id after wait/profile prefix packets.
+    uint64_t first_payload_id;
+    // Logical packet index of |first_payload_id| within the full submission.
+    uint32_t index_base;
+    // Number of reserved payload packet slots available to the processor.
+    uint32_t count;
+    // Header words produced by profiled replay and published by the caller.
+    uint16_t* headers;
+    // Setup words produced with |headers|.
+    uint16_t* setups;
+  } packets;
+  // Reserved queue-owned kernarg storage consumed by profiled replay.
+  struct {
+    // First reserved kernarg block.
+    iree_hal_amdgpu_kernarg_block_t* blocks;
+    // Number of reserved kernarg blocks.
+    uint32_t count;
+  } kernargs;
+  // Queue-owned payload visibility requirements.
+  struct {
+    // Minimum acquire scope for replayed payload packets in this block.
+    iree_hsa_fence_scope_t acquire_scope;
+    // Number of leading recorded payload packets requiring |acquire_scope|.
+    uint32_t acquire_packet_count;
+  } payload;
+  // Host profile sidecars consumed by profiled replay.
+  struct {
+    // Selected dispatches that receive profile packet augmentation.
+    iree_hal_amdgpu_aql_block_processor_profile_dispatch_list_t dispatches;
+    // Dispatch timestamp event reservation for profiled dispatches.
+    iree_hal_amdgpu_profile_dispatch_event_reservation_t dispatch_events;
+    // Harvest sources written when dispatch timestamp profiling is active.
+    iree_hal_amdgpu_profile_dispatch_harvest_source_t* harvest_sources;
+    // Stable command buffer id used in emitted profile records.
+    uint64_t command_buffer_id;
+    // Number of counter sets emitted around each profiled dispatch.
+    uint32_t counter_set_count;
+    // Number of executable trace packets emitted across this block.
+    uint32_t trace_packet_count;
+  } profile;
+  // Flags from iree_hal_amdgpu_aql_block_processor_profile_flag_bits_t.
+  iree_hal_amdgpu_aql_block_processor_profile_flags_t flags;
+} iree_hal_amdgpu_aql_block_processor_profile_t;
+
+// Result of invoking the profiled processor on one block.
+typedef struct iree_hal_amdgpu_aql_block_processor_profile_result_t {
+  // Packet accounting reported by the processor.
+  struct {
+    // Number of recorded block AQL packets consumed.
+    uint32_t recorded;
+    // Number of reserved AQL packets populated.
+    uint32_t emitted;
+  } packets;
+  // Kernarg accounting reported by the processor.
+  struct {
+    // Number of reserved kernarg blocks consumed.
+    uint32_t consumed;
+  } kernargs;
+  // Profile accounting reported by the processor.
+  struct {
+    // Number of dispatch profile events emitted.
+    uint32_t events;
+  } profile;
+  // Terminator kind reached by this invocation.
+  iree_hal_amdgpu_aql_block_processor_profile_terminator_t terminator;
+  // Branch target block ordinal when |terminator| is BRANCH.
+  uint32_t target_block_ordinal;
+} iree_hal_amdgpu_aql_block_processor_profile_result_t;
+
+// Initializes |out_processor| with borrowed submission storage.
+void iree_hal_amdgpu_aql_block_processor_profile_initialize(
+    const iree_hal_amdgpu_aql_block_processor_profile_t* params,
+    iree_hal_amdgpu_aql_block_processor_profile_t* out_processor);
+
+// Deinitializes |processor|. This currently releases no resources.
+void iree_hal_amdgpu_aql_block_processor_profile_deinitialize(
+    iree_hal_amdgpu_aql_block_processor_profile_t* processor);
+
+// Invokes |processor| and populates reserved packet/kernarg/profile storage.
+iree_status_t iree_hal_amdgpu_aql_block_processor_profile_invoke(
+    const iree_hal_amdgpu_aql_block_processor_profile_t* processor,
+    iree_hal_amdgpu_aql_block_processor_profile_result_t* out_result);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_AQL_BLOCK_PROCESSOR_PROFILE_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/aql_block_processor_test.cc b/runtime/src/iree/hal/drivers/amdgpu/aql_block_processor_test.cc
new file mode 100644
index 0000000..3e0d195
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/aql_block_processor_test.cc
@@ -0,0 +1,1105 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/aql_block_processor.h"
+
+#include <array>
+#include <cstring>
+#include <memory>
+
+#include "iree/hal/drivers/amdgpu/abi/queue.h"
+#include "iree/hal/drivers/amdgpu/buffer.h"
+#include "iree/hal/drivers/amdgpu/device/dispatch.h"
+#include "iree/hal/drivers/amdgpu/util/aql_emitter.h"
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace iree::hal::amdgpu {
+namespace {
+
+struct ReturnBlock {
+  // Block header at the ABI-defined block base.
+  iree_hal_amdgpu_command_buffer_block_header_t header;
+  // Single return terminator command.
+  iree_hal_amdgpu_command_buffer_return_command_t return_command;
+};
+
+struct BranchBlock {
+  // Block header at the ABI-defined block base.
+  iree_hal_amdgpu_command_buffer_block_header_t header;
+  // Single branch terminator command.
+  iree_hal_amdgpu_command_buffer_branch_command_t branch_command;
+};
+
+struct DirectDispatchBlock {
+  // Block header at the ABI-defined block base.
+  iree_hal_amdgpu_command_buffer_block_header_t header;
+  // Custom-direct dispatch command under test.
+  iree_hal_amdgpu_command_buffer_dispatch_command_t dispatch_command;
+  // Inline custom-direct kernarg tail copied into queue-owned kernargs.
+  uint64_t tail[2];
+  // Return terminator following the dispatch command and inline tail.
+  iree_hal_amdgpu_command_buffer_return_command_t return_command;
+};
+
+struct IndirectDispatchBlock {
+  // Block header at the ABI-defined block base.
+  iree_hal_amdgpu_command_buffer_block_header_t header;
+  // Custom indirect dispatch command under test.
+  iree_hal_amdgpu_command_buffer_dispatch_command_t dispatch_command;
+  // Return terminator following the dispatch command.
+  iree_hal_amdgpu_command_buffer_return_command_t return_command;
+  // Indirect parameter source referenced by the dispatch command.
+  iree_hal_amdgpu_command_buffer_binding_source_t indirect_params_source;
+};
+
+template <uint32_t DispatchCount>
+struct DispatchBlock {
+  // Block header at the ABI-defined block base.
+  iree_hal_amdgpu_command_buffer_block_header_t header;
+  // Direct dispatch commands recorded in this block.
+  iree_hal_amdgpu_command_buffer_dispatch_command_t
+      dispatch_commands[DispatchCount];
+  // Return terminator following the dispatch commands.
+  iree_hal_amdgpu_command_buffer_return_command_t return_command;
+};
+
+struct MalformedBlock {
+  // Block header at the ABI-defined block base.
+  iree_hal_amdgpu_command_buffer_block_header_t header;
+  // Non-terminating barrier command used to exercise validation.
+  iree_hal_amdgpu_command_buffer_barrier_command_t barrier_command;
+};
+
+struct PacketHeaderSummary {
+  // Counts accumulated across emitted packet headers.
+  struct {
+    // Number of emitted packet headers summarized.
+    uint32_t total;
+    // Number of emitted packet headers carrying the AQL barrier bit.
+    uint32_t barrier;
+    // Number of emitted packet headers with SYSTEM acquire scope.
+    uint32_t system_acquire;
+    // Number of emitted packet headers with SYSTEM release scope.
+    uint32_t system_release;
+  } counts;
+  // Boundary packet headers from the emitted span.
+  struct {
+    // First emitted packet header, or zero for empty spans.
+    uint16_t first;
+    // Last emitted packet header, or zero for empty spans.
+    uint16_t last;
+  } headers;
+};
+
+struct CommandBufferDeleter {
+  void operator()(iree_hal_command_buffer_t* command_buffer) const {
+    iree_hal_command_buffer_release(command_buffer);
+  }
+};
+
+using CommandBufferPtr =
+    std::unique_ptr<iree_hal_command_buffer_t, CommandBufferDeleter>;
+
+struct BufferDeleter {
+  void operator()(iree_hal_buffer_t* buffer) const {
+    iree_hal_buffer_release(buffer);
+  }
+};
+
+using BufferPtr = std::unique_ptr<iree_hal_buffer_t, BufferDeleter>;
+
+constexpr uint64_t kFillBlockX16KernelObject = 0xF160u;
+constexpr uint64_t kCopyBlockX16KernelObject = 0xC160u;
+constexpr uint64_t kPatchIndirectParamsKernelObject = 0x1D1EC7u;
+
+static void InitializeBlockHeader(
+    uint32_t block_length, uint32_t command_length, uint16_t command_count,
+    uint32_t aql_packet_count, uint32_t kernarg_length,
+    iree_hal_amdgpu_command_buffer_block_header_t* out_header) {
+  std::memset(out_header, 0, sizeof(*out_header));
+  out_header->magic = IREE_HAL_AMDGPU_COMMAND_BUFFER_BLOCK_MAGIC;
+  out_header->version = IREE_HAL_AMDGPU_COMMAND_BUFFER_BLOCK_VERSION_0;
+  out_header->header_length = sizeof(*out_header);
+  out_header->block_length = block_length;
+  out_header->command_offset = sizeof(*out_header);
+  out_header->command_length = command_length;
+  out_header->command_count = command_count;
+  out_header->aql_packet_count = aql_packet_count;
+  out_header->kernarg_length = kernarg_length;
+  out_header->initial_barrier_packet_count = aql_packet_count;
+  out_header->binding_source_offset = block_length;
+  out_header->rodata_offset = block_length;
+}
+
+static void InitializeReturnCommand(
+    iree_hal_amdgpu_command_buffer_return_command_t* out_command) {
+  std::memset(out_command, 0, sizeof(*out_command));
+  out_command->header.opcode = IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_RETURN;
+  out_command->header.length_qwords =
+      sizeof(*out_command) / IREE_HAL_AMDGPU_COMMAND_BUFFER_RECORD_ALIGNMENT;
+}
+
+static void SetReturnTerminator(
+    iree_hal_amdgpu_command_buffer_block_header_t* block_header) {
+  block_header->terminator_opcode =
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_RETURN;
+  block_header->terminator_target_block_ordinal = 0;
+}
+
+static void SetBranchTerminator(
+    uint32_t target_block_ordinal,
+    iree_hal_amdgpu_command_buffer_block_header_t* block_header) {
+  block_header->terminator_opcode =
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BRANCH;
+  block_header->terminator_target_block_ordinal = target_block_ordinal;
+}
+
+static uint8_t CommandFlags(uint8_t flags, iree_hsa_fence_scope_t acquire_scope,
+                            iree_hsa_fence_scope_t release_scope) {
+  return iree_hal_amdgpu_command_buffer_command_flags_set_fence_scopes(
+      flags, (uint8_t)acquire_scope, (uint8_t)release_scope);
+}
+
+static void InitializeDirectDispatchCommand(
+    uint32_t command_index, uint8_t command_flags,
+    iree_hal_amdgpu_command_buffer_dispatch_command_t* out_command) {
+  std::memset(out_command, 0, sizeof(*out_command));
+  out_command->header.opcode = IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_DISPATCH;
+  out_command->header.flags = command_flags;
+  out_command->header.length_qwords =
+      sizeof(*out_command) / IREE_HAL_AMDGPU_COMMAND_BUFFER_RECORD_ALIGNMENT;
+  out_command->header.command_index = command_index;
+  out_command->kernel_object = 0xABCDEF0000000000ull + command_index;
+  out_command->payload_reference = sizeof(*out_command);
+  out_command->kernarg_strategy =
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_CUSTOM_DIRECT;
+  out_command->implicit_args_offset_qwords = UINT16_MAX;
+  out_command->setup = 3;
+  out_command->workgroup_size[0] = 1;
+  out_command->workgroup_size[1] = 1;
+  out_command->workgroup_size[2] = 1;
+  out_command->grid_size[0] = 1;
+  out_command->grid_size[1] = 1;
+  out_command->grid_size[2] = 1;
+}
+
+static iree_hal_amdgpu_device_kernel_args_t MakeKernelArgs(
+    uint64_t kernel_object, uint16_t setup, uint16_t workgroup_size_x,
+    uint32_t private_segment_size, uint32_t group_segment_size) {
+  iree_hal_amdgpu_device_kernel_args_t kernel_args = {};
+  kernel_args.kernel_object = kernel_object;
+  kernel_args.kernarg_size = 32;
+  kernel_args.kernarg_alignment = 8;
+  kernel_args.setup = setup;
+  kernel_args.workgroup_size[0] = workgroup_size_x;
+  kernel_args.workgroup_size[1] = 1;
+  kernel_args.workgroup_size[2] = 1;
+  kernel_args.private_segment_size = private_segment_size;
+  kernel_args.group_segment_size = group_segment_size;
+  return kernel_args;
+}
+
+static iree_hal_amdgpu_device_kernels_t MakeTransferKernels() {
+  iree_hal_amdgpu_device_kernels_t kernels = {};
+  kernels.iree_hal_amdgpu_device_buffer_fill_block_x16 =
+      MakeKernelArgs(kFillBlockX16KernelObject, 5, 32, 8, 12);
+  kernels.iree_hal_amdgpu_device_buffer_copy_block_x16 =
+      MakeKernelArgs(kCopyBlockX16KernelObject, 10, 32, 13, 17);
+  kernels.iree_hal_amdgpu_device_dispatch_patch_indirect_params =
+      MakeKernelArgs(kPatchIndirectParamsKernelObject, 12, 1, 3, 7);
+  return kernels;
+}
+
+static iree_hal_amdgpu_device_buffer_transfer_context_t MakeTransferContext(
+    const iree_hal_amdgpu_device_kernels_t* kernels) {
+  iree_hal_amdgpu_device_buffer_transfer_context_t context = {};
+  iree_hal_amdgpu_device_buffer_transfer_context_initialize(
+      kernels, /*compute_unit_count=*/4, /*wavefront_size=*/64, &context);
+  return context;
+}
+
+static void NoOpBufferRelease(void* user_data, iree_hal_buffer_t* buffer) {
+  (void)user_data;
+  (void)buffer;
+}
+
+static iree_hal_buffer_binding_t MakeBinding(iree_hal_buffer_t* buffer,
+                                             iree_device_size_t offset,
+                                             iree_device_size_t length) {
+  iree_hal_buffer_binding_t binding = {};
+  binding.buffer = buffer;
+  binding.offset = offset;
+  binding.length = length;
+  return binding;
+}
+
+static ReturnBlock MakeReturnBlock() {
+  ReturnBlock block;
+  InitializeBlockHeader(sizeof(block), sizeof(block.return_command),
+                        /*command_count=*/1, /*aql_packet_count=*/0,
+                        /*kernarg_length=*/0, &block.header);
+  InitializeReturnCommand(&block.return_command);
+  SetReturnTerminator(&block.header);
+  return block;
+}
+
+static BranchBlock MakeBranchBlock(uint32_t target_block_ordinal) {
+  BranchBlock block;
+  InitializeBlockHeader(sizeof(block), sizeof(block.branch_command),
+                        /*command_count=*/1, /*aql_packet_count=*/0,
+                        /*kernarg_length=*/0, &block.header);
+  std::memset(&block.branch_command, 0, sizeof(block.branch_command));
+  block.branch_command.header.opcode =
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BRANCH;
+  block.branch_command.header.length_qwords =
+      sizeof(block.branch_command) /
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_RECORD_ALIGNMENT;
+  block.branch_command.target_block_ordinal = target_block_ordinal;
+  SetBranchTerminator(target_block_ordinal, &block.header);
+  return block;
+}
+
+static MalformedBlock MakeUnterminatedBlock() {
+  MalformedBlock block;
+  InitializeBlockHeader(sizeof(block), sizeof(block.barrier_command),
+                        /*command_count=*/1, /*aql_packet_count=*/0,
+                        /*kernarg_length=*/0, &block.header);
+  std::memset(&block.barrier_command, 0, sizeof(block.barrier_command));
+  block.barrier_command.header.opcode =
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BARRIER;
+  block.barrier_command.header.length_qwords =
+      sizeof(block.barrier_command) /
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_RECORD_ALIGNMENT;
+  return block;
+}
+
+static DirectDispatchBlock MakeDirectDispatchBlock() {
+  DirectDispatchBlock block;
+  const uint32_t dispatch_command_length =
+      sizeof(block.dispatch_command) + sizeof(block.tail);
+  const uint32_t command_length =
+      dispatch_command_length + sizeof(block.return_command);
+  InitializeBlockHeader(sizeof(block), command_length, /*command_count=*/2,
+                        /*aql_packet_count=*/1,
+                        /*kernarg_length=*/sizeof(block.tail), &block.header);
+
+  std::memset(&block.dispatch_command, 0, sizeof(block.dispatch_command));
+  block.dispatch_command.header.opcode =
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_DISPATCH;
+  block.dispatch_command.header.flags =
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_USES_QUEUE_KERNARGS;
+  block.dispatch_command.header.length_qwords =
+      dispatch_command_length / IREE_HAL_AMDGPU_COMMAND_BUFFER_RECORD_ALIGNMENT;
+  block.dispatch_command.kernel_object = 0x123456789ABCDEF0ull;
+  block.dispatch_command.payload_reference = sizeof(block.dispatch_command);
+  block.dispatch_command.kernarg_length_qwords =
+      sizeof(block.tail) / sizeof(uint64_t);
+  block.dispatch_command.payload.tail_length_qwords =
+      sizeof(block.tail) / sizeof(uint64_t);
+  block.dispatch_command.kernarg_strategy =
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_CUSTOM_DIRECT;
+  block.dispatch_command.implicit_args_offset_qwords = UINT16_MAX;
+  block.dispatch_command.setup = 3;
+  block.dispatch_command.workgroup_size[0] = 4;
+  block.dispatch_command.workgroup_size[1] = 2;
+  block.dispatch_command.workgroup_size[2] = 1;
+  block.dispatch_command.grid_size[0] = 64;
+  block.dispatch_command.grid_size[1] = 8;
+  block.dispatch_command.grid_size[2] = 1;
+  block.dispatch_command.private_segment_size = 128;
+  block.dispatch_command.group_segment_size = 256;
+  block.tail[0] = 0x0A0B0C0D0E0F1011ull;
+  block.tail[1] = 0x1213141516171819ull;
+
+  block.header.dispatch_count = 1;
+  InitializeReturnCommand(&block.return_command);
+  SetReturnTerminator(&block.header);
+  return block;
+}
+
+static IndirectDispatchBlock MakeIndirectDispatchBlock(
+    const uint32_t* workgroup_count) {
+  IndirectDispatchBlock block;
+  const uint32_t command_length =
+      sizeof(block.dispatch_command) + sizeof(block.return_command);
+  InitializeBlockHeader(sizeof(block), command_length, /*command_count=*/2,
+                        /*aql_packet_count=*/2,
+                        /*kernarg_length=*/
+                        2 * sizeof(iree_hal_amdgpu_kernarg_block_t),
+                        &block.header);
+  block.header.dispatch_count = 1;
+  block.header.indirect_dispatch_count = 1;
+  block.header.binding_source_count = 1;
+  block.header.binding_source_offset =
+      offsetof(IndirectDispatchBlock, indirect_params_source);
+
+  std::memset(&block.dispatch_command, 0, sizeof(block.dispatch_command));
+  block.dispatch_command.header.opcode =
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_DISPATCH;
+  block.dispatch_command.header.flags =
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_USES_QUEUE_KERNARGS;
+  block.dispatch_command.header.length_qwords =
+      sizeof(block.dispatch_command) /
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_RECORD_ALIGNMENT;
+  block.dispatch_command.kernel_object = 0x123456789ABCDEF0ull;
+  block.dispatch_command.binding_source_offset =
+      block.header.binding_source_offset;
+  block.dispatch_command.dispatch_flags =
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_DISPATCH_FLAG_INDIRECT_PARAMETERS;
+  block.dispatch_command.kernarg_strategy =
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_CUSTOM_DIRECT;
+  block.dispatch_command.implicit_args_offset_qwords = UINT16_MAX;
+  block.dispatch_command.setup = 3;
+  block.dispatch_command.workgroup_size[0] = 4;
+  block.dispatch_command.workgroup_size[1] = 2;
+  block.dispatch_command.workgroup_size[2] = 1;
+  block.dispatch_command.private_segment_size = 128;
+  block.dispatch_command.group_segment_size = 256;
+
+  InitializeReturnCommand(&block.return_command);
+  SetReturnTerminator(&block.header);
+
+  std::memset(&block.indirect_params_source, 0,
+              sizeof(block.indirect_params_source));
+  block.indirect_params_source.flags =
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_INDIRECT_PARAMETERS;
+  block.indirect_params_source.offset_or_pointer =
+      (uint64_t)(uintptr_t)workgroup_count;
+  return block;
+}
+
+template <uint32_t DispatchCount>
+static DispatchBlock<DispatchCount> MakeDispatchBlock(
+    const uint8_t (&dispatch_command_flags)[DispatchCount]) {
+  DispatchBlock<DispatchCount> block;
+  const uint32_t command_length =
+      DispatchCount * sizeof(block.dispatch_commands[0]) +
+      sizeof(block.return_command);
+  InitializeBlockHeader(sizeof(block), command_length,
+                        /*command_count=*/DispatchCount + 1,
+                        /*aql_packet_count=*/DispatchCount,
+                        /*kernarg_length=*/DispatchCount *
+                            sizeof(iree_hal_amdgpu_kernarg_block_t),
+                        &block.header);
+  block.header.dispatch_count = DispatchCount;
+  for (uint32_t i = 0; i < DispatchCount; ++i) {
+    InitializeDirectDispatchCommand(i, dispatch_command_flags[i],
+                                    &block.dispatch_commands[i]);
+  }
+  InitializeReturnCommand(&block.return_command);
+  SetReturnTerminator(&block.header);
+  return block;
+}
+
+static iree_hal_amdgpu_aql_block_processor_t MakeProcessor(
+    iree_hal_amdgpu_aql_ring_t* ring, uint32_t packet_count,
+    uint16_t* packet_headers, uint16_t* packet_setups,
+    iree_hal_amdgpu_kernarg_block_t* kernarg_blocks,
+    uint32_t kernarg_block_count,
+    iree_hal_amdgpu_aql_block_processor_flags_t flags =
+        IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_FLAG_NONE,
+    iree_hsa_fence_scope_t inline_acquire_scope = IREE_HSA_FENCE_SCOPE_NONE,
+    iree_hsa_fence_scope_t signal_release_scope = IREE_HSA_FENCE_SCOPE_SYSTEM,
+    iree_hsa_fence_scope_t payload_acquire_scope = IREE_HSA_FENCE_SCOPE_SYSTEM,
+    const iree_hal_amdgpu_device_buffer_transfer_context_t* transfer_context =
+        nullptr,
+    iree_hal_command_buffer_t* command_buffer = nullptr,
+    iree_hal_buffer_binding_table_t binding_table = {0, nullptr}) {
+  iree_hal_amdgpu_aql_block_processor_t processor = {};
+  processor.transfer_context = transfer_context;
+  processor.command_buffer = command_buffer;
+  processor.bindings.table = binding_table;
+  processor.packets.ring = ring;
+  processor.packets.first_id = 4;
+  processor.packets.index_base = 0;
+  processor.packets.count = packet_count;
+  processor.packets.headers = packet_headers;
+  processor.packets.setups = packet_setups;
+  processor.kernargs.blocks = kernarg_blocks;
+  processor.kernargs.count = kernarg_block_count;
+  processor.submission.inline_acquire_scope = inline_acquire_scope;
+  processor.submission.signal_release_scope = signal_release_scope;
+  processor.payload.acquire_scope = payload_acquire_scope;
+  processor.flags = flags;
+  return processor;
+}
+
+static uint32_t KernargBlockCount(
+    const iree_hal_amdgpu_command_buffer_block_header_t* block) {
+  return (uint32_t)iree_host_size_ceil_div(
+      block->kernarg_length, sizeof(iree_hal_amdgpu_kernarg_block_t));
+}
+
+static uint16_t AqlHeaderField(uint16_t header, uint32_t bit_offset,
+                               uint32_t bit_width) {
+  return (header >> bit_offset) & ((1u << bit_width) - 1u);
+}
+
+static bool AqlHeaderHasBarrier(uint16_t header) {
+  return AqlHeaderField(header, IREE_HSA_PACKET_HEADER_BARRIER,
+                        IREE_HSA_PACKET_HEADER_WIDTH_BARRIER) != 0;
+}
+
+static iree_hsa_fence_scope_t AqlHeaderAcquireScope(uint16_t header) {
+  return (iree_hsa_fence_scope_t)AqlHeaderField(
+      header, IREE_HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE,
+      IREE_HSA_PACKET_HEADER_WIDTH_SCACQUIRE_FENCE_SCOPE);
+}
+
+static iree_hsa_fence_scope_t AqlHeaderReleaseScope(uint16_t header) {
+  return (iree_hsa_fence_scope_t)AqlHeaderField(
+      header, IREE_HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE,
+      IREE_HSA_PACKET_HEADER_WIDTH_SCRELEASE_FENCE_SCOPE);
+}
+
+static PacketHeaderSummary SummarizePacketHeaders(
+    const uint16_t* packet_headers, uint32_t packet_count) {
+  PacketHeaderSummary summary = {};
+  for (uint32_t i = 0; i < packet_count; ++i) {
+    const uint16_t header = packet_headers[i];
+    if (summary.counts.total == 0) summary.headers.first = header;
+    summary.headers.last = header;
+    ++summary.counts.total;
+    if (AqlHeaderHasBarrier(header)) ++summary.counts.barrier;
+    if (AqlHeaderAcquireScope(header) == IREE_HSA_FENCE_SCOPE_SYSTEM) {
+      ++summary.counts.system_acquire;
+    }
+    if (AqlHeaderReleaseScope(header) == IREE_HSA_FENCE_SCOPE_SYSTEM) {
+      ++summary.counts.system_release;
+    }
+  }
+  return summary;
+}
+
+template <uint32_t DispatchCount>
+static iree_status_t InvokeAndSummarizeDispatchBlock(
+    const DispatchBlock<DispatchCount>& block,
+    iree_hal_amdgpu_aql_block_processor_flags_t flags,
+    iree_hsa_fence_scope_t inline_acquire_scope,
+    iree_hsa_fence_scope_t signal_release_scope,
+    iree_hsa_fence_scope_t payload_acquire_scope,
+    PacketHeaderSummary* out_summary) {
+  alignas(64) iree_hal_amdgpu_aql_packet_t packets[8] = {};
+  iree_hal_amdgpu_aql_ring_t ring = {};
+  ring.base = packets;
+  ring.mask = IREE_ARRAYSIZE(packets) - 1u;
+  uint16_t packet_headers[DispatchCount] = {};
+  uint16_t packet_setups[DispatchCount] = {};
+  iree_hal_amdgpu_kernarg_block_t kernarg_blocks[DispatchCount] = {};
+  iree_hal_amdgpu_aql_block_processor_t processor = MakeProcessor(
+      &ring, /*packet_count=*/DispatchCount, packet_headers, packet_setups,
+      kernarg_blocks, /*kernarg_block_count=*/DispatchCount, flags,
+      inline_acquire_scope, signal_release_scope, payload_acquire_scope);
+
+  iree_hal_amdgpu_aql_block_processor_result_t result;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_aql_block_processor_invoke(
+      &processor, &block.header, &result));
+  *out_summary = SummarizePacketHeaders(packet_headers, DispatchCount);
+  return iree_ok_status();
+}
+
+class AqlBlockProcessorRecordedTest : public ::testing::Test {
+ protected:
+  void SetUp() override {
+    IREE_ASSERT_OK(iree_hal_allocator_create_heap(
+        iree_make_cstring_view("aql_block_processor_test"),
+        iree_allocator_system(), iree_allocator_system(), &device_allocator_));
+    iree_hal_amdgpu_profile_metadata_initialize(iree_allocator_system(),
+                                                &profile_metadata_);
+    IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_block_pool_initialize(
+        block_size_, iree_allocator_system(), &block_pool_));
+  }
+
+  void TearDown() override {
+    iree_arena_block_pool_deinitialize(&block_pool_);
+    iree_hal_amdgpu_profile_metadata_deinitialize(&profile_metadata_);
+    iree_hal_allocator_release(device_allocator_);
+  }
+
+  CommandBufferPtr CreateCommandBuffer(iree_host_size_t binding_capacity) {
+    iree_hal_command_buffer_t* command_buffer = nullptr;
+    IREE_EXPECT_OK(iree_hal_amdgpu_aql_command_buffer_create(
+        device_allocator_, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT,
+        IREE_HAL_COMMAND_CATEGORY_ANY, IREE_HAL_QUEUE_AFFINITY_ANY,
+        binding_capacity, /*device_ordinal=*/0,
+        iree_hal_amdgpu_aql_prepublished_kernarg_storage_disabled(),
+        &profile_metadata_, &block_pool_, &block_pool_, iree_allocator_system(),
+        &command_buffer));
+    return CommandBufferPtr(command_buffer);
+  }
+
+  BufferPtr CreateBuffer(void* storage, iree_device_size_t length) {
+    iree_hal_buffer_release_callback_t release_callback = {};
+    release_callback.fn = NoOpBufferRelease;
+    iree_hal_buffer_t* buffer = nullptr;
+    IREE_EXPECT_OK(iree_hal_amdgpu_buffer_create(
+        /*libhsa=*/nullptr, iree_hal_buffer_placement_undefined(),
+        IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL | IREE_HAL_MEMORY_TYPE_HOST_VISIBLE,
+        IREE_HAL_MEMORY_ACCESS_ALL,
+        IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_DISPATCH, length,
+        length, storage, release_callback, iree_allocator_system(), &buffer));
+    return BufferPtr(buffer);
+  }
+
+ private:
+  // Test allocator borrowed by command buffers for validation.
+  iree_hal_allocator_t* device_allocator_ = nullptr;
+  // Fixed block size used by recorded command-buffer tests.
+  iree_host_size_t block_size_ = 4096;
+  // Program and resource-set block pool borrowed by test command buffers.
+  iree_arena_block_pool_t block_pool_;
+  // Profile metadata registry borrowed by test command buffers.
+  iree_hal_amdgpu_profile_metadata_registry_t profile_metadata_;
+};
+
+TEST(AqlBlockProcessorTest, ReturnTerminatorProducesNoPayload) {
+  ReturnBlock block = MakeReturnBlock();
+  iree_hal_amdgpu_aql_block_processor_t processor =
+      MakeProcessor(/*ring=*/nullptr, /*packet_count=*/0,
+                    /*packet_headers=*/nullptr, /*packet_setups=*/nullptr,
+                    /*kernarg_blocks=*/nullptr, /*kernarg_block_count=*/0);
+
+  iree_hal_amdgpu_aql_block_processor_result_t result;
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_block_processor_invoke(
+      &processor, &block.header, &result));
+
+  EXPECT_EQ(result.terminator,
+            IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_TERMINATOR_RETURN);
+  EXPECT_EQ(block.header.terminator_opcode,
+            IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_RETURN);
+  EXPECT_EQ(result.packets.recorded, 0u);
+  EXPECT_EQ(result.packets.emitted, 0u);
+  EXPECT_EQ(result.kernargs.consumed, 0u);
+}
+
+TEST(AqlBlockProcessorTest, BranchTerminatorReportsTargetBlock) {
+  BranchBlock block = MakeBranchBlock(/*target_block_ordinal=*/7);
+  iree_hal_amdgpu_aql_block_processor_t processor =
+      MakeProcessor(/*ring=*/nullptr, /*packet_count=*/0,
+                    /*packet_headers=*/nullptr, /*packet_setups=*/nullptr,
+                    /*kernarg_blocks=*/nullptr, /*kernarg_block_count=*/0);
+
+  iree_hal_amdgpu_aql_block_processor_result_t result;
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_block_processor_invoke(
+      &processor, &block.header, &result));
+
+  EXPECT_EQ(result.terminator,
+            IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_TERMINATOR_BRANCH);
+  EXPECT_EQ(result.target_block_ordinal, 7u);
+  EXPECT_EQ(block.header.terminator_opcode,
+            IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BRANCH);
+  EXPECT_EQ(block.header.terminator_target_block_ordinal,
+            result.target_block_ordinal);
+  EXPECT_EQ(result.packets.recorded, 0u);
+  EXPECT_EQ(result.packets.emitted, 0u);
+  EXPECT_EQ(result.kernargs.consumed, 0u);
+}
+
+TEST(AqlBlockProcessorTest, UnterminatedBlockFails) {
+  MalformedBlock block = MakeUnterminatedBlock();
+  iree_hal_amdgpu_aql_block_processor_t processor =
+      MakeProcessor(/*ring=*/nullptr, /*packet_count=*/0,
+                    /*packet_headers=*/nullptr, /*packet_setups=*/nullptr,
+                    /*kernarg_blocks=*/nullptr, /*kernarg_block_count=*/0);
+
+  iree_hal_amdgpu_aql_block_processor_result_t result;
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_INVALID_ARGUMENT,
+                        iree_hal_amdgpu_aql_block_processor_invoke(
+                            &processor, &block.header, &result));
+}
+
+TEST(AqlBlockProcessorTest, DirectDispatchPopulatesPacketAndKernarg) {
+  DirectDispatchBlock block = MakeDirectDispatchBlock();
+  alignas(64) iree_hal_amdgpu_aql_packet_t packets[8] = {};
+  iree_hal_amdgpu_aql_ring_t ring = {};
+  ring.base = packets;
+  ring.mask = 7;
+  uint16_t packet_headers[1] = {0xCDCD};
+  uint16_t packet_setups[1] = {0xCDCD};
+  iree_hal_amdgpu_kernarg_block_t kernarg_blocks[1] = {};
+  iree_hal_amdgpu_aql_block_processor_t processor = MakeProcessor(
+      &ring, /*packet_count=*/1, packet_headers, packet_setups, kernarg_blocks,
+      /*kernarg_block_count=*/1,
+      IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_FLAG_FINAL_PAYLOAD_PACKET);
+
+  iree_hal_amdgpu_aql_block_processor_result_t result;
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_block_processor_invoke(
+      &processor, &block.header, &result));
+
+  EXPECT_EQ(result.terminator,
+            IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_TERMINATOR_RETURN);
+  EXPECT_EQ(result.packets.recorded, 1u);
+  EXPECT_EQ(result.packets.emitted, 1u);
+  EXPECT_EQ(result.kernargs.consumed, 1u);
+
+  const iree_hal_amdgpu_aql_packet_t& packet = packets[4];
+  EXPECT_EQ(packet.dispatch.setup, block.dispatch_command.setup);
+  EXPECT_EQ(packet.dispatch.workgroup_size[0],
+            block.dispatch_command.workgroup_size[0]);
+  EXPECT_EQ(packet.dispatch.grid_size[0], block.dispatch_command.grid_size[0]);
+  EXPECT_EQ(packet.dispatch.private_segment_size,
+            block.dispatch_command.private_segment_size);
+  EXPECT_EQ(packet.dispatch.group_segment_size,
+            block.dispatch_command.group_segment_size);
+  EXPECT_EQ(packet.dispatch.kernel_object,
+            block.dispatch_command.kernel_object);
+  EXPECT_EQ(packet.dispatch.kernarg_address, kernarg_blocks[0].data);
+  EXPECT_EQ(packet.dispatch.completion_signal.handle, 0u);
+  EXPECT_EQ(packet_setups[0], block.dispatch_command.setup);
+  EXPECT_EQ(std::memcmp(kernarg_blocks[0].data, block.tail, sizeof(block.tail)),
+            0);
+
+  EXPECT_EQ(packet_headers[0],
+            iree_hsa_make_packet_header(IREE_HSA_PACKET_TYPE_KERNEL_DISPATCH,
+                                        /*is_barrier=*/true,
+                                        IREE_HSA_FENCE_SCOPE_SYSTEM,
+                                        IREE_HSA_FENCE_SCOPE_SYSTEM));
+}
+
+TEST(AqlBlockProcessorTest,
+     BuilderProducedSplitBlocksInvokeAsBranchThenReturn) {
+  iree_arena_block_pool_t block_pool;
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_block_pool_initialize(
+      /*block_size=*/256, iree_allocator_system(), &block_pool));
+
+  iree_hal_amdgpu_aql_program_builder_t builder;
+  iree_hal_amdgpu_aql_program_builder_initialize(&block_pool, &builder);
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_builder_begin(&builder));
+
+  for (uint32_t i = 0; i < 4; ++i) {
+    iree_hal_amdgpu_command_buffer_command_header_t* command = nullptr;
+    IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_builder_append_command(
+        &builder, IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_DISPATCH,
+        IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_NONE,
+        sizeof(iree_hal_amdgpu_command_buffer_dispatch_command_t),
+        /*binding_source_count=*/0, /*aql_packet_count=*/1,
+        sizeof(iree_hal_amdgpu_kernarg_block_t), &command,
+        /*out_binding_sources=*/nullptr));
+    auto* dispatch_command =
+        reinterpret_cast<iree_hal_amdgpu_command_buffer_dispatch_command_t*>(
+            command);
+    dispatch_command->kernel_object = 0xABCDEF0000000000ull + i;
+    dispatch_command->payload_reference = sizeof(*dispatch_command);
+    dispatch_command->kernarg_strategy =
+        IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_CUSTOM_DIRECT;
+    dispatch_command->implicit_args_offset_qwords = UINT16_MAX;
+    dispatch_command->setup = 3;
+    dispatch_command->workgroup_size[0] = 1;
+    dispatch_command->workgroup_size[1] = 1;
+    dispatch_command->workgroup_size[2] = 1;
+    dispatch_command->grid_size[0] = 1;
+    dispatch_command->grid_size[1] = 1;
+    dispatch_command->grid_size[2] = 1;
+  }
+
+  iree_hal_amdgpu_aql_program_t program = {};
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_builder_end(&builder, &program));
+  iree_hal_amdgpu_aql_program_builder_deinitialize(&builder);
+
+  ASSERT_GE(program.block_count, 2u);
+  const iree_hal_amdgpu_command_buffer_block_header_t* first_block =
+      program.first_block;
+  const iree_hal_amdgpu_command_buffer_block_header_t* second_block =
+      iree_hal_amdgpu_aql_program_block_next(&block_pool, first_block);
+  ASSERT_NE(second_block, nullptr);
+
+  alignas(64) iree_hal_amdgpu_aql_packet_t packets[8] = {};
+  iree_hal_amdgpu_aql_ring_t ring = {};
+  ring.base = packets;
+  ring.mask = IREE_ARRAYSIZE(packets) - 1u;
+
+  uint16_t first_packet_headers[4] = {};
+  uint16_t first_packet_setups[4] = {};
+  iree_hal_amdgpu_kernarg_block_t first_kernarg_blocks[4] = {};
+  iree_hal_amdgpu_aql_block_processor_t first_processor = MakeProcessor(
+      &ring, first_block->aql_packet_count, first_packet_headers,
+      first_packet_setups, first_kernarg_blocks, KernargBlockCount(first_block),
+      IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_FLAG_NONE, IREE_HSA_FENCE_SCOPE_NONE,
+      IREE_HSA_FENCE_SCOPE_NONE, IREE_HSA_FENCE_SCOPE_NONE);
+
+  iree_hal_amdgpu_aql_block_processor_result_t first_result;
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_block_processor_invoke(
+      &first_processor, first_block, &first_result));
+  EXPECT_EQ(first_result.terminator,
+            IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_TERMINATOR_BRANCH);
+  EXPECT_EQ(first_result.target_block_ordinal, 1u);
+  EXPECT_EQ(first_result.packets.emitted, first_block->aql_packet_count);
+  EXPECT_EQ(first_result.kernargs.consumed, KernargBlockCount(first_block));
+  EXPECT_EQ(first_block->terminator_opcode,
+            IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BRANCH);
+
+  uint16_t second_packet_headers[4] = {};
+  uint16_t second_packet_setups[4] = {};
+  iree_hal_amdgpu_kernarg_block_t second_kernarg_blocks[4] = {};
+  iree_hal_amdgpu_aql_block_processor_t second_processor = MakeProcessor(
+      &ring, second_block->aql_packet_count, second_packet_headers,
+      second_packet_setups, second_kernarg_blocks,
+      KernargBlockCount(second_block),
+      IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_FLAG_FINAL_PAYLOAD_PACKET,
+      IREE_HSA_FENCE_SCOPE_NONE, IREE_HSA_FENCE_SCOPE_NONE,
+      IREE_HSA_FENCE_SCOPE_NONE);
+
+  iree_hal_amdgpu_aql_block_processor_result_t second_result;
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_block_processor_invoke(
+      &second_processor, second_block, &second_result));
+  EXPECT_EQ(second_result.terminator,
+            IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_TERMINATOR_RETURN);
+  EXPECT_EQ(second_result.packets.emitted, second_block->aql_packet_count);
+  EXPECT_EQ(second_result.kernargs.consumed, KernargBlockCount(second_block));
+  EXPECT_EQ(second_block->terminator_opcode,
+            IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_RETURN);
+
+  iree_hal_amdgpu_aql_program_release(&program);
+  iree_arena_block_pool_deinitialize(&block_pool);
+}
+
+TEST(AqlBlockProcessorTest, IndirectDispatchEmitsPatchThenUnpublishedDispatch) {
+  const iree_hal_amdgpu_device_kernels_t kernels = MakeTransferKernels();
+  const iree_hal_amdgpu_device_buffer_transfer_context_t transfer_context =
+      MakeTransferContext(&kernels);
+  const uint32_t workgroup_count[3] = {7, 5, 3};
+  IndirectDispatchBlock block = MakeIndirectDispatchBlock(workgroup_count);
+
+  alignas(64) iree_hal_amdgpu_aql_packet_t packets[8] = {};
+  iree_hal_amdgpu_aql_ring_t ring = {};
+  ring.base = packets;
+  ring.mask = IREE_ARRAYSIZE(packets) - 1u;
+  uint16_t packet_headers[2] = {};
+  uint16_t packet_setups[2] = {};
+  iree_hal_amdgpu_kernarg_block_t kernarg_blocks[2] = {};
+  iree_hal_amdgpu_aql_block_processor_t processor = MakeProcessor(
+      &ring, /*packet_count=*/2, packet_headers, packet_setups, kernarg_blocks,
+      /*kernarg_block_count=*/2,
+      IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_FLAG_FINAL_PAYLOAD_PACKET,
+      IREE_HSA_FENCE_SCOPE_NONE, IREE_HSA_FENCE_SCOPE_SYSTEM,
+      IREE_HSA_FENCE_SCOPE_NONE, &transfer_context);
+
+  iree_hal_amdgpu_aql_block_processor_result_t result;
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_block_processor_invoke(
+      &processor, &block.header, &result));
+
+  EXPECT_EQ(result.terminator,
+            IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_TERMINATOR_RETURN);
+  EXPECT_EQ(result.packets.recorded, 2u);
+  EXPECT_EQ(result.packets.emitted, 2u);
+  EXPECT_EQ(result.kernargs.consumed, 2u);
+
+  const iree_hal_amdgpu_aql_packet_t& patch_packet = packets[4];
+  const iree_hal_amdgpu_aql_packet_t& dispatch_packet = packets[5];
+  EXPECT_EQ(patch_packet.dispatch.kernel_object,
+            kPatchIndirectParamsKernelObject);
+  EXPECT_EQ(dispatch_packet.dispatch.kernel_object,
+            block.dispatch_command.kernel_object);
+  EXPECT_EQ(dispatch_packet.dispatch.kernarg_address, kernarg_blocks[1].data);
+  EXPECT_EQ(packet_headers[1], IREE_HSA_PACKET_TYPE_INVALID);
+
+  const auto* patch_args = reinterpret_cast<
+      const iree_hal_amdgpu_device_dispatch_patch_indirect_params_args_t*>(
+      kernarg_blocks[0].data);
+  EXPECT_EQ(patch_args->workgroup_count, workgroup_count);
+  EXPECT_EQ(patch_args->dispatch_packet, &packets[5].dispatch);
+  EXPECT_EQ(patch_args->implicit_args, nullptr);
+  EXPECT_EQ(patch_args->dispatch_header_setup,
+            (uint32_t)iree_hal_amdgpu_aql_make_header(
+                IREE_HSA_PACKET_TYPE_KERNEL_DISPATCH,
+                iree_hal_amdgpu_aql_packet_control(
+                    /*has_barrier=*/true, IREE_HSA_FENCE_SCOPE_NONE,
+                    IREE_HSA_FENCE_SCOPE_SYSTEM)) |
+                ((uint32_t)packet_setups[1] << 16));
+  EXPECT_TRUE(AqlHeaderHasBarrier(packet_headers[0]));
+  EXPECT_EQ(AqlHeaderAcquireScope(packet_headers[0]),
+            IREE_HSA_FENCE_SCOPE_AGENT);
+  EXPECT_EQ(AqlHeaderReleaseScope(packet_headers[0]),
+            IREE_HSA_FENCE_SCOPE_AGENT);
+}
+
+TEST_F(AqlBlockProcessorRecordedTest, RecordedTransfersEmitBlitPackets) {
+  alignas(16) uint8_t fill_target_storage[1024] = {};
+  alignas(16) uint8_t copy_source_storage[1024] = {};
+  alignas(16) uint8_t copy_target_storage[1024] = {};
+  BufferPtr fill_target_buffer =
+      CreateBuffer(fill_target_storage, sizeof(fill_target_storage));
+  BufferPtr copy_source_buffer =
+      CreateBuffer(copy_source_storage, sizeof(copy_source_storage));
+  BufferPtr copy_target_buffer =
+      CreateBuffer(copy_target_storage, sizeof(copy_target_storage));
+  ASSERT_NE(fill_target_buffer, nullptr);
+  ASSERT_NE(copy_source_buffer, nullptr);
+  ASSERT_NE(copy_target_buffer, nullptr);
+
+  CommandBufferPtr command_buffer = CreateCommandBuffer(/*binding_capacity=*/3);
+  ASSERT_NE(command_buffer, nullptr);
+  const uint32_t fill_pattern = 0xAABBCCDDu;
+  const uint8_t update_source[16] = {
+      0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77,
+      0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF,
+  };
+
+  IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer.get()));
+  IREE_ASSERT_OK(iree_hal_command_buffer_fill_buffer(
+      command_buffer.get(),
+      iree_hal_make_indirect_buffer_ref(/*buffer_slot=*/0, /*offset=*/32,
+                                        /*length=*/512),
+      &fill_pattern, sizeof(fill_pattern), IREE_HAL_FILL_FLAG_NONE));
+  IREE_ASSERT_OK(iree_hal_command_buffer_copy_buffer(
+      command_buffer.get(),
+      iree_hal_make_indirect_buffer_ref(/*buffer_slot=*/1, /*offset=*/0,
+                                        /*length=*/512),
+      iree_hal_make_indirect_buffer_ref(/*buffer_slot=*/2, /*offset=*/64,
+                                        /*length=*/512),
+      IREE_HAL_COPY_FLAG_NONE));
+  IREE_ASSERT_OK(iree_hal_command_buffer_update_buffer(
+      command_buffer.get(), update_source, /*source_offset=*/0,
+      iree_hal_make_indirect_buffer_ref(/*buffer_slot=*/2, /*offset=*/128,
+                                        sizeof(update_source)),
+      IREE_HAL_UPDATE_FLAG_NONE));
+  IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer.get()));
+
+  const iree_hal_amdgpu_aql_program_t* program =
+      iree_hal_amdgpu_aql_command_buffer_program(command_buffer.get());
+  ASSERT_NE(program->first_block, nullptr);
+  ASSERT_EQ(program->block_count, 1u);
+  const iree_hal_amdgpu_command_buffer_block_header_t* block =
+      program->first_block;
+  ASSERT_EQ(block->aql_packet_count, 3u);
+  ASSERT_EQ(KernargBlockCount(block), 3u);
+
+  const iree_hal_amdgpu_device_kernels_t kernels = MakeTransferKernels();
+  const iree_hal_amdgpu_device_buffer_transfer_context_t transfer_context =
+      MakeTransferContext(&kernels);
+
+  const std::array<iree_hal_buffer_binding_t, 3> bindings = {{
+      MakeBinding(fill_target_buffer.get(), /*offset=*/0,
+                  sizeof(fill_target_storage)),
+      MakeBinding(copy_source_buffer.get(), /*offset=*/0,
+                  sizeof(copy_source_storage)),
+      MakeBinding(copy_target_buffer.get(), /*offset=*/0,
+                  sizeof(copy_target_storage)),
+  }};
+  const iree_hal_buffer_binding_table_t binding_table = {bindings.size(),
+                                                         bindings.data()};
+
+  alignas(64) iree_hal_amdgpu_aql_packet_t packets[8] = {};
+  iree_hal_amdgpu_aql_ring_t ring = {};
+  ring.base = packets;
+  ring.mask = IREE_ARRAYSIZE(packets) - 1u;
+  uint16_t packet_headers[3] = {};
+  uint16_t packet_setups[3] = {};
+  iree_hal_amdgpu_kernarg_block_t kernarg_blocks[3] = {};
+  iree_hal_amdgpu_aql_block_processor_t processor = MakeProcessor(
+      &ring, /*packet_count=*/3, packet_headers, packet_setups, kernarg_blocks,
+      /*kernarg_block_count=*/3,
+      IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_FLAG_FINAL_PAYLOAD_PACKET,
+      IREE_HSA_FENCE_SCOPE_NONE, IREE_HSA_FENCE_SCOPE_SYSTEM,
+      IREE_HSA_FENCE_SCOPE_NONE, &transfer_context, command_buffer.get(),
+      binding_table);
+
+  iree_hal_amdgpu_aql_block_processor_result_t result;
+  IREE_ASSERT_OK(
+      iree_hal_amdgpu_aql_block_processor_invoke(&processor, block, &result));
+
+  EXPECT_EQ(result.terminator,
+            IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_TERMINATOR_RETURN);
+  EXPECT_EQ(result.packets.recorded, 3u);
+  EXPECT_EQ(result.packets.emitted, 3u);
+  EXPECT_EQ(result.kernargs.consumed, 3u);
+
+  const iree_hal_amdgpu_aql_packet_t& fill_packet = packets[4];
+  const iree_hal_amdgpu_aql_packet_t& copy_packet = packets[5];
+  const iree_hal_amdgpu_aql_packet_t& update_packet = packets[6];
+  EXPECT_EQ(fill_packet.dispatch.kernel_object, kFillBlockX16KernelObject);
+  EXPECT_EQ(copy_packet.dispatch.kernel_object, kCopyBlockX16KernelObject);
+  EXPECT_EQ(update_packet.dispatch.kernel_object, kCopyBlockX16KernelObject);
+  EXPECT_EQ(fill_packet.dispatch.kernarg_address, kernarg_blocks[0].data);
+  EXPECT_EQ(copy_packet.dispatch.kernarg_address, kernarg_blocks[1].data);
+  EXPECT_EQ(update_packet.dispatch.kernarg_address, kernarg_blocks[2].data);
+
+  const auto* fill_args =
+      reinterpret_cast<const iree_hal_amdgpu_device_buffer_fill_kernargs_t*>(
+          kernarg_blocks[0].data);
+  EXPECT_EQ(fill_args->target_ptr,
+            static_cast<void*>(fill_target_storage + 32));
+  EXPECT_EQ(fill_args->element_length, 32u);
+  EXPECT_EQ(fill_args->pattern, 0xAABBCCDDAABBCCDDull);
+
+  const auto* copy_args =
+      reinterpret_cast<const iree_hal_amdgpu_device_buffer_copy_kernargs_t*>(
+          kernarg_blocks[1].data);
+  EXPECT_EQ(copy_args->source_ptr,
+            static_cast<const void*>(copy_source_storage));
+  EXPECT_EQ(copy_args->target_ptr,
+            static_cast<void*>(copy_target_storage + 64));
+  EXPECT_EQ(copy_args->element_length, 32u);
+
+  const auto* update_args =
+      reinterpret_cast<const iree_hal_amdgpu_device_buffer_copy_kernargs_t*>(
+          kernarg_blocks[2].data);
+  EXPECT_EQ(update_args->target_ptr,
+            static_cast<void*>(copy_target_storage + 128));
+  EXPECT_EQ(update_args->element_length, 1u);
+  ASSERT_NE(update_args->source_ptr, nullptr);
+  EXPECT_EQ(std::memcmp(update_args->source_ptr, update_source,
+                        sizeof(update_source)),
+            0);
+
+  EXPECT_TRUE(AqlHeaderHasBarrier(packet_headers[2]));
+  EXPECT_EQ(AqlHeaderReleaseScope(packet_headers[2]),
+            IREE_HSA_FENCE_SCOPE_SYSTEM);
+}
+
+TEST(AqlBlockProcessorTest,
+     PacketHeadersOmitInteriorBarriersWithoutExecutionBarrier) {
+  const uint8_t dispatch_command_flags[2] = {
+      CommandFlags(
+          IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_USES_QUEUE_KERNARGS,
+          IREE_HSA_FENCE_SCOPE_NONE, IREE_HSA_FENCE_SCOPE_NONE),
+      CommandFlags(
+          IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_USES_QUEUE_KERNARGS,
+          IREE_HSA_FENCE_SCOPE_NONE, IREE_HSA_FENCE_SCOPE_NONE),
+  };
+  DispatchBlock<2> block = MakeDispatchBlock(dispatch_command_flags);
+
+  PacketHeaderSummary summary = {};
+  IREE_ASSERT_OK(InvokeAndSummarizeDispatchBlock(
+      block, IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_FLAG_FINAL_PAYLOAD_PACKET,
+      IREE_HSA_FENCE_SCOPE_NONE, IREE_HSA_FENCE_SCOPE_NONE,
+      IREE_HSA_FENCE_SCOPE_NONE, &summary));
+
+  EXPECT_EQ(summary.counts.total, 2u);
+  EXPECT_EQ(summary.counts.barrier, 1u);
+  EXPECT_FALSE(AqlHeaderHasBarrier(summary.headers.first));
+  EXPECT_EQ(AqlHeaderReleaseScope(summary.headers.first),
+            IREE_HSA_FENCE_SCOPE_NONE);
+  EXPECT_TRUE(AqlHeaderHasBarrier(summary.headers.last));
+}
+
+TEST(AqlBlockProcessorTest, PacketHeadersBarrierFirstPayloadForInlineWait) {
+  const uint8_t dispatch_command_flags[2] = {
+      CommandFlags(
+          IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_USES_QUEUE_KERNARGS,
+          IREE_HSA_FENCE_SCOPE_NONE, IREE_HSA_FENCE_SCOPE_NONE),
+      CommandFlags(
+          IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_USES_QUEUE_KERNARGS,
+          IREE_HSA_FENCE_SCOPE_NONE, IREE_HSA_FENCE_SCOPE_NONE),
+  };
+  DispatchBlock<2> block = MakeDispatchBlock(dispatch_command_flags);
+
+  PacketHeaderSummary summary = {};
+  IREE_ASSERT_OK(InvokeAndSummarizeDispatchBlock(
+      block, IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_FLAG_FINAL_PAYLOAD_PACKET,
+      IREE_HSA_FENCE_SCOPE_AGENT, IREE_HSA_FENCE_SCOPE_NONE,
+      IREE_HSA_FENCE_SCOPE_SYSTEM, &summary));
+
+  EXPECT_EQ(summary.counts.total, 2u);
+  EXPECT_EQ(summary.counts.barrier, 2u);
+  EXPECT_TRUE(AqlHeaderHasBarrier(summary.headers.first));
+  EXPECT_EQ(AqlHeaderAcquireScope(summary.headers.first),
+            IREE_HSA_FENCE_SCOPE_SYSTEM);
+  EXPECT_EQ(AqlHeaderReleaseScope(summary.headers.first),
+            IREE_HSA_FENCE_SCOPE_NONE);
+  EXPECT_TRUE(AqlHeaderHasBarrier(summary.headers.last));
+}
+
+TEST(AqlBlockProcessorTest, PacketHeadersPreserveExplicitMemoryBarrierScopes) {
+  const uint8_t dispatch_command_flags[2] = {
+      CommandFlags(
+          IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_USES_QUEUE_KERNARGS,
+          IREE_HSA_FENCE_SCOPE_SYSTEM, IREE_HSA_FENCE_SCOPE_SYSTEM),
+      CommandFlags(
+          IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_USES_QUEUE_KERNARGS |
+              IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_HAS_BARRIER,
+          IREE_HSA_FENCE_SCOPE_SYSTEM, IREE_HSA_FENCE_SCOPE_AGENT),
+  };
+  DispatchBlock<2> block = MakeDispatchBlock(dispatch_command_flags);
+
+  PacketHeaderSummary summary = {};
+  IREE_ASSERT_OK(InvokeAndSummarizeDispatchBlock(
+      block, IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_FLAG_FINAL_PAYLOAD_PACKET,
+      IREE_HSA_FENCE_SCOPE_NONE, IREE_HSA_FENCE_SCOPE_NONE,
+      IREE_HSA_FENCE_SCOPE_NONE, &summary));
+
+  EXPECT_EQ(summary.counts.total, 2u);
+  EXPECT_EQ(summary.counts.barrier, 1u);
+  EXPECT_FALSE(AqlHeaderHasBarrier(summary.headers.first));
+  EXPECT_EQ(AqlHeaderAcquireScope(summary.headers.first),
+            IREE_HSA_FENCE_SCOPE_SYSTEM);
+  EXPECT_EQ(AqlHeaderReleaseScope(summary.headers.first),
+            IREE_HSA_FENCE_SCOPE_SYSTEM);
+  EXPECT_TRUE(AqlHeaderHasBarrier(summary.headers.last));
+  EXPECT_EQ(AqlHeaderAcquireScope(summary.headers.last),
+            IREE_HSA_FENCE_SCOPE_SYSTEM);
+  EXPECT_EQ(AqlHeaderReleaseScope(summary.headers.last),
+            IREE_HSA_FENCE_SCOPE_AGENT);
+}
+
+TEST(AqlBlockProcessorTest, PacketHeadersHonorExplicitExecutionBarrier) {
+  const uint8_t dispatch_command_flags[3] = {
+      CommandFlags(
+          IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_USES_QUEUE_KERNARGS,
+          IREE_HSA_FENCE_SCOPE_NONE, IREE_HSA_FENCE_SCOPE_NONE),
+      CommandFlags(
+          IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_USES_QUEUE_KERNARGS |
+              IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_HAS_BARRIER,
+          IREE_HSA_FENCE_SCOPE_AGENT, IREE_HSA_FENCE_SCOPE_AGENT),
+      CommandFlags(
+          IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_USES_QUEUE_KERNARGS,
+          IREE_HSA_FENCE_SCOPE_NONE, IREE_HSA_FENCE_SCOPE_NONE),
+  };
+  DispatchBlock<3> block = MakeDispatchBlock(dispatch_command_flags);
+
+  alignas(64) iree_hal_amdgpu_aql_packet_t packets[8] = {};
+  iree_hal_amdgpu_aql_ring_t ring = {};
+  ring.base = packets;
+  ring.mask = IREE_ARRAYSIZE(packets) - 1u;
+  uint16_t packet_headers[3] = {};
+  uint16_t packet_setups[3] = {};
+  iree_hal_amdgpu_kernarg_block_t kernarg_blocks[3] = {};
+  iree_hal_amdgpu_aql_block_processor_t processor = MakeProcessor(
+      &ring, /*packet_count=*/3, packet_headers, packet_setups, kernarg_blocks,
+      /*kernarg_block_count=*/3,
+      IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_FLAG_FINAL_PAYLOAD_PACKET,
+      IREE_HSA_FENCE_SCOPE_NONE, IREE_HSA_FENCE_SCOPE_NONE,
+      IREE_HSA_FENCE_SCOPE_NONE);
+
+  iree_hal_amdgpu_aql_block_processor_result_t result;
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_block_processor_invoke(
+      &processor, &block.header, &result));
+  PacketHeaderSummary summary = SummarizePacketHeaders(packet_headers, 3);
+
+  EXPECT_EQ(summary.counts.total, 3u);
+  EXPECT_EQ(summary.counts.barrier, 2u);
+  EXPECT_FALSE(AqlHeaderHasBarrier(packet_headers[0]));
+  EXPECT_TRUE(AqlHeaderHasBarrier(packet_headers[1]));
+  EXPECT_TRUE(AqlHeaderHasBarrier(packet_headers[2]));
+}
+
+TEST(AqlBlockProcessorTest,
+     PacketHeadersApplySystemAcquireOnlyToFirstDynamicKernargPacket) {
+  const uint8_t dispatch_command_flags[2] = {
+      CommandFlags(
+          IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_USES_QUEUE_KERNARGS,
+          IREE_HSA_FENCE_SCOPE_NONE, IREE_HSA_FENCE_SCOPE_NONE),
+      CommandFlags(
+          IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_USES_QUEUE_KERNARGS,
+          IREE_HSA_FENCE_SCOPE_NONE, IREE_HSA_FENCE_SCOPE_NONE),
+  };
+  DispatchBlock<2> block = MakeDispatchBlock(dispatch_command_flags);
+
+  PacketHeaderSummary summary = {};
+  IREE_ASSERT_OK(InvokeAndSummarizeDispatchBlock(
+      block, IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_FLAG_FINAL_PAYLOAD_PACKET,
+      IREE_HSA_FENCE_SCOPE_SYSTEM, IREE_HSA_FENCE_SCOPE_NONE,
+      IREE_HSA_FENCE_SCOPE_SYSTEM, &summary));
+
+  EXPECT_EQ(summary.counts.total, 2u);
+  EXPECT_EQ(summary.counts.barrier, 2u);
+  EXPECT_EQ(summary.counts.system_acquire, 1u);
+  EXPECT_EQ(summary.counts.system_release, 0u);
+}
+
+}  // namespace
+}  // namespace iree::hal::amdgpu
diff --git a/runtime/src/iree/hal/drivers/amdgpu/aql_block_processor_timestamp.c b/runtime/src/iree/hal/drivers/amdgpu/aql_block_processor_timestamp.c
new file mode 100644
index 0000000..6189ca1
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/aql_block_processor_timestamp.c
@@ -0,0 +1,299 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/aql_block_processor_timestamp.h"
+
+#include <string.h>
+
+#include "iree/hal/drivers/amdgpu/util/aql_emitter.h"
+
+static bool iree_hal_amdgpu_aql_block_processor_timestamp_has_command_buffer(
+    const iree_hal_amdgpu_aql_block_processor_timestamp_t* processor) {
+  return processor->command_buffer.target.record != NULL;
+}
+
+static bool
+iree_hal_amdgpu_aql_block_processor_timestamp_has_command_buffer_storage(
+    const iree_hal_amdgpu_aql_block_processor_timestamp_t* processor) {
+  return processor->command_buffer.target.record ||
+         processor->command_buffer.packets.start.packet ||
+         processor->command_buffer.packets.start.pm4_ib_slot ||
+         processor->command_buffer.packets.end.packet ||
+         processor->command_buffer.packets.end.pm4_ib_slot;
+}
+
+static bool iree_hal_amdgpu_aql_block_processor_timestamp_has_dispatches(
+    const iree_hal_amdgpu_aql_block_processor_timestamp_t* processor) {
+  return processor->dispatches.count != 0;
+}
+
+static iree_hal_amdgpu_dispatch_timestamp_record_flags_t
+iree_hal_amdgpu_aql_block_processor_timestamp_dispatch_flags(
+    const iree_hal_amdgpu_aql_command_buffer_dispatch_summary_t* summary) {
+  iree_hal_amdgpu_dispatch_timestamp_record_flags_t flags =
+      IREE_HAL_AMDGPU_DISPATCH_TIMESTAMP_RECORD_FLAG_NONE;
+  if (iree_any_bit_set(
+          summary->metadata.dispatch_flags,
+          IREE_HAL_AMDGPU_COMMAND_BUFFER_DISPATCH_FLAG_INDIRECT_PARAMETERS)) {
+    flags |= IREE_HAL_AMDGPU_DISPATCH_TIMESTAMP_RECORD_FLAG_INDIRECT_PARAMETERS;
+  }
+  return flags;
+}
+
+static iree_hsa_signal_t
+iree_hal_amdgpu_aql_block_processor_timestamp_completion_signal(
+    const iree_amd_signal_t* signal) {
+  return (iree_hsa_signal_t){.handle = (uint64_t)(uintptr_t)signal};
+}
+
+static void iree_hal_amdgpu_aql_block_processor_timestamp_initialize_header(
+    uint32_t record_length, iree_hal_amdgpu_timestamp_record_type_t type,
+    uint32_t record_ordinal,
+    iree_hal_amdgpu_timestamp_record_header_t* out_header) {
+  out_header->record_length = record_length;
+  out_header->version = IREE_HAL_AMDGPU_TIMESTAMP_RECORD_VERSION_0;
+  out_header->type = type;
+  out_header->record_ordinal = record_ordinal;
+  out_header->reserved0 = 0;
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_block_processor_timestamp_validate_command_buffer(
+    const iree_hal_amdgpu_aql_block_processor_timestamp_t* processor) {
+  if (!iree_hal_amdgpu_aql_block_processor_timestamp_has_command_buffer(
+          processor) &&
+      iree_hal_amdgpu_aql_block_processor_timestamp_has_command_buffer_storage(
+          processor)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "command-buffer timestamp mode requires a timestamp record target");
+  }
+  if (!iree_hal_amdgpu_aql_block_processor_timestamp_has_command_buffer(
+          processor)) {
+    return iree_ok_status();
+  }
+  if (IREE_UNLIKELY(!processor->command_buffer.packets.start.packet ||
+                    !processor->command_buffer.packets.start.pm4_ib_slot ||
+                    !processor->command_buffer.packets.end.packet ||
+                    !processor->command_buffer.packets.end.pm4_ib_slot)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "command-buffer timestamp mode requires start and end PM4 packet "
+        "storage");
+  }
+  return iree_ok_status();
+}
+
+static void iree_hal_amdgpu_aql_block_processor_timestamp_emit_command_buffer(
+    const iree_hal_amdgpu_aql_block_processor_timestamp_t* processor,
+    iree_hal_amdgpu_aql_block_processor_timestamp_result_t* out_result) {
+  iree_hal_amdgpu_command_buffer_timestamp_record_t* record =
+      processor->command_buffer.target.record;
+  memset(record, 0, sizeof(*record));
+  iree_hal_amdgpu_aql_block_processor_timestamp_initialize_header(
+      sizeof(*record), IREE_HAL_AMDGPU_TIMESTAMP_RECORD_TYPE_COMMAND_BUFFER,
+      processor->command_buffer.metadata.record_ordinal, &record->header);
+  record->command_buffer_id =
+      processor->command_buffer.metadata.command_buffer_id;
+  record->block_ordinal = processor->command_buffer.metadata.block_ordinal;
+
+  out_result->command_buffer.start.header =
+      iree_hal_amdgpu_aql_emit_timestamp_start(
+          &processor->command_buffer.packets.start.packet->pm4_ib,
+          processor->command_buffer.packets.start.pm4_ib_slot,
+          processor->command_buffer.packets.start.control,
+          &record->ticks.start_tick, &out_result->command_buffer.start.setup);
+  out_result->command_buffer.end.header =
+      iree_hal_amdgpu_aql_emit_timestamp_end(
+          &processor->command_buffer.packets.end.packet->pm4_ib,
+          processor->command_buffer.packets.end.pm4_ib_slot,
+          processor->command_buffer.packets.end.control,
+          processor->command_buffer.packets.end.completion_signal,
+          &record->ticks.end_tick, &out_result->command_buffer.end.setup);
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_block_processor_timestamp_validate_dispatches(
+    const iree_hal_amdgpu_aql_block_processor_timestamp_t* processor) {
+  if (!iree_hal_amdgpu_aql_block_processor_timestamp_has_dispatches(
+          processor)) {
+    return iree_ok_status();
+  }
+  if (IREE_UNLIKELY(!processor->dispatches.values)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "dispatch timestamp mode requires dispatch sidecar records");
+  }
+  if (IREE_UNLIKELY(!processor->harvest.kernel_args ||
+                    !processor->harvest.packet ||
+                    !processor->harvest.kernarg_ptr)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "dispatch timestamp mode requires a harvest kernel, packet, and "
+        "kernargs");
+  }
+  for (uint32_t i = 0; i < processor->dispatches.count; ++i) {
+    const iree_hal_amdgpu_aql_block_processor_timestamp_dispatch_t* dispatch =
+        &processor->dispatches.values[i];
+    if (IREE_UNLIKELY(dispatch->ordinals.packet_ordinal >=
+                      processor->base.packets.count)) {
+      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                              "dispatch timestamp packet ordinal %" PRIu32
+                              " exceeds emitted payload packet count %" PRIu32,
+                              dispatch->ordinals.packet_ordinal,
+                              processor->base.packets.count);
+    }
+    if (IREE_UNLIKELY(!dispatch->target.completion_signal ||
+                      !dispatch->target.record)) {
+      return iree_make_status(
+          IREE_STATUS_INVALID_ARGUMENT,
+          "dispatch timestamp mode requires a completion signal and record");
+    }
+  }
+  return iree_ok_status();
+}
+
+static void iree_hal_amdgpu_aql_block_processor_timestamp_emit_dispatches(
+    const iree_hal_amdgpu_aql_block_processor_timestamp_t* processor,
+    iree_hal_amdgpu_aql_block_processor_timestamp_result_t* out_result) {
+  iree_hal_amdgpu_dispatch_timestamp_harvest_source_t* sources =
+      iree_hal_amdgpu_device_timestamp_emplace_dispatch_harvest(
+          processor->harvest.kernel_args, processor->dispatches.count,
+          &processor->harvest.packet->dispatch, processor->harvest.kernarg_ptr);
+  for (uint32_t i = 0; i < processor->dispatches.count; ++i) {
+    const iree_hal_amdgpu_aql_block_processor_timestamp_dispatch_t* dispatch =
+        &processor->dispatches.values[i];
+    iree_hal_amdgpu_dispatch_timestamp_record_t* record =
+        dispatch->target.record;
+    memset(record, 0, sizeof(*record));
+    iree_hal_amdgpu_aql_block_processor_timestamp_initialize_header(
+        sizeof(*record), IREE_HAL_AMDGPU_TIMESTAMP_RECORD_TYPE_DISPATCH,
+        dispatch->ordinals.record_ordinal, &record->header);
+    record->command_buffer_id = dispatch->metadata.command_buffer_id;
+    record->executable_id = dispatch->metadata.executable_id;
+    record->block_ordinal = dispatch->metadata.block_ordinal;
+    record->command_index = dispatch->metadata.command_index;
+    record->export_ordinal = dispatch->metadata.export_ordinal;
+    record->flags = dispatch->metadata.flags;
+
+    iree_hal_amdgpu_aql_packet_t* packet = iree_hal_amdgpu_aql_ring_packet(
+        processor->base.packets.ring,
+        processor->base.packets.first_id + dispatch->ordinals.packet_ordinal);
+    packet->dispatch.completion_signal =
+        iree_hal_amdgpu_aql_block_processor_timestamp_completion_signal(
+            dispatch->target.completion_signal);
+    sources[i].completion_signal = dispatch->target.completion_signal;
+    sources[i].ticks = &record->ticks;
+  }
+
+  processor->harvest.packet->dispatch.completion_signal =
+      processor->harvest.completion_signal;
+  out_result->dispatches.count = processor->dispatches.count;
+  out_result->harvest.header = iree_hal_amdgpu_aql_make_header(
+      IREE_HSA_PACKET_TYPE_KERNEL_DISPATCH, processor->harvest.packet_control);
+  out_result->harvest.setup = processor->harvest.packet->dispatch.setup;
+}
+
+void iree_hal_amdgpu_aql_block_processor_timestamp_initialize(
+    const iree_hal_amdgpu_aql_block_processor_timestamp_t* params,
+    iree_hal_amdgpu_aql_block_processor_timestamp_t* out_processor) {
+  *out_processor = *params;
+}
+
+void iree_hal_amdgpu_aql_block_processor_timestamp_deinitialize(
+    iree_hal_amdgpu_aql_block_processor_timestamp_t* processor) {
+  memset(processor, 0, sizeof(*processor));
+}
+
+iree_status_t
+iree_hal_amdgpu_aql_block_processor_timestamp_dispatch_list_initialize(
+    const iree_hal_amdgpu_aql_block_processor_timestamp_dispatch_list_params_t*
+        params,
+    iree_hal_amdgpu_aql_block_processor_timestamp_dispatch_list_t*
+        out_dispatches) {
+  IREE_ASSERT_ARGUMENT(params);
+  IREE_ASSERT_ARGUMENT(out_dispatches);
+  memset(out_dispatches, 0, sizeof(*out_dispatches));
+  if (params->summaries.count == 0) return iree_ok_status();
+  if (IREE_UNLIKELY(!params->summaries.first || !params->storage.dispatches ||
+                    !params->storage.completion_signals ||
+                    !params->storage.records)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "dispatch timestamp list requires summaries and target storage");
+  }
+  if (IREE_UNLIKELY(params->metadata.first_record_ordinal >
+                    UINT32_MAX - (params->summaries.count - 1u))) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "dispatch timestamp record ordinal range overflows uint32_t");
+  }
+
+  const iree_hal_amdgpu_aql_command_buffer_dispatch_summary_t* summary =
+      params->summaries.first;
+  for (uint32_t summary_ordinal = 0; summary_ordinal < params->summaries.count;
+       ++summary_ordinal) {
+    if (IREE_UNLIKELY(!summary)) {
+      return iree_make_status(
+          IREE_STATUS_INTERNAL,
+          "retained dispatch summary list ended after %" PRIu32 " of %" PRIu32
+          " entries",
+          summary_ordinal, params->summaries.count);
+    }
+
+    iree_hal_amdgpu_aql_block_processor_timestamp_dispatch_t* dispatch =
+        &params->storage.dispatches[summary_ordinal];
+    memset(dispatch, 0, sizeof(*dispatch));
+    dispatch->ordinals.packet_ordinal = summary->packets.dispatch_ordinal;
+    dispatch->ordinals.record_ordinal =
+        params->metadata.first_record_ordinal + summary_ordinal;
+    dispatch->metadata.command_buffer_id = params->metadata.command_buffer_id;
+    dispatch->metadata.executable_id = summary->metadata.executable_id;
+    dispatch->metadata.block_ordinal = params->metadata.block_ordinal;
+    dispatch->metadata.command_index = summary->metadata.command_index;
+    dispatch->metadata.export_ordinal = summary->metadata.export_ordinal;
+    dispatch->metadata.flags =
+        iree_hal_amdgpu_aql_block_processor_timestamp_dispatch_flags(summary);
+    dispatch->target.completion_signal =
+        &params->storage.completion_signals[summary_ordinal];
+    dispatch->target.record = &params->storage.records[summary_ordinal];
+    summary = summary->next;
+  }
+
+  out_dispatches->values = params->storage.dispatches;
+  out_dispatches->count = params->summaries.count;
+  return iree_ok_status();
+}
+
+iree_status_t iree_hal_amdgpu_aql_block_processor_timestamp_invoke(
+    const iree_hal_amdgpu_aql_block_processor_timestamp_t* processor,
+    const iree_hal_amdgpu_command_buffer_block_header_t* block,
+    iree_hal_amdgpu_aql_block_processor_timestamp_result_t* out_result) {
+  memset(out_result, 0, sizeof(*out_result));
+  iree_status_t status =
+      iree_hal_amdgpu_aql_block_processor_timestamp_validate_command_buffer(
+          processor);
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_aql_block_processor_timestamp_validate_dispatches(
+        processor);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_aql_block_processor_invoke(&processor->base, block,
+                                                        &out_result->base);
+  }
+  if (iree_status_is_ok(status) &&
+      iree_hal_amdgpu_aql_block_processor_timestamp_has_command_buffer(
+          processor)) {
+    iree_hal_amdgpu_aql_block_processor_timestamp_emit_command_buffer(
+        processor, out_result);
+  }
+  if (iree_status_is_ok(status) &&
+      iree_hal_amdgpu_aql_block_processor_timestamp_has_dispatches(processor)) {
+    iree_hal_amdgpu_aql_block_processor_timestamp_emit_dispatches(processor,
+                                                                  out_result);
+  }
+  return status;
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/aql_block_processor_timestamp.h b/runtime/src/iree/hal/drivers/amdgpu/aql_block_processor_timestamp.h
new file mode 100644
index 0000000..3c3e9ad
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/aql_block_processor_timestamp.h
@@ -0,0 +1,219 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_AQL_BLOCK_PROCESSOR_TIMESTAMP_H_
+#define IREE_HAL_DRIVERS_AMDGPU_AQL_BLOCK_PROCESSOR_TIMESTAMP_H_
+
+#include "iree/base/api.h"
+#include "iree/hal/drivers/amdgpu/abi/timestamp.h"
+#include "iree/hal/drivers/amdgpu/aql_block_processor.h"
+#include "iree/hal/drivers/amdgpu/device/timestamp.h"
+#include "iree/hal/drivers/amdgpu/util/pm4_emitter.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Dispatch timestamp sidecar for one already-recorded dispatch packet.
+typedef struct iree_hal_amdgpu_aql_block_processor_timestamp_dispatch_t {
+  // Precomputed ordinals in the payload and timestamp record streams.
+  struct {
+    // Payload packet ordinal whose completion signal must be timestamped.
+    uint32_t packet_ordinal;
+    // Dispatch timestamp record ordinal written into the record header.
+    uint32_t record_ordinal;
+  } ordinals;
+  // Correlation metadata copied into the fixed timestamp record.
+  struct {
+    // Producer-defined command-buffer identifier, or 0 for direct dispatch.
+    uint64_t command_buffer_id;
+    // Producer-defined executable identifier, or 0 when unavailable.
+    uint64_t executable_id;
+    // Command-buffer block ordinal containing this dispatch.
+    uint32_t block_ordinal;
+    // Program-global command index of this dispatch.
+    uint32_t command_index;
+    // Executable export ordinal dispatched.
+    uint32_t export_ordinal;
+    // Flags from iree_hal_amdgpu_dispatch_timestamp_record_flag_bits_t.
+    iree_hal_amdgpu_dispatch_timestamp_record_flags_t flags;
+  } metadata;
+  // Caller-owned storage patched or populated by the timestamp processor.
+  struct {
+    // Raw completion signal that receives CP dispatch timestamps.
+    iree_amd_signal_t* completion_signal;
+    // Fixed binary dispatch timestamp record populated by the processor.
+    iree_hal_amdgpu_dispatch_timestamp_record_t* record;
+  } target;
+} iree_hal_amdgpu_aql_block_processor_timestamp_dispatch_t;
+
+// Dispatch timestamp sidecars in command order.
+typedef struct iree_hal_amdgpu_aql_block_processor_timestamp_dispatch_list_t {
+  // Dispatch timestamp sidecars selected for this block.
+  const iree_hal_amdgpu_aql_block_processor_timestamp_dispatch_t* values;
+  // Number of entries in |values|.
+  uint32_t count;
+} iree_hal_amdgpu_aql_block_processor_timestamp_dispatch_list_t;
+
+// Parameters for materializing dispatch timestamp sidecars from retained
+// command-buffer dispatch summaries.
+typedef struct
+    iree_hal_amdgpu_aql_block_processor_timestamp_dispatch_list_params_t {
+  // Retained command-buffer dispatch summaries in command order.
+  struct {
+    // First retained dispatch summary in a linked block-local list.
+    const iree_hal_amdgpu_aql_command_buffer_dispatch_summary_t* first;
+    // Number of retained dispatch summaries expected in |first|.
+    uint32_t count;
+  } summaries;
+  // Timestamp metadata shared by every materialized dispatch sidecar.
+  struct {
+    // Producer-defined command-buffer identifier used for correlation.
+    uint64_t command_buffer_id;
+    // Command-buffer block ordinal containing these dispatches.
+    uint32_t block_ordinal;
+    // First dispatch timestamp record ordinal assigned to this list.
+    uint32_t first_record_ordinal;
+  } metadata;
+  // Caller-owned storage receiving sidecar records and timestamp targets.
+  struct {
+    // Sidecar array with capacity for |summaries.count| entries.
+    iree_hal_amdgpu_aql_block_processor_timestamp_dispatch_t* dispatches;
+    // Raw completion signal targets with capacity for |summaries.count|
+    // entries.
+    iree_amd_signal_t* completion_signals;
+    // Fixed timestamp records with capacity for |summaries.count| entries.
+    iree_hal_amdgpu_dispatch_timestamp_record_t* records;
+  } storage;
+} iree_hal_amdgpu_aql_block_processor_timestamp_dispatch_list_params_t;
+
+// Opt-in timestamp processor for one AQL command-buffer block.
+typedef struct iree_hal_amdgpu_aql_block_processor_timestamp_t {
+  // Base payload processor configuration. Its packet span excludes timestamp
+  // prefix, suffix, and harvest packets.
+  iree_hal_amdgpu_aql_block_processor_t base;
+  // Optional command-buffer/block timestamp record and PM4 packets.
+  struct {
+    // Correlation metadata copied into the fixed timestamp record.
+    struct {
+      // Command-buffer timestamp record ordinal written into the record header.
+      uint32_t record_ordinal;
+      // Producer-defined command-buffer identifier used for correlation.
+      uint64_t command_buffer_id;
+      // Command-buffer block ordinal, or UINT32_MAX for whole-execute records.
+      uint32_t block_ordinal;
+    } metadata;
+    // Caller-owned storage receiving command-buffer timestamp data.
+    struct {
+      // Fixed binary command-buffer timestamp record, or NULL when disabled.
+      iree_hal_amdgpu_command_buffer_timestamp_record_t* record;
+    } target;
+    // PM4 timestamp packets owned by the enclosing submission.
+    struct {
+      // Start timestamp packet emitted before the payload span.
+      struct {
+        // AQL packet receiving the top-of-pipe timestamp PM4 IB envelope.
+        iree_hal_amdgpu_aql_packet_t* packet;
+        // PM4 IB slot referenced by |packet|.
+        iree_hal_amdgpu_pm4_ib_slot_t* pm4_ib_slot;
+        // Packet control used when publishing |packet|.
+        iree_hal_amdgpu_aql_packet_control_t control;
+      } start;
+      // End timestamp packet emitted after the payload and harvest spans.
+      struct {
+        // AQL packet receiving the bottom-of-pipe timestamp PM4 IB envelope.
+        iree_hal_amdgpu_aql_packet_t* packet;
+        // PM4 IB slot referenced by |packet|.
+        iree_hal_amdgpu_pm4_ib_slot_t* pm4_ib_slot;
+        // Packet control used when publishing |packet|.
+        iree_hal_amdgpu_aql_packet_control_t control;
+        // Optional completion signal decremented when |packet| completes.
+        iree_hsa_signal_t completion_signal;
+      } end;
+    } packets;
+  } command_buffer;
+  // Optional dispatch timestamp records patched into payload packets.
+  iree_hal_amdgpu_aql_block_processor_timestamp_dispatch_list_t dispatches;
+  // Optional dispatch-timestamp harvest packet and kernargs.
+  struct {
+    // Builtin harvest kernel arguments.
+    const iree_hal_amdgpu_device_kernel_args_t* kernel_args;
+    // AQL dispatch packet emitted after payload dispatches complete.
+    iree_hal_amdgpu_aql_packet_t* packet;
+    // Queue-owned kernarg storage for the harvest dispatch.
+    void* kernarg_ptr;
+    // Packet control used when publishing |packet|.
+    iree_hal_amdgpu_aql_packet_control_t packet_control;
+    // Optional completion signal decremented when |packet| completes.
+    iree_hsa_signal_t completion_signal;
+  } harvest;
+} iree_hal_amdgpu_aql_block_processor_timestamp_t;
+
+// Result of invoking the timestamp processor on one block.
+typedef struct iree_hal_amdgpu_aql_block_processor_timestamp_result_t {
+  // Result reported by the embedded base payload processor.
+  iree_hal_amdgpu_aql_block_processor_result_t base;
+  // Command-buffer timestamp packet metadata produced when enabled.
+  struct {
+    // Start timestamp packet commit metadata.
+    struct {
+      // AQL packet header for the start timestamp packet.
+      uint16_t header;
+      // AQL packet setup word for the start timestamp packet.
+      uint16_t setup;
+    } start;
+    // End timestamp packet commit metadata.
+    struct {
+      // AQL packet header for the end timestamp packet.
+      uint16_t header;
+      // AQL packet setup word for the end timestamp packet.
+      uint16_t setup;
+    } end;
+  } command_buffer;
+  // Dispatch timestamp accounting produced when enabled.
+  struct {
+    // Number of dispatch timestamp records initialized and harvest sources set.
+    uint32_t count;
+  } dispatches;
+  // Dispatch timestamp harvest packet metadata produced when enabled.
+  struct {
+    // AQL packet header for the harvest dispatch.
+    uint16_t header;
+    // AQL packet setup word for the harvest dispatch.
+    uint16_t setup;
+  } harvest;
+} iree_hal_amdgpu_aql_block_processor_timestamp_result_t;
+
+// Initializes |out_processor| with borrowed submission storage.
+void iree_hal_amdgpu_aql_block_processor_timestamp_initialize(
+    const iree_hal_amdgpu_aql_block_processor_timestamp_t* params,
+    iree_hal_amdgpu_aql_block_processor_timestamp_t* out_processor);
+
+// Deinitializes |processor|. This currently releases no resources.
+void iree_hal_amdgpu_aql_block_processor_timestamp_deinitialize(
+    iree_hal_amdgpu_aql_block_processor_timestamp_t* processor);
+
+// Materializes dispatch timestamp sidecars from retained command-buffer
+// dispatch summaries and caller-owned target storage.
+iree_status_t
+iree_hal_amdgpu_aql_block_processor_timestamp_dispatch_list_initialize(
+    const iree_hal_amdgpu_aql_block_processor_timestamp_dispatch_list_params_t*
+        params,
+    iree_hal_amdgpu_aql_block_processor_timestamp_dispatch_list_t*
+        out_dispatches);
+
+// Invokes |processor| on |block| and populates payload packets plus any
+// timestamp sidecars selected by the caller.
+iree_status_t iree_hal_amdgpu_aql_block_processor_timestamp_invoke(
+    const iree_hal_amdgpu_aql_block_processor_timestamp_t* processor,
+    const iree_hal_amdgpu_command_buffer_block_header_t* block,
+    iree_hal_amdgpu_aql_block_processor_timestamp_result_t* out_result);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_AQL_BLOCK_PROCESSOR_TIMESTAMP_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/aql_block_processor_timestamp_test.cc b/runtime/src/iree/hal/drivers/amdgpu/aql_block_processor_timestamp_test.cc
new file mode 100644
index 0000000..166b98e
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/aql_block_processor_timestamp_test.cc
@@ -0,0 +1,405 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/aql_block_processor_timestamp.h"
+
+#include <cstring>
+
+#include "iree/hal/drivers/amdgpu/util/aql_emitter.h"
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace iree::hal::amdgpu {
+namespace {
+
+struct DirectDispatchBlock {
+  // Block header at the ABI-defined block base.
+  iree_hal_amdgpu_command_buffer_block_header_t header;
+  // Custom-direct dispatch command under test.
+  iree_hal_amdgpu_command_buffer_dispatch_command_t dispatch_command;
+  // Inline custom-direct kernarg tail copied into queue-owned kernargs.
+  uint64_t tail[2];
+  // Return terminator following the dispatch command and inline tail.
+  iree_hal_amdgpu_command_buffer_return_command_t return_command;
+};
+
+static void InitializeBlockHeader(
+    uint32_t block_length, uint32_t command_length, uint16_t command_count,
+    uint32_t aql_packet_count, uint32_t kernarg_length,
+    iree_hal_amdgpu_command_buffer_block_header_t* out_header) {
+  std::memset(out_header, 0, sizeof(*out_header));
+  out_header->magic = IREE_HAL_AMDGPU_COMMAND_BUFFER_BLOCK_MAGIC;
+  out_header->version = IREE_HAL_AMDGPU_COMMAND_BUFFER_BLOCK_VERSION_0;
+  out_header->header_length = sizeof(*out_header);
+  out_header->block_length = block_length;
+  out_header->command_offset = sizeof(*out_header);
+  out_header->command_length = command_length;
+  out_header->command_count = command_count;
+  out_header->aql_packet_count = aql_packet_count;
+  out_header->kernarg_length = kernarg_length;
+  out_header->initial_barrier_packet_count = aql_packet_count;
+  out_header->binding_source_offset = block_length;
+  out_header->rodata_offset = block_length;
+}
+
+static void InitializeReturnCommand(
+    iree_hal_amdgpu_command_buffer_return_command_t* out_command) {
+  std::memset(out_command, 0, sizeof(*out_command));
+  out_command->header.opcode = IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_RETURN;
+  out_command->header.length_qwords =
+      sizeof(*out_command) / IREE_HAL_AMDGPU_COMMAND_BUFFER_RECORD_ALIGNMENT;
+}
+
+static DirectDispatchBlock MakeDirectDispatchBlock() {
+  DirectDispatchBlock block;
+  const uint32_t dispatch_command_length =
+      sizeof(block.dispatch_command) + sizeof(block.tail);
+  const uint32_t command_length =
+      dispatch_command_length + sizeof(block.return_command);
+  InitializeBlockHeader(sizeof(block), command_length, /*command_count=*/2,
+                        /*aql_packet_count=*/1,
+                        /*kernarg_length=*/sizeof(block.tail), &block.header);
+
+  std::memset(&block.dispatch_command, 0, sizeof(block.dispatch_command));
+  block.dispatch_command.header.opcode =
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_DISPATCH;
+  block.dispatch_command.header.flags =
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_USES_QUEUE_KERNARGS;
+  block.dispatch_command.header.command_index = 12;
+  block.dispatch_command.header.length_qwords =
+      dispatch_command_length / IREE_HAL_AMDGPU_COMMAND_BUFFER_RECORD_ALIGNMENT;
+  block.dispatch_command.kernel_object = 0x123456789ABCDEF0ull;
+  block.dispatch_command.payload_reference = sizeof(block.dispatch_command);
+  block.dispatch_command.kernarg_length_qwords =
+      sizeof(block.tail) / sizeof(uint64_t);
+  block.dispatch_command.payload.tail_length_qwords =
+      sizeof(block.tail) / sizeof(uint64_t);
+  block.dispatch_command.kernarg_strategy =
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_CUSTOM_DIRECT;
+  block.dispatch_command.implicit_args_offset_qwords = UINT16_MAX;
+  block.dispatch_command.setup = 3;
+  block.dispatch_command.workgroup_size[0] = 4;
+  block.dispatch_command.workgroup_size[1] = 2;
+  block.dispatch_command.workgroup_size[2] = 1;
+  block.dispatch_command.grid_size[0] = 64;
+  block.dispatch_command.grid_size[1] = 8;
+  block.dispatch_command.grid_size[2] = 1;
+  block.tail[0] = 0x0A0B0C0D0E0F1011ull;
+  block.tail[1] = 0x1213141516171819ull;
+
+  block.header.dispatch_count = 1;
+  block.header.terminator_opcode = IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_RETURN;
+  InitializeReturnCommand(&block.return_command);
+  return block;
+}
+
+static iree_hal_amdgpu_device_kernel_args_t MakeHarvestKernelArgs() {
+  iree_hal_amdgpu_device_kernel_args_t kernel_args = {};
+  kernel_args.kernel_object = 0x12345678ull;
+  kernel_args.setup = 2;
+  kernel_args.workgroup_size[0] = 32;
+  kernel_args.workgroup_size[1] = 1;
+  kernel_args.workgroup_size[2] = 1;
+  kernel_args.kernarg_alignment = 16;
+  return kernel_args;
+}
+
+static iree_hal_amdgpu_aql_block_processor_t MakeBaseProcessor(
+    iree_hal_amdgpu_aql_ring_t* ring, uint16_t* packet_headers,
+    uint16_t* packet_setups, iree_hal_amdgpu_kernarg_block_t* kernarg_blocks) {
+  iree_hal_amdgpu_aql_block_processor_t processor = {};
+  processor.packets.ring = ring;
+  processor.packets.first_id = 4;
+  processor.packets.index_base = 0;
+  processor.packets.count = 1;
+  processor.packets.headers = packet_headers;
+  processor.packets.setups = packet_setups;
+  processor.kernargs.blocks = kernarg_blocks;
+  processor.kernargs.count = 1;
+  processor.submission.signal_release_scope = IREE_HSA_FENCE_SCOPE_SYSTEM;
+  processor.payload.acquire_scope = IREE_HSA_FENCE_SCOPE_NONE;
+  processor.flags =
+      IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_FLAG_FINAL_PAYLOAD_PACKET;
+  return processor;
+}
+
+static uint16_t AqlHeaderField(uint16_t header, uint32_t bit_offset,
+                               uint32_t bit_width) {
+  return (header >> bit_offset) & ((1u << bit_width) - 1u);
+}
+
+static iree_hsa_packet_type_t AqlHeaderType(uint16_t header) {
+  return (iree_hsa_packet_type_t)AqlHeaderField(
+      header, IREE_HSA_PACKET_HEADER_TYPE, IREE_HSA_PACKET_HEADER_WIDTH_TYPE);
+}
+
+TEST(AqlBlockProcessorTimestampTest,
+     CommandBufferTimestampInitializesRecordAndPackets) {
+  DirectDispatchBlock block = MakeDirectDispatchBlock();
+  alignas(64) iree_hal_amdgpu_aql_packet_t packets[8] = {};
+  iree_hal_amdgpu_aql_ring_t ring = {};
+  ring.base = packets;
+  ring.mask = IREE_ARRAYSIZE(packets) - 1u;
+  iree_hal_amdgpu_pm4_ib_slot_t pm4_ib_slots[8] = {};
+  uint16_t packet_headers[1] = {};
+  uint16_t packet_setups[1] = {};
+  iree_hal_amdgpu_kernarg_block_t kernarg_blocks[1] = {};
+  iree_hal_amdgpu_command_buffer_timestamp_record_t record = {};
+
+  iree_hal_amdgpu_aql_block_processor_timestamp_t processor = {};
+  processor.base =
+      MakeBaseProcessor(&ring, packet_headers, packet_setups, kernarg_blocks);
+  processor.command_buffer.metadata.record_ordinal = 7;
+  processor.command_buffer.metadata.command_buffer_id = 0xCAFEull;
+  processor.command_buffer.metadata.block_ordinal = 3;
+  processor.command_buffer.target.record = &record;
+  processor.command_buffer.packets.start.packet = &packets[2];
+  processor.command_buffer.packets.start.pm4_ib_slot = &pm4_ib_slots[2];
+  processor.command_buffer.packets.start.control =
+      iree_hal_amdgpu_aql_packet_control_barrier(IREE_HSA_FENCE_SCOPE_AGENT,
+                                                 IREE_HSA_FENCE_SCOPE_NONE);
+  processor.command_buffer.packets.end.packet = &packets[6];
+  processor.command_buffer.packets.end.pm4_ib_slot = &pm4_ib_slots[6];
+  processor.command_buffer.packets.end.control =
+      iree_hal_amdgpu_aql_packet_control_barrier(IREE_HSA_FENCE_SCOPE_NONE,
+                                                 IREE_HSA_FENCE_SCOPE_SYSTEM);
+  processor.command_buffer.packets.end.completion_signal.handle = 0x1234;
+
+  iree_hal_amdgpu_aql_block_processor_timestamp_result_t result;
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_block_processor_timestamp_invoke(
+      &processor, &block.header, &result));
+
+  EXPECT_EQ(result.base.terminator,
+            IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_TERMINATOR_RETURN);
+  EXPECT_EQ(record.header.record_length, sizeof(record));
+  EXPECT_EQ(record.header.version, IREE_HAL_AMDGPU_TIMESTAMP_RECORD_VERSION_0);
+  EXPECT_EQ(record.header.type,
+            IREE_HAL_AMDGPU_TIMESTAMP_RECORD_TYPE_COMMAND_BUFFER);
+  EXPECT_EQ(record.header.record_ordinal, 7u);
+  EXPECT_EQ(record.command_buffer_id, 0xCAFEull);
+  EXPECT_EQ(record.block_ordinal, 3u);
+  EXPECT_EQ(record.ticks.start_tick, 0u);
+  EXPECT_EQ(record.ticks.end_tick, 0u);
+
+  EXPECT_EQ(AqlHeaderType(result.command_buffer.start.header),
+            IREE_HSA_PACKET_TYPE_VENDOR_SPECIFIC);
+  EXPECT_EQ(result.command_buffer.start.setup, IREE_HSA_AMD_AQL_FORMAT_PM4_IB);
+  EXPECT_EQ(packets[2].pm4_ib.completion_signal.handle, 0u);
+  EXPECT_EQ(AqlHeaderType(result.command_buffer.end.header),
+            IREE_HSA_PACKET_TYPE_VENDOR_SPECIFIC);
+  EXPECT_EQ(result.command_buffer.end.setup, IREE_HSA_AMD_AQL_FORMAT_PM4_IB);
+  EXPECT_EQ(packets[6].pm4_ib.completion_signal.handle, 0x1234u);
+}
+
+TEST(AqlBlockProcessorTimestampTest,
+     DispatchTimestampPatchesCompletionSignalAndHarvestSource) {
+  DirectDispatchBlock block = MakeDirectDispatchBlock();
+  alignas(64) iree_hal_amdgpu_aql_packet_t packets[8] = {};
+  iree_hal_amdgpu_aql_ring_t ring = {};
+  ring.base = packets;
+  ring.mask = IREE_ARRAYSIZE(packets) - 1u;
+  uint16_t packet_headers[1] = {};
+  uint16_t packet_setups[1] = {};
+  iree_hal_amdgpu_kernarg_block_t kernarg_blocks[2] = {};
+  iree_amd_signal_t completion_signal = {};
+  iree_hal_amdgpu_dispatch_timestamp_record_t record = {};
+  iree_hal_amdgpu_device_kernel_args_t harvest_kernel_args =
+      MakeHarvestKernelArgs();
+
+  iree_hal_amdgpu_aql_block_processor_timestamp_dispatch_t dispatch = {};
+  dispatch.ordinals.packet_ordinal = 0;
+  dispatch.ordinals.record_ordinal = 5;
+  dispatch.metadata.command_buffer_id = 0xBEEF;
+  dispatch.metadata.executable_id = 0xFEED;
+  dispatch.metadata.block_ordinal = 9;
+  dispatch.metadata.command_index = block.dispatch_command.header.command_index;
+  dispatch.metadata.export_ordinal = 4;
+  dispatch.target.completion_signal = &completion_signal;
+  dispatch.target.record = &record;
+
+  iree_hal_amdgpu_aql_block_processor_timestamp_t processor = {};
+  processor.base =
+      MakeBaseProcessor(&ring, packet_headers, packet_setups, kernarg_blocks);
+  processor.dispatches.values = &dispatch;
+  processor.dispatches.count = 1;
+  processor.harvest.kernel_args = &harvest_kernel_args;
+  processor.harvest.packet = &packets[5];
+  processor.harvest.kernarg_ptr = kernarg_blocks[1].data;
+  processor.harvest.packet_control = iree_hal_amdgpu_aql_packet_control_barrier(
+      IREE_HSA_FENCE_SCOPE_AGENT, IREE_HSA_FENCE_SCOPE_SYSTEM);
+  processor.harvest.completion_signal.handle = 0x1234;
+
+  iree_hal_amdgpu_aql_block_processor_timestamp_result_t result;
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_block_processor_timestamp_invoke(
+      &processor, &block.header, &result));
+
+  EXPECT_EQ(result.base.packets.emitted, 1u);
+  EXPECT_EQ(result.dispatches.count, 1u);
+  EXPECT_EQ(packets[4].dispatch.completion_signal.handle,
+            (uint64_t)(uintptr_t)&completion_signal);
+  EXPECT_EQ(record.header.record_length, sizeof(record));
+  EXPECT_EQ(record.header.version, IREE_HAL_AMDGPU_TIMESTAMP_RECORD_VERSION_0);
+  EXPECT_EQ(record.header.type, IREE_HAL_AMDGPU_TIMESTAMP_RECORD_TYPE_DISPATCH);
+  EXPECT_EQ(record.header.record_ordinal, 5u);
+  EXPECT_EQ(record.command_buffer_id, 0xBEEFull);
+  EXPECT_EQ(record.executable_id, 0xFEEDull);
+  EXPECT_EQ(record.block_ordinal, 9u);
+  EXPECT_EQ(record.command_index, block.dispatch_command.header.command_index);
+  EXPECT_EQ(record.export_ordinal, 4u);
+  EXPECT_EQ(record.flags, IREE_HAL_AMDGPU_DISPATCH_TIMESTAMP_RECORD_FLAG_NONE);
+  EXPECT_EQ(record.ticks.start_tick, 0u);
+  EXPECT_EQ(record.ticks.end_tick, 0u);
+
+  const auto* harvest_args = reinterpret_cast<
+      const iree_hal_amdgpu_dispatch_timestamp_harvest_args_t*>(
+      kernarg_blocks[1].data);
+  ASSERT_EQ(harvest_args->source_count, 1u);
+  const iree_hal_amdgpu_dispatch_timestamp_harvest_source_t* source =
+      harvest_args->sources;
+  ASSERT_NE(source, nullptr);
+  EXPECT_EQ(source[0].completion_signal, &completion_signal);
+  EXPECT_EQ(source[0].ticks, &record.ticks);
+  EXPECT_EQ(packets[5].dispatch.completion_signal.handle, 0x1234u);
+  EXPECT_EQ(AqlHeaderType(result.harvest.header),
+            IREE_HSA_PACKET_TYPE_KERNEL_DISPATCH);
+  EXPECT_EQ(result.harvest.setup, harvest_kernel_args.setup);
+}
+
+TEST(AqlBlockProcessorTimestampTest,
+     DispatchListInitializesSidecarsFromSummaries) {
+  iree_hal_amdgpu_aql_command_buffer_dispatch_summary_t summaries[2] = {};
+  summaries[0].next = &summaries[1];
+  summaries[0].packets.first_ordinal = 0;
+  summaries[0].packets.dispatch_ordinal = 0;
+  summaries[0].metadata.executable_id = 0xA0;
+  summaries[0].metadata.command_index = 12;
+  summaries[0].metadata.export_ordinal = 2;
+  summaries[0].metadata.dispatch_flags =
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_DISPATCH_FLAG_NONE;
+  summaries[1].packets.first_ordinal = 1;
+  summaries[1].packets.dispatch_ordinal = 2;
+  summaries[1].metadata.executable_id = 0xB0;
+  summaries[1].metadata.command_index = 13;
+  summaries[1].metadata.export_ordinal = 4;
+  summaries[1].metadata.dispatch_flags =
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_DISPATCH_FLAG_INDIRECT_PARAMETERS;
+
+  iree_hal_amdgpu_aql_block_processor_timestamp_dispatch_t dispatches[2] = {};
+  iree_amd_signal_t completion_signals[2] = {};
+  iree_hal_amdgpu_dispatch_timestamp_record_t records[2] = {};
+  const iree_hal_amdgpu_aql_block_processor_timestamp_dispatch_list_params_t
+      params = {
+          .summaries =
+              {
+                  .first = summaries,
+                  .count = 2,
+              },
+          .metadata =
+              {
+                  .command_buffer_id = 0xCAFE,
+                  .block_ordinal = 5,
+                  .first_record_ordinal = 9,
+              },
+          .storage =
+              {
+                  .dispatches = dispatches,
+                  .completion_signals = completion_signals,
+                  .records = records,
+              },
+      };
+
+  iree_hal_amdgpu_aql_block_processor_timestamp_dispatch_list_t list;
+  IREE_ASSERT_OK(
+      iree_hal_amdgpu_aql_block_processor_timestamp_dispatch_list_initialize(
+          &params, &list));
+
+  ASSERT_EQ(list.values, dispatches);
+  ASSERT_EQ(list.count, 2u);
+  EXPECT_EQ(dispatches[0].ordinals.packet_ordinal, 0u);
+  EXPECT_EQ(dispatches[0].ordinals.record_ordinal, 9u);
+  EXPECT_EQ(dispatches[0].metadata.command_buffer_id, 0xCAFEull);
+  EXPECT_EQ(dispatches[0].metadata.executable_id, 0xA0ull);
+  EXPECT_EQ(dispatches[0].metadata.block_ordinal, 5u);
+  EXPECT_EQ(dispatches[0].metadata.command_index, 12u);
+  EXPECT_EQ(dispatches[0].metadata.export_ordinal, 2u);
+  EXPECT_EQ(dispatches[0].metadata.flags,
+            IREE_HAL_AMDGPU_DISPATCH_TIMESTAMP_RECORD_FLAG_NONE);
+  EXPECT_EQ(dispatches[0].target.completion_signal, &completion_signals[0]);
+  EXPECT_EQ(dispatches[0].target.record, &records[0]);
+
+  EXPECT_EQ(dispatches[1].ordinals.packet_ordinal, 2u);
+  EXPECT_EQ(dispatches[1].ordinals.record_ordinal, 10u);
+  EXPECT_EQ(dispatches[1].metadata.command_buffer_id, 0xCAFEull);
+  EXPECT_EQ(dispatches[1].metadata.executable_id, 0xB0ull);
+  EXPECT_EQ(dispatches[1].metadata.block_ordinal, 5u);
+  EXPECT_EQ(dispatches[1].metadata.command_index, 13u);
+  EXPECT_EQ(dispatches[1].metadata.export_ordinal, 4u);
+  EXPECT_EQ(dispatches[1].metadata.flags,
+            IREE_HAL_AMDGPU_DISPATCH_TIMESTAMP_RECORD_FLAG_INDIRECT_PARAMETERS);
+  EXPECT_EQ(dispatches[1].target.completion_signal, &completion_signals[1]);
+  EXPECT_EQ(dispatches[1].target.record, &records[1]);
+}
+
+TEST(AqlBlockProcessorTimestampTest, RejectsPartialCommandBufferTimestampPlan) {
+  DirectDispatchBlock block = MakeDirectDispatchBlock();
+  alignas(64) iree_hal_amdgpu_aql_packet_t packets[8] = {};
+  iree_hal_amdgpu_aql_ring_t ring = {};
+  ring.base = packets;
+  ring.mask = IREE_ARRAYSIZE(packets) - 1u;
+  iree_hal_amdgpu_pm4_ib_slot_t pm4_ib_slots[8] = {};
+  uint16_t packet_headers[1] = {};
+  uint16_t packet_setups[1] = {};
+  iree_hal_amdgpu_kernarg_block_t kernarg_blocks[1] = {};
+
+  iree_hal_amdgpu_aql_block_processor_timestamp_t processor = {};
+  processor.base =
+      MakeBaseProcessor(&ring, packet_headers, packet_setups, kernarg_blocks);
+  processor.command_buffer.packets.start.packet = &packets[2];
+  processor.command_buffer.packets.start.pm4_ib_slot = &pm4_ib_slots[2];
+
+  iree_hal_amdgpu_aql_block_processor_timestamp_result_t result;
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_INVALID_ARGUMENT,
+                        iree_hal_amdgpu_aql_block_processor_timestamp_invoke(
+                            &processor, &block.header, &result));
+  EXPECT_EQ(packets[4].dispatch.kernel_object, 0u);
+}
+
+TEST(AqlBlockProcessorTimestampTest, RejectsOutOfRangeDispatchPacketOrdinal) {
+  DirectDispatchBlock block = MakeDirectDispatchBlock();
+  alignas(64) iree_hal_amdgpu_aql_packet_t packets[8] = {};
+  iree_hal_amdgpu_aql_ring_t ring = {};
+  ring.base = packets;
+  ring.mask = IREE_ARRAYSIZE(packets) - 1u;
+  uint16_t packet_headers[1] = {};
+  uint16_t packet_setups[1] = {};
+  iree_hal_amdgpu_kernarg_block_t kernarg_blocks[2] = {};
+  iree_amd_signal_t completion_signal = {};
+  iree_hal_amdgpu_dispatch_timestamp_record_t record = {};
+  iree_hal_amdgpu_device_kernel_args_t harvest_kernel_args =
+      MakeHarvestKernelArgs();
+
+  iree_hal_amdgpu_aql_block_processor_timestamp_dispatch_t dispatch = {};
+  dispatch.ordinals.packet_ordinal = 1;
+  dispatch.target.completion_signal = &completion_signal;
+  dispatch.target.record = &record;
+
+  iree_hal_amdgpu_aql_block_processor_timestamp_t processor = {};
+  processor.base =
+      MakeBaseProcessor(&ring, packet_headers, packet_setups, kernarg_blocks);
+  processor.dispatches.values = &dispatch;
+  processor.dispatches.count = 1;
+  processor.harvest.kernel_args = &harvest_kernel_args;
+  processor.harvest.packet = &packets[5];
+  processor.harvest.kernarg_ptr = kernarg_blocks[1].data;
+
+  iree_hal_amdgpu_aql_block_processor_timestamp_result_t result;
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_INVALID_ARGUMENT,
+                        iree_hal_amdgpu_aql_block_processor_timestamp_invoke(
+                            &processor, &block.header, &result));
+}
+
+}  // namespace
+}  // namespace iree::hal::amdgpu
diff --git a/runtime/src/iree/hal/drivers/amdgpu/aql_command_buffer.c b/runtime/src/iree/hal/drivers/amdgpu/aql_command_buffer.c
new file mode 100644
index 0000000..aa9cbdd
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/aql_command_buffer.c
@@ -0,0 +1,2724 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/aql_command_buffer.h"
+
+#include <string.h>
+
+#include "iree/base/alignment.h"
+#include "iree/hal/drivers/amdgpu/abi/queue.h"
+#include "iree/hal/drivers/amdgpu/aql_command_buffer_profile.h"
+#include "iree/hal/drivers/amdgpu/buffer.h"
+#include "iree/hal/drivers/amdgpu/device/blit.h"
+#include "iree/hal/drivers/amdgpu/executable.h"
+#include "iree/hal/drivers/amdgpu/transient_buffer.h"
+#include "iree/hal/drivers/amdgpu/util/kernarg_ring.h"
+#include "iree/hal/utils/resource_set.h"
+
+//===----------------------------------------------------------------------===//
+// iree_hal_amdgpu_aql_command_buffer_t
+//===----------------------------------------------------------------------===//
+
+enum {
+  IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_STATIC_BUFFER_PAGE_CAPACITY_LOG2 = 9,
+  IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_STATIC_BUFFER_PAGE_CAPACITY = 512,
+  IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_STATIC_BUFFER_PAGE_MASK =
+      IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_STATIC_BUFFER_PAGE_CAPACITY - 1,
+  IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_RODATA_SEGMENT_PAGE_CAPACITY_LOG2 = 9,
+  IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_RODATA_SEGMENT_PAGE_CAPACITY = 512,
+  IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_RODATA_SEGMENT_PAGE_MASK =
+      IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_RODATA_SEGMENT_PAGE_CAPACITY - 1,
+};
+
+typedef enum iree_hal_amdgpu_aql_command_buffer_rodata_segment_flag_bits_e {
+  IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_RODATA_SEGMENT_FLAG_NONE = 0u,
+  IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_RODATA_SEGMENT_FLAG_PREPUBLISHED_KERNARGS =
+      1u << 0,
+} iree_hal_amdgpu_aql_command_buffer_rodata_segment_flag_bits_t;
+
+typedef uint32_t iree_hal_amdgpu_aql_command_buffer_rodata_segment_flags_t;
+
+typedef enum iree_hal_amdgpu_aql_command_buffer_recording_state_e {
+  IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_RECORDING_STATE_INITIAL = 0,
+  IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_RECORDING_STATE_RECORDING = 1,
+  IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_RECORDING_STATE_FINALIZED = 2,
+  IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_RECORDING_STATE_FAILED = 3,
+} iree_hal_amdgpu_aql_command_buffer_recording_state_t;
+
+typedef struct iree_hal_amdgpu_aql_command_buffer_static_buffer_page_t {
+  // Next page in ordinal order.
+  struct iree_hal_amdgpu_aql_command_buffer_static_buffer_page_t* next;
+  // Number of valid entries in |buffers|.
+  uint32_t count;
+  // Reserved bits for future page metadata.
+  uint32_t reserved0;
+  // Fixed-capacity direct buffer table.
+  iree_hal_buffer_t*
+      buffers[IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_STATIC_BUFFER_PAGE_CAPACITY];
+} iree_hal_amdgpu_aql_command_buffer_static_buffer_page_t;
+static_assert((IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_STATIC_BUFFER_PAGE_CAPACITY &
+               IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_STATIC_BUFFER_PAGE_MASK) == 0,
+              "static buffer page capacity must be a power-of-two");
+static_assert(
+    sizeof(iree_hal_amdgpu_aql_command_buffer_static_buffer_page_t) <=
+        IREE_HAL_AMDGPU_AQL_PROGRAM_DEFAULT_BLOCK_SIZE,
+    "static buffer page should fit in a default command-buffer block");
+
+typedef struct iree_hal_amdgpu_aql_command_buffer_rodata_segment_t {
+  // Command-buffer-owned immutable payload bytes.
+  uint8_t* data;
+  // Prepublished kernarg metadata used only when the segment flag is set.
+  struct {
+    // Dispatch command whose payload reference is patched during finalization.
+    iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command;
+  } prepublished;
+  // Byte length of |data|.
+  uint32_t length;
+  // Required alignment for the device pointer when materialized.
+  uint32_t alignment;
+  // Segment flags from
+  // iree_hal_amdgpu_aql_command_buffer_rodata_segment_flag_bits_t.
+  iree_hal_amdgpu_aql_command_buffer_rodata_segment_flags_t flags;
+  // Reserved bytes that must be zero.
+  uint32_t reserved0;
+} iree_hal_amdgpu_aql_command_buffer_rodata_segment_t;
+
+typedef struct iree_hal_amdgpu_aql_command_buffer_rodata_segment_page_t {
+  // Next page in ordinal order.
+  struct iree_hal_amdgpu_aql_command_buffer_rodata_segment_page_t* next;
+  // Number of valid entries in |segments|.
+  uint32_t count;
+  // Reserved bits for future page metadata.
+  uint32_t reserved0;
+  // Fixed-capacity rodata segment descriptors.
+  iree_hal_amdgpu_aql_command_buffer_rodata_segment_t
+      segments[IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_RODATA_SEGMENT_PAGE_CAPACITY];
+} iree_hal_amdgpu_aql_command_buffer_rodata_segment_page_t;
+static_assert((IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_RODATA_SEGMENT_PAGE_CAPACITY &
+               IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_RODATA_SEGMENT_PAGE_MASK) ==
+                  0,
+              "rodata segment page capacity must be a power-of-two");
+static_assert(
+    sizeof(iree_hal_amdgpu_aql_command_buffer_rodata_segment_page_t) <=
+        IREE_HAL_AMDGPU_AQL_PROGRAM_DEFAULT_BLOCK_SIZE,
+    "rodata segment page should fit in a default command-buffer block");
+
+typedef struct iree_hal_amdgpu_aql_command_buffer_dispatch_summary_block_t {
+  // Next block summary in recording order.
+  struct iree_hal_amdgpu_aql_command_buffer_dispatch_summary_block_t* next;
+  // Recorded command-buffer block this summary describes.
+  const iree_hal_amdgpu_command_buffer_block_header_t* header;
+  // Retained dispatch summaries for this block.
+  struct {
+    // First dispatch summary in command order.
+    iree_hal_amdgpu_aql_command_buffer_dispatch_summary_t* first;
+    // Final dispatch summary in command order.
+    iree_hal_amdgpu_aql_command_buffer_dispatch_summary_t* last;
+    // Number of dispatch summaries in this block.
+    uint32_t count;
+  } dispatch;
+} iree_hal_amdgpu_aql_command_buffer_dispatch_summary_block_t;
+
+typedef struct iree_hal_amdgpu_aql_command_buffer_t {
+  // Base HAL command-buffer resource.
+  iree_hal_command_buffer_t base;
+  // Host allocator used to allocate the command-buffer object.
+  iree_allocator_t host_allocator;
+  // Borrowed device allocator used during recording finalization.
+  iree_hal_allocator_t* device_allocator;
+  // Borrowed block pools used for command-buffer-owned storage.
+  struct {
+    // Block pool used for durable command-buffer program blocks.
+    iree_arena_block_pool_t* program;
+    // Block pool used for retained HAL resource sets.
+    iree_arena_block_pool_t* resource_set;
+  } block_pools;
+  // Physical device ordinal selected from the command buffer's queue affinity.
+  uint32_t device_ordinal;
+  // One-shot lifecycle state enforced even when generic HAL validation is off.
+  iree_hal_amdgpu_aql_command_buffer_recording_state_t recording_state;
+  // Arena owning recording-lifetime static buffer pages, rodata pages, and
+  // rodata payload bytes referenced by finalized program command records.
+  iree_arena_allocator_t recording_arena;
+  // Resource set retaining direct buffers and executables when not unretained.
+  iree_hal_resource_set_t* resource_set;
+  // Direct buffer ordinal table captured while recording.
+  struct {
+    // First static buffer page in ordinal order.
+    iree_hal_amdgpu_aql_command_buffer_static_buffer_page_t* first_page;
+    // Last static buffer page in ordinal order.
+    iree_hal_amdgpu_aql_command_buffer_static_buffer_page_t* current_page;
+    // Total direct buffer ordinals assigned.
+    uint32_t count;
+    // Reserved bytes for stable layout.
+    uint32_t reserved0;
+  } static_buffers;
+  // Device-visible storage containing prepublished static dispatch kernargs.
+  struct {
+    // Cold-path storage strategy selected during command-buffer creation.
+    iree_hal_amdgpu_aql_prepublished_kernarg_storage_t storage;
+    // Recording-time materialization plan for immutable kernarg templates.
+    struct {
+      // Number of prepublished kernarg templates recorded.
+      iree_host_size_t count;
+      // Minimum byte length required before base-alignment slack.
+      iree_host_size_t payload_length;
+      // Maximum device pointer alignment required by any template.
+      uint32_t max_alignment;
+    } templates;
+    // Materialized device-visible kernarg template allocation.
+    struct {
+      // Retained buffer containing all prepublished kernarg templates.
+      iree_hal_buffer_t* buffer;
+      // Device pointer to the first byte of |buffer|.
+      uint8_t* device_base;
+      // Allocated byte length of |buffer|.
+      iree_device_size_t byte_length;
+    } materialized;
+  } prepublished_kernargs;
+  // Immutable payload ordinal table captured while recording.
+  struct {
+    // First segment page in ordinal order.
+    iree_hal_amdgpu_aql_command_buffer_rodata_segment_page_t* first_page;
+    // Last segment page in ordinal order.
+    iree_hal_amdgpu_aql_command_buffer_rodata_segment_page_t* current_page;
+    // Total segment descriptors assigned ordinals.
+    uint32_t segment_count;
+  } rodata;
+  // Command-buffer profile metadata retained for profiling-enabled recording.
+  struct {
+    // Borrowed logical-device profiling metadata registry.
+    iree_hal_amdgpu_profile_metadata_registry_t* metadata;
+    // Producer-local profile command-buffer id, or 0 when profile metadata is
+    // not retained for this command buffer.
+    uint64_t id;
+  } profile;
+  // Recording-time sidecars retained for profiling and timestamp planning.
+  struct {
+    // Block-level dispatch summary list.
+    struct {
+      // First block with retained dispatch summaries.
+      iree_hal_amdgpu_aql_command_buffer_dispatch_summary_block_t* first;
+      // Current recording block with retained dispatch summaries.
+      iree_hal_amdgpu_aql_command_buffer_dispatch_summary_block_t* current;
+    } block;
+    // Total retained dispatch summaries across all blocks.
+    uint32_t count;
+  } dispatch_summaries;
+  // Builder used only during begin/end recording.
+  iree_hal_amdgpu_aql_program_builder_t builder;
+  // Program produced by end() and consumed by queue execution.
+  iree_hal_amdgpu_aql_program_t program;
+} iree_hal_amdgpu_aql_command_buffer_t;
+
+static const iree_hal_command_buffer_vtable_t
+    iree_hal_amdgpu_aql_command_buffer_vtable;
+
+static iree_hal_amdgpu_aql_command_buffer_t*
+iree_hal_amdgpu_aql_command_buffer_cast(iree_hal_command_buffer_t* base_value) {
+  IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_amdgpu_aql_command_buffer_vtable);
+  return (iree_hal_amdgpu_aql_command_buffer_t*)base_value;
+}
+
+static bool iree_hal_amdgpu_aql_command_buffer_retains_resources(
+    const iree_hal_amdgpu_aql_command_buffer_t* command_buffer) {
+  return !iree_all_bits_set(command_buffer->base.mode,
+                            IREE_HAL_COMMAND_BUFFER_MODE_UNRETAINED);
+}
+
+static bool iree_hal_amdgpu_aql_command_buffer_retains_profile_metadata(
+    const iree_hal_amdgpu_aql_command_buffer_t* command_buffer) {
+  return iree_all_bits_set(
+      command_buffer->base.mode,
+      IREE_HAL_COMMAND_BUFFER_MODE_RETAIN_PROFILE_METADATA);
+}
+
+static bool iree_hal_amdgpu_aql_command_buffer_retains_dispatch_summaries(
+    const iree_hal_amdgpu_aql_command_buffer_t* command_buffer) {
+  return iree_any_bit_set(
+      command_buffer->base.mode,
+      IREE_HAL_COMMAND_BUFFER_MODE_RETAIN_PROFILE_METADATA |
+          IREE_HAL_COMMAND_BUFFER_MODE_RETAIN_DISPATCH_METADATA);
+}
+
+static bool iree_hal_amdgpu_aql_command_buffer_validates(
+    const iree_hal_amdgpu_aql_command_buffer_t* command_buffer) {
+#if IREE_HAL_COMMAND_BUFFER_VALIDATION_ENABLE
+  return !iree_any_bit_set(command_buffer->base.mode,
+                           IREE_HAL_COMMAND_BUFFER_MODE_UNVALIDATED);
+#else
+  (void)command_buffer;
+  return false;
+#endif  // IREE_HAL_COMMAND_BUFFER_VALIDATION_ENABLE
+}
+
+static bool iree_hal_amdgpu_aql_command_buffer_prepublish_enabled(
+    const iree_hal_amdgpu_aql_command_buffer_t* command_buffer) {
+  return command_buffer->prepublished_kernargs.storage.strategy !=
+         IREE_HAL_AMDGPU_AQL_PREPUBLISHED_KERNARG_STORAGE_STRATEGY_DISABLED;
+}
+
+static void iree_hal_amdgpu_aql_command_buffer_reset_resources(
+    iree_hal_amdgpu_aql_command_buffer_t* command_buffer) {
+  iree_hal_resource_set_free(command_buffer->resource_set);
+  command_buffer->resource_set = NULL;
+  command_buffer->static_buffers.first_page = NULL;
+  command_buffer->static_buffers.current_page = NULL;
+  command_buffer->static_buffers.count = 0;
+  iree_hal_buffer_release(
+      command_buffer->prepublished_kernargs.materialized.buffer);
+  command_buffer->prepublished_kernargs.templates.count = 0;
+  command_buffer->prepublished_kernargs.templates.payload_length = 0;
+  command_buffer->prepublished_kernargs.templates.max_alignment = 1;
+  command_buffer->prepublished_kernargs.materialized.buffer = NULL;
+  command_buffer->prepublished_kernargs.materialized.device_base = NULL;
+  command_buffer->prepublished_kernargs.materialized.byte_length = 0;
+  iree_arena_reset(&command_buffer->recording_arena);
+  command_buffer->rodata.first_page = NULL;
+  command_buffer->rodata.current_page = NULL;
+  command_buffer->rodata.segment_count = 0;
+  command_buffer->dispatch_summaries.block.first = NULL;
+  command_buffer->dispatch_summaries.block.current = NULL;
+  command_buffer->dispatch_summaries.count = 0;
+}
+
+static void iree_hal_amdgpu_aql_command_buffer_discard_recording(
+    iree_hal_amdgpu_aql_command_buffer_t* command_buffer) {
+  iree_hal_amdgpu_aql_program_release(&command_buffer->program);
+  iree_hal_amdgpu_aql_program_builder_deinitialize(&command_buffer->builder);
+  iree_hal_amdgpu_aql_program_builder_initialize(
+      command_buffer->block_pools.program, &command_buffer->builder);
+  iree_hal_amdgpu_aql_command_buffer_reset_resources(command_buffer);
+}
+
+static iree_status_t iree_hal_amdgpu_aql_command_buffer_ensure_resource_set(
+    iree_hal_amdgpu_aql_command_buffer_t* command_buffer) {
+  if (!iree_hal_amdgpu_aql_command_buffer_retains_resources(command_buffer) ||
+      command_buffer->resource_set) {
+    return iree_ok_status();
+  }
+  IREE_TRACE_ZONE_BEGIN(z0);
+  iree_status_t status = iree_hal_resource_set_allocate(
+      command_buffer->block_pools.resource_set, &command_buffer->resource_set);
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+static iree_hal_buffer_t*
+iree_hal_amdgpu_aql_command_buffer_static_buffer_for_ordinal(
+    iree_hal_amdgpu_aql_command_buffer_t* command_buffer, uint32_t ordinal) {
+  if (IREE_UNLIKELY(ordinal >= command_buffer->static_buffers.count)) {
+    return NULL;
+  }
+
+  uint32_t page_ordinal =
+      ordinal >>
+      IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_STATIC_BUFFER_PAGE_CAPACITY_LOG2;
+  const uint32_t buffer_ordinal =
+      ordinal & IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_STATIC_BUFFER_PAGE_MASK;
+  iree_hal_amdgpu_aql_command_buffer_static_buffer_page_t* page =
+      command_buffer->static_buffers.first_page;
+  while (page_ordinal > 0 && page) {
+    page = page->next;
+    --page_ordinal;
+  }
+  if (IREE_UNLIKELY(!page || buffer_ordinal >= page->count)) return NULL;
+  return page->buffers[buffer_ordinal];
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_command_buffer_ensure_static_buffer_page(
+    iree_hal_amdgpu_aql_command_buffer_t* command_buffer,
+    iree_hal_amdgpu_aql_command_buffer_static_buffer_page_t** out_page) {
+  *out_page = NULL;
+  if (IREE_UNLIKELY(command_buffer->static_buffers.count == UINT32_MAX)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "command-buffer static buffer table overflow");
+  }
+  iree_hal_amdgpu_aql_command_buffer_static_buffer_page_t* page =
+      command_buffer->static_buffers.current_page;
+  if (page &&
+      page->count <
+          IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_STATIC_BUFFER_PAGE_CAPACITY) {
+    *out_page = page;
+    return iree_ok_status();
+  }
+
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, command_buffer->static_buffers.count);
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_arena_allocate(&command_buffer->recording_arena, sizeof(*page),
+                              (void**)&page));
+  memset(page, 0, sizeof(*page));
+  if (command_buffer->static_buffers.current_page) {
+    command_buffer->static_buffers.current_page->next = page;
+  } else {
+    command_buffer->static_buffers.first_page = page;
+  }
+  command_buffer->static_buffers.current_page = page;
+  *out_page = page;
+  IREE_TRACE_ZONE_END(z0);
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_aql_command_buffer_allocate_static_buffer(
+    iree_hal_amdgpu_aql_command_buffer_t* command_buffer,
+    iree_hal_buffer_t* buffer, uint32_t* out_ordinal) {
+  *out_ordinal = 0;
+  iree_hal_amdgpu_aql_command_buffer_static_buffer_page_t* page = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_command_buffer_ensure_static_buffer_page(
+          command_buffer, &page));
+  *out_ordinal = command_buffer->static_buffers.count;
+  ++command_buffer->static_buffers.count;
+  page->buffers[page->count++] = buffer;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_aql_command_buffer_record_static_buffer(
+    iree_hal_amdgpu_aql_command_buffer_t* command_buffer,
+    iree_hal_buffer_t* buffer, uint32_t* out_ordinal) {
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_command_buffer_ensure_resource_set(command_buffer));
+  if (command_buffer->resource_set) {
+    IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert(
+        command_buffer->resource_set, /*count=*/1, &buffer));
+  }
+  return iree_hal_amdgpu_aql_command_buffer_allocate_static_buffer(
+      command_buffer, buffer, out_ordinal);
+}
+
+static iree_hal_amdgpu_aql_command_buffer_rodata_segment_t*
+iree_hal_amdgpu_aql_command_buffer_rodata_segment_for_ordinal(
+    iree_hal_amdgpu_aql_command_buffer_t* command_buffer, uint64_t ordinal) {
+  if (IREE_UNLIKELY(ordinal >= command_buffer->rodata.segment_count)) {
+    return NULL;
+  }
+
+  uint64_t page_ordinal =
+      ordinal >>
+      IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_RODATA_SEGMENT_PAGE_CAPACITY_LOG2;
+  const uint32_t segment_ordinal =
+      (uint32_t)(ordinal &
+                 IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_RODATA_SEGMENT_PAGE_MASK);
+  iree_hal_amdgpu_aql_command_buffer_rodata_segment_page_t* page =
+      command_buffer->rodata.first_page;
+  while (page_ordinal > 0 && page) {
+    page = page->next;
+    --page_ordinal;
+  }
+  if (IREE_UNLIKELY(!page || segment_ordinal >= page->count)) return NULL;
+  return &page->segments[segment_ordinal];
+}
+
+static iree_status_t iree_hal_amdgpu_aql_command_buffer_append_rodata_segment(
+    iree_hal_amdgpu_aql_command_buffer_t* command_buffer, uint8_t* data,
+    iree_host_size_t byte_length, uint32_t alignment,
+    iree_hal_amdgpu_aql_command_buffer_rodata_segment_flags_t flags,
+    uint64_t* out_rodata_ordinal,
+    iree_hal_amdgpu_aql_command_buffer_rodata_segment_t** out_segment) {
+  *out_rodata_ordinal = 0;
+  if (out_segment) *out_segment = NULL;
+  if (IREE_UNLIKELY(byte_length > UINT32_MAX)) {
+    return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
+                            "command-buffer rodata segment too large");
+  }
+  if (IREE_UNLIKELY(!alignment || !iree_host_size_is_power_of_two(alignment))) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "command-buffer rodata segment alignment must be a non-zero power of "
+        "two");
+  }
+  if (IREE_UNLIKELY(command_buffer->rodata.segment_count == UINT32_MAX)) {
+    return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
+                            "command-buffer rodata segment count overflow");
+  }
+
+  iree_hal_amdgpu_aql_command_buffer_rodata_segment_page_t* page =
+      command_buffer->rodata.current_page;
+  if (!page ||
+      page->count ==
+          IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_RODATA_SEGMENT_PAGE_CAPACITY) {
+    IREE_TRACE_ZONE_BEGIN(z0);
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_arena_allocate(&command_buffer->recording_arena, sizeof(*page),
+                                (void**)&page));
+    memset(page, 0, sizeof(*page));
+    if (command_buffer->rodata.current_page) {
+      command_buffer->rodata.current_page->next = page;
+    } else {
+      command_buffer->rodata.first_page = page;
+    }
+    command_buffer->rodata.current_page = page;
+    IREE_TRACE_ZONE_END(z0);
+  }
+
+  const uint32_t ordinal = command_buffer->rodata.segment_count++;
+  iree_hal_amdgpu_aql_command_buffer_rodata_segment_t* segment =
+      &page->segments[page->count++];
+  *segment = (iree_hal_amdgpu_aql_command_buffer_rodata_segment_t){
+      .data = data,
+      .length = (uint32_t)byte_length,
+      .alignment = alignment,
+      .flags = flags,
+  };
+  *out_rodata_ordinal = ordinal;
+  if (out_segment) *out_segment = segment;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_aql_command_buffer_allocate_rodata_segment(
+    iree_hal_amdgpu_aql_command_buffer_t* command_buffer,
+    iree_host_size_t byte_length, uint32_t alignment,
+    iree_hal_amdgpu_aql_command_buffer_rodata_segment_flags_t flags,
+    uint8_t** out_data, uint64_t* out_rodata_ordinal,
+    iree_hal_amdgpu_aql_command_buffer_rodata_segment_t** out_segment) {
+  *out_data = NULL;
+  *out_rodata_ordinal = 0;
+  if (out_segment) *out_segment = NULL;
+  uint8_t* rodata = NULL;
+  IREE_RETURN_IF_ERROR(iree_arena_allocate_aligned(
+      &command_buffer->recording_arena,
+      iree_max((iree_host_size_t)1, byte_length), alignment, (void**)&rodata));
+  iree_status_t status =
+      iree_hal_amdgpu_aql_command_buffer_append_rodata_segment(
+          command_buffer, rodata, byte_length, alignment, flags,
+          out_rodata_ordinal, out_segment);
+  if (iree_status_is_ok(status)) {
+    *out_data = rodata;
+  }
+  return status;
+}
+
+static iree_status_t iree_hal_amdgpu_aql_command_buffer_record_rodata(
+    iree_hal_amdgpu_aql_command_buffer_t* command_buffer,
+    const void* source_buffer, iree_host_size_t source_offset,
+    iree_host_size_t source_length, uint64_t* out_rodata_ordinal) {
+  *out_rodata_ordinal = 0;
+  iree_host_size_t source_end = 0;
+  if (IREE_UNLIKELY(!iree_host_size_checked_add(source_offset, source_length,
+                                                &source_end))) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "command-buffer update source span overflows host size "
+        "(offset=%" PRIhsz ", length=%" PRIhsz ")",
+        source_offset, source_length);
+  }
+  (void)source_end;
+
+  uint8_t* rodata = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_command_buffer_allocate_rodata_segment(
+          command_buffer, source_length, /*alignment=*/1,
+          IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_RODATA_SEGMENT_FLAG_NONE, &rodata,
+          out_rodata_ordinal, /*out_segment=*/NULL));
+  if (source_length > 0) {
+    memcpy(rodata, (const uint8_t*)source_buffer + source_offset,
+           source_length);
+  }
+  return iree_ok_status();
+}
+
+static bool iree_hal_amdgpu_aql_command_buffer_rodata_is_prepublished_kernarg(
+    const iree_hal_amdgpu_aql_command_buffer_rodata_segment_t* segment) {
+  return iree_all_bits_set(
+      segment->flags,
+      IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_RODATA_SEGMENT_FLAG_PREPUBLISHED_KERNARGS);
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_command_buffer_append_prepublished_kernarg_template(
+    iree_hal_amdgpu_aql_command_buffer_t* command_buffer,
+    const iree_hal_amdgpu_aql_command_buffer_rodata_segment_t* segment) {
+  iree_host_size_t aligned_payload_length = 0;
+  if (IREE_UNLIKELY(!iree_host_size_checked_align(
+          command_buffer->prepublished_kernargs.templates.payload_length,
+          segment->alignment, &aligned_payload_length))) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "prepublished command-buffer kernarg offset overflow");
+  }
+  command_buffer->prepublished_kernargs.templates.payload_length =
+      aligned_payload_length;
+  if (IREE_UNLIKELY(!iree_host_size_checked_add(
+          command_buffer->prepublished_kernargs.templates.payload_length,
+          iree_max((iree_host_size_t)1, (iree_host_size_t)segment->length),
+          &command_buffer->prepublished_kernargs.templates.payload_length))) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "prepublished command-buffer kernarg storage overflow");
+  }
+  command_buffer->prepublished_kernargs.templates.max_alignment =
+      iree_max(command_buffer->prepublished_kernargs.templates.max_alignment,
+               segment->alignment);
+  ++command_buffer->prepublished_kernargs.templates.count;
+  return iree_ok_status();
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_command_buffer_copy_prepublished_kernargs(
+    iree_hal_amdgpu_aql_command_buffer_t* command_buffer,
+    iree_hal_buffer_mapping_t* mapping, uint8_t* device_base) {
+  uint8_t* host_base = mapping->contents.data;
+  const uintptr_t device_base_address = (uintptr_t)device_base;
+  iree_host_size_t payload_offset = 0;
+  for (iree_hal_amdgpu_aql_command_buffer_rodata_segment_page_t* page =
+           command_buffer->rodata.first_page;
+       page; page = page->next) {
+    for (uint32_t i = 0; i < page->count; ++i) {
+      iree_hal_amdgpu_aql_command_buffer_rodata_segment_t* segment =
+          &page->segments[i];
+      if (!iree_hal_amdgpu_aql_command_buffer_rodata_is_prepublished_kernarg(
+              segment)) {
+        continue;
+      }
+      const uintptr_t unaligned_address = device_base_address + payload_offset;
+      const uintptr_t aligned_address =
+          (unaligned_address + segment->alignment - 1) &
+          ~((uintptr_t)segment->alignment - 1);
+      payload_offset =
+          (iree_host_size_t)(aligned_address - device_base_address);
+      if (IREE_UNLIKELY(payload_offset > UINT32_MAX)) {
+        return iree_make_status(
+            IREE_STATUS_OUT_OF_RANGE,
+            "prepublished command-buffer kernarg offset exceeds uint32_t");
+      }
+      if (IREE_UNLIKELY(!segment->prepublished.dispatch_command)) {
+        return iree_make_status(
+            IREE_STATUS_FAILED_PRECONDITION,
+            "prepublished command-buffer kernarg has no dispatch command");
+      }
+      if (segment->length > 0) {
+        memcpy(host_base + payload_offset, segment->data, segment->length);
+      }
+      segment->prepublished.dispatch_command->payload_reference =
+          (uint32_t)payload_offset;
+      payload_offset +=
+          iree_max((iree_host_size_t)1, (iree_host_size_t)segment->length);
+    }
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_command_buffer_verify_prepublished_kernarg_storage(
+    const iree_hal_amdgpu_aql_command_buffer_t* command_buffer,
+    iree_hal_memory_type_t required_type, iree_hal_buffer_t* buffer) {
+  const iree_hal_memory_type_t actual_type =
+      iree_hal_buffer_memory_type(buffer);
+  if (IREE_LIKELY(iree_all_bits_set(actual_type, required_type))) {
+    return iree_ok_status();
+  }
+#if IREE_STATUS_MODE
+  iree_bitfield_string_temp_t required_temp;
+  iree_bitfield_string_temp_t actual_temp;
+  const iree_string_view_t required_string =
+      iree_hal_memory_type_format(required_type, &required_temp);
+  const iree_string_view_t actual_string =
+      iree_hal_memory_type_format(actual_type, &actual_temp);
+  return iree_make_status(
+      IREE_STATUS_FAILED_PRECONDITION,
+      "prepublished command-buffer kernarg strategy %u requires "
+      "memory_type=%.*s but allocation returned memory_type=%.*s",
+      command_buffer->prepublished_kernargs.storage.strategy,
+      (int)required_string.size, required_string.data, (int)actual_string.size,
+      actual_string.data);
+#else
+  return iree_make_status(
+      IREE_STATUS_FAILED_PRECONDITION,
+      "prepublished command-buffer kernarg allocation returned incompatible "
+      "memory type");
+#endif  // IREE_STATUS_MODE
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_command_buffer_materialize_prepublished_kernargs(
+    iree_hal_amdgpu_aql_command_buffer_t* command_buffer) {
+  const iree_host_size_t template_count =
+      command_buffer->prepublished_kernargs.templates.count;
+  if (template_count == 0) {
+    return iree_ok_status();
+  }
+  const iree_host_size_t payload_length =
+      command_buffer->prepublished_kernargs.templates.payload_length;
+  const uint32_t max_alignment =
+      command_buffer->prepublished_kernargs.templates.max_alignment;
+  if (IREE_UNLIKELY(!iree_hal_amdgpu_aql_command_buffer_prepublish_enabled(
+          command_buffer))) {
+    return iree_make_status(
+        IREE_STATUS_FAILED_PRECONDITION,
+        "command buffer recorded prepublished kernargs without a storage "
+        "strategy");
+  }
+  if (IREE_UNLIKELY(payload_length > UINT32_MAX)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "prepublished command-buffer kernarg payload length %" PRIhsz
+        " exceeds uint32_t reference max",
+        payload_length);
+  }
+  if (IREE_UNLIKELY(payload_length > (iree_host_size_t)IREE_DEVICE_SIZE_MAX)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "prepublished command-buffer kernarg storage length %" PRIhsz
+        " exceeds device size max %" PRIdsz,
+        payload_length, IREE_DEVICE_SIZE_MAX);
+  }
+
+  iree_host_size_t allocation_length = 0;
+  if (IREE_UNLIKELY(
+          !iree_host_size_checked_add(payload_length, max_alignment - 1,
+                                      &allocation_length) ||
+          allocation_length > (iree_host_size_t)IREE_DEVICE_SIZE_MAX)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "prepublished command-buffer kernarg allocation length overflow");
+  }
+
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, allocation_length);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, template_count);
+
+  iree_hal_buffer_params_t params =
+      command_buffer->prepublished_kernargs.storage.buffer_params;
+  params.queue_affinity = command_buffer->base.queue_affinity;
+
+  iree_hal_buffer_t* template_buffer = NULL;
+  iree_status_t status = iree_hal_allocator_allocate_buffer(
+      command_buffer->device_allocator, params,
+      (iree_device_size_t)allocation_length, &template_buffer);
+  iree_hal_buffer_mapping_t mapping;
+  memset(&mapping, 0, sizeof(mapping));
+  uint8_t* device_base = NULL;
+  if (iree_status_is_ok(status)) {
+    device_base =
+        (uint8_t*)iree_hal_amdgpu_buffer_device_pointer(template_buffer);
+    if (IREE_UNLIKELY(!device_base)) {
+      status = iree_make_status(
+          IREE_STATUS_INVALID_ARGUMENT,
+          "prepublished command-buffer kernarg buffer must be backed by an "
+          "AMDGPU allocation");
+    }
+  }
+  if (iree_status_is_ok(status)) {
+    status =
+        iree_hal_amdgpu_aql_command_buffer_verify_prepublished_kernarg_storage(
+            command_buffer, params.type, template_buffer);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_buffer_map_range(
+        template_buffer, IREE_HAL_MAPPING_MODE_SCOPED,
+        IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE, /*byte_offset=*/0,
+        (iree_device_size_t)allocation_length, &mapping);
+  }
+  if (iree_status_is_ok(status)) {
+    memset(mapping.contents.data, 0, allocation_length);
+    status = iree_hal_amdgpu_aql_command_buffer_copy_prepublished_kernargs(
+        command_buffer, &mapping, device_base);
+  }
+  if (mapping.buffer) {
+    status = iree_status_join(status, iree_hal_buffer_unmap_range(&mapping));
+  }
+  if (iree_status_is_ok(status)) {
+    command_buffer->prepublished_kernargs.materialized.buffer = template_buffer;
+    command_buffer->prepublished_kernargs.materialized.device_base =
+        device_base;
+    command_buffer->prepublished_kernargs.materialized.byte_length =
+        (iree_device_size_t)allocation_length;
+  } else {
+    iree_hal_buffer_release(template_buffer);
+  }
+
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+//===----------------------------------------------------------------------===//
+// Lifecycle
+//===----------------------------------------------------------------------===//
+
+static void iree_hal_amdgpu_aql_command_buffer_destroy(
+    iree_hal_command_buffer_t* base_command_buffer);
+
+iree_status_t iree_hal_amdgpu_aql_command_buffer_create(
+    iree_hal_allocator_t* device_allocator, iree_hal_command_buffer_mode_t mode,
+    iree_hal_command_category_t command_categories,
+    iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity,
+    iree_host_size_t device_ordinal,
+    iree_hal_amdgpu_aql_prepublished_kernarg_storage_t
+        prepublished_kernarg_storage,
+    iree_hal_amdgpu_profile_metadata_registry_t* profile_metadata,
+    iree_arena_block_pool_t* program_block_pool,
+    iree_arena_block_pool_t* resource_set_block_pool,
+    iree_allocator_t host_allocator,
+    iree_hal_command_buffer_t** out_command_buffer) {
+  IREE_ASSERT_ARGUMENT(device_allocator);
+  IREE_ASSERT_ARGUMENT(out_command_buffer);
+  *out_command_buffer = NULL;
+
+  if (iree_any_bit_set(mode,
+                       IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION) &&
+      !iree_all_bits_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "ALLOW_INLINE_EXECUTION requires ONE_SHOT mode");
+  }
+  if (IREE_UNLIKELY(!program_block_pool)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "command-buffer program block pool is required");
+  }
+  if (IREE_UNLIKELY(!resource_set_block_pool)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "command-buffer resource set block pool is "
+                            "required");
+  }
+  const bool retain_profile_metadata = iree_all_bits_set(
+      mode, IREE_HAL_COMMAND_BUFFER_MODE_RETAIN_PROFILE_METADATA);
+  if (IREE_UNLIKELY(retain_profile_metadata && !profile_metadata)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "command-buffer profile metadata is required");
+  }
+  if (IREE_UNLIKELY(device_ordinal > UINT32_MAX)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "command-buffer device ordinal %" PRIhsz
+                            " exceeds uint32_t storage",
+                            device_ordinal);
+  }
+  switch (prepublished_kernarg_storage.strategy) {
+    case IREE_HAL_AMDGPU_AQL_PREPUBLISHED_KERNARG_STORAGE_STRATEGY_DISABLED:
+    case IREE_HAL_AMDGPU_AQL_PREPUBLISHED_KERNARG_STORAGE_STRATEGY_DEVICE_FINE_HOST_COHERENT:
+      break;
+    default:
+      return iree_make_status(
+          IREE_STATUS_INVALID_ARGUMENT,
+          "unsupported prepublished command-buffer kernarg storage strategy %u",
+          prepublished_kernarg_storage.strategy);
+  }
+
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  iree_host_size_t total_size = 0;
+  iree_host_size_t validation_state_offset = 0;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, IREE_STRUCT_LAYOUT(
+              sizeof(iree_hal_amdgpu_aql_command_buffer_t), &total_size,
+              IREE_STRUCT_FIELD(iree_hal_command_buffer_validation_state_size(
+                                    mode, binding_capacity),
+                                uint8_t, &validation_state_offset)));
+
+  iree_hal_amdgpu_aql_command_buffer_t* command_buffer = NULL;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_allocator_malloc(host_allocator, total_size,
+                                (void**)&command_buffer));
+  memset(command_buffer, 0, sizeof(*command_buffer));
+  iree_hal_command_buffer_initialize(
+      device_allocator, mode, command_categories, queue_affinity,
+      binding_capacity, (uint8_t*)command_buffer + validation_state_offset,
+      &iree_hal_amdgpu_aql_command_buffer_vtable, &command_buffer->base);
+  command_buffer->host_allocator = host_allocator;
+  command_buffer->device_allocator = device_allocator;
+  command_buffer->block_pools.program = program_block_pool;
+  command_buffer->block_pools.resource_set = resource_set_block_pool;
+  command_buffer->profile.metadata = profile_metadata;
+  command_buffer->device_ordinal = (uint32_t)device_ordinal;
+  command_buffer->prepublished_kernargs.storage = prepublished_kernarg_storage;
+  command_buffer->prepublished_kernargs.templates.max_alignment = 1;
+  iree_arena_initialize(program_block_pool, &command_buffer->recording_arena);
+  iree_hal_amdgpu_aql_program_builder_initialize(program_block_pool,
+                                                 &command_buffer->builder);
+
+  iree_status_t status = iree_ok_status();
+  if (retain_profile_metadata) {
+    status = iree_hal_amdgpu_profile_metadata_register_command_buffer(
+        profile_metadata, mode, command_categories, queue_affinity,
+        device_ordinal, &command_buffer->profile.id);
+  }
+  if (iree_status_is_ok(status)) {
+    *out_command_buffer = &command_buffer->base;
+  } else {
+    iree_hal_amdgpu_aql_command_buffer_destroy(&command_buffer->base);
+  }
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+static void iree_hal_amdgpu_aql_command_buffer_destroy(
+    iree_hal_command_buffer_t* base_command_buffer) {
+  iree_hal_amdgpu_aql_command_buffer_t* command_buffer =
+      iree_hal_amdgpu_aql_command_buffer_cast(base_command_buffer);
+  iree_allocator_t host_allocator = command_buffer->host_allocator;
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  iree_hal_amdgpu_aql_program_release(&command_buffer->program);
+  iree_hal_amdgpu_aql_program_builder_deinitialize(&command_buffer->builder);
+  iree_hal_amdgpu_aql_command_buffer_reset_resources(command_buffer);
+  iree_arena_deinitialize(&command_buffer->recording_arena);
+  iree_allocator_free(host_allocator, command_buffer);
+
+  IREE_TRACE_ZONE_END(z0);
+}
+
+bool iree_hal_amdgpu_aql_command_buffer_isa(
+    iree_hal_command_buffer_t* command_buffer) {
+  return iree_hal_resource_is(&command_buffer->resource,
+                              &iree_hal_amdgpu_aql_command_buffer_vtable);
+}
+
+const iree_hal_amdgpu_aql_program_t* iree_hal_amdgpu_aql_command_buffer_program(
+    iree_hal_command_buffer_t* base_command_buffer) {
+  iree_hal_amdgpu_aql_command_buffer_t* command_buffer =
+      iree_hal_amdgpu_aql_command_buffer_cast(base_command_buffer);
+  return &command_buffer->program;
+}
+
+iree_host_size_t iree_hal_amdgpu_aql_command_buffer_device_ordinal(
+    iree_hal_command_buffer_t* base_command_buffer) {
+  iree_hal_amdgpu_aql_command_buffer_t* command_buffer =
+      iree_hal_amdgpu_aql_command_buffer_cast(base_command_buffer);
+  return command_buffer->device_ordinal;
+}
+
+uint64_t iree_hal_amdgpu_aql_command_buffer_profile_id(
+    iree_hal_command_buffer_t* base_command_buffer) {
+  iree_hal_amdgpu_aql_command_buffer_t* command_buffer =
+      iree_hal_amdgpu_aql_command_buffer_cast(base_command_buffer);
+  return command_buffer->profile.id;
+}
+
+const iree_hal_amdgpu_aql_command_buffer_dispatch_summary_t*
+iree_hal_amdgpu_aql_command_buffer_dispatch_summaries(
+    iree_hal_command_buffer_t* base_command_buffer,
+    const iree_hal_amdgpu_command_buffer_block_header_t* block,
+    uint32_t* out_count) {
+  IREE_ASSERT_ARGUMENT(out_count);
+  *out_count = 0;
+  iree_hal_amdgpu_aql_command_buffer_t* command_buffer =
+      iree_hal_amdgpu_aql_command_buffer_cast(base_command_buffer);
+  for (const iree_hal_amdgpu_aql_command_buffer_dispatch_summary_block_t*
+           summary_block = command_buffer->dispatch_summaries.block.first;
+       summary_block; summary_block = summary_block->next) {
+    if (summary_block->header == block) {
+      *out_count = summary_block->dispatch.count;
+      return summary_block->dispatch.first;
+    }
+    if (summary_block->header->block_ordinal > block->block_ordinal) {
+      break;
+    }
+  }
+  return NULL;
+}
+
+iree_hal_buffer_t* iree_hal_amdgpu_aql_command_buffer_static_buffer(
+    iree_hal_command_buffer_t* base_command_buffer, uint32_t ordinal) {
+  iree_hal_amdgpu_aql_command_buffer_t* command_buffer =
+      iree_hal_amdgpu_aql_command_buffer_cast(base_command_buffer);
+  return iree_hal_amdgpu_aql_command_buffer_static_buffer_for_ordinal(
+      command_buffer, ordinal);
+}
+
+const uint8_t* iree_hal_amdgpu_aql_command_buffer_rodata(
+    iree_hal_command_buffer_t* base_command_buffer, uint64_t ordinal,
+    uint32_t length) {
+  iree_hal_amdgpu_aql_command_buffer_t* command_buffer =
+      iree_hal_amdgpu_aql_command_buffer_cast(base_command_buffer);
+  const iree_hal_amdgpu_aql_command_buffer_rodata_segment_t* segment =
+      iree_hal_amdgpu_aql_command_buffer_rodata_segment_for_ordinal(
+          command_buffer, ordinal);
+  if (IREE_UNLIKELY(!segment)) {
+    return NULL;
+  }
+  return length == segment->length ? segment->data : NULL;
+}
+
+void* iree_hal_amdgpu_aql_command_buffer_prepublished_kernarg(
+    iree_hal_command_buffer_t* base_command_buffer, uint32_t byte_offset,
+    uint32_t length) {
+  iree_hal_amdgpu_aql_command_buffer_t* command_buffer =
+      iree_hal_amdgpu_aql_command_buffer_cast(base_command_buffer);
+  if (IREE_UNLIKELY(
+          !command_buffer->prepublished_kernargs.materialized.buffer)) {
+    return NULL;
+  }
+  const iree_device_size_t required_length =
+      iree_max((iree_device_size_t)1, (iree_device_size_t)length);
+  iree_device_size_t end_offset = 0;
+  if (IREE_UNLIKELY(
+          !iree_device_size_checked_add((iree_device_size_t)byte_offset,
+                                        required_length, &end_offset) ||
+          end_offset >
+              command_buffer->prepublished_kernargs.materialized.byte_length)) {
+    return NULL;
+  }
+  return command_buffer->prepublished_kernargs.materialized.device_base +
+         byte_offset;
+}
+
+//===----------------------------------------------------------------------===//
+// Recording Session
+//===----------------------------------------------------------------------===//
+
+static iree_status_t iree_hal_amdgpu_aql_command_buffer_begin(
+    iree_hal_command_buffer_t* base_command_buffer) {
+  iree_hal_amdgpu_aql_command_buffer_t* command_buffer =
+      iree_hal_amdgpu_aql_command_buffer_cast(base_command_buffer);
+  switch (command_buffer->recording_state) {
+    case IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_RECORDING_STATE_INITIAL:
+      break;
+    case IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_RECORDING_STATE_RECORDING:
+      return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                              "command buffer is already in a recording state");
+    case IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_RECORDING_STATE_FINALIZED:
+      return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                              "command buffer has already been recorded; "
+                              "re-recording command buffers is not allowed");
+    case IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_RECORDING_STATE_FAILED:
+      return iree_make_status(
+          IREE_STATUS_FAILED_PRECONDITION,
+          "command buffer recording failed and cannot be reused");
+    default:
+      return iree_make_status(IREE_STATUS_INTERNAL,
+                              "invalid command-buffer recording state %d",
+                              (int)command_buffer->recording_state);
+  }
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_program_builder_begin(&command_buffer->builder));
+  command_buffer->recording_state =
+      IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_RECORDING_STATE_RECORDING;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_aql_command_buffer_end(
+    iree_hal_command_buffer_t* base_command_buffer) {
+  iree_hal_amdgpu_aql_command_buffer_t* command_buffer =
+      iree_hal_amdgpu_aql_command_buffer_cast(base_command_buffer);
+  if (IREE_UNLIKELY(
+          command_buffer->recording_state !=
+          IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_RECORDING_STATE_RECORDING)) {
+    return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                            "command buffer is not in a recording state");
+  }
+  iree_status_t status = iree_hal_amdgpu_aql_program_builder_end(
+      &command_buffer->builder, &command_buffer->program);
+  if (iree_status_is_ok(status)) {
+    status =
+        iree_hal_amdgpu_aql_command_buffer_materialize_prepublished_kernargs(
+            command_buffer);
+  }
+  if (iree_status_is_ok(status) &&
+      iree_hal_amdgpu_aql_command_buffer_retains_profile_metadata(
+          command_buffer)) {
+    status = iree_hal_amdgpu_aql_command_buffer_register_profile_operations(
+        command_buffer->profile.metadata, command_buffer->profile.id,
+        &command_buffer->program, command_buffer->host_allocator);
+  }
+  if (iree_status_is_ok(status)) {
+    iree_hal_resource_set_freeze(command_buffer->resource_set);
+    command_buffer->recording_state =
+        IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_RECORDING_STATE_FINALIZED;
+  } else {
+    iree_hal_amdgpu_aql_command_buffer_discard_recording(command_buffer);
+    command_buffer->recording_state =
+        IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_RECORDING_STATE_FAILED;
+  }
+  return status;
+}
+
+//===----------------------------------------------------------------------===//
+// Debug Groups
+//===----------------------------------------------------------------------===//
+
+static iree_status_t iree_hal_amdgpu_aql_command_buffer_begin_debug_group(
+    iree_hal_command_buffer_t* base_command_buffer, iree_string_view_t label,
+    iree_hal_label_color_t label_color,
+    const iree_hal_label_location_t* location) {
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_aql_command_buffer_end_debug_group(
+    iree_hal_command_buffer_t* base_command_buffer) {
+  return iree_ok_status();
+}
+
+//===----------------------------------------------------------------------===//
+// Barriers and Events
+//===----------------------------------------------------------------------===//
+
+static iree_hsa_fence_scope_t
+iree_hal_amdgpu_aql_command_buffer_access_scope_fence_scope(
+    iree_hal_access_scope_t access_scope) {
+  if (access_scope == 0) return IREE_HSA_FENCE_SCOPE_NONE;
+  // Resolve HAL memory visibility to HSA fence scope while recording so replay
+  // can consume compact command flags without re-inspecting barrier operands.
+  // Same-agent device producer/consumer edges use AGENT; host/system-visible
+  // edges use SYSTEM; execution-only barriers carry no acquire/release scope.
+  const iree_hal_access_scope_t system_scopes =
+      IREE_HAL_ACCESS_SCOPE_HOST_READ | IREE_HAL_ACCESS_SCOPE_HOST_WRITE |
+      IREE_HAL_ACCESS_SCOPE_MEMORY_READ | IREE_HAL_ACCESS_SCOPE_MEMORY_WRITE;
+  return iree_any_bit_set(access_scope, system_scopes)
+             ? IREE_HSA_FENCE_SCOPE_SYSTEM
+             : IREE_HSA_FENCE_SCOPE_AGENT;
+}
+
+static void iree_hal_amdgpu_aql_command_buffer_accumulate_barrier_scopes(
+    iree_hal_access_scope_t source_scope, iree_hal_access_scope_t target_scope,
+    iree_hsa_fence_scope_t* release_scope,
+    iree_hsa_fence_scope_t* acquire_scope) {
+  const iree_hsa_fence_scope_t source_fence_scope =
+      iree_hal_amdgpu_aql_command_buffer_access_scope_fence_scope(source_scope);
+  const iree_hsa_fence_scope_t target_fence_scope =
+      iree_hal_amdgpu_aql_command_buffer_access_scope_fence_scope(target_scope);
+  const iree_hsa_fence_scope_t fence_scope =
+      source_fence_scope > target_fence_scope ? source_fence_scope
+                                              : target_fence_scope;
+  if (source_scope != 0 && fence_scope > *release_scope) {
+    *release_scope = fence_scope;
+  }
+  if (target_scope != 0 && fence_scope > *acquire_scope) {
+    *acquire_scope = fence_scope;
+  }
+}
+
+static void iree_hal_amdgpu_aql_command_buffer_resolve_barrier_scopes(
+    iree_hal_execution_stage_t source_stage_mask,
+    iree_hal_execution_stage_t target_stage_mask,
+    iree_host_size_t memory_barrier_count,
+    const iree_hal_memory_barrier_t* memory_barriers,
+    iree_host_size_t buffer_barrier_count,
+    const iree_hal_buffer_barrier_t* buffer_barriers,
+    iree_hsa_fence_scope_t* out_acquire_scope,
+    iree_hsa_fence_scope_t* out_release_scope) {
+  iree_hsa_fence_scope_t acquire_scope = IREE_HSA_FENCE_SCOPE_NONE;
+  iree_hsa_fence_scope_t release_scope = IREE_HSA_FENCE_SCOPE_NONE;
+  if (iree_any_bit_set(source_stage_mask, IREE_HAL_EXECUTION_STAGE_HOST)) {
+    acquire_scope = IREE_HSA_FENCE_SCOPE_SYSTEM;
+  }
+  if (iree_any_bit_set(target_stage_mask, IREE_HAL_EXECUTION_STAGE_HOST)) {
+    release_scope = IREE_HSA_FENCE_SCOPE_SYSTEM;
+  }
+  for (iree_host_size_t i = 0; i < memory_barrier_count; ++i) {
+    iree_hal_amdgpu_aql_command_buffer_accumulate_barrier_scopes(
+        memory_barriers[i].source_scope, memory_barriers[i].target_scope,
+        &release_scope, &acquire_scope);
+  }
+  for (iree_host_size_t i = 0; i < buffer_barrier_count; ++i) {
+    iree_hal_amdgpu_aql_command_buffer_accumulate_barrier_scopes(
+        buffer_barriers[i].source_scope, buffer_barriers[i].target_scope,
+        &release_scope, &acquire_scope);
+  }
+  *out_acquire_scope = acquire_scope;
+  *out_release_scope = release_scope;
+}
+
+static iree_status_t iree_hal_amdgpu_aql_command_buffer_execution_barrier(
+    iree_hal_command_buffer_t* base_command_buffer,
+    iree_hal_execution_stage_t source_stage_mask,
+    iree_hal_execution_stage_t target_stage_mask,
+    iree_hal_execution_barrier_flags_t flags,
+    iree_host_size_t memory_barrier_count,
+    const iree_hal_memory_barrier_t* memory_barriers,
+    iree_host_size_t buffer_barrier_count,
+    const iree_hal_buffer_barrier_t* buffer_barriers) {
+  if (IREE_UNLIKELY(flags != IREE_HAL_EXECUTION_BARRIER_FLAG_NONE)) {
+    return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                            "unsupported execution barrier flags");
+  }
+
+  iree_hal_amdgpu_aql_command_buffer_t* command_buffer =
+      iree_hal_amdgpu_aql_command_buffer_cast(base_command_buffer);
+
+  iree_hsa_fence_scope_t acquire_scope = IREE_HSA_FENCE_SCOPE_NONE;
+  iree_hsa_fence_scope_t release_scope = IREE_HSA_FENCE_SCOPE_NONE;
+  iree_hal_amdgpu_aql_command_buffer_resolve_barrier_scopes(
+      source_stage_mask, target_stage_mask, memory_barrier_count,
+      memory_barriers, buffer_barrier_count, buffer_barriers, &acquire_scope,
+      &release_scope);
+
+  iree_hal_amdgpu_command_buffer_command_header_t* header = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_aql_program_builder_append_command(
+      &command_buffer->builder, IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BARRIER,
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_NONE,
+      sizeof(iree_hal_amdgpu_command_buffer_barrier_command_t),
+      /*binding_source_count=*/0, /*aql_packet_count=*/0,
+      /*kernarg_length=*/0, &header, /*out_binding_sources=*/NULL));
+
+  iree_hal_amdgpu_command_buffer_barrier_command_t* barrier =
+      (iree_hal_amdgpu_command_buffer_barrier_command_t*)header;
+  barrier->acquire_scope = (uint8_t)acquire_scope;
+  barrier->release_scope = (uint8_t)release_scope;
+  barrier->barrier_flags = (uint16_t)flags;
+  iree_hal_amdgpu_aql_program_builder_set_pending_barrier_scopes(
+      &command_buffer->builder, barrier->acquire_scope, barrier->release_scope);
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_aql_command_buffer_signal_event(
+    iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event,
+    iree_hal_execution_stage_t source_stage_mask) {
+  return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                          "AMDGPU command-buffer events not implemented");
+}
+
+static iree_status_t iree_hal_amdgpu_aql_command_buffer_reset_event(
+    iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event,
+    iree_hal_execution_stage_t source_stage_mask) {
+  return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                          "AMDGPU command-buffer events not implemented");
+}
+
+static iree_status_t iree_hal_amdgpu_aql_command_buffer_wait_events(
+    iree_hal_command_buffer_t* base_command_buffer,
+    iree_host_size_t event_count, const iree_hal_event_t** events,
+    iree_hal_execution_stage_t source_stage_mask,
+    iree_hal_execution_stage_t target_stage_mask,
+    iree_host_size_t memory_barrier_count,
+    const iree_hal_memory_barrier_t* memory_barriers,
+    iree_host_size_t buffer_barrier_count,
+    const iree_hal_buffer_barrier_t* buffer_barriers) {
+  return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                          "AMDGPU command-buffer events not implemented");
+}
+
+//===----------------------------------------------------------------------===//
+// Buffer Reference Recording
+//===----------------------------------------------------------------------===//
+
+static bool
+iree_hal_amdgpu_aql_command_buffer_allows_staged_transient_buffer_refs(
+    const iree_hal_amdgpu_aql_command_buffer_t* command_buffer) {
+  // One-shot command buffers are recorded for a single queued execution and may
+  // capture transient backing staged by a preceding queue_alloca before the
+  // user-visible alloca signal is published. Reusable command buffers require
+  // committed backing because the captured pointer can be replayed later.
+  return iree_all_bits_set(command_buffer->base.mode,
+                           IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT);
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_command_buffer_resolve_static_buffer_ref(
+    const iree_hal_amdgpu_aql_command_buffer_t* command_buffer,
+    const iree_hal_buffer_ref_t* buffer_ref, uint64_t* out_device_pointer) {
+  *out_device_pointer = 0;
+  iree_hal_buffer_t* allocated_buffer =
+      iree_hal_buffer_allocated_buffer(buffer_ref->buffer);
+  if (iree_hal_amdgpu_transient_buffer_isa(allocated_buffer)) {
+    iree_hal_buffer_t* backing_buffer = NULL;
+    if (iree_hal_amdgpu_aql_command_buffer_allows_staged_transient_buffer_refs(
+            command_buffer)) {
+      backing_buffer =
+          iree_hal_amdgpu_transient_buffer_backing_buffer(allocated_buffer);
+      if (IREE_UNLIKELY(!backing_buffer)) {
+        return iree_make_status(
+            IREE_STATUS_FAILED_PRECONDITION,
+            "one-shot command-buffer buffer reference has no staged AMDGPU "
+            "backing");
+      }
+    } else {
+      IREE_RETURN_IF_ERROR(
+          iree_hal_amdgpu_transient_buffer_resolve_committed_backing(
+              allocated_buffer, &backing_buffer));
+    }
+    allocated_buffer = iree_hal_buffer_allocated_buffer(backing_buffer);
+  }
+  void* device_ptr = iree_hal_amdgpu_buffer_device_pointer(allocated_buffer);
+  if (IREE_UNLIKELY(!device_ptr)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "command-buffer buffer reference must be backed by an AMDGPU "
+        "allocation");
+  }
+  iree_device_size_t device_offset = 0;
+  if (IREE_UNLIKELY(!iree_device_size_checked_add(
+          iree_hal_buffer_byte_offset(buffer_ref->buffer), buffer_ref->offset,
+          &device_offset))) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "command-buffer buffer reference device pointer offset overflows "
+        "device size");
+  }
+  if (IREE_UNLIKELY(device_offset > UINTPTR_MAX)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "command-buffer buffer reference device pointer offset exceeds host "
+        "pointer size");
+  }
+  *out_device_pointer =
+      (uint64_t)((uintptr_t)device_ptr + (uintptr_t)device_offset);
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_aql_command_buffer_record_buffer_ref(
+    iree_hal_amdgpu_aql_command_buffer_t* command_buffer,
+    iree_hal_buffer_ref_t buffer_ref,
+    iree_hal_amdgpu_command_buffer_binding_kind_t* out_kind,
+    uint32_t* out_ordinal, uint64_t* out_offset, uint64_t* out_length) {
+  *out_kind = IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_KIND_INVALID;
+  *out_ordinal = 0;
+  *out_offset = 0;
+  *out_length = 0;
+
+  if (!buffer_ref.buffer) {
+    if (IREE_UNLIKELY(buffer_ref.buffer_slot == UINT32_MAX)) {
+      return iree_make_status(
+          IREE_STATUS_OUT_OF_RANGE,
+          "indirect command-buffer buffer slot %u exceeds binding count "
+          "storage",
+          buffer_ref.buffer_slot);
+    }
+    command_buffer->base.binding_count = iree_max(
+        command_buffer->base.binding_count, buffer_ref.buffer_slot + 1);
+    *out_kind = IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_KIND_DYNAMIC;
+    *out_ordinal = buffer_ref.buffer_slot;
+    *out_offset = buffer_ref.offset;
+    *out_length = buffer_ref.length;
+    return iree_ok_status();
+  }
+
+  iree_device_size_t resolved_offset = 0;
+  iree_device_size_t resolved_length = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_calculate_range(
+      /*base_offset=*/0, iree_hal_buffer_byte_length(buffer_ref.buffer),
+      buffer_ref.offset, buffer_ref.length, &resolved_offset,
+      &resolved_length));
+
+  uint32_t ordinal = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_aql_command_buffer_record_static_buffer(
+      command_buffer, buffer_ref.buffer, &ordinal));
+
+  *out_kind = IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_KIND_STATIC;
+  *out_ordinal = ordinal;
+  *out_offset = resolved_offset;
+  *out_length = resolved_length;
+  return iree_ok_status();
+}
+
+//===----------------------------------------------------------------------===//
+// Dispatch Recording
+//===----------------------------------------------------------------------===//
+
+static bool iree_hal_amdgpu_dispatch_config_has_workgroup_size_override(
+    const iree_hal_dispatch_config_t config) {
+  return config.workgroup_size[0] || config.workgroup_size[1] ||
+         config.workgroup_size[2];
+}
+
+static iree_status_t iree_hal_amdgpu_aql_command_buffer_check_dispatch_flags(
+    iree_hal_dispatch_flags_t flags) {
+  if (iree_hal_dispatch_uses_indirect_arguments(flags)) {
+    return iree_make_status(
+        IREE_STATUS_UNIMPLEMENTED,
+        "indirect dispatch arguments are not supported by AMDGPU command "
+        "buffers yet");
+  }
+  const iree_hal_dispatch_flags_t supported_flags =
+      IREE_HAL_DISPATCH_FLAG_DYNAMIC_INDIRECT_PARAMETERS |
+      IREE_HAL_DISPATCH_FLAG_STATIC_INDIRECT_PARAMETERS |
+      IREE_HAL_DISPATCH_FLAG_CUSTOM_DIRECT_ARGUMENTS |
+      IREE_HAL_DISPATCH_FLAG_ALLOW_INLINE_EXECUTION;
+  if (IREE_UNLIKELY(iree_any_bit_set(flags, ~supported_flags))) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "unsupported dispatch flags: 0x%" PRIx64, flags);
+  }
+  return iree_ok_status();
+}
+
+static void iree_hal_amdgpu_aql_command_buffer_select_dispatch_kernel_args(
+    const iree_hal_amdgpu_executable_dispatch_descriptor_t* descriptor,
+    const iree_hal_dispatch_config_t config,
+    iree_hal_amdgpu_device_kernel_args_t* override_kernel_args,
+    const iree_hal_amdgpu_device_kernel_args_t** out_kernel_args) {
+  *out_kernel_args = &descriptor->kernel_args;
+  if (!iree_hal_amdgpu_dispatch_config_has_workgroup_size_override(config)) {
+    return;
+  }
+
+  *override_kernel_args = descriptor->kernel_args;
+  for (iree_host_size_t i = 0; i < 3; ++i) {
+    override_kernel_args->workgroup_size[i] =
+        (uint16_t)config.workgroup_size[i];
+  }
+
+  *out_kernel_args = override_kernel_args;
+}
+
+static iree_status_t iree_hal_amdgpu_aql_command_buffer_validate_dispatch_shape(
+    const iree_hal_amdgpu_executable_dispatch_descriptor_t* descriptor,
+    const iree_hal_dispatch_config_t config, iree_hal_dispatch_flags_t flags) {
+  const bool uses_indirect_parameters =
+      iree_hal_dispatch_uses_indirect_parameters(flags);
+  if (iree_hal_amdgpu_dispatch_config_has_workgroup_size_override(config)) {
+    for (iree_host_size_t i = 0; i < 3; ++i) {
+      if (IREE_UNLIKELY(!config.workgroup_size[i])) {
+        return iree_make_status(
+            IREE_STATUS_INVALID_ARGUMENT,
+            "dispatch workgroup size override must specify all dimensions");
+      }
+      if (IREE_UNLIKELY(config.workgroup_size[i] > UINT16_MAX)) {
+        return iree_make_status(
+            IREE_STATUS_OUT_OF_RANGE,
+            "dispatch workgroup size override dimension %" PRIhsz
+            " value %u exceeds %u",
+            i, config.workgroup_size[i], UINT16_MAX);
+      }
+      if (!uses_indirect_parameters) {
+        const uint64_t grid_size =
+            (uint64_t)config.workgroup_count[i] * config.workgroup_size[i];
+        if (IREE_UNLIKELY(grid_size > UINT32_MAX)) {
+          return iree_make_status(
+              IREE_STATUS_OUT_OF_RANGE,
+              "dispatch grid dimension %" PRIhsz
+              " overflows uint32_t (workgroup_count=%u, workgroup_size=%u)",
+              i, config.workgroup_count[i], config.workgroup_size[i]);
+        }
+      }
+    }
+  } else if (!uses_indirect_parameters) {
+    for (iree_host_size_t i = 0; i < 3; ++i) {
+      if (IREE_UNLIKELY(config.workgroup_count[i] >
+                        descriptor->max_workgroup_count[i])) {
+        return iree_make_status(
+            IREE_STATUS_OUT_OF_RANGE,
+            "dispatch grid dimension %" PRIhsz
+            " overflows uint32_t (workgroup_count=%u, workgroup_size=%u)",
+            i, config.workgroup_count[i],
+            descriptor->kernel_args.workgroup_size[i]);
+      }
+    }
+  }
+  if (IREE_UNLIKELY(config.dynamic_workgroup_local_memory >
+                    descriptor->max_dynamic_workgroup_local_memory)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "dispatch group segment size overflows uint32_t "
+                            "(static=%u, dynamic=%u)",
+                            descriptor->kernel_args.group_segment_size,
+                            config.dynamic_workgroup_local_memory);
+  }
+  return iree_ok_status();
+}
+
+static bool iree_hal_amdgpu_aql_command_buffer_should_defer_static_buffer_ref(
+    const iree_hal_amdgpu_aql_command_buffer_t* command_buffer,
+    const iree_hal_buffer_ref_t* buffer_ref) {
+  if (!iree_hal_amdgpu_aql_command_buffer_allows_staged_transient_buffer_refs(
+          command_buffer)) {
+    return false;
+  }
+  iree_hal_buffer_t* allocated_buffer =
+      iree_hal_buffer_allocated_buffer(buffer_ref->buffer);
+  return iree_hal_amdgpu_transient_buffer_isa(allocated_buffer) &&
+         !iree_hal_amdgpu_transient_buffer_backing_buffer(allocated_buffer);
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_command_buffer_prepare_dispatch_binding_sources(
+    iree_hal_amdgpu_aql_command_buffer_t* command_buffer,
+    const iree_hal_buffer_ref_list_t bindings) {
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_command_buffer_ensure_resource_set(command_buffer));
+
+  iree_host_size_t binding_count = command_buffer->base.binding_count;
+  iree_status_t status = iree_ok_status();
+  iree_host_size_t failed_index = 0;
+  for (iree_host_size_t i = 0; i < bindings.count && iree_status_is_ok(status);
+       ++i) {
+    failed_index = i;
+    const iree_hal_buffer_ref_t* binding = &bindings.values[i];
+    if (!binding->buffer) {
+      if (IREE_UNLIKELY(binding->buffer_slot == UINT32_MAX)) {
+        status = iree_make_status(
+            IREE_STATUS_OUT_OF_RANGE,
+            "indirect command-buffer dispatch binding slot %u exceeds binding "
+            "count storage",
+            binding->buffer_slot);
+      } else {
+        binding_count = iree_max(binding_count, binding->buffer_slot + 1);
+      }
+      continue;
+    }
+
+    iree_device_size_t unused_offset = 0;
+    iree_device_size_t unused_length = 0;
+    status = iree_hal_buffer_calculate_range(
+        /*base_offset=*/0, iree_hal_buffer_byte_length(binding->buffer),
+        binding->offset, binding->length, &unused_offset, &unused_length);
+    if (iree_status_is_ok(status) &&
+        !iree_hal_amdgpu_aql_command_buffer_should_defer_static_buffer_ref(
+            command_buffer, binding)) {
+      uint64_t unused_device_pointer = 0;
+      status = iree_hal_amdgpu_aql_command_buffer_resolve_static_buffer_ref(
+          command_buffer, binding, &unused_device_pointer);
+    }
+    if (iree_status_is_ok(status) && command_buffer->resource_set) {
+      status = iree_hal_resource_set_insert(command_buffer->resource_set,
+                                            /*count=*/1, &binding->buffer);
+    }
+  }
+  if (iree_status_is_ok(status)) {
+    command_buffer->base.binding_count = (uint32_t)binding_count;
+  } else {
+    status =
+        iree_status_annotate_f(status, "binding[%" PRIhsz "]", failed_index);
+  }
+  return status;
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_command_buffer_record_deferred_dispatch_binding_source(
+    iree_hal_amdgpu_aql_command_buffer_t* command_buffer,
+    const iree_hal_buffer_ref_t* binding,
+    iree_hal_amdgpu_command_buffer_binding_source_t* binding_source) {
+  iree_device_size_t resolved_offset = 0;
+  iree_device_size_t resolved_length = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_calculate_range(
+      /*base_offset=*/0, iree_hal_buffer_byte_length(binding->buffer),
+      binding->offset, binding->length, &resolved_offset, &resolved_length));
+  (void)resolved_length;
+
+  uint32_t static_buffer_ordinal = 0;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_command_buffer_allocate_static_buffer(
+          command_buffer, binding->buffer, &static_buffer_ordinal));
+  binding_source->offset_or_pointer = resolved_offset;
+  binding_source->slot = static_buffer_ordinal;
+  binding_source->flags =
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_STATIC_BUFFER;
+  return iree_ok_status();
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_command_buffer_write_dispatch_binding_sources(
+    iree_hal_amdgpu_aql_command_buffer_t* command_buffer,
+    const iree_hal_buffer_ref_list_t bindings,
+    iree_hal_amdgpu_command_buffer_binding_source_t* binding_sources) {
+  iree_status_t status = iree_ok_status();
+  for (iree_host_size_t i = 0; i < bindings.count && iree_status_is_ok(status);
+       ++i) {
+    const iree_hal_buffer_ref_t* binding = &bindings.values[i];
+    iree_hal_amdgpu_command_buffer_binding_source_t* binding_source =
+        &binding_sources[i];
+    binding_source->target_binding_ordinal = (uint16_t)i;
+    if (!binding->buffer) {
+      binding_source->offset_or_pointer = binding->offset;
+      binding_source->slot = binding->buffer_slot;
+      binding_source->flags =
+          IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_DYNAMIC;
+      continue;
+    }
+
+    if (iree_hal_amdgpu_aql_command_buffer_should_defer_static_buffer_ref(
+            command_buffer, binding)) {
+      status =
+          iree_hal_amdgpu_aql_command_buffer_record_deferred_dispatch_binding_source(
+              command_buffer, binding, binding_source);
+      continue;
+    }
+
+    status = iree_hal_amdgpu_aql_command_buffer_resolve_static_buffer_ref(
+        command_buffer, binding, &binding_source->offset_or_pointer);
+    if (iree_status_is_ok(status)) {
+      binding_source->slot = 0;
+      binding_source->flags =
+          IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_NONE;
+    }
+  }
+  return status;
+}
+
+static void
+iree_hal_amdgpu_aql_command_buffer_write_dynamic_dispatch_binding_sources(
+    iree_hal_buffer_ref_list_t bindings,
+    iree_hal_amdgpu_command_buffer_binding_source_t* binding_sources) {
+  iree_host_size_t source_index = 0;
+  for (iree_host_size_t i = 0; i < bindings.count; ++i) {
+    const iree_hal_buffer_ref_t* binding = &bindings.values[i];
+    if (binding->buffer) continue;
+
+    iree_hal_amdgpu_command_buffer_binding_source_t* binding_source =
+        &binding_sources[source_index++];
+    binding_source->offset_or_pointer = binding->offset;
+    binding_source->slot = binding->buffer_slot;
+    binding_source->flags =
+        IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_DYNAMIC;
+    binding_source->target_binding_ordinal = (uint16_t)i;
+  }
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_command_buffer_check_indirect_workgroup_count_ref(
+    iree_hal_buffer_ref_t buffer_ref) {
+  const iree_device_size_t workgroup_count_length = sizeof(uint32_t[3]);
+  if (IREE_UNLIKELY((buffer_ref.offset % sizeof(uint32_t)) != 0)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "indirect workgroup count offset must be 4-byte aligned");
+  }
+  if (IREE_UNLIKELY(buffer_ref.length != IREE_HAL_WHOLE_BUFFER &&
+                    buffer_ref.length < workgroup_count_length)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "indirect workgroup count buffer must contain at least uint32_t[3]");
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_command_buffer_write_indirect_parameter_source(
+    iree_hal_amdgpu_aql_command_buffer_t* command_buffer,
+    iree_hal_buffer_ref_t buffer_ref,
+    iree_hal_amdgpu_command_buffer_binding_source_t* binding_source) {
+  memset(binding_source, 0, sizeof(*binding_source));
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_command_buffer_check_indirect_workgroup_count_ref(
+          buffer_ref));
+
+  if (!buffer_ref.buffer) {
+    if (IREE_UNLIKELY(buffer_ref.buffer_slot == UINT32_MAX)) {
+      return iree_make_status(
+          IREE_STATUS_OUT_OF_RANGE,
+          "indirect workgroup count binding slot %u exceeds binding count "
+          "storage",
+          buffer_ref.buffer_slot);
+    }
+    command_buffer->base.binding_count = iree_max(
+        command_buffer->base.binding_count, buffer_ref.buffer_slot + 1);
+    binding_source->offset_or_pointer = buffer_ref.offset;
+    binding_source->slot = buffer_ref.buffer_slot;
+    binding_source->flags =
+        IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_DYNAMIC |
+        IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_INDIRECT_PARAMETERS;
+    return iree_ok_status();
+  }
+
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_memory_type(
+      iree_hal_buffer_memory_type(buffer_ref.buffer),
+      IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE));
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_usage(
+      iree_hal_buffer_allowed_usage(buffer_ref.buffer),
+      IREE_HAL_BUFFER_USAGE_DISPATCH_INDIRECT_PARAMETERS));
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_access(
+      iree_hal_buffer_allowed_access(buffer_ref.buffer),
+      IREE_HAL_MEMORY_ACCESS_READ));
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_range(
+      buffer_ref.buffer, buffer_ref.offset, sizeof(uint32_t[3])));
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_command_buffer_ensure_resource_set(command_buffer));
+  if (command_buffer->resource_set) {
+    IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert(
+        command_buffer->resource_set, /*count=*/1, &buffer_ref.buffer));
+  }
+
+  if (iree_hal_amdgpu_aql_command_buffer_should_defer_static_buffer_ref(
+          command_buffer, &buffer_ref)) {
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_aql_command_buffer_record_deferred_dispatch_binding_source(
+            command_buffer, &buffer_ref, binding_source));
+    binding_source->flags |=
+        IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_INDIRECT_PARAMETERS;
+    return iree_ok_status();
+  }
+
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_command_buffer_resolve_static_buffer_ref(
+          command_buffer, &buffer_ref, &binding_source->offset_or_pointer));
+  binding_source->slot = 0;
+  binding_source->flags =
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_INDIRECT_PARAMETERS;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_aql_command_buffer_qword_length(
+    iree_host_size_t byte_length, const char* label, uint16_t* out_qwords,
+    iree_host_size_t* out_padded_length) {
+  if (IREE_UNLIKELY(byte_length > IREE_HOST_SIZE_MAX - 7)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "%s byte length %" PRIhsz
+                            " overflows 8-byte alignment",
+                            label, byte_length);
+  }
+  const iree_host_size_t padded_length = iree_host_align(byte_length, 8);
+  const iree_host_size_t qword_length = padded_length / 8;
+  if (IREE_UNLIKELY(qword_length > UINT16_MAX)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "%s byte length %" PRIhsz
+                            " exceeds uint16_t qword storage",
+                            label, byte_length);
+  }
+  *out_qwords = (uint16_t)qword_length;
+  if (out_padded_length) *out_padded_length = padded_length;
+  return iree_ok_status();
+}
+
+static void iree_hal_amdgpu_aql_command_buffer_write_implicit_args(
+    const iree_hal_amdgpu_device_kernel_args_t* kernel_args,
+    const iree_hal_dispatch_config_t config,
+    iree_amdgpu_kernel_implicit_args_t* implicit_args) {
+  implicit_args->block_count[0] = config.workgroup_count[0];
+  implicit_args->block_count[1] = config.workgroup_count[1];
+  implicit_args->block_count[2] = config.workgroup_count[2];
+  implicit_args->group_size[0] = kernel_args->workgroup_size[0];
+  implicit_args->group_size[1] = kernel_args->workgroup_size[1];
+  implicit_args->group_size[2] = kernel_args->workgroup_size[2];
+  implicit_args->grid_dims = 3;
+  implicit_args->printf_buffer = NULL;
+  implicit_args->hostcall_buffer = NULL;
+  implicit_args->dynamic_lds_size = config.dynamic_workgroup_local_memory;
+}
+
+static iree_status_t iree_hal_amdgpu_aql_command_buffer_write_dispatch_tail(
+    const iree_hal_amdgpu_device_kernel_args_t* kernel_args,
+    const iree_hal_amdgpu_device_dispatch_kernarg_layout_t* layout,
+    const iree_hal_dispatch_config_t config, iree_const_byte_span_t constants,
+    iree_hal_amdgpu_command_buffer_kernarg_strategy_t kernarg_strategy,
+    uint8_t* tail_payload) {
+  switch (kernarg_strategy) {
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_HAL:
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_DYNAMIC_BINDINGS: {
+      const iree_host_size_t binding_bytes =
+          (iree_host_size_t)kernel_args->binding_count * sizeof(uint64_t);
+      if (constants.data_length > 0) {
+        memcpy(tail_payload, constants.data, constants.data_length);
+      }
+      if (layout->has_implicit_args) {
+        iree_amdgpu_kernel_implicit_args_t* implicit_args =
+            (iree_amdgpu_kernel_implicit_args_t*)(tail_payload +
+                                                  layout->implicit_args_offset -
+                                                  binding_bytes);
+        iree_hal_amdgpu_aql_command_buffer_write_implicit_args(
+            kernel_args, config, implicit_args);
+      }
+      return iree_ok_status();
+    }
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_CUSTOM_DIRECT:
+      if (constants.data_length > 0) {
+        memcpy(tail_payload, constants.data, constants.data_length);
+      }
+      return iree_ok_status();
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_INDIRECT:
+      return iree_make_status(
+          IREE_STATUS_UNIMPLEMENTED,
+          "indirect dispatch arguments are not supported by AMDGPU command "
+          "buffers yet");
+    default:
+      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                              "unsupported command-buffer kernarg strategy %u",
+                              kernarg_strategy);
+  }
+}
+
+static uint16_t
+iree_hal_amdgpu_aql_command_buffer_count_dynamic_dispatch_bindings(
+    iree_hal_buffer_ref_list_t bindings) {
+  uint16_t count = 0;
+  for (iree_host_size_t i = 0; i < bindings.count; ++i) {
+    if (!bindings.values[i].buffer) ++count;
+  }
+  return count;
+}
+
+typedef uint32_t iree_hal_amdgpu_aql_command_buffer_kernarg_template_flags_t;
+enum iree_hal_amdgpu_aql_command_buffer_kernarg_template_flag_bits_t {
+  IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_KERNARG_TEMPLATE_FLAG_NONE = 0u,
+  IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_KERNARG_TEMPLATE_FLAG_ALLOW_DYNAMIC_BINDINGS =
+      1u << 0,
+};
+
+static iree_status_t
+iree_hal_amdgpu_aql_command_buffer_write_dispatch_kernarg_template(
+    iree_hal_amdgpu_aql_command_buffer_t* command_buffer,
+    const iree_hal_amdgpu_device_kernel_args_t* kernel_args,
+    const iree_hal_amdgpu_device_dispatch_kernarg_layout_t* layout,
+    const iree_hal_dispatch_config_t config, iree_const_byte_span_t constants,
+    iree_hal_buffer_ref_list_t bindings,
+    iree_hal_amdgpu_command_buffer_kernarg_strategy_t kernarg_strategy,
+    iree_hal_amdgpu_aql_command_buffer_kernarg_template_flags_t flags,
+    uint8_t* kernarg_data) {
+  switch (kernarg_strategy) {
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_HAL: {
+      uint64_t* binding_dst = (uint64_t*)kernarg_data;
+      for (iree_host_size_t i = 0; i < bindings.count; ++i) {
+        if (!bindings.values[i].buffer) {
+          if (IREE_UNLIKELY(!iree_any_bit_set(
+                  flags,
+                  IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_KERNARG_TEMPLATE_FLAG_ALLOW_DYNAMIC_BINDINGS))) {
+            return iree_make_status(
+                IREE_STATUS_INVALID_ARGUMENT,
+                "prepublished command-buffer kernarg template cannot contain "
+                "dynamic bindings");
+          }
+          binding_dst[i] = 0;
+          continue;
+        }
+        IREE_RETURN_IF_ERROR(
+            iree_hal_amdgpu_aql_command_buffer_resolve_static_buffer_ref(
+                command_buffer, &bindings.values[i], &binding_dst[i]));
+      }
+      const iree_host_size_t binding_bytes =
+          (iree_host_size_t)kernel_args->binding_count * sizeof(uint64_t);
+      return iree_hal_amdgpu_aql_command_buffer_write_dispatch_tail(
+          kernel_args, layout, config, constants, kernarg_strategy,
+          kernarg_data + binding_bytes);
+    }
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_CUSTOM_DIRECT:
+      return iree_hal_amdgpu_aql_command_buffer_write_dispatch_tail(
+          kernel_args, layout, config, constants, kernarg_strategy,
+          kernarg_data);
+    default:
+      return iree_make_status(
+          IREE_STATUS_INVALID_ARGUMENT,
+          "unsupported command-buffer kernarg template strategy %u",
+          kernarg_strategy);
+  }
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_command_buffer_record_prepublished_dispatch_kernargs(
+    iree_hal_amdgpu_aql_command_buffer_t* command_buffer,
+    iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command,
+    const iree_hal_amdgpu_device_kernel_args_t* kernel_args,
+    const iree_hal_amdgpu_device_dispatch_kernarg_layout_t* layout,
+    const iree_hal_dispatch_config_t config, iree_const_byte_span_t constants,
+    iree_hal_buffer_ref_list_t bindings,
+    iree_hal_amdgpu_command_buffer_kernarg_strategy_t kernarg_strategy,
+    iree_host_size_t kernarg_padded_length, uint64_t* out_rodata_ordinal) {
+  *out_rodata_ordinal = 0;
+  uint8_t* kernarg_data = NULL;
+  iree_hal_amdgpu_aql_command_buffer_rodata_segment_t* segment = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_aql_command_buffer_allocate_rodata_segment(
+      command_buffer, kernarg_padded_length, kernel_args->kernarg_alignment,
+      IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_RODATA_SEGMENT_FLAG_PREPUBLISHED_KERNARGS,
+      &kernarg_data, out_rodata_ordinal, &segment));
+  segment->prepublished.dispatch_command = dispatch_command;
+  memset(kernarg_data, 0, iree_max((iree_host_size_t)1, kernarg_padded_length));
+
+  iree_status_t status =
+      iree_hal_amdgpu_aql_command_buffer_write_dispatch_kernarg_template(
+          command_buffer, kernel_args, layout, config, constants, bindings,
+          kernarg_strategy,
+          IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_KERNARG_TEMPLATE_FLAG_NONE,
+          kernarg_data);
+  if (iree_status_is_ok(status)) {
+    status =
+        iree_hal_amdgpu_aql_command_buffer_append_prepublished_kernarg_template(
+            command_buffer, segment);
+  }
+  return status;
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_command_buffer_record_patched_dispatch_kernargs(
+    iree_hal_amdgpu_aql_command_buffer_t* command_buffer,
+    const iree_hal_amdgpu_device_kernel_args_t* kernel_args,
+    const iree_hal_amdgpu_device_dispatch_kernarg_layout_t* layout,
+    const iree_hal_dispatch_config_t config, iree_const_byte_span_t constants,
+    iree_hal_buffer_ref_list_t bindings,
+    iree_hal_amdgpu_command_buffer_kernarg_strategy_t kernarg_strategy,
+    iree_host_size_t kernarg_padded_length, uint64_t* out_rodata_ordinal) {
+  *out_rodata_ordinal = 0;
+  uint8_t* kernarg_data = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_command_buffer_allocate_rodata_segment(
+          command_buffer, kernarg_padded_length, kernel_args->kernarg_alignment,
+          IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_RODATA_SEGMENT_FLAG_NONE,
+          &kernarg_data, out_rodata_ordinal, /*out_segment=*/NULL));
+  memset(kernarg_data, 0, iree_max((iree_host_size_t)1, kernarg_padded_length));
+  return iree_hal_amdgpu_aql_command_buffer_write_dispatch_kernarg_template(
+      command_buffer, kernel_args, layout, config, constants, bindings,
+      kernarg_strategy,
+      IREE_HAL_AMDGPU_AQL_COMMAND_BUFFER_KERNARG_TEMPLATE_FLAG_ALLOW_DYNAMIC_BINDINGS,
+      kernarg_data);
+}
+
+typedef enum iree_hal_amdgpu_aql_dispatch_plan_flag_bits_e {
+  IREE_HAL_AMDGPU_AQL_DISPATCH_PLAN_FLAG_NONE = 0u,
+  IREE_HAL_AMDGPU_AQL_DISPATCH_PLAN_FLAG_CUSTOM_DIRECT_ARGUMENTS = 1u << 0,
+  IREE_HAL_AMDGPU_AQL_DISPATCH_PLAN_FLAG_INDIRECT_PARAMETERS = 1u << 1,
+  IREE_HAL_AMDGPU_AQL_DISPATCH_PLAN_FLAG_DYNAMIC_BINDINGS = 1u << 2,
+} iree_hal_amdgpu_aql_dispatch_plan_flag_bits_t;
+
+typedef uint32_t iree_hal_amdgpu_aql_dispatch_plan_flags_t;
+
+typedef enum iree_hal_amdgpu_aql_dispatch_layout_flag_bits_e {
+  IREE_HAL_AMDGPU_AQL_DISPATCH_LAYOUT_FLAG_NONE = 0u,
+  IREE_HAL_AMDGPU_AQL_DISPATCH_LAYOUT_FLAG_PREPUBLISH_KERNARGS = 1u << 0,
+  IREE_HAL_AMDGPU_AQL_DISPATCH_LAYOUT_FLAG_PATCH_KERNARG_TEMPLATE = 1u << 1,
+} iree_hal_amdgpu_aql_dispatch_layout_flag_bits_t;
+
+typedef uint32_t iree_hal_amdgpu_aql_dispatch_layout_flags_t;
+
+typedef struct iree_hal_amdgpu_aql_dispatch_inputs_t {
+  // Executable containing the requested export.
+  iree_hal_executable_t* executable;
+
+  // Export ordinal within |executable|.
+  iree_hal_executable_export_ordinal_t export_ordinal;
+
+  // HAL dispatch configuration.
+  iree_hal_dispatch_config_t config;
+
+  // Borrowed constant bytes passed to the dispatch.
+  iree_const_byte_span_t constants;
+
+  // Borrowed binding list passed to the dispatch.
+  iree_hal_buffer_ref_list_t bindings;
+
+  // HAL dispatch flags.
+  iree_hal_dispatch_flags_t flags;
+} iree_hal_amdgpu_aql_dispatch_inputs_t;
+
+typedef struct iree_hal_amdgpu_aql_dispatch_plan_t {
+  // Descriptor resolved for the selected physical device/export pair.
+  const iree_hal_amdgpu_executable_dispatch_descriptor_t* descriptor;
+
+  // Workgroup-size override storage used when |kernel_args| points here.
+  iree_hal_amdgpu_device_kernel_args_t override_kernel_args;
+
+  // Kernel argument descriptor selected from |descriptor|.
+  const iree_hal_amdgpu_device_kernel_args_t* kernel_args;
+
+  // Kernarg layout selected for HAL or custom-direct arguments.
+  const iree_hal_amdgpu_device_dispatch_kernarg_layout_t* layout;
+
+  // Number of kernarg blocks required by the selected descriptor path.
+  uint32_t kernarg_block_count;
+
+  // Binding-source plan for this dispatch.
+  struct {
+    // Number of dynamic binding table sources used by this dispatch.
+    uint16_t dynamic_count;
+  } bindings;
+
+  // Command-buffer kernarg strategy used for this dispatch.
+  iree_hal_amdgpu_command_buffer_kernarg_strategy_t kernarg_strategy;
+
+  // Plan flags from iree_hal_amdgpu_aql_dispatch_plan_flag_bits_t.
+  iree_hal_amdgpu_aql_dispatch_plan_flags_t flags;
+} iree_hal_amdgpu_aql_dispatch_plan_t;
+
+static bool iree_hal_amdgpu_aql_dispatch_plan_uses_custom_direct_arguments(
+    const iree_hal_amdgpu_aql_dispatch_plan_t* plan) {
+  return iree_any_bit_set(
+      plan->flags,
+      IREE_HAL_AMDGPU_AQL_DISPATCH_PLAN_FLAG_CUSTOM_DIRECT_ARGUMENTS);
+}
+
+static bool iree_hal_amdgpu_aql_dispatch_plan_uses_indirect_parameters(
+    const iree_hal_amdgpu_aql_dispatch_plan_t* plan) {
+  return iree_any_bit_set(
+      plan->flags, IREE_HAL_AMDGPU_AQL_DISPATCH_PLAN_FLAG_INDIRECT_PARAMETERS);
+}
+
+static bool iree_hal_amdgpu_aql_dispatch_plan_has_dynamic_bindings(
+    const iree_hal_amdgpu_aql_dispatch_plan_t* plan) {
+  return iree_any_bit_set(
+      plan->flags, IREE_HAL_AMDGPU_AQL_DISPATCH_PLAN_FLAG_DYNAMIC_BINDINGS);
+}
+
+typedef struct iree_hal_amdgpu_aql_dispatch_layout_t {
+  // Command record and binding-source layout in the AQL program.
+  struct {
+    // Byte length of the command record allocation.
+    iree_host_size_t byte_length;
+
+    // Number of binding-source records following the command.
+    uint16_t binding_source_count;
+
+    // Worst-case AQL packets required by replay.
+    uint32_t aql_packet_count;
+  } command;
+
+  // Queue-time kernarg allocation requirements.
+  struct {
+    // Total kernarg qword length.
+    uint16_t total_length_qwords;
+
+    // Tail payload qword length stored after the command record.
+    uint16_t tail_length_qwords;
+
+    // Implicit-argument qword offset, or UINT16_MAX when absent.
+    uint16_t implicit_args_offset_qwords;
+
+    // Total kernarg length padded to qword alignment.
+    iree_host_size_t total_padded_length;
+
+    // Tail payload length padded to qword alignment.
+    iree_host_size_t tail_padded_length;
+
+    // Queue-time kernarg block bytes reserved by replay.
+    uint32_t queue_block_length;
+  } kernarg;
+
+  // Layout flags from iree_hal_amdgpu_aql_dispatch_layout_flag_bits_t.
+  iree_hal_amdgpu_aql_dispatch_layout_flags_t flags;
+} iree_hal_amdgpu_aql_dispatch_layout_t;
+
+static bool iree_hal_amdgpu_aql_dispatch_layout_prepublishes_kernargs(
+    const iree_hal_amdgpu_aql_dispatch_layout_t* layout) {
+  return iree_any_bit_set(
+      layout->flags,
+      IREE_HAL_AMDGPU_AQL_DISPATCH_LAYOUT_FLAG_PREPUBLISH_KERNARGS);
+}
+
+static bool iree_hal_amdgpu_aql_dispatch_layout_patches_kernarg_template(
+    const iree_hal_amdgpu_aql_dispatch_layout_t* layout) {
+  return iree_any_bit_set(
+      layout->flags,
+      IREE_HAL_AMDGPU_AQL_DISPATCH_LAYOUT_FLAG_PATCH_KERNARG_TEMPLATE);
+}
+
+static bool iree_hal_amdgpu_aql_dispatch_layout_uses_kernarg_template(
+    const iree_hal_amdgpu_aql_dispatch_layout_t* layout) {
+  return iree_any_bit_set(
+      layout->flags,
+      IREE_HAL_AMDGPU_AQL_DISPATCH_LAYOUT_FLAG_PREPUBLISH_KERNARGS |
+          IREE_HAL_AMDGPU_AQL_DISPATCH_LAYOUT_FLAG_PATCH_KERNARG_TEMPLATE);
+}
+
+static iree_status_t iree_hal_amdgpu_aql_command_buffer_prepare_dispatch_plan(
+    iree_hal_amdgpu_aql_command_buffer_t* command_buffer,
+    const iree_hal_amdgpu_aql_dispatch_inputs_t* inputs, bool validates,
+    iree_hal_amdgpu_aql_dispatch_plan_t* out_plan) {
+  *out_plan = (iree_hal_amdgpu_aql_dispatch_plan_t){0};
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_command_buffer_check_dispatch_flags(inputs->flags));
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_executable_lookup_dispatch_descriptor_for_device(
+          inputs->executable, inputs->export_ordinal,
+          command_buffer->device_ordinal, &out_plan->descriptor));
+
+  if (validates) {
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_aql_command_buffer_validate_dispatch_shape(
+            out_plan->descriptor, inputs->config, inputs->flags));
+  }
+
+  iree_hal_amdgpu_aql_command_buffer_select_dispatch_kernel_args(
+      out_plan->descriptor, inputs->config, &out_plan->override_kernel_args,
+      &out_plan->kernel_args);
+
+  if (iree_any_bit_set(inputs->flags,
+                       IREE_HAL_DISPATCH_FLAG_CUSTOM_DIRECT_ARGUMENTS)) {
+    out_plan->flags |=
+        IREE_HAL_AMDGPU_AQL_DISPATCH_PLAN_FLAG_CUSTOM_DIRECT_ARGUMENTS;
+  }
+  if (iree_hal_dispatch_uses_indirect_parameters(inputs->flags)) {
+    out_plan->flags |=
+        IREE_HAL_AMDGPU_AQL_DISPATCH_PLAN_FLAG_INDIRECT_PARAMETERS;
+  }
+  out_plan->kernarg_strategy =
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_HAL;
+  if (IREE_UNLIKELY(inputs->constants.data_length > 0 &&
+                    !inputs->constants.data)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "dispatch constant data must be non-null when length is non-zero");
+  }
+
+  if (iree_hal_amdgpu_aql_dispatch_plan_uses_custom_direct_arguments(
+          out_plan)) {
+    if (IREE_UNLIKELY(inputs->constants.data_length !=
+                      out_plan->descriptor->kernel_args.kernarg_size)) {
+      return iree_make_status(
+          IREE_STATUS_INVALID_ARGUMENT,
+          "custom dispatch argument length mismatch; expected %u but got "
+          "%" PRIhsz,
+          out_plan->descriptor->kernel_args.kernarg_size,
+          inputs->constants.data_length);
+    }
+    out_plan->layout = &out_plan->descriptor->custom_kernarg_layout;
+    out_plan->kernarg_block_count =
+        iree_max(1u, out_plan->descriptor->custom_kernarg_block_count);
+    out_plan->kernarg_strategy =
+        IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_CUSTOM_DIRECT;
+    return iree_ok_status();
+  }
+
+  const iree_host_size_t expected_constant_length =
+      (iree_host_size_t)out_plan->descriptor->kernel_args.constant_count *
+      sizeof(uint32_t);
+  if (IREE_UNLIKELY(inputs->constants.data_length !=
+                    expected_constant_length)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "dispatch constant count mismatch; expected %u but got %" PRIhsz,
+        (uint32_t)out_plan->descriptor->kernel_args.constant_count,
+        inputs->constants.data_length / sizeof(uint32_t));
+  }
+  if (IREE_UNLIKELY(inputs->bindings.count !=
+                    out_plan->descriptor->kernel_args.binding_count)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "dispatch binding count mismatch; expected %u but got %" PRIhsz,
+        (uint32_t)out_plan->descriptor->kernel_args.binding_count,
+        inputs->bindings.count);
+  }
+  if (IREE_UNLIKELY(inputs->bindings.count > 0 && !inputs->bindings.values)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "dispatch bindings must be non-null when count is non-zero");
+  }
+  out_plan->bindings.dynamic_count =
+      iree_hal_amdgpu_aql_command_buffer_count_dynamic_dispatch_bindings(
+          inputs->bindings);
+  if (out_plan->bindings.dynamic_count != 0) {
+    out_plan->flags |= IREE_HAL_AMDGPU_AQL_DISPATCH_PLAN_FLAG_DYNAMIC_BINDINGS;
+  }
+  if (out_plan->bindings.dynamic_count != 0 &&
+      !iree_hal_amdgpu_aql_dispatch_plan_uses_indirect_parameters(out_plan) &&
+      out_plan->bindings.dynamic_count == inputs->bindings.count) {
+    out_plan->kernarg_strategy =
+        IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_DYNAMIC_BINDINGS;
+  }
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_command_buffer_prepare_dispatch_binding_sources(
+          command_buffer, inputs->bindings));
+  out_plan->layout = &out_plan->descriptor->hal_kernarg_layout;
+  out_plan->kernarg_block_count =
+      iree_max(1u, out_plan->descriptor->hal_kernarg_block_count);
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_aql_command_buffer_retain_dispatch_inputs(
+    iree_hal_amdgpu_aql_command_buffer_t* command_buffer,
+    const iree_hal_amdgpu_aql_dispatch_inputs_t* inputs) {
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_command_buffer_ensure_resource_set(command_buffer));
+  if (!command_buffer->resource_set) return iree_ok_status();
+  return iree_hal_resource_set_insert(command_buffer->resource_set,
+                                      /*count=*/1, &inputs->executable);
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_command_buffer_calculate_dispatch_layout(
+    const iree_hal_amdgpu_aql_command_buffer_t* command_buffer,
+    const iree_hal_amdgpu_aql_dispatch_inputs_t* inputs,
+    const iree_hal_amdgpu_aql_dispatch_plan_t* plan,
+    iree_hal_amdgpu_aql_dispatch_layout_t* out_layout) {
+  *out_layout = (iree_hal_amdgpu_aql_dispatch_layout_t){0};
+
+  const iree_host_size_t binding_bytes =
+      iree_hal_amdgpu_aql_dispatch_plan_uses_custom_direct_arguments(plan)
+          ? 0
+          : (iree_host_size_t)plan->kernel_args->binding_count *
+                sizeof(uint64_t);
+  const iree_host_size_t tail_byte_length =
+      plan->layout->total_kernarg_size - binding_bytes;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_aql_command_buffer_qword_length(
+      tail_byte_length, "dispatch tail payload",
+      &out_layout->kernarg.tail_length_qwords,
+      &out_layout->kernarg.tail_padded_length));
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_aql_command_buffer_qword_length(
+      plan->layout->total_kernarg_size, "dispatch kernarg",
+      &out_layout->kernarg.total_length_qwords,
+      &out_layout->kernarg.total_padded_length));
+  out_layout->kernarg.implicit_args_offset_qwords =
+      plan->layout->has_implicit_args
+          ? (uint16_t)(plan->layout->implicit_args_offset / 8)
+          : UINT16_MAX;
+  if (IREE_UNLIKELY(plan->kernarg_block_count >
+                    UINT32_MAX / sizeof(iree_hal_amdgpu_kernarg_block_t))) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "dispatch kernargs require too many kernarg blocks (%" PRIu32 ")",
+        plan->kernarg_block_count);
+  }
+
+  const uint32_t dispatch_kernarg_block_length =
+      plan->kernarg_block_count * sizeof(iree_hal_amdgpu_kernarg_block_t);
+  const bool uses_indirect_parameters =
+      iree_hal_amdgpu_aql_dispatch_plan_uses_indirect_parameters(plan);
+  const bool uses_custom_direct_arguments =
+      iree_hal_amdgpu_aql_dispatch_plan_uses_custom_direct_arguments(plan);
+  const uint32_t patch_kernarg_block_length =
+      uses_indirect_parameters ? sizeof(iree_hal_amdgpu_kernarg_block_t) : 0;
+  const uint32_t kernarg_block_length =
+      patch_kernarg_block_length + dispatch_kernarg_block_length;
+  // Prepublication is a reusable-command-buffer strategy for immutable
+  // kernargs. It materializes static kernargs once at end() so replay avoids
+  // queue-time kernarg reservation, binding patching, and block growth.
+  if (iree_hal_amdgpu_aql_command_buffer_prepublish_enabled(command_buffer) &&
+      !iree_all_bits_set(command_buffer->base.mode,
+                         IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT) &&
+      !uses_indirect_parameters &&
+      !iree_hal_amdgpu_aql_dispatch_plan_has_dynamic_bindings(plan)) {
+    out_layout->flags |=
+        IREE_HAL_AMDGPU_AQL_DISPATCH_LAYOUT_FLAG_PREPUBLISH_KERNARGS;
+  }
+  // Mixed static/dynamic reusable dispatches keep an immutable host template
+  // and patch only the dynamic binding qwords at replay time. All-dynamic
+  // dispatches stay on the compact inline form but use a dynamic-only replay
+  // strategy so packet processing does not branch over impossible static
+  // binding source cases.
+  if (!iree_all_bits_set(command_buffer->base.mode,
+                         IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT) &&
+      !uses_indirect_parameters && !uses_custom_direct_arguments &&
+      iree_hal_amdgpu_aql_dispatch_plan_has_dynamic_bindings(plan) &&
+      plan->bindings.dynamic_count < inputs->bindings.count) {
+    out_layout->flags |=
+        IREE_HAL_AMDGPU_AQL_DISPATCH_LAYOUT_FLAG_PATCH_KERNARG_TEMPLATE;
+  }
+  const bool prepublishes_kernargs =
+      iree_hal_amdgpu_aql_dispatch_layout_prepublishes_kernargs(out_layout);
+  const bool uses_kernarg_template =
+      iree_hal_amdgpu_aql_dispatch_layout_uses_kernarg_template(out_layout);
+  const bool patches_kernarg_template =
+      iree_hal_amdgpu_aql_dispatch_layout_patches_kernarg_template(out_layout);
+  out_layout->kernarg.queue_block_length =
+      prepublishes_kernargs ? 0 : kernarg_block_length;
+
+  IREE_RETURN_IF_ERROR(IREE_STRUCT_LAYOUT(
+      sizeof(iree_hal_amdgpu_command_buffer_dispatch_command_t),
+      &out_layout->command.byte_length,
+      IREE_STRUCT_FIELD(
+          uses_kernarg_template ? 0 : out_layout->kernarg.tail_padded_length,
+          uint8_t, NULL)));
+  out_layout->command.binding_source_count =
+      prepublishes_kernargs
+          ? 0
+          : (patches_kernarg_template
+                 ? plan->bindings.dynamic_count
+                 : (uint16_t)((uses_custom_direct_arguments
+                                   ? 0
+                                   : inputs->bindings.count) +
+                              (uses_indirect_parameters ? 1 : 0)));
+  out_layout->command.aql_packet_count = uses_indirect_parameters ? 2 : 1;
+  return iree_ok_status();
+}
+
+static void iree_hal_amdgpu_aql_command_buffer_initialize_dispatch_command(
+    iree_hal_amdgpu_aql_command_buffer_t* command_buffer,
+    const iree_hal_amdgpu_aql_dispatch_inputs_t* inputs,
+    const iree_hal_amdgpu_aql_dispatch_plan_t* plan,
+    const iree_hal_amdgpu_aql_dispatch_layout_t* layout,
+    uint64_t kernarg_template_reference,
+    iree_hal_amdgpu_command_buffer_command_header_t* header,
+    iree_hal_amdgpu_command_buffer_binding_source_t* binding_sources) {
+  iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command =
+      (iree_hal_amdgpu_command_buffer_dispatch_command_t*)header;
+  dispatch_command->kernel_object = plan->kernel_args->kernel_object;
+  dispatch_command->binding_source_offset =
+      binding_sources
+          ? (uint32_t)((uint8_t*)binding_sources -
+                       (uint8_t*)command_buffer->builder.current_block.header)
+          : 0;
+  const bool prepublishes_kernargs =
+      iree_hal_amdgpu_aql_dispatch_layout_prepublishes_kernargs(layout);
+  const bool patches_kernarg_template =
+      iree_hal_amdgpu_aql_dispatch_layout_patches_kernarg_template(layout);
+  const bool uses_kernarg_template =
+      prepublishes_kernargs || patches_kernarg_template;
+  dispatch_command->payload_reference =
+      uses_kernarg_template
+          ? (uint32_t)kernarg_template_reference
+          : sizeof(iree_hal_amdgpu_command_buffer_dispatch_command_t);
+  dispatch_command->binding_count = (uint16_t)inputs->bindings.count;
+  dispatch_command->kernarg_length_qwords = layout->kernarg.total_length_qwords;
+  dispatch_command->payload.tail_length_qwords =
+      uses_kernarg_template ? 0 : layout->kernarg.tail_length_qwords;
+  if (patches_kernarg_template) {
+    dispatch_command->payload.patch_source_count = plan->bindings.dynamic_count;
+  }
+  if (prepublishes_kernargs) {
+    dispatch_command->kernarg_strategy =
+        IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_PREPUBLISHED;
+  } else if (patches_kernarg_template) {
+    dispatch_command->kernarg_strategy =
+        IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_PATCHED_TEMPLATE;
+  } else {
+    dispatch_command->kernarg_strategy = (uint8_t)plan->kernarg_strategy;
+  }
+  const bool uses_indirect_parameters =
+      iree_hal_amdgpu_aql_dispatch_plan_uses_indirect_parameters(plan);
+  dispatch_command->dispatch_flags =
+      uses_indirect_parameters
+          ? IREE_HAL_AMDGPU_COMMAND_BUFFER_DISPATCH_FLAG_INDIRECT_PARAMETERS
+          : IREE_HAL_AMDGPU_COMMAND_BUFFER_DISPATCH_FLAG_NONE;
+  dispatch_command->setup = plan->kernel_args->setup;
+  dispatch_command->export_ordinal = inputs->export_ordinal;
+  dispatch_command->workgroup_size[0] = plan->kernel_args->workgroup_size[0];
+  dispatch_command->workgroup_size[1] = plan->kernel_args->workgroup_size[1];
+  dispatch_command->workgroup_size[2] = plan->kernel_args->workgroup_size[2];
+  dispatch_command->implicit_args_offset_qwords =
+      layout->kernarg.implicit_args_offset_qwords;
+  dispatch_command->grid_size[0] =
+      uses_indirect_parameters ? 0
+                               : inputs->config.workgroup_count[0] *
+                                     plan->kernel_args->workgroup_size[0];
+  dispatch_command->grid_size[1] =
+      uses_indirect_parameters ? 0
+                               : inputs->config.workgroup_count[1] *
+                                     plan->kernel_args->workgroup_size[1];
+  dispatch_command->grid_size[2] =
+      uses_indirect_parameters ? 0
+                               : inputs->config.workgroup_count[2] *
+                                     plan->kernel_args->workgroup_size[2];
+  dispatch_command->private_segment_size =
+      plan->kernel_args->private_segment_size;
+  dispatch_command->group_segment_size =
+      plan->kernel_args->group_segment_size +
+      inputs->config.dynamic_workgroup_local_memory;
+  dispatch_command->executable_id =
+      iree_hal_amdgpu_executable_profile_id(inputs->executable);
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_command_buffer_ensure_current_dispatch_summary_block(
+    iree_hal_amdgpu_aql_command_buffer_t* command_buffer,
+    const iree_hal_amdgpu_command_buffer_block_header_t* block,
+    iree_hal_amdgpu_aql_command_buffer_dispatch_summary_block_t**
+        out_summary_block) {
+  *out_summary_block = command_buffer->dispatch_summaries.block.current;
+  if (*out_summary_block && (*out_summary_block)->header == block) {
+    return iree_ok_status();
+  }
+
+  iree_hal_amdgpu_aql_command_buffer_dispatch_summary_block_t* summary_block =
+      NULL;
+  IREE_RETURN_IF_ERROR(iree_arena_allocate(&command_buffer->recording_arena,
+                                           sizeof(*summary_block),
+                                           (void**)&summary_block));
+  memset(summary_block, 0, sizeof(*summary_block));
+  summary_block->header = block;
+  if (command_buffer->dispatch_summaries.block.current) {
+    command_buffer->dispatch_summaries.block.current->next = summary_block;
+  } else {
+    command_buffer->dispatch_summaries.block.first = summary_block;
+  }
+  command_buffer->dispatch_summaries.block.current = summary_block;
+  *out_summary_block = summary_block;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_aql_command_buffer_record_dispatch_summary(
+    iree_hal_amdgpu_aql_command_buffer_t* command_buffer,
+    const iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command,
+    const iree_hal_amdgpu_aql_dispatch_layout_t* layout) {
+  if (!iree_hal_amdgpu_aql_command_buffer_retains_dispatch_summaries(
+          command_buffer)) {
+    return iree_ok_status();
+  }
+
+  iree_hal_amdgpu_aql_command_buffer_dispatch_summary_block_t* summary_block =
+      NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_command_buffer_ensure_current_dispatch_summary_block(
+          command_buffer, command_buffer->builder.current_block.header,
+          &summary_block));
+
+  iree_hal_amdgpu_aql_command_buffer_dispatch_summary_t* summary = NULL;
+  IREE_RETURN_IF_ERROR(iree_arena_allocate(&command_buffer->recording_arena,
+                                           sizeof(*summary), (void**)&summary));
+  memset(summary, 0, sizeof(*summary));
+  const uint32_t first_packet_ordinal =
+      command_buffer->builder.current_block.aql_packet_count -
+      layout->command.aql_packet_count;
+  summary->packets.first_ordinal = first_packet_ordinal;
+  const bool uses_indirect_parameters = iree_any_bit_set(
+      dispatch_command->dispatch_flags,
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_DISPATCH_FLAG_INDIRECT_PARAMETERS);
+  summary->packets.dispatch_ordinal =
+      first_packet_ordinal + (uses_indirect_parameters ? 1u : 0u);
+  summary->metadata.executable_id = dispatch_command->executable_id;
+  summary->metadata.command_index = dispatch_command->header.command_index;
+  summary->metadata.export_ordinal = dispatch_command->export_ordinal;
+  summary->metadata.dispatch_flags = dispatch_command->dispatch_flags;
+  for (iree_host_size_t i = 0; i < IREE_ARRAYSIZE(summary->workgroup.size);
+       ++i) {
+    summary->workgroup.size[i] = dispatch_command->workgroup_size[i];
+    if (!uses_indirect_parameters && dispatch_command->workgroup_size[i] != 0) {
+      summary->workgroup.count[i] =
+          dispatch_command->grid_size[i] / dispatch_command->workgroup_size[i];
+    }
+  }
+
+  if (summary_block->dispatch.last) {
+    summary_block->dispatch.last->next = summary;
+  } else {
+    summary_block->dispatch.first = summary;
+  }
+  summary_block->dispatch.last = summary;
+  ++summary_block->dispatch.count;
+  ++command_buffer->dispatch_summaries.count;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_aql_command_buffer_append_dispatch_command(
+    iree_hal_amdgpu_aql_command_buffer_t* command_buffer,
+    const iree_hal_amdgpu_aql_dispatch_inputs_t* inputs,
+    const iree_hal_amdgpu_aql_dispatch_plan_t* plan,
+    const iree_hal_amdgpu_aql_dispatch_layout_t* layout,
+    iree_hal_amdgpu_command_buffer_dispatch_command_t** out_dispatch_command,
+    iree_hal_amdgpu_command_buffer_binding_source_t** out_binding_sources) {
+  *out_dispatch_command = NULL;
+  *out_binding_sources = NULL;
+
+  iree_hal_amdgpu_command_buffer_command_header_t* header = NULL;
+  const bool uses_indirect_parameters =
+      iree_hal_amdgpu_aql_dispatch_plan_uses_indirect_parameters(plan);
+  const uint8_t command_flags =
+      uses_indirect_parameters
+          ? IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_HAS_BARRIER
+          : IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_NONE;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_aql_program_builder_append_command(
+      &command_buffer->builder, IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_DISPATCH,
+      command_flags, layout->command.byte_length,
+      layout->command.binding_source_count, layout->command.aql_packet_count,
+      layout->kernarg.queue_block_length, &header, out_binding_sources));
+
+  uint64_t kernarg_template_reference = 0;
+  if (iree_hal_amdgpu_aql_dispatch_layout_prepublishes_kernargs(layout)) {
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_aql_command_buffer_record_prepublished_dispatch_kernargs(
+            command_buffer,
+            (iree_hal_amdgpu_command_buffer_dispatch_command_t*)header,
+            plan->kernel_args, plan->layout, inputs->config, inputs->constants,
+            inputs->bindings, plan->kernarg_strategy,
+            layout->kernarg.total_padded_length, &kernarg_template_reference));
+    if (IREE_UNLIKELY(kernarg_template_reference > UINT32_MAX)) {
+      return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                              "prepublished command-buffer kernarg rodata "
+                              "ordinal exceeds uint32_t");
+    }
+  } else if (iree_hal_amdgpu_aql_dispatch_layout_patches_kernarg_template(
+                 layout)) {
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_aql_command_buffer_record_patched_dispatch_kernargs(
+            command_buffer, plan->kernel_args, plan->layout, inputs->config,
+            inputs->constants, inputs->bindings, plan->kernarg_strategy,
+            layout->kernarg.total_padded_length, &kernarg_template_reference));
+    if (IREE_UNLIKELY(kernarg_template_reference > UINT32_MAX)) {
+      return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                              "patched command-buffer kernarg template rodata "
+                              "ordinal exceeds uint32_t");
+    }
+  }
+
+  iree_hal_amdgpu_aql_command_buffer_initialize_dispatch_command(
+      command_buffer, inputs, plan, layout, kernarg_template_reference, header,
+      *out_binding_sources);
+  if (uses_indirect_parameters) {
+    ++command_buffer->builder.current_block.indirect_dispatch_count;
+  }
+  *out_dispatch_command =
+      (iree_hal_amdgpu_command_buffer_dispatch_command_t*)header;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_aql_command_buffer_write_dispatch_payload(
+    iree_hal_amdgpu_aql_command_buffer_t* command_buffer,
+    const iree_hal_amdgpu_aql_dispatch_inputs_t* inputs,
+    const iree_hal_amdgpu_aql_dispatch_plan_t* plan,
+    const iree_hal_amdgpu_aql_dispatch_layout_t* layout,
+    iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command,
+    iree_hal_amdgpu_command_buffer_binding_source_t* binding_sources) {
+  const bool uses_custom_direct_arguments =
+      iree_hal_amdgpu_aql_dispatch_plan_uses_custom_direct_arguments(plan);
+  if (binding_sources && !uses_custom_direct_arguments) {
+    if (iree_hal_amdgpu_aql_dispatch_layout_patches_kernarg_template(layout)) {
+      iree_hal_amdgpu_aql_command_buffer_write_dynamic_dispatch_binding_sources(
+          inputs->bindings, binding_sources);
+    } else {
+      IREE_RETURN_IF_ERROR(
+          iree_hal_amdgpu_aql_command_buffer_write_dispatch_binding_sources(
+              command_buffer, inputs->bindings, binding_sources));
+    }
+  }
+  if (iree_hal_amdgpu_aql_dispatch_plan_uses_indirect_parameters(plan)) {
+    iree_hal_amdgpu_command_buffer_binding_source_t* parameter_source =
+        binding_sources +
+        (uses_custom_direct_arguments ? 0 : inputs->bindings.count);
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_aql_command_buffer_write_indirect_parameter_source(
+            command_buffer, inputs->config.workgroup_count_ref,
+            parameter_source));
+  }
+  if (iree_hal_amdgpu_aql_dispatch_layout_uses_kernarg_template(layout)) {
+    return iree_ok_status();
+  }
+
+  uint8_t* tail_payload =
+      (uint8_t*)dispatch_command + dispatch_command->payload_reference;
+  return iree_hal_amdgpu_aql_command_buffer_write_dispatch_tail(
+      plan->kernel_args, plan->layout, inputs->config, inputs->constants,
+      plan->kernarg_strategy, tail_payload);
+}
+
+//===----------------------------------------------------------------------===//
+// Buffer Commands
+//===----------------------------------------------------------------------===//
+
+static iree_status_t iree_hal_amdgpu_aql_command_buffer_advise_buffer(
+    iree_hal_command_buffer_t* base_command_buffer,
+    iree_hal_buffer_ref_t buffer_ref, iree_hal_memory_advise_flags_t flags,
+    uint64_t arg0, uint64_t arg1) {
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_aql_command_buffer_fill_buffer(
+    iree_hal_command_buffer_t* base_command_buffer,
+    iree_hal_buffer_ref_t target_ref, const void* pattern,
+    iree_host_size_t pattern_length, iree_hal_fill_flags_t flags) {
+  if (IREE_UNLIKELY(!pattern)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "fill pattern must be non-null");
+  }
+  if (IREE_UNLIKELY(pattern_length != 1 && pattern_length != 2 &&
+                    pattern_length != 4)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "fill patterns must be 1, 2, or 4 bytes (got %" PRIhsz ")",
+        pattern_length);
+  }
+  if (IREE_UNLIKELY(flags != IREE_HAL_FILL_FLAG_NONE)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "unsupported fill flags: 0x%" PRIx64, flags);
+  }
+
+  iree_hal_amdgpu_aql_command_buffer_t* command_buffer =
+      iree_hal_amdgpu_aql_command_buffer_cast(base_command_buffer);
+  iree_hal_amdgpu_command_buffer_binding_kind_t target_kind =
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_KIND_INVALID;
+  uint32_t target_ordinal = 0;
+  uint64_t target_offset = 0;
+  uint64_t length = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_aql_command_buffer_record_buffer_ref(
+      command_buffer, target_ref, &target_kind, &target_ordinal, &target_offset,
+      &length));
+
+  uint64_t pattern_bits = 0;
+  memcpy(&pattern_bits, pattern, pattern_length);
+  iree_hal_amdgpu_command_buffer_command_header_t* header = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_aql_program_builder_append_command(
+      &command_buffer->builder, IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_FILL,
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_NONE,
+      sizeof(iree_hal_amdgpu_command_buffer_fill_command_t),
+      /*binding_source_count=*/0, /*aql_packet_count=*/1,
+      sizeof(iree_hal_amdgpu_kernarg_block_t), &header,
+      /*out_binding_sources=*/NULL));
+
+  iree_hal_amdgpu_command_buffer_fill_command_t* fill_command =
+      (iree_hal_amdgpu_command_buffer_fill_command_t*)header;
+  fill_command->target_offset = target_offset;
+  fill_command->length = length;
+  fill_command->pattern = pattern_bits;
+  fill_command->target_ordinal = target_ordinal;
+  fill_command->target_kind = target_kind;
+  fill_command->pattern_length = (uint8_t)pattern_length;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_aql_command_buffer_update_buffer(
+    iree_hal_command_buffer_t* base_command_buffer, const void* source_buffer,
+    iree_host_size_t source_offset, iree_hal_buffer_ref_t target_ref,
+    iree_hal_update_flags_t flags) {
+  if (IREE_UNLIKELY(!source_buffer)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "update source buffer must be non-null");
+  }
+  if (IREE_UNLIKELY(target_ref.length >
+                    IREE_HAL_COMMAND_BUFFER_MAX_UPDATE_SIZE)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "command-buffer update length %" PRIdsz
+                            " exceeds maximum update size %" PRIdsz,
+                            target_ref.length,
+                            IREE_HAL_COMMAND_BUFFER_MAX_UPDATE_SIZE);
+  }
+  if (IREE_UNLIKELY(flags != IREE_HAL_UPDATE_FLAG_NONE)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "unsupported update flags: 0x%" PRIx64, flags);
+  }
+
+  iree_hal_amdgpu_aql_command_buffer_t* command_buffer =
+      iree_hal_amdgpu_aql_command_buffer_cast(base_command_buffer);
+  iree_hal_amdgpu_command_buffer_binding_kind_t target_kind =
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_KIND_INVALID;
+  uint32_t target_ordinal = 0;
+  uint64_t target_offset = 0;
+  uint64_t target_length = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_aql_command_buffer_record_buffer_ref(
+      command_buffer, target_ref, &target_kind, &target_ordinal, &target_offset,
+      &target_length));
+
+  uint64_t rodata_ordinal = 0;
+  const iree_host_size_t source_length = (iree_host_size_t)target_length;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_aql_command_buffer_record_rodata(
+      command_buffer, source_buffer, source_offset, source_length,
+      &rodata_ordinal));
+
+  const iree_host_size_t source_payload_offset =
+      IREE_HAL_AMDGPU_DEVICE_BUFFER_COPY_STAGED_SOURCE_OFFSET;
+  iree_host_size_t kernarg_length = 0;
+  if (IREE_UNLIKELY(!iree_host_size_checked_add(
+          source_payload_offset, source_length, &kernarg_length))) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "command-buffer update staging payload overflows host size "
+        "(offset=%" PRIhsz ", source_length=%" PRIhsz ")",
+        source_payload_offset, source_length);
+  }
+  const iree_host_size_t kernarg_block_count = iree_host_size_ceil_div(
+      kernarg_length, sizeof(iree_hal_amdgpu_kernarg_block_t));
+  iree_host_size_t kernarg_block_length = 0;
+  IREE_RETURN_IF_ERROR(IREE_STRUCT_LAYOUT(
+      0, &kernarg_block_length,
+      IREE_STRUCT_FIELD(kernarg_block_count, iree_hal_amdgpu_kernarg_block_t,
+                        NULL)));
+  if (IREE_UNLIKELY(kernarg_block_length > UINT32_MAX)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "command-buffer update staging payload requires too many kernarg "
+        "bytes (%" PRIhsz ")",
+        kernarg_block_length);
+  }
+
+  iree_hal_amdgpu_command_buffer_command_header_t* header = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_aql_program_builder_append_command(
+      &command_buffer->builder, IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_UPDATE,
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_NONE,
+      sizeof(iree_hal_amdgpu_command_buffer_update_command_t),
+      /*binding_source_count=*/0, /*aql_packet_count=*/1,
+      (uint32_t)kernarg_block_length, &header,
+      /*out_binding_sources=*/NULL));
+
+  iree_hal_amdgpu_command_buffer_update_command_t* update_command =
+      (iree_hal_amdgpu_command_buffer_update_command_t*)header;
+  update_command->rodata_ordinal = rodata_ordinal;
+  update_command->target_offset = target_offset;
+  update_command->length = (uint32_t)target_length;
+  update_command->target_ordinal = target_ordinal;
+  update_command->target_kind = target_kind;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_aql_command_buffer_copy_buffer(
+    iree_hal_command_buffer_t* base_command_buffer,
+    iree_hal_buffer_ref_t source_ref, iree_hal_buffer_ref_t target_ref,
+    iree_hal_copy_flags_t flags) {
+  iree_hal_amdgpu_aql_command_buffer_t* command_buffer =
+      iree_hal_amdgpu_aql_command_buffer_cast(base_command_buffer);
+  if (IREE_UNLIKELY(flags != IREE_HAL_COPY_FLAG_NONE)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "unsupported copy flags: 0x%" PRIx64, flags);
+  }
+  iree_hal_amdgpu_command_buffer_binding_kind_t source_kind =
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_KIND_INVALID;
+  uint32_t source_ordinal = 0;
+  uint64_t source_offset = 0;
+  uint64_t source_length = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_aql_command_buffer_record_buffer_ref(
+      command_buffer, source_ref, &source_kind, &source_ordinal, &source_offset,
+      &source_length));
+
+  iree_hal_amdgpu_command_buffer_binding_kind_t target_kind =
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_KIND_INVALID;
+  uint32_t target_ordinal = 0;
+  uint64_t target_offset = 0;
+  uint64_t target_length = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_aql_command_buffer_record_buffer_ref(
+      command_buffer, target_ref, &target_kind, &target_ordinal, &target_offset,
+      &target_length));
+
+  if (IREE_UNLIKELY(source_length != target_length)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "copy spans between source and target must match "
+                            "(source_length=%" PRIu64 ", target_length=%" PRIu64
+                            ")",
+                            source_length, target_length);
+  }
+  iree_hal_amdgpu_command_buffer_command_header_t* header = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_aql_program_builder_append_command(
+      &command_buffer->builder, IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_COPY,
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_NONE,
+      sizeof(iree_hal_amdgpu_command_buffer_copy_command_t),
+      /*binding_source_count=*/0, /*aql_packet_count=*/1,
+      sizeof(iree_hal_amdgpu_kernarg_block_t), &header,
+      /*out_binding_sources=*/NULL));
+
+  iree_hal_amdgpu_command_buffer_copy_command_t* copy_command =
+      (iree_hal_amdgpu_command_buffer_copy_command_t*)header;
+  copy_command->length = source_length;
+  copy_command->source_offset = source_offset;
+  copy_command->target_offset = target_offset;
+  copy_command->source_ordinal = source_ordinal;
+  copy_command->target_ordinal = target_ordinal;
+  copy_command->source_kind = source_kind;
+  copy_command->target_kind = target_kind;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_aql_command_buffer_collective(
+    iree_hal_command_buffer_t* base_command_buffer, iree_hal_channel_t* channel,
+    iree_hal_collective_op_t op, uint32_t param, iree_hal_buffer_ref_t send_ref,
+    iree_hal_buffer_ref_t recv_ref, iree_device_size_t element_count) {
+  return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                          "AMDGPU collectives not implemented");
+}
+
+static iree_status_t iree_hal_amdgpu_aql_command_buffer_dispatch(
+    iree_hal_command_buffer_t* base_command_buffer,
+    iree_hal_executable_t* executable,
+    iree_hal_executable_export_ordinal_t export_ordinal,
+    const iree_hal_dispatch_config_t config, iree_const_byte_span_t constants,
+    iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) {
+  iree_hal_amdgpu_aql_command_buffer_t* command_buffer =
+      iree_hal_amdgpu_aql_command_buffer_cast(base_command_buffer);
+  const iree_hal_amdgpu_aql_dispatch_inputs_t inputs = {
+      .executable = executable,
+      .export_ordinal = export_ordinal,
+      .config = config,
+      .constants = constants,
+      .bindings = bindings,
+      .flags = flags,
+  };
+  const bool validates =
+      iree_hal_amdgpu_aql_command_buffer_validates(command_buffer);
+  iree_hal_amdgpu_aql_dispatch_plan_t plan;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_aql_command_buffer_prepare_dispatch_plan(
+      command_buffer, &inputs, validates, &plan));
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_command_buffer_retain_dispatch_inputs(command_buffer,
+                                                                &inputs));
+
+  iree_hal_amdgpu_aql_dispatch_layout_t layout;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_command_buffer_calculate_dispatch_layout(
+          command_buffer, &inputs, &plan, &layout));
+
+  iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command = NULL;
+  iree_hal_amdgpu_command_buffer_binding_source_t* binding_sources = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_command_buffer_append_dispatch_command(
+          command_buffer, &inputs, &plan, &layout, &dispatch_command,
+          &binding_sources));
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_command_buffer_write_dispatch_payload(
+          command_buffer, &inputs, &plan, &layout, dispatch_command,
+          binding_sources));
+  return iree_hal_amdgpu_aql_command_buffer_record_dispatch_summary(
+      command_buffer, dispatch_command, &layout);
+}
+
+//===----------------------------------------------------------------------===//
+// Vtable
+//===----------------------------------------------------------------------===//
+
+static const iree_hal_command_buffer_vtable_t
+    iree_hal_amdgpu_aql_command_buffer_vtable = {
+        .destroy = iree_hal_amdgpu_aql_command_buffer_destroy,
+        .begin = iree_hal_amdgpu_aql_command_buffer_begin,
+        .end = iree_hal_amdgpu_aql_command_buffer_end,
+        .begin_debug_group =
+            iree_hal_amdgpu_aql_command_buffer_begin_debug_group,
+        .end_debug_group = iree_hal_amdgpu_aql_command_buffer_end_debug_group,
+        .execution_barrier =
+            iree_hal_amdgpu_aql_command_buffer_execution_barrier,
+        .signal_event = iree_hal_amdgpu_aql_command_buffer_signal_event,
+        .reset_event = iree_hal_amdgpu_aql_command_buffer_reset_event,
+        .wait_events = iree_hal_amdgpu_aql_command_buffer_wait_events,
+        .advise_buffer = iree_hal_amdgpu_aql_command_buffer_advise_buffer,
+        .fill_buffer = iree_hal_amdgpu_aql_command_buffer_fill_buffer,
+        .update_buffer = iree_hal_amdgpu_aql_command_buffer_update_buffer,
+        .copy_buffer = iree_hal_amdgpu_aql_command_buffer_copy_buffer,
+        .collective = iree_hal_amdgpu_aql_command_buffer_collective,
+        .dispatch = iree_hal_amdgpu_aql_command_buffer_dispatch,
+};
diff --git a/runtime/src/iree/hal/drivers/amdgpu/aql_command_buffer.h b/runtime/src/iree/hal/drivers/amdgpu/aql_command_buffer.h
new file mode 100644
index 0000000..1cec5e8
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/aql_command_buffer.h
@@ -0,0 +1,119 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_AQL_COMMAND_BUFFER_H_
+#define IREE_HAL_DRIVERS_AMDGPU_AQL_COMMAND_BUFFER_H_
+
+#include "iree/base/api.h"
+#include "iree/base/internal/arena.h"
+#include "iree/hal/api.h"
+#include "iree/hal/drivers/amdgpu/aql_prepublished_kernarg_storage.h"
+#include "iree/hal/drivers/amdgpu/aql_program_builder.h"
+#include "iree/hal/drivers/amdgpu/profile_metadata.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Creates a host-recorded AQL command buffer.
+//
+// The command buffer borrows |program_block_pool| for durable AQL program
+// storage and recording scratch memory. It borrows |resource_set_block_pool|
+// for retained HAL resource sets. Both pools must outlive all command buffers
+// created from them.
+iree_status_t iree_hal_amdgpu_aql_command_buffer_create(
+    iree_hal_allocator_t* device_allocator, iree_hal_command_buffer_mode_t mode,
+    iree_hal_command_category_t command_categories,
+    iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity,
+    iree_host_size_t device_ordinal,
+    iree_hal_amdgpu_aql_prepublished_kernarg_storage_t
+        prepublished_kernarg_storage,
+    iree_hal_amdgpu_profile_metadata_registry_t* profile_metadata,
+    iree_arena_block_pool_t* program_block_pool,
+    iree_arena_block_pool_t* resource_set_block_pool,
+    iree_allocator_t host_allocator,
+    iree_hal_command_buffer_t** out_command_buffer);
+
+// Returns true if |command_buffer| is an AMDGPU AQL command buffer.
+bool iree_hal_amdgpu_aql_command_buffer_isa(
+    iree_hal_command_buffer_t* command_buffer);
+
+// Retained command-buffer metadata for one dispatch operation.
+typedef struct iree_hal_amdgpu_aql_command_buffer_dispatch_summary_t {
+  // Next retained dispatch summary in the same command-buffer block.
+  const struct iree_hal_amdgpu_aql_command_buffer_dispatch_summary_t* next;
+  // Payload packet ordinals produced by command-buffer replay.
+  struct {
+    // First payload packet ordinal emitted for this command.
+    uint32_t first_ordinal;
+    // Payload packet ordinal of the dispatch whose completion signal represents
+    // this command.
+    uint32_t dispatch_ordinal;
+  } packets;
+  // Correlation metadata used by profiling, timestamps, and diagnostics.
+  struct {
+    // Session-local profile executable id used for event attribution.
+    uint64_t executable_id;
+    // Program-global command index used for profiling/source attribution.
+    uint32_t command_index;
+    // Executable export ordinal used for profiling and diagnostics.
+    uint32_t export_ordinal;
+    // Dispatch flags from iree_hal_amdgpu_command_buffer_dispatch_flag_bits_t.
+    uint8_t dispatch_flags;
+    // Reserved bytes that must be zero.
+    uint8_t reserved0[3];
+  } metadata;
+  // Dispatch launch dimensions used for event metadata.
+  struct {
+    // Static workgroup counts. Zero for indirect dispatches whose counts are
+    // read at device execution time.
+    uint32_t count[3];
+    // AQL dispatch packet workgroup size fields.
+    uint16_t size[3];
+  } workgroup;
+} iree_hal_amdgpu_aql_command_buffer_dispatch_summary_t;
+
+// Returns the immutable program produced by end().
+const iree_hal_amdgpu_aql_program_t* iree_hal_amdgpu_aql_command_buffer_program(
+    iree_hal_command_buffer_t* command_buffer);
+
+// Returns the physical device ordinal this command buffer was recorded for.
+iree_host_size_t iree_hal_amdgpu_aql_command_buffer_device_ordinal(
+    iree_hal_command_buffer_t* command_buffer);
+
+// Returns the producer-local profile command-buffer id, or 0 when recording
+// did not retain command-buffer profile metadata.
+uint64_t iree_hal_amdgpu_aql_command_buffer_profile_id(
+    iree_hal_command_buffer_t* command_buffer);
+
+// Returns retained dispatch summaries for |block|, or NULL when no
+// summaries were retained. |out_count| receives the number of linked summaries.
+const iree_hal_amdgpu_aql_command_buffer_dispatch_summary_t*
+iree_hal_amdgpu_aql_command_buffer_dispatch_summaries(
+    iree_hal_command_buffer_t* command_buffer,
+    const iree_hal_amdgpu_command_buffer_block_header_t* block,
+    uint32_t* out_count);
+
+// Returns a direct buffer recorded in the command-buffer static binding table.
+iree_hal_buffer_t* iree_hal_amdgpu_aql_command_buffer_static_buffer(
+    iree_hal_command_buffer_t* command_buffer, uint32_t ordinal);
+
+// Returns command-buffer-owned rodata referenced by |command_buffer|.
+const uint8_t* iree_hal_amdgpu_aql_command_buffer_rodata(
+    iree_hal_command_buffer_t* command_buffer, uint64_t ordinal,
+    uint32_t length);
+
+// Returns command-buffer-owned device-visible prepublished kernargs at
+// |byte_offset| within the finalized prepublished kernarg storage.
+void* iree_hal_amdgpu_aql_command_buffer_prepublished_kernarg(
+    iree_hal_command_buffer_t* command_buffer, uint32_t byte_offset,
+    uint32_t length);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_AQL_COMMAND_BUFFER_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/aql_command_buffer_profile.c b/runtime/src/iree/hal/drivers/amdgpu/aql_command_buffer_profile.c
new file mode 100644
index 0000000..1ecf3fb
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/aql_command_buffer_profile.c
@@ -0,0 +1,279 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/aql_command_buffer_profile.h"
+
+#include "iree/base/alignment.h"
+
+static iree_hal_profile_command_operation_flags_t
+iree_hal_amdgpu_aql_command_buffer_profile_binding_kind_flags(
+    iree_hal_amdgpu_command_buffer_binding_kind_t kind) {
+  switch (kind) {
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_KIND_STATIC:
+      return IREE_HAL_PROFILE_COMMAND_OPERATION_FLAG_STATIC_BINDINGS;
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_KIND_DYNAMIC:
+      return IREE_HAL_PROFILE_COMMAND_OPERATION_FLAG_DYNAMIC_BINDINGS;
+    default:
+      return IREE_HAL_PROFILE_COMMAND_OPERATION_FLAG_NONE;
+  }
+}
+
+static iree_hal_profile_command_operation_type_t
+iree_hal_amdgpu_aql_command_buffer_profile_operation_type(uint8_t opcode) {
+  switch (opcode) {
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BARRIER:
+      return IREE_HAL_PROFILE_COMMAND_OPERATION_TYPE_BARRIER;
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_DISPATCH:
+      return IREE_HAL_PROFILE_COMMAND_OPERATION_TYPE_DISPATCH;
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_FILL:
+      return IREE_HAL_PROFILE_COMMAND_OPERATION_TYPE_FILL;
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_COPY:
+      return IREE_HAL_PROFILE_COMMAND_OPERATION_TYPE_COPY;
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_UPDATE:
+      return IREE_HAL_PROFILE_COMMAND_OPERATION_TYPE_UPDATE;
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_PROFILE_MARKER:
+      return IREE_HAL_PROFILE_COMMAND_OPERATION_TYPE_PROFILE_MARKER;
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BRANCH:
+      return IREE_HAL_PROFILE_COMMAND_OPERATION_TYPE_BRANCH;
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_COND_BRANCH:
+      return IREE_HAL_PROFILE_COMMAND_OPERATION_TYPE_COND_BRANCH;
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_RETURN:
+      return IREE_HAL_PROFILE_COMMAND_OPERATION_TYPE_RETURN;
+    default:
+      return IREE_HAL_PROFILE_COMMAND_OPERATION_TYPE_NONE;
+  }
+}
+
+static iree_hal_profile_command_operation_flags_t
+iree_hal_amdgpu_aql_command_buffer_profile_dispatch_binding_flags(
+    const iree_hal_amdgpu_command_buffer_block_header_t* block,
+    const iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command) {
+  if (dispatch_command->binding_count == 0) {
+    return IREE_HAL_PROFILE_COMMAND_OPERATION_FLAG_NONE;
+  }
+
+  iree_hal_profile_command_operation_flags_t flags =
+      IREE_HAL_PROFILE_COMMAND_OPERATION_FLAG_NONE;
+  switch (dispatch_command->kernarg_strategy) {
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_PREPUBLISHED:
+      return IREE_HAL_PROFILE_COMMAND_OPERATION_FLAG_STATIC_BINDINGS;
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_PATCHED_TEMPLATE:
+      if (dispatch_command->payload.patch_source_count != 0) {
+        flags |= IREE_HAL_PROFILE_COMMAND_OPERATION_FLAG_DYNAMIC_BINDINGS;
+      }
+      if (dispatch_command->payload.patch_source_count <
+          dispatch_command->binding_count) {
+        flags |= IREE_HAL_PROFILE_COMMAND_OPERATION_FLAG_STATIC_BINDINGS;
+      }
+      return flags;
+    default:
+      break;
+  }
+  if (dispatch_command->binding_source_offset == 0) {
+    return IREE_HAL_PROFILE_COMMAND_OPERATION_FLAG_STATIC_BINDINGS;
+  }
+
+  const uint8_t* block_base = (const uint8_t*)block;
+  const uint32_t binding_source_offset =
+      dispatch_command->binding_source_offset;
+  const iree_hal_amdgpu_command_buffer_binding_source_t* binding_sources =
+      (const iree_hal_amdgpu_command_buffer_binding_source_t*)(block_base +
+                                                               binding_source_offset);
+  for (uint16_t binding_ordinal = 0;
+       binding_ordinal < dispatch_command->binding_count; ++binding_ordinal) {
+    flags |= iree_any_bit_set(
+                 binding_sources[binding_ordinal].flags,
+                 IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_DYNAMIC)
+                 ? IREE_HAL_PROFILE_COMMAND_OPERATION_FLAG_DYNAMIC_BINDINGS
+                 : IREE_HAL_PROFILE_COMMAND_OPERATION_FLAG_STATIC_BINDINGS;
+  }
+  return flags;
+}
+
+static void iree_hal_amdgpu_aql_command_buffer_initialize_profile_operation(
+    uint64_t command_buffer_id,
+    const iree_hal_amdgpu_command_buffer_block_header_t* block,
+    uint32_t block_command_ordinal,
+    const iree_hal_amdgpu_command_buffer_command_header_t* command,
+    iree_hal_profile_command_operation_record_t* out_record) {
+  iree_hal_profile_command_operation_record_t record =
+      iree_hal_profile_command_operation_record_default();
+  record.type = iree_hal_amdgpu_aql_command_buffer_profile_operation_type(
+      command->opcode);
+  record.command_buffer_id = command_buffer_id;
+  record.command_index = command->command_index;
+  record.flags |= IREE_HAL_PROFILE_COMMAND_OPERATION_FLAG_BLOCK_STRUCTURE;
+  record.block_ordinal = block->block_ordinal;
+  record.block_command_ordinal = block_command_ordinal;
+  if (iree_any_bit_set(
+          command->flags,
+          IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_HAS_BARRIER)) {
+    record.flags |= IREE_HAL_PROFILE_COMMAND_OPERATION_FLAG_EXECUTION_BARRIER;
+  }
+
+  switch (command->opcode) {
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BARRIER:
+      record.flags |= IREE_HAL_PROFILE_COMMAND_OPERATION_FLAG_EXECUTION_BARRIER;
+      break;
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_DISPATCH: {
+      const iree_hal_amdgpu_command_buffer_dispatch_command_t*
+          dispatch_command =
+              (const iree_hal_amdgpu_command_buffer_dispatch_command_t*)command;
+      record.flags |=
+          iree_hal_amdgpu_aql_command_buffer_profile_dispatch_binding_flags(
+              block, dispatch_command);
+      if (iree_any_bit_set(
+              dispatch_command->dispatch_flags,
+              IREE_HAL_AMDGPU_COMMAND_BUFFER_DISPATCH_FLAG_INDIRECT_PARAMETERS)) {
+        record.flags |=
+            IREE_HAL_PROFILE_COMMAND_OPERATION_FLAG_INDIRECT_PARAMETERS;
+      }
+      if (dispatch_command->kernarg_strategy ==
+          IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_PREPUBLISHED) {
+        record.flags |=
+            IREE_HAL_PROFILE_COMMAND_OPERATION_FLAG_PREPUBLISHED_ARGUMENTS;
+      }
+      record.executable_id = dispatch_command->executable_id;
+      record.export_ordinal = dispatch_command->export_ordinal;
+      record.binding_count = dispatch_command->binding_count;
+      record.workgroup_size[0] = dispatch_command->workgroup_size[0];
+      record.workgroup_size[1] = dispatch_command->workgroup_size[1];
+      record.workgroup_size[2] = dispatch_command->workgroup_size[2];
+      if (!iree_any_bit_set(
+              dispatch_command->dispatch_flags,
+              IREE_HAL_AMDGPU_COMMAND_BUFFER_DISPATCH_FLAG_INDIRECT_PARAMETERS)) {
+        for (iree_host_size_t dimension_ordinal = 0;
+             dimension_ordinal < IREE_ARRAYSIZE(record.workgroup_count);
+             ++dimension_ordinal) {
+          record.workgroup_count[dimension_ordinal] =
+              dispatch_command->workgroup_size[dimension_ordinal] == 0
+                  ? 0
+                  : dispatch_command->grid_size[dimension_ordinal] /
+                        dispatch_command->workgroup_size[dimension_ordinal];
+        }
+      }
+      break;
+    }
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_FILL: {
+      const iree_hal_amdgpu_command_buffer_fill_command_t* fill_command =
+          (const iree_hal_amdgpu_command_buffer_fill_command_t*)command;
+      record.flags |=
+          iree_hal_amdgpu_aql_command_buffer_profile_binding_kind_flags(
+              fill_command->target_kind);
+      record.target_offset = fill_command->target_offset;
+      record.length = fill_command->length;
+      record.target_ordinal = fill_command->target_ordinal;
+      break;
+    }
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_COPY: {
+      const iree_hal_amdgpu_command_buffer_copy_command_t* copy_command =
+          (const iree_hal_amdgpu_command_buffer_copy_command_t*)command;
+      record.flags |=
+          iree_hal_amdgpu_aql_command_buffer_profile_binding_kind_flags(
+              copy_command->source_kind);
+      record.flags |=
+          iree_hal_amdgpu_aql_command_buffer_profile_binding_kind_flags(
+              copy_command->target_kind);
+      record.source_offset = copy_command->source_offset;
+      record.target_offset = copy_command->target_offset;
+      record.length = copy_command->length;
+      record.source_ordinal = copy_command->source_ordinal;
+      record.target_ordinal = copy_command->target_ordinal;
+      break;
+    }
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_UPDATE: {
+      const iree_hal_amdgpu_command_buffer_update_command_t* update_command =
+          (const iree_hal_amdgpu_command_buffer_update_command_t*)command;
+      record.flags |=
+          iree_hal_amdgpu_aql_command_buffer_profile_binding_kind_flags(
+              update_command->target_kind);
+      record.target_offset = update_command->target_offset;
+      record.length = update_command->length;
+      record.target_ordinal = update_command->target_ordinal;
+      break;
+    }
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BRANCH: {
+      const iree_hal_amdgpu_command_buffer_branch_command_t* branch_command =
+          (const iree_hal_amdgpu_command_buffer_branch_command_t*)command;
+      record.flags |= IREE_HAL_PROFILE_COMMAND_OPERATION_FLAG_CONTROL_FLOW;
+      record.target_block_ordinal = branch_command->target_block_ordinal;
+      break;
+    }
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_COND_BRANCH: {
+      const iree_hal_amdgpu_command_buffer_cond_branch_command_t*
+          branch_command =
+              (const iree_hal_amdgpu_command_buffer_cond_branch_command_t*)
+                  command;
+      record.flags |= IREE_HAL_PROFILE_COMMAND_OPERATION_FLAG_CONTROL_FLOW;
+      record.target_block_ordinal = branch_command->true_block_ordinal;
+      record.alternate_block_ordinal = branch_command->false_block_ordinal;
+      break;
+    }
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_RETURN:
+      record.flags |= IREE_HAL_PROFILE_COMMAND_OPERATION_FLAG_CONTROL_FLOW;
+      break;
+    default:
+      break;
+  }
+  *out_record = record;
+}
+
+iree_status_t iree_hal_amdgpu_aql_command_buffer_register_profile_operations(
+    iree_hal_amdgpu_profile_metadata_registry_t* profile_metadata,
+    uint64_t command_buffer_id, const iree_hal_amdgpu_aql_program_t* program,
+    iree_allocator_t host_allocator) {
+  if (program->command_count == 0) {
+    return iree_ok_status();
+  }
+
+  iree_host_size_t byte_length = 0;
+  IREE_RETURN_IF_ERROR(IREE_STRUCT_LAYOUT(
+      0, &byte_length,
+      IREE_STRUCT_FIELD(program->command_count,
+                        iree_hal_profile_command_operation_record_t, NULL)));
+
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, program->command_count);
+
+  iree_hal_profile_command_operation_record_t* records = NULL;
+  iree_status_t status =
+      iree_allocator_malloc(host_allocator, byte_length, (void**)&records);
+
+  iree_host_size_t record_count = 0;
+  if (iree_status_is_ok(status)) {
+    for (iree_hal_amdgpu_command_buffer_block_header_t* block =
+             program->first_block;
+         block; block = iree_hal_amdgpu_aql_program_block_next(
+                    program->block_pool, block)) {
+      const iree_hal_amdgpu_command_buffer_command_header_t* command =
+          iree_hal_amdgpu_command_buffer_block_commands_const(block);
+      for (uint16_t command_ordinal = 0;
+           command_ordinal < block->command_count &&
+           record_count < program->command_count;
+           ++command_ordinal) {
+        iree_hal_amdgpu_aql_command_buffer_initialize_profile_operation(
+            command_buffer_id, block, command_ordinal, command,
+            &records[record_count++]);
+        command = iree_hal_amdgpu_command_buffer_command_next_const(command);
+      }
+    }
+  }
+  if (iree_status_is_ok(status) && record_count != program->command_count) {
+    status =
+        iree_make_status(IREE_STATUS_INTERNAL,
+                         "profile command-operation count mismatch: expected "
+                         "%u but got %" PRIhsz,
+                         program->command_count, record_count);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_profile_metadata_register_command_operations(
+        profile_metadata, record_count, records);
+  }
+
+  iree_allocator_free(host_allocator, records);
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/aql_command_buffer_profile.h b/runtime/src/iree/hal/drivers/amdgpu/aql_command_buffer_profile.h
new file mode 100644
index 0000000..4210fbe
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/aql_command_buffer_profile.h
@@ -0,0 +1,34 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_AQL_COMMAND_BUFFER_PROFILE_H_
+#define IREE_HAL_DRIVERS_AMDGPU_AQL_COMMAND_BUFFER_PROFILE_H_
+
+#include "iree/base/api.h"
+#include "iree/hal/drivers/amdgpu/aql_program_builder.h"
+#include "iree/hal/drivers/amdgpu/profile_metadata.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Synthesizes command-operation profile metadata for every command in a
+// finalized AQL command-buffer program and registers it with
+// |profile_metadata|.
+//
+// This is a cold command-buffer finalization path. It performs one temporary
+// host allocation sized to |program->command_count| and does not run during
+// queue submission or replay.
+iree_status_t iree_hal_amdgpu_aql_command_buffer_register_profile_operations(
+    iree_hal_amdgpu_profile_metadata_registry_t* profile_metadata,
+    uint64_t command_buffer_id, const iree_hal_amdgpu_aql_program_t* program,
+    iree_allocator_t host_allocator);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_AQL_COMMAND_BUFFER_PROFILE_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/aql_command_buffer_test.cc b/runtime/src/iree/hal/drivers/amdgpu/aql_command_buffer_test.cc
new file mode 100644
index 0000000..d423d74
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/aql_command_buffer_test.cc
@@ -0,0 +1,312 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/aql_command_buffer.h"
+
+#include <array>
+#include <cstring>
+#include <memory>
+
+#include "iree/hal/drivers/amdgpu/abi/queue.h"
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace iree::hal::amdgpu {
+namespace {
+
+struct CommandBufferDeleter {
+  void operator()(iree_hal_command_buffer_t* command_buffer) const {
+    iree_hal_command_buffer_release(command_buffer);
+  }
+};
+
+using CommandBufferPtr =
+    std::unique_ptr<iree_hal_command_buffer_t, CommandBufferDeleter>;
+
+class AqlCommandBufferTest : public ::testing::Test {
+ protected:
+  void SetUp() override {
+    IREE_ASSERT_OK(iree_hal_allocator_create_heap(
+        iree_make_cstring_view("aql_command_buffer_test"),
+        iree_allocator_system(), iree_allocator_system(), &device_allocator_));
+    iree_hal_amdgpu_profile_metadata_initialize(iree_allocator_system(),
+                                                &profile_metadata_);
+    IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_block_pool_initialize(
+        block_size_, iree_allocator_system(), &block_pool_));
+  }
+
+  void TearDown() override {
+    iree_arena_block_pool_deinitialize(&block_pool_);
+    iree_hal_amdgpu_profile_metadata_deinitialize(&profile_metadata_);
+    iree_hal_allocator_release(device_allocator_);
+  }
+
+  CommandBufferPtr CreateCommandBufferWithMode(
+      iree_hal_command_buffer_mode_t mode,
+      iree_host_size_t binding_capacity = 0) {
+    return CreateCommandBufferWithProfileMetadata(mode, &profile_metadata_,
+                                                  binding_capacity);
+  }
+
+  CommandBufferPtr CreateCommandBufferWithProfileMetadata(
+      iree_hal_command_buffer_mode_t mode,
+      iree_hal_amdgpu_profile_metadata_registry_t* profile_metadata,
+      iree_host_size_t binding_capacity = 0) {
+    iree_hal_command_buffer_t* command_buffer = nullptr;
+    IREE_EXPECT_OK(iree_hal_amdgpu_aql_command_buffer_create(
+        device_allocator_, mode, IREE_HAL_COMMAND_CATEGORY_ANY,
+        IREE_HAL_QUEUE_AFFINITY_ANY, binding_capacity, /*device_ordinal=*/0,
+        iree_hal_amdgpu_aql_prepublished_kernarg_storage_disabled(),
+        profile_metadata, &block_pool_, &block_pool_, iree_allocator_system(),
+        &command_buffer));
+    return CommandBufferPtr(command_buffer);
+  }
+
+  CommandBufferPtr CreateCommandBuffer(iree_host_size_t binding_capacity = 0) {
+    return CreateCommandBufferWithMode(IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT,
+                                       binding_capacity);
+  }
+
+  const iree_hal_amdgpu_profile_metadata_registry_t& profile_metadata() const {
+    return profile_metadata_;
+  }
+
+ private:
+  // Test allocator borrowed by command buffers for validation.
+  iree_hal_allocator_t* device_allocator_ = nullptr;
+  // Fixed block size used by command-buffer tests.
+  iree_host_size_t block_size_ = 256;
+  // Program and resource-set block pool borrowed by test command buffers.
+  iree_arena_block_pool_t block_pool_;
+  // Profile metadata registry borrowed by test command buffers.
+  iree_hal_amdgpu_profile_metadata_registry_t profile_metadata_;
+};
+
+TEST_F(AqlCommandBufferTest, UnrecordedCommandBufferHasNoProgram) {
+  CommandBufferPtr command_buffer = CreateCommandBuffer();
+  ASSERT_NE(command_buffer, nullptr);
+
+  EXPECT_TRUE(iree_hal_amdgpu_aql_command_buffer_isa(command_buffer.get()));
+  const iree_hal_amdgpu_aql_program_t* program =
+      iree_hal_amdgpu_aql_command_buffer_program(command_buffer.get());
+  EXPECT_EQ(program->first_block, nullptr);
+}
+
+TEST_F(AqlCommandBufferTest, EmptyRecordingHasReturnTerminator) {
+  CommandBufferPtr command_buffer = CreateCommandBuffer();
+  ASSERT_NE(command_buffer, nullptr);
+
+  IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer.get()));
+  IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer.get()));
+
+  const iree_hal_amdgpu_aql_program_t* program =
+      iree_hal_amdgpu_aql_command_buffer_program(command_buffer.get());
+  ASSERT_NE(program->first_block, nullptr);
+  EXPECT_EQ(program->block_count, 1u);
+  EXPECT_EQ(program->command_count, 1u);
+
+  const iree_hal_amdgpu_command_buffer_command_header_t* command =
+      iree_hal_amdgpu_command_buffer_block_commands_const(program->first_block);
+  EXPECT_EQ(command->opcode, IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_RETURN);
+}
+
+TEST_F(AqlCommandBufferTest, UnvalidatedCommandBufferCannotBeginTwice) {
+  CommandBufferPtr command_buffer =
+      CreateCommandBufferWithMode(IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT |
+                                  IREE_HAL_COMMAND_BUFFER_MODE_UNVALIDATED);
+  ASSERT_NE(command_buffer, nullptr);
+
+  IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer.get()));
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_FAILED_PRECONDITION,
+                        iree_hal_command_buffer_begin(command_buffer.get()));
+  IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer.get()));
+}
+
+TEST_F(AqlCommandBufferTest, UnvalidatedCommandBufferCannotRerecord) {
+  CommandBufferPtr command_buffer =
+      CreateCommandBufferWithMode(IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT |
+                                  IREE_HAL_COMMAND_BUFFER_MODE_UNVALIDATED);
+  ASSERT_NE(command_buffer, nullptr);
+
+  IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer.get()));
+  IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer.get()));
+
+  const iree_hal_amdgpu_aql_program_t* program =
+      iree_hal_amdgpu_aql_command_buffer_program(command_buffer.get());
+  ASSERT_NE(program->first_block, nullptr);
+  const iree_hal_amdgpu_command_buffer_block_header_t* first_block =
+      program->first_block;
+  const uint32_t command_count = program->command_count;
+  const iree_host_size_t profile_operation_count =
+      profile_metadata().command_operation_record_count;
+  EXPECT_EQ(profile_operation_count, 0u);
+
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_FAILED_PRECONDITION,
+                        iree_hal_command_buffer_begin(command_buffer.get()));
+
+  program = iree_hal_amdgpu_aql_command_buffer_program(command_buffer.get());
+  EXPECT_EQ(first_block, program->first_block);
+  EXPECT_EQ(command_count, program->command_count);
+  EXPECT_EQ(profile_operation_count,
+            profile_metadata().command_operation_record_count);
+}
+
+TEST_F(AqlCommandBufferTest, RetainedProfileMetadataRegistersOperations) {
+  CommandBufferPtr command_buffer = CreateCommandBufferWithMode(
+      IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT |
+      IREE_HAL_COMMAND_BUFFER_MODE_RETAIN_PROFILE_METADATA);
+  ASSERT_NE(command_buffer, nullptr);
+
+  IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer.get()));
+  IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer.get()));
+
+  const iree_hal_amdgpu_aql_program_t* program =
+      iree_hal_amdgpu_aql_command_buffer_program(command_buffer.get());
+  ASSERT_NE(program->first_block, nullptr);
+  EXPECT_EQ(program->command_count,
+            profile_metadata().command_operation_record_count);
+}
+
+TEST_F(AqlCommandBufferTest, RetainedDispatchMetadataDoesNotRequireProfile) {
+  CommandBufferPtr command_buffer = CreateCommandBufferWithProfileMetadata(
+      IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT |
+          IREE_HAL_COMMAND_BUFFER_MODE_RETAIN_DISPATCH_METADATA,
+      /*profile_metadata=*/nullptr);
+  ASSERT_NE(command_buffer, nullptr);
+  EXPECT_EQ(iree_hal_amdgpu_aql_command_buffer_profile_id(command_buffer.get()),
+            0u);
+
+  IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer.get()));
+  IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer.get()));
+  EXPECT_EQ(profile_metadata().command_buffer_record_count, 0u);
+  EXPECT_EQ(profile_metadata().command_operation_record_count, 0u);
+}
+
+TEST_F(AqlCommandBufferTest, BarrierOnlyRecordingHasBarrierAndReturn) {
+  CommandBufferPtr command_buffer = CreateCommandBuffer();
+  ASSERT_NE(command_buffer, nullptr);
+
+  IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer.get()));
+  IREE_ASSERT_OK(iree_hal_command_buffer_execution_barrier(
+      command_buffer.get(), IREE_HAL_EXECUTION_STAGE_COMMAND_ISSUE,
+      IREE_HAL_EXECUTION_STAGE_COMMAND_RETIRE,
+      IREE_HAL_EXECUTION_BARRIER_FLAG_NONE,
+      /*memory_barrier_count=*/0, /*memory_barriers=*/nullptr,
+      /*buffer_barrier_count=*/0, /*buffer_barriers=*/nullptr));
+  IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer.get()));
+
+  const iree_hal_amdgpu_aql_program_t* program =
+      iree_hal_amdgpu_aql_command_buffer_program(command_buffer.get());
+  ASSERT_NE(program->first_block, nullptr);
+  EXPECT_EQ(program->block_count, 1u);
+  EXPECT_EQ(program->command_count, 2u);
+
+  const iree_hal_amdgpu_command_buffer_command_header_t* barrier_command =
+      iree_hal_amdgpu_command_buffer_block_commands_const(program->first_block);
+  EXPECT_EQ(barrier_command->opcode,
+            IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BARRIER);
+  const auto* barrier =
+      reinterpret_cast<const iree_hal_amdgpu_command_buffer_barrier_command_t*>(
+          barrier_command);
+  EXPECT_EQ(barrier->acquire_scope, IREE_HSA_FENCE_SCOPE_NONE);
+  EXPECT_EQ(barrier->release_scope, IREE_HSA_FENCE_SCOPE_NONE);
+  const iree_hal_amdgpu_command_buffer_command_header_t* return_command =
+      iree_hal_amdgpu_command_buffer_command_next_const(barrier_command);
+  EXPECT_EQ(return_command->opcode,
+            IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_RETURN);
+}
+
+TEST_F(AqlCommandBufferTest, MemoryBarrierRecordingPreservesFenceScopes) {
+  CommandBufferPtr command_buffer = CreateCommandBuffer();
+  ASSERT_NE(command_buffer, nullptr);
+
+  const iree_hal_memory_barrier_t memory_barrier = {
+      .source_scope = IREE_HAL_ACCESS_SCOPE_DISPATCH_WRITE,
+      .target_scope = IREE_HAL_ACCESS_SCOPE_DISPATCH_READ,
+  };
+  IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer.get()));
+  IREE_ASSERT_OK(iree_hal_command_buffer_execution_barrier(
+      command_buffer.get(), IREE_HAL_EXECUTION_STAGE_DISPATCH,
+      IREE_HAL_EXECUTION_STAGE_DISPATCH, IREE_HAL_EXECUTION_BARRIER_FLAG_NONE,
+      /*memory_barrier_count=*/1, &memory_barrier,
+      /*buffer_barrier_count=*/0, /*buffer_barriers=*/nullptr));
+  IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer.get()));
+
+  const iree_hal_amdgpu_aql_program_t* program =
+      iree_hal_amdgpu_aql_command_buffer_program(command_buffer.get());
+  ASSERT_NE(program->first_block, nullptr);
+  const iree_hal_amdgpu_command_buffer_command_header_t* barrier_command =
+      iree_hal_amdgpu_command_buffer_block_commands_const(program->first_block);
+  ASSERT_EQ(barrier_command->opcode,
+            IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BARRIER);
+  const auto* barrier =
+      reinterpret_cast<const iree_hal_amdgpu_command_buffer_barrier_command_t*>(
+          barrier_command);
+  EXPECT_EQ(barrier->acquire_scope, IREE_HSA_FENCE_SCOPE_AGENT);
+  EXPECT_EQ(barrier->release_scope, IREE_HSA_FENCE_SCOPE_AGENT);
+}
+
+TEST_F(AqlCommandBufferTest, UpdatePayloadsUseStableRodataOrdinals) {
+  CommandBufferPtr command_buffer = CreateCommandBuffer(/*binding_capacity=*/1);
+  ASSERT_NE(command_buffer, nullptr);
+
+  std::array<uint8_t, 300> source_bytes0;
+  for (size_t i = 0; i < source_bytes0.size(); ++i) {
+    source_bytes0[i] = (uint8_t)i;
+  }
+  std::array<uint8_t, 19> source_bytes1;
+  for (size_t i = 0; i < source_bytes1.size(); ++i) {
+    source_bytes1[i] = (uint8_t)(0xE0u + i);
+  }
+
+  iree_hal_buffer_ref_t target_ref0 = {0};
+  target_ref0.buffer_slot = 0;
+  target_ref0.offset = 0;
+  target_ref0.length = source_bytes0.size();
+  iree_hal_buffer_ref_t target_ref1 = {0};
+  target_ref1.buffer_slot = 0;
+  target_ref1.offset = source_bytes0.size();
+  target_ref1.length = 7;
+
+  IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer.get()));
+  IREE_ASSERT_OK(iree_hal_command_buffer_update_buffer(
+      command_buffer.get(), source_bytes0.data(), /*source_offset=*/0,
+      target_ref0, IREE_HAL_UPDATE_FLAG_NONE));
+  IREE_ASSERT_OK(iree_hal_command_buffer_update_buffer(
+      command_buffer.get(), source_bytes1.data(), /*source_offset=*/5,
+      target_ref1, IREE_HAL_UPDATE_FLAG_NONE));
+  IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer.get()));
+
+  const iree_hal_amdgpu_aql_program_t* program =
+      iree_hal_amdgpu_aql_command_buffer_program(command_buffer.get());
+  ASSERT_NE(program->first_block, nullptr);
+  ASSERT_EQ(program->command_count, 3u);
+
+  const iree_hal_amdgpu_command_buffer_command_header_t* command0 =
+      iree_hal_amdgpu_command_buffer_block_commands_const(program->first_block);
+  ASSERT_EQ(command0->opcode, IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_UPDATE);
+  const iree_hal_amdgpu_command_buffer_update_command_t* update0 =
+      (const iree_hal_amdgpu_command_buffer_update_command_t*)command0;
+  const uint8_t* rodata0 = iree_hal_amdgpu_aql_command_buffer_rodata(
+      command_buffer.get(), update0->rodata_ordinal, update0->length);
+  ASSERT_NE(rodata0, nullptr);
+  EXPECT_EQ(0,
+            std::memcmp(rodata0, source_bytes0.data(), source_bytes0.size()));
+
+  const iree_hal_amdgpu_command_buffer_command_header_t* command1 =
+      iree_hal_amdgpu_command_buffer_command_next_const(command0);
+  ASSERT_EQ(command1->opcode, IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_UPDATE);
+  const iree_hal_amdgpu_command_buffer_update_command_t* update1 =
+      (const iree_hal_amdgpu_command_buffer_update_command_t*)command1;
+  const uint8_t* rodata1 = iree_hal_amdgpu_aql_command_buffer_rodata(
+      command_buffer.get(), update1->rodata_ordinal, update1->length);
+  ASSERT_NE(rodata1, nullptr);
+  EXPECT_EQ(0, std::memcmp(rodata1, source_bytes1.data() + 5,
+                           (size_t)target_ref1.length));
+}
+
+}  // namespace
+}  // namespace iree::hal::amdgpu
diff --git a/runtime/src/iree/hal/drivers/amdgpu/aql_prepublished_kernarg_storage.h b/runtime/src/iree/hal/drivers/amdgpu/aql_prepublished_kernarg_storage.h
new file mode 100644
index 0000000..8cb3bfe
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/aql_prepublished_kernarg_storage.h
@@ -0,0 +1,60 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_AQL_PREPUBLISHED_KERNARG_STORAGE_H_
+#define IREE_HAL_DRIVERS_AMDGPU_AQL_PREPUBLISHED_KERNARG_STORAGE_H_
+
+#include "iree/base/api.h"
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Strategy used to materialize reusable command-buffer kernarg templates.
+typedef enum iree_hal_amdgpu_aql_prepublished_kernarg_storage_strategy_e {
+  IREE_HAL_AMDGPU_AQL_PREPUBLISHED_KERNARG_STORAGE_STRATEGY_DISABLED = 0,
+  // Device-local fine-grained memory that is CPU-visible and host-coherent.
+  IREE_HAL_AMDGPU_AQL_PREPUBLISHED_KERNARG_STORAGE_STRATEGY_DEVICE_FINE_HOST_COHERENT =
+      1,
+} iree_hal_amdgpu_aql_prepublished_kernarg_storage_strategy_t;
+
+// Storage strategy for finalized reusable command-buffer kernarg templates.
+typedef struct iree_hal_amdgpu_aql_prepublished_kernarg_storage_t {
+  // Selected backing strategy.
+  iree_hal_amdgpu_aql_prepublished_kernarg_storage_strategy_t strategy;
+  // HAL allocation parameters used for materialized kernarg storage.
+  iree_hal_buffer_params_t buffer_params;
+} iree_hal_amdgpu_aql_prepublished_kernarg_storage_t;
+
+static inline iree_hal_amdgpu_aql_prepublished_kernarg_storage_t
+iree_hal_amdgpu_aql_prepublished_kernarg_storage_disabled(void) {
+  iree_hal_amdgpu_aql_prepublished_kernarg_storage_t storage = {
+      IREE_HAL_AMDGPU_AQL_PREPUBLISHED_KERNARG_STORAGE_STRATEGY_DISABLED};
+  return storage;
+}
+
+static inline iree_hal_amdgpu_aql_prepublished_kernarg_storage_t
+iree_hal_amdgpu_aql_prepublished_kernarg_storage_device_fine_host_coherent(
+    void) {
+  iree_hal_amdgpu_aql_prepublished_kernarg_storage_t storage =
+      iree_hal_amdgpu_aql_prepublished_kernarg_storage_disabled();
+  storage.strategy =
+      IREE_HAL_AMDGPU_AQL_PREPUBLISHED_KERNARG_STORAGE_STRATEGY_DEVICE_FINE_HOST_COHERENT;
+  storage.buffer_params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL |
+                               IREE_HAL_MEMORY_TYPE_HOST_VISIBLE |
+                               IREE_HAL_MEMORY_TYPE_HOST_COHERENT;
+  storage.buffer_params.access = IREE_HAL_MEMORY_ACCESS_ALL;
+  storage.buffer_params.usage = IREE_HAL_BUFFER_USAGE_DISPATCH_UNIFORM_READ |
+                                IREE_HAL_BUFFER_USAGE_MAPPING;
+  return storage;
+}
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_AQL_PREPUBLISHED_KERNARG_STORAGE_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/aql_program_builder.c b/runtime/src/iree/hal/drivers/amdgpu/aql_program_builder.c
new file mode 100644
index 0000000..2239167
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/aql_program_builder.c
@@ -0,0 +1,647 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/aql_program_builder.h"
+
+#include <string.h>
+
+#include "iree/base/alignment.h"
+#include "iree/hal/drivers/amdgpu/abi/queue.h"
+
+//===----------------------------------------------------------------------===//
+// Block Pool Utilities
+//===----------------------------------------------------------------------===//
+
+iree_status_t iree_hal_amdgpu_aql_program_block_pool_initialize(
+    iree_host_size_t block_size, iree_allocator_t host_allocator,
+    iree_arena_block_pool_t* out_block_pool) {
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, block_size);
+  if (IREE_UNLIKELY(!iree_host_size_is_power_of_two(block_size) ||
+                    block_size < IREE_HAL_AMDGPU_AQL_PROGRAM_MIN_BLOCK_SIZE)) {
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                             "command-buffer block size must be a "
+                             "power-of-two >= %u bytes",
+                             IREE_HAL_AMDGPU_AQL_PROGRAM_MIN_BLOCK_SIZE));
+  }
+  if (IREE_UNLIKELY(block_size > UINT32_MAX)) {
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_make_status(
+                IREE_STATUS_INVALID_ARGUMENT,
+                "command-buffer block size must fit in the block ABI"));
+  }
+
+  iree_host_size_t total_block_size = 0;
+  if (IREE_UNLIKELY(!iree_host_size_checked_add(
+          block_size, sizeof(iree_arena_block_t), &total_block_size))) {
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
+                             "command-buffer block size overflow"));
+  }
+
+  iree_arena_block_pool_initialize(total_block_size, host_allocator,
+                                   out_block_pool);
+  IREE_TRACE_ZONE_END(z0);
+  return iree_ok_status();
+}
+
+//===----------------------------------------------------------------------===//
+// Arena Block Helpers
+//===----------------------------------------------------------------------===//
+
+static iree_arena_block_t* iree_hal_amdgpu_aql_program_arena_block(
+    iree_arena_block_pool_t* block_pool,
+    const iree_hal_amdgpu_command_buffer_block_header_t* block) {
+  return iree_arena_block_trailer(block_pool, (void*)block);
+}
+
+static iree_hal_amdgpu_command_buffer_block_header_t*
+iree_hal_amdgpu_aql_program_block_from_arena(
+    iree_arena_block_pool_t* block_pool, iree_arena_block_t* arena_block) {
+  return arena_block ? (iree_hal_amdgpu_command_buffer_block_header_t*)
+                           iree_arena_block_ptr(block_pool, arena_block)
+                     : NULL;
+}
+
+iree_hal_amdgpu_command_buffer_block_header_t*
+iree_hal_amdgpu_aql_program_block_next(
+    iree_arena_block_pool_t* block_pool,
+    const iree_hal_amdgpu_command_buffer_block_header_t* block) {
+  iree_arena_block_t* arena_block =
+      iree_hal_amdgpu_aql_program_arena_block(block_pool, block);
+  return iree_hal_amdgpu_aql_program_block_from_arena(block_pool,
+                                                      arena_block->next);
+}
+
+//===----------------------------------------------------------------------===//
+// Recording Output
+//===----------------------------------------------------------------------===//
+
+void iree_hal_amdgpu_aql_program_release(
+    iree_hal_amdgpu_aql_program_t* program) {
+  if (!program->first_block) {
+    return;
+  }
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  iree_arena_block_t* first_arena_block =
+      iree_hal_amdgpu_aql_program_arena_block(program->block_pool,
+                                              program->first_block);
+  iree_arena_block_t* last_arena_block = first_arena_block;
+  while (last_arena_block->next) {
+    last_arena_block = last_arena_block->next;
+  }
+
+  iree_arena_block_pool_release(program->block_pool, first_arena_block,
+                                last_arena_block);
+  memset(program, 0, sizeof(*program));
+  IREE_TRACE_ZONE_END(z0);
+}
+
+//===----------------------------------------------------------------------===//
+// Lifecycle
+//===----------------------------------------------------------------------===//
+
+void iree_hal_amdgpu_aql_program_builder_initialize(
+    iree_arena_block_pool_t* block_pool,
+    iree_hal_amdgpu_aql_program_builder_t* out_builder) {
+  memset(out_builder, 0, sizeof(*out_builder));
+  out_builder->block_pool = block_pool;
+}
+
+void iree_hal_amdgpu_aql_program_builder_deinitialize(
+    iree_hal_amdgpu_aql_program_builder_t* builder) {
+  if (!builder->block_pool) {
+    return;
+  }
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  if (builder->current_block.header) {
+    iree_arena_block_t* arena_block = iree_hal_amdgpu_aql_program_arena_block(
+        builder->block_pool, builder->current_block.header);
+    arena_block->next = NULL;
+    iree_arena_block_pool_release(builder->block_pool, arena_block,
+                                  arena_block);
+    builder->current_block.header = NULL;
+  }
+
+  if (builder->first_block) {
+    iree_hal_amdgpu_aql_program_t program = {
+        .block_pool = builder->block_pool,
+        .first_block = builder->first_block,
+    };
+    iree_hal_amdgpu_aql_program_release(&program);
+  }
+
+  memset(builder, 0, sizeof(*builder));
+  IREE_TRACE_ZONE_END(z0);
+}
+
+//===----------------------------------------------------------------------===//
+// Block Management
+//===----------------------------------------------------------------------===//
+
+static iree_status_t iree_hal_amdgpu_aql_program_builder_begin_block(
+    iree_hal_amdgpu_aql_program_builder_t* builder) {
+  if (IREE_UNLIKELY(builder->block_count == UINT32_MAX)) {
+    return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
+                            "command-buffer block count overflow");
+  }
+
+  iree_arena_block_t* arena_block = NULL;
+  void* block_data = NULL;
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_arena_block_pool_acquire(builder->block_pool, &arena_block,
+                                        &block_data));
+
+  arena_block->next = NULL;
+  iree_hal_amdgpu_command_buffer_block_header_t* block =
+      (iree_hal_amdgpu_command_buffer_block_header_t*)block_data;
+  memset(block, 0, sizeof(*block));
+  block->magic = IREE_HAL_AMDGPU_COMMAND_BUFFER_BLOCK_MAGIC;
+  block->version = IREE_HAL_AMDGPU_COMMAND_BUFFER_BLOCK_VERSION_0;
+  block->header_length = sizeof(*block);
+  block->block_ordinal = builder->block_count;
+  block->block_length = (uint32_t)builder->block_pool->usable_block_size;
+  block->command_offset = sizeof(*block);
+  block->rodata_offset = block->block_length;
+
+  builder->current_block.header = block;
+  builder->current_block.command_cursor = (uint8_t*)block + sizeof(*block);
+  builder->current_block.binding_source_cursor =
+      (uint8_t*)block + builder->block_pool->usable_block_size;
+  builder->current_block.command_count = 0;
+  builder->current_block.binding_source_count = 0;
+  builder->current_block.dispatch_count = 0;
+  builder->current_block.indirect_dispatch_count = 0;
+  builder->current_block.profile_marker_count = 0;
+  builder->current_block.aql_packet_count = 0;
+  builder->current_block.kernarg_length = 0;
+  builder->current_block.initial_barrier_packet_count = 0;
+  builder->current_block.pending_barrier_acquire_scope =
+      IREE_HSA_FENCE_SCOPE_NONE;
+  builder->current_block.flags = IREE_HAL_AMDGPU_AQL_PROGRAM_BUILDER_FLAG_NONE;
+  ++builder->block_count;
+  IREE_TRACE_ZONE_END(z0);
+  return iree_ok_status();
+}
+
+static void iree_hal_amdgpu_aql_program_builder_finalize_block(
+    iree_hal_amdgpu_aql_program_builder_t* builder) {
+  iree_hal_amdgpu_command_buffer_block_header_t* block =
+      builder->current_block.header;
+  block->command_length = (uint32_t)(builder->current_block.command_cursor -
+                                     ((uint8_t*)block + block->command_offset));
+  block->binding_source_offset =
+      (uint32_t)(builder->current_block.binding_source_cursor -
+                 (uint8_t*)block);
+  block->command_count = builder->current_block.command_count;
+  block->binding_source_count = builder->current_block.binding_source_count;
+  block->aql_packet_count = builder->current_block.aql_packet_count;
+  block->kernarg_length = builder->current_block.kernarg_length;
+  block->dispatch_count = builder->current_block.dispatch_count;
+  block->indirect_dispatch_count =
+      builder->current_block.indirect_dispatch_count;
+  block->profile_marker_count = builder->current_block.profile_marker_count;
+  block->initial_barrier_packet_count =
+      iree_any_bit_set(
+          builder->current_block.flags,
+          IREE_HAL_AMDGPU_AQL_PROGRAM_BUILDER_FLAG_HAS_INITIAL_BARRIER_PACKET)
+          ? builder->current_block.initial_barrier_packet_count
+          : block->aql_packet_count;
+
+  if (block->aql_packet_count > builder->max_block_aql_packet_count) {
+    builder->max_block_aql_packet_count = block->aql_packet_count;
+  }
+  if (block->kernarg_length > builder->max_block_kernarg_length) {
+    builder->max_block_kernarg_length = block->kernarg_length;
+  }
+
+  iree_arena_block_t* arena_block =
+      iree_hal_amdgpu_aql_program_arena_block(builder->block_pool, block);
+  if (builder->last_block) {
+    iree_arena_block_t* last_arena_block =
+        iree_hal_amdgpu_aql_program_arena_block(builder->block_pool,
+                                                builder->last_block);
+    last_arena_block->next = arena_block;
+  } else {
+    builder->first_block = block;
+  }
+  builder->last_block = block;
+
+  builder->current_block.header = NULL;
+  builder->current_block.command_cursor = NULL;
+  builder->current_block.binding_source_cursor = NULL;
+}
+
+static iree_host_size_t iree_hal_amdgpu_aql_program_builder_remaining(
+    const iree_hal_amdgpu_aql_program_builder_t* builder) {
+  return (iree_host_size_t)(builder->current_block.binding_source_cursor -
+                            builder->current_block.command_cursor);
+}
+
+static iree_status_t iree_hal_amdgpu_aql_program_builder_append_terminator(
+    iree_hal_amdgpu_aql_program_builder_t* builder, uint8_t opcode,
+    uint32_t target_block_ordinal) {
+  if (IREE_UNLIKELY(builder->current_block.command_count == UINT16_MAX ||
+                    builder->command_count == UINT32_MAX)) {
+    return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
+                            "command-buffer command count overflow");
+  }
+
+  iree_host_size_t command_length = 0;
+  if (opcode == IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BRANCH) {
+    command_length = sizeof(iree_hal_amdgpu_command_buffer_branch_command_t);
+  } else if (opcode == IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_RETURN) {
+    command_length = sizeof(iree_hal_amdgpu_command_buffer_return_command_t);
+  } else {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "unsupported command-buffer terminator opcode %u",
+                            opcode);
+  }
+
+  if (IREE_UNLIKELY(iree_hal_amdgpu_aql_program_builder_remaining(builder) <
+                    command_length)) {
+    return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
+                            "command-buffer block has no terminator space");
+  }
+
+  iree_hal_amdgpu_command_buffer_command_header_t* header =
+      (iree_hal_amdgpu_command_buffer_command_header_t*)
+          builder->current_block.command_cursor;
+  memset(header, 0, command_length);
+  header->opcode = opcode;
+  header->length_qwords =
+      (uint16_t)(command_length /
+                 IREE_HAL_AMDGPU_COMMAND_BUFFER_RECORD_ALIGNMENT);
+  header->command_index = builder->command_count;
+
+  if (opcode == IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BRANCH) {
+    iree_hal_amdgpu_command_buffer_branch_command_t* branch_command =
+        (iree_hal_amdgpu_command_buffer_branch_command_t*)header;
+    branch_command->target_block_ordinal = target_block_ordinal;
+  }
+  builder->current_block.header->terminator_opcode = opcode;
+  builder->current_block.header->terminator_target_block_ordinal =
+      opcode == IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BRANCH
+          ? target_block_ordinal
+          : 0;
+
+  builder->current_block.command_cursor += command_length;
+  ++builder->command_count;
+  ++builder->current_block.command_count;
+  return iree_ok_status();
+}
+
+static iree_host_size_t iree_hal_amdgpu_aql_program_terminator_reserve(void) {
+  return sizeof(iree_hal_amdgpu_command_buffer_branch_command_t);
+}
+
+static iree_status_t iree_hal_amdgpu_aql_program_builder_split_block(
+    iree_hal_amdgpu_aql_program_builder_t* builder) {
+  if (IREE_UNLIKELY(builder->block_count == UINT32_MAX)) {
+    return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
+                            "command-buffer block count overflow");
+  }
+  const iree_hal_amdgpu_aql_program_builder_flags_t carried_flags =
+      builder->current_block.flags &
+      IREE_HAL_AMDGPU_AQL_PROGRAM_BUILDER_FLAG_HAS_PENDING_EXECUTION_BARRIER;
+  const uint8_t carried_acquire_scope =
+      carried_flags ? builder->current_block.pending_barrier_acquire_scope
+                    : IREE_HSA_FENCE_SCOPE_NONE;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_aql_program_builder_append_terminator(
+      builder, IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BRANCH,
+      builder->block_count));
+  iree_hal_amdgpu_aql_program_builder_finalize_block(builder);
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_aql_program_builder_begin_block(builder));
+  builder->current_block.flags |= carried_flags;
+  builder->current_block.pending_barrier_acquire_scope = carried_acquire_scope;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_aql_program_builder_validate_command(
+    const iree_hal_amdgpu_aql_program_builder_t* builder, uint8_t opcode,
+    iree_host_size_t command_length, uint16_t binding_source_count,
+    iree_host_size_t* out_required_length,
+    iree_host_size_t* out_binding_source_length) {
+  if (IREE_UNLIKELY(!builder->current_block.header)) {
+    return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                            "command-buffer builder is not recording");
+  }
+  if (IREE_UNLIKELY(opcode == IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_INVALID ||
+                    opcode == IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BRANCH ||
+                    opcode ==
+                        IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_COND_BRANCH ||
+                    opcode == IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_RETURN)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "opcode %u cannot be appended as a work command",
+                            opcode);
+  }
+  if (IREE_UNLIKELY(
+          command_length <
+              sizeof(iree_hal_amdgpu_command_buffer_command_header_t) ||
+          !iree_host_size_has_alignment(
+              command_length,
+              IREE_HAL_AMDGPU_COMMAND_BUFFER_RECORD_ALIGNMENT) ||
+          command_length / IREE_HAL_AMDGPU_COMMAND_BUFFER_RECORD_ALIGNMENT >
+              UINT16_MAX)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "command length must be qword-aligned and fit in "
+                            "uint16 qword units");
+  }
+
+  iree_host_size_t binding_source_length = 0;
+  IREE_RETURN_IF_ERROR(IREE_STRUCT_LAYOUT(
+      0, &binding_source_length,
+      IREE_STRUCT_FIELD(binding_source_count,
+                        iree_hal_amdgpu_command_buffer_binding_source_t,
+                        NULL)));
+
+  iree_host_size_t required_length = 0;
+  if (IREE_UNLIKELY(!iree_host_size_checked_add(
+          command_length, binding_source_length, &required_length))) {
+    return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
+                            "command record size overflow");
+  }
+  *out_required_length = required_length;
+  *out_binding_source_length = binding_source_length;
+  return iree_ok_status();
+}
+
+static bool iree_hal_amdgpu_aql_program_command_fits_empty_block(
+    const iree_hal_amdgpu_aql_program_builder_t* builder,
+    iree_host_size_t required_length) {
+  const iree_host_size_t empty_available =
+      builder->block_pool->usable_block_size -
+      sizeof(iree_hal_amdgpu_command_buffer_block_header_t);
+  iree_host_size_t required_with_terminator = 0;
+  if (!iree_host_size_checked_add(
+          required_length, iree_hal_amdgpu_aql_program_terminator_reserve(),
+          &required_with_terminator)) {
+    return false;
+  }
+  return required_with_terminator <= empty_available;
+}
+
+static bool iree_hal_amdgpu_aql_program_command_fits_current_block(
+    const iree_hal_amdgpu_aql_program_builder_t* builder,
+    uint16_t binding_source_count, uint32_t aql_packet_count,
+    uint32_t kernarg_length) {
+  if (builder->current_block.command_count > UINT16_MAX - 2) return false;
+  if (binding_source_count >
+      UINT16_MAX - builder->current_block.binding_source_count) {
+    return false;
+  }
+  if (aql_packet_count > UINT32_MAX - builder->current_block.aql_packet_count) {
+    return false;
+  }
+  if (kernarg_length > UINT32_MAX - builder->current_block.kernarg_length) {
+    return false;
+  }
+  return true;
+}
+
+//===----------------------------------------------------------------------===//
+// Recording
+//===----------------------------------------------------------------------===//
+
+iree_status_t iree_hal_amdgpu_aql_program_builder_begin(
+    iree_hal_amdgpu_aql_program_builder_t* builder) {
+  if (IREE_UNLIKELY(!builder->block_pool)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "command-buffer builder requires a block pool");
+  }
+  if (IREE_UNLIKELY(!iree_host_size_is_power_of_two(
+                        builder->block_pool->usable_block_size) ||
+                    builder->block_pool->usable_block_size <
+                        IREE_HAL_AMDGPU_AQL_PROGRAM_MIN_BLOCK_SIZE)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "command-buffer block pool must have power-of-two "
+                            "usable blocks >= %u bytes",
+                            IREE_HAL_AMDGPU_AQL_PROGRAM_MIN_BLOCK_SIZE);
+  }
+  if (IREE_UNLIKELY(builder->block_pool->usable_block_size > UINT32_MAX)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "command-buffer block pool usable size must fit in the block ABI");
+  }
+  if (IREE_UNLIKELY(builder->current_block.header || builder->first_block)) {
+    return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                            "command-buffer builder already has a recording");
+  }
+  builder->last_payload_command = NULL;
+  return iree_hal_amdgpu_aql_program_builder_begin_block(builder);
+}
+
+iree_status_t iree_hal_amdgpu_aql_program_builder_end(
+    iree_hal_amdgpu_aql_program_builder_t* builder,
+    iree_hal_amdgpu_aql_program_t* out_program) {
+  if (IREE_UNLIKELY(!out_program)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "command-buffer program output is required");
+  }
+  memset(out_program, 0, sizeof(*out_program));
+
+  if (IREE_UNLIKELY(!builder->current_block.header)) {
+    return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                            "command-buffer builder is not recording");
+  }
+
+  iree_status_t status = iree_hal_amdgpu_aql_program_builder_append_terminator(
+      builder, IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_RETURN,
+      /*target_block_ordinal=*/0);
+  if (iree_status_is_ok(status)) {
+    iree_hal_amdgpu_aql_program_builder_finalize_block(builder);
+    *out_program = (iree_hal_amdgpu_aql_program_t){
+        .block_pool = builder->block_pool,
+        .first_block = builder->first_block,
+        .block_count = builder->block_count,
+        .command_count = builder->command_count,
+        .max_block_aql_packet_count = builder->max_block_aql_packet_count,
+        .max_block_kernarg_length = builder->max_block_kernarg_length,
+    };
+    builder->first_block = NULL;
+    builder->last_block = NULL;
+    builder->last_payload_command = NULL;
+    builder->block_count = 0;
+    builder->command_count = 0;
+    builder->max_block_aql_packet_count = 0;
+    builder->max_block_kernarg_length = 0;
+  }
+  return status;
+}
+
+iree_status_t iree_hal_amdgpu_aql_program_builder_append_command(
+    iree_hal_amdgpu_aql_program_builder_t* builder, uint8_t opcode,
+    uint8_t flags, iree_host_size_t command_length,
+    uint16_t binding_source_count, uint32_t aql_packet_count,
+    uint32_t kernarg_length,
+    iree_hal_amdgpu_command_buffer_command_header_t** out_command,
+    iree_hal_amdgpu_command_buffer_binding_source_t** out_binding_sources) {
+  if (IREE_UNLIKELY(!out_command ||
+                    (binding_source_count > 0 && !out_binding_sources))) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "command output pointers are required");
+  }
+  *out_command = NULL;
+  if (out_binding_sources) *out_binding_sources = NULL;
+
+  if (IREE_UNLIKELY(builder->command_count == UINT32_MAX)) {
+    return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
+                            "command-buffer command count overflow");
+  }
+
+  iree_host_size_t required_length = 0;
+  iree_host_size_t binding_source_length = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_aql_program_builder_validate_command(
+      builder, opcode, command_length, binding_source_count, &required_length,
+      &binding_source_length));
+  if (IREE_UNLIKELY(!iree_hal_amdgpu_aql_program_command_fits_empty_block(
+          builder, required_length))) {
+    return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
+                            "command record and binding sources cannot fit in "
+                            "one command-buffer block");
+  }
+
+  iree_host_size_t required_with_terminator = 0;
+  if (IREE_UNLIKELY(!iree_host_size_checked_add(
+          required_length, iree_hal_amdgpu_aql_program_terminator_reserve(),
+          &required_with_terminator))) {
+    return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
+                            "command record size overflow");
+  }
+  if (iree_hal_amdgpu_aql_program_builder_remaining(builder) <
+          required_with_terminator ||
+      !iree_hal_amdgpu_aql_program_command_fits_current_block(
+          builder, binding_source_count, aql_packet_count, kernarg_length)) {
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_aql_program_builder_split_block(builder));
+    if (IREE_UNLIKELY(builder->command_count == UINT32_MAX)) {
+      return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
+                              "command-buffer command count overflow");
+    }
+  }
+
+  iree_hal_amdgpu_command_buffer_command_header_t* command =
+      (iree_hal_amdgpu_command_buffer_command_header_t*)
+          builder->current_block.command_cursor;
+  memset(command, 0, command_length);
+  builder->current_block.command_cursor += command_length;
+
+  iree_hal_amdgpu_command_buffer_binding_source_t* binding_sources = NULL;
+  if (binding_source_length > 0) {
+    builder->current_block.binding_source_cursor -= binding_source_length;
+    binding_sources = (iree_hal_amdgpu_command_buffer_binding_source_t*)
+                          builder->current_block.binding_source_cursor;
+    memset(binding_sources, 0, binding_source_length);
+  }
+
+  uint8_t command_flags = flags;
+  if (kernarg_length != 0) {
+    command_flags |=
+        IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_USES_QUEUE_KERNARGS;
+  }
+  const bool command_has_payload = aql_packet_count != 0;
+  const bool has_pending_execution_barrier = iree_any_bit_set(
+      builder->current_block.flags,
+      IREE_HAL_AMDGPU_AQL_PROGRAM_BUILDER_FLAG_HAS_PENDING_EXECUTION_BARRIER);
+  const bool command_forces_barrier = iree_any_bit_set(
+      command_flags, IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_HAS_BARRIER);
+  const bool command_has_barrier_packet =
+      command_has_payload &&
+      (has_pending_execution_barrier || command_forces_barrier);
+
+  command->opcode = opcode;
+  if (command_has_barrier_packet) {
+    uint8_t acquire_scope = IREE_HSA_FENCE_SCOPE_NONE;
+    uint8_t release_scope = IREE_HSA_FENCE_SCOPE_NONE;
+    if (has_pending_execution_barrier) {
+      acquire_scope = builder->current_block.pending_barrier_acquire_scope;
+    }
+    if (command_forces_barrier) {
+      if (acquire_scope < IREE_HSA_FENCE_SCOPE_AGENT) {
+        acquire_scope = IREE_HSA_FENCE_SCOPE_AGENT;
+      }
+      release_scope = IREE_HSA_FENCE_SCOPE_AGENT;
+    }
+    command_flags |= IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_HAS_BARRIER;
+    command_flags =
+        iree_hal_amdgpu_command_buffer_command_flags_set_fence_scopes(
+            command_flags, acquire_scope, release_scope);
+  }
+  command->flags = command_flags;
+  command->length_qwords =
+      (uint16_t)(command_length /
+                 IREE_HAL_AMDGPU_COMMAND_BUFFER_RECORD_ALIGNMENT);
+  command->command_index = builder->command_count;
+
+  if (command_has_barrier_packet &&
+      !iree_any_bit_set(
+          builder->current_block.flags,
+          IREE_HAL_AMDGPU_AQL_PROGRAM_BUILDER_FLAG_HAS_INITIAL_BARRIER_PACKET)) {
+    builder->current_block.initial_barrier_packet_count =
+        builder->current_block.aql_packet_count + 1;
+    builder->current_block.flags |=
+        IREE_HAL_AMDGPU_AQL_PROGRAM_BUILDER_FLAG_HAS_INITIAL_BARRIER_PACKET;
+  }
+  if (opcode == IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BARRIER) {
+    builder->current_block.flags |=
+        IREE_HAL_AMDGPU_AQL_PROGRAM_BUILDER_FLAG_HAS_PENDING_EXECUTION_BARRIER;
+  } else if (command_has_payload) {
+    builder->current_block.flags &=
+        ~IREE_HAL_AMDGPU_AQL_PROGRAM_BUILDER_FLAG_HAS_PENDING_EXECUTION_BARRIER;
+    builder->current_block.pending_barrier_acquire_scope =
+        IREE_HSA_FENCE_SCOPE_NONE;
+  }
+  if (command_has_payload) {
+    builder->last_payload_command = command;
+  }
+
+  ++builder->command_count;
+  ++builder->current_block.command_count;
+  if (opcode == IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_DISPATCH) {
+    ++builder->current_block.dispatch_count;
+  } else if (opcode == IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_PROFILE_MARKER) {
+    ++builder->current_block.profile_marker_count;
+  }
+  builder->current_block.binding_source_count += binding_source_count;
+  builder->current_block.aql_packet_count += aql_packet_count;
+  builder->current_block.kernarg_length += kernarg_length;
+
+  *out_command = command;
+  if (out_binding_sources) *out_binding_sources = binding_sources;
+  return iree_ok_status();
+}
+
+void iree_hal_amdgpu_aql_program_builder_set_pending_barrier_scopes(
+    iree_hal_amdgpu_aql_program_builder_t* builder, uint8_t acquire_scope,
+    uint8_t release_scope) {
+  if (!iree_any_bit_set(
+          builder->current_block.flags,
+          IREE_HAL_AMDGPU_AQL_PROGRAM_BUILDER_FLAG_HAS_PENDING_EXECUTION_BARRIER)) {
+    return;
+  }
+  if (builder->last_payload_command) {
+    const uint8_t current_acquire_scope =
+        iree_hal_amdgpu_command_buffer_command_flags_acquire_scope(
+            builder->last_payload_command->flags);
+    const uint8_t current_release_scope =
+        iree_hal_amdgpu_command_buffer_command_flags_release_scope(
+            builder->last_payload_command->flags);
+    if (release_scope > current_release_scope) {
+      builder->last_payload_command->flags =
+          iree_hal_amdgpu_command_buffer_command_flags_set_fence_scopes(
+              builder->last_payload_command->flags, current_acquire_scope,
+              release_scope);
+    }
+  }
+  if (acquire_scope > builder->current_block.pending_barrier_acquire_scope) {
+    builder->current_block.pending_barrier_acquire_scope = acquire_scope;
+  }
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/aql_program_builder.h b/runtime/src/iree/hal/drivers/amdgpu/aql_program_builder.h
new file mode 100644
index 0000000..e685845
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/aql_program_builder.h
@@ -0,0 +1,176 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_AQL_PROGRAM_BUILDER_H_
+#define IREE_HAL_DRIVERS_AMDGPU_AQL_PROGRAM_BUILDER_H_
+
+#include "iree/base/api.h"
+#include "iree/base/internal/arena.h"
+#include "iree/hal/drivers/amdgpu/abi/command_buffer.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+//===----------------------------------------------------------------------===//
+// Block Pool Utilities
+//===----------------------------------------------------------------------===//
+
+enum {
+  // Default usable bytes per command-buffer block.
+  IREE_HAL_AMDGPU_AQL_PROGRAM_DEFAULT_BLOCK_SIZE = 128 * 1024,
+  // Minimum usable bytes per command-buffer block.
+  IREE_HAL_AMDGPU_AQL_PROGRAM_MIN_BLOCK_SIZE = 256,
+};
+
+// Initializes |out_block_pool| so each acquired block has exactly |block_size|
+// usable bytes. |block_size| must be a non-zero power of two.
+iree_status_t iree_hal_amdgpu_aql_program_block_pool_initialize(
+    iree_host_size_t block_size, iree_allocator_t host_allocator,
+    iree_arena_block_pool_t* out_block_pool);
+
+//===----------------------------------------------------------------------===//
+// Recording Output
+//===----------------------------------------------------------------------===//
+
+typedef struct iree_hal_amdgpu_aql_program_t {
+  // Block pool that owns all blocks in this program.
+  iree_arena_block_pool_t* block_pool;
+  // First finalized block in program order.
+  iree_hal_amdgpu_command_buffer_block_header_t* first_block;
+  // Number of finalized blocks in the program.
+  uint32_t block_count;
+  // Number of command records in the program, including terminators.
+  uint32_t command_count;
+  // Worst-case AQL packet count across all blocks.
+  uint32_t max_block_aql_packet_count;
+  // Worst-case kernarg byte count across all blocks.
+  uint32_t max_block_kernarg_length;
+} iree_hal_amdgpu_aql_program_t;
+
+// Releases all blocks in |program| back to its block pool.
+void iree_hal_amdgpu_aql_program_release(
+    iree_hal_amdgpu_aql_program_t* program);
+
+// Returns the block following |block| in program order.
+iree_hal_amdgpu_command_buffer_block_header_t*
+iree_hal_amdgpu_aql_program_block_next(
+    iree_arena_block_pool_t* block_pool,
+    const iree_hal_amdgpu_command_buffer_block_header_t* block);
+
+//===----------------------------------------------------------------------===//
+// Program Builder
+//===----------------------------------------------------------------------===//
+
+typedef uint32_t iree_hal_amdgpu_aql_program_builder_flags_t;
+enum iree_hal_amdgpu_aql_program_builder_flag_bits_t {
+  IREE_HAL_AMDGPU_AQL_PROGRAM_BUILDER_FLAG_NONE = 0u,
+  // The current block has recorded its initial barrier packet count.
+  IREE_HAL_AMDGPU_AQL_PROGRAM_BUILDER_FLAG_HAS_INITIAL_BARRIER_PACKET = 1u << 0,
+  // An execution barrier command must apply to the next payload command,
+  // possibly in a later split block.
+  IREE_HAL_AMDGPU_AQL_PROGRAM_BUILDER_FLAG_HAS_PENDING_EXECUTION_BARRIER = 1u
+                                                                           << 1,
+};
+
+typedef struct iree_hal_amdgpu_aql_program_builder_t {
+  // Block pool used to acquire fixed-capacity recording blocks.
+  iree_arena_block_pool_t* block_pool;
+  // First finalized block in program order.
+  iree_hal_amdgpu_command_buffer_block_header_t* first_block;
+  // Last finalized block in program order.
+  iree_hal_amdgpu_command_buffer_block_header_t* last_block;
+  // Last payload command recorded in program order, if any. Execution barriers
+  // patch producer-side release scopes into this command at recording time.
+  iree_hal_amdgpu_command_buffer_command_header_t* last_payload_command;
+  // Number of finalized and current blocks.
+  uint32_t block_count;
+  // Number of command records emitted into finalized and current blocks.
+  uint32_t command_count;
+  // State for the block currently being recorded.
+  struct {
+    // Current block being recorded.
+    iree_hal_amdgpu_command_buffer_block_header_t* header;
+    // Forward cursor used to append command records.
+    uint8_t* command_cursor;
+    // Backward cursor used to append binding source records.
+    uint8_t* binding_source_cursor;
+    // Number of command records emitted into the block.
+    uint16_t command_count;
+    // Number of binding source records emitted into the block.
+    uint16_t binding_source_count;
+    // Number of dispatch command records emitted into the block.
+    uint16_t dispatch_count;
+    // Number of indirect dispatch command records emitted into the block.
+    uint16_t indirect_dispatch_count;
+    // Number of profile marker command records emitted into the block.
+    uint16_t profile_marker_count;
+    // Worst-case AQL packet count emitted by the block.
+    uint32_t aql_packet_count;
+    // Worst-case kernarg byte count emitted by the block.
+    uint32_t kernarg_length;
+    // Number of leading AQL payload packets in the block's initial unordered
+    // span, including the first packet with a barrier edge.
+    uint32_t initial_barrier_packet_count;
+    // Acquire fence scope carried by a pending execution barrier.
+    uint8_t pending_barrier_acquire_scope;
+    // State flags from iree_hal_amdgpu_aql_program_builder_flag_bits_t.
+    iree_hal_amdgpu_aql_program_builder_flags_t flags;
+  } current_block;
+  // Worst-case AQL packet count across finalized blocks.
+  uint32_t max_block_aql_packet_count;
+  // Worst-case kernarg byte count across finalized blocks.
+  uint32_t max_block_kernarg_length;
+} iree_hal_amdgpu_aql_program_builder_t;
+
+// Initializes |out_builder|. No blocks are acquired until begin().
+void iree_hal_amdgpu_aql_program_builder_initialize(
+    iree_arena_block_pool_t* block_pool,
+    iree_hal_amdgpu_aql_program_builder_t* out_builder);
+
+// Deinitializes |builder| and releases any blocks not transferred by end().
+void iree_hal_amdgpu_aql_program_builder_deinitialize(
+    iree_hal_amdgpu_aql_program_builder_t* builder);
+
+// Begins a recording session by acquiring the first block.
+iree_status_t iree_hal_amdgpu_aql_program_builder_begin(
+    iree_hal_amdgpu_aql_program_builder_t* builder);
+
+// Finalizes recording with a return terminator and transfers blocks to
+// |out_program|.
+iree_status_t iree_hal_amdgpu_aql_program_builder_end(
+    iree_hal_amdgpu_aql_program_builder_t* builder,
+    iree_hal_amdgpu_aql_program_t* out_program);
+
+// Appends a command record and optional binding source records.
+//
+// |command_length| must be qword-aligned and include the common command header.
+// |aql_packet_count| and |kernarg_length| are worst-case replay resource
+// requirements contributed by this command. The builder automatically splits
+// blocks and inserts a branch terminator when the current block cannot fit the
+// command while preserving room for a terminator.
+iree_status_t iree_hal_amdgpu_aql_program_builder_append_command(
+    iree_hal_amdgpu_aql_program_builder_t* builder, uint8_t opcode,
+    uint8_t flags, iree_host_size_t command_length,
+    uint16_t binding_source_count, uint32_t aql_packet_count,
+    uint32_t kernarg_length,
+    iree_hal_amdgpu_command_buffer_command_header_t** out_command,
+    iree_hal_amdgpu_command_buffer_binding_source_t** out_binding_sources);
+
+// Sets the fence scopes for the pending execution barrier most recently
+// appended to |builder|. The release scope patches the previous payload command
+// while the acquire scope is carried until the next payload command. Multiple
+// adjacent barriers are coalesced by taking the maximum encoded HSA fence scope
+// for each side.
+void iree_hal_amdgpu_aql_program_builder_set_pending_barrier_scopes(
+    iree_hal_amdgpu_aql_program_builder_t* builder, uint8_t acquire_scope,
+    uint8_t release_scope);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_AQL_PROGRAM_BUILDER_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/aql_program_builder_test.cc b/runtime/src/iree/hal/drivers/amdgpu/aql_program_builder_test.cc
new file mode 100644
index 0000000..0a7d8af
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/aql_program_builder_test.cc
@@ -0,0 +1,466 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/aql_program_builder.h"
+
+#include <cstddef>
+#include <cstdint>
+
+#include "iree/hal/drivers/amdgpu/abi/queue.h"
+#include "iree/hal/drivers/amdgpu/aql_program_validation.h"
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace iree::hal::amdgpu {
+namespace {
+
+class AqlProgramBuilderTest : public ::testing::Test {
+ protected:
+  void SetUp() override {
+    IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_block_pool_initialize(
+        block_size_, iree_allocator_system(), &block_pool_));
+  }
+
+  void TearDown() override { iree_arena_block_pool_deinitialize(&block_pool_); }
+
+  iree_arena_block_pool_t* block_pool() { return &block_pool_; }
+
+ private:
+  iree_host_size_t block_size_ = 256;
+  iree_arena_block_pool_t block_pool_;
+};
+
+static const iree_hal_amdgpu_command_buffer_command_header_t* FirstCommand(
+    const iree_hal_amdgpu_command_buffer_block_header_t* block) {
+  return iree_hal_amdgpu_command_buffer_block_commands_const(block);
+}
+
+static const iree_hal_amdgpu_command_buffer_command_header_t* LastCommand(
+    const iree_hal_amdgpu_command_buffer_block_header_t* block) {
+  const iree_hal_amdgpu_command_buffer_command_header_t* command =
+      FirstCommand(block);
+  for (uint16_t i = 1; i < block->command_count; ++i) {
+    command = iree_hal_amdgpu_command_buffer_command_next_const(command);
+  }
+  return command;
+}
+
+static iree_status_t AppendBarriers(
+    iree_hal_amdgpu_aql_program_builder_t* builder, int count) {
+  for (int i = 0; i < count; ++i) {
+    iree_hal_amdgpu_command_buffer_command_header_t* barrier = nullptr;
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_aql_program_builder_append_command(
+        builder, IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BARRIER,
+        IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_NONE,
+        sizeof(iree_hal_amdgpu_command_buffer_barrier_command_t),
+        /*binding_source_count=*/0, /*aql_packet_count=*/0,
+        /*kernarg_length=*/0, &barrier, /*out_binding_sources=*/nullptr));
+  }
+  return iree_ok_status();
+}
+
+TEST(CommandBufferAbiTest, CoreRecordSizes) {
+  EXPECT_EQ(sizeof(iree_hal_amdgpu_command_buffer_block_header_t), 64u);
+  EXPECT_EQ(sizeof(iree_hal_amdgpu_command_buffer_command_header_t), 8u);
+  EXPECT_EQ(sizeof(iree_hal_amdgpu_command_buffer_binding_source_t), 16u);
+  EXPECT_EQ(sizeof(iree_hal_amdgpu_command_buffer_barrier_command_t), 16u);
+  EXPECT_EQ(sizeof(iree_hal_amdgpu_command_buffer_dispatch_command_t), 80u);
+  EXPECT_EQ(sizeof(iree_hal_amdgpu_command_buffer_fill_command_t), 40u);
+  EXPECT_EQ(sizeof(iree_hal_amdgpu_command_buffer_copy_command_t), 48u);
+  EXPECT_EQ(sizeof(iree_hal_amdgpu_command_buffer_update_command_t), 40u);
+  EXPECT_EQ(sizeof(iree_hal_amdgpu_command_buffer_branch_command_t), 16u);
+  EXPECT_EQ(sizeof(iree_hal_amdgpu_command_buffer_cond_branch_command_t), 24u);
+  EXPECT_EQ(sizeof(iree_hal_amdgpu_command_buffer_return_command_t), 8u);
+}
+
+TEST(CommandBufferAbiTest, BlockPoolRejectsNonPowerOfTwoBlockSize) {
+  iree_arena_block_pool_t block_pool;
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_INVALID_ARGUMENT,
+      iree_hal_amdgpu_aql_program_block_pool_initialize(
+          /*block_size=*/384, iree_allocator_system(), &block_pool));
+}
+
+TEST(CommandBufferAbiTest, BuilderRejectsOversizedUsableBlockSize) {
+  iree_arena_block_pool_t block_pool;
+  iree_arena_block_pool_initialize(
+      (iree_host_size_t)UINT32_MAX + 1 + sizeof(iree_arena_block_t),
+      iree_allocator_system(), &block_pool);
+
+  iree_hal_amdgpu_aql_program_builder_t builder;
+  iree_hal_amdgpu_aql_program_builder_initialize(&block_pool, &builder);
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_INVALID_ARGUMENT,
+                        iree_hal_amdgpu_aql_program_builder_begin(&builder));
+  iree_hal_amdgpu_aql_program_builder_deinitialize(&builder);
+  iree_arena_block_pool_deinitialize(&block_pool);
+}
+
+TEST_F(AqlProgramBuilderTest, EmptyProgramRecordsReturnBlock) {
+  iree_hal_amdgpu_aql_program_builder_t builder;
+  iree_hal_amdgpu_aql_program_builder_initialize(block_pool(), &builder);
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_builder_begin(&builder));
+
+  iree_hal_amdgpu_aql_program_t program = {};
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_builder_end(&builder, &program));
+  iree_hal_amdgpu_aql_program_builder_deinitialize(&builder);
+
+  ASSERT_NE(program.first_block, nullptr);
+  EXPECT_EQ(program.block_count, 1u);
+  EXPECT_EQ(program.command_count, 1u);
+  EXPECT_EQ(program.max_block_aql_packet_count, 0u);
+  EXPECT_EQ(program.max_block_kernarg_length, 0u);
+
+  const iree_hal_amdgpu_command_buffer_block_header_t* block =
+      program.first_block;
+  EXPECT_EQ(block->magic, IREE_HAL_AMDGPU_COMMAND_BUFFER_BLOCK_MAGIC);
+  EXPECT_EQ(block->version, IREE_HAL_AMDGPU_COMMAND_BUFFER_BLOCK_VERSION_0);
+  EXPECT_EQ(block->header_length, 64u);
+  EXPECT_EQ(block->block_ordinal, 0u);
+  EXPECT_EQ(block->block_length, block_pool()->usable_block_size);
+  EXPECT_EQ(block->command_offset, 64u);
+  EXPECT_EQ(block->command_count, 1u);
+  EXPECT_EQ(block->binding_source_count, 0u);
+  EXPECT_EQ(block->aql_packet_count, 0u);
+  EXPECT_EQ(block->kernarg_length, 0u);
+  EXPECT_EQ(block->terminator_opcode,
+            IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_RETURN);
+  EXPECT_EQ(block->terminator_target_block_ordinal, 0u);
+
+  const iree_hal_amdgpu_command_buffer_command_header_t* command =
+      FirstCommand(block);
+  EXPECT_EQ(command->opcode, IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_RETURN);
+  EXPECT_EQ(iree_hal_amdgpu_command_buffer_command_length(command),
+            sizeof(iree_hal_amdgpu_command_buffer_return_command_t));
+
+  IREE_EXPECT_OK(iree_hal_amdgpu_aql_program_validate_metadata_only(&program));
+
+  iree_hal_amdgpu_aql_program_release(&program);
+}
+
+TEST_F(AqlProgramBuilderTest, AppendsCommandAndBindingSources) {
+  iree_hal_amdgpu_aql_program_builder_t builder;
+  iree_hal_amdgpu_aql_program_builder_initialize(block_pool(), &builder);
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_builder_begin(&builder));
+
+  iree_hal_amdgpu_command_buffer_command_header_t* command = nullptr;
+  iree_hal_amdgpu_command_buffer_binding_source_t* binding_sources = nullptr;
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_builder_append_command(
+      &builder, IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_DISPATCH,
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_HAS_BARRIER,
+      sizeof(iree_hal_amdgpu_command_buffer_dispatch_command_t),
+      /*binding_source_count=*/2, /*aql_packet_count=*/1,
+      /*kernarg_length=*/128, &command, &binding_sources));
+
+  ASSERT_NE(command, nullptr);
+  ASSERT_NE(binding_sources, nullptr);
+  EXPECT_EQ(command->opcode, IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_DISPATCH);
+  EXPECT_TRUE(iree_all_bits_set(
+      command->flags,
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_HAS_BARRIER |
+          IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_USES_QUEUE_KERNARGS));
+  EXPECT_EQ(command->command_index, 0u);
+
+  binding_sources[0].flags =
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_DYNAMIC;
+  binding_sources[0].slot = 3;
+  binding_sources[0].offset_or_pointer = 64;
+  binding_sources[1].flags =
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_NONE;
+  binding_sources[1].slot = 0;
+  binding_sources[1].offset_or_pointer = 0x12345678u;
+
+  iree_hal_amdgpu_aql_program_t program = {};
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_builder_end(&builder, &program));
+  iree_hal_amdgpu_aql_program_builder_deinitialize(&builder);
+
+  const iree_hal_amdgpu_command_buffer_block_header_t* block =
+      program.first_block;
+  EXPECT_EQ(block->command_count, 2u);
+  EXPECT_EQ(block->binding_source_count, 2u);
+  EXPECT_EQ(block->dispatch_count, 1u);
+  EXPECT_EQ(block->indirect_dispatch_count, 0u);
+  EXPECT_EQ(block->profile_marker_count, 0u);
+  EXPECT_EQ(block->aql_packet_count, 1u);
+  EXPECT_EQ(block->kernarg_length, 128u);
+  EXPECT_EQ(block->terminator_opcode,
+            IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_RETURN);
+  EXPECT_EQ(block->terminator_target_block_ordinal, 0u);
+  EXPECT_EQ(program.max_block_aql_packet_count, 1u);
+  EXPECT_EQ(program.max_block_kernarg_length, 128u);
+
+  const iree_hal_amdgpu_command_buffer_binding_source_t* block_binding_sources =
+      iree_hal_amdgpu_command_buffer_block_binding_sources_const(block);
+  EXPECT_EQ(block_binding_sources[0].flags,
+            IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_DYNAMIC);
+  EXPECT_EQ(block_binding_sources[0].slot, 3u);
+  EXPECT_EQ(block_binding_sources[0].offset_or_pointer, 64u);
+  EXPECT_EQ(block_binding_sources[1].flags,
+            IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_NONE);
+  EXPECT_EQ(block_binding_sources[1].slot, 0u);
+  EXPECT_EQ(block_binding_sources[1].offset_or_pointer, 0x12345678u);
+
+  const iree_hal_amdgpu_command_buffer_command_header_t* return_command =
+      LastCommand(block);
+  EXPECT_EQ(return_command->opcode,
+            IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_RETURN);
+
+  iree_hal_amdgpu_aql_program_release(&program);
+}
+
+TEST_F(AqlProgramBuilderTest, PatchesBarrierScopesAtRecordingTime) {
+  iree_hal_amdgpu_aql_program_builder_t builder;
+  iree_hal_amdgpu_aql_program_builder_initialize(block_pool(), &builder);
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_builder_begin(&builder));
+
+  iree_hal_amdgpu_command_buffer_command_header_t* first_dispatch = nullptr;
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_builder_append_command(
+      &builder, IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_DISPATCH,
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_NONE,
+      sizeof(iree_hal_amdgpu_command_buffer_dispatch_command_t),
+      /*binding_source_count=*/0, /*aql_packet_count=*/1,
+      /*kernarg_length=*/0, &first_dispatch,
+      /*out_binding_sources=*/nullptr));
+
+  iree_hal_amdgpu_command_buffer_command_header_t* barrier = nullptr;
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_builder_append_command(
+      &builder, IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BARRIER,
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_NONE,
+      sizeof(iree_hal_amdgpu_command_buffer_barrier_command_t),
+      /*binding_source_count=*/0, /*aql_packet_count=*/0,
+      /*kernarg_length=*/0, &barrier, /*out_binding_sources=*/nullptr));
+  iree_hal_amdgpu_aql_program_builder_set_pending_barrier_scopes(
+      &builder, IREE_HSA_FENCE_SCOPE_SYSTEM, IREE_HSA_FENCE_SCOPE_AGENT);
+
+  iree_hal_amdgpu_command_buffer_command_header_t* second_dispatch = nullptr;
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_builder_append_command(
+      &builder, IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_DISPATCH,
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_NONE,
+      sizeof(iree_hal_amdgpu_command_buffer_dispatch_command_t),
+      /*binding_source_count=*/0, /*aql_packet_count=*/1,
+      /*kernarg_length=*/0, &second_dispatch,
+      /*out_binding_sources=*/nullptr));
+
+  iree_hal_amdgpu_aql_program_t program = {};
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_builder_end(&builder, &program));
+  iree_hal_amdgpu_aql_program_builder_deinitialize(&builder);
+
+  EXPECT_FALSE(iree_any_bit_set(
+      first_dispatch->flags,
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_HAS_BARRIER));
+  EXPECT_EQ(iree_hal_amdgpu_command_buffer_command_flags_acquire_scope(
+                first_dispatch->flags),
+            IREE_HSA_FENCE_SCOPE_NONE);
+  EXPECT_EQ(iree_hal_amdgpu_command_buffer_command_flags_release_scope(
+                first_dispatch->flags),
+            IREE_HSA_FENCE_SCOPE_AGENT);
+
+  EXPECT_TRUE(iree_any_bit_set(
+      second_dispatch->flags,
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_HAS_BARRIER));
+  EXPECT_EQ(iree_hal_amdgpu_command_buffer_command_flags_acquire_scope(
+                second_dispatch->flags),
+            IREE_HSA_FENCE_SCOPE_SYSTEM);
+  EXPECT_EQ(iree_hal_amdgpu_command_buffer_command_flags_release_scope(
+                second_dispatch->flags),
+            IREE_HSA_FENCE_SCOPE_NONE);
+
+  iree_hal_amdgpu_aql_program_release(&program);
+}
+
+TEST_F(AqlProgramBuilderTest, ForcedBarrierKeepsPendingAcquireScope) {
+  iree_hal_amdgpu_aql_program_builder_t builder;
+  iree_hal_amdgpu_aql_program_builder_initialize(block_pool(), &builder);
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_builder_begin(&builder));
+
+  iree_hal_amdgpu_command_buffer_command_header_t* barrier = nullptr;
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_builder_append_command(
+      &builder, IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BARRIER,
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_NONE,
+      sizeof(iree_hal_amdgpu_command_buffer_barrier_command_t),
+      /*binding_source_count=*/0, /*aql_packet_count=*/0,
+      /*kernarg_length=*/0, &barrier, /*out_binding_sources=*/nullptr));
+  iree_hal_amdgpu_aql_program_builder_set_pending_barrier_scopes(
+      &builder, IREE_HSA_FENCE_SCOPE_SYSTEM, IREE_HSA_FENCE_SCOPE_NONE);
+
+  iree_hal_amdgpu_command_buffer_command_header_t* dispatch = nullptr;
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_builder_append_command(
+      &builder, IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_DISPATCH,
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_HAS_BARRIER,
+      sizeof(iree_hal_amdgpu_command_buffer_dispatch_command_t),
+      /*binding_source_count=*/0, /*aql_packet_count=*/1,
+      /*kernarg_length=*/0, &dispatch, /*out_binding_sources=*/nullptr));
+
+  iree_hal_amdgpu_aql_program_t program = {};
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_builder_end(&builder, &program));
+  iree_hal_amdgpu_aql_program_builder_deinitialize(&builder);
+
+  EXPECT_TRUE(iree_any_bit_set(
+      dispatch->flags,
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_HAS_BARRIER));
+  EXPECT_EQ(iree_hal_amdgpu_command_buffer_command_flags_acquire_scope(
+                dispatch->flags),
+            IREE_HSA_FENCE_SCOPE_SYSTEM);
+  EXPECT_EQ(iree_hal_amdgpu_command_buffer_command_flags_release_scope(
+                dispatch->flags),
+            IREE_HSA_FENCE_SCOPE_AGENT);
+
+  iree_hal_amdgpu_aql_program_release(&program);
+}
+
+TEST_F(AqlProgramBuilderTest, SplitsBlocksWithBranchTerminator) {
+  iree_hal_amdgpu_aql_program_builder_t builder;
+  iree_hal_amdgpu_aql_program_builder_initialize(block_pool(), &builder);
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_builder_begin(&builder));
+
+  for (int i = 0; i < 4; ++i) {
+    iree_hal_amdgpu_command_buffer_command_header_t* command = nullptr;
+    IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_builder_append_command(
+        &builder, IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_DISPATCH,
+        IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_NONE,
+        sizeof(iree_hal_amdgpu_command_buffer_dispatch_command_t),
+        /*binding_source_count=*/0, /*aql_packet_count=*/1,
+        /*kernarg_length=*/32, &command, /*out_binding_sources=*/nullptr));
+    EXPECT_NE(command, nullptr);
+  }
+
+  iree_hal_amdgpu_aql_program_t program = {};
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_builder_end(&builder, &program));
+  iree_hal_amdgpu_aql_program_builder_deinitialize(&builder);
+
+  ASSERT_GE(program.block_count, 2u);
+  const iree_hal_amdgpu_command_buffer_block_header_t* first_block =
+      program.first_block;
+  const iree_hal_amdgpu_command_buffer_command_header_t* first_terminator =
+      LastCommand(first_block);
+  ASSERT_EQ(first_terminator->opcode,
+            IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BRANCH);
+  EXPECT_EQ(first_block->terminator_opcode,
+            IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BRANCH);
+  EXPECT_EQ(first_block->terminator_target_block_ordinal, 1u);
+  const auto* branch =
+      reinterpret_cast<const iree_hal_amdgpu_command_buffer_branch_command_t*>(
+          first_terminator);
+  EXPECT_EQ(branch->target_block_ordinal, 1u);
+
+  const iree_hal_amdgpu_command_buffer_block_header_t* second_block =
+      iree_hal_amdgpu_aql_program_block_next(block_pool(), first_block);
+  ASSERT_NE(second_block, nullptr);
+  EXPECT_EQ(second_block->block_ordinal, 1u);
+  EXPECT_EQ(second_block->terminator_opcode,
+            IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_RETURN);
+  EXPECT_EQ(second_block->terminator_target_block_ordinal, 0u);
+
+  iree_hal_amdgpu_aql_program_release(&program);
+}
+
+TEST_F(AqlProgramBuilderTest, ValidatesSplitMetadataOnlyProgram) {
+  iree_hal_amdgpu_aql_program_builder_t builder;
+  iree_hal_amdgpu_aql_program_builder_initialize(block_pool(), &builder);
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_builder_begin(&builder));
+
+  IREE_ASSERT_OK(AppendBarriers(&builder, 12));
+
+  iree_hal_amdgpu_aql_program_t program = {};
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_builder_end(&builder, &program));
+  iree_hal_amdgpu_aql_program_builder_deinitialize(&builder);
+
+  ASSERT_GE(program.block_count, 2u);
+  ASSERT_EQ(program.max_block_aql_packet_count, 0u);
+  IREE_EXPECT_OK(iree_hal_amdgpu_aql_program_validate_metadata_only(&program));
+
+  iree_hal_amdgpu_aql_program_release(&program);
+}
+
+TEST_F(AqlProgramBuilderTest, MetadataOnlyValidationRejectsPayloadBlocks) {
+  iree_hal_amdgpu_aql_program_builder_t builder;
+  iree_hal_amdgpu_aql_program_builder_initialize(block_pool(), &builder);
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_builder_begin(&builder));
+
+  iree_hal_amdgpu_command_buffer_command_header_t* dispatch = nullptr;
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_builder_append_command(
+      &builder, IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_DISPATCH,
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_NONE,
+      sizeof(iree_hal_amdgpu_command_buffer_dispatch_command_t),
+      /*binding_source_count=*/0, /*aql_packet_count=*/1,
+      /*kernarg_length=*/0, &dispatch, /*out_binding_sources=*/nullptr));
+
+  iree_hal_amdgpu_aql_program_t program = {};
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_builder_end(&builder, &program));
+  iree_hal_amdgpu_aql_program_builder_deinitialize(&builder);
+
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_INVALID_ARGUMENT,
+      iree_hal_amdgpu_aql_program_validate_metadata_only(&program));
+
+  iree_hal_amdgpu_aql_program_release(&program);
+}
+
+TEST_F(AqlProgramBuilderTest, MetadataOnlyValidationRejectsProfileMarkers) {
+  iree_hal_amdgpu_aql_program_builder_t builder;
+  iree_hal_amdgpu_aql_program_builder_initialize(block_pool(), &builder);
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_builder_begin(&builder));
+
+  iree_hal_amdgpu_command_buffer_command_header_t* profile_marker = nullptr;
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_builder_append_command(
+      &builder, IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_PROFILE_MARKER,
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_NONE,
+      sizeof(iree_hal_amdgpu_command_buffer_command_header_t),
+      /*binding_source_count=*/0, /*aql_packet_count=*/0,
+      /*kernarg_length=*/0, &profile_marker,
+      /*out_binding_sources=*/nullptr));
+
+  iree_hal_amdgpu_aql_program_t program = {};
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_builder_end(&builder, &program));
+  iree_hal_amdgpu_aql_program_builder_deinitialize(&builder);
+
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_UNIMPLEMENTED,
+      iree_hal_amdgpu_aql_program_validate_metadata_only(&program));
+
+  iree_hal_amdgpu_aql_program_release(&program);
+}
+
+TEST_F(AqlProgramBuilderTest,
+       MetadataOnlyValidationRejectsTerminatorMetadataMismatch) {
+  iree_hal_amdgpu_aql_program_builder_t builder;
+  iree_hal_amdgpu_aql_program_builder_initialize(block_pool(), &builder);
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_builder_begin(&builder));
+
+  IREE_ASSERT_OK(AppendBarriers(&builder, 12));
+
+  iree_hal_amdgpu_aql_program_t program = {};
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_builder_end(&builder, &program));
+  iree_hal_amdgpu_aql_program_builder_deinitialize(&builder);
+
+  ASSERT_GE(program.block_count, 2u);
+  program.first_block->terminator_target_block_ordinal = 2;
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_INVALID_ARGUMENT,
+      iree_hal_amdgpu_aql_program_validate_metadata_only(&program));
+
+  iree_hal_amdgpu_aql_program_release(&program);
+}
+
+TEST_F(AqlProgramBuilderTest, RejectsOversizedCommand) {
+  iree_hal_amdgpu_aql_program_builder_t builder;
+  iree_hal_amdgpu_aql_program_builder_initialize(block_pool(), &builder);
+  IREE_ASSERT_OK(iree_hal_amdgpu_aql_program_builder_begin(&builder));
+
+  iree_hal_amdgpu_command_buffer_command_header_t* command = nullptr;
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_RESOURCE_EXHAUSTED,
+      iree_hal_amdgpu_aql_program_builder_append_command(
+          &builder, IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_DISPATCH,
+          IREE_HAL_AMDGPU_COMMAND_BUFFER_COMMAND_FLAG_NONE,
+          block_pool()->usable_block_size, /*binding_source_count=*/0,
+          /*aql_packet_count=*/1, /*kernarg_length=*/0, &command,
+          /*out_binding_sources=*/nullptr));
+
+  iree_hal_amdgpu_aql_program_builder_deinitialize(&builder);
+}
+
+}  // namespace
+}  // namespace iree::hal::amdgpu
diff --git a/runtime/src/iree/hal/drivers/amdgpu/aql_program_validation.c b/runtime/src/iree/hal/drivers/amdgpu/aql_program_validation.c
new file mode 100644
index 0000000..1e1a857
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/aql_program_validation.c
@@ -0,0 +1,176 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/aql_program_validation.h"
+
+iree_status_t iree_hal_amdgpu_aql_program_validate_block_terminator(
+    const iree_hal_amdgpu_command_buffer_block_header_t* block) {
+  switch (block->terminator_opcode) {
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BRANCH:
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_RETURN:
+      return iree_ok_status();
+    case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_COND_BRANCH:
+      return iree_make_status(
+          IREE_STATUS_UNIMPLEMENTED,
+          "conditional AQL command-buffer branch replay not yet wired");
+    default:
+      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                              "AQL command-buffer block %" PRIu32
+                              " has no terminator",
+                              block->block_ordinal);
+  }
+}
+
+iree_status_t iree_hal_amdgpu_aql_program_next_linear_block(
+    const iree_hal_amdgpu_aql_program_t* program,
+    const iree_hal_amdgpu_command_buffer_block_header_t* block,
+    uint32_t target_block_ordinal,
+    const iree_hal_amdgpu_command_buffer_block_header_t** out_next_block) {
+  *out_next_block = NULL;
+  const iree_hal_amdgpu_command_buffer_block_header_t* next_block =
+      iree_hal_amdgpu_aql_program_block_next(program->block_pool, block);
+  if (IREE_UNLIKELY(!next_block ||
+                    next_block->block_ordinal != target_block_ordinal)) {
+    return iree_make_status(
+        IREE_STATUS_UNIMPLEMENTED,
+        "non-linear AQL command-buffer branch replay not yet wired");
+  }
+  *out_next_block = next_block;
+  return iree_ok_status();
+}
+
+static iree_status_t
+iree_hal_amdgpu_aql_program_validate_metadata_block_commands(
+    const iree_hal_amdgpu_command_buffer_block_header_t* block) {
+  const iree_hal_amdgpu_command_buffer_command_header_t* command =
+      iree_hal_amdgpu_command_buffer_block_commands_const(block);
+  iree_status_t status = iree_ok_status();
+  bool reached_terminator = false;
+  for (uint16_t i = 0; i < block->command_count && iree_status_is_ok(status) &&
+                       !reached_terminator;
+       ++i) {
+    const bool is_final_command = i + 1 == block->command_count;
+    switch (command->opcode) {
+      case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BARRIER:
+        break;
+      case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BRANCH: {
+        const iree_hal_amdgpu_command_buffer_branch_command_t* branch_command =
+            (const iree_hal_amdgpu_command_buffer_branch_command_t*)command;
+        if (IREE_UNLIKELY(!is_final_command)) {
+          status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                                    "AQL command-buffer block %" PRIu32
+                                    " has a branch before the final command",
+                                    block->block_ordinal);
+          break;
+        }
+        if (IREE_UNLIKELY(block->terminator_opcode !=
+                              IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BRANCH ||
+                          branch_command->target_block_ordinal !=
+                              block->terminator_target_block_ordinal)) {
+          status =
+              iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                               "AQL command-buffer block %" PRIu32
+                               " has mismatched branch terminator metadata",
+                               block->block_ordinal);
+          break;
+        }
+        reached_terminator = true;
+        break;
+      }
+      case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_RETURN:
+        if (IREE_UNLIKELY(!is_final_command)) {
+          status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                                    "AQL command-buffer block %" PRIu32
+                                    " has a return before the final command",
+                                    block->block_ordinal);
+          break;
+        }
+        if (IREE_UNLIKELY(block->terminator_opcode !=
+                          IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_RETURN)) {
+          status =
+              iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                               "AQL command-buffer block %" PRIu32
+                               " has mismatched return terminator metadata",
+                               block->block_ordinal);
+          break;
+        }
+        reached_terminator = true;
+        break;
+      case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_DISPATCH:
+      case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_FILL:
+      case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_COPY:
+      case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_UPDATE:
+      case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_PROFILE_MARKER:
+      case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_COND_BRANCH:
+        status = iree_make_status(
+            IREE_STATUS_UNIMPLEMENTED,
+            "AQL command-buffer opcode %u metadata-only replay not yet wired",
+            command->opcode);
+        break;
+      default:
+        status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                                  "malformed AQL command-buffer opcode %u",
+                                  command->opcode);
+        break;
+    }
+    if (iree_status_is_ok(status) && !reached_terminator) {
+      command = iree_hal_amdgpu_command_buffer_command_next_const(command);
+    }
+  }
+  if (iree_status_is_ok(status) && !reached_terminator) {
+    status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                              "AQL command-buffer block %" PRIu32
+                              " has no terminator",
+                              block->block_ordinal);
+  }
+  return status;
+}
+
+iree_status_t iree_hal_amdgpu_aql_program_validate_metadata_only(
+    const iree_hal_amdgpu_aql_program_t* program) {
+  if (IREE_UNLIKELY(!program->first_block)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AQL command-buffer program has no blocks");
+  }
+
+  const iree_hal_amdgpu_command_buffer_block_header_t* block =
+      program->first_block;
+  bool reached_return = false;
+  iree_status_t status = iree_ok_status();
+  while (iree_status_is_ok(status) && !reached_return && block) {
+    if (IREE_UNLIKELY(block->aql_packet_count != 0)) {
+      status =
+          iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                           "metadata-only AQL command-buffer block %" PRIu32
+                           " declares %" PRIu32 " AQL packets",
+                           block->block_ordinal, block->aql_packet_count);
+      break;
+    }
+    status = iree_hal_amdgpu_aql_program_validate_block_terminator(block);
+    if (!iree_status_is_ok(status)) break;
+    status =
+        iree_hal_amdgpu_aql_program_validate_metadata_block_commands(block);
+    if (!iree_status_is_ok(status)) break;
+
+    switch (block->terminator_opcode) {
+      case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_BRANCH:
+        status = iree_hal_amdgpu_aql_program_next_linear_block(
+            program, block, block->terminator_target_block_ordinal, &block);
+        break;
+      case IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_RETURN:
+        reached_return = true;
+        break;
+      default:
+        IREE_ASSERT_UNREACHABLE("block terminator was already validated");
+        break;
+    }
+  }
+  if (iree_status_is_ok(status) && !reached_return) {
+    status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                              "AQL command-buffer program has no return");
+  }
+  return status;
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/aql_program_validation.h b/runtime/src/iree/hal/drivers/amdgpu/aql_program_validation.h
new file mode 100644
index 0000000..4964a82
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/aql_program_validation.h
@@ -0,0 +1,37 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_AQL_PROGRAM_VALIDATION_H_
+#define IREE_HAL_DRIVERS_AMDGPU_AQL_PROGRAM_VALIDATION_H_
+
+#include "iree/base/api.h"
+#include "iree/hal/drivers/amdgpu/aql_program_builder.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Validates that |block| has a replayable branch or return terminator.
+iree_status_t iree_hal_amdgpu_aql_program_validate_block_terminator(
+    const iree_hal_amdgpu_command_buffer_block_header_t* block);
+
+// Resolves the next block for currently supported linear branch replay.
+iree_status_t iree_hal_amdgpu_aql_program_next_linear_block(
+    const iree_hal_amdgpu_aql_program_t* program,
+    const iree_hal_amdgpu_command_buffer_block_header_t* block,
+    uint32_t target_block_ordinal,
+    const iree_hal_amdgpu_command_buffer_block_header_t** out_next_block);
+
+// Validates that |program| can be replayed without invoking AQL block
+// processors.
+iree_status_t iree_hal_amdgpu_aql_program_validate_metadata_only(
+    const iree_hal_amdgpu_aql_program_t* program);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_AQL_PROGRAM_VALIDATION_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/buffer.c b/runtime/src/iree/hal/drivers/amdgpu/buffer.c
index 0233842..6af2a78 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/buffer.c
+++ b/runtime/src/iree/hal/drivers/amdgpu/buffer.c
@@ -1,4 +1,4 @@
-// Copyright 2025 The IREE Authors
+// Copyright 2026 The IREE Authors
 //
 // Licensed under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,88 +6,380 @@
 
 #include "iree/hal/drivers/amdgpu/buffer.h"
 
+#include "iree/hal/drivers/amdgpu/logical_device.h"
+#include "iree/hal/drivers/amdgpu/transient_buffer.h"
+
 //===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_external_buffer_t
+// iree_hal_amdgpu_buffer_t
 //===----------------------------------------------------------------------===//
 
-typedef struct iree_hal_amdgpu_external_buffer_t {
+// A buffer backed by an HSA memory pool allocation.
+// Device-local buffers are usually coarse-grained and unmappable; explicit
+// host-visible/host-local buffers use fine-grained pools and can be mapped.
+struct iree_hal_amdgpu_buffer_t {
+  // Base HAL buffer resource returned to callers.
   iree_hal_buffer_t base;
+
+  // Host allocator used to free unpooled wrapper storage.
   iree_allocator_t host_allocator;
+
+  // Pool this wrapper returns to when its final reference is released.
+  iree_hal_amdgpu_buffer_pool_t* pool;
+
+  // Next wrapper in either the pool return stack or acquire-side cache.
+  iree_hal_amdgpu_buffer_t* pool_next;
+
+  // Unowned libhsa handle for freeing the allocation on destroy.
+  const iree_hal_amdgpu_libhsa_t* libhsa;
+
+  // HSA-allocated pointer. Accessible from both host and device when allocated
+  // from a fine-grained pool, or device-only from a coarse-grained pool.
+  void* host_ptr;
+
+  // Optional callback for provider/pool-owned buffer storage.
+  // When present the callback owns release of |host_ptr| and any backing pool
+  // bookkeeping. When null this buffer frees |host_ptr| directly with HSA.
   iree_hal_buffer_release_callback_t release_callback;
-  uint64_t device_ptr;
-} iree_hal_amdgpu_external_buffer_t;
 
-static const iree_hal_buffer_vtable_t iree_hal_amdgpu_external_buffer_vtable;
+  // Session-local profiling allocation id for direct allocator buffers.
+  uint64_t profile_allocation_id;
 
-static iree_hal_amdgpu_external_buffer_t* iree_hal_amdgpu_external_buffer_cast(
+  // Profiling session id owning |profile_allocation_id|.
+  uint64_t profile_session_id;
+
+  // Producer-defined memory pool id used for profiling events.
+  uint64_t profile_pool_id;
+
+  // Physical device ordinal used for profiling allocation/free events.
+  uint32_t profile_physical_device_ordinal;
+
+  // Byte alignment used for profiling allocation/free events.
+  iree_device_size_t profile_alignment;
+};
+
+static const iree_hal_buffer_vtable_t iree_hal_amdgpu_buffer_vtable;
+
+static iree_hal_amdgpu_buffer_t* iree_hal_amdgpu_buffer_cast(
     iree_hal_buffer_t* base_value) {
-  IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_amdgpu_external_buffer_vtable);
-  return (iree_hal_amdgpu_external_buffer_t*)base_value;
+  IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_amdgpu_buffer_vtable);
+  return (iree_hal_amdgpu_buffer_t*)base_value;
 }
 
-static const iree_hal_amdgpu_external_buffer_t*
-iree_hal_amdgpu_external_buffer_const_cast(
-    const iree_hal_buffer_t* base_value) {
-  IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_amdgpu_external_buffer_vtable);
-  return (const iree_hal_amdgpu_external_buffer_t*)base_value;
+//===----------------------------------------------------------------------===//
+// iree_hal_amdgpu_buffer_pool_t
+//===----------------------------------------------------------------------===//
+
+static iree_host_size_t iree_hal_amdgpu_buffer_pool_slot_size(void) {
+  return iree_host_align(sizeof(iree_hal_amdgpu_buffer_t),
+                         iree_alignof(iree_hal_amdgpu_buffer_t));
 }
 
-iree_status_t iree_hal_amdgpu_external_buffer_wrap(
+static iree_status_t iree_hal_amdgpu_buffer_pool_grow_locked(
+    iree_hal_amdgpu_buffer_pool_t* pool) {
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  const iree_host_size_t slot_size = iree_hal_amdgpu_buffer_pool_slot_size();
+  const iree_host_size_t slot_count =
+      pool->block_pool->usable_block_size / slot_size;
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)slot_count);
+
+  iree_arena_block_t* block = NULL;
+  void* block_ptr = NULL;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_arena_block_pool_acquire(pool->block_pool, &block, &block_ptr));
+
+  if (pool->block_tail) {
+    pool->block_tail->next = block;
+  } else {
+    pool->block_head = block;
+  }
+  pool->block_tail = block;
+
+  uint8_t* slot_ptr = (uint8_t*)block_ptr;
+  for (iree_host_size_t i = 0; i < slot_count; ++i) {
+    iree_hal_amdgpu_buffer_t* buffer = (iree_hal_amdgpu_buffer_t*)slot_ptr;
+    buffer->pool = pool;
+    buffer->pool_next = pool->acquire_head;
+    pool->acquire_head = buffer;
+    slot_ptr += slot_size;
+  }
+
+  IREE_TRACE_ZONE_END(z0);
+  return iree_ok_status();
+}
+
+iree_status_t iree_hal_amdgpu_buffer_pool_initialize(
+    iree_arena_block_pool_t* block_pool,
+    iree_hal_amdgpu_buffer_pool_t* out_pool) {
+  IREE_ASSERT_ARGUMENT(block_pool);
+  IREE_ASSERT_ARGUMENT(out_pool);
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  memset(out_pool, 0, sizeof(*out_pool));
+  const iree_host_size_t slot_size = iree_hal_amdgpu_buffer_pool_slot_size();
+  if (IREE_UNLIKELY(block_pool->usable_block_size < slot_size)) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "AMDGPU buffer pool block usable size %" PRIhsz
+                            " is smaller than wrapper slot size %" PRIhsz,
+                            block_pool->usable_block_size, slot_size);
+  }
+
+  out_pool->block_pool = block_pool;
+  iree_atomic_store(&out_pool->return_head, 0, iree_memory_order_relaxed);
+  iree_slim_mutex_initialize(&out_pool->mutex);
+#if !defined(NDEBUG)
+  iree_atomic_store(&out_pool->live_count, 0, iree_memory_order_relaxed);
+#endif  // !defined(NDEBUG)
+
+  IREE_TRACE_ZONE_END(z0);
+  return iree_ok_status();
+}
+
+void iree_hal_amdgpu_buffer_pool_deinitialize(
+    iree_hal_amdgpu_buffer_pool_t* pool) {
+  if (!pool || !pool->block_pool) return;
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+#if !defined(NDEBUG)
+  const int32_t live_count =
+      iree_atomic_load(&pool->live_count, iree_memory_order_acquire);
+  IREE_ASSERT(live_count == 0,
+              "deinitializing AMDGPU buffer pool with %d live wrappers",
+              live_count);
+#endif  // !defined(NDEBUG)
+
+  iree_atomic_store(&pool->return_head, 0, iree_memory_order_relaxed);
+  pool->acquire_head = NULL;
+  if (pool->block_head) {
+    iree_arena_block_pool_release(pool->block_pool, pool->block_head,
+                                  pool->block_tail);
+  }
+  pool->block_head = NULL;
+  pool->block_tail = NULL;
+  iree_slim_mutex_deinitialize(&pool->mutex);
+  pool->block_pool = NULL;
+
+  IREE_TRACE_ZONE_END(z0);
+}
+
+static iree_status_t iree_hal_amdgpu_buffer_pool_acquire(
+    iree_hal_amdgpu_buffer_pool_t* pool,
+    iree_hal_amdgpu_buffer_t** out_buffer) {
+  *out_buffer = NULL;
+
+  iree_slim_mutex_lock(&pool->mutex);
+
+  iree_status_t status = iree_ok_status();
+  iree_hal_amdgpu_buffer_t* buffer = pool->acquire_head;
+  if (buffer) {
+    pool->acquire_head = buffer->pool_next;
+  } else {
+    buffer = (iree_hal_amdgpu_buffer_t*)iree_atomic_exchange(
+        &pool->return_head, 0, iree_memory_order_acquire);
+    if (buffer) {
+      pool->acquire_head = buffer->pool_next;
+    } else {
+      status = iree_hal_amdgpu_buffer_pool_grow_locked(pool);
+      if (iree_status_is_ok(status)) {
+        buffer = pool->acquire_head;
+        pool->acquire_head = buffer->pool_next;
+      }
+    }
+  }
+
+  iree_slim_mutex_unlock(&pool->mutex);
+
+  if (iree_status_is_ok(status)) {
+    buffer->pool_next = NULL;
+#if !defined(NDEBUG)
+    iree_atomic_fetch_add(&pool->live_count, 1, iree_memory_order_acq_rel);
+#endif  // !defined(NDEBUG)
+    *out_buffer = buffer;
+  }
+  return status;
+}
+
+static void iree_hal_amdgpu_buffer_pool_release(
+    iree_hal_amdgpu_buffer_pool_t* pool, iree_hal_amdgpu_buffer_t* buffer) {
+#if !defined(NDEBUG)
+  const int32_t old_live_count =
+      iree_atomic_fetch_sub(&pool->live_count, 1, iree_memory_order_acq_rel);
+  IREE_ASSERT(old_live_count > 0,
+              "releasing AMDGPU buffer wrapper with no live wrapper count");
+#endif  // !defined(NDEBUG)
+
+  intptr_t expected = 0;
+  do {
+    expected = iree_atomic_load(&pool->return_head, iree_memory_order_relaxed);
+    buffer->pool_next = (iree_hal_amdgpu_buffer_t*)expected;
+  } while (!iree_atomic_compare_exchange_weak(
+      &pool->return_head, &expected, (intptr_t)buffer,
+      iree_memory_order_release, iree_memory_order_relaxed));
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hal_amdgpu_buffer_t
+//===----------------------------------------------------------------------===//
+
+void* iree_hal_amdgpu_buffer_device_pointer(iree_hal_buffer_t* base_buffer) {
+  if (!iree_hal_resource_is((const iree_hal_resource_t*)base_buffer,
+                            &iree_hal_amdgpu_buffer_vtable)) {
+    if (iree_hal_amdgpu_transient_buffer_isa(base_buffer)) {
+      iree_hal_buffer_t* backing_buffer =
+          iree_hal_amdgpu_transient_buffer_backing_buffer(base_buffer);
+      if (!backing_buffer) return NULL;
+      return iree_hal_amdgpu_buffer_device_pointer(backing_buffer);
+    }
+    return NULL;
+  }
+  return ((iree_hal_amdgpu_buffer_t*)base_buffer)->host_ptr;
+}
+
+static void iree_hal_amdgpu_buffer_initialize(
+    const iree_hal_amdgpu_libhsa_t* libhsa,
     iree_hal_buffer_placement_t placement, iree_hal_memory_type_t memory_type,
     iree_hal_memory_access_t allowed_access,
     iree_hal_buffer_usage_t allowed_usage, iree_device_size_t allocation_size,
-    iree_device_size_t byte_offset, iree_device_size_t byte_length,
-    uint64_t device_ptr, iree_hal_buffer_release_callback_t release_callback,
-    iree_allocator_t host_allocator, iree_hal_buffer_t** out_buffer) {
-  IREE_ASSERT_ARGUMENT(device_ptr);
-  IREE_ASSERT_ARGUMENT(out_buffer);
-  *out_buffer = NULL;
-  IREE_TRACE_ZONE_BEGIN(z0);
+    iree_device_size_t byte_length, void* host_ptr,
+    iree_hal_buffer_release_callback_t release_callback,
+    iree_hal_amdgpu_buffer_pool_t* pool, iree_allocator_t host_allocator,
+    iree_hal_amdgpu_buffer_t* out_buffer) {
+  iree_hal_buffer_initialize(placement, &out_buffer->base, allocation_size,
+                             /*byte_offset=*/0, byte_length, memory_type,
+                             allowed_access, allowed_usage,
+                             &iree_hal_amdgpu_buffer_vtable, &out_buffer->base);
+  out_buffer->host_allocator = host_allocator;
+  out_buffer->pool = pool;
+  out_buffer->pool_next = NULL;
+  out_buffer->libhsa = libhsa;
+  out_buffer->host_ptr = host_ptr;
+  out_buffer->release_callback = release_callback;
+  out_buffer->profile_allocation_id = 0;
+  out_buffer->profile_session_id = 0;
+  out_buffer->profile_pool_id = 0;
+  out_buffer->profile_physical_device_ordinal = UINT32_MAX;
+  out_buffer->profile_alignment = 0;
+}
 
-  iree_hal_amdgpu_external_buffer_t* buffer = NULL;
+iree_status_t iree_hal_amdgpu_buffer_create(
+    const iree_hal_amdgpu_libhsa_t* libhsa,
+    iree_hal_buffer_placement_t placement, iree_hal_memory_type_t memory_type,
+    iree_hal_memory_access_t allowed_access,
+    iree_hal_buffer_usage_t allowed_usage, iree_device_size_t allocation_size,
+    iree_device_size_t byte_length, void* host_ptr,
+    iree_hal_buffer_release_callback_t release_callback,
+    iree_allocator_t host_allocator, iree_hal_buffer_t** out_buffer) {
+  IREE_ASSERT_ARGUMENT(out_buffer);
+  IREE_TRACE_ZONE_BEGIN(z0);
+  *out_buffer = NULL;
+
+  iree_hal_amdgpu_buffer_t* buffer = NULL;
   IREE_RETURN_AND_END_ZONE_IF_ERROR(
       z0,
       iree_allocator_malloc(host_allocator, sizeof(*buffer), (void**)&buffer));
-  iree_hal_buffer_initialize(
-      placement, &buffer->base, allocation_size, byte_offset, byte_length,
-      memory_type, allowed_access, allowed_usage,
-      &iree_hal_amdgpu_external_buffer_vtable, &buffer->base);
-  buffer->host_allocator = host_allocator;
-  buffer->release_callback = release_callback;
-  buffer->device_ptr = device_ptr;
+  iree_hal_amdgpu_buffer_initialize(
+      libhsa, placement, memory_type, allowed_access, allowed_usage,
+      allocation_size, byte_length, host_ptr, release_callback, /*pool=*/NULL,
+      host_allocator, buffer);
 
   *out_buffer = &buffer->base;
   IREE_TRACE_ZONE_END(z0);
   return iree_ok_status();
 }
 
-static void iree_hal_amdgpu_external_buffer_destroy(
-    iree_hal_buffer_t* base_buffer) {
-  iree_hal_amdgpu_external_buffer_t* buffer =
-      iree_hal_amdgpu_external_buffer_cast(base_buffer);
+iree_status_t iree_hal_amdgpu_buffer_create_pooled(
+    const iree_hal_amdgpu_libhsa_t* libhsa,
+    iree_hal_buffer_placement_t placement, iree_hal_memory_type_t memory_type,
+    iree_hal_memory_access_t allowed_access,
+    iree_hal_buffer_usage_t allowed_usage, iree_device_size_t allocation_size,
+    iree_device_size_t byte_length, void* host_ptr,
+    iree_hal_buffer_release_callback_t release_callback,
+    iree_hal_amdgpu_buffer_pool_t* pool, iree_allocator_t host_allocator,
+    iree_hal_buffer_t** out_buffer) {
+  IREE_ASSERT_ARGUMENT(pool);
+  IREE_ASSERT_ARGUMENT(out_buffer);
+  IREE_TRACE_ZONE_BEGIN(z0);
+  *out_buffer = NULL;
+
+  iree_hal_amdgpu_buffer_t* buffer = NULL;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_hal_amdgpu_buffer_pool_acquire(pool, &buffer));
+  iree_hal_amdgpu_buffer_initialize(
+      libhsa, placement, memory_type, allowed_access, allowed_usage,
+      allocation_size, byte_length, host_ptr, release_callback, pool,
+      host_allocator, buffer);
+
+  *out_buffer = &buffer->base;
+  IREE_TRACE_ZONE_END(z0);
+  return iree_ok_status();
+}
+
+void iree_hal_amdgpu_buffer_set_profile_allocation(
+    iree_hal_buffer_t* base_buffer, uint64_t session_id, uint64_t allocation_id,
+    uint64_t pool_id, uint32_t physical_device_ordinal,
+    iree_device_size_t alignment) {
+  iree_hal_amdgpu_buffer_t* buffer = iree_hal_amdgpu_buffer_cast(base_buffer);
+  buffer->profile_allocation_id = allocation_id;
+  buffer->profile_session_id = session_id;
+  buffer->profile_pool_id = pool_id;
+  buffer->profile_physical_device_ordinal = physical_device_ordinal;
+  buffer->profile_alignment = alignment;
+}
+
+static void iree_hal_amdgpu_buffer_destroy(iree_hal_buffer_t* base_buffer) {
+  iree_hal_amdgpu_buffer_t* buffer = iree_hal_amdgpu_buffer_cast(base_buffer);
   iree_allocator_t host_allocator = buffer->host_allocator;
+  iree_hal_amdgpu_buffer_pool_t* pool = buffer->pool;
   IREE_TRACE_ZONE_BEGIN(z0);
 
-  // Optionally call a release callback when the buffer is destroyed. Not all
-  // implementations may require this but it's cheap and provides additional
-  // flexibility.
+  if (buffer->profile_allocation_id != 0 && base_buffer->placement.device) {
+    iree_hal_profile_memory_event_t event =
+        iree_hal_profile_memory_event_default();
+    event.type = IREE_HAL_PROFILE_MEMORY_EVENT_TYPE_BUFFER_FREE;
+    event.allocation_id = buffer->profile_allocation_id;
+    event.pool_id = buffer->profile_pool_id;
+    event.backing_id = (uint64_t)(uintptr_t)buffer->host_ptr;
+    event.physical_device_ordinal = buffer->profile_physical_device_ordinal;
+    event.memory_type = base_buffer->memory_type;
+    event.buffer_usage = base_buffer->allowed_usage;
+    event.length = base_buffer->allocation_size;
+    event.alignment = buffer->profile_alignment;
+    iree_hal_amdgpu_logical_device_record_profile_memory_event_for_session(
+        base_buffer->placement.device, buffer->profile_session_id, &event);
+  }
+
   if (buffer->release_callback.fn) {
     buffer->release_callback.fn(buffer->release_callback.user_data,
                                 base_buffer);
+  } else if (buffer->host_ptr) {
+    iree_hal_amdgpu_hsa_cleanup_assert_success(
+        iree_hsa_amd_memory_pool_free_raw(buffer->libhsa, buffer->host_ptr));
   }
 
-  iree_allocator_free(host_allocator, buffer);
+  buffer->libhsa = NULL;
+  buffer->host_ptr = NULL;
+  buffer->release_callback = iree_hal_buffer_release_callback_null();
+  buffer->profile_allocation_id = 0;
+  buffer->profile_session_id = 0;
+  buffer->profile_pool_id = 0;
+  buffer->profile_physical_device_ordinal = UINT32_MAX;
+  buffer->profile_alignment = 0;
+  if (pool) {
+    iree_hal_amdgpu_buffer_pool_release(pool, buffer);
+  } else {
+    iree_allocator_free(host_allocator, buffer);
+  }
 
   IREE_TRACE_ZONE_END(z0);
 }
 
-static iree_status_t iree_hal_amdgpu_external_buffer_map_range(
+static iree_status_t iree_hal_amdgpu_buffer_map_range(
     iree_hal_buffer_t* base_buffer, iree_hal_mapping_mode_t mapping_mode,
     iree_hal_memory_access_t memory_access,
     iree_device_size_t local_byte_offset, iree_device_size_t local_byte_length,
     iree_hal_buffer_mapping_t* mapping) {
-  iree_hal_amdgpu_external_buffer_t* buffer =
-      iree_hal_amdgpu_external_buffer_cast(base_buffer);
+  iree_hal_amdgpu_buffer_t* buffer = iree_hal_amdgpu_buffer_cast(base_buffer);
 
   IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_memory_type(
       iree_hal_buffer_memory_type(base_buffer),
@@ -98,242 +390,39 @@
           ? IREE_HAL_BUFFER_USAGE_MAPPING_PERSISTENT
           : IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED));
 
+  // Host-visible AMDGPU HSA allocations are directly host-accessible.
   mapping->contents = iree_make_byte_span(
-      (IREE_AMDGPU_DEVICE_PTR void*)buffer->device_ptr, local_byte_length);
+      (uint8_t*)buffer->host_ptr + local_byte_offset, local_byte_length);
 
   return iree_ok_status();
 }
 
-static iree_status_t iree_hal_amdgpu_external_buffer_unmap_range(
+static iree_status_t iree_hal_amdgpu_buffer_unmap_range(
     iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset,
     iree_device_size_t local_byte_length, iree_hal_buffer_mapping_t* mapping) {
-  // Nothing to do today (though maybe we may want to flush?).
+  // Nothing to do — all host-visible AMDGPU allocations are currently coherent.
   return iree_ok_status();
 }
 
-static iree_status_t iree_hal_amdgpu_external_buffer_invalidate_range(
+static iree_status_t iree_hal_amdgpu_buffer_invalidate_range(
     iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset,
     iree_device_size_t local_byte_length) {
-  // TODO(benvanik): anything we need to do to invalidate?
+  // Nothing to do — all host-visible AMDGPU allocations are currently coherent.
   return iree_ok_status();
 }
 
-static iree_status_t iree_hal_amdgpu_external_buffer_flush_range(
+static iree_status_t iree_hal_amdgpu_buffer_flush_range(
     iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset,
     iree_device_size_t local_byte_length) {
-  // TODO(benvanik): anything we need to do to flush?
+  // Nothing to do — all host-visible AMDGPU allocations are currently coherent.
   return iree_ok_status();
 }
 
-static const iree_hal_buffer_vtable_t iree_hal_amdgpu_external_buffer_vtable = {
+static const iree_hal_buffer_vtable_t iree_hal_amdgpu_buffer_vtable = {
     .recycle = iree_hal_buffer_recycle,
-    .destroy = iree_hal_amdgpu_external_buffer_destroy,
-    .map_range = iree_hal_amdgpu_external_buffer_map_range,
-    .unmap_range = iree_hal_amdgpu_external_buffer_unmap_range,
-    .invalidate_range = iree_hal_amdgpu_external_buffer_invalidate_range,
-    .flush_range = iree_hal_amdgpu_external_buffer_flush_range,
+    .destroy = iree_hal_amdgpu_buffer_destroy,
+    .map_range = iree_hal_amdgpu_buffer_map_range,
+    .unmap_range = iree_hal_amdgpu_buffer_unmap_range,
+    .invalidate_range = iree_hal_amdgpu_buffer_invalidate_range,
+    .flush_range = iree_hal_amdgpu_buffer_flush_range,
 };
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_transient_buffer_t
-//===----------------------------------------------------------------------===//
-
-static const iree_hal_buffer_vtable_t iree_hal_amdgpu_transient_buffer_vtable;
-
-void iree_hal_amdgpu_transient_buffer_initialize(
-    iree_hal_buffer_placement_t placement,
-    iree_hal_amdgpu_device_allocation_handle_t* handle,
-    iree_hal_buffer_release_callback_t release_callback,
-    iree_hal_amdgpu_transient_buffer_t* out_buffer) {
-  IREE_ASSERT_ARGUMENT(handle);
-  IREE_ASSERT_ARGUMENT(out_buffer);
-
-  memset(out_buffer, 0, sizeof(*out_buffer));
-
-  // Transient buffers *are* asynchronous buffers.
-  placement.flags |= IREE_HAL_BUFFER_PLACEMENT_FLAG_ASYNCHRONOUS;
-
-  iree_hal_buffer_initialize(
-      placement, &out_buffer->base,
-      /*allocation_size=*/0,
-      /*byte_offset=*/0, /*byte_length=*/0, /*memory_type=*/0,
-      /*allowed_access=*/0, /*allowed_usage=*/0,
-      &iree_hal_amdgpu_transient_buffer_vtable, &out_buffer->base);
-  out_buffer->handle = handle;
-  out_buffer->release_callback = release_callback;
-
-  // NOTE: transient buffers start with 0 references as they are pooled.
-  iree_atomic_ref_count_init_value(&out_buffer->base.resource.ref_count, 0);
-}
-
-void iree_hal_amdgpu_transient_buffer_deinitialize(
-    iree_hal_amdgpu_transient_buffer_t* buffer) {
-  // No-op.
-}
-
-static iree_hal_amdgpu_transient_buffer_t*
-iree_hal_amdgpu_transient_buffer_cast(iree_hal_buffer_t* base_value) {
-  IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_amdgpu_transient_buffer_vtable);
-  return (iree_hal_amdgpu_transient_buffer_t*)base_value;
-}
-
-static const iree_hal_amdgpu_transient_buffer_t*
-iree_hal_amdgpu_transient_buffer_const_cast(
-    const iree_hal_buffer_t* base_value) {
-  IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_amdgpu_transient_buffer_vtable);
-  return (const iree_hal_amdgpu_transient_buffer_t*)base_value;
-}
-
-// Returns true if |buffer| is an iree_hal_amdgpu_transient_buffer_t.
-static bool iree_hal_amdgpu_transient_buffer_isa(iree_hal_buffer_t* buffer) {
-  return iree_hal_resource_is(buffer, &iree_hal_amdgpu_transient_buffer_vtable);
-}
-
-static void iree_hal_amdgpu_transient_buffer_recycle(
-    iree_hal_buffer_t* base_buffer) {
-  if (IREE_UNLIKELY(!base_buffer)) return;
-  iree_hal_amdgpu_transient_buffer_t* buffer =
-      iree_hal_amdgpu_transient_buffer_cast(base_buffer);
-  IREE_ASSERT(buffer->release_callback.fn);
-  if (buffer->release_callback.fn) {
-    buffer->release_callback.fn(buffer->release_callback.user_data,
-                                base_buffer);
-  }
-}
-
-void iree_hal_amdgpu_transient_buffer_reset(
-    iree_hal_amdgpu_transient_buffer_t* buffer, iree_hal_buffer_params_t params,
-    iree_device_size_t allocation_size, iree_device_size_t byte_offset,
-    iree_device_size_t byte_length) {
-  buffer->base.memory_type = params.type;
-  buffer->base.allowed_access = params.access;
-  buffer->base.allowed_usage = params.usage;
-  buffer->base.allocation_size = allocation_size;
-  buffer->base.byte_offset = byte_offset;
-  buffer->base.byte_length = byte_length;
-  buffer->base.placement.queue_affinity = params.queue_affinity
-                                              ? params.queue_affinity
-                                              : IREE_HAL_QUEUE_AFFINITY_ANY;
-}
-
-// Returns the device pointer from the allocation handle if it is currently
-// allocated. This may read from device memory and is only valid when the buffer
-// is allocated and the completion signals have propagated (ordering the
-// writes).
-static iree_status_t iree_hal_amdgpu_transient_buffer_ptr(
-    iree_hal_amdgpu_transient_buffer_t* buffer, void** out_ptr) {
-  static_assert(sizeof(void*) == sizeof(iree_atomic_uint64_t),
-                "only 64-bit pointers are supported");
-  IREE_AMDGPU_DEVICE_PTR void* ptr = (void*)iree_atomic_load(
-      (iree_atomic_uint64_t*)&buffer->handle->ptr, iree_memory_order_acquire);
-  if (!ptr) {
-    *out_ptr = NULL;
-    return iree_make_status(
-        IREE_STATUS_FAILED_PRECONDITION,
-        "transient buffer has no backing allocation at this time; host usage "
-        "is only valid once an alloca has signaled completion and prior to "
-        "enqueuing any dealloca");
-  }
-  *out_ptr = ptr;
-  return iree_ok_status();
-}
-
-static iree_status_t iree_hal_amdgpu_transient_buffer_map_range(
-    iree_hal_buffer_t* base_buffer, iree_hal_mapping_mode_t mapping_mode,
-    iree_hal_memory_access_t memory_access,
-    iree_device_size_t local_byte_offset, iree_device_size_t local_byte_length,
-    iree_hal_buffer_mapping_t* mapping) {
-  iree_hal_amdgpu_transient_buffer_t* buffer =
-      iree_hal_amdgpu_transient_buffer_cast(base_buffer);
-
-  IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_memory_type(
-      iree_hal_buffer_memory_type(base_buffer),
-      IREE_HAL_MEMORY_TYPE_HOST_VISIBLE));
-  IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_usage(
-      iree_hal_buffer_allowed_usage(base_buffer),
-      mapping_mode == IREE_HAL_MAPPING_MODE_PERSISTENT
-          ? IREE_HAL_BUFFER_USAGE_MAPPING_PERSISTENT
-          : IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED));
-
-  // Resolve device pointer (if available).
-  IREE_AMDGPU_DEVICE_PTR void* ptr = NULL;
-  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_transient_buffer_ptr(buffer, &ptr));
-
-  mapping->contents = iree_make_byte_span(ptr, local_byte_length);
-
-  return iree_ok_status();
-}
-
-static iree_status_t iree_hal_amdgpu_transient_buffer_unmap_range(
-    iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset,
-    iree_device_size_t local_byte_length, iree_hal_buffer_mapping_t* mapping) {
-  // Nothing to do today (though maybe we may want to flush?).
-  return iree_ok_status();
-}
-
-static iree_status_t iree_hal_amdgpu_transient_buffer_invalidate_range(
-    iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset,
-    iree_device_size_t local_byte_length) {
-  // TODO(benvanik): anything we need to do to invalidate?
-  return iree_ok_status();
-}
-
-static iree_status_t iree_hal_amdgpu_transient_buffer_flush_range(
-    iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset,
-    iree_device_size_t local_byte_length) {
-  // TODO(benvanik): anything we need to do to flush?
-  return iree_ok_status();
-}
-
-static const iree_hal_buffer_vtable_t iree_hal_amdgpu_transient_buffer_vtable =
-    {
-        .recycle = iree_hal_amdgpu_transient_buffer_recycle,
-        .destroy = iree_hal_amdgpu_transient_buffer_recycle,
-        .map_range = iree_hal_amdgpu_transient_buffer_map_range,
-        .unmap_range = iree_hal_amdgpu_transient_buffer_unmap_range,
-        .invalidate_range = iree_hal_amdgpu_transient_buffer_invalidate_range,
-        .flush_range = iree_hal_amdgpu_transient_buffer_flush_range,
-};
-
-//===----------------------------------------------------------------------===//
-// Buffer Resolution
-//===----------------------------------------------------------------------===//
-
-iree_status_t iree_hal_amdgpu_resolve_buffer(
-    iree_hal_buffer_t* base_buffer,
-    iree_hal_amdgpu_device_buffer_type_t* out_type, uint64_t* out_bits) {
-  if (iree_hal_resource_is(base_buffer,
-                           &iree_hal_amdgpu_transient_buffer_vtable)) {
-    *out_type = IREE_HAL_AMDGPU_DEVICE_BUFFER_TYPE_HANDLE;
-    *out_bits =
-        (uint64_t)((iree_hal_amdgpu_transient_buffer_t*)base_buffer)->handle;
-    return iree_ok_status();
-  } else if (iree_hal_resource_is(base_buffer,
-                                  &iree_hal_amdgpu_external_buffer_vtable)) {
-    *out_type = IREE_HAL_AMDGPU_DEVICE_BUFFER_TYPE_PTR;
-    *out_bits =
-        (uint64_t)((iree_hal_amdgpu_external_buffer_t*)base_buffer)->device_ptr;
-    return iree_ok_status();
-  } else {
-    return iree_make_status(
-        IREE_STATUS_FAILED_PRECONDITION,
-        "unsupported buffer type; expected something known to the AMDGPU HAL");
-  }
-}
-
-iree_status_t iree_hal_amdgpu_resolve_transient_buffer(
-    iree_hal_buffer_t* base_buffer,
-    iree_hal_amdgpu_device_allocation_handle_t** out_handle) {
-  if (!iree_hal_resource_is(base_buffer,
-                            &iree_hal_amdgpu_transient_buffer_vtable)) {
-    return iree_make_status(
-        IREE_STATUS_INVALID_ARGUMENT,
-        "provided buffer is not a transient allocation; only buffers allocated "
-        "with iree_hal_device_queue_alloca can be deallocated using "
-        "iree_hal_device_queue_dealloca");
-  }
-  iree_hal_amdgpu_transient_buffer_t* buffer =
-      (iree_hal_amdgpu_transient_buffer_t*)base_buffer;
-  *out_handle = buffer->handle;
-  return iree_ok_status();
-}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/buffer.h b/runtime/src/iree/hal/drivers/amdgpu/buffer.h
index 102cbaa..bb6aa33 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/buffer.h
+++ b/runtime/src/iree/hal/drivers/amdgpu/buffer.h
@@ -1,4 +1,4 @@
-// Copyright 2025 The IREE Authors
+// Copyright 2026 The IREE Authors
 //
 // Licensed under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -8,79 +8,119 @@
 #define IREE_HAL_DRIVERS_AMDGPU_BUFFER_H_
 
 #include "iree/base/api.h"
+#include "iree/base/internal/arena.h"
+#include "iree/base/internal/atomics.h"
+#include "iree/base/threading/mutex.h"
 #include "iree/hal/api.h"
-#include "iree/hal/drivers/amdgpu/device/buffer.h"
+#include "iree/hal/drivers/amdgpu/util/libhsa.h"
 
 #ifdef __cplusplus
 extern "C" {
 #endif  // __cplusplus
 
 //===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_external_buffer_t
+// iree_hal_amdgpu_buffer_t
 //===----------------------------------------------------------------------===//
 
-// Wraps an external device-accessible |device_ptr| allocation in an
-// iree_hal_buffer_t. The |release_callback| will be issued upon release of the
-// last buffer reference.
-iree_status_t iree_hal_amdgpu_external_buffer_wrap(
+typedef struct iree_hal_amdgpu_buffer_t iree_hal_amdgpu_buffer_t;
+
+// Per-physical-device pool of materialized AMDGPU HAL buffer wrappers.
+//
+// The pool only owns the host-side iree_hal_amdgpu_buffer_t storage. Backing
+// HSA memory ownership remains expressed by each buffer's release callback or
+// direct HSA allocation ownership.
+typedef struct iree_hal_amdgpu_buffer_pool_t {
+  // Per-physical-device host block pool used for cold wrapper-block growth.
+  iree_arena_block_pool_t* block_pool;
+
+  // Head of the lock-free return stack pushed by AMDGPU buffer destroy.
+  iree_atomic_intptr_t return_head;
+
+  // Serializes acquire-side cache pops, return-stack migration, and growth.
+  iree_slim_mutex_t mutex;
+
+  // Head of the mutex-protected acquire-side cache.
+  iree_hal_amdgpu_buffer_t* acquire_head;
+
+  // First host block owned by this pool.
+  iree_arena_block_t* block_head;
+
+  // Last host block owned by this pool.
+  iree_arena_block_t* block_tail;
+
+#if !defined(NDEBUG)
+  // Number of wrappers currently retained by users or in-flight operations.
+  iree_atomic_int32_t live_count;
+#endif  // !defined(NDEBUG)
+} iree_hal_amdgpu_buffer_pool_t;
+
+// Initializes a per-physical-device materialized buffer wrapper pool.
+//
+// No wrapper memory is allocated until the first acquire. Wrapper storage grows
+// in blocks borrowed from |block_pool| and returned during deinitialization.
+iree_status_t iree_hal_amdgpu_buffer_pool_initialize(
+    iree_arena_block_pool_t* block_pool,
+    iree_hal_amdgpu_buffer_pool_t* out_pool);
+
+// Deinitializes the pool and releases all cold-grown wrapper blocks.
+//
+// All buffers allocated from the pool must have been released before this is
+// called. Violating that lifetime contract is a device teardown/use-after-free
+// bug and is checked in debug builds.
+void iree_hal_amdgpu_buffer_pool_deinitialize(
+    iree_hal_amdgpu_buffer_pool_t* pool);
+
+// Wraps an HSA memory pool allocation in an iree_hal_buffer_t.
+// If |release_callback| is null the buffer owns the HSA allocation and frees
+// it directly on destroy. Otherwise the callback owns teardown/release of the
+// wrapped memory and any associated pool bookkeeping.
+//
+// |allocation_size| is the full size of the HSA allocation and may be larger
+// than the logical |byte_length| exposed through the HAL buffer.
+iree_status_t iree_hal_amdgpu_buffer_create(
+    const iree_hal_amdgpu_libhsa_t* libhsa,
     iree_hal_buffer_placement_t placement, iree_hal_memory_type_t memory_type,
     iree_hal_memory_access_t allowed_access,
     iree_hal_buffer_usage_t allowed_usage, iree_device_size_t allocation_size,
-    iree_device_size_t byte_offset, iree_device_size_t byte_length,
-    uint64_t device_ptr, iree_hal_buffer_release_callback_t release_callback,
+    iree_device_size_t byte_length, void* host_ptr,
+    iree_hal_buffer_release_callback_t release_callback,
     iree_allocator_t host_allocator, iree_hal_buffer_t** out_buffer);
 
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_transient_buffer_t
-//===----------------------------------------------------------------------===//
-
-typedef struct iree_hal_amdgpu_transient_buffer_t {
-  iree_hal_buffer_t base;  // must be at 0
-
-  // Device-side allocation handle in a memory pool accessible to all agents.
-  // This may reside in host local memory.
-  IREE_AMDGPU_DEVICE_PTR iree_hal_amdgpu_device_allocation_handle_t* handle;
-
-  // Release callback that handles deallocation.
-  iree_hal_buffer_release_callback_t release_callback;
-} iree_hal_amdgpu_transient_buffer_t;
-
-// Initializes a transient buffer in-place with a 0 ref count.
-// The owning pool must increment the ref count to 1 before returning the
-// buffer to users.
-void iree_hal_amdgpu_transient_buffer_initialize(
-    iree_hal_buffer_placement_t placement,
-    iree_hal_amdgpu_device_allocation_handle_t* handle,
+// Wraps an HSA memory pool allocation in a pooled iree_hal_buffer_t wrapper.
+//
+// The returned buffer has the same memory ownership semantics as
+// iree_hal_amdgpu_buffer_create(), but its host-side wrapper storage is
+// returned to |pool| instead of |host_allocator| when the final reference is
+// released. |pool| must outlive all buffers allocated from it.
+iree_status_t iree_hal_amdgpu_buffer_create_pooled(
+    const iree_hal_amdgpu_libhsa_t* libhsa,
+    iree_hal_buffer_placement_t placement, iree_hal_memory_type_t memory_type,
+    iree_hal_memory_access_t allowed_access,
+    iree_hal_buffer_usage_t allowed_usage, iree_device_size_t allocation_size,
+    iree_device_size_t byte_length, void* host_ptr,
     iree_hal_buffer_release_callback_t release_callback,
-    iree_hal_amdgpu_transient_buffer_t* out_buffer);
+    iree_hal_amdgpu_buffer_pool_t* pool, iree_allocator_t host_allocator,
+    iree_hal_buffer_t** out_buffer);
 
-// Deinitializes a transient buffer in-place assuming it has a 0 ref count.
-void iree_hal_amdgpu_transient_buffer_deinitialize(
-    iree_hal_amdgpu_transient_buffer_t* buffer);
+// Tags |buffer| with a profiling allocation id.
+//
+// Only direct synchronous allocator buffers should be tagged. Pool/materialized
+// and queue_alloca transient buffers have their own pool/reservation event
+// streams and must not be double-counted as standalone buffer allocations.
+void iree_hal_amdgpu_buffer_set_profile_allocation(
+    iree_hal_buffer_t* buffer, uint64_t session_id, uint64_t allocation_id,
+    uint64_t pool_id, uint32_t physical_device_ordinal,
+    iree_device_size_t alignment);
 
-// Resets |buffer| to the given parameters as if it had just been allocated.
-void iree_hal_amdgpu_transient_buffer_reset(
-    iree_hal_amdgpu_transient_buffer_t* buffer, iree_hal_buffer_params_t params,
-    iree_device_size_t allocation_size, iree_device_size_t byte_offset,
-    iree_device_size_t byte_length);
-
-//===----------------------------------------------------------------------===//
-// Buffer Resolution
-//===----------------------------------------------------------------------===//
-
-// Resolves a HAL buffer to a device-side type and pointer/handle.
-// Returns success if the buffer is of a type that can be used toll-free on any
-// device but does not verify the memory referenced is accessible to any
-// particular device.
-iree_status_t iree_hal_amdgpu_resolve_buffer(
-    iree_hal_buffer_t* buffer, iree_hal_amdgpu_device_buffer_type_t* out_type,
-    uint64_t* out_bits);
-
-// Resolves a HAL buffer that is required to be a transient buffer allocated via
-// iree_hal_device_queue_alloca. Fails if the buffer is any other type.
-iree_status_t iree_hal_amdgpu_resolve_transient_buffer(
-    iree_hal_buffer_t* buffer,
-    iree_hal_amdgpu_device_allocation_handle_t** out_handle);
+// Returns the HSA-allocated base pointer for the given |buffer|, or NULL if
+// |buffer| is not an AMDGPU buffer. HSA uses unified virtual addressing so
+// the returned pointer is valid for both host and GPU access.
+//
+// This is the entire allocated_buffer and must be offset by
+// iree_hal_buffer_byte_offset and the binding offset when computing kernarg
+// binding addresses. |buffer| must be the allocated buffer (not a subspan);
+// callers should use iree_hal_buffer_allocated_buffer() to unwrap first.
+void* iree_hal_amdgpu_buffer_device_pointer(iree_hal_buffer_t* buffer);
 
 #ifdef __cplusplus
 }  // extern "C"
diff --git a/runtime/src/iree/hal/drivers/amdgpu/buffer_pool.c b/runtime/src/iree/hal/drivers/amdgpu/buffer_pool.c
deleted file mode 100644
index 8417340..0000000
--- a/runtime/src/iree/hal/drivers/amdgpu/buffer_pool.c
+++ /dev/null
@@ -1,437 +0,0 @@
-// Copyright 2025 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/hal/drivers/amdgpu/buffer_pool.h"
-
-#include "iree/hal/drivers/amdgpu/buffer.h"
-#include "iree/hal/drivers/amdgpu/device/buffer.h"
-#include "iree/hal/drivers/amdgpu/util/topology.h"
-
-static void iree_hal_amdgpu_buffer_pool_link_free_block(
-    iree_hal_amdgpu_buffer_pool_t* buffer_pool,
-    iree_hal_amdgpu_buffer_pool_block_t* block);
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_buffer_pool_block_t
-//===----------------------------------------------------------------------===//
-
-// A block of allocated buffers. Manages both host heap memory and
-// device-visible memory for the device-side library resources.
-//
-// Thread-safe; each block has its own lock for free list management.
-typedef struct iree_hal_amdgpu_buffer_pool_block_t {
-  // Pool that owns this block.
-  iree_hal_amdgpu_buffer_pool_t* buffer_pool;
-  // Previous block in the pool block linked list.
-  iree_hal_amdgpu_buffer_pool_block_t* prev_block;
-  // Next block in the pool block linked list.
-  iree_hal_amdgpu_buffer_pool_block_t* next_block;
-  // Next block in the pool block linked list with free entries.
-  iree_hal_amdgpu_buffer_pool_block_t* next_free;
-  // Capacity of the block in buffers.
-  iree_host_size_t capacity;
-  // Device memory base pointer used for
-  // `iree_hal_amdgpu_device_allocation_handle_t`.
-  IREE_AMDGPU_DEVICE_PTR uint8_t* device_allocation_ptr;
-  // Mutex guarding the mutable block fields.
-  iree_slim_mutex_t mutex;
-  // Count of free buffers in the block stored in the free_list.
-  iree_host_size_t free_count IREE_GUARDED_BY(mutex);
-  // Free buffers that are available for use.
-  iree_hal_amdgpu_transient_buffer_t* free_list[/*capacity*/] IREE_GUARDED_BY(
-      mutex);
-  // Trailing list of iree_hal_amdgpu_transient_buffer_t[capacity].
-} iree_hal_amdgpu_buffer_pool_block_t;
-
-static void iree_hal_amdgpu_buffer_pool_block_free(
-    iree_hal_amdgpu_buffer_pool_block_t* block);
-static void iree_hal_amdgpu_buffer_pool_block_recycle(
-    void* user_data, iree_hal_buffer_t* base_buffer);
-
-// Allocates a block of |capacity| buffers on host and device.
-static iree_status_t iree_hal_amdgpu_buffer_pool_block_allocate(
-    iree_hal_amdgpu_buffer_pool_t* buffer_pool, iree_host_size_t capacity,
-    iree_hal_amdgpu_buffer_pool_block_t** out_block) {
-  IREE_ASSERT_ARGUMENT(out_block);
-  IREE_TRACE_ZONE_BEGIN(z0);
-  *out_block = NULL;
-
-  // Allocate and initialize host memory.
-  iree_hal_amdgpu_buffer_pool_block_t* block = NULL;
-  const iree_host_size_t free_list_size =
-      capacity * sizeof(block->free_list[0]);
-  const iree_host_size_t total_block_size =
-      sizeof(*block) + free_list_size +
-      capacity * sizeof(iree_hal_amdgpu_transient_buffer_t);
-  IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0, iree_allocator_malloc(buffer_pool->host_allocator, total_block_size,
-                                (void**)&block));
-  block->buffer_pool = buffer_pool;
-  block->prev_block = NULL;
-  block->next_block = NULL;
-  block->next_free = NULL;
-  block->capacity = capacity;
-  block->device_allocation_ptr = NULL;
-  iree_slim_mutex_initialize(&block->mutex);
-
-  // Allocate device memory from the device memory pool.
-  const iree_host_size_t total_device_size =
-      capacity * sizeof(iree_hal_amdgpu_device_allocation_handle_t);
-  iree_status_t status = iree_hsa_amd_memory_pool_allocate(
-      IREE_LIBHSA(buffer_pool->libhsa), buffer_pool->memory_pool,
-      total_device_size, HSA_AMD_MEMORY_POOL_STANDARD_FLAG,
-      (void**)&block->device_allocation_ptr);
-
-  // Make the allocation visible to all devices.
-  if (iree_status_is_ok(status)) {
-    status = iree_hsa_amd_agents_allow_access(
-        IREE_LIBHSA(buffer_pool->libhsa),
-        buffer_pool->topology->all_agent_count,
-        buffer_pool->topology->all_agents, /*flags=*/NULL,
-        block->device_allocation_ptr);
-  }
-
-  // Initialize each host buffer and build the free list.
-  if (iree_status_is_ok(status)) {
-    iree_hal_amdgpu_transient_buffer_t* base_host_ptr =
-        (iree_hal_amdgpu_transient_buffer_t*)((uint8_t*)block + sizeof(*block) +
-                                              free_list_size);
-    iree_hal_amdgpu_device_allocation_handle_t* base_device_ptr =
-        (iree_hal_amdgpu_device_allocation_handle_t*)
-            block->device_allocation_ptr;
-    block->free_count = capacity;
-    iree_hal_buffer_release_callback_t release_callback = {
-        .fn = iree_hal_amdgpu_buffer_pool_block_recycle,
-        .user_data = block,
-    };
-    for (iree_host_size_t i = 0; i < capacity; ++i) {
-      iree_hal_amdgpu_transient_buffer_t* buffer = &base_host_ptr[i];
-      iree_hal_amdgpu_device_allocation_handle_t* handle = &base_device_ptr[i];
-      iree_hal_amdgpu_transient_buffer_initialize(
-          buffer_pool->placement, handle, release_callback, buffer);
-      block->free_list[i] = buffer;
-    }
-  }
-
-  if (iree_status_is_ok(status)) {
-    *out_block = block;
-  } else {
-    iree_hal_amdgpu_buffer_pool_block_free(block);
-  }
-  IREE_TRACE_ZONE_END(z0);
-  return status;
-}
-
-// Frees a |block| of buffers and its device memory.
-static void iree_hal_amdgpu_buffer_pool_block_free(
-    iree_hal_amdgpu_buffer_pool_block_t* block) {
-  IREE_ASSERT_ARGUMENT(block);
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  iree_slim_mutex_lock(&block->mutex);
-  IREE_ASSERT_EQ(block->free_count, block->capacity);
-  iree_slim_mutex_unlock(&block->mutex);
-
-  // Deinitialize all host buffers. They are allocated as part of the block
-  // and only need to be cleaned up.
-  iree_hal_amdgpu_transient_buffer_t* base_host_ptr =
-      (iree_hal_amdgpu_transient_buffer_t*)((uint8_t*)block +
-                                            block->capacity *
-                                                sizeof(block->free_list[0]));
-  for (iree_host_size_t i = 0; i < block->capacity; ++i) {
-    iree_hal_amdgpu_transient_buffer_t* buffer = &base_host_ptr[i];
-    iree_hal_amdgpu_transient_buffer_deinitialize(buffer);
-  }
-
-  // Deallocate device memory.
-  if (block->device_allocation_ptr) {
-    IREE_IGNORE_ERROR(iree_hsa_amd_memory_pool_free(
-        IREE_LIBHSA(block->buffer_pool->libhsa), block->device_allocation_ptr));
-    block->device_allocation_ptr = NULL;
-  }
-
-  // Frees the block and its embedded storage.
-  iree_slim_mutex_deinitialize(&block->mutex);
-  iree_allocator_free(block->buffer_pool->host_allocator, block);
-
-  IREE_TRACE_ZONE_END(z0);
-}
-
-// Recycles a buffer after it has no remaining uses.
-static void iree_hal_amdgpu_buffer_pool_block_recycle(
-    void* user_data, iree_hal_buffer_t* base_buffer) {
-  iree_hal_amdgpu_buffer_pool_block_t* block =
-      (iree_hal_amdgpu_buffer_pool_block_t*)user_data;
-  iree_hal_amdgpu_transient_buffer_t* buffer =
-      (iree_hal_amdgpu_transient_buffer_t*)base_buffer;
-
-  // Buffer should have zero references before being recycled.
-  IREE_ASSERT_REF_COUNT_ZERO(&base_buffer->resource.ref_count);
-
-  // Add to the block free list.
-  iree_slim_mutex_lock(&block->mutex);
-
-  const bool full_to_free = block->free_count == 0;
-  block->free_list[block->free_count++] = buffer;
-
-  iree_slim_mutex_unlock(&block->mutex);
-
-  // If the block has gone from 0 to >0 free entries then link it back into the
-  // pool free list for use. Note that we can only do this on the transition
-  // from full to free as otherwise the block is already in the free list.
-  //
-  // NOTE: this happens outside of the per-block lock as the pool will hold its
-  // lock over the free list while acquiring a new entry. This may lead to
-  // (safe) races where an acquire checks the free list while we are updating
-  // the block above but before we update the free list but that's rare and
-  // bounded (there may be one extra block in the pool).
-  if (full_to_free) {
-    iree_hal_amdgpu_buffer_pool_link_free_block(block->buffer_pool, block);
-  }
-}
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_buffer_pool_t
-//===----------------------------------------------------------------------===//
-
-iree_status_t iree_hal_amdgpu_buffer_pool_initialize(
-    const iree_hal_amdgpu_libhsa_t* libhsa,
-    const iree_hal_amdgpu_topology_t* topology,
-    iree_hal_buffer_placement_t placement, iree_host_size_t block_capacity,
-    iree_allocator_t host_allocator, hsa_amd_memory_pool_t memory_pool,
-    iree_hal_amdgpu_buffer_pool_t* out_buffer_pool) {
-  IREE_ASSERT_ARGUMENT(libhsa);
-  IREE_ASSERT_ARGUMENT(topology);
-  IREE_ASSERT_ARGUMENT(out_buffer_pool);
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  out_buffer_pool->libhsa = libhsa;
-  out_buffer_pool->topology = topology;
-  out_buffer_pool->placement = placement;
-  out_buffer_pool->host_allocator = host_allocator;
-  out_buffer_pool->memory_pool = memory_pool;
-
-  // TODO(benvanik): set a pool handle we can use to route requests across the
-  // host service.
-  out_buffer_pool->pool = 0;
-
-  iree_slim_mutex_initialize(&out_buffer_pool->mutex);
-  out_buffer_pool->list_head = NULL;
-  out_buffer_pool->free_head = NULL;
-
-  // Query the memory pool for its allocation granularity.
-  // This is not the minimum allocation size
-  // (HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_GRANULE) but the recommended size
-  // to prevent internal fragmentation. We will always make allocations of this
-  // size and adjust the block capacity to match.
-  size_t alloc_rec_granule = 0;
-  IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0,
-      iree_hsa_amd_memory_pool_get_info(
-          IREE_LIBHSA(libhsa), memory_pool,
-          HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_REC_GRANULE,
-          &alloc_rec_granule),
-      "querying HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_REC_GRANULE to "
-      "determine block capacity");
-  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, alloc_rec_granule);
-
-  // Allocate aligned to the recommended allocation granularity.
-  // We'll always be some multiple of the recommended size so we waste no device
-  // space.
-  const iree_host_size_t min_capacity_per_allocation = iree_host_size_ceil_div(
-      alloc_rec_granule, sizeof(iree_hal_amdgpu_device_allocation_handle_t));
-  const iree_host_size_t capacity_per_allocation =
-      iree_host_size_ceil_div(block_capacity, min_capacity_per_allocation) *
-      min_capacity_per_allocation;
-  out_buffer_pool->block_capacity = capacity_per_allocation;
-
-  IREE_TRACE_ZONE_END(z0);
-  return iree_ok_status();
-}
-
-void iree_hal_amdgpu_buffer_pool_deinitialize(
-    iree_hal_amdgpu_buffer_pool_t* buffer_pool) {
-  IREE_ASSERT_ARGUMENT(buffer_pool);
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  iree_slim_mutex_lock(&buffer_pool->mutex);
-  iree_hal_amdgpu_buffer_pool_block_t* block = buffer_pool->list_head;
-  while (block != NULL) {
-    iree_hal_amdgpu_buffer_pool_block_t* next_block = block->next_block;
-    IREE_ASSERT_EQ(block->free_count, block->capacity);
-    iree_hal_amdgpu_buffer_pool_block_free(block);
-    block = next_block;
-  }
-  buffer_pool->list_head = NULL;
-  buffer_pool->free_head = NULL;
-  iree_slim_mutex_unlock(&buffer_pool->mutex);
-
-  iree_slim_mutex_deinitialize(&buffer_pool->mutex);
-
-  IREE_TRACE_ZONE_END(z0);
-}
-
-// Grows the |buffer_pool| by one block.
-// Requires the pool lock be held.
-static iree_status_t iree_hal_amdgpu_buffer_pool_grow(
-    iree_hal_amdgpu_buffer_pool_t* buffer_pool) {
-  IREE_ASSERT_ARGUMENT(buffer_pool);
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  // Allocate the new block and its resources.
-  iree_hal_amdgpu_buffer_pool_block_t* block = NULL;
-  IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0, iree_hal_amdgpu_buffer_pool_block_allocate(
-              buffer_pool, buffer_pool->block_capacity, &block));
-
-  // Link the block into the allocated list and the free list.
-  block->prev_block = NULL;
-  block->next_block = buffer_pool->list_head;
-  if (block->next_block) {
-    block->next_block->prev_block = block;
-  }
-  buffer_pool->list_head = block;
-  block->next_free = buffer_pool->free_head;
-  buffer_pool->free_head = block;
-
-  IREE_TRACE_ZONE_END(z0);
-  return iree_ok_status();
-}
-
-iree_status_t iree_hal_amdgpu_buffer_pool_preallocate(
-    iree_hal_amdgpu_buffer_pool_t* buffer_pool, iree_host_size_t count) {
-  IREE_ASSERT_ARGUMENT(buffer_pool);
-  IREE_TRACE_ZONE_BEGIN(z0);
-  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, count);
-
-  iree_status_t status = iree_ok_status();
-  if (count > 0) {
-    const iree_host_size_t block_count =
-        iree_host_size_ceil_div(count, buffer_pool->block_capacity);
-    for (iree_host_size_t i = 0; iree_status_is_ok(status) && i < block_count;
-         ++i) {
-      status = iree_hal_amdgpu_buffer_pool_grow(buffer_pool);
-    }
-  }
-
-  IREE_TRACE_ZONE_END(z0);
-  return status;
-}
-
-iree_status_t iree_hal_amdgpu_buffer_pool_acquire(
-    iree_hal_amdgpu_buffer_pool_t* buffer_pool, iree_hal_buffer_params_t params,
-    iree_device_size_t allocation_size, iree_hal_buffer_t** out_buffer,
-    iree_hal_amdgpu_device_allocation_handle_t** out_handle) {
-  IREE_ASSERT_ARGUMENT(buffer_pool);
-  IREE_ASSERT_ARGUMENT(out_buffer);
-  IREE_ASSERT_ARGUMENT(out_handle);
-  IREE_TRACE_ZONE_BEGIN(z0);
-  *out_buffer = NULL;
-  *out_handle = NULL;
-
-  iree_slim_mutex_lock(&buffer_pool->mutex);
-
-  // If there are no blocks with free buffers allocate a new one.
-  iree_status_t status = iree_ok_status();
-  if (buffer_pool->free_head == NULL) {
-    // TODO(benvanik): do this outside of the lock? This allocates device
-    // resources. We could have an exclusive growth lock that does not block
-    // recycling.
-    status = iree_hal_amdgpu_buffer_pool_grow(buffer_pool);
-  }
-
-  // Get the next free buffer and possibly maintain the free list.
-  iree_hal_amdgpu_transient_buffer_t* buffer = NULL;
-  if (iree_status_is_ok(status)) {
-    // Pop the last free buffer from the block.
-    iree_hal_amdgpu_buffer_pool_block_t* block = buffer_pool->free_head;
-    buffer = block->free_list[block->free_count - 1];
-    block->free_list[block->free_count - 1] = NULL;
-    --block->free_count;
-
-    // If there are no more free buffers in the block remove it from the
-    // free list.
-    if (block->free_count == 0) {
-      buffer_pool->free_head = block->next_free;
-      block->next_free = NULL;
-    }
-  }
-
-  iree_slim_mutex_unlock(&buffer_pool->mutex);
-
-  if (iree_status_is_ok(status)) {
-    // Reset buffer properties to those requested.
-    iree_hal_amdgpu_transient_buffer_reset(buffer, params, allocation_size,
-                                           /*byte_offset=*/0,
-                                           /*byte_length=*/allocation_size);
-
-    // Return with a 1 ref count as if we had allocated it.
-    iree_atomic_ref_count_inc(&buffer->base.resource.ref_count);
-    *out_buffer = &buffer->base;
-    *out_handle = buffer->handle;
-  }
-  IREE_TRACE_ZONE_END(z0);
-  return status;
-}
-
-// Links |block| into the |buffer_pool| free list.
-// Must not already be in the list.
-// The block is inserted at the head to try to have new acquisitions reuse it
-// before any others and keep the utilization high.
-static void iree_hal_amdgpu_buffer_pool_link_free_block(
-    iree_hal_amdgpu_buffer_pool_t* buffer_pool,
-    iree_hal_amdgpu_buffer_pool_block_t* block) {
-  iree_slim_mutex_lock(&buffer_pool->mutex);
-  block->next_free = buffer_pool->free_head;
-  buffer_pool->free_head = block;
-  iree_slim_mutex_unlock(&buffer_pool->mutex);
-}
-
-void iree_hal_amdgpu_buffer_pool_trim(
-    iree_hal_amdgpu_buffer_pool_t* buffer_pool) {
-  IREE_ASSERT_ARGUMENT(buffer_pool);
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  // Walk each block in the free list. If all buffers are free then drop it.
-  iree_slim_mutex_lock(&buffer_pool->mutex);
-  iree_hal_amdgpu_buffer_pool_block_t* prev_block = NULL;
-  iree_hal_amdgpu_buffer_pool_block_t* block = buffer_pool->free_head;
-  while (block != NULL) {
-    iree_hal_amdgpu_buffer_pool_block_t* next_block = block->next_free;
-    if (block->free_count != block->capacity) {
-      // One or more buffers in use - cannot free the block.
-      prev_block = block;
-      block = next_block;
-      continue;
-    }
-
-    // Unlink the block from the free list.
-    if (prev_block != NULL) {
-      prev_block->next_free = next_block;
-    } else {
-      buffer_pool->free_head = next_block;
-    }
-
-    // Unlink the block from the main list.
-    if (block->prev_block != NULL) {
-      block->prev_block->next_block = block->next_block;
-    } else {
-      buffer_pool->list_head = block->next_block;
-    }
-    if (block->next_block != NULL) {
-      block->next_block->prev_block = block->prev_block;
-    }
-
-    // Free the block and its resources.
-    iree_hal_amdgpu_buffer_pool_block_free(block);
-
-    block = next_block;
-  }
-
-  iree_slim_mutex_unlock(&buffer_pool->mutex);
-
-  IREE_TRACE_ZONE_END(z0);
-}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/buffer_pool.h b/runtime/src/iree/hal/drivers/amdgpu/buffer_pool.h
deleted file mode 100644
index 7f8dd1c..0000000
--- a/runtime/src/iree/hal/drivers/amdgpu/buffer_pool.h
+++ /dev/null
@@ -1,116 +0,0 @@
-// Copyright 2025 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_HAL_DRIVERS_AMDGPU_BUFFER_POOL_H_
-#define IREE_HAL_DRIVERS_AMDGPU_BUFFER_POOL_H_
-
-#include "iree/base/api.h"
-#include "iree/base/threading/mutex.h"
-#include "iree/hal/api.h"
-#include "iree/hal/drivers/amdgpu/util/libhsa.h"
-
-#ifdef __cplusplus
-extern "C" {
-#endif  // __cplusplus
-
-typedef struct iree_hal_amdgpu_device_allocation_handle_t
-    iree_hal_amdgpu_device_allocation_handle_t;
-
-typedef struct iree_hal_amdgpu_buffer_pool_block_t
-    iree_hal_amdgpu_buffer_pool_block_t;
-typedef struct iree_hal_amdgpu_topology_t iree_hal_amdgpu_topology_t;
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_buffer_pool_t
-//===----------------------------------------------------------------------===//
-
-// Default buffer count per block in the pool.
-// Larger is better to reduce the number of device memory allocations but we
-// don't want to have too high of a fixed overhead. Most programs only have a
-// few dozen live buffers at a time but some with heavy async behavior may have
-// many more for outstanding async allocations.
-#define IREE_HAL_AMDGPU_BUFFER_POOL_DEFAULT_BLOCK_CAPACITY (2 * 1024)
-
-// A pool of transient buffers and their corresponding device handles.
-// Buffers are allocated in blocks to reduce the number of device allocations
-// we make (as some devices/drivers may have limits). Blocks are allocated
-// on-demand and contain a fixed-size set of HAL buffers allocated inline.
-//
-// Thread-safe; multiple host threads may share the same pool.
-typedef struct iree_hal_amdgpu_buffer_pool_t {
-  // Unowned libhsa handle. Must be retained by the owner.
-  const iree_hal_amdgpu_libhsa_t* libhsa;
-  // Topology with all CPU and GPU agents. Buffers must be visible to all.
-  const iree_hal_amdgpu_topology_t* topology;
-
-  // Placement of all buffers allocated from this pool. This is the logical HAL
-  // device and queues that map to the physical device the pool is for.
-  iree_hal_buffer_placement_t placement;
-
-  // Allocator used for host allocations.
-  iree_allocator_t host_allocator;
-  // Device memory pool for device allocations.
-  hsa_amd_memory_pool_t memory_pool;
-
-  // Unused/opaque pool ID.
-  // TODO(benvanik): make this link back to this pool somehow.
-  uint64_t pool;
-
-  // Capacity of each block in buffers.
-  // Most likely IREE_HAL_AMDGPU_BUFFER_POOL_DEFAULT_BLOCK_CAPACITY (rounded up
-  // to the recommended allocation granularity).
-  iree_host_size_t block_capacity;
-
-  // Guards pool resources during acquisition.
-  iree_slim_mutex_t mutex;
-  // A doubly-linked list of all allocated blocks.
-  iree_hal_amdgpu_buffer_pool_block_t* list_head IREE_GUARDED_BY(mutex);
-  // A singly-linked list of blocks that have one or more free buffer.
-  iree_hal_amdgpu_buffer_pool_block_t* free_head IREE_GUARDED_BY(mutex);
-} iree_hal_amdgpu_buffer_pool_t;
-
-// Initializes |out_buffer_pool| for use. Performs no allocation.
-// Buffers will be usable on all GPU devices in |topology|. |placement| is used
-// as the base placement for all buffers but exact queues will be assigned when
-// buffers are acquired. Device-accessible allocation handle storage will be
-// allocated from |memory_pool| as needed (not the actual buffers - just
-// iree_hal_amdgpu_device_allocation_handle_t).
-iree_status_t iree_hal_amdgpu_buffer_pool_initialize(
-    const iree_hal_amdgpu_libhsa_t* libhsa,
-    const iree_hal_amdgpu_topology_t* topology,
-    iree_hal_buffer_placement_t placement, iree_host_size_t block_capacity,
-    iree_allocator_t host_allocator, hsa_amd_memory_pool_t memory_pool,
-    iree_hal_amdgpu_buffer_pool_t* out_buffer_pool);
-
-// Deinitializes |buffer_pool| and releases underlying memory.
-// All buffers created from the pool must have been released back to it.
-void iree_hal_amdgpu_buffer_pool_deinitialize(
-    iree_hal_amdgpu_buffer_pool_t* buffer_pool);
-
-// Preallocates |count| buffer handles and adds them to the pool free list.
-iree_status_t iree_hal_amdgpu_buffer_pool_preallocate(
-    iree_hal_amdgpu_buffer_pool_t* buffer_pool, iree_host_size_t count);
-
-// Acquires an |allocation_size| buffer from the pool with the given |params|.
-// The pool must remain live until the returned |out_buffer| has been fully
-// recycled (ref count 0). The returned |out_handle| is the device-side handle
-// owned by the buffer and is provided to callers as a convenience; it can be
-// accessed from the buffer in the future using
-// iree_hal_amdgpu_resolve_transient_buffer.
-iree_status_t iree_hal_amdgpu_buffer_pool_acquire(
-    iree_hal_amdgpu_buffer_pool_t* buffer_pool, iree_hal_buffer_params_t params,
-    iree_device_size_t allocation_size, iree_hal_buffer_t** out_buffer,
-    iree_hal_amdgpu_device_allocation_handle_t** out_handle);
-
-// Trims all blocks that have no allocated buffers.
-void iree_hal_amdgpu_buffer_pool_trim(
-    iree_hal_amdgpu_buffer_pool_t* buffer_pool);
-
-#ifdef __cplusplus
-}  // extern "C"
-#endif  // __cplusplus
-
-#endif  // IREE_HAL_DRIVERS_AMDGPU_UTIL_BUFFER_POOL_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/buffer_pool_test.cc b/runtime/src/iree/hal/drivers/amdgpu/buffer_pool_test.cc
deleted file mode 100644
index 9c4ee55..0000000
--- a/runtime/src/iree/hal/drivers/amdgpu/buffer_pool_test.cc
+++ /dev/null
@@ -1,278 +0,0 @@
-// Copyright 2025 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/hal/drivers/amdgpu/buffer_pool.h"
-
-#include <vector>
-
-#include "iree/base/api.h"
-#include "iree/hal/drivers/amdgpu/buffer.h"
-#include "iree/hal/drivers/amdgpu/device/buffer.h"
-#include "iree/hal/drivers/amdgpu/util/topology.h"
-#include "iree/hal/drivers/amdgpu/util/vmem.h"
-#include "iree/testing/gtest.h"
-#include "iree/testing/status_matchers.h"
-
-namespace iree::hal::amdgpu {
-namespace {
-
-using iree::testing::status::StatusIs;
-
-struct BufferPoolTest : public ::testing::Test {
-  static iree_allocator_t host_allocator;
-  static iree_hal_amdgpu_libhsa_t libhsa;
-  static iree_hal_amdgpu_topology_t topology;
-  static hsa_amd_memory_pool_t cpu_memory_pool;
-
-  static void SetUpTestSuite() {
-    IREE_TRACE_SCOPE();
-    host_allocator = iree_allocator_system();
-    iree_status_t status = iree_hal_amdgpu_libhsa_initialize(
-        IREE_HAL_AMDGPU_LIBHSA_FLAG_NONE, iree_string_view_list_empty(),
-        host_allocator, &libhsa);
-    if (!iree_status_is_ok(status)) {
-      iree_status_fprint(stderr, status);
-      iree_status_ignore(status);
-      GTEST_SKIP() << "HSA not available, skipping tests";
-    }
-    IREE_ASSERT_OK(
-        iree_hal_amdgpu_topology_initialize_with_defaults(&libhsa, &topology));
-    if (topology.gpu_agent_count == 0) {
-      GTEST_SKIP() << "no GPU devices available, skipping tests";
-    }
-
-    hsa_agent_t cpu_agent = topology.cpu_agents[0];
-    IREE_ASSERT_OK(iree_hal_amdgpu_find_fine_global_memory_pool(
-        &libhsa, cpu_agent, &cpu_memory_pool));
-  }
-
-  static void TearDownTestSuite() {
-    IREE_TRACE_SCOPE();
-    iree_hal_amdgpu_topology_deinitialize(&topology);
-    iree_hal_amdgpu_libhsa_deinitialize(&libhsa);
-  }
-};
-iree_allocator_t BufferPoolTest::host_allocator;
-iree_hal_amdgpu_libhsa_t BufferPoolTest::libhsa;
-iree_hal_amdgpu_topology_t BufferPoolTest::topology;
-hsa_amd_memory_pool_t BufferPoolTest::cpu_memory_pool;
-
-// Tests that a pool can be initialized/deinitialized successfully.
-// Note that pools do not allocate anything on initialization so this should
-// never allocate.
-TEST_F(BufferPoolTest, Lifetime) {
-  IREE_TRACE_SCOPE();
-
-  iree_hal_buffer_placement_t placement = {
-      /*.device=*/NULL,  // not available in test
-      /*.queue_affinity=*/IREE_HAL_QUEUE_AFFINITY_ANY,
-      /*.flags=*/IREE_HAL_BUFFER_PLACEMENT_FLAG_ASYNCHRONOUS,
-  };
-  iree_hal_amdgpu_buffer_pool_t buffer_pool = {0};
-  IREE_ASSERT_OK(iree_hal_amdgpu_buffer_pool_initialize(
-      &libhsa, &topology, placement,
-      IREE_HAL_AMDGPU_BUFFER_POOL_DEFAULT_BLOCK_CAPACITY, host_allocator,
-      cpu_memory_pool, &buffer_pool));
-
-  // No-op since nothing has been allocated.
-  iree_hal_amdgpu_buffer_pool_trim(&buffer_pool);
-
-  iree_hal_amdgpu_buffer_pool_deinitialize(&buffer_pool);
-}
-
-// Tests a pool that has preallocation requests.
-// We make a few requests interleaved with trims and then rely on
-// deinitialization to free the remaining resources to ensure there are no
-// leaks.
-TEST_F(BufferPoolTest, LifetimePreallocate) {
-  IREE_TRACE_SCOPE();
-
-  iree_hal_buffer_placement_t placement = {
-      /*.device=*/NULL,  // not available in test
-      /*.queue_affinity=*/IREE_HAL_QUEUE_AFFINITY_ANY,
-      /*.flags=*/IREE_HAL_BUFFER_PLACEMENT_FLAG_ASYNCHRONOUS,
-  };
-  iree_hal_amdgpu_buffer_pool_t buffer_pool = {0};
-  IREE_ASSERT_OK(iree_hal_amdgpu_buffer_pool_initialize(
-      &libhsa, &topology, placement,
-      /*block_capacity=*/32, host_allocator, cpu_memory_pool, &buffer_pool));
-
-  // No-op since nothing has been allocated yet.
-  iree_hal_amdgpu_buffer_pool_trim(&buffer_pool);
-
-  // No-op preallocation (can happen if we blindly pass options/flags of 0).
-  IREE_ASSERT_OK(iree_hal_amdgpu_buffer_pool_preallocate(&buffer_pool, 0));
-
-  // Preallocate one block.
-  IREE_ASSERT_OK(iree_hal_amdgpu_buffer_pool_preallocate(
-      &buffer_pool, buffer_pool.block_capacity));
-
-  // Trim the entire block (nothing is used).
-  iree_hal_amdgpu_buffer_pool_trim(&buffer_pool);
-
-  // Preallocate two blocks.
-  IREE_ASSERT_OK(iree_hal_amdgpu_buffer_pool_preallocate(
-      &buffer_pool, buffer_pool.block_capacity + 1));
-
-  // Preallocate one more block (1 buffer ceildiv capacity = 1 block).
-  IREE_ASSERT_OK(iree_hal_amdgpu_buffer_pool_preallocate(&buffer_pool, 1));
-
-  // Deinitialize with remaining preallocated blocks to test cleanup.
-  iree_hal_amdgpu_buffer_pool_deinitialize(&buffer_pool);
-}
-
-// Tests acquiring and releasing a buffer handle from the pool.
-TEST_F(BufferPoolTest, AcquireRelease) {
-  IREE_TRACE_SCOPE();
-
-  iree_hal_buffer_placement_t placement = {
-      /*.device=*/(iree_hal_device_t*)0xF00Du,  // not available in test
-      /*.queue_affinity=*/1ull,
-      /*.flags=*/IREE_HAL_BUFFER_PLACEMENT_FLAG_ASYNCHRONOUS,
-  };
-  iree_hal_amdgpu_buffer_pool_t buffer_pool = {0};
-  IREE_ASSERT_OK(iree_hal_amdgpu_buffer_pool_initialize(
-      &libhsa, &topology, placement,
-      /*block_capacity=*/32, host_allocator, cpu_memory_pool, &buffer_pool));
-
-  iree_hal_buffer_params_t buffer_params = {
-      /*.usage=*/IREE_HAL_BUFFER_USAGE_DEFAULT |
-          IREE_HAL_BUFFER_USAGE_MAPPING_PERSISTENT,
-      /*.access=*/IREE_HAL_MEMORY_ACCESS_ALL,
-      /*.type=*/IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE |
-          IREE_HAL_MEMORY_TYPE_HOST_VISIBLE,
-      /*.queue_affinity=*/placement.queue_affinity,
-      /*.min_alignment=*/0,
-  };
-  iree_device_size_t requested_size = 127u;
-
-  // Handle is just for convenience and is stored within the buffer as well.
-  iree_hal_buffer_t* buffer = NULL;
-  iree_hal_amdgpu_device_allocation_handle_t* handle = NULL;
-  IREE_ASSERT_OK(iree_hal_amdgpu_buffer_pool_acquire(
-      &buffer_pool, buffer_params, /*allocation_size=*/requested_size, &buffer,
-      &handle));
-  ASSERT_NE(buffer, nullptr);
-  ASSERT_NE(handle, nullptr);
-  EXPECT_GE(iree_hal_buffer_allocation_size(buffer), requested_size);
-  EXPECT_EQ(iree_hal_buffer_allocation_placement(buffer).device,
-            placement.device);
-  EXPECT_EQ(iree_hal_buffer_allocation_placement(buffer).queue_affinity,
-            placement.queue_affinity);
-  EXPECT_EQ(iree_hal_buffer_allocation_placement(buffer).flags,
-            placement.flags);
-  EXPECT_EQ(iree_hal_buffer_byte_length(buffer), requested_size);
-
-  // Handle should have no physical pointer since nothing has been allocated.
-  EXPECT_EQ(handle->ptr, nullptr);
-
-  // Ensure the buffer resolves to the handle.
-  iree_hal_amdgpu_device_buffer_type_t type = 0;
-  uint64_t bits = 0;
-  IREE_ASSERT_OK(iree_hal_amdgpu_resolve_buffer(buffer, &type, &bits));
-  EXPECT_EQ(type, IREE_HAL_AMDGPU_DEVICE_BUFFER_TYPE_HANDLE);
-  EXPECT_EQ(bits, (uint64_t)handle);
-
-  // Same as above but for when we're confident we're dealing with a transient
-  // buffer (here we are).
-  iree_hal_amdgpu_device_allocation_handle_t* queried_handle = NULL;
-  IREE_ASSERT_OK(
-      iree_hal_amdgpu_resolve_transient_buffer(buffer, &queried_handle));
-  EXPECT_EQ(handle, queried_handle);
-
-  // Since the buffer is not actually allocated any attempt to map (even though
-  // we requested it) should fail.
-  iree_hal_buffer_mapping_t mapping = {};
-  EXPECT_THAT(
-      Status(iree_hal_buffer_map_range(buffer, IREE_HAL_MAPPING_MODE_PERSISTENT,
-                                       IREE_HAL_MEMORY_ACCESS_READ, 0,
-                                       IREE_HAL_WHOLE_BUFFER, &mapping)),
-      StatusIs(StatusCode::kFailedPrecondition));
-
-  // Release the buffer back to the pool - we're the last reference and it
-  // should be recycled.
-  iree_hal_buffer_release(buffer);
-
-  iree_hal_amdgpu_buffer_pool_deinitialize(&buffer_pool);
-}
-
-// Explicitly tests pool growth by acquiring an entire block worth of buffers+1.
-// We then release all the buffers that should have been in the first block and
-// trim with the second block outstanding to ensure it is not reclaimed with the
-// buffer outstanding.
-TEST_F(BufferPoolTest, Growth) {
-  IREE_TRACE_SCOPE();
-
-  iree_hal_buffer_placement_t placement = {
-      /*.device=*/NULL,  // not available in test
-      /*.queue_affinity=*/IREE_HAL_QUEUE_AFFINITY_ANY,
-      /*.flags=*/IREE_HAL_BUFFER_PLACEMENT_FLAG_ASYNCHRONOUS,
-  };
-  iree_hal_amdgpu_buffer_pool_t buffer_pool = {0};
-  IREE_ASSERT_OK(iree_hal_amdgpu_buffer_pool_initialize(
-      &libhsa, &topology, placement, /*block_capacity=*/32, host_allocator,
-      cpu_memory_pool, &buffer_pool));
-  // NOTE: the capacity may be larger than requested due to alignment.
-  const iree_host_size_t block_capacity = buffer_pool.block_capacity;
-
-  iree_hal_buffer_params_t buffer_params = {
-      /*.usage=*/IREE_HAL_BUFFER_USAGE_DEFAULT,
-      /*.access=*/IREE_HAL_MEMORY_ACCESS_ALL,
-      /*.type=*/IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE,
-      /*.queue_affinity=*/placement.queue_affinity,
-      /*.min_alignment=*/0,
-  };
-  iree_device_size_t requested_size = 128u;
-  std::vector<iree_hal_buffer_t*> buffers(block_capacity);
-  std::vector<iree_hal_amdgpu_device_allocation_handle_t*> handles(
-      block_capacity);
-
-  // Preallocate the first block (just to put more load on that path).
-  IREE_ASSERT_OK(
-      iree_hal_amdgpu_buffer_pool_preallocate(&buffer_pool, block_capacity));
-
-  // Allocate enough to consume the entire first block.
-  for (iree_host_size_t i = 0; i < block_capacity; ++i) {
-    IREE_ASSERT_OK(iree_hal_amdgpu_buffer_pool_acquire(
-        &buffer_pool, buffer_params, /*allocation_size=*/requested_size,
-        &buffers[i], &handles[i]));
-    EXPECT_EQ(handles[i]->ptr, nullptr);
-  }
-
-  // Allocate +1 to trigger growth and acquire the next block.
-  iree_hal_buffer_t* growth_buffer = NULL;
-  iree_hal_amdgpu_device_allocation_handle_t* growth_handle = NULL;
-  IREE_ASSERT_OK(iree_hal_amdgpu_buffer_pool_acquire(
-      &buffer_pool, buffer_params, /*allocation_size=*/requested_size,
-      &growth_buffer, &growth_handle));
-  ASSERT_NE(growth_buffer, nullptr);
-  ASSERT_NE(growth_handle, nullptr);
-  EXPECT_EQ(growth_handle->ptr, nullptr);
-
-  // Recycle all the buffers from the first block. After this it should have no
-  // outstanding buffers allocated it from it and be a candidate for trimming.
-  for (iree_host_size_t i = 0; i < block_capacity; ++i) {
-    iree_hal_buffer_release(buffers[i]);
-  }
-
-  // Ensure the growth buffer is still valid (should be, as we shouldn't have
-  // deallocated anything).
-  EXPECT_EQ(growth_handle->ptr, nullptr);
-
-  // Trim to drop the unused first block.
-  iree_hal_amdgpu_buffer_pool_trim(&buffer_pool);
-
-  // Check that we didn't drop the growth buffer that's in the second block.
-  EXPECT_EQ(growth_handle->ptr, nullptr);
-
-  // Release the last buffer and let the deinitialize cleanup the second block.
-  iree_hal_buffer_release(growth_buffer);
-
-  iree_hal_amdgpu_buffer_pool_deinitialize(&buffer_pool);
-}
-
-}  // namespace
-}  // namespace iree::hal::amdgpu
diff --git a/runtime/src/iree/hal/drivers/amdgpu/channel.c b/runtime/src/iree/hal/drivers/amdgpu/channel.c
deleted file mode 100644
index a4cdf82..0000000
--- a/runtime/src/iree/hal/drivers/amdgpu/channel.c
+++ /dev/null
@@ -1,133 +0,0 @@
-// Copyright 2025 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/hal/drivers/amdgpu/channel.h"
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_channel_t
-//===----------------------------------------------------------------------===//
-
-typedef struct iree_hal_amdgpu_channel_t {
-  iree_hal_resource_t resource;
-  iree_allocator_t host_allocator;
-
-  // Parent channel this was split from, if any.
-  // This is only used to keep the parent channel live for as long as there are
-  // any split channels live (including transitive splits).
-  iree_hal_channel_t* parent_channel;
-} iree_hal_amdgpu_channel_t;
-
-static const iree_hal_channel_vtable_t iree_hal_amdgpu_channel_vtable;
-
-static iree_hal_amdgpu_channel_t* iree_hal_amdgpu_channel_cast(
-    iree_hal_channel_t* base_value) {
-  IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_amdgpu_channel_vtable);
-  return (iree_hal_amdgpu_channel_t*)base_value;
-}
-
-static const iree_hal_amdgpu_channel_t* iree_hal_amdgpu_channel_const_cast(
-    const iree_hal_channel_t* base_value) {
-  IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_amdgpu_channel_vtable);
-  return (const iree_hal_amdgpu_channel_t*)base_value;
-}
-
-iree_status_t iree_hal_amdgpu_channel_create(iree_hal_channel_params_t params,
-                                             iree_allocator_t host_allocator,
-                                             iree_hal_channel_t** out_channel) {
-  IREE_ASSERT_ARGUMENT(out_channel);
-  *out_channel = NULL;
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  iree_hal_amdgpu_channel_t* channel = NULL;
-  IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0, iree_allocator_malloc(host_allocator, sizeof(*channel),
-                                (void**)&channel));
-  iree_hal_resource_initialize(&iree_hal_amdgpu_channel_vtable,
-                               &channel->resource);
-  channel->host_allocator = host_allocator;
-
-  // TODO(benvanik): implement channel setup using params. Note that the id is
-  // not retained and must be copied local if needed beyond this function call.
-  iree_status_t status = iree_make_status(
-      IREE_STATUS_UNIMPLEMENTED, "collective channels not implemented");
-
-  if (iree_status_is_ok(status)) {
-    *out_channel = (iree_hal_channel_t*)channel;
-  } else {
-    iree_hal_channel_release((iree_hal_channel_t*)channel);
-  }
-  IREE_TRACE_ZONE_END(z0);
-  return status;
-}
-
-static void iree_hal_amdgpu_channel_destroy(iree_hal_channel_t* base_channel) {
-  iree_hal_amdgpu_channel_t* channel =
-      iree_hal_amdgpu_channel_cast(base_channel);
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  iree_allocator_t host_allocator = channel->host_allocator;
-
-  // TODO(benvanik): destroy any implementation resources.
-
-  iree_hal_channel_release(channel->parent_channel);
-  iree_allocator_free(host_allocator, channel);
-
-  IREE_TRACE_ZONE_END(z0);
-}
-
-static iree_status_t iree_hal_amdgpu_channel_split(
-    iree_hal_channel_t* base_channel, int32_t color, int32_t key,
-    iree_hal_channel_flags_t flags, iree_hal_channel_t** out_split_channel) {
-  iree_hal_amdgpu_channel_t* channel =
-      iree_hal_amdgpu_channel_cast(base_channel);
-
-  // TODO(benvanik): split the channel and get any native resources required.
-  iree_status_t status = iree_make_status(IREE_STATUS_UNIMPLEMENTED,
-                                          "channel splitting not implemented");
-
-  // Wrap the split channel resources in a new HAL channel.
-  iree_hal_amdgpu_channel_t* split_channel = NULL;
-  if (iree_status_is_ok(status)) {
-    status =
-        iree_allocator_malloc(channel->host_allocator, sizeof(*split_channel),
-                              (void**)&split_channel);
-  }
-  if (iree_status_is_ok(status)) {
-    iree_hal_resource_initialize(&iree_hal_amdgpu_channel_vtable,
-                                 &split_channel->resource);
-    split_channel->host_allocator = channel->host_allocator;
-    split_channel->parent_channel = base_channel;
-    iree_hal_channel_retain(base_channel);
-
-    // TODO(benvanik): transfer ownership of the implementation resources.
-  }
-
-  if (iree_status_is_ok(status)) {
-    *out_split_channel = (iree_hal_channel_t*)split_channel;
-  } else {
-    iree_hal_channel_release((iree_hal_channel_t*)split_channel);
-  }
-  return status;
-}
-
-static void iree_hal_amdgpu_channel_query_rank_and_count(
-    const iree_hal_channel_t* base_channel, int32_t* out_rank,
-    int32_t* out_count) {
-  const iree_hal_amdgpu_channel_t* channel =
-      iree_hal_amdgpu_channel_const_cast(base_channel);
-
-  // TODO(benvanik): query the rank and count from the implementation or cache
-  // them locally to avoid overheads (this may be called frequently).
-  (void)channel;
-  *out_rank = 0;
-  *out_count = 0;
-}
-
-static const iree_hal_channel_vtable_t iree_hal_amdgpu_channel_vtable = {
-    .destroy = iree_hal_amdgpu_channel_destroy,
-    .split = iree_hal_amdgpu_channel_split,
-    .query_rank_and_count = iree_hal_amdgpu_channel_query_rank_and_count,
-};
diff --git a/runtime/src/iree/hal/drivers/amdgpu/channel.h b/runtime/src/iree/hal/drivers/amdgpu/channel.h
deleted file mode 100644
index a094ee5..0000000
--- a/runtime/src/iree/hal/drivers/amdgpu/channel.h
+++ /dev/null
@@ -1,22 +0,0 @@
-// Copyright 2025 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_HAL_DRIVERS_AMDGPU_CHANNEL_H_
-#define IREE_HAL_DRIVERS_AMDGPU_CHANNEL_H_
-
-#include "iree/base/api.h"
-#include "iree/hal/api.h"
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_channel_t
-//===----------------------------------------------------------------------===//
-
-// Creates an AMDGPU HAL collective channel using the given |params|.
-iree_status_t iree_hal_amdgpu_channel_create(iree_hal_channel_params_t params,
-                                             iree_allocator_t host_allocator,
-                                             iree_hal_channel_t** out_channel);
-
-#endif  // IREE_HAL_DRIVERS_AMDGPU_CHANNEL_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/command_buffer.c b/runtime/src/iree/hal/drivers/amdgpu/command_buffer.c
deleted file mode 100644
index 320af6d..0000000
--- a/runtime/src/iree/hal/drivers/amdgpu/command_buffer.c
+++ /dev/null
@@ -1,1873 +0,0 @@
-// Copyright 2025 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/hal/drivers/amdgpu/command_buffer.h"
-
-#include "iree/hal/drivers/amdgpu/buffer.h"
-#include "iree/hal/drivers/amdgpu/device/blit.h"
-#include "iree/hal/drivers/amdgpu/executable.h"
-#include "iree/hal/drivers/amdgpu/util/block_pool.h"
-#include "iree/hal/utils/resource_set.h"
-
-//===----------------------------------------------------------------------===//
-// Utilities
-//===----------------------------------------------------------------------===//
-
-// Populates the device-side |out_ref| variant of a HAL buffer |ref|.
-// This performs the HAL translation from an allocated buffer to the device
-// buffer so that no access to HAL data structures is needed on device
-// (iree_hal_buffer_t, etc).
-static iree_status_t iree_hal_amdgpu_translate_device_buffer_ref(
-    iree_hal_buffer_ref_t ref, iree_hal_amdgpu_device_buffer_ref_t* out_ref) {
-  // Slot references are resolved when the command buffer is issued.
-  // Record the range to translate from the buffer provided with the submission.
-  if (!ref.buffer) {
-    out_ref->offset = ref.offset;
-    out_ref->type = IREE_HAL_AMDGPU_DEVICE_BUFFER_TYPE_SLOT;
-    out_ref->length = ref.length;
-    out_ref->value.slot = ref.buffer_slot;
-    return iree_ok_status();
-  }
-
-  // Resolve buffers to their device type and value bits (which is
-  // type-specific). Note that we use the allocated buffer and not any wrapper
-  // that may be around it (subspans, etc).
-  iree_hal_amdgpu_device_buffer_type_t type = 0;
-  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_resolve_buffer(
-      iree_hal_buffer_allocated_buffer(ref.buffer), &type,
-      &out_ref->value.bits));
-  out_ref->type = type;
-
-  // Translate range from the allocated buffer base along with the reference
-  // offset. This avoids the device needing to read the HAL buffer when issuing.
-  out_ref->offset = iree_hal_buffer_byte_offset(ref.buffer) + ref.offset;
-  if (ref.length == IREE_HAL_WHOLE_BUFFER) {
-    out_ref->length = iree_hal_buffer_byte_length(ref.buffer);
-  } else {
-    out_ref->length = ref.length;
-  }
-
-  return iree_ok_status();
-}
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_query_id_scratch_t
-//===----------------------------------------------------------------------===//
-
-// Scratch space for acquiring query IDs that are associated with commands.
-// Query IDs are dependent on HAL commands and not AQL packets so we can share
-// them for all devices.
-typedef struct iree_hal_amdgpu_query_id_scratch_t {
-  // Next block-relative query ID when tracing commands.
-  // Also represents the maximum ID required.
-  iree_hal_amdgpu_device_command_query_id_t next;
-  // Total number of commands query ID slots assigned.
-  // Not all commands have query IDs and for those that don't the invalid value
-  // IREE_HAL_AMDGPU_TRACE_EXECUTION_QUERY_ID_INVALID is assigned.
-  uint32_t count;
-  // Scratch space for recording query IDs.
-  // Reused across blocks and copied out per-block as they each end recording.
-  iree_hal_amdgpu_device_command_query_id_t values[/*command_capacity*/];
-} iree_hal_amdgpu_query_id_scratch_t;
-
-// Resets the scratch state to empty.
-// Query ID values are not reset and should be assumed uninitialized.
-static void iree_hal_amdgpu_query_id_scratch_reset(
-    iree_hal_amdgpu_query_id_scratch_t* scratch) {
-  if (!scratch) return;
-  scratch->count = 0;
-  scratch->next.control_id = 0;
-  scratch->next.dispatch_id = 0;
-}
-
-// Assigns query IDs for the given command based on its type and increments the
-// command count. If the command does not require query IDs an empty slot is
-// assigned. Requires that the caller does bounds checking.
-static void iree_hal_amdgpu_assign_cmd_query_ids(
-    iree_hal_amdgpu_query_id_scratch_t* scratch,
-    iree_hal_amdgpu_device_cmd_type_t type) {
-  if (!scratch) return;
-#if !(IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION_DEVICE)
-  return;  // tracing disabled
-#endif     // !IREE_TRACING_FEATURE_INSTRUMENTATION_DEVICE
-  enum {
-    NONE = 0,
-    CONTROL = 1 << 0,
-    DISPATCH = 1 << 1,
-    ALL = CONTROL | DISPATCH,
-  };
-  static const uint8_t required_ids[IREE_HAL_AMDGPU_DEVICE_CMD_MAX + 1] = {
-      [IREE_HAL_AMDGPU_DEVICE_CMD_DEBUG_GROUP_BEGIN] = ALL,
-      [IREE_HAL_AMDGPU_DEVICE_CMD_DEBUG_GROUP_END] = ALL,
-      [IREE_HAL_AMDGPU_DEVICE_CMD_BARRIER] = NONE,
-      [IREE_HAL_AMDGPU_DEVICE_CMD_SIGNAL_EVENT] = NONE,
-      [IREE_HAL_AMDGPU_DEVICE_CMD_RESET_EVENT] = NONE,
-      [IREE_HAL_AMDGPU_DEVICE_CMD_WAIT_EVENTS] = NONE,
-      [IREE_HAL_AMDGPU_DEVICE_CMD_FILL_BUFFER] = DISPATCH,
-      [IREE_HAL_AMDGPU_DEVICE_CMD_COPY_BUFFER] = DISPATCH,
-      [IREE_HAL_AMDGPU_DEVICE_CMD_DISPATCH] = DISPATCH,
-      [IREE_HAL_AMDGPU_DEVICE_CMD_DISPATCH_INDIRECT_DYNAMIC] = DISPATCH,
-      [IREE_HAL_AMDGPU_DEVICE_CMD_BRANCH] = NONE,
-      [IREE_HAL_AMDGPU_DEVICE_CMD_RETURN] = NONE,
-  };
-  iree_hal_amdgpu_device_command_query_id_t query_id = {
-      .control_id = IREE_HAL_AMDGPU_TRACE_EXECUTION_QUERY_ID_INVALID,
-      .dispatch_id = IREE_HAL_AMDGPU_TRACE_EXECUTION_QUERY_ID_INVALID,
-  };
-  if (required_ids[type] & CONTROL) {
-    query_id.control_id = scratch->next.control_id++;
-  }
-  if (required_ids[type] & DISPATCH) {
-    query_id.dispatch_id = scratch->next.dispatch_id++;
-  }
-  scratch->values[scratch->count++] = query_id;
-}
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_command_encoder_t
-//===----------------------------------------------------------------------===//
-
-// Per-device state managed by a command block encoder.
-typedef struct iree_hal_amdgpu_device_encoder_state_t {
-  // Arena of device-side memory used for metadata (blocks, query lists, etc)
-  // reused across blocks.
-  iree_hal_amdgpu_block_arena_t metadata_arena;
-  // Arena of device-side memory used for embedded data (inline update buffers,
-  // bindings, constants, etc) reused across blocks.
-  iree_hal_amdgpu_block_arena_t storage_arena;
-
-  // Command block storage.
-  // Commands are fixed-size and we allocate one device memory block per CFG
-  // block we encode. This wastes space and requires a compaction step during
-  // finalization if we wanted to fix it up.
-  struct {
-    // Device block pool used for allocating command storage.
-    iree_hal_amdgpu_block_pool_t* pool;
-    // Head of the command block linked list.
-    // This is the first encoded block.
-    iree_hal_amdgpu_block_t* head;
-    // Tail of the command block linked list.
-    // This is the last (completed) encoded block.
-    iree_hal_amdgpu_block_t* tail;
-    // Current command block in device memory storing the encoded commands.
-    // Reset each CFG block.
-    iree_hal_amdgpu_block_t* current;
-  } cmd_block;
-} iree_hal_amdgpu_device_encoder_state_t;
-
-// Manages command buffer block encoding for multiple devices.
-// Information common to each device is shared (such as command counts and query
-// IDs) while command and embedded data is stored per-device.
-//
-// Offsets and sizes here are limited both by the options provided during
-// command buffer construction (usually derived from user options or device
-// limits) and the maximum values allowed by the device-side scheduler. Most
-// command packets use uint16_t offsets to fit under the fixed packet size limit
-// but this happens to be in-line with what devices are practically limited to
-// by design.
-typedef struct iree_hal_amdgpu_command_encoder_t {
-  // Whether a block is open for encoding (between a begin/end).
-  uint64_t in_block : 1;
-  // An execution barrier has been requested and the next command should have
-  // its barrier bit set.
-  uint64_t barrier_pending : 1;
-
-  // Total number of devices being recorded. Used to size per-device state.
-  iree_host_size_t device_count;
-
-  // Total number of blocks recorded (including the current block, if any).
-  iree_host_size_t block_count;
-
-  // Scratch host memory used for accumulating data for finalization.
-  // Only reset when recording completes.
-  iree_arena_allocator_t* host_arena;
-
-  // Query ID scratch space for per-command query ID allocation.
-  // Shared across all devices and copied to per-device block metadata at the
-  // end of encoding each block.
-  iree_hal_amdgpu_query_id_scratch_t* query_ids;
-
-  // Maximum number of commands that can be encoded into a block (including
-  // terminator). Blocks are split automatically if they exceed the capacity.
-  uint32_t command_capacity;
-  // Current count of commands encoded into the block.
-  // The encoder ensures that there is always room under the capacity for a
-  // terminator command.
-  uint16_t command_count;
-
-  // Maximum number of AQL packets that can be used in a block.
-  // This is checked against the target queue limits during command buffer
-  // submission.
-  uint32_t max_aql_packet_capacity;
-  // Peak number of AQL packets required by any block encoded so far.
-  // Depending on dynamic state when the command buffer is issued fewer packets
-  // may be required and this is used to ensure HSA queue capacity is available
-  // prior to issuing a command block.
-  uint16_t peak_aql_packet_count;
-  // Offset of the next AQL packet to be allocated.
-  uint16_t aql_packet_offset;
-
-  // Maximum kernarg capacity in bytes that can be used within a single block.
-  uint32_t max_kernarg_capacity;
-  // Peak kernarg size in bytes required by any block encoded so far.
-  uint16_t peak_kernarg_size;
-  // Offset of the next kernarg pointer to be allocated.
-  // This may not be aligned (1 byte alignment) and needs to be aligned by the
-  // allocator requesting space.
-  uint16_t kernarg_offset;
-
-  // Scratch space for the last command appended on each device. This is
-  // returned to callers appending commands so they can do per-device updates.
-  // Pointers are invalidated on block reset or when the next command is
-  // appended.
-  IREE_AMDGPU_DEVICE_PTR iree_hal_amdgpu_device_cmd_header_t**
-      device_cmds /*[device_count]*/;
-
-  // Active command buffer block encoder for each device if the command buffer
-  // is recording. Encoders are acquired from the small block pool when
-  // recording begins and discarded upon finalization.
-  iree_hal_amdgpu_device_encoder_state_t device_state[/*device_count*/];
-} iree_hal_amdgpu_command_encoder_t;
-
-// Initializes a command encoder with limits based on the provided |options|
-// that can encode into |device_count| devices. The encoder will be allocated
-// from |host_arena| and iree_hal_amdgpu_command_encoder_deinitialize must be
-// called prior to resetting the arena.
-static iree_status_t iree_hal_amdgpu_command_encoder_initialize(
-    const iree_hal_amdgpu_command_buffer_options_t* options,
-    iree_host_size_t device_count, iree_arena_allocator_t* host_arena,
-    iree_hal_amdgpu_command_encoder_t** out_encoder) {
-  IREE_ASSERT_ARGUMENT(out_encoder);
-  IREE_TRACE_ZONE_BEGIN(z0);
-  *out_encoder = NULL;
-
-  // Allocate the dynamically-sized encoder from the arena.
-  iree_hal_amdgpu_command_encoder_t* encoder = NULL;
-  const iree_host_size_t encoder_size =
-      sizeof(*encoder) + device_count * sizeof(encoder->device_state[0]);
-  IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0, iree_arena_allocate(host_arena, encoder_size, (void**)&encoder));
-
-  // NOTE: arena allocations may be uninitialized.
-  memset(encoder, 0, sizeof(*encoder));
-  encoder->device_count = device_count;
-  encoder->host_arena = host_arena;
-
-  // Capacities are defined as options or are derived from available pools.
-  // A larger pool will allow more commands to be encoded in a single block.
-  // We require contiguous allocations for the commands (today) and if we want
-  // to save space can perform a compaction during finalization (as commands
-  // don't contain pointers to the command memory).
-  encoder->command_capacity =
-      iree_host_size_floor_div(options->device_block_pools[0]->large.block_size,
-                               IREE_HAL_AMDGPU_DEVICE_CMD_SIZE);
-  encoder->max_aql_packet_capacity = options->block_aql_packet_count;
-  encoder->max_kernarg_capacity = UINT32_MAX;
-
-// Allocate query IDs from the host arena.
-// The query ID list is uninitialized but that's ok: we don't guarantee any
-// queries beyond the current count are initialized.
-#if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION_DEVICE
-  // TODO(benvanik): support options->recording_flags bit for tracing control
-  // per-command-buffer. If not set we can disable it here and save some scratch
-  // space and device uploads.
-  IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0, iree_arena_allocate(host_arena,
-                              sizeof(*encoder->query_ids) +
-                                  encoder->command_capacity *
-                                      sizeof(encoder->query_ids->values[0]),
-                              (void**)&encoder->query_ids));
-  encoder->query_ids->next.control_id = 0;
-  encoder->query_ids->next.dispatch_id = 0;
-  encoder->query_ids->count = 0;
-#endif  // IREE_TRACING_FEATURE_INSTRUMENTATION_DEVICE
-
-  // The device command scratch list is used for returning the commands
-  // allocated on each device. This isn't great but since we have an area it's
-  // cheap.
-  IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0, iree_arena_allocate(host_arena,
-                              device_count * sizeof(encoder->device_cmds[0]),
-                              (void**)&encoder->device_cmds));
-
-  // Initialize per-device encoder state.
-  for (iree_host_size_t i = 0; i < device_count; ++i) {
-    iree_hal_amdgpu_device_encoder_state_t* device_state =
-        &encoder->device_state[i];
-    memset(device_state, 0, sizeof(*device_state));
-    iree_hal_amdgpu_block_arena_initialize(
-        &options->device_block_pools[i]->small, &device_state->metadata_arena);
-    iree_hal_amdgpu_block_arena_initialize(
-        &options->device_block_pools[i]->large, &device_state->storage_arena);
-    // TODO(benvanik): choose the pool based on the expected command buffer
-    // size. We don't know this today but could add a hint flag for "small" that
-    // let us avoid taking a large block for 1 command.
-    device_state->cmd_block.pool = &options->device_block_pools[i]->large;
-  }
-
-  *out_encoder = encoder;
-  IREE_TRACE_ZONE_END(z0);
-  return iree_ok_status();
-}
-
-// Deinitializes an |encoder| and releases any resources regardless of whether a
-// block is in progress (such as when cleaning up from a failure state).
-static void iree_hal_amdgpu_command_encoder_deinitialize(
-    iree_hal_amdgpu_command_encoder_t* encoder) {
-  if (!encoder) return;
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  // Release all resources that the encoder may still hold on to.
-  // In successful recordings these will have been taken by the parent command
-  // buffer.
-  for (iree_host_size_t i = 0; i < encoder->device_count; ++i) {
-    iree_hal_amdgpu_device_encoder_state_t* device_state =
-        &encoder->device_state[i];
-    if (device_state->cmd_block.head) {
-      iree_hal_amdgpu_block_pool_release_list(device_state->cmd_block.pool,
-                                              device_state->cmd_block.head);
-    }
-    if (device_state->cmd_block.current) {
-      iree_hal_amdgpu_block_pool_release(device_state->cmd_block.pool,
-                                         device_state->cmd_block.current);
-    }
-    iree_hal_amdgpu_block_arena_deinitialize(&device_state->storage_arena);
-    iree_hal_amdgpu_block_arena_deinitialize(&device_state->metadata_arena);
-  }
-
-  IREE_TRACE_ZONE_END(z0);
-}
-
-// Begins a new command buffer block and returns its block ordinal.
-// The encoder must be in the default state with any prior block closed.
-//
-// The block ordinal will remain stable during recording and can be used when
-// emitting branch operations as if it were a label. The final pointer of the
-// block in device memory will not be available until recording has completed.
-static iree_status_t iree_hal_amdgpu_command_encoder_begin_block(
-    iree_hal_amdgpu_command_encoder_t* encoder, uint32_t* out_block_ordinal) {
-  IREE_ASSERT_ARGUMENT(encoder);
-  IREE_ASSERT(out_block_ordinal);
-  *out_block_ordinal = 0;
-  if (IREE_UNLIKELY(encoder->in_block)) {
-    return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
-                            "a block is already being recorded and must be "
-                            "completed prior to recording another");
-  }
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  // Allocate initial empty command blocks on each device.
-  for (iree_host_size_t i = 0; i < encoder->device_count; ++i) {
-    iree_hal_amdgpu_device_encoder_state_t* device_state =
-        &encoder->device_state[i];
-    IREE_RETURN_AND_END_ZONE_IF_ERROR(
-        z0,
-        iree_hal_amdgpu_block_pool_acquire(device_state->cmd_block.pool,
-                                           &device_state->cmd_block.current));
-  }
-
-  // Begin block with implicit barrier at entry.
-  encoder->in_block = 1;
-  encoder->barrier_pending = 1;
-  ++encoder->block_count;
-
-  IREE_TRACE_ZONE_END(z0);
-  return iree_ok_status();
-}
-
-// Ends the current command buffer block and finalizes its device-side
-// resources. Upon return the encoder will be in the default idle state.
-// The caller must have inserted a terminator block prior to ending in order
-// for the command buffer to be valid.
-static iree_status_t iree_hal_amdgpu_command_encoder_end_block(
-    iree_hal_amdgpu_command_encoder_t* encoder) {
-  IREE_ASSERT_ARGUMENT(encoder);
-  if (IREE_UNLIKELY(!encoder->in_block)) {
-    return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
-                            "no block is actively being recorded");
-  }
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  // Per-block metadata contains the query map sized based on the number of
-  // commands recorded.
-  const iree_hal_amdgpu_query_id_scratch_t* query_ids = encoder->query_ids;
-  const iree_host_size_t total_metadata_size =
-      sizeof(iree_hal_amdgpu_device_command_block_t) +
-      (query_ids ? query_ids->count *
-                       sizeof(iree_hal_amdgpu_device_command_query_id_t)
-                 : 0);
-
-  // TODO(benvanik): sort the commands by type? Would help us use larger
-  // workgroup sizes (but may need padding); today if we used >1 workgroup size
-  // and any command in the workgroup was different we'd end up in unhappy land.
-
-  // Upload the block metadata to each device.
-  for (iree_host_size_t i = 0; i < encoder->device_count; ++i) {
-    iree_hal_amdgpu_device_encoder_state_t* device_state =
-        &encoder->device_state[i];
-
-    // Allocate block metadata from the block arena.
-    IREE_AMDGPU_DEVICE_PTR iree_hal_amdgpu_device_command_block_t*
-        block_metadata = NULL;
-    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_block_arena_allocate(
-        &device_state->metadata_arena, total_metadata_size,
-        (void**)&block_metadata));
-
-    // Move the current block storage to the end of the block linked list.
-    iree_hal_amdgpu_block_t* cmd_block = device_state->cmd_block.current;
-    device_state->cmd_block.current = NULL;
-    cmd_block->next = NULL;
-    if (device_state->cmd_block.tail) {
-      device_state->cmd_block.tail->next = cmd_block;
-    } else {
-      device_state->cmd_block.head = cmd_block;
-    }
-    device_state->cmd_block.tail = cmd_block;
-
-    // Store the command buffer block metadata pointer in the device storage
-    // block. This is free space that we can reuse without needing to keep
-    // track of a growing list of blocks. Consider the list of block pool blocks
-    // as an iovec and the metadata is attached to one of the entries.
-    cmd_block->user_data[0] = (uint64_t)block_metadata;
-
-    // The metadata references the command block.
-    block_metadata->max_packet_count = encoder->aql_packet_offset;
-    block_metadata->command_count = encoder->command_count;
-    block_metadata->commands =
-        (const iree_hal_amdgpu_device_cmd_t*)cmd_block->ptr;
-
-    // Each device gets a copy of the query IDs in its local memory.
-    if (query_ids) {
-      block_metadata->query_map.max_control_query_count =
-          query_ids->next.control_id;
-      block_metadata->query_map.max_dispatch_query_count =
-          query_ids->next.dispatch_id;
-      iree_memcpy_stream_dst(block_metadata->query_map.query_ids,
-                             query_ids->values,
-                             query_ids->count * sizeof(query_ids->values[0]));
-    }
-  }
-
-  // Reset encoder state for the next block.
-  encoder->command_count = 0;
-  encoder->peak_aql_packet_count = encoder->aql_packet_offset;
-  encoder->aql_packet_offset = 0;
-  encoder->peak_kernarg_size = encoder->kernarg_offset;
-  encoder->kernarg_offset = 0;
-  iree_hal_amdgpu_query_id_scratch_reset(encoder->query_ids);
-
-  encoder->in_block = 0;
-
-  IREE_TRACE_ZONE_END(z0);
-  return iree_ok_status();
-}
-
-static iree_status_t iree_hal_amdgpu_command_encoder_append_cmd(
-    iree_hal_amdgpu_command_encoder_t* encoder,
-    iree_hal_amdgpu_device_cmd_type_t type,
-    iree_hal_amdgpu_device_cmd_flags_t flags, uint16_t aql_packet_count,
-    uint16_t kernarg_size, uint16_t kernarg_alignment,
-    void* IREE_AMDGPU_DEVICE_PTR** out_device_cmds,
-    uint32_t* out_kernarg_offset);
-
-// Splits the current block in two by inserting a branch command.
-// The block must have at least one command slot available so the branch can be
-// inserted. All active state is carried over to the new block and since the
-// split is internal no new block label is needed (the only incoming edge is
-// from the current block that is being split).
-static iree_status_t iree_hal_amdgpu_command_encoder_split_block(
-    iree_hal_amdgpu_command_encoder_t* encoder) {
-  IREE_ASSERT_ARGUMENT(encoder);
-  if (IREE_UNLIKELY(!encoder->in_block)) {
-    return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
-                            "no block is actively being recorded");
-  }
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  // Insert a branch to the next block (that we will be beginning below).
-  const uint32_t target_block = encoder->block_count + 1;
-  iree_hal_amdgpu_device_cmd_branch_t** device_cmds = NULL;
-  uint32_t kernarg_offset = 0;
-  IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0, iree_hal_amdgpu_command_encoder_append_cmd(
-              encoder, IREE_HAL_AMDGPU_DEVICE_CMD_BRANCH,
-              IREE_HAL_AMDGPU_DEVICE_CMD_FLAG_QUEUE_AWAIT_BARRIER,
-              IREE_HAL_AMDGPU_DEVICE_CMD_BRANCH_AQL_PACKET_COUNT,
-              IREE_HAL_AMDGPU_DEVICE_CMD_BRANCH_KERNARG_SIZE,
-              IREE_HAL_AMDGPU_DEVICE_CMD_BRANCH_KERNARG_ALIGNMENT,
-              (void***)&device_cmds, &kernarg_offset));
-  for (iree_host_size_t i = 0; i < encoder->device_count; ++i) {
-    device_cmds[i]->kernarg_offset = kernarg_offset;
-    device_cmds[i]->target_block = target_block;
-  }
-
-  // End the current block now that we've terminated it with the branch.
-  IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0, iree_hal_amdgpu_command_encoder_end_block(encoder));
-
-  // Begin the new block. We don't care about the ordinal here as we already
-  // calculated it above and no one will be able to branch to it.
-  uint32_t block_ordinal = 0;
-  IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0, iree_hal_amdgpu_command_encoder_begin_block(encoder, &block_ordinal));
-
-  IREE_TRACE_ZONE_END(z0);
-  return iree_ok_status();
-}
-
-// Appends a command to the current block.
-// The command header will be initialized based on the current encoding state
-// (such as whether a barrier is required) and the accounting of kernarg storage
-// will be calculated. Callers must populate any additional command-specific
-// information or flags per-device in the returned |out_device_cmds| list.
-//
-// If the current block has exceeded its capacity in the target device memory
-// block it will be split and joined with a branch.
-//
-// NOTE: the command data returned should be treated as write-only as its memory
-// lives on device and may be uncached. If we really wanted to optimize things
-// we'd probably be producing commands in host memory and doing a non-temporal
-// memcpy but that likely adds latency in the case of smaller command buffers.
-// This way we are always writing a bulk of the data while we go and
-// finalization per-device is just a quick fixup.
-static iree_status_t iree_hal_amdgpu_command_encoder_append_cmd(
-    iree_hal_amdgpu_command_encoder_t* encoder,
-    iree_hal_amdgpu_device_cmd_type_t type,
-    iree_hal_amdgpu_device_cmd_flags_t flags, uint16_t aql_packet_count,
-    uint16_t kernarg_size, uint16_t kernarg_alignment,
-    void* IREE_AMDGPU_DEVICE_PTR** out_device_cmds,
-    uint32_t* out_kernarg_offset) {
-  IREE_ASSERT_ARGUMENT(!kernarg_size || out_kernarg_offset);
-
-  if (IREE_UNLIKELY(!encoder->in_block)) {
-    return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
-                            "no block is actively being recorded");
-  }
-
-  // Split the block when we have only one command remaining unless the command
-  // being recorded is a terminator. We should always have one command available
-  // for use (start with 1, check here that there's always 1, and split if
-  // there's 0).
-  //
-  // Though rare it's possible the queue may not have enough space for the
-  // requested AQL packets - usually command capacity is conservative enough to
-  // be under any device limits but we do not want to fail once scheduled (as
-  // that's a device loss).
-  const bool is_terminator = type == IREE_HAL_AMDGPU_DEVICE_CMD_BRANCH ||
-                             type == IREE_HAL_AMDGPU_DEVICE_CMD_RETURN;
-  const bool exceeded_command_capacity =
-      !is_terminator && encoder->command_count + 1 >= encoder->command_capacity;
-  const bool exceeded_aql_packet_capacity =
-      encoder->aql_packet_offset + aql_packet_count >
-      encoder->max_aql_packet_capacity;
-  const bool exceeded_kernarg_capacity =
-      kernarg_size > 0 &&
-      iree_host_align(encoder->kernarg_offset, kernarg_alignment) +
-              kernarg_size >
-          encoder->max_kernarg_capacity;
-  if (exceeded_command_capacity || exceeded_aql_packet_capacity ||
-      exceeded_kernarg_capacity) {
-    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_command_encoder_split_block(encoder));
-  }
-
-  // --- WARNING ------------------------------------------------------------ //
-  // The split above may have reset encoder state and any state loaded from it
-  // must be required for subsequent usage.
-  // --- WARNING ------------------------------------------------------------ //
-
-  // We update a scratch header and copy it to each device at the end.
-  iree_hal_amdgpu_device_cmd_header_t cmd = {
-      .type = type,
-  };
-
-  // Acquire the next command from the device memory block.
-  uint16_t command_offset = encoder->command_count;
-  ++encoder->command_count;
-
-  // If a barrier is pending (a prior command requires exclusive execution) then
-  // consume it. Since commands are processed in order all following commands
-  // recorded will also wait.
-  cmd.flags = flags;
-  if (encoder->barrier_pending) {
-    cmd.flags |= IREE_HAL_AMDGPU_DEVICE_CMD_FLAG_QUEUE_AWAIT_BARRIER;
-    encoder->barrier_pending = 0;
-  }
-
-  // Consume the requested packets from the AQL packet queue.
-  // Note that these are relative values as actual queue offsets are calculated
-  // when the commands are issued.
-  cmd.packet_offset = encoder->aql_packet_offset;
-  encoder->aql_packet_offset += aql_packet_count;
-
-  // Kernarg storage is committed when the commands are executed; here we are
-  // just tracking how much space is required and the offset of each requested
-  // block in memory. Note that kernargs have an alignment that we need to
-  // preserve.
-  if (kernarg_size > 0) {
-    iree_host_size_t kernarg_offset =
-        iree_host_align(encoder->kernarg_offset, kernarg_alignment);
-    encoder->kernarg_offset = encoder->kernarg_offset + kernarg_size;
-    *out_kernarg_offset = (uint16_t)kernarg_offset;
-  }
-
-  // Assign tracing query IDs for the command (if needed).
-  iree_hal_amdgpu_assign_cmd_query_ids(encoder->query_ids, cmd.type);
-
-  // TODO(benvanik): find a way to avoid needing to do this here - if we could
-  // instead pass back the command header and make callers responsible they
-  // could do this themselves in the loop they are likely already performing.
-  // For now this keeps things contained and simpler in cases where there is no
-  // per-device command information.
-  for (uint32_t i = 0; i < encoder->device_count; ++i) {
-    IREE_AMDGPU_DEVICE_PTR iree_hal_amdgpu_device_cmd_header_t* device_cmd =
-        (iree_hal_amdgpu_device_cmd_header_t*)encoder->device_state[i]
-            .cmd_block.current->ptr +
-        command_offset;
-    iree_memcpy_stream_dst(device_cmd, &cmd, sizeof(device_cmd));
-    encoder->device_cmds[i] = device_cmd;
-  }
-
-  *out_device_cmds = (void**)encoder->device_cmds;
-  return iree_ok_status();
-}
-
-// Emplaces a data buffer into the block storage on |device_index| and returns
-// the pointer in device memory. The returned pointer is aligned to
-// iree_hal_amdgpu_max_align_t bytes.
-static iree_status_t iree_hal_amdgpu_command_encoder_emplace_data(
-    iree_hal_amdgpu_command_encoder_t* encoder, iree_host_size_t device_index,
-    iree_const_byte_span_t data, IREE_AMDGPU_DEVICE_PTR void** out_ptr) {
-  iree_hal_amdgpu_device_encoder_state_t* device_state =
-      &encoder->device_state[device_index];
-
-  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_block_arena_allocate(
-      &device_state->storage_arena, data.data_length, (void**)&out_ptr));
-
-  // Stream data to device memory (we will never read it on the host).
-  iree_memcpy_stream_dst(*out_ptr, data.data, data.data_length);
-
-  return iree_ok_status();
-}
-
-// Emplaces two data buffers into the block storage on |device_index| and
-// returns the pointer to each in device memory. This is more efficient than
-// emplacing both individually. Alignment is performed on each data span to
-// iree_hal_amdgpu_max_align_t bytes.
-static iree_status_t iree_hal_amdgpu_command_encoder_emplace_data_concat(
-    iree_hal_amdgpu_command_encoder_t* encoder, iree_host_size_t device_index,
-    iree_const_byte_span_t data0, iree_const_byte_span_t data1,
-    IREE_AMDGPU_DEVICE_PTR void** out_ptr0,
-    IREE_AMDGPU_DEVICE_PTR void** out_ptr1) {
-  iree_hal_amdgpu_device_encoder_state_t* device_state =
-      &encoder->device_state[device_index];
-
-  // Allocate combined block with alignment on each pointer.
-  // The allocation will start at the alignment and the total size will be
-  // aligned so we just need to ensure our internal data1 pointer is aligned.
-  const iree_host_size_t data1_offset =
-      iree_host_align(data0.data_length, iree_hal_amdgpu_max_align_t);
-  const iree_host_size_t total_size = data1_offset + data1.data_length;
-  uint8_t* base_ptr = NULL;
-  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_block_arena_allocate(
-      &device_state->storage_arena, total_size, (void**)&base_ptr));
-  *out_ptr0 = base_ptr;
-  *out_ptr1 = base_ptr + data1_offset;
-
-  // Stream data to device memory (we will never read it on the host).
-  iree_memcpy_stream_dst(*out_ptr0, data0.data, data0.data_length);
-  iree_memcpy_stream_dst(*out_ptr1, data1.data, data1.data_length);
-
-  return iree_ok_status();
-}
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_command_buffer_options_t
-//===----------------------------------------------------------------------===//
-
-#define IREE_HAL_AMDGPU_COMMAND_BUFFER_MIN_HOST_SMALL_BLOCK_SIZE (4 * 1024 - 32)
-#define IREE_HAL_AMDGPU_COMMAND_BUFFER_MIN_HOST_LARGE_BLOCK_SIZE (32 * 1024)
-#define IREE_HAL_AMDGPU_COMMAND_BUFFER_MIN_METADATA_BLOCK_SIZE (8 * 1024 - 32)
-#define IREE_HAL_AMDGPU_COMMAND_BUFFER_MIN_STORAGE_BLOCK_SIZE (64 * 1024)
-
-void iree_hal_amdgpu_command_buffer_options_initialize(
-    iree_hal_allocator_t* device_allocator, iree_hal_command_buffer_mode_t mode,
-    iree_hal_command_category_t command_categories,
-    iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity,
-    iree_hal_amdgpu_command_buffer_options_t* out_options) {
-  memset(out_options, 0, sizeof(*out_options));
-
-  out_options->device_allocator = device_allocator;
-  out_options->mode = mode;
-  out_options->command_categories = command_categories;
-  out_options->queue_affinity = queue_affinity;
-  out_options->binding_capacity = binding_capacity;
-
-  out_options->recording_flags =
-      IREE_HAL_AMDGPU_COMMAND_BUFFER_RECORDING_FLAG_NONE;
-
-  out_options->block_aql_packet_count =
-      IREE_HAL_AMDGPU_COMMAND_BUFFER_MAX_BLOCK_AQL_PACKET_COUNT;
-}
-
-iree_status_t iree_hal_amdgpu_command_buffer_options_verify(
-    const iree_hal_amdgpu_command_buffer_options_t* options) {
-  // Verify we can use the execution queues the command buffer is targeting.
-  if (options->block_aql_packet_count <
-          IREE_HAL_AMDGPU_COMMAND_BUFFER_MIN_BLOCK_AQL_PACKET_COUNT ||
-      options->block_aql_packet_count >
-          IREE_HAL_AMDGPU_COMMAND_BUFFER_MAX_BLOCK_AQL_PACKET_COUNT) {
-    return iree_make_status(
-        IREE_STATUS_INVALID_ARGUMENT,
-        "block_aql_packet_count must be between %d and %d "
-        "but %" PRIhsz " was requested",
-        IREE_HAL_AMDGPU_COMMAND_BUFFER_MIN_BLOCK_AQL_PACKET_COUNT,
-        IREE_HAL_AMDGPU_COMMAND_BUFFER_MAX_BLOCK_AQL_PACKET_COUNT,
-        options->block_aql_packet_count);
-  }
-
-  // Verify any devices are targeted. We could probably handle recording for no
-  // devices and use that as a baseline performance metric but today zero
-  // devices would definitely be unintended by users.
-  int device_count =
-      iree_hal_amdgpu_device_affinity_count(options->device_affinity);
-  if (device_count == 0) {
-    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                            "at least one physical device must be specified");
-  }
-
-  // Host block pools use a small amount of the block space for internal
-  // accounting and we must ensure that there's enough usable for our minimum
-  // contiguous allocations. This may mean they have to be slightly oversized
-  // if we end up needing a power-of-two allocation.
-  if (options->host_block_pools->small.usable_block_size <
-      IREE_HAL_AMDGPU_COMMAND_BUFFER_MIN_HOST_SMALL_BLOCK_SIZE) {
-    return iree_make_status(
-        IREE_STATUS_INVALID_ARGUMENT,
-        "host small block pool must have at least %d usable bytes per block "
-        "but only has %" PRIhsz,
-        IREE_HAL_AMDGPU_COMMAND_BUFFER_MIN_HOST_SMALL_BLOCK_SIZE,
-        options->host_block_pools->small.usable_block_size);
-  }
-  if (options->host_block_pools->large.usable_block_size <
-      IREE_HAL_AMDGPU_COMMAND_BUFFER_MIN_HOST_LARGE_BLOCK_SIZE) {
-    return iree_make_status(
-        IREE_STATUS_INVALID_ARGUMENT,
-        "host small block pool must have at least %d usable bytes per block "
-        "but only has %" PRIhsz,
-        IREE_HAL_AMDGPU_COMMAND_BUFFER_MIN_HOST_LARGE_BLOCK_SIZE,
-        options->host_block_pools->large.usable_block_size);
-  }
-
-  // Verify that the device-side pools will fit our data structures that we
-  // require to be contiguous.
-  for (int i = 0; i < device_count; ++i) {
-    if (options->device_block_pools[i]->small.block_size <
-        IREE_HAL_AMDGPU_COMMAND_BUFFER_MIN_METADATA_BLOCK_SIZE) {
-      return iree_make_status(
-          IREE_STATUS_INVALID_ARGUMENT,
-          "device[%d] small block pool must have at least "
-          "%d bytes per block but only has %" PRIdsz,
-          i, IREE_HAL_AMDGPU_COMMAND_BUFFER_MIN_METADATA_BLOCK_SIZE,
-          options->device_block_pools[i]->small.block_size);
-    }
-    if (options->device_block_pools[i]->large.block_size <
-        IREE_HAL_AMDGPU_COMMAND_BUFFER_MIN_STORAGE_BLOCK_SIZE) {
-      return iree_make_status(
-          IREE_STATUS_INVALID_ARGUMENT,
-          "device[%d] large block pool must have at least "
-          "%d bytes per block but only has %" PRIdsz,
-          i, IREE_HAL_AMDGPU_COMMAND_BUFFER_MIN_STORAGE_BLOCK_SIZE,
-          options->device_block_pools[i]->large.block_size);
-    }
-  }
-
-  return iree_ok_status();
-}
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_command_buffer_t
-//===----------------------------------------------------------------------===//
-
-// A host-side HAL command buffer wrapping a device-side handle.
-// Each command buffer can be recorded for multiple physical devices where all
-// data will reside in device-local memory. Expensive validation, encoding
-// logic, and resource tracking are amortized across all devices the command
-// buffer is recorded for.
-//
-// Once finalized a command buffer owns per-device references to a list of
-// memory blocks holding command buffer metadata such as the device-side
-// iree_hal_amdgpu_device_command_buffer_t and its CFG blocks and list of
-// memory blocks holding any embedded data used by the command buffer such as
-// bindings/constants and inline buffer uploads.
-//
-// During recording iree_hal_amdgpu_command_encoder_t is used to track the
-// current block, total counts required for allocating final device resources,
-// and per-device arenas. Once recording completes the encoder is discarded.
-//
-// A goal was to have the cost of recording a command buffer for a single device
-// not suffer from the support for multiple devices - in most cases no
-// additional memory during recording and at rest will be allocated beyond what
-// would have been needed if this only supported a single device and besides a
-// predictable 1-trip loop has no meaningful additional per-command overhead.
-//
-// Attention is paid to recording performance even though we expect most large
-// command buffers to be recorded at initialization time and reused. Though IREE
-// compiled programs can easily reuse command buffers applications directly
-// using the HAL may not be able to and we want one-shot command buffers to not
-// be super inefficient. A one-shot command buffer will never achieve the
-// latency of a CUDA stream-like API but 99% of application-level operations are
-// copies that either could benefit from batching or could be done unbatched at
-// the queue level. The HAL as a whole focuses more on the expensive parts of
-// the program than the cheap/garbage-y ones: it's best to have deterministic
-// excellent performance on critical workloads than to have low latency in
-// naively authored applications. We do upload block metadata and embedded
-// storage data to devices as we record the command buffer such that after
-// recording completes we can launch the command buffer immediately but we
-// cannot time-travel to begin executing while the command buffer is still
-// recording.
-//
-// One thing we could support is using SDMA to upload the command buffer
-// contents asynchronously with the recording. There are some notes sprinkled
-// around about supporting relocatable command buffers by storing only relative
-// offsets in all data. If command buffers _were_ relocatable we could record on
-// the host and broadcast to all devices with SDMA, or record to a single device
-// and use P2P SDMA to replicate it.
-//
-// Thread-compatible during recording and thread-safe once finalized. Multiple
-// threads are allowed to submit the command buffer for execution concurrently.
-typedef struct iree_hal_amdgpu_command_buffer_t {
-  iree_hal_command_buffer_t base;
-  iree_allocator_t host_allocator;
-
-  // Bitmap of physical devices the command buffer has been prepared for.
-  // Not all devices may have a copy of the command buffer. Population count
-  // denotes the number of devices.
-  iree_hal_amdgpu_device_affinity_t device_affinity;
-
-  // Maximum kernarg capacity required to execute any block in the command
-  // buffer, in bytes.
-  iree_host_size_t max_kernarg_capacity;
-
-  // State used only during recording.
-  struct {
-    // Arena used during recording. Reset after recording completes.
-    // Blocks are allocated from the host pool.
-    iree_arena_allocator_t host_arena;
-
-    // The last executable referenced by a dispatch. Since a large majority of
-    // the time there is a single executable we fast-path this to avoid
-    // thrashing the resource set LRU cache and growing the resource set
-    // allocation/cleanup time.
-    iree_hal_executable_t* last_executable;
-
-    // Active command buffer block encoder for all devices if the command buffer
-    // is recording. Encoders are acquired from the small block pool when
-    // recording begins and discarded upon finalization.
-    iree_hal_amdgpu_command_encoder_t* encoder;
-  } recording_state;
-
-  // Retains references to any HAL resources used by the command buffer.
-  // Only one reference is needed regardless of the number of devices the
-  // command buffer has been instantiated on.
-  iree_hal_resource_set_t* resource_set;
-
-  // Compacted list of device-side copies of the command buffer.
-  // Only those device ordinals specified by the device_affinity bitmap are
-  // present. A device affinity of 0b110 would lead to two device command
-  // buffers in the list at [0] and [1].
-  struct {
-    // Device-side command buffer descriptor used to launch execution.
-    // Only allocated after recording has ended. Lives within the metadata pool.
-    IREE_AMDGPU_DEVICE_PTR iree_hal_amdgpu_device_command_buffer_t* handle;
-    // Pool used for the device-side command blocks.
-    iree_hal_amdgpu_block_pool_t* cmd_block_pool;
-    // List of blocks containing commands allocated from the device block pool.
-    iree_hal_amdgpu_block_t* cmd_block_head;
-    // Pool used for the device-side metadata like the command buffer and CFG
-    // blocks that is usually smaller than the storage pool meant for bulk data.
-    iree_hal_amdgpu_block_pool_t* metadata_pool;
-    // List of blocks allocated from the device block pool storing the
-    // device-side command buffer metadata.
-    iree_hal_amdgpu_block_t* metadata_head;
-    // Device block pool used for command buffer storage. Must be large enough
-    // to contain any single dynamic allocation (64KB for updates) and
-    // determines the maximum number of commands per command buffer block.
-    iree_hal_amdgpu_block_pool_t* storage_pool;
-    // List of blocks allocated from the device block pool storing the
-    // device-side command buffer embedded data.
-    iree_hal_amdgpu_block_t* storage_head;
-  } device_state[/*popcnt(device_affinity)*/];
-} iree_hal_amdgpu_command_buffer_t;
-
-static const iree_hal_command_buffer_vtable_t
-    iree_hal_amdgpu_command_buffer_vtable;
-
-static iree_hal_amdgpu_command_buffer_t* iree_hal_amdgpu_command_buffer_cast(
-    iree_hal_command_buffer_t* base_value) {
-  IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_amdgpu_command_buffer_vtable);
-  return (iree_hal_amdgpu_command_buffer_t*)base_value;
-}
-
-iree_status_t iree_hal_amdgpu_command_buffer_create(
-    const iree_hal_amdgpu_command_buffer_options_t* options,
-    iree_allocator_t host_allocator,
-    iree_hal_command_buffer_t** out_command_buffer) {
-  IREE_ASSERT_ARGUMENT(options);
-  IREE_ASSERT_ARGUMENT(out_command_buffer);
-  IREE_TRACE_ZONE_BEGIN(z0);
-  *out_command_buffer = NULL;
-
-  // Verify that the provided options are supported across all target devices.
-  IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0, iree_hal_amdgpu_command_buffer_options_verify(options));
-
-  // Allocate command buffer storage as:
-  // { command buffer, device_state[], validation_state }
-  iree_hal_amdgpu_command_buffer_t* command_buffer = NULL;
-  const iree_host_size_t validation_state_size =
-      iree_hal_command_buffer_validation_state_size(options->mode,
-                                                    options->binding_capacity);
-  const iree_host_size_t device_count =
-      iree_hal_amdgpu_device_affinity_count(options->device_affinity);
-  const iree_host_size_t total_size =
-      sizeof(*command_buffer) +
-      device_count * sizeof(command_buffer->device_state[0]) +
-      validation_state_size;
-  IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0, iree_allocator_malloc(host_allocator, total_size,
-                                (void**)&command_buffer));
-  iree_hal_command_buffer_initialize(
-      options->device_allocator, options->mode, options->command_categories,
-      options->queue_affinity, options->binding_capacity,
-      (uint8_t*)command_buffer + total_size - validation_state_size,
-      &iree_hal_amdgpu_command_buffer_vtable, &command_buffer->base);
-  command_buffer->host_allocator = host_allocator;
-  command_buffer->device_affinity = options->device_affinity;
-  for (iree_host_size_t i = 0; i < device_count; ++i) {
-    command_buffer->device_state[i].cmd_block_pool =
-        &options->device_block_pools[i]->large;
-    command_buffer->device_state[i].metadata_pool =
-        &options->device_block_pools[i]->small;
-    command_buffer->device_state[i].storage_pool =
-        &options->device_block_pools[i]->large;
-  }
-
-  // Setup initial recording state while we have all of the options available.
-  // If we supported re-recording we'd need to save this information so that we
-  // can reinitialize the state in _begin.
-  iree_arena_initialize(&options->host_block_pools->large,
-                        &command_buffer->recording_state.host_arena);
-  iree_status_t status = iree_hal_amdgpu_command_encoder_initialize(
-      options, device_count, &command_buffer->recording_state.host_arena,
-      &command_buffer->recording_state.encoder);
-
-  // Allocate resource set from a host block pool.
-  // We expect to have a small number of resources but this may not be the case:
-  // we may have 1 executable if every binding is indirect or 1000 buffers if
-  // the user is silly and passed in 1000 buffers. We err on the side of not
-  // wasting so much memory right now with the intent that everyone should be
-  // running with reusable command buffers if they care. If we did go larger we
-  // would want to repack/trim the resource set when freezing (if not one-shot
-  // where it doesn't matter). The risk with large allocs is that a user with
-  // 10000 reusable command buffers will eat all that memory forever.
-  if (iree_status_is_ok(status) &&
-      !iree_all_bits_set(options->mode,
-                         IREE_HAL_COMMAND_BUFFER_MODE_UNRETAINED)) {
-    status = iree_hal_resource_set_allocate(&options->host_block_pools->small,
-                                            &command_buffer->resource_set);
-  }
-
-  if (iree_status_is_ok(status)) {
-    *out_command_buffer = &command_buffer->base;
-  } else {
-    iree_hal_command_buffer_release(&command_buffer->base);
-  }
-  IREE_TRACE_ZONE_END(z0);
-  return status;
-}
-
-static void iree_hal_amdgpu_command_buffer_destroy(
-    iree_hal_command_buffer_t* base_command_buffer) {
-  iree_hal_amdgpu_command_buffer_t* command_buffer =
-      iree_hal_amdgpu_command_buffer_cast(base_command_buffer);
-  iree_allocator_t host_allocator = command_buffer->host_allocator;
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  // Release all device resources (if we completed recording).
-  const iree_host_size_t device_count =
-      iree_hal_amdgpu_device_affinity_count(command_buffer->device_affinity);
-  for (iree_host_size_t i = 0; i < device_count; ++i) {
-    iree_hal_amdgpu_block_pool_release_list(
-        command_buffer->device_state[i].cmd_block_pool,
-        command_buffer->device_state[i].cmd_block_head);
-    iree_hal_amdgpu_block_pool_release_list(
-        command_buffer->device_state[i].metadata_pool,
-        command_buffer->device_state[i].metadata_head);
-    iree_hal_amdgpu_block_pool_release_list(
-        command_buffer->device_state[i].storage_pool,
-        command_buffer->device_state[i].storage_head);
-  }
-
-  // Recording state should have been cleaned up but may not be if there was a
-  // failure during recording.
-  iree_hal_amdgpu_command_encoder_deinitialize(
-      command_buffer->recording_state.encoder);
-  iree_arena_deinitialize(&command_buffer->recording_state.host_arena);
-
-  // Release all resources the command buffer is retaining (buffers,
-  // executables, etc) and return the resource set memory back to the host pool.
-  iree_hal_resource_set_free(command_buffer->resource_set);
-
-  iree_allocator_free(host_allocator, command_buffer);
-
-  IREE_TRACE_ZONE_END(z0);
-}
-
-bool iree_hal_amdgpu_command_buffer_isa(
-    iree_hal_command_buffer_t* command_buffer) {
-  return iree_hal_resource_is(&command_buffer->resource,
-                              &iree_hal_amdgpu_command_buffer_vtable);
-}
-
-iree_status_t iree_hal_amdgpu_command_buffer_query_execution_state(
-    iree_hal_command_buffer_t* base_command_buffer,
-    iree_host_size_t device_ordinal,
-    IREE_AMDGPU_DEVICE_PTR iree_hal_amdgpu_device_command_buffer_t**
-        out_device_command_buffer,
-    iree_host_size_t* out_max_kernarg_capacity) {
-  iree_hal_amdgpu_command_buffer_t* command_buffer =
-      iree_hal_amdgpu_command_buffer_cast(base_command_buffer);
-
-  const int device_count =
-      iree_hal_amdgpu_device_affinity_count(command_buffer->device_affinity);
-  if (IREE_UNLIKELY(device_ordinal >= device_count)) {
-    return iree_make_status(
-        IREE_STATUS_OUT_OF_RANGE,
-        "device ordinal %" PRIhsz
-        " out of range; command buffer was allocated for %d devices",
-        device_ordinal, device_count);
-  }
-
-  iree_hal_amdgpu_device_command_buffer_t* device_command_buffer =
-      command_buffer->device_state[device_ordinal].handle;
-  if (IREE_UNLIKELY(device_command_buffer == NULL)) {
-    return iree_make_status(
-        IREE_STATUS_FAILED_PRECONDITION,
-        "command buffer was not recorded for device ordinal %" PRIhsz
-        "; queue affinity provided during construction must include all "
-        "devices the command buffer may be executed on");
-  }
-
-  *out_device_command_buffer = device_command_buffer;
-  *out_max_kernarg_capacity = command_buffer->max_kernarg_capacity;
-
-  return iree_ok_status();
-}
-
-static iree_status_t iree_hal_amdgpu_command_buffer_begin(
-    iree_hal_command_buffer_t* base_command_buffer) {
-  iree_hal_amdgpu_command_buffer_t* command_buffer =
-      iree_hal_amdgpu_command_buffer_cast(base_command_buffer);
-
-  // Today we only support a single recording. If we want to allow re-recording
-  // (we don't) we'd need to preserve enough information to initialize the
-  // encoder again.
-  iree_hal_amdgpu_command_encoder_t* encoder =
-      command_buffer->recording_state.encoder;
-  if (IREE_UNLIKELY(!encoder)) {
-    return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
-                            "command buffers can only be recorded once");
-  }
-
-  // Begin a new block.
-  uint32_t block_ordinal = 0;
-  IREE_RETURN_IF_ERROR(
-      iree_hal_amdgpu_command_encoder_begin_block(encoder, &block_ordinal));
-  IREE_ASSERT_EQ(block_ordinal, 0, "must start at block 0");
-
-  return iree_ok_status();
-}
-
-// Finalizes command buffer recording by uploading metadata structures and
-// possibly compacting command storage. Must be called after all blocks have
-// completed recording.
-static iree_status_t iree_hal_amdgpu_command_buffer_finalize(
-    iree_hal_amdgpu_command_buffer_t* command_buffer) {
-  IREE_ASSERT(command_buffer);
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  iree_hal_amdgpu_command_encoder_t* encoder =
-      command_buffer->recording_state.encoder;
-  if (IREE_UNLIKELY(encoder->in_block)) {
-    IREE_TRACE_ZONE_END(z0);
-    return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
-                            "last block was not ended prior to finalizing; "
-                            "all recording must have completed");
-  }
-
-  // Capture host-side metadata used for submissions.
-  command_buffer->max_kernarg_capacity = encoder->peak_kernarg_size;
-
-  // Capture device-side metadata used for scheduling.
-  for (iree_host_size_t i = 0; i < encoder->device_count; ++i) {
-    iree_hal_amdgpu_device_encoder_state_t* device_state =
-        &encoder->device_state[i];
-
-    // Allocate the command buffer wrapper now that we know the final block
-    // count.
-    //
-    // TODO(benvanik): this _may_ allocate a new block just for the metadata and
-    // that's unfortunate: we should do something better to avoid 1% utilization
-    // of a new block.
-    IREE_AMDGPU_DEVICE_PTR iree_hal_amdgpu_device_command_buffer_t* handle =
-        NULL;
-    const iree_host_size_t handle_size =
-        sizeof(*handle) + encoder->block_count * sizeof(handle->blocks[0]);
-    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_block_arena_allocate(
-        &device_state->metadata_arena, handle_size, (void**)&handle));
-    handle->max_kernarg_capacity = encoder->peak_kernarg_size;
-    handle->block_count = encoder->block_count;
-    command_buffer->device_state[i].handle = handle;
-
-    // Take ownership of the command block linked list and associated metadata.
-    // After this point the command buffer will be responsible for cleaning up
-    // the block storage when needed.
-    iree_hal_amdgpu_block_t* cmd_head = device_state->cmd_block.head;
-    device_state->cmd_block.head = NULL;
-    device_state->cmd_block.tail = NULL;
-    IREE_ASSERT(!device_state->cmd_block.current);
-    command_buffer->device_state[i].cmd_block_head = cmd_head;
-    iree_hal_amdgpu_block_t* cmd_block = cmd_head;
-    for (iree_host_size_t j = 0; j < encoder->block_count;
-         ++j, cmd_block = cmd_block->next) {
-      handle->blocks[j] =
-          (iree_hal_amdgpu_device_command_block_t*)cmd_block->user_data[0];
-    }
-    command_buffer->device_state[i].metadata_head =
-        iree_hal_amdgpu_block_arena_release_blocks(
-            &device_state->metadata_arena);
-
-    // Take ownership of the embedded data storage linked list.
-    command_buffer->device_state[i].storage_head =
-        iree_hal_amdgpu_block_arena_release_blocks(
-            &device_state->storage_arena);
-  }
-
-  IREE_TRACE_ZONE_END(z0);
-  return iree_ok_status();
-}
-
-static iree_status_t iree_hal_amdgpu_command_buffer_end(
-    iree_hal_command_buffer_t* base_command_buffer) {
-  iree_hal_amdgpu_command_buffer_t* command_buffer =
-      iree_hal_amdgpu_command_buffer_cast(base_command_buffer);
-
-  iree_hal_amdgpu_command_encoder_t* encoder =
-      command_buffer->recording_state.encoder;
-
-  // Return from the command buffer.
-  // A barrier is implicit and this will also take care of any pending barrier
-  // bit if one is set. This _should_ always be able to be appended in the
-  // current block without needing a split as we reserve one command slot for
-  // terminators while recording.
-  iree_hal_amdgpu_device_cmd_return_t** device_cmds = NULL;
-  uint32_t kernarg_offset = 0;
-  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_command_encoder_append_cmd(
-      encoder, IREE_HAL_AMDGPU_DEVICE_CMD_RETURN,
-      IREE_HAL_AMDGPU_DEVICE_CMD_FLAG_QUEUE_AWAIT_BARRIER,
-      IREE_HAL_AMDGPU_DEVICE_CMD_RETURN_AQL_PACKET_COUNT,
-      IREE_HAL_AMDGPU_DEVICE_CMD_RETURN_KERNARG_SIZE,
-      IREE_HAL_AMDGPU_DEVICE_CMD_RETURN_KERNARG_ALIGNMENT,
-      (void***)&device_cmds, &kernarg_offset));
-  for (iree_host_size_t i = 0; i < encoder->device_count; ++i) {
-    device_cmds[i]->kernarg_offset = kernarg_offset;
-  }
-
-  // Flush the current block and reset the encoder.
-  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_command_encoder_end_block(encoder));
-
-  // Finalize the command buffer and upload metadata to all devices.
-  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_command_buffer_finalize(command_buffer));
-
-  // Deallocate the block encoders by resetting the arena they were allocated
-  // from and cleaning up other host-only recording state.
-  iree_hal_amdgpu_command_encoder_deinitialize(encoder);
-  iree_arena_deinitialize(&command_buffer->recording_state.host_arena);
-  memset(&command_buffer->recording_state, 0,
-         sizeof(command_buffer->recording_state));
-
-  // Freeze the resource set as all resources have been added and it is now
-  // immutable. It can be nested within another resource set if it needs
-  // extension.
-  iree_hal_resource_set_freeze(command_buffer->resource_set);
-
-  return iree_ok_status();
-}
-
-static iree_status_t iree_hal_amdgpu_command_buffer_begin_debug_group(
-    iree_hal_command_buffer_t* base_command_buffer, iree_string_view_t label,
-    iree_hal_label_color_t label_color,
-    const iree_hal_label_location_t* location) {
-  // TODO(benvanik): if we route these to tooling/profilers we'll want to enable
-  // them with some flag bits.
-#if !(IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION_DEVICE)
-  // Tracing disabled - this wouldn't go anywhere.
-  return iree_ok_status();
-#endif  // !IREE_TRACING_FEATURE_INSTRUMENTATION_DEVICE
-
-  iree_hal_amdgpu_command_buffer_t* command_buffer =
-      iree_hal_amdgpu_command_buffer_cast(base_command_buffer);
-
-  // TODO(benvanik): intern label and source location.
-  // Tracy requires that the pointers remain valid until the termination of
-  // the program _or_ they be dynamically allocated for single usage. The issue
-  // is that with reusable command buffers we both don't know the lifetime of
-  // the input we were given or how many times the command buffer will be used.
-  //
-  // One solution would be an big hash map. That adds overhead during recording
-  // but is free from then on and for a majority of our command buffers that are
-  // reusable will be in the noise. The set of unique source locations and
-  // labels should be fairly small (we don't allow dynamic values) and these
-  // debug groups aren't super common today.
-  const char* label_literal = "todo_label_interning";
-  static iree_hal_amdgpu_trace_src_loc_t dummy_src_loc = {
-      .name = NULL,
-      .function = NULL,
-      .file = NULL,
-      .line = 0,
-      .color = 0,
-  };
-  iree_hal_amdgpu_trace_src_loc_ptr_t src_loc =
-      (iree_hal_amdgpu_trace_src_loc_ptr_t)&dummy_src_loc;
-
-  iree_hal_amdgpu_command_encoder_t* encoder =
-      command_buffer->recording_state.encoder;
-
-  // Begin scope with full barrier so that all prior work completes before any
-  // nested commands begin execution. This needs us to set the flag as if an
-  // execution barrier had been scheduled so that the next command recorded
-  // waits.
-  iree_hal_amdgpu_device_cmd_debug_group_begin_t** device_cmds = NULL;
-  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_command_encoder_append_cmd(
-      encoder, IREE_HAL_AMDGPU_DEVICE_CMD_DEBUG_GROUP_BEGIN,
-      IREE_HAL_AMDGPU_DEVICE_CMD_FLAG_QUEUE_AWAIT_BARRIER,
-      IREE_HAL_AMDGPU_DEVICE_CMD_DEBUG_GROUP_BEGIN_AQL_PACKET_COUNT,
-      /*kernarg_size=*/0, /*kernarg_alignment=*/0, (void***)&device_cmds,
-      NULL));
-
-  // Loop over each device this command buffer is being broadcast to apply
-  // per-device information.
-  for (iree_host_size_t i = 0; i < encoder->device_count; ++i) {
-    iree_hal_amdgpu_device_cmd_debug_group_begin_t* cmd = device_cmds[i];
-    cmd->src_loc = src_loc;
-    cmd->label_literal = (uint64_t)label_literal;
-    cmd->label_literal_length = label.size;
-    cmd->color = label_color.value;
-  }
-
-  return iree_ok_status();
-}
-
-static iree_status_t iree_hal_amdgpu_command_buffer_end_debug_group(
-    iree_hal_command_buffer_t* base_command_buffer) {
-#if !(IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION_DEVICE)
-  // Tracing disabled - this wouldn't go anywhere.
-  return iree_ok_status();
-#endif  // !IREE_TRACING_FEATURE_INSTRUMENTATION_DEVICE
-  iree_hal_amdgpu_command_buffer_t* command_buffer =
-      iree_hal_amdgpu_command_buffer_cast(base_command_buffer);
-
-  iree_hal_amdgpu_command_encoder_t* encoder =
-      command_buffer->recording_state.encoder;
-
-  // End scope with full barrier so that any nested commands complete before
-  // capturing.
-  iree_hal_amdgpu_device_cmd_debug_group_end_t** device_cmds = NULL;
-  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_command_encoder_append_cmd(
-      encoder, IREE_HAL_AMDGPU_DEVICE_CMD_DEBUG_GROUP_END,
-      IREE_HAL_AMDGPU_DEVICE_CMD_FLAG_QUEUE_AWAIT_BARRIER,
-      IREE_HAL_AMDGPU_DEVICE_CMD_DEBUG_GROUP_END_AQL_PACKET_COUNT,
-      /*kernarg_size=*/0, /*kernarg_alignment=*/0, (void***)&device_cmds,
-      NULL));
-
-  return iree_ok_status();
-}
-
-static iree_status_t iree_hal_amdgpu_command_buffer_execution_barrier(
-    iree_hal_command_buffer_t* base_command_buffer,
-    iree_hal_execution_stage_t source_stage_mask,
-    iree_hal_execution_stage_t target_stage_mask,
-    iree_hal_execution_barrier_flags_t flags,
-    iree_host_size_t memory_barrier_count,
-    const iree_hal_memory_barrier_t* memory_barriers,
-    iree_host_size_t buffer_barrier_count,
-    const iree_hal_buffer_barrier_t* buffer_barriers) {
-  iree_hal_amdgpu_command_buffer_t* command_buffer =
-      iree_hal_amdgpu_command_buffer_cast(base_command_buffer);
-
-  iree_hal_amdgpu_command_encoder_t* encoder =
-      command_buffer->recording_state.encoder;
-
-  // Force the next command to have a barrier flag on it and wait until all
-  // prior commands have completed.
-  encoder->barrier_pending = 1;
-
-  // TODO(benvanik): emit a barrier command if needed.
-  // I'm not sure it is today but we have it encoded for cases where we can't
-  // carry across barrier bits (between blocks/etc).
-  //
-  // iree_hal_amdgpu_device_cmd_barrier_t** device_cmds = NULL;
-  // IREE_RETURN_IF_ERROR(iree_hal_amdgpu_command_encoder_append_cmd(
-  //     command_buffer->encoder, IREE_HAL_AMDGPU_DEVICE_CMD_BARRIER,
-  //     IREE_HAL_AMDGPU_DEVICE_CMD_FLAG_QUEUE_AWAIT_BARRIER,
-  //     IREE_HAL_AMDGPU_DEVICE_CMD_BARRIER_AQL_PACKET_COUNT,
-  //     /*kernarg_size=*/0, /*kernarg_alignment=*/0, (void***)&device_cmds,
-  //     NULL));
-  //
-  // The case where this would be most helpful is on fork/join of multiple
-  // commands. If we want to have an acquire, multiple dispatches, and release
-  // we'd need the barrier packets to anchor those scope operations on.
-  // The command buffer API doesn't let us predict subsequent commands on
-  // barriers and without fixup we don't know when it's worth adding the extra
-  // packet (which introduces additional issue latency). This is a place where
-  // the command buffer API could be improved around fork/join subgraphs.
-
-  return iree_ok_status();
-}
-
-static iree_status_t iree_hal_amdgpu_command_buffer_signal_event(
-    iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event,
-    iree_hal_execution_stage_t source_stage_mask) {
-  iree_hal_amdgpu_command_buffer_t* command_buffer =
-      iree_hal_amdgpu_command_buffer_cast(base_command_buffer);
-
-  // TODO(benvanik): WIP API and may change; signals the given event allowing
-  // waiters to proceed.
-  (void)command_buffer;
-  iree_status_t status =
-      iree_make_status(IREE_STATUS_UNIMPLEMENTED, "events not implemented");
-
-  return status;
-}
-
-static iree_status_t iree_hal_amdgpu_command_buffer_reset_event(
-    iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event,
-    iree_hal_execution_stage_t source_stage_mask) {
-  iree_hal_amdgpu_command_buffer_t* command_buffer =
-      iree_hal_amdgpu_command_buffer_cast(base_command_buffer);
-
-  // TODO(benvanik): WIP API and may change; resets the given event to
-  // unsignaled.
-  (void)command_buffer;
-  iree_status_t status =
-      iree_make_status(IREE_STATUS_UNIMPLEMENTED, "events not implemented");
-
-  return status;
-}
-
-static iree_status_t iree_hal_amdgpu_command_buffer_wait_events(
-    iree_hal_command_buffer_t* base_command_buffer,
-    iree_host_size_t event_count, const iree_hal_event_t** events,
-    iree_hal_execution_stage_t source_stage_mask,
-    iree_hal_execution_stage_t target_stage_mask,
-    iree_host_size_t memory_barrier_count,
-    const iree_hal_memory_barrier_t* memory_barriers,
-    iree_host_size_t buffer_barrier_count,
-    const iree_hal_buffer_barrier_t* buffer_barriers) {
-  iree_hal_amdgpu_command_buffer_t* command_buffer =
-      iree_hal_amdgpu_command_buffer_cast(base_command_buffer);
-
-  // TODO(benvanik): WIP API and may change; waits on the list of events and
-  // enacts the specified set of barriers. Implementations without fine-grained
-  // tracking can treat this as an execution_barrier and ignore the
-  // memory/buffer barriers provided.
-  (void)command_buffer;
-  iree_status_t status =
-      iree_make_status(IREE_STATUS_UNIMPLEMENTED, "events not implemented");
-
-  return status;
-}
-
-static iree_status_t iree_hal_amdgpu_command_buffer_advise_buffer(
-    iree_hal_command_buffer_t* base_command_buffer,
-    iree_hal_buffer_ref_t buffer_ref, iree_hal_memory_advise_flags_t flags,
-    uint64_t arg0, uint64_t arg1) {
-  iree_hal_amdgpu_command_buffer_t* command_buffer =
-      iree_hal_amdgpu_command_buffer_cast(base_command_buffer);
-
-  // TODO(benvanik): WIP API and may change; this is likely to become an
-  // madvise-like command that can be used to control prefetching and other
-  // cache behavior. The current discard behavior is a hint that the buffer
-  // contents will never be used again and that if they are in a cache they need
-  // not be written back to global memory.
-  (void)command_buffer;
-
-  return iree_ok_status();
-}
-
-static iree_status_t iree_hal_amdgpu_command_buffer_fill_buffer(
-    iree_hal_command_buffer_t* base_command_buffer,
-    iree_hal_buffer_ref_t target_ref, const void* pattern,
-    iree_host_size_t pattern_length, iree_hal_fill_flags_t flags) {
-  iree_hal_amdgpu_command_buffer_t* command_buffer =
-      iree_hal_amdgpu_command_buffer_cast(base_command_buffer);
-
-  // Translate the target buffer to its device-side representations. Translation
-  // is currently device-independent.
-  iree_hal_amdgpu_device_buffer_ref_t resolved_target_ref = {0};
-  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_translate_device_buffer_ref(
-                           target_ref, &resolved_target_ref),
-                       "resolving target_ref");
-
-  // Retain buffers (if specified) for the lifetime of the command buffer.
-  if (target_ref.buffer) {
-    iree_hal_resource_t* resources[] = {
-        (iree_hal_resource_t*)target_ref.buffer,
-    };
-    IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert(
-        command_buffer->resource_set, IREE_ARRAYSIZE(resources), resources));
-  }
-
-  iree_hal_amdgpu_command_encoder_t* encoder =
-      command_buffer->recording_state.encoder;
-
-  // Append command to all devices.
-  iree_hal_amdgpu_device_cmd_fill_buffer_t** device_cmds = NULL;
-  uint32_t kernarg_offset = 0;
-  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_command_encoder_append_cmd(
-      encoder, IREE_HAL_AMDGPU_DEVICE_CMD_FILL_BUFFER,
-      IREE_HAL_AMDGPU_DEVICE_CMD_FLAG_NONE,
-      IREE_HAL_AMDGPU_DEVICE_CMD_FILL_BUFFER_AQL_PACKET_COUNT,
-      IREE_HAL_AMDGPU_DEVICE_BUFFER_FILL_KERNARG_SIZE,
-      IREE_HAL_AMDGPU_DEVICE_BUFFER_FILL_KERNARG_ALIGNMENT,
-      (void***)&device_cmds, &kernarg_offset));
-
-  // Loop over each device this command buffer is being broadcast to apply
-  // per-device information.
-  for (iree_host_size_t i = 0; i < encoder->device_count; ++i) {
-    iree_hal_amdgpu_device_cmd_fill_buffer_t* cmd = device_cmds[i];
-    cmd->kernarg_offset = kernarg_offset;
-    cmd->target_ref = resolved_target_ref;
-    memcpy(&cmd->pattern, pattern, pattern_length);
-    cmd->pattern_length = pattern_length;
-  }
-
-  return iree_ok_status();
-}
-
-static iree_status_t iree_hal_amdgpu_command_buffer_update_buffer(
-    iree_hal_command_buffer_t* base_command_buffer, const void* source_buffer,
-    iree_host_size_t source_offset, iree_hal_buffer_ref_t target_ref,
-    iree_hal_update_flags_t flags) {
-  iree_hal_amdgpu_command_buffer_t* command_buffer =
-      iree_hal_amdgpu_command_buffer_cast(base_command_buffer);
-
-  // Translate the target buffer to its device-side representations. Translation
-  // is currently device-independent.
-  iree_hal_amdgpu_device_buffer_ref_t resolved_target_ref = {0};
-  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_translate_device_buffer_ref(
-                           target_ref, &resolved_target_ref),
-                       "resolving target_ref");
-
-  // If the size exceeds the data block capacity we cannot allocate it - we
-  // could support oversized allocations but command buffers are designed to
-  // only contain small allocations and the device block pools should have a
-  // large enough size (we verify on construction).
-  const iree_device_size_t aligned_size =
-      iree_device_align(target_ref.length, iree_hal_amdgpu_max_align_t);
-  if (IREE_UNLIKELY(aligned_size >
-                    IREE_HAL_AMDGPU_COMMAND_BUFFER_MIN_STORAGE_BLOCK_SIZE)) {
-    return iree_make_status(
-        IREE_STATUS_FAILED_PRECONDITION,
-        "aligned size of embedded command buffer data %" PRIdsz
-        " exceeds minimum block pool storage capacity of %d per block",
-        aligned_size, IREE_HAL_AMDGPU_COMMAND_BUFFER_MIN_STORAGE_BLOCK_SIZE);
-  }
-
-  // Retain buffers (if specified) for the lifetime of the command buffer.
-  if (target_ref.buffer) {
-    iree_hal_resource_t* resources[] = {
-        (iree_hal_resource_t*)target_ref.buffer,
-    };
-    IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert(
-        command_buffer->resource_set, IREE_ARRAYSIZE(resources), resources));
-  }
-
-  iree_hal_amdgpu_command_encoder_t* encoder =
-      command_buffer->recording_state.encoder;
-
-  // Append command to all devices.
-  iree_hal_amdgpu_device_cmd_copy_buffer_t** device_cmds = NULL;
-  uint32_t kernarg_offset = 0;
-  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_command_encoder_append_cmd(
-      encoder, IREE_HAL_AMDGPU_DEVICE_CMD_COPY_BUFFER,
-      IREE_HAL_AMDGPU_DEVICE_CMD_FLAG_NONE,
-      IREE_HAL_AMDGPU_DEVICE_CMD_COPY_BUFFER_AQL_PACKET_COUNT,
-      IREE_HAL_AMDGPU_DEVICE_BUFFER_COPY_KERNARG_SIZE,
-      IREE_HAL_AMDGPU_DEVICE_BUFFER_COPY_KERNARG_ALIGNMENT,
-      (void***)&device_cmds, &kernarg_offset));
-
-  // Loop over each device this command buffer is being broadcast to apply
-  // per-device information.
-  for (iree_host_size_t i = 0; i < encoder->device_count; ++i) {
-    iree_hal_amdgpu_device_cmd_copy_buffer_t* cmd = device_cmds[i];
-    cmd->kernarg_offset = kernarg_offset;
-    cmd->target_ref = resolved_target_ref;
-
-    // Capture and embed the source data per-device.
-    //
-    // TODO(benvanik): hoist embedding to lead device if requested by
-    // IREE_HAL_AMDGPU_COMMAND_BUFFER_RECORDING_FLAG_DATA_ON_LEAD_PHYSICAL_DEVICE.
-    // That would reduce total memory consumption in multi-device cases at the
-    // cost of additional cross-device traffic when the command buffer is issued
-    // on other devices. I don't think we want that, but may be worthwhile in
-    // cases where most workloads are on the lead device, the lead device is in
-    // an independent power island, or the user _really_ cares about memory
-    // consumption of a lot of command buffers.
-    cmd->source_ref = (iree_hal_amdgpu_device_buffer_ref_t){
-        .type = IREE_HAL_AMDGPU_DEVICE_BUFFER_TYPE_PTR,
-        .offset = 0,
-        .length = resolved_target_ref.length,
-        .value.bits = 0,
-    };
-    IREE_RETURN_IF_ERROR(
-        iree_hal_amdgpu_command_encoder_emplace_data(
-            encoder, i, iree_make_const_byte_span(source_buffer, source_offset),
-            &cmd->source_ref.value.ptr),
-        "embedding update data of %" PRIu64 "B",
-        (uint64_t)resolved_target_ref.length);
-  }
-
-  return iree_ok_status();
-}
-
-static iree_status_t iree_hal_amdgpu_command_buffer_copy_buffer(
-    iree_hal_command_buffer_t* base_command_buffer,
-    iree_hal_buffer_ref_t source_ref, iree_hal_buffer_ref_t target_ref,
-    iree_hal_copy_flags_t flags) {
-  iree_hal_amdgpu_command_buffer_t* command_buffer =
-      iree_hal_amdgpu_command_buffer_cast(base_command_buffer);
-
-  // Translate the buffers to their device-side representations. Translation is
-  // currently device-independent.
-  iree_hal_amdgpu_device_buffer_ref_t resolved_source_ref = {0};
-  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_translate_device_buffer_ref(
-                           source_ref, &resolved_source_ref),
-                       "resolving source_ref");
-  iree_hal_amdgpu_device_buffer_ref_t resolved_target_ref = {0};
-  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_translate_device_buffer_ref(
-                           target_ref, &resolved_target_ref),
-                       "resolving target_ref");
-
-  // Retain buffers (if specified) for the lifetime of the command buffer.
-  if (source_ref.buffer || target_ref.buffer) {
-    iree_hal_resource_t* resources[] = {
-        (iree_hal_resource_t*)source_ref.buffer,
-        (iree_hal_resource_t*)target_ref.buffer,
-    };
-    IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert(
-        command_buffer->resource_set, IREE_ARRAYSIZE(resources), resources));
-  }
-
-  iree_hal_amdgpu_command_encoder_t* encoder =
-      command_buffer->recording_state.encoder;
-
-  // Append command to all devices.
-  iree_hal_amdgpu_device_cmd_copy_buffer_t** device_cmds = NULL;
-  uint32_t kernarg_offset = 0;
-  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_command_encoder_append_cmd(
-      encoder, IREE_HAL_AMDGPU_DEVICE_CMD_COPY_BUFFER,
-      IREE_HAL_AMDGPU_DEVICE_CMD_FLAG_NONE,
-      IREE_HAL_AMDGPU_DEVICE_CMD_COPY_BUFFER_AQL_PACKET_COUNT,
-      IREE_HAL_AMDGPU_DEVICE_BUFFER_COPY_KERNARG_SIZE,
-      IREE_HAL_AMDGPU_DEVICE_BUFFER_COPY_KERNARG_ALIGNMENT,
-      (void***)&device_cmds, &kernarg_offset));
-
-  // Loop over each device this command buffer is being broadcast to apply
-  // per-device information.
-  for (iree_host_size_t i = 0; i < encoder->device_count; ++i) {
-    iree_hal_amdgpu_device_cmd_copy_buffer_t* cmd = device_cmds[i];
-    cmd->kernarg_offset = kernarg_offset;
-    cmd->source_ref = resolved_source_ref;
-    cmd->target_ref = resolved_target_ref;
-  }
-
-  return iree_ok_status();
-}
-
-static iree_status_t iree_hal_amdgpu_command_buffer_collective(
-    iree_hal_command_buffer_t* base_command_buffer, iree_hal_channel_t* channel,
-    iree_hal_collective_op_t op, uint32_t param, iree_hal_buffer_ref_t send_ref,
-    iree_hal_buffer_ref_t recv_ref, iree_device_size_t element_count) {
-  iree_hal_amdgpu_command_buffer_t* command_buffer =
-      iree_hal_amdgpu_command_buffer_cast(base_command_buffer);
-
-  // TODO(benvanik): perform the collective operation defined by op. See the
-  // headers for more information. The channel is fixed for a particular
-  // recording but note that either buffer may be a reference to a binding table
-  // slot in which case it will be provided during submission to a queue.
-  (void)command_buffer;
-  iree_status_t status = iree_make_status(IREE_STATUS_UNIMPLEMENTED,
-                                          "collectives not implemented");
-
-  return status;
-}
-
-// Records a direct or indirect dispatch on each physical device required.
-// If |is_indirect| is true the |workgroup_count| must contain a valid workgroup
-// count buffer reference and otherwise the static workgroup count dimensions
-// will be used.
-static iree_status_t iree_hal_amdgpu_command_buffer_dispatch(
-    iree_hal_command_buffer_t* base_command_buffer,
-    iree_hal_executable_t* executable,
-    iree_hal_executable_export_ordinal_t export_ordinal,
-    const iree_hal_dispatch_config_t config, iree_const_byte_span_t constants,
-    iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) {
-  iree_hal_amdgpu_command_buffer_t* command_buffer =
-      iree_hal_amdgpu_command_buffer_cast(base_command_buffer);
-
-  if (iree_hal_dispatch_uses_custom_arguments(flags)) {
-    return iree_make_status(
-        IREE_STATUS_UNIMPLEMENTED,
-        "direct/indirect arguments are not supported in AMDGPU yet");
-  }
-
-  // Configure dispatch flags controlling how the scheduler interprets the
-  // command when issuing. We use one code path for static dispatches where
-  // the workgroup count is available either directly in the recorded command or
-  // via an indirect workgroup count buffer that is known static at the time the
-  // command buffer is issued. Dynamic indirect workgroup counts that may be
-  // populated by a prior dispatch/transfer in the same command buffer require
-  // some additional work.
-  iree_hal_amdgpu_device_dispatch_flags_t cmd_flags =
-      IREE_HAL_AMDGPU_DEVICE_DISPATCH_FLAG_NONE;
-  if (iree_any_bit_set(flags,
-                       IREE_HAL_DISPATCH_FLAG_STATIC_INDIRECT_PARAMETERS)) {
-    cmd_flags |= IREE_HAL_AMDGPU_DEVICE_DISPATCH_FLAG_INDIRECT_STATIC;
-  } else if (iree_any_bit_set(
-                 flags,
-                 IREE_HAL_AMDGPU_DEVICE_DISPATCH_FLAG_INDIRECT_DYNAMIC)) {
-    cmd_flags |= IREE_HAL_AMDGPU_DEVICE_DISPATCH_FLAG_INDIRECT_DYNAMIC;
-  }
-
-  // Resolve the workgroup count buffer to the thin buffer ref used by the
-  // device-side command buffer.
-  iree_hal_amdgpu_device_workgroup_count_t workgroup_count;
-  if (iree_hal_dispatch_uses_indirect_parameters(flags)) {
-    iree_hal_amdgpu_device_buffer_ref_t resolved_workgroup_count_ref = {0};
-    IREE_RETURN_IF_ERROR(
-        iree_hal_amdgpu_translate_device_buffer_ref(
-            config.workgroup_count_ref, &resolved_workgroup_count_ref),
-        "resolving indirect workgroups_ref");
-    if (IREE_UNLIKELY(resolved_workgroup_count_ref.length <
-                      sizeof(uint32_t[3]))) {
-      return iree_make_status(
-          IREE_STATUS_INVALID_ARGUMENT,
-          "indirect workgroup count buffer reference must be at least "
-          "uint32_t[3] (12B) but it resolved to %" PRIu64 "B",
-          (uint64_t)resolved_workgroup_count_ref.length);
-    }
-    workgroup_count = (iree_hal_amdgpu_device_workgroup_count_t){
-        .ref =
-            {
-                .type = resolved_workgroup_count_ref.type,
-                .offset = resolved_workgroup_count_ref.offset,
-                .value.bits = resolved_workgroup_count_ref.value.bits,
-            },
-    };
-  } else {
-    workgroup_count = (iree_hal_amdgpu_device_workgroup_count_t){
-        .dims =
-            {
-                config.workgroup_count[0],
-                config.workgroup_count[1],
-                config.workgroup_count[2],
-            },
-    };
-  }
-
-  // Lookup the kernel metadata in host memory so we don't risk crossing a bus.
-  // Validation is amortized across all target devices as all devices will have
-  // the same export (just different kernel object pointers).
-  const iree_hal_amdgpu_device_kernel_args_t* host_kernel_args = NULL;
-  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_executable_lookup_kernel_args_for_host(
-      executable, export_ordinal, &host_kernel_args));
-  if (IREE_UNLIKELY(constants.data_length !=
-                    host_kernel_args->constant_count * sizeof(uint32_t))) {
-    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                            "dispatch requires %" PRIhsz
-                            "B of constants but %" PRIhsz "B was provided",
-                            host_kernel_args->constant_count * sizeof(uint32_t),
-                            constants.data_length);
-  } else if (IREE_UNLIKELY(bindings.count != host_kernel_args->binding_count)) {
-    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                            "dispatch requires %u bindings but %" PRIhsz
-                            " were provided",
-                            host_kernel_args->binding_count, bindings.count);
-  }
-
-  // Constant and binding capacity should have been checked when the executable
-  // was loaded. These limits ensure we don't blow the host stack while
-  // recording.
-  IREE_ASSERT_LE(
-      constants.data_length,
-      IREE_HAL_AMDGPU_MAX_DISPATCH_CONSTANT_COUNT * sizeof(uint32_t));
-  IREE_ASSERT_LE(bindings.count, IREE_HAL_AMDGPU_MAX_DISPATCH_BINDING_COUNT);
-
-  // If the kernarg size declared by the kernel is larger than we expect it
-  // (likely) means there are implicit kernargs expected following the explicit
-  // kernargs we provide for bindings and constants. These are aligned to 8
-  // bytes and follow any padding after the explicit args. Technically the
-  // compiler will truncate the size to only those fields required but doing
-  // that logic during command issue is more expensive than just eating the
-  // extra few dozen bytes of kernarg space that's mostly filled with zeros so
-  // we go to the max size we allow. This handling isn't great but does allow us
-  // to run HIP kernels that follow our ABI even if they use device library code
-  // that often pulls in these implicit args.
-  //
-  // See iree_amdgpu_kernel_implicit_args_t for more information on implicit
-  // args. Since most are constant we don't actually produce them here and leave
-  // that to the issue phase. Note that the implicit args contain the workgroup
-  // count and when indirect we need to update the implicit args along with the
-  // dispatch packet because they store redundant (to us/HIP) information. Yuck.
-  const uint16_t explicit_kernarg_size =
-      bindings.count * sizeof(uint64_t) +
-      iree_host_align(constants.data_length, 8);
-  const uint16_t implicit_kernarg_size =
-      host_kernel_args->kernarg_size - explicit_kernarg_size;
-  cmd_flags |= implicit_kernarg_size > 0
-                   ? IREE_HAL_AMDGPU_DEVICE_DISPATCH_FLAG_IMPLICIT_ARGS
-                   : 0;
-  const uint16_t kernarg_size =
-      implicit_kernarg_size > 0
-          ? explicit_kernarg_size +
-                iree_max(implicit_kernarg_size,
-                         IREE_AMDGPU_KERNEL_IMPLICIT_ARGS_SIZE)
-          : host_kernel_args->kernarg_size;
-
-  // Resolve all bindings ahead of the per-device work.
-  // Today we do not have per-device resolution but could in the future if we
-  // wanted to support replication of buffers across devices such that they had
-  // unique resolved addresses.
-  //
-  // TODO(benvanik): add support for lead-physical-device mode:
-  // IREE_HAL_AMDGPU_COMMAND_BUFFER_RECORDING_FLAG_DATA_ON_LEAD_PHYSICAL_DEVICE.
-  // We'd produce this directly into the lead device memory (+ constants)
-  // outside of the loop and reference it instead of copying it to each device.
-  // A bulk of the command buffer memory is in constants/bindings so if there
-  // was ever a need for lead-device placement it'd be for this information. I'm
-  // not sure there's a need, though, so it's kept simple/time-efficient today.
-  iree_hal_amdgpu_device_buffer_ref_t* resolved_bindings =
-      (iree_hal_amdgpu_device_buffer_ref_t*)iree_alloca(
-          bindings.count * sizeof(iree_hal_amdgpu_device_buffer_ref_t));
-  for (iree_host_size_t i = 0; i < bindings.count; ++i) {
-    IREE_RETURN_IF_ERROR(
-        iree_hal_amdgpu_translate_device_buffer_ref(
-            bindings.values[i],
-            (iree_hal_amdgpu_device_buffer_ref_t*)&resolved_bindings[i]),
-        "resolving binding[%" PRIhsz "]", i);
-  }
-
-  // Insert indirect workgroup count buffers into the resource set.
-  if (iree_hal_dispatch_uses_indirect_parameters(flags)) {
-    IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert(
-        command_buffer->resource_set, 1, &config.workgroup_count_ref.buffer));
-  }
-
-  // Insert all bindings into the resource set.
-  // We do this once regardless of how many devices we broadcast to.
-  IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert_strided(
-      command_buffer->resource_set, bindings.count, bindings.values,
-      offsetof(iree_hal_buffer_ref_t, buffer), sizeof(iree_hal_buffer_ref_t)));
-
-  // Insert executable into the resource set.
-  // This will keep all device copies alive (including those for devices we
-  // aren't recording for - that's ok).
-  //
-  // NOTE: this is 99.9% of the time the same executable and to avoid thrashing
-  // the resource set LRU cache we special case checking for this exact case.
-  // This is a micro-optimization that mostly shows up in either smaller
-  // resource sets or faster cleanup time (fewer redundant references to the
-  // same executable).
-  if (executable != command_buffer->recording_state.last_executable) {
-    IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert(
-        command_buffer->resource_set, 1, &executable));
-    command_buffer->recording_state.last_executable = executable;
-  }
-
-  iree_hal_amdgpu_command_encoder_t* encoder =
-      command_buffer->recording_state.encoder;
-
-  // Append command based on what type of dispatch it is.
-  iree_hal_amdgpu_device_cmd_dispatch_t** device_cmds = NULL;
-  uint32_t kernarg_offset = 0;
-  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_command_encoder_append_cmd(
-      encoder,
-      iree_all_bits_set(cmd_flags,
-                        IREE_HAL_AMDGPU_DEVICE_DISPATCH_FLAG_INDIRECT_DYNAMIC)
-          ? IREE_HAL_AMDGPU_DEVICE_CMD_DISPATCH_INDIRECT_DYNAMIC
-          : IREE_HAL_AMDGPU_DEVICE_CMD_DISPATCH,
-      IREE_HAL_AMDGPU_DEVICE_CMD_FLAG_NONE,
-      IREE_HAL_AMDGPU_DEVICE_CMD_DISPATCH_AQL_PACKET_COUNT(cmd_flags),
-      kernarg_size, host_kernel_args->kernarg_alignment, (void***)&device_cmds,
-      &kernarg_offset));
-
-  // Loop over each device this command buffer is being broadcast to apply
-  // per-device information.
-  IREE_HAL_AMDGPU_FOR_PHYSICAL_DEVICE(command_buffer->device_affinity) {
-    iree_hal_amdgpu_device_cmd_dispatch_t* cmd = device_cmds[device_index];
-    cmd->kernarg_offset = kernarg_offset;
-
-    // Lookup the per-device kernel information. Each device has its own
-    // loaded kernel object pointer.
-    const iree_hal_amdgpu_device_kernel_args_t* device_kernel_args = NULL;
-    IREE_RETURN_IF_ERROR(
-        iree_hal_amdgpu_executable_lookup_kernel_args_for_device(
-            executable, export_ordinal, device_ordinal, &device_kernel_args));
-
-    cmd->config.flags = cmd_flags;
-    cmd->config.kernel_args = device_kernel_args;
-    memcpy(&cmd->config.workgroup_count, &workgroup_count,
-           sizeof(cmd->config.workgroup_count));
-
-    // Copy resolved bindings and constants into the embedded data block.
-    // Since this will be going across the PCI bus we do it as a memcpy in hopes
-    // of getting decent transaction overhead.
-    //
-    // The dispatch command does not require that the bindings and constants be
-    // adjacent in memory but we do so here to reduce overheads by only having
-    // one block growth check and range of memory being touched instead of two.
-    // In the future if we want to do things like share embedded data across
-    // devices to reduce memory consumption we could still shard the commands
-    // (as we have to for unique kernel ptrs) but could share these.
-    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_command_encoder_emplace_data_concat(
-        encoder, device_index,
-        iree_make_const_byte_span(
-            resolved_bindings, bindings.count * sizeof(resolved_bindings[0])),
-        constants, (void**)&cmd->bindings, (void**)&cmd->constants));
-  }
-
-  return iree_ok_status();
-}
-
-static const iree_hal_command_buffer_vtable_t
-    iree_hal_amdgpu_command_buffer_vtable = {
-        .destroy = iree_hal_amdgpu_command_buffer_destroy,
-        .begin = iree_hal_amdgpu_command_buffer_begin,
-        .end = iree_hal_amdgpu_command_buffer_end,
-        .begin_debug_group = iree_hal_amdgpu_command_buffer_begin_debug_group,
-        .end_debug_group = iree_hal_amdgpu_command_buffer_end_debug_group,
-        .execution_barrier = iree_hal_amdgpu_command_buffer_execution_barrier,
-        .signal_event = iree_hal_amdgpu_command_buffer_signal_event,
-        .reset_event = iree_hal_amdgpu_command_buffer_reset_event,
-        .wait_events = iree_hal_amdgpu_command_buffer_wait_events,
-        .advise_buffer = iree_hal_amdgpu_command_buffer_advise_buffer,
-        .fill_buffer = iree_hal_amdgpu_command_buffer_fill_buffer,
-        .update_buffer = iree_hal_amdgpu_command_buffer_update_buffer,
-        .copy_buffer = iree_hal_amdgpu_command_buffer_copy_buffer,
-        .collective = iree_hal_amdgpu_command_buffer_collective,
-        .dispatch = iree_hal_amdgpu_command_buffer_dispatch,
-};
diff --git a/runtime/src/iree/hal/drivers/amdgpu/command_buffer.h b/runtime/src/iree/hal/drivers/amdgpu/command_buffer.h
deleted file mode 100644
index 31a2b01..0000000
--- a/runtime/src/iree/hal/drivers/amdgpu/command_buffer.h
+++ /dev/null
@@ -1,159 +0,0 @@
-// Copyright 2025 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_HAL_DRIVERS_AMDGPU_COMMAND_BUFFER_H_
-#define IREE_HAL_DRIVERS_AMDGPU_COMMAND_BUFFER_H_
-
-#include "iree/base/api.h"
-#include "iree/base/internal/arena.h"
-#include "iree/hal/api.h"
-#include "iree/hal/drivers/amdgpu/device/command_buffer.h"
-#include "iree/hal/drivers/amdgpu/util/affinity.h"
-
-typedef struct iree_hal_amdgpu_block_pools_t iree_hal_amdgpu_block_pools_t;
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_command_buffer_options_t
-//===----------------------------------------------------------------------===//
-
-// Determines where and how command buffers are recorded.
-typedef enum iree_hal_amdgpu_command_buffer_recording_flags_t {
-  IREE_HAL_AMDGPU_COMMAND_BUFFER_RECORDING_FLAG_NONE = 0u,
-
-  // TODO(benvanik): support lead-physical-device storage. This would need the
-  // block pool on the lead device to make its blocks accessible to all devices
-  // - today the block pool is device-local only. Produced data is immutable and
-  // PCIe atomics/coherency is not required across devices.
-  //
-  // Allocate embedded data on the lead physical device instead of on each
-  // device the command buffer is recorded for. This reduces overall memory
-  // consumption and recording time at the cost of cross-device transfers.
-  // IREE_HAL_AMDGPU_COMMAND_BUFFER_RECORDING_FLAG_DATA_ON_LEAD_PHYSICAL_DEVICE
-  // = 1u << 0,
-
-  // TODO(benvanik): support compaction. This would require changing the command
-  // buffer to use relative offsets for embedded data and a data table for
-  // indirecting so that we can move around base pointers. A fixup would be
-  // possible as well by launching a kernel that rebased the embedded pointers
-  // (though trickier). For now we assume the block pool block size is a big
-  // enough lever and most programs only use a handful of command buffers so
-  // the waste per command buffer is minimal (compared to a single layer weight
-  // in an ML model).
-  //
-  // Compacts the command buffer when recording ends by reallocating it to the
-  // precise size required and reuploads it to each device. This will return any
-  // block pool blocks back to their respective pool for reuse and ensure
-  // there's no unused device memory - the cost is extra host time to do the
-  // reallocation/copies.
-  // IREE_HAL_AMDGPU_COMMAND_BUFFER_RECORDING_FLAG_COMPACT_ON_FINALIZE
-  // = 1u << 1,
-} iree_hal_amdgpu_command_buffer_recording_flags_t;
-
-// TODO(benvanik): move this someplace common.
-//
-// Block pools for host memory blocks of various sizes.
-typedef struct iree_hal_amdgpu_host_block_pools_t {
-  // Used for small allocations of around 1-4KB.
-  iree_arena_block_pool_t small;
-  // Used for large page-sized allocations of 32-64kB.
-  iree_arena_block_pool_t large;
-} iree_hal_amdgpu_host_block_pools_t;
-
-// Minimum number of AQL packets in a single command buffer block.
-// Any fewer and it's not guaranteed a command buffer can complete execution.
-#define IREE_HAL_AMDGPU_COMMAND_BUFFER_MIN_BLOCK_AQL_PACKET_COUNT (16)
-
-// Maximum number of AQL packets in a single command buffer block.
-// This is currently limited by the `uint16_t packet_offset` in
-// iree_hal_amdgpu_device_cmd_header_t.
-//
-// TODO(benvanik): currently we also limit this by tracy's outstanding GPU event
-// limit. If we made our own timeline (which we really need to for concurrency)
-// then we could eliminate this artificial limit.
-#define IREE_HAL_AMDGPU_COMMAND_BUFFER_MAX_BLOCK_AQL_PACKET_COUNT            \
-  IREE_AMDGPU_MIN(IREE_HAL_AMDGPU_DEVICE_QUERY_RINGBUFFER_CAPACITY,          \
-                  (1u << sizeof(((iree_hal_amdgpu_device_cmd_header_t*)NULL) \
-                                    ->packet_offset) *                       \
-                             8))
-
-// Recording options for a command buffer.
-// Referenced data structures such as block pools must remain live for the
-// lifetime of the command buffer but the options struct and its storage (such
-// as the device block pool list) need not.
-typedef struct iree_hal_amdgpu_command_buffer_options_t {
-  iree_hal_allocator_t* device_allocator;
-  iree_hal_command_buffer_mode_t mode;
-  iree_hal_command_category_t command_categories;
-  iree_hal_queue_affinity_t queue_affinity;
-  iree_host_size_t binding_capacity;
-
-  // Controls recording behavior (placement, optimization, debugging, etc).
-  iree_hal_amdgpu_command_buffer_recording_flags_t recording_flags;
-
-  // Maximum number of AQL packets the command buffer is allowed to issue at
-  // a time. Must be at or under the HSA queue capacity of any execution queue
-  // the command buffer will be scheduled on. The command buffer may decide to
-  // use fewer packets.
-  iree_host_size_t block_aql_packet_count;
-
-  // Block pools for host-only (heap) memory blocks of various sizes.
-  iree_hal_amdgpu_host_block_pools_t* host_block_pools;
-
-  // Bitmap of physical devices that the command buffer will be recorded for.
-  // The command buffer can only be issued on these devices.
-  iree_hal_amdgpu_device_affinity_t device_affinity;
-
-  // Compact list of physical device block pools corresponding to the bits set
-  // in the device_affinity bitmap. A device affinity of 0b110 would lead to two
-  // device block pools in the list at [0] and [1].
-  //
-  // These pools should be allocated from coarse-grained memory as once we
-  // record command buffers we will never change them again and do not need any
-  // synchronization.
-  iree_hal_amdgpu_block_pools_t* const* device_block_pools /*[device_count]*/;
-} iree_hal_amdgpu_command_buffer_options_t;
-
-// Initializes |out_options| to its default values.
-void iree_hal_amdgpu_command_buffer_options_initialize(
-    iree_hal_allocator_t* device_allocator, iree_hal_command_buffer_mode_t mode,
-    iree_hal_command_category_t command_categories,
-    iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity,
-    iree_hal_amdgpu_command_buffer_options_t* out_options);
-
-// Verifies command buffer options to ensure they meet the requirements of the
-// devices the command buffer will be scheduled on.
-iree_status_t iree_hal_amdgpu_command_buffer_options_verify(
-    const iree_hal_amdgpu_command_buffer_options_t* options);
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_command_buffer_t
-//===----------------------------------------------------------------------===//
-
-// Creates an AMDGPU command buffer with the given |options| controlling how
-// it is recorded and prepared for execution.
-//
-// Referenced data structures in the options such as block pools must remain
-// live for the lifetime of the command buffer.
-iree_status_t iree_hal_amdgpu_command_buffer_create(
-    const iree_hal_amdgpu_command_buffer_options_t* options,
-    iree_allocator_t host_allocator,
-    iree_hal_command_buffer_t** out_command_buffer);
-
-// Returns true if |command_buffer| is a AMDGPU command buffer.
-bool iree_hal_amdgpu_command_buffer_isa(
-    iree_hal_command_buffer_t* command_buffer);
-
-// Queries the device-side command buffer representation for the GPU device
-// agent with |device_ordinal| in the system topology.
-// |out_max_kernarg_capacity| will be set to the minimum required kernarg
-// reservation used by any block in the command buffer.
-iree_status_t iree_hal_amdgpu_command_buffer_query_execution_state(
-    iree_hal_command_buffer_t* command_buffer, iree_host_size_t device_ordinal,
-    IREE_AMDGPU_DEVICE_PTR iree_hal_amdgpu_device_command_buffer_t**
-        out_device_command_buffer,
-    iree_host_size_t* out_max_kernarg_capacity);
-
-#endif  // IREE_HAL_DRIVERS_AMDGPU_COMMAND_BUFFER_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/cts/backends.cc b/runtime/src/iree/hal/drivers/amdgpu/cts/backends.cc
index 1b05253..b2690b3 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/cts/backends.cc
+++ b/runtime/src/iree/hal/drivers/amdgpu/cts/backends.cc
@@ -18,7 +18,7 @@
   iree_status_t status = iree_hal_amdgpu_driver_module_register(
       iree_hal_driver_registry_default());
   if (iree_status_is_already_exists(status)) {
-    (void)iree_status_consume_code(status);
+    iree_status_free(status);
     status = iree_ok_status();
   }
 
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/BUILD.bazel b/runtime/src/iree/hal/drivers/amdgpu/device/BUILD.bazel
index 92588af..2cc367a 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/device/BUILD.bazel
+++ b/runtime/src/iree/hal/drivers/amdgpu/device/BUILD.bazel
@@ -4,9 +4,7 @@
 # See https://llvm.org/LICENSE.txt for license information.
 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-load("//build_tools/bazel:build_defs.oss.bzl", "iree_runtime_cc_library")
-load("//build_tools/bazel:iree_amdgpu_binary.bzl", "iree_amdgpu_binary")
-load("//build_tools/embed_data:build_defs.bzl", "iree_c_embed_data")
+load("//build_tools/bazel:build_defs.oss.bzl", "iree_runtime_cc_library", "iree_runtime_cc_test")
 
 package(
     default_visibility = ["//visibility:public"],
@@ -14,19 +12,35 @@
     licenses = ["notice"],  # Apache 2.0
 )
 
-#===------------------------------------------------------------------------===#
-# Common sources
-#===------------------------------------------------------------------------===#
+DEVICE_SRCS = [
+    "blit.c",
+    "dispatch.c",
+    "timestamp.c",
+]
 
-BITCODE_SRCS = glob([
-    "*.c",
-    "support/*.c",
-])
+DEVICE_HDRS = [
+    "blit.h",
+    "dispatch.h",
+    "kernel_tables.h",
+    "kernels.h",
+    "support/common.h",
+    "support/kernel.h",
+    "support/queue.h",
+    "support/signal.h",
+    "timestamp.h",
+]
 
-BITCODE_HDRS = glob([
-    "*.h",
-    "support/*.h",
-])
+filegroup(
+    name = "bitcode_srcs",
+    srcs = DEVICE_SRCS,
+    visibility = ["//runtime/src/iree/hal/drivers/amdgpu/device/binaries:__pkg__"],
+)
+
+filegroup(
+    name = "bitcode_hdrs",
+    srcs = DEVICE_HDRS,
+    visibility = ["//runtime/src/iree/hal/drivers/amdgpu/device/binaries:__pkg__"],
+)
 
 #===------------------------------------------------------------------------===#
 # Exported Headers
@@ -34,41 +48,91 @@
 
 iree_runtime_cc_library(
     name = "headers",
-    hdrs = BITCODE_HDRS,
-)
-
-#===------------------------------------------------------------------------===#
-# Architecture-specific Binaries
-#===------------------------------------------------------------------------===#
-# NOTE: the naming here matches what HSA_ISA_INFO_NAME returns so that we can
-# match them at runtime without having to load and reflect each code object.
-
-# TODO(benvanik): when TheRock stabilizes its naming convention we'll want to
-# copy that and make it configurable. See:
-# https://github.com/ROCm/TheRock/blob/main/cmake/therock_amdgpu_targets.cmake
-# Matching their family naming scheme would allow us to directly source from
-# their command line arguments. How best to map this to bazel I don't know, so
-# for now we include a hand-picked set that people using bazel request.
-
-iree_amdgpu_binary(
-    name = "amdgcn-amd-amdhsa--gfx1100",
-    srcs = BITCODE_SRCS,
-    arch = "gfx1100",
-    internal_hdrs = BITCODE_HDRS,
-    target = "amdgcn-amd-amdhsa",
-)
-
-#===------------------------------------------------------------------------===#
-# Embedded Binary Table
-#===------------------------------------------------------------------------===#
-
-iree_c_embed_data(
-    name = "binaries",
-    srcs = [
-        ":amdgcn-amd-amdhsa--gfx1100.so",
+    hdrs = DEVICE_HDRS,
+    deps = [
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/base/internal",
+        "//runtime/src/iree/base/threading:thread",
+        "//runtime/src/iree/hal/drivers/amdgpu/abi",
     ],
-    c_file_output = "binaries.c",
-    flatten = True,
-    h_file_output = "binaries.h",
-    identifier = "iree_hal_amdgpu_device_binaries",
+)
+
+iree_runtime_cc_library(
+    name = "blit",
+    srcs = [
+        "blit.c",
+    ],
+    deps = [
+        ":headers",
+    ],
+)
+
+iree_runtime_cc_library(
+    name = "dispatch",
+    srcs = [
+        "dispatch.c",
+    ],
+    deps = [
+        ":headers",
+    ],
+)
+
+iree_runtime_cc_library(
+    name = "timestamp",
+    srcs = [
+        "timestamp.c",
+    ],
+    deps = [
+        ":dispatch",
+        ":headers",
+    ],
+)
+
+iree_runtime_cc_test(
+    name = "blit_test",
+    srcs = ["blit_test.cc"],
+    group = "iree-hal-drivers-amdgpu-tests",
+    tags = [
+        "driver=amdgpu",
+        "nodocker",
+    ],
+    deps = [
+        ":blit",
+        ":headers",
+        "//runtime/src/iree/testing:gtest",
+        "//runtime/src/iree/testing:gtest_main",
+    ],
+)
+
+iree_runtime_cc_test(
+    name = "dispatch_test",
+    srcs = ["dispatch_test.cc"],
+    group = "iree-hal-drivers-amdgpu-tests",
+    tags = [
+        "driver=amdgpu",
+        "nodocker",
+    ],
+    deps = [
+        ":dispatch",
+        ":headers",
+        "//runtime/src/iree/testing:gtest",
+        "//runtime/src/iree/testing:gtest_main",
+    ],
+)
+
+iree_runtime_cc_test(
+    name = "timestamp_test",
+    srcs = ["timestamp_test.cc"],
+    group = "iree-hal-drivers-amdgpu-tests",
+    tags = [
+        "driver=amdgpu",
+        "nodocker",
+    ],
+    deps = [
+        ":headers",
+        ":timestamp",
+        "//runtime/src/iree/hal/drivers/amdgpu/abi",
+        "//runtime/src/iree/testing:gtest",
+        "//runtime/src/iree/testing:gtest_main",
+    ],
 )
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/CMakeLists.txt b/runtime/src/iree/hal/drivers/amdgpu/device/CMakeLists.txt
index 98fe713..0c07ddb 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/device/CMakeLists.txt
+++ b/runtime/src/iree/hal/drivers/amdgpu/device/CMakeLists.txt
@@ -1,99 +1,147 @@
-# Copyright 2025 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from           #
+# runtime/src/iree/hal/drivers/amdgpu/device/BUILD.bazel                       #
+#                                                                              #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary   #
+# CMake-only content.                                                          #
+#                                                                              #
+# To disable autogeneration for this file entirely, delete this header.        #
+################################################################################
 
-#===------------------------------------------------------------------------===#
-# Common Sources
-#===------------------------------------------------------------------------===#
+iree_add_all_subdirs()
 
-set(_BITCODE_SRCS
-  "blit.c"
-  "buffer.c"
-  "command_buffer.c"
-  "host_client.c"
-  "semaphore.c"
-  "tracing.c"
+add_custom_command(OUTPUT bitcode_srcs.stamp
+    COMMAND ${CMAKE_COMMAND} -E touch bitcode_srcs.stamp
+  DEPENDS
+    "blit.c"
+    "dispatch.c"
+    "timestamp.c"
 )
 
-set(_BITCODE_HDRS
-  "blit.h"
-  "buffer.h"
-  "command_buffer.h"
-  "kernel_tables.h"
-  "kernels.h"
-  "host_client.h"
-  "semaphore.h"
-  "tracing.h"
-  "support/common.h"
-  "support/kernel_args.h"
-  "support/mutex.h"
-  "support/queue.h"
-  "support/signal.h"
+add_custom_target(bitcode_srcs
+    DEPENDS bitcode_srcs.stamp
 )
 
-#===------------------------------------------------------------------------===#
-# Exported Headers
-#===------------------------------------------------------------------------===#
+add_custom_command(OUTPUT bitcode_hdrs.stamp
+    COMMAND ${CMAKE_COMMAND} -E touch bitcode_hdrs.stamp
+  DEPENDS
+    "blit.h"
+    "dispatch.h"
+    "kernel_tables.h"
+    "kernels.h"
+    "support/common.h"
+    "support/kernel.h"
+    "support/queue.h"
+    "support/signal.h"
+    "timestamp.h"
+)
+
+add_custom_target(bitcode_hdrs
+    DEPENDS bitcode_hdrs.stamp
+)
 
 iree_cc_library(
   NAME
     headers
   HDRS
-    "${_BITCODE_HDRS}"
+    "blit.h"
+    "dispatch.h"
+    "kernel_tables.h"
+    "kernels.h"
+    "support/common.h"
+    "support/kernel.h"
+    "support/queue.h"
+    "support/signal.h"
+    "timestamp.h"
+  DEPS
+    iree::base
+    iree::base::internal
+    iree::base::threading::thread
+    iree::hal::drivers::amdgpu::abi
   PUBLIC
 )
 
-#===------------------------------------------------------------------------===#
-# Architecture-specific Binaries
-#===------------------------------------------------------------------------===#
-# NOTE: the naming here matches what HSA_ISA_INFO_NAME returns so that we can
-# match them at runtime without having to load and reflect each code object.
-
-# TODO(benvanik): when TheRock stabilizes its naming convention we'll want to
-# copy that and make it configurable. See:
-# https://github.com/ROCm/TheRock/blob/main/cmake/therock_amdgpu_targets.cmake
-# Matching their family naming scheme would allow us to directly source from
-# their command line arguments.
-
-set(IREE_HAL_AMDGPU_DEVICE_LIBRARY_TARGETS
-    "gfx942;gfx1100"
-    CACHE STRING
-    "Bundled device library architectures included in the runtime binary.")
-
-set(_ARCH_BINARIES)
-foreach(_ARCH ${IREE_HAL_AMDGPU_DEVICE_LIBRARY_TARGETS})
-  iree_amdgpu_binary(
-    NAME
-      amdgcn-amd-amdhsa--${_ARCH}
-    TARGET
-      amdgcn-amd-amdhsa
-    ARCH
-      ${_ARCH}
-    SRCS
-      "${_BITCODE_SRCS}"
-    INTERNAL_HDRS
-      "${_BITCODE_HDRS}"
-  )
-  list(APPEND _ARCH_BINARIES "amdgcn-amd-amdhsa--${_ARCH}.so")
-endforeach()
-
-#===------------------------------------------------------------------------===#
-# Embedded Binary Table
-#===------------------------------------------------------------------------===#
-
-iree_c_embed_data(
+iree_cc_library(
   NAME
-    binaries
+    blit
   SRCS
-    "${_ARCH_BINARIES}"
-  C_FILE_OUTPUT
-    "binaries.c"
-  H_FILE_OUTPUT
-    "binaries.h"
-  IDENTIFIER
-    "iree_hal_amdgpu_device_binaries"
-  FLATTEN
+    "blit.c"
+  DEPS
+    ::headers
   PUBLIC
 )
+
+iree_cc_library(
+  NAME
+    dispatch
+  SRCS
+    "dispatch.c"
+  DEPS
+    ::headers
+  PUBLIC
+)
+
+iree_cc_library(
+  NAME
+    timestamp
+  SRCS
+    "timestamp.c"
+  DEPS
+    ::dispatch
+    ::headers
+  PUBLIC
+)
+
+iree_cc_test(
+  NAME
+    blit_test
+  SRCS
+    "blit_test.cc"
+  DEPS
+    ::blit
+    ::headers
+    iree::testing::gtest
+    iree::testing::gtest_main
+  LABELS
+    "driver=amdgpu"
+    "nodocker"
+  GROUP
+    "iree-hal-drivers-amdgpu-tests"
+)
+
+iree_cc_test(
+  NAME
+    dispatch_test
+  SRCS
+    "dispatch_test.cc"
+  DEPS
+    ::dispatch
+    ::headers
+    iree::testing::gtest
+    iree::testing::gtest_main
+  LABELS
+    "driver=amdgpu"
+    "nodocker"
+  GROUP
+    "iree-hal-drivers-amdgpu-tests"
+)
+
+iree_cc_test(
+  NAME
+    timestamp_test
+  SRCS
+    "timestamp_test.cc"
+  DEPS
+    ::headers
+    ::timestamp
+    iree::hal::drivers::amdgpu::abi
+    iree::testing::gtest
+    iree::testing::gtest_main
+  LABELS
+    "driver=amdgpu"
+    "nodocker"
+  GROUP
+    "iree-hal-drivers-amdgpu-tests"
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/binaries/BUILD.bazel b/runtime/src/iree/hal/drivers/amdgpu/device/binaries/BUILD.bazel
new file mode 100644
index 0000000..c7c716c
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/device/binaries/BUILD.bazel
@@ -0,0 +1,37 @@
+# Copyright 2026 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load(":targets.bzl", "iree_hal_amdgpu_device_binaries")
+
+package(
+    default_visibility = ["//visibility:public"],
+    features = ["layering_check"],
+    licenses = ["notice"],
+)
+
+BITCODE_SRCS = [
+    "//runtime/src/iree/hal/drivers/amdgpu/device:bitcode_srcs",
+]
+
+ABI_HDRS = [
+    "//runtime/src/iree/hal/drivers/amdgpu/abi:command_buffer.h",
+    "//runtime/src/iree/hal/drivers/amdgpu/abi:common.h",
+    "//runtime/src/iree/hal/drivers/amdgpu/abi:kernel_args.h",
+    "//runtime/src/iree/hal/drivers/amdgpu/abi:profile.h",
+    "//runtime/src/iree/hal/drivers/amdgpu/abi:queue.h",
+    "//runtime/src/iree/hal/drivers/amdgpu/abi:signal.h",
+    "//runtime/src/iree/hal/drivers/amdgpu/abi:timestamp.h",
+]
+
+BITCODE_HDRS = [
+    "//runtime/src/iree/hal/drivers/amdgpu/device:bitcode_hdrs",
+] + ABI_HDRS
+
+iree_hal_amdgpu_device_binaries(
+    name = "toc",
+    srcs = BITCODE_SRCS,
+    internal_hdrs = BITCODE_HDRS,
+)
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/binaries/CMakeLists.txt b/runtime/src/iree/hal/drivers/amdgpu/device/binaries/CMakeLists.txt
new file mode 100644
index 0000000..b4a541e
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/device/binaries/CMakeLists.txt
@@ -0,0 +1,119 @@
+# Copyright 2026 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# Builtin AMDGPU device libraries embedded into the runtime.
+
+set(IREE_HAL_AMDGPU_DEVICE_LIBRARY_TARGETS
+    "all"
+    CACHE STRING
+    "AMDGPU device library targets, LLVM generic ISA targets, or TheRock-style target families to embed.")
+
+include("${CMAKE_CURRENT_LIST_DIR}/target_map.cmake")
+
+function(_iree_hal_amdgpu_device_target_family_var out_var family)
+  string(MAKE_C_IDENTIFIER "${family}" _family_identifier)
+  set(${out_var} "_IREE_HAL_AMDGPU_DEVICE_TARGET_FAMILY_${_family_identifier}" PARENT_SCOPE)
+endfunction()
+
+function(_iree_hal_amdgpu_device_target_code_object out_var target)
+  set(${out_var} "${_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_${target}}" PARENT_SCOPE)
+endfunction()
+
+function(_iree_hal_amdgpu_expand_device_library_targets out_targets)
+  set(_expanded_targets)
+  foreach(_selection ${IREE_HAL_AMDGPU_DEVICE_LIBRARY_TARGETS})
+    if("${_selection}" IN_LIST _IREE_HAL_AMDGPU_DEVICE_CODE_OBJECT_TARGETS)
+      list(APPEND _expanded_targets "${_selection}")
+    elseif("${_selection}" IN_LIST _IREE_HAL_AMDGPU_DEVICE_TARGETS)
+      _iree_hal_amdgpu_device_target_code_object(_code_object_target "${_selection}")
+      list(APPEND _expanded_targets "${_code_object_target}")
+    elseif("${_selection}" IN_LIST _IREE_HAL_AMDGPU_DEVICE_TARGET_FAMILIES)
+      _iree_hal_amdgpu_device_target_family_var(_family_var "${_selection}")
+      foreach(_exact_target ${${_family_var}})
+        _iree_hal_amdgpu_device_target_code_object(_code_object_target "${_exact_target}")
+        list(APPEND _expanded_targets "${_code_object_target}")
+      endforeach()
+    else()
+      set(_available_selections
+        ${_IREE_HAL_AMDGPU_DEVICE_TARGETS}
+        ${_IREE_HAL_AMDGPU_DEVICE_CODE_OBJECT_TARGETS}
+        ${_IREE_HAL_AMDGPU_DEVICE_TARGET_FAMILIES}
+      )
+      list(REMOVE_DUPLICATES _available_selections)
+      list(SORT _available_selections)
+      string(JOIN " " _available_pretty ${_available_selections})
+      message(FATAL_ERROR
+        "Unknown AMDGPU device library target or family '${_selection}'. "
+        "Available: ${_available_pretty}"
+      )
+    endif()
+  endforeach()
+  list(REMOVE_DUPLICATES _expanded_targets)
+  set(${out_targets} "${_expanded_targets}" PARENT_SCOPE)
+endfunction()
+
+set(_BITCODE_SRCS
+  "../blit.c"
+  "../dispatch.c"
+  "../timestamp.c"
+)
+
+set(_ABI_HDRS
+  "${CMAKE_CURRENT_SOURCE_DIR}/../../abi/command_buffer.h"
+  "${CMAKE_CURRENT_SOURCE_DIR}/../../abi/common.h"
+  "${CMAKE_CURRENT_SOURCE_DIR}/../../abi/kernel_args.h"
+  "${CMAKE_CURRENT_SOURCE_DIR}/../../abi/profile.h"
+  "${CMAKE_CURRENT_SOURCE_DIR}/../../abi/queue.h"
+  "${CMAKE_CURRENT_SOURCE_DIR}/../../abi/signal.h"
+  "${CMAKE_CURRENT_SOURCE_DIR}/../../abi/timestamp.h"
+)
+
+set(_BITCODE_HDRS
+  "${CMAKE_CURRENT_SOURCE_DIR}/../blit.h"
+  "${CMAKE_CURRENT_SOURCE_DIR}/../dispatch.h"
+  "${CMAKE_CURRENT_SOURCE_DIR}/../kernel_tables.h"
+  "${CMAKE_CURRENT_SOURCE_DIR}/../kernels.h"
+  "${CMAKE_CURRENT_SOURCE_DIR}/../support/common.h"
+  "${CMAKE_CURRENT_SOURCE_DIR}/../support/kernel.h"
+  "${CMAKE_CURRENT_SOURCE_DIR}/../support/queue.h"
+  "${CMAKE_CURRENT_SOURCE_DIR}/../support/signal.h"
+  "${CMAKE_CURRENT_SOURCE_DIR}/../timestamp.h"
+  ${_ABI_HDRS}
+)
+
+_iree_hal_amdgpu_expand_device_library_targets(_DEVICE_LIBRARY_TARGETS)
+
+set(_ARCH_BINARIES)
+foreach(_ARCH ${_DEVICE_LIBRARY_TARGETS})
+  iree_amdgpu_binary(
+    NAME
+      amdgcn-amd-amdhsa--${_ARCH}
+    TARGET
+      amdgcn-amd-amdhsa
+    ARCH
+      ${_ARCH}
+    SRCS
+      "${_BITCODE_SRCS}"
+    INTERNAL_HDRS
+      "${_BITCODE_HDRS}"
+  )
+  list(APPEND _ARCH_BINARIES "amdgcn-amd-amdhsa--${_ARCH}.so")
+endforeach()
+
+iree_c_embed_data(
+  NAME
+    toc
+  SRCS
+    "${_ARCH_BINARIES}"
+  C_FILE_OUTPUT
+    "toc.c"
+  H_FILE_OUTPUT
+    "toc.h"
+  IDENTIFIER
+    "iree_hal_amdgpu_device_binaries"
+  FLATTEN
+  PUBLIC
+)
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/binaries/README.md b/runtime/src/iree/hal/drivers/amdgpu/device/binaries/README.md
new file mode 100644
index 0000000..e12cb00
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/device/binaries/README.md
@@ -0,0 +1,70 @@
+# AMDGPU Device Binaries
+
+This package builds the small device-side AMDGPU support library embedded into
+the HAL runtime. These code objects contain blit kernels and runtime utility
+kernels; they are not math libraries, so the build intentionally targets LLVM
+generic ISA processors wherever LLVM documents a compatible generic target.
+
+The public selector vocabulary mirrors ROCm/TheRock so users can request the
+same target families in both builds. The source of truth for that family naming
+is TheRock's
+[`cmake/therock_amdgpu_targets.cmake`](https://github.com/ROCm/TheRock/blob/main/cmake/therock_amdgpu_targets.cmake).
+The source of truth for generic ISA coverage is LLVM's AMDGPU generic processor
+documentation and tablegen data under `third_party/llvm-project/llvm/`.
+
+## Support Mechanism
+
+The single checked-in target map lives in
+`build_tools/scripts/amdgpu_target_map.py`. It records:
+
+- exact HSA ISA architecture suffixes, such as `gfx1100`;
+- the code object target to compile for each exact architecture, such as
+  `gfx11-generic`;
+- TheRock-style selectors, such as `gfx110X-all`, `dgpu-all`, and `igpu-all`.
+
+Running the script emits small generated fragments:
+
+- `target_map.bzl`, loaded by `targets.bzl` for Bazel selector expansion;
+- `target_map.cmake`, included by `CMakeLists.txt` for CMake selector expansion;
+- `runtime/src/iree/hal/drivers/amdgpu/util/target_id_map.inl`, included by
+  `target_id.c` so runtime ISA lookup uses the same exact-to-code-object map as
+  the build.
+
+The generated files are checked in. Pre-commit runs
+`python build_tools/scripts/amdgpu_target_map.py --check` so CI catches drift
+between the Python map and the generated fragments.
+
+## Current Generic-Family Audit
+
+The current map intentionally includes generic code-object coverage for the
+modern families LLVM documents and the ROCm/TheRock selector vocabulary names:
+
+| Selector family | Exact targets | Code-object target |
+| --- | --- | --- |
+| `gfx9-4` CDNA | `gfx940`, `gfx941`, `gfx942`, `gfx950` | `gfx9-4-generic` |
+| `gfx11` RDNA/APU | `gfx1100`, `gfx1101`, `gfx1102`, `gfx1103`, `gfx1150`, `gfx1151`, `gfx1152`, `gfx1153`, `gfx1170`, `gfx1171`, `gfx1172` | `gfx11-generic` |
+| `gfx12` RDNA | `gfx1200`, `gfx1201` | `gfx12-generic` |
+| `gfx12.5` RDNA | `gfx1250`, `gfx1251` | `gfx12-5-generic` |
+
+Targets outside this table should fail selection loudly until LLVM documents
+their code-object compatibility and the embedded support library has been
+compiled for the required exact or generic processor.
+
+## Adding An Architecture
+
+When a new AMDGPU architecture is supported:
+
+1. Check TheRock's `therock_amdgpu_targets.cmake` for the selector/family
+   vocabulary and product-family membership.
+2. Check LLVM's AMDGPU generic processor docs/tablegen files for the matching
+   generic ISA target. If LLVM does not document generic coverage for the new
+   exact ISA, keep the code object exact until that support exists.
+3. Update `EXACT_TARGET_CODE_OBJECTS` and `TARGET_FAMILIES` in
+   `build_tools/scripts/amdgpu_target_map.py`.
+4. Run `python build_tools/scripts/amdgpu_target_map.py`.
+5. Run `buildifier`, `clang-format`, the target-map pre-commit check, and the
+   focused AMDGPU device-library build/test targets.
+
+Do not hand-edit `target_map.bzl`, `target_map.cmake`, or `target_id_map.inl`.
+Do not add a parallel table in `device_library.c`; the runtime loader consumes
+`target_id` helpers directly.
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/binaries/target_map.bzl b/runtime/src/iree/hal/drivers/amdgpu/device/binaries/target_map.bzl
new file mode 100644
index 0000000..820c0d3
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/device/binaries/target_map.bzl
@@ -0,0 +1,259 @@
+# Generated by build_tools/scripts/amdgpu_target_map.py.
+# Do not edit directly; edit the map in that script and regenerate.
+# Output: runtime/src/iree/hal/drivers/amdgpu/device/binaries/target_map.bzl
+
+IREE_HAL_AMDGPU_DEVICE_LIBRARY_DEFAULT_TARGETS = [
+    "all",
+]
+
+IREE_HAL_AMDGPU_DEVICE_LIBRARY_EXACT_TARGETS = [
+    "gfx900",
+    "gfx902",
+    "gfx904",
+    "gfx90c",
+    "gfx906",
+    "gfx908",
+    "gfx909",
+    "gfx90a",
+    "gfx940",
+    "gfx941",
+    "gfx942",
+    "gfx950",
+    "gfx1010",
+    "gfx1011",
+    "gfx1012",
+    "gfx1013",
+    "gfx1030",
+    "gfx1031",
+    "gfx1032",
+    "gfx1033",
+    "gfx1034",
+    "gfx1035",
+    "gfx1036",
+    "gfx1100",
+    "gfx1101",
+    "gfx1102",
+    "gfx1103",
+    "gfx1150",
+    "gfx1151",
+    "gfx1152",
+    "gfx1153",
+    "gfx1170",
+    "gfx1171",
+    "gfx1172",
+    "gfx1200",
+    "gfx1201",
+    "gfx1250",
+    "gfx1251",
+]
+
+IREE_HAL_AMDGPU_DEVICE_LIBRARY_CODE_OBJECT_TARGETS = [
+    "gfx9-generic",
+    "gfx908",
+    "gfx90a",
+    "gfx9-4-generic",
+    "gfx10-1-generic",
+    "gfx10-3-generic",
+    "gfx11-generic",
+    "gfx12-generic",
+    "gfx12-5-generic",
+]
+
+IREE_HAL_AMDGPU_DEVICE_LIBRARY_EXACT_TARGET_CODE_OBJECTS = {
+    "gfx900": "gfx9-generic",
+    "gfx902": "gfx9-generic",
+    "gfx904": "gfx9-generic",
+    "gfx90c": "gfx9-generic",
+    "gfx906": "gfx9-generic",
+    "gfx908": "gfx908",
+    "gfx909": "gfx9-generic",
+    "gfx90a": "gfx90a",
+    "gfx940": "gfx9-4-generic",
+    "gfx941": "gfx9-4-generic",
+    "gfx942": "gfx9-4-generic",
+    "gfx950": "gfx9-4-generic",
+    "gfx1010": "gfx10-1-generic",
+    "gfx1011": "gfx10-1-generic",
+    "gfx1012": "gfx10-1-generic",
+    "gfx1013": "gfx10-1-generic",
+    "gfx1030": "gfx10-3-generic",
+    "gfx1031": "gfx10-3-generic",
+    "gfx1032": "gfx10-3-generic",
+    "gfx1033": "gfx10-3-generic",
+    "gfx1034": "gfx10-3-generic",
+    "gfx1035": "gfx10-3-generic",
+    "gfx1036": "gfx10-3-generic",
+    "gfx1100": "gfx11-generic",
+    "gfx1101": "gfx11-generic",
+    "gfx1102": "gfx11-generic",
+    "gfx1103": "gfx11-generic",
+    "gfx1150": "gfx11-generic",
+    "gfx1151": "gfx11-generic",
+    "gfx1152": "gfx11-generic",
+    "gfx1153": "gfx11-generic",
+    "gfx1170": "gfx11-generic",
+    "gfx1171": "gfx11-generic",
+    "gfx1172": "gfx11-generic",
+    "gfx1200": "gfx12-generic",
+    "gfx1201": "gfx12-generic",
+    "gfx1250": "gfx12-5-generic",
+    "gfx1251": "gfx12-5-generic",
+}
+
+IREE_HAL_AMDGPU_DEVICE_LIBRARY_TARGET_FAMILY_NAMES = [
+    "all",
+    "dcgpu-all",
+    "dgpu-all",
+    "gfx900-dgpu",
+    "gfx906-dgpu",
+    "gfx908-dcgpu",
+    "gfx90a-dcgpu",
+    "gfx90c-igpu",
+    "gfx94X-all",
+    "gfx94X-dcgpu",
+    "gfx950-all",
+    "gfx950-dcgpu",
+    "gfx101X-all",
+    "gfx101X-dgpu",
+    "gfx103X-all",
+    "gfx103X-dgpu",
+    "gfx103X-igpu",
+    "gfx110X-all",
+    "gfx110X-dgpu",
+    "gfx110X-igpu",
+    "gfx115X-all",
+    "gfx115X-igpu",
+    "gfx117X-all",
+    "gfx120X-all",
+    "gfx125X-all",
+    "igpu-all",
+]
+
+IREE_HAL_AMDGPU_DEVICE_LIBRARY_TARGET_FAMILIES = {
+    "all": IREE_HAL_AMDGPU_DEVICE_LIBRARY_EXACT_TARGETS,
+    "dcgpu-all": [
+        "gfx908",
+        "gfx90a",
+        "gfx940",
+        "gfx941",
+        "gfx942",
+        "gfx950",
+    ],
+    "dgpu-all": [
+        "gfx900",
+        "gfx902",
+        "gfx904",
+        "gfx906",
+        "gfx909",
+        "gfx1010",
+        "gfx1011",
+        "gfx1012",
+        "gfx1013",
+        "gfx1030",
+        "gfx1031",
+        "gfx1032",
+        "gfx1034",
+        "gfx1100",
+        "gfx1101",
+        "gfx1102",
+        "gfx1200",
+        "gfx1201",
+    ],
+    "gfx900-dgpu": ["gfx900"],
+    "gfx906-dgpu": ["gfx906"],
+    "gfx908-dcgpu": ["gfx908"],
+    "gfx90a-dcgpu": ["gfx90a"],
+    "gfx90c-igpu": ["gfx90c"],
+    "gfx94X-all": [
+        "gfx940",
+        "gfx941",
+        "gfx942",
+    ],
+    "gfx94X-dcgpu": [
+        "gfx940",
+        "gfx941",
+        "gfx942",
+    ],
+    "gfx950-all": ["gfx950"],
+    "gfx950-dcgpu": ["gfx950"],
+    "gfx101X-all": [
+        "gfx1010",
+        "gfx1011",
+        "gfx1012",
+        "gfx1013",
+    ],
+    "gfx101X-dgpu": [
+        "gfx1010",
+        "gfx1011",
+        "gfx1012",
+        "gfx1013",
+    ],
+    "gfx103X-all": [
+        "gfx1030",
+        "gfx1031",
+        "gfx1032",
+        "gfx1033",
+        "gfx1034",
+        "gfx1035",
+        "gfx1036",
+    ],
+    "gfx103X-dgpu": [
+        "gfx1030",
+        "gfx1031",
+        "gfx1032",
+        "gfx1034",
+    ],
+    "gfx103X-igpu": [
+        "gfx1033",
+        "gfx1035",
+        "gfx1036",
+    ],
+    "gfx110X-all": [
+        "gfx1100",
+        "gfx1101",
+        "gfx1102",
+        "gfx1103",
+    ],
+    "gfx110X-dgpu": [
+        "gfx1100",
+        "gfx1101",
+        "gfx1102",
+    ],
+    "gfx110X-igpu": ["gfx1103"],
+    "gfx115X-all": [
+        "gfx1150",
+        "gfx1151",
+        "gfx1152",
+        "gfx1153",
+    ],
+    "gfx115X-igpu": [
+        "gfx1150",
+        "gfx1151",
+        "gfx1152",
+        "gfx1153",
+    ],
+    "gfx117X-all": [
+        "gfx1170",
+        "gfx1171",
+        "gfx1172",
+    ],
+    "gfx120X-all": [
+        "gfx1200",
+        "gfx1201",
+    ],
+    "gfx125X-all": [
+        "gfx1250",
+        "gfx1251",
+    ],
+    "igpu-all": [
+        "gfx90c",
+        "gfx1033",
+        "gfx1035",
+        "gfx1036",
+        "gfx1103",
+        "gfx1150",
+        "gfx1151",
+        "gfx1152",
+        "gfx1153",
+    ],
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/binaries/target_map.cmake b/runtime/src/iree/hal/drivers/amdgpu/device/binaries/target_map.cmake
new file mode 100644
index 0000000..42cb5ef
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/device/binaries/target_map.cmake
@@ -0,0 +1,269 @@
+# Generated by build_tools/scripts/amdgpu_target_map.py.
+# Do not edit directly; edit the map in that script and regenerate.
+# Output: runtime/src/iree/hal/drivers/amdgpu/device/binaries/target_map.cmake
+
+set(_IREE_HAL_AMDGPU_DEVICE_TARGETS
+  "gfx900"
+  "gfx902"
+  "gfx904"
+  "gfx90c"
+  "gfx906"
+  "gfx908"
+  "gfx909"
+  "gfx90a"
+  "gfx940"
+  "gfx941"
+  "gfx942"
+  "gfx950"
+  "gfx1010"
+  "gfx1011"
+  "gfx1012"
+  "gfx1013"
+  "gfx1030"
+  "gfx1031"
+  "gfx1032"
+  "gfx1033"
+  "gfx1034"
+  "gfx1035"
+  "gfx1036"
+  "gfx1100"
+  "gfx1101"
+  "gfx1102"
+  "gfx1103"
+  "gfx1150"
+  "gfx1151"
+  "gfx1152"
+  "gfx1153"
+  "gfx1170"
+  "gfx1171"
+  "gfx1172"
+  "gfx1200"
+  "gfx1201"
+  "gfx1250"
+  "gfx1251"
+)
+
+set(_IREE_HAL_AMDGPU_DEVICE_CODE_OBJECT_TARGETS
+  "gfx9-generic"
+  "gfx908"
+  "gfx90a"
+  "gfx9-4-generic"
+  "gfx10-1-generic"
+  "gfx10-3-generic"
+  "gfx11-generic"
+  "gfx12-generic"
+  "gfx12-5-generic"
+)
+
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx900 "gfx9-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx902 "gfx9-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx904 "gfx9-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx90c "gfx9-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx906 "gfx9-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx908 "gfx908")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx909 "gfx9-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx90a "gfx90a")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx940 "gfx9-4-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx941 "gfx9-4-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx942 "gfx9-4-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx950 "gfx9-4-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx1010 "gfx10-1-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx1011 "gfx10-1-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx1012 "gfx10-1-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx1013 "gfx10-1-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx1030 "gfx10-3-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx1031 "gfx10-3-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx1032 "gfx10-3-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx1033 "gfx10-3-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx1034 "gfx10-3-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx1035 "gfx10-3-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx1036 "gfx10-3-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx1100 "gfx11-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx1101 "gfx11-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx1102 "gfx11-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx1103 "gfx11-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx1150 "gfx11-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx1151 "gfx11-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx1152 "gfx11-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx1153 "gfx11-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx1170 "gfx11-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx1171 "gfx11-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx1172 "gfx11-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx1200 "gfx12-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx1201 "gfx12-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx1250 "gfx12-5-generic")
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_CODE_OBJECT_gfx1251 "gfx12-5-generic")
+
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_FAMILIES
+  "all"
+  "dcgpu-all"
+  "dgpu-all"
+  "gfx900-dgpu"
+  "gfx906-dgpu"
+  "gfx908-dcgpu"
+  "gfx90a-dcgpu"
+  "gfx90c-igpu"
+  "gfx94X-all"
+  "gfx94X-dcgpu"
+  "gfx950-all"
+  "gfx950-dcgpu"
+  "gfx101X-all"
+  "gfx101X-dgpu"
+  "gfx103X-all"
+  "gfx103X-dgpu"
+  "gfx103X-igpu"
+  "gfx110X-all"
+  "gfx110X-dgpu"
+  "gfx110X-igpu"
+  "gfx115X-all"
+  "gfx115X-igpu"
+  "gfx117X-all"
+  "gfx120X-all"
+  "gfx125X-all"
+  "igpu-all"
+)
+
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_FAMILY_all
+  ${_IREE_HAL_AMDGPU_DEVICE_TARGETS}
+)
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_FAMILY_dcgpu_all
+  "gfx908"
+  "gfx90a"
+  "gfx940"
+  "gfx941"
+  "gfx942"
+  "gfx950"
+)
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_FAMILY_dgpu_all
+  "gfx900"
+  "gfx902"
+  "gfx904"
+  "gfx906"
+  "gfx909"
+  "gfx1010"
+  "gfx1011"
+  "gfx1012"
+  "gfx1013"
+  "gfx1030"
+  "gfx1031"
+  "gfx1032"
+  "gfx1034"
+  "gfx1100"
+  "gfx1101"
+  "gfx1102"
+  "gfx1200"
+  "gfx1201"
+)
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_FAMILY_gfx900_dgpu
+  "gfx900"
+)
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_FAMILY_gfx906_dgpu
+  "gfx906"
+)
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_FAMILY_gfx908_dcgpu
+  "gfx908"
+)
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_FAMILY_gfx90a_dcgpu
+  "gfx90a"
+)
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_FAMILY_gfx90c_igpu
+  "gfx90c"
+)
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_FAMILY_gfx94X_all
+  "gfx940"
+  "gfx941"
+  "gfx942"
+)
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_FAMILY_gfx94X_dcgpu
+  "gfx940"
+  "gfx941"
+  "gfx942"
+)
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_FAMILY_gfx950_all
+  "gfx950"
+)
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_FAMILY_gfx950_dcgpu
+  "gfx950"
+)
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_FAMILY_gfx101X_all
+  "gfx1010"
+  "gfx1011"
+  "gfx1012"
+  "gfx1013"
+)
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_FAMILY_gfx101X_dgpu
+  "gfx1010"
+  "gfx1011"
+  "gfx1012"
+  "gfx1013"
+)
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_FAMILY_gfx103X_all
+  "gfx1030"
+  "gfx1031"
+  "gfx1032"
+  "gfx1033"
+  "gfx1034"
+  "gfx1035"
+  "gfx1036"
+)
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_FAMILY_gfx103X_dgpu
+  "gfx1030"
+  "gfx1031"
+  "gfx1032"
+  "gfx1034"
+)
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_FAMILY_gfx103X_igpu
+  "gfx1033"
+  "gfx1035"
+  "gfx1036"
+)
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_FAMILY_gfx110X_all
+  "gfx1100"
+  "gfx1101"
+  "gfx1102"
+  "gfx1103"
+)
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_FAMILY_gfx110X_dgpu
+  "gfx1100"
+  "gfx1101"
+  "gfx1102"
+)
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_FAMILY_gfx110X_igpu
+  "gfx1103"
+)
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_FAMILY_gfx115X_all
+  "gfx1150"
+  "gfx1151"
+  "gfx1152"
+  "gfx1153"
+)
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_FAMILY_gfx115X_igpu
+  "gfx1150"
+  "gfx1151"
+  "gfx1152"
+  "gfx1153"
+)
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_FAMILY_gfx117X_all
+  "gfx1170"
+  "gfx1171"
+  "gfx1172"
+)
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_FAMILY_gfx120X_all
+  "gfx1200"
+  "gfx1201"
+)
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_FAMILY_gfx125X_all
+  "gfx1250"
+  "gfx1251"
+)
+set(_IREE_HAL_AMDGPU_DEVICE_TARGET_FAMILY_igpu_all
+  "gfx90c"
+  "gfx1033"
+  "gfx1035"
+  "gfx1036"
+  "gfx1103"
+  "gfx1150"
+  "gfx1151"
+  "gfx1152"
+  "gfx1153"
+)
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/binaries/targets.bzl b/runtime/src/iree/hal/drivers/amdgpu/device/binaries/targets.bzl
new file mode 100644
index 0000000..4e9d860
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/device/binaries/targets.bzl
@@ -0,0 +1,140 @@
+# Copyright 2026 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+"""AMDGPU device library target selection."""
+
+load("@bazel_skylib//lib:selects.bzl", "selects")
+load("@bazel_skylib//rules:common_settings.bzl", "BuildSettingInfo")
+load("//build_tools/bazel:iree_amdgpu_binary.bzl", "iree_amdgpu_binary")
+load("//build_tools/embed_data:build_defs.bzl", "iree_c_embed_data")
+load(
+    ":target_map.bzl",
+    "IREE_HAL_AMDGPU_DEVICE_LIBRARY_CODE_OBJECT_TARGETS",
+    "IREE_HAL_AMDGPU_DEVICE_LIBRARY_DEFAULT_TARGETS",
+    "IREE_HAL_AMDGPU_DEVICE_LIBRARY_EXACT_TARGETS",
+    "IREE_HAL_AMDGPU_DEVICE_LIBRARY_EXACT_TARGET_CODE_OBJECTS",
+    "IREE_HAL_AMDGPU_DEVICE_LIBRARY_TARGET_FAMILIES",
+    "IREE_HAL_AMDGPU_DEVICE_LIBRARY_TARGET_FAMILY_NAMES",
+)
+
+def _append_unique(values, new_values):
+    for value in new_values:
+        if value not in values:
+            values.append(value)
+
+def _valid_selectors():
+    selectors = []
+    _append_unique(selectors, IREE_HAL_AMDGPU_DEVICE_LIBRARY_EXACT_TARGETS)
+    _append_unique(selectors, IREE_HAL_AMDGPU_DEVICE_LIBRARY_CODE_OBJECT_TARGETS)
+    _append_unique(selectors, IREE_HAL_AMDGPU_DEVICE_LIBRARY_TARGET_FAMILY_NAMES)
+    return selectors
+
+def iree_hal_amdgpu_expand_device_library_targets(targets):
+    expanded_targets = []
+    for target in targets:
+        if target in IREE_HAL_AMDGPU_DEVICE_LIBRARY_CODE_OBJECT_TARGETS:
+            _append_unique(expanded_targets, [target])
+        elif target in IREE_HAL_AMDGPU_DEVICE_LIBRARY_EXACT_TARGETS:
+            _append_unique(
+                expanded_targets,
+                [IREE_HAL_AMDGPU_DEVICE_LIBRARY_EXACT_TARGET_CODE_OBJECTS[target]],
+            )
+        elif target in IREE_HAL_AMDGPU_DEVICE_LIBRARY_TARGET_FAMILIES:
+            for exact_target in IREE_HAL_AMDGPU_DEVICE_LIBRARY_TARGET_FAMILIES[target]:
+                _append_unique(
+                    expanded_targets,
+                    [IREE_HAL_AMDGPU_DEVICE_LIBRARY_EXACT_TARGET_CODE_OBJECTS[exact_target]],
+                )
+        else:
+            fail("Unknown AMDGPU device library target or family '%s'. Available: %s" % (
+                target,
+                ", ".join(_valid_selectors()),
+            ))
+    return expanded_targets
+
+def _target_label_fragment(target):
+    return target.replace("-", "_").replace(".", "_")
+
+def _selectors_for_code_object_target(code_object_target):
+    selectors = [code_object_target]
+    for exact_target in IREE_HAL_AMDGPU_DEVICE_LIBRARY_EXACT_TARGETS:
+        if IREE_HAL_AMDGPU_DEVICE_LIBRARY_EXACT_TARGET_CODE_OBJECTS[exact_target] == code_object_target:
+            _append_unique(selectors, [exact_target])
+    for family in IREE_HAL_AMDGPU_DEVICE_LIBRARY_TARGET_FAMILY_NAMES:
+        if code_object_target in iree_hal_amdgpu_expand_device_library_targets([family]):
+            _append_unique(selectors, [family])
+    return selectors
+
+def _device_library_targets_flag_impl(ctx):
+    valid_selectors = _valid_selectors()
+    invalid_selectors = [
+        selector
+        for selector in ctx.build_setting_value
+        if selector not in valid_selectors
+    ]
+    if invalid_selectors:
+        fail("Unknown AMDGPU device library target selector(s) [{}]. Available: {}".format(
+            ", ".join(invalid_selectors),
+            ", ".join(valid_selectors),
+        ))
+    return BuildSettingInfo(value = ctx.build_setting_value)
+
+_device_library_targets_flag = rule(
+    implementation = _device_library_targets_flag_impl,
+    build_setting = config.string_list(flag = True),
+)
+
+def iree_hal_amdgpu_device_binaries(
+        name,
+        srcs,
+        internal_hdrs,
+        target_selections = None,
+        target = "amdgcn-amd-amdhsa"):
+    if target_selections == None:
+        target_selections = IREE_HAL_AMDGPU_DEVICE_LIBRARY_DEFAULT_TARGETS
+
+    _device_library_targets_flag(
+        name = "targets",
+        build_setting_default = target_selections,
+    )
+
+    for selector in _valid_selectors():
+        native.config_setting(
+            name = "%s_selected" % (_target_label_fragment(selector),),
+            flag_values = {
+                ":targets": selector,
+            },
+        )
+
+    binary_srcs = []
+    for code_object_target in IREE_HAL_AMDGPU_DEVICE_LIBRARY_CODE_OBJECT_TARGETS:
+        binary_name = "%s--%s" % (target, code_object_target)
+        iree_amdgpu_binary(
+            name = binary_name,
+            srcs = srcs,
+            arch = code_object_target,
+            internal_hdrs = internal_hdrs,
+            target = target,
+        )
+        selects.config_setting_group(
+            name = "%s_requested" % (_target_label_fragment(code_object_target),),
+            match_any = [
+                ":%s_selected" % (_target_label_fragment(selector),)
+                for selector in _selectors_for_code_object_target(code_object_target)
+            ],
+        )
+        binary_srcs += select({
+            ":%s_requested" % (_target_label_fragment(code_object_target),): [":%s.so" % (binary_name,)],
+            "//conditions:default": [],
+        })
+    iree_c_embed_data(
+        name = name,
+        srcs = binary_srcs,
+        c_file_output = "toc.c",
+        flatten = True,
+        h_file_output = "toc.h",
+        identifier = "iree_hal_amdgpu_device_binaries",
+    )
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/blit.c b/runtime/src/iree/hal/drivers/amdgpu/device/blit.c
index 6d5b12f..905e5bc 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/device/blit.c
+++ b/runtime/src/iree/hal/drivers/amdgpu/device/blit.c
@@ -6,200 +6,324 @@
 
 #include "iree/hal/drivers/amdgpu/device/blit.h"
 
-//===----------------------------------------------------------------------===//
-// Buffer transfer operation utilities
-//===----------------------------------------------------------------------===//
-
-// Reserves the next packet in the queue and returns its packet_id.
-// If tracing is enabled |out_completion_signal| will be populated with the
-// signal that must be attached to the operation.
-static uint64_t iree_hal_amdgpu_device_blit_reserve(
-    const iree_hal_amdgpu_device_buffer_transfer_context_t* IREE_AMDGPU_RESTRICT
-        context,
-    iree_hal_amdgpu_trace_execution_zone_type_t zone_type,
-    iree_hsa_signal_t* IREE_AMDGPU_RESTRICT out_completion_signal) {
-#if IREE_HAL_AMDGPU_HAS_TRACING_FEATURE( \
-    IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_EXECUTION)
-  if (context->trace_buffer) {
-    iree_hal_amdgpu_trace_execution_query_id_t execution_query_id =
-        iree_hal_amdgpu_device_query_ringbuffer_acquire(
-            &context->trace_buffer->query_ringbuffer);
-    *out_completion_signal =
-        iree_hal_amdgpu_device_trace_execution_zone_dispatch(
-            context->trace_buffer, zone_type, 0, execution_query_id);
-  } else {
-    *out_completion_signal = iree_hsa_signal_null();
-  }
-#endif  // IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_EXECUTION
-
-  // Reserve the next packet in the queue.
-  const uint64_t packet_id = iree_hsa_queue_add_write_index(
-      &context->queue, 1u, iree_amdgpu_memory_order_relaxed);
-  while (packet_id - iree_hsa_queue_load_read_index(
-                         &context->queue, iree_amdgpu_memory_order_acquire) >=
-         context->queue.size) {
-    iree_amdgpu_yield();  // spinning
-  }
-
-  return packet_id;
-}
-
-// Commits a reserved transfer packet.
-// The header will be updated and the target queue doorbell will be signaled.
-static void iree_hal_amdgpu_device_blit_commit(
-    const iree_hal_amdgpu_device_buffer_transfer_context_t* IREE_AMDGPU_RESTRICT
-        context,
-    uint64_t packet_id,
-    iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT packet,
-    iree_hsa_signal_t completion_signal) {
-  // Chain completion.
-  packet->completion_signal = completion_signal;
-
-  // Populate the header and release the packet to the queue.
-  uint16_t header = IREE_HSA_PACKET_TYPE_KERNEL_DISPATCH
-                    << IREE_HSA_PACKET_HEADER_TYPE;
-
-  // TODO(benvanik): need to pull in barrier/scope overrides from command buffer
-  // execution context flags. They should override the barrier bit and the
-  // scopes to be on SYSTEM regardless of what we choose here.
-
-  // NOTE: we don't need a barrier bit as the caller is expecting it to run
-  // concurrently if needed.
-  header |= 0 << IREE_HSA_PACKET_HEADER_BARRIER;
-
-#if IREE_HAL_AMDGPU_HAS_TRACING_FEATURE( \
-    IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_EXECUTION)
-  if (context->trace_buffer) {
-    // Force a barrier bit if we are tracing execution. This ensures that we get
-    // exclusive timing for the operation.
-    header |= 1 << IREE_HSA_PACKET_HEADER_BARRIER;
-  }
-#endif  // IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_EXECUTION
-
-  // TODO(benvanik): scope to agent if the pointer is local, or maybe none in
-  // cases where surrounding barriers performed the cache management.
-  header |= IREE_HSA_FENCE_SCOPE_SYSTEM
-            << IREE_HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE;
-  header |= IREE_HSA_FENCE_SCOPE_SYSTEM
-            << IREE_HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE;
-
-  const uint32_t header_setup = header | (uint32_t)(packet->setup << 16);
-  iree_amdgpu_scoped_atomic_store(
-      (iree_amdgpu_scoped_atomic_uint32_t*)packet, header_setup,
-      iree_amdgpu_memory_order_release, iree_amdgpu_memory_scope_system);
-
-  // Signal the queue doorbell indicating the packet has been updated.
-  iree_hsa_signal_store(context->queue.doorbell_signal, packet_id,
-                        iree_amdgpu_memory_order_relaxed);
-}
+#include "iree/hal/drivers/amdgpu/device/support/common.h"
+#include "iree/hal/drivers/amdgpu/device/support/kernel.h"
 
 //===----------------------------------------------------------------------===//
 // Blit kernel utilities
 //===----------------------------------------------------------------------===//
 
 // 2 uint64_t values totaling 16 bytes.
-typedef uint32_t iree_amdgpu_uint64x2_t __attribute__((vector_size(16)));
+typedef uint64_t iree_amdgpu_uint64x2_t __attribute__((vector_size(16)));
+// Unaligned view of a 16-byte vector. The __packed__ attribute on the
+// enclosing struct propagates to the |value| member, so a dereference of a
+// iree_amdgpu_unaligned_uint64x2_t* generates unaligned loads/stores instead
+// of the 16-byte-aligned form the compiler would otherwise assume. Used by the
+// unaligned block kernels to vectorize copies/fills when pointers and/or
+// length are not 16-byte aligned.
+typedef struct IREE_AMDGPU_ATTRIBUTE_PACKED {
+  iree_amdgpu_uint64x2_t value;
+} iree_amdgpu_unaligned_uint64x2_t;
 
-static inline size_t iree_hal_amdgpu_blit_linear_id(void) {
-  const size_t id_x = iree_hal_amdgpu_device_group_id_x() *
-                          IREE_HAL_AMDGPU_BLIT_WORKGROUP_SIZE_X +
-                      iree_hal_amdgpu_device_local_id_x();
-  const size_t id_y = iree_hal_amdgpu_device_group_id_y() *
-                          IREE_HAL_AMDGPU_BLIT_WORKGROUP_SIZE_Y +
-                      iree_hal_amdgpu_device_local_id_y();
+// 128 bytes is enough to amortize launch overhead over at least one full
+// vector store per lane at wave32; anything smaller stays on the scalar byte
+// path. The threshold is deliberately independent of the per-block unroll
+// count (IREE_HAL_AMDGPU_*_BLOCK_COUNT) so that benchmark-driven tuning of the
+// unroll factor does not change the selection boundary between scalar and
+// vector paths.
+#define IREE_HAL_AMDGPU_BLIT_UNALIGNED_MIN_BYTES 128
+
+#define IREE_HAL_AMDGPU_FILL_BLOCK_ELEMENT_SIZE sizeof(iree_amdgpu_uint64x2_t)
+#define IREE_HAL_AMDGPU_FILL_BLOCK_COUNT 4
+#define IREE_HAL_AMDGPU_FILL_BLOCK_UNALIGNED_MIN_SIZE \
+  IREE_HAL_AMDGPU_BLIT_UNALIGNED_MIN_BYTES
+
+#define IREE_HAL_AMDGPU_FILL_BLOCK_X4_ELEMENT_SIZE sizeof(uint32_t)
+#define IREE_HAL_AMDGPU_FILL_BLOCK_X4_COUNT 16
+
+#define IREE_HAL_AMDGPU_COPY_BLOCK_ELEMENT_SIZE sizeof(iree_amdgpu_uint64x2_t)
+#define IREE_HAL_AMDGPU_COPY_BLOCK_COUNT 1
+#define IREE_HAL_AMDGPU_COPY_BLOCK_UNALIGNED_MIN_SIZE \
+  IREE_HAL_AMDGPU_BLIT_UNALIGNED_MIN_BYTES
+
+#define IREE_HAL_AMDGPU_COPY_BLOCK_X8_ELEMENT_SIZE sizeof(uint64_t)
+#define IREE_HAL_AMDGPU_COPY_BLOCK_X8_COUNT 8
+
+#define IREE_HAL_AMDGPU_COPY_BLOCK_X4_ELEMENT_SIZE sizeof(uint32_t)
+#define IREE_HAL_AMDGPU_COPY_BLOCK_X4_COUNT 16
+
+#define IREE_HAL_AMDGPU_BLIT_WORKGROUPS_PER_COMPUTE_UNIT 4
+
+void iree_hal_amdgpu_device_buffer_transfer_context_initialize(
+    const iree_hal_amdgpu_device_kernels_t* kernels,
+    uint32_t compute_unit_count, uint32_t wavefront_size,
+    iree_hal_amdgpu_device_buffer_transfer_context_t* out_context) {
+  // Preconditions (validated by the caller; see physical_device.c):
+  //   compute_unit_count > 0
+  //   wavefront_size in {32, 64}
+  const uint64_t max_workgroup_count =
+      (uint64_t)compute_unit_count *
+      IREE_HAL_AMDGPU_BLIT_WORKGROUPS_PER_COMPUTE_UNIT;
+  *out_context = (iree_hal_amdgpu_device_buffer_transfer_context_t){
+      .kernels = kernels,
+      .wavefront_size = (uint16_t)wavefront_size,
+      .workgroup_size_x = (uint16_t)wavefront_size,
+      .max_workgroup_count = max_workgroup_count > UINT32_MAX
+                                 ? UINT32_MAX
+                                 : (uint32_t)max_workgroup_count,
+  };
+}
+
+#if defined(IREE_AMDGPU_TARGET_DEVICE)
+
+static inline IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE uint64_t
+iree_hal_amdgpu_blit_linear_id(void) {
+  const uint64_t id_x = iree_hal_amdgpu_device_global_id_x();
+  const uint64_t id_y = iree_hal_amdgpu_device_global_id_y();
   return id_y * iree_amdgcn_dispatch_ptr()->grid_size[0] + id_x;
 }
 
+static inline IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE uint64_t
+iree_hal_amdgpu_blit_grid_size(void) {
+  const iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT dispatch_ptr =
+      iree_amdgcn_dispatch_ptr();
+  return (uint64_t)dispatch_ptr->grid_size[0] * dispatch_ptr->grid_size[1];
+}
+
+static inline IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE bool
+iree_hal_amdgpu_blit_advance(uint64_t* element_offset,
+                             const uint64_t element_stride) {
+  if (IREE_AMDGPU_UNLIKELY(*element_offset > UINT64_MAX - element_stride)) {
+    return false;
+  }
+  *element_offset += element_stride;
+  return true;
+}
+
+// Returns the byte at |byte_offset| of a repeating fill pattern.
+// Precondition: |pattern| has been extended to a full 8-byte repetition of
+// the original 1/2/4/8-byte pattern (see
+// iree_hal_amdgpu_device_extend_pattern_x8). The mask |byte_offset & 7u| works
+// for any pattern_length that divides 8 — which includes all valid
+// pattern_lengths (1, 2, 4, 8). Callers use this only on tail bytes whose
+// offset within the buffer is already a multiple of the original
+// pattern_length, so |byte_offset| lines up with the pattern phase.
+static inline IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE uint8_t
+iree_hal_amdgpu_blit_pattern_byte(const uint64_t pattern,
+                                  const uint64_t byte_offset) {
+  return (uint8_t)(pattern >> ((byte_offset & 7u) * 8u));
+}
+
+#endif  // IREE_AMDGPU_TARGET_DEVICE
+
 //===----------------------------------------------------------------------===//
 // iree_hal_amdgpu_device_buffer_fill_*
 //===----------------------------------------------------------------------===//
+//
+// IREE_HAL_AMDGPU_FILL_BLOCK_COUNT tuning sweep on gfx1100 (RDNA3, wave32,
+// 96 CUs, GDDR6). Bandwidth in GB/s; off is buffer target offset alignment,
+// pat is fill pattern_length. All rows here use the fill_block_x16 kernel.
+// Build: --compilation_mode=opt --copt=-O3 --copt=-march=native
+//        --copt=-flto=thin --linkopt=-flto=thin.
+//
+//                          QueueFill (single)       QueueFillBatch20
+//   length   off pat    cnt=1 cnt=2 cnt=4 cnt=8   cnt=1 cnt=2 cnt=4 cnt=8
+//   --------------------------------------------------------------------
+//    64KiB    0   4      2.8   2.8   2.7   3.0     5.7   6.3   6.5   8.2
+//     2MiB    0   2     91.4  99.9 101.4 100.1   207.3 199.3 215.0 212.3
+//     2MiB    2   2     91.3 101.6 100.0  98.7   204.5 210.0 210.0 198.9
+//     1GiB    0   4    629.7 270.4 652.1 177.3   638.1 273.2 657.9 179.0
+//
+//   Geomean over bandwidth-relevant rows (64KiB..1GiB across alignments):
+//     QueueFill       : cnt=1: 26.8, cnt=2: 24.8, cnt=4: 25.8, cnt=8: 23.8
+//     QueueFillBatch20: cnt=4: 55.8, cnt=1: 51.9, cnt=8: 50.9, cnt=2: 47.5
+//
+// cnt=4 is the tuned value for gfx1100. Unlike copy, fill shows a bathtub:
+// cnt=1 and cnt=4 both hit ~640 GB/s at 1GiB (~2/3 of the GDDR6 ceiling)
+// while cnt=2 collapses to ~270 GB/s and cnt=8 collapses to ~180 GB/s. The
+// cnt=8 cliff is the same VGPR-occupancy story as copy — `#pragma unroll 8`
+// over 16-byte writes burns enough extra VGPRs to gate occupancy. The cnt=2
+// trough is not fully understood and may be a compiler instruction
+// scheduling / VGPR packing artifact on this toolchain version — it is
+// reproducible but the mechanism has not been isolated here.
+//
+// cnt=4 wins the batched benchmark at every bandwidth-relevant size and
+// ties cnt=1 at the huge-size ceiling, which is why it is the chosen
+// default. The QueueFill (single-op) geomean shows cnt=1 marginally ahead
+// of cnt=4 because several single-op 2MiB rows have enough measurement
+// variance to push cnt=4 into a temporary 74 GB/s dip that the stable
+// batched run does not reproduce; re-running the benchmark with more
+// iterations would likely tighten this. Small batched sizes <16KiB prefer
+// cnt=8 by ~25%, but absolute throughput there is <1 GB/s and does not
+// move real workloads. Sizes <64KiB are omitted here — they are
+// measurement-variance-dominated.
+//
+// CDNA (MI300+, 512 VGPRs per SIMD) may move the cnt=8 cliff further out
+// and should be re-swept before changing.
+
+#if defined(IREE_AMDGPU_TARGET_DEVICE)
 
 IREE_AMDGPU_ATTRIBUTE_KERNEL void iree_hal_amdgpu_device_buffer_fill_x1(
     uint8_t* IREE_AMDGPU_RESTRICT target_ptr, const uint64_t element_length,
     const uint8_t pattern) {
-  const size_t element_offset = iree_hal_amdgpu_blit_linear_id();
-  if (IREE_AMDGPU_LIKELY(element_offset < element_length)) {
-    // Slowest possible copy; benchmarks required to iterate on better impls.
+  const uint64_t element_stride = iree_hal_amdgpu_blit_grid_size();
+  for (uint64_t element_offset = iree_hal_amdgpu_blit_linear_id();
+       element_offset < element_length;) {
     target_ptr[element_offset] = pattern;
+    if (!iree_hal_amdgpu_blit_advance(&element_offset, element_stride)) break;
   }
 }
 
 IREE_AMDGPU_ATTRIBUTE_KERNEL void iree_hal_amdgpu_device_buffer_fill_x2(
     uint16_t* IREE_AMDGPU_RESTRICT target_ptr, const uint64_t element_length,
     const uint16_t pattern) {
-  const size_t element_offset = iree_hal_amdgpu_blit_linear_id();
-  if (IREE_AMDGPU_LIKELY(element_offset < element_length)) {
-    // Slowest possible fill; benchmarks required to iterate on better impls.
+  const uint64_t element_stride = iree_hal_amdgpu_blit_grid_size();
+  for (uint64_t element_offset = iree_hal_amdgpu_blit_linear_id();
+       element_offset < element_length;) {
     target_ptr[element_offset] = pattern;
+    if (!iree_hal_amdgpu_blit_advance(&element_offset, element_stride)) break;
   }
 }
 
 IREE_AMDGPU_ATTRIBUTE_KERNEL void iree_hal_amdgpu_device_buffer_fill_x4(
     uint32_t* IREE_AMDGPU_RESTRICT target_ptr, const uint64_t element_length,
     const uint32_t pattern) {
-  const size_t element_offset = iree_hal_amdgpu_blit_linear_id();
-  if (IREE_AMDGPU_LIKELY(element_offset < element_length)) {
-    // Slowest possible fill; benchmarks required to iterate on better impls.
-    target_ptr[element_offset] = pattern;
+  const uint64_t block_stride = iree_hal_amdgpu_blit_grid_size();
+  for (uint64_t block_id = iree_hal_amdgpu_blit_linear_id();;) {
+    const uint64_t element_offset =
+        block_id * IREE_HAL_AMDGPU_FILL_BLOCK_X4_COUNT;
+    if (IREE_AMDGPU_UNLIKELY(element_offset >= element_length)) return;
+    const uint64_t element_count = IREE_AMDGPU_MIN(
+        IREE_HAL_AMDGPU_FILL_BLOCK_X4_COUNT, element_length - element_offset);
+    if (IREE_AMDGPU_LIKELY(element_count ==
+                           IREE_HAL_AMDGPU_FILL_BLOCK_X4_COUNT)) {
+#pragma unroll
+      for (size_t i = 0; i < IREE_HAL_AMDGPU_FILL_BLOCK_X4_COUNT; ++i) {
+        target_ptr[element_offset + i] = pattern;
+      }
+    } else {
+      for (size_t i = 0; i < element_count; ++i) {
+        target_ptr[element_offset + i] = pattern;
+      }
+    }
+    if (!iree_hal_amdgpu_blit_advance(&block_id, block_stride)) return;
   }
 }
 
 IREE_AMDGPU_ATTRIBUTE_KERNEL void iree_hal_amdgpu_device_buffer_fill_x8(
     uint64_t* IREE_AMDGPU_RESTRICT target_ptr, const uint64_t element_length,
     const uint64_t pattern) {
-  const size_t element_offset = iree_hal_amdgpu_blit_linear_id();
-  if (IREE_AMDGPU_LIKELY(element_offset < element_length)) {
-    // Slowest possible fill; benchmarks required to iterate on better impls.
+  const uint64_t element_stride = iree_hal_amdgpu_blit_grid_size();
+  for (uint64_t element_offset = iree_hal_amdgpu_blit_linear_id();
+       element_offset < element_length;) {
     target_ptr[element_offset] = pattern;
+    if (!iree_hal_amdgpu_blit_advance(&element_offset, element_stride)) break;
   }
 }
 
-#define IREE_HAL_AMDGPU_FILL_BLOCK_ELEMENT_SIZE sizeof(iree_amdgpu_uint64x2_t)
-#define IREE_HAL_AMDGPU_FILL_BLOCK_COUNT 8
-#define IREE_HAL_AMDGPU_FILL_BLOCK_SIZE \
-  (IREE_HAL_AMDGPU_FILL_BLOCK_ELEMENT_SIZE * IREE_HAL_AMDGPU_FILL_BLOCK_COUNT)
-
-// Fills a block of up to IREE_HAL_AMDGPU_FILL_BLOCK_COUNT 16-byte elements with
-// a fixed pattern. Requires an alignment of 16-bytes on both the |target_ptr|
-// and |length|.
+// Fills blocks of up to IREE_HAL_AMDGPU_FILL_BLOCK_COUNT 16-byte elements with
+// a fixed pattern. Requires an alignment of 16-bytes on both |target_ptr| and
+// |length|.
 IREE_AMDGPU_ATTRIBUTE_KERNEL void iree_hal_amdgpu_device_buffer_fill_block_x16(
     iree_amdgpu_uint64x2_t* IREE_AMDGPU_RESTRICT target_ptr,
     const uint64_t element_length, const uint64_t pattern) {
-  const size_t block_id = iree_hal_amdgpu_blit_linear_id();
-  const size_t element_offset = block_id * IREE_HAL_AMDGPU_FILL_BLOCK_COUNT;
-  if (IREE_AMDGPU_UNLIKELY(element_offset >= element_length)) return;
-  iree_amdgpu_uint64x2_t pattern_x16 = {pattern, pattern};
-  const size_t element_count =
-      IREE_AMDGPU_MIN(IREE_HAL_AMDGPU_FILL_BLOCK_COUNT,
-                      element_length - element_offset) /
-      sizeof(pattern_x16);
-  if (IREE_AMDGPU_LIKELY(element_count == IREE_HAL_AMDGPU_FILL_BLOCK_COUNT)) {
+  const uint64_t block_stride = iree_hal_amdgpu_blit_grid_size();
+  const iree_amdgpu_uint64x2_t pattern_x16 = {pattern, pattern};
+  for (uint64_t block_id = iree_hal_amdgpu_blit_linear_id();;) {
+    const uint64_t element_offset = block_id * IREE_HAL_AMDGPU_FILL_BLOCK_COUNT;
+    if (IREE_AMDGPU_UNLIKELY(element_offset >= element_length)) return;
+    const uint64_t element_count = IREE_AMDGPU_MIN(
+        IREE_HAL_AMDGPU_FILL_BLOCK_COUNT, element_length - element_offset);
+    if (IREE_AMDGPU_LIKELY(element_count == IREE_HAL_AMDGPU_FILL_BLOCK_COUNT)) {
 #pragma unroll
-    for (int i = 0; i < IREE_HAL_AMDGPU_FILL_BLOCK_COUNT; ++i) {
-      target_ptr[element_offset + i] = pattern_x16;
+      for (size_t i = 0; i < IREE_HAL_AMDGPU_FILL_BLOCK_COUNT; ++i) {
+        target_ptr[element_offset + i] = pattern_x16;
+      }
+    } else {
+      for (size_t i = 0; i < element_count; ++i) {
+        target_ptr[element_offset + i] = pattern_x16;
+      }
     }
-  } else {
-    for (int i = 0; i < element_count; ++i) {
-      target_ptr[element_offset + i] = pattern_x16;
-    }
+    if (!iree_hal_amdgpu_blit_advance(&block_id, block_stride)) return;
   }
 }
 
+// Fills a byte-granular region using 16-byte vector stores where possible.
+// NOTE: |element_length| is a byte length here, not a count of 16-byte
+// elements as in fill_block_x16. The kernargs struct field name is shared
+// across all fill variants so the host-side emplace path doesn't need a
+// second struct; this kernel reinterprets that field as a byte length and
+// derives the vector/tail split internally.
+IREE_AMDGPU_ATTRIBUTE_KERNEL void
+iree_hal_amdgpu_device_buffer_fill_block_unaligned_x16(
+    iree_amdgpu_unaligned_uint64x2_t* IREE_AMDGPU_RESTRICT target_ptr,
+    const uint64_t element_length, const uint64_t pattern) {
+  const uint64_t full_element_count =
+      element_length / IREE_HAL_AMDGPU_FILL_BLOCK_ELEMENT_SIZE;
+  const uint64_t vector_block_count = IREE_AMDGPU_CEIL_DIV(
+      full_element_count, IREE_HAL_AMDGPU_FILL_BLOCK_COUNT);
+  const uint64_t tail_offset =
+      full_element_count * IREE_HAL_AMDGPU_FILL_BLOCK_ELEMENT_SIZE;
+  const uint64_t tail_length = element_length - tail_offset;
+  const uint64_t block_stride = iree_hal_amdgpu_blit_grid_size();
+  const iree_amdgpu_uint64x2_t pattern_x16 = {pattern, pattern};
+  for (uint64_t block_id = iree_hal_amdgpu_blit_linear_id();;) {
+    if (IREE_AMDGPU_UNLIKELY(vector_block_count == 0)) {
+      if (block_id == 0) {
+        uint8_t* tail_ptr = (uint8_t*)target_ptr;
+        for (uint64_t i = 0; i < tail_length; ++i) {
+          tail_ptr[i] = iree_hal_amdgpu_blit_pattern_byte(pattern, i);
+        }
+      }
+      return;
+    }
+    if (IREE_AMDGPU_UNLIKELY(block_id >= vector_block_count)) return;
+    const uint64_t element_offset = block_id * IREE_HAL_AMDGPU_FILL_BLOCK_COUNT;
+    const uint64_t element_count = IREE_AMDGPU_MIN(
+        IREE_HAL_AMDGPU_FILL_BLOCK_COUNT, full_element_count - element_offset);
+    if (IREE_AMDGPU_LIKELY(element_count == IREE_HAL_AMDGPU_FILL_BLOCK_COUNT)) {
+#pragma unroll
+      for (size_t i = 0; i < IREE_HAL_AMDGPU_FILL_BLOCK_COUNT; ++i) {
+        target_ptr[element_offset + i].value = pattern_x16;
+      }
+    } else {
+      for (size_t i = 0; i < element_count; ++i) {
+        target_ptr[element_offset + i].value = pattern_x16;
+      }
+    }
+    if (IREE_AMDGPU_UNLIKELY(tail_length &&
+                             block_id + 1 == vector_block_count)) {
+      uint8_t* tail_ptr = (uint8_t*)target_ptr + tail_offset;
+      for (uint64_t i = 0; i < tail_length; ++i) {
+        tail_ptr[i] = iree_hal_amdgpu_blit_pattern_byte(pattern, i);
+      }
+    }
+    if (!iree_hal_amdgpu_blit_advance(&block_id, block_stride)) return;
+  }
+}
+
+#endif  // IREE_AMDGPU_TARGET_DEVICE
+
 // Returns the bytes of |pattern| of length |pattern_length| splatted to
 // an 8-byte value.
 static uint64_t iree_hal_amdgpu_device_extend_pattern_x8(
     const uint64_t pattern, const uint8_t pattern_length) {
   switch (pattern_length) {
-    case 1:
-      return ((uint64_t)pattern << 56) | ((uint64_t)pattern << 48) |
-             ((uint64_t)pattern << 40) | ((uint64_t)pattern << 32) |
-             ((uint64_t)pattern << 24) | ((uint64_t)pattern << 16) |
-             ((uint64_t)pattern << 8) | pattern;
-    case 2:
-      return ((uint64_t)pattern << 48) | ((uint64_t)pattern << 32) |
-             ((uint64_t)pattern << 16) | pattern;
-    case 4:
-      return ((uint64_t)pattern << 32) | pattern;
+    case 1: {
+      const uint64_t pattern_x1 = pattern & 0xFFu;
+      return (pattern_x1 << 56) | (pattern_x1 << 48) | (pattern_x1 << 40) |
+             (pattern_x1 << 32) | (pattern_x1 << 24) | (pattern_x1 << 16) |
+             (pattern_x1 << 8) | pattern_x1;
+    }
+    case 2: {
+      const uint64_t pattern_x2 = pattern & 0xFFFFu;
+      return (pattern_x2 << 48) | (pattern_x2 << 32) | (pattern_x2 << 16) |
+             pattern_x2;
+    }
+    case 4: {
+      const uint64_t pattern_x4 = pattern & 0xFFFFFFFFu;
+      return (pattern_x4 << 32) | pattern_x4;
+    }
     case 8:
       return pattern;
     default:
@@ -207,280 +331,439 @@
   }
 }
 
-iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT
-iree_hal_amdgpu_device_buffer_fill_emplace_reserve(
+// Returns the number of blocks needed to cover |length| bytes when the vector
+// kernels process |elements_per_block| elements of |element_size| bytes each.
+// Callers of this helper gate on length >= UNALIGNED_MIN_SIZE (128), so
+// full_element_count is guaranteed to be non-zero here.
+static uint64_t iree_hal_amdgpu_blit_unaligned_block_count(
+    const uint64_t length, const uint64_t element_size,
+    const uint64_t elements_per_block) {
+  const uint64_t full_element_count = length / element_size;
+  return IREE_AMDGPU_CEIL_DIV(full_element_count, elements_per_block);
+}
+
+// Computes a bounded 2D dispatch grid for |block_count| logical blocks without
+// exceeding the 32-bit packet grid dimensions. Kernels use grid-stride loops,
+// so large transfers cap resident work to the context's launch metadata.
+// Zero-block launches use a 1x1 no-op dispatch, one-row launches use
+// [work_item_count, 1], and larger launches choose the smallest X that keeps Y
+// in-range, minimizing overshoot from the final partially-filled row.
+static bool iree_hal_amdgpu_blit_calculate_grid_size(
+    const iree_hal_amdgpu_device_buffer_transfer_context_t* context,
+    const uint64_t block_count, uint32_t* out_grid_size_x,
+    uint32_t* out_grid_size_y) {
+  const uint64_t max_work_item_count =
+      (uint64_t)context->max_workgroup_count * context->workgroup_size_x;
+  uint64_t work_item_count = IREE_AMDGPU_MIN(block_count, max_work_item_count);
+  if (IREE_AMDGPU_UNLIKELY(work_item_count == 0)) {
+    *out_grid_size_x = 1;
+    *out_grid_size_y = 1;
+    return true;
+  }
+  if (IREE_AMDGPU_LIKELY(work_item_count <= UINT32_MAX)) {
+    *out_grid_size_x = (uint32_t)work_item_count;
+    *out_grid_size_y = 1;
+    return true;
+  }
+  const uint64_t grid_size_x = 1 + ((work_item_count - 1) / UINT32_MAX);
+  if (IREE_AMDGPU_UNLIKELY(grid_size_x > UINT32_MAX)) {
+    return false;
+  }
+  const uint64_t grid_size_y = 1 + ((work_item_count - 1) / grid_size_x);
+  *out_grid_size_x = (uint32_t)grid_size_x;
+  *out_grid_size_y = (uint32_t)grid_size_y;
+  return true;
+}
+
+static void iree_hal_amdgpu_blit_emplace_dispatch(
     const iree_hal_amdgpu_device_buffer_transfer_context_t* IREE_AMDGPU_RESTRICT
         context,
-    void* target_ptr, const uint64_t length, uint64_t pattern,
-    const uint8_t pattern_length, uint64_t* IREE_AMDGPU_RESTRICT kernarg_ptr,
-    const uint64_t packet_id) {
-  IREE_AMDGPU_TRACE_BUFFER_SCOPE(context->trace_buffer);
-  IREE_AMDGPU_TRACE_ZONE_BEGIN(z0);
+    const iree_hal_amdgpu_device_kernel_args_t* IREE_AMDGPU_RESTRICT
+        kernel_args,
+    iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT dispatch_packet,
+    const uint32_t grid_size_x, const uint32_t grid_size_y, void* kernarg_ptr) {
+  dispatch_packet->setup = kernel_args->setup;
+  dispatch_packet->workgroup_size[0] = context->workgroup_size_x;
+  dispatch_packet->workgroup_size[1] = 1;
+  dispatch_packet->workgroup_size[2] = 1;
+  dispatch_packet->reserved0 = 0;
+  dispatch_packet->grid_size[0] = grid_size_x;
+  dispatch_packet->grid_size[1] = grid_size_y;
+  dispatch_packet->grid_size[2] = 1;
+  dispatch_packet->private_segment_size = kernel_args->private_segment_size;
+  dispatch_packet->group_segment_size = kernel_args->group_segment_size;
+  dispatch_packet->kernel_object = kernel_args->kernel_object;
+  dispatch_packet->kernarg_address = kernarg_ptr;
+  dispatch_packet->reserved2 = 0;
+  dispatch_packet->completion_signal = iree_hsa_signal_null();
+}
 
+bool iree_hal_amdgpu_device_buffer_fill_emplace(
+    const iree_hal_amdgpu_device_buffer_transfer_context_t* IREE_AMDGPU_RESTRICT
+        context,
+    iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT dispatch_packet,
+    void* target_ptr, uint64_t length, uint64_t pattern, uint8_t pattern_length,
+    void* IREE_AMDGPU_RESTRICT kernarg_ptr) {
   // Select the kernel for the fill operation.
   const iree_hal_amdgpu_device_kernel_args_t* IREE_AMDGPU_RESTRICT kernel_args =
       NULL;
-  size_t element_size = 1;
-  size_t block_size = 1;
-  if (iree_amdgpu_has_alignment((size_t)target_ptr,
+  uint64_t element_size = 1;
+  uint64_t block_size = 1;
+  bool uses_byte_length = false;
+  if (IREE_AMDGPU_UNLIKELY(pattern_length != 1 && pattern_length != 2 &&
+                           pattern_length != 4 && pattern_length != 8)) {
+    return false;
+  }
+  if (IREE_AMDGPU_UNLIKELY(
+          !iree_amdgpu_has_alignment((uintptr_t)target_ptr, pattern_length) ||
+          !iree_amdgpu_has_alignment(length, pattern_length))) {
+    return false;
+  }
+  if (iree_amdgpu_has_alignment((uintptr_t)target_ptr,
                                 IREE_HAL_AMDGPU_FILL_BLOCK_ELEMENT_SIZE) &&
       iree_amdgpu_has_alignment(length,
                                 IREE_HAL_AMDGPU_FILL_BLOCK_ELEMENT_SIZE)) {
-    IREE_AMDGPU_TRACE_ZONE_APPEND_TEXT_LITERAL(z0, "fill_block_x16");
     pattern = iree_hal_amdgpu_device_extend_pattern_x8(pattern, pattern_length);
     kernel_args =
         &context->kernels->iree_hal_amdgpu_device_buffer_fill_block_x16;
     element_size = IREE_HAL_AMDGPU_FILL_BLOCK_ELEMENT_SIZE;
     block_size = IREE_HAL_AMDGPU_FILL_BLOCK_COUNT;
+  } else if (pattern_length <= 4 &&
+             iree_amdgpu_has_alignment(
+                 (uintptr_t)target_ptr,
+                 IREE_HAL_AMDGPU_FILL_BLOCK_X4_ELEMENT_SIZE) &&
+             iree_amdgpu_has_alignment(
+                 length, IREE_HAL_AMDGPU_FILL_BLOCK_X4_ELEMENT_SIZE)) {
+    pattern = iree_hal_amdgpu_device_extend_pattern_x8(pattern, pattern_length);
+    kernel_args = &context->kernels->iree_hal_amdgpu_device_buffer_fill_x4;
+    element_size = IREE_HAL_AMDGPU_FILL_BLOCK_X4_ELEMENT_SIZE;
+    block_size = IREE_HAL_AMDGPU_FILL_BLOCK_X4_COUNT;
+  } else if (length >= IREE_HAL_AMDGPU_FILL_BLOCK_UNALIGNED_MIN_SIZE) {
+    pattern = iree_hal_amdgpu_device_extend_pattern_x8(pattern, pattern_length);
+    kernel_args = &context->kernels
+                       ->iree_hal_amdgpu_device_buffer_fill_block_unaligned_x16;
+    element_size = 1;
+    block_size = 1;
+    uses_byte_length = true;
   } else {
     switch (pattern_length) {
       case 1:
-        IREE_AMDGPU_TRACE_ZONE_APPEND_TEXT_LITERAL(z0, "fill_x1");
         kernel_args = &context->kernels->iree_hal_amdgpu_device_buffer_fill_x1;
         break;
       case 2:
-        IREE_AMDGPU_TRACE_ZONE_APPEND_TEXT_LITERAL(z0, "fill_x2");
         kernel_args = &context->kernels->iree_hal_amdgpu_device_buffer_fill_x2;
+        element_size = 2;
         break;
       case 4:
-        IREE_AMDGPU_TRACE_ZONE_APPEND_TEXT_LITERAL(z0, "fill_x4");
         kernel_args = &context->kernels->iree_hal_amdgpu_device_buffer_fill_x4;
+        element_size = 4;
         break;
       case 8:
-        IREE_AMDGPU_TRACE_ZONE_APPEND_TEXT_LITERAL(z0, "fill_x8");
         kernel_args = &context->kernels->iree_hal_amdgpu_device_buffer_fill_x8;
+        element_size = 8;
         break;
+      default:
+        return false;
     }
-    element_size = pattern_length;
     block_size = 1;
   }
 
+  const uint64_t element_count = length / element_size;
+  const uint64_t block_count =
+      uses_byte_length ? iree_hal_amdgpu_blit_unaligned_block_count(
+                             length, IREE_HAL_AMDGPU_FILL_BLOCK_ELEMENT_SIZE,
+                             IREE_HAL_AMDGPU_FILL_BLOCK_COUNT)
+                       : IREE_AMDGPU_CEIL_DIV(element_count, block_size);
+  uint32_t grid_size_x = 0;
+  uint32_t grid_size_y = 0;
+  if (IREE_AMDGPU_UNLIKELY(!iree_hal_amdgpu_blit_calculate_grid_size(
+          context, block_count, &grid_size_x, &grid_size_y))) {
+    return false;
+  }
+
   // Update kernargs (same API for all kernels).
-  const size_t element_count = length / element_size;
   iree_hal_amdgpu_device_buffer_fill_kernargs_t* kernargs =
       (iree_hal_amdgpu_device_buffer_fill_kernargs_t*)kernarg_ptr;
   kernargs->target_ptr = target_ptr;
   kernargs->element_length = element_count;
   kernargs->pattern = pattern;
 
-  // To support fills with more than UINT_MAX elements (uint32_t grid_size)
-  // we divide the problem into chunks as needed. We keep the innermost chunk
-  // size small as if we do [X,Y,1] we're likely to overshoot and don't want to
-  // have too many wasted invocations.
-  const size_t block_count = IREE_AMDGPU_CEIL_DIV(element_count, block_size);
-  uint32_t grid_size_x = 1;
-  uint32_t grid_size_y = 1;
-  if (IREE_AMDGPU_LIKELY(block_count <= 0xFFFFFFFFu)) {
-    grid_size_x = (uint32_t)block_count;
-  } else {
-    grid_size_x = 256;
-    grid_size_y = (uint32_t)IREE_AMDGPU_CEIL_DIV(block_count, grid_size_x);
-  }
-
-  // Populate the packet.
-  const uint64_t queue_mask = context->queue.size - 1;  // power of two
-  iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT dispatch_packet =
-      context->queue.base_address + (packet_id & queue_mask) * 64;
-  dispatch_packet->setup = kernel_args->setup;
-  dispatch_packet->workgroup_size[0] = kernel_args->workgroup_size[0];
-  dispatch_packet->workgroup_size[1] = kernel_args->workgroup_size[1];
-  dispatch_packet->workgroup_size[2] = kernel_args->workgroup_size[2];
-  dispatch_packet->reserved0 = 0;
-  dispatch_packet->grid_size[0] = grid_size_x;
-  dispatch_packet->grid_size[1] = grid_size_y;
-  dispatch_packet->grid_size[2] = 1;
-  dispatch_packet->private_segment_size = kernel_args->private_segment_size;
-  dispatch_packet->group_segment_size = kernel_args->group_segment_size;
-  dispatch_packet->kernel_object = kernel_args->kernel_object;
-  dispatch_packet->kernarg_address = kernarg_ptr;
-  dispatch_packet->reserved2 = 0;
-
-  IREE_AMDGPU_TRACE_ZONE_END(z0);
-  return dispatch_packet;
-}
-
-void iree_hal_amdgpu_device_buffer_fill_enqueue(
-    const iree_hal_amdgpu_device_buffer_transfer_context_t* IREE_AMDGPU_RESTRICT
-        context,
-    void* target_ptr, const uint64_t length, const uint64_t pattern,
-    const uint8_t pattern_length, uint64_t* IREE_AMDGPU_RESTRICT kernarg_ptr) {
-  IREE_AMDGPU_TRACE_BUFFER_SCOPE(context->trace_buffer);
-  IREE_AMDGPU_TRACE_ZONE_BEGIN(z0);
-  IREE_AMDGPU_TRACE_ZONE_APPEND_VALUE_I64(z0, length);
-
-  // Reserve and begin populating the operation packet.
-  // When tracing is enabled capture the timing signal.
-  iree_hsa_signal_t completion_signal = iree_hsa_signal_null();
-  const uint64_t packet_id = iree_hal_amdgpu_device_blit_reserve(
-      context, IREE_HAL_AMDGPU_TRACE_EXECUTION_ZONE_TYPE_FILL,
-      &completion_signal);
-
-  // Emplace the dispatch packet into the queue.
-  // Note that until the packet is issued the queue will stall.
-  iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT packet =
-      iree_hal_amdgpu_device_buffer_fill_emplace_reserve(
-          context, target_ptr, length, pattern, pattern_length, kernarg_ptr,
-          packet_id);
-
-  // Issues the buffer operation packet by configuring its header and signaling
-  // the queue doorbell.
-  iree_hal_amdgpu_device_blit_commit(context, packet_id, packet,
-                                     completion_signal);
-
-  IREE_AMDGPU_TRACE_ZONE_END(z0);
+  iree_hal_amdgpu_blit_emplace_dispatch(context, kernel_args, dispatch_packet,
+                                        grid_size_x, grid_size_y, kernarg_ptr);
+  return true;
 }
 
 //===----------------------------------------------------------------------===//
 // iree_hal_amdgpu_device_buffer_copy_*
 //===----------------------------------------------------------------------===//
+//
+// IREE_HAL_AMDGPU_COPY_BLOCK_COUNT tuning sweep on gfx1100 (RDNA3, wave32,
+// 96 CUs, GDDR6). Bandwidth in GB/s; src/tgt are buffer offset alignments.
+// Build: --compilation_mode=opt --copt=-O3 --copt=-march=native
+//        --copt=-flto=thin --linkopt=-flto=thin.
+//
+//                          QueueCopy (single)       QueueCopyBatch20
+//   length   src tgt    cnt=1 cnt=2 cnt=4 cnt=8   cnt=1 cnt=2 cnt=4 cnt=8
+//   --------------------------------------------------------------------
+//    64KiB    0   0      2.7   2.5   2.7   3.0     5.7   5.4   7.1   8.1
+//     2MiB    0   0     91.6  86.1  85.2  86.5   177.7 176.0 146.2 151.7
+//     2MiB    1   2     91.7  89.7  82.6  80.6   176.7 161.0 138.4 134.7
+//     1GiB    0   0    299.5 161.8 118.8 103.7   303.9 163.0 119.2 104.0
+//
+//   Geomean over bandwidth-relevant rows (64KiB..1GiB across alignments):
+//     QueueCopy       : cnt=1: 28.1, cnt=2: 23.5, cnt=4: 23.0, cnt=8: 22.7
+//     QueueCopyBatch20: cnt=1: 50.4, cnt=2: 43.3, cnt=4: 41.3, cnt=8: 41.2
+//
+// The 3x cliff at 1GiB matches the expected VGPR occupancy drop from the
+// `#pragma unroll 8` body: every in-flight 16-byte copy needs ~4 VGPRs, so
+// cnt=8 burns ~32 extra VGPRs, cutting max waves-per-SIMD from ~16 to ~5 on
+// gfx1100 (256 VGPRs per SIMD) and starving the latency-hiding budget for
+// bandwidth-bound transfers.
+//
+// Small batched transfers (<16KiB) prefer cnt=8 by ~15-20% because the work
+// is launch-overhead-bound and each wave doing more reduces dispatch cost,
+// but absolute throughput there is <1 GB/s and doesn't move real workloads.
+// Sizes <64KiB are omitted here — they are measurement-variance-dominated.
+//
+// cnt=1 is the tuned value for gfx1100. CDNA (MI300+, 512 VGPRs per SIMD)
+// may prefer a larger value and should be re-swept before changing.
+
+#if defined(IREE_AMDGPU_TARGET_DEVICE)
 
 IREE_AMDGPU_ATTRIBUTE_KERNEL void iree_hal_amdgpu_device_buffer_copy_x1(
     const uint8_t* IREE_AMDGPU_RESTRICT source_ptr,
     uint8_t* IREE_AMDGPU_RESTRICT target_ptr, const uint64_t element_length) {
-  const size_t element_offset = iree_hal_amdgpu_blit_linear_id();
-  if (IREE_AMDGPU_LIKELY(element_offset < element_length)) {
-    // Slowest possible copy; benchmarks required to iterate on better impls.
+  const uint64_t element_stride = iree_hal_amdgpu_blit_grid_size();
+  for (uint64_t element_offset = iree_hal_amdgpu_blit_linear_id();
+       element_offset < element_length;) {
     target_ptr[element_offset] = source_ptr[element_offset];
+    if (!iree_hal_amdgpu_blit_advance(&element_offset, element_stride)) break;
   }
 }
 
-#define IREE_HAL_AMDGPU_COPY_BLOCK_ELEMENT_SIZE sizeof(iree_amdgpu_uint64x2_t)
-#define IREE_HAL_AMDGPU_COPY_BLOCK_COUNT 8
-#define IREE_HAL_AMDGPU_COPY_BLOCK_SIZE \
-  (IREE_HAL_AMDGPU_COPY_BLOCK_ELEMENT_SIZE * IREE_HAL_AMDGPU_COPY_BLOCK_COUNT)
+// Copies blocks of up to IREE_HAL_AMDGPU_COPY_BLOCK_X4_COUNT 4-byte elements
+// from |source_ptr| to |target_ptr|. Requires an alignment of 4-bytes on all of
+// |source_ptr|, |target_ptr|, and |length|.
+IREE_AMDGPU_ATTRIBUTE_KERNEL void iree_hal_amdgpu_device_buffer_copy_block_x4(
+    const uint32_t* IREE_AMDGPU_RESTRICT source_ptr,
+    uint32_t* IREE_AMDGPU_RESTRICT target_ptr, const uint64_t element_length) {
+  const uint64_t block_stride = iree_hal_amdgpu_blit_grid_size();
+  for (uint64_t block_id = iree_hal_amdgpu_blit_linear_id();;) {
+    const uint64_t element_offset =
+        block_id * IREE_HAL_AMDGPU_COPY_BLOCK_X4_COUNT;
+    if (IREE_AMDGPU_UNLIKELY(element_offset >= element_length)) return;
+    const uint64_t element_count = IREE_AMDGPU_MIN(
+        IREE_HAL_AMDGPU_COPY_BLOCK_X4_COUNT, element_length - element_offset);
+    if (IREE_AMDGPU_LIKELY(element_count ==
+                           IREE_HAL_AMDGPU_COPY_BLOCK_X4_COUNT)) {
+#pragma unroll
+      for (size_t i = 0; i < IREE_HAL_AMDGPU_COPY_BLOCK_X4_COUNT; ++i) {
+        target_ptr[element_offset + i] = source_ptr[element_offset + i];
+      }
+    } else {
+      for (size_t i = 0; i < element_count; ++i) {
+        target_ptr[element_offset + i] = source_ptr[element_offset + i];
+      }
+    }
+    if (!iree_hal_amdgpu_blit_advance(&block_id, block_stride)) return;
+  }
+}
 
-// Copies a block of up to IREE_HAL_AMDGPU_COPY_BLOCK_COUNT 16-byte elements
+// Copies blocks of up to IREE_HAL_AMDGPU_COPY_BLOCK_X8_COUNT 8-byte elements
+// from |source_ptr| to |target_ptr|. Requires an alignment of 8-bytes on all of
+// |source_ptr|, |target_ptr|, and |length|.
+IREE_AMDGPU_ATTRIBUTE_KERNEL void iree_hal_amdgpu_device_buffer_copy_block_x8(
+    const uint64_t* IREE_AMDGPU_RESTRICT source_ptr,
+    uint64_t* IREE_AMDGPU_RESTRICT target_ptr, const uint64_t element_length) {
+  const uint64_t block_stride = iree_hal_amdgpu_blit_grid_size();
+  for (uint64_t block_id = iree_hal_amdgpu_blit_linear_id();;) {
+    const uint64_t element_offset =
+        block_id * IREE_HAL_AMDGPU_COPY_BLOCK_X8_COUNT;
+    if (IREE_AMDGPU_UNLIKELY(element_offset >= element_length)) return;
+    const uint64_t element_count = IREE_AMDGPU_MIN(
+        IREE_HAL_AMDGPU_COPY_BLOCK_X8_COUNT, element_length - element_offset);
+    if (IREE_AMDGPU_LIKELY(element_count ==
+                           IREE_HAL_AMDGPU_COPY_BLOCK_X8_COUNT)) {
+#pragma unroll
+      for (size_t i = 0; i < IREE_HAL_AMDGPU_COPY_BLOCK_X8_COUNT; ++i) {
+        target_ptr[element_offset + i] = source_ptr[element_offset + i];
+      }
+    } else {
+      for (size_t i = 0; i < element_count; ++i) {
+        target_ptr[element_offset + i] = source_ptr[element_offset + i];
+      }
+    }
+    if (!iree_hal_amdgpu_blit_advance(&block_id, block_stride)) return;
+  }
+}
+
+// Copies blocks of up to IREE_HAL_AMDGPU_COPY_BLOCK_COUNT 16-byte elements
 // from |source_ptr| to |target_ptr|. Requires an alignment of 16-bytes on all
 // of |source_ptr|, |target_ptr|, and |length|.
-//
-// Dispatched on a 2D grid with up to UINT32_MAX blocks on X.
 IREE_AMDGPU_ATTRIBUTE_KERNEL void iree_hal_amdgpu_device_buffer_copy_block_x16(
     const iree_amdgpu_uint64x2_t* IREE_AMDGPU_RESTRICT source_ptr,
     iree_amdgpu_uint64x2_t* IREE_AMDGPU_RESTRICT target_ptr,
     const uint64_t element_length) {
-  const size_t block_id = iree_hal_amdgpu_blit_linear_id();
-  const size_t element_offset = block_id * IREE_HAL_AMDGPU_COPY_BLOCK_COUNT;
-  if (IREE_AMDGPU_UNLIKELY(element_offset >= element_length)) return;
-  const size_t element_count = IREE_AMDGPU_MIN(IREE_HAL_AMDGPU_COPY_BLOCK_COUNT,
-                                               element_length - element_offset);
-  if (IREE_AMDGPU_LIKELY(element_count == IREE_HAL_AMDGPU_COPY_BLOCK_COUNT)) {
+  const uint64_t block_stride = iree_hal_amdgpu_blit_grid_size();
+  for (uint64_t block_id = iree_hal_amdgpu_blit_linear_id();;) {
+    const uint64_t element_offset = block_id * IREE_HAL_AMDGPU_COPY_BLOCK_COUNT;
+    if (IREE_AMDGPU_UNLIKELY(element_offset >= element_length)) return;
+    const uint64_t element_count = IREE_AMDGPU_MIN(
+        IREE_HAL_AMDGPU_COPY_BLOCK_COUNT, element_length - element_offset);
+    if (IREE_AMDGPU_LIKELY(element_count == IREE_HAL_AMDGPU_COPY_BLOCK_COUNT)) {
 #pragma unroll
-    for (int i = 0; i < IREE_HAL_AMDGPU_COPY_BLOCK_COUNT; ++i) {
-      target_ptr[element_offset + i] = source_ptr[element_offset + i];
+      for (size_t i = 0; i < IREE_HAL_AMDGPU_COPY_BLOCK_COUNT; ++i) {
+        target_ptr[element_offset + i] = source_ptr[element_offset + i];
+      }
+    } else {
+      for (size_t i = 0; i < element_count; ++i) {
+        target_ptr[element_offset + i] = source_ptr[element_offset + i];
+      }
     }
-  } else {
-    for (int i = 0; i < element_count; ++i) {
-      target_ptr[element_offset + i] = source_ptr[element_offset + i];
-    }
+    if (!iree_hal_amdgpu_blit_advance(&block_id, block_stride)) return;
   }
 }
 
-// TODO(benvanik): experiment with enqueuing SDMA somehow (may need to take a
-// DMA queue as well as the dispatch queue). Note that on some configurations
-// (InfinityFabric) blit kernels can be 2x faster than SDMA.
-iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT
-iree_hal_amdgpu_device_buffer_copy_emplace_reserve(
+// Copies a byte-granular region using 16-byte vector loads/stores where
+// possible. See the note on fill_block_unaligned_x16 regarding the
+// |element_length| field sharing its name with the aligned variants despite
+// being a byte length here.
+IREE_AMDGPU_ATTRIBUTE_KERNEL void
+iree_hal_amdgpu_device_buffer_copy_block_unaligned_x16(
+    const iree_amdgpu_unaligned_uint64x2_t* IREE_AMDGPU_RESTRICT source_ptr,
+    iree_amdgpu_unaligned_uint64x2_t* IREE_AMDGPU_RESTRICT target_ptr,
+    const uint64_t element_length) {
+  const uint64_t full_element_count =
+      element_length / IREE_HAL_AMDGPU_COPY_BLOCK_ELEMENT_SIZE;
+  const uint64_t vector_block_count = IREE_AMDGPU_CEIL_DIV(
+      full_element_count, IREE_HAL_AMDGPU_COPY_BLOCK_COUNT);
+  const uint64_t tail_offset =
+      full_element_count * IREE_HAL_AMDGPU_COPY_BLOCK_ELEMENT_SIZE;
+  const uint64_t tail_length = element_length - tail_offset;
+  const uint64_t block_stride = iree_hal_amdgpu_blit_grid_size();
+  for (uint64_t block_id = iree_hal_amdgpu_blit_linear_id();;) {
+    if (IREE_AMDGPU_UNLIKELY(vector_block_count == 0)) {
+      if (block_id == 0) {
+        const uint8_t* source_tail_ptr = (const uint8_t*)source_ptr;
+        uint8_t* target_tail_ptr = (uint8_t*)target_ptr;
+        for (uint64_t i = 0; i < tail_length; ++i) {
+          target_tail_ptr[i] = source_tail_ptr[i];
+        }
+      }
+      return;
+    }
+    if (IREE_AMDGPU_UNLIKELY(block_id >= vector_block_count)) return;
+    const uint64_t element_offset = block_id * IREE_HAL_AMDGPU_COPY_BLOCK_COUNT;
+    const uint64_t element_count = IREE_AMDGPU_MIN(
+        IREE_HAL_AMDGPU_COPY_BLOCK_COUNT, full_element_count - element_offset);
+    if (IREE_AMDGPU_LIKELY(element_count == IREE_HAL_AMDGPU_COPY_BLOCK_COUNT)) {
+#pragma unroll
+      for (size_t i = 0; i < IREE_HAL_AMDGPU_COPY_BLOCK_COUNT; ++i) {
+        target_ptr[element_offset + i].value =
+            source_ptr[element_offset + i].value;
+      }
+    } else {
+      for (size_t i = 0; i < element_count; ++i) {
+        target_ptr[element_offset + i].value =
+            source_ptr[element_offset + i].value;
+      }
+    }
+    if (IREE_AMDGPU_UNLIKELY(tail_length &&
+                             block_id + 1 == vector_block_count)) {
+      const uint8_t* source_tail_ptr = (const uint8_t*)source_ptr + tail_offset;
+      uint8_t* target_tail_ptr = (uint8_t*)target_ptr + tail_offset;
+      for (uint64_t i = 0; i < tail_length; ++i) {
+        target_tail_ptr[i] = source_tail_ptr[i];
+      }
+    }
+    if (!iree_hal_amdgpu_blit_advance(&block_id, block_stride)) return;
+  }
+}
+
+#endif  // IREE_AMDGPU_TARGET_DEVICE
+
+// Copies currently dispatch builtin blit kernels. SDMA emission belongs in a
+// queue-specific wrapper because it changes queue ownership and packet
+// reservation policy.
+bool iree_hal_amdgpu_device_buffer_copy_emplace(
     const iree_hal_amdgpu_device_buffer_transfer_context_t* IREE_AMDGPU_RESTRICT
         context,
-    const void* source_ptr, void* target_ptr, const uint64_t length,
-    uint64_t* IREE_AMDGPU_RESTRICT kernarg_ptr, const uint64_t packet_id) {
-  IREE_AMDGPU_TRACE_BUFFER_SCOPE(context->trace_buffer);
-  IREE_AMDGPU_TRACE_ZONE_BEGIN(z0);
-
+    iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT dispatch_packet,
+    const void* source_ptr, void* target_ptr, uint64_t length,
+    void* IREE_AMDGPU_RESTRICT kernarg_ptr) {
   // Select the kernel for the copy operation.
-  // TODO(benvanik): switch kernel based on source/target/length alignment.
   const iree_hal_amdgpu_device_kernel_args_t* IREE_AMDGPU_RESTRICT kernel_args =
       NULL;
-  size_t element_size = 1;
-  size_t block_size = 1;
-  if (iree_amdgpu_has_alignment((size_t)source_ptr,
+  uint64_t element_size = 1;
+  uint64_t block_size = 1;
+  bool uses_byte_length = false;
+  if (iree_amdgpu_has_alignment((uintptr_t)source_ptr,
                                 IREE_HAL_AMDGPU_COPY_BLOCK_ELEMENT_SIZE) &&
-      iree_amdgpu_has_alignment((size_t)target_ptr,
+      iree_amdgpu_has_alignment((uintptr_t)target_ptr,
                                 IREE_HAL_AMDGPU_COPY_BLOCK_ELEMENT_SIZE) &&
       iree_amdgpu_has_alignment(length,
                                 IREE_HAL_AMDGPU_COPY_BLOCK_ELEMENT_SIZE)) {
-    IREE_AMDGPU_TRACE_ZONE_APPEND_TEXT_LITERAL(z0, "copy_block_x16");
     kernel_args =
         &context->kernels->iree_hal_amdgpu_device_buffer_copy_block_x16;
     element_size = IREE_HAL_AMDGPU_COPY_BLOCK_ELEMENT_SIZE;
     block_size = IREE_HAL_AMDGPU_COPY_BLOCK_COUNT;
+  } else if (iree_amdgpu_has_alignment(
+                 (uintptr_t)source_ptr,
+                 IREE_HAL_AMDGPU_COPY_BLOCK_X8_ELEMENT_SIZE) &&
+             iree_amdgpu_has_alignment(
+                 (uintptr_t)target_ptr,
+                 IREE_HAL_AMDGPU_COPY_BLOCK_X8_ELEMENT_SIZE) &&
+             iree_amdgpu_has_alignment(
+                 length, IREE_HAL_AMDGPU_COPY_BLOCK_X8_ELEMENT_SIZE)) {
+    kernel_args =
+        &context->kernels->iree_hal_amdgpu_device_buffer_copy_block_x8;
+    element_size = IREE_HAL_AMDGPU_COPY_BLOCK_X8_ELEMENT_SIZE;
+    block_size = IREE_HAL_AMDGPU_COPY_BLOCK_X8_COUNT;
+  } else if (iree_amdgpu_has_alignment(
+                 (uintptr_t)source_ptr,
+                 IREE_HAL_AMDGPU_COPY_BLOCK_X4_ELEMENT_SIZE) &&
+             iree_amdgpu_has_alignment(
+                 (uintptr_t)target_ptr,
+                 IREE_HAL_AMDGPU_COPY_BLOCK_X4_ELEMENT_SIZE) &&
+             iree_amdgpu_has_alignment(
+                 length, IREE_HAL_AMDGPU_COPY_BLOCK_X4_ELEMENT_SIZE)) {
+    kernel_args =
+        &context->kernels->iree_hal_amdgpu_device_buffer_copy_block_x4;
+    element_size = IREE_HAL_AMDGPU_COPY_BLOCK_X4_ELEMENT_SIZE;
+    block_size = IREE_HAL_AMDGPU_COPY_BLOCK_X4_COUNT;
+  } else if (length >= IREE_HAL_AMDGPU_COPY_BLOCK_UNALIGNED_MIN_SIZE) {
+    kernel_args = &context->kernels
+                       ->iree_hal_amdgpu_device_buffer_copy_block_unaligned_x16;
+    element_size = 1;
+    block_size = 1;
+    uses_byte_length = true;
   } else {
-    IREE_AMDGPU_TRACE_ZONE_APPEND_TEXT_LITERAL(z0, "copy_x1");
     kernel_args = &context->kernels->iree_hal_amdgpu_device_buffer_copy_x1;
     element_size = 1;
     block_size = 1;
   }
 
+  const uint64_t element_count = length / element_size;
+  const uint64_t block_count =
+      uses_byte_length ? iree_hal_amdgpu_blit_unaligned_block_count(
+                             length, IREE_HAL_AMDGPU_COPY_BLOCK_ELEMENT_SIZE,
+                             IREE_HAL_AMDGPU_COPY_BLOCK_COUNT)
+                       : IREE_AMDGPU_CEIL_DIV(element_count, block_size);
+  uint32_t grid_size_x = 0;
+  uint32_t grid_size_y = 0;
+  if (IREE_AMDGPU_UNLIKELY(!iree_hal_amdgpu_blit_calculate_grid_size(
+          context, block_count, &grid_size_x, &grid_size_y))) {
+    return false;
+  }
+
   // Update kernargs (same API for all kernels).
-  const size_t element_count = length / element_size;
   iree_hal_amdgpu_device_buffer_copy_kernargs_t* kernargs =
       (iree_hal_amdgpu_device_buffer_copy_kernargs_t*)kernarg_ptr;
   kernargs->source_ptr = source_ptr;
   kernargs->target_ptr = target_ptr;
   kernargs->element_length = element_count;
 
-  // To support transfers with more than UINT_MAX elements (uint32_t grid_size)
-  // we divide the problem into chunks as needed. We keep the innermost chunk
-  // size small as if we do [X,Y,1] we're likely to overshoot and don't want to
-  // have too many wasted invocations.
-  const size_t block_count = IREE_AMDGPU_CEIL_DIV(element_count, block_size);
-  uint32_t grid_size_x = 1;
-  uint32_t grid_size_y = 1;
-  if (IREE_AMDGPU_LIKELY(block_count <= 0xFFFFFFFFu)) {
-    grid_size_x = (uint32_t)block_count;
-  } else {
-    grid_size_x = 256;
-    grid_size_y = (uint32_t)IREE_AMDGPU_CEIL_DIV(block_count, grid_size_x);
-  }
-
-  // Populate the packet.
-  const uint64_t queue_mask = context->queue.size - 1;  // power of two
-  iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT dispatch_packet =
-      context->queue.base_address + (packet_id & queue_mask) * 64;
-  dispatch_packet->setup = kernel_args->setup;
-  dispatch_packet->workgroup_size[0] = kernel_args->workgroup_size[0];
-  dispatch_packet->workgroup_size[1] = kernel_args->workgroup_size[1];
-  dispatch_packet->workgroup_size[2] = kernel_args->workgroup_size[2];
-  dispatch_packet->reserved0 = 0;
-  dispatch_packet->grid_size[0] = grid_size_x;
-  dispatch_packet->grid_size[1] = grid_size_y;
-  dispatch_packet->grid_size[2] = 1;
-  dispatch_packet->private_segment_size = kernel_args->private_segment_size;
-  dispatch_packet->group_segment_size = kernel_args->group_segment_size;
-  dispatch_packet->kernel_object = kernel_args->kernel_object;
-  dispatch_packet->kernarg_address = kernarg_ptr;
-  dispatch_packet->reserved2 = 0;
-
-  IREE_AMDGPU_TRACE_ZONE_END(z0);
-  return dispatch_packet;
-}
-
-void iree_hal_amdgpu_device_buffer_copy_enqueue(
-    const iree_hal_amdgpu_device_buffer_transfer_context_t* IREE_AMDGPU_RESTRICT
-        context,
-    const void* source_ptr, void* target_ptr, const uint64_t length,
-    uint64_t* IREE_AMDGPU_RESTRICT kernarg_ptr) {
-  IREE_AMDGPU_TRACE_BUFFER_SCOPE(context->trace_buffer);
-  IREE_AMDGPU_TRACE_ZONE_BEGIN(z0);
-  IREE_AMDGPU_TRACE_ZONE_APPEND_VALUE_I64(z0, length);
-
-  // Reserve and begin populating the operation packet.
-  // When tracing is enabled capture the timing signal.
-  iree_hsa_signal_t completion_signal = iree_hsa_signal_null();
-  const uint64_t packet_id = iree_hal_amdgpu_device_blit_reserve(
-      context, IREE_HAL_AMDGPU_TRACE_EXECUTION_ZONE_TYPE_COPY,
-      &completion_signal);
-
-  // Emplace the dispatch packet into the queue.
-  // Note that until the packet is issued the queue will stall.
-  iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT packet =
-      iree_hal_amdgpu_device_buffer_copy_emplace_reserve(
-          context, source_ptr, target_ptr, length, kernarg_ptr, packet_id);
-
-  // Issues the buffer operation packet by configuring its header and signaling
-  // the queue doorbell.
-  iree_hal_amdgpu_device_blit_commit(context, packet_id, packet,
-                                     completion_signal);
-
-  IREE_AMDGPU_TRACE_ZONE_END(z0);
+  iree_hal_amdgpu_blit_emplace_dispatch(context, kernel_args, dispatch_packet,
+                                        grid_size_x, grid_size_y, kernarg_ptr);
+  return true;
 }
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/blit.h b/runtime/src/iree/hal/drivers/amdgpu/device/blit.h
index db6c6f5..6195d9d 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/device/blit.h
+++ b/runtime/src/iree/hal/drivers/amdgpu/device/blit.h
@@ -7,25 +7,46 @@
 #ifndef IREE_HAL_DRIVERS_AMDGPU_DEVICE_BLIT_H_
 #define IREE_HAL_DRIVERS_AMDGPU_DEVICE_BLIT_H_
 
+#include "iree/hal/drivers/amdgpu/abi/queue.h"
 #include "iree/hal/drivers/amdgpu/device/kernels.h"
-#include "iree/hal/drivers/amdgpu/device/support/common.h"
-#include "iree/hal/drivers/amdgpu/device/support/queue.h"
-#include "iree/hal/drivers/amdgpu/device/tracing.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
 
 //===----------------------------------------------------------------------===//
 // Blit Kernels
 //===----------------------------------------------------------------------===//
 
-// Context used when scheduling transfer commands.
+// Builtin transfer kernel table used when populating blit dispatch packets.
+// Queue reservation, packet header commit, completion-signal assignment, and
+// doorbell writes are handled by the caller's queue implementation.
 typedef struct iree_hal_amdgpu_device_buffer_transfer_context_t {
-  // Target queue that will execute the transfer operation.
-  iree_amd_cached_queue_t queue;
   // Handles to opaque kernel objects used to dispatch builtin kernels.
   const iree_hal_amdgpu_device_kernels_t* kernels;
-  // Optional trace buffer used when tracing infrastructure is available.
-  iree_hal_amdgpu_device_trace_buffer_t* trace_buffer;
+
+  // Device wavefront width used when choosing the builtin blit workgroup size.
+  // This is kept explicit so future wave32/wave64-specialized kernels can
+  // select variants without guessing from the loaded code object.
+  uint16_t wavefront_size;
+  // X-dimension workgroup size used for all builtin blit dispatches. Y/Z are
+  // always 1; the kernels are 1D along the global linear index.
+  uint16_t workgroup_size_x;
+
+  // Maximum number of blit workgroups to launch for one transfer. Kernels use
+  // grid-stride loops, so large transfers bound resident work and let each
+  // lane process multiple elements instead of launching one lane per element.
+  uint32_t max_workgroup_count;
 } iree_hal_amdgpu_device_buffer_transfer_context_t;
 
+// Initializes a builtin transfer context from device properties. The caller
+// must ensure |compute_unit_count| is non-zero and |wavefront_size| is one of
+// {32, 64}; see physical_device.c for the HSA-query-backed validation path.
+void iree_hal_amdgpu_device_buffer_transfer_context_initialize(
+    const iree_hal_amdgpu_device_kernels_t* kernels,
+    uint32_t compute_unit_count, uint32_t wavefront_size,
+    iree_hal_amdgpu_device_buffer_transfer_context_t* out_context);
+
 // Kernel arguments for the `iree_hal_amdgpu_device_buffer_fill_*` family.
 typedef struct iree_hal_amdgpu_device_buffer_fill_kernargs_t {
   void* target_ptr;
@@ -47,52 +68,45 @@
   sizeof(iree_hal_amdgpu_device_buffer_copy_kernargs_t)
 #define IREE_HAL_AMDGPU_DEVICE_BUFFER_COPY_KERNARG_ALIGNMENT \
   IREE_AMDGPU_ALIGNOF(iree_hal_amdgpu_device_buffer_copy_kernargs_t)
+// Alignment used for host-staged update payloads consumed by copy kernels.
+#define IREE_HAL_AMDGPU_DEVICE_BUFFER_COPY_STAGED_SOURCE_ALIGNMENT 16
+// Byte offset to a host-staged update payload following copy kernargs.
+#define IREE_HAL_AMDGPU_DEVICE_BUFFER_COPY_STAGED_SOURCE_OFFSET       \
+  ((IREE_HAL_AMDGPU_DEVICE_BUFFER_COPY_KERNARG_SIZE +                 \
+    IREE_HAL_AMDGPU_DEVICE_BUFFER_COPY_STAGED_SOURCE_ALIGNMENT - 1) & \
+   ~(IREE_HAL_AMDGPU_DEVICE_BUFFER_COPY_STAGED_SOURCE_ALIGNMENT - 1))
 
-#if defined(IREE_AMDGPU_TARGET_DEVICE)
-
-// Emplaces a fill dispatch packet in the target queue at the given index.
-// The queue doorbell will not be signaled.
+// Populates a builtin fill dispatch packet and its kernargs in already-reserved
+// storage. The caller owns packet header commit, completion signal assignment,
+// and queue doorbell signaling.
 //
-// NOTE: this only works with blits today. SDMA will require a different
-// signature.
-iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT
-iree_hal_amdgpu_device_buffer_fill_emplace_reserve(
+// Returns false if |pattern_length| is unsupported, the target pointer/length
+// alignment is incompatible with that pattern width, or |length| cannot be
+// represented by the dispatch packet grid dimensions. On failure,
+// |dispatch_packet| and |kernarg_ptr| are left unmodified.
+bool iree_hal_amdgpu_device_buffer_fill_emplace(
     const iree_hal_amdgpu_device_buffer_transfer_context_t* IREE_AMDGPU_RESTRICT
         context,
-    void* target_ptr, const uint64_t length, const uint64_t pattern,
-    const uint8_t pattern_length, uint64_t* IREE_AMDGPU_RESTRICT kernarg_ptr,
-    const uint64_t packet_id);
+    iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT dispatch_packet,
+    void* target_ptr, uint64_t length, uint64_t pattern, uint8_t pattern_length,
+    void* IREE_AMDGPU_RESTRICT kernarg_ptr);
 
-// Enqueues a fill dispatch packet in the target queue.
-// The packet will be acquired at the current write_index and the queue doorbell
-// will be signaled.
-void iree_hal_amdgpu_device_buffer_fill_enqueue(
-    const iree_hal_amdgpu_device_buffer_transfer_context_t* IREE_AMDGPU_RESTRICT
-        context,
-    void* target_ptr, const uint64_t length, const uint64_t pattern,
-    const uint8_t pattern_length, uint64_t* IREE_AMDGPU_RESTRICT kernarg_ptr);
-
-// Emplaces a copy dispatch packet in the target queue at the given index.
-// The queue doorbell will not be signaled.
+// Populates a builtin copy dispatch packet and its kernargs in already-reserved
+// storage. The caller owns packet header commit, completion signal assignment,
+// and queue doorbell signaling.
 //
-// NOTE: this only works with blits today. SDMA will require a different
-// signature.
-iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT
-iree_hal_amdgpu_device_buffer_copy_emplace_reserve(
+// Returns false if |length| cannot be represented by the dispatch packet grid
+// dimensions. On failure, |dispatch_packet| and |kernarg_ptr| are left
+// unmodified.
+bool iree_hal_amdgpu_device_buffer_copy_emplace(
     const iree_hal_amdgpu_device_buffer_transfer_context_t* IREE_AMDGPU_RESTRICT
         context,
-    const void* source_ptr, void* target_ptr, const uint64_t length,
-    uint64_t* IREE_AMDGPU_RESTRICT kernarg_ptr, const uint64_t packet_id);
+    iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT dispatch_packet,
+    const void* source_ptr, void* target_ptr, uint64_t length,
+    void* IREE_AMDGPU_RESTRICT kernarg_ptr);
 
-// Enqueues a copy dispatch packet in the target queue.
-// The packet will be acquired at the current write_index and the queue doorbell
-// will be signaled.
-void iree_hal_amdgpu_device_buffer_copy_enqueue(
-    const iree_hal_amdgpu_device_buffer_transfer_context_t* IREE_AMDGPU_RESTRICT
-        context,
-    const void* source_ptr, void* target_ptr, const uint64_t length,
-    uint64_t* IREE_AMDGPU_RESTRICT kernarg_ptr);
-
-#endif  // IREE_AMDGPU_TARGET_DEVICE
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
 
 #endif  // IREE_HAL_DRIVERS_AMDGPU_DEVICE_BLIT_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/blit_test.cc b/runtime/src/iree/hal/drivers/amdgpu/device/blit_test.cc
new file mode 100644
index 0000000..1e1e391
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/device/blit_test.cc
@@ -0,0 +1,568 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/device/blit.h"
+
+#include "iree/testing/gtest.h"
+
+namespace iree::hal::amdgpu {
+namespace {
+
+// Sentinels for kernel object pointers (we don't dispatch here and just need a
+// way to ensure the proper kernel is selected).
+constexpr uint64_t kFillX1KernelObject = 0xF11u;
+constexpr uint64_t kFillX2KernelObject = 0xF12u;
+constexpr uint64_t kFillX4KernelObject = 0xF14u;
+constexpr uint64_t kFillX8KernelObject = 0xF18u;
+constexpr uint64_t kFillBlockX16KernelObject = 0xF160u;
+constexpr uint64_t kFillBlockUnalignedX16KernelObject = 0xF161u;
+constexpr uint64_t kCopyX1KernelObject = 0xC11u;
+constexpr uint64_t kCopyBlockX4KernelObject = 0xC40u;
+constexpr uint64_t kCopyBlockX8KernelObject = 0xC80u;
+constexpr uint64_t kCopyBlockX16KernelObject = 0xC160u;
+constexpr uint64_t kCopyBlockUnalignedX16KernelObject = 0xC161u;
+
+static iree_hal_amdgpu_device_kernel_args_t MakeKernelArgs(
+    uint64_t kernel_object, uint16_t setup, uint16_t workgroup_size_x,
+    uint32_t private_segment_size, uint32_t group_segment_size) {
+  iree_hal_amdgpu_device_kernel_args_t kernel_args = {};
+  kernel_args.kernel_object = kernel_object;
+  kernel_args.kernarg_size = 24;
+  kernel_args.kernarg_alignment = 8;
+  kernel_args.setup = setup;
+  kernel_args.workgroup_size[0] = workgroup_size_x;
+  kernel_args.workgroup_size[1] = 1;
+  kernel_args.workgroup_size[2] = 1;
+  kernel_args.private_segment_size = private_segment_size;
+  kernel_args.group_segment_size = group_segment_size;
+  return kernel_args;
+}
+
+static iree_hal_amdgpu_device_kernels_t MakeKernels() {
+  iree_hal_amdgpu_device_kernels_t kernels = {};
+  kernels.iree_hal_amdgpu_device_buffer_fill_x1 =
+      MakeKernelArgs(kFillX1KernelObject, 1, 32, 4, 8);
+  kernels.iree_hal_amdgpu_device_buffer_fill_x2 =
+      MakeKernelArgs(kFillX2KernelObject, 2, 32, 5, 9);
+  kernels.iree_hal_amdgpu_device_buffer_fill_x4 =
+      MakeKernelArgs(kFillX4KernelObject, 3, 32, 6, 10);
+  kernels.iree_hal_amdgpu_device_buffer_fill_x8 =
+      MakeKernelArgs(kFillX8KernelObject, 4, 32, 7, 11);
+  kernels.iree_hal_amdgpu_device_buffer_fill_block_x16 =
+      MakeKernelArgs(kFillBlockX16KernelObject, 5, 32, 8, 12);
+  kernels.iree_hal_amdgpu_device_buffer_fill_block_unaligned_x16 =
+      MakeKernelArgs(kFillBlockUnalignedX16KernelObject, 6, 32, 9, 13);
+  kernels.iree_hal_amdgpu_device_buffer_copy_x1 =
+      MakeKernelArgs(kCopyX1KernelObject, 7, 32, 10, 14);
+  kernels.iree_hal_amdgpu_device_buffer_copy_block_x4 =
+      MakeKernelArgs(kCopyBlockX4KernelObject, 8, 32, 11, 15);
+  kernels.iree_hal_amdgpu_device_buffer_copy_block_x8 =
+      MakeKernelArgs(kCopyBlockX8KernelObject, 9, 32, 12, 16);
+  kernels.iree_hal_amdgpu_device_buffer_copy_block_x16 =
+      MakeKernelArgs(kCopyBlockX16KernelObject, 10, 32, 13, 17);
+  kernels.iree_hal_amdgpu_device_buffer_copy_block_unaligned_x16 =
+      MakeKernelArgs(kCopyBlockUnalignedX16KernelObject, 11, 32, 14, 18);
+  return kernels;
+}
+
+static iree_hal_amdgpu_device_buffer_transfer_context_t MakeContext(
+    const iree_hal_amdgpu_device_kernels_t* kernels,
+    uint32_t compute_unit_count = 4, uint32_t wavefront_size = 64) {
+  iree_hal_amdgpu_device_buffer_transfer_context_t context = {};
+  iree_hal_amdgpu_device_buffer_transfer_context_initialize(
+      kernels, compute_unit_count, wavefront_size, &context);
+  return context;
+}
+
+TEST(BlitTest, TransferContextInitializeUsesDeviceLaunchMetadata) {
+  const iree_hal_amdgpu_device_kernels_t kernels = MakeKernels();
+  const iree_hal_amdgpu_device_buffer_transfer_context_t context =
+      MakeContext(&kernels, /*compute_unit_count=*/7,
+                  /*wavefront_size=*/32);
+
+  EXPECT_EQ(context.kernels, &kernels);
+  EXPECT_EQ(context.wavefront_size, 32);
+  EXPECT_EQ(context.workgroup_size_x, 32);
+  EXPECT_EQ(context.max_workgroup_count, 28u);
+}
+
+TEST(BlitTest, TransferContextInitializeSaturatesLargeComputeUnitCount) {
+  const iree_hal_amdgpu_device_kernels_t kernels = MakeKernels();
+  const iree_hal_amdgpu_device_buffer_transfer_context_t context =
+      MakeContext(&kernels, /*compute_unit_count=*/UINT32_MAX,
+                  /*wavefront_size=*/64);
+
+  EXPECT_EQ(context.wavefront_size, 64);
+  EXPECT_EQ(context.workgroup_size_x, 64);
+  EXPECT_EQ(context.max_workgroup_count, UINT32_MAX);
+}
+
+TEST(BlitTest, FillEmplaceSelectsBlockFillForAlignedTransfer) {
+  const iree_hal_amdgpu_device_kernels_t kernels = MakeKernels();
+  const iree_hal_amdgpu_device_buffer_transfer_context_t context =
+      MakeContext(&kernels);
+  iree_hsa_kernel_dispatch_packet_t packet = {};
+  iree_hal_amdgpu_device_buffer_fill_kernargs_t kernargs = {};
+
+  ASSERT_TRUE(iree_hal_amdgpu_device_buffer_fill_emplace(
+      &context, &packet, (void*)0x2000, /*length=*/512,
+      /*pattern=*/0xABu,
+      /*pattern_length=*/1, &kernargs));
+
+  EXPECT_EQ(packet.setup, 5);
+  EXPECT_EQ(packet.workgroup_size[0], 64);
+  EXPECT_EQ(packet.workgroup_size[1], 1);
+  EXPECT_EQ(packet.workgroup_size[2], 1);
+  EXPECT_EQ(packet.grid_size[0], 8);
+  EXPECT_EQ(packet.grid_size[1], 1);
+  EXPECT_EQ(packet.grid_size[2], 1);
+  EXPECT_EQ(packet.private_segment_size, 8);
+  EXPECT_EQ(packet.group_segment_size, 12);
+  EXPECT_EQ(packet.kernel_object, kFillBlockX16KernelObject);
+  EXPECT_EQ(packet.kernarg_address, &kernargs);
+  EXPECT_EQ(packet.completion_signal.handle, iree_hsa_signal_null().handle);
+
+  EXPECT_EQ(kernargs.target_ptr, (void*)0x2000);
+  EXPECT_EQ(kernargs.element_length, 32u);
+  EXPECT_EQ(kernargs.pattern, 0xABABABABABABABABull);
+}
+
+TEST(BlitTest, FillEmplaceUsesNoopDispatchForZeroLengthTransfer) {
+  const iree_hal_amdgpu_device_kernels_t kernels = MakeKernels();
+  const iree_hal_amdgpu_device_buffer_transfer_context_t context =
+      MakeContext(&kernels);
+  iree_hsa_kernel_dispatch_packet_t packet = {};
+  iree_hal_amdgpu_device_buffer_fill_kernargs_t kernargs = {};
+
+  ASSERT_TRUE(iree_hal_amdgpu_device_buffer_fill_emplace(
+      &context, &packet, (void*)0x2000, /*length=*/0,
+      /*pattern=*/0xABu,
+      /*pattern_length=*/1, &kernargs));
+
+  EXPECT_EQ(packet.workgroup_size[0], 64);
+  EXPECT_EQ(packet.grid_size[0], 1);
+  EXPECT_EQ(packet.grid_size[1], 1);
+  EXPECT_EQ(packet.grid_size[2], 1);
+  EXPECT_EQ(packet.kernel_object, kFillBlockX16KernelObject);
+  EXPECT_EQ(kernargs.target_ptr, (void*)0x2000);
+  EXPECT_EQ(kernargs.element_length, 0u);
+}
+
+TEST(BlitTest, FillEmplaceSelectsScalarX1FillForUnalignedByteTransfer) {
+  const iree_hal_amdgpu_device_kernels_t kernels = MakeKernels();
+  const iree_hal_amdgpu_device_buffer_transfer_context_t context =
+      MakeContext(&kernels);
+  iree_hsa_kernel_dispatch_packet_t packet = {};
+  iree_hal_amdgpu_device_buffer_fill_kernargs_t kernargs = {};
+
+  ASSERT_TRUE(iree_hal_amdgpu_device_buffer_fill_emplace(
+      &context, &packet, (void*)0x2001, /*length=*/7,
+      /*pattern=*/0xABu,
+      /*pattern_length=*/1, &kernargs));
+
+  EXPECT_EQ(packet.setup, 1);
+  EXPECT_EQ(packet.grid_size[0], 7);
+  EXPECT_EQ(packet.grid_size[1], 1);
+  EXPECT_EQ(packet.kernel_object, kFillX1KernelObject);
+  EXPECT_EQ(kernargs.target_ptr, (void*)0x2001);
+  EXPECT_EQ(kernargs.element_length, 7u);
+  EXPECT_EQ(kernargs.pattern, 0xABu);
+}
+
+TEST(BlitTest, FillEmplaceSelectsUnalignedBlockFillAtVectorThreshold) {
+  const iree_hal_amdgpu_device_kernels_t kernels = MakeKernels();
+  const iree_hal_amdgpu_device_buffer_transfer_context_t context =
+      MakeContext(&kernels);
+  iree_hsa_kernel_dispatch_packet_t packet = {};
+  iree_hal_amdgpu_device_buffer_fill_kernargs_t kernargs = {};
+
+  ASSERT_TRUE(iree_hal_amdgpu_device_buffer_fill_emplace(
+      &context, &packet, (void*)0x2001, /*length=*/128,
+      /*pattern=*/0xABu,
+      /*pattern_length=*/1, &kernargs));
+
+  EXPECT_EQ(packet.setup, 6);
+  EXPECT_EQ(packet.grid_size[0], 2);
+  EXPECT_EQ(packet.grid_size[1], 1);
+  EXPECT_EQ(packet.kernel_object, kFillBlockUnalignedX16KernelObject);
+  EXPECT_EQ(kernargs.target_ptr, (void*)0x2001);
+  EXPECT_EQ(kernargs.element_length, 128u);
+  EXPECT_EQ(kernargs.pattern, 0xABABABABABABABABull);
+}
+
+TEST(BlitTest, FillEmplaceSelectsUnalignedBlockFillForLargeByteTransfer) {
+  const iree_hal_amdgpu_device_kernels_t kernels = MakeKernels();
+  const iree_hal_amdgpu_device_buffer_transfer_context_t context =
+      MakeContext(&kernels);
+  iree_hsa_kernel_dispatch_packet_t packet = {};
+  iree_hal_amdgpu_device_buffer_fill_kernargs_t kernargs = {};
+
+  ASSERT_TRUE(iree_hal_amdgpu_device_buffer_fill_emplace(
+      &context, &packet, (void*)0x2001, /*length=*/257,
+      /*pattern=*/0xABu,
+      /*pattern_length=*/1, &kernargs));
+
+  EXPECT_EQ(packet.setup, 6);
+  EXPECT_EQ(packet.grid_size[0], 4);
+  EXPECT_EQ(packet.grid_size[1], 1);
+  EXPECT_EQ(packet.kernel_object, kFillBlockUnalignedX16KernelObject);
+  EXPECT_EQ(kernargs.target_ptr, (void*)0x2001);
+  EXPECT_EQ(kernargs.element_length, 257u);
+  EXPECT_EQ(kernargs.pattern, 0xABABABABABABABABull);
+}
+
+TEST(BlitTest, FillEmplaceSelectsScalarX2FillForHalfwordTransfer) {
+  const iree_hal_amdgpu_device_kernels_t kernels = MakeKernels();
+  const iree_hal_amdgpu_device_buffer_transfer_context_t context =
+      MakeContext(&kernels);
+  iree_hsa_kernel_dispatch_packet_t packet = {};
+  iree_hal_amdgpu_device_buffer_fill_kernargs_t kernargs = {};
+
+  ASSERT_TRUE(iree_hal_amdgpu_device_buffer_fill_emplace(
+      &context, &packet, (void*)0x2002, /*length=*/10,
+      /*pattern=*/0xABCDu,
+      /*pattern_length=*/2, &kernargs));
+
+  EXPECT_EQ(packet.setup, 2);
+  EXPECT_EQ(packet.grid_size[0], 5);
+  EXPECT_EQ(packet.grid_size[1], 1);
+  EXPECT_EQ(packet.kernel_object, kFillX2KernelObject);
+  EXPECT_EQ(kernargs.target_ptr, (void*)0x2002);
+  EXPECT_EQ(kernargs.element_length, 5u);
+  EXPECT_EQ(kernargs.pattern, 0xABCDu);
+}
+
+TEST(BlitTest, FillEmplaceSelectsBlockX4FillForDwordTransfer) {
+  const iree_hal_amdgpu_device_kernels_t kernels = MakeKernels();
+  const iree_hal_amdgpu_device_buffer_transfer_context_t context =
+      MakeContext(&kernels);
+  iree_hsa_kernel_dispatch_packet_t packet = {};
+  iree_hal_amdgpu_device_buffer_fill_kernargs_t kernargs = {};
+
+  ASSERT_TRUE(iree_hal_amdgpu_device_buffer_fill_emplace(
+      &context, &packet, (void*)0x2004, /*length=*/20,
+      /*pattern=*/0xABCDEF01u,
+      /*pattern_length=*/4, &kernargs));
+
+  EXPECT_EQ(packet.setup, 3);
+  EXPECT_EQ(packet.grid_size[0], 1);
+  EXPECT_EQ(packet.grid_size[1], 1);
+  EXPECT_EQ(packet.kernel_object, kFillX4KernelObject);
+  EXPECT_EQ(kernargs.target_ptr, (void*)0x2004);
+  EXPECT_EQ(kernargs.element_length, 5u);
+  EXPECT_EQ(kernargs.pattern, 0xABCDEF01ABCDEF01ull);
+}
+
+TEST(BlitTest, FillEmplaceUsesDwordFillForSmallAlignedPatterns) {
+  const iree_hal_amdgpu_device_kernels_t kernels = MakeKernels();
+  const iree_hal_amdgpu_device_buffer_transfer_context_t context =
+      MakeContext(&kernels);
+  iree_hsa_kernel_dispatch_packet_t packet = {};
+  iree_hal_amdgpu_device_buffer_fill_kernargs_t kernargs = {};
+
+  ASSERT_TRUE(iree_hal_amdgpu_device_buffer_fill_emplace(
+      &context, &packet, (void*)0x2000, /*length=*/4,
+      /*pattern=*/0xABu,
+      /*pattern_length=*/1, &kernargs));
+
+  EXPECT_EQ(packet.setup, 3);
+  EXPECT_EQ(packet.grid_size[0], 1);
+  EXPECT_EQ(packet.grid_size[1], 1);
+  EXPECT_EQ(packet.kernel_object, kFillX4KernelObject);
+  EXPECT_EQ(kernargs.target_ptr, (void*)0x2000);
+  EXPECT_EQ(kernargs.element_length, 1u);
+  EXPECT_EQ(kernargs.pattern, 0xABABABABABABABABull);
+}
+
+TEST(BlitTest, FillEmplaceSelectsScalarX8FillForInternalQwordTransfer) {
+  const iree_hal_amdgpu_device_kernels_t kernels = MakeKernels();
+  const iree_hal_amdgpu_device_buffer_transfer_context_t context =
+      MakeContext(&kernels);
+  iree_hsa_kernel_dispatch_packet_t packet = {};
+  iree_hal_amdgpu_device_buffer_fill_kernargs_t kernargs = {};
+
+  ASSERT_TRUE(iree_hal_amdgpu_device_buffer_fill_emplace(
+      &context, &packet, (void*)0x2008, /*length=*/24,
+      /*pattern=*/0x0123456789ABCDEFull,
+      /*pattern_length=*/8, &kernargs));
+
+  EXPECT_EQ(packet.setup, 4);
+  EXPECT_EQ(packet.grid_size[0], 3);
+  EXPECT_EQ(packet.grid_size[1], 1);
+  EXPECT_EQ(packet.kernel_object, kFillX8KernelObject);
+  EXPECT_EQ(kernargs.target_ptr, (void*)0x2008);
+  EXPECT_EQ(kernargs.element_length, 3u);
+  EXPECT_EQ(kernargs.pattern, 0x0123456789ABCDEFull);
+}
+
+TEST(BlitTest, FillEmplaceMasksPatternToDeclaredWidth) {
+  const iree_hal_amdgpu_device_kernels_t kernels = MakeKernels();
+  const iree_hal_amdgpu_device_buffer_transfer_context_t context =
+      MakeContext(&kernels);
+  iree_hsa_kernel_dispatch_packet_t packet = {};
+  iree_hal_amdgpu_device_buffer_fill_kernargs_t kernargs = {};
+
+  ASSERT_TRUE(iree_hal_amdgpu_device_buffer_fill_emplace(
+      &context, &packet, (void*)0x2000, /*length=*/512,
+      /*pattern=*/0x1ABu,
+      /*pattern_length=*/1, &kernargs));
+
+  EXPECT_EQ(packet.kernel_object, kFillBlockX16KernelObject);
+  EXPECT_EQ(kernargs.pattern, 0xABABABABABABABABull);
+}
+
+TEST(BlitTest, FillEmplaceRejectsUnsupportedPatternLengthWithoutMutation) {
+  const iree_hal_amdgpu_device_kernels_t kernels = MakeKernels();
+  const iree_hal_amdgpu_device_buffer_transfer_context_t context =
+      MakeContext(&kernels);
+  iree_hsa_kernel_dispatch_packet_t packet = {};
+  packet.setup = 0x55AAu;
+  packet.kernel_object = 0xDEADCAFEu;
+  iree_hal_amdgpu_device_buffer_fill_kernargs_t kernargs = {
+      /*.target_ptr=*/(void*)0x1234,
+      /*.element_length=*/7,
+      /*.pattern=*/0x99,
+  };
+
+  EXPECT_FALSE(iree_hal_amdgpu_device_buffer_fill_emplace(
+      &context, &packet, (void*)0x2000, /*length=*/12,
+      /*pattern=*/0xABCDu,
+      /*pattern_length=*/3, &kernargs));
+
+  EXPECT_EQ(packet.setup, 0x55AAu);
+  EXPECT_EQ(packet.kernel_object, 0xDEADCAFEu);
+  EXPECT_EQ(kernargs.target_ptr, (void*)0x1234);
+  EXPECT_EQ(kernargs.element_length, 7u);
+  EXPECT_EQ(kernargs.pattern, 0x99u);
+}
+
+TEST(BlitTest, FillEmplaceRejectsMisalignedPatternTransferWithoutMutation) {
+  const iree_hal_amdgpu_device_kernels_t kernels = MakeKernels();
+  const iree_hal_amdgpu_device_buffer_transfer_context_t context =
+      MakeContext(&kernels);
+  iree_hsa_kernel_dispatch_packet_t packet = {};
+  packet.setup = 0x55AAu;
+  packet.kernel_object = 0xDEADCAFEu;
+  iree_hal_amdgpu_device_buffer_fill_kernargs_t kernargs = {
+      /*.target_ptr=*/(void*)0x1234,
+      /*.element_length=*/7,
+      /*.pattern=*/0x99,
+  };
+
+  EXPECT_FALSE(iree_hal_amdgpu_device_buffer_fill_emplace(
+      &context, &packet, (void*)0x2002, /*length=*/16,
+      /*pattern=*/0xABCDu,
+      /*pattern_length=*/4, &kernargs));
+
+  EXPECT_EQ(packet.setup, 0x55AAu);
+  EXPECT_EQ(packet.kernel_object, 0xDEADCAFEu);
+  EXPECT_EQ(kernargs.target_ptr, (void*)0x1234);
+  EXPECT_EQ(kernargs.element_length, 7u);
+  EXPECT_EQ(kernargs.pattern, 0x99u);
+}
+
+TEST(BlitTest, CopyEmplaceSelectsBlockCopyForAlignedTransfer) {
+  const iree_hal_amdgpu_device_kernels_t kernels = MakeKernels();
+  const iree_hal_amdgpu_device_buffer_transfer_context_t context =
+      MakeContext(&kernels);
+  iree_hsa_kernel_dispatch_packet_t packet = {};
+  iree_hal_amdgpu_device_buffer_copy_kernargs_t kernargs = {};
+
+  ASSERT_TRUE(iree_hal_amdgpu_device_buffer_copy_emplace(
+      &context, &packet, (const void*)0x4000, (void*)0x8000,
+      /*length=*/256, &kernargs));
+
+  EXPECT_EQ(packet.setup, 10);
+  EXPECT_EQ(packet.workgroup_size[0], 64);
+  EXPECT_EQ(packet.workgroup_size[1], 1);
+  EXPECT_EQ(packet.workgroup_size[2], 1);
+  EXPECT_EQ(packet.grid_size[0], 16);
+  EXPECT_EQ(packet.grid_size[1], 1);
+  EXPECT_EQ(packet.grid_size[2], 1);
+  EXPECT_EQ(packet.private_segment_size, 13);
+  EXPECT_EQ(packet.group_segment_size, 17);
+  EXPECT_EQ(packet.kernel_object, kCopyBlockX16KernelObject);
+  EXPECT_EQ(packet.kernarg_address, &kernargs);
+  EXPECT_EQ(packet.completion_signal.handle, iree_hsa_signal_null().handle);
+
+  EXPECT_EQ(kernargs.source_ptr, (const void*)0x4000);
+  EXPECT_EQ(kernargs.target_ptr, (void*)0x8000);
+  EXPECT_EQ(kernargs.element_length, 16u);
+}
+
+TEST(BlitTest, CopyEmplaceSelectsBlockX8ForQwordAlignedTransfer) {
+  const iree_hal_amdgpu_device_kernels_t kernels = MakeKernels();
+  const iree_hal_amdgpu_device_buffer_transfer_context_t context =
+      MakeContext(&kernels);
+  iree_hsa_kernel_dispatch_packet_t packet = {};
+  iree_hal_amdgpu_device_buffer_copy_kernargs_t kernargs = {};
+
+  ASSERT_TRUE(iree_hal_amdgpu_device_buffer_copy_emplace(
+      &context, &packet, (const void*)0x4008, (void*)0x8008,
+      /*length=*/72, &kernargs));
+
+  EXPECT_EQ(packet.setup, 9);
+  EXPECT_EQ(packet.grid_size[0], 2);
+  EXPECT_EQ(packet.grid_size[1], 1);
+  EXPECT_EQ(packet.grid_size[2], 1);
+  EXPECT_EQ(packet.private_segment_size, 12);
+  EXPECT_EQ(packet.group_segment_size, 16);
+  EXPECT_EQ(packet.kernel_object, kCopyBlockX8KernelObject);
+  EXPECT_EQ(kernargs.source_ptr, (const void*)0x4008);
+  EXPECT_EQ(kernargs.target_ptr, (void*)0x8008);
+  EXPECT_EQ(kernargs.element_length, 9u);
+}
+
+TEST(BlitTest, CopyEmplaceSelectsBlockX4ForDwordAlignedTransfer) {
+  const iree_hal_amdgpu_device_kernels_t kernels = MakeKernels();
+  const iree_hal_amdgpu_device_buffer_transfer_context_t context =
+      MakeContext(&kernels);
+  iree_hsa_kernel_dispatch_packet_t packet = {};
+  iree_hal_amdgpu_device_buffer_copy_kernargs_t kernargs = {};
+
+  ASSERT_TRUE(iree_hal_amdgpu_device_buffer_copy_emplace(
+      &context, &packet, (const void*)0x4004, (void*)0x8004,
+      /*length=*/68, &kernargs));
+
+  EXPECT_EQ(packet.setup, 8);
+  EXPECT_EQ(packet.grid_size[0], 2);
+  EXPECT_EQ(packet.grid_size[1], 1);
+  EXPECT_EQ(packet.grid_size[2], 1);
+  EXPECT_EQ(packet.private_segment_size, 11);
+  EXPECT_EQ(packet.group_segment_size, 15);
+  EXPECT_EQ(packet.kernel_object, kCopyBlockX4KernelObject);
+  EXPECT_EQ(kernargs.source_ptr, (const void*)0x4004);
+  EXPECT_EQ(kernargs.target_ptr, (void*)0x8004);
+  EXPECT_EQ(kernargs.element_length, 17u);
+}
+
+TEST(BlitTest, CopyEmplaceFallsBackToByteCopyForUnalignedTransfer) {
+  const iree_hal_amdgpu_device_kernels_t kernels = MakeKernels();
+  const iree_hal_amdgpu_device_buffer_transfer_context_t context =
+      MakeContext(&kernels);
+  iree_hsa_kernel_dispatch_packet_t packet = {};
+  iree_hal_amdgpu_device_buffer_copy_kernargs_t kernargs = {};
+
+  ASSERT_TRUE(iree_hal_amdgpu_device_buffer_copy_emplace(
+      &context, &packet, (const void*)0x4001, (void*)0x8002,
+      /*length=*/17, &kernargs));
+
+  EXPECT_EQ(packet.setup, 7);
+  EXPECT_EQ(packet.grid_size[0], 17);
+  EXPECT_EQ(packet.grid_size[1], 1);
+  EXPECT_EQ(packet.grid_size[2], 1);
+  EXPECT_EQ(packet.kernel_object, kCopyX1KernelObject);
+  EXPECT_EQ(kernargs.source_ptr, (const void*)0x4001);
+  EXPECT_EQ(kernargs.target_ptr, (void*)0x8002);
+  EXPECT_EQ(kernargs.element_length, 17u);
+}
+
+TEST(BlitTest, CopyEmplaceSelectsUnalignedBlockCopyAtVectorThreshold) {
+  const iree_hal_amdgpu_device_kernels_t kernels = MakeKernels();
+  const iree_hal_amdgpu_device_buffer_transfer_context_t context =
+      MakeContext(&kernels);
+  iree_hsa_kernel_dispatch_packet_t packet = {};
+  iree_hal_amdgpu_device_buffer_copy_kernargs_t kernargs = {};
+
+  ASSERT_TRUE(iree_hal_amdgpu_device_buffer_copy_emplace(
+      &context, &packet, (const void*)0x4001, (void*)0x8002,
+      /*length=*/128, &kernargs));
+
+  EXPECT_EQ(packet.setup, 11);
+  EXPECT_EQ(packet.grid_size[0], 8);
+  EXPECT_EQ(packet.grid_size[1], 1);
+  EXPECT_EQ(packet.grid_size[2], 1);
+  EXPECT_EQ(packet.kernel_object, kCopyBlockUnalignedX16KernelObject);
+  EXPECT_EQ(kernargs.source_ptr, (const void*)0x4001);
+  EXPECT_EQ(kernargs.target_ptr, (void*)0x8002);
+  EXPECT_EQ(kernargs.element_length, 128u);
+}
+
+TEST(BlitTest, CopyEmplaceSelectsUnalignedBlockCopyForLargeByteTransfer) {
+  const iree_hal_amdgpu_device_kernels_t kernels = MakeKernels();
+  const iree_hal_amdgpu_device_buffer_transfer_context_t context =
+      MakeContext(&kernels);
+  iree_hsa_kernel_dispatch_packet_t packet = {};
+  iree_hal_amdgpu_device_buffer_copy_kernargs_t kernargs = {};
+
+  ASSERT_TRUE(iree_hal_amdgpu_device_buffer_copy_emplace(
+      &context, &packet, (const void*)0x4001, (void*)0x8002,
+      /*length=*/257, &kernargs));
+
+  EXPECT_EQ(packet.setup, 11);
+  EXPECT_EQ(packet.grid_size[0], 16);
+  EXPECT_EQ(packet.grid_size[1], 1);
+  EXPECT_EQ(packet.grid_size[2], 1);
+  EXPECT_EQ(packet.private_segment_size, 14);
+  EXPECT_EQ(packet.group_segment_size, 18);
+  EXPECT_EQ(packet.kernel_object, kCopyBlockUnalignedX16KernelObject);
+  EXPECT_EQ(kernargs.source_ptr, (const void*)0x4001);
+  EXPECT_EQ(kernargs.target_ptr, (void*)0x8002);
+  EXPECT_EQ(kernargs.element_length, 257u);
+}
+
+TEST(BlitTest, CopyEmplaceCapsWave32TransferGridToResidentWork) {
+  const iree_hal_amdgpu_device_kernels_t kernels = MakeKernels();
+  const iree_hal_amdgpu_device_buffer_transfer_context_t context =
+      MakeContext(&kernels, /*compute_unit_count=*/2,
+                  /*wavefront_size=*/32);
+  iree_hsa_kernel_dispatch_packet_t packet = {};
+  iree_hal_amdgpu_device_buffer_copy_kernargs_t kernargs = {};
+
+  ASSERT_TRUE(iree_hal_amdgpu_device_buffer_copy_emplace(
+      &context, &packet, (const void*)0x4001, (void*)0x8002,
+      /*length=*/UINT64_MAX, &kernargs));
+
+  EXPECT_EQ(packet.workgroup_size[0], 32);
+  EXPECT_EQ(packet.grid_size[0], 256);
+  EXPECT_EQ(packet.grid_size[1], 1);
+  EXPECT_EQ(packet.kernel_object, kCopyBlockUnalignedX16KernelObject);
+  EXPECT_EQ(kernargs.element_length, UINT64_MAX);
+}
+
+TEST(BlitTest, CopyEmplaceCapsLargeTransferGridToResidentWork) {
+  const iree_hal_amdgpu_device_kernels_t kernels = MakeKernels();
+  const iree_hal_amdgpu_device_buffer_transfer_context_t context =
+      MakeContext(&kernels, /*compute_unit_count=*/2,
+                  /*wavefront_size=*/64);
+  iree_hsa_kernel_dispatch_packet_t packet = {};
+  iree_hal_amdgpu_device_buffer_copy_kernargs_t kernargs = {};
+
+  ASSERT_TRUE(iree_hal_amdgpu_device_buffer_copy_emplace(
+      &context, &packet, (const void*)0x4001, (void*)0x8002,
+      /*length=*/UINT64_MAX, &kernargs));
+
+  EXPECT_EQ(packet.workgroup_size[0], 64);
+  EXPECT_EQ(packet.grid_size[0], 512);
+  EXPECT_EQ(packet.grid_size[1], 1);
+  EXPECT_EQ(packet.kernel_object, kCopyBlockUnalignedX16KernelObject);
+  EXPECT_EQ(kernargs.source_ptr, (const void*)0x4001);
+  EXPECT_EQ(kernargs.target_ptr, (void*)0x8002);
+  EXPECT_EQ(kernargs.element_length, UINT64_MAX);
+}
+
+TEST(BlitTest, CopyEmplaceUsesTwoDimensionalGridWhenResidentWorkExceedsXDim) {
+  const iree_hal_amdgpu_device_kernels_t kernels = MakeKernels();
+  const iree_hal_amdgpu_device_buffer_transfer_context_t context =
+      MakeContext(&kernels, /*compute_unit_count=*/UINT32_MAX,
+                  /*wavefront_size=*/64);
+  iree_hsa_kernel_dispatch_packet_t packet = {};
+  iree_hal_amdgpu_device_buffer_copy_kernargs_t kernargs = {};
+
+  ASSERT_TRUE(iree_hal_amdgpu_device_buffer_copy_emplace(
+      &context, &packet, (const void*)0x4001, (void*)0x8002,
+      /*length=*/UINT64_MAX, &kernargs));
+
+  EXPECT_EQ(packet.workgroup_size[0], 64);
+  EXPECT_EQ(packet.grid_size[0], 64);
+  EXPECT_EQ(packet.grid_size[1], UINT32_MAX);
+  EXPECT_EQ(packet.grid_size[2], 1);
+  EXPECT_EQ(packet.kernel_object, kCopyBlockUnalignedX16KernelObject);
+  EXPECT_EQ(kernargs.element_length, UINT64_MAX);
+}
+
+}  // namespace
+}  // namespace iree::hal::amdgpu
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/buffer.c b/runtime/src/iree/hal/drivers/amdgpu/device/buffer.c
deleted file mode 100644
index c0c21d7..0000000
--- a/runtime/src/iree/hal/drivers/amdgpu/device/buffer.c
+++ /dev/null
@@ -1,88 +0,0 @@
-// Copyright 2025 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/hal/drivers/amdgpu/device/buffer.h"
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_device_buffer_ref_t
-//===----------------------------------------------------------------------===//
-
-// TODO(benvanik): simplify this for command buffers by pre-baking as much as we
-// can during the queue issue - we can at least dereference handles and add in
-// the offset for everything such that we only have to deal with the slot offset
-// and have less branchy code.
-void* iree_hal_amdgpu_device_buffer_ref_resolve(
-    iree_hal_amdgpu_device_buffer_ref_t buffer_ref,
-    IREE_AMDGPU_ALIGNAS(64)
-        const iree_hal_amdgpu_device_buffer_ref_t* IREE_AMDGPU_RESTRICT
-            binding_table) {
-  if (buffer_ref.type == IREE_HAL_AMDGPU_DEVICE_BUFFER_TYPE_SLOT) {
-    const iree_hal_amdgpu_device_buffer_ref_t binding =
-        binding_table[buffer_ref.value.slot];
-    const uint64_t offset = buffer_ref.offset + binding.offset;
-    const uint64_t length = binding.length == UINT64_MAX
-                                ? buffer_ref.length - offset
-                                : buffer_ref.length;
-    buffer_ref = (iree_hal_amdgpu_device_buffer_ref_t){
-        .type = binding.type,
-        .offset = offset,
-        .length = length,
-        .value.bits = binding.value.bits,
-    };
-  }
-  if (buffer_ref.type == IREE_HAL_AMDGPU_DEVICE_BUFFER_TYPE_HANDLE) {
-    buffer_ref.value.ptr = buffer_ref.value.handle->ptr;
-  }
-  return buffer_ref.value.ptr
-             ? (uint8_t*)buffer_ref.value.ptr + buffer_ref.offset
-             : NULL;
-}
-
-void* iree_hal_amdgpu_device_workgroup_count_buffer_ref_resolve(
-    iree_hal_amdgpu_device_workgroup_count_buffer_ref_t buffer_ref,
-    IREE_AMDGPU_ALIGNAS(64)
-        const iree_hal_amdgpu_device_buffer_ref_t* IREE_AMDGPU_RESTRICT
-            binding_table) {
-  if (buffer_ref.type == IREE_HAL_AMDGPU_DEVICE_BUFFER_TYPE_SLOT) {
-    const iree_hal_amdgpu_device_buffer_ref_t binding =
-        binding_table[buffer_ref.value.slot];
-    const uint64_t offset = buffer_ref.offset + binding.offset;
-    buffer_ref = (iree_hal_amdgpu_device_workgroup_count_buffer_ref_t){
-        .type = binding.type,
-        .offset = offset,
-        .value.bits = binding.value.bits,
-    };
-  }
-  if (buffer_ref.type == IREE_HAL_AMDGPU_DEVICE_BUFFER_TYPE_HANDLE) {
-    buffer_ref.value.ptr = buffer_ref.value.handle->ptr;
-  }
-  return buffer_ref.value.ptr
-             ? (uint8_t*)buffer_ref.value.ptr + buffer_ref.offset
-             : NULL;
-}
-
-void* iree_hal_amdgpu_device_uint64_buffer_ref_resolve(
-    iree_hal_amdgpu_device_uint64_buffer_ref_t buffer_ref,
-    IREE_AMDGPU_ALIGNAS(64)
-        const iree_hal_amdgpu_device_buffer_ref_t* IREE_AMDGPU_RESTRICT
-            binding_table) {
-  if (buffer_ref.type == IREE_HAL_AMDGPU_DEVICE_BUFFER_TYPE_SLOT) {
-    const iree_hal_amdgpu_device_buffer_ref_t binding =
-        binding_table[buffer_ref.value.slot];
-    const uint64_t offset = buffer_ref.offset + binding.offset;
-    buffer_ref = (iree_hal_amdgpu_device_uint64_buffer_ref_t){
-        .type = binding.type,
-        .offset = offset,
-        .value.bits = binding.value.bits,
-    };
-  }
-  if (buffer_ref.type == IREE_HAL_AMDGPU_DEVICE_BUFFER_TYPE_HANDLE) {
-    buffer_ref.value.ptr = buffer_ref.value.handle->ptr;
-  }
-  return buffer_ref.value.ptr
-             ? (uint8_t*)buffer_ref.value.ptr + buffer_ref.offset
-             : NULL;
-}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/buffer.h b/runtime/src/iree/hal/drivers/amdgpu/device/buffer.h
deleted file mode 100644
index a9e596e..0000000
--- a/runtime/src/iree/hal/drivers/amdgpu/device/buffer.h
+++ /dev/null
@@ -1,198 +0,0 @@
-// Copyright 2025 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_HAL_DRIVERS_AMDGPU_DEVICE_BUFFER_H_
-#define IREE_HAL_DRIVERS_AMDGPU_DEVICE_BUFFER_H_
-
-#include "iree/hal/drivers/amdgpu/device/support/common.h"
-
-typedef struct iree_hal_amdgpu_device_allocator_pool_t
-    iree_hal_amdgpu_device_allocator_pool_t;
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_device_allocation_handle_t
-//===----------------------------------------------------------------------===//
-
-// Fat allocation pool identifier used to allow both the host and the device to
-// route to their respective pool implementations without lookups.
-typedef struct iree_hal_amdgpu_device_allocation_pool_id_t {
-  // Device-side pool in the memory space of the device that owns the
-  // allocation. Note that this may not be the local device.
-  iree_hal_amdgpu_device_allocator_pool_t* device_pool;
-  // Opaque host-side pool token.
-  uint64_t host_pool;
-} iree_hal_amdgpu_device_allocation_pool_id_t;
-
-// A handle for a dynamically device-allocated pointer.
-// The owner of the handle is responsible for storing it in device-visible
-// memory and consistently passing it in buffer references with the
-// IREE_HAL_AMDGPU_DEVICE_BUFFER_TYPE_HANDLE type. The device will dereference
-// the handle to get the actual pointer before using it. Device-side allocs and
-// frees will update the pointer in queue-order. The handle contents are only
-// valid on the device between an alloca/dealloca pair and we assume the client
-// code is not going to do something invalid (free and then try to use the
-// handle).
-//
-// Though the on-device allocator is usually responsible for manipulating the
-// handle there are cases where the host or a remote device may need to. For
-// example if the user has the last iree_hal_buffer_t reference and drops it
-// we'll need to enqueue a device-side deallocation to handle the cleanup. To
-// avoid extra round-trips we also optimize for host-side pool growth by
-// allowing the host to initialize the handle after it has grown a pool without
-// needing to requeue the device allocation.
-typedef struct iree_hal_amdgpu_device_allocation_handle_t {
-  // Allocated pointer, if any assigned.
-  void* ptr;
-  // Pool identifier the pointer resides in.
-  iree_hal_amdgpu_device_allocation_pool_id_t pool_id;
-  // Opaque data used by the allocator.
-  struct {
-    // TODO(benvanik): block the allocation resides in and other information
-    // the allocator needs to avoid lookups when deallocating.
-    int reserved;
-  } metadata;
-} iree_hal_amdgpu_device_allocation_handle_t;
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_device_buffer_ref_t
-//===----------------------------------------------------------------------===//
-
-// Identifies the type of a buffer reference and how it should be resolved.
-typedef uint8_t iree_hal_amdgpu_device_buffer_type_t;
-enum iree_hal_amdgpu_device_buffer_type_e {
-  // Reference is to an absolute device pointer that can be directly accessed.
-  IREE_HAL_AMDGPU_DEVICE_BUFFER_TYPE_PTR = 0u,
-  // Reference is to a queue-ordered allocation handle that is only valid at
-  // the time the buffer is committed. The handle will be valid for the lifetime
-  // of the logical buffer and any resources referencing it but the pointer must
-  // only be resolved between a corresponding alloca/dealloca.
-  IREE_HAL_AMDGPU_DEVICE_BUFFER_TYPE_HANDLE,
-  // Reference is to a slot in the binding table provided during execution.
-  // Only one indirection is allowed (table slots cannot reference other slots
-  // - yet).
-  IREE_HAL_AMDGPU_DEVICE_BUFFER_TYPE_SLOT,
-};
-
-// The ordinal of a slot in the binding table.
-typedef uint32_t iree_hal_amdgpu_device_buffer_ordinal_t;
-
-// Describes a subrange of a buffer that can be bound to a binding slot.
-typedef struct iree_hal_amdgpu_device_buffer_ref_t {
-  // Offset, in bytes, into the buffer that the binding starts at.
-  // This will be added to the offset specified on each usage of the slot.
-  uint64_t offset;
-  // Type of the buffer reference used to resolve the device pointer.
-  uint64_t type : 2;
-  // Length, in bytes, of the buffer that is available to the executable.
-  uint64_t length : 62;
-  union {
-    // IREE_HAL_AMDGPU_DEVICE_BUFFER_TYPE_PTR: device pointer.
-    void* ptr;
-    // IREE_HAL_AMDGPU_DEVICE_BUFFER_TYPE_HANDLE: queue-ordered allocation
-    // handle.
-    iree_hal_amdgpu_device_allocation_handle_t* handle;
-    // IREE_HAL_AMDGPU_DEVICE_BUFFER_TYPE_SLOT: binding table slot.
-    iree_hal_amdgpu_device_buffer_ordinal_t slot;
-    // Used for setting the value.
-    uint64_t bits;
-  } value;
-} iree_hal_amdgpu_device_buffer_ref_t;
-static_assert(sizeof(iree_hal_amdgpu_device_buffer_ref_t) == 24,
-              "binding table entries should be 8 byte aligned");
-
-// Describes a buffer binding that contains a uint32_t[3] XYZ workgroup count.
-// This is a size-optimized version of iree_hal_amdgpu_device_buffer_ref_t so
-// that it will fit in our tiny packets. We know the length is a constant 12 and
-// only need the offset, type, and value.
-typedef struct iree_hal_amdgpu_device_workgroup_count_buffer_ref_t {
-  // Type of the buffer reference used to resolve the device pointer.
-  uint64_t type : 2;  // iree_hal_amdgpu_device_buffer_type_t
-  // Offset, in bytes, into the buffer that the binding starts at.
-  // This will be added to the offset specified on each usage of the slot.
-  uint64_t offset : 62;
-  union {
-    // IREE_HAL_AMDGPU_DEVICE_BUFFER_TYPE_PTR: raw device pointer.
-    void* ptr;
-    // IREE_HAL_AMDGPU_DEVICE_BUFFER_TYPE_HANDLE: queue-ordered allocation
-    // handle.
-    iree_hal_amdgpu_device_allocation_handle_t* handle;
-    // IREE_HAL_AMDGPU_DEVICE_BUFFER_TYPE_SLOT: binding table slot.
-    iree_hal_amdgpu_device_buffer_ordinal_t slot;
-    // Used for setting the value.
-    uint64_t bits;
-  } value;
-} iree_hal_amdgpu_device_workgroup_count_buffer_ref_t;
-static_assert(sizeof(iree_hal_amdgpu_device_workgroup_count_buffer_ref_t) == 16,
-              "binding table entries should be 8 byte aligned and tiny");
-
-#define iree_hal_amdgpu_device_workgroup_count_buffer_ref_length(buffer_ref) \
-  (sizeof(uint32_t) * 3)
-
-// Describes a buffer binding that contains a single uint64_t value.
-// This is a size-optimized version of iree_hal_amdgpu_device_buffer_ref_t so
-// that it will fit in our tiny packets. We know the length is a constant 8 and
-// only need the offset, type, and value.
-typedef struct iree_hal_amdgpu_device_uint64_buffer_ref_t {
-  // Type of the buffer reference used to resolve the device pointer.
-  uint64_t type : 2;  // iree_hal_amdgpu_device_buffer_type_t
-  // Offset, in bytes, into the buffer that the binding starts at.
-  // This will be added to the offset specified on each usage of the slot.
-  uint64_t offset : 62;
-  union {
-    // IREE_HAL_AMDGPU_DEVICE_BUFFER_TYPE_PTR: raw device pointer.
-    void* ptr;
-    // IREE_HAL_AMDGPU_DEVICE_BUFFER_TYPE_HANDLE: queue-ordered allocation
-    // handle.
-    iree_hal_amdgpu_device_allocation_handle_t* handle;
-    // IREE_HAL_AMDGPU_DEVICE_BUFFER_TYPE_SLOT: binding table slot.
-    iree_hal_amdgpu_device_buffer_ordinal_t slot;
-    // Used for setting the value.
-    uint64_t bits;
-  } value;
-} iree_hal_amdgpu_device_uint64_buffer_ref_t;
-static_assert(sizeof(iree_hal_amdgpu_device_uint64_buffer_ref_t) == 16,
-              "binding table entries should be 8 byte aligned and tiny");
-
-#define iree_hal_amdgpu_device_uint64_buffer_ref_length(buffer_ref) \
-  sizeof(uint64_t)
-
-#if defined(IREE_AMDGPU_TARGET_DEVICE)
-
-// Resolves a buffer reference to an absolute device pointer.
-// Expects that the binding table is provided if needed and has sufficient
-// capacity for any slot that may be referenced. All queue-ordered allocations
-// that may be provided via allocation handles must be committed prior to
-// attempting to resolve them and must remain committed until all commands using
-// the returned device pointer have completed.
-void* iree_hal_amdgpu_device_buffer_ref_resolve(
-    iree_hal_amdgpu_device_buffer_ref_t buffer_ref,
-    IREE_AMDGPU_ALIGNAS(64)
-        const iree_hal_amdgpu_device_buffer_ref_t* IREE_AMDGPU_RESTRICT
-            binding_table);
-
-// Resolves a workgroup count buffer reference to an absolute device pointer.
-// This is equivalent to iree_hal_amdgpu_device_buffer_ref_resolve but for a
-// fixed-size uint32_t[3] value. The returned pointer should have 4-byte
-// alignment.
-void* iree_hal_amdgpu_device_workgroup_count_buffer_ref_resolve(
-    iree_hal_amdgpu_device_workgroup_count_buffer_ref_t buffer_ref,
-    IREE_AMDGPU_ALIGNAS(64)
-        const iree_hal_amdgpu_device_buffer_ref_t* IREE_AMDGPU_RESTRICT
-            binding_table);
-
-// Resolves a scalar uint64_t buffer reference to an absolute device pointer.
-// This is equivalent to iree_hal_amdgpu_device_buffer_ref_resolve but for a
-// fixed-size uint64_t value. The returned pointer should have 8-byte
-// alignment.
-void* iree_hal_amdgpu_device_uint64_buffer_ref_resolve(
-    iree_hal_amdgpu_device_uint64_buffer_ref_t buffer_ref,
-    IREE_AMDGPU_ALIGNAS(64)
-        const iree_hal_amdgpu_device_buffer_ref_t* IREE_AMDGPU_RESTRICT
-            binding_table);
-
-#endif  // IREE_AMDGPU_TARGET_DEVICE
-
-#endif  // IREE_HAL_DRIVERS_AMDGPU_DEVICE_BUFFER_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/command_buffer.c b/runtime/src/iree/hal/drivers/amdgpu/device/command_buffer.c
deleted file mode 100644
index b350cad..0000000
--- a/runtime/src/iree/hal/drivers/amdgpu/device/command_buffer.c
+++ /dev/null
@@ -1,1254 +0,0 @@
-// Copyright 2025 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/hal/drivers/amdgpu/device/command_buffer.h"
-
-// TODO(benvanik): bring in scheduler implementation. For now we define the
-// methods we use to produce valid builds.
-static void
-iree_hal_amdgpu_device_queue_scheduler_reschedule_from_execution_queue(
-    iree_hal_amdgpu_device_queue_scheduler_t* scheduler,
-    uint64_t scheduler_queue_entry) {}
-static void iree_hal_amdgpu_device_queue_scheduler_retire_from_execution_queue(
-    iree_hal_amdgpu_device_queue_scheduler_t* scheduler,
-    uint64_t scheduler_queue_entry) {}
-
-//===----------------------------------------------------------------------===//
-// Utilities
-//===----------------------------------------------------------------------===//
-
-// Returns the packet pointer in the execution queue with the given |packet_id|.
-static iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT
-iree_hal_amdgpu_device_cmd_resolve_dispatch_packet(
-    iree_hal_amdgpu_device_execution_state_t* IREE_AMDGPU_RESTRICT state,
-    const uint64_t packet_id) {
-  const uint64_t queue_mask = state->execution_queue.size - 1;  // power of two
-  return state->execution_queue.base_address + (packet_id & queue_mask) * 64;
-}
-
-// Makes all bits of an AQL packet header except the type.
-// The caller must OR in the type before setting the packet.
-static inline IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE uint16_t
-iree_hal_amdgpu_device_make_cmd_packet_header(
-    const iree_hal_amdgpu_device_cmd_header_t* IREE_AMDGPU_RESTRICT cmd_header,
-    iree_hal_amdgpu_device_execution_flags_t execution_flags) {
-  const bool force_barrier =
-      (execution_flags & IREE_HAL_AMDGPU_DEVICE_EXECUTION_FLAG_SERIALIZE) != 0;
-  const bool force_uncached =
-      (execution_flags & IREE_HAL_AMDGPU_DEVICE_EXECUTION_FLAG_UNCACHED) != 0;
-
-  // Translate command flags; they're mostly just bit-packed header bits.
-  const bool barrier =
-      force_barrier ||
-      ((cmd_header->flags &
-        IREE_HAL_AMDGPU_DEVICE_CMD_FLAG_QUEUE_AWAIT_BARRIER) != 0);
-  const iree_hsa_fence_scope_t scacquire_fence_scope =
-      force_uncached ? IREE_HSA_FENCE_SCOPE_SYSTEM
-                     : ((cmd_header->flags >>
-                         IREE_HAL_AMDGPU_DEVICE_CMD_FLAG_FENCE_ACQUIRE_BIT) &
-                        0x3);
-  const iree_hsa_fence_scope_t screlease_fence_scope =
-      force_uncached ? IREE_HSA_FENCE_SCOPE_SYSTEM
-                     : ((cmd_header->flags >>
-                         IREE_HAL_AMDGPU_DEVICE_CMD_FLAG_FENCE_RELEASE_BIT) &
-                        0x3);
-
-  // Form the header word.
-  return (barrier << IREE_HSA_PACKET_HEADER_BARRIER) |
-         (scacquire_fence_scope
-          << IREE_HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE) |
-         (screlease_fence_scope
-          << IREE_HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE);
-}
-
-// Emplaces a lightweight barrier packet (no cache management, no-op wait)
-// and associates the optional |completion_signal|. The packet processor will
-// populate the timestamps on the signal after the packet has retired.
-static void iree_hal_amdgpu_device_cmd_emplace_barrier(
-    iree_hal_amdgpu_device_execution_state_t* IREE_AMDGPU_RESTRICT state,
-    const iree_hal_amdgpu_device_cmd_header_t* IREE_AMDGPU_RESTRICT cmd_header,
-    const uint64_t packet_id, iree_hsa_signal_t completion_signal) {
-  const uint64_t queue_mask = state->execution_queue.size - 1;  // power of two
-  iree_hsa_barrier_or_packet_t* IREE_AMDGPU_RESTRICT packet =
-      state->execution_queue.base_address + (packet_id & queue_mask) * 64;
-
-  // No signals to make this a no-op.
-  for (size_t i = 0; i < IREE_AMDGPU_ARRAYSIZE(packet->dep_signal); ++i) {
-    packet->dep_signal[i] = iree_hsa_signal_null();
-  }
-
-  // Chain the provided signal, which is likely an trace query.
-  packet->completion_signal = completion_signal;
-
-  // Form the header word.
-  // NOTE: uint16_t high is reserved0.
-  const uint32_t barrier_header =
-      iree_hal_amdgpu_device_make_cmd_packet_header(cmd_header, state->flags);
-
-  // Swap header to enable the packet.
-  iree_amdgpu_scoped_atomic_store(
-      (iree_amdgpu_scoped_atomic_uint32_t*)packet, barrier_header,
-      iree_amdgpu_memory_order_release, iree_amdgpu_memory_scope_device);
-}
-
-// Commits a CFG control packet.
-// These are assumed to run on a single thread.
-static void iree_hal_amdgpu_device_cmd_commit_cfg_packet(
-    iree_hal_amdgpu_device_execution_state_t* IREE_AMDGPU_RESTRICT state,
-    const iree_hal_amdgpu_device_cmd_header_t* IREE_AMDGPU_RESTRICT cmd_header,
-    const uint64_t packet_id,
-    const iree_hal_amdgpu_trace_execution_query_id_t execution_query_id,
-    const iree_hal_amdgpu_device_kernel_args_t* IREE_AMDGPU_RESTRICT
-        kernel_args,
-    void* IREE_AMDGPU_RESTRICT kernarg_ptr) {
-  // Emplace a packet in the execution queue but leave the header uninitialized.
-  const uint64_t queue_mask = state->execution_queue.size - 1;  // power of two
-  iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT packet =
-      (iree_hsa_kernel_dispatch_packet_t*)state->execution_queue.base_address +
-      (packet_id & queue_mask);
-  packet->setup = kernel_args->setup;
-  packet->workgroup_size[0] = kernel_args->workgroup_size[0];
-  packet->workgroup_size[1] = kernel_args->workgroup_size[1];
-  packet->workgroup_size[2] = kernel_args->workgroup_size[2];
-  packet->reserved0 = 0;
-  packet->grid_size[0] = 1;
-  packet->grid_size[1] = 1;
-  packet->grid_size[2] = 1;
-  packet->private_segment_size = kernel_args->private_segment_size;
-  packet->group_segment_size = kernel_args->group_segment_size;
-  packet->kernel_object = kernel_args->kernel_object;
-  packet->kernarg_address = kernarg_ptr;
-  packet->reserved2 = 0;
-
-#if IREE_HAL_AMDGPU_HAS_TRACING_FEATURE( \
-    IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_EXECUTION)
-  // Enqueue tracing event and get a query signal used for timing.
-  if (execution_query_id != IREE_HAL_AMDGPU_TRACE_EXECUTION_QUERY_ID_INVALID) {
-    packet->completion_signal =
-        iree_hal_amdgpu_device_trace_execution_zone_dispatch(
-            state->trace_buffer,
-            IREE_HAL_AMDGPU_TRACE_EXECUTION_ZONE_TYPE_INTERNAL,
-            kernel_args->trace_src_loc, execution_query_id);
-  } else {
-    packet->completion_signal = iree_hsa_signal_null();
-  }
-#else
-  packet->completion_signal = iree_hsa_signal_null();
-#endif  // IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_EXECUTION
-
-  // Populate the header and release the packet to the queue.
-  const uint16_t header =
-      (IREE_HSA_PACKET_TYPE_KERNEL_DISPATCH << IREE_HSA_PACKET_HEADER_TYPE) |
-      iree_hal_amdgpu_device_make_cmd_packet_header(cmd_header, state->flags);
-  const uint32_t header_setup = header | (uint32_t)(packet->setup << 16);
-  iree_amdgpu_scoped_atomic_store(
-      (iree_amdgpu_scoped_atomic_uint32_t*)packet, header_setup,
-      iree_amdgpu_memory_order_release, iree_amdgpu_memory_scope_device);
-
-  iree_hsa_signal_store(state->execution_queue.doorbell_signal, packet_id,
-                        iree_amdgpu_memory_order_release);
-}
-
-// Updates a dispatch packet header and optional tracing query signal.
-// The returned packet will still have an INVALID type and that will need to be
-// OR'ed in by the caller.
-static void iree_hal_amdgpu_device_cmd_update_dispatch_packet(
-    iree_hal_amdgpu_device_execution_state_t* IREE_AMDGPU_RESTRICT state,
-    const iree_hal_amdgpu_device_cmd_header_t* IREE_AMDGPU_RESTRICT cmd_header,
-    iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT packet,
-    const iree_hal_amdgpu_trace_execution_zone_type_t execution_zone_type,
-    const uint64_t export_loc,
-    const iree_hal_amdgpu_trace_execution_query_id_t execution_query_id) {
-#if IREE_HAL_AMDGPU_HAS_TRACING_FEATURE( \
-    IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_EXECUTION)
-  // Enqueue tracing event and get a query signal used for timing.
-  if (execution_query_id != IREE_HAL_AMDGPU_TRACE_EXECUTION_QUERY_ID_INVALID) {
-    packet->completion_signal =
-        iree_hal_amdgpu_device_trace_execution_zone_dispatch(
-            state->trace_buffer, execution_zone_type, export_loc,
-            execution_query_id);
-  } else {
-    packet->completion_signal = iree_hsa_signal_null();
-  }
-#else
-  packet->completion_signal = iree_hsa_signal_null();
-#endif  // IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_EXECUTION
-
-  // Populate the header and release the packet to the queue.
-  // NOTE: we don't assign the packet type yet - the commit needs to do that
-  // only when the packet has been full formed.
-  packet->header =
-      (IREE_HSA_PACKET_TYPE_INVALID << IREE_HSA_PACKET_HEADER_TYPE) |
-      iree_hal_amdgpu_device_make_cmd_packet_header(cmd_header, state->flags);
-}
-
-// Commits a dispatch or transfer packet.
-static void iree_hal_amdgpu_device_cmd_commit_dispatch_packet(
-    iree_hal_amdgpu_device_execution_state_t* IREE_AMDGPU_RESTRICT state,
-    const iree_hal_amdgpu_device_cmd_header_t* IREE_AMDGPU_RESTRICT cmd_header,
-    iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT packet,
-    const iree_hal_amdgpu_trace_execution_zone_type_t execution_zone_type,
-    const uint64_t export_loc,
-    const iree_hal_amdgpu_trace_execution_query_id_t execution_query_id) {
-  // Update the packet with the required information.
-  iree_hal_amdgpu_device_cmd_update_dispatch_packet(
-      state, cmd_header, packet, execution_zone_type, export_loc,
-      execution_query_id);
-
-  // Update the header from INVALID to KERNEL_DISPATCH so the packet processor
-  // can begin executing it.
-  const uint16_t header =
-      (packet->header &
-       ~(IREE_HSA_PACKET_TYPE_INVALID << IREE_HSA_PACKET_HEADER_TYPE)) |
-      (IREE_HSA_PACKET_TYPE_KERNEL_DISPATCH << IREE_HSA_PACKET_HEADER_TYPE);
-  const uint32_t header_setup = header | (uint32_t)(packet->setup << 16);
-  iree_amdgpu_scoped_atomic_store(
-      (iree_amdgpu_scoped_atomic_uint32_t*)packet, header_setup,
-      iree_amdgpu_memory_order_release, iree_amdgpu_memory_scope_device);
-}
-
-// Flushes all outstanding tracing queries from the current block.
-// If the caller is running on the execution queue it will not be included (as
-// its query has not yet been resolved).
-//
-// TODO(benvanik): support resolving the final terminator time.
-static void iree_hal_amdgpu_device_flush_execution_queries(
-    iree_hal_amdgpu_device_execution_state_t* IREE_AMDGPU_RESTRICT state) {
-#if IREE_HAL_AMDGPU_HAS_TRACING_FEATURE( \
-    IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_EXECUTION)
-  iree_hal_amdgpu_trace_agent_time_range_t* IREE_AMDGPU_RESTRICT time_ranges =
-      iree_hal_amdgpu_device_trace_execution_zone_notify_batch(
-          state->trace_buffer, state->trace_block_query_base_id,
-          state->trace_block_query_count);
-  for (uint16_t i = 0; i < state->trace_block_query_count; ++i) {
-    iree_amd_signal_t* IREE_AMDGPU_RESTRICT signal =
-        (iree_amd_signal_t*)
-            iree_hal_amdgpu_device_query_ringbuffer_signal_for_id(
-                &state->trace_buffer->query_ringbuffer,
-                state->trace_block_query_base_id + i)
-                .handle;
-    time_ranges[i] = (iree_hal_amdgpu_trace_agent_time_range_t){
-        .begin = signal->start_ts,
-        .end = signal->end_ts,
-    };
-  }
-#endif  // IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_EXECUTION
-}
-
-//===----------------------------------------------------------------------===//
-// Device-side Enqueuing
-//===----------------------------------------------------------------------===//
-
-static void iree_hal_amdgpu_device_cmd_issue(
-    iree_hal_amdgpu_device_execution_state_t* IREE_AMDGPU_RESTRICT state,
-    const iree_hal_amdgpu_device_command_block_t* IREE_AMDGPU_RESTRICT block,
-    uint32_t command_ordinal, uint64_t base_packet_id);
-
-static void iree_hal_amdgpu_device_command_buffer_enqueue_next_block_serial(
-    iree_hal_amdgpu_device_execution_state_t* IREE_AMDGPU_RESTRICT state) {
-  // Reserve space for all of the execution packets. We will populate them all
-  // in the loop below.
-  const uint64_t base_packet_id = iree_hsa_queue_add_write_index(
-      &state->execution_queue, state->block->max_packet_count,
-      iree_amdgpu_memory_order_relaxed);
-  while (base_packet_id -
-             iree_hsa_queue_load_read_index(&state->execution_queue,
-                                            iree_amdgpu_memory_order_acquire) >=
-         state->execution_queue.size) {
-    iree_amdgpu_yield();  // spinning
-  }
-
-  // Signal the execution queue doorbell immediately even though we haven't
-  // populated the packets yet: it should kick it into waking and spinning while
-  // we populate the packets. Since we write packets in order the span from this
-  // moment to the first packet execution should be as small as possible when
-  // hopping queues.
-  iree_hsa_signal_store(state->execution_queue.doorbell_signal,
-                        base_packet_id + state->block->max_packet_count,
-                        iree_amdgpu_memory_order_relaxed);
-
-  // Issue all packets to the execution queue.
-  for (uint32_t i = 0; i < state->block->command_count; ++i) {
-    iree_hal_amdgpu_device_cmd_issue(state, state->block, i, base_packet_id);
-  }
-}
-
-static void iree_hal_amdgpu_device_command_buffer_enqueue_next_block_parallel(
-    iree_hal_amdgpu_device_execution_state_t* IREE_AMDGPU_RESTRICT state) {
-  // Reserve the next packet in the control queue for the issue_block kernel.
-  const iree_amd_cached_queue_t* control_queue = state->control_queue;
-  const uint64_t control_packet_id = iree_hsa_queue_add_write_index(
-      control_queue, 1u, iree_amdgpu_memory_order_relaxed);
-  while (control_packet_id -
-             iree_hsa_queue_load_read_index(control_queue,
-                                            iree_amdgpu_memory_order_acquire) >=
-         control_queue->size) {
-    iree_amdgpu_yield();  // spinning
-  }
-
-  // Reserve space for all of the execution packets.
-  // We need to ensure we have this entire range for the issue_block kernel to
-  // populate prior to launching it.
-  //
-  // NOTE: we do this after we insert the control packet as the control and
-  // execution queues may be the same: we must issue the control packet prior to
-  // any execution packets that are reserved as INVALID and that will block the
-  // packet processor.
-  const uint64_t base_packet_id = iree_hsa_queue_add_write_index(
-      &state->execution_queue, state->block->max_packet_count,
-      iree_amdgpu_memory_order_relaxed);
-  while (base_packet_id -
-             iree_hsa_queue_load_read_index(&state->execution_queue,
-                                            iree_amdgpu_memory_order_acquire) >=
-         state->execution_queue.size) {
-    iree_amdgpu_yield();  // spinning
-  }
-
-  // Kernel arguments stored in the shared control kernarg storage. There should
-  // only be one control dispatch enqueued at a time.
-  uint64_t* kernarg_ptr = (uint64_t*)state->control_kernarg_storage;
-  kernarg_ptr[0] = (uint64_t)state;
-  kernarg_ptr[1] = (uint64_t)state->block;
-  kernarg_ptr[2] = base_packet_id;
-
-  // Construct the control packet.
-  // Note that the header is not written until the end so that the
-  // hardware command processor stalls until we're done writing.
-  const iree_hal_amdgpu_device_kernel_args_t control_args =
-      state->kernels->iree_hal_amdgpu_device_cmd_block_issue;
-  const uint64_t queue_mask = control_queue->size - 1;  // power of two
-  iree_hsa_kernel_dispatch_packet_t* control_packet =
-      (iree_hsa_kernel_dispatch_packet_t*)control_queue->base_address +
-      (control_packet_id & queue_mask);
-  control_packet->setup = control_args.setup;
-  control_packet->workgroup_size[0] = control_args.workgroup_size[0];
-  control_packet->workgroup_size[1] = control_args.workgroup_size[1];
-  control_packet->workgroup_size[2] = control_args.workgroup_size[2];
-  control_packet->reserved0 = 0;
-  control_packet->grid_size[0] = state->block->command_count;
-  control_packet->grid_size[1] = 1;
-  control_packet->grid_size[2] = 1;
-  control_packet->private_segment_size = control_args.private_segment_size;
-  control_packet->group_segment_size = control_args.group_segment_size;
-  control_packet->kernel_object = control_args.kernel_object;
-  control_packet->kernarg_address = kernarg_ptr;
-  control_packet->reserved2 = 0;
-  control_packet->completion_signal.handle = 0;
-
-  // Populate the header and release the packet to the queue.
-  uint16_t control_header = IREE_HSA_PACKET_TYPE_KERNEL_DISPATCH
-                            << IREE_HSA_PACKET_HEADER_TYPE;
-
-  // Force a barrier while performing the issue so that any shared resources we
-  // use will not have hazards.
-  control_header |= 1 << IREE_HSA_PACKET_HEADER_BARRIER;
-
-  // We scope to the agent as we should be scheduled and targeting execution
-  // queues on the same one.
-  control_header |= IREE_HSA_FENCE_SCOPE_AGENT
-                    << IREE_HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE;
-  control_header |= IREE_HSA_FENCE_SCOPE_AGENT
-                    << IREE_HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE;
-
-  // Mark the control packet as ready to execute. The hardware command processor
-  // may begin executing it immediately after performing the atomic swap.
-  const uint32_t control_header_setup =
-      control_header | (uint32_t)(control_packet->setup << 16);
-  iree_amdgpu_scoped_atomic_store(
-      (iree_amdgpu_scoped_atomic_uint32_t*)control_packet, control_header_setup,
-      iree_amdgpu_memory_order_release, iree_amdgpu_memory_scope_device);
-
-  // Signal the queue doorbell indicating the packet has been updated.
-  iree_hsa_signal_store(control_queue->doorbell_signal, control_packet_id,
-                        iree_amdgpu_memory_order_relaxed);
-}
-
-void iree_hal_amdgpu_device_command_buffer_enqueue_next_block(
-    iree_hal_amdgpu_device_execution_state_t* IREE_AMDGPU_RESTRICT state) {
-#if IREE_HAL_AMDGPU_HAS_TRACING_FEATURE( \
-    IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_CONTROL)
-  // Reserve a query ID range for the commands in the block.
-  // We take up to the maximum required for each tracing mode we may be in but
-  // the block may not use them all.
-  uint16_t query_count = 0;
-  if (state->flags & IREE_HAL_AMDGPU_DEVICE_EXECUTION_FLAG_TRACE_DISPATCH) {
-    query_count = state->block->query_map.max_dispatch_query_count;
-  } else if (state->flags &
-             IREE_HAL_AMDGPU_DEVICE_EXECUTION_FLAG_TRACE_CONTROL) {
-    query_count = state->block->query_map.max_control_query_count;
-  }
-  state->trace_block_query_count = query_count;
-  if (query_count > 0) {
-    state->trace_block_query_base_id =
-        iree_hal_amdgpu_device_query_ringbuffer_acquire_range(
-            &state->trace_buffer->query_ringbuffer, query_count);
-  }
-#endif  // IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_CONTROL
-
-  // The execution request decides whether we issue serially here on the control
-  // queue or in parallel via a dispatch to the control queue. The dispatch has
-  // higher latency but greater throughput and is something we only want to use
-  // if that throughput is required (lots of commands).
-  if (state->flags & IREE_HAL_AMDGPU_DEVICE_EXECUTION_FLAG_ISSUE_SERIALLY) {
-    return iree_hal_amdgpu_device_command_buffer_enqueue_next_block_serial(
-        state);
-  } else {
-    return iree_hal_amdgpu_device_command_buffer_enqueue_next_block_parallel(
-        state);
-  }
-}
-
-//===----------------------------------------------------------------------===//
-// IREE_HAL_AMDGPU_DEVICE_CMD_DEBUG_GROUP_BEGIN
-//===----------------------------------------------------------------------===//
-
-static void iree_hal_amdgpu_device_cmd_debug_group_begin_issue(
-    iree_hal_amdgpu_device_execution_state_t* IREE_AMDGPU_RESTRICT state,
-    const iree_hal_amdgpu_device_command_block_t* IREE_AMDGPU_RESTRICT block,
-    const iree_hal_amdgpu_device_cmd_debug_group_begin_t* IREE_AMDGPU_RESTRICT
-        cmd,
-    const uint64_t packet_id,
-    const iree_hal_amdgpu_trace_execution_query_id_t execution_query_id) {
-  // If tracing is enabled then get the signal used to query timestamps.
-  iree_hsa_signal_t completion_signal = iree_hsa_signal_null();
-#if IREE_HAL_AMDGPU_HAS_TRACING_FEATURE( \
-    IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_CONTROL)
-  if (execution_query_id != IREE_HAL_AMDGPU_TRACE_EXECUTION_QUERY_ID_INVALID) {
-    completion_signal = iree_hal_amdgpu_device_trace_execution_zone_begin(
-        state->trace_buffer, execution_query_id, cmd->src_loc);
-  }
-#endif  // IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_CONTROL
-
-  // Emit a lightweight barrier packet (no cache management, no-op wait) to
-  // force the command buffer to execute as if we were capturing timing even if
-  // we aren't. This can be useful for native debugging tools and also lets us
-  // more easily detect the overhead of tracing.
-  return iree_hal_amdgpu_device_cmd_emplace_barrier(
-      state, &cmd->header, packet_id, completion_signal);
-}
-
-//===----------------------------------------------------------------------===//
-// IREE_HAL_AMDGPU_DEVICE_CMD_DEBUG_GROUP_END
-//===----------------------------------------------------------------------===//
-
-static void iree_hal_amdgpu_device_cmd_debug_group_end_issue(
-    iree_hal_amdgpu_device_execution_state_t* IREE_AMDGPU_RESTRICT state,
-    const iree_hal_amdgpu_device_command_block_t* IREE_AMDGPU_RESTRICT block,
-    const iree_hal_amdgpu_device_cmd_debug_group_end_t* IREE_AMDGPU_RESTRICT
-        cmd,
-    const uint64_t packet_id,
-    const iree_hal_amdgpu_trace_execution_query_id_t execution_query_id) {
-  // If tracing is enabled then get the signal used to query timestamps.
-  iree_hsa_signal_t completion_signal = iree_hsa_signal_null();
-#if IREE_HAL_AMDGPU_HAS_TRACING_FEATURE( \
-    IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_CONTROL)
-  if (execution_query_id != IREE_HAL_AMDGPU_TRACE_EXECUTION_QUERY_ID_INVALID) {
-    completion_signal = iree_hal_amdgpu_device_trace_execution_zone_end(
-        state->trace_buffer, execution_query_id);
-  }
-#endif  // IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_CONTROL
-
-  // Emit a lightweight barrier packet (no cache management, no-op wait) to
-  // force the command buffer to execute as if we were capturing timing even if
-  // we aren't. This can be useful for native debugging tools and also lets us
-  // more easily detect the overhead of tracing.
-  return iree_hal_amdgpu_device_cmd_emplace_barrier(
-      state, &cmd->header, packet_id, completion_signal);
-}
-
-//===----------------------------------------------------------------------===//
-// IREE_HAL_AMDGPU_DEVICE_CMD_BARRIER
-//===----------------------------------------------------------------------===//
-
-static void iree_hal_amdgpu_device_cmd_barrier_issue(
-    iree_hal_amdgpu_device_execution_state_t* IREE_AMDGPU_RESTRICT state,
-    const iree_hal_amdgpu_device_command_block_t* IREE_AMDGPU_RESTRICT block,
-    const iree_hal_amdgpu_device_cmd_barrier_t* IREE_AMDGPU_RESTRICT cmd,
-    const uint64_t packet_id,
-    const iree_hal_amdgpu_trace_execution_query_id_t execution_query_id) {
-  // TODO(benvanik): derive scope from command header.
-  return iree_hal_amdgpu_device_cmd_emplace_barrier(
-      state, &cmd->header, packet_id, iree_hsa_signal_null());
-}
-
-//===----------------------------------------------------------------------===//
-// IREE_HAL_AMDGPU_DEVICE_CMD_SIGNAL_EVENT
-//===----------------------------------------------------------------------===//
-
-static void iree_hal_amdgpu_device_cmd_signal_event_issue(
-    iree_hal_amdgpu_device_execution_state_t* IREE_AMDGPU_RESTRICT state,
-    const iree_hal_amdgpu_device_command_block_t* IREE_AMDGPU_RESTRICT block,
-    const iree_hal_amdgpu_device_cmd_signal_event_t* IREE_AMDGPU_RESTRICT cmd,
-    const uint64_t packet_id,
-    const iree_hal_amdgpu_trace_execution_query_id_t execution_query_id) {
-  // TODO(benvanik): HSA signal handling.
-}
-
-//===----------------------------------------------------------------------===//
-// IREE_HAL_AMDGPU_DEVICE_CMD_RESET_EVENT
-//===----------------------------------------------------------------------===//
-
-static void iree_hal_amdgpu_device_cmd_reset_event_issue(
-    iree_hal_amdgpu_device_execution_state_t* IREE_AMDGPU_RESTRICT state,
-    const iree_hal_amdgpu_device_command_block_t* IREE_AMDGPU_RESTRICT block,
-    const iree_hal_amdgpu_device_cmd_reset_event_t* IREE_AMDGPU_RESTRICT cmd,
-    const uint64_t packet_id,
-    const iree_hal_amdgpu_trace_execution_query_id_t execution_query_id) {
-  // TODO(benvanik): HSA signal handling.
-}
-
-//===----------------------------------------------------------------------===//
-// IREE_HAL_AMDGPU_DEVICE_CMD_WAIT_EVENTS
-//===----------------------------------------------------------------------===//
-
-static void iree_hal_amdgpu_device_cmd_wait_events_issue(
-    iree_hal_amdgpu_device_execution_state_t* IREE_AMDGPU_RESTRICT state,
-    const iree_hal_amdgpu_device_command_block_t* IREE_AMDGPU_RESTRICT block,
-    const iree_hal_amdgpu_device_cmd_wait_events_t* IREE_AMDGPU_RESTRICT cmd,
-    const uint64_t packet_id,
-    const iree_hal_amdgpu_trace_execution_query_id_t execution_query_id) {
-  // TODO(benvanik): HSA signal handling.
-}
-
-//===----------------------------------------------------------------------===//
-// IREE_HAL_AMDGPU_DEVICE_CMD_FILL_BUFFER
-//===----------------------------------------------------------------------===//
-
-static void iree_hal_amdgpu_device_cmd_fill_buffer_issue(
-    iree_hal_amdgpu_device_execution_state_t* IREE_AMDGPU_RESTRICT state,
-    const iree_hal_amdgpu_device_command_block_t* IREE_AMDGPU_RESTRICT block,
-    const iree_hal_amdgpu_device_cmd_fill_buffer_t* IREE_AMDGPU_RESTRICT cmd,
-    const uint64_t packet_id,
-    const iree_hal_amdgpu_trace_execution_query_id_t execution_query_id) {
-  // Resolve bindings.
-  void* target_ptr = iree_hal_amdgpu_device_buffer_ref_resolve(cmd->target_ref,
-                                                               state->bindings);
-  const uint64_t length = cmd->target_ref.length;
-
-  // Emplace a packet in the execution queue but leave the header uninitialized.
-  uint64_t* kernarg_ptr =
-      (uint64_t*)(state->execution_kernarg_storage + cmd->kernarg_offset);
-  iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT packet =
-      iree_hal_amdgpu_device_buffer_fill_emplace_reserve(
-          &state->transfer_context, target_ptr, length, cmd->pattern,
-          cmd->pattern_length, kernarg_ptr, packet_id);
-
-  // Commit the packet.
-  return iree_hal_amdgpu_device_cmd_commit_dispatch_packet(
-      state, &cmd->header, packet,
-      IREE_HAL_AMDGPU_TRACE_EXECUTION_ZONE_TYPE_FILL, 0, execution_query_id);
-}
-
-//===----------------------------------------------------------------------===//
-// IREE_HAL_AMDGPU_DEVICE_CMD_COPY_BUFFER
-//===----------------------------------------------------------------------===//
-
-static void iree_hal_amdgpu_device_cmd_copy_buffer_issue(
-    iree_hal_amdgpu_device_execution_state_t* IREE_AMDGPU_RESTRICT state,
-    const iree_hal_amdgpu_device_command_block_t* IREE_AMDGPU_RESTRICT block,
-    const iree_hal_amdgpu_device_cmd_copy_buffer_t* IREE_AMDGPU_RESTRICT cmd,
-    const uint64_t packet_id,
-    const iree_hal_amdgpu_trace_execution_query_id_t execution_query_id) {
-  // Resolve bindings.
-  const void* source_ptr = iree_hal_amdgpu_device_buffer_ref_resolve(
-      cmd->source_ref, state->bindings);
-  void* target_ptr = iree_hal_amdgpu_device_buffer_ref_resolve(cmd->target_ref,
-                                                               state->bindings);
-  const uint64_t length = cmd->target_ref.length;
-
-  // Emplace a packet in the execution queue but leave the header uninitialized.
-  uint64_t* kernarg_ptr =
-      (uint64_t*)(state->execution_kernarg_storage + cmd->kernarg_offset);
-  iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT packet =
-      iree_hal_amdgpu_device_buffer_copy_emplace_reserve(
-          &state->transfer_context, source_ptr, target_ptr, length, kernarg_ptr,
-          packet_id);
-
-  // Commit the packet.
-  return iree_hal_amdgpu_device_cmd_commit_dispatch_packet(
-      state, &cmd->header, packet,
-      IREE_HAL_AMDGPU_TRACE_EXECUTION_ZONE_TYPE_COPY, 0, execution_query_id);
-}
-
-//===----------------------------------------------------------------------===//
-// IREE_HAL_AMDGPU_DEVICE_CMD_DISPATCH
-//===----------------------------------------------------------------------===//
-
-static iree_hsa_kernel_dispatch_packet_t*
-iree_hal_amdgpu_device_cmd_dispatch_reserve(
-    iree_hal_amdgpu_device_execution_state_t* IREE_AMDGPU_RESTRICT state,
-    const iree_hal_amdgpu_device_command_block_t* IREE_AMDGPU_RESTRICT block,
-    const iree_hal_amdgpu_device_cmd_dispatch_t* IREE_AMDGPU_RESTRICT cmd,
-    uint8_t* IREE_AMDGPU_RESTRICT kernarg_base, const uint64_t packet_id) {
-  const iree_hal_amdgpu_device_kernel_args_t* dispatch_args =
-      cmd->config.kernel_args;
-
-  // Populate bindings and constants in the reserved kernarg storage.
-  for (uint16_t i = 0; i < dispatch_args->binding_count; ++i) {
-    ((uint64_t*)kernarg_base)[i] =
-        (uint64_t)iree_hal_amdgpu_device_buffer_ref_resolve(cmd->bindings[i],
-                                                            state->bindings);
-  }
-  uint8_t* kernarg_ptr =
-      kernarg_base + dispatch_args->binding_count * sizeof(void*);
-  iree_amdgpu_memcpy(kernarg_ptr, cmd->constants,
-                     dispatch_args->constant_count * sizeof(uint32_t));
-
-  // Construct the dispatch packet based on the template embedded in the command
-  // buffer. Note that the header is not written until the end so that the
-  // hardware command processor stalls until we're done writing.
-  iree_hsa_kernel_dispatch_packet_t* packet =
-      iree_hal_amdgpu_device_cmd_resolve_dispatch_packet(state, packet_id);
-  packet->setup = dispatch_args->setup;
-  packet->workgroup_size[0] = dispatch_args->workgroup_size[0];
-  packet->workgroup_size[1] = dispatch_args->workgroup_size[1];
-  packet->workgroup_size[2] = dispatch_args->workgroup_size[2];
-  packet->reserved0 = 0;
-  packet->private_segment_size = dispatch_args->private_segment_size;
-  packet->group_segment_size =
-      dispatch_args->group_segment_size + cmd->config.dynamic_lds_size;
-  packet->kernel_object = dispatch_args->kernel_object;
-  packet->kernarg_address = kernarg_ptr;
-  packet->reserved2 = 0;
-
-  // Resolve the workgroup count (if possible).
-  const uint32_t* workgroup_count_ptr = NULL;
-  if (cmd->config.flags &
-      IREE_HAL_AMDGPU_DEVICE_DISPATCH_FLAG_INDIRECT_STATIC) {
-    // Workgroup count is indirect but statically available and can be resolved
-    // during issue. This is the common case where the workgroup count is stored
-    // in a uniform buffer by the launcher and it allows us to avoid any
-    // additional dispatch overhead.
-    workgroup_count_ptr =
-        iree_hal_amdgpu_device_workgroup_count_buffer_ref_resolve(
-            cmd->config.workgroup_count.ref, state->bindings);
-  } else {
-    // Workgroup count is constant.
-    workgroup_count_ptr = cmd->config.workgroup_count.dims;
-  }
-  packet->grid_size[0] = workgroup_count_ptr[0] * packet->workgroup_size[0];
-  packet->grid_size[1] = workgroup_count_ptr[1] * packet->workgroup_size[1];
-  packet->grid_size[2] = workgroup_count_ptr[2] * packet->workgroup_size[2];
-
-  // If the dispatch requires implicit args then populate them now.
-  // Some of these are static and could be precomputed while others are
-  // dependent on where we are running. We only need to produce until we meet
-  // the implicit kernarg size (the remainder after our own explicit args).
-  //
-  // TODO(benvanik): once we have real kernels coming through see what we
-  // actually use. We can also check if it makes sense to try to be efficient or
-  // always splat in everything to keep things more uniform. The kernarg
-  // reservation being larger than required is sad but so is the need to check
-  // each field individually.
-  //
-  // Today we hardcode any kernel needing implicit args to attach a
-  // IREE_AMDGPU_KERNEL_IMPLICIT_ARGS_SIZE-byte suffix that we initialize
-  // unconditionally. Detecting what we need to populate is significantly more
-  // expensive than just always reserving the space. Probably. This entire
-  // design makes me sad as it's *never* required in our generated kernels and
-  // the best we can hope for is that we don't accidentally use any device
-  // library code that triggers the implicit args to be required.
-  if (cmd->header.flags & IREE_HAL_AMDGPU_DEVICE_DISPATCH_FLAG_IMPLICIT_ARGS) {
-    kernarg_ptr +=
-        iree_amdgpu_align(dispatch_args->constant_count * sizeof(uint32_t), 8);
-    iree_amdgpu_kernel_implicit_args_t* IREE_AMDGPU_RESTRICT implicit_args =
-        (iree_amdgpu_kernel_implicit_args_t*)kernarg_ptr;
-    // This information is redundant with the dispatch packet and it's sad that
-    // it is required by HIP.
-    implicit_args->block_count[0] = workgroup_count_ptr[0];
-    implicit_args->block_count[1] = workgroup_count_ptr[1];
-    implicit_args->block_count[2] = workgroup_count_ptr[2];
-    implicit_args->group_size[0] = dispatch_args->workgroup_size[0];
-    implicit_args->group_size[1] = dispatch_args->workgroup_size[1];
-    implicit_args->group_size[2] = dispatch_args->workgroup_size[2];
-    // Hardcoded to 0 in HIP.
-    implicit_args->remainder[0] = 0;
-    implicit_args->remainder[1] = 0;
-    implicit_args->remainder[2] = 0;
-    // Hardcoded to 0 in HIP.
-    implicit_args->global_offset[0] = 0;
-    implicit_args->global_offset[1] = 0;
-    implicit_args->global_offset[2] = 0;
-    // Hardcoded to 3 in HIP.
-    implicit_args->grid_dims = 3;
-    // TODO(benvanik): support printf_buffer (and maybe hostcall_buffer).
-    // Today we set to NULL so we get a segfault if a kernel happens to use
-    // them.
-    implicit_args->printf_buffer = NULL;
-    implicit_args->hostcall_buffer = NULL;
-    // We don't currently use dynamic LDS but may allow HIP kernels to do so.
-    implicit_args->dynamic_lds_size = cmd->config.dynamic_lds_size;
-  }
-
-  // NOTE: we return the packet without having updated the header. The caller
-  // is responsible for calling iree_hal_amdgpu_device_cmd_dispatch_mark_ready
-  // when it is ready for the hardware command processor to pick up the packet.
-  return packet;
-}
-
-static void iree_hal_amdgpu_device_cmd_dispatch_issue(
-    iree_hal_amdgpu_device_execution_state_t* IREE_AMDGPU_RESTRICT state,
-    const iree_hal_amdgpu_device_command_block_t* IREE_AMDGPU_RESTRICT block,
-    const iree_hal_amdgpu_device_cmd_dispatch_t* IREE_AMDGPU_RESTRICT cmd,
-    const uint64_t packet_id,
-    const iree_hal_amdgpu_trace_execution_query_id_t execution_query_id) {
-  // Enqueue the dispatch packet but do not mark it as ready yet.
-  uint8_t* kernarg_ptr = state->execution_kernarg_storage + cmd->kernarg_offset;
-  iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT packet =
-      iree_hal_amdgpu_device_cmd_dispatch_reserve(state, block, cmd,
-                                                  kernarg_ptr, packet_id);
-
-  // Mark the dispatch as complete and allow the hardware command processor to
-  // process it.
-  return iree_hal_amdgpu_device_cmd_commit_dispatch_packet(
-      state, &cmd->header, packet,
-      IREE_HAL_AMDGPU_TRACE_EXECUTION_ZONE_TYPE_DISPATCH,
-      cmd->config.kernel_args->trace_src_loc, execution_query_id);
-}
-
-//===----------------------------------------------------------------------===//
-// IREE_HAL_AMDGPU_DEVICE_CMD_DISPATCH_INDIRECT_DYNAMIC
-//===----------------------------------------------------------------------===//
-
-IREE_AMDGPU_ATTRIBUTE_KERNEL IREE_AMDGPU_ATTRIBUTE_SINGLE_WORK_ITEM void
-iree_hal_amdgpu_device_cmd_dispatch_update(
-    const iree_hal_amdgpu_device_cmd_dispatch_t* IREE_AMDGPU_RESTRICT cmd,
-    const uint32_t* IREE_AMDGPU_RESTRICT workgroups_ptr,
-    iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT packet) {
-  // Read the uint32_t[3] workgroup count buffer and update the packet in-place.
-  packet->grid_size[0] = workgroups_ptr[0] * packet->workgroup_size[0];
-  packet->grid_size[1] = workgroups_ptr[1] * packet->workgroup_size[0];
-  packet->grid_size[2] = workgroups_ptr[2] * packet->workgroup_size[0];
-
-  // If the kernel has implicit args then update those as well.
-  // This results from an unfortunate design decision and is only needed for
-  // compatibility with HIP kernels that use the device library routines that
-  // fetch implicit args instead of using the builtins that reference registers.
-  if (cmd->header.flags & IREE_HAL_AMDGPU_DEVICE_DISPATCH_FLAG_IMPLICIT_ARGS) {
-    const iree_hal_amdgpu_device_kernel_args_t* dispatch_args =
-        cmd->config.kernel_args;
-    uint8_t* implicit_args_ptr =
-        (uint8_t*)packet->kernarg_address +
-        dispatch_args->binding_count * sizeof(void*) +
-        iree_amdgpu_align(dispatch_args->constant_count * sizeof(uint32_t), 8);
-    iree_amdgpu_kernel_implicit_args_t* IREE_AMDGPU_RESTRICT implicit_args =
-        (iree_amdgpu_kernel_implicit_args_t*)implicit_args_ptr;
-    implicit_args->block_count[0] = workgroups_ptr[0];
-    implicit_args->block_count[1] = workgroups_ptr[1];
-    implicit_args->block_count[2] = workgroups_ptr[2];
-  }
-
-  // Now that the packet (and maybe kernargs) have been updated we can mark it
-  // as ready so that the hardware command processor can take it. Since the
-  // execution queue has already had its doorbell updated we don't need to do
-  // that - it _should_ be spinning on the packet.
-  const uint16_t header =
-      (packet->header &
-       ~(IREE_HSA_PACKET_TYPE_INVALID << IREE_HSA_PACKET_HEADER_TYPE)) |
-      (IREE_HSA_PACKET_TYPE_KERNEL_DISPATCH << IREE_HSA_PACKET_HEADER_TYPE);
-  const uint32_t header_setup = header | (uint32_t)(packet->setup << 16);
-  iree_amdgpu_scoped_atomic_store(
-      (iree_amdgpu_scoped_atomic_uint32_t*)packet, header_setup,
-      iree_amdgpu_memory_order_release, iree_amdgpu_memory_scope_device);
-}
-
-static void iree_hal_amdgpu_device_cmd_dispatch_indirect_dynamic_issue(
-    iree_hal_amdgpu_device_execution_state_t* IREE_AMDGPU_RESTRICT state,
-    const iree_hal_amdgpu_device_command_block_t* IREE_AMDGPU_RESTRICT block,
-    const iree_hal_amdgpu_device_cmd_dispatch_t* IREE_AMDGPU_RESTRICT cmd,
-    const uint64_t packet_id,
-    const iree_hal_amdgpu_trace_execution_query_id_t execution_query_id) {
-  const uint32_t update_id = packet_id;
-  const uint32_t dispatch_id = update_id + 1;
-
-  // Enqueue the dispatch packet but do not mark it as ready yet.
-  // We do this first so that if the workgroup count update dispatch begins
-  // executing while we're still running we want it to have valid data to
-  // manipulate.
-  uint8_t* IREE_AMDGPU_RESTRICT dispatch_kernarg_ptr =
-      state->execution_kernarg_storage + cmd->kernarg_offset +
-      IREE_HAL_AMDGPU_DEVICE_WORKGROUP_COUNT_UPDATE_KERNARG_SIZE;
-  iree_hsa_kernel_dispatch_packet_t* dispatch_packet =
-      iree_hal_amdgpu_device_cmd_dispatch_reserve(
-          state, block, cmd, dispatch_kernarg_ptr, dispatch_id);
-
-  // Update the dispatch packet with the required scheduling information.
-  // It is not yet committed and will not have its type set until the indirect
-  // update executes.
-  iree_hal_amdgpu_device_cmd_update_dispatch_packet(
-      state, &cmd->header, dispatch_packet,
-      IREE_HAL_AMDGPU_TRACE_EXECUTION_ZONE_TYPE_DISPATCH_INDIRECT,
-      cmd->config.kernel_args->trace_src_loc, execution_query_id);
-
-  // Workgroup count is dynamic and must be resolved just prior to executing
-  // the dispatch. There's no native AQL dispatch behavior to enable this so
-  // we have to emulate it by enqueuing a builtin that performs the
-  // indirection and overwrites the packet memory directly.
-  uint64_t* IREE_AMDGPU_RESTRICT update_kernarg_ptr =
-      (uint64_t*)(state->execution_kernarg_storage + cmd->kernarg_offset);
-  update_kernarg_ptr[0] = (uint64_t)cmd;
-  update_kernarg_ptr[1] =
-      (uint64_t)iree_hal_amdgpu_device_workgroup_count_buffer_ref_resolve(
-          cmd->config.workgroup_count.ref, state->bindings);
-  update_kernarg_ptr[2] = (uint64_t)dispatch_packet;
-
-  // Construct the update packet.
-  // Note that the header is not written until the end so that the
-  // hardware command processor stalls until we're done writing.
-  const iree_hal_amdgpu_device_kernel_args_t update_args =
-      state->kernels->iree_hal_amdgpu_device_cmd_dispatch_update;
-  iree_hsa_kernel_dispatch_packet_t* update_packet =
-      iree_hal_amdgpu_device_cmd_resolve_dispatch_packet(state, packet_id);
-  update_packet->setup = update_args.setup;
-  update_packet->workgroup_size[0] = update_args.workgroup_size[0];
-  update_packet->workgroup_size[1] = update_args.workgroup_size[1];
-  update_packet->workgroup_size[2] = update_args.workgroup_size[2];
-  update_packet->reserved0 = 0;
-  update_packet->grid_size[0] = 1;
-  update_packet->grid_size[1] = 1;
-  update_packet->grid_size[2] = 1;
-  update_packet->private_segment_size = update_args.private_segment_size;
-  update_packet->group_segment_size = update_args.group_segment_size;
-  update_packet->kernel_object = update_args.kernel_object;
-  update_packet->kernarg_address = update_kernarg_ptr;
-  update_packet->reserved2 = 0;
-
-  // Mark the update packet as ready to execute. The hardware command processor
-  // may begin executing it immediately after performing the atomic swap.
-  //
-  // NOTE: the following dispatch packet is still marked INVALID and is only
-  // changed after the update dispatch completes. The hardware command processor
-  // should process the update (as we change it from INVALID here) and then
-  // block before reading the contents of the dispatch packet.
-  return iree_hal_amdgpu_device_cmd_commit_dispatch_packet(
-      state, &cmd->header, update_packet,
-      IREE_HAL_AMDGPU_TRACE_EXECUTION_ZONE_TYPE_INTERNAL, 0,
-      IREE_HAL_AMDGPU_TRACE_EXECUTION_QUERY_ID_INVALID);
-}
-
-//===----------------------------------------------------------------------===//
-// IREE_HAL_AMDGPU_DEVICE_CMD_BRANCH
-//===----------------------------------------------------------------------===//
-
-// Enqueues the command buffer rescheduling on the scheduler queue.
-// The next tick will move the parent queue entry to the ready list and attempt
-// to schedule the next block as set on the state.
-IREE_AMDGPU_ATTRIBUTE_KERNEL IREE_AMDGPU_ATTRIBUTE_SINGLE_WORK_ITEM void
-iree_hal_amdgpu_device_cmd_branch(
-    iree_hal_amdgpu_device_execution_state_t* IREE_AMDGPU_RESTRICT state,
-    const iree_hal_amdgpu_device_command_block_t* IREE_AMDGPU_RESTRICT
-        next_block) {
-  // Flush trace zones if any were used.
-  // Note that this won't include this kernel as it is still running.
-  // TODO(benvanik): find a way to include the time for the terminator.
-  iree_hal_amdgpu_device_flush_execution_queries(state);
-
-  // Set the target block of the branch.
-  // When the execution state is rescheduled it will resume at the new block.
-  state->block = next_block;
-
-  // TODO(benvanik): evaluate or make a mode bit to control whether command
-  // buffers yield for rescheduling or if they directly enqueue the issue block.
-  // Rescheduling would allow for better QoS as older but newly-runnable entries
-  // would be allowed to execute ahead of the continuation point. The downside
-  // of rescheduling is that we'll have at least one hop to do the scheduler
-  // tick and then one more to do the block issue.
-  //
-  // NOTE: the rescheduling may happen immediately and we cannot use any
-  // execution state.
-  const bool direct_issue_block = true;
-  if (direct_issue_block) {
-    // Enqueue the command buffer issue on the control queue.
-    // It'll continue executing at the state block set above.
-    return iree_hal_amdgpu_device_command_buffer_enqueue_next_block(state);
-  } else {
-    // Enqueue the parent queue scheduler tick.
-    // It will move the queue entry to the ready list and may immediately begin
-    // issuing the next block.
-    return iree_hal_amdgpu_device_queue_scheduler_reschedule_from_execution_queue(
-        state->scheduler, state->scheduler_queue_entry);
-  }
-}
-
-static void iree_hal_amdgpu_device_cmd_branch_issue(
-    iree_hal_amdgpu_device_execution_state_t* IREE_AMDGPU_RESTRICT state,
-    const iree_hal_amdgpu_device_command_block_t* IREE_AMDGPU_RESTRICT block,
-    const iree_hal_amdgpu_device_cmd_branch_t* IREE_AMDGPU_RESTRICT cmd,
-    const uint64_t packet_id,
-    const iree_hal_amdgpu_trace_execution_query_id_t execution_query_id) {
-  // Direct branches are like tail calls and can simply begin issuing the
-  // following block. The kernargs are stored in state->control_kernarg_storage
-  // so that the issue_block can completely overwrite the values.
-  // Command buffer issue has already bumped the write_index and all we need to
-  // do is populate the packet.
-  //
-  // NOTE: we implicitly assume
-  // IREE_HAL_AMDGPU_DEVICE_CMD_FLAG_QUEUE_AWAIT_BARRIER but need not do so
-  // (technically) when continuing within the same command buffer. Performing a
-  // barrier is a more conservative operation and may mask compiler/command
-  // buffer construction issues with the more strict execution model but in
-  // practice is unlikely to have an appreciable effect on latency.
-
-  // Pass target block to the branch op.
-  // For other CFG commands (like conditional branches) we pass the information
-  // required to evaluate the condition and calculate the target block.
-  uint64_t* kernarg_ptr =
-      (uint64_t*)(state->execution_kernarg_storage + cmd->kernarg_offset);
-  kernarg_ptr[0] = (uint64_t)state;
-  kernarg_ptr[1] = (uint64_t)&state->command_buffer->blocks[cmd->target_block];
-
-  // Emplace and ready the CFG packet.
-  return iree_hal_amdgpu_device_cmd_commit_cfg_packet(
-      state, &cmd->header, packet_id, execution_query_id,
-      &state->kernels->iree_hal_amdgpu_device_cmd_branch, kernarg_ptr);
-}
-
-//===----------------------------------------------------------------------===//
-// IREE_HAL_AMDGPU_DEVICE_CMD_COND_BRANCH
-//===----------------------------------------------------------------------===//
-
-static inline bool iree_hal_amdgpu_device_evaluate_cond(
-    uint64_t lhs, iree_hal_amdgpu_device_cmd_cond_t cond, uint64_t rhs) {
-  switch (cond) {
-    default:
-    case IREE_HAL_AMDGPU_DEVICE_CMD_COND_EQ:
-      return (int64_t)lhs == (int64_t)rhs;
-    case IREE_HAL_AMDGPU_DEVICE_CMD_COND_NE:
-      return (int64_t)lhs != (int64_t)rhs;
-    case IREE_HAL_AMDGPU_DEVICE_CMD_COND_SLT:
-      return (int64_t)lhs < (int64_t)rhs;
-    case IREE_HAL_AMDGPU_DEVICE_CMD_COND_SLE:
-      return (int64_t)lhs <= (int64_t)rhs;
-    case IREE_HAL_AMDGPU_DEVICE_CMD_COND_SGT:
-      return (int64_t)lhs > (int64_t)rhs;
-    case IREE_HAL_AMDGPU_DEVICE_CMD_COND_SGE:
-      return (int64_t)lhs >= (int64_t)rhs;
-    case IREE_HAL_AMDGPU_DEVICE_CMD_COND_ULT:
-      return (uint64_t)lhs < (uint64_t)rhs;
-    case IREE_HAL_AMDGPU_DEVICE_CMD_COND_ULE:
-      return (uint64_t)lhs <= (uint64_t)rhs;
-    case IREE_HAL_AMDGPU_DEVICE_CMD_COND_UGT:
-      return (uint64_t)lhs > (uint64_t)rhs;
-    case IREE_HAL_AMDGPU_DEVICE_CMD_COND_UGE:
-      return (uint64_t)lhs >= (uint64_t)rhs;
-  }
-}
-
-// Enqueues the command buffer rescheduling on the scheduler queue.
-// The next tick will move the parent queue entry to the ready list and attempt
-// to schedule the next block as set on the state. The block selected is based
-// on the provided dynamic and static values:
-//   next_block = *ref_ptr <cond> value ? true_block : false_block
-IREE_AMDGPU_ATTRIBUTE_KERNEL IREE_AMDGPU_ATTRIBUTE_SINGLE_WORK_ITEM void
-iree_hal_amdgpu_device_cmd_cond_branch(
-    iree_hal_amdgpu_device_execution_state_t* IREE_AMDGPU_RESTRICT state,
-    const uint64_t* IREE_AMDGPU_RESTRICT ref_ptr,
-    iree_hal_amdgpu_device_cmd_cond_t cond, uint64_t value,
-    const iree_hal_amdgpu_device_command_block_t* IREE_AMDGPU_RESTRICT
-        true_block,
-    const iree_hal_amdgpu_device_command_block_t* IREE_AMDGPU_RESTRICT
-        false_block) {
-  // Flush trace zones if any were used.
-  // Note that this won't include this kernel as it is still running.
-  // TODO(benvanik): find a way to include the time for the terminator.
-  iree_hal_amdgpu_device_flush_execution_queries(state);
-
-  // Evaluate condition.
-  //
-  // TODO(benvanik): evaluate whether we should be doing this like a dynamic
-  // indirect dispatch: since the AMD LLVMGPU backend (/hardware) is pretty
-  // garbage at function calls having this kernel may cause a pretty extreme
-  // amount of bloat as the compiler desperately tries to inline every function
-  // to avoid dealing with the 👻 behavior of 👋 functions and stacks 👋. We
-  // could instead always issue an unconditional branch but have the update
-  // dispatch overwrite the kernarg for the `next_block` with the result of the
-  // condition. This is a lot clearer and more flexible, though, so hopefully
-  // that's not required. It's probably required. Sigh.
-  const iree_hal_amdgpu_device_command_block_t* IREE_AMDGPU_RESTRICT
-      next_block = iree_hal_amdgpu_device_evaluate_cond(ref_ptr[0], cond, value)
-                       ? true_block
-                       : false_block;
-
-  // Set the target block of the branch.
-  // When the execution state is rescheduled it will resume at the new block.
-  state->block = next_block;
-
-  // TODO(benvanik): evaluate or make a mode bit to control whether command
-  // buffers yield for rescheduling or if they directly enqueue the issue block.
-  // Rescheduling would allow for better QoS as older but newly-runnable entries
-  // would be allowed to execute ahead of the continuation point. The downside
-  // of rescheduling is that we'll have at least one hop to do the scheduler
-  // tick and then one more to do the block issue.
-  //
-  // NOTE: the rescheduling may happen immediately and we cannot use any
-  // execution state.
-  const bool direct_issue_block = true;
-  if (direct_issue_block) {
-    // Enqueue the command buffer issue on the control queue.
-    // It'll continue executing at the state block set above.
-    return iree_hal_amdgpu_device_command_buffer_enqueue_next_block(state);
-  } else {
-    // Enqueue the parent queue scheduler tick.
-    // It will move the queue entry to the ready list and may immediately begin
-    // issuing the next block.
-    return iree_hal_amdgpu_device_queue_scheduler_reschedule_from_execution_queue(
-        state->scheduler, state->scheduler_queue_entry);
-  }
-}
-
-static void iree_hal_amdgpu_device_cmd_cond_branch_issue(
-    iree_hal_amdgpu_device_execution_state_t* IREE_AMDGPU_RESTRICT state,
-    const iree_hal_amdgpu_device_command_block_t* IREE_AMDGPU_RESTRICT block,
-    const iree_hal_amdgpu_device_cmd_cond_branch_t* IREE_AMDGPU_RESTRICT cmd,
-    const uint64_t packet_id,
-    const iree_hal_amdgpu_trace_execution_query_id_t execution_query_id) {
-  // Direct branches are like tail calls and can simply begin issuing the
-  // following block. The kernargs are stored in state->control_kernarg_storage
-  // so that the issue_block can completely overwrite the values.
-  // Command buffer issue has already bumped the write_index and all we need to
-  // do is populate the packet.
-  //
-  // NOTE: we implicitly assume
-  // IREE_HAL_AMDGPU_DEVICE_CMD_FLAG_QUEUE_AWAIT_BARRIER but need not do so
-  // (technically) when continuing within the same command buffer. Performing a
-  // barrier is a more conservative operation and may mask compiler/command
-  // buffer construction issues with the more strict execution model but in
-  // practice is unlikely to have an appreciable effect on latency.
-
-  // Pass the condition, its inputs, and the true/false blocks to the kernel.
-  // Which block is evaluated at the time the command is executed but since
-  // block pointers are immutable we can resolve them now to make the kernel
-  // simpler.
-  uint64_t* kernarg_ptr =
-      (uint64_t*)(state->execution_kernarg_storage + cmd->kernarg_offset);
-  kernarg_ptr[0] = (uint64_t)state;
-  kernarg_ptr[1] = (uint64_t)iree_hal_amdgpu_device_uint64_buffer_ref_resolve(
-      cmd->ref, state->bindings);
-  kernarg_ptr[2] = cmd->cond;
-  kernarg_ptr[3] = cmd->value;
-  kernarg_ptr[4] = (uint64_t)&state->command_buffer->blocks[cmd->true_block];
-  kernarg_ptr[5] = (uint64_t)&state->command_buffer->blocks[cmd->false_block];
-
-  // Emplace and ready the CFG packet.
-  return iree_hal_amdgpu_device_cmd_commit_cfg_packet(
-      state, &cmd->header, packet_id, execution_query_id,
-      &state->kernels->iree_hal_amdgpu_device_cmd_cond_branch, kernarg_ptr);
-}
-
-//===----------------------------------------------------------------------===//
-// IREE_HAL_AMDGPU_DEVICE_CMD_RETURN
-//===----------------------------------------------------------------------===//
-
-// Enqueues the command buffer retirement on the parent scheduler.
-// The execution state may be deallocated immediately.
-IREE_AMDGPU_ATTRIBUTE_KERNEL IREE_AMDGPU_ATTRIBUTE_SINGLE_WORK_ITEM void
-iree_hal_amdgpu_device_cmd_return(
-    iree_hal_amdgpu_device_execution_state_t* IREE_AMDGPU_RESTRICT state) {
-  // Flush trace zones if any were used.
-  // Note that this won't include this kernel as it is still running.
-  // TODO(benvanik): find a way to include the time for the terminator.
-  iree_hal_amdgpu_device_flush_execution_queries(state);
-
-  // Clear block to indicate execution has completed.
-  //
-  // TODO(benvanik): does this need to be atomic release?
-  state->block = NULL;
-
-  // Enqueue the parent queue scheduler tick.
-  // It will clean up the command buffer execution state and resume
-  // processing queue entries.
-  //
-  // NOTE: the retire may immediately reclaim the execution state and we cannot
-  // do anything else with it.
-  return iree_hal_amdgpu_device_queue_scheduler_retire_from_execution_queue(
-      state->scheduler, state->scheduler_queue_entry);
-}
-
-static void iree_hal_amdgpu_device_cmd_return_issue(
-    iree_hal_amdgpu_device_execution_state_t* IREE_AMDGPU_RESTRICT state,
-    const iree_hal_amdgpu_device_command_block_t* IREE_AMDGPU_RESTRICT block,
-    const iree_hal_amdgpu_device_cmd_return_t* IREE_AMDGPU_RESTRICT cmd,
-    const uint64_t packet_id,
-    const iree_hal_amdgpu_trace_execution_query_id_t execution_query_id) {
-  // TODO(benvanik): handle call stacks when nesting command buffers. For now a
-  // return is always going back to the queue scheduler and can be enqueued as
-  // such.
-
-  // Pass just the state; the returns will always return to the scheduler.
-  uint64_t* kernarg_ptr =
-      (uint64_t*)(state->execution_kernarg_storage + cmd->kernarg_offset);
-  kernarg_ptr[0] = (uint64_t)state;
-
-  // Emplace and ready the CFG packet.
-  return iree_hal_amdgpu_device_cmd_commit_cfg_packet(
-      state, &cmd->header, packet_id, execution_query_id,
-      &state->kernels->iree_hal_amdgpu_device_cmd_return, kernarg_ptr);
-}
-
-//===----------------------------------------------------------------------===//
-// Command Issue
-//===----------------------------------------------------------------------===//
-
-// Issues a block of commands in parallel.
-// Each work item processes a single command. Each command in the block contains
-// a relative offset into the queue where AQL packets should be placed and must
-// fill all packets that were declared when the command buffer was recorded
-// (even if they are no-oped).
-//
-// This relies on the AQL queue mechanics defined in section 2.8.3 of the HSA
-// System Architecture Specification. The parent enqueuing this kernel reserves
-// sufficient queue space for all AQL packets and bumps the write_index to the
-// end of the block. Each command processed combines the base queue index
-// provided with the per-command relative offset and performs the required queue
-// masking to get the final packet pointer. Packets are written by populating
-// all kernel arguments (if any), populating the packet fields, and finally
-// atomically changing the packet type from INVALID to (likely) KERNEL_DISPATCH.
-// Even though the write_index of the queue was bumped to the end the queue
-// processor is required to block on the first packet it finds with an INVALID
-// type and as such we don't require ordering guarantees on the packet
-// population. It's of course better if the first packet complete population
-// first so that the queue processor can launch it and that will often be the
-// case given that HSA mandates that workgroups with lower indices are scheduled
-// to resources before those with higher ones.
-//
-// Of course, the spec _could_ be wrong and not match reality. At this point I'm
-// _sure_ it doesn't match on many closely related aspects. Let's see what we
-// can get away with for the next short while. I immediately recognize I will
-// regret both making this implementation decision and writing these words.
-// Future me: deal ;) Workarounds are to issue all from a single thread or issue
-// with chunked reservations per thread. We could also do an atomic MAX on the
-// write index to keep the window between nearly-populated packets (those who
-// have bumped the write index but not yet changed from INVALID) and
-// fully-populated packets (with INVALID transitioned to a packet type) as small
-// as possible, however PCIe atomics don't support MAX and though we _shouldn't_
-// need that here (device->device) it does close off some design space we may
-// need to explore to get around _other_ wonkiness in various hardware (I'm
-// looking at you, HDP ಠ_ಠ).
-IREE_AMDGPU_ATTRIBUTE_KERNEL void iree_hal_amdgpu_device_cmd_block_issue(
-    iree_hal_amdgpu_device_execution_state_t* IREE_AMDGPU_RESTRICT state,
-    const iree_hal_amdgpu_device_command_block_t* IREE_AMDGPU_RESTRICT block,
-    const uint64_t base_packet_id) {
-  // Each invocation handles a single command in the block.
-  const uint32_t command_ordinal = iree_hal_amdgpu_device_global_id_x();
-  if (command_ordinal >= block->command_count) return;
-  return iree_hal_amdgpu_device_cmd_issue(state, block, command_ordinal,
-                                          base_packet_id);  // tail ideal
-}
-
-// Issues a single command packet in a block.
-//
-// NOTE: this should be tail-called to avoid non-trivial stack management
-// overhead (as the AMD LLVMGPU backend is very poor at function calls). Not
-// correctly tail-calling such that LLVM can recognize it can easily double
-// binary size. Unfortunately today this often doesn't happen due to ABI
-// mismatch (the caller is device_kernel, the target is cdecl). We should find a
-// way to make these line up so the caller can turn into a jump instead of
-// having to deal with 👻 function calls 👻.
-static void iree_hal_amdgpu_device_cmd_issue(
-    iree_hal_amdgpu_device_execution_state_t* IREE_AMDGPU_RESTRICT state,
-    const iree_hal_amdgpu_device_command_block_t* IREE_AMDGPU_RESTRICT block,
-    uint32_t command_ordinal, uint64_t base_packet_id) {
-  // When device control or dispatch tracing is enabled we need to pass a query
-  // signal with any work we do. Prior to the block starting execution we
-  // acquire a range for all commands on the scheduler queue and store it in
-  // state->trace_block_query_base_id. Here we then take that base ID and add a
-  // relative offset that was precomputed when the command buffer was recorded.
-  // This allows us to support sparse/partial queries and still issue in
-  // parallel while respecting the required query ordering.
-  //
-  // There's probably a much simpler way of doing this - not needing all this
-  // branching per command or the precomputed query map would be nice.
-  iree_hal_amdgpu_trace_execution_query_id_t execution_query_id =
-      IREE_HAL_AMDGPU_TRACE_EXECUTION_QUERY_ID_INVALID;
-#if IREE_HAL_AMDGPU_HAS_TRACING_FEATURE( \
-    IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_CONTROL)
-  const iree_hal_amdgpu_device_command_query_id_t command_query_id =
-      block->query_map.query_ids[command_ordinal];
-  if ((state->flags & IREE_HAL_AMDGPU_DEVICE_EXECUTION_FLAG_TRACE_DISPATCH) &&
-      command_query_id.dispatch_id !=
-          IREE_HAL_AMDGPU_TRACE_EXECUTION_QUERY_ID_INVALID) {
-    execution_query_id = iree_hal_amdgpu_device_query_ringbuffer_query_id(
-        &state->trace_buffer->query_ringbuffer,
-        state->trace_block_query_base_id + command_query_id.dispatch_id);
-  } else if ((state->flags &
-              IREE_HAL_AMDGPU_DEVICE_EXECUTION_FLAG_TRACE_CONTROL) &&
-             command_query_id.control_id !=
-                 IREE_HAL_AMDGPU_TRACE_EXECUTION_QUERY_ID_INVALID) {
-    execution_query_id = iree_hal_amdgpu_device_query_ringbuffer_query_id(
-        &state->trace_buffer->query_ringbuffer,
-        state->trace_block_query_base_id + command_query_id.control_id);
-  }
-#endif  // IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_CONTROL
-
-  // Tail-call into the command handler.
-  const iree_hal_amdgpu_device_cmd_t* IREE_AMDGPU_RESTRICT cmd =
-      &block->commands[command_ordinal];
-  const uint64_t packet_id = base_packet_id + cmd->header.packet_offset;
-  switch (cmd->header.type) {
-    default:
-      return;  // no-op
-    case IREE_HAL_AMDGPU_DEVICE_CMD_DEBUG_GROUP_BEGIN:
-      return iree_hal_amdgpu_device_cmd_debug_group_begin_issue(
-          state, block,
-          (const iree_hal_amdgpu_device_cmd_debug_group_begin_t*)cmd, packet_id,
-          execution_query_id);
-    case IREE_HAL_AMDGPU_DEVICE_CMD_DEBUG_GROUP_END:
-      return iree_hal_amdgpu_device_cmd_debug_group_end_issue(
-          state, block,
-          (const iree_hal_amdgpu_device_cmd_debug_group_end_t*)cmd, packet_id,
-          execution_query_id);
-    case IREE_HAL_AMDGPU_DEVICE_CMD_BARRIER:
-      return iree_hal_amdgpu_device_cmd_barrier_issue(
-          state, block, (const iree_hal_amdgpu_device_cmd_barrier_t*)cmd,
-          packet_id, execution_query_id);
-    case IREE_HAL_AMDGPU_DEVICE_CMD_SIGNAL_EVENT:
-      return iree_hal_amdgpu_device_cmd_signal_event_issue(
-          state, block, (const iree_hal_amdgpu_device_cmd_signal_event_t*)cmd,
-          packet_id, execution_query_id);
-    case IREE_HAL_AMDGPU_DEVICE_CMD_RESET_EVENT:
-      return iree_hal_amdgpu_device_cmd_reset_event_issue(
-          state, block, (const iree_hal_amdgpu_device_cmd_reset_event_t*)cmd,
-          packet_id, execution_query_id);
-    case IREE_HAL_AMDGPU_DEVICE_CMD_WAIT_EVENTS:
-      return iree_hal_amdgpu_device_cmd_wait_events_issue(
-          state, block, (const iree_hal_amdgpu_device_cmd_wait_events_t*)cmd,
-          packet_id, execution_query_id);
-    case IREE_HAL_AMDGPU_DEVICE_CMD_FILL_BUFFER:
-      return iree_hal_amdgpu_device_cmd_fill_buffer_issue(
-          state, block, (const iree_hal_amdgpu_device_cmd_fill_buffer_t*)cmd,
-          packet_id, execution_query_id);
-    case IREE_HAL_AMDGPU_DEVICE_CMD_COPY_BUFFER:
-      return iree_hal_amdgpu_device_cmd_copy_buffer_issue(
-          state, block, (const iree_hal_amdgpu_device_cmd_copy_buffer_t*)cmd,
-          packet_id, execution_query_id);
-    case IREE_HAL_AMDGPU_DEVICE_CMD_DISPATCH:
-      return iree_hal_amdgpu_device_cmd_dispatch_issue(
-          state, block, (const iree_hal_amdgpu_device_cmd_dispatch_t*)cmd,
-          packet_id, execution_query_id);
-    case IREE_HAL_AMDGPU_DEVICE_CMD_DISPATCH_INDIRECT_DYNAMIC:
-      return iree_hal_amdgpu_device_cmd_dispatch_indirect_dynamic_issue(
-          state, block, (const iree_hal_amdgpu_device_cmd_dispatch_t*)cmd,
-          packet_id, execution_query_id);
-    case IREE_HAL_AMDGPU_DEVICE_CMD_BRANCH:
-      return iree_hal_amdgpu_device_cmd_branch_issue(
-          state, block, (const iree_hal_amdgpu_device_cmd_branch_t*)cmd,
-          packet_id, execution_query_id);
-    case IREE_HAL_AMDGPU_DEVICE_CMD_COND_BRANCH:
-      return iree_hal_amdgpu_device_cmd_cond_branch_issue(
-          state, block, (const iree_hal_amdgpu_device_cmd_cond_branch_t*)cmd,
-          packet_id, execution_query_id);
-    case IREE_HAL_AMDGPU_DEVICE_CMD_RETURN:
-      return iree_hal_amdgpu_device_cmd_return_issue(
-          state, block, (const iree_hal_amdgpu_device_cmd_return_t*)cmd,
-          packet_id, execution_query_id);
-  }
-  // NOTE: we need the above switch to end in tail calls in all cases. It
-  // doesn't today. But it should. If the stars align and we can make that
-  // happen it eliminates the (currently) only requirement for expanded shared
-  // memory faults in the entire library. At last check it also would halve the
-  // binary size: the AMD LLVMGPU backend inlines this function in any kernel
-  // that performs an issue.
-}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/command_buffer.h b/runtime/src/iree/hal/drivers/amdgpu/device/command_buffer.h
deleted file mode 100644
index 4d3c212..0000000
--- a/runtime/src/iree/hal/drivers/amdgpu/device/command_buffer.h
+++ /dev/null
@@ -1,994 +0,0 @@
-// Copyright 2025 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_HAL_DRIVERS_AMDGPU_DEVICE_COMMAND_BUFFER_H_
-#define IREE_HAL_DRIVERS_AMDGPU_DEVICE_COMMAND_BUFFER_H_
-
-#include "iree/hal/drivers/amdgpu/device/blit.h"
-#include "iree/hal/drivers/amdgpu/device/buffer.h"
-#include "iree/hal/drivers/amdgpu/device/kernels.h"
-#include "iree/hal/drivers/amdgpu/device/support/common.h"
-#include "iree/hal/drivers/amdgpu/device/support/queue.h"
-#include "iree/hal/drivers/amdgpu/device/tracing.h"
-
-typedef struct iree_hal_amdgpu_device_queue_scheduler_t
-    iree_hal_amdgpu_device_queue_scheduler_t;
-
-//===----------------------------------------------------------------------===//
-// Device-side Command Buffer
-//===----------------------------------------------------------------------===//
-//
-// Command buffers are represented by a host-side wrapper that implements the
-// IREE HAL API and a device-side data structure holding the recorded contents.
-// All information required to execute a command buffer lives on the device and
-// a command buffer can be submitted from the device without host involvement.
-// Command buffer data structures are immutable once constructed and can be
-// executed concurrently and repeatedly based on the same recording because
-// mutable execution state is stored separately as part of the issuing queue
-// operation. Though the device-side recorded command buffer closely follows the
-// HAL command buffer API it does not need to match 1:1. The initial
-// implementation bakes out a lot of information as part of recording but leaves
-// AQL packet construction to issue-time; future iterations could blend the two
-// approaches by preconstructing packets to copy when profitable.
-//
-// The recorded command buffer is partitioned into one or more command blocks.
-// Each block represents a yieldable point in the execution where the command
-// buffer scheduler is allowed to suspend processing. Segmenting allows for
-// basic control flow to be implemented within a command buffer by skipping,
-// branching, or looping over blocks and also enables execution when hardware
-// queues may not have capacity for the entire command buffer. Conceptually
-// command buffers are like coroutines/fibers in that any number may be
-// simultaneously executing the same program on the same hardware resources with
-// independent states.
-//
-// +----------------------------------+
-// | iree_hal_amdgpu_command_buffer_t | (host)
-// +-----------------------v----------+
-//                         |   +-----------------------------------------+
-//                (device) +---> iree_hal_amdgpu_device_command_buffer_t |
-//                             +------------------v----------------------+
-//                                                |
-//      +------------+------------+------------+--+------+--+------------+
-//      |            |            |            |            |            |
-// +----v----+  +----v----+  +----v----+  +----v----+  +----v----+  +----v----+
-// |  block  |..|  block  |..|  block  |..|  block  |..|  block  |..|  block  |
-// +----v----+  +---------+  +---------+  +---------+  +---------+  +---------+
-//      |
-//      |    +------------------------------+
-//      +----> command entry[]              | fixed-length struct array
-//      |    +------------------------------+
-//      +----> embedded command data...     | variable length packed buffer
-//           +------------------------------+
-//
-// Each block contains one or more commands encoded in fixed-length entries.
-// This allows commands to be indexed by ordinal within the block such that
-// command processing can be parallelized. An extra buffer is used to embed any
-// variable-length data the commands require in read-only memory such as update
-// source buffers, dispatch constants, and dispatch binding references.
-// Execution-invariant information is stored in the command and any
-// execution-dependent information is stored as either deltas/relative values or
-// bits that can be used to derive the information when the command is issued.
-//
-// Execution starts with the first block and progresses through blocks based
-// on control commands. Blocks are translated to AQL packets in parallel via
-// the command buffer block issue kernel. Each command may translate to one or
-// more AQL packets and space is reserved for the maximum potential AQL packets
-// that are required when the block is launched. Execution uses a state
-// structure that resides on device and is valid for the full duration of the
-// command buffer execution. Every concurrently executing instance of a command
-// buffer has its own state referencing its own kernel arguments buffer. Note
-// that any optional AQL packet not used must still be set to a valid no-op
-// packet in order for the command processor to progress through the AQL queue.
-//
-// Processing behavior:
-//   1. Initialize iree_hal_amdgpu_device_execution_state_t:
-//     a. Allocate execution state from the queue ringbuffer
-//     b. Assign the target hardware AQL queue to receive packets
-//     c. Reserve kernel arguments buffer with max size used by any block
-//     d. Copy binding table into the state (if present)
-//     e. Assign the first command buffer block as the entry block
-//   2. Enqueue iree_hal_amdgpu_device_cmd_block_issue:
-//     a. Reserve queue space for all command AQL packets
-//     b. Enqueue command processor kernel for the next block with barrier bit
-//   3. Command processor, parallelized over each command in a block:
-//     a. Assign/copy kernel arguments to scratch buffer (if needed)
-//     b. Construct AQL packet(s) for the command
-//     c. Change from type INVALID to the real type
-//   4. Repeat 2 and 3 until all blocks completed
-//   5. Enqueue top-level queue scheduler upon completion
-//   6. Deinitialize execution state (release resources)
-//
-//=----------------------------------------------------------------------------=
-//
-// Command buffer scheduling is always performed on the scheduler queue.
-// Execution of the commands is allowed to target another queue dedicated to
-// execution. When using multiple queues it's possible for the hardware to begin
-// executing the initial commands while the rest of the commands are still being
-// issued. This also allows the thread-compatible tracing logic to operate in
-// single-threaded mode with the scheduler queue being the only one producing
-// the synchronous "CPU" trace events while the execution queue produces the
-// asynchronous "GPU" trace events.
-//
-//              +=========+  +-------------+                      +--------+
-// scheduler q: | execute |->| issue block |...                ...| retire |
-//              +=========+  +|-|-|-|-|-|-|+                      +--------+
-//                            \ \ \ \ \ \ \                       ^
-//                             v v v v v v v                     /
-//                             +-----+-----+-----+-----+-----+--|--+
-// execution q:                | cmd | cmd | cmd | cmd | cmd | ret |
-//                             +-----+-----+-----+-----+-----+-----+
-//
-// The additional scheduler/execution queue hops between the command buffer
-// execution request, each block, and the retire are insignificant compared to
-// the actual execution time _when command buffers are large_ and it allows us
-// to use queue priorities to ensure that scheduling runs ASAP even if the
-// execution queue is heavily utilized. It also allows us to have one scheduler
-// target multiple execution queues for concurrent command buffer processing or
-// multiple schedulers target a single execution queue to ensure it is always
-// utilized. It is not ideal for many small command buffers: there's currently
-// some naive attempts at mitigating the latency (such as serially issuing
-// small blocks) but more work would be required to optimize for that case.
-// CUDA stream-like APIs are not the target here (as they aren't good) and we
-// align more with the graph-based approaches of modern CUDA (though that too
-// has issues). Latency on tiny command buffers (1-8 commands) is really the
-// major flaw of this implementation. I'll probably end up having to solve that
-// in the future with more special-casing or scheduler-level carve-outs. Hooray.
-//
-//=----------------------------------------------------------------------------=
-//
-// Command buffers are recorded with a forward progress guarantee ensuring that
-// once issued they will complete even if no other work can be executed on the
-// same queue. Events used within the command buffer have a signal-before-wait
-// requirement when used on the same queue.
-//
-// Dispatches have their kernel arguments packed while their packets are
-// constructed and enqueued. Some arguments are fixed (constants, directly
-// referenced buffers) and copied directly from the command data buffer while
-// others may be substituted with per-invocation state (indirectly referenced
-// buffers from a binding table).
-//
-// Though most AQL packets are written once during their initial enqueuing some
-// commands such as indirect dispatches require updating the packets after they
-// have been placed in the target queue. Indirect dispatch parameters may either
-// be declared static and captured at the start of command buffer processing
-// or dynamic until immediately prior to when the particular dispatch is
-// executed. Static parameters are preferred as the command scheduler can
-// enqueue the dispatch packet by dereferencing the workgroups buffer while
-// constructing the AQL packet. Dynamic parameters require dispatching a special
-// fixup kernel immediately prior to the actual dispatch that does the
-// indirection and updates the following packet in the queue. The AQL queue
-// processing model is exploited by having the actual dispatch packet encoded as
-// INVALID and thus halting the hardware command processor and the fixup
-// dispatch is what switches it to a valid KERNEL_DISPATCH type.
-//
-// Disclaimer: that's how it's _supposed_ to work. I've discovered that all is
-// not what it declares to be in AQL-land. Static indirect dispatch is always
-// possible but the dynamic indirect dispatch with packet fixup may need some
-// big hammers to workaround: in the worst case we put them in their own blocks
-// and treat indirect dispatches as branches that route back through the
-// scheduler to ensure each block has all indirect parameters available when the
-// block is issued... 🤞
-//
-//=----------------------------------------------------------------------------=
-//
-// AQL agents launch packets in order but may complete them in any order.
-// The two mechanisms of controlling the launch timeline are the barrier bit and
-// barrier packets. When set on a packet the barrier bit indicates that all
-// prior work on the queue must complete before the packet can be launched and
-// that behavior matches our HAL execution barrier. Barrier packets can be used
-// to set up dependencies via HSA signals roughly matching our HAL events.
-//
-// When a command buffer is recorded we use the execution barrier commands to
-// set the barrier bit on recorded packets and in many cases end up with no
-// additional barrier packets:
-//  +------------+
-//  | barrier    |      (no aql packet needed)
-//  +------------+
-//  | dispatch A |  --> AQL dispatch w/ barrier = true (await all prior)
-//  +------------+
-//  | barrier    |      (no aql packet needed)
-//  +------------+
-//  | dispatch B |  --> AQL dispatch w/ barrier = true (await dispatch A)
-//  +------------+
-//
-// In cases of concurrency a nop packet is needed to allow multiple dispatches
-// to launch without blocking. The complication is that at the time we get the
-// execution barrier command we don't know how many commands will follow before
-// the next barrier. To support single-pass recording we do some tricks with
-// moving packets in order to insert barrier packets as required:
-//  +------------+
-//  | dispatch A |  --> dispatch w/ barrier = true (await all prior)
-//  +------------+
-//  | dispatch B |  --> dispatch w/ barrier = false (execute concurrently)
-//  +------------+
-//
-// Fence acquire/release behavior is supported on nop barrier packets
-// allowing for commands on either side to potentially avoid setting the
-// behavior themselves. For example in serialized cases without the barrier
-// packets the dispatches would need to acquire/release:
-//  +------------+
-//  | dispatch A |  --> acquire (as needed), release AGENT
-//  +------------+
-//  | dispatch B |  --> acquire AGENT, release (as needed)
-//  +------------+
-// While the barrier packet can allow this to be avoided:
-//  +------------+
-//  | barrier    |  --> acquire (as needed), release AGENT
-//  +------------+
-//  | dispatch A |  --> acquire/release NONE
-//  +------------+
-//  | dispatch B |  --> acquire/release NONE
-//  +------------+
-//  | barrier    |  --> release (as needed)
-//  +------------+
-// The recording logic is more complex than desired but by figuring it out at
-// record-time the command buffer logic running here on device is kept much
-// more straightforward.
-//
-// TODO(benvanik): define how events map to packets.
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_device_cmd_t
-//===----------------------------------------------------------------------===//
-
-// Defines the recorded command type.
-// Note that commands may expand to zero or more AQL packets in the target
-// execution queue as they may be routed to other queues or require multiple
-// packets to complete.
-typedef uint8_t iree_hal_amdgpu_device_cmd_type_t;
-enum iree_hal_amdgpu_device_cmd_type_e {
-  // iree_hal_amdgpu_device_cmd_debug_group_begin_t
-  IREE_HAL_AMDGPU_DEVICE_CMD_DEBUG_GROUP_BEGIN = 0u,
-  // iree_hal_amdgpu_device_cmd_debug_group_end_t
-  IREE_HAL_AMDGPU_DEVICE_CMD_DEBUG_GROUP_END,
-  // iree_hal_amdgpu_device_cmd_barrier_t
-  IREE_HAL_AMDGPU_DEVICE_CMD_BARRIER,
-  // iree_hal_amdgpu_device_cmd_signal_event_t
-  IREE_HAL_AMDGPU_DEVICE_CMD_SIGNAL_EVENT,
-  // iree_hal_amdgpu_device_cmd_reset_event_t
-  IREE_HAL_AMDGPU_DEVICE_CMD_RESET_EVENT,
-  // iree_hal_amdgpu_device_cmd_wait_events_t
-  IREE_HAL_AMDGPU_DEVICE_CMD_WAIT_EVENTS,
-  // iree_hal_amdgpu_device_cmd_fill_buffer_t
-  IREE_HAL_AMDGPU_DEVICE_CMD_FILL_BUFFER,
-  // iree_hal_amdgpu_device_cmd_copy_buffer_t
-  IREE_HAL_AMDGPU_DEVICE_CMD_COPY_BUFFER,
-  // iree_hal_amdgpu_device_cmd_dispatch_t
-  IREE_HAL_AMDGPU_DEVICE_CMD_DISPATCH,
-  // iree_hal_amdgpu_device_cmd_dispatch_t
-  IREE_HAL_AMDGPU_DEVICE_CMD_DISPATCH_INDIRECT_DYNAMIC,
-  // iree_hal_amdgpu_device_cmd_branch_t
-  IREE_HAL_AMDGPU_DEVICE_CMD_BRANCH,
-  // iree_hal_amdgpu_device_cmd_cond_branch_t
-  IREE_HAL_AMDGPU_DEVICE_CMD_COND_BRANCH,
-  // iree_hal_amdgpu_device_cmd_return_t
-  IREE_HAL_AMDGPU_DEVICE_CMD_RETURN,
-  // TODO(benvanik): trace flush block for intra-block query/sampling resets.
-  // Today we assume command blocks under the query pool size.
-  IREE_HAL_AMDGPU_DEVICE_CMD_MAX = IREE_HAL_AMDGPU_DEVICE_CMD_RETURN,
-};
-
-enum {
-  IREE_HAL_AMDGPU_DEVICE_CMD_FLAG_FENCE_ACQUIRE_BIT = 1,  // bit 1 & 2
-  IREE_HAL_AMDGPU_DEVICE_CMD_FLAG_FENCE_RELEASE_BIT = 3,  // bit 3 & 4
-};
-
-// Flags controlling command processing behavior.
-typedef uint8_t iree_hal_amdgpu_device_cmd_flags_t;
-enum iree_hal_amdgpu_device_cmd_flag_bits_t {
-  IREE_HAL_AMDGPU_DEVICE_CMD_FLAG_NONE = 0u,
-
-  // Sets the barrier bit in the first AQL packet of the command in order to
-  // force a wait on all prior packets to complete before processing the command
-  // packets. This is much lighter weight than barriers and signals for the
-  // common case of straight-line execution.
-  IREE_HAL_AMDGPU_DEVICE_CMD_FLAG_QUEUE_AWAIT_BARRIER = 1u << 0,
-
-  // Sets HSA_FENCE_SCOPE_AGENT on the AQL packet acquire scope.
-  // This invalidates the I/K/L1 caches.
-  IREE_HAL_AMDGPU_DEVICE_CMD_FLAG_FENCE_ACQUIRE_AGENT =
-      IREE_HSA_FENCE_SCOPE_AGENT
-      << IREE_HAL_AMDGPU_DEVICE_CMD_FLAG_FENCE_ACQUIRE_BIT,
-  // Sets HSA_FENCE_SCOPE_SYSTEM on the AQL packet acquire scope.
-  // This invalidates the L1/L2 caches and flushes L2.
-  IREE_HAL_AMDGPU_DEVICE_CMD_FLAG_FENCE_ACQUIRE_SYSTEM =
-      IREE_HSA_FENCE_SCOPE_SYSTEM
-      << IREE_HAL_AMDGPU_DEVICE_CMD_FLAG_FENCE_ACQUIRE_BIT,
-
-  // Sets HSA_FENCE_SCOPE_AGENT on the AQL packet release scope.
-  // This flushes the I/K/L1 caches.
-  IREE_HAL_AMDGPU_DEVICE_CMD_FLAG_FENCE_RELEASE_AGENT =
-      IREE_HSA_FENCE_SCOPE_AGENT << 3,
-  // Sets HSA_FENCE_SCOPE_SYSTEM on the AQL packet release scope.
-  // This flushes the L1/L2 caches.
-  IREE_HAL_AMDGPU_DEVICE_CMD_FLAG_FENCE_RELEASE_SYSTEM =
-      IREE_HSA_FENCE_SCOPE_SYSTEM << 3,
-};
-
-// Commands are fixed-size to allow for indexing into an array of commands.
-// Additional variable-length data is stored out-of-band of the command struct.
-#define IREE_HAL_AMDGPU_DEVICE_CMD_SIZE 64
-
-// Header at the start of every command used to control command processing.
-typedef struct iree_hal_amdgpu_device_cmd_header_t {
-  // Command type indicating the parent structure.
-  iree_hal_amdgpu_device_cmd_type_t type;
-  // Flags controlling command processing behavior.
-  iree_hal_amdgpu_device_cmd_flags_t flags;
-  // Offset into the queue where AQL packets for the command should be placed.
-  // If more than one packet is required they are stored contiguously from the
-  // base offset.
-  uint16_t packet_offset;
-} iree_hal_amdgpu_device_cmd_header_t;
-static_assert(sizeof(iree_hal_amdgpu_device_cmd_header_t) == 4,
-              "header should be small as it's embedded in every command");
-
-// Pushes a new debug group to the stack.
-// All trace zones emitted between this and the corresponding
-// iree_hal_amdgpu_device_cmd_debug_group_end_t command will be nested within.
-//
-// NOTE: the pointers used in the command are in the host address space. This is
-// wonky, but the host trace buffer translation checks first to see if the
-// address is in the expected range of device pointers and otherwise passes it
-// right through.
-//
-// Recorded by:
-//  iree_hal_command_buffer_begin_debug_group
-typedef struct IREE_AMDGPU_ALIGNAS(64)
-    iree_hal_amdgpu_device_cmd_debug_group_begin_t {
-  iree_hal_amdgpu_device_cmd_header_t header;
-  // Source location pointer, if available. May be in the host address space.
-  iree_hal_amdgpu_trace_src_loc_ptr_t src_loc;
-  // Label for the group.
-  // Value must be a pointer to a process-lifetime string literal. The host-side
-  // command buffer recorder should perform interning if required.
-  uint64_t label_literal;
-  // Length of the label_literal in characters.
-  uint32_t label_literal_length;
-  // Color of the group. 0 indicates unspecified/default.
-  iree_hal_amdgpu_trace_color_t color;
-} iree_hal_amdgpu_device_cmd_debug_group_begin_t;
-#if IREE_HAL_AMDGPU_HAS_TRACING_FEATURE( \
-    IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_CONTROL)
-#define IREE_HAL_AMDGPU_DEVICE_CMD_DEBUG_GROUP_BEGIN_AQL_PACKET_COUNT 1
-#else
-#define IREE_HAL_AMDGPU_DEVICE_CMD_DEBUG_GROUP_BEGIN_AQL_PACKET_COUNT 0
-#endif  // IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_CONTROL
-
-// Pops the current debug group from the stack.
-//
-// Recorded by:
-//  iree_hal_command_buffer_end_debug_group
-typedef struct IREE_AMDGPU_ALIGNAS(64)
-    iree_hal_amdgpu_device_cmd_debug_group_end_t {
-  iree_hal_amdgpu_device_cmd_header_t header;
-} iree_hal_amdgpu_device_cmd_debug_group_end_t;
-#if IREE_HAL_AMDGPU_HAS_TRACING_FEATURE( \
-    IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_CONTROL)
-#define IREE_HAL_AMDGPU_DEVICE_CMD_DEBUG_GROUP_END_AQL_PACKET_COUNT 1
-#else
-#define IREE_HAL_AMDGPU_DEVICE_CMD_DEBUG_GROUP_END_AQL_PACKET_COUNT 0
-#endif  // IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_CONTROL
-
-// Performs a full queue barrier causing subsequent commands to block until all
-// prior commands have completed. This is effectively a no-op packet that just
-// has the IREE_HAL_AMDGPU_DEVICE_CMD_FLAG_QUEUE_AWAIT_BARRIER bit set but could
-// be used to perform coarse synchronization (acquire/release agent/system,
-// etc).
-//
-// Recorded by:
-//  iree_hal_command_buffer_execution_barrier (sometimes)
-typedef struct IREE_AMDGPU_ALIGNAS(64) iree_hal_amdgpu_device_cmd_barrier_t {
-  iree_hal_amdgpu_device_cmd_header_t header;
-} iree_hal_amdgpu_device_cmd_barrier_t;
-#define IREE_HAL_AMDGPU_DEVICE_CMD_BARRIER_AQL_PACKET_COUNT 1
-
-// TODO(benvanik): rework events so that they can be reused. We really should
-// have an events table-like thing or something that allows capture at time of
-// issue (if we even want to allow events to be used across command buffers).
-// Today events are similar to Vulkan ones which don't support concurrent issue
-// and that limits us here.
-//
-// Storing an ordinal to the event table would let us bulk allocate them as part
-// of the execution state. Recording would need to track the unique set of
-// events used in order to determine the capacity. We could make it be declared
-// similar to the binding table capacity and swap to recording with ordinals but
-// that makes it more difficult for users to compose. Recording could also only
-// support events created from the command buffer during recording
-// (iree_hal_command_buffer_acquire_event, etc) and that could also be used to
-// verify lifetime and invalid cross-command-buffer usage. The event handle
-// could just be an integer all the way into the compiler.
-//
-// For now the event-based code below uses an opaque value that we can
-// substitute with whatever we come up with.
-typedef uint32_t iree_hal_amdgpu_device_event_ordinal_t;
-
-// Signals event after prior commands complete.
-// The AQL signal will be decremented from a value of 1 to 0 to allow AQL
-// dependencies to be satisfied directly.
-//
-// Recorded by:
-//  iree_hal_command_buffer_signal_event
-typedef struct IREE_AMDGPU_ALIGNAS(64)
-    iree_hal_amdgpu_device_cmd_signal_event_t {
-  iree_hal_amdgpu_device_cmd_header_t header;
-  iree_hal_amdgpu_device_event_ordinal_t event;
-} iree_hal_amdgpu_device_cmd_signal_event_t;
-#define IREE_HAL_AMDGPU_DEVICE_CMD_SIGNAL_EVENT_AQL_PACKET_COUNT 1
-
-// Resets event to unsignaled after prior commands complete.
-// The AQL signal will be set to a value of 1.
-//
-// Recorded by:
-//  iree_hal_command_buffer_reset_event
-typedef struct IREE_AMDGPU_ALIGNAS(64)
-    iree_hal_amdgpu_device_cmd_reset_event_t {
-  iree_hal_amdgpu_device_cmd_header_t header;
-  iree_hal_amdgpu_device_event_ordinal_t event;
-} iree_hal_amdgpu_device_cmd_reset_event_t;
-#define IREE_HAL_AMDGPU_DEVICE_CMD_RESET_EVENT_AQL_PACKET_COUNT 1
-
-// Number of events that can be stored inline in a
-// iree_hal_amdgpu_device_cmd_wait_events_t command. This is the same as the AQL
-// barrier-and packet and allows us to avoid additional storage/indirections in
-// the common case of waits one or two events.
-#define IREE_HAL_AMDGPU_DEVICE_CMD_WAIT_EVENT_INLINE_CAPACITY 5
-
-// Waits for the given events to be signaled before proceeding.
-// All events much reach a value of 0. May be decomposed into multiple barrier
-// packets if the event count exceeds the capacity of the barrier-and packet.
-//
-// Recorded by:
-//  iree_hal_command_buffer_wait_events
-typedef struct IREE_AMDGPU_ALIGNAS(64)
-    iree_hal_amdgpu_device_cmd_wait_events_t {
-  iree_hal_amdgpu_device_cmd_header_t header;
-  // Number of events being waited upon.
-  uint32_t event_count;
-  union {
-    // Inlined events if event_count is less than
-    // IREE_HAL_AMDGPU_DEVICE_CMD_WAIT_EVENT_INLINE_CAPACITY.
-    iree_hal_amdgpu_device_event_ordinal_t
-        events[IREE_HAL_AMDGPU_DEVICE_CMD_WAIT_EVENT_INLINE_CAPACITY];
-    // Externally stored events if event_count is greater than
-    // IREE_HAL_AMDGPU_DEVICE_CMD_WAIT_EVENT_INLINE_CAPACITY.
-    iree_hal_amdgpu_device_event_ordinal_t* events_ptr;
-  };
-} iree_hal_amdgpu_device_cmd_wait_events_t;
-#define IREE_HAL_AMDGPU_DEVICE_CMD_WAIT_EVENTS_PER_AQL_PACKET 5
-#define IREE_HAL_AMDGPU_DEVICE_CMD_WAIT_EVENTS_AQL_PACKET_COUNT(event_count) \
-  IREE_AMDGPU_CEIL_DIV((event_count),                                        \
-                       IREE_HAL_AMDGPU_DEVICE_CMD_WAIT_EVENTS_PER_AQL_PACKET)
-
-// Fills a buffer with a repeating pattern.
-// Performed via a blit kernel.
-//
-// Recorded by:
-//  iree_hal_command_buffer_fill_buffer
-typedef struct IREE_AMDGPU_ALIGNAS(64)
-    iree_hal_amdgpu_device_cmd_fill_buffer_t {
-  iree_hal_amdgpu_device_cmd_header_t header;
-  // Block-relative kernel arguments address.
-  uint32_t kernarg_offset;
-  // Target buffer to fill.
-  iree_hal_amdgpu_device_buffer_ref_t target_ref;
-  // 1 to 8 pattern bytes, little endian.
-  uint64_t pattern;
-  // Length in bytes of the pattern.
-  uint8_t pattern_length;
-} iree_hal_amdgpu_device_cmd_fill_buffer_t;
-#define IREE_HAL_AMDGPU_DEVICE_CMD_FILL_BUFFER_AQL_PACKET_COUNT 1
-
-// Copies between buffers.
-// Performed via a blit kernel. May be implementable with SDMA but it is
-// currently unverified.
-//
-// Recorded by:
-//  iree_hal_command_buffer_copy_buffer
-typedef struct IREE_AMDGPU_ALIGNAS(64)
-    iree_hal_amdgpu_device_cmd_copy_buffer_t {
-  iree_hal_amdgpu_device_cmd_header_t header;
-  // Block-relative kernel arguments address.
-  uint32_t kernarg_offset;
-  // Copy source.
-  iree_hal_amdgpu_device_buffer_ref_t source_ref;
-  // Copy target.
-  iree_hal_amdgpu_device_buffer_ref_t target_ref;
-} iree_hal_amdgpu_device_cmd_copy_buffer_t;
-#define IREE_HAL_AMDGPU_DEVICE_CMD_COPY_BUFFER_AQL_PACKET_COUNT 1
-
-// Bitfield specifying flags controlling a dispatch operation.
-typedef uint16_t iree_hal_amdgpu_device_dispatch_flags_t;
-enum iree_hal_amdgpu_device_dispatch_flag_bits_t {
-  IREE_HAL_AMDGPU_DEVICE_DISPATCH_FLAG_NONE = 0,
-  // Dispatch requires the iree_amdgpu_kernel_implicit_args_t kernargs to be
-  // appended (with 8-byte alignment) to the kernargs.
-  IREE_HAL_AMDGPU_DEVICE_DISPATCH_FLAG_IMPLICIT_ARGS = 1u << 0,
-  // Dispatch uses an indirect workgroup count that is constant and available
-  // prior to command buffer execution. The command processor will read the
-  // workgroup count and embed it directly in the AQL kernel dispatch packet.
-  IREE_HAL_AMDGPU_DEVICE_DISPATCH_FLAG_INDIRECT_STATIC = 1u << 1,
-  // Dispatch uses an indirect workgroup count that is dynamic and may change up
-  // to the exact moment the dispatch is issue. The command processor will
-  // enqueue a kernel that performs the indirection and updates the kernel
-  // dispatch packet with the value before allowing the hardware queue to
-  // continue.
-  IREE_HAL_AMDGPU_DEVICE_DISPATCH_FLAG_INDIRECT_DYNAMIC = 1u << 2,
-};
-
-// Either directly embedded workgroup count XYZ dimensions or a thin buffer ref
-// pointing to a buffer containing the `uint32_t dims[3]` count.
-typedef union iree_hal_amdgpu_device_workgroup_count_t {
-  // XYZ dimensions of the grid, in workgroups. Must be greater than 0.
-  // If the grid has fewer than 3 dimensions the unused ones must be 1.
-  // Unused if the dispatch is indirect and instead the workgroups buffer
-  // reference in the parameters is used.
-  uint32_t dims[3];
-  // Optional buffer containing the workgroup count.
-  // Processing is controlled by the
-  // IREE_HAL_AMDGPU_DEVICE_DISPATCH_FLAG_INDIRECT_* flags.
-  iree_hal_amdgpu_device_workgroup_count_buffer_ref_t ref;
-} iree_hal_amdgpu_device_workgroup_count_t;
-
-// AQL/HAL dispatch parameters as recorded.
-// Some parameters may be overwritten as the packet is enqueued or during
-// execution (such as for indirect dispatches).
-typedef struct iree_hal_amdgpu_device_cmd_dispatch_config_t {
-  // Dispatch control flags.
-  iree_hal_amdgpu_device_dispatch_flags_t flags;
-  uint16_t reserved0;
-  // sharedMemBytes from the original dispatch. Added to the group_segment_size
-  // during packet production.
-  uint32_t dynamic_lds_size;
-  // Kernel arguments used to dispatch the kernel.
-  const iree_hal_amdgpu_device_kernel_args_t* kernel_args;
-  // Direct or indirect workgroup count based on the flags.
-  iree_hal_amdgpu_device_workgroup_count_t workgroup_count;
-} iree_hal_amdgpu_device_cmd_dispatch_config_t;
-static_assert(
-    sizeof(iree_hal_amdgpu_device_cmd_dispatch_config_t) == 32,
-    "dispatch packet template is inlined into cmd structs and must be small");
-#define IREE_HAL_AMDGPU_DEVICE_WORKGROUP_COUNT_UPDATE_KERNARG_SIZE \
-  (3 * sizeof(uint64_t))
-
-// Dispatches (directly or indirectly) a kernel.
-// All information required to build the AQL packet is stored within the command
-// such that it can be enqueued without additional indirection.
-//
-// Bindings and constants used by the dispatch are stored in an external data
-// segment and may not be adjacent to each other. It's possible for multiple
-// dispatches to share the same bindings and constants. Translation into
-// kernargs happens when the command is issued on device.
-//
-// Recorded by:
-//  iree_hal_command_buffer_dispatch
-typedef struct IREE_AMDGPU_ALIGNAS(64) iree_hal_amdgpu_device_cmd_dispatch_t {
-  iree_hal_amdgpu_device_cmd_header_t header;
-  // Block-relative kernel arguments address.
-  // This will be added to the per-execution base kernel arguments address
-  // during packet production.
-  //
-  // If the IREE_HAL_AMDGPU_DEVICE_DISPATCH_FLAG_INDIRECT_DYNAMIC bit is set
-  // then this will include an additional
-  // IREE_HAL_AMDGPU_DEVICE_WORKGROUP_COUNT_UPDATE_KERNARG_SIZE prefix that is
-  // used for dispatching the
-  // `iree_hal_amdgpu_device_command_buffer_workgroup_count_update` builtin
-  // kernel.
-  uint32_t kernarg_offset;
-  // AQL packet template and dispatch parameters.
-  iree_hal_amdgpu_device_cmd_dispatch_config_t config;
-  // References describing how binding pointers are passed to the kernel.
-  // References may include direct device pointers, allocation or slots in the
-  // binding table included as part of the execution request.
-  const iree_hal_amdgpu_device_buffer_ref_t* bindings /*[binding_count]*/;
-  // Dispatch constants passed to the kernel.
-  const uint32_t* constants /*[constant_count]*/;
-} iree_hal_amdgpu_device_cmd_dispatch_t;
-#define IREE_HAL_AMDGPU_DEVICE_CMD_DISPATCH_DIRECT_AQL_PACKET_COUNT 1
-#define IREE_HAL_AMDGPU_DEVICE_CMD_DISPATCH_INDIRECT_STATIC_AQL_PACKET_COUNT 1
-#define IREE_HAL_AMDGPU_DEVICE_CMD_DISPATCH_INDIRECT_DYNAMIC_AQL_PACKET_COUNT 2
-#define IREE_HAL_AMDGPU_DEVICE_CMD_DISPATCH_AQL_PACKET_COUNT(dispatch_flags)         \
-  (((dispatch_flags) &                                                               \
-    IREE_HAL_AMDGPU_DEVICE_DISPATCH_FLAG_INDIRECT_STATIC) != 0)                      \
-      ? IREE_HAL_AMDGPU_DEVICE_CMD_DISPATCH_INDIRECT_STATIC_AQL_PACKET_COUNT         \
-      : ((((dispatch_flags) &                                                        \
-           IREE_HAL_AMDGPU_DEVICE_DISPATCH_FLAG_INDIRECT_DYNAMIC) != 0)              \
-             ? IREE_HAL_AMDGPU_DEVICE_CMD_DISPATCH_INDIRECT_DYNAMIC_AQL_PACKET_COUNT \
-             : IREE_HAL_AMDGPU_DEVICE_CMD_DISPATCH_DIRECT_AQL_PACKET_COUNT)
-
-// TODO(benvanik): better specify control flow; maybe conditional support.
-// The current implementation is a placeholder for more sophisticated control
-// flow both within a command buffer (branching) and across command buffers
-// (calls). Calls will require nesting execution state and we may need to
-// preallocate that (a primary command buffer keeping track of the max nesting
-// depth).
-
-// Unconditionally branches from the current block to a new block within the
-// same command buffer.
-typedef struct IREE_AMDGPU_ALIGNAS(64) iree_hal_amdgpu_device_cmd_branch_t {
-  iree_hal_amdgpu_device_cmd_header_t header;
-  // Block-relative kernel arguments address.
-  uint32_t kernarg_offset;
-  // Block ordinal within the parent command buffer where execution will
-  // continue. The block pointer can be retrieved from the command buffer
-  // blocks list.
-  uint32_t target_block;
-} iree_hal_amdgpu_device_cmd_branch_t;
-#define IREE_HAL_AMDGPU_DEVICE_CMD_BRANCH_AQL_PACKET_COUNT 1
-#define IREE_HAL_AMDGPU_DEVICE_CMD_BRANCH_KERNARG_SIZE (2 * sizeof(uint64_t))
-#define IREE_HAL_AMDGPU_DEVICE_CMD_BRANCH_KERNARG_ALIGNMENT 8
-
-// Specifies a `*ref <cond> value` operation.
-typedef uint8_t iree_hal_amdgpu_device_cmd_cond_t;
-enum iree_hal_amdgpu_device_cmd_cond_e {
-  IREE_HAL_AMDGPU_DEVICE_CMD_COND_EQ = 0u,  // *ref == value
-  IREE_HAL_AMDGPU_DEVICE_CMD_COND_NE,       // *ref != value
-  IREE_HAL_AMDGPU_DEVICE_CMD_COND_SLT,      // *ref < value (signed)
-  IREE_HAL_AMDGPU_DEVICE_CMD_COND_SLE,      // *ref <= value (signed)
-  IREE_HAL_AMDGPU_DEVICE_CMD_COND_SGT,      // *ref > value (signed)
-  IREE_HAL_AMDGPU_DEVICE_CMD_COND_SGE,      // *ref >= value (signed)
-  IREE_HAL_AMDGPU_DEVICE_CMD_COND_ULT,      // *ref < value (unsigned)
-  IREE_HAL_AMDGPU_DEVICE_CMD_COND_ULE,      // *ref <= value (unsigned)
-  IREE_HAL_AMDGPU_DEVICE_CMD_COND_UGT,      // *ref > value (unsigned)
-  IREE_HAL_AMDGPU_DEVICE_CMD_COND_UGE,      // *ref >= value (unsigned)
-};
-
-// Conditionally branches from the current block to a new block within the
-// same command buffer based on the specified condition evaluated at the time
-// the command is executed.
-//
-// It is assumed that during recording any true_block==false_block conditions
-// have been turned into unconditional iree_hal_amdgpu_device_cmd_branch_t
-// commands.
-typedef struct IREE_AMDGPU_ALIGNAS(64)
-    iree_hal_amdgpu_device_cmd_cond_branch_t {
-  iree_hal_amdgpu_device_cmd_header_t header;
-  // Block-relative kernel arguments address.
-  uint32_t kernarg_offset;
-  // Buffer containing the uint64_t value used for the condition comparison.
-  iree_hal_amdgpu_device_uint64_buffer_ref_t ref;
-  // Conditional operation (`*ref cond value`).
-  iree_hal_amdgpu_device_cmd_cond_t cond;
-  uint8_t reserved[7];
-  // Value compared against using cond. Represented as signless bits with the
-  // interpretation of the both `ref` and this value being based on the `cond`.
-  uint64_t value;
-  // Block ordinal within the parent command buffer where execution will
-  // continue when the condition evaluates to true. The block pointer can be
-  // retrieved from the command buffer blocks list.
-  uint32_t true_block;
-  // Block ordinal within the parent command buffer where execution will
-  // continue when the condition evaluates to false. The block pointer can be
-  // retrieved from the command buffer blocks list.
-  uint32_t false_block;
-} iree_hal_amdgpu_device_cmd_cond_branch_t;
-#define IREE_HAL_AMDGPU_DEVICE_CMD_COND_BRANCH_AQL_PACKET_COUNT 1
-#define IREE_HAL_AMDGPU_DEVICE_CMD_COND_BRANCH_KERNARG_SIZE \
-  (6 * sizeof(uint64_t))
-#define IREE_HAL_AMDGPU_DEVICE_CMD_COND_BRANCH_KERNARG_ALIGNMENT 8
-
-// Returns from processing a command buffer by launching the scheduler.
-//
-// TODO(benvanik): differentiate return to scheduler from return to caller
-// command buffer. Today this always assumes the scheduler is going to be the
-// target.
-typedef struct IREE_AMDGPU_ALIGNAS(64) iree_hal_amdgpu_device_cmd_return_t {
-  iree_hal_amdgpu_device_cmd_header_t header;
-  // Block-relative kernel arguments address.
-  uint32_t kernarg_offset;
-} iree_hal_amdgpu_device_cmd_return_t;
-#define IREE_HAL_AMDGPU_DEVICE_CMD_RETURN_AQL_PACKET_COUNT 1
-#define IREE_HAL_AMDGPU_DEVICE_CMD_RETURN_KERNARG_SIZE (1 * sizeof(uint64_t))
-#define IREE_HAL_AMDGPU_DEVICE_CMD_RETURN_KERNARG_ALIGNMENT 8
-
-static_assert(sizeof(iree_hal_amdgpu_device_cmd_debug_group_begin_t) <=
-                  IREE_HAL_AMDGPU_DEVICE_CMD_SIZE,
-              "commands must fit within the fixed command size");
-static_assert(sizeof(iree_hal_amdgpu_device_cmd_debug_group_end_t) <=
-                  IREE_HAL_AMDGPU_DEVICE_CMD_SIZE,
-              "commands must fit within the fixed command size");
-static_assert(sizeof(iree_hal_amdgpu_device_cmd_barrier_t) <=
-                  IREE_HAL_AMDGPU_DEVICE_CMD_SIZE,
-              "commands must fit within the fixed command size");
-static_assert(sizeof(iree_hal_amdgpu_device_cmd_signal_event_t) <=
-                  IREE_HAL_AMDGPU_DEVICE_CMD_SIZE,
-              "commands must fit within the fixed command size");
-static_assert(sizeof(iree_hal_amdgpu_device_cmd_reset_event_t) <=
-                  IREE_HAL_AMDGPU_DEVICE_CMD_SIZE,
-              "commands must fit within the fixed command size");
-static_assert(sizeof(iree_hal_amdgpu_device_cmd_wait_events_t) <=
-                  IREE_HAL_AMDGPU_DEVICE_CMD_SIZE,
-              "commands must fit within the fixed command size");
-static_assert(sizeof(iree_hal_amdgpu_device_cmd_fill_buffer_t) <=
-                  IREE_HAL_AMDGPU_DEVICE_CMD_SIZE,
-              "commands must fit within the fixed command size");
-static_assert(sizeof(iree_hal_amdgpu_device_cmd_copy_buffer_t) <=
-                  IREE_HAL_AMDGPU_DEVICE_CMD_SIZE,
-              "commands must fit within the fixed command size");
-static_assert(sizeof(iree_hal_amdgpu_device_cmd_dispatch_t) <=
-                  IREE_HAL_AMDGPU_DEVICE_CMD_SIZE,
-              "commands must fit within the fixed command size");
-static_assert(sizeof(iree_hal_amdgpu_device_cmd_branch_t) <=
-                  IREE_HAL_AMDGPU_DEVICE_CMD_SIZE,
-              "commands must fit within the fixed command size");
-static_assert(sizeof(iree_hal_amdgpu_device_cmd_cond_branch_t) <=
-                  IREE_HAL_AMDGPU_DEVICE_CMD_SIZE,
-              "commands must fit within the fixed command size");
-static_assert(sizeof(iree_hal_amdgpu_device_cmd_return_t) <=
-                  IREE_HAL_AMDGPU_DEVICE_CMD_SIZE,
-              "commands must fit within the fixed command size");
-
-// A command describing an operation that may translate to zero or more AQL
-// packets.
-typedef struct IREE_AMDGPU_ALIGNAS(64) iree_hal_amdgpu_device_cmd_t {
-  union {
-    iree_hal_amdgpu_device_cmd_header_t header;
-    iree_hal_amdgpu_device_cmd_debug_group_begin_t debug_group_begin;
-    iree_hal_amdgpu_device_cmd_debug_group_end_t debug_group_end;
-    iree_hal_amdgpu_device_cmd_barrier_t barrier;
-    iree_hal_amdgpu_device_cmd_signal_event_t signal_event;
-    iree_hal_amdgpu_device_cmd_reset_event_t reset_event;
-    iree_hal_amdgpu_device_cmd_wait_events_t wait_events;
-    iree_hal_amdgpu_device_cmd_fill_buffer_t fill_buffer;
-    iree_hal_amdgpu_device_cmd_copy_buffer_t copy_buffer;
-    iree_hal_amdgpu_device_cmd_dispatch_t dispatch;
-    iree_hal_amdgpu_device_cmd_branch_t branch;
-    iree_hal_amdgpu_device_cmd_cond_branch_t cond_branch;
-    iree_hal_amdgpu_device_cmd_return_t ret /*urn*/;
-  };
-} iree_hal_amdgpu_device_cmd_t;
-static_assert(sizeof(iree_hal_amdgpu_device_cmd_t) <=
-                  IREE_HAL_AMDGPU_DEVICE_CMD_SIZE,
-              "commands must fit within the fixed command size");
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_device_command_buffer_t
-//===----------------------------------------------------------------------===//
-
-// Tracing query IDs used by a single command depending on tracing mode.
-// These IDs are relative to the command block they are referenced from and
-// added to whatever query ringbuffer base ID is used.
-//
-// Query IDs of 0xFFFF (IREE_HAL_AMDGPU_TRACE_EXECUTION_QUERY_ID_INVALID)
-// indicate that a particular command does not use a query ID.
-typedef struct iree_hal_amdgpu_device_command_query_id_t {
-  // Query ID used for the command when the control flag is set:
-  // IREE_HAL_AMDGPU_DEVICE_EXECUTION_FLAG_TRACE_CONTROL
-  iree_hal_amdgpu_trace_execution_query_id_t control_id;
-  // Query ID used for the command when the control+dispatch flag is set:
-  // IREE_HAL_AMDGPU_DEVICE_EXECUTION_FLAG_TRACE_DISPATCH
-  iree_hal_amdgpu_trace_execution_query_id_t dispatch_id;
-} iree_hal_amdgpu_device_command_query_id_t;
-static_assert(sizeof(iree_hal_amdgpu_device_command_query_id_t) == 4,
-              "query IDs interleaved/packed");
-
-// Information required to allocate and map commands to query IDs used with
-// tracing/profiling. The counts control how many unique query signals are
-// allocated from the query ringbuffer when issuing the block. The embedded ID
-// map is from each command to a relative query ID based on the ringbuffer's
-// returned base ID. Query IDs must not be reused within the same command block
-// as they are only captured during a flush.
-typedef struct iree_hal_amdgpu_device_command_query_map_t {
-  // Maximum number of queries used when in control mode:
-  // IREE_HAL_AMDGPU_DEVICE_EXECUTION_FLAG_TRACE_CONTROL
-  uint16_t max_control_query_count;
-  // Maximum number of queries used with in control+dispatch mode:
-  // IREE_HAL_AMDGPU_DEVICE_EXECUTION_FLAG_TRACE_DISPATCH
-  uint16_t max_dispatch_query_count;
-  uint32_t reserved;  // may be uninitialized
-  // One query ID entry per command when profiling/tracing is enabled.
-  // Each entry contains the query ID to use in control-only mode and the one to
-  // use in control+dispatch mode.
-  iree_hal_amdgpu_device_command_query_id_t query_ids[/*command_count*/];
-} iree_hal_amdgpu_device_command_query_map_t;
-
-// A block of commands within a command buffer.
-// Each block represents one or more commands that should be issued to target
-// AQL queues as part of a single parallelized issue in a single contiguous
-// span.
-//
-// Blocks are immutable once recorded and a block may be executed multiple
-// times concurrently or serially with pipelining. Blocks are replicated per
-// device such that any embedded device-local pointers are always valid for any
-// queue the block is issued on. Any pointers that reference per-execution
-// state (such as kernel argument buffers) are encoded as relative offsets to be
-// added to whatever base pointer is reserved for the execution.
-//
-// Blocks are stored in a read-only memory region.
-typedef struct IREE_AMDGPU_ALIGNAS(64) iree_hal_amdgpu_device_command_block_t {
-  // Maximum number of AQL packets that the block will enqueue during a single
-  // execution. Fewer packets may be used but they will still be populated with
-  // valid no-op AQL packets to ensure forward progress by the packet processor.
-  uint32_t max_packet_count;
-  // Total number of commands in the block.
-  uint32_t command_count;
-  // Aligned storage for fixed-length command structures.
-  const iree_hal_amdgpu_device_cmd_t* commands;
-  // Tracing/profiling query map for commands in the block.
-  iree_hal_amdgpu_device_command_query_map_t query_map;  // tail array
-} iree_hal_amdgpu_device_command_block_t;
-
-// A program consisting of one or more blocks of commands and control flow
-// between them. Command buffers are immutable once recorded and retained in
-// device local memory. A command buffer may be enqueued multiple times
-// concurrently or in sequence as any state needed is stored separately in
-// iree_hal_amdgpu_device_execution_state_t.
-//
-// Execution of a command buffer starts at block[0] and continues based on
-// control flow commands at the tail of each block. Blocks may direct execution
-// within the same command buffer or transfer control to other command buffers
-// by nesting. Upon completion a return command at the tail of a block will
-// return back to the caller.
-//
-// Command buffers are stored in a read-only memory region.
-typedef struct IREE_AMDGPU_ALIGNAS(64) iree_hal_amdgpu_device_command_buffer_t {
-  // Minimum required kernel argument buffer capacity to execute all blocks.
-  // Only one block executes at a time and the storage will be reused.
-  uint32_t max_kernarg_capacity;
-  // Total number of blocks in the command buffer.
-  uint32_t block_count;
-  // A list of all blocks with block[0] being the entry point.
-  // Commands reference blocks by ordinal in this list.
-  const iree_hal_amdgpu_device_command_block_t* blocks[];  // tail array
-} iree_hal_amdgpu_device_command_buffer_t;
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_device_execution_state_t
-//===----------------------------------------------------------------------===//
-
-#define IREE_HAL_AMDGPU_DEVICE_EXECUTION_ISSUE_BLOCK_KERNARG_SIZE \
-  (3 * sizeof(uint64_t))
-#define IREE_HAL_AMDGPU_DEVICE_EXECUTION_CONTROL_KERNARG_SIZE \
-  IREE_AMDGPU_MAX(8 * sizeof(uint64_t),                       \
-                  IREE_HAL_AMDGPU_DEVICE_EXECUTION_ISSUE_BLOCK_KERNARG_SIZE)
-
-// Controls command buffer execution behavior.
-typedef uint8_t iree_hal_amdgpu_device_execution_flags_t;
-enum iree_hal_amdgpu_device_execution_flag_bits_e {
-  IREE_HAL_AMDGPU_DEVICE_EXECUTION_FLAG_NONE = 0u,
-  // Issues work on the execution queue serially from the control queue.
-  // This reduces execution latency but decreases throughput as one single
-  // thread is populating all packets instead of it being done in parallel.
-  IREE_HAL_AMDGPU_DEVICE_EXECUTION_FLAG_ISSUE_SERIALLY = (1u << 0),
-  // Forces every command executed to have the AQL barrier bit set. This
-  // serializes execution such that only one command can execute at a time. When
-  // debugging dispatch exceptions or data corruption this can be used to ensure
-  // only one dispatch at a time is executing on the device.
-  IREE_HAL_AMDGPU_DEVICE_EXECUTION_FLAG_SERIALIZE = (1u << 1),
-  // Forces cache invalidations/flushes between every command. This can be used
-  // when stepping to ensure the host and device can see changes made on either
-  // immediately.
-  IREE_HAL_AMDGPU_DEVICE_EXECUTION_FLAG_UNCACHED = (1u << 2),
-  // Enables tracing of command buffer control logic and instrumentation.
-  // Implicit zones such as the total command buffer execution time, each
-  // scheduling stage, and other events will be produced. Explicit zones created
-  // via HAL command buffer debug APIs will be included.
-  //
-  // TODO(benvanik): find a way to avoid serializing here.
-  IREE_HAL_AMDGPU_DEVICE_EXECUTION_FLAG_TRACE_CONTROL =
-      (1u << 6) | IREE_HAL_AMDGPU_DEVICE_EXECUTION_FLAG_SERIALIZE,
-  // Enables tracing of every dispatch (or DMA) command.
-  // Timings are captured by the hardware and stored on a per-command query
-  // signal. Forces all commands to be executed serially so that trace zones
-  // remain perfectly nested and timing does not have any interference from
-  // other concurrently executing commands. Note that total latency is expected
-  // to increase due to the lack of concurrency.
-  IREE_HAL_AMDGPU_DEVICE_EXECUTION_FLAG_TRACE_DISPATCH =
-      (1u << 7) | IREE_HAL_AMDGPU_DEVICE_EXECUTION_FLAG_SERIALIZE |
-      IREE_HAL_AMDGPU_DEVICE_EXECUTION_FLAG_TRACE_CONTROL,
-};
-
-// Transient state used during the execution of a command buffer.
-// Command buffers are executed like coroutines by having the command processor
-// issue a sequence of commands before tail-enqueuing further processing or a
-// return back to the top-level scheduler.
-//
-// Execution state is stored in mutable global memory so that the scheduler can
-// manipulate it. Though immutable command buffer storage can be shared across
-// all agents the execution state is initialized for the agent is executing on.
-typedef struct IREE_AMDGPU_ALIGNAS(64)
-    iree_hal_amdgpu_device_execution_state_t {
-  // Flags controlling execution behavior.
-  iree_hal_amdgpu_device_execution_flags_t flags;
-
-  uint8_t reserved0[7];
-
-  // Command buffer being executed.
-  const iree_hal_amdgpu_device_command_buffer_t* command_buffer;
-
-  // Block that should be executed when the state is scheduled.
-  // Updated by control flow operations and set to NULL when returning.
-  const iree_hal_amdgpu_device_command_block_t* block;
-
-  // Scheduler that is managing the execution state lifetime.
-  // When the command buffer completes it will be scheduled to handle cleanup
-  // and resuming queue processing.
-  iree_hal_amdgpu_device_queue_scheduler_t* scheduler;
-
-  // Parent scheduler queue entry that this execution state is associated with.
-  // Used to notify the scheduler when execution completes.
-  uint64_t scheduler_queue_entry;
-
-  // Queue used for control operations. This _may_ be the execution queue but
-  // is usually independent to allow for command issue to overlap with
-  // execution.
-  iree_amd_cached_queue_t* control_queue;
-
-  union {
-    // State used for transfer operations done as part of the command buffer.
-    iree_hal_amdgpu_device_buffer_transfer_context_t transfer_context;
-    // NOTE: must match iree_hal_amdgpu_device_buffer_transfer_context_t.
-    // This lets us toll-free share the transfer_state with the command buffer
-    // itself.
-    struct {
-      // Queue used for command execution.
-      // This may differ from the top-level scheduling queue.
-      iree_amd_cached_queue_t execution_queue;
-      // Handles to opaque kernel objects used to dispatch builtin kernels.
-      const iree_hal_amdgpu_device_kernels_t* kernels;
-      // Optional trace buffer used when tracing infrastructure is available.
-      iree_hal_amdgpu_device_trace_buffer_t* trace_buffer;
-    };
-  };
-
-  // Storage with space for control kernel arguments.
-  // Initialized to contain the required set of args on initialization and
-  // reused without change as we store control changes in the state struct.
-  // Must be at least IREE_HAL_AMDGPU_DEVICE_EXECUTION_CONTROL_KERNARG_SIZE
-  // bytes.
-  IREE_AMDGPU_ALIGNAS(8) uint8_t* control_kernarg_storage;
-
-  // Reserved storage for kernel arguments of at least the size specified by the
-  // command buffer max_kernarg_capacity. Only one block can be executed
-  // at a time and storage is reused. Note that storage is uninitialized and
-  // must be fully specified by the command processor.
-  IREE_AMDGPU_ALIGNAS(8) uint8_t* execution_kernarg_storage;
-
-  // Last acquired base query ringbuffer index.
-  // Used for all commands in the current block and reset after each block.
-  uint64_t trace_block_query_base_id;
-  // Total number of queries allocated from the ringbuffer for the last block.
-  uint16_t trace_block_query_count;
-
-  // TODO(benvanik): stack for remembering resume blocks when returning from
-  // nested command buffers. For now we don't have calls so it's not needed.
-  // uint32_t block_stack[...];
-
-  // Binding table used to resolve indirect binding references.
-  // Contains enough elements to satisfy all slots referenced by
-  // iree_hal_amdgpu_device_buffer_ref_t in the command buffer.
-  //
-  // The enqueuing agent populates this and must ensure that all bindings stay
-  // live until the command buffer completes executing by attaching a resource
-  // set.
-  //
-  // Note that bindings here will not reference slots (though maybe we could
-  // support that in the future for silly aliasing tricks).
-  IREE_AMDGPU_ALIGNAS(64)
-  iree_hal_amdgpu_device_buffer_ref_t bindings[];  // tail array
-} iree_hal_amdgpu_device_execution_state_t;
-
-//===----------------------------------------------------------------------===//
-// Device-side Enqueuing
-//===----------------------------------------------------------------------===//
-
-#if defined(IREE_AMDGPU_TARGET_DEVICE)
-
-// Launches a command buffer with the given initialized execution state.
-// The command buffer will begin execution at the block specified in the state
-// and continue (possibly rescheduling itself) until a return command is
-// reached.
-//
-// Forward progress is only guaranteed so long as the hardware scheduling queue
-// is not blocked (such as by waiting on the completion signal). Upon completion
-// the command buffer return command will enqueue the scheduler so that it can
-// clean up the execution state and resume processing the queue.
-void iree_hal_amdgpu_device_command_buffer_enqueue_next_block(
-    iree_hal_amdgpu_device_execution_state_t* IREE_AMDGPU_RESTRICT state);
-
-#endif  // IREE_AMDGPU_TARGET_DEVICE
-
-#endif  // IREE_HAL_DRIVERS_AMDGPU_DEVICE_COMMAND_BUFFER_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/dispatch.c b/runtime/src/iree/hal/drivers/amdgpu/device/dispatch.c
new file mode 100644
index 0000000..78bb09796
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/device/dispatch.c
@@ -0,0 +1,165 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/device/dispatch.h"
+
+#include "iree/hal/drivers/amdgpu/device/support/kernel.h"
+
+//===----------------------------------------------------------------------===//
+// Dispatch packet emission
+//===----------------------------------------------------------------------===//
+
+void iree_hal_amdgpu_device_dispatch_emplace_packet(
+    const iree_hal_amdgpu_device_kernel_args_t* IREE_AMDGPU_RESTRICT
+        kernel_args,
+    const uint32_t workgroup_count[3], uint32_t dynamic_workgroup_local_memory,
+    iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT dispatch_packet,
+    void* IREE_AMDGPU_RESTRICT kernarg_ptr) {
+  dispatch_packet->setup = kernel_args->setup;
+  dispatch_packet->workgroup_size[0] = kernel_args->workgroup_size[0];
+  dispatch_packet->workgroup_size[1] = kernel_args->workgroup_size[1];
+  dispatch_packet->workgroup_size[2] = kernel_args->workgroup_size[2];
+  dispatch_packet->reserved0 = 0;
+  dispatch_packet->grid_size[0] =
+      workgroup_count[0] * kernel_args->workgroup_size[0];
+  dispatch_packet->grid_size[1] =
+      workgroup_count[1] * kernel_args->workgroup_size[1];
+  dispatch_packet->grid_size[2] =
+      workgroup_count[2] * kernel_args->workgroup_size[2];
+  dispatch_packet->private_segment_size = kernel_args->private_segment_size;
+  dispatch_packet->group_segment_size =
+      kernel_args->group_segment_size + dynamic_workgroup_local_memory;
+  dispatch_packet->kernel_object = kernel_args->kernel_object;
+  dispatch_packet->kernarg_address = kernarg_ptr;
+  dispatch_packet->reserved2 = 0;
+  dispatch_packet->completion_signal = iree_hsa_signal_null();
+}
+
+//===----------------------------------------------------------------------===//
+// Dispatch kernarg emission
+//===----------------------------------------------------------------------===//
+
+static void iree_hal_amdgpu_device_dispatch_emplace_implicit_args(
+    const iree_hal_amdgpu_device_kernel_args_t* IREE_AMDGPU_RESTRICT
+        kernel_args,
+    const uint32_t workgroup_count[3], uint32_t dynamic_workgroup_local_memory,
+    const iree_hal_amdgpu_device_dispatch_kernarg_layout_t* IREE_AMDGPU_RESTRICT
+        layout,
+    void* IREE_AMDGPU_RESTRICT kernarg_ptr) {
+  if (!layout->has_implicit_args) return;
+
+  iree_amdgpu_kernel_implicit_args_t* IREE_AMDGPU_RESTRICT implicit_args =
+      (iree_amdgpu_kernel_implicit_args_t*)((uint8_t*)kernarg_ptr +
+                                            layout->implicit_args_offset);
+  iree_amdgpu_memset(implicit_args, 0, IREE_AMDGPU_KERNEL_IMPLICIT_ARGS_SIZE);
+  implicit_args->block_count[0] = workgroup_count[0];
+  implicit_args->block_count[1] = workgroup_count[1];
+  implicit_args->block_count[2] = workgroup_count[2];
+  implicit_args->group_size[0] = kernel_args->workgroup_size[0];
+  implicit_args->group_size[1] = kernel_args->workgroup_size[1];
+  implicit_args->group_size[2] = kernel_args->workgroup_size[2];
+  implicit_args->grid_dims = 3;
+  implicit_args->printf_buffer = NULL;
+  implicit_args->hostcall_buffer = NULL;
+  implicit_args->dynamic_lds_size = dynamic_workgroup_local_memory;
+}
+
+void iree_hal_amdgpu_device_dispatch_emplace_hal_kernargs(
+    const iree_hal_amdgpu_device_kernel_args_t* IREE_AMDGPU_RESTRICT
+        kernel_args,
+    const uint32_t workgroup_count[3], uint32_t dynamic_workgroup_local_memory,
+    const iree_hal_amdgpu_device_dispatch_kernarg_layout_t* IREE_AMDGPU_RESTRICT
+        layout,
+    const uint64_t* IREE_AMDGPU_RESTRICT binding_ptrs,
+    const uint32_t* IREE_AMDGPU_RESTRICT constants,
+    void* IREE_AMDGPU_RESTRICT kernarg_ptr) {
+  iree_amdgpu_memset(kernarg_ptr, 0, layout->total_kernarg_size);
+
+  const size_t binding_bytes =
+      (size_t)kernel_args->binding_count * sizeof(uint64_t);
+  const size_t constant_bytes =
+      (size_t)kernel_args->constant_count * sizeof(uint32_t);
+  if (binding_bytes > 0) {
+    iree_amdgpu_memcpy(kernarg_ptr, binding_ptrs, binding_bytes);
+  }
+  if (constant_bytes > 0) {
+    iree_amdgpu_memcpy((uint8_t*)kernarg_ptr + binding_bytes, constants,
+                       constant_bytes);
+  }
+
+  iree_hal_amdgpu_device_dispatch_emplace_implicit_args(
+      kernel_args, workgroup_count, dynamic_workgroup_local_memory, layout,
+      kernarg_ptr);
+}
+
+void iree_hal_amdgpu_device_dispatch_emplace_custom_kernargs(
+    const iree_hal_amdgpu_device_dispatch_kernarg_layout_t* IREE_AMDGPU_RESTRICT
+        layout,
+    const void* IREE_AMDGPU_RESTRICT custom_kernarg_ptr,
+    void* IREE_AMDGPU_RESTRICT kernarg_ptr) {
+  if (layout->total_kernarg_size > 0) {
+    iree_amdgpu_memcpy(kernarg_ptr, custom_kernarg_ptr,
+                       layout->total_kernarg_size);
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// Indirect dispatch parameter patching
+//===----------------------------------------------------------------------===//
+
+void iree_hal_amdgpu_device_dispatch_emplace_indirect_params_patch(
+    const iree_hal_amdgpu_device_kernel_args_t* IREE_AMDGPU_RESTRICT
+        patch_kernel_args,
+    const uint32_t* IREE_AMDGPU_RESTRICT workgroup_count,
+    iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT dispatch_packet,
+    uint16_t dispatch_header, uint16_t dispatch_setup,
+    iree_amdgpu_kernel_implicit_args_t* IREE_AMDGPU_RESTRICT implicit_args,
+    iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT patch_packet,
+    void* IREE_AMDGPU_RESTRICT kernarg_ptr) {
+  iree_hal_amdgpu_device_dispatch_patch_indirect_params_args_t*
+      IREE_AMDGPU_RESTRICT kernargs =
+          (iree_hal_amdgpu_device_dispatch_patch_indirect_params_args_t*)
+              kernarg_ptr;
+  kernargs->workgroup_count = workgroup_count;
+  kernargs->dispatch_packet = dispatch_packet;
+  kernargs->implicit_args = implicit_args;
+  kernargs->dispatch_header_setup =
+      (uint32_t)dispatch_header | ((uint32_t)dispatch_setup << 16);
+
+  const uint32_t patch_workgroup_count[3] = {1, 1, 1};
+  iree_hal_amdgpu_device_dispatch_emplace_packet(
+      patch_kernel_args, patch_workgroup_count,
+      /*dynamic_workgroup_local_memory=*/0, patch_packet, kernarg_ptr);
+}
+
+#if defined(IREE_AMDGPU_TARGET_DEVICE)
+
+IREE_AMDGPU_ATTRIBUTE_KERNEL void
+iree_hal_amdgpu_device_dispatch_patch_indirect_params(
+    const uint32_t* IREE_AMDGPU_RESTRICT workgroup_count,
+    iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT dispatch_packet,
+    iree_amdgpu_kernel_implicit_args_t* IREE_AMDGPU_RESTRICT implicit_args,
+    uint32_t dispatch_header_setup) {
+  dispatch_packet->grid_size[0] =
+      workgroup_count[0] * dispatch_packet->workgroup_size[0];
+  dispatch_packet->grid_size[1] =
+      workgroup_count[1] * dispatch_packet->workgroup_size[1];
+  dispatch_packet->grid_size[2] =
+      workgroup_count[2] * dispatch_packet->workgroup_size[2];
+
+  if (implicit_args) {
+    implicit_args->block_count[0] = workgroup_count[0];
+    implicit_args->block_count[1] = workgroup_count[1];
+    implicit_args->block_count[2] = workgroup_count[2];
+  }
+
+  iree_amdgpu_scoped_atomic_store(
+      (iree_amdgpu_scoped_atomic_uint32_t*)dispatch_packet,
+      dispatch_header_setup, iree_amdgpu_memory_order_release,
+      iree_amdgpu_memory_scope_system);
+}
+
+#endif  // IREE_AMDGPU_TARGET_DEVICE
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/dispatch.h b/runtime/src/iree/hal/drivers/amdgpu/device/dispatch.h
new file mode 100644
index 0000000..e0b5862
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/device/dispatch.h
@@ -0,0 +1,211 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_DEVICE_DISPATCH_H_
+#define IREE_HAL_DRIVERS_AMDGPU_DEVICE_DISPATCH_H_
+
+#include "iree/hal/drivers/amdgpu/abi/kernel_args.h"
+#include "iree/hal/drivers/amdgpu/abi/queue.h"
+#include "iree/hal/drivers/amdgpu/device/support/common.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+//===----------------------------------------------------------------------===//
+// Dispatch Kernarg Layout
+//===----------------------------------------------------------------------===//
+
+// Device-visible kernarg byte layout for one dispatch.
+//
+// This is intentionally a prevalidated data contract instead of a status-
+// producing API: once device-side replay is emitting packets there is no sane
+// way to report malformed ABI metadata or recover from partially-written
+// packet/kernarg storage. Host recording/submission code must validate kernel
+// metadata and user-provided arguments before passing a layout here.
+typedef struct iree_hal_amdgpu_device_dispatch_kernarg_layout_t {
+  // Size in bytes of the explicitly provided dispatch arguments.
+  size_t explicit_kernarg_size;
+  // Offset in bytes of the implicit HIP/OpenCL suffix, if present.
+  size_t implicit_args_offset;
+  // Total kernarg reservation size in bytes required for this dispatch.
+  size_t total_kernarg_size;
+  // True if a HIP/OpenCL implicit args suffix is appended at
+  // |implicit_args_offset| and must be populated during emplace.
+  bool has_implicit_args;
+} iree_hal_amdgpu_device_dispatch_kernarg_layout_t;
+
+// Returns the HAL ABI kernarg layout for |kernel_args|.
+//
+// Explicit args are laid out as:
+//   uint64_t bindings[kernel_args->binding_count]
+//   uint32_t constants[kernel_args->constant_count]
+//   zero padding to 8-byte alignment
+//
+// If kernel metadata declares more bytes than those explicit args, a
+// HIP/OpenCL implicit-args suffix is appended at the aligned explicit size and
+// the reservation is extended to cover at least
+// IREE_AMDGPU_KERNEL_IMPLICIT_ARGS_SIZE bytes of suffix storage.
+//
+// Caller must have validated that kernel_args->kernarg_size is not smaller than
+// the explicit HAL ABI size if that is considered malformed for the current
+// executable.
+static inline iree_hal_amdgpu_device_dispatch_kernarg_layout_t
+iree_hal_amdgpu_device_dispatch_make_hal_kernarg_layout(
+    const iree_hal_amdgpu_device_kernel_args_t* kernel_args) {
+  const size_t binding_bytes =
+      (size_t)kernel_args->binding_count * sizeof(uint64_t);
+  const size_t constant_bytes = iree_amdgpu_align(
+      (size_t)kernel_args->constant_count * sizeof(uint32_t), 8);
+  const size_t explicit_kernarg_size = binding_bytes + constant_bytes;
+  const bool has_implicit_args =
+      (size_t)kernel_args->kernarg_size > explicit_kernarg_size;
+  const size_t total_kernarg_size =
+      has_implicit_args
+          ? IREE_AMDGPU_MAX(
+                (size_t)kernel_args->kernarg_size,
+                explicit_kernarg_size + IREE_AMDGPU_KERNEL_IMPLICIT_ARGS_SIZE)
+          : explicit_kernarg_size;
+  return (iree_hal_amdgpu_device_dispatch_kernarg_layout_t){
+      .explicit_kernarg_size = explicit_kernarg_size,
+      .implicit_args_offset = explicit_kernarg_size,
+      .total_kernarg_size = total_kernarg_size,
+      .has_implicit_args = has_implicit_args,
+  };
+}
+
+// Returns a custom-direct-argument layout for a raw kernarg blob of
+// |kernarg_size| bytes.
+//
+// The caller owns all packing and padding in the raw argument blob. No implicit
+// suffix is synthesized in this mode.
+static inline iree_hal_amdgpu_device_dispatch_kernarg_layout_t
+iree_hal_amdgpu_device_dispatch_make_custom_kernarg_layout(
+    size_t kernarg_size) {
+  return (iree_hal_amdgpu_device_dispatch_kernarg_layout_t){
+      .explicit_kernarg_size = kernarg_size,
+      .implicit_args_offset = kernarg_size,
+      .total_kernarg_size = kernarg_size,
+      .has_implicit_args = false,
+  };
+}
+
+//===----------------------------------------------------------------------===//
+// Dispatch Packet/Kernarg Emission
+//===----------------------------------------------------------------------===//
+
+// Kernel arguments for the builtin indirect-parameter patch dispatch.
+typedef struct iree_hal_amdgpu_device_dispatch_patch_indirect_params_args_t {
+  // Device pointer to a uint32_t[3] workgroup-count parameter buffer.
+  const uint32_t* workgroup_count;
+  // Device pointer to the AQL dispatch packet to publish after patching.
+  iree_hsa_kernel_dispatch_packet_t* dispatch_packet;
+  // Optional device pointer to the dispatch's implicit args suffix.
+  iree_amdgpu_kernel_implicit_args_t* implicit_args;
+  // Final 32-bit header/setup word to publish with a release store.
+  uint32_t dispatch_header_setup;
+} iree_hal_amdgpu_device_dispatch_patch_indirect_params_args_t;
+IREE_AMDGPU_STATIC_ASSERT(
+    sizeof(iree_hal_amdgpu_device_dispatch_patch_indirect_params_args_t) == 32,
+    "indirect dispatch patch args must match the kernel ABI");
+
+// Populates a kernel dispatch packet body in already-reserved storage.
+//
+// The caller owns packet header commit, completion-signal assignment, and
+// doorbell writes. Zero workgroup counts are preserved verbatim and produce a
+// valid zero-grid dispatch packet.
+//
+// Preconditions:
+//   - |kernel_args|, |workgroup_count|, |dispatch_packet|, and |kernarg_ptr|
+//     are non-NULL.
+//   - |kernel_args->workgroup_size| and
+//     |kernel_args->group_segment_size + dynamic_workgroup_local_memory| are
+//     valid for the target kernel.
+//   - Each grid dimension product
+//     |workgroup_count[i] * kernel_args->workgroup_size[i]| fits in uint32_t.
+//   - |kernarg_ptr| satisfies |kernel_args->kernarg_alignment|.
+void iree_hal_amdgpu_device_dispatch_emplace_packet(
+    const iree_hal_amdgpu_device_kernel_args_t* IREE_AMDGPU_RESTRICT
+        kernel_args,
+    const uint32_t workgroup_count[3], uint32_t dynamic_workgroup_local_memory,
+    iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT dispatch_packet,
+    void* IREE_AMDGPU_RESTRICT kernarg_ptr);
+
+// Populates HAL ABI kernargs in already-reserved storage.
+//
+// |binding_ptrs| must provide |kernel_args->binding_count| device pointers as
+// raw 64-bit values. |constants| must provide
+// |kernel_args->constant_count * sizeof(uint32_t)| bytes. Either pointer may be
+// NULL when its corresponding count is zero.
+//
+// Preconditions:
+//   - |kernel_args|, |workgroup_count|, |layout|, and |kernarg_ptr| are
+//     non-NULL.
+//   - |layout| was derived from |kernel_args| using
+//     iree_hal_amdgpu_device_dispatch_make_hal_kernarg_layout.
+//   - |kernarg_ptr| points to at least |layout->total_kernarg_size| bytes of
+//     writable storage.
+void iree_hal_amdgpu_device_dispatch_emplace_hal_kernargs(
+    const iree_hal_amdgpu_device_kernel_args_t* IREE_AMDGPU_RESTRICT
+        kernel_args,
+    const uint32_t workgroup_count[3], uint32_t dynamic_workgroup_local_memory,
+    const iree_hal_amdgpu_device_dispatch_kernarg_layout_t* IREE_AMDGPU_RESTRICT
+        layout,
+    const uint64_t* IREE_AMDGPU_RESTRICT binding_ptrs,
+    const uint32_t* IREE_AMDGPU_RESTRICT constants,
+    void* IREE_AMDGPU_RESTRICT kernarg_ptr);
+
+// Populates custom direct kernargs in already-reserved storage.
+//
+// |custom_kernarg_ptr| must provide |layout->total_kernarg_size| bytes in the
+// final kernel ABI shape expected by the target kernel.
+//
+// Preconditions:
+//   - |layout| and |kernarg_ptr| are non-NULL.
+//   - |layout| was derived with
+//     iree_hal_amdgpu_device_dispatch_make_custom_kernarg_layout.
+//   - |custom_kernarg_ptr| is non-NULL when |layout->total_kernarg_size| > 0.
+void iree_hal_amdgpu_device_dispatch_emplace_custom_kernargs(
+    const iree_hal_amdgpu_device_dispatch_kernarg_layout_t* IREE_AMDGPU_RESTRICT
+        layout,
+    const void* IREE_AMDGPU_RESTRICT custom_kernarg_ptr,
+    void* IREE_AMDGPU_RESTRICT kernarg_ptr);
+
+// Populates the builtin patch dispatch that updates an indirect-parameter
+// dispatch packet and then publishes its header.
+//
+// The target dispatch packet must already contain every non-header field. The
+// patch dispatch reads |workgroup_count| on device, updates the target packet's
+// grid-size fields and optional implicit args, then atomically publishes the
+// provided final dispatch header/setup word.
+void iree_hal_amdgpu_device_dispatch_emplace_indirect_params_patch(
+    const iree_hal_amdgpu_device_kernel_args_t* IREE_AMDGPU_RESTRICT
+        patch_kernel_args,
+    const uint32_t* IREE_AMDGPU_RESTRICT workgroup_count,
+    iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT dispatch_packet,
+    uint16_t dispatch_header, uint16_t dispatch_setup,
+    iree_amdgpu_kernel_implicit_args_t* IREE_AMDGPU_RESTRICT implicit_args,
+    iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT patch_packet,
+    void* IREE_AMDGPU_RESTRICT kernarg_ptr);
+
+#if defined(IREE_AMDGPU_TARGET_DEVICE)
+
+// Device builtin that patches a following dispatch packet from indirect
+// workgroup-count parameters. Launched as a single work-item dispatch.
+IREE_AMDGPU_ATTRIBUTE_KERNEL void
+iree_hal_amdgpu_device_dispatch_patch_indirect_params(
+    const uint32_t* IREE_AMDGPU_RESTRICT workgroup_count,
+    iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT dispatch_packet,
+    iree_amdgpu_kernel_implicit_args_t* IREE_AMDGPU_RESTRICT implicit_args,
+    uint32_t dispatch_header_setup);
+
+#endif  // IREE_AMDGPU_TARGET_DEVICE
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_DEVICE_DISPATCH_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/dispatch_test.cc b/runtime/src/iree/hal/drivers/amdgpu/device/dispatch_test.cc
new file mode 100644
index 0000000..273fe49
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/device/dispatch_test.cc
@@ -0,0 +1,182 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/device/dispatch.h"
+
+#include <array>
+#include <cstddef>
+#include <cstdint>
+#include <cstring>
+
+#include "iree/testing/gtest.h"
+
+namespace iree::hal::amdgpu {
+namespace {
+
+static iree_hal_amdgpu_device_kernel_args_t MakeKernelArgs(
+    uint64_t kernel_object, uint16_t kernarg_size, uint16_t kernarg_alignment,
+    uint16_t binding_count, uint16_t constant_count) {
+  iree_hal_amdgpu_device_kernel_args_t kernel_args = {};
+  kernel_args.kernel_object = kernel_object;
+  kernel_args.setup = 3;
+  kernel_args.workgroup_size[0] = 4;
+  kernel_args.workgroup_size[1] = 5;
+  kernel_args.workgroup_size[2] = 6;
+  kernel_args.private_segment_size = 7;
+  kernel_args.group_segment_size = 8;
+  kernel_args.kernarg_size = kernarg_size;
+  kernel_args.kernarg_alignment = kernarg_alignment;
+  kernel_args.constant_count = constant_count;
+  kernel_args.binding_count = binding_count;
+  return kernel_args;
+}
+
+TEST(DispatchTest, MakeHalKernargLayoutWithoutImplicitArgs) {
+  iree_hal_amdgpu_device_kernel_args_t kernel_args =
+      MakeKernelArgs(/*kernel_object=*/0x1234u, /*kernarg_size=*/32,
+                     /*kernarg_alignment=*/16, /*binding_count=*/2,
+                     /*constant_count=*/3);
+
+  iree_hal_amdgpu_device_dispatch_kernarg_layout_t layout =
+      iree_hal_amdgpu_device_dispatch_make_hal_kernarg_layout(&kernel_args);
+
+  EXPECT_EQ(layout.explicit_kernarg_size, 32u);
+  EXPECT_EQ(layout.implicit_args_offset, 32u);
+  EXPECT_EQ(layout.total_kernarg_size, 32u);
+  EXPECT_FALSE(layout.has_implicit_args);
+}
+
+TEST(DispatchTest, MakeHalKernargLayoutWithImplicitArgs) {
+  iree_hal_amdgpu_device_kernel_args_t kernel_args =
+      MakeKernelArgs(/*kernel_object=*/0x1234u, /*kernarg_size=*/40,
+                     /*kernarg_alignment=*/16, /*binding_count=*/2,
+                     /*constant_count=*/3);
+
+  iree_hal_amdgpu_device_dispatch_kernarg_layout_t layout =
+      iree_hal_amdgpu_device_dispatch_make_hal_kernarg_layout(&kernel_args);
+
+  EXPECT_EQ(layout.explicit_kernarg_size, 32u);
+  EXPECT_EQ(layout.implicit_args_offset, 32u);
+  EXPECT_EQ(layout.total_kernarg_size,
+            32u + IREE_AMDGPU_KERNEL_IMPLICIT_ARGS_SIZE);
+  EXPECT_TRUE(layout.has_implicit_args);
+}
+
+TEST(DispatchTest, EmplacePacketPreservesZeroWorkgroupCounts) {
+  iree_hal_amdgpu_device_kernel_args_t kernel_args =
+      MakeKernelArgs(/*kernel_object=*/0xBEEFu, /*kernarg_size=*/0,
+                     /*kernarg_alignment=*/16, /*binding_count=*/0,
+                     /*constant_count=*/0);
+  iree_hsa_kernel_dispatch_packet_t packet = {};
+  packet.header = 0xFFFFu;
+  alignas(16) std::array<uint8_t, 64> kernargs = {};
+  const uint32_t workgroup_count[3] = {0, 2, 0};
+
+  iree_hal_amdgpu_device_dispatch_emplace_packet(
+      &kernel_args, workgroup_count,
+      /*dynamic_workgroup_local_memory=*/9, &packet, kernargs.data());
+
+  EXPECT_EQ(packet.header, 0xFFFFu);
+  EXPECT_EQ(packet.setup, 3u);
+  EXPECT_EQ(packet.workgroup_size[0], 4u);
+  EXPECT_EQ(packet.workgroup_size[1], 5u);
+  EXPECT_EQ(packet.workgroup_size[2], 6u);
+  EXPECT_EQ(packet.reserved0, 0u);
+  EXPECT_EQ(packet.grid_size[0], 0u);
+  EXPECT_EQ(packet.grid_size[1], 10u);
+  EXPECT_EQ(packet.grid_size[2], 0u);
+  EXPECT_EQ(packet.private_segment_size, 7u);
+  EXPECT_EQ(packet.group_segment_size, 17u);
+  EXPECT_EQ(packet.kernel_object, 0xBEEFu);
+  EXPECT_EQ(packet.kernarg_address, kernargs.data());
+  EXPECT_EQ(packet.reserved2, 0u);
+  EXPECT_EQ(packet.completion_signal.handle, iree_hsa_signal_null().handle);
+}
+
+TEST(DispatchTest, EmplaceHalKernargsWritesBindingsConstantsAndImplicitArgs) {
+  iree_hal_amdgpu_device_kernel_args_t kernel_args = MakeKernelArgs(
+      /*kernel_object=*/0x1234u,
+      /*kernarg_size=*/32 + IREE_AMDGPU_KERNEL_IMPLICIT_ARGS_SIZE,
+      /*kernarg_alignment=*/16, /*binding_count=*/2,
+      /*constant_count=*/3);
+  iree_hal_amdgpu_device_dispatch_kernarg_layout_t layout =
+      iree_hal_amdgpu_device_dispatch_make_hal_kernarg_layout(&kernel_args);
+  const uint32_t workgroup_count[3] = {7, 8, 9};
+  const uint64_t bindings[2] = {0x1111222233334444ull, 0x5555666677778888ull};
+  const uint32_t constants[3] = {0xAAu, 0xBBu, 0xCCu};
+  alignas(16) std::array<uint8_t, 256> kernargs = {};
+  kernargs.fill(0xFD);
+
+  iree_hal_amdgpu_device_dispatch_emplace_hal_kernargs(
+      &kernel_args, workgroup_count,
+      /*dynamic_workgroup_local_memory=*/13, &layout, bindings, constants,
+      kernargs.data());
+
+  const uint64_t* binding_words =
+      reinterpret_cast<const uint64_t*>(kernargs.data());
+  EXPECT_EQ(binding_words[0], 0x1111222233334444ull);
+  EXPECT_EQ(binding_words[1], 0x5555666677778888ull);
+
+  const uint32_t* constant_words =
+      reinterpret_cast<const uint32_t*>(kernargs.data() + 16);
+  EXPECT_EQ(constant_words[0], 0xAAu);
+  EXPECT_EQ(constant_words[1], 0xBBu);
+  EXPECT_EQ(constant_words[2], 0xCCu);
+  EXPECT_EQ(kernargs[28], 0u);
+  EXPECT_EQ(kernargs[29], 0u);
+  EXPECT_EQ(kernargs[30], 0u);
+  EXPECT_EQ(kernargs[31], 0u);
+
+  const auto* implicit_args =
+      reinterpret_cast<const iree_amdgpu_kernel_implicit_args_t*>(
+          kernargs.data() + layout.implicit_args_offset);
+  EXPECT_EQ(implicit_args->block_count[0], 7u);
+  EXPECT_EQ(implicit_args->block_count[1], 8u);
+  EXPECT_EQ(implicit_args->block_count[2], 9u);
+  EXPECT_EQ(implicit_args->group_size[0], 4u);
+  EXPECT_EQ(implicit_args->group_size[1], 5u);
+  EXPECT_EQ(implicit_args->group_size[2], 6u);
+  EXPECT_EQ(implicit_args->remainder[0], 0u);
+  EXPECT_EQ(implicit_args->remainder[1], 0u);
+  EXPECT_EQ(implicit_args->remainder[2], 0u);
+  EXPECT_EQ(implicit_args->reserved0, 0u);
+  EXPECT_EQ(implicit_args->reserved1, 0u);
+  EXPECT_EQ(implicit_args->global_offset[0], 0u);
+  EXPECT_EQ(implicit_args->global_offset[1], 0u);
+  EXPECT_EQ(implicit_args->global_offset[2], 0u);
+  EXPECT_EQ(implicit_args->grid_dims, 3u);
+  EXPECT_EQ(implicit_args->printf_buffer, nullptr);
+  EXPECT_EQ(implicit_args->hostcall_buffer, nullptr);
+  EXPECT_EQ(implicit_args->deprecated_multigrid_sync_arg, 0u);
+  EXPECT_EQ(implicit_args->unused_heap_v1, 0u);
+  EXPECT_EQ(implicit_args->unused_default_queue, 0u);
+  EXPECT_EQ(implicit_args->unused_completion_action, 0u);
+  EXPECT_EQ(implicit_args->dynamic_lds_size, 13u);
+}
+
+TEST(DispatchTest, EmplaceCustomKernargsCopiesRawBlob) {
+  iree_hal_amdgpu_device_dispatch_kernarg_layout_t layout =
+      iree_hal_amdgpu_device_dispatch_make_custom_kernarg_layout(20);
+  const std::array<uint8_t, 20> custom_kernargs = {
+      0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09,
+      0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13,
+  };
+  alignas(16) std::array<uint8_t, 32> kernargs = {};
+  kernargs.fill(0xFD);
+
+  iree_hal_amdgpu_device_dispatch_emplace_custom_kernargs(
+      &layout, custom_kernargs.data(), kernargs.data());
+
+  EXPECT_EQ(std::memcmp(kernargs.data(), custom_kernargs.data(),
+                        custom_kernargs.size()),
+            0);
+  for (size_t i = custom_kernargs.size(); i < kernargs.size(); ++i) {
+    EXPECT_EQ(kernargs[i], 0xFD);
+  }
+}
+
+}  // namespace
+}  // namespace iree::hal::amdgpu
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/host_client.c b/runtime/src/iree/hal/drivers/amdgpu/device/host_client.c
deleted file mode 100644
index 78d7d3e..0000000
--- a/runtime/src/iree/hal/drivers/amdgpu/device/host_client.c
+++ /dev/null
@@ -1,107 +0,0 @@
-// Copyright 2025 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/hal/drivers/amdgpu/device/host_client.h"
-
-//===----------------------------------------------------------------------===//
-// Device-side Enqueuing
-//===----------------------------------------------------------------------===//
-
-void iree_hal_amdgpu_device_host_client_post(
-    const iree_hal_amdgpu_device_host_client_t* IREE_AMDGPU_RESTRICT client,
-    uint16_t type, uint64_t return_address, uint64_t arg0, uint64_t arg1,
-    uint64_t arg2, uint64_t arg3, iree_hsa_signal_t completion_signal) {
-  // Reserve a packet write index and wait for it to become available in cases
-  // where the queue is exhausted.
-  const uint64_t packet_id = iree_hsa_queue_add_write_index(
-      &client->service_queue, 1u, iree_amdgpu_memory_order_relaxed);
-  while (packet_id -
-             iree_hsa_queue_load_read_index(&client->service_queue,
-                                            iree_amdgpu_memory_order_acquire) >=
-         client->service_queue.size) {
-    iree_amdgpu_yield();  // spinning
-  }
-  const uint64_t queue_mask = client->service_queue.size - 1;  // power of two
-  iree_hsa_agent_dispatch_packet_t* IREE_AMDGPU_RESTRICT agent_packet =
-      client->service_queue.base_address + (packet_id & queue_mask) * 64;
-
-  // Populate all of the packet besides the header.
-  // NOTE: we could use the reserved fields if we wanted at the risk of the
-  // packets not being inspectable by queue interception tooling.
-  agent_packet->reserved0 = 0;
-  agent_packet->return_address = (void*)return_address;
-  agent_packet->arg[0] = arg0;
-  agent_packet->arg[1] = arg1;
-  agent_packet->arg[2] = arg2;
-  agent_packet->arg[3] = arg3;
-  agent_packet->reserved2 = 0;
-  agent_packet->completion_signal = completion_signal;
-
-  // Populate the header and release the packet to the queue.
-  uint16_t header = IREE_HSA_PACKET_TYPE_AGENT_DISPATCH
-                    << IREE_HSA_PACKET_HEADER_TYPE;
-
-  // NOTE: we shouldn't need a barrier bit as posts should technically be
-  // executed back-to-back. If a particular post type supports concurrent or
-  // out-of-order execution then it _may_ do so unless the bit is set.
-  header |= 1 << IREE_HSA_PACKET_HEADER_BARRIER;
-
-  // Posts are unidirectional and take device agent resources and make them
-  // available to the host. We may be able to get away with an scacquire of
-  // IREE_HSA_FENCE_SCOPE_AGENT here but conservatively use
-  // IREE_HSA_FENCE_SCOPE_SYSTEM so that if any resources happen to have been
-  // touched on other agents (such as when executing multi-device work as part
-  // of a command buffer collective operation) the host can see all of that.
-  // It certainly is not optimal to do, though.
-  header |= IREE_HSA_FENCE_SCOPE_SYSTEM
-            << IREE_HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE;
-  header |= IREE_HSA_FENCE_SCOPE_SYSTEM
-            << IREE_HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE;
-
-  const uint32_t header_type = header | (uint32_t)(type << 16);
-  iree_amdgpu_scoped_atomic_store(
-      (iree_amdgpu_scoped_atomic_uint32_t*)agent_packet, header_type,
-      iree_amdgpu_memory_order_release, iree_amdgpu_memory_scope_system);
-
-  // Signal the queue doorbell.
-  // This will store the packet_id to the doorbell signal (though in MULTI mode
-  // it's ignored) and in the case of the host agent trigger a hardware
-  // interrupt via the event mailbox pointer on the signal. If the host is doing
-  // a kernel wait via the HSA APIs it should be woken pretty quickly.
-  // https://sourcegraph.com/github.com/ROCm/rocMLIR/-/blob/external/llvm-project/amd/device-libs/ockl/src/hsaqs.cl?L69
-  iree_hsa_signal_store(client->service_queue.doorbell_signal, packet_id,
-                        iree_amdgpu_memory_order_relaxed);
-}
-
-void iree_hal_amdgpu_device_host_client_post_signal(
-    const iree_hal_amdgpu_device_host_client_t* IREE_AMDGPU_RESTRICT client,
-    uint64_t external_semaphore, uint64_t payload) {
-  IREE_AMDGPU_TRACE_BUFFER_SCOPE(host->trace_buffer);
-  IREE_AMDGPU_TRACE_ZONE_BEGIN(z0);
-
-  iree_hal_amdgpu_device_host_client_post(
-      client, IREE_HAL_AMDGPU_DEVICE_HOST_CALL_POST_SIGNAL,
-      /*return_address=*/0, external_semaphore, payload,
-      /*unused=*/0,
-      /*unused=*/0, iree_hsa_signal_null());
-
-  IREE_AMDGPU_TRACE_ZONE_END(z0);
-}
-
-void iree_hal_amdgpu_device_host_client_post_release(
-    const iree_hal_amdgpu_device_host_client_t* IREE_AMDGPU_RESTRICT client,
-    uint64_t resource0, uint64_t resource1, uint64_t resource2,
-    uint64_t resource3, iree_hsa_signal_t completion_signal) {
-  IREE_AMDGPU_TRACE_BUFFER_SCOPE(host->trace_buffer);
-  IREE_AMDGPU_TRACE_ZONE_BEGIN(z0);
-
-  iree_hal_amdgpu_device_host_client_post(
-      client, IREE_HAL_AMDGPU_DEVICE_HOST_CALL_POST_RELEASE,
-      /*return_address=*/0, resource0, resource1, resource2, resource3,
-      completion_signal);
-
-  IREE_AMDGPU_TRACE_ZONE_END(z0);
-}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/host_client.h b/runtime/src/iree/hal/drivers/amdgpu/device/host_client.h
deleted file mode 100644
index 1a0abd0..0000000
--- a/runtime/src/iree/hal/drivers/amdgpu/device/host_client.h
+++ /dev/null
@@ -1,115 +0,0 @@
-// Copyright 2025 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_HAL_DRIVERS_AMDGPU_DEVICE_HOST_CLIENT_H_
-#define IREE_HAL_DRIVERS_AMDGPU_DEVICE_HOST_CLIENT_H_
-
-#include "iree/hal/drivers/amdgpu/device/support/common.h"
-#include "iree/hal/drivers/amdgpu/device/support/queue.h"
-#include "iree/hal/drivers/amdgpu/device/tracing.h"
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_device_host_client_t
-//===----------------------------------------------------------------------===//
-
-typedef uint16_t iree_hal_amdgpu_device_host_call_t;
-enum iree_hal_amdgpu_device_host_call_e {
-  // Host will notify any registered listeners of the semaphore signal.
-  // The semaphore provided is a host handle to a generic HAL semaphore and may
-  // be of any device in the system - not just AMDGPU semaphores.
-  //
-  // Signature:
-  //   arg0: iree_hal_semaphore_t* semaphore
-  //   arg1: uint64_t payload
-  //   arg2: unused
-  //   arg3: unused
-  //   return_address: unused
-  //   completion_signal: unused
-  IREE_HAL_AMDGPU_DEVICE_HOST_CALL_POST_SIGNAL = 0u,
-
-  // Host will call iree_hal_resource_release on each non-NULL resource pointer.
-  // This is effectively a transfer operation indicating that the device will no
-  // longer be using the resources.
-  //
-  // It's strongly recommended that iree_hal_resource_set_t is used where
-  // appropriate so that the number of packets required to release a set of
-  // resources can be kept small. The 4 available here is just enough for the
-  // common case of submissions like execute that are a wait semaphore, the
-  // command buffer, the binding table resource set, and the signal semaphore.
-  //
-  // TODO(benvanik): evaluate a version that takes a ringbuffer of uint64_t
-  // pointers and make this a drain request instead. Then we can enqueue as many
-  // as we want and kick the host to drain as it is able.
-  //
-  // Signature:
-  //   arg0: iree_hal_resource_t* resource0
-  //   arg1: iree_hal_resource_t* resource1
-  //   arg2: iree_hal_resource_t* resource2
-  //   arg3: iree_hal_resource_t* resource3
-  //   return_address: unused
-  //   completion_signal: optional, signaled when the release has completed
-  IREE_HAL_AMDGPU_DEVICE_HOST_CALL_POST_RELEASE,
-};
-
-// Represents the host runtime thread that is managing host interrupts.
-// One or more schedulers may share a single host queue. Any host calls that
-// need to identify the scheduler or scheduler-related resources must pass those
-// as arguments.
-typedef struct iree_hal_amdgpu_device_host_client_t {
-  // Host soft-queue processing device requests. May be servicing requests from
-  // multiple device agents. Cached inline in device memory but references a
-  // ringbuffer in host memory.
-  iree_amd_cached_queue_t service_queue;
-  // Optional trace buffer used when tracing infrastructure is available.
-  iree_hal_amdgpu_device_trace_buffer_t* trace_buffer;
-} iree_hal_amdgpu_device_host_client_t;
-
-//===----------------------------------------------------------------------===//
-// Device-side Enqueuing
-//===----------------------------------------------------------------------===//
-
-#if defined(IREE_AMDGPU_TARGET_DEVICE)
-
-// Enqueues a unidirection host agent packet ("post").
-// Since this is device->host only operation this only uses an acquire scope
-// from the agent and releases to the entire system so the host agent can
-// observe changes. The completion signal is optional and may be
-// `iree_hsa_signal_null()`.
-//
-// NOTE: the barrier bit is set but the host processing is (today) synchronous
-// with respect to other packets and generally only executes in FIFO order with
-// respect to what each packet may affect anyway. We could tweak this in the
-// future e.g. posts to flush a ringbuffer don't need to block and can be
-// eagerly processed. Maybe. For non-post operations we'd rely on queue barrier
-// packets.
-//
-// NOTE: we currently use the agent dispatch packet fields as intended (mostly)
-// so that tooling that intercepts them can work. We don't have to, though, and
-// could even have custom vendor packets instead to get the most bytes out of
-// the channel.
-void iree_hal_amdgpu_device_host_client_post(
-    const iree_hal_amdgpu_device_host_client_t* IREE_AMDGPU_RESTRICT client,
-    uint16_t type, uint64_t return_address, uint64_t arg0, uint64_t arg1,
-    uint64_t arg2, uint64_t arg3, iree_hsa_signal_t completion_signal);
-
-// Posts a semaphore signal notification to the host.
-// This is only needed for external semaphores that are managed by the host.
-void iree_hal_amdgpu_device_host_client_post_signal(
-    const iree_hal_amdgpu_device_host_client_t* IREE_AMDGPU_RESTRICT client,
-    uint64_t semaphore, uint64_t payload);
-
-// Posts a multi-resource release request to the host.
-// The host will call iree_hal_resource_release on each non-NULL resource
-// pointer provided. The optional |completion_signal| will be signaled when the
-// release has completed.
-void iree_hal_amdgpu_device_host_client_post_release(
-    const iree_hal_amdgpu_device_host_client_t* IREE_AMDGPU_RESTRICT client,
-    uint64_t resource0, uint64_t resource1, uint64_t resource2,
-    uint64_t resource3, iree_hsa_signal_t completion_signal);
-
-#endif  // IREE_AMDGPU_TARGET_DEVICE
-
-#endif  // IREE_HAL_DRIVERS_AMDGPU_DEVICE_HOST_CLIENT_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/kernel_tables.h b/runtime/src/iree/hal/drivers/amdgpu/device/kernel_tables.h
index a1bd30f..8d3f378 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/device/kernel_tables.h
+++ b/runtime/src/iree/hal/drivers/amdgpu/device/kernel_tables.h
@@ -8,6 +8,9 @@
 // Blits (blit.h)
 //===----------------------------------------------------------------------===//
 
+// Conservative metadata defaults used when loading builtin kernel descriptors.
+// Transfer packet emission overrides the X dimension from runtime wavefront
+// metadata so the same code object table can support wave32 and wave64 devices.
 #define IREE_HAL_AMDGPU_BLIT_WORKGROUP_SIZE_X 32
 #define IREE_HAL_AMDGPU_BLIT_WORKGROUP_SIZE_Y 1
 #define IREE_HAL_AMDGPU_BLIT_WORKGROUP_SIZE_Z 1
@@ -32,75 +35,43 @@
                               IREE_HAL_AMDGPU_BLIT_WORKGROUP_SIZE_X,
                               IREE_HAL_AMDGPU_BLIT_WORKGROUP_SIZE_Y,
                               IREE_HAL_AMDGPU_BLIT_WORKGROUP_SIZE_Z)
+IREE_HAL_AMDGPU_DEVICE_KERNEL(
+    iree_hal_amdgpu_device_buffer_fill_block_unaligned_x16,
+    IREE_HAL_AMDGPU_BLIT_WORKGROUP_SIZE_X,
+    IREE_HAL_AMDGPU_BLIT_WORKGROUP_SIZE_Y,
+    IREE_HAL_AMDGPU_BLIT_WORKGROUP_SIZE_Z)
 IREE_HAL_AMDGPU_DEVICE_KERNEL(iree_hal_amdgpu_device_buffer_copy_x1,
                               IREE_HAL_AMDGPU_BLIT_WORKGROUP_SIZE_X,
                               IREE_HAL_AMDGPU_BLIT_WORKGROUP_SIZE_Y,
                               IREE_HAL_AMDGPU_BLIT_WORKGROUP_SIZE_Z)
+IREE_HAL_AMDGPU_DEVICE_KERNEL(iree_hal_amdgpu_device_buffer_copy_block_x4,
+                              IREE_HAL_AMDGPU_BLIT_WORKGROUP_SIZE_X,
+                              IREE_HAL_AMDGPU_BLIT_WORKGROUP_SIZE_Y,
+                              IREE_HAL_AMDGPU_BLIT_WORKGROUP_SIZE_Z)
+IREE_HAL_AMDGPU_DEVICE_KERNEL(iree_hal_amdgpu_device_buffer_copy_block_x8,
+                              IREE_HAL_AMDGPU_BLIT_WORKGROUP_SIZE_X,
+                              IREE_HAL_AMDGPU_BLIT_WORKGROUP_SIZE_Y,
+                              IREE_HAL_AMDGPU_BLIT_WORKGROUP_SIZE_Z)
 IREE_HAL_AMDGPU_DEVICE_KERNEL(iree_hal_amdgpu_device_buffer_copy_block_x16,
                               IREE_HAL_AMDGPU_BLIT_WORKGROUP_SIZE_X,
                               IREE_HAL_AMDGPU_BLIT_WORKGROUP_SIZE_Y,
                               IREE_HAL_AMDGPU_BLIT_WORKGROUP_SIZE_Z)
+IREE_HAL_AMDGPU_DEVICE_KERNEL(
+    iree_hal_amdgpu_device_buffer_copy_block_unaligned_x16,
+    IREE_HAL_AMDGPU_BLIT_WORKGROUP_SIZE_X,
+    IREE_HAL_AMDGPU_BLIT_WORKGROUP_SIZE_Y,
+    IREE_HAL_AMDGPU_BLIT_WORKGROUP_SIZE_Z)
 
 //===----------------------------------------------------------------------===//
-// Command buffers (command_buffer.h)
+// Dispatch helpers (dispatch.h)
 //===----------------------------------------------------------------------===//
 
-// TODO(benvanik): evaluate the optimal size for issue workgroup size.
-// Lower sizes (ideally 1) are the most reliable on current hardware that does
-// not allow for divergent threads _and_ the assumption that we have a mix of
-// commands that causes each thread to diverge, but that's a guess. We may find
-// that since 90+% of packets are dispatches we're mostly running the same code
-// paths per command and can benefit from thread-level parallelism.
-#define IREE_HAL_AMDGPU_CMD_ISSUE_WORKGROUP_SIZE_X 32
-#define IREE_HAL_AMDGPU_CMD_ISSUE_WORKGROUP_SIZE_Y 1
-#define IREE_HAL_AMDGPU_CMD_ISSUE_WORKGROUP_SIZE_Z 1
-
-#define IREE_HAL_AMDGPU_CMD_CONTROL_WORKGROUP_SIZE_X 1
-#define IREE_HAL_AMDGPU_CMD_CONTROL_WORKGROUP_SIZE_Y 1
-#define IREE_HAL_AMDGPU_CMD_CONTROL_WORKGROUP_SIZE_Z 1
-
-// NOTE: these workgroup sizes are guesses and need to be changed.
-IREE_HAL_AMDGPU_DEVICE_KERNEL(iree_hal_amdgpu_device_cmd_block_issue,
-                              IREE_HAL_AMDGPU_CMD_ISSUE_WORKGROUP_SIZE_X,
-                              IREE_HAL_AMDGPU_CMD_ISSUE_WORKGROUP_SIZE_Y,
-                              IREE_HAL_AMDGPU_CMD_ISSUE_WORKGROUP_SIZE_Z)
-IREE_HAL_AMDGPU_DEVICE_KERNEL(iree_hal_amdgpu_device_cmd_dispatch_update,
-                              IREE_HAL_AMDGPU_CMD_CONTROL_WORKGROUP_SIZE_X,
-                              IREE_HAL_AMDGPU_CMD_CONTROL_WORKGROUP_SIZE_Y,
-                              IREE_HAL_AMDGPU_CMD_CONTROL_WORKGROUP_SIZE_Z)
-IREE_HAL_AMDGPU_DEVICE_KERNEL(iree_hal_amdgpu_device_cmd_branch,
-                              IREE_HAL_AMDGPU_CMD_CONTROL_WORKGROUP_SIZE_X,
-                              IREE_HAL_AMDGPU_CMD_CONTROL_WORKGROUP_SIZE_Y,
-                              IREE_HAL_AMDGPU_CMD_CONTROL_WORKGROUP_SIZE_Z)
-IREE_HAL_AMDGPU_DEVICE_KERNEL(iree_hal_amdgpu_device_cmd_cond_branch,
-                              IREE_HAL_AMDGPU_CMD_CONTROL_WORKGROUP_SIZE_X,
-                              IREE_HAL_AMDGPU_CMD_CONTROL_WORKGROUP_SIZE_Y,
-                              IREE_HAL_AMDGPU_CMD_CONTROL_WORKGROUP_SIZE_Z)
-IREE_HAL_AMDGPU_DEVICE_KERNEL(iree_hal_amdgpu_device_cmd_return,
-                              IREE_HAL_AMDGPU_CMD_CONTROL_WORKGROUP_SIZE_X,
-                              IREE_HAL_AMDGPU_CMD_CONTROL_WORKGROUP_SIZE_Y,
-                              IREE_HAL_AMDGPU_CMD_CONTROL_WORKGROUP_SIZE_Z)
+IREE_HAL_AMDGPU_DEVICE_KERNEL(
+    iree_hal_amdgpu_device_dispatch_patch_indirect_params, 1, 1, 1)
 
 //===----------------------------------------------------------------------===//
-// Scheduling (scheduler.h)
+// Timestamp helpers (timestamp.h)
 //===----------------------------------------------------------------------===//
 
-// NOTE: these workgroup sizes are guesses and need to be changed.
-#if 0
-IREE_HAL_AMDGPU_DEVICE_KERNEL(iree_hal_amdgpu_device_queue_scheduler_initialize,
-                              1, 1, 1)
-IREE_HAL_AMDGPU_DEVICE_KERNEL(iree_hal_amdgpu_device_queue_scheduler_tick, 1, 1,
-                              1)
-IREE_HAL_AMDGPU_DEVICE_KERNEL(iree_hal_amdgpu_device_queue_retire_entry, 1, 1,
-                              1)
-#endif
-
-//===----------------------------------------------------------------------===//
-// Tracing (tracing.h)
-//===----------------------------------------------------------------------===//
-
-#if 0
-// NOTE: these workgroup sizes are guesses and need to be changed.
-IREE_HAL_AMDGPU_DEVICE_KERNEL(iree_hal_amdgpu_device_trace_buffer_initialize,
-                              32, 1, 1)
-#endif
+IREE_HAL_AMDGPU_DEVICE_KERNEL(
+    iree_hal_amdgpu_device_timestamp_harvest_dispatch_records, 32, 1, 1)
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/kernels.h b/runtime/src/iree/hal/drivers/amdgpu/device/kernels.h
index e57cc80..8f1509c 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/device/kernels.h
+++ b/runtime/src/iree/hal/drivers/amdgpu/device/kernels.h
@@ -7,8 +7,8 @@
 #ifndef IREE_HAL_DRIVERS_AMDGPU_DEVICE_KERNELS_H_
 #define IREE_HAL_DRIVERS_AMDGPU_DEVICE_KERNELS_H_
 
+#include "iree/hal/drivers/amdgpu/abi/kernel_args.h"
 #include "iree/hal/drivers/amdgpu/device/support/common.h"
-#include "iree/hal/drivers/amdgpu/device/support/kernel_args.h"
 
 //===----------------------------------------------------------------------===//
 // iree_hal_amdgpu_device_kernels_t
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/semaphore.c b/runtime/src/iree/hal/drivers/amdgpu/device/semaphore.c
deleted file mode 100644
index ab6027b..0000000
--- a/runtime/src/iree/hal/drivers/amdgpu/device/semaphore.c
+++ /dev/null
@@ -1,13 +0,0 @@
-// Copyright 2025 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/hal/drivers/amdgpu/device/semaphore.h"
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_device_semaphore_t
-//===----------------------------------------------------------------------===//
-
-// TODO(benvanik): implement device semaphore logic.
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/semaphore.h b/runtime/src/iree/hal/drivers/amdgpu/device/semaphore.h
deleted file mode 100644
index e92f658..0000000
--- a/runtime/src/iree/hal/drivers/amdgpu/device/semaphore.h
+++ /dev/null
@@ -1,97 +0,0 @@
-// Copyright 2025 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_HAL_DRIVERS_AMDGPU_DEVICE_SEMAPHORE_H_
-#define IREE_HAL_DRIVERS_AMDGPU_DEVICE_SEMAPHORE_H_
-
-#include "iree/hal/drivers/amdgpu/device/support/common.h"
-#include "iree/hal/drivers/amdgpu/device/support/mutex.h"
-#include "iree/hal/drivers/amdgpu/device/support/signal.h"
-
-typedef struct iree_hal_amdgpu_device_semaphore_t
-    iree_hal_amdgpu_device_semaphore_t;
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_device_semaphore_t
-//===----------------------------------------------------------------------===//
-
-// A semaphore with an intrusive linked list of targets to wake.
-// Semaphores are an HSA signal plus tracking of waiters to allow direct wakes.
-// For semaphores only ever used on devices we avoid host interrupts and only
-// use our own wakes.
-//
-// Targets are registered with the minimum value that must be reached before
-// they can wake and notifications will wake all that are satisfied. The linked
-// list uses storage within the waiting targets.
-//
-// This optimizes for being able to wake multiple waiters when the signal is
-// notified and tries to allow the waiters to do the completion state polling.
-// This enables external wakes to be handled the same as ones from the wake list
-// and for us to not need an exhaustive wake list per semaphore per queue entry.
-// Instead, we only need to track what unique semaphores are being waited on
-// and ensure each waiter is in the list once with the minimum payload that
-// would cause them to wake. Since we may have long pipelines of work on a small
-// number of semaphores this optimizes for "wake and dequeue next" instead of
-// needing to walk the entire pipeline graph.
-//
-// When a waiter wants to register for notification they insert themselves into
-// the list. The insertion will fail if the last notified value is >= the
-// requested minimum value (as no signal will ever be made for it) and callers
-// can immediately continue processing. This allows the waiters to treat
-// insertions as the polling operation instead of having to check multiple
-// times. If the waiter is already in the wake list at a larger value then the
-// entry is removed and reinserted in the appropriate order. This is relatively
-// rare and handled on a slow path.
-//
-// Must be allocated in device/host shared memory if it will ever be used on the
-// host - and most may be. A outer wrapper iree_hal_semaphore_t owns the memory
-// and manages its lifetime and can pool this device-side block for reuse.
-//
-// Thread-safe. May be accessed from both host and device concurrently.
-// Zero initialization compatible.
-//
-// TODO(benvanik): make doubly-linked? insertion scan from the tail may be best
-// as usually we enqueue operations in order (1->2->3).
-typedef struct iree_hal_amdgpu_device_semaphore_t {
-  // HSA signal in device-visible memory.
-  // This may be a ROCR DefaultSignal (busy-wait) or InterruptSignal (event)
-  // based on the semaphore type.
-  IREE_AMDGPU_DEVICE_PTR iree_amd_signal_t* signal;
-
-  // A pointer back to the owning host iree_hal_amdgpu_*_semaphore_t.
-  // Used when asking the host to manipulate the semaphore.
-  uint64_t host_semaphore;
-
-  // TODO(benvanik): implement device-side semaphore data for both host and
-  // device queue modes.
-} iree_hal_amdgpu_device_semaphore_t;
-
-// A list of semaphores and the payload the semaphore is expected to reach or
-// be signaled to depending on the operation.
-typedef struct iree_hal_amdgpu_device_semaphore_list_t {
-  uint16_t count;
-  uint16_t reserved0;
-  uint32_t reserved1;  // could store wait state tracking
-  struct {
-    iree_hal_amdgpu_device_semaphore_t* semaphore;
-    uint64_t payload;
-  } entries[];
-} iree_hal_amdgpu_device_semaphore_list_t;
-
-// Returns the total size in bytes of a semaphore list storing |count| entries.
-static inline size_t iree_hal_amdgpu_device_semaphore_list_size(
-    uint16_t count) {
-  iree_hal_amdgpu_device_semaphore_list_t* list = NULL;
-  return sizeof(*list) + count * sizeof(list->entries[0]);
-}
-
-#if defined(IREE_AMDGPU_TARGET_DEVICE)
-
-// TODO(benvanik): implement device semaphore logic.
-
-#endif  // IREE_AMDGPU_TARGET_DEVICE
-
-#endif  // IREE_HAL_DRIVERS_AMDGPU_DEVICE_SEMAPHORE_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/support/common.h b/runtime/src/iree/hal/drivers/amdgpu/device/support/common.h
index db6c5a5..26a318e 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/device/support/common.h
+++ b/runtime/src/iree/hal/drivers/amdgpu/device/support/common.h
@@ -14,66 +14,21 @@
 #ifndef IREE_HAL_DRIVERS_AMDGPU_DEVICE_SUPPORT_COMMON_H_
 #define IREE_HAL_DRIVERS_AMDGPU_DEVICE_SUPPORT_COMMON_H_
 
+#include "iree/hal/drivers/amdgpu/abi/common.h"  // IWYU pragma: export
+
+#if !defined(IREE_AMDGPU_TARGET_DEVICE)
+#include "iree/base/internal/atomics.h"
+#include "iree/base/threading/thread.h"
+#endif  // !IREE_AMDGPU_TARGET_DEVICE
+
 //===----------------------------------------------------------------------===//
 // Compiler Configuration
 //===----------------------------------------------------------------------===//
 
-#if defined(__AMDGPU__)
-#define IREE_AMDGPU_TARGET_DEVICE 1
-#else
-#define IREE_AMDGPU_TARGET_HOST 1
-#endif  // __AMDGPU__
-
 #if defined(IREE_AMDGPU_TARGET_DEVICE)
 
-typedef char int8_t;
-typedef unsigned char uint8_t;
-typedef short int16_t;
-typedef unsigned short uint16_t;
-typedef int int32_t;
-typedef unsigned int uint32_t;
-typedef long int64_t;
-typedef unsigned long uint64_t;
-
-typedef int64_t ssize_t;
-typedef uint64_t size_t;
-typedef int64_t intptr_t;
-typedef uint64_t uintptr_t;
-
-#define UINT64_MAX 0xFFFFFFFFFFFFFFFFull
-
-#define NULL ((void*)0)
-
-#else
-
-// NOTE: minimal support for including headers in host code is provided to make
-// sharing enums/structures possible; no code is expected to compile.
-
-#include <stddef.h>
-#include <stdint.h>
-
-#include "iree/base/internal/atomics.h"
-#include "iree/base/threading/thread.h"
-#include "third_party/hsa-runtime-headers/include/hsa/hsa.h"  // IWYU pragma: export
-
-#endif  // IREE_AMDGPU_TARGET_DEVICE
-
-//===----------------------------------------------------------------------===//
-// Attributes
-//===----------------------------------------------------------------------===//
-
-#if defined(IREE_AMDGPU_TARGET_DEVICE)
-
-#define IREE_AMDGPU_RESTRICT __restrict__
-#define IREE_AMDGPU_ALIGNAS(x) __attribute__((aligned(x)))
-#define IREE_AMDGPU_ALIGNOF(x) __alignof__(x)
-
-#define IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE __attribute__((always_inline))
 #define IREE_AMDGPU_ATTRIBUTE_SINGLE_WORK_ITEM
-#define IREE_AMDGPU_ATTRIBUTE_PACKED __attribute__((__packed__))
-
 #define IREE_AMDGPU_ATTRIBUTE_MUSTTAIL [[clang::musttail]]
-
 #define IREE_AMDGPU_ATTRIBUTE_KERNEL \
   [[clang::amdgpu_kernel, gnu::visibility("protected")]]
 
@@ -84,13 +39,7 @@
 
 #else
 
-#define IREE_AMDGPU_RESTRICT IREE_RESTRICT
-#define IREE_AMDGPU_ALIGNAS(x) iree_alignas(x)
-#define IREE_AMDGPU_ALIGNOF(x) iree_alignof(x)
-
-#define IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE IREE_ATTRIBUTE_ALWAYS_INLINE
 #define IREE_AMDGPU_ATTRIBUTE_SINGLE_WORK_ITEM
-#define IREE_AMDGPU_ATTRIBUTE_PACKED IREE_ATTRIBUTE_PACKED
 
 #define IREE_AMDGPU_LIKELY(x) IREE_LIKELY(x)
 #define IREE_AMDGPU_UNLIKELY(x) IREE_UNLIKELY(x)
@@ -134,8 +83,6 @@
 
 #if defined(IREE_AMDGPU_TARGET_DEVICE)
 
-#define IREE_AMDGPU_OFFSETOF(type, field) __builtin_offsetof(type, field)
-
 // Returns the number of leading zeros in a 64-bit bitfield.
 // Returns -1 if no bits are set.
 // Commonly used in HIP as `__lastbit_u32_u64`.
@@ -149,8 +96,6 @@
 
 #else
 
-#define IREE_AMDGPU_OFFSETOF(type, field) offsetof(type, field)
-
 #define IREE_AMDGPU_LASTBIT_U64(v) \
   ((v) == 0 ? -1 : iree_math_count_trailing_zeros_u64(v))
 
@@ -298,11 +243,6 @@
 // Timing
 //===----------------------------------------------------------------------===//
 
-// Tick in the agent domain.
-// This can be converted to the system domain for correlation across agents and
-// the host with hsa_amd_profiling_convert_tick_to_system_domain.
-typedef uint64_t iree_amdgpu_device_tick_t;
-
 #if defined(IREE_AMDGPU_TARGET_DEVICE)
 
 // Returns a tick in the agent domain.
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/support/kernel.h b/runtime/src/iree/hal/drivers/amdgpu/device/support/kernel.h
new file mode 100644
index 0000000..98ee096
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/device/support/kernel.h
@@ -0,0 +1,159 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// Device-side kernel execution geometry helpers built on the AMDGPU dispatch
+// packet ABI. For kernel argument and packet layouts see abi/kernel_args.h and
+// abi/queue.h (exported below).
+//
+// These helpers intentionally avoid depending on device queue mutation APIs.
+// Builtin kernels and tracing paths can include this header to inspect the
+// currently executing dispatch without inheriting the old device-side enqueue
+// surface.
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_DEVICE_SUPPORT_KERNEL_H_
+#define IREE_HAL_DRIVERS_AMDGPU_DEVICE_SUPPORT_KERNEL_H_
+
+#include "iree/hal/drivers/amdgpu/abi/kernel_args.h"  // IWYU pragma: export
+#include "iree/hal/drivers/amdgpu/abi/queue.h"        // IWYU pragma: export
+#include "iree/hal/drivers/amdgpu/device/support/common.h"
+
+//===----------------------------------------------------------------------===//
+// OpenCL/HIP Dispatch ABI
+//===----------------------------------------------------------------------===//
+// These come from llvm-project/amd/device-libs/ockl/src/workitem.cl (the ockl
+// functions) and llvm-project/clang/lib/CodeGen/CGBuiltin.cpp (e.g.
+// EmitAMDGPUWorkGroupSize). Using either runs a chance of pulling in the
+// entire iree_amdgpu_kernel_implicit_args_t struct and we don't want to set
+// that. We also don't need it: we aren't requiring OpenCL compatibility and
+// have no need for the extra features provided by the implicit args (like
+// workgroup offset and device-side enqueue - that's our job).
+
+#if defined(IREE_AMDGPU_TARGET_DEVICE)
+
+// Returns the pointer to the iree_hsa_kernel_dispatch_packet_t being executed.
+#define iree_amdgcn_dispatch_ptr()                                 \
+  ((const iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT) \
+       __builtin_amdgcn_dispatch_ptr())
+
+// __ockl_get_global_id(0) / get_global_id_x using OLD_ABI.
+static inline IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE size_t
+iree_hal_amdgpu_device_global_id_x(void) {
+  const uint32_t local_id = __builtin_amdgcn_workitem_id_x();
+  const uint32_t group_id = __builtin_amdgcn_workgroup_id_x();
+  const uint32_t group_size = iree_amdgcn_dispatch_ptr()->workgroup_size[0];
+  return group_id * group_size + local_id;
+}
+static inline IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE size_t
+iree_hal_amdgpu_device_global_id_y(void) {
+  const uint32_t local_id = __builtin_amdgcn_workitem_id_y();
+  const uint32_t group_id = __builtin_amdgcn_workgroup_id_y();
+  const uint32_t group_size = iree_amdgcn_dispatch_ptr()->workgroup_size[1];
+  return group_id * group_size + local_id;
+}
+static inline IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE size_t
+iree_hal_amdgpu_device_global_id_z(void) {
+  const uint32_t local_id = __builtin_amdgcn_workitem_id_z();
+  const uint32_t group_id = __builtin_amdgcn_workgroup_id_z();
+  const uint32_t group_size = iree_amdgcn_dispatch_ptr()->workgroup_size[2];
+  return group_id * group_size + local_id;
+}
+
+// __ockl_get_group_id(0)
+#define iree_hal_amdgpu_device_group_id_x() __builtin_amdgcn_workgroup_id_x()
+#define iree_hal_amdgpu_device_group_id_y() __builtin_amdgcn_workgroup_id_y()
+#define iree_hal_amdgpu_device_group_id_z() __builtin_amdgcn_workgroup_id_z()
+
+// __ockl_get_num_groups(0)
+static inline IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE size_t
+iree_hal_amdgpu_device_group_count_x(void) {
+  const uint32_t grid_size = iree_amdgcn_dispatch_ptr()->grid_size[0];
+  const uint32_t group_size = iree_amdgcn_dispatch_ptr()->workgroup_size[0];
+  const uint32_t group_count = grid_size / group_size;
+  return group_count + (grid_size > group_count * group_size);
+}
+static inline IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE size_t
+iree_hal_amdgpu_device_group_count_y(void) {
+  const uint32_t grid_size = iree_amdgcn_dispatch_ptr()->grid_size[1];
+  const uint32_t group_size = iree_amdgcn_dispatch_ptr()->workgroup_size[1];
+  const uint32_t group_count = grid_size / group_size;
+  return group_count + (grid_size > group_count * group_size);
+}
+static inline IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE size_t
+iree_hal_amdgpu_device_group_count_z(void) {
+  const uint32_t grid_size = iree_amdgcn_dispatch_ptr()->grid_size[2];
+  const uint32_t group_size = iree_amdgcn_dispatch_ptr()->workgroup_size[2];
+  const uint32_t group_count = grid_size / group_size;
+  return group_count + (grid_size > group_count * group_size);
+}
+
+// __ockl_get_local_id(0)
+#define iree_hal_amdgpu_device_local_id_x() __builtin_amdgcn_workitem_id_x()
+#define iree_hal_amdgpu_device_local_id_y() __builtin_amdgcn_workitem_id_y()
+#define iree_hal_amdgpu_device_local_id_z() __builtin_amdgcn_workitem_id_z()
+
+// __ockl_get_local_size(0) / get_local_size_x
+static inline IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE size_t
+iree_hal_amdgpu_device_workgroup_size_x(void) {
+  const uint32_t group_id = __builtin_amdgcn_workgroup_id_x();
+  const uint32_t group_size = iree_amdgcn_dispatch_ptr()->workgroup_size[0];
+  const uint32_t grid_size = iree_amdgcn_dispatch_ptr()->grid_size[0];
+  const uint32_t remainder = grid_size - group_id * group_size;
+  return IREE_AMDGPU_MIN(remainder, group_size);
+}
+static inline IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE size_t
+iree_hal_amdgpu_device_workgroup_size_y(void) {
+  const uint32_t group_id = __builtin_amdgcn_workgroup_id_y();
+  const uint32_t group_size = iree_amdgcn_dispatch_ptr()->workgroup_size[1];
+  const uint32_t grid_size = iree_amdgcn_dispatch_ptr()->grid_size[1];
+  const uint32_t remainder = grid_size - group_id * group_size;
+  return IREE_AMDGPU_MIN(remainder, group_size);
+}
+static inline IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE size_t
+iree_hal_amdgpu_device_workgroup_size_z(void) {
+  const uint32_t group_id = __builtin_amdgcn_workgroup_id_z();
+  const uint32_t group_size = iree_amdgcn_dispatch_ptr()->workgroup_size[2];
+  const uint32_t grid_size = iree_amdgcn_dispatch_ptr()->grid_size[2];
+  const uint32_t remainder = grid_size - group_id * group_size;
+  return IREE_AMDGPU_MIN(remainder, group_size);
+}
+
+static inline IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE size_t
+iree_hal_amdgpu_device_global_linear_id_1d(void) {
+  return iree_hal_amdgpu_device_group_id_x() *
+             iree_amdgcn_dispatch_ptr()->workgroup_size[0] +
+         iree_hal_amdgpu_device_local_id_x();
+}
+
+static inline IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE size_t
+iree_hal_amdgpu_device_global_linear_id_2d(void) {
+  const size_t id_x = iree_hal_amdgpu_device_group_id_x() *
+                          iree_amdgcn_dispatch_ptr()->workgroup_size[0] +
+                      iree_hal_amdgpu_device_local_id_x();
+  const size_t id_y = iree_hal_amdgpu_device_group_id_y() *
+                          iree_amdgcn_dispatch_ptr()->workgroup_size[1] +
+                      iree_hal_amdgpu_device_local_id_y();
+  return id_y * iree_amdgcn_dispatch_ptr()->grid_size[0] + id_x;
+}
+
+static inline IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE size_t
+iree_hal_amdgpu_device_global_linear_id_3d(void) {
+  const size_t id_x = iree_hal_amdgpu_device_group_id_x() *
+                          iree_amdgcn_dispatch_ptr()->workgroup_size[0] +
+                      iree_hal_amdgpu_device_local_id_x();
+  const size_t id_y = iree_hal_amdgpu_device_group_id_y() *
+                          iree_amdgcn_dispatch_ptr()->workgroup_size[1] +
+                      iree_hal_amdgpu_device_local_id_y();
+  const size_t id_z = iree_hal_amdgpu_device_group_id_z() *
+                          iree_amdgcn_dispatch_ptr()->workgroup_size[2] +
+                      iree_hal_amdgpu_device_local_id_z();
+  return (id_z * iree_amdgcn_dispatch_ptr()->grid_size[1] + id_y) *
+             iree_amdgcn_dispatch_ptr()->grid_size[0] +
+         id_x;
+}
+
+#endif  // IREE_AMDGPU_TARGET_DEVICE
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_DEVICE_SUPPORT_KERNEL_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/support/kernel_args.h b/runtime/src/iree/hal/drivers/amdgpu/device/support/kernel_args.h
deleted file mode 100644
index f150a2a..0000000
--- a/runtime/src/iree/hal/drivers/amdgpu/device/support/kernel_args.h
+++ /dev/null
@@ -1,56 +0,0 @@
-// Copyright 2025 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_HAL_DRIVERS_AMDGPU_DEVICE_SUPPORT_KERNEL_ARGS_H_
-#define IREE_HAL_DRIVERS_AMDGPU_DEVICE_SUPPORT_KERNEL_ARGS_H_
-
-#include "iree/hal/drivers/amdgpu/device/support/common.h"
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_device_kernel_args_t
-//===----------------------------------------------------------------------===//
-
-// Kernel arguments used for fixed-size kernels.
-// This must match what the kernel was compiled to support.
-typedef struct iree_hal_amdgpu_device_kernel_args_t {
-  // Opaque handle to the kernel object to execute.
-  uint64_t kernel_object;
-  // Dispatch setup parameters. Used to configure kernel dispatch parameters
-  // such as the number of dimensions in the grid. The parameters are
-  // described by hsa_kernel_dispatch_packet_setup_t.
-  uint16_t setup;
-  // XYZ dimensions of work-group, in work-items. Must be greater than 0.
-  // If the grid has fewer than 3 dimensions the unused must be 1.
-  uint16_t workgroup_size[3];
-  // Size in bytes of private memory allocation request (per work-item).
-  uint32_t private_segment_size;
-  // Size in bytes of group memory allocation request (per work-group). Must
-  // not be less than the sum of the group memory used by the kernel (and the
-  // functions it calls directly or indirectly) and the dynamically allocated
-  // group segment variables.
-  uint32_t group_segment_size;
-  // Size of kernarg segment memory that is required to hold the values of the
-  // kernel arguments, in bytes. Must be a multiple of 16.
-  uint16_t kernarg_size;
-  // Alignment (in bytes) of the buffer used to pass arguments to the kernel,
-  // which is the maximum of 16 and the maximum alignment of any of the kernel
-  // arguments.
-  uint16_t kernarg_alignment;
-  // Allocated source location in host memory. Inaccessible and only here to
-  // feed back to the host for trace processing.
-  uint64_t trace_src_loc;
-  // Total number of 4-byte constants used by the dispatch (if a HAL dispatch).
-  uint16_t constant_count;
-  // Total number of bindings used by the dispatch (if a HAL dispatch).
-  uint16_t binding_count;
-  uint32_t reserved;
-} iree_hal_amdgpu_device_kernel_args_t;
-static_assert(
-    sizeof(iree_hal_amdgpu_device_kernel_args_t) <= 64,
-    "keep hot kernel arg structure in as few cache lines as possible; every "
-    "dispatch issued must access this information and it is likely uncached");
-
-#endif  // IREE_HAL_DRIVERS_AMDGPU_DEVICE_SUPPORT_KERNEL_ARGS_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/support/mutex.h b/runtime/src/iree/hal/drivers/amdgpu/device/support/mutex.h
deleted file mode 100644
index 5973d64..0000000
--- a/runtime/src/iree/hal/drivers/amdgpu/device/support/mutex.h
+++ /dev/null
@@ -1,68 +0,0 @@
-// Copyright 2025 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_HAL_DRIVERS_AMDGPU_DEVICE_SUPPORT_MUTEX_H_
-#define IREE_HAL_DRIVERS_AMDGPU_DEVICE_SUPPORT_MUTEX_H_
-
-#include "iree/hal/drivers/amdgpu/device/support/common.h"
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_device_mutex_t
-//===----------------------------------------------------------------------===//
-
-// Device spin-lock mutex.
-// This can run on the host as well but is optimized for device usage. Spinning
-// on the host is a bad idea. Spinning on the device is _also_ a bad idea, but
-// does have its uses.
-//
-// Note that because atomics are not guaranteed to work off-agent this is only
-// to be used for intra-agent exclusion such as when multiple queues on the
-// same agent are sharing a data structure.
-//
-// Reference: https://rigtorp.se/spinlock/
-typedef iree_amdgpu_scoped_atomic_uint32_t iree_hal_amdgpu_device_mutex_t;
-
-#define IREE_HAL_AMDGPU_DEVICE_MUTEX_UNLOCKED 0u
-#define IREE_HAL_AMDGPU_DEVICE_MUTEX_LOCKED 1u
-
-// Initializes a mutex to the unlocked state.
-static inline void iree_hal_amdgpu_device_mutex_initialize(
-    iree_hal_amdgpu_device_mutex_t* IREE_AMDGPU_RESTRICT out_mutex) {
-  uint32_t initial_value = IREE_HAL_AMDGPU_DEVICE_MUTEX_UNLOCKED;
-  IREE_AMDGPU_SCOPED_ATOMIC_INIT(out_mutex, initial_value);
-}
-
-// Spins until a lock on the mutex is acquired.
-static inline void iree_hal_amdgpu_device_mutex_lock(
-    iree_hal_amdgpu_device_mutex_t* IREE_AMDGPU_RESTRICT mutex) {
-  for (;;) {
-    // Optimistically assume the lock is free on the first try.
-    uint32_t prev = IREE_HAL_AMDGPU_DEVICE_MUTEX_UNLOCKED;
-    if (iree_amdgpu_scoped_atomic_compare_exchange_strong(
-            mutex, &prev, IREE_HAL_AMDGPU_DEVICE_MUTEX_LOCKED,
-            iree_amdgpu_memory_order_acquire, iree_amdgpu_memory_order_acquire,
-            iree_amdgpu_memory_scope_system)) {
-      return;
-    }
-    // Wait for lock to be released without generating cache misses.
-    while (iree_amdgpu_scoped_atomic_load(mutex,
-                                          iree_amdgpu_memory_order_relaxed,
-                                          iree_amdgpu_memory_scope_system)) {
-      // Yield for a bit to give the other thread a chance to unlock.
-      iree_amdgpu_yield();
-    }
-  }
-}
-
-// Unlocks a mutex. Must be called with the lock held by the caller.
-static inline void iree_hal_amdgpu_device_mutex_unlock(
-    iree_hal_amdgpu_device_mutex_t* IREE_AMDGPU_RESTRICT mutex) {
-  iree_amdgpu_scoped_atomic_store(mutex, IREE_HAL_AMDGPU_DEVICE_MUTEX_UNLOCKED,
-                                  iree_amdgpu_memory_order_release,
-                                  iree_amdgpu_memory_scope_system);
-}
-
-#endif  // IREE_HAL_DRIVERS_AMDGPU_DEVICE_SUPPORT_MUTEX_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/support/queue.h b/runtime/src/iree/hal/drivers/amdgpu/device/support/queue.h
index 6a4ad4b..bcc5165 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/device/support/queue.h
+++ b/runtime/src/iree/hal/drivers/amdgpu/device/support/queue.h
@@ -4,10 +4,9 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
-// NOTE: these structs/enums are taken from the HSA spec, the hsa.h and
-// hsa_ext_amd.h headers, and the LLVM AMDGPU device library headers.
-// We define them locally as the HSA headers cannot be directly used in
-// bare-metal C and the device libraries are only available in a fork of LLM.
+// Device-side queue manipulation functions and optimized queue access built on
+// top of the ABI queue types. Kernel dispatch/work-item helpers live in
+// device/support/kernel.h, which is exported below for compatibility.
 //
 // Sources:
 // https://hsafoundation.com/wp-content/uploads/2021/02/HSA-SysArch-1.2.pdf
@@ -17,120 +16,10 @@
 #ifndef IREE_HAL_DRIVERS_AMDGPU_DEVICE_SUPPORT_QUEUE_H_
 #define IREE_HAL_DRIVERS_AMDGPU_DEVICE_SUPPORT_QUEUE_H_
 
-#include "iree/hal/drivers/amdgpu/device/support/common.h"
+#include "iree/hal/drivers/amdgpu/device/support/kernel.h"  // IWYU pragma: export
 #include "iree/hal/drivers/amdgpu/device/support/signal.h"
 
 //===----------------------------------------------------------------------===//
-// HSA/AMDGPU AQL Queue
-//===----------------------------------------------------------------------===//
-
-typedef enum {
-  // Queue supports multiple producers.
-  IREE_HSA_QUEUE_TYPE_MULTI = 0,
-  // Queue only supports a single producer.
-  IREE_HSA_QUEUE_TYPE_SINGLE = 1,
-} iree_hsa_queue_type_t;
-
-// NOTE: this is not our struct and we cannot change it.
-typedef struct iree_hsa_queue_t {
-  // Queue type.
-  iree_hsa_queue_type_t type;
-
-  // Queue features mask. This is a bit-field of iree_hsa_queue_feature_t
-  // values. Applications should ignore any unknown set bits.
-  uint32_t features;
-
-  // Packet storage. Must be accessible on any agents that may operate on it and
-  // aligned to at least 64 (the size of an AQL packet).
-  void* base_address;
-
-  // Signal object used by the application to indicate the ID of a packet that
-  // is ready to be processed. The HSA runtime or hardware packet processor
-  // manages the doorbell signal. If the application tries to replace or destroy
-  // this signal the behavior is undefined.
-  //
-  // If type is HSA_QUEUE_TYPE_SINGLE the doorbell signal value must be
-  // updated in a monotonically increasing fashion. If type is
-  // HSA_QUEUE_TYPE_MULTI the doorbell signal value can be updated with any
-  // value and the act of writing a differing value is enough to wake the
-  // processor. On AMD GPUs today it is reportedly not any more efficient to
-  // use SINGLE queues as the packet processor handles both the same way.
-  iree_hsa_signal_t doorbell_signal;
-
-  // Maximum number of packets the queue can hold. Must be a power of 2.
-  uint32_t size;
-
-  uint32_t reserved1;  // must be 0
-
-  // Queue identifier, which is unique over the lifetime of the application even
-  // if the queue is reallocated.
-  uint64_t id;
-} iree_hsa_queue_t;
-
-#define IREE_AMD_HSA_BITS_CREATE_ENUM_ENTRIES(name, shift, width) \
-  name##_SHIFT = (shift), name##_WIDTH = (width),                 \
-  name = (((1 << (width)) - 1) << (shift))
-#define IREE_AMD_HSA_BITS_SET(dst, mask, val) \
-  dst &= (~(1 << mask##_SHIFT) & ~mask);      \
-  dst |= (((val) << mask##_SHIFT) & mask)
-
-enum iree_amd_queue_properties_t {
-  IREE_AMD_HSA_BITS_CREATE_ENUM_ENTRIES(
-      IREE_AMD_QUEUE_PROPERTIES_ENABLE_TRAP_HANDLER, 0, 1),
-  // All devices we care about are 64-bit.
-  IREE_AMD_HSA_BITS_CREATE_ENUM_ENTRIES(IREE_AMD_QUEUE_PROPERTIES_IS_PTR64, 1,
-                                        1),
-  IREE_AMD_HSA_BITS_CREATE_ENUM_ENTRIES(
-      IREE_AMD_QUEUE_PROPERTIES_ENABLE_TRAP_HANDLER_DEBUG_SGPRS, 2, 1),
-  // Timestamps will be stored on signals (start_ts/end_ts).
-  IREE_AMD_HSA_BITS_CREATE_ENUM_ENTRIES(
-      IREE_AMD_QUEUE_PROPERTIES_ENABLE_PROFILING, 3, 1),
-  IREE_AMD_HSA_BITS_CREATE_ENUM_ENTRIES(
-      IREE_AMD_QUEUE_PROPERTIES_USE_SCRATCH_ONCE, 4, 1),
-  IREE_AMD_HSA_BITS_CREATE_ENUM_ENTRIES(IREE_AMD_QUEUE_PROPERTIES_RESERVED1, 5,
-                                        27)
-};
-typedef uint32_t iree_amd_queue_properties32_t;
-
-// An AQL packet queue.
-// We generally treat these as opaque except for if we need to read queue
-// properties to check modes - otherwise we just treat any queue handle as
-// an iree_hsa_queue_t.
-typedef struct IREE_AMDGPU_ALIGNAS(64) iree_amd_queue_t {
-  iree_hsa_queue_t hsa_queue;
-  uint32_t caps;
-  uint32_t reserved1[3];
-  volatile uint64_t write_dispatch_id;
-  uint32_t group_segment_aperture_base_hi;
-  uint32_t private_segment_aperture_base_hi;
-  uint32_t max_cu_id;
-  uint32_t max_wave_id;
-  volatile uint64_t max_legacy_doorbell_dispatch_id_plus_1;
-  volatile uint32_t legacy_doorbell_lock;
-  uint32_t reserved2[9];
-  volatile uint64_t read_dispatch_id;
-  uint32_t read_dispatch_id_field_base_byte_offset;
-  uint32_t compute_tmpring_size;
-  uint32_t scratch_resource_descriptor[4];
-  uint64_t scratch_backing_memory_location;
-  uint64_t scratch_backing_memory_byte_size;
-  uint32_t scratch_workitem_byte_size;
-  iree_amd_queue_properties32_t queue_properties;
-  volatile uint64_t scratch_last_used_index; /* async-reclaim */
-  iree_hsa_signal_t queue_inactive_signal;
-  uint32_t reserved4[2];
-  volatile uint64_t alt_scratch_last_used_index; /* async-reclaim */
-  uint64_t alt_scratch_backing_memory_location;  /* async-reclaim */
-  uint64_t alt_scratch_backing_memory_byte_size; /* async-reclaim */
-  uint32_t alt_scratch_dispatch_limit_x;         /* async-reclaim */
-  uint32_t alt_scratch_dispatch_limit_y;         /* async-reclaim */
-  uint32_t alt_scratch_dispatch_limit_z;         /* async-reclaim */
-  uint32_t alt_scratch_wave64_lane_byte_size;    /* async-reclaim */
-  uint32_t alt_compute_tmpring_size;             /* async-reclaim */
-  uint32_t reserved5;
-} iree_amd_queue_t;
-
-//===----------------------------------------------------------------------===//
 // Cached HSA/AMD Queue
 //===----------------------------------------------------------------------===//
 
@@ -171,7 +60,7 @@
   iree_amd_queue_t* queue;
 } iree_amd_cached_queue_t;
 
-// Returns an HSA queue reference with the important
+// Returns a cached queue with the hot fields hoisted from the full HSA queue.
 static inline iree_amd_cached_queue_t iree_amd_make_cached_queue(
     iree_hsa_queue_t* queue) {
   iree_amd_cached_queue_t result = {
@@ -191,701 +80,6 @@
 }
 
 //===----------------------------------------------------------------------===//
-// HSA/AMDGPU AQL Packets
-//===----------------------------------------------------------------------===//
-
-typedef enum {
-  // Handled entirely by the packet processor and will vary agent to agent.
-  IREE_HSA_PACKET_TYPE_VENDOR_SPECIFIC = 0,
-  // Invalid packet (not yet populated) that will stall the packet processor.
-  IREE_HSA_PACKET_TYPE_INVALID = 1,
-  // iree_hsa_kernel_dispatch_packet_t
-  IREE_HSA_PACKET_TYPE_KERNEL_DISPATCH = 2,
-  // iree_hsa_barrier_and_packet_t
-  IREE_HSA_PACKET_TYPE_BARRIER_AND = 3,
-  // iree_hsa_agent_dispatch_packet_t
-  IREE_HSA_PACKET_TYPE_AGENT_DISPATCH = 4,
-  // iree_hsa_barrier_or_packet_t
-  IREE_HSA_PACKET_TYPE_BARRIER_OR = 5,
-} iree_hsa_packet_type_t;
-
-// Bit offsets within the header word of various values.
-// We have to perform the bit manipulation ourselves because OpenCL has no
-// bitfields. Crazy.
-//
-// If we did have bitfields the struct would look like:
-// typedef struct {
-//   uint16_t type : 8;
-//   uint16_t barrier : 1;
-//   uint16_t scacquire_fence_scope : 2;
-//   uint16_t screlease_fence_scope : 2;
-//   uint16_t reserved : 3;  // must be 0
-// } iree_hsa_packet_header_t;
-//
-// Since the smallest atomic width is 32-bits and this header is 16-bits any
-// operations updating the header must include the subsequent 16-bits of the
-// packet (e.g. setup for kernel dispatches).
-//
-// See spec 2.9.1 and child entries for the full details.
-typedef enum {
-  // Determines the packet type as processed by the packet processor.
-  // The header is the same for all packets but all other following contents may
-  // change.
-  IREE_HSA_PACKET_HEADER_TYPE = 0,
-  // If set then processing of the packet will only begin when all preceding
-  // packets are complete. There is no implicit fence defined as part of the
-  // barrier and an acquire fence scope must still be specified if any is
-  // required.
-  IREE_HSA_PACKET_HEADER_BARRIER = 8,
-  // A packet memory acquire fence ensures any subsequent global segment or
-  // image loads by any unit of execution that belongs to a dispatch that has
-  // not yet entered the active phase on any queue of the same agent, sees any
-  // data previously released at the scopes specified by the packet acquire
-  // fence.
-  //
-  // Behavior:
-  //   IREE_HSA_FENCE_SCOPE_NONE:
-  //     No fence is applied and the packet relies on an earlier acquire fence
-  //     performed on the agent or acquire fences within the operation (e.g. by
-  //     the kernel).
-  //   IREE_HSA_FENCE_SCOPE_AGENT:
-  //     The acquire fence is applied with agent scope for the global segment.
-  //   IREE_HSA_FENCE_SCOPE_SYSTEM:
-  //     The acquire fence is applied across both agent and system scope for the
-  //     global segment.
-  IREE_HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE = 9,
-  // A packet memory release fence makes any global segment or image data that
-  // was stored by any unit of execution that belonged to a dispatch that has
-  // completed the active phase on any queue of the same agent visible in all
-  // the scopes specified by the packet release fence.
-  //
-  // Behavior:
-  //   IREE_HSA_FENCE_SCOPE_NONE:
-  //     No fence is applied and the packet relies on a later release fence
-  //     performed on the agent or release fences within the operation (e.g. by
-  //     the kernel).
-  //   IREE_HSA_FENCE_SCOPE_AGENT:
-  //     The release fence is applied with agent scope for the global segment.
-  //   IREE_HSA_FENCE_SCOPE_SYSTEM:
-  //     The release fence is applied across both agent and system scope for the
-  //     global segment.
-  IREE_HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE = 11,
-} iree_hsa_packet_header_t;
-
-// Width in bits of the sub-fields in iree_hsa_packet_header_t.
-typedef enum {
-  IREE_HSA_PACKET_HEADER_WIDTH_TYPE = 8,
-  IREE_HSA_PACKET_HEADER_WIDTH_BARRIER = 1,
-  IREE_HSA_PACKET_HEADER_WIDTH_SCACQUIRE_FENCE_SCOPE = 2,
-  IREE_HSA_PACKET_HEADER_WIDTH_SCRELEASE_FENCE_SCOPE = 2,
-} iree_hsa_packet_header_width_t;
-
-// Forms a packet 16-bit AQL packet header.
-#define iree_hsa_make_packet_header(type, is_barrier, scacquire_fence_scope,   \
-                                    screlease_fence_scope)                     \
-  (((type) << IREE_HSA_PACKET_HEADER_TYPE) |                                   \
-   ((is_barrier ? 1 : 0) << IREE_HSA_PACKET_HEADER_BARRIER) |                  \
-   ((scacquire_fence_scope) << IREE_HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE) | \
-   ((screlease_fence_scope) << IREE_HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE))
-
-typedef enum {
-  // No cache management occurs.
-  IREE_HSA_FENCE_SCOPE_NONE = 0,
-  // Invalidates I, K and L1 caches. Changes will be available to any queue on
-  // the same agent but may not be available on any other agent.
-  IREE_HSA_FENCE_SCOPE_AGENT = 1,
-  // Invalidates L1, L2 and flushes L2 caches. Changes will be available on all
-  // agents in the system after the fence completes.
-  IREE_HSA_FENCE_SCOPE_SYSTEM = 2,
-} iree_hsa_fence_scope_t;
-
-// Kernel dispatch (2.9.6 in the spec).
-//
-// Pseudo-code:
-//   for (uint32_t z = 0; z < grid_size[2] / workgroup_size[2]; ++z) {
-//     for (uint32_t y = 0; y < grid_size[1] / workgroup_size[1]; ++y) {
-//       for (uint32_t x = 0; x < grid_size[0] / workgroup_size[0]; ++x) {
-//         kernel_object(*kernarg_address);
-//       }
-//     }
-//   }
-//   iree_hsa_signal_subtract(completion_signal, 1);
-//
-// The acquire fence is applied at the end of the launch phase just before the
-// packet enters the active phase. The release fence is applied at the start of
-// the completion phase of the packet.
-typedef struct iree_hsa_kernel_dispatch_packet_t {
-  // AQL packet header. See iree_hsa_packet_header_t for details.
-  uint16_t header;
-  // Number of grid dimensions (1, 2, or 3 - we always use 3).
-  uint16_t setup;
-  // Work-group size in work-items.
-  uint16_t workgroup_size[3];
-  uint16_t reserved0;  // must be 0
-  // Grid size in work-items.
-  uint32_t grid_size[3];
-  // Total size in bytes of the per-work-item memory.
-  uint32_t private_segment_size;
-  // Total size in bytes of the per-work-group memory.
-  uint32_t group_segment_size;
-  // Kernel object (function) handle as returned from a query on the symbol
-  // of HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_OBJECT.
-  uint64_t kernel_object;
-  // Kernel arguments as required by the function.
-  // Must be 16-byte aligned and live until the dispatch has completed.
-  void* kernarg_address;
-  uint64_t reserved2;  // must be 0
-  iree_hsa_signal_t completion_signal;
-  // Optional signal indicating completion of all work-groups.
-} iree_hsa_kernel_dispatch_packet_t;
-
-// Agent dispatch (2.9.7 in the spec).
-//
-// Pseudo-code:
-//   *return_address = fns[type](arg[0], arg[1], arg[2], arg[3]);
-//   iree_hsa_signal_subtract(completion_signal, 1);
-//
-// The acquire fence is applied at the end of the launch phase just before the
-// packet enters the active phase. The release fence is applied at the start of
-// the completion phase of the packet.
-typedef struct iree_hsa_agent_dispatch_packet_t {
-  // AQL packet header. See iree_hsa_packet_header_t for details.
-  uint16_t header;
-  // Agent-defined type (discriminator).
-  uint16_t type;
-  uint32_t reserved0;  // must be 0
-  // Pointer to store the return value(s) in with the contents and layout
-  // defined by the type.
-  void* return_address;
-  // Arguments to the dispatch as defined by the type.
-  uint64_t arg[4];
-  uint64_t reserved2;  // must be 0
-  // Optional signal indicating completion of the dispatch.
-  iree_hsa_signal_t completion_signal;
-} iree_hsa_agent_dispatch_packet_t;
-
-// Barrier-AND (2.9.8 in the spec).
-// Waits until all dep_signals reach the value 0 at the same time and then
-// decrements the completion_signal. Ignores any 0 (null) signals.
-//
-// Pseudo-code:
-//   do {
-//     bool any_unsatisfied = false;
-//     for (int i = 0; i < 5; ++i) {
-//       if (iree_hsa_signal_load(dep_signal[i]) != 0) any_unsatisfied = true;
-//     }
-//     if (!any_unsatisfied) break;
-//     iree_amdgpu_yield();
-//   } while(true);
-//   iree_hsa_signal_subtract(completion_signal, 1);
-//
-// The acquire fence is processed first in the completion phase of the packet
-// after the barrier condition has been met. The release fence is processed
-// after the acquire fence in the completion phase.
-typedef struct iree_hsa_barrier_and_packet_t {
-  // AQL packet header. See iree_hsa_packet_header_t for details.
-  uint16_t header;
-  uint16_t reserved0;  // must be 0
-  uint32_t reserved1;  // must be 0
-  // Handles for dependent signaling objects to be evaluated by the packet
-  // processor. Any 0 (null) handles are ignored.
-  iree_hsa_signal_t dep_signal[5];
-  uint64_t reserved2;
-  // Signal to decrement when all dep_signals are satisfied.
-  iree_hsa_signal_t completion_signal;
-} iree_hsa_barrier_and_packet_t;
-
-// Barrier-OR (2.9.9 in the spec).
-// Waits until any one dep_signal reaches the value 0 and then decrements the
-// completion_signal. Ignores any 0 (null) signals.
-//
-// Pseudo-code:
-//   do {
-//     for (int i = 0; i < 5; ++i) {
-//       if (iree_hsa_signal_load(dep_signal[i]) == 0) break;
-//     }
-//     iree_amdgpu_yield();
-//   } while(true);
-//   iree_hsa_signal_subtract(completion_signal, 1);
-//
-// The acquire fence is processed first in the completion phase of the packet
-// after the barrier condition has been met. The release fence is processed
-// after the acquire fence in the completion phase.
-typedef struct iree_hsa_barrier_or_packet_t {
-  // AQL packet header. See iree_hsa_packet_header_t for details.
-  uint16_t header;
-  uint16_t reserved0;  // must be 0
-  uint32_t reserved1;  // must be 0
-  // Handles for dependent signaling objects to be evaluated by the packet
-  // processor. Any 0 (null) handles are ignored.
-  iree_hsa_signal_t dep_signal[5];
-  uint64_t reserved2;  // must be 0
-  // Signal to decrement when any dep_signal is satisfied.
-  iree_hsa_signal_t completion_signal;
-} iree_hsa_barrier_or_packet_t;
-
-typedef enum {
-  // iree_hsa_amd_barrier_value_packet_t
-  IREE_HSA_AMD_PACKET_TYPE_BARRIER_VALUE = 2,
-} iree_hsa_amd_packet_type_t;
-typedef uint8_t iree_hsa_amd_packet_type8_t;
-
-// Prefix of AMD-specific vendor packets.
-typedef struct iree_hsa_amd_packet_header_t {
-  // AQL packet header. See iree_hsa_packet_header_t for details.
-  uint16_t header;
-  // Secondary type indicating which AMD-specific packet this is.
-  iree_hsa_amd_packet_type8_t AmdFormat;
-  uint8_t reserved;  // must be 0
-} iree_hsa_amd_vendor_packet_header_t;
-
-// Barrier value extension.
-// Halts packet processing and waits for `(signal_value & mask) cond value` to
-// be satisfied before decrementing the completion_signal.
-//
-// Pseudo-code:
-//   do {
-//     if (iree_hsa_evaluate_signal_condition(
-//         /*condition=*/cond,
-//         /*current_value=*/(iree_hsa_signal_load(signal) & mask),
-//         /*desired_value=*/value)) {
-//       break;
-//     }
-//     iree_amdgpu_yield();
-//   } while(true);
-//   iree_hsa_signal_subtract(completion_signal, 1);
-//
-// The acquire fence is processed first in the completion phase of the packet
-// after the barrier condition has been met. The release fence is processed
-// after the acquire fence in the completion phase.
-typedef struct iree_hsa_amd_barrier_value_packet_t {
-  // AMD vendor-specific packet header.
-  iree_hsa_amd_vendor_packet_header_t header;
-  uint32_t reserved0;  // must be 0
-  // Dependent signal object. A 0 (null) signal will be treated as satisfied.
-  iree_hsa_signal_t signal;
-  // Value to compare the signal against (no mask applied).
-  iree_hsa_signal_value_t value;
-  // Bitmask applied to the current signal value.
-  iree_hsa_signal_value_t mask;
-  // Comparison operation.
-  iree_hsa_signal_condition32_t cond;
-  uint32_t reserved1;  // must be 0
-  uint64_t reserved2;  // must be 0
-  uint64_t reserved3;  // must be 0
-  // Signal to decrement when any dep_signal is satisfied.
-  iree_hsa_signal_t completion_signal;
-} iree_hsa_amd_barrier_value_packet_t;
-
-//===----------------------------------------------------------------------===//
-// iree_amdgpu_kernel_implicit_args_t
-//===----------------------------------------------------------------------===//
-
-// Implicit kernel arguments passed to OpenCL/HIP kernels that use them.
-// Not all kernels require this and the metadata needs to be checked to detect
-// its use (or if the total kernargs size is > what we think it should be).
-// Layout-wise explicit args always start at offset 0 and implicit args follow
-// those with 8-byte alignment.
-//
-// The metadata will contain exact fields and offsets and most driver code will
-// carefully walk to detect, align, pad, and write each field:
-// OpenCL/HIP: (`amd::KernelParameterDescriptor`...)
-// https://github.com/ROCm/clr/blob/5da72f9d524420c43fe3eee44b11ac875d884e0f/rocclr/device/rocm/rocvirtual.cpp#L3197
-//
-// This complex construction was required once upon a time. The LLVM code
-// producing the kernargs layout and metadata handles these cases much more
-// simply by only ever truncating the implicit args at the last used field:
-// https://github.com/llvm/llvm-project/blob/7f1b465c6ae476e59dc90652d58fc648932d23b1/llvm/lib/Target/AMDGPU/AMDGPUHSAMetadataStreamer.cpp#L389
-//
-// Then at some point in time someone was like "meh, who cares about optimizing"
-// and decided to include all of them always 🤦:
-// https://github.com/llvm/llvm-project/blob/7f1b465c6ae476e59dc90652d58fc648932d23b1/llvm/lib/Target/AMDGPU/AMDGPUSubtarget.cpp#L299
-//
-// What this means in practice is that if any implicit arg is used then all will
-// be included and declared in the metadata even if only one is actually read by
-// the kernel -- there's no way for us to know. In the ideal case none of them
-// are read and the kernel function gets the `amdgpu-no-implicitarg-ptr` attr
-// so that all of them can be skipped. Otherwise we reserve the 256 bytes and
-// just splat them all in. This at least keeps our code simple relative to all
-// the implementations that enumerate the metadata and write args one at a time.
-// We really should try to force `amdgpu-no-implicitarg-ptr` when we generate
-// code, though.
-//
-// For our bare-metal C runtime device code we have total freedom and don't use
-// any OpenCL/HIP-related things that would emit the implicit args.
-typedef struct IREE_AMDGPU_ALIGNAS(8) iree_amdgpu_kernel_implicit_args_t {
-  // Grid dispatch workgroup count.
-  // Some languages, such as OpenCL, support a last workgroup in each
-  // dimension being partial. This count only includes the non-partial
-  // workgroup count. This is not the same as the value in the AQL dispatch
-  // packet, which has the grid size in workitems.
-  //
-  // Represented in metadata as:
-  //   hidden_block_count_x
-  //   hidden_block_count_y
-  //   hidden_block_count_z
-  uint32_t block_count[3];  // + 0/4/8
-
-  // Grid dispatch workgroup size.
-  // This size only applies to the non-partial workgroups. This is the same
-  // value as the AQL dispatch packet workgroup size.
-  //
-  // Represented in metadata as:
-  //   hidden_group_size_x
-  //   hidden_group_size_y
-  //   hidden_group_size_z
-  uint16_t group_size[3];  // + 12/14/16
-
-  // Grid dispatch work group size of the partial work group, if it exists.
-  // Any dimension that does not exist must be 0. Only used in OpenCL and can
-  // be 0.
-  //
-  // Represented in metadata as:
-  //   hidden_remainder_x
-  //   hidden_remainder_y
-  //   hidden_remainder_z
-  uint16_t remainder[3];  // + 18/20/22
-
-  uint64_t reserved0;  // + 24 hidden_tool_correlation_id
-  uint64_t reserved1;  // + 32
-
-  // OpenCL grid dispatch global offset.
-  // Always 0 in HIP but still required as the device library functions for
-  // grid locations is shared with OpenCL and unconditionally factors it in.
-  //
-  // Hardcoded to 0 in HIP:
-  // https://github.com/ROCm/clr/blob/a2550e0a9ecaa8f371cb14d08904c51874c37cbe/hipamd/src/hip_module.cpp#L348
-  //
-  // Represented in metadata as:
-  //   hidden_global_offset_x
-  //   hidden_global_offset_y
-  //   hidden_global_offset_z
-  uint64_t global_offset[3];  // + 40/48/56
-
-  // Grid dispatch dimensionality. This is the same value as the AQL
-  // dispatch packet dimensionality. Must be a value between 1 and 3.
-  //
-  // Hardcoded to 3 in HIP:
-  // https://github.com/ROCm/clr/blob/a2550e0a9ecaa8f371cb14d08904c51874c37cbe/hipamd/src/hip_module.cpp#L349
-  //
-  // Represented in metadata as:
-  //   hidden_grid_dims
-  uint16_t grid_dims;  // + 64
-
-  // Fixed-size buffer for `-mprintf-kind=buffered` support.
-  // By default LLVM uses `hostcall` but that's a mess and we avoid it.
-  // `__printf_alloc` in the device library is used to grab this pointer, the
-  // header DWORDs are manipulated, and the contents are written to the buffer.
-  //
-  // struct {
-  //   atomic_uint32_t offset;
-  //   uint32_t size;
-  //   uint8_t data[size];
-  // } printf_buffer_t;
-  //
-  // One of many disappointing parts of this scheme is that constant string
-  // values are interned, MD5 hashed, and stored *externally* in the amdhsa data
-  // blob. In order to print with any constant format string this data blob
-  // needs to be parsed, retained, and referenced every time a printf packet is
-  // processed. It would have been significantly better to embed the table in
-  // the ELF as a global constant instead as then we could reference it on both
-  // host and device and not need to parse the amdhsa blob.
-  //
-  // The contents of the data buffer are best defined by the janky parser code:
-  // https://github.com/ROCm/clr/blob/a2550e0a9ecaa8f371cb14d08904c51874c37cbe/rocclr/device/rocm/rocprintf.cpp#L454
-  // Each printf consists of a control DWORD followed by 8-byte aligned
-  // contents. Effectively:
-  // struct {
-  //   uint32_t is_stderr : 1;       // else stdout
-  //   uint32_t constant : 1;        // constant format string code path
-  //   uint32_t size_in_bytes : 30;  // (including this header)
-  //   uint64_t data[size_in_bytes / 8];
-  // } printf_packet_t;
-  //
-  // To construct the full format data buffer if constant == 1:
-  //  data[0] contains the lower 64-bits of the MD5 hash of the string followed
-  //  by size_in_bytes-12 arguments. The data buffer needs to be expanded into
-  //  an 8-byte aligned NUL-terminated string with the corresponding hash
-  //  followed by the arguments verbatim. Once reconstituted the subsequent
-  //  logic is the same.
-  //
-  // The data buffer is an 8-byte aligned NUL-terminated string followed by
-  // the argument data. E.g. `hi! %s` would be encoded as `hi! %s` 0x00 0x??
-  // (with the last byte being padding to an 8-byte boundary). The reference
-  // code for formatting the string lives in the CLR:
-  // https://github.com/ROCm/clr/blob/a2550e0a9ecaa8f371cb14d08904c51874c37cbe/rocclr/device/devhcprintf.cpp#L168
-  // Note that the documentation is incorrect about there being a version prefix
-  // and it expects the first uint64_t to contain the format string bytes.
-  //
-  // Note that in another disappointing display of rube-goldbergian development
-  // this implementation for some reason uses uint64_t for its data elements
-  // but never aligns it - meaning that consumer code must use unaligned loads
-  // in order to read the data. The CLR just copies it out each time. One could
-  // think that was for streaming (release the buffer contents early back to
-  // dispatches) but since they fully halt the world and synchronize after every
-  // dispatch containing a print none of that matters and it's just poor
-  // engineering.
-  //
-  // The compiler emits strings in the delimited form of
-  // `"0:0:<format_string_hash>,<actual_format_string>"`. Note that the first
-  // two values should always be 0 and are delimited by `:` while the MD5 hash
-  // is delimited from the format string itself by `,`. There's some special
-  // handling in the CLR for `:` being in the format string because whoever
-  // wrote it did a find from the end instead of a prefix consume - there's
-  // special handling of \72 (`:`) and other weird things that I'm not sure is
-  // needed. Example from LLVM: `"0:0:8addc4c0362218ac,Hello World!:\n"`.
-  //
-  // The hash is the lower 64 bits of the MD5 hash in hex but we don't care as
-  // it's just a semi-unique value we use to lookup the string formats. On load
-  // we sort and do a binary search instead of creating an std::map for every
-  // single print invocation like the CLR does. Just... wow.
-  //
-  // Handling the contents is also overtly complicated and poorly documented:
-  // https://github.com/ROCm/clr/blob/a2550e0a9ecaa8f371cb14d08904c51874c37cbe/rocclr/device/devhcprintf.cpp#L168
-  //
-  // See:
-  // https://github.com/ROCm/llvm-project/commit/631c965483e03355cdc1dba578e787b259c4d79d
-  // https://github.com/ROCm/llvm-project/blob/997363823fcc5ccc7b0cc572aad05ba08714bf5f/amd/device-libs/ockl/src/cprintf.cl#L17
-  // https://github.com/ROCm/clr/blob/a2550e0a9ecaa8f371cb14d08904c51874c37cbe/rocclr/device/rocm/rocprintf.cpp#L393
-  //
-  // Note that having a printf in a kernel causes the kernel to dispatch
-  // synchronously :facepalm:. We can't do the same and would need to emit
-  // flush packets (or something) into the control queue. What a mess.
-  // https://github.com/ROCm/clr/blob/a2550e0a9ecaa8f371cb14d08904c51874c37cbe/rocclr/device/rocm/rocvirtual.cpp#L3644
-  // https://github.com/ROCm/clr/blob/a2550e0a9ecaa8f371cb14d08904c51874c37cbe/rocclr/device/rocm/rocprintf.cpp#L428-L429
-  //
-  // Represented in metadata as:
-  //   hidden_printf_buffer
-  void* printf_buffer;  // + 72
-
-  // Used for ASAN, printf, and more modern device memory allocations.
-  // It's bizarre and only "documented" in code and I really hope we don't have
-  // to touch it. Note that due to some LLVM bug sometimes this will be included
-  // in the offset table for a kernel even if it is not used (the
-  // `amdgpu-no-hostcall-ptr` attribute is set). At this point I'm quite sure no
-  // one has ever actually inspected the files produced by the LLVM backend.
-  //
-  // Represented in metadata as:
-  //   hidden_hostcall_buffer
-  void* hostcall_buffer;  // + 80
-
-  // Multi-grid support was deprecated in ROCM 5.x and should never appear in
-  // any program we generate ourselves or care about running.
-  //
-  // Represented in metadata as:
-  //   hidden_multigrid_sync_arg
-  uint64_t deprecated_multigrid_sync_arg;
-
-  // Device memory heap pointer for device malloc/free.
-  // We don't support kernels using this as it requires too much goo for little
-  // payoff. The kernels we run shouldn't be malloc/freeing internally. If they
-  // do we will need to implement the heap API via hostcalls and other silly
-  // things that add a tremendous amount of complexity.
-  //
-  // See:
-  // https://github.com/ROCm/llvm-project/blob/97753eeaa4c79c2db2dcd9f37b7989596a8d4f15/amd/device-libs/ockl/src/dm.cl#L192
-  //
-  // Represented in metadata as:
-  //   hidden_heap_v1
-  uint64_t unused_heap_v1;
-
-  // AQL queue handles are only used by OpenCL device-side enqueue and we do not
-  // support that. We could, probably, by passing in our execution queue but
-  // since HIP has never supported it the use case doesn't exist. If we wanted
-  // to support device-enqueue we'd do it in a structured fashion instead of
-  // letting kernels splat right into the AQL queue.
-  //
-  // See:
-  // https://github.com/ROCm/llvm-project/blob/97753eeaa4c79c2db2dcd9f37b7989596a8d4f15/amd/device-libs/opencl/src/devenq/enqueue.cl#L310
-  //
-  // Represented in metadata as:
-  //   hidden_default_queue
-  uint64_t unused_default_queue;
-
-  // Completion actions were (I believe) an attempt at dynamic parallelism and
-  // HIP has never supported them. Device-side enqueue in OpenCL uses this but
-  // we don't support those kernels.
-  //
-  // See:
-  // https://github.com/ROCm/llvm-project/blob/97753eeaa4c79c2db2dcd9f37b7989596a8d4f15/amd/device-libs/opencl/src/devenq/enqueue.cl#L311
-  //
-  // Represented in metadata as:
-  //   hidden_completion_action
-  uint64_t unused_completion_action;
-
-  // The value of the sharedMemBytes parameter to the dispatch indicating how
-  // much dynamic shared memory was reserved for the kernel. This may be larger
-  // than the requested amount. The total group_segment_size for a dispatch is
-  // the static LDS requirement of the kernel plus this value.
-  //
-  // Represented in metadata as:
-  //   hidden_dynamic_lds_size
-  uint32_t dynamic_lds_size;
-
-  uint8_t reserved[68];
-
-  // Only used by GFX8, which we don't support.
-  //
-  // Represented in metadata as:
-  //   hidden_private_base
-  uint32_t deprecated_private_base;
-
-  // Only used by GFX8, which we don't support.
-  //
-  // Represented in metadata as:
-  //   hidden_shared_base
-  uint32_t deprecated_shared_base;
-
-  // AQL queue the dispatch is running on.
-  // Only used by pre-GFX9 devices, which we don't support.
-  //
-  // Represented in metadata as:
-  //   hidden_queue_ptr;
-  iree_hsa_queue_t* deprecated_queue_ptr;
-} iree_amdgpu_kernel_implicit_args_t;
-
-#define IREE_AMDGPU_KERNEL_IMPLICIT_ARGS_SIZE               \
-  (IREE_AMDGPU_OFFSETOF(iree_amdgpu_kernel_implicit_args_t, \
-                        dynamic_lds_size) +                 \
-   sizeof(((iree_amdgpu_kernel_implicit_args_t*)NULL)->dynamic_lds_size))
-
-//===----------------------------------------------------------------------===//
-// OpenCL/HIP Dispatch ABI
-//===----------------------------------------------------------------------===//
-// These come from llvm-project/amd/device-libs/ockl/src/workitem.cl (the ockl
-// functions) and llvm-project/clang/lib/CodeGen/CGBuiltin.cpp (e.g.
-// EmitAMDGPUWorkGroupSize). Using either runs a chance of pulling in the
-// entire iree_amdgpu_kernel_implicit_args_t struct and we don't want to set
-// that. We also don't need it: we aren't requiring OpenCL compatibility and
-// have no need for the extra features provided by the implicit args (like
-// workgroup offset and device-side enqueue - that's our job).
-
-#if defined(IREE_AMDGPU_TARGET_DEVICE)
-
-// Returns the pointer to the iree_hsa_kernel_dispatch_packet_t being executed.
-#define iree_amdgcn_dispatch_ptr()                                 \
-  ((const iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT) \
-       __builtin_amdgcn_dispatch_ptr())
-
-// __ockl_get_global_id(0) / get_global_id_x using OLD_ABI
-static inline IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE size_t
-iree_hal_amdgpu_device_global_id_x(void) {
-  const uint32_t local_id = __builtin_amdgcn_workitem_id_x();
-  const uint32_t group_id = __builtin_amdgcn_workgroup_id_x();
-  const uint32_t group_size = iree_amdgcn_dispatch_ptr()->workgroup_size[0];
-  return (group_id * group_size + local_id);
-}
-static inline IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE size_t
-iree_hal_amdgpu_device_global_id_y(void) {
-  const uint32_t local_id = __builtin_amdgcn_workitem_id_y();
-  const uint32_t group_id = __builtin_amdgcn_workgroup_id_y();
-  const uint32_t group_size = iree_amdgcn_dispatch_ptr()->workgroup_size[1];
-  return (group_id * group_size + local_id);
-}
-static inline IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE size_t
-iree_hal_amdgpu_device_global_id_z(void) {
-  const uint32_t local_id = __builtin_amdgcn_workitem_id_z();
-  const uint32_t group_id = __builtin_amdgcn_workgroup_id_z();
-  const uint32_t group_size = iree_amdgcn_dispatch_ptr()->workgroup_size[2];
-  return (group_id * group_size + local_id);
-}
-
-// __ockl_get_group_id(0)
-#define iree_hal_amdgpu_device_group_id_x() __builtin_amdgcn_workgroup_id_x()
-#define iree_hal_amdgpu_device_group_id_y() __builtin_amdgcn_workgroup_id_y()
-#define iree_hal_amdgpu_device_group_id_z() __builtin_amdgcn_workgroup_id_z()
-
-// __ockl_get_num_groups(0)
-static inline IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE size_t
-iree_hal_amdgpu_device_group_count_x(void) {
-  const uint32_t grid_size = iree_amdgcn_dispatch_ptr()->grid_size[0];
-  const uint32_t group_size = iree_amdgcn_dispatch_ptr()->workgroup_size[0];
-  const uint32_t q = grid_size / group_size;
-  return q + (grid_size > q * group_size);
-}
-static inline IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE size_t
-iree_hal_amdgpu_device_group_count_y(void) {
-  const uint32_t grid_size = iree_amdgcn_dispatch_ptr()->grid_size[1];
-  const uint32_t group_size = iree_amdgcn_dispatch_ptr()->workgroup_size[1];
-  const uint32_t q = grid_size / group_size;
-  return q + (grid_size > q * group_size);
-}
-static inline IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE size_t
-iree_hal_amdgpu_device_group_count_z(void) {
-  const uint32_t grid_size = iree_amdgcn_dispatch_ptr()->grid_size[2];
-  const uint32_t group_size = iree_amdgcn_dispatch_ptr()->workgroup_size[2];
-  const uint32_t q = grid_size / group_size;
-  return q + (grid_size > q * group_size);
-}
-
-// __ockl_get_local_id(0)
-#define iree_hal_amdgpu_device_local_id_x() __builtin_amdgcn_workitem_id_x()
-#define iree_hal_amdgpu_device_local_id_y() __builtin_amdgcn_workitem_id_y()
-#define iree_hal_amdgpu_device_local_id_z() __builtin_amdgcn_workitem_id_z()
-
-// __ockl_get_local_size(0) / get_local_size_x
-static inline IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE size_t
-iree_hal_amdgpu_device_workgroup_size_x(void) {
-  const uint32_t group_id = __builtin_amdgcn_workgroup_id_x();
-  const uint32_t group_size = iree_amdgcn_dispatch_ptr()->workgroup_size[0];
-  const uint32_t grid_size = iree_amdgcn_dispatch_ptr()->grid_size[0];
-  const uint32_t r = grid_size - group_id * group_size;
-  return (r < group_size) ? r : group_size;
-}
-static inline IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE size_t
-iree_hal_amdgpu_device_workgroup_size_y(void) {
-  const uint32_t group_id = __builtin_amdgcn_workgroup_id_y();
-  const uint32_t group_size = iree_amdgcn_dispatch_ptr()->workgroup_size[1];
-  const uint32_t grid_size = iree_amdgcn_dispatch_ptr()->grid_size[1];
-  const uint32_t r = grid_size - group_id * group_size;
-  return (r < group_size) ? r : group_size;
-}
-static inline IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE size_t
-iree_hal_amdgpu_device_workgroup_size_z(void) {
-  const uint32_t group_id = __builtin_amdgcn_workgroup_id_z();
-  const uint32_t group_size = iree_amdgcn_dispatch_ptr()->workgroup_size[2];
-  const uint32_t grid_size = iree_amdgcn_dispatch_ptr()->grid_size[2];
-  const uint32_t r = grid_size - group_id * group_size;
-  return (r < group_size) ? r : group_size;
-}
-
-static inline IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE size_t
-iree_hal_amdgpu_device_global_linear_id_1d(void) {
-  return iree_hal_amdgpu_device_group_id_x() *
-             iree_amdgcn_dispatch_ptr()->workgroup_size[0] +
-         iree_hal_amdgpu_device_local_id_x();
-}
-
-static inline IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE size_t
-iree_hal_amdgpu_device_global_linear_id_2d(void) {
-  const size_t id_x = iree_hal_amdgpu_device_group_id_x() *
-                          iree_amdgcn_dispatch_ptr()->workgroup_size[0] +
-                      iree_hal_amdgpu_device_local_id_x();
-  const size_t id_y = iree_hal_amdgpu_device_group_id_y() *
-                          iree_amdgcn_dispatch_ptr()->workgroup_size[1] +
-                      iree_hal_amdgpu_device_local_id_y();
-  return id_y * iree_amdgcn_dispatch_ptr()->grid_size[0] + id_x;
-}
-
-static inline IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE size_t
-iree_hal_amdgpu_device_global_linear_id_3d(void) {
-  const size_t id_x = iree_hal_amdgpu_device_group_id_x() *
-                          iree_amdgcn_dispatch_ptr()->workgroup_size[0] +
-                      iree_hal_amdgpu_device_local_id_x();
-  const size_t id_y = iree_hal_amdgpu_device_group_id_y() *
-                          iree_amdgcn_dispatch_ptr()->workgroup_size[1] +
-                      iree_hal_amdgpu_device_local_id_y();
-  const size_t id_z = iree_hal_amdgpu_device_group_id_z() *
-                          iree_amdgcn_dispatch_ptr()->workgroup_size[2] +
-                      iree_hal_amdgpu_device_local_id_z();
-  return (id_z * iree_amdgcn_dispatch_ptr()->grid_size[1] + id_y) *
-             iree_amdgcn_dispatch_ptr()->grid_size[0] +
-         id_x;
-}
-
-#endif  // IREE_AMDGPU_TARGET_DEVICE
-
-//===----------------------------------------------------------------------===//
 // Device Library Functions
 //===----------------------------------------------------------------------===//
 // These are cloned from llvm-project/amd/device-libs/ockl/src/hsaqs.cl so that
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/support/signal.h b/runtime/src/iree/hal/drivers/amdgpu/device/support/signal.h
index 993a9d5..28fb943 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/device/support/signal.h
+++ b/runtime/src/iree/hal/drivers/amdgpu/device/support/signal.h
@@ -4,10 +4,9 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
-// NOTE: these structs/enums are taken from the HSA spec, the hsa.h and
-// hsa_ext_amd.h headers, and the LLVM AMDGPU device library headers.
-// We define them locally as the HSA headers cannot be directly used in
-// bare-metal C and the device libraries are only available in a fork of LLM.
+// Device-side signal manipulation functions built on top of the ABI types.
+// For the signal struct layout and type definitions see abi/signal.h (exported
+// below).
 //
 // Sources:
 // https://hsafoundation.com/wp-content/uploads/2021/02/HSA-SysArch-1.2.pdf
@@ -17,168 +16,10 @@
 #ifndef IREE_HAL_DRIVERS_AMDGPU_DEVICE_SUPPORT_SIGNAL_H_
 #define IREE_HAL_DRIVERS_AMDGPU_DEVICE_SUPPORT_SIGNAL_H_
 
+#include "iree/hal/drivers/amdgpu/abi/signal.h"  // IWYU pragma: export
 #include "iree/hal/drivers/amdgpu/device/support/common.h"
 
 //===----------------------------------------------------------------------===//
-// HSA/AMDGPU Signal
-//===----------------------------------------------------------------------===//
-
-#if defined(IREE_AMDGPU_TARGET_DEVICE)
-
-// "Opaque" reference to an iree_amd_signal_t*.
-// A value of 0 indicates a no-op signal (waits will succeed immediately and
-// completions will no-op).
-typedef struct iree_hsa_signal_t {
-  uint64_t handle;
-} iree_hsa_signal_t;
-
-#else
-
-typedef hsa_signal_t iree_hsa_signal_t;
-
-#endif  // IREE_AMDGPU_TARGET_DEVICE
-
-// No-op signal that will immediately succeed when waited on and be ignored when
-// signaling.
-#define iree_hsa_signal_null() (iree_hsa_signal_t){0}
-
-// Returns true if the given signal is null.
-#define iree_hsa_signal_is_null(signal) ((signal).handle == 0)
-
-// Value of a signal.
-// The interpretation of this is dependent on the operation consuming it.
-// With barrier value packets it's user-defined and can be any value.
-// With barrier-and/barrier-or and dispatch packets it acts as a semaphore where
-// a 0 value indicates set and a non-zero value indicates unset. For example,
-// if 3 operations are required to complete before another can proceed it should
-// be set to 3, included as the completion_signal for the 3 operations, and
-// used as the dependent signal in a barrier. As each operation completes it
-// will decrement the value and when it reaches 0 the barrier will succeed and
-// allow the dependent operation to execute.
-typedef int64_t iree_hsa_signal_value_t;
-
-// AMD signal kind.
-enum iree_amd_signal_kind_t {
-  // Unassigned (not seen).
-  IREE_AMD_SIGNAL_KIND_INVALID = 0,
-  // User-defined signal that supports all signal operations.
-  IREE_AMD_SIGNAL_KIND_USER = 1,
-  // Agent-defined doorbell (usually the queue's doorbell_signal field).
-  // Only writes are permitted from any agent other than the origin and for our
-  // purposes that means no writes ever. Soft queues created by the user must
-  // use IREE_AMD_SIGNAL_KIND_USER as this is reserved for hardware.
-  IREE_AMD_SIGNAL_KIND_DOORBELL = -1,
-};
-// AMD signal kind.
-typedef int64_t iree_amd_signal_kind64_t;
-
-// AMDGPU signal implementation.
-// This is an implementation detail from the perspective of the HSA spec but a
-// stable interface to the current generations of hardware implementing HSA.
-// Signals are just locations in memory and have no special behavior other than
-// how they are initialized. For our purposes there are two types: USER and
-// DOORBELL.
-//
-// Signal values depend on the producer/consumer operations. See
-// `iree_hsa_signal_value_t` for more information.
-//
-// Doorbell signals are firmware/hardware-specific and must only be written to
-// by the host and other agents (that means no waiting either, as that's a
-// read). Only the hardware queues as allocated by the HSA implementation should
-// set these.
-//
-// User signals as presented to the hardware via `iree_amd_signal_t` are like
-// futices: allocating memory accessible to a set of agents and populating it
-// is enough to create and use the signal and (so long as it's not used
-// afterward) deleting it is just freeing the memory. Special behavior only
-// comes with host interaction: using any host HSA API (`hsa_signal_store_*`,
-// `hsa_signal_wait_*`, etc) is only possible with signals allocated via either
-// `hsa_signal_create` or `hsa_amd_signal_create` as those functions cast to an
-// internal ROCR `Signal` interface. If the signal will only ever be used by our
-// device code, the hardware queues, or our own host code not using the HSA APIs
-// then we don't need to use signals created by HSA. When we do need to interact
-// with the APIs the signals are implemented by two types: busy-wait and
-// interrupt (as implemented in ROCR by `BusyWaitSignal` and `InterruptSignal`).
-// Busy-wait are like a futex and _mostly_ exist entirely in user-mode.
-// Interrupt are the same but with an additional platform event handle so that
-// `hsaKmtWaitOnEvent` and other kernel-level waits can be performed. For such
-// signals the platform event as returned by `hsaKmtCreateEvent` is stored in
-// the `event_mailbox_ptr` and the value to post is `event_id`. I suspect in
-// modern implementations that could be removed as they could be implemented
-// with a futex when in-process and then the full platform handles would be
-// reserved for IPC.
-//
-// Timestamps on the signal are set by the agent processing the operation.
-// `start_ts` is set when the packet enters the active phase and `end_ts` is set
-// when it completes. These timestamps are in agent-specific ticks and need to
-// be translated into system-scope by scaling by relative frequencies of the
-// system and the particular agent by
-// `hsa_amd_profiling_convert_tick_to_system_domain` that handles the scaling.
-// At its core that method occasionally queries the base timestamps and
-// frequencies of the agents (as they may change over time) and the
-// resynchronization accounts for drift. In order to resolve timestamps fully
-// on-device we do the same thing by polling `AMDKFD_IOC_GET_CLOCK_COUNTERS`
-// and providing it to the device runtime. Every time the clocks are resynced
-// there's the potential for a discontinuity/backwards rolling timestamps so
-// we try to only do it per-submission to at least keep all of the times within
-// relatively aligned even if the entire submission may have drifted from the
-// system by the end. Note that because work can happen out-of-order the
-// timestamps on a set of signals may be out-of-order with respect to the system
-// time once resolved and anything using the timestamps needs to handle that or
-// unset the CONCURRENT execution flag on the queue.
-typedef struct IREE_AMDGPU_ALIGNAS(64) iree_amd_signal_t {
-  iree_amd_signal_kind64_t kind;
-  union {
-    volatile iree_hsa_signal_value_t value;
-    volatile uint64_t* hardware_doorbell_ptr;
-  };
-  uint64_t event_mailbox_ptr;
-  uint32_t event_id;
-  uint32_t reserved1;
-  iree_amdgpu_device_tick_t start_ts;
-  iree_amdgpu_device_tick_t end_ts;
-  struct iree_amd_queue_s* queue_ptr;
-  uint32_t reserved3[2];
-} iree_amd_signal_t;
-
-// Wait condition operation.
-typedef uint32_t iree_hsa_signal_condition32_t;
-typedef enum {
-  // The two operands are equal.
-  IREE_HSA_SIGNAL_CONDITION_EQ = 0,
-  // The two operands are not equal.
-  IREE_HSA_SIGNAL_CONDITION_NE = 1,
-  // The first operand is less than the second operand.
-  IREE_HSA_SIGNAL_CONDITION_LT = 2,
-  // The first operand is greater than or equal to the second operand.
-  IREE_HSA_SIGNAL_CONDITION_GTE = 3
-} iree_hsa_signal_condition_t;
-// Wait condition operation.
-
-//===----------------------------------------------------------------------===//
-// HSA Signal Utilities
-//===----------------------------------------------------------------------===//
-
-// Returns true if the given |current_signal| value matches the expected
-// |desired_value| as defined by |condition|.
-static IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE inline bool
-iree_hsa_evaluate_signal_condition(iree_hsa_signal_condition32_t condition,
-                                   iree_hsa_signal_value_t current_value,
-                                   iree_hsa_signal_value_t desired_value) {
-  switch (condition) {
-    default:
-    case IREE_HSA_SIGNAL_CONDITION_EQ:
-      return current_value == desired_value;
-    case IREE_HSA_SIGNAL_CONDITION_NE:
-      return current_value != desired_value;
-    case IREE_HSA_SIGNAL_CONDITION_LT:
-      return current_value < desired_value;
-    case IREE_HSA_SIGNAL_CONDITION_GTE:
-      return current_value >= desired_value;
-  }
-}
-
-//===----------------------------------------------------------------------===//
 // Device Library Functions
 //===----------------------------------------------------------------------===//
 // These are cloned from llvm-project/amd/device-libs/ockl/src/hsaqs.cl so that
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/timestamp.c b/runtime/src/iree/hal/drivers/amdgpu/device/timestamp.c
new file mode 100644
index 0000000..4873ea7
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/device/timestamp.c
@@ -0,0 +1,58 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/device/timestamp.h"
+
+#include "iree/hal/drivers/amdgpu/device/support/kernel.h"
+
+//===----------------------------------------------------------------------===//
+// Dispatch timestamp harvest
+//===----------------------------------------------------------------------===//
+
+iree_hal_amdgpu_dispatch_timestamp_harvest_source_t*
+iree_hal_amdgpu_device_timestamp_emplace_dispatch_harvest(
+    const iree_hal_amdgpu_device_kernel_args_t* IREE_AMDGPU_RESTRICT
+        harvest_kernel_args,
+    uint32_t source_count,
+    iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT dispatch_packet,
+    void* IREE_AMDGPU_RESTRICT kernarg_ptr) {
+  iree_hal_amdgpu_dispatch_timestamp_harvest_args_t* IREE_AMDGPU_RESTRICT
+      kernargs =
+          (iree_hal_amdgpu_dispatch_timestamp_harvest_args_t*)kernarg_ptr;
+  iree_hal_amdgpu_dispatch_timestamp_harvest_source_t* sources =
+      iree_hal_amdgpu_device_timestamp_dispatch_harvest_sources(kernarg_ptr);
+  kernargs->sources = sources;
+  kernargs->source_count = source_count;
+  kernargs->reserved0 = 0;
+
+  const uint32_t harvest_workgroup_size =
+      harvest_kernel_args->workgroup_size[0];
+  const uint32_t harvest_workgroup_count[3] = {
+      (uint32_t)IREE_AMDGPU_CEIL_DIV(source_count, harvest_workgroup_size), 1,
+      1};
+  iree_hal_amdgpu_device_dispatch_emplace_packet(
+      harvest_kernel_args, harvest_workgroup_count,
+      /*dynamic_workgroup_local_memory=*/0, dispatch_packet, kernarg_ptr);
+  return sources;
+}
+
+#if defined(IREE_AMDGPU_TARGET_DEVICE)
+
+IREE_AMDGPU_ATTRIBUTE_KERNEL void
+iree_hal_amdgpu_device_timestamp_harvest_dispatch_records(
+    const iree_hal_amdgpu_dispatch_timestamp_harvest_source_t*
+        IREE_AMDGPU_RESTRICT sources,
+    uint32_t source_count) {
+  const size_t source_index = iree_hal_amdgpu_device_global_linear_id_1d();
+  if (source_index >= source_count) return;
+
+  const iree_hal_amdgpu_dispatch_timestamp_harvest_source_t source =
+      sources[source_index];
+  source.ticks->start_tick = source.completion_signal->start_ts;
+  source.ticks->end_tick = source.completion_signal->end_ts;
+}
+
+#endif  // IREE_AMDGPU_TARGET_DEVICE
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/timestamp.h b/runtime/src/iree/hal/drivers/amdgpu/device/timestamp.h
new file mode 100644
index 0000000..7fb894b
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/device/timestamp.h
@@ -0,0 +1,72 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_DEVICE_TIMESTAMP_H_
+#define IREE_HAL_DRIVERS_AMDGPU_DEVICE_TIMESTAMP_H_
+
+#include "iree/hal/drivers/amdgpu/abi/timestamp.h"
+#include "iree/hal/drivers/amdgpu/device/dispatch.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Returns the byte offset of the harvest source table after the kernel args.
+static inline size_t
+iree_hal_amdgpu_device_timestamp_dispatch_harvest_source_offset(void) {
+  return iree_amdgpu_align(
+      sizeof(iree_hal_amdgpu_dispatch_timestamp_harvest_args_t),
+      IREE_AMDGPU_ALIGNOF(iree_hal_amdgpu_dispatch_timestamp_harvest_source_t));
+}
+
+// Returns the kernarg byte length required for |source_count| harvest sources.
+static inline size_t
+iree_hal_amdgpu_device_timestamp_dispatch_harvest_kernarg_length(
+    uint32_t source_count) {
+  return iree_hal_amdgpu_device_timestamp_dispatch_harvest_source_offset() +
+         (size_t)source_count *
+             sizeof(iree_hal_amdgpu_dispatch_timestamp_harvest_source_t);
+}
+
+// Returns the harvest source table embedded in |kernarg_ptr|.
+static inline iree_hal_amdgpu_dispatch_timestamp_harvest_source_t*
+iree_hal_amdgpu_device_timestamp_dispatch_harvest_sources(
+    void* IREE_AMDGPU_RESTRICT kernarg_ptr) {
+  uint8_t* source_ptr =
+      (uint8_t*)kernarg_ptr +
+      iree_hal_amdgpu_device_timestamp_dispatch_harvest_source_offset();
+  return (iree_hal_amdgpu_dispatch_timestamp_harvest_source_t*)source_ptr;
+}
+
+// Populates a builtin dispatch packet and kernargs that harvest timestamps from
+// dispatch completion signals into fixed binary timestamp records.
+//
+// |dispatch_packet| and |kernarg_ptr| must point to reserved queue storage.
+// The caller owns completion-signal assignment and header commit.
+iree_hal_amdgpu_dispatch_timestamp_harvest_source_t*
+iree_hal_amdgpu_device_timestamp_emplace_dispatch_harvest(
+    const iree_hal_amdgpu_device_kernel_args_t* IREE_AMDGPU_RESTRICT
+        harvest_kernel_args,
+    uint32_t source_count,
+    iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT dispatch_packet,
+    void* IREE_AMDGPU_RESTRICT kernarg_ptr);
+
+#if defined(IREE_AMDGPU_TARGET_DEVICE)
+
+// Device builtin that copies per-dispatch CP timestamps into timestamp records.
+IREE_AMDGPU_ATTRIBUTE_KERNEL void
+iree_hal_amdgpu_device_timestamp_harvest_dispatch_records(
+    const iree_hal_amdgpu_dispatch_timestamp_harvest_source_t*
+        IREE_AMDGPU_RESTRICT sources,
+    uint32_t source_count);
+
+#endif  // IREE_AMDGPU_TARGET_DEVICE
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_DEVICE_TIMESTAMP_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/timestamp_test.cc b/runtime/src/iree/hal/drivers/amdgpu/device/timestamp_test.cc
new file mode 100644
index 0000000..0370c4d
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/device/timestamp_test.cc
@@ -0,0 +1,119 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/device/timestamp.h"
+
+#include <array>
+#include <cstddef>
+#include <cstdint>
+
+#include "iree/hal/drivers/amdgpu/abi/profile.h"
+#include "iree/testing/gtest.h"
+
+namespace iree::hal::amdgpu {
+namespace {
+
+static iree_hal_amdgpu_device_kernel_args_t MakeHarvestKernelArgs() {
+  iree_hal_amdgpu_device_kernel_args_t kernel_args = {};
+  kernel_args.kernel_object = 0x12345678ull;
+  kernel_args.setup = 2;
+  kernel_args.workgroup_size[0] = 32;
+  kernel_args.workgroup_size[1] = 1;
+  kernel_args.workgroup_size[2] = 1;
+  kernel_args.kernarg_alignment = 16;
+  return kernel_args;
+}
+
+TEST(TimestampTest, AbiRecordLayoutIsFixed) {
+  EXPECT_EQ(sizeof(iree_hal_amdgpu_timestamp_range_t), 16u);
+  EXPECT_EQ(sizeof(iree_hal_amdgpu_timestamp_record_header_t), 16u);
+  EXPECT_EQ(sizeof(iree_hal_amdgpu_command_buffer_timestamp_record_t), 48u);
+  EXPECT_EQ(sizeof(iree_hal_amdgpu_dispatch_timestamp_record_t), 64u);
+  EXPECT_EQ(offsetof(iree_hal_amdgpu_command_buffer_timestamp_record_t, ticks),
+            32u);
+  EXPECT_EQ(offsetof(iree_hal_amdgpu_dispatch_timestamp_record_t, ticks), 48u);
+}
+
+TEST(TimestampTest, MakesRecordHeader) {
+  iree_hal_amdgpu_timestamp_record_header_t header = {};
+  header.record_length = sizeof(iree_hal_amdgpu_dispatch_timestamp_record_t);
+  header.version = IREE_HAL_AMDGPU_TIMESTAMP_RECORD_VERSION_0;
+  header.type = IREE_HAL_AMDGPU_TIMESTAMP_RECORD_TYPE_DISPATCH;
+  header.record_ordinal = 7;
+
+  EXPECT_EQ(header.record_length,
+            sizeof(iree_hal_amdgpu_dispatch_timestamp_record_t));
+  EXPECT_EQ(header.version, IREE_HAL_AMDGPU_TIMESTAMP_RECORD_VERSION_0);
+  EXPECT_EQ(header.type, IREE_HAL_AMDGPU_TIMESTAMP_RECORD_TYPE_DISPATCH);
+  EXPECT_EQ(header.record_ordinal, 7u);
+  EXPECT_EQ(header.reserved0, 0u);
+}
+
+TEST(TimestampTest, ComputesHarvestKernargLayout) {
+  EXPECT_EQ(iree_hal_amdgpu_device_timestamp_dispatch_harvest_source_offset(),
+            16u);
+  EXPECT_EQ(iree_hal_amdgpu_device_timestamp_dispatch_harvest_kernarg_length(0),
+            16u);
+  EXPECT_EQ(iree_hal_amdgpu_device_timestamp_dispatch_harvest_kernarg_length(3),
+            64u);
+}
+
+TEST(TimestampTest, ProfileDispatchHarvestUsesTimestampRangeTarget) {
+  EXPECT_EQ(sizeof(iree_hal_amdgpu_profile_dispatch_harvest_source_t),
+            sizeof(iree_hal_amdgpu_dispatch_timestamp_harvest_source_t));
+  EXPECT_EQ(sizeof(iree_hal_amdgpu_profile_dispatch_harvest_args_t),
+            sizeof(iree_hal_amdgpu_dispatch_timestamp_harvest_args_t));
+
+  iree_hal_amdgpu_profile_dispatch_event_t event = {};
+  iree_hal_amdgpu_timestamp_range_t* ticks =
+      iree_hal_amdgpu_profile_dispatch_event_ticks(&event);
+  EXPECT_EQ(reinterpret_cast<uintptr_t>(ticks),
+            reinterpret_cast<uintptr_t>(&event.start_tick));
+
+  ticks->start_tick = 11;
+  ticks->end_tick = 22;
+  EXPECT_EQ(event.start_tick, 11u);
+  EXPECT_EQ(event.end_tick, 22u);
+}
+
+TEST(TimestampTest, EmplacesDispatchHarvestPacketAndKernargs) {
+  iree_hal_amdgpu_device_kernel_args_t kernel_args = MakeHarvestKernelArgs();
+  iree_hsa_kernel_dispatch_packet_t packet = {};
+  packet.header = 0xFFFFu;
+  alignas(16) std::array<uint8_t, 256> kernargs = {};
+  const uint32_t source_count = 65;
+
+  iree_hal_amdgpu_dispatch_timestamp_harvest_source_t* sources =
+      iree_hal_amdgpu_device_timestamp_emplace_dispatch_harvest(
+          &kernel_args, source_count, &packet, kernargs.data());
+  const auto* args = reinterpret_cast<
+      const iree_hal_amdgpu_dispatch_timestamp_harvest_args_t*>(
+      kernargs.data());
+
+  EXPECT_EQ(args->sources, sources);
+  EXPECT_EQ(args->source_count, source_count);
+  EXPECT_EQ(args->reserved0, 0u);
+  EXPECT_EQ(
+      sources,
+      reinterpret_cast<iree_hal_amdgpu_dispatch_timestamp_harvest_source_t*>(
+          kernargs.data() +
+          iree_hal_amdgpu_device_timestamp_dispatch_harvest_source_offset()));
+
+  EXPECT_EQ(packet.header, 0xFFFFu);
+  EXPECT_EQ(packet.setup, 2u);
+  EXPECT_EQ(packet.workgroup_size[0], 32u);
+  EXPECT_EQ(packet.workgroup_size[1], 1u);
+  EXPECT_EQ(packet.workgroup_size[2], 1u);
+  EXPECT_EQ(packet.grid_size[0], 96u);
+  EXPECT_EQ(packet.grid_size[1], 1u);
+  EXPECT_EQ(packet.grid_size[2], 1u);
+  EXPECT_EQ(packet.kernel_object, 0x12345678ull);
+  EXPECT_EQ(packet.kernarg_address, kernargs.data());
+  EXPECT_EQ(packet.completion_signal.handle, iree_hsa_signal_null().handle);
+}
+
+}  // namespace
+}  // namespace iree::hal::amdgpu
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/tracing.c b/runtime/src/iree/hal/drivers/amdgpu/device/tracing.c
deleted file mode 100644
index 427ff75..0000000
--- a/runtime/src/iree/hal/drivers/amdgpu/device/tracing.c
+++ /dev/null
@@ -1,547 +0,0 @@
-// Copyright 2025 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/hal/drivers/amdgpu/device/tracing.h"
-
-// NOTE: this header in clang only declares the builtins for va_list-related
-// things - if it becomes an issue we can easily inline them here.
-#include <stdarg.h>
-
-#include "iree/hal/drivers/amdgpu/device/support/queue.h"
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_device_query_ringbuffer_t
-//===----------------------------------------------------------------------===//
-
-iree_hal_amdgpu_trace_execution_query_id_t
-iree_hal_amdgpu_device_query_ringbuffer_acquire(
-    iree_hal_amdgpu_device_query_ringbuffer_t* IREE_AMDGPU_RESTRICT
-        ringbuffer) {
-  // Slice off a single value and return it mapped into a query ID.
-  return (ringbuffer->write_index++) &
-         (IREE_AMDGPU_ARRAYSIZE(ringbuffer->signals) - 1);
-}
-
-uint64_t iree_hal_amdgpu_device_query_ringbuffer_acquire_range(
-    iree_hal_amdgpu_device_query_ringbuffer_t* IREE_AMDGPU_RESTRICT ringbuffer,
-    uint16_t count) {
-  // Slice off another chunk.
-  uint64_t base_index = ringbuffer->write_index;
-  ringbuffer->write_index += count;
-  return base_index;
-}
-
-void iree_hal_amdgpu_device_query_ringbuffer_release_range(
-    iree_hal_amdgpu_device_query_ringbuffer_t* IREE_AMDGPU_RESTRICT ringbuffer,
-    uint16_t count) {
-  // Reset all returned signals.
-  for (uint32_t i = ringbuffer->read_index; i < ringbuffer->read_index + count;
-       ++i) {
-    iree_amd_signal_t* signal =
-        &ringbuffer
-             ->signals[i & (IREE_AMDGPU_ARRAYSIZE(ringbuffer->signals) - 1)];
-    signal->value = 1;
-    signal->start_ts = 0;
-    signal->end_ts = 0;
-  }
-  ringbuffer->read_index += count;
-}
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_device_trace_buffer_t
-//===----------------------------------------------------------------------===//
-
-#if IREE_HAL_AMDGPU_TRACING_FEATURES
-
-// Initializes the signals in a trace buffer's query ringbuffer.
-// The ringbuffer memory must be zero-initialized by the allocator.
-// Must be called prior to acquiring any signals.
-IREE_AMDGPU_ATTRIBUTE_KERNEL void
-iree_hal_amdgpu_device_trace_buffer_initialize(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer) {
-  const size_t i = iree_hal_amdgpu_device_global_id_x();
-
-  // Initialize device trace buffer memory (once).
-  if (i == 0) {
-    trace_buffer->read_commit_offset = 0;
-    trace_buffer->write_reserve_offset = 0;
-    trace_buffer->write_commit_offset = 0;
-    trace_buffer->query_ringbuffer.read_index = 0;
-    trace_buffer->query_ringbuffer.write_index = 0;
-  }
-
-  // Initialize signal.
-  iree_amd_signal_t* signal = &trace_buffer->query_ringbuffer.signals[i];
-  signal->kind = IREE_AMD_SIGNAL_KIND_USER;
-  signal->value = 1;
-  signal->event_mailbox_ptr = 0;
-  signal->event_id = 0;
-  signal->reserved1 = 0;
-  signal->start_ts = 0;
-  signal->end_ts = 0;
-  signal->queue_ptr = 0;
-  signal->reserved3[0] = 0;
-  signal->reserved3[1] = 0;
-}
-
-// Reserves |length| bytes from the trace buffer and returns a pointer to it.
-// Callers must populate the entire packet prior to calling
-// iree_hal_amdgpu_device_trace_commit_range. Multiple reservations can be made
-// between commits to batch the commit logic (which usually involves a host
-// interrupt to flush the ringbuffer).
-static inline void* iree_hal_amdgpu_device_trace_reserve_range(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer,
-    size_t length) {
-  // Reserve a range of the requested size from the current reservation offset.
-  // NOTE: this is only modified on device and on the agent associated with the
-  // scheduler that's calling this and as such only has to be at device scope.
-  uint64_t write_offset = iree_amdgpu_scoped_atomic_fetch_add(
-      &trace_buffer->write_reserve_offset, length,
-      iree_amdgpu_memory_order_relaxed, iree_amdgpu_memory_scope_device);
-
-  // Spin until there's capacity in the ringbuffer. We need to wait until the
-  // host catches up to our last flush.
-  // WARNING: this may lock up forever if we really spill the ring.
-  // TODO(benvanik): find a way to fail here, or throw an interrupt.
-  // We could use a signal instead of an atomic but there's no good way to park
-  // from the current pc.
-  if (write_offset + length -
-          iree_amdgpu_scoped_atomic_load(&trace_buffer->read_commit_offset,
-                                         iree_amdgpu_memory_order_acquire,
-                                         iree_amdgpu_memory_scope_system) >
-      trace_buffer->ringbuffer_capacity) {
-    iree_amdgpu_yield();
-  }
-
-  // Calculate base address of the packet within the ringbuffer. Note that it
-  // may extend off the end of the base allocation but so long as the length is
-  // in bounds it'll be accessing the physical memory through the subsequent
-  // virtual address mapping.
-  void* packet_ptr =
-      (uint8_t*)trace_buffer->ringbuffer_base +
-      (write_offset & iree_hal_amdgpu_device_trace_buffer_mask(trace_buffer));
-
-  return packet_ptr;
-}
-
-bool iree_hal_amdgpu_device_trace_commit_range(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer) {
-  // Bump the commit offset as seen by the host to the reserve offset at the
-  // start of this call. The host may immediately begin reading from its last
-  // read_commit_offset up to the new write_commit_offset and we cannot
-  // overwrite any of that range until the read_commit_offset has been bumped by
-  // the host.
-  uint64_t last_reserve_offset = iree_amdgpu_scoped_atomic_load(
-      &trace_buffer->write_reserve_offset, iree_amdgpu_memory_order_acquire,
-      iree_amdgpu_memory_scope_device);
-  uint64_t last_commit_offset = iree_amdgpu_scoped_atomic_exchange(
-      &trace_buffer->write_commit_offset, last_reserve_offset,
-      iree_amdgpu_memory_order_release, iree_amdgpu_memory_scope_system);
-
-  // If the last commit offset matches the last reserve offset then there were
-  // no pending writes to commit and the caller does not need to notify the
-  // host.
-  return last_reserve_offset != last_commit_offset;
-}
-
-#else
-
-// Dummy to appease the loader.
-// We could find some macro tricks to remove this entirely but we want to allow
-// host and device builds to differ in their settings.
-IREE_AMDGPU_ATTRIBUTE_KERNEL void
-iree_hal_amdgpu_device_trace_buffer_initialize(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer) {}
-
-bool iree_hal_amdgpu_device_trace_commit_range(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer) {
-  return false;
-}
-
-#endif  // IREE_HAL_AMDGPU_TRACING_FEATURES
-
-//===----------------------------------------------------------------------===//
-// IREE_HAL_AMDGPU_TRACING_FEATURE_INSTRUMENTATION
-//===----------------------------------------------------------------------===//
-
-#if IREE_HAL_AMDGPU_HAS_TRACING_FEATURE( \
-    IREE_HAL_AMDGPU_TRACING_FEATURE_INSTRUMENTATION)
-
-iree_hal_amdgpu_zone_id_t iree_hal_amdgpu_device_trace_zone_begin(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer,
-    iree_hal_amdgpu_trace_src_loc_ptr_t src_loc) {
-  iree_hal_amdgpu_trace_zone_begin_t* packet =
-      iree_hal_amdgpu_device_trace_reserve_range(
-          trace_buffer, sizeof(iree_hal_amdgpu_trace_zone_begin_t));
-  packet->event_type = IREE_HAL_AMDGPU_TRACE_EVENT_ZONE_BEGIN;
-  packet->timestamp = iree_amdgpu_device_timestamp();
-  packet->src_loc = src_loc;
-  return 1;
-}
-
-void iree_hal_amdgpu_device_trace_zone_end(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer) {
-  iree_hal_amdgpu_trace_zone_end_t* packet =
-      iree_hal_amdgpu_device_trace_reserve_range(
-          trace_buffer, sizeof(iree_hal_amdgpu_trace_zone_end_t));
-  packet->event_type = IREE_HAL_AMDGPU_TRACE_EVENT_ZONE_END;
-  packet->timestamp = iree_amdgpu_device_timestamp();
-}
-
-void iree_hal_amdgpu_device_trace_zone_append_value_i64(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer,
-    int64_t value) {
-  iree_hal_amdgpu_trace_zone_value_i64_t* packet =
-      iree_hal_amdgpu_device_trace_reserve_range(
-          trace_buffer, sizeof(iree_hal_amdgpu_trace_zone_value_i64_t));
-  packet->event_type = IREE_HAL_AMDGPU_TRACE_EVENT_ZONE_VALUE_I64;
-  packet->value = value;
-}
-
-void iree_hal_amdgpu_device_trace_zone_append_text_literal(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer,
-    iree_hal_amdgpu_trace_string_literal_ptr_t value_literal) {
-  iree_hal_amdgpu_trace_zone_value_text_literal_t* packet =
-      iree_hal_amdgpu_device_trace_reserve_range(
-          trace_buffer,
-          sizeof(iree_hal_amdgpu_trace_zone_value_text_literal_t));
-  packet->event_type = IREE_HAL_AMDGPU_TRACE_EVENT_ZONE_VALUE_TEXT_LITERAL;
-  packet->value = value_literal;
-}
-
-void iree_hal_amdgpu_device_trace_zone_append_text_dynamic(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer,
-    const char* IREE_AMDGPU_RESTRICT value, size_t value_length) {
-  const size_t total_size =
-      sizeof(iree_hal_amdgpu_trace_zone_value_text_dynamic_t) + value_length;
-  iree_hal_amdgpu_trace_zone_value_text_dynamic_t* packet =
-      iree_hal_amdgpu_device_trace_reserve_range(trace_buffer, total_size);
-  packet->event_type = IREE_HAL_AMDGPU_TRACE_EVENT_ZONE_VALUE_TEXT_DYNAMIC;
-  packet->length = (uint32_t)value_length;
-  iree_amdgpu_memcpy(packet->value, value, value_length);
-}
-
-void iree_hal_amdgpu_device_trace_plot_configure(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer,
-    iree_hal_amdgpu_trace_string_literal_ptr_t name_literal,
-    iree_hal_amdgpu_trace_plot_type_t type,
-    iree_hal_amdgpu_trace_plot_flags_t flags,
-    iree_hal_amdgpu_trace_color_t color) {
-  iree_hal_amdgpu_trace_plot_config_t* packet =
-      iree_hal_amdgpu_device_trace_reserve_range(
-          trace_buffer, sizeof(iree_hal_amdgpu_trace_plot_config_t));
-  packet->event_type = IREE_HAL_AMDGPU_TRACE_EVENT_PLOT_CONFIG;
-  packet->plot_type = type;
-  packet->plot_flags = flags;
-  packet->color = color;
-  packet->name = name_literal;
-}
-
-void iree_hal_amdgpu_device_trace_plot_value_i64(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer,
-    iree_hal_amdgpu_trace_string_literal_ptr_t name_literal, int64_t value) {
-  iree_hal_amdgpu_trace_plot_value_i64_t* packet =
-      iree_hal_amdgpu_device_trace_reserve_range(
-          trace_buffer, sizeof(iree_hal_amdgpu_trace_plot_value_i64_t));
-  packet->event_type = IREE_HAL_AMDGPU_TRACE_EVENT_PLOT_VALUE_I64;
-  packet->plot_name = name_literal;
-  packet->timestamp = iree_amdgpu_device_timestamp();
-  packet->value = value;
-}
-
-#endif  // IREE_HAL_AMDGPU_TRACING_FEATURE_INSTRUMENTATION
-
-//===----------------------------------------------------------------------===//
-// IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_CONTROL
-//===----------------------------------------------------------------------===//
-
-#if IREE_HAL_AMDGPU_HAS_TRACING_FEATURE( \
-    IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_CONTROL)
-
-iree_hsa_signal_t iree_hal_amdgpu_device_trace_execution_zone_begin(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer,
-    iree_hal_amdgpu_trace_execution_query_id_t execution_query_id,
-    iree_hal_amdgpu_trace_src_loc_ptr_t src_loc) {
-  iree_hal_amdgpu_trace_execution_zone_begin_t* packet =
-      iree_hal_amdgpu_device_trace_reserve_range(
-          trace_buffer, sizeof(iree_hal_amdgpu_trace_execution_zone_begin_t));
-  packet->event_type = IREE_HAL_AMDGPU_TRACE_EVENT_EXECUTION_ZONE_BEGIN;
-  packet->executor_id = trace_buffer->executor_id;
-  packet->execution_query_id = execution_query_id;
-  packet->issue_timestamp = iree_amdgpu_device_timestamp();
-  packet->src_loc = src_loc;
-  return iree_hal_amdgpu_device_query_ringbuffer_signal_for_id(
-      &trace_buffer->query_ringbuffer, execution_query_id);
-}
-
-iree_hsa_signal_t iree_hal_amdgpu_device_trace_execution_zone_end(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer,
-    iree_hal_amdgpu_trace_execution_query_id_t execution_query_id) {
-  iree_hal_amdgpu_trace_execution_zone_end_t* packet =
-      iree_hal_amdgpu_device_trace_reserve_range(
-          trace_buffer, sizeof(iree_hal_amdgpu_trace_execution_zone_end_t));
-  packet->event_type = IREE_HAL_AMDGPU_TRACE_EVENT_EXECUTION_ZONE_END;
-  packet->executor_id = trace_buffer->executor_id;
-  packet->execution_query_id = execution_query_id;
-  packet->issue_timestamp = iree_amdgpu_device_timestamp();
-  return iree_hal_amdgpu_device_query_ringbuffer_signal_for_id(
-      &trace_buffer->query_ringbuffer, execution_query_id);
-}
-
-iree_hal_amdgpu_trace_agent_time_range_t* IREE_AMDGPU_RESTRICT
-iree_hal_amdgpu_device_trace_execution_zone_notify_batch(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer,
-    iree_hal_amdgpu_trace_execution_query_id_t execution_query_id_base,
-    uint16_t execution_query_count) {
-  iree_hal_amdgpu_trace_execution_zone_notify_batch_t* packet =
-      iree_hal_amdgpu_device_trace_reserve_range(
-          trace_buffer,
-          sizeof(iree_hal_amdgpu_trace_execution_zone_notify_batch_t) +
-              execution_query_count *
-                  sizeof(iree_hal_amdgpu_trace_agent_time_range_t));
-  packet->event_type = IREE_HAL_AMDGPU_TRACE_EVENT_EXECUTION_ZONE_NOTIFY_BATCH;
-  packet->executor_id = trace_buffer->executor_id;
-  packet->execution_query_id_base = execution_query_id_base;
-  packet->execution_query_count = execution_query_count;
-  return &packet->execution_time_ranges[0];
-}
-
-#endif  // IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_CONTROL
-
-//===----------------------------------------------------------------------===//
-// IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_EXECUTION
-//===----------------------------------------------------------------------===//
-
-#if IREE_HAL_AMDGPU_HAS_TRACING_FEATURE( \
-    IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_EXECUTION)
-
-iree_hsa_signal_t iree_hal_amdgpu_device_trace_execution_zone_dispatch(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer,
-    iree_hal_amdgpu_trace_execution_zone_type_t zone_type, uint64_t export_loc,
-    iree_hal_amdgpu_trace_execution_query_id_t execution_query_id) {
-  iree_hal_amdgpu_trace_execution_zone_dispatch_t* packet =
-      iree_hal_amdgpu_device_trace_reserve_range(
-          trace_buffer,
-          sizeof(iree_hal_amdgpu_trace_execution_zone_dispatch_t));
-  packet->event_type = IREE_HAL_AMDGPU_TRACE_EVENT_EXECUTION_ZONE_DISPATCH;
-  packet->zone_type = zone_type;
-  packet->executor_id = trace_buffer->executor_id;
-  packet->execution_query_id = execution_query_id;
-  packet->export_loc = export_loc;
-  packet->issue_timestamp = iree_amdgpu_device_timestamp();
-  return iree_hal_amdgpu_device_query_ringbuffer_signal_for_id(
-      &trace_buffer->query_ringbuffer, execution_query_id);
-}
-
-#endif  // IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_EXECUTION
-
-//===----------------------------------------------------------------------===//
-// IREE_HAL_AMDGPU_TRACING_FEATURE_ALLOCATION_TRACKING
-//===----------------------------------------------------------------------===//
-
-#if IREE_HAL_AMDGPU_HAS_TRACING_FEATURE( \
-    IREE_HAL_AMDGPU_TRACING_FEATURE_ALLOCATION_TRACKING)
-
-void iree_hal_amdgpu_device_trace_memory_alloc(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer,
-    iree_hal_amdgpu_trace_string_literal_ptr_t name_literal, uint64_t ptr,
-    uint64_t size) {
-  iree_hal_amdgpu_trace_memory_alloc_t* packet =
-      iree_hal_amdgpu_device_trace_reserve_range(
-          trace_buffer, sizeof(iree_hal_amdgpu_trace_memory_alloc_t));
-  packet->event_type = IREE_HAL_AMDGPU_TRACE_EVENT_MEMORY_ALLOC;
-  packet->pool = name_literal;
-  packet->timestamp = iree_amdgpu_device_timestamp();
-  packet->ptr = ptr;
-  packet->size = size;
-}
-
-void iree_hal_amdgpu_device_trace_memory_free(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer,
-    iree_hal_amdgpu_trace_string_literal_ptr_t name_literal, uint64_t ptr) {
-  iree_hal_amdgpu_trace_memory_free_t* packet =
-      iree_hal_amdgpu_device_trace_reserve_range(
-          trace_buffer, sizeof(iree_hal_amdgpu_trace_memory_free_t));
-  packet->event_type = IREE_HAL_AMDGPU_TRACE_EVENT_MEMORY_FREE;
-  packet->timestamp = iree_amdgpu_device_timestamp();
-  packet->ptr = ptr;
-}
-
-#endif  // IREE_HAL_AMDGPU_TRACING_FEATURE_ALLOCATION_TRACKING
-
-//===----------------------------------------------------------------------===//
-// IREE_HAL_AMDGPU_TRACING_FEATURE_LOG_MESSAGES
-//===----------------------------------------------------------------------===//
-
-#if IREE_HAL_AMDGPU_HAS_TRACING_FEATURE( \
-    IREE_HAL_AMDGPU_TRACING_FEATURE_LOG_MESSAGES)
-
-void iree_hal_amdgpu_device_trace_message_literal(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer,
-    iree_hal_amdgpu_trace_color_t color,
-    iree_hal_amdgpu_trace_string_literal_ptr_t value_literal) {
-  iree_hal_amdgpu_trace_message_literal_t* packet =
-      iree_hal_amdgpu_device_trace_reserve_range(
-          trace_buffer, sizeof(iree_hal_amdgpu_trace_message_literal_t));
-  packet->event_type = IREE_HAL_AMDGPU_TRACE_EVENT_MESSAGE_LITERAL;
-  packet->timestamp = iree_amdgpu_device_timestamp();
-  packet->value = value_literal;
-}
-
-void iree_hal_amdgpu_device_trace_message_dynamic(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer,
-    iree_hal_amdgpu_trace_color_t color, const char* IREE_AMDGPU_RESTRICT value,
-    size_t value_length) {
-  const size_t total_size =
-      sizeof(iree_hal_amdgpu_trace_message_dynamic_t) + value_length;
-  iree_hal_amdgpu_trace_message_dynamic_t* packet =
-      iree_hal_amdgpu_device_trace_reserve_range(trace_buffer, total_size);
-  packet->event_type = IREE_HAL_AMDGPU_TRACE_EVENT_MESSAGE_DYNAMIC;
-  packet->length = (uint32_t)value_length;
-  packet->timestamp = iree_amdgpu_device_timestamp();
-  iree_amdgpu_memcpy(packet->value, value, value_length);
-}
-
-// WARNING: this implementation is for debugging purposes only and is likely
-// both unsafe and incorrect.
-static int iree_hal_amdgpu_itoa_uint64(char* IREE_AMDGPU_RESTRICT buffer,
-                                       uint64_t value, int base,
-                                       char base_hex_char) {
-  int length = 0;
-  uint64_t d = 1;
-  while (value / d >= base) d *= base;
-  while (d != 0) {
-    int digit = value / d;
-    value %= d;
-    d /= base;
-    if (length > 0 || digit > 0 || d == 0) {
-      if (buffer) {
-        buffer[length] = digit + (digit < 10 ? '0' : base_hex_char - 10);
-      }
-      ++length;
-    }
-  }
-  return length;
-}
-
-// WARNING: this implementation is for debugging purposes only and is likely
-// both unsafe and incorrect.
-static int iree_hal_amdgpu_vsprintf(char* IREE_AMDGPU_RESTRICT buffer,
-                                    char const* IREE_AMDGPU_RESTRICT format,
-                                    va_list vlist) {
-  int length = 0;
-  while (format[0] != '\0') {
-    int c = format[0];
-    ++format;
-    if (c != '%') {
-      if (buffer) buffer[length] = c;
-      ++length;
-      continue;
-    }
-    c = format[0];
-    ++format;
-    bool is_long = false;
-    if (c == 'l') {
-      c = format[0];
-      ++format;
-      is_long = true;
-    }
-    switch (c) {
-      case 0: {
-        return length;
-      }
-      case '%': {
-        if (buffer) buffer[length] = '%';
-        ++length;
-        break;
-      }
-      case 'c': {
-        const char value = va_arg(vlist, int);
-        if (buffer) buffer[length] = value;
-        ++length;
-        break;
-      }
-      case 's': {
-        const char* str = va_arg(vlist, const char*);
-        if (buffer) {
-          for (; *str; ++str, ++length) {
-            buffer[length] = *str;
-          }
-        } else {
-          for (; *str; ++str, ++length);
-        }
-        break;
-      }
-      case 'u': {
-        const uint64_t value = is_long ? va_arg(vlist, unsigned long int)
-                                       : va_arg(vlist, unsigned int);
-        length += iree_hal_amdgpu_itoa_uint64(buffer ? &buffer[length] : NULL,
-                                              value, 10, 0);
-        break;
-      }
-      case 'x':
-      case 'X': {
-        const uint64_t value = is_long ? va_arg(vlist, unsigned long int)
-                                       : va_arg(vlist, unsigned int);
-        length += iree_hal_amdgpu_itoa_uint64(buffer ? &buffer[length] : NULL,
-                                              value, 16, c == 'X' ? 'A' : 'a');
-        break;
-      }
-      case 'p': {
-        if (buffer) {
-          buffer[length + 0] = '0';
-          buffer[length + 1] = 'x';
-        }
-        length += 2;
-        const uint64_t value = va_arg(vlist, unsigned long int);
-        length += iree_hal_amdgpu_itoa_uint64(buffer ? &buffer[length] : NULL,
-                                              value, 16, 'A');
-        break;
-      }
-      default: {
-        // Unhandled.
-        if (buffer) buffer[length] = '?';
-        ++length;
-        break;
-      }
-    }
-  }
-  return length;
-}
-
-void iree_hal_amdgpu_device_trace_debug(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer,
-    const char* IREE_AMDGPU_RESTRICT format, ...) {
-  // NOTE: we capture the timestamp before doing our string processing - not
-  // that this should be used for performance work but having the message as
-  // close in time to the originating site is always nice.
-  const iree_hal_amdgpu_trace_agent_timestamp_t timestamp =
-      iree_amdgpu_device_timestamp();
-
-  va_list vlist0, vlist1;
-  va_start(vlist0, format);
-  va_start(vlist1, format);
-
-  // Determine total length of the formatted message in characters.
-  int required_length = iree_hal_amdgpu_vsprintf(NULL, format, vlist0);
-
-  // Reserve trace buffer space for the message contents.
-  const size_t total_size =
-      sizeof(iree_hal_amdgpu_trace_message_dynamic_t) + required_length;
-  iree_hal_amdgpu_trace_message_dynamic_t* packet =
-      iree_hal_amdgpu_device_trace_reserve_range(trace_buffer, total_size);
-  packet->event_type = IREE_HAL_AMDGPU_TRACE_EVENT_MESSAGE_DYNAMIC;
-  packet->length = (uint32_t)required_length;
-  packet->timestamp = timestamp;
-
-  // Print the message into the reserved space.
-  iree_hal_amdgpu_vsprintf(packet->value, format, vlist1);
-
-  va_end(vlist0);
-  va_end(vlist1);
-}
-
-#endif  // IREE_HAL_AMDGPU_TRACING_FEATURE_LOG_MESSAGES
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device/tracing.h b/runtime/src/iree/hal/drivers/amdgpu/device/tracing.h
deleted file mode 100644
index 25bb35d..0000000
--- a/runtime/src/iree/hal/drivers/amdgpu/device/tracing.h
+++ /dev/null
@@ -1,1020 +0,0 @@
-// Copyright 2025 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_HAL_DRIVERS_AMDGPU_DEVICE_TRACING_H_
-#define IREE_HAL_DRIVERS_AMDGPU_DEVICE_TRACING_H_
-
-#include "iree/hal/drivers/amdgpu/device/support/common.h"
-#include "iree/hal/drivers/amdgpu/device/support/signal.h"
-
-//===----------------------------------------------------------------------===//
-// IREE_HAL_AMDGPU_TRACING_FEATURE_* Flags and Options
-//===----------------------------------------------------------------------===//
-
-// Enables IREE_AMDGPU_TRACE_* macros for instrumented tracing.
-#define IREE_HAL_AMDGPU_TRACING_FEATURE_INSTRUMENTATION (1 << 0)
-
-// Enables instrumentation of command buffer control (dispatches, DMA, etc).
-// This can have significant code size and runtime overhead and should only be
-// used when specifically tracing device-side execution.
-#define IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_CONTROL (1 << 1)
-
-// Enables instrumentation of command buffer execution (dispatches, DMA, etc).
-// This can have significant code size and runtime overhead and should only be
-// used when specifically tracing device-side execution.
-#define IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_EXECUTION \
-  ((1 << 2) | IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_CONTROL)
-
-// Tracks all device allocations.
-#define IREE_HAL_AMDGPU_TRACING_FEATURE_ALLOCATION_TRACKING (1 << 3)
-
-// Forwards log messages to traces, which will be visible under "Messages" in
-// the Tracy UI.
-#define IREE_HAL_AMDGPU_TRACING_FEATURE_LOG_MESSAGES (1 << 4)
-
-// Enables the IREE_HAL_AMDGPU_DBG print macros. May massively increase binary
-// size and decrease performance.
-#define IREE_HAL_AMDGPU_TRACING_FEATURE_DEBUG_MESSAGES \
-  ((1 << 5) | IREE_HAL_AMDGPU_TRACING_FEATURE_LOG_MESSAGES)
-
-// TODO(benvanik): expose as a friendly option matching the host mode.
-// For now we need the compilation to match and there are extra flags required
-// for that.
-#if 0
-#define IREE_HAL_AMDGPU_TRACING_FEATURES                 \
-  (IREE_HAL_AMDGPU_TRACING_FEATURE_INSTRUMENTATION |     \
-   IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_CONTROL |      \
-   IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_EXECUTION |    \
-   IREE_HAL_AMDGPU_TRACING_FEATURE_ALLOCATION_TRACKING | \
-   IREE_HAL_AMDGPU_TRACING_FEATURE_LOG_MESSAGES |        \
-   IREE_HAL_AMDGPU_TRACING_FEATURE_DEBUG_MESSAGES)
-#else
-#define IREE_HAL_AMDGPU_TRACING_FEATURES 0
-#endif
-
-// Tests whether one or more tracing features have been enabled in the build.
-//
-// Example:
-//  #if IREE_HAL_AMDGPU_HAS_TRACING_FEATURE( \
-//     IREE_HAL_AMDGPU_TRACING_FEATURE_LOG_MESSAGES)
-//  <<code that should only run when LOG_MESSAGES is enabled>>
-//  #endif  // IREE_HAL_AMDGPU_TRACING_FEATURE_LOG_MESSAGES
-#define IREE_HAL_AMDGPU_HAS_TRACING_FEATURE(feature_bits) \
-  IREE_AMDGPU_ALL_BITS_SET(IREE_HAL_AMDGPU_TRACING_FEATURES, (feature_bits))
-
-//===----------------------------------------------------------------------===//
-// Tracing Buffer Definitions
-//===----------------------------------------------------------------------===//
-
-// A timestamp in the domain of the agent who owns the buffer the trace event
-// is recorded in. Each agent may have differing times that need to be converted
-// into the system domain on the host.
-typedef uint64_t iree_hal_amdgpu_trace_agent_timestamp_t;
-
-// A time range bounded by two timestamps.
-typedef struct iree_hal_amdgpu_trace_agent_time_range_t {
-  iree_hal_amdgpu_trace_agent_timestamp_t begin;
-  iree_hal_amdgpu_trace_agent_timestamp_t end;
-} iree_hal_amdgpu_trace_agent_time_range_t;
-
-// A process-unique ID assigned to an agent that executes execution zones.
-// In Tracy this is a GPU context.
-typedef uint8_t iree_hal_amdgpu_trace_executor_id_t;
-
-// An outstanding execution zone query ID.
-// As execution events are issued an ID is reserved and at some point in the
-// future after execution has completed the ID is used to match up the acquired
-// timing information for the event. Being 16-bit we have a limited number of
-// outstanding IDs but as we scope them per trace buffer we should be ok.
-typedef uint16_t iree_hal_amdgpu_trace_execution_query_id_t;
-
-// Indicates that a query ID is not used.
-#define IREE_HAL_AMDGPU_TRACE_EXECUTION_QUERY_ID_INVALID \
-  ((iree_hal_amdgpu_trace_execution_query_id_t)0xFFFF)
-
-// An 0xAABBGGRR color used when presenting messages and zones in a tracing UI.
-// 0x0 can (usually) be used to indicate "default". Alpha may be ignored but
-// should be 0xFF in most cases.
-typedef uint32_t iree_hal_amdgpu_trace_color_t;
-
-// A pointer that lives within the read-only data segment of the device library
-// code object.
-//
-// The target memory may not be accessible (or may be slow) as the code object
-// is loaded onto the GPU agent. The host tracing infrastructure creates a
-// shadow copy of the code in host memory and adjusts the address from the
-// device into that shadow such that it can access it locally.
-//
-// A special case is when the pointer is outside of the loaded code object
-// range. When translating the host will pass-through any such pointers without
-// modifying them to allow for host pointers to be round-tripped from the host.
-// In this way calling a pointer an iree_hal_amdgpu_trace_rodata_ptr_t really
-// just means the host must try to translate it before dereferencing it instead
-// of strictly saying it's in code object memory.
-//
-// Example translation:
-//  uint64_t code_base = 0;
-//  err = loader.hsa_ven_amd_loader_loaded_code_object_get_info(
-//      loaded_code_object,
-//      HSA_VEN_AMD_LOADER_LOADED_CODE_OBJECT_INFO_LOAD_BASE, &code_base);
-//  const uint8_t* code_shadow = malloc(...); // + copy
-//  iree_hal_amdgpu_trace_rodata_ptr_t device_ptr = ...;
-//  const uint8_t* host_ptr = (const uint8_t*)(device_ptr - code_base +
-//                                             (uint64_t)code_shadow);
-typedef uint64_t iree_hal_amdgpu_trace_rodata_ptr_t;
-
-// A NUL-terminated string literal stored in the code object data segment.
-// This must be translated into the code object shadow copy prior to using on
-// the host.
-typedef iree_hal_amdgpu_trace_rodata_ptr_t
-    iree_hal_amdgpu_trace_string_literal_ptr_t;
-
-// Static information about a trace zone source location.
-// Tracy and other tools require the source location and its contained strings
-// to have process lifetime. Since the code object rodata segment they are
-// stored in will be unloaded as HSA is shut down we create a shadow copy that
-// we can persist in host memory until the process exits so that tracing tools
-// can access it.
-//
-// NOTE: this matches the Tracy expected source location structure exactly so
-// that we can pass it unmodified. Tracy uses the pointer of the source location
-// for several lookup tables.
-typedef struct iree_hal_amdgpu_trace_src_loc_t {
-  const char* name;
-  const char* function;
-  const char* file;
-  uint32_t line;
-  iree_hal_amdgpu_trace_color_t color;
-} iree_hal_amdgpu_trace_src_loc_t;
-
-// A iree_hal_amdgpu_trace_rodata_ptr_t that specifically references a static
-// iree_hal_amdgpu_trace_src_loc_t structure in the rodata segment. Note that
-// any pointers nested within the target src_loc are also in the rodata segment.
-typedef iree_hal_amdgpu_trace_rodata_ptr_t iree_hal_amdgpu_trace_src_loc_ptr_t;
-
-// Matches Tracy's PlotFormatType enum.
-typedef uint8_t iree_hal_amdgpu_trace_plot_type_t;
-enum iree_hal_amdgpu_trace_plot_type_e {
-  // Values will be displayed as plain numbers.
-  IREE_HAL_AMDGPU_TRACE_PLOT_TYPE_NUMBER = 0,
-  // Treats the values as memory sizes. Will display kilobytes, megabytes, etc.
-  IREE_HAL_AMDGPU_TRACE_PLOT_TYPE_MEMORY = 1,
-  // Values will be displayed as percentage with value 100 being equal to 100%.
-  IREE_HAL_AMDGPU_TRACE_PLOT_TYPE_PERCENTAGE = 2,
-};
-
-// Controls plot display and accumulation behavior.
-typedef uint8_t iree_hal_amdgpu_trace_plot_flags_t;
-enum iree_hal_amdgpu_trace_plot_flag_bits_e {
-  // Plot has discrete steps instead of being interpolated/smooth.
-  IREE_HAL_AMDGPU_TRACE_PLOT_FLAG_DISCRETE = 1u << 0,
-  // Plot has its display area filled with a solid color.
-  IREE_HAL_AMDGPU_TRACE_PLOT_FLAG_FILL = 1u << 1,
-};
-
-// Event type used to interpret the remainder of the event data.
-typedef uint8_t iree_hal_amdgpu_trace_event_type_t;
-enum {
-  IREE_HAL_AMDGPU_TRACE_EVENT_ZONE_BEGIN = 0u,
-  IREE_HAL_AMDGPU_TRACE_EVENT_ZONE_END,
-  IREE_HAL_AMDGPU_TRACE_EVENT_ZONE_VALUE_I64,
-  IREE_HAL_AMDGPU_TRACE_EVENT_ZONE_VALUE_TEXT_LITERAL,
-  IREE_HAL_AMDGPU_TRACE_EVENT_ZONE_VALUE_TEXT_DYNAMIC,
-  IREE_HAL_AMDGPU_TRACE_EVENT_EXECUTION_ZONE_BEGIN,
-  IREE_HAL_AMDGPU_TRACE_EVENT_EXECUTION_ZONE_END,
-  IREE_HAL_AMDGPU_TRACE_EVENT_EXECUTION_ZONE_DISPATCH,
-  IREE_HAL_AMDGPU_TRACE_EVENT_EXECUTION_ZONE_NOTIFY_BATCH,
-  IREE_HAL_AMDGPU_TRACE_EVENT_MEMORY_ALLOC,
-  IREE_HAL_AMDGPU_TRACE_EVENT_MEMORY_FREE,
-  IREE_HAL_AMDGPU_TRACE_EVENT_MESSAGE_LITERAL,
-  IREE_HAL_AMDGPU_TRACE_EVENT_MESSAGE_DYNAMIC,
-  IREE_HAL_AMDGPU_TRACE_EVENT_PLOT_CONFIG,
-  IREE_HAL_AMDGPU_TRACE_EVENT_PLOT_VALUE_I64,
-};
-
-// Begins a trace zone and pushes it onto the zone stack.
-typedef struct IREE_AMDGPU_ALIGNAS(8) iree_hal_amdgpu_trace_zone_begin_t {
-  // IREE_HAL_AMDGPU_TRACE_EVENT_ZONE_BEGIN
-  iree_hal_amdgpu_trace_event_type_t event_type;
-  uint8_t reserved[7];  // may be uninitialized
-  // Timestamp the zone begins at.
-  iree_hal_amdgpu_trace_agent_timestamp_t timestamp;
-  // Source location of the zone being entered.
-  iree_hal_amdgpu_trace_src_loc_ptr_t src_loc;
-} iree_hal_amdgpu_trace_zone_begin_t;
-
-// Ends the trace zone on the top of the zone stack.
-typedef struct IREE_AMDGPU_ALIGNAS(8) iree_hal_amdgpu_trace_zone_end_t {
-  // IREE_HAL_AMDGPU_TRACE_EVENT_ZONE_END
-  iree_hal_amdgpu_trace_event_type_t event_type;
-  uint8_t reserved[7];  // may be uninitialized
-  // Timestamp the zone ends at.
-  iree_hal_amdgpu_trace_agent_timestamp_t timestamp;
-} iree_hal_amdgpu_trace_zone_end_t;
-
-// Appends an i64 value to the zone on the top of the zone stack.
-typedef struct IREE_AMDGPU_ALIGNAS(8) iree_hal_amdgpu_trace_zone_value_i64_t {
-  // IREE_HAL_AMDGPU_TRACE_EVENT_ZONE_VALUE_I64
-  iree_hal_amdgpu_trace_event_type_t event_type;
-  uint8_t reserved[7];  // may be uninitialized
-  // Payload value attached to the zone.
-  uint64_t value;
-} iree_hal_amdgpu_trace_zone_value_i64_t;
-
-// Appends a string value to the zone on the top of the zone stack.
-typedef struct IREE_AMDGPU_ALIGNAS(8)
-    iree_hal_amdgpu_trace_zone_value_text_literal_t {
-  // IREE_HAL_AMDGPU_TRACE_EVENT_ZONE_VALUE_TEXT_LITERAL
-  iree_hal_amdgpu_trace_event_type_t event_type;
-  uint8_t reserved[3];
-  // Payload value attached to the zone.
-  // NUL terminated. Must be stored in the code object data segment.
-  iree_hal_amdgpu_trace_string_literal_ptr_t value;
-} iree_hal_amdgpu_trace_zone_value_text_literal_t;
-
-// Appends a string value to the zone on the top of the zone stack.
-// The contents are embedded in the trace buffer to support dynamically
-// generated values.
-typedef struct IREE_AMDGPU_ALIGNAS(8)
-    iree_hal_amdgpu_trace_zone_value_text_dynamic_t {
-  // IREE_HAL_AMDGPU_TRACE_EVENT_ZONE_VALUE_TEXT_DYNAMIC
-  iree_hal_amdgpu_trace_event_type_t event_type;
-  uint8_t reserved[3];
-  // Length of the value in characters.
-  uint32_t length;
-  // Payload value attached to the zone. Not NUL terminated.
-  char value[/*length*/];
-} iree_hal_amdgpu_trace_zone_value_text_dynamic_t;
-
-// Begins a device execution zone.
-// This captures the timestamp the zone is issued as well as a query_id used to
-// correlate a future update of the timing when available.
-typedef struct IREE_AMDGPU_ALIGNAS(8)
-    iree_hal_amdgpu_trace_execution_zone_begin_t {
-  // IREE_HAL_AMDGPU_TRACE_EVENT_EXECUTION_ZONE_BEGIN
-  iree_hal_amdgpu_trace_event_type_t event_type;
-  // Execution trace ID used to distinguish different execution units. This is
-  // assigned on the host when the execution context is configured.
-  iree_hal_amdgpu_trace_executor_id_t executor_id;
-  uint8_t reserved0[2];  // may be uninitialized
-  // A query ID used to feed back the timestamp once the execution has
-  // completed.
-  iree_hal_amdgpu_trace_execution_query_id_t execution_query_id;
-  uint8_t reserved1[2];  // may be uninitialized
-  // Timestamp the zone begin was issued at.
-  // Note that this need not be in order with any other timestamps.
-  iree_hal_amdgpu_trace_agent_timestamp_t issue_timestamp;
-  // Source location of the zone being entered.
-  iree_hal_amdgpu_trace_src_loc_ptr_t src_loc;
-} iree_hal_amdgpu_trace_execution_zone_begin_t;
-
-// Ends a device execution zone.
-typedef struct IREE_AMDGPU_ALIGNAS(8)
-    iree_hal_amdgpu_trace_execution_zone_end_t {
-  // IREE_HAL_AMDGPU_TRACE_EVENT_EXECUTION_ZONE_END
-  iree_hal_amdgpu_trace_event_type_t event_type;
-  // Execution trace ID used to distinguish different execution units. This is
-  // assigned on the host when the execution context is configured.
-  iree_hal_amdgpu_trace_executor_id_t executor_id;
-  uint8_t reserved0[2];  // may be uninitialized
-  // A query ID used to feed back the timestamp once the execution has
-  // completed.
-  iree_hal_amdgpu_trace_execution_query_id_t execution_query_id;
-  uint8_t reserved1[2];  // may be uninitialized
-  // Timestamp the zone end was issued at.
-  // Note that this need not be in order with any other timestamps.
-  iree_hal_amdgpu_trace_agent_timestamp_t issue_timestamp;
-} iree_hal_amdgpu_trace_execution_zone_end_t;
-
-// Defines the type of an execution zone dispatch.
-typedef uint8_t iree_hal_amdgpu_trace_execution_zone_type_t;
-enum iree_hal_amdgpu_trace_execution_zone_type_e {
-  // Indicates an executable export dispatch (kernel launch).
-  // The export_loc will be populated with the value defined by the host.
-  IREE_HAL_AMDGPU_TRACE_EXECUTION_ZONE_TYPE_DISPATCH = 0u,
-  // Indicates an indirect executable export dispatch.
-  // The export_loc will be populated with the value defined by the host.
-  // The total time will span both the indirect preparation dispatch (if
-  // required) and the dispatch itself.
-  IREE_HAL_AMDGPU_TRACE_EXECUTION_ZONE_TYPE_DISPATCH_INDIRECT = 0u,
-  // Indicates a DMA copy operation. export_loc may be 0.
-  IREE_HAL_AMDGPU_TRACE_EXECUTION_ZONE_TYPE_COPY,
-  // Indicates a DMA fill operation. export_loc may be 0.
-  IREE_HAL_AMDGPU_TRACE_EXECUTION_ZONE_TYPE_FILL,
-  // Indicates an internal bookkeeping dispatch. export_loc may be 0.
-  IREE_HAL_AMDGPU_TRACE_EXECUTION_ZONE_TYPE_INTERNAL,
-};
-
-// Represents a leaf device execution zone.
-// This is the same as emitting an execution zone begin and end pair but has
-// less overhead for the common cases of leaf zones (dispatches, etc).
-typedef struct IREE_AMDGPU_ALIGNAS(8)
-    iree_hal_amdgpu_trace_execution_zone_dispatch_t {
-  // IREE_HAL_AMDGPU_TRACE_EVENT_EXECUTION_ZONE_DISPATCH
-  iree_hal_amdgpu_trace_event_type_t event_type;
-  // Defines what kind of dispatch operation this is and how the export_loc is
-  // interpreted (if it is at all).
-  iree_hal_amdgpu_trace_execution_zone_type_t zone_type;
-  // Execution trace ID used to distinguish different execution units. This is
-  // assigned on the host when the execution context is configured.
-  iree_hal_amdgpu_trace_executor_id_t executor_id;
-  uint8_t reserved0[1];  // may be uninitialized
-  // A query ID used to feed back the timestamp once the execution has
-  // completed. In indirect dispatches with multiple device dispatches this is
-  // used only for the primary dispatch.
-  // Note that we have space for another ID but allocating those is annoying.
-  iree_hal_amdgpu_trace_execution_query_id_t execution_query_id;
-  uint8_t reserved1[2];
-  // A reference to the interned export source location in host memory.
-  // The host queries this information and preserves it with process lifetime
-  // so that we can quickly look it up when feeding it to the trace sink.
-  // In other approaches we'd have to allocate the source location each time we
-  // recorded an event to it or deal with tracking such information on the
-  // device.
-  uint64_t export_loc;
-  // Timestamp the dispatch was issued at.
-  // Note that this need not be in order with any other timestamps.
-  // To save on space we only record the instantaneous timestamp of the issue
-  // and apply a fixed duration. We do the issues in parallel and the timings
-  // would be messy anyway. A base timestamp is enough to calculate latency from
-  // issue to execution.
-  //
-  // NOTE: this relies on the current tracy behavior of using the information
-  // only to hint at where in the global timeline the issue occurred. If it was
-  // actually trying to assign issues to zones then we'd likely have a problem
-  // and need to either serialize all issues when tracing or do some funny math.
-  iree_hal_amdgpu_trace_agent_timestamp_t issue_timestamp;
-} iree_hal_amdgpu_trace_execution_zone_dispatch_t;
-
-// Notifies the trace sink of a batch of completed queries.
-// All queries must have contiguous IDs starting at the specified base ID.
-typedef struct IREE_AMDGPU_ALIGNAS(8)
-    iree_hal_amdgpu_trace_execution_zone_notify_batch_t {
-  // IREE_HAL_AMDGPU_TRACE_EVENT_EXECUTION_ZONE_NOTIFY_BATCH
-  iree_hal_amdgpu_trace_event_type_t event_type;
-  // Execution trace ID used to distinguish different execution units. This is
-  // assigned on the host when the execution context is configured.
-  iree_hal_amdgpu_trace_executor_id_t executor_id;
-  uint8_t reserved0[2];  // may be uninitialized
-  // The base query ID for all queries in the batch.
-  iree_hal_amdgpu_trace_execution_query_id_t execution_query_id_base;
-  // The total number of queries in the batch.
-  uint16_t execution_query_count;
-  // Timestamp ranges of the queries as executed.
-  iree_hal_amdgpu_trace_agent_time_range_t
-      execution_time_ranges[/*execution_query_count*/];
-} iree_hal_amdgpu_trace_execution_zone_notify_batch_t;
-
-// Records the allocation of a block of memory from a named pool.
-typedef struct IREE_AMDGPU_ALIGNAS(8) iree_hal_amdgpu_trace_memory_alloc_t {
-  // IREE_HAL_AMDGPU_TRACE_EVENT_MEMORY_ALLOC
-  iree_hal_amdgpu_trace_event_type_t event_type;
-  // TODO(benvanik): try to see if we can get the memory pool name in 7 bytes -
-  // if so we can shrink the packet to 24 bytes.
-  uint8_t reserved[7];
-  // Pool name used as both a title for the pool and the unique ID for
-  // correlating alloc/free events.
-  iree_hal_amdgpu_trace_string_literal_ptr_t pool;
-  // Timestamp the allocation was made.
-  iree_hal_amdgpu_trace_agent_timestamp_t timestamp;
-  // Pointer in whatever memory space the pool defines.
-  uint64_t ptr;
-  // Size of the allocation in bytes.
-  uint64_t size;
-} iree_hal_amdgpu_trace_memory_alloc_t;
-
-// Records the freeing of a block of memory from a named pool.
-typedef struct IREE_AMDGPU_ALIGNAS(8) iree_hal_amdgpu_trace_memory_free_t {
-  // IREE_HAL_AMDGPU_TRACE_EVENT_MEMORY_FREE
-  iree_hal_amdgpu_trace_event_type_t event_type;
-  uint8_t reserved[7];
-  // Pool name used as both a title for the pool and the unique ID for
-  // correlating alloc/free events.
-  iree_hal_amdgpu_trace_string_literal_ptr_t pool;
-  // Timestamp the allocation was freed.
-  iree_hal_amdgpu_trace_agent_timestamp_t timestamp;
-  // Pointer in whatever memory space the pool defines. Must have previously
-  // been used in a memory allocation event.
-  uint64_t ptr;
-} iree_hal_amdgpu_trace_memory_free_t;
-
-// Logs a string message.
-typedef struct IREE_AMDGPU_ALIGNAS(8) iree_hal_amdgpu_trace_message_literal_t {
-  // IREE_HAL_AMDGPU_TRACE_EVENT_MESSAGE_LITERAL
-  iree_hal_amdgpu_trace_event_type_t event_type;
-  // TODO(benvanik): try to see if we can get the literal in 7 bytes - if so
-  // we can shrink the packet to 16 bytes.
-  uint8_t reserved[7];
-  // Timestamp the message was emitted.
-  iree_hal_amdgpu_trace_agent_timestamp_t timestamp;
-  // Message payload. Not NUL terminated.
-  iree_hal_amdgpu_trace_string_literal_ptr_t value;
-} iree_hal_amdgpu_trace_message_literal_t;
-
-// Logs a string message.
-// The contents are embedded in the trace buffer to support dynamically
-// generated values.
-typedef struct IREE_AMDGPU_ALIGNAS(8) iree_hal_amdgpu_trace_message_dynamic_t {
-  // IREE_HAL_AMDGPU_TRACE_EVENT_MESSAGE_DYNAMIC
-  iree_hal_amdgpu_trace_event_type_t event_type;
-  uint8_t reserved[3];
-  // Length of the value in characters.
-  uint32_t length;
-  // Timestamp the message was emitted.
-  iree_hal_amdgpu_trace_agent_timestamp_t timestamp;
-  // Message payload. Not NUL terminated.
-  char value[/*length*/];
-} iree_hal_amdgpu_trace_message_dynamic_t;
-
-// Defines a plot.
-// This must be called prior to
-typedef struct IREE_AMDGPU_ALIGNAS(8) iree_hal_amdgpu_trace_plot_config_t {
-  // IREE_HAL_AMDGPU_TRACE_EVENT_PLOT_CONFIG
-  iree_hal_amdgpu_trace_event_type_t event_type;
-  // Defines the plot type.
-  iree_hal_amdgpu_trace_plot_type_t plot_type;
-  // Controls plot display and accumulation behavior.
-  iree_hal_amdgpu_trace_plot_flags_t plot_flags;
-  // Base color of the plot (line and fill will be derived from this).
-  iree_hal_amdgpu_trace_color_t color;
-  // Plot name displayed as a title.
-  // The pointer value is used as a key for future plot data.
-  iree_hal_amdgpu_trace_string_literal_ptr_t name;
-} iree_hal_amdgpu_trace_plot_config_t;
-
-// Records an i64 plot value change.
-typedef struct IREE_AMDGPU_ALIGNAS(8) iree_hal_amdgpu_trace_plot_value_i64_t {
-  // IREE_HAL_AMDGPU_TRACE_EVENT_PLOT_VALUE_I64
-  iree_hal_amdgpu_trace_event_type_t event_type;
-  // TODO(benvanik): try to see if we can get the plot name in 7 bytes - if so
-  // we can shrink the packet to 24 bytes.
-  uint8_t reserved[7];
-  // Uniqued name of the plot as used during configuration.
-  iree_hal_amdgpu_trace_string_literal_ptr_t plot_name;
-  // Time the plot value was emitted.
-  iree_hal_amdgpu_trace_agent_timestamp_t timestamp;
-  // Plot value as interpreted by the plot type.
-  int64_t value;
-} iree_hal_amdgpu_trace_plot_value_i64_t;
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_device_query_ringbuffer_t
-//===----------------------------------------------------------------------===//
-
-// Total number of query signals allocated to a trace buffer ringbuffer.
-// This preallocates the signals as part of the trace buffer structure.
-//
-// We generally trade off some fixed device memory consumption by allocating a
-// large pool instead of trying to handle cases of exhaustion. This could be
-// lowered but at the point you have a few hundred GB of GPU memory 4MB is a
-// drop in a very large bucket.
-//
-// Due to tracy behavior we have to reserve query indices for begin/end of
-// dispatches even though we only need one signal. We use the upper bit to
-// differentiate when reporting the signals to tracy.
-//
-// Must be a power-of-two.
-#define IREE_HAL_AMDGPU_DEVICE_QUERY_RINGBUFFER_CAPACITY (0xFFFFu >> 1)
-
-// A ringbuffer of device-only query signals that can be acquired/released in
-// large blocks. Query signals are not full HSA signals and cannot be used on
-// the host - there's no backing mailbox/doorbell for raising interrupts and
-// attempting to cast them to hsa_signal_t (on the host) will fail.
-//
-// This exploits the fact that signals are just iree_amd_signal_t structs in
-// memory from the perspective of the device - only the host cares if they are
-// wrapped in ROCR/HSA types.
-//
-// Signals are maintained at rest with a value of 1. Those acquiring can change
-// this value after acquiring them if needed.
-//
-// Thread-compatible; only the thread owning the trace buffer will acquire and
-// release from the ringbuffer so there's no need to make it safer.
-typedef struct iree_hal_amdgpu_device_query_ringbuffer_t {
-  uint64_t read_index;
-  uint64_t write_index;
-  iree_amd_signal_t signals[IREE_HAL_AMDGPU_DEVICE_QUERY_RINGBUFFER_CAPACITY];
-} iree_hal_amdgpu_device_query_ringbuffer_t;
-
-#if defined(IREE_AMDGPU_TARGET_DEVICE)
-
-// Acquires a signal from the ringbuffer and returns the execution query_id for
-// it. Callers must use iree_hal_amdgpu_device_query_ringbuffer_signal_for_id to
-// get the signal handle that can be provided in packets.
-iree_hal_amdgpu_trace_execution_query_id_t
-iree_hal_amdgpu_device_query_ringbuffer_acquire(
-    iree_hal_amdgpu_device_query_ringbuffer_t* IREE_AMDGPU_RESTRICT ringbuffer);
-
-// Acquires |count| signals from the ringbuffer and returns the base index in
-// the absolute ringbuffer domain. Callers must use
-// iree_hal_amdgpu_device_query_ringbuffer_signal to get the signal handle that
-// can be provided in packets.
-uint64_t iree_hal_amdgpu_device_query_ringbuffer_acquire_range(
-    iree_hal_amdgpu_device_query_ringbuffer_t* IREE_AMDGPU_RESTRICT ringbuffer,
-    uint16_t count);
-
-// Releases the oldest acquired batch of |count| signals back to the ringbuffer.
-// The signals may immediately be overwritten/reused and must have no
-// outstanding references by either the caller or the hardware queues.
-void iree_hal_amdgpu_device_query_ringbuffer_release_range(
-    iree_hal_amdgpu_device_query_ringbuffer_t* IREE_AMDGPU_RESTRICT ringbuffer,
-    uint16_t count);
-
-// Returns the tracing ID used for the signal at the given absolute ringbuffer
-// index.
-static inline IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE
-    iree_hal_amdgpu_trace_execution_query_id_t
-    iree_hal_amdgpu_device_query_ringbuffer_query_id(
-        const iree_hal_amdgpu_device_query_ringbuffer_t* IREE_AMDGPU_RESTRICT
-            ringbuffer,
-        uint64_t index) {
-  return (
-      iree_hal_amdgpu_trace_execution_query_id_t)(index &
-                                                  (IREE_AMDGPU_ARRAYSIZE(
-                                                       ringbuffer->signals) -
-                                                   1));
-}
-
-// Returns a device-only HSA signal handle for the query signal at the given
-// absolute ringbuffer index.
-static inline IREE_AMDGPU_ATTRIBUTE_ALWAYS_INLINE iree_hsa_signal_t
-iree_hal_amdgpu_device_query_ringbuffer_signal_for_id(
-    const iree_hal_amdgpu_device_query_ringbuffer_t* IREE_AMDGPU_RESTRICT
-        ringbuffer,
-    iree_hal_amdgpu_trace_execution_query_id_t query_id) {
-  return (iree_hsa_signal_t){(uint64_t)&ringbuffer->signals[query_id]};
-}
-
-#endif  // IREE_AMDGPU_TARGET_DEVICE
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_device_trace_buffer_t
-//===----------------------------------------------------------------------===//
-
-#if IREE_HAL_AMDGPU_TRACING_FEATURES
-
-// Single-producer/single-consumer ringbuffer with mapping tricks.
-// Trace events are emitted by the scheduler in batches by having the scheduler
-// mark the start of a reservation, populating that with as many events as it
-// wants, and then committing at the end of its written range. The host side is
-// responsible for processing the range defined from the last read commit offset
-// to the last write commit offset it receives when running.
-//
-// Writes must check for overflow by ensuring that there is sufficient capacity
-// for their reservation. An example write sequence:
-// if (write_reserve_offset + requested_size - read_commit_offset >= capacity) {
-//    iree_amdgpu_yield();
-// }
-// memcpy(ringbuffer_base + write_reserve_offset, contents, requested_size);
-// write_reserve_offset += requested_size;
-//
-// This presents as a ringbuffer that does not need any special logic for
-// wrapping from base offsets used when copying in memory. It follows the
-// approach documented in of virtual memory mapping the buffer multiple times:
-// https://github.com/google/wuffs/blob/main/script/mmap-ring-buffer.c
-// We use SVM to allocate the physical memory of the ringbuffer and then stitch
-// together 3 virtual memory ranges in one contiguous virtual allocation that
-// alias the physical allocation. By treating the middle range as the base
-// buffer pointer we are then able to freely dereference both before and after
-// the base pointer by up to the ringbuffer size in length.
-//   physical: <ringbuffer size> --+------+------+
-//                                 v      v      v
-//                        virtual: [prev] [base] [next]
-//                                        ^
-//                                        +-- base_ptr
-//
-// Because of the mapping trick we have a maximum outstanding ringbuffer size
-// equal to the ringbuffer capacity (modulo alignment requirements). We flush
-// after each major phase of work (when the scheduler goes idle, a command
-// buffer block completes execution, etc) and need a minimum capacity enough to
-// store all of the data produced during those phases. Command buffers are the
-// riskiest and with a Tracy-imposed uint16_t signal query ringbuffer we have to
-// chunk those anyway and can ensure we have enough space for the
-// iree_hal_amdgpu_trace_execution_zone_dispatch_t and corresponding
-// iree_hal_amdgpu_trace_execution_zone_notify_batch_t packets (+ margin for the
-// scheduler).
-//
-// Thread-compatible; single-producer/single-consumer. The scheduler that owns
-// the trace buffer is the "thread" and is the only one allowed to write to it.
-// The paired host command processor is the only one allowed to read from it
-// when marshaling to the native host tracing APIs. This allows us to use
-// relatively simple data structures and commit at fixed intervals: we reserve
-// from the write index while producing data and only commit to make it host
-// visible at reasonable flush points.
-//
-// Each trace buffer represents a single thread of execution (the scheduler)
-// plus one or more additional device executors (the hardware queues executing
-// commands) so we don't need to track thread IDs or other TLS information on
-// events. The host processing the events will assign the appropriate tracing
-// IDs when transcribing the events.
-//
-// Pointers embedded within the trace events are treated as either opaque (such
-// as the pointer used during an allocation event) or as special read-only data
-// segment pointers. See iree_hal_amdgpu_trace_rodata_ptr_t for more
-// information. The host must translate these pointers before dereferencing them
-// or providing them to the trace sink that will (for example to get source
-// location name strings).
-typedef struct iree_hal_amdgpu_device_trace_buffer_t {
-  // Base ringbuffer pointer used for all relative addressing.
-  // Pointers must always be within the range of
-  // (ringbuffer_base-ringbuffer_capacity, ringbuffer_base+ringbuffer_capacity).
-  uint8_t* IREE_AMDGPU_RESTRICT ringbuffer_base;
-  // Total size in bytes of the trace data ringbuffer.
-  // Note that this is the size of the underlying physical allocation and the
-  // virtual range is 3x that. Must be a power of two.
-  uint64_t ringbuffer_capacity;
-  // Process-unique executor ID used for tracing command execution.
-  // NOTE: this assumes only one executor per trace buffer; we may want to
-  // support multiple and make users pass them in for cases where a single
-  // command buffer may execute on multiple hardware execution queues.
-  iree_hal_amdgpu_trace_executor_id_t executor_id;
-  uint8_t reserved0[7];  // may be uninitialized
-  // iree_hal_amdgpu_trace_buffer_t* used to route flush requests back to the
-  // owning host resource.
-  uint64_t host_trace_buffer;
-  // Used by the host to indicate where it has completed reading to. The host
-  // should atomically bump read_commit_offset when it has completed reading a
-  // chunk from the ringbuffer. If the capacity is reached then the device may
-  // spin until the host has caught up.
-  // Absolute offset - must be masked with ringbuffer_mask.
-  IREE_AMDGPU_ALIGNAS(iree_amdgpu_destructive_interference_size)
-  iree_amdgpu_scoped_atomic_uint64_t read_commit_offset;
-  // Exclusively used by the scheduler to mark the start of its current
-  // reservation. This is assigned after each flush and only advanced with each
-  // trace event recorded.
-  // Absolute offset - must be masked with ringbuffer_mask.
-  IREE_AMDGPU_ALIGNAS(iree_amdgpu_destructive_interference_size)
-  iree_amdgpu_scoped_atomic_uint64_t write_reserve_offset;
-  // Used by both the host and scheduler to track the current committed write
-  // range. Always <= the write_reserve_offset (with == indicating that there
-  // are no pending events).
-  // Absolute offset - must be masked with ringbuffer_mask.
-  IREE_AMDGPU_ALIGNAS(iree_amdgpu_destructive_interference_size)
-  iree_amdgpu_scoped_atomic_uint64_t write_commit_offset;
-  // Ringbuffer of device-only signals that can be used to get timestamps from
-  // the packet processor. Only the scheduler that owns the trace buffer is
-  // allowed to acquire/release from the ringbuffer and it is sized to fit only
-  // a single command buffer block worth of operations.
-  iree_hal_amdgpu_device_query_ringbuffer_t query_ringbuffer;
-} iree_hal_amdgpu_device_trace_buffer_t;
-
-// Mask used to wrap an absolute ringbuffer offset into a base pointer offset.
-#define iree_hal_amdgpu_device_trace_buffer_mask(trace_buffer) \
-  ((trace_buffer)->ringbuffer_capacity - 1)
-
-#else
-
-typedef struct iree_hal_amdgpu_device_trace_buffer_t {
-  int reserved;
-} iree_hal_amdgpu_device_trace_buffer_t;
-
-#define IREE_HAL_AMDGPU_TRACE_BUFFER_KERNARG_SIZE 0
-
-#endif  // IREE_HAL_AMDGPU_TRACING_FEATURES
-
-// Control kernargs used when launching the trace buffer kernels.
-typedef struct iree_hal_amdgpu_device_trace_buffer_kernargs_t {
-  IREE_AMDGPU_DEVICE_PTR iree_hal_amdgpu_device_trace_buffer_t* trace_buffer;
-} iree_hal_amdgpu_device_trace_buffer_kernargs_t;
-
-//===----------------------------------------------------------------------===//
-// Tracing Macros/Support
-//===----------------------------------------------------------------------===//
-
-typedef uint32_t iree_hal_amdgpu_zone_id_t;
-
-#if defined(IREE_AMDGPU_TARGET_DEVICE)
-
-// Colors used for messages based on the level provided to the macro.
-enum {
-  IREE_AMDGPU_TRACE_MESSAGE_LEVEL_ERROR = 0xFF0000u,
-  IREE_AMDGPU_TRACE_MESSAGE_LEVEL_WARNING = 0xFFFF00u,
-  IREE_AMDGPU_TRACE_MESSAGE_LEVEL_INFO = 0xFFFFFFu,
-  IREE_AMDGPU_TRACE_MESSAGE_LEVEL_VERBOSE = 0xC0C0C0u,
-  IREE_AMDGPU_TRACE_MESSAGE_LEVEL_DEBUG = 0x00FF00u,
-};
-
-#if IREE_HAL_AMDGPU_TRACING_FEATURES
-
-#define IREE_AMDGPU_TRACE_BUFFER_SCOPE(trace_buffer)                  \
-  iree_hal_amdgpu_device_trace_buffer_t* __iree_amdgpu_trace_buffer = \
-      (trace_buffer)
-#define IREE_AMDGPU_TRACE_BUFFER() (__iree_amdgpu_trace_buffer)
-
-#else
-
-#define IREE_AMDGPU_TRACE_BUFFER_SCOPE(...)
-#define IREE_AMDGPU_TRACE_BUFFER() NULL
-
-#endif  // IREE_HAL_AMDGPU_TRACING_FEATURES
-
-#if IREE_HAL_AMDGPU_HAS_TRACING_FEATURE( \
-    IREE_HAL_AMDGPU_TRACING_FEATURE_INSTRUMENTATION)
-
-#define IREE_AMDGPU_TRACE_CONCAT_(x, y) IREE_AMDGPU_TRACE_CONCAT_INDIRECT_(x, y)
-#define IREE_AMDGPU_TRACE_CONCAT_INDIRECT_(x, y) x##y
-
-// Begins a new zone with the parent function name.
-#define IREE_AMDGPU_TRACE_ZONE_BEGIN(zone_id) \
-  IREE_AMDGPU_TRACE_ZONE_BEGIN_NAMED(zone_id, NULL)
-
-// Begins a new zone with the given compile-time |name_literal|.
-// The literal must be static const and will be embedded in the trace buffer by
-// reference.
-#define IREE_AMDGPU_TRACE_ZONE_BEGIN_NAMED(zone_id, name_literal) \
-  IREE_AMDGPU_TRACE_ZONE_BEGIN_NAMED_COLORED(zone_id, name_literal, 0)
-
-// Begins a new zone with the given compile-time |name_literal| and color.
-// The literal must be static const and will be embedded in the trace buffer by
-// reference.
-#define IREE_AMDGPU_TRACE_ZONE_BEGIN_NAMED_COLORED(zone_id, name_literal,      \
-                                                   color)                      \
-  static const iree_hal_amdgpu_trace_src_loc_t IREE_AMDGPU_TRACE_CONCAT_(      \
-      __iree_amdgpu_trace_src_loc, __LINE__) = {                               \
-      (name_literal), __FUNCTION__, __FILE__, (uint32_t)__LINE__, (color),     \
-  };                                                                           \
-  iree_hal_amdgpu_zone_id_t zone_id = iree_hal_amdgpu_device_trace_zone_begin( \
-      IREE_AMDGPU_TRACE_BUFFER(),                                              \
-      (iree_hal_amdgpu_trace_src_loc_ptr_t) &                                  \
-          IREE_AMDGPU_TRACE_CONCAT_(__iree_amdgpu_trace_src_loc, __LINE__));
-
-// Ends the current zone. Must be passed the |zone_id| from the _BEGIN.
-#define IREE_AMDGPU_TRACE_ZONE_END(zone_id) \
-  iree_hal_amdgpu_device_trace_zone_end(IREE_AMDGPU_TRACE_BUFFER())
-
-// Appends an int64_t value to the parent zone. May be called multiple times.
-#define IREE_AMDGPU_TRACE_ZONE_APPEND_VALUE_I64(zone_id, value) \
-  (void)(zone_id);                                              \
-  iree_hal_amdgpu_device_trace_zone_append_value_i64(           \
-      IREE_AMDGPU_TRACE_BUFFER(), (int64_t)(value))
-
-// Appends a string literal value to the parent zone. May be called multiple
-// times. The provided NUL-terminated C string will be referenced directly.
-#define IREE_AMDGPU_TRACE_ZONE_APPEND_TEXT_LITERAL(zone_id, value_literal) \
-  (void)(zone_id);                                                         \
-  iree_hal_amdgpu_device_trace_zone_append_text_literal(                   \
-      IREE_AMDGPU_TRACE_BUFFER(),                                          \
-      (iree_hal_amdgpu_trace_string_literal_ptr_t)(value_literal))
-
-// Appends a string value to the parent zone. May be called multiple times.
-// The provided NUL-terminated C string or string view will be copied into the
-// trace buffer.
-#define IREE_AMDGPU_TRACE_ZONE_APPEND_TEXT_DYNAMIC(...)                     \
-  IREE_AMDGPU_TRACE_IMPL_GET_VARIADIC_(                                     \
-      (__VA_ARGS__, IREE_AMDGPU_TRACE_ZONE_APPEND_TEXT_STRING_VIEW_DYNAMIC, \
-       IREE_AMDGPU_TRACE_ZONE_APPEND_TEXT_CSTRING_DYNAMIC))                 \
-  (__VA_ARGS__)
-#define IREE_AMDGPU_TRACE_ZONE_APPEND_TEXT_CSTRING_DYNAMIC(zone_id, value)   \
-  IREE_AMDGPU_TRACE_ZONE_APPEND_TEXT_STRING_VIEW_DYNAMIC((zone_id), (value), \
-                                                         sizeof(value))
-#define IREE_AMDGPU_TRACE_ZONE_APPEND_TEXT_STRING_VIEW_DYNAMIC(zone_id, value, \
-                                                               value_length)   \
-  (void)(zone_id);                                                             \
-  iree_hal_amdgpu_device_trace_zone_append_text_dynamic(                       \
-      IREE_AMDGPU_TRACE_BUFFER(), (value), (value_length))
-
-// Configures the named plot with iree_hal_amdgpu_trace_plot_type_t data and
-// iree_hal_amdgpu_trace_plot_flags_t controlling the display.
-#define IREE_AMDGPU_TRACE_PLOT_CONFIGURE(name_literal, type, flags, color) \
-  iree_hal_amdgpu_device_trace_plot_configure(                             \
-      IREE_AMDGPU_TRACE_BUFFER(),                                          \
-      (iree_hal_amdgpu_trace_string_literal_ptr_t)(name_literal), (type),  \
-      (flags), (color))
-// Plots a value in the named plot group as an int64_t.
-#define IREE_AMDGPU_TRACE_PLOT_VALUE_I64(name_literal, value) \
-  iree_hal_amdgpu_device_trace_plot_value_i64(                \
-      IREE_AMDGPU_TRACE_BUFFER(),                             \
-      (iree_hal_amdgpu_trace_string_literal_ptr_t)(name_literal), (value))
-
-// Utilities:
-#define IREE_AMDGPU_TRACE_IMPL_GET_VARIADIC_HELPER_(_1, _2, _3, NAME, ...) NAME
-#define IREE_AMDGPU_TRACE_IMPL_GET_VARIADIC_(args) \
-  IREE_AMDGPU_TRACE_IMPL_GET_VARIADIC_HELPER_ args
-
-#else
-
-#define IREE_AMDGPU_TRACE_ZONE_BEGIN(zone_id) \
-  iree_hal_amdgpu_zone_id_t zone_id = 0;      \
-  (void)zone_id;
-#define IREE_AMDGPU_TRACE_ZONE_BEGIN_NAMED(zone_id, name_literal) \
-  IREE_AMDGPU_TRACE_ZONE_BEGIN(zone_id)
-#define IREE_AMDGPU_TRACE_ZONE_BEGIN_NAMED_COLORED(zone_id, name_literal, \
-                                                   color)                 \
-  IREE_AMDGPU_TRACE_ZONE_BEGIN(zone_id)
-#define IREE_AMDGPU_TRACE_ZONE_END(zone_id) (void)(zone_id)
-
-#define IREE_AMDGPU_TRACE_ZONE_APPEND_VALUE_I64(zone_id, value)
-#define IREE_AMDGPU_TRACE_ZONE_APPEND_TEXT_LITERAL(zone_id, value_literal)
-#define IREE_AMDGPU_TRACE_ZONE_APPEND_TEXT_DYNAMIC(zone_id, ...)
-
-#define IREE_AMDGPU_TRACE_PLOT_CONFIGURE(name_literal, type, flags, color)
-#define IREE_AMDGPU_TRACE_PLOT_VALUE_I64(name_literal, value)
-
-#endif  // IREE_HAL_AMDGPU_TRACING_FEATURE_INSTRUMENTATION
-
-#if IREE_HAL_AMDGPU_HAS_TRACING_FEATURE( \
-    IREE_HAL_AMDGPU_TRACING_FEATURE_ALLOCATION_TRACKING)
-
-// Traces a new memory allocation in a named memory pool.
-// Reallocations must be recorded as an
-// IREE_AMDGPU_TRACE_ALLOC_NAMED/IREE_AMDGPU_TRACE_FREE_NAMED pair.
-#define IREE_AMDGPU_TRACE_ALLOC_NAMED(name_literal, ptr, size)    \
-  iree_hal_amdgpu_device_trace_memory_alloc(                      \
-      IREE_AMDGPU_TRACE_BUFFER(),                                 \
-      (iree_hal_amdgpu_trace_string_literal_ptr_t)(name_literal), \
-      (uint64_t)(ptr), (size))
-
-// Traces a free of an existing allocation traced with
-// IREE_AMDGPU_TRACE_ALLOC_NAMED.
-#define IREE_AMDGPU_TRACE_FREE_NAMED(name_literal, ptr)           \
-  iree_hal_amdgpu_device_trace_memory_free(                       \
-      IREE_AMDGPU_TRACE_BUFFER(),                                 \
-      (iree_hal_amdgpu_trace_string_literal_ptr_t)(name_literal), \
-      (uint64_t)(ptr))
-
-#else
-
-#define IREE_AMDGPU_TRACE_ALLOC_NAMED(name_literal, ptr, size)
-#define IREE_AMDGPU_TRACE_FREE_NAMED(name_literal, ptr)
-
-#endif  // IREE_HAL_AMDGPU_TRACING_FEATURE_ALLOCATION_TRACKING
-
-#if IREE_HAL_AMDGPU_HAS_TRACING_FEATURE( \
-    IREE_HAL_AMDGPU_TRACING_FEATURE_LOG_MESSAGES)
-
-// Logs a message at the given logging level to the trace.
-// The message text must be a compile-time string literal.
-#define IREE_AMDGPU_TRACE_MESSAGE_LITERAL(level, value_literal) \
-  IREE_AMDGPU_TRACE_MESSAGE_LITERAL_COLORED(                    \
-      IREE_AMDGPU_TRACE_MESSAGE_LEVEL_##level, (value_literal))
-
-// Logs a message with the given color to the trace.
-// Standard colors are defined as IREE_AMDGPU_TRACE_MESSAGE_LEVEL_* values.
-// The message text must be a compile-time string literal.
-#define IREE_AMDGPU_TRACE_MESSAGE_LITERAL_COLORED(color, value_literal) \
-  iree_hal_amdgpu_device_trace_message_literal(                         \
-      IREE_AMDGPU_TRACE_BUFFER(), (color),                              \
-      (iree_hal_amdgpu_trace_string_literal_ptr_t)(value_literal))
-
-// Logs a dynamically-allocated message at the given logging level to the trace.
-// The string |value| will be copied into the trace buffer.
-#define IREE_AMDGPU_TRACE_MESSAGE_DYNAMIC(level, value, value_length) \
-  IREE_AMDGPU_TRACE_MESSAGE_DYNAMIC_COLORED(                          \
-      IREE_AMDGPU_TRACE_MESSAGE_LEVEL_##level, (value), (value_length))
-
-// Logs a dynamically-allocated message with the given color to the trace.
-// Standard colors are defined as IREE_AMDGPU_TRACE_MESSAGE_LEVEL_* values.
-// The string |value| will be copied into the trace buffer.
-#define IREE_AMDGPU_TRACE_MESSAGE_DYNAMIC_COLORED(color, value, value_length) \
-  iree_hal_amdgpu_device_trace_message_dynamic(                               \
-      IREE_AMDGPU_TRACE_BUFFER(), (color), (value), (value_length))
-
-#else
-
-#define IREE_AMDGPU_TRACE_MESSAGE_LITERAL(level, value_literal)
-#define IREE_AMDGPU_TRACE_MESSAGE_LITERAL_COLORED(color, value_literal)
-#define IREE_AMDGPU_TRACE_MESSAGE_DYNAMIC(level, value, value_length)
-#define IREE_AMDGPU_TRACE_MESSAGE_DYNAMIC_COLORED(color, value, value_length)
-
-#endif  // IREE_HAL_AMDGPU_TRACING_FEATURE_LOG_MESSAGES
-
-#if IREE_HAL_AMDGPU_HAS_TRACING_FEATURE( \
-    IREE_HAL_AMDGPU_TRACING_FEATURE_DEBUG_MESSAGES)
-
-// Logs a message formatted with an extremely basic sprintf-like function.
-// Supported format specifiers:
-//   %% (escape for `%`)
-//   %c (single `char`)
-//   %s (NUL-terminated string)
-//   %u/%lu (uint32_t/uint64_t in base 10)
-//   %x/%lx (uint32_t/uint64_t in base 16)
-//   %p (pointer)
-#define IREE_AMDGPU_DBG(format, ...)                                       \
-  iree_hal_amdgpu_device_trace_debug(IREE_AMDGPU_TRACE_BUFFER(), (format), \
-                                     __VA_ARGS__)
-
-#else
-
-#define IREE_AMDGPU_DBG(format, ...)
-
-#endif  // IREE_HAL_AMDGPU_TRACING_FEATURE_DEBUG_MESSAGES
-
-#endif  // IREE_AMDGPU_TARGET_DEVICE
-
-//===----------------------------------------------------------------------===//
-// Device-side API
-//===----------------------------------------------------------------------===//
-
-#if defined(IREE_AMDGPU_TARGET_DEVICE)
-
-// Commits the current write reservation to the ringbuffer so that the host can
-// begin reading it. Callers must notify the host that new data is available via
-// a host interrupt if this returns true.
-bool iree_hal_amdgpu_device_trace_commit_range(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer);
-
-#if IREE_HAL_AMDGPU_HAS_TRACING_FEATURE( \
-    IREE_HAL_AMDGPU_TRACING_FEATURE_INSTRUMENTATION)
-
-iree_hal_amdgpu_zone_id_t iree_hal_amdgpu_device_trace_zone_begin(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer,
-    iree_hal_amdgpu_trace_src_loc_ptr_t src_loc);
-void iree_hal_amdgpu_device_trace_zone_end(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer);
-
-void iree_hal_amdgpu_device_trace_zone_append_value_i64(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer,
-    int64_t value);
-void iree_hal_amdgpu_device_trace_zone_append_text_literal(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer,
-    iree_hal_amdgpu_trace_string_literal_ptr_t value_literal);
-void iree_hal_amdgpu_device_trace_zone_append_text_dynamic(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer,
-    const char* IREE_AMDGPU_RESTRICT value, size_t value_length);
-
-void iree_hal_amdgpu_device_trace_plot_configure(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer,
-    iree_hal_amdgpu_trace_string_literal_ptr_t name_literal,
-    iree_hal_amdgpu_trace_plot_type_t type,
-    iree_hal_amdgpu_trace_plot_flags_t flags,
-    iree_hal_amdgpu_trace_color_t color);
-void iree_hal_amdgpu_device_trace_plot_value_i64(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer,
-    iree_hal_amdgpu_trace_string_literal_ptr_t name_literal, int64_t value);
-
-#endif  // IREE_HAL_AMDGPU_TRACING_FEATURE_INSTRUMENTATION
-
-#if IREE_HAL_AMDGPU_HAS_TRACING_FEATURE( \
-    IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_CONTROL)
-
-iree_hsa_signal_t iree_hal_amdgpu_device_trace_execution_zone_begin(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer,
-    iree_hal_amdgpu_trace_execution_query_id_t execution_query_id,
-    iree_hal_amdgpu_trace_src_loc_ptr_t src_loc);
-iree_hsa_signal_t iree_hal_amdgpu_device_trace_execution_zone_end(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer,
-    iree_hal_amdgpu_trace_execution_query_id_t execution_query_id);
-void iree_hal_amdgpu_device_trace_execution_zone_notify(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer,
-    iree_hal_amdgpu_trace_execution_query_id_t execution_query_id,
-    iree_hal_amdgpu_trace_agent_time_range_t time_range);
-iree_hal_amdgpu_trace_agent_time_range_t* IREE_AMDGPU_RESTRICT
-iree_hal_amdgpu_device_trace_execution_zone_notify_batch(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer,
-    iree_hal_amdgpu_trace_execution_query_id_t execution_query_id_base,
-    uint16_t execution_query_count);
-
-#endif  // IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_CONTROL
-
-#if IREE_HAL_AMDGPU_HAS_TRACING_FEATURE( \
-    IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_EXECUTION)
-
-iree_hsa_signal_t iree_hal_amdgpu_device_trace_execution_zone_dispatch(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer,
-    iree_hal_amdgpu_trace_execution_zone_type_t zone_type, uint64_t export_loc,
-    iree_hal_amdgpu_trace_execution_query_id_t execution_query_id);
-
-#endif  // IREE_HAL_AMDGPU_TRACING_FEATURE_DEVICE_EXECUTION
-
-#if IREE_HAL_AMDGPU_HAS_TRACING_FEATURE( \
-    IREE_HAL_AMDGPU_TRACING_FEATURE_ALLOCATION_TRACKING)
-
-void iree_hal_amdgpu_device_trace_memory_alloc(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer,
-    iree_hal_amdgpu_trace_string_literal_ptr_t name_literal, uint64_t ptr,
-    uint64_t size);
-void iree_hal_amdgpu_device_trace_memory_free(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer,
-    iree_hal_amdgpu_trace_string_literal_ptr_t name_literal, uint64_t ptr);
-
-#endif  // IREE_HAL_AMDGPU_TRACING_FEATURE_ALLOCATION_TRACKING
-
-#if IREE_HAL_AMDGPU_HAS_TRACING_FEATURE( \
-    IREE_HAL_AMDGPU_TRACING_FEATURE_LOG_MESSAGES)
-
-void iree_hal_amdgpu_device_trace_message_literal(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer,
-    iree_hal_amdgpu_trace_color_t color,
-    iree_hal_amdgpu_trace_string_literal_ptr_t value_literal);
-void iree_hal_amdgpu_device_trace_message_dynamic(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer,
-    iree_hal_amdgpu_trace_color_t color, const char* IREE_AMDGPU_RESTRICT value,
-    size_t value_length);
-
-#endif  // IREE_HAL_AMDGPU_TRACING_FEATURE_LOG_MESSAGES
-
-#if IREE_HAL_AMDGPU_HAS_TRACING_FEATURE( \
-    IREE_HAL_AMDGPU_TRACING_FEATURE_DEBUG_MESSAGES)
-
-void iree_hal_amdgpu_device_trace_debug(
-    iree_hal_amdgpu_device_trace_buffer_t* IREE_AMDGPU_RESTRICT trace_buffer,
-    const char* IREE_AMDGPU_RESTRICT format, ...);
-
-#endif  // IREE_HAL_AMDGPU_TRACING_FEATURE_DEBUG_MESSAGES
-
-#endif  // IREE_AMDGPU_TARGET_DEVICE
-
-#endif  // IREE_HAL_DRIVERS_AMDGPU_DEVICE_TRACING_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device_queue.c b/runtime/src/iree/hal/drivers/amdgpu/device_queue.c
deleted file mode 100644
index 8f320c9..0000000
--- a/runtime/src/iree/hal/drivers/amdgpu/device_queue.c
+++ /dev/null
@@ -1,48 +0,0 @@
-// Copyright 2025 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/hal/drivers/amdgpu/device_queue.h"
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_device_queue_t
-//===----------------------------------------------------------------------===//
-
-typedef struct iree_hal_amdgpu_device_queue_t {
-  iree_hal_amdgpu_virtual_queue_t base;
-} iree_hal_amdgpu_device_queue_t;
-
-iree_host_size_t iree_hal_amdgpu_device_queue_calculate_size(
-    const iree_hal_amdgpu_queue_options_t* options) {
-  IREE_ASSERT_EQ(options->placement, IREE_HAL_AMDGPU_QUEUE_PLACEMENT_DEVICE);
-  return sizeof(iree_hal_amdgpu_device_queue_t);
-}
-
-iree_status_t iree_hal_amdgpu_device_queue_initialize(
-    iree_hal_amdgpu_system_t* system, iree_hal_amdgpu_queue_options_t options,
-    hsa_agent_t device_agent, iree_host_size_t device_ordinal,
-    iree_hal_amdgpu_host_service_t* host_service,
-    iree_arena_block_pool_t* host_block_pool,
-    iree_hal_amdgpu_block_allocators_t* block_allocators,
-    iree_hal_amdgpu_buffer_pool_t* buffer_pool,
-    iree_hal_amdgpu_error_callback_t error_callback,
-    hsa_signal_t initialization_signal, iree_allocator_t host_allocator,
-    iree_hal_amdgpu_virtual_queue_t* out_queue) {
-  IREE_ASSERT_ARGUMENT(system);
-  IREE_ASSERT_EQ(options.placement, IREE_HAL_AMDGPU_QUEUE_PLACEMENT_DEVICE);
-  IREE_ASSERT_ARGUMENT(host_service);
-  IREE_ASSERT_ARGUMENT(host_block_pool);
-  IREE_ASSERT_ARGUMENT(block_allocators);
-  IREE_ASSERT_ARGUMENT(buffer_pool);
-  IREE_ASSERT_ARGUMENT(out_queue);
-  IREE_TRACE_ZONE_BEGIN(z0);
-  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, device_ordinal);
-
-  iree_status_t status = iree_make_status(
-      IREE_STATUS_UNIMPLEMENTED, "device-side queuing not yet implemented");
-
-  IREE_TRACE_ZONE_END(z0);
-  return status;
-}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/device_queue.h b/runtime/src/iree/hal/drivers/amdgpu/device_queue.h
deleted file mode 100644
index e92b147..0000000
--- a/runtime/src/iree/hal/drivers/amdgpu/device_queue.h
+++ /dev/null
@@ -1,41 +0,0 @@
-// Copyright 2025 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_HAL_DRIVERS_AMDGPU_DEVICE_QUEUE_H_
-#define IREE_HAL_DRIVERS_AMDGPU_DEVICE_QUEUE_H_
-
-#include "iree/hal/drivers/amdgpu/util/error_callback.h"
-#include "iree/hal/drivers/amdgpu/virtual_queue.h"
-
-typedef struct iree_arena_block_pool_t iree_arena_block_pool_t;
-typedef struct iree_hal_amdgpu_block_allocators_t
-    iree_hal_amdgpu_block_allocators_t;
-typedef struct iree_hal_amdgpu_buffer_pool_t iree_hal_amdgpu_buffer_pool_t;
-typedef struct iree_hal_amdgpu_host_service_t iree_hal_amdgpu_host_service_t;
-typedef struct iree_hal_amdgpu_system_t iree_hal_amdgpu_system_t;
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_device_queue_t
-//===----------------------------------------------------------------------===//
-
-// Calculates the size in bytes of the storage required for a queue
-// implementation based on the provided |options|.
-iree_host_size_t iree_hal_amdgpu_device_queue_calculate_size(
-    const iree_hal_amdgpu_queue_options_t* options);
-
-// Initializes |out_queue| in-place based on |options|.
-iree_status_t iree_hal_amdgpu_device_queue_initialize(
-    iree_hal_amdgpu_system_t* system, iree_hal_amdgpu_queue_options_t options,
-    hsa_agent_t device_agent, iree_host_size_t device_ordinal,
-    iree_hal_amdgpu_host_service_t* host_service,
-    iree_arena_block_pool_t* host_block_pool,
-    iree_hal_amdgpu_block_allocators_t* block_allocators,
-    iree_hal_amdgpu_buffer_pool_t* buffer_pool,
-    iree_hal_amdgpu_error_callback_t error_callback,
-    hsa_signal_t initialization_signal, iree_allocator_t host_allocator,
-    iree_hal_amdgpu_virtual_queue_t* out_queue);
-
-#endif  // IREE_HAL_DRIVERS_AMDGPU_DEVICE_QUEUE_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/driver.c b/runtime/src/iree/hal/drivers/amdgpu/driver.c
index 0b90b0a..837ab3e 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/driver.c
+++ b/runtime/src/iree/hal/drivers/amdgpu/driver.c
@@ -6,7 +6,9 @@
 
 #include "iree/hal/drivers/amdgpu/driver.h"
 
+#include "iree/base/internal/debugging.h"
 #include "iree/hal/drivers/amdgpu/api.h"
+#include "iree/hal/drivers/amdgpu/logical_device.h"
 
 //===----------------------------------------------------------------------===//
 // iree_hal_amdgpu_driver_options_t
@@ -32,13 +34,17 @@
 IREE_API_EXPORT iree_status_t iree_hal_amdgpu_driver_options_parse(
     iree_hal_amdgpu_driver_options_t* options, iree_string_pair_list_t params) {
   IREE_ASSERT_ARGUMENT(options);
-  if (!params.count) return iree_ok_status();  // no-op
+  if (!params.count) return iree_ok_status();
   IREE_TRACE_ZONE_BEGIN(z0);
 
-  // TODO(benvanik): parameters.
+  const iree_string_pair_t* first_param = &params.pairs[0];
+  iree_status_t status = iree_make_status(
+      IREE_STATUS_INVALID_ARGUMENT,
+      "AMDGPU driver options do not support key/value parameter '%.*s'",
+      (int)first_param->key.size, first_param->key.data);
 
   IREE_TRACE_ZONE_END(z0);
-  return iree_ok_status();
+  return status;
 }
 
 static iree_status_t iree_hal_amdgpu_driver_options_verify(
@@ -46,11 +52,32 @@
   IREE_ASSERT_ARGUMENT(options);
   IREE_TRACE_ZONE_BEGIN(z0);
 
-  // TODO(benvanik): verify that the parameters are within expected ranges and
-  // any requested features are supported.
+  iree_status_t status =
+      iree_hal_amdgpu_logical_device_options_verify_supported_features(
+          &options->default_device_options);
+  if (iree_status_is_ok(status) && options->libhsa_search_paths.count &&
+      !options->libhsa_search_paths.values) {
+    status =
+        iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                         "AMDGPU libhsa search path list has count %" PRIhsz
+                         " but no value storage",
+                         options->libhsa_search_paths.count);
+  }
+  for (iree_host_size_t i = 0;
+       i < options->libhsa_search_paths.count && iree_status_is_ok(status);
+       ++i) {
+    const iree_string_view_t search_path =
+        options->libhsa_search_paths.values[i];
+    if (search_path.size && !search_path.data) {
+      status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                                "AMDGPU libhsa search path %" PRIhsz
+                                " has a nonzero length and no storage",
+                                i);
+    }
+  }
 
   IREE_TRACE_ZONE_END(z0);
-  return iree_ok_status();
+  return status;
 }
 
 //===----------------------------------------------------------------------===//
@@ -200,15 +227,20 @@
 //===----------------------------------------------------------------------===//
 
 typedef struct iree_hal_amdgpu_driver_t {
+  // HAL resource header.
   iree_hal_resource_t resource;
+  // Host allocator used for driver-owned allocations.
   iree_allocator_t host_allocator;
 
+  // Stable driver identifier stored inline after this struct.
   iree_string_view_t identifier;
+  // Driver options with retained string views pointing into trailing storage.
   iree_hal_amdgpu_driver_options_t options;
 
+  // Retained HSA runtime handle used for enumeration and device creation.
   iree_hal_amdgpu_libhsa_t libhsa;
 
-  // + trailing identifier string storage
+  // + trailing libhsa_search_paths table, identifier, and search path strings.
 } iree_hal_amdgpu_driver_t;
 
 static const iree_hal_driver_vtable_t iree_hal_amdgpu_driver_vtable;
@@ -245,26 +277,62 @@
   *out_driver = NULL;
   IREE_TRACE_ZONE_BEGIN(z0);
 
-  // TODO(benvanik): verify options; this may be moved after any libraries are
-  // loaded so the verification can use underlying implementation queries.
+  // Reject unsupported public options before loading HSA or touching vendor
+  // state. Device-dependent verification happens during logical-device create.
   IREE_RETURN_AND_END_ZONE_IF_ERROR(
       z0, iree_hal_amdgpu_driver_options_verify(options));
 
+  iree_host_size_t search_path_storage_size = 0;
+  for (iree_host_size_t i = 0; i < options->libhsa_search_paths.count; ++i) {
+    if (IREE_UNLIKELY(!iree_host_size_checked_add(
+            search_path_storage_size,
+            options->libhsa_search_paths.values[i].size,
+            &search_path_storage_size))) {
+      IREE_RETURN_AND_END_ZONE_IF_ERROR(
+          z0, iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                               "AMDGPU libhsa search path storage overflow"));
+    }
+  }
+
+  iree_host_size_t search_paths_offset = 0;
+  iree_host_size_t identifier_offset = 0;
+  iree_host_size_t search_path_storage_offset = 0;
+  iree_host_size_t total_size = 0;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, IREE_STRUCT_LAYOUT(
+              iree_sizeof_struct(iree_hal_amdgpu_driver_t), &total_size,
+              IREE_STRUCT_FIELD(options->libhsa_search_paths.count,
+                                iree_string_view_t, &search_paths_offset),
+              IREE_STRUCT_FIELD(identifier.size, char, &identifier_offset),
+              IREE_STRUCT_FIELD(search_path_storage_size, char,
+                                &search_path_storage_offset)));
+
   iree_hal_amdgpu_driver_t* driver = NULL;
-  iree_host_size_t total_size = sizeof(*driver) + identifier.size;
   IREE_RETURN_AND_END_ZONE_IF_ERROR(
       z0, iree_allocator_malloc(host_allocator, total_size, (void**)&driver));
+  memset(driver, 0, total_size);
   iree_hal_resource_initialize(&iree_hal_amdgpu_driver_vtable,
                                &driver->resource);
   driver->host_allocator = host_allocator;
-  iree_string_view_append_to_buffer(
-      identifier, &driver->identifier,
-      (char*)driver + total_size - identifier.size);
-
-  // TODO(benvanik): if there are any string fields then they will need to be
-  // retained as well (similar to the identifier they can be tagged on to the
-  // end of the driver struct).
+  iree_string_view_append_to_buffer(identifier, &driver->identifier,
+                                    (char*)driver + identifier_offset);
   memcpy(&driver->options, options, sizeof(*options));
+  if (options->libhsa_search_paths.count) {
+    iree_string_view_t* search_paths =
+        (iree_string_view_t*)((uint8_t*)driver + search_paths_offset);
+    char* search_path_storage = (char*)driver + search_path_storage_offset;
+    for (iree_host_size_t i = 0; i < options->libhsa_search_paths.count; ++i) {
+      const iree_string_view_t source_path =
+          options->libhsa_search_paths.values[i];
+      iree_string_view_append_to_buffer(source_path, &search_paths[i],
+                                        search_path_storage);
+      search_path_storage += source_path.size;
+    }
+    driver->options.libhsa_search_paths = (iree_string_view_list_t){
+        .count = options->libhsa_search_paths.count, .values = search_paths};
+  } else {
+    driver->options.libhsa_search_paths = iree_string_view_list_empty();
+  }
 
   // Load HSA. The HSA runtime shared library may already be loaded and we
   // retain a copy during creation to ensure it doesn't get unloaded.
@@ -301,13 +369,17 @@
   iree_hal_amdgpu_driver_t* driver = iree_hal_amdgpu_driver_cast(base_driver);
   *out_device_info_count = 0;
   *out_device_infos = NULL;
+  IREE_TRACE_ZONE_BEGIN(z0);
 
   // Query available devices based on the default configuration.
   iree_hal_amdgpu_topology_t topology;
-  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_topology_initialize_with_defaults(
-      &driver->libhsa, &topology));
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_hal_amdgpu_topology_initialize_with_defaults(&driver->libhsa,
+                                                            &topology));
   if (topology.gpu_agent_count == 0) {
-    return iree_ok_status();  // no devices
+    iree_hal_amdgpu_topology_deinitialize(&topology);
+    IREE_TRACE_ZONE_END(z0);
+    return iree_ok_status();
   }
 
   // Run the string builder in size calculation mode.
@@ -348,6 +420,7 @@
   }
 
   iree_string_builder_deinitialize(&builder);
+  iree_hal_amdgpu_topology_deinitialize(&topology);
 
   if (iree_status_is_ok(status)) {
     *out_device_info_count = device_info_count;
@@ -355,6 +428,7 @@
   } else {
     iree_allocator_free(host_allocator, device_infos);
   }
+  IREE_TRACE_ZONE_END(z0);
   return status;
 }
 
@@ -384,6 +458,8 @@
     const iree_hal_device_create_params_t* create_params,
     iree_allocator_t host_allocator, iree_hal_device_t** out_device) {
   iree_hal_amdgpu_driver_t* driver = iree_hal_amdgpu_driver_cast(base_driver);
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, device_id);
 
   // Use the provided params to overwrite the default options.
   // The format of the params is implementation-defined. The params strings can
@@ -391,11 +467,18 @@
   // access them during the create call below.
   iree_hal_amdgpu_logical_device_options_t options =
       driver->options.default_device_options;
-  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_logical_device_options_parse(
-      &options, (iree_string_pair_list_t){
-                    .count = param_count,
-                    .pairs = params,
-                }));
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_hal_amdgpu_logical_device_options_parse(
+              &options, (iree_string_pair_list_t){
+                            .count = param_count,
+                            .pairs = params,
+                        }));
+
+  // ROCR lazily allocates global singleton state during agent enumeration,
+  // memory pool queries, and ISA iteration. These allocations are never freed
+  // (intentional ROCR design). Suppress the resulting LSAN reports for the
+  // entire device creation sequence.
+  IREE_LEAK_CHECK_DISABLE_PUSH();
 
   // Initialize the topology based on the device ID.
   // The ID is a bitfield of device ordinals defined by ROCR_VISIBLE_DEVICES.
@@ -414,6 +497,9 @@
   }
 
   iree_hal_amdgpu_topology_deinitialize(&topology);
+
+  IREE_LEAK_CHECK_DISABLE_POP();
+  IREE_TRACE_ZONE_END(z0);
   return status;
 }
 
@@ -424,6 +510,8 @@
     const iree_hal_device_create_params_t* create_params,
     iree_allocator_t host_allocator, iree_hal_device_t** out_device) {
   iree_hal_amdgpu_driver_t* driver = iree_hal_amdgpu_driver_cast(base_driver);
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_TEXT(z0, device_path.data, device_path.size);
 
   // Use the provided params to overwrite the default options.
   // The format of the params is implementation-defined. The params strings can
@@ -431,17 +519,25 @@
   // access them during the create call below.
   iree_hal_amdgpu_logical_device_options_t options =
       driver->options.default_device_options;
-  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_logical_device_options_parse(
-      &options, (iree_string_pair_list_t){
-                    .count = param_count,
-                    .pairs = params,
-                }));
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_hal_amdgpu_logical_device_options_parse(
+              &options, (iree_string_pair_list_t){
+                            .count = param_count,
+                            .pairs = params,
+                        }));
 
   // Load HSA. HSA may already be loaded and we retain a copy during creation to
   // ensure it doesn't get unloaded.
   iree_hal_amdgpu_libhsa_t libhsa;
-  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_driver_load_libhsa(
-      &driver->options, host_allocator, &libhsa));
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_hal_amdgpu_driver_load_libhsa(&driver->options, host_allocator,
+                                             &libhsa));
+
+  // ROCR lazily allocates global singleton state during agent enumeration,
+  // memory pool queries, and ISA iteration. These allocations are never freed
+  // (intentional ROCR design). Suppress the resulting LSAN reports for the
+  // entire device creation sequence.
+  IREE_LEAK_CHECK_DISABLE_PUSH();
 
   // Initialize the topology with the given path. It may indicate multiple
   // devices and use different schemes to determine which devices are included.
@@ -458,6 +554,9 @@
 
   iree_hal_amdgpu_topology_deinitialize(&topology);
   iree_hal_amdgpu_libhsa_deinitialize(&libhsa);
+
+  IREE_LEAK_CHECK_DISABLE_POP();
+  IREE_TRACE_ZONE_END(z0);
   return status;
 }
 
diff --git a/runtime/src/iree/hal/drivers/amdgpu/driver_options_test.cc b/runtime/src/iree/hal/drivers/amdgpu/driver_options_test.cc
new file mode 100644
index 0000000..9ae7cf8
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/driver_options_test.cc
@@ -0,0 +1,133 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <cstdint>
+
+#include "iree/hal/drivers/amdgpu/api.h"
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace iree::hal::amdgpu {
+namespace {
+
+TEST(AmdgpuDriverOptionsTest, DriverParamsAreRejectedUntilDefined) {
+  iree_hal_amdgpu_driver_options_t options;
+  iree_hal_amdgpu_driver_options_initialize(&options);
+  const iree_string_pair_t pair = iree_make_cstring_pair("unknown", "value");
+
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_INVALID_ARGUMENT,
+      iree_hal_amdgpu_driver_options_parse(&options, (iree_string_pair_list_t){
+                                                         .count = 1,
+                                                         .pairs = &pair,
+                                                     }));
+}
+
+TEST(AmdgpuDriverOptionsTest, LogicalDeviceParamsAreRejectedUntilDefined) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  const iree_string_pair_t pair = iree_make_cstring_pair("unknown", "value");
+
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_INVALID_ARGUMENT,
+                        iree_hal_amdgpu_logical_device_options_parse(
+                            &options, (iree_string_pair_list_t){
+                                          .count = 1,
+                                          .pairs = &pair,
+                                      }));
+}
+
+TEST(AmdgpuDriverOptionsTest, RejectsMissingSearchPathStorageBeforeLoadingHsa) {
+  iree_hal_amdgpu_driver_options_t options;
+  iree_hal_amdgpu_driver_options_initialize(&options);
+  options.libhsa_search_paths = (iree_string_view_list_t){
+      .count = 1,
+      .values = NULL,
+  };
+
+  iree_hal_driver_t* driver = NULL;
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_INVALID_ARGUMENT,
+      iree_hal_amdgpu_driver_create(IREE_SV("amdgpu"), &options,
+                                    iree_allocator_system(), &driver));
+  iree_hal_driver_release(driver);
+}
+
+TEST(AmdgpuDriverOptionsTest, RejectsMissingSearchPathDataBeforeLoadingHsa) {
+  iree_hal_amdgpu_driver_options_t options;
+  iree_hal_amdgpu_driver_options_initialize(&options);
+  const iree_string_view_t search_path = iree_make_string_view(NULL, 1);
+  options.libhsa_search_paths = (iree_string_view_list_t){
+      .count = 1,
+      .values = &search_path,
+  };
+
+  iree_hal_driver_t* driver = NULL;
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_INVALID_ARGUMENT,
+      iree_hal_amdgpu_driver_create(IREE_SV("amdgpu"), &options,
+                                    iree_allocator_system(), &driver));
+  iree_hal_driver_release(driver);
+}
+
+static iree_status_t CreateDriverWithDefaultDeviceOptions(
+    const iree_hal_amdgpu_logical_device_options_t* device_options) {
+  iree_hal_amdgpu_driver_options_t options;
+  iree_hal_amdgpu_driver_options_initialize(&options);
+  options.default_device_options = *device_options;
+  iree_hal_driver_t* driver = NULL;
+  iree_status_t status = iree_hal_amdgpu_driver_create(
+      IREE_SV("amdgpu"), &options, iree_allocator_system(), &driver);
+  iree_hal_driver_release(driver);
+  return status;
+}
+
+TEST(AmdgpuDriverOptionsTest, RejectsDeviceQueuePlacementBeforeLoadingHsa) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.queue_placement = IREE_HAL_AMDGPU_QUEUE_PLACEMENT_DEVICE;
+
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_UNIMPLEMENTED,
+                        CreateDriverWithDefaultDeviceOptions(&options));
+}
+
+TEST(AmdgpuDriverOptionsTest, RejectsInvalidQueuePlacementBeforeLoadingHsa) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.queue_placement = (iree_hal_amdgpu_queue_placement_t)99;
+
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_INVALID_ARGUMENT,
+                        CreateDriverWithDefaultDeviceOptions(&options));
+}
+
+TEST(AmdgpuDriverOptionsTest, RejectsExclusiveExecutionBeforeLoadingHsa) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.exclusive_execution = 1;
+
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_UNIMPLEMENTED,
+                        CreateDriverWithDefaultDeviceOptions(&options));
+}
+
+TEST(AmdgpuDriverOptionsTest, RejectsNegativeActiveWaitBeforeLoadingHsa) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.wait_active_for_ns = -1;
+
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_OUT_OF_RANGE,
+                        CreateDriverWithDefaultDeviceOptions(&options));
+}
+
+TEST(AmdgpuDriverOptionsTest, RejectsActiveWaitBeforeLoadingHsa) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.wait_active_for_ns = 1;
+
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_UNIMPLEMENTED,
+                        CreateDriverWithDefaultDeviceOptions(&options));
+}
+
+}  // namespace
+}  // namespace iree::hal::amdgpu
diff --git a/runtime/src/iree/hal/drivers/amdgpu/event.c b/runtime/src/iree/hal/drivers/amdgpu/event.c
deleted file mode 100644
index 47193d4..0000000
--- a/runtime/src/iree/hal/drivers/amdgpu/event.c
+++ /dev/null
@@ -1,64 +0,0 @@
-// Copyright 2025 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/hal/drivers/amdgpu/event.h"
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_event_t
-//===----------------------------------------------------------------------===//
-
-typedef struct iree_hal_amdgpu_event_t {
-  iree_hal_resource_t resource;
-  iree_allocator_t host_allocator;
-} iree_hal_amdgpu_event_t;
-
-static const iree_hal_event_vtable_t iree_hal_amdgpu_event_vtable;
-
-static iree_hal_amdgpu_event_t* iree_hal_amdgpu_event_cast(
-    iree_hal_event_t* base_value) {
-  IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_amdgpu_event_vtable);
-  return (iree_hal_amdgpu_event_t*)base_value;
-}
-
-iree_status_t iree_hal_amdgpu_event_create(
-    iree_hal_queue_affinity_t queue_affinity, iree_hal_event_flags_t flags,
-    iree_allocator_t host_allocator, iree_hal_event_t** out_event) {
-  IREE_ASSERT_ARGUMENT(out_event);
-  *out_event = NULL;
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  iree_hal_amdgpu_event_t* event = NULL;
-  IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0,
-      iree_allocator_malloc(host_allocator, sizeof(*event), (void**)&event));
-  iree_hal_resource_initialize(&iree_hal_amdgpu_event_vtable, &event->resource);
-  event->host_allocator = host_allocator;
-
-  // TODO(benvanik): WIP API; this is a no-op today.
-  iree_status_t status = iree_ok_status();
-
-  if (iree_status_is_ok(status)) {
-    *out_event = (iree_hal_event_t*)event;
-  } else {
-    iree_hal_event_release((iree_hal_event_t*)event);
-  }
-  IREE_TRACE_ZONE_END(z0);
-  return status;
-}
-
-static void iree_hal_amdgpu_event_destroy(iree_hal_event_t* base_event) {
-  iree_hal_amdgpu_event_t* event = iree_hal_amdgpu_event_cast(base_event);
-  iree_allocator_t host_allocator = event->host_allocator;
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  iree_allocator_free(host_allocator, event);
-
-  IREE_TRACE_ZONE_END(z0);
-}
-
-static const iree_hal_event_vtable_t iree_hal_amdgpu_event_vtable = {
-    .destroy = iree_hal_amdgpu_event_destroy,
-};
diff --git a/runtime/src/iree/hal/drivers/amdgpu/event.h b/runtime/src/iree/hal/drivers/amdgpu/event.h
deleted file mode 100644
index ee70e0b..0000000
--- a/runtime/src/iree/hal/drivers/amdgpu/event.h
+++ /dev/null
@@ -1,22 +0,0 @@
-// Copyright 2025 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_HAL_DRIVERS_AMDGPU_EVENT_H_
-#define IREE_HAL_DRIVERS_AMDGPU_EVENT_H_
-
-#include "iree/base/api.h"
-#include "iree/hal/api.h"
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_event_t
-//===----------------------------------------------------------------------===//
-
-// WIP API and may change. Mostly ignored for now.
-iree_status_t iree_hal_amdgpu_event_create(
-    iree_hal_queue_affinity_t queue_affinity, iree_hal_event_flags_t flags,
-    iree_allocator_t host_allocator, iree_hal_event_t** out_event);
-
-#endif  // IREE_HAL_DRIVERS_AMDGPU_EVENT_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/executable.c b/runtime/src/iree/hal/drivers/amdgpu/executable.c
index c6308dc..d056869 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/executable.c
+++ b/runtime/src/iree/hal/drivers/amdgpu/executable.c
@@ -7,8 +7,14 @@
 #include "iree/hal/drivers/amdgpu/executable.h"
 
 #include "iree/base/internal/debugging.h"
+#include "iree/hal/drivers/amdgpu/queue_affinity.h"
+#include "iree/hal/drivers/amdgpu/util/code_object_target.h"
+#include "iree/hal/drivers/amdgpu/util/hsaco_metadata.h"
+#include "iree/hal/drivers/amdgpu/util/kernarg_ring.h"
+#include "iree/hal/drivers/amdgpu/util/target_id.h"
 #include "iree/hal/drivers/amdgpu/util/topology.h"
 #include "iree/hal/drivers/amdgpu/util/vmem.h"
+#include "iree/hal/utils/elf_format.h"
 #include "iree/hal/utils/executable_debug_info.h"
 #include "iree/hal/utils/executable_header.h"
 
@@ -17,25 +23,28 @@
 #include "iree/schemas/amdgpu_executable_def_reader.h"
 #include "iree/schemas/amdgpu_executable_def_verifier.h"
 
-// TODO(benvanik): replace with include when device-side tracing imported.
-// #include "iree/hal/drivers/amdgpu/device/tracing.h"
-typedef uint32_t iree_hal_amdgpu_trace_color_t;
-typedef struct iree_hal_amdgpu_trace_src_loc_t {
-  const char* name;
-  const char* function;
-  const char* file;
-  uint32_t line;
-  iree_hal_amdgpu_trace_color_t color;
-} iree_hal_amdgpu_trace_src_loc_t;
-
 //===----------------------------------------------------------------------===//
 // ISA Support
 //===----------------------------------------------------------------------===//
 
 typedef struct iree_hal_amdgpu_agent_available_isas_t {
+  // Number of valid entries in |values|.
   iree_host_size_t count;
+  // Fixed-capacity ISA list populated by HSA iteration callbacks.
   hsa_isa_t values[32];
 } iree_hal_amdgpu_agent_available_isas_t;
+
+typedef struct iree_hal_amdgpu_agent_isa_target_t {
+  // HSA ISA handle the target identity was queried from.
+  hsa_isa_t isa;
+  // NUL-terminated HSA ISA name storage.
+  char name_buffer[64 + /*NUL*/ 1];
+  // Borrowed view into |name_buffer| excluding the NUL terminator.
+  iree_string_view_t name;
+  // Parsed target identity borrowing processor text from |name_buffer|.
+  iree_hal_amdgpu_target_id_t target_id;
+} iree_hal_amdgpu_agent_isa_target_t;
+
 static hsa_status_t iree_hal_amdgpu_iterate_agent_isa(hsa_isa_t isa,
                                                       void* user_data) {
   iree_hal_amdgpu_agent_available_isas_t* isas =
@@ -47,29 +56,96 @@
   return HSA_STATUS_SUCCESS;
 }
 
-static iree_status_t iree_hal_amdgpu_verify_isas_equal(
-    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_isa_t isa_a, hsa_isa_t isa_b) {
-  uint32_t name_length_a = 0;
+static iree_status_t iree_hal_amdgpu_query_agent_available_isas(
+    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_agent_t device_agent,
+    iree_hal_amdgpu_agent_available_isas_t* out_available_isas) {
+  memset(out_available_isas, 0, sizeof(*out_available_isas));
+  return iree_hsa_agent_iterate_isas(IREE_LIBHSA(libhsa), device_agent,
+                                     iree_hal_amdgpu_iterate_agent_isa,
+                                     out_available_isas);
+}
+
+static iree_status_t iree_hal_amdgpu_query_agent_isa_target(
+    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_isa_t isa,
+    iree_hal_amdgpu_agent_isa_target_t* out_isa_target) {
+  memset(out_isa_target, 0, sizeof(*out_isa_target));
+  out_isa_target->isa = isa;
+
+  uint32_t name_length = 0;
   IREE_RETURN_IF_ERROR(iree_hsa_isa_get_info_alt(
-      IREE_LIBHSA(libhsa), isa_a, HSA_ISA_INFO_NAME_LENGTH, &name_length_a));
-  uint32_t name_length_b = 0;
-  IREE_RETURN_IF_ERROR(iree_hsa_isa_get_info_alt(
-      IREE_LIBHSA(libhsa), isa_a, HSA_ISA_INFO_NAME_LENGTH, &name_length_b));
-  char name_a[64 + /*NUL*/ 1] = {0};
-  char name_b[64 + /*NUL*/ 1] = {0};
-  if (name_length_a > IREE_ARRAYSIZE(name_a) ||
-      name_length_b > IREE_ARRAYSIZE(name_b)) {
+      IREE_LIBHSA(libhsa), isa, HSA_ISA_INFO_NAME_LENGTH, &name_length));
+  if (name_length == 0 ||
+      name_length > IREE_ARRAYSIZE(out_isa_target->name_buffer)) {
     return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
-                            "ISA name too long");
+                            "ISA name invalid (empty or too long: %u)",
+                            name_length);
   }
-  IREE_RETURN_IF_ERROR(iree_hsa_isa_get_info_alt(IREE_LIBHSA(libhsa), isa_a,
-                                                 HSA_ISA_INFO_NAME, &name_a));
-  IREE_RETURN_IF_ERROR(iree_hsa_isa_get_info_alt(IREE_LIBHSA(libhsa), isa_a,
-                                                 HSA_ISA_INFO_NAME, &name_b));
-  if (name_length_a != name_length_b ||
-      memcmp(name_a, name_b, name_length_a) != 0) {
+
+  IREE_RETURN_IF_ERROR(iree_hsa_isa_get_info_alt(IREE_LIBHSA(libhsa), isa,
+                                                 HSA_ISA_INFO_NAME,
+                                                 out_isa_target->name_buffer));
+  out_isa_target->name = iree_make_string_view(out_isa_target->name_buffer,
+                                               name_length - /*NUL*/ 1);
+  return iree_hal_amdgpu_target_id_parse_hsa_isa_name(
+      out_isa_target->name, &out_isa_target->target_id);
+}
+
+static iree_status_t iree_hal_amdgpu_verify_isas_equal(
+    const iree_hal_amdgpu_libhsa_t* libhsa, iree_host_size_t agent_a_ordinal,
+    hsa_isa_t isa_a, iree_host_size_t agent_b_ordinal, hsa_isa_t isa_b,
+    iree_host_size_t isa_ordinal) {
+  iree_hal_amdgpu_agent_isa_target_t target_a;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_query_agent_isa_target(libhsa, isa_a, &target_a));
+  iree_hal_amdgpu_agent_isa_target_t target_b;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_query_agent_isa_target(libhsa, isa_b, &target_b));
+
+  iree_hal_amdgpu_target_compatibility_t mismatch =
+      IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_COMPATIBLE;
+  if (target_a.target_id.kind != target_b.target_id.kind ||
+      target_a.target_id.version.major != target_b.target_id.version.major ||
+      target_a.target_id.version.minor != target_b.target_id.version.minor ||
+      target_a.target_id.version.stepping !=
+          target_b.target_id.version.stepping) {
+    mismatch |= IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_MISMATCH_PROCESSOR;
+  }
+  if ((target_a.target_id.kind == IREE_HAL_AMDGPU_TARGET_KIND_GENERIC ||
+       target_b.target_id.kind == IREE_HAL_AMDGPU_TARGET_KIND_GENERIC) &&
+      !iree_string_view_equal(target_a.target_id.processor,
+                              target_b.target_id.processor)) {
+    mismatch |= IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_MISMATCH_GENERIC_FAMILY;
+  }
+  if (target_a.target_id.generic_version !=
+      target_b.target_id.generic_version) {
+    mismatch |= IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_MISMATCH_GENERIC_VERSION;
+  }
+  if (target_a.target_id.sramecc != target_b.target_id.sramecc) {
+    mismatch |= IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_MISMATCH_SRAMECC;
+  }
+  if (target_a.target_id.xnack != target_b.target_id.xnack) {
+    mismatch |= IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_MISMATCH_XNACK;
+  }
+  if (mismatch != IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_COMPATIBLE) {
+    char target_a_string[128] = {0};
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_target_id_format(
+        &target_a.target_id, sizeof(target_a_string), target_a_string,
+        /*out_buffer_length=*/NULL));
+    char target_b_string[128] = {0};
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_target_id_format(
+        &target_b.target_id, sizeof(target_b_string), target_b_string,
+        /*out_buffer_length=*/NULL));
+    char mismatch_reasons[128] = {0};
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_target_compatibility_format(
+        mismatch, sizeof(mismatch_reasons), mismatch_reasons,
+        /*out_buffer_length=*/NULL));
     return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
-                            "ISAs do not match: `%s` != `%s`", name_a, name_b);
+                            "GPU agent[%" PRIhsz "] ISA[%" PRIhsz
+                            "] target `%s` does not match GPU agent[%" PRIhsz
+                            "] ISA[%" PRIhsz "] target `%s` (mismatched %s)",
+                            agent_a_ordinal, isa_ordinal, target_a_string,
+                            agent_b_ordinal, isa_ordinal, target_b_string,
+                            mismatch_reasons);
   }
   return iree_ok_status();
 }
@@ -85,21 +161,17 @@
   // Query all available ISAs supported by the first GPU agent.
   // We'll use this to compare with all other GPU agents.
   iree_hal_amdgpu_agent_available_isas_t expected_isas;
-  memset(&expected_isas, 0, sizeof(expected_isas));
   IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0, iree_hsa_agent_iterate_isas(
-              IREE_LIBHSA(libhsa), topology->gpu_agents[0],
-              iree_hal_amdgpu_iterate_agent_isa, &expected_isas));
+      z0, iree_hal_amdgpu_query_agent_available_isas(
+              libhsa, topology->gpu_agents[0], &expected_isas));
 
   // For all subsequent GPU agents ensure their ISAs match.
   for (iree_host_size_t i = 1; i < topology->gpu_agent_count; ++i) {
     // Get ISAs supported by this agent.
     iree_hal_amdgpu_agent_available_isas_t available_isas;
-    memset(&available_isas, 0, sizeof(available_isas));
     IREE_RETURN_AND_END_ZONE_IF_ERROR(
-        z0, iree_hsa_agent_iterate_isas(
-                IREE_LIBHSA(libhsa), topology->gpu_agents[i],
-                iree_hal_amdgpu_iterate_agent_isa, &available_isas));
+        z0, iree_hal_amdgpu_query_agent_available_isas(
+                libhsa, topology->gpu_agents[i], &available_isas));
 
     // Ensure ISAs match.
     // We could be less strict here and require only one matching ISA that we
@@ -116,8 +188,10 @@
     }
     for (iree_host_size_t j = 0; j < expected_isas.count; ++j) {
       IREE_RETURN_AND_END_ZONE_IF_ERROR(
-          z0, iree_hal_amdgpu_verify_isas_equal(libhsa, expected_isas.values[j],
-                                                available_isas.values[j]));
+          z0, iree_hal_amdgpu_verify_isas_equal(
+                  libhsa, /*agent_a_ordinal=*/0, expected_isas.values[j],
+                  /*agent_b_ordinal=*/i, available_isas.values[j],
+                  /*isa_ordinal=*/j));
     }
   }
 
@@ -132,48 +206,34 @@
   *out_supported = false;
   if (out_isa) out_isa->handle = 0;
 
-  // Strip hsa-* prefix.
-  if (!iree_string_view_starts_with(
-          format, iree_make_cstring_view("amdgcn-amd-amdhsa-"))) {
-    // Not HSA-like.
-    *out_supported = false;
+  if (!iree_string_view_starts_with(format, IREE_SV("gfx")) &&
+      !iree_string_view_starts_with(format, IREE_SV("amdgcn-amd-amdhsa--"))) {
     return iree_ok_status();
   }
 
+  iree_hal_amdgpu_target_id_t format_target_id;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_target_id_parse(
+      format,
+      IREE_HAL_AMDGPU_TARGET_ID_PARSE_FLAG_ALLOW_HSA_PREFIX |
+          IREE_HAL_AMDGPU_TARGET_ID_PARSE_FLAG_ALLOW_ARCH_ONLY |
+          IREE_HAL_AMDGPU_TARGET_ID_PARSE_FLAG_ALLOW_FEATURE_SUFFIXES,
+      &format_target_id));
+
   // Query all available ISAs supported by any GPU agent.
   // This list is ordered by descending priority.
   iree_hal_amdgpu_agent_available_isas_t available_isas;
-  memset(&available_isas, 0, sizeof(available_isas));
-  IREE_RETURN_IF_ERROR(iree_hsa_agent_iterate_isas(
-      IREE_LIBHSA(libhsa), device_agent, iree_hal_amdgpu_iterate_agent_isa,
-      &available_isas));
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_query_agent_available_isas(
+      libhsa, device_agent, &available_isas));
 
   for (iree_host_size_t i = 0; i < available_isas.count; ++i) {
-    // Get the ISA name - it'll be something like `amdgcn-amd-amdhsa--gfx1100`
-    // for some reason. Note that the docs for HSA_ISA_INFO_NAME_LENGTH say it
-    // doesn't include the NUL terminator but it definitely does - for our
-    // proper string view usage we have to subtract one. HSA sets length+1 to
-    // NUL so we must ensure we have sufficient space.
-    hsa_isa_t isa = available_isas.values[i];
-    char isa_name_buffer[64 + /*NUL*/ 1];
-    uint32_t isa_name_length = 0;
-    IREE_RETURN_IF_ERROR(iree_hsa_isa_get_info_alt(
-        IREE_LIBHSA(libhsa), isa, HSA_ISA_INFO_NAME_LENGTH, &isa_name_length));
-    if (isa_name_length == 0 ||
-        isa_name_length > IREE_ARRAYSIZE(isa_name_buffer)) {
-      return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
-                              "ISA name invalid (empty or too long: %u)",
-                              isa_name_length);
-    }
-    IREE_RETURN_IF_ERROR(iree_hsa_isa_get_info_alt(
-        IREE_LIBHSA(libhsa), isa, HSA_ISA_INFO_NAME, isa_name_buffer));
-    iree_string_view_t isa_name =
-        iree_make_string_view(isa_name_buffer, isa_name_length - /*NUL*/ 1);
-
-    // Compare exactly.
-    if (iree_string_view_equal(format, isa_name)) {
+    iree_hal_amdgpu_agent_isa_target_t isa_target;
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_query_agent_isa_target(
+        libhsa, available_isas.values[i], &isa_target));
+    if (iree_hal_amdgpu_target_id_check_compatible(&format_target_id,
+                                                   &isa_target.target_id) ==
+        IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_COMPATIBLE) {
       *out_supported = true;
-      if (out_isa) *out_isa = isa;
+      if (out_isa) *out_isa = isa_target.isa;
       return iree_ok_status();
     }
   }
@@ -189,9 +249,9 @@
 //===----------------------------------------------------------------------===//
 
 typedef struct iree_hal_amdgpu_device_limits_t {
-  // HSA_ISA_INFO_WORKGROUP_MAX_SIZE
+  // Maximum total workgroup size from HSA_ISA_INFO_WORKGROUP_MAX_SIZE.
   uint32_t max_workgroup_size;
-  // HSA_ISA_INFO_WORKGROUP_MAX_DIM
+  // Maximum workgroup size per dimension from HSA_ISA_INFO_WORKGROUP_MAX_DIM.
   uint16_t max_workgroup_size_per_dim[3];
 } iree_hal_amdgpu_device_limits_t;
 static iree_status_t iree_hal_amdgpu_query_device_limits(
@@ -213,46 +273,129 @@
   return iree_ok_status();
 }
 
+static bool iree_hal_amdgpu_executable_data_is_wrapped_flatbuffer(
+    iree_const_byte_span_t executable_data) {
+  if (executable_data.data_length != 0 &&
+      executable_data.data_length < sizeof(uint32_t)) {
+    return false;
+  }
+  iree_const_byte_span_t identifier_data =
+      iree_make_const_byte_span(executable_data.data, sizeof(uint32_t));
+  if (iree_const_byte_span_is_empty(identifier_data)) {
+    return false;
+  }
+  return memcmp(identifier_data.data,
+                iree_hal_amdgpu_ExecutableDef_file_identifier,
+                identifier_data.data_length) == 0;
+}
+
+static iree_status_t iree_hal_amdgpu_executable_get_single_module_image(
+    iree_hal_amdgpu_ModuleDef_vec_t module_defs,
+    iree_const_byte_span_t* out_code_object_data) {
+  *out_code_object_data = iree_const_byte_span_empty();
+
+  // Today we require a single module. We could support multiple and link them
+  // together by loading code objects in the order specified. This could be
+  // useful if we ever made our own fat binaries or wanted to reuse shared ELFs
+  // across multiple executables by having them reference the same ranges in a
+  // larger file.
+  const iree_host_size_t module_count =
+      iree_hal_amdgpu_ModuleDef_vec_len(module_defs);
+  if (module_count != 1) {
+    return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                            "only a single ModuleDef per ExecutableDef is "
+                            "supported; executable declares %" PRIhsz
+                            " modules",
+                            module_count);
+  }
+  iree_hal_amdgpu_ModuleDef_table_t module_def =
+      iree_hal_amdgpu_ModuleDef_vec_at(module_defs, 0);
+  if (!module_def) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "module is NULL");
+  }
+  flatbuffers_string_t image = iree_hal_amdgpu_ModuleDef_image_get(module_def);
+  if (!image) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "module image is empty");
+  }
+  const iree_host_size_t image_size = flatbuffers_string_len(image);
+  if (image_size == 0) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "module image is empty");
+  }
+  *out_code_object_data =
+      iree_make_const_byte_span((const uint8_t*)image, image_size);
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_executable_format_from_target_id(
+    const iree_hal_amdgpu_target_id_t* target_id,
+    iree_host_size_t executable_format_capacity, char* executable_format) {
+  return iree_hal_amdgpu_target_id_format(target_id, executable_format_capacity,
+                                          executable_format,
+                                          /*out_buffer_length=*/NULL);
+}
+
 iree_status_t iree_hal_amdgpu_executable_infer_format(
     iree_const_byte_span_t executable_data,
     iree_host_size_t executable_format_capacity, char* executable_format,
-    iree_host_size_t* out_inferred_size) {
+    iree_allocator_t host_allocator, iree_host_size_t* out_inferred_size) {
+  (void)host_allocator;
+  const bool is_wrapped_flatbuffer =
+      iree_hal_amdgpu_executable_data_is_wrapped_flatbuffer(executable_data);
+
   // Read the header prefix (with unsafe inference if size is unknown).
   const bool unsafe_infer_size = (executable_data.data_length == 0);
   iree_const_byte_span_t flatbuffer_data = iree_const_byte_span_empty();
-  IREE_RETURN_IF_ERROR(iree_hal_read_executable_flatbuffer_header(
-      executable_data, unsafe_infer_size,
-      iree_hal_amdgpu_ExecutableDef_file_identifier, &flatbuffer_data));
+  if (is_wrapped_flatbuffer) {
+    IREE_RETURN_IF_ERROR(iree_hal_read_executable_flatbuffer_header(
+        executable_data, unsafe_infer_size,
+        iree_hal_amdgpu_ExecutableDef_file_identifier, &flatbuffer_data));
 
-  // Verify the flatbuffer structure.
-  if (!iree_hal_amdgpu_ExecutableDef_verify_as_root(
-          flatbuffer_data.data, flatbuffer_data.data_length)) {
-    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                            "failed to verify executable flatbuffer structure");
+    // Verify the flatbuffer structure.
+    const int verify_ret = iree_hal_amdgpu_ExecutableDef_verify_as_root(
+        flatbuffer_data.data, flatbuffer_data.data_length);
+    if (verify_ret != flatcc_verify_ok) {
+      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                              "flatbuffer verification failed: %s",
+                              flatcc_verify_error_string(verify_ret));
+    }
+
+    iree_hal_amdgpu_ExecutableDef_table_t executable_def =
+        iree_hal_amdgpu_ExecutableDef_as_root(flatbuffer_data.data);
+    iree_const_byte_span_t code_object_data = iree_const_byte_span_empty();
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_executable_get_single_module_image(
+        iree_hal_amdgpu_ExecutableDef_modules_get(executable_def),
+        &code_object_data));
+
+    iree_hal_amdgpu_target_id_t code_object_target_id;
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_code_object_target_id_from_elf(
+        code_object_data, &code_object_target_id));
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_executable_format_from_target_id(
+        &code_object_target_id, executable_format_capacity, executable_format));
+
+    // Return the total size (header + flatbuffer).
+    *out_inferred_size =
+        sizeof(iree_flatbuffer_file_header_t) + flatbuffer_data.data_length;
+    return iree_ok_status();
+  } else {
+    iree_const_byte_span_t hsaco_data = executable_data;
+    if (unsafe_infer_size) {
+      iree_host_size_t hsaco_size = 0;
+      IREE_RETURN_IF_ERROR(iree_hal_elf_calculate_size(hsaco_data, &hsaco_size),
+                           "calculating raw HSACO ELF size");
+      hsaco_data = iree_make_const_byte_span(executable_data.data, hsaco_size);
+    }
+
+    iree_hal_amdgpu_target_id_t code_object_target_id;
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_code_object_target_id_from_elf(
+        hsaco_data, &code_object_target_id));
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_executable_format_from_target_id(
+        &code_object_target_id, executable_format_capacity, executable_format));
+
+    *out_inferred_size = hsaco_data.data_length;
+    return iree_ok_status();
   }
-
-  // Get the ISA name from the flatbuffer.
-  iree_hal_amdgpu_ExecutableDef_table_t executable_def =
-      iree_hal_amdgpu_ExecutableDef_as_root(flatbuffer_data.data);
-  flatbuffers_string_t isa =
-      iree_hal_amdgpu_ExecutableDef_isa_get(executable_def);
-  if (!isa) {
-    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                            "executable missing target_arch");
-  }
-
-  // Write the format string (ISA name).
-  const iree_host_size_t isa_length = flatbuffers_string_len(isa);
-  if (isa_length >= executable_format_capacity) {
-    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
-                            "executable format buffer too small");
-  }
-  memcpy(executable_format, isa, isa_length + /*NUL*/ 1);
-
-  // Return the total size (header + flatbuffer).
-  *out_inferred_size =
-      sizeof(iree_flatbuffer_file_header_t) + flatbuffer_data.data_length;
-  return iree_ok_status();
 }
 
 // Verifies the structure of the flatbuffer.
@@ -327,13 +470,13 @@
             limits->max_workgroup_size_per_dim[1],
             limits->max_workgroup_size_per_dim[2]);
       }
-      const uint32_t total_workgroup_size =
-          workgroup_size->x * workgroup_size->y * workgroup_size->z;
+      const uint64_t total_workgroup_size =
+          (uint64_t)workgroup_size->x * workgroup_size->y * workgroup_size->z;
       if (total_workgroup_size > limits->max_workgroup_size) {
         return iree_make_status(
             IREE_STATUS_INVALID_ARGUMENT,
-            "exports[%" PRIhsz
-            "] workgroup size total %u exceeds device maximum %u",
+            "exports[%" PRIhsz "] workgroup size total %" PRIu64
+            " exceeds device maximum %u",
             i, total_workgroup_size, limits->max_workgroup_size);
       }
     } else {
@@ -373,31 +516,128 @@
 // Executable Loading
 //===----------------------------------------------------------------------===//
 
-// Loads an executable ELF from memory for all agents in |topology| and stores
-// the frozen executable in |out_handle|.
-static iree_status_t iree_hal_amdgpu_executable_load_modules(
+static bool iree_hal_amdgpu_physical_device_mask_contains(
+    uint64_t physical_device_mask, iree_host_size_t physical_device_ordinal) {
+  return physical_device_ordinal < IREE_HAL_MAX_QUEUES &&
+         iree_all_bits_set(physical_device_mask,
+                           ((uint64_t)1) << physical_device_ordinal);
+}
+
+static iree_status_t iree_hal_amdgpu_executable_select_physical_devices(
+    const iree_hal_amdgpu_topology_t* topology,
+    iree_hal_queue_affinity_t requested_affinity,
+    iree_hal_amdgpu_queue_affinity_physical_device_set_t*
+        out_physical_devices) {
+  memset(out_physical_devices, 0, sizeof(*out_physical_devices));
+
+  iree_hal_amdgpu_queue_affinity_domain_t queue_affinity_domain = {
+      .supported_affinity = 0,
+      .physical_device_count = topology->gpu_agent_count,
+      .queue_count_per_physical_device = topology->gpu_agent_queue_count,
+  };
+
+  for (iree_host_size_t physical_device_ordinal = 0;
+       physical_device_ordinal < topology->gpu_agent_count;
+       ++physical_device_ordinal) {
+    iree_hal_queue_affinity_t physical_device_affinity = 0;
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_queue_affinity_for_physical_device(
+        queue_affinity_domain, physical_device_ordinal,
+        &physical_device_affinity));
+    iree_hal_queue_affinity_or_into(queue_affinity_domain.supported_affinity,
+                                    physical_device_affinity);
+  }
+
+  return iree_hal_amdgpu_queue_affinity_select_physical_devices(
+      queue_affinity_domain, requested_affinity, out_physical_devices);
+}
+
+static iree_status_t iree_hal_amdgpu_executable_format_target_id_for_message(
+    const iree_hal_amdgpu_target_id_t* target_id,
+    iree_host_size_t buffer_capacity, char* buffer) {
+  return iree_hal_amdgpu_target_id_format(target_id, buffer_capacity, buffer,
+                                          /*out_buffer_length=*/NULL);
+}
+
+static iree_status_t iree_hal_amdgpu_executable_preflight_agent_code_object(
+    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_agent_t device_agent,
+    iree_host_size_t physical_device_ordinal,
+    const iree_hal_amdgpu_target_id_t* code_object_target_id) {
+  iree_hal_amdgpu_agent_available_isas_t available_isas;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_query_agent_available_isas(
+      libhsa, device_agent, &available_isas));
+  if (available_isas.count == 0) {
+    return iree_make_status(IREE_STATUS_INCOMPATIBLE,
+                            "GPU agent[%" PRIhsz
+                            "] reports no AMDGPU ISA targets",
+                            physical_device_ordinal);
+  }
+
+  iree_hal_amdgpu_target_compatibility_t first_mismatch =
+      IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_COMPATIBLE;
+  char first_agent_target[128] = {0};
+  for (iree_host_size_t i = 0; i < available_isas.count; ++i) {
+    iree_hal_amdgpu_agent_isa_target_t isa_target;
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_query_agent_isa_target(
+        libhsa, available_isas.values[i], &isa_target));
+    const iree_hal_amdgpu_target_compatibility_t compatibility =
+        iree_hal_amdgpu_target_id_check_compatible(code_object_target_id,
+                                                   &isa_target.target_id);
+    if (compatibility == IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_COMPATIBLE) {
+      return iree_ok_status();
+    }
+    if (first_mismatch == IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_COMPATIBLE) {
+      first_mismatch = compatibility;
+      IREE_RETURN_IF_ERROR(
+          iree_hal_amdgpu_executable_format_target_id_for_message(
+              &isa_target.target_id, sizeof(first_agent_target),
+              first_agent_target));
+    }
+  }
+
+  char code_object_target[128] = {0};
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_executable_format_target_id_for_message(
+      code_object_target_id, sizeof(code_object_target), code_object_target));
+  char mismatch_reasons[128] = {0};
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_target_compatibility_format(
+      first_mismatch, sizeof(mismatch_reasons), mismatch_reasons,
+      /*out_buffer_length=*/NULL));
+  return iree_make_status(
+      IREE_STATUS_INCOMPATIBLE,
+      "AMDGPU code object target `%s` is not compatible with GPU agent[%" PRIhsz
+      "] target `%s` (mismatched %s)",
+      code_object_target, physical_device_ordinal, first_agent_target,
+      mismatch_reasons);
+}
+
+static iree_status_t iree_hal_amdgpu_executable_preflight_code_object(
     const iree_hal_amdgpu_libhsa_t* libhsa,
     const iree_hal_amdgpu_topology_t* topology,
+    const iree_hal_amdgpu_queue_affinity_physical_device_set_t*
+        physical_devices,
+    const iree_hal_amdgpu_target_id_t* code_object_target_id) {
+  for (iree_host_size_t i = 0; i < topology->gpu_agent_count; ++i) {
+    if (!iree_hal_amdgpu_physical_device_mask_contains(
+            physical_devices->physical_device_mask, i)) {
+      continue;
+    }
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_executable_preflight_agent_code_object(
+        libhsa, topology->gpu_agents[i], i, code_object_target_id));
+  }
+  return iree_ok_status();
+}
+
+// Loads an executable ELF from memory for selected agents in |topology| and
+// stores the frozen executable in |out_handle|.
+static iree_status_t iree_hal_amdgpu_executable_load_module(
+    const iree_hal_amdgpu_libhsa_t* libhsa,
+    const iree_hal_amdgpu_topology_t* topology,
+    const iree_hal_amdgpu_queue_affinity_physical_device_set_t*
+        physical_devices,
     const iree_hal_executable_params_t* executable_params,
-    iree_hal_amdgpu_ModuleDef_vec_t module_defs, hsa_executable_t* out_handle) {
+    iree_const_byte_span_t code_object_data, hsa_executable_t* out_handle) {
   IREE_TRACE_ZONE_BEGIN(z0);
   *out_handle = (hsa_executable_t){0};
 
-  // Today we require a single module.
-  // We could support multiple and link them together by loading their code
-  // objects in the order specified. This could be useful if we ever made our
-  // own fat binaries or wanted to reuse shared ELFs across multiple executables
-  // by having them reference the same ranges in a larger file.
-  if (iree_hal_amdgpu_ModuleDef_vec_len(module_defs) != 1) {
-    IREE_RETURN_AND_END_ZONE_IF_ERROR(
-        z0, iree_make_status(IREE_STATUS_UNIMPLEMENTED,
-                             "only a single ModuleDef per ExecutableDef is "
-                             "supported; executable declares %zu modules",
-                             iree_hal_amdgpu_ModuleDef_vec_len(module_defs)));
-  }
-  flatbuffers_string_t image = iree_hal_amdgpu_ModuleDef_image_get(
-      iree_hal_amdgpu_ModuleDef_vec_at(module_defs, 0));
-
   // TODO(#18877): support executable constants in HSA executables.
   // We currently don't support executable constants but we could by way of
   // global symbols. We should be using externs for the constants and then
@@ -415,12 +655,20 @@
   // lacking. These may have only been used for HSAIL anyway.
   const char* options = NULL;
 
+  iree_hal_amdgpu_target_id_t code_object_target_id;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_hal_amdgpu_code_object_target_id_from_elf(
+              code_object_data, &code_object_target_id));
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_hal_amdgpu_executable_preflight_code_object(
+              libhsa, topology, physical_devices, &code_object_target_id));
+
   // Bind a code object reader to the memory sourced from our rodata.
   hsa_code_object_reader_t code_object_reader;
   IREE_RETURN_AND_END_ZONE_IF_ERROR(
       z0, iree_hsa_code_object_reader_create_from_memory(
-              IREE_LIBHSA(libhsa), image, flatbuffers_string_len(image),
-              &code_object_reader));
+              IREE_LIBHSA(libhsa), (const char*)code_object_data.data,
+              code_object_data.data_length, &code_object_reader));
 
   // Create the executable that will hold all of the loaded code objects.
   // TODO(benvanik): pass profile/rounding mode from queried info.
@@ -429,9 +677,13 @@
       IREE_LIBHSA(libhsa), HSA_PROFILE_FULL,
       HSA_DEFAULT_FLOAT_ROUNDING_MODE_DEFAULT, options, &handle);
 
-  // Load the code object for each agent.
+  // Load the code object for each selected agent.
   for (iree_host_size_t i = 0;
        iree_status_is_ok(status) && i < topology->gpu_agent_count; ++i) {
+    if (!iree_hal_amdgpu_physical_device_mask_contains(
+            physical_devices->physical_device_mask, i)) {
+      continue;
+    }
     status = iree_hsa_executable_load_agent_code_object(
         IREE_LIBHSA(libhsa), handle, topology->gpu_agents[i],
         code_object_reader, options, NULL);
@@ -444,54 +696,160 @@
   }
 
   // Release the reader now that the executable has been fully loaded.
-  IREE_IGNORE_ERROR(iree_hsa_code_object_reader_destroy(IREE_LIBHSA(libhsa),
-                                                        code_object_reader));
+  status =
+      iree_status_join(status, iree_hsa_code_object_reader_destroy(
+                                   IREE_LIBHSA(libhsa), code_object_reader));
 
   if (iree_status_is_ok(status)) {
     *out_handle = handle;
   } else if (handle.handle) {
-    IREE_IGNORE_ERROR(iree_hsa_executable_destroy(IREE_LIBHSA(libhsa), handle));
+    status = iree_status_join(
+        status, iree_hsa_executable_destroy(IREE_LIBHSA(libhsa), handle));
   }
   IREE_TRACE_ZONE_END(z0);
   return status;
 }
 
+typedef struct iree_hal_amdgpu_executable_find_loaded_code_object_state_t {
+  // Borrowed HSA API table used for loader extension queries.
+  const iree_hal_amdgpu_libhsa_t* libhsa;
+  // HSA agent whose loaded code object is being searched.
+  hsa_agent_t agent;
+  // Loaded code object matching |agent| when found.
+  hsa_loaded_code_object_t loaded_code_object;
+} iree_hal_amdgpu_executable_find_loaded_code_object_state_t;
+
+static hsa_status_t iree_hal_amdgpu_executable_iterate_loaded_code_object(
+    hsa_executable_t executable, hsa_loaded_code_object_t loaded_code_object,
+    void* user_data) {
+  iree_hal_amdgpu_executable_find_loaded_code_object_state_t* find_state =
+      (iree_hal_amdgpu_executable_find_loaded_code_object_state_t*)user_data;
+  hsa_agent_t agent = {0};
+  hsa_status_t hsa_status =
+      find_state->libhsa->amd_loader
+          .hsa_ven_amd_loader_loaded_code_object_get_info(
+              loaded_code_object,
+              HSA_VEN_AMD_LOADER_LOADED_CODE_OBJECT_INFO_AGENT, &agent);
+  if (hsa_status != HSA_STATUS_SUCCESS) return hsa_status;
+  if (agent.handle == find_state->agent.handle) {
+    find_state->loaded_code_object = loaded_code_object;
+    return HSA_STATUS_INFO_BREAK;
+  }
+  return HSA_STATUS_SUCCESS;
+}
+
+static iree_status_t iree_hal_amdgpu_executable_find_loaded_code_object(
+    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_executable_t executable,
+    hsa_agent_t agent, hsa_loaded_code_object_t* out_loaded_code_object) {
+  *out_loaded_code_object = (hsa_loaded_code_object_t){0};
+  iree_hal_amdgpu_executable_find_loaded_code_object_state_t find_state = {
+      .libhsa = libhsa,
+      .agent = agent,
+      .loaded_code_object = {0},
+  };
+  hsa_status_t hsa_status =
+      libhsa->amd_loader
+          .hsa_ven_amd_loader_executable_iterate_loaded_code_objects(
+              executable, iree_hal_amdgpu_executable_iterate_loaded_code_object,
+              &find_state);
+  if (hsa_status == HSA_STATUS_SUCCESS) {
+    return iree_make_status(IREE_STATUS_NOT_FOUND,
+                            "no loaded code object found for agent");
+  }
+  if (hsa_status != HSA_STATUS_INFO_BREAK) {
+    return iree_status_from_hsa_status(
+        __FILE__, __LINE__, hsa_status,
+        "hsa_ven_amd_loader_executable_iterate_loaded_code_objects",
+        "iterating loaded executable code objects");
+  }
+  *out_loaded_code_object = find_state.loaded_code_object;
+  return iree_ok_status();
+}
+
+static iree_status_t
+iree_hal_amdgpu_executable_populate_profile_code_object_load_info(
+    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_executable_t executable,
+    uint32_t physical_device_ordinal, hsa_agent_t device_agent,
+    iree_hal_amdgpu_profile_code_object_load_info_t* out_load_info) {
+  memset(out_load_info, 0, sizeof(*out_load_info));
+  out_load_info->physical_device_ordinal = physical_device_ordinal;
+
+  hsa_loaded_code_object_t loaded_code_object = {0};
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_executable_find_loaded_code_object(
+      libhsa, executable, device_agent, &loaded_code_object));
+
+  hsa_status_t hsa_status =
+      libhsa->amd_loader.hsa_ven_amd_loader_loaded_code_object_get_info(
+          loaded_code_object,
+          HSA_VEN_AMD_LOADER_LOADED_CODE_OBJECT_INFO_LOAD_DELTA,
+          &out_load_info->load_delta);
+  if (hsa_status == HSA_STATUS_SUCCESS) {
+    hsa_status =
+        libhsa->amd_loader.hsa_ven_amd_loader_loaded_code_object_get_info(
+            loaded_code_object,
+            HSA_VEN_AMD_LOADER_LOADED_CODE_OBJECT_INFO_LOAD_SIZE,
+            &out_load_info->load_size);
+  }
+  return iree_status_from_hsa_status(
+      __FILE__, __LINE__, hsa_status,
+      "hsa_ven_amd_loader_loaded_code_object_get_info",
+      "querying loaded executable code-object profile metadata");
+}
+
+#define IREE_HAL_AMDGPU_MAX_STACK_SYMBOL_NAME_LENGTH \
+  ((iree_host_size_t)(4 * 1024))
+
+static iree_status_t iree_hal_amdgpu_executable_get_symbol_by_cstring(
+    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_executable_t executable,
+    const char* symbol_name, hsa_agent_t device_agent,
+    hsa_executable_symbol_t* out_symbol) {
+  // NOTE: AMDGPU kernel symbols must include the `.kd` suffix.
+  return iree_hsa_executable_get_symbol_by_name(
+      IREE_LIBHSA(libhsa), executable, symbol_name, &device_agent, out_symbol);
+}
+
+static iree_status_t iree_hal_amdgpu_executable_get_raw_hsaco_symbol_by_name(
+    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_executable_t executable,
+    iree_string_view_t symbol_name, hsa_agent_t device_agent,
+    hsa_executable_symbol_t* out_symbol) {
+  if (iree_string_view_is_empty(symbol_name)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "executable kernel symbol name is empty");
+  }
+  if (symbol_name.size > IREE_HAL_AMDGPU_MAX_STACK_SYMBOL_NAME_LENGTH) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "executable kernel symbol name `%.*s` exceeds maximum length %" PRIhsz,
+        (int)symbol_name.size, symbol_name.data,
+        IREE_HAL_AMDGPU_MAX_STACK_SYMBOL_NAME_LENGTH);
+  }
+
+  // AMDGPU MessagePack strings are length-delimited and not NUL-terminated.
+  // Copy only at the HSA API boundary so ROCR can use its internal symbol map.
+  char* symbol_name_storage = (char*)iree_alloca(symbol_name.size + 1);
+  memcpy(symbol_name_storage, symbol_name.data, symbol_name.size);
+  symbol_name_storage[symbol_name.size] = 0;
+  return iree_hal_amdgpu_executable_get_symbol_by_cstring(
+      libhsa, executable, symbol_name_storage, device_agent, out_symbol);
+}
+
 // Resolves the uniform kernel arguments that are the same on all GPU device
 // agents in the topology (since we assume all are the same device type).
 // All fields besides `kernel_object` will have valid values.
-static iree_status_t iree_hal_amdgpu_executable_resolve_kernel_args(
-    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_executable_t executable,
-    iree_hal_amdgpu_ExportDef_table_t export_def,
-    const iree_hal_amdgpu_trace_src_loc_t* export_loc,
-    hsa_agent_t any_device_agent,
+static iree_status_t iree_hal_amdgpu_executable_resolve_kernel_args_from_symbol(
+    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_executable_symbol_t symbol,
+    const uint32_t workgroup_size[3], uint16_t constant_count,
+    uint16_t binding_count,
     iree_hal_amdgpu_device_kernel_args_t* out_kernel_args) {
   IREE_ASSERT_ARGUMENT(out_kernel_args);
   IREE_TRACE_ZONE_BEGIN(z0);
 
-  const char* symbol_name =
-      iree_hal_amdgpu_ExportDef_symbol_name_get(export_def);
-  IREE_TRACE_ZONE_APPEND_TEXT(z0, symbol_name);
-
-  // Lookup the symbol on any device. All devices today must be the same so the
-  // parameters will match (except the kernel_object pointer).
-  //
-  // NOTE: must include `.kd` suffix.
-  hsa_executable_symbol_t symbol = {0};
-  IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0, iree_hsa_executable_get_symbol_by_name(IREE_LIBHSA(libhsa),
-                                                 executable, symbol_name,
-                                                 &any_device_agent, &symbol));
-
   // All of our kernels assume 3 dimensions.
   out_kernel_args->setup = 3 << HSA_KERNEL_DISPATCH_PACKET_SETUP_DIMENSIONS;
 
-  // TODO(benvanik): embed this as a custom section or attributes that we could
-  // somehow query? For now we need the flatbuffer.
-  const iree_hal_amdgpu_Dims_struct_t workgroup_size =
-      iree_hal_amdgpu_ExportDef_workgroup_size_get(export_def);
-  out_kernel_args->workgroup_size[0] = workgroup_size->x;
-  out_kernel_args->workgroup_size[1] = workgroup_size->y;
-  out_kernel_args->workgroup_size[2] = workgroup_size->z;
+  out_kernel_args->workgroup_size[0] = workgroup_size[0];
+  out_kernel_args->workgroup_size[1] = workgroup_size[1];
+  out_kernel_args->workgroup_size[2] = workgroup_size[2];
 
   // NOTE: the object pointer is per-device and we populate that when uploading
   // device tables.
@@ -520,109 +878,25 @@
               HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_KERNARG_SEGMENT_ALIGNMENT,
               &out_kernel_args->kernarg_alignment));
 
-  iree_hal_amdgpu_BindingBits_vec_t binding_bits =
-      iree_hal_amdgpu_ExportDef_binding_flags_get(export_def);
-  out_kernel_args->binding_count =
-      (uint16_t)iree_hal_amdgpu_BindingBits_vec_len(binding_bits);
-  out_kernel_args->constant_count =
-      (uint16_t)iree_hal_amdgpu_ExportDef_constant_count_get(export_def);
-
-  // Interned debugging info for the lifetime of the process. This is required
-  // so tracing tools can access the values while flushing when the process
-  // exits. If no debugging info was available or it's not enabled in the build
-  // this will be 0/NULL.
-  out_kernel_args->trace_src_loc = (uint64_t)export_loc;
+  out_kernel_args->binding_count = binding_count;
+  out_kernel_args->constant_count = constant_count;
 
   IREE_TRACE_ZONE_END(z0);
   return iree_ok_status();
 }
 
-// Allocates (and leaks) a table of source locations for each of |export_defs|.
-// The returned table matches 1:1 and will persist for the lifetime of the
-// process.
-static iree_status_t iree_hal_amdgpu_executable_intern_trace_locs(
-    iree_hal_amdgpu_ExportDef_vec_t export_defs,
-    iree_allocator_t host_allocator,
-    iree_hal_amdgpu_trace_src_loc_t** out_export_locs) {
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  const iree_host_size_t export_count =
-      iree_hal_amdgpu_ExportDef_vec_len(export_defs);
-  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, export_count);
-
-  // Sum up the total storage required for all information.
-  iree_host_size_t total_size =
-      export_count * sizeof(iree_hal_amdgpu_trace_src_loc_t);
-  for (iree_host_size_t i = 0; i < export_count; ++i) {
-    iree_hal_amdgpu_ExportDef_table_t export_def =
-        iree_hal_amdgpu_ExportDef_vec_at(export_defs, i);
-    iree_hal_debug_ExportDef_table_t debug_def =
-        iree_hal_amdgpu_ExportDef_debug_info_get(export_def);
-    total_size +=
-        flatbuffers_string_len(iree_hal_debug_ExportDef_name_get(debug_def)) +
-        1;
-    iree_hal_debug_FileLineLocDef_table_t loc_def =
-        iree_hal_debug_ExportDef_location_get(debug_def);
-    if (loc_def) {
-      total_size += flatbuffers_string_len(
-                        iree_hal_debug_FileLineLocDef_filename_get(loc_def)) +
-                    1;
-    }
-  }
-
-  // Allocate persistent storage.
-  iree_hal_amdgpu_trace_src_loc_t* export_locs = NULL;
-  IREE_LEAK_CHECK_DISABLE_PUSH();
-  export_locs = (iree_hal_amdgpu_trace_src_loc_t*)malloc(total_size);
-  IREE_LEAK_CHECK_DISABLE_POP();
-  char* char_buffer = (char*)&export_locs[export_count];
-
-  // Populate table and fill the buffer. The only pointers used are those
-  // pointing into the persistent allocation.
-  for (iree_host_size_t i = 0; i < export_count; ++i) {
-    iree_hal_amdgpu_trace_src_loc_t* export_loc = &export_locs[i];
-    export_loc->name = NULL;  // not needed
-
-    iree_hal_amdgpu_ExportDef_table_t export_def =
-        iree_hal_amdgpu_ExportDef_vec_at(export_defs, i);
-    iree_hal_debug_ExportDef_table_t debug_def =
-        iree_hal_amdgpu_ExportDef_debug_info_get(export_def);
-
-    flatbuffers_string_t function =
-        iree_hal_debug_ExportDef_name_get(debug_def);
-    iree_host_size_t function_len = flatbuffers_string_len(function);
-    memcpy(char_buffer, function, function_len);
-    export_loc->function = char_buffer;
-    char_buffer += function_len + 1;
-
-    iree_hal_debug_FileLineLocDef_table_t loc_def =
-        iree_hal_debug_ExportDef_location_get(debug_def);
-    if (loc_def) {
-      flatbuffers_string_t file =
-          iree_hal_debug_FileLineLocDef_filename_get(loc_def);
-      iree_host_size_t file_len = flatbuffers_string_len(file);
-      memcpy(char_buffer, file, file_len);
-      export_loc->file = char_buffer;
-      char_buffer += file_len + 1;
-      export_loc->line = iree_hal_debug_FileLineLocDef_line_get(loc_def);
-    }
-
-    // We could do something clever here to ensure consistent colors, like
-    // hashing based on name.
-    export_loc->color = 0;
-  }
-
-  *out_export_locs = export_locs;
-  IREE_TRACE_ZONE_END(z0);
-  return iree_ok_status();
+static iree_status_t iree_hal_amdgpu_executable_resolve_kernel_object(
+    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_executable_symbol_t symbol,
+    uint64_t* out_kernel_object) {
+  return iree_hsa_executable_symbol_get_info(
+      IREE_LIBHSA(libhsa), symbol, HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_OBJECT,
+      out_kernel_object);
 }
 
 // Uploads the provided kernel table to |device_agent| and returns the pointer.
-// |host_kernel_args| will have its `kernel_object` fields mutated during the
-// upload.
-static iree_status_t iree_hal_amdgpu_executable_upload_kernel_table(
-    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_executable_t executable,
-    iree_hal_amdgpu_ExportDef_vec_t export_defs, iree_host_size_t kernel_count,
+// |host_kernel_args| must already have device-specific `kernel_object` fields.
+static iree_status_t iree_hal_amdgpu_executable_upload_resolved_kernel_table(
+    const iree_hal_amdgpu_libhsa_t* libhsa, iree_host_size_t kernel_count,
     iree_hal_amdgpu_device_kernel_args_t* host_kernel_args,
     hsa_agent_t device_agent,
     IREE_AMDGPU_DEVICE_PTR const iree_hal_amdgpu_device_kernel_args_t**
@@ -630,29 +904,9 @@
   IREE_TRACE_ZONE_BEGIN(z0);
   *out_device_kernel_args = NULL;
 
-  // Upload copies of kernel arguments for each device.
-  // We reuse the host storage we already allocated to make it possible to
-  // memcpy the entire table in one go from host memory.
-  // Resolve all kernel object pointers for the device agent.
-  for (iree_host_size_t kernel_ordinal = 0; kernel_ordinal < kernel_count;
-       ++kernel_ordinal) {
-    iree_hal_amdgpu_ExportDef_table_t export_def =
-        iree_hal_amdgpu_ExportDef_vec_at(export_defs, kernel_ordinal);
-    const char* symbol_name =
-        iree_hal_amdgpu_ExportDef_symbol_name_get(export_def);
-
-    // NOTE: must include `.kd` suffix.
-    hsa_executable_symbol_t symbol = {0};
-    IREE_RETURN_AND_END_ZONE_IF_ERROR(z0,
-                                      iree_hsa_executable_get_symbol_by_name(
-                                          IREE_LIBHSA(libhsa), executable,
-                                          symbol_name, &device_agent, &symbol),
-                                      "resolving `%s`", symbol_name);
-    IREE_RETURN_AND_END_ZONE_IF_ERROR(
-        z0, iree_hsa_executable_symbol_get_info(
-                IREE_LIBHSA(libhsa), symbol,
-                HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_OBJECT,
-                &host_kernel_args[kernel_ordinal].kernel_object));
+  if (kernel_count == 0) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_ok_status();
   }
 
   // Find a memory pool on the agent where we can upload the table.
@@ -680,19 +934,154 @@
   if (iree_status_is_ok(status)) {
     *out_device_kernel_args = device_kernel_args;
   } else {
-    IREE_IGNORE_ERROR(
+    status = iree_status_join(
+        status,
         iree_hsa_amd_memory_pool_free(IREE_LIBHSA(libhsa), device_kernel_args));
   }
   IREE_TRACE_ZONE_END(z0);
   return status;
 }
 
+static iree_status_t iree_hal_amdgpu_executable_upload_flatbuffer_kernel_table(
+    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_executable_t executable,
+    iree_hal_amdgpu_ExportDef_vec_t export_defs,
+    iree_hal_amdgpu_device_kernel_args_t* host_kernel_args,
+    hsa_agent_t device_agent,
+    IREE_AMDGPU_DEVICE_PTR const iree_hal_amdgpu_device_kernel_args_t**
+        out_device_kernel_args) {
+  const iree_host_size_t kernel_count =
+      iree_hal_amdgpu_ExportDef_vec_len(export_defs);
+  for (iree_host_size_t kernel_ordinal = 0; kernel_ordinal < kernel_count;
+       ++kernel_ordinal) {
+    iree_hal_amdgpu_ExportDef_table_t export_def =
+        iree_hal_amdgpu_ExportDef_vec_at(export_defs, kernel_ordinal);
+    flatbuffers_string_t symbol_name =
+        iree_hal_amdgpu_ExportDef_symbol_name_get(export_def);
+    iree_string_view_t symbol_name_view =
+        symbol_name ? iree_make_string_view(symbol_name,
+                                            flatbuffers_string_len(symbol_name))
+                    : iree_string_view_empty();
+    if (iree_string_view_is_empty(symbol_name_view)) {
+      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                              "executable kernel symbol name is empty");
+    }
+    hsa_executable_symbol_t symbol = {0};
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_executable_get_symbol_by_cstring(
+            libhsa, executable, symbol_name, device_agent, &symbol),
+        "resolving `%.*s`", (int)symbol_name_view.size, symbol_name_view.data);
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_executable_resolve_kernel_object(
+        libhsa, symbol, &host_kernel_args[kernel_ordinal].kernel_object));
+  }
+  return iree_hal_amdgpu_executable_upload_resolved_kernel_table(
+      libhsa, kernel_count, host_kernel_args, device_agent,
+      out_device_kernel_args);
+}
+
+static iree_status_t iree_hal_amdgpu_executable_upload_raw_hsaco_kernel_table(
+    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_executable_t executable,
+    const iree_hal_amdgpu_hsaco_metadata_t* hsaco_metadata,
+    iree_hal_amdgpu_device_kernel_args_t* host_kernel_args,
+    hsa_agent_t device_agent,
+    IREE_AMDGPU_DEVICE_PTR const iree_hal_amdgpu_device_kernel_args_t**
+        out_device_kernel_args) {
+  for (iree_host_size_t kernel_ordinal = 0;
+       kernel_ordinal < hsaco_metadata->kernel_count; ++kernel_ordinal) {
+    iree_string_view_t symbol_name =
+        hsaco_metadata->kernels[kernel_ordinal].symbol_name;
+    hsa_executable_symbol_t symbol = {0};
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_executable_get_raw_hsaco_symbol_by_name(
+            libhsa, executable, symbol_name, device_agent, &symbol),
+        "resolving `%.*s`", (int)symbol_name.size, symbol_name.data);
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_executable_resolve_kernel_object(
+        libhsa, symbol, &host_kernel_args[kernel_ordinal].kernel_object));
+  }
+  return iree_hal_amdgpu_executable_upload_resolved_kernel_table(
+      libhsa, hsaco_metadata->kernel_count, host_kernel_args, device_agent,
+      out_device_kernel_args);
+}
+
+static iree_status_t iree_hal_amdgpu_executable_calculate_kernarg_block_count(
+    const iree_hal_amdgpu_device_dispatch_kernarg_layout_t* layout,
+    uint32_t* out_kernarg_block_count) {
+  iree_host_size_t kernarg_block_count = iree_host_size_ceil_div(
+      layout->total_kernarg_size, sizeof(iree_hal_amdgpu_kernarg_block_t));
+  if (kernarg_block_count == 0) {
+    kernarg_block_count = 1;
+  }
+  if (IREE_UNLIKELY(kernarg_block_count > UINT32_MAX)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "dispatch kernargs require too many blocks (%" PRIhsz ", max=%u)",
+        kernarg_block_count, UINT32_MAX);
+  }
+  *out_kernarg_block_count = (uint32_t)kernarg_block_count;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_executable_initialize_dispatch_descriptor(
+    const iree_hal_amdgpu_device_kernel_args_t* kernel_args,
+    iree_hal_amdgpu_executable_dispatch_descriptor_t* out_descriptor) {
+  memset(out_descriptor, 0, sizeof(*out_descriptor));
+
+  if (IREE_UNLIKELY(
+          !iree_host_size_is_power_of_two(kernel_args->kernarg_alignment))) {
+    return iree_make_status(
+        IREE_STATUS_FAILED_PRECONDITION,
+        "executable kernel kernarg alignment must be a power of two (got %u)",
+        kernel_args->kernarg_alignment);
+  }
+  if (IREE_UNLIKELY(kernel_args->kernarg_alignment >
+                    iree_alignof(iree_hal_amdgpu_kernarg_block_t))) {
+    return iree_make_status(
+        IREE_STATUS_FAILED_PRECONDITION,
+        "executable kernel kernarg alignment %u exceeds queue kernarg ring "
+        "alignment %" PRIhsz,
+        kernel_args->kernarg_alignment,
+        (iree_host_size_t)iree_alignof(iree_hal_amdgpu_kernarg_block_t));
+  }
+
+  out_descriptor->kernel_args = *kernel_args;
+  // Kernel metadata reports the bytes actually consumed by the compiled
+  // kernel. That may be smaller than the HAL ABI explicit argument footprint
+  // when optimization removes unused bindings/constants; we still reserve and
+  // populate the public HAL ABI bytes for dispatch.
+  out_descriptor->hal_kernarg_layout =
+      iree_hal_amdgpu_device_dispatch_make_hal_kernarg_layout(kernel_args);
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_executable_calculate_kernarg_block_count(
+      &out_descriptor->hal_kernarg_layout,
+      &out_descriptor->hal_kernarg_block_count));
+
+  out_descriptor->custom_kernarg_layout =
+      iree_hal_amdgpu_device_dispatch_make_custom_kernarg_layout(
+          kernel_args->kernarg_size);
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_executable_calculate_kernarg_block_count(
+      &out_descriptor->custom_kernarg_layout,
+      &out_descriptor->custom_kernarg_block_count));
+
+  for (iree_host_size_t i = 0; i < 3; ++i) {
+    if (IREE_UNLIKELY(kernel_args->workgroup_size[i] == 0)) {
+      return iree_make_status(
+          IREE_STATUS_FAILED_PRECONDITION,
+          "executable kernel workgroup size dimension %" PRIhsz " is zero", i);
+    }
+    out_descriptor->max_workgroup_count[i] =
+        UINT32_MAX / kernel_args->workgroup_size[i];
+  }
+  out_descriptor->max_dynamic_workgroup_local_memory =
+      UINT32_MAX - kernel_args->group_segment_size;
+  return iree_ok_status();
+}
+
 //===----------------------------------------------------------------------===//
 // iree_hal_amdgpu_executable_t
 //===----------------------------------------------------------------------===//
 
 typedef struct iree_hal_amdgpu_executable_t {
+  // HAL executable resource header.
   iree_hal_resource_t resource;
+  // Host allocator used for executable-owned metadata tables.
   iree_allocator_t host_allocator;
 
   // Unowned HSA API handle. Must remain valid for the lifetime of the
@@ -702,18 +1091,39 @@
   // Loaded HSA executable with a code object for each device.
   hsa_executable_t handle;
 
+  // Producer-local profile executable id assigned at creation.
+  uint64_t profile_id;
+  // Stable content hash for the exact loaded HSACO/code-object bytes.
+  uint64_t profile_code_object_hash[2];
+
   // Total number of exports in the executable.
   iree_host_size_t kernel_count;
+  // Host-resident reflection information for each export.
+  iree_hal_executable_export_info_t* export_infos /*[kernel_count]*/;
+  // Prefix-sum offsets into |export_parameters| for each export plus a
+  // sentinel.
+  iree_host_size_t* export_parameter_offsets /*[kernel_count + 1]*/;
+  // Host-resident parameter reflection records for all exports.
+  iree_hal_executable_export_parameter_t* export_parameters;
   // Table of kernel args stored in host memory. We have them local so that
   // host-side command buffer recording doesn't need to access device memory.
   // The kernel object specified in each is invalid as it's agent-specific.
   iree_hal_amdgpu_device_kernel_args_t* host_kernel_args /*[kernel_count]*/;
+  // Host-resident dispatch descriptors stored as [device_count][kernel_count].
+  iree_hal_amdgpu_executable_dispatch_descriptor_t*
+      host_dispatch_descriptors /*[device_count * kernel_count]*/;
 
-  // Total number of GPU devices in the system that the executable kernel arg
-  // table has been uploaded to.
+  // Queue affinity this executable was loaded for after normalization.
+  iree_hal_queue_affinity_t queue_affinity;
+  // Bitmask of physical GPU device ordinals with loaded code objects.
+  uint64_t loaded_physical_device_mask;
+  // Number of loaded physical GPU devices in |loaded_physical_device_mask|.
+  iree_host_size_t loaded_physical_device_count;
+  // Total number of GPU devices in the topology used for per-device tables.
   iree_host_size_t device_count;
   // Table of kernel args stored in device memory, one copy per device.
-  // Each device has an entire `kernel_count` set of args.
+  // Selected devices have an entire `kernel_count` set of args; unselected
+  // devices remain NULL and fail lookup.
   IREE_AMDGPU_DEVICE_PTR const iree_hal_amdgpu_device_kernel_args_t*
       device_kernel_args[/*device_count*/];
 } iree_hal_amdgpu_executable_t;
@@ -732,31 +1142,923 @@
   return (const iree_hal_amdgpu_executable_t*)base_value;
 }
 
+static iree_string_view_t iree_hal_amdgpu_executable_flatbuffer_string_view(
+    flatbuffers_string_t value) {
+  return value ? iree_make_string_view(value, flatbuffers_string_len(value))
+               : iree_string_view_empty();
+}
+
+static iree_string_view_t iree_hal_amdgpu_executable_export_reflection_name(
+    iree_hal_amdgpu_ExportDef_table_t export_def) {
+  iree_hal_debug_ExportDef_table_t debug_def =
+      iree_hal_amdgpu_ExportDef_debug_info_get(export_def);
+  if (debug_def) {
+    iree_string_view_t debug_name =
+        iree_hal_amdgpu_executable_flatbuffer_string_view(
+            iree_hal_debug_ExportDef_name_get(debug_def));
+    if (!iree_string_view_is_empty(debug_name)) return debug_name;
+  }
+
+  iree_string_view_t symbol_name =
+      iree_hal_amdgpu_executable_flatbuffer_string_view(
+          iree_hal_amdgpu_ExportDef_symbol_name_get(export_def));
+  return iree_string_view_strip_suffix(symbol_name, IREE_SV(".kd"));
+}
+
+static iree_status_t iree_hal_amdgpu_executable_allocate(
+    const iree_hal_amdgpu_libhsa_t* libhsa,
+    const iree_hal_amdgpu_topology_t* topology,
+    const iree_hal_amdgpu_queue_affinity_physical_device_set_t*
+        physical_devices,
+    iree_host_size_t export_count, iree_host_size_t export_name_storage_size,
+    iree_host_size_t export_parameter_count,
+    iree_host_size_t export_parameter_name_storage_size,
+    iree_allocator_t host_allocator, char** out_export_name_storage,
+    char** out_export_parameter_name_storage,
+    iree_hal_amdgpu_executable_t** out_executable) {
+  *out_export_name_storage = NULL;
+  *out_export_parameter_name_storage = NULL;
+  *out_executable = NULL;
+
+  iree_host_size_t dispatch_descriptor_count = 0;
+  if (!iree_host_size_checked_mul(topology->gpu_agent_count, export_count,
+                                  &dispatch_descriptor_count)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "dispatch descriptor table size overflow");
+  }
+
+  iree_host_size_t export_parameter_offset_count = 0;
+  if (!iree_host_size_checked_add(export_count, 1,
+                                  &export_parameter_offset_count)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "export parameter offset table size overflow");
+  }
+
+  iree_host_size_t total_size = 0;
+  iree_host_size_t export_infos_offset = 0;
+  iree_host_size_t export_name_storage_offset = 0;
+  iree_host_size_t export_parameter_offsets_offset = 0;
+  iree_host_size_t export_parameters_offset = 0;
+  iree_host_size_t export_parameter_name_storage_offset = 0;
+  iree_host_size_t host_kernel_args_offset = 0;
+  iree_host_size_t host_dispatch_descriptors_offset = 0;
+  IREE_RETURN_IF_ERROR(IREE_STRUCT_LAYOUT(
+      sizeof(iree_hal_amdgpu_executable_t), &total_size,
+      IREE_STRUCT_FIELD_FAM(
+          topology->gpu_agent_count,
+          IREE_AMDGPU_DEVICE_PTR const iree_hal_amdgpu_device_kernel_args_t*),
+      IREE_STRUCT_FIELD(export_count, iree_hal_executable_export_info_t,
+                        &export_infos_offset),
+      IREE_STRUCT_FIELD(export_name_storage_size, char,
+                        &export_name_storage_offset),
+      IREE_STRUCT_FIELD(export_parameter_offset_count, iree_host_size_t,
+                        &export_parameter_offsets_offset),
+      IREE_STRUCT_FIELD(export_parameter_count,
+                        iree_hal_executable_export_parameter_t,
+                        &export_parameters_offset),
+      IREE_STRUCT_FIELD(export_parameter_name_storage_size, char,
+                        &export_parameter_name_storage_offset),
+      IREE_STRUCT_FIELD(export_count, iree_hal_amdgpu_device_kernel_args_t,
+                        &host_kernel_args_offset),
+      IREE_STRUCT_FIELD(dispatch_descriptor_count,
+                        iree_hal_amdgpu_executable_dispatch_descriptor_t,
+                        &host_dispatch_descriptors_offset)));
+
+  iree_hal_amdgpu_executable_t* executable = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_allocator_malloc(host_allocator, total_size, (void**)&executable));
+  memset(executable, 0, total_size);
+  iree_hal_resource_initialize(&iree_hal_amdgpu_executable_vtable,
+                               &executable->resource);
+  executable->host_allocator = host_allocator;
+  executable->libhsa = libhsa;
+  executable->kernel_count = export_count;
+  uint8_t* executable_storage = (uint8_t*)executable;
+  executable->export_infos =
+      (iree_hal_executable_export_info_t*)(executable_storage +
+                                           export_infos_offset);
+  executable->export_parameter_offsets =
+      (iree_host_size_t*)(executable_storage + export_parameter_offsets_offset);
+  executable->export_parameters =
+      export_parameter_count
+          ? (iree_hal_executable_export_parameter_t*)(executable_storage +
+                                                      export_parameters_offset)
+          : NULL;
+  executable->host_kernel_args =
+      (iree_hal_amdgpu_device_kernel_args_t*)(executable_storage +
+                                              host_kernel_args_offset);
+  executable->host_dispatch_descriptors =
+      (iree_hal_amdgpu_executable_dispatch_descriptor_t*)(executable_storage +
+                                                          host_dispatch_descriptors_offset);
+  executable->queue_affinity = physical_devices->queue_affinity;
+  executable->loaded_physical_device_mask =
+      physical_devices->physical_device_mask;
+  executable->loaded_physical_device_count =
+      physical_devices->physical_device_count;
+  executable->device_count = topology->gpu_agent_count;
+
+  *out_export_name_storage =
+      (char*)executable_storage + export_name_storage_offset;
+  *out_export_parameter_name_storage =
+      (char*)executable_storage + export_parameter_name_storage_offset;
+  *out_executable = executable;
+  return iree_ok_status();
+}
+
+static void iree_hal_amdgpu_executable_invalidate_host_kernel_objects(
+    iree_hal_amdgpu_executable_t* executable) {
+  if (!executable) return;
+  for (iree_host_size_t kernel_ordinal = 0;
+       kernel_ordinal < executable->kernel_count; ++kernel_ordinal) {
+    executable->host_kernel_args[kernel_ordinal].kernel_object = 0;
+  }
+}
+
+static iree_status_t
+iree_hal_amdgpu_executable_initialize_dispatch_descriptors_for_device(
+    iree_hal_amdgpu_executable_t* executable, iree_host_size_t device_ordinal) {
+  for (iree_host_size_t kernel_ordinal = 0;
+       kernel_ordinal < executable->kernel_count; ++kernel_ordinal) {
+    const iree_host_size_t descriptor_ordinal =
+        device_ordinal * executable->kernel_count + kernel_ordinal;
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_executable_initialize_dispatch_descriptor(
+            &executable->host_kernel_args[kernel_ordinal],
+            &executable->host_dispatch_descriptors[descriptor_ordinal]),
+        "initializing dispatch descriptor for device %" PRIhsz
+        " export %" PRIhsz,
+        device_ordinal, kernel_ordinal);
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_executable_register_profile_artifacts(
+    const iree_hal_amdgpu_libhsa_t* libhsa,
+    const iree_hal_amdgpu_topology_t* topology,
+    iree_hal_amdgpu_profile_metadata_registry_t* profile_metadata,
+    iree_const_byte_span_t code_object_data,
+    iree_hal_amdgpu_executable_t* executable) {
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  iree_host_size_t load_info_storage_size = 0;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, IREE_STRUCT_LAYOUT(
+              0, &load_info_storage_size,
+              IREE_STRUCT_FIELD(executable->loaded_physical_device_count,
+                                iree_hal_amdgpu_profile_code_object_load_info_t,
+                                NULL)));
+
+  iree_hal_amdgpu_profile_code_object_load_info_t* load_infos = NULL;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_allocator_malloc(executable->host_allocator,
+                                load_info_storage_size, (void**)&load_infos));
+
+  iree_status_t status = iree_ok_status();
+  iree_host_size_t load_info_ordinal = 0;
+  for (iree_host_size_t device_ordinal = 0;
+       device_ordinal < topology->gpu_agent_count && iree_status_is_ok(status);
+       ++device_ordinal) {
+    if (!iree_hal_amdgpu_physical_device_mask_contains(
+            executable->loaded_physical_device_mask, device_ordinal)) {
+      continue;
+    }
+    if (IREE_UNLIKELY(device_ordinal > UINT32_MAX)) {
+      status = iree_make_status(
+          IREE_STATUS_OUT_OF_RANGE,
+          "profile executable physical device ordinal exceeds uint32_t");
+    } else {
+      status =
+          iree_hal_amdgpu_executable_populate_profile_code_object_load_info(
+              libhsa, executable->handle, (uint32_t)device_ordinal,
+              topology->gpu_agents[device_ordinal],
+              &load_infos[load_info_ordinal]);
+      if (iree_status_is_ok(status)) ++load_info_ordinal;
+    }
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_profile_metadata_register_executable_artifacts(
+        profile_metadata, executable->profile_id, code_object_data,
+        executable->profile_code_object_hash,
+        executable->loaded_physical_device_count, load_infos);
+  }
+
+  iree_allocator_free(executable->host_allocator, load_infos);
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+static iree_status_t iree_hal_amdgpu_executable_register_profile_metadata(
+    const iree_hal_amdgpu_libhsa_t* libhsa,
+    const iree_hal_amdgpu_topology_t* topology,
+    iree_hal_amdgpu_profile_metadata_registry_t* profile_metadata,
+    iree_const_byte_span_t code_object_data,
+    iree_hal_amdgpu_executable_t* executable) {
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_profile_metadata_register_executable(
+      profile_metadata, executable->kernel_count, executable->export_infos,
+      executable->export_parameter_offsets,
+      executable->profile_code_object_hash, executable->host_kernel_args,
+      &executable->profile_id));
+
+  // Executable trace profiling may begin after executable preparation. Preserve
+  // exact code-object bytes and loader load ranges while |code_object_data| is
+  // still in scope so later ATT capture can always emit a self-contained
+  // profile.
+  return iree_hal_amdgpu_executable_register_profile_artifacts(
+      libhsa, topology, profile_metadata, code_object_data, executable);
+}
+
+static iree_status_t
+iree_hal_amdgpu_executable_validate_export_parameter_requirements(
+    iree_hal_amdgpu_ExportDef_table_t export_def,
+    iree_string_view_t symbol_name,
+    const iree_hal_amdgpu_hsaco_metadata_export_parameter_requirements_t*
+        requirements) {
+  // The flatbuffer owns the HAL ABI counts for wrapped executables. HSACO
+  // metadata may omit arguments that LLVM optimized away, but it must not
+  // require more visible arguments than the flatbuffer layout can supply.
+  const uint32_t expected_constant_count =
+      iree_hal_amdgpu_ExportDef_constant_count_get(export_def);
+  if (requirements->constant_count > expected_constant_count) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "HSACO metadata for export `%.*s` declares %u reflected constants but "
+        "ExecutableDef only declares %u",
+        (int)symbol_name.size, symbol_name.data,
+        (uint32_t)requirements->constant_count, expected_constant_count);
+  }
+
+  iree_hal_amdgpu_BindingBits_vec_t binding_flags =
+      iree_hal_amdgpu_ExportDef_binding_flags_get(export_def);
+  const iree_host_size_t expected_binding_count =
+      iree_hal_amdgpu_BindingBits_vec_len(binding_flags);
+  if (requirements->binding_count > expected_binding_count) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "HSACO metadata for export `%.*s` declares %u reflected bindings but "
+        "ExecutableDef only declares %" PRIhsz,
+        (int)symbol_name.size, symbol_name.data,
+        (uint32_t)requirements->binding_count, expected_binding_count);
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_executable_calculate_reflection_storage(
+    iree_hal_amdgpu_ExportDef_vec_t export_defs,
+    const iree_hal_amdgpu_hsaco_metadata_t* hsaco_metadata,
+    iree_host_size_t* out_export_name_storage_size,
+    iree_host_size_t* out_export_parameter_count,
+    iree_host_size_t* out_export_parameter_name_storage_size) {
+  iree_host_size_t export_name_storage_size = 0;
+  iree_host_size_t export_parameter_count = 0;
+  iree_host_size_t export_parameter_name_storage_size = 0;
+  const iree_host_size_t export_count =
+      iree_hal_amdgpu_ExportDef_vec_len(export_defs);
+  for (iree_host_size_t i = 0; i < export_count; ++i) {
+    iree_hal_amdgpu_ExportDef_table_t export_def =
+        iree_hal_amdgpu_ExportDef_vec_at(export_defs, i);
+    iree_string_view_t name =
+        iree_hal_amdgpu_executable_export_reflection_name(export_def);
+    if (!iree_host_size_checked_add(export_name_storage_size, name.size,
+                                    &export_name_storage_size)) {
+      return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                              "export name storage size overflow");
+    }
+
+    iree_string_view_t symbol_name =
+        iree_hal_amdgpu_executable_flatbuffer_string_view(
+            iree_hal_amdgpu_ExportDef_symbol_name_get(export_def));
+    const iree_hal_amdgpu_hsaco_metadata_kernel_t* kernel = NULL;
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_hsaco_metadata_find_kernel_by_symbol(
+                             hsaco_metadata, symbol_name, &kernel),
+                         "looking up HSACO metadata for export `%.*s`",
+                         (int)symbol_name.size, symbol_name.data);
+
+    iree_hal_amdgpu_hsaco_metadata_export_parameter_requirements_t requirements;
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_hsaco_metadata_calculate_default_export_parameter_requirements(
+            kernel, &requirements),
+        "projecting HSACO parameters for export `%.*s`", (int)symbol_name.size,
+        symbol_name.data);
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_executable_validate_export_parameter_requirements(
+            export_def, symbol_name, &requirements));
+
+    if (!iree_host_size_checked_add(export_parameter_count,
+                                    requirements.parameter_count,
+                                    &export_parameter_count) ||
+        !iree_host_size_checked_add(export_parameter_name_storage_size,
+                                    requirements.name_storage_size,
+                                    &export_parameter_name_storage_size)) {
+      return iree_make_status(
+          IREE_STATUS_OUT_OF_RANGE,
+          "export parameter reflection storage size overflow");
+    }
+  }
+  *out_export_name_storage_size = export_name_storage_size;
+  *out_export_parameter_count = export_parameter_count;
+  *out_export_parameter_name_storage_size = export_parameter_name_storage_size;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_executable_verify_raw_hsaco_kernel(
+    const iree_hal_amdgpu_hsaco_metadata_kernel_t* kernel,
+    const iree_hal_amdgpu_device_limits_t* limits) {
+  if (!kernel->has_required_workgroup_size) {
+    return iree_ok_status();
+  }
+
+  const uint32_t* workgroup_size = kernel->required_workgroup_size;
+  if (workgroup_size[0] > limits->max_workgroup_size_per_dim[0] ||
+      workgroup_size[1] > limits->max_workgroup_size_per_dim[1] ||
+      workgroup_size[2] > limits->max_workgroup_size_per_dim[2]) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "raw HSACO kernel `%.*s` workgroup size dims %ux%ux%u exceed device "
+        "maximum %ux%ux%u",
+        (int)kernel->symbol_name.size, kernel->symbol_name.data,
+        workgroup_size[0], workgroup_size[1], workgroup_size[2],
+        limits->max_workgroup_size_per_dim[0],
+        limits->max_workgroup_size_per_dim[1],
+        limits->max_workgroup_size_per_dim[2]);
+  }
+  const uint64_t total_workgroup_size =
+      (uint64_t)workgroup_size[0] * workgroup_size[1] * workgroup_size[2];
+  if (total_workgroup_size > limits->max_workgroup_size) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "raw HSACO kernel `%.*s` workgroup size total %" PRIu64
+        " exceeds device "
+        "maximum %u",
+        (int)kernel->symbol_name.size, kernel->symbol_name.data,
+        total_workgroup_size, limits->max_workgroup_size);
+  }
+  return iree_ok_status();
+}
+
+static void iree_hal_amdgpu_executable_raw_hsaco_workgroup_size(
+    const iree_hal_amdgpu_hsaco_metadata_kernel_t* kernel,
+    uint32_t out_workgroup_size[3]) {
+  if (kernel->has_required_workgroup_size) {
+    out_workgroup_size[0] = kernel->required_workgroup_size[0];
+    out_workgroup_size[1] = kernel->required_workgroup_size[1];
+    out_workgroup_size[2] = kernel->required_workgroup_size[2];
+  } else {
+    // Raw HSACO without `.reqd_workgroup_size` is represented as a dynamic
+    // workgroup-size export with 1x1x1 minimum granularity. The actual launch
+    // geometry must come from the dispatch config.
+    out_workgroup_size[0] = 1;
+    out_workgroup_size[1] = 1;
+    out_workgroup_size[2] = 1;
+  }
+}
+
+static iree_status_t
+iree_hal_amdgpu_executable_calculate_raw_hsaco_reflection_storage(
+    const iree_hal_amdgpu_hsaco_metadata_t* hsaco_metadata,
+    const iree_hal_amdgpu_device_limits_t* limits,
+    iree_host_size_t* out_export_name_storage_size,
+    iree_host_size_t* out_export_parameter_count,
+    iree_host_size_t* out_export_parameter_name_storage_size) {
+  if (iree_string_view_is_empty(hsaco_metadata->target)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "raw HSACO metadata is missing `amdhsa.target`; direct loading "
+        "requires the code object to declare its target ISA");
+  }
+
+  iree_host_size_t export_name_storage_size = 0;
+  iree_host_size_t export_parameter_count = 0;
+  iree_host_size_t export_parameter_name_storage_size = 0;
+  for (iree_host_size_t i = 0; i < hsaco_metadata->kernel_count; ++i) {
+    const iree_hal_amdgpu_hsaco_metadata_kernel_t* kernel =
+        &hsaco_metadata->kernels[i];
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_executable_verify_raw_hsaco_kernel(kernel, limits));
+    if (!iree_host_size_checked_add(export_name_storage_size,
+                                    kernel->reflection_name.size,
+                                    &export_name_storage_size)) {
+      return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                              "export name storage size overflow");
+    }
+
+    iree_hal_amdgpu_hsaco_metadata_export_parameter_requirements_t requirements;
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_hsaco_metadata_calculate_default_export_parameter_requirements(
+            kernel, &requirements),
+        "projecting HSACO parameters for raw kernel `%.*s`",
+        (int)kernel->symbol_name.size, kernel->symbol_name.data);
+    if (!iree_host_size_checked_add(export_parameter_count,
+                                    requirements.parameter_count,
+                                    &export_parameter_count) ||
+        !iree_host_size_checked_add(export_parameter_name_storage_size,
+                                    requirements.name_storage_size,
+                                    &export_parameter_name_storage_size)) {
+      return iree_make_status(
+          IREE_STATUS_OUT_OF_RANGE,
+          "export parameter reflection storage size overflow");
+    }
+  }
+  *out_export_name_storage_size = export_name_storage_size;
+  *out_export_parameter_count = export_parameter_count;
+  *out_export_parameter_name_storage_size = export_parameter_name_storage_size;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_executable_initialize_export_infos(
+    iree_hal_amdgpu_ExportDef_vec_t export_defs,
+    const iree_hal_amdgpu_hsaco_metadata_t* hsaco_metadata,
+    iree_hal_executable_export_info_t* export_infos,
+    iree_host_size_t* export_parameter_offsets,
+    iree_hal_executable_export_parameter_t* export_parameters,
+    char* export_name_storage, char* export_parameter_name_storage) {
+  iree_host_size_t export_parameter_offset = 0;
+  const iree_host_size_t export_count =
+      iree_hal_amdgpu_ExportDef_vec_len(export_defs);
+  for (iree_host_size_t i = 0; i < export_count; ++i) {
+    iree_hal_amdgpu_ExportDef_table_t export_def =
+        iree_hal_amdgpu_ExportDef_vec_at(export_defs, i);
+    iree_hal_executable_export_info_t* info = &export_infos[i];
+    export_parameter_offsets[i] = export_parameter_offset;
+
+    iree_string_view_t symbol_name =
+        iree_hal_amdgpu_executable_flatbuffer_string_view(
+            iree_hal_amdgpu_ExportDef_symbol_name_get(export_def));
+    const iree_hal_amdgpu_hsaco_metadata_kernel_t* kernel = NULL;
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_hsaco_metadata_find_kernel_by_symbol(
+                             hsaco_metadata, symbol_name, &kernel),
+                         "looking up HSACO metadata for export `%.*s`",
+                         (int)symbol_name.size, symbol_name.data);
+
+    iree_hal_amdgpu_hsaco_metadata_export_parameter_requirements_t requirements;
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_hsaco_metadata_calculate_default_export_parameter_requirements(
+            kernel, &requirements),
+        "projecting HSACO parameters for export `%.*s`", (int)symbol_name.size,
+        symbol_name.data);
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_executable_validate_export_parameter_requirements(
+            export_def, symbol_name, &requirements));
+
+    iree_string_view_t name =
+        iree_hal_amdgpu_executable_export_reflection_name(export_def);
+    if (!iree_string_view_is_empty(name)) {
+      memcpy(export_name_storage, name.data, name.size);
+    }
+
+    memset(info, 0, sizeof(*info));
+    info->name = iree_make_string_view(export_name_storage, name.size);
+    info->flags = IREE_HAL_EXECUTABLE_EXPORT_FLAG_NONE;
+    // Preserve the flatbuffer ABI counts even when the HSACO metadata has lost
+    // optimized-unused arguments.
+    info->constant_count =
+        (uint16_t)iree_hal_amdgpu_ExportDef_constant_count_get(export_def);
+    iree_hal_amdgpu_BindingBits_vec_t binding_flags =
+        iree_hal_amdgpu_ExportDef_binding_flags_get(export_def);
+    info->binding_count =
+        (uint16_t)iree_hal_amdgpu_BindingBits_vec_len(binding_flags);
+    info->parameter_count = requirements.parameter_count;
+    const iree_hal_amdgpu_Dims_struct_t workgroup_size =
+        iree_hal_amdgpu_ExportDef_workgroup_size_get(export_def);
+    info->workgroup_size[0] = workgroup_size->x;
+    info->workgroup_size[1] = workgroup_size->y;
+    info->workgroup_size[2] = workgroup_size->z;
+
+    iree_hal_executable_export_parameter_t* export_parameter_base =
+        requirements.parameter_count
+            ? &export_parameters[export_parameter_offset]
+            : NULL;
+    char* export_parameter_name_base =
+        requirements.name_storage_size ? export_parameter_name_storage : NULL;
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_hsaco_metadata_populate_default_export_parameters(
+            kernel, requirements.parameter_count, export_parameter_base,
+            requirements.name_storage_size, export_parameter_name_base),
+        "populating reflected parameters for export `%.*s`",
+        (int)symbol_name.size, symbol_name.data);
+
+    export_name_storage += name.size;
+    export_parameter_offset += requirements.parameter_count;
+    export_parameter_name_storage += requirements.name_storage_size;
+  }
+  export_parameter_offsets[export_count] = export_parameter_offset;
+  return iree_ok_status();
+}
+
+static iree_status_t
+iree_hal_amdgpu_executable_initialize_raw_hsaco_export_infos(
+    const iree_hal_amdgpu_hsaco_metadata_t* hsaco_metadata,
+    iree_hal_executable_export_info_t* export_infos,
+    iree_host_size_t* export_parameter_offsets,
+    iree_hal_executable_export_parameter_t* export_parameters,
+    char* export_name_storage, char* export_parameter_name_storage) {
+  iree_host_size_t export_parameter_offset = 0;
+  for (iree_host_size_t i = 0; i < hsaco_metadata->kernel_count; ++i) {
+    const iree_hal_amdgpu_hsaco_metadata_kernel_t* kernel =
+        &hsaco_metadata->kernels[i];
+    iree_hal_executable_export_info_t* info = &export_infos[i];
+    export_parameter_offsets[i] = export_parameter_offset;
+
+    iree_hal_amdgpu_hsaco_metadata_export_parameter_requirements_t requirements;
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_hsaco_metadata_calculate_default_export_parameter_requirements(
+            kernel, &requirements),
+        "projecting HSACO parameters for raw kernel `%.*s`",
+        (int)kernel->symbol_name.size, kernel->symbol_name.data);
+
+    iree_string_view_t name = kernel->reflection_name;
+    if (!iree_string_view_is_empty(name)) {
+      memcpy(export_name_storage, name.data, name.size);
+    }
+
+    memset(info, 0, sizeof(*info));
+    info->name = iree_make_string_view(export_name_storage, name.size);
+    info->flags = kernel->has_required_workgroup_size
+                      ? IREE_HAL_EXECUTABLE_EXPORT_FLAG_NONE
+                      : IREE_HAL_EXECUTABLE_EXPORT_FLAG_WORKGROUP_SIZE_DYNAMIC;
+    info->constant_count = requirements.constant_count;
+    info->binding_count = requirements.binding_count;
+    info->parameter_count = requirements.parameter_count;
+    iree_hal_amdgpu_executable_raw_hsaco_workgroup_size(kernel,
+                                                        info->workgroup_size);
+
+    iree_hal_executable_export_parameter_t* export_parameter_base =
+        requirements.parameter_count
+            ? &export_parameters[export_parameter_offset]
+            : NULL;
+    char* export_parameter_name_base =
+        requirements.name_storage_size ? export_parameter_name_storage : NULL;
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_hsaco_metadata_populate_default_export_parameters(
+            kernel, requirements.parameter_count, export_parameter_base,
+            requirements.name_storage_size, export_parameter_name_base),
+        "populating reflected parameters for raw kernel `%.*s`",
+        (int)kernel->symbol_name.size, kernel->symbol_name.data);
+
+    export_name_storage += name.size;
+    export_parameter_offset += requirements.parameter_count;
+    export_parameter_name_storage += requirements.name_storage_size;
+  }
+  export_parameter_offsets[hsaco_metadata->kernel_count] =
+      export_parameter_offset;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_executable_resolve_flatbuffer_kernel_args(
+    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_executable_t executable,
+    iree_hal_amdgpu_ExportDef_vec_t export_defs, hsa_agent_t any_device_agent,
+    iree_hal_amdgpu_device_kernel_args_t* host_kernel_args) {
+  const iree_host_size_t kernel_count =
+      iree_hal_amdgpu_ExportDef_vec_len(export_defs);
+  for (iree_host_size_t kernel_ordinal = 0; kernel_ordinal < kernel_count;
+       ++kernel_ordinal) {
+    iree_hal_amdgpu_ExportDef_table_t export_def =
+        iree_hal_amdgpu_ExportDef_vec_at(export_defs, kernel_ordinal);
+    flatbuffers_string_t symbol_name =
+        iree_hal_amdgpu_ExportDef_symbol_name_get(export_def);
+    iree_string_view_t symbol_name_view =
+        iree_hal_amdgpu_executable_flatbuffer_string_view(symbol_name);
+    if (iree_string_view_is_empty(symbol_name_view)) {
+      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                              "executable kernel symbol name is empty");
+    }
+
+    hsa_executable_symbol_t symbol = {0};
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_executable_get_symbol_by_cstring(
+            libhsa, executable, symbol_name, any_device_agent, &symbol),
+        "looking up HSA symbol for export `%.*s`", (int)symbol_name_view.size,
+        symbol_name_view.data);
+
+    const iree_hal_amdgpu_Dims_struct_t flatbuffer_workgroup_size =
+        iree_hal_amdgpu_ExportDef_workgroup_size_get(export_def);
+    const uint32_t workgroup_size[3] = {
+        flatbuffer_workgroup_size->x,
+        flatbuffer_workgroup_size->y,
+        flatbuffer_workgroup_size->z,
+    };
+    const uint16_t constant_count =
+        (uint16_t)iree_hal_amdgpu_ExportDef_constant_count_get(export_def);
+    iree_hal_amdgpu_BindingBits_vec_t binding_bits =
+        iree_hal_amdgpu_ExportDef_binding_flags_get(export_def);
+    const uint16_t binding_count =
+        (uint16_t)iree_hal_amdgpu_BindingBits_vec_len(binding_bits);
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_executable_resolve_kernel_args_from_symbol(
+            libhsa, symbol, workgroup_size, constant_count, binding_count,
+            &host_kernel_args[kernel_ordinal]),
+        "resolving kernel args for `%.*s`", (int)symbol_name_view.size,
+        symbol_name_view.data);
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_executable_resolve_raw_hsaco_kernel_args(
+    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_executable_t executable,
+    const iree_hal_amdgpu_hsaco_metadata_t* hsaco_metadata,
+    hsa_agent_t any_device_agent,
+    iree_hal_amdgpu_device_kernel_args_t* host_kernel_args) {
+  for (iree_host_size_t kernel_ordinal = 0;
+       kernel_ordinal < hsaco_metadata->kernel_count; ++kernel_ordinal) {
+    const iree_hal_amdgpu_hsaco_metadata_kernel_t* kernel =
+        &hsaco_metadata->kernels[kernel_ordinal];
+    iree_string_view_t symbol_name = kernel->symbol_name;
+
+    hsa_executable_symbol_t symbol = {0};
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_executable_get_raw_hsaco_symbol_by_name(
+            libhsa, executable, symbol_name, any_device_agent, &symbol),
+        "looking up HSA symbol for raw kernel `%.*s`", (int)symbol_name.size,
+        symbol_name.data);
+
+    iree_hal_amdgpu_hsaco_metadata_export_parameter_requirements_t requirements;
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_hsaco_metadata_calculate_default_export_parameter_requirements(
+            kernel, &requirements),
+        "projecting HSACO parameters for raw kernel `%.*s`",
+        (int)symbol_name.size, symbol_name.data);
+
+    uint32_t workgroup_size[3] = {0};
+    iree_hal_amdgpu_executable_raw_hsaco_workgroup_size(kernel, workgroup_size);
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_executable_resolve_kernel_args_from_symbol(
+            libhsa, symbol, workgroup_size, requirements.constant_count,
+            requirements.binding_count, &host_kernel_args[kernel_ordinal]),
+        "resolving kernel args for raw kernel `%.*s`", (int)symbol_name.size,
+        symbol_name.data);
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_executable_create_from_flatbuffer(
+    const iree_hal_amdgpu_libhsa_t* libhsa,
+    const iree_hal_amdgpu_topology_t* topology,
+    const iree_hal_amdgpu_queue_affinity_physical_device_set_t*
+        physical_devices,
+    const iree_hal_executable_params_t* executable_params,
+    const iree_hal_amdgpu_device_limits_t* limits, hsa_agent_t any_device_agent,
+    iree_hal_amdgpu_profile_metadata_registry_t* profile_metadata,
+    iree_allocator_t host_allocator, iree_hal_executable_t** out_executable) {
+  *out_executable = NULL;
+
+  iree_const_byte_span_t executable_flatbuffer = iree_const_byte_span_empty();
+  iree_hal_amdgpu_ExecutableDef_table_t executable_def = 0;
+  iree_hal_amdgpu_ExportDef_vec_t export_defs = 0;
+  iree_const_byte_span_t code_object_data = iree_const_byte_span_empty();
+  iree_host_size_t export_count = 0;
+  iree_status_t status = iree_hal_read_executable_flatbuffer_header(
+      executable_params->executable_data, /*unsafe_infer_size=*/false,
+      iree_hal_amdgpu_ExecutableDef_file_identifier, &executable_flatbuffer);
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_executable_flatbuffer_verify(executable_flatbuffer,
+                                                          limits);
+  }
+  if (iree_status_is_ok(status)) {
+    executable_def =
+        iree_hal_amdgpu_ExecutableDef_as_root(executable_flatbuffer.data);
+    export_defs = iree_hal_amdgpu_ExecutableDef_exports_get(executable_def);
+    export_count = iree_hal_amdgpu_ExportDef_vec_len(export_defs);
+    iree_hal_amdgpu_ModuleDef_vec_t module_defs =
+        iree_hal_amdgpu_ExecutableDef_modules_get(executable_def);
+    status = iree_hal_amdgpu_executable_get_single_module_image(
+        module_defs, &code_object_data);
+  }
+
+  iree_hal_amdgpu_hsaco_metadata_t hsaco_metadata = {0};
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_hsaco_metadata_initialize_from_elf(
+        code_object_data, host_allocator, &hsaco_metadata);
+  }
+
+  iree_host_size_t export_name_storage_size = 0;
+  iree_host_size_t export_parameter_count = 0;
+  iree_host_size_t export_parameter_name_storage_size = 0;
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_executable_calculate_reflection_storage(
+        export_defs, &hsaco_metadata, &export_name_storage_size,
+        &export_parameter_count, &export_parameter_name_storage_size);
+  }
+
+  iree_hal_amdgpu_executable_t* executable = NULL;
+  char* export_name_storage = NULL;
+  char* export_parameter_name_storage = NULL;
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_executable_allocate(
+        libhsa, topology, physical_devices, export_count,
+        export_name_storage_size, export_parameter_count,
+        export_parameter_name_storage_size, host_allocator,
+        &export_name_storage, &export_parameter_name_storage, &executable);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_executable_initialize_export_infos(
+        export_defs, &hsaco_metadata, executable->export_infos,
+        executable->export_parameter_offsets, executable->export_parameters,
+        export_name_storage, export_parameter_name_storage);
+  }
+  if (iree_status_is_ok(status)) {
+    iree_hal_amdgpu_profile_metadata_hash_code_object(
+        code_object_data, executable->profile_code_object_hash);
+  }
+
+  // Publish any embedded source files to the tracing infrastructure.
+  if (iree_status_is_ok(status)) {
+    iree_hal_debug_publish_source_files(
+        iree_hal_amdgpu_ExecutableDef_source_files_get(executable_def));
+  }
+
+  // Load executable and register it with selected GPU agents.
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_executable_load_module(
+        libhsa, topology, physical_devices, executable_params, code_object_data,
+        &executable->handle);
+  }
+
+  // Resolve kernel args for each export.
+  // These parameters should be the same for all devices as we require all
+  // devices have the same ISA. The only thing that will differ is the
+  // kernel_object pointer and we handle that per-device during table upload.
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_executable_resolve_flatbuffer_kernel_args(
+        libhsa, executable->handle, export_defs, any_device_agent,
+        executable->host_kernel_args);
+  }
+
+  // Upload copies of kernel arguments for each device.
+  // We reuse the host storage we already allocated to make it possible to
+  // memcpy the entire table in one go from host memory.
+  for (iree_host_size_t device_ordinal = 0;
+       iree_status_is_ok(status) && device_ordinal < executable->device_count;
+       ++device_ordinal) {
+    if (!iree_hal_amdgpu_physical_device_mask_contains(
+            executable->loaded_physical_device_mask, device_ordinal)) {
+      continue;
+    }
+    status = iree_hal_amdgpu_executable_upload_flatbuffer_kernel_table(
+        libhsa, executable->handle, export_defs, executable->host_kernel_args,
+        topology->gpu_agents[device_ordinal],
+        &executable->device_kernel_args[device_ordinal]);
+    if (iree_status_is_ok(status)) {
+      status =
+          iree_hal_amdgpu_executable_initialize_dispatch_descriptors_for_device(
+              executable, device_ordinal);
+    }
+  }
+
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_executable_register_profile_metadata(
+        libhsa, topology, profile_metadata, code_object_data, executable);
+  }
+
+  // Invalidate the kernel object pointer in all host args so that we don't
+  // accidentally use it instead of the device-specific one.
+  if (iree_status_is_ok(status)) {
+    iree_hal_amdgpu_executable_invalidate_host_kernel_objects(executable);
+  }
+
+  iree_hal_amdgpu_hsaco_metadata_deinitialize(&hsaco_metadata);
+
+  if (iree_status_is_ok(status)) {
+    *out_executable = (iree_hal_executable_t*)executable;
+  } else if (executable) {
+    iree_hal_executable_destroy((iree_hal_executable_t*)executable);
+  }
+  return status;
+}
+
+static iree_status_t iree_hal_amdgpu_executable_create_from_raw_hsaco(
+    const iree_hal_amdgpu_libhsa_t* libhsa,
+    const iree_hal_amdgpu_topology_t* topology,
+    const iree_hal_amdgpu_queue_affinity_physical_device_set_t*
+        physical_devices,
+    const iree_hal_executable_params_t* executable_params,
+    const iree_hal_amdgpu_device_limits_t* limits, hsa_agent_t any_device_agent,
+    iree_hal_amdgpu_profile_metadata_registry_t* profile_metadata,
+    iree_allocator_t host_allocator, iree_hal_executable_t** out_executable) {
+  *out_executable = NULL;
+
+  iree_const_byte_span_t code_object_data = executable_params->executable_data;
+  iree_hal_amdgpu_hsaco_metadata_t hsaco_metadata = {0};
+  iree_status_t status = iree_hal_amdgpu_hsaco_metadata_initialize_from_elf(
+      code_object_data, host_allocator, &hsaco_metadata);
+
+  iree_host_size_t export_name_storage_size = 0;
+  iree_host_size_t export_parameter_count = 0;
+  iree_host_size_t export_parameter_name_storage_size = 0;
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_executable_calculate_raw_hsaco_reflection_storage(
+        &hsaco_metadata, limits, &export_name_storage_size,
+        &export_parameter_count, &export_parameter_name_storage_size);
+  }
+
+  iree_hal_amdgpu_executable_t* executable = NULL;
+  char* export_name_storage = NULL;
+  char* export_parameter_name_storage = NULL;
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_executable_allocate(
+        libhsa, topology, physical_devices, hsaco_metadata.kernel_count,
+        export_name_storage_size, export_parameter_count,
+        export_parameter_name_storage_size, host_allocator,
+        &export_name_storage, &export_parameter_name_storage, &executable);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_executable_initialize_raw_hsaco_export_infos(
+        &hsaco_metadata, executable->export_infos,
+        executable->export_parameter_offsets, executable->export_parameters,
+        export_name_storage, export_parameter_name_storage);
+  }
+  if (iree_status_is_ok(status)) {
+    iree_hal_amdgpu_profile_metadata_hash_code_object(
+        code_object_data, executable->profile_code_object_hash);
+  }
+
+  // Load executable and register it with selected GPU agents.
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_executable_load_module(
+        libhsa, topology, physical_devices, executable_params, code_object_data,
+        &executable->handle);
+  }
+
+  // Resolve kernel args for each export.
+  // These parameters should be the same for all devices as we require all
+  // devices have the same ISA. The only thing that will differ is the
+  // kernel_object pointer and we handle that per-device during table upload.
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_executable_resolve_raw_hsaco_kernel_args(
+        libhsa, executable->handle, &hsaco_metadata, any_device_agent,
+        executable->host_kernel_args);
+  }
+
+  // Upload copies of kernel arguments for each device.
+  // We reuse the host storage we already allocated to make it possible to
+  // memcpy the entire table in one go from host memory.
+  for (iree_host_size_t device_ordinal = 0;
+       iree_status_is_ok(status) && device_ordinal < executable->device_count;
+       ++device_ordinal) {
+    if (!iree_hal_amdgpu_physical_device_mask_contains(
+            executable->loaded_physical_device_mask, device_ordinal)) {
+      continue;
+    }
+    status = iree_hal_amdgpu_executable_upload_raw_hsaco_kernel_table(
+        libhsa, executable->handle, &hsaco_metadata,
+        executable->host_kernel_args, topology->gpu_agents[device_ordinal],
+        &executable->device_kernel_args[device_ordinal]);
+    if (iree_status_is_ok(status)) {
+      status =
+          iree_hal_amdgpu_executable_initialize_dispatch_descriptors_for_device(
+              executable, device_ordinal);
+    }
+  }
+
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_executable_register_profile_metadata(
+        libhsa, topology, profile_metadata, code_object_data, executable);
+  }
+
+  // Invalidate the kernel object pointer in all host args so that we don't
+  // accidentally use it instead of the device-specific one.
+  if (iree_status_is_ok(status)) {
+    iree_hal_amdgpu_executable_invalidate_host_kernel_objects(executable);
+  }
+
+  iree_hal_amdgpu_hsaco_metadata_deinitialize(&hsaco_metadata);
+
+  if (iree_status_is_ok(status)) {
+    *out_executable = (iree_hal_executable_t*)executable;
+  } else if (executable) {
+    iree_hal_executable_destroy((iree_hal_executable_t*)executable);
+  }
+  return status;
+}
+
 iree_status_t iree_hal_amdgpu_executable_create(
     const iree_hal_amdgpu_libhsa_t* libhsa,
     const iree_hal_amdgpu_topology_t* topology,
     const iree_hal_executable_params_t* executable_params,
+    iree_hal_amdgpu_profile_metadata_registry_t* profile_metadata,
     iree_allocator_t host_allocator, iree_hal_executable_t** out_executable) {
   IREE_ASSERT_ARGUMENT(executable_params);
+  IREE_ASSERT_ARGUMENT(profile_metadata);
   IREE_ASSERT_ARGUMENT(out_executable);
   *out_executable = NULL;
   IREE_TRACE_ZONE_BEGIN(z0);
 
-  // TODO(benvanik): use executable_params->queue_affinity instead of the raw
-  // topology - the affinity will tell us exactly which physical devices we need
-  // to load the executable on. We have to map from queue affinity to GPU agent
-  // and don't have a utility for that accessible here yet.
-
-  // Pick a device to be our template for device queries. All devices in the
-  // topology are expected to be the same. This should have been checked
-  // earlier but we do it here in case the user is bypassing that code.
-  IREE_ASSERT_GE(topology->gpu_agent_count, 1);
   if (IREE_UNLIKELY(topology->gpu_agent_count == 0)) {
     IREE_RETURN_AND_END_ZONE_IF_ERROR(
         z0, iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
                              "topology must have at least one GPU device"));
   }
-  hsa_agent_t any_device_agent = topology->gpu_agents[0];
+
+  // Resolve the executable queue affinity to the physical devices that need
+  // code-object loads and per-device kernel tables.
+  iree_hal_amdgpu_queue_affinity_physical_device_set_t physical_devices;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_hal_amdgpu_executable_select_physical_devices(
+              topology, executable_params->queue_affinity, &physical_devices));
+
+  // Pick a selected device to be our template for device queries. All devices
+  // in the topology are expected to be the same. This should have been checked
+  // earlier but we do it here in case the user is bypassing that code.
+  hsa_agent_t any_device_agent =
+      topology->gpu_agents[physical_devices.first_physical_device_ordinal];
 
   // Check that the executable is supported and get the ISA it matches.
   bool supported = false;
@@ -774,126 +2076,37 @@
                              executable_params->executable_format.data));
   }
 
-  // Verify the flatbuffer is valid.
-  // Doing this first ensures we don't need to check the structure of the
-  // flatbuffer during loading (though things like optional fields still need to
-  // be checked!).
   iree_hal_amdgpu_device_limits_t limits = {0};
   IREE_RETURN_AND_END_ZONE_IF_ERROR(
       z0, iree_hal_amdgpu_query_device_limits(libhsa, any_device_agent, isa,
                                               &limits));
 
-  // Read and strip the flatbuffer header prefix.
-  iree_const_byte_span_t executable_flatbuffer = iree_const_byte_span_empty();
-  IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0, iree_hal_read_executable_flatbuffer_header(
-              executable_params->executable_data, /*unsafe_infer_size=*/false,
-              iree_hal_amdgpu_ExecutableDef_file_identifier,
-              &executable_flatbuffer));
-
-  IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0, iree_hal_amdgpu_executable_flatbuffer_verify(executable_flatbuffer,
-                                                       &limits));
-
-  // Dereference the flatbuffer.
-  iree_hal_amdgpu_ExecutableDef_table_t executable_def =
-      iree_hal_amdgpu_ExecutableDef_as_root(executable_flatbuffer.data);
-  iree_hal_amdgpu_ExportDef_vec_t export_defs =
-      iree_hal_amdgpu_ExecutableDef_exports_get(executable_def);
-  const iree_host_size_t export_count =
-      iree_hal_amdgpu_ExportDef_vec_len(export_defs);
-
-  // Allocate storage for the executable and its associated data structures.
-  iree_hal_amdgpu_executable_t* executable = NULL;
-  const iree_host_size_t total_size =
-      sizeof(*executable) +
-      export_count * sizeof(executable->host_kernel_args[0]) +
-      topology->gpu_agent_count * sizeof(executable->device_kernel_args[0]);
-  IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0,
-      iree_allocator_malloc(host_allocator, total_size, (void**)&executable));
-  iree_hal_resource_initialize(&iree_hal_amdgpu_executable_vtable,
-                               &executable->resource);
-  executable->host_allocator = host_allocator;
-  executable->libhsa = libhsa;
-  executable->kernel_count = export_count;
-  executable->host_kernel_args =
-      (iree_hal_amdgpu_device_kernel_args_t*)(((uint8_t*)executable) +
-                                              sizeof(*executable));
-  executable->device_count = topology->gpu_agent_count;
-
-  // Publish any embedded source files to the tracing infrastructure.
-  iree_hal_debug_publish_source_files(
-      iree_hal_amdgpu_ExecutableDef_source_files_get(executable_def));
-
   iree_status_t status = iree_ok_status();
-
-  // Intern source locations for all exported functions. These will persist for
-  // the lifetime of the process and be passed to tooling as if they were in a
-  // rodata segment.
-  iree_hal_amdgpu_trace_src_loc_t* export_locs = NULL;
-#if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION_DEVICE
-  if (iree_status_is_ok(status)) {
-    status = iree_hal_amdgpu_executable_intern_trace_locs(
-        export_defs, host_allocator, &export_locs);
-  }
-#endif  // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION_DEVICE
-
-  // Load executable and register it with all GPU agents.
-  if (iree_status_is_ok(status)) {
-    status = iree_hal_amdgpu_executable_load_modules(
-        libhsa, topology, executable_params,
-        iree_hal_amdgpu_ExecutableDef_modules_get(executable_def),
-        &executable->handle);
-  }
-
-  // Resolve kernel args for each export.
-  // These parameters should be the same for all devices as we require all
-  // devices have the same ISA. The only thing that will differ is the
-  // kernel_object pointer and we handle that per-device during table upload.
-  for (iree_host_size_t kernel_ordinal = 0;
-       iree_status_is_ok(status) && kernel_ordinal < executable->kernel_count;
-       ++kernel_ordinal) {
-    iree_hal_amdgpu_ExportDef_table_t export_def =
-        iree_hal_amdgpu_ExportDef_vec_at(export_defs, kernel_ordinal);
-    const iree_hal_amdgpu_trace_src_loc_t* export_loc =
-        export_locs ? &export_locs[kernel_ordinal] : NULL;
-    status = iree_status_annotate_f(
-        iree_hal_amdgpu_executable_resolve_kernel_args(
-            libhsa, executable->handle, export_def, export_loc,
-            any_device_agent, &executable->host_kernel_args[kernel_ordinal]),
-        "resolving kernel args for `%s`",
-        iree_hal_amdgpu_ExportDef_symbol_name_get(export_def));
-  }
-
-  // Upload copies of kernel arguments for each device.
-  // We reuse the host storage we already allocated to make it possible to
-  // memcpy the entire table in one go from host memory.
-  for (iree_host_size_t device_ordinal = 0;
-       iree_status_is_ok(status) && device_ordinal < executable->device_count;
-       ++device_ordinal) {
-    status = iree_hal_amdgpu_executable_upload_kernel_table(
-        libhsa, executable->handle, export_defs, executable->kernel_count,
-        executable->host_kernel_args, topology->gpu_agents[device_ordinal],
-        &executable->device_kernel_args[device_ordinal]);
-  }
-
-  // Invalidate the kernel object pointer in all host args so that we don't
-  // accidentally use it instead of the device-specific one.
-  for (iree_host_size_t kernel_ordinal = 0;
-       kernel_ordinal < executable->kernel_count; ++kernel_ordinal) {
-    executable->host_kernel_args[kernel_ordinal].kernel_object = 0;
-  }
-
-  if (iree_status_is_ok(status)) {
-    *out_executable = (iree_hal_executable_t*)executable;
+  if (iree_hal_amdgpu_executable_data_is_wrapped_flatbuffer(
+          executable_params->executable_data)) {
+    status = iree_hal_amdgpu_executable_create_from_flatbuffer(
+        libhsa, topology, &physical_devices, executable_params, &limits,
+        any_device_agent, profile_metadata, host_allocator, out_executable);
   } else {
-    iree_hal_executable_destroy((iree_hal_executable_t*)executable);
+    status = iree_hal_amdgpu_executable_create_from_raw_hsaco(
+        libhsa, topology, &physical_devices, executable_params, &limits,
+        any_device_agent, profile_metadata, host_allocator, out_executable);
+  }
+  if (!iree_status_is_ok(status)) {
+    iree_hal_executable_release(*out_executable);
+    *out_executable = NULL;
   }
   IREE_TRACE_ZONE_END(z0);
   return status;
 }
 
+uint64_t iree_hal_amdgpu_executable_profile_id(
+    iree_hal_executable_t* base_executable) {
+  iree_hal_amdgpu_executable_t* executable =
+      iree_hal_amdgpu_executable_cast(base_executable);
+  return executable->profile_id;
+}
+
 static void iree_hal_amdgpu_executable_destroy(
     iree_hal_executable_t* base_executable) {
   iree_hal_amdgpu_executable_t* executable =
@@ -905,14 +2118,14 @@
        device_ordinal < executable->device_count; ++device_ordinal) {
     void* kernel_args = (void*)executable->device_kernel_args[device_ordinal];
     if (kernel_args) {
-      IREE_IGNORE_ERROR(iree_hsa_amd_memory_pool_free(
-          IREE_LIBHSA(executable->libhsa), kernel_args));
+      iree_hal_amdgpu_hsa_cleanup_assert_success(
+          iree_hsa_amd_memory_pool_free_raw(executable->libhsa, kernel_args));
     }
   }
 
   if (executable->handle.handle) {
-    IREE_IGNORE_ERROR(iree_hsa_executable_destroy(
-        IREE_LIBHSA(executable->libhsa), executable->handle));
+    iree_hal_amdgpu_hsa_cleanup_assert_success(iree_hsa_executable_destroy_raw(
+        executable->libhsa, executable->handle));
   }
 
   iree_allocator_free(host_allocator, executable);
@@ -957,9 +2170,15 @@
   } else if (IREE_UNLIKELY(device_ordinal >= executable->device_count)) {
     return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
                             "device ordinal %" PRIhsz
-                            " out of range; executable was loaded on %" PRIhsz
-                            " devices",
+                            " out of range; executable topology has %" PRIhsz
+                            " physical devices",
                             device_ordinal, executable->device_count);
+  } else if (IREE_UNLIKELY(!iree_hal_amdgpu_physical_device_mask_contains(
+                 executable->loaded_physical_device_mask, device_ordinal))) {
+    return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                            "device ordinal %" PRIhsz
+                            " is not in executable queue affinity 0x%" PRIx64,
+                            device_ordinal, executable->queue_affinity);
   }
 
   *out_kernel_args =
@@ -968,13 +2187,45 @@
   return iree_ok_status();
 }
 
+iree_status_t iree_hal_amdgpu_executable_lookup_dispatch_descriptor_for_device(
+    iree_hal_executable_t* base_executable,
+    iree_hal_executable_export_ordinal_t export_ordinal,
+    iree_host_size_t device_ordinal,
+    const iree_hal_amdgpu_executable_dispatch_descriptor_t** out_descriptor) {
+  const iree_hal_amdgpu_executable_t* executable =
+      iree_hal_amdgpu_executable_const_cast(base_executable);
+  *out_descriptor = NULL;
+
+  if (IREE_UNLIKELY(export_ordinal >= executable->kernel_count)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "export ordinal %" PRIu32
+                            " out of range; executable has %" PRIhsz " exports",
+                            export_ordinal, executable->kernel_count);
+  } else if (IREE_UNLIKELY(device_ordinal >= executable->device_count)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "device ordinal %" PRIhsz
+                            " out of range; executable topology has %" PRIhsz
+                            " physical devices",
+                            device_ordinal, executable->device_count);
+  } else if (IREE_UNLIKELY(!iree_hal_amdgpu_physical_device_mask_contains(
+                 executable->loaded_physical_device_mask, device_ordinal))) {
+    return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                            "device ordinal %" PRIhsz
+                            " is not in executable queue affinity 0x%" PRIx64,
+                            device_ordinal, executable->queue_affinity);
+  }
+
+  const iree_host_size_t descriptor_ordinal =
+      device_ordinal * executable->kernel_count + export_ordinal;
+  *out_descriptor = &executable->host_dispatch_descriptors[descriptor_ordinal];
+  return iree_ok_status();
+}
+
 static iree_host_size_t iree_hal_amdgpu_executable_export_count(
     iree_hal_executable_t* base_executable) {
   iree_hal_amdgpu_executable_t* executable =
       iree_hal_amdgpu_executable_cast(base_executable);
-  // TODO(amdgpu): return the total number of exports in the executable.
-  (void)executable;
-  return 0;
+  return executable->kernel_count;
 }
 
 static iree_status_t iree_hal_amdgpu_executable_export_info(
@@ -983,10 +2234,15 @@
     iree_hal_executable_export_info_t* out_info) {
   iree_hal_amdgpu_executable_t* executable =
       iree_hal_amdgpu_executable_cast(base_executable);
-  (void)executable;
-  // TODO(amdgpu): return export information from kernel metadata.
-  return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
-                          "reflection not implemented");
+  memset(out_info, 0, sizeof(*out_info));
+  if (IREE_UNLIKELY(export_ordinal >= executable->kernel_count)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "export ordinal %" PRIu32
+                            " out of range; executable has %" PRIhsz " exports",
+                            export_ordinal, executable->kernel_count);
+  }
+  *out_info = executable->export_infos[export_ordinal];
+  return iree_ok_status();
 }
 
 static iree_status_t iree_hal_amdgpu_executable_export_parameters(
@@ -994,12 +2250,26 @@
     iree_hal_executable_export_ordinal_t export_ordinal,
     iree_host_size_t capacity,
     iree_hal_executable_export_parameter_t* out_parameters) {
+  IREE_ASSERT_ARGUMENT(out_parameters || capacity == 0);
   iree_hal_amdgpu_executable_t* executable =
       iree_hal_amdgpu_executable_cast(base_executable);
-  (void)executable;
-  // TODO(amdgpu): return export parameter information from kernel metadata.
-  return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
-                          "parameter reflection not implemented");
+  if (IREE_UNLIKELY(export_ordinal >= executable->kernel_count)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "export ordinal %" PRIu32
+                            " out of range; executable has %" PRIhsz " exports",
+                            export_ordinal, executable->kernel_count);
+  }
+  const iree_host_size_t parameter_begin =
+      executable->export_parameter_offsets[export_ordinal];
+  const iree_host_size_t parameter_end =
+      executable->export_parameter_offsets[export_ordinal + 1];
+  const iree_host_size_t parameter_count = parameter_end - parameter_begin;
+  const iree_host_size_t copy_count = iree_min(capacity, parameter_count);
+  if (copy_count > 0) {
+    memcpy(out_parameters, &executable->export_parameters[parameter_begin],
+           copy_count * sizeof(out_parameters[0]));
+  }
+  return iree_ok_status();
 }
 
 static iree_status_t iree_hal_amdgpu_executable_lookup_export_by_name(
@@ -1007,10 +2277,16 @@
     iree_hal_executable_export_ordinal_t* out_export_ordinal) {
   iree_hal_amdgpu_executable_t* executable =
       iree_hal_amdgpu_executable_cast(base_executable);
-  (void)executable;
-  // TODO(amdgpu): lookup the export ordinal by name.
-  return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
-                          "reflection not implemented");
+  for (iree_host_size_t i = 0; i < executable->kernel_count; ++i) {
+    iree_string_view_t export_name = executable->export_infos[i].name;
+    if (iree_string_view_equal(export_name, name)) {
+      *out_export_ordinal = (iree_hal_executable_export_ordinal_t)i;
+      return iree_ok_status();
+    }
+  }
+  return iree_make_status(IREE_STATUS_NOT_FOUND,
+                          "export '%.*s' not found in executable",
+                          (int)name.size, name.data);
 }
 
 static const iree_hal_executable_vtable_t iree_hal_amdgpu_executable_vtable = {
diff --git a/runtime/src/iree/hal/drivers/amdgpu/executable.h b/runtime/src/iree/hal/drivers/amdgpu/executable.h
index a954f12..e9a38a0 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/executable.h
+++ b/runtime/src/iree/hal/drivers/amdgpu/executable.h
@@ -9,11 +9,17 @@
 
 #include "iree/base/api.h"
 #include "iree/hal/api.h"
-#include "iree/hal/drivers/amdgpu/device/support/kernel_args.h"
+#include "iree/hal/drivers/amdgpu/abi/kernel_args.h"
+#include "iree/hal/drivers/amdgpu/device/dispatch.h"
+#include "iree/hal/drivers/amdgpu/profile_metadata.h"
 #include "iree/hal/drivers/amdgpu/util/libhsa.h"
 
 typedef struct iree_hal_amdgpu_topology_t iree_hal_amdgpu_topology_t;
 
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
 //===----------------------------------------------------------------------===//
 // ISA Support
 //===----------------------------------------------------------------------===//
@@ -27,13 +33,14 @@
     const iree_hal_amdgpu_libhsa_t* libhsa,
     const iree_hal_amdgpu_topology_t* topology);
 
-// Returns whether the canonical IREE HAL executable |format| is supported by
-// all GPU devices in |topology|. Some devices may support multiple ISAs.
+// Returns whether the IREE HAL executable |format| is supported by all GPU
+// devices in |topology|. Some devices may support multiple ISAs.
 //
-// To avoid creating yet another naming scheme we directly use the ISA names
-// reported by HSA, e.g. `amdgcn-amd-amdhsa--gfx1100`. It's not pretty, but it
-// is precise for this particular HAL and lets us avoid any potential runtime
-// changes if LLVM<->HSA naming changes with new code object versions.
+// Supports AMDGPU target IDs in both compiler spelling (`gfx1100`,
+// `gfx942:xnack-`) and the canonical ISA names reported by HSA
+// (`amdgcn-amd-amdhsa--gfx1100`). Matching uses structured target-ID
+// compatibility so generic code-object targets and explicit feature modes can
+// be checked without relying on string equality.
 //
 // Optionally |out_isa| can be used to get the agent ISA for the given format.
 // Note that this will be from the first device but should match all other
@@ -54,14 +61,41 @@
 // This is limited by the field size in iree_hal_amdgpu_device_kernel_args_t.
 #define IREE_HAL_AMDGPU_MAX_DISPATCH_CONSTANT_COUNT UINT16_MAX
 
+// Host-resident dispatch metadata precomputed for one executable export on one
+// physical device.
+//
+// Descriptors are immutable after executable creation and remain valid for the
+// lifetime of the executable. They intentionally duplicate the device-visible
+// kernel argument table in ordinary host memory so queue submission does not
+// read per-dispatch metadata from memory allocated for GPU visibility.
+typedef struct iree_hal_amdgpu_executable_dispatch_descriptor_t {
+  // Device-specific kernel arguments with a valid kernel_object for dispatch.
+  iree_hal_amdgpu_device_kernel_args_t kernel_args;
+  // HAL ABI kernarg layout derived from |kernel_args|.
+  iree_hal_amdgpu_device_dispatch_kernarg_layout_t hal_kernarg_layout;
+  // Custom direct-argument kernarg layout derived from |kernel_args|.
+  iree_hal_amdgpu_device_dispatch_kernarg_layout_t custom_kernarg_layout;
+  // Queue kernarg-ring block count for HAL ABI dispatches.
+  uint32_t hal_kernarg_block_count;
+  // Queue kernarg-ring block count for custom direct-argument dispatches.
+  uint32_t custom_kernarg_block_count;
+  // Maximum static workgroup count accepted for each dimension.
+  uint32_t max_workgroup_count[3];
+  // Maximum dynamic group-memory byte count accepted for this export.
+  uint32_t max_dynamic_workgroup_local_memory;
+} iree_hal_amdgpu_executable_dispatch_descriptor_t;
+
 // Infers the format of the executable and calculates its total size.
 // If executable_data.data_length is 0 attempts to infer size from the data.
-// Returns the canonical format string and total size of the executable data.
-// The format will be the ISA name like "amdgcn-amd-amdhsa--gfx1100".
+// Returns the canonical target-ID format string and total size of the
+// executable data.
+//
+// Wrapped AMDGPU flatbuffers infer the target ID from the embedded ELF image
+// instead of trusting the flatbuffer metadata target label.
 iree_status_t iree_hal_amdgpu_executable_infer_format(
     iree_const_byte_span_t executable_data,
     iree_host_size_t executable_format_capacity, char* executable_format,
-    iree_host_size_t* out_inferred_size);
+    iree_allocator_t host_allocator, iree_host_size_t* out_inferred_size);
 
 // Creates a AMDGPU executable from a binary in memory. Each executable may
 // contain multiple entry points and be composed of several modules presented to
@@ -70,12 +104,22 @@
 //
 // |libhsa| and |topology| are captured by-reference and must remain valid for
 // the lifetime of the cache.
+//
+// Exact code-object image bytes and loader load ranges are retained in profile
+// metadata for offline trace/disassembly workflows. Executable trace profiling
+// may begin after executable preparation, so this cold-path metadata is always
+// durable instead of being gated on an active profiling session.
 iree_status_t iree_hal_amdgpu_executable_create(
     const iree_hal_amdgpu_libhsa_t* libhsa,
     const iree_hal_amdgpu_topology_t* topology,
     const iree_hal_executable_params_t* executable_params,
+    iree_hal_amdgpu_profile_metadata_registry_t* profile_metadata,
     iree_allocator_t host_allocator, iree_hal_executable_t** out_executable);
 
+// Returns the producer-local profile executable id assigned at creation.
+uint64_t iree_hal_amdgpu_executable_profile_id(
+    iree_hal_executable_t* executable);
+
 // Returns metadata about an exported kernel function in host memory.
 // The returned pointers will remain valid for the lifetime of the executable.
 // The returned kernel_object field is undefined in the returned args as there
@@ -89,11 +133,30 @@
 // Returns metadata about an exported kernel function in device memory.
 // Kernel arguments are specific to the physical device specified by
 // |device_ordinal| in the topology and cannot be used on any other device. The
-// returned pointers will remain valid for the lifetime of the executable.
+// lookup fails if the executable queue affinity did not include
+// |device_ordinal| at load time. The returned pointers will remain valid for
+// the lifetime of the executable.
 iree_status_t iree_hal_amdgpu_executable_lookup_kernel_args_for_device(
     iree_hal_executable_t* executable,
     iree_hal_executable_export_ordinal_t export_ordinal,
     iree_host_size_t device_ordinal,
     const iree_hal_amdgpu_device_kernel_args_t** out_kernel_args);
 
+// Returns host-resident dispatch metadata for an exported kernel function on a
+// physical device.
+//
+// The returned descriptor is specific to |device_ordinal| because the kernel
+// object embedded in the dispatch packet is per device. The lookup fails if the
+// executable queue affinity did not include |device_ordinal| at load time. The
+// pointer remains valid for the lifetime of the executable.
+iree_status_t iree_hal_amdgpu_executable_lookup_dispatch_descriptor_for_device(
+    iree_hal_executable_t* executable,
+    iree_hal_executable_export_ordinal_t export_ordinal,
+    iree_host_size_t device_ordinal,
+    const iree_hal_amdgpu_executable_dispatch_descriptor_t** out_descriptor);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
 #endif  // IREE_HAL_DRIVERS_AMDGPU_EXECUTABLE_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/executable_cache.c b/runtime/src/iree/hal/drivers/amdgpu/executable_cache.c
index bbaa992..9479c10 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/executable_cache.c
+++ b/runtime/src/iree/hal/drivers/amdgpu/executable_cache.c
@@ -14,10 +14,16 @@
 //===----------------------------------------------------------------------===//
 
 typedef struct iree_hal_amdgpu_executable_cache_t {
+  // HAL resource header.
   iree_hal_resource_t resource;
+  // Host allocator used for cache lifetime.
   iree_allocator_t host_allocator;
+  // Borrowed HSA API table used for executable loading.
   const iree_hal_amdgpu_libhsa_t* libhsa;
+  // Borrowed topology describing the physical devices to load onto.
   const iree_hal_amdgpu_topology_t* topology;
+  // Borrowed logical-device profiling metadata registry.
+  iree_hal_amdgpu_profile_metadata_registry_t* profile_metadata;
 } iree_hal_amdgpu_executable_cache_t;
 
 static const iree_hal_executable_cache_vtable_t
@@ -31,8 +37,9 @@
 
 iree_status_t iree_hal_amdgpu_executable_cache_create(
     const iree_hal_amdgpu_libhsa_t* libhsa,
-    const iree_hal_amdgpu_topology_t* topology, iree_string_view_t identifier,
-    iree_allocator_t host_allocator,
+    const iree_hal_amdgpu_topology_t* topology,
+    iree_hal_amdgpu_profile_metadata_registry_t* profile_metadata,
+    iree_string_view_t identifier, iree_allocator_t host_allocator,
     iree_hal_executable_cache_t** out_executable_cache) {
   IREE_ASSERT_ARGUMENT(out_executable_cache);
   IREE_TRACE_ZONE_BEGIN(z0);
@@ -56,6 +63,7 @@
   executable_cache->host_allocator = host_allocator;
   executable_cache->libhsa = libhsa;
   executable_cache->topology = topology;
+  executable_cache->profile_metadata = profile_metadata;
 
   *out_executable_cache = (iree_hal_executable_cache_t*)executable_cache;
   IREE_TRACE_ZONE_END(z0);
@@ -80,10 +88,12 @@
     iree_const_byte_span_t executable_data,
     iree_host_size_t executable_format_capacity, char* executable_format,
     iree_host_size_t* out_inferred_size) {
+  iree_hal_amdgpu_executable_cache_t* executable_cache =
+      iree_hal_amdgpu_executable_cache_cast(base_executable_cache);
   IREE_TRACE_ZONE_BEGIN(z0);
   iree_status_t status = iree_hal_amdgpu_executable_infer_format(
       executable_data, executable_format_capacity, executable_format,
-      out_inferred_size);
+      executable_cache->host_allocator, out_inferred_size);
   IREE_TRACE_ZONE_END(z0);
   return status;
 }
@@ -95,9 +105,15 @@
   iree_hal_amdgpu_executable_cache_t* executable_cache =
       iree_hal_amdgpu_executable_cache_cast(base_executable_cache);
   bool is_supported = false;
-  IREE_IGNORE_ERROR(iree_hal_amdgpu_executable_format_supported(
+  iree_status_t status = iree_hal_amdgpu_executable_format_supported(
       executable_cache->libhsa, executable_cache->topology->gpu_agents[0],
-      executable_format, &is_supported, /*out_isa=*/NULL));
+      executable_format, &is_supported, /*out_isa=*/NULL);
+  if (!iree_status_is_ok(status)) {
+    // The HAL cache predicate has no status channel; query failures mean the
+    // format cannot be prepared by this cache.
+    iree_status_free(status);
+    return false;
+  }
   return is_supported;
 }
 
@@ -109,7 +125,8 @@
       iree_hal_amdgpu_executable_cache_cast(base_executable_cache);
   return iree_hal_amdgpu_executable_create(
       executable_cache->libhsa, executable_cache->topology, executable_params,
-      executable_cache->host_allocator, out_executable);
+      executable_cache->profile_metadata, executable_cache->host_allocator,
+      out_executable);
 }
 
 static const iree_hal_executable_cache_vtable_t
diff --git a/runtime/src/iree/hal/drivers/amdgpu/executable_cache.h b/runtime/src/iree/hal/drivers/amdgpu/executable_cache.h
index 5137a2a..d5d88f1 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/executable_cache.h
+++ b/runtime/src/iree/hal/drivers/amdgpu/executable_cache.h
@@ -9,6 +9,7 @@
 
 #include "iree/base/api.h"
 #include "iree/hal/api.h"
+#include "iree/hal/drivers/amdgpu/profile_metadata.h"
 #include "iree/hal/drivers/amdgpu/util/libhsa.h"
 
 typedef struct iree_hal_amdgpu_topology_t iree_hal_amdgpu_topology_t;
@@ -24,10 +25,14 @@
 //
 // |libhsa| and |topology| are captured by-reference and must remain valid for
 // the lifetime of the cache.
+//
+// Exact code-object image bytes and loader load ranges are retained in profile
+// metadata for every prepared executable.
 iree_status_t iree_hal_amdgpu_executable_cache_create(
     const iree_hal_amdgpu_libhsa_t* libhsa,
-    const iree_hal_amdgpu_topology_t* topology, iree_string_view_t identifier,
-    iree_allocator_t host_allocator,
+    const iree_hal_amdgpu_topology_t* topology,
+    iree_hal_amdgpu_profile_metadata_registry_t* profile_metadata,
+    iree_string_view_t identifier, iree_allocator_t host_allocator,
     iree_hal_executable_cache_t** out_executable_cache);
 
 #endif  // IREE_HAL_DRIVERS_AMDGPU_EXECUTABLE_CACHE_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/executable_test.cc b/runtime/src/iree/hal/drivers/amdgpu/executable_test.cc
new file mode 100644
index 0000000..ca59526
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/executable_test.cc
@@ -0,0 +1,185 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/executable.h"
+
+#include <array>
+#include <cstring>
+#include <string>
+#include <vector>
+
+#include "iree/base/alignment.h"
+#include "iree/base/api.h"
+#include "iree/base/internal/flatcc/building.h"
+#include "iree/schemas/amdgpu_executable_def_builder.h"
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace iree::hal::amdgpu {
+namespace {
+
+static constexpr uint8_t kElfClass64 = 2;
+static constexpr uint8_t kElfData2Lsb = 1;
+static constexpr uint8_t kElfVersionCurrent = 1;
+static constexpr uint8_t kElfOsAbiAmdgpuHsa = 64;
+static constexpr uint8_t kElfAbiVersionV5 = 3;
+static constexpr uint16_t kElfMachineAmdgpu = 224;
+static constexpr uint32_t kElfMachineGfx942 = 0x04c;
+static constexpr uint32_t kElfFeatureXnackOffV4 = 0x200;
+static constexpr uint32_t kElfFeatureSrameccOnV4 = 0xc00;
+
+static std::array<uint8_t, 64> MakeElf64AmdgpuHsa(uint8_t abi_version,
+                                                  uint16_t machine,
+                                                  uint32_t e_flags) {
+  std::array<uint8_t, 64> elf = {};
+  elf[0] = 0x7f;
+  elf[1] = 'E';
+  elf[2] = 'L';
+  elf[3] = 'F';
+  elf[4] = kElfClass64;
+  elf[5] = kElfData2Lsb;
+  elf[6] = kElfVersionCurrent;
+  elf[7] = kElfOsAbiAmdgpuHsa;
+  elf[8] = abi_version;
+  iree_unaligned_store_le_u16((uint16_t*)&elf[18], machine);
+  iree_unaligned_store_le_u32((uint32_t*)&elf[20], kElfVersionCurrent);
+  iree_unaligned_store_le_u32((uint32_t*)&elf[48], e_flags);
+  iree_unaligned_store_le_u16((uint16_t*)&elf[52], (uint16_t)elf.size());
+  return elf;
+}
+
+static iree_status_t MakeWrappedAmdgpuExecutable(
+    iree_string_view_t metadata_target_id, iree_const_byte_span_t code_object,
+    std::vector<uint8_t>* out_executable_data) {
+  IREE_ASSERT_ARGUMENT(out_executable_data);
+  out_executable_data->clear();
+
+  flatbuffers_builder_t builder;
+  if (IREE_UNLIKELY(flatcc_builder_init(&builder) != 0)) {
+    return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
+                            "failed to initialize flatbuffer builder");
+  }
+
+  iree_status_t status = iree_ok_status();
+  if (IREE_UNLIKELY(flatbuffers_failed(
+          iree_hal_amdgpu_ExecutableDef_start_as_root(&builder)))) {
+    status = iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
+                              "failed to start AMDGPU executable flatbuffer");
+  }
+
+  flatbuffers_string_ref_t isa_ref = 0;
+  flatbuffers_string_ref_t image_ref = 0;
+  iree_hal_amdgpu_ModuleDef_ref_t module_ref = 0;
+  iree_hal_amdgpu_ModuleDef_vec_ref_t modules_ref = 0;
+  if (iree_status_is_ok(status)) {
+    isa_ref = flatbuffers_string_create(&builder, metadata_target_id.data,
+                                        metadata_target_id.size);
+    image_ref = flatbuffers_string_create(
+        &builder, (const char*)code_object.data, code_object.data_length);
+    if (IREE_UNLIKELY(!isa_ref || !image_ref)) {
+      status = iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
+                                "failed to create AMDGPU executable strings");
+    }
+  }
+  if (iree_status_is_ok(status)) {
+    module_ref = iree_hal_amdgpu_ModuleDef_create(&builder, image_ref);
+    modules_ref = iree_hal_amdgpu_ModuleDef_vec_create(&builder, &module_ref,
+                                                       /*len=*/1);
+    if (IREE_UNLIKELY(!module_ref || !modules_ref)) {
+      status =
+          iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
+                           "failed to create AMDGPU executable module vector");
+    }
+  }
+  if (iree_status_is_ok(status)) {
+    if (IREE_UNLIKELY(
+            flatbuffers_failed(
+                iree_hal_amdgpu_ExecutableDef_isa_add(&builder, isa_ref)) ||
+            flatbuffers_failed(iree_hal_amdgpu_ExecutableDef_modules_add(
+                &builder, modules_ref)))) {
+      status =
+          iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
+                           "failed to populate AMDGPU executable flatbuffer");
+    }
+  }
+  if (iree_status_is_ok(status) &&
+      IREE_UNLIKELY(!iree_hal_amdgpu_ExecutableDef_end_as_root(&builder))) {
+    status = iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
+                              "failed to finish AMDGPU executable flatbuffer");
+  }
+
+  size_t flatbuffer_size = 0;
+  void* flatbuffer_data = NULL;
+  if (iree_status_is_ok(status)) {
+    flatbuffer_data =
+        flatcc_builder_finalize_aligned_buffer(&builder, &flatbuffer_size);
+    if (IREE_UNLIKELY(!flatbuffer_data || flatbuffer_size == 0)) {
+      status =
+          iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
+                           "failed to finalize AMDGPU executable flatbuffer");
+    }
+  }
+
+  if (iree_status_is_ok(status)) {
+    iree_flatbuffer_file_header_t header = {};
+    memcpy(&header.magic, iree_hal_amdgpu_ExecutableDef_file_identifier,
+           sizeof(header.magic));
+    header.version = 0;
+    header.content_size = flatbuffer_size;
+
+    out_executable_data->resize(sizeof(header) + flatbuffer_size);
+    memcpy(out_executable_data->data(), &header, sizeof(header));
+    memcpy(out_executable_data->data() + sizeof(header), flatbuffer_data,
+           flatbuffer_size);
+  }
+
+  flatcc_builder_aligned_free(flatbuffer_data);
+  flatcc_builder_clear(&builder);
+  return status;
+}
+
+static std::string InferExecutableFormat(iree_const_byte_span_t executable_data,
+                                         iree_host_size_t* out_inferred_size) {
+  char executable_format[64] = {};
+  IREE_CHECK_OK(iree_hal_amdgpu_executable_infer_format(
+      executable_data, sizeof(executable_format), executable_format,
+      iree_allocator_system(), out_inferred_size));
+  return std::string(executable_format);
+}
+
+TEST(ExecutableTest, InfersRawHsacoTargetIdFromElfFlags) {
+  const auto elf = MakeElf64AmdgpuHsa(
+      kElfAbiVersionV5, kElfMachineAmdgpu,
+      kElfMachineGfx942 | kElfFeatureSrameccOnV4 | kElfFeatureXnackOffV4);
+
+  iree_host_size_t inferred_size = 0;
+  EXPECT_EQ(
+      InferExecutableFormat(iree_make_const_byte_span(elf.data(), elf.size()),
+                            &inferred_size),
+      "gfx942:sramecc+:xnack-");
+  EXPECT_EQ(inferred_size, elf.size());
+}
+
+TEST(ExecutableTest, InfersWrappedFlatbufferTargetIdFromEmbeddedElf) {
+  const auto elf = MakeElf64AmdgpuHsa(
+      kElfAbiVersionV5, kElfMachineAmdgpu,
+      kElfMachineGfx942 | kElfFeatureSrameccOnV4 | kElfFeatureXnackOffV4);
+  std::vector<uint8_t> executable_data;
+  IREE_ASSERT_OK(MakeWrappedAmdgpuExecutable(
+      IREE_SV("gfx1100"), iree_make_const_byte_span(elf.data(), elf.size()),
+      &executable_data));
+
+  iree_host_size_t inferred_size = 0;
+  EXPECT_EQ(
+      InferExecutableFormat(iree_make_const_byte_span(executable_data.data(),
+                                                      executable_data.size()),
+                            &inferred_size),
+      "gfx942:sramecc+:xnack-");
+  EXPECT_EQ(inferred_size, executable_data.size());
+}
+
+}  // namespace
+}  // namespace iree::hal::amdgpu
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue.c b/runtime/src/iree/hal/drivers/amdgpu/host_queue.c
index df50368..cc7d21d 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/host_queue.c
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue.c
@@ -1,4 +1,4 @@
-// Copyright 2025 The IREE Authors
+// Copyright 2026 The IREE Authors
 //
 // Licensed under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,256 +6,1303 @@
 
 #include "iree/hal/drivers/amdgpu/host_queue.h"
 
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_host_queue_t
-//===----------------------------------------------------------------------===//
+#include <stdio.h>
+#include <string.h>
+
+#include "iree/async/frontier_tracker.h"
+#include "iree/async/notification.h"
+#include "iree/base/threading/thread.h"
+#include "iree/hal/drivers/amdgpu/host_queue_blit.h"
+#include "iree/hal/drivers/amdgpu/host_queue_command_buffer.h"
+#include "iree/hal/drivers/amdgpu/host_queue_command_buffer_scratch.h"
+#include "iree/hal/drivers/amdgpu/host_queue_dispatch.h"
+#include "iree/hal/drivers/amdgpu/host_queue_file.h"
+#include "iree/hal/drivers/amdgpu/host_queue_host_call.h"
+#include "iree/hal/drivers/amdgpu/host_queue_memory.h"
+#include "iree/hal/drivers/amdgpu/host_queue_pending.h"
+#include "iree/hal/drivers/amdgpu/host_queue_policy.h"
+#include "iree/hal/drivers/amdgpu/host_queue_profile.h"
+#include "iree/hal/drivers/amdgpu/host_queue_profile_events.h"
+#include "iree/hal/drivers/amdgpu/host_queue_submission.h"
+#include "iree/hal/drivers/amdgpu/host_queue_waits.h"
+#include "iree/hal/drivers/amdgpu/semaphore.h"
+#include "iree/hal/drivers/amdgpu/transient_buffer.h"
+#include "iree/hal/drivers/amdgpu/util/pm4_emitter.h"
+#include "iree/hal/utils/resource_set.h"
 
 static const iree_hal_amdgpu_virtual_queue_vtable_t
     iree_hal_amdgpu_host_queue_vtable;
 
-typedef struct iree_hal_amdgpu_host_queue_t {
-  iree_hal_amdgpu_virtual_queue_t base;
-
-  // Optional callback issued when an asynchronous queue error occurs.
-  iree_hal_amdgpu_error_callback_t error_callback;
-} iree_hal_amdgpu_host_queue_t;
-
-static iree_hal_amdgpu_host_queue_t* iree_hal_amdgpu_host_queue_cast(
-    iree_hal_amdgpu_virtual_queue_t* virtual_queue) {
-  IREE_ASSERT_ARGUMENT(virtual_queue);
-  IREE_ASSERT_EQ(virtual_queue->vtable, &iree_hal_amdgpu_host_queue_vtable);
-  return (iree_hal_amdgpu_host_queue_t*)virtual_queue;
+static iree_status_t iree_hal_amdgpu_host_queue_allocate_pm4_ib_slots(
+    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_agent_t gpu_agent,
+    hsa_amd_memory_pool_t pm4_ib_pool, uint32_t aql_queue_capacity,
+    iree_hal_amdgpu_host_queue_t* out_queue) {
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, aql_queue_capacity);
+  iree_host_size_t pm4_ib_size = 0;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, IREE_STRUCT_LAYOUT(
+              0, &pm4_ib_size,
+              IREE_STRUCT_FIELD(aql_queue_capacity,
+                                iree_hal_amdgpu_pm4_ib_slot_t, NULL)));
+  if (IREE_UNLIKELY(!pm4_ib_pool.handle)) {
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                             "PM4 IB memory pool is required"));
+  }
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, pm4_ib_size);
+  iree_hal_amdgpu_pm4_ib_slot_t* pm4_ib_slots = NULL;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_hsa_amd_memory_pool_allocate(
+              IREE_LIBHSA(libhsa), pm4_ib_pool, pm4_ib_size,
+              HSA_AMD_MEMORY_POOL_EXECUTABLE_FLAG, (void**)&pm4_ib_slots));
+  iree_status_t status = iree_hsa_amd_agents_allow_access(
+      IREE_LIBHSA(libhsa), /*num_agents=*/1, &gpu_agent, /*flags=*/NULL,
+      pm4_ib_slots);
+  if (iree_status_is_ok(status)) {
+    memset(pm4_ib_slots, 0, pm4_ib_size);
+    out_queue->pm4_ib_slots = pm4_ib_slots;
+  } else {
+    status = iree_status_join(status, iree_hsa_amd_memory_pool_free(
+                                          IREE_LIBHSA(libhsa), pm4_ib_slots));
+  }
+  IREE_TRACE_ZONE_END(z0);
+  return status;
 }
 
-iree_host_size_t iree_hal_amdgpu_host_queue_calculate_size(
-    const iree_hal_amdgpu_queue_options_t* options) {
-  IREE_ASSERT_EQ(options->placement, IREE_HAL_AMDGPU_QUEUE_PLACEMENT_HOST);
-  // TODO(benvanik): factor in dynamic sizes (execution queue count, etc).
-  return sizeof(iree_hal_amdgpu_host_queue_t);
+static void iree_hal_amdgpu_host_queue_reclaim_retired(
+    iree_hal_amdgpu_reclaim_entry_t* entry, uint64_t epoch, void* user_data) {
+  (void)epoch;
+  iree_hal_amdgpu_host_queue_t* queue =
+      (iree_hal_amdgpu_host_queue_t*)user_data;
+  iree_hal_amdgpu_profile_dispatch_event_reservation_t reservation = {
+      .first_event_position = entry->profile_event_first_position,
+      .event_count = entry->profile_event_count,
+  };
+  iree_hal_amdgpu_host_queue_retire_profile_dispatch_events(queue, reservation);
+  iree_hal_amdgpu_profile_queue_device_event_reservation_t
+      queue_device_reservation = {
+          .first_event_position = entry->queue_device_event_first_position,
+          .event_count = entry->queue_device_event_count,
+      };
+  iree_hal_amdgpu_host_queue_retire_profile_queue_device_events(
+      queue, queue_device_reservation);
+}
+
+static void iree_hal_amdgpu_host_queue_reclaim_queue_owned_positions(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_reclaim_positions_t reclaim_positions) {
+  if (reclaim_positions.kernarg_write_position > 0) {
+    iree_hal_amdgpu_kernarg_ring_reclaim(
+        &queue->kernarg_ring, reclaim_positions.kernarg_write_position);
+  }
+  if (reclaim_positions.queue_upload_write_position > 0) {
+    IREE_ASSERT(queue->queue_upload_ring.base,
+                "queue upload bytes retired without an initialized upload "
+                "ring");
+    iree_hal_amdgpu_queue_upload_ring_reclaim(
+        &queue->queue_upload_ring,
+        reclaim_positions.queue_upload_write_position);
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// Initialization / deinitialization
+//===----------------------------------------------------------------------===//
+
+void iree_hal_amdgpu_host_queue_enqueue_post_drain_action(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_host_queue_post_drain_action_t* action,
+    iree_hal_amdgpu_host_queue_post_drain_fn_t fn, void* user_data) {
+  action->next = NULL;
+  action->fn = fn;
+  action->user_data = user_data;
+
+  iree_slim_mutex_lock(&queue->locks.post_drain_mutex);
+  if (queue->post_drain.tail) {
+    queue->post_drain.tail->next = action;
+  } else {
+    queue->post_drain.head = action;
+  }
+  queue->post_drain.tail = action;
+  iree_slim_mutex_unlock(&queue->locks.post_drain_mutex);
+}
+
+static void iree_hal_amdgpu_host_queue_run_post_drain_actions(
+    iree_hal_amdgpu_host_queue_t* queue) {
+  iree_slim_mutex_lock(&queue->locks.post_drain_mutex);
+  iree_hal_amdgpu_host_queue_post_drain_action_t* action =
+      queue->post_drain.head;
+  queue->post_drain.head = NULL;
+  queue->post_drain.tail = NULL;
+  iree_slim_mutex_unlock(&queue->locks.post_drain_mutex);
+
+  while (action) {
+    iree_hal_amdgpu_host_queue_post_drain_action_t* next_action = action->next;
+    action->next = NULL;
+    action->fn(action->user_data);
+    action = next_action;
+  }
+}
+
+// Drains completed notification entries and reclaims kernarg space. If the GPU
+// queue has faulted (error_status is set), fails all pending entries instead of
+// draining normally.
+static iree_host_size_t iree_hal_amdgpu_host_queue_drain_completions(
+    iree_hal_amdgpu_host_queue_t* queue) {
+  // Check for GPU queue error (set by the HSA error callback on another
+  // thread). If the queue has faulted, no further epochs will advance;
+  // fail all pending entries so waiters get the actual GPU error instead
+  // of hanging or timing out.
+  iree_status_t error = (iree_status_t)iree_atomic_load(
+      &queue->error_status, iree_memory_order_acquire);
+  const uint64_t previous_epoch = (uint64_t)iree_atomic_load(
+      &queue->notification_ring.epoch.last_drained, iree_memory_order_relaxed);
+  iree_hal_amdgpu_reclaim_positions_t reclaim_positions = {0};
+  iree_host_size_t count = 0;
+  if (IREE_UNLIKELY(error)) {
+    count = iree_hal_amdgpu_notification_ring_fail_all_reclaim_positions(
+        &queue->notification_ring, error, &reclaim_positions);
+    iree_hal_amdgpu_host_queue_clear_profile_events(queue);
+    iree_async_frontier_tracker_fail_axis(
+        queue->frontier_tracker, queue->axis,
+        iree_status_from_code(iree_status_code(error)));
+  } else {
+    count = iree_hal_amdgpu_notification_ring_drain_reclaim_positions(
+        &queue->notification_ring,
+        /*fallback_frontier=*/NULL, iree_hal_amdgpu_host_queue_reclaim_retired,
+        queue, &reclaim_positions);
+    const uint64_t current_epoch =
+        (uint64_t)iree_atomic_load(&queue->notification_ring.epoch.last_drained,
+                                   iree_memory_order_acquire);
+    if (current_epoch > previous_epoch) {
+      iree_async_frontier_tracker_advance(queue->frontier_tracker, queue->axis,
+                                          current_epoch);
+    }
+  }
+  iree_hal_amdgpu_host_queue_reclaim_queue_owned_positions(queue,
+                                                           reclaim_positions);
+  iree_hal_amdgpu_host_queue_run_post_drain_actions(queue);
+  return count;
+}
+
+static bool iree_hal_amdgpu_host_queue_has_error(
+    iree_hal_amdgpu_host_queue_t* queue) {
+  return iree_atomic_load(&queue->error_status, iree_memory_order_acquire) != 0;
+}
+
+static bool iree_hal_amdgpu_host_queue_store_error(
+    iree_hal_amdgpu_host_queue_t* queue, iree_status_t error) {
+  intptr_t expected = 0;
+  if (iree_atomic_compare_exchange_strong(
+          &queue->error_status, &expected, (intptr_t)error,
+          iree_memory_order_release, iree_memory_order_acquire)) {
+    return true;
+  }
+  iree_status_free(error);
+  return false;
+}
+
+static void iree_hal_amdgpu_host_queue_request_completion_thread_stop(
+    iree_hal_amdgpu_host_queue_t* queue) {
+  if (queue->completion.stop_signal.handle) {
+    iree_hsa_signal_store_screlease(IREE_LIBHSA(queue->libhsa),
+                                    queue->completion.stop_signal, 1);
+  }
+}
+
+static hsa_signal_value_t iree_hal_amdgpu_host_queue_last_drained_signal_value(
+    iree_hal_amdgpu_host_queue_t* queue) {
+  const uint64_t last_drained_epoch = (uint64_t)iree_atomic_load(
+      &queue->notification_ring.epoch.last_drained, iree_memory_order_acquire);
+  return (hsa_signal_value_t)(IREE_HAL_AMDGPU_EPOCH_INITIAL_VALUE -
+                              last_drained_epoch);
+}
+
+// Completion thread entry point. Blocks in HSA until either the queue epoch
+// signal changes or teardown/error signals the stop signal. Completion wakeups
+// drain normally; stop/error wakeups perform one final drain/fail before exit.
+static int iree_hal_amdgpu_host_queue_completion_thread_main(void* entry_arg) {
+  {
+    IREE_TRACE_ZONE_BEGIN_NAMED(
+        z0, "iree_hal_amdgpu_host_queue_completion_thread_start");
+    IREE_TRACE_ZONE_END(z0);
+  }
+  iree_hal_amdgpu_host_queue_t* queue =
+      (iree_hal_amdgpu_host_queue_t*)entry_arg;
+
+  enum {
+    IREE_HAL_AMDGPU_COMPLETION_WAIT_EPOCH_SIGNAL = 0,
+    IREE_HAL_AMDGPU_COMPLETION_WAIT_STOP_SIGNAL = 1,
+    IREE_HAL_AMDGPU_COMPLETION_WAIT_SIGNAL_COUNT = 2,
+  };
+
+  hsa_signal_t epoch_signal =
+      iree_hal_amdgpu_notification_ring_epoch_signal(&queue->notification_ring);
+  hsa_signal_t stop_signal = queue->completion.stop_signal;
+  hsa_signal_value_t last_epoch_value =
+      iree_hal_amdgpu_host_queue_last_drained_signal_value(queue);
+
+  bool keep_running = true;
+  while (keep_running) {
+    hsa_signal_t signals[IREE_HAL_AMDGPU_COMPLETION_WAIT_SIGNAL_COUNT] = {
+        epoch_signal,
+        stop_signal,
+    };
+    hsa_signal_condition_t
+        conditions[IREE_HAL_AMDGPU_COMPLETION_WAIT_SIGNAL_COUNT] = {
+            HSA_SIGNAL_CONDITION_NE,
+            HSA_SIGNAL_CONDITION_NE,
+        };
+    hsa_signal_value_t values[IREE_HAL_AMDGPU_COMPLETION_WAIT_SIGNAL_COUNT] = {
+        last_epoch_value,
+        0,
+    };
+    const uint32_t signal_index = iree_hsa_amd_signal_wait_any(
+        IREE_LIBHSA(queue->libhsa),
+        IREE_HAL_AMDGPU_COMPLETION_WAIT_SIGNAL_COUNT, signals, conditions,
+        values, UINT64_MAX, HSA_WAIT_STATE_BLOCKED,
+        /*satisfying_value=*/NULL);
+
+    {
+      IREE_TRACE_ZONE_BEGIN_NAMED(
+          z0, "iree_hal_amdgpu_host_queue_completion_thread_pump");
+      IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, signal_index);
+
+      if (signal_index == IREE_HAL_AMDGPU_COMPLETION_WAIT_EPOCH_SIGNAL) {
+        iree_hal_amdgpu_host_queue_drain_completions(queue);
+        // Arm the next wait from the epoch we actually drained, not from a raw
+        // HSA signal load. A GPU completion can race with the drain and update
+        // the signal after drain() sampled it; observing that newer value here
+        // would mark an undrained epoch as already seen and could sleep forever
+        // with a user semaphore still pending.
+        last_epoch_value =
+            iree_hal_amdgpu_host_queue_last_drained_signal_value(queue);
+      }
+
+      if (signal_index == IREE_HAL_AMDGPU_COMPLETION_WAIT_STOP_SIGNAL ||
+          iree_hal_amdgpu_host_queue_has_error(queue)) {
+        iree_hal_amdgpu_host_queue_drain_completions(queue);
+        keep_running = false;
+      } else if (IREE_UNLIKELY(signal_index >=
+                               IREE_HAL_AMDGPU_COMPLETION_WAIT_SIGNAL_COUNT)) {
+        iree_status_t error = iree_make_status(
+            IREE_STATUS_INTERNAL,
+            "hsa_amd_signal_wait_any returned invalid signal index %u",
+            signal_index);
+        iree_hal_amdgpu_host_queue_store_error(queue, error);
+        iree_hal_amdgpu_host_queue_drain_completions(queue);
+        keep_running = false;
+      }
+
+      IREE_TRACE_ZONE_END(z0);
+    }
+  }
+
+  {
+    IREE_TRACE_ZONE_BEGIN_NAMED(
+        z0, "iree_hal_amdgpu_host_queue_completion_thread_exit");
+    IREE_TRACE_ZONE_END(z0);
+  }
+  return 0;
+}
+
+// HSA queue error callback. Called by the HSA runtime (on an internal thread)
+// when the queue encounters an unrecoverable error (page fault, invalid AQL
+// packet, ECC error). Stores the error atomically on the queue so the
+// completion thread can fail pending semaphores with the actual GPU error.
+static void iree_hal_amdgpu_host_queue_error_callback(hsa_status_t status,
+                                                      hsa_queue_t* source,
+                                                      void* data) {
+  iree_hal_amdgpu_host_queue_t* queue = (iree_hal_amdgpu_host_queue_t*)data;
+
+  // Convert the HSA error to an IREE status with diagnostic information.
+  iree_status_t error = iree_status_from_hsa_status(
+      __FILE__, __LINE__, status, "hsa_queue_error_callback",
+      "GPU queue encountered an unrecoverable error");
+
+  // First-error-wins: store the error with release semantics so the status
+  // payload (heap-allocated string, backtrace) is visible to any thread that
+  // loads with acquire. If another error already won the race, free ours.
+  if (iree_hal_amdgpu_host_queue_store_error(queue, error)) {
+    iree_hal_amdgpu_host_queue_request_completion_thread_stop(queue);
+  }
 }
 
 iree_status_t iree_hal_amdgpu_host_queue_initialize(
-    iree_hal_amdgpu_system_t* system, iree_hal_amdgpu_queue_options_t options,
-    hsa_agent_t device_agent, iree_host_size_t device_ordinal,
-    iree_hal_amdgpu_host_service_t* host_service,
-    iree_arena_block_pool_t* host_block_pool,
-    iree_hal_amdgpu_block_allocators_t* block_allocators,
-    iree_hal_amdgpu_buffer_pool_t* buffer_pool,
-    iree_hal_amdgpu_error_callback_t error_callback,
-    hsa_signal_t initialization_signal, iree_allocator_t host_allocator,
-    iree_hal_amdgpu_virtual_queue_t* out_queue) {
-  IREE_ASSERT_ARGUMENT(system);
-  IREE_ASSERT_EQ(options.placement, IREE_HAL_AMDGPU_QUEUE_PLACEMENT_HOST);
-  IREE_ASSERT_ARGUMENT(host_service);
-  IREE_ASSERT_ARGUMENT(host_block_pool);
-  IREE_ASSERT_ARGUMENT(block_allocators);
-  IREE_ASSERT_ARGUMENT(buffer_pool);
+    const iree_hal_amdgpu_libhsa_t* libhsa, iree_hal_device_t* logical_device,
+    iree_async_proactor_t* proactor, hsa_agent_t gpu_agent,
+    const iree_hal_amdgpu_kernarg_ring_memory_t* kernarg_memory,
+    hsa_amd_memory_pool_t pm4_ib_pool,
+    iree_async_frontier_tracker_t* frontier_tracker, iree_async_axis_t axis,
+    iree_hal_queue_affinity_t queue_affinity,
+    iree_thread_affinity_t completion_thread_affinity,
+    iree_hal_amdgpu_wait_barrier_strategy_t wait_barrier_strategy,
+    iree_hal_amdgpu_vendor_packet_capability_flags_t vendor_packet_capabilities,
+    iree_hal_amdgpu_epoch_signal_table_t* epoch_table,
+    iree_arena_block_pool_t* block_pool,
+    iree_hal_amdgpu_block_pool_t* profiling_signal_block_pool,
+    const iree_hal_amdgpu_device_buffer_transfer_context_t* transfer_context,
+    const iree_hal_pool_set_t* default_pool_set, iree_hal_pool_t* default_pool,
+    iree_hal_amdgpu_transient_buffer_pool_t* transient_buffer_pool,
+    iree_hal_amdgpu_staging_pool_t* staging_pool,
+    iree_host_size_t device_ordinal, uint32_t aql_queue_capacity,
+    uint32_t notification_capacity, uint32_t kernarg_capacity_in_blocks,
+    uint32_t upload_capacity, iree_allocator_t host_allocator,
+    iree_hal_amdgpu_host_queue_t* out_queue) {
+  IREE_ASSERT_ARGUMENT(libhsa);
+  IREE_ASSERT_ARGUMENT(logical_device);
+  IREE_ASSERT_ARGUMENT(proactor);
+  IREE_ASSERT_ARGUMENT(kernarg_memory);
+  IREE_ASSERT_ARGUMENT(frontier_tracker);
+  IREE_ASSERT_ARGUMENT(epoch_table);
+  IREE_ASSERT_ARGUMENT(block_pool);
+  IREE_ASSERT_ARGUMENT(profiling_signal_block_pool);
+  IREE_ASSERT_ARGUMENT(transfer_context);
+  IREE_ASSERT_ARGUMENT(default_pool_set);
+  IREE_ASSERT_ARGUMENT(default_pool);
+  IREE_ASSERT_ARGUMENT(transient_buffer_pool);
   IREE_ASSERT_ARGUMENT(out_queue);
+
+  if (!iree_host_size_is_power_of_two(aql_queue_capacity) ||
+      !iree_host_size_is_power_of_two(notification_capacity) ||
+      !iree_host_size_is_power_of_two(kernarg_capacity_in_blocks) ||
+      (upload_capacity != 0 &&
+       !iree_host_size_is_power_of_two(upload_capacity))) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "all enabled capacities must be powers of two");
+  }
+  if (kernarg_capacity_in_blocks / 2u < aql_queue_capacity) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "kernarg ring capacity must be at least 2x the AQL ring capacity "
+        "to cover one tail-padding gap at wrap (got kernarg_blocks=%u, "
+        "aql_packets=%u)",
+        kernarg_capacity_in_blocks, aql_queue_capacity);
+  }
+
   IREE_TRACE_ZONE_BEGIN(z0);
-  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, device_ordinal);
 
-  iree_hal_amdgpu_host_queue_t* queue =
-      (iree_hal_amdgpu_host_queue_t*)out_queue;
-  queue->base.vtable = &iree_hal_amdgpu_host_queue_vtable;
-  queue->error_callback = error_callback;
+  memset(out_queue, 0, sizeof(*out_queue));
+  out_queue->base.vtable = &iree_hal_amdgpu_host_queue_vtable;
+  out_queue->libhsa = libhsa;
+  out_queue->logical_device = logical_device;
+  out_queue->proactor = proactor;
+  out_queue->frontier_tracker = frontier_tracker;
+  out_queue->host_allocator = host_allocator;
 
-  // TODO(benvanik): implement the host queue.
-  iree_status_t status = iree_make_status(
-      IREE_STATUS_UNIMPLEMENTED, "host-side queuing not yet implemented");
+  // Submission pipeline state.
+  iree_slim_mutex_initialize(&out_queue->locks.submission_mutex);
+  iree_slim_mutex_initialize(&out_queue->locks.post_drain_mutex);
+  iree_slim_mutex_initialize(&out_queue->profiling.event_mutex);
+  out_queue->profiling.signals.block_pool = profiling_signal_block_pool;
+  out_queue->axis = axis;
+  out_queue->wait_barrier_strategy = wait_barrier_strategy;
+  out_queue->vendor_packet_capabilities = vendor_packet_capabilities;
+  out_queue->queue_affinity = queue_affinity;
+  out_queue->last_signal.semaphore = NULL;
+  out_queue->last_signal.epoch = 0;
+  out_queue->block_pool = block_pool;
+  out_queue->can_publish_frontier = true;
+  out_queue->transfer_context = transfer_context;
+  out_queue->default_pool_set = default_pool_set;
+  out_queue->default_pool = default_pool;
+  out_queue->transient_buffer_pool = transient_buffer_pool;
+  out_queue->staging_pool = staging_pool;
+  out_queue->device_ordinal = device_ordinal;
+  out_queue->pending_head = NULL;
+  iree_async_frontier_initialize(iree_hal_amdgpu_host_queue_frontier(out_queue),
+                                 /*entry_count=*/0);
+
+  // The optional tracker semaphore is an iree_async_semaphore_t bridge for
+  // CPU-side wait integration. The queue's GPU-visible HSA epoch signal is
+  // created by the notification ring below and registered in the epoch table.
+  iree_status_t status = iree_async_frontier_tracker_register_axis(
+      frontier_tracker, axis, /*semaphore=*/NULL);
+
+  // Create the host-only stop signal before the hardware queue so the HSA error
+  // callback always has a valid signal to wake if queue creation races with an
+  // asynchronous fault.
+  if (iree_status_is_ok(status)) {
+    status = iree_hsa_amd_signal_create(
+        IREE_LIBHSA(libhsa), /*initial_value=*/0,
+        /*num_consumers=*/0, /*consumers=*/NULL, /*attributes=*/0,
+        &out_queue->completion.stop_signal);
+  }
+
+  // Create the HSA hardware AQL queue.
+  //
+  // HSA_QUEUE_TYPE_MULTI is required (not just an optimization). Once command
+  // buffers start performing device-side enqueue, the CP itself becomes a
+  // concurrent producer alongside the host submission path, so the queue must
+  // permit multiple concurrent producers. The host-side reserve already uses
+  // an atomic fetch_add on the write index, which is well-defined only on
+  // MULTI queues.
+  hsa_queue_t* hardware_queue = NULL;
+  if (iree_status_is_ok(status)) {
+    status = iree_hsa_queue_create(
+        IREE_LIBHSA(libhsa), gpu_agent, aql_queue_capacity,
+        HSA_QUEUE_TYPE_MULTI, iree_hal_amdgpu_host_queue_error_callback,
+        /*data=*/out_queue,
+        /*private_segment_size=*/UINT32_MAX,
+        /*group_segment_size=*/UINT32_MAX, &hardware_queue);
+  }
+
+  // Initialize the AQL ring from the hardware queue.
+  if (iree_status_is_ok(status)) {
+    out_queue->hardware_queue = hardware_queue;
+    iree_hal_amdgpu_aql_ring_initialize((iree_amd_queue_t*)hardware_queue,
+                                        &out_queue->aql_ring);
+  }
+
+  // Initialize the kernarg ring from the selected HSA memory pool.
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_kernarg_ring_initialize(libhsa, kernarg_memory,
+                                                     kernarg_capacity_in_blocks,
+                                                     &out_queue->kernarg_ring);
+  }
+
+  // Initialize the optional queue-control upload ring from the same
+  // host-visible memory policy as queue-owned kernargs. A zero capacity keeps
+  // future device-side fixup storage opt-in and avoids charging every queue for
+  // an unused allocation.
+  if (iree_status_is_ok(status) && upload_capacity != 0) {
+    const iree_hal_amdgpu_queue_upload_ring_memory_t upload_memory = {
+        .memory_pool = kernarg_memory->memory_pool,
+        .access_agents = kernarg_memory->access_agents,
+        .access_agent_count = kernarg_memory->access_agent_count,
+        .publication = kernarg_memory->publication,
+    };
+    status = iree_hal_amdgpu_queue_upload_ring_initialize(
+        libhsa, &upload_memory, upload_capacity, &out_queue->queue_upload_ring);
+  }
+
+  // Initialize the optional PM4 IB slot buffer. Capability-driven allocation
+  // keeps dynamic PM4 storage available on CDNA queues that use BARRIER_VALUE
+  // for waits but still support AQL PM4-IB snippets for other features. The
+  // buffer is indexed by AQL packet id and inherits AQL ring
+  // backpressure/reuse; there is no separate PM4 producer or reclaim position.
+  if (iree_status_is_ok(status) &&
+      (vendor_packet_capabilities &
+       IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_AQL_PM4_IB)) {
+    status = iree_hal_amdgpu_host_queue_allocate_pm4_ib_slots(
+        libhsa, gpu_agent, pm4_ib_pool, aql_queue_capacity, out_queue);
+  }
+
+  // Initialize the notification ring (creates epoch signal + entry buffer).
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_notification_ring_initialize(
+        libhsa, block_pool, notification_capacity, host_allocator,
+        &out_queue->notification_ring);
+  }
+
+  // Register this queue's epoch signal in the shared table for cross-queue
+  // barrier emission lookups. Must happen after notification ring init (which
+  // creates the epoch signal) and before any submissions.
+  if (iree_status_is_ok(status)) {
+    iree_hal_amdgpu_epoch_signal_table_register(
+        epoch_table, iree_async_axis_device_index(axis),
+        iree_async_axis_queue_index(axis),
+        iree_hal_amdgpu_notification_ring_epoch_signal(
+            &out_queue->notification_ring));
+    out_queue->epoch_table = epoch_table;
+  }
+
+  if (iree_status_is_ok(status)) {
+    iree_thread_create_params_t thread_params;
+    memset(&thread_params, 0, sizeof(thread_params));
+    char thread_name[32] = {0};
+    snprintf(thread_name, IREE_ARRAYSIZE(thread_name),
+             "iree-hal-amdgpu-l0p%uq%u-complete",
+             (unsigned)iree_async_axis_device_index(axis),
+             (unsigned)iree_async_axis_queue_index(axis));
+    thread_params.name = iree_make_cstring_view(thread_name);
+    thread_params.initial_affinity = completion_thread_affinity;
+    status = iree_thread_create(
+        iree_hal_amdgpu_host_queue_completion_thread_main, out_queue,
+        thread_params, host_allocator, &out_queue->completion.thread);
+  }
+  if (!iree_status_is_ok(status)) {
+    iree_hal_amdgpu_host_queue_deinitialize(out_queue);
+  }
 
   IREE_TRACE_ZONE_END(z0);
   return status;
 }
 
-static void iree_hal_amdgpu_host_queue_deinitialize(
-    iree_hal_amdgpu_virtual_queue_t* virtual_queue) {
-  iree_hal_amdgpu_host_queue_t* queue =
-      iree_hal_amdgpu_host_queue_cast(virtual_queue);
+void iree_hal_amdgpu_host_queue_deinitialize(
+    iree_hal_amdgpu_host_queue_t* queue) {
+  IREE_ASSERT_ARGUMENT(queue);
   IREE_TRACE_ZONE_BEGIN(z0);
 
-  (void)queue;
+  iree_slim_mutex_lock(&queue->locks.submission_mutex);
+  queue->is_shutting_down = true;
+  iree_slim_mutex_unlock(&queue->locks.submission_mutex);
+
+  if (queue->completion.thread) {
+    iree_hal_amdgpu_host_queue_request_completion_thread_stop(queue);
+    // There is only one owner for the thread, so this also joins the thread.
+    iree_thread_release(queue->completion.thread);
+    queue->completion.thread = NULL;
+  }
+
+  // Destroy the hardware queue before the remaining host-side resources so the
+  // HSA runtime cannot race a late error callback against signal teardown.
+  if (queue->hardware_queue) {
+    iree_hal_amdgpu_hsa_cleanup_assert_success(
+        iree_hsa_queue_destroy_raw(queue->libhsa, queue->hardware_queue));
+    queue->hardware_queue = NULL;
+  }
+
+  // Capacity-parked pending ops are retried by post-drain callbacks. Flush
+  // those callbacks under shutdown first so they observe cancellation and own
+  // their normal failure path instead of being destroyed out from under the
+  // callback storage.
+  iree_hal_amdgpu_host_queue_run_post_drain_actions(queue);
+
+  // Cancel all pending (deferred) operations. Their signal semaphores are
+  // failed with CANCELLED so downstream waiters don't hang.
+  if (queue->pending_head) {
+    iree_hal_amdgpu_host_queue_cancel_pending(queue, IREE_STATUS_CANCELLED,
+                                              "queue shutting down");
+  }
+
+  // Process any remaining notification entries before destroying resources.
+  // If the GPU faulted, fail all pending entries so waiters get the actual
+  // error. Otherwise drain normally (entries completed but not yet processed).
+  iree_status_t error = (iree_status_t)iree_atomic_load(
+      &queue->error_status, iree_memory_order_acquire);
+  iree_hal_amdgpu_reclaim_positions_t reclaim_positions = {0};
+  if (!iree_status_is_ok(error)) {
+    iree_hal_amdgpu_notification_ring_fail_all_reclaim_positions(
+        &queue->notification_ring, error, &reclaim_positions);
+    iree_hal_amdgpu_host_queue_clear_profile_events(queue);
+    iree_status_free(error);
+  } else {
+    iree_hal_amdgpu_notification_ring_drain_reclaim_positions(
+        &queue->notification_ring,
+        /*fallback_frontier=*/NULL, iree_hal_amdgpu_host_queue_reclaim_retired,
+        queue, &reclaim_positions);
+  }
+  iree_hal_amdgpu_host_queue_reclaim_queue_owned_positions(queue,
+                                                           reclaim_positions);
+  iree_hal_amdgpu_host_queue_run_post_drain_actions(queue);
+
+  // Deregister from the epoch signal table before destroying the notification
+  // ring (which owns the epoch signal). Guarded by epoch_table != NULL to
+  // handle partial initialization (init failed before registration).
+  if (queue->epoch_table) {
+    iree_hal_amdgpu_epoch_signal_table_deregister(
+        queue->epoch_table, iree_async_axis_device_index(queue->axis),
+        iree_async_axis_queue_index(queue->axis));
+    queue->epoch_table = NULL;
+  }
+
+  if (queue->frontier_tracker) {
+    iree_async_frontier_tracker_retire_axis(
+        queue->frontier_tracker, queue->axis,
+        iree_status_from_code(IREE_STATUS_CANCELLED));
+    queue->frontier_tracker = NULL;
+    queue->axis = 0;
+  }
+
+  iree_hal_amdgpu_notification_ring_deinitialize(&queue->notification_ring);
+
+  if (queue->queue_upload_ring.base) {
+    iree_hal_amdgpu_queue_upload_ring_deinitialize(queue->libhsa,
+                                                   &queue->queue_upload_ring);
+  }
+
+  iree_hal_amdgpu_kernarg_ring_deinitialize(queue->libhsa,
+                                            &queue->kernarg_ring);
+
+  if (queue->pm4_ib_slots) {
+    iree_hal_amdgpu_hsa_cleanup_assert_success(
+        iree_hsa_amd_memory_pool_free_raw(queue->libhsa, queue->pm4_ib_slots));
+    queue->pm4_ib_slots = NULL;
+  }
+
+  iree_hal_amdgpu_host_queue_deallocate_profiling_completion_signals(queue);
+  iree_hal_amdgpu_host_queue_deallocate_profile_events(queue);
+
+  if (queue->command_buffer_scratch) {
+    iree_allocator_free(queue->host_allocator, queue->command_buffer_scratch);
+    queue->command_buffer_scratch = NULL;
+  }
+
+  if (queue->completion.stop_signal.handle) {
+    iree_hal_amdgpu_hsa_cleanup_assert_success(iree_hsa_signal_destroy_raw(
+        queue->libhsa, queue->completion.stop_signal));
+    queue->completion.stop_signal.handle = 0;
+  }
+
+  iree_slim_mutex_deinitialize(&queue->locks.post_drain_mutex);
+  iree_slim_mutex_deinitialize(&queue->profiling.event_mutex);
+  iree_slim_mutex_deinitialize(&queue->locks.submission_mutex);
 
   IREE_TRACE_ZONE_END(z0);
 }
 
+iree_status_t iree_hal_amdgpu_host_queue_set_hsa_profiling_enabled(
+    iree_hal_amdgpu_host_queue_t* queue, bool enabled) {
+  IREE_ASSERT_ARGUMENT(queue);
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, enabled ? 1 : 0);
+
+  if (enabled) {
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_hal_amdgpu_host_queue_ensure_profile_event_storage(queue));
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0,
+        iree_hal_amdgpu_host_queue_ensure_profiling_completion_signals(queue));
+    iree_hal_amdgpu_host_queue_clear_profile_events(queue);
+  }
+
+  iree_status_t status = iree_hsa_amd_profiling_set_profiler_enabled(
+      IREE_LIBHSA(queue->libhsa), queue->hardware_queue, enabled ? 1 : 0);
+  if (iree_status_is_ok(status)) {
+    queue->profiling.hsa_queue_timestamps_enabled = enabled ? 1 : 0;
+  }
+
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
 static void iree_hal_amdgpu_host_queue_trim(
-    iree_hal_amdgpu_virtual_queue_t* virtual_queue) {
-  iree_hal_amdgpu_host_queue_t* queue =
-      iree_hal_amdgpu_host_queue_cast(virtual_queue);
-  IREE_TRACE_ZONE_BEGIN(z0);
+    iree_hal_amdgpu_virtual_queue_t* base_queue) {}
 
-  (void)queue;
+//===----------------------------------------------------------------------===//
+// Queue operations
+//===----------------------------------------------------------------------===//
 
-  IREE_TRACE_ZONE_END(z0);
+typedef struct iree_hal_amdgpu_host_queue_op_submission_t {
+  // Queue whose submission_mutex is held between begin/end.
+  iree_hal_amdgpu_host_queue_t* queue;
+
+  // Wait resolution computed while holding submission_mutex.
+  iree_hal_amdgpu_wait_resolution_t resolution;
+
+  // Deferred operation captured while holding submission_mutex, if any.
+  iree_hal_amdgpu_pending_op_t* deferred_op;
+
+  // Number of input waits. Capacity retries only need post-drain resubmission
+  // when no semantic waits are available to naturally re-enter the queue.
+  iree_host_size_t wait_semaphore_count;
+
+  // Whether the direct submit helper found enough queue capacity.
+  bool ready;
+
+  // Whether |deferred_op| should retry on the completion thread after drain.
+  bool wait_for_capacity;
+} iree_hal_amdgpu_host_queue_op_submission_t;
+
+// Begins one direct/deferred queue operation attempt. The caller must pair this
+// with iree_hal_amdgpu_host_queue_op_submission_end exactly once.
+static inline void iree_hal_amdgpu_host_queue_op_submission_begin(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t wait_semaphore_list,
+    iree_hal_amdgpu_host_queue_op_submission_t* out_submission) {
+  out_submission->queue = queue;
+  out_submission->deferred_op = NULL;
+  out_submission->wait_semaphore_count = wait_semaphore_list.count;
+  out_submission->ready = true;
+  out_submission->wait_for_capacity = false;
+
+  iree_slim_mutex_lock(&queue->locks.submission_mutex);
+  iree_hal_amdgpu_host_queue_resolve_waits(queue, wait_semaphore_list,
+                                           &out_submission->resolution);
 }
 
-static iree_status_t iree_hal_amdgpu_host_queue_alloca(
-    iree_hal_amdgpu_virtual_queue_t* virtual_queue,
-    const iree_hal_semaphore_list_t wait_semaphore_list,
-    const iree_hal_semaphore_list_t signal_semaphore_list,
-    iree_hal_pool_t* pool, iree_hal_buffer_params_t params,
-    iree_device_size_t allocation_size, iree_hal_alloca_flags_t flags,
-    iree_hal_buffer_t** IREE_RESTRICT out_buffer) {
-  iree_hal_amdgpu_host_queue_t* queue =
-      iree_hal_amdgpu_host_queue_cast(virtual_queue);
-  IREE_TRACE_ZONE_BEGIN(z0);
+// Marks a captured pending op as retrying after completion-thread drain because
+// direct submission ran out of queue capacity.
+static inline void iree_hal_amdgpu_host_queue_op_submission_defer_for_capacity(
+    iree_hal_amdgpu_host_queue_op_submission_t* submission) {
+  submission->wait_for_capacity = submission->wait_semaphore_count == 0;
+}
 
-  iree_status_t status =
-      iree_make_status(IREE_STATUS_UNIMPLEMENTED, "queue_alloca");
-  (void)queue;
+// Ends one direct/deferred queue operation attempt by releasing
+// submission_mutex and starting any captured pending op outside the lock.
+static inline iree_status_t iree_hal_amdgpu_host_queue_op_submission_end(
+    iree_hal_amdgpu_host_queue_op_submission_t* submission,
+    iree_status_t status) {
+  iree_slim_mutex_unlock(&submission->queue->locks.submission_mutex);
 
-  IREE_TRACE_ZONE_END(z0);
+  if (iree_status_is_ok(status) && submission->deferred_op) {
+    status = iree_hal_amdgpu_pending_op_start(submission->deferred_op,
+                                              submission->wait_for_capacity);
+  }
   return status;
 }
 
-static iree_status_t iree_hal_amdgpu_host_queue_dealloca(
-    iree_hal_amdgpu_virtual_queue_t* virtual_queue,
-    const iree_hal_semaphore_list_t wait_semaphore_list,
-    const iree_hal_semaphore_list_t signal_semaphore_list,
-    iree_hal_buffer_t* buffer, iree_hal_dealloca_flags_t flags) {
-  iree_hal_amdgpu_host_queue_t* queue =
-      iree_hal_amdgpu_host_queue_cast(virtual_queue);
-  IREE_TRACE_ZONE_BEGIN(z0);
+static iree_status_t iree_hal_amdgpu_host_queue_signal_empty_barrier(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t signal_semaphore_list) {
+  iree_slim_mutex_lock(&queue->locks.submission_mutex);
+  iree_status_t status = iree_ok_status();
+  if (IREE_UNLIKELY(queue->is_shutting_down)) {
+    status = iree_make_status(IREE_STATUS_CANCELLED, "queue shutting down");
+  }
+  iree_slim_mutex_unlock(&queue->locks.submission_mutex);
 
-  iree_status_t status =
-      iree_make_status(IREE_STATUS_UNIMPLEMENTED, "queue_dealloca");
-  (void)queue;
-
-  IREE_TRACE_ZONE_END(z0);
-  return status;
-}
-
-static iree_status_t iree_hal_amdgpu_host_queue_fill(
-    iree_hal_amdgpu_virtual_queue_t* virtual_queue,
-    const iree_hal_semaphore_list_t wait_semaphore_list,
-    const iree_hal_semaphore_list_t signal_semaphore_list,
-    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
-    iree_device_size_t length, uint64_t pattern_bits,
-    iree_host_size_t pattern_length, iree_hal_fill_flags_t flags) {
-  iree_hal_amdgpu_host_queue_t* queue =
-      iree_hal_amdgpu_host_queue_cast(virtual_queue);
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  iree_status_t status =
-      iree_make_status(IREE_STATUS_UNIMPLEMENTED, "queue_fill");
-  (void)queue;
-
-  IREE_TRACE_ZONE_END(z0);
-  return status;
-}
-
-static iree_status_t iree_hal_amdgpu_host_queue_update(
-    iree_hal_amdgpu_virtual_queue_t* virtual_queue,
-    const iree_hal_semaphore_list_t wait_semaphore_list,
-    const iree_hal_semaphore_list_t signal_semaphore_list,
-    const void* source_buffer, iree_host_size_t source_offset,
-    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
-    iree_device_size_t length, iree_hal_update_flags_t flags) {
-  iree_hal_amdgpu_host_queue_t* queue =
-      iree_hal_amdgpu_host_queue_cast(virtual_queue);
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  iree_status_t status =
-      iree_make_status(IREE_STATUS_UNIMPLEMENTED, "queue_update");
-  (void)queue;
-
-  IREE_TRACE_ZONE_END(z0);
-  return status;
-}
-
-static iree_status_t iree_hal_amdgpu_host_queue_copy(
-    iree_hal_amdgpu_virtual_queue_t* virtual_queue,
-    const iree_hal_semaphore_list_t wait_semaphore_list,
-    const iree_hal_semaphore_list_t signal_semaphore_list,
-    iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
-    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
-    iree_device_size_t length, iree_hal_copy_flags_t flags) {
-  iree_hal_amdgpu_host_queue_t* queue =
-      iree_hal_amdgpu_host_queue_cast(virtual_queue);
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  iree_status_t status =
-      iree_make_status(IREE_STATUS_UNIMPLEMENTED, "queue_copy");
-  (void)queue;
-
-  IREE_TRACE_ZONE_END(z0);
-  return status;
-}
-
-static iree_status_t iree_hal_amdgpu_host_queue_read(
-    iree_hal_amdgpu_virtual_queue_t* virtual_queue,
-    const iree_hal_semaphore_list_t wait_semaphore_list,
-    const iree_hal_semaphore_list_t signal_semaphore_list,
-    iree_hal_file_t* source_file, uint64_t source_offset,
-    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
-    iree_device_size_t length, iree_hal_read_flags_t flags) {
-  iree_hal_amdgpu_host_queue_t* queue =
-      iree_hal_amdgpu_host_queue_cast(virtual_queue);
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  iree_status_t status =
-      iree_make_status(IREE_STATUS_UNIMPLEMENTED, "queue_read");
-  (void)queue;
-
-  IREE_TRACE_ZONE_END(z0);
-  return status;
-}
-
-static iree_status_t iree_hal_amdgpu_host_queue_write(
-    iree_hal_amdgpu_virtual_queue_t* virtual_queue,
-    const iree_hal_semaphore_list_t wait_semaphore_list,
-    const iree_hal_semaphore_list_t signal_semaphore_list,
-    iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
-    iree_hal_file_t* target_file, uint64_t target_offset,
-    iree_device_size_t length, iree_hal_write_flags_t flags) {
-  iree_hal_amdgpu_host_queue_t* queue =
-      iree_hal_amdgpu_host_queue_cast(virtual_queue);
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  iree_status_t status =
-      iree_make_status(IREE_STATUS_UNIMPLEMENTED, "queue_write");
-  (void)queue;
-
-  IREE_TRACE_ZONE_END(z0);
+  if (iree_status_is_ok(status)) {
+    // Signal outside submission_mutex: semaphore signaling dispatches satisfied
+    // timepoints, and those callbacks may submit additional queue work.
+    status = iree_hal_semaphore_list_signal(signal_semaphore_list,
+                                            /*frontier=*/NULL);
+  }
   return status;
 }
 
 static iree_status_t iree_hal_amdgpu_host_queue_execute(
-    iree_hal_amdgpu_virtual_queue_t* virtual_queue,
+    iree_hal_amdgpu_virtual_queue_t* base_queue,
     const iree_hal_semaphore_list_t wait_semaphore_list,
     const iree_hal_semaphore_list_t signal_semaphore_list,
     iree_hal_command_buffer_t* command_buffer,
     iree_hal_buffer_binding_table_t binding_table,
     iree_hal_execute_flags_t flags) {
   iree_hal_amdgpu_host_queue_t* queue =
-      iree_hal_amdgpu_host_queue_cast(virtual_queue);
-  IREE_TRACE_ZONE_BEGIN(z0);
+      (iree_hal_amdgpu_host_queue_t*)base_queue;
 
-  iree_status_t status =
-      iree_make_status(IREE_STATUS_UNIMPLEMENTED, "queue_execute");
-  (void)queue;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_host_queue_validate_execute_flags(flags));
 
-  IREE_TRACE_ZONE_END(z0);
+  if (!command_buffer && wait_semaphore_list.count == 0) {
+    if (IREE_UNLIKELY(binding_table.count != 0)) {
+      return iree_make_status(
+          IREE_STATUS_INVALID_ARGUMENT,
+          "barrier-only queue_execute must not provide a binding table "
+          "(count=%" PRIhsz ")",
+          binding_table.count);
+    }
+    return iree_hal_amdgpu_host_queue_signal_empty_barrier(
+        queue, signal_semaphore_list);
+  }
+
+  iree_hal_amdgpu_host_queue_op_submission_t submission;
+  iree_hal_amdgpu_host_queue_op_submission_begin(queue, wait_semaphore_list,
+                                                 &submission);
+  iree_status_t status = iree_ok_status();
+  if (submission.resolution.needs_deferral) {
+    status = iree_hal_amdgpu_host_queue_defer_execute(
+        queue, &wait_semaphore_list, &signal_semaphore_list, command_buffer,
+        binding_table, flags, &submission.deferred_op);
+  } else if (!command_buffer) {
+    if (IREE_UNLIKELY(binding_table.count != 0)) {
+      status = iree_make_status(
+          IREE_STATUS_INVALID_ARGUMENT,
+          "barrier-only queue_execute must not provide a binding table "
+          "(count=%" PRIhsz ")",
+          binding_table.count);
+    } else {
+      uint64_t submission_id = 0;
+      iree_hal_amdgpu_host_queue_profile_event_info_t profile_event_info = {
+          .type = IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_BARRIER,
+          .operation_count = 0,
+      };
+      status = iree_hal_amdgpu_host_queue_try_submit_barrier(
+          queue, &submission.resolution, signal_semaphore_list,
+          (iree_hal_amdgpu_reclaim_action_t){0},
+          /*operation_resources=*/NULL,
+          /*operation_resource_count=*/0, &profile_event_info,
+          iree_hal_amdgpu_host_queue_post_commit_callback_null(),
+          /*resource_set=*/NULL,
+          IREE_HAL_AMDGPU_HOST_QUEUE_SUBMISSION_FLAG_RETAIN_RESOURCES,
+          &submission.ready, &submission_id);
+      if (iree_status_is_ok(status) && submission.ready) {
+        profile_event_info.submission_id = submission_id;
+        iree_hal_amdgpu_host_queue_record_profile_queue_event(
+            queue, &submission.resolution, signal_semaphore_list,
+            &profile_event_info);
+      }
+      if (iree_status_is_ok(status) && !submission.ready) {
+        status = iree_hal_amdgpu_host_queue_defer_execute(
+            queue, &wait_semaphore_list, &signal_semaphore_list,
+            /*command_buffer=*/NULL, iree_hal_buffer_binding_table_empty(),
+            flags, &submission.deferred_op);
+        iree_hal_amdgpu_host_queue_op_submission_defer_for_capacity(
+            &submission);
+      }
+    }
+  } else {
+    iree_hal_resource_set_t* binding_resource_set = NULL;
+    status = iree_hal_amdgpu_host_queue_submit_command_buffer(
+        queue, &submission.resolution, signal_semaphore_list, command_buffer,
+        binding_table, flags, &binding_resource_set, &submission.ready);
+    if (iree_status_is_ok(status) && !submission.ready) {
+      iree_hal_resource_set_free(binding_resource_set);
+      status = iree_hal_amdgpu_host_queue_defer_execute(
+          queue, &wait_semaphore_list, &signal_semaphore_list, command_buffer,
+          binding_table, flags, &submission.deferred_op);
+      iree_hal_amdgpu_host_queue_op_submission_defer_for_capacity(&submission);
+    } else if (!iree_status_is_ok(status)) {
+      iree_hal_resource_set_free(binding_resource_set);
+    }
+  }
+  return iree_hal_amdgpu_host_queue_op_submission_end(&submission, status);
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_alloca(
+    iree_hal_amdgpu_virtual_queue_t* base_queue,
+    const iree_hal_semaphore_list_t wait_semaphore_list,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_pool_t* pool, iree_hal_buffer_params_t params,
+    iree_device_size_t allocation_size, iree_hal_alloca_flags_t flags,
+    iree_hal_buffer_t** IREE_RESTRICT out_buffer) {
+  IREE_ASSERT_ARGUMENT(out_buffer);
+  *out_buffer = NULL;
+
+  iree_hal_amdgpu_host_queue_t* queue =
+      (iree_hal_amdgpu_host_queue_t*)base_queue;
+
+  iree_hal_pool_t* allocation_pool = NULL;
+  iree_hal_buffer_t* buffer = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_prepare_alloca_wrapper(
+      queue, pool, &params, allocation_size, flags, &allocation_pool, &buffer));
+  // Always ask the pool to surface waitable death-frontier candidates so the
+  // queue can distinguish true pool pressure from a dependency the caller did
+  // not authorize. The HAL alloca flag is checked before consuming any
+  // OK_NEEDS_WAIT reservation. Disallow growth while submission_mutex is held;
+  // growable pools report that as a cold retry instead of calling into their
+  // slab provider on the serialized queue path.
+  const iree_hal_pool_reserve_flags_t reserve_flags =
+      IREE_HAL_POOL_RESERVE_FLAG_ALLOW_WAIT_FRONTIER |
+      IREE_HAL_POOL_RESERVE_FLAG_DISALLOW_GROWTH;
+
+  iree_hal_amdgpu_host_queue_op_submission_t submission;
+  iree_hal_amdgpu_host_queue_op_submission_begin(queue, wait_semaphore_list,
+                                                 &submission);
+  iree_status_t status = iree_ok_status();
+  iree_hal_amdgpu_pending_op_t* memory_wait_op = NULL;
+  if (submission.resolution.needs_deferral) {
+    status = iree_hal_amdgpu_host_queue_defer_alloca(
+        queue, &wait_semaphore_list, &signal_semaphore_list, allocation_pool,
+        params, allocation_size, flags, reserve_flags, buffer,
+        &submission.deferred_op);
+  } else {
+    status = iree_hal_amdgpu_host_queue_submit_alloca(
+        queue, &submission.resolution, signal_semaphore_list, allocation_pool,
+        params, allocation_size, flags, reserve_flags, buffer,
+        IREE_HAL_AMDGPU_HOST_QUEUE_SUBMISSION_FLAG_RETAIN_RESOURCES,
+        /*pending_op=*/NULL, &memory_wait_op, &submission.ready);
+    if (iree_status_is_ok(status) && !submission.ready && !memory_wait_op) {
+      status = iree_hal_amdgpu_host_queue_defer_alloca(
+          queue, &wait_semaphore_list, &signal_semaphore_list, allocation_pool,
+          params, allocation_size, flags, reserve_flags, buffer,
+          &submission.deferred_op);
+      iree_hal_amdgpu_host_queue_op_submission_defer_for_capacity(&submission);
+    }
+  }
+  status = iree_hal_amdgpu_host_queue_op_submission_end(&submission, status);
+  if (iree_status_is_ok(status) && memory_wait_op) {
+    iree_hal_amdgpu_pending_op_enqueue_alloca_memory_wait(memory_wait_op);
+  }
+
+  if (iree_status_is_ok(status)) {
+    *out_buffer = buffer;
+  } else {
+    iree_hal_buffer_release(buffer);
+  }
   return status;
 }
 
-static iree_status_t iree_hal_amdgpu_host_queue_flush(
-    iree_hal_amdgpu_virtual_queue_t* virtual_queue) {
+static iree_status_t iree_hal_amdgpu_host_queue_dealloca(
+    iree_hal_amdgpu_virtual_queue_t* base_queue,
+    const iree_hal_semaphore_list_t wait_semaphore_list,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_buffer_t* buffer, iree_hal_dealloca_flags_t flags) {
   iree_hal_amdgpu_host_queue_t* queue =
-      iree_hal_amdgpu_host_queue_cast(virtual_queue);
-  IREE_TRACE_ZONE_BEGIN(z0);
+      (iree_hal_amdgpu_host_queue_t*)base_queue;
 
-  iree_status_t status =
-      iree_make_status(IREE_STATUS_UNIMPLEMENTED, "queue_flush");
-  (void)queue;
+  if (IREE_UNLIKELY(
+          iree_any_bit_set(flags, ~(IREE_HAL_DEALLOCA_FLAG_NONE |
+                                    IREE_HAL_DEALLOCA_FLAG_PREFER_ORIGIN)))) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "unsupported dealloca flags: 0x%" PRIx64, flags);
+  }
 
-  IREE_TRACE_ZONE_END(z0);
+  // iree_hal_device_queue_dealloca() applies PREFER_ORIGIN before vtable
+  // dispatch by rewriting the device and queue affinity from the buffer's
+  // allocation placement. Transient wrappers created by queue_alloca carry this
+  // queue's one-bit affinity in that placement, so this host-queue path can use
+  // |base_queue| directly.
+  if (!iree_hal_amdgpu_transient_buffer_isa(buffer)) {
+    return iree_hal_amdgpu_host_queue_execute(
+        base_queue, wait_semaphore_list, signal_semaphore_list,
+        /*command_buffer=*/NULL, iree_hal_buffer_binding_table_empty(),
+        IREE_HAL_EXECUTE_FLAG_NONE);
+  }
+
+  if (IREE_UNLIKELY(!iree_hal_amdgpu_transient_buffer_begin_dealloca(buffer))) {
+    return iree_make_status(
+        IREE_STATUS_FAILED_PRECONDITION,
+        "transient buffer has already been queued for deallocation");
+  }
+
+  iree_hal_amdgpu_host_queue_op_submission_t submission;
+  iree_hal_amdgpu_host_queue_op_submission_begin(queue, wait_semaphore_list,
+                                                 &submission);
+  iree_status_t status = iree_ok_status();
+  if (submission.resolution.needs_deferral) {
+    status = iree_hal_amdgpu_host_queue_defer_dealloca(
+        queue, &wait_semaphore_list, &signal_semaphore_list, buffer,
+        &submission.deferred_op);
+  } else {
+    status = iree_hal_amdgpu_host_queue_submit_dealloca(
+        queue, &submission.resolution, signal_semaphore_list, buffer,
+        IREE_HAL_AMDGPU_HOST_QUEUE_SUBMISSION_FLAG_RETAIN_RESOURCES,
+        &submission.ready);
+    if (iree_status_is_ok(status) && !submission.ready) {
+      status = iree_hal_amdgpu_host_queue_defer_dealloca(
+          queue, &wait_semaphore_list, &signal_semaphore_list, buffer,
+          &submission.deferred_op);
+      iree_hal_amdgpu_host_queue_op_submission_defer_for_capacity(&submission);
+    }
+  }
+  status = iree_hal_amdgpu_host_queue_op_submission_end(&submission, status);
+  if (!iree_status_is_ok(status)) {
+    iree_hal_amdgpu_transient_buffer_abort_dealloca(buffer);
+  }
   return status;
 }
 
+// Queue fill entry point. Resolves waits under submission_mutex and captures a
+// pending operation only when waits or submission capacity require deferral.
+static iree_status_t iree_hal_amdgpu_host_queue_fill(
+    iree_hal_amdgpu_virtual_queue_t* base_queue,
+    const iree_hal_semaphore_list_t wait_semaphore_list,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+    iree_device_size_t length, uint64_t pattern_bits,
+    iree_host_size_t pattern_length, iree_hal_fill_flags_t flags) {
+  iree_hal_amdgpu_host_queue_t* queue =
+      (iree_hal_amdgpu_host_queue_t*)base_queue;
+
+  iree_hal_amdgpu_host_queue_op_submission_t submission;
+  iree_hal_amdgpu_host_queue_op_submission_begin(queue, wait_semaphore_list,
+                                                 &submission);
+  iree_status_t status = iree_ok_status();
+  if (submission.resolution.needs_deferral) {
+    status = iree_hal_amdgpu_host_queue_defer_fill(
+        queue, &wait_semaphore_list, &signal_semaphore_list, target_buffer,
+        target_offset, length, pattern_bits, pattern_length, flags,
+        &submission.deferred_op);
+  } else {
+    status = iree_hal_amdgpu_host_queue_submit_fill(
+        queue, &submission.resolution, signal_semaphore_list, target_buffer,
+        target_offset, length, pattern_bits, pattern_length, flags,
+        IREE_HAL_AMDGPU_HOST_QUEUE_SUBMISSION_FLAG_RETAIN_RESOURCES,
+        &submission.ready);
+    if (iree_status_is_ok(status) && !submission.ready) {
+      status = iree_hal_amdgpu_host_queue_defer_fill(
+          queue, &wait_semaphore_list, &signal_semaphore_list, target_buffer,
+          target_offset, length, pattern_bits, pattern_length, flags,
+          &submission.deferred_op);
+      iree_hal_amdgpu_host_queue_op_submission_defer_for_capacity(&submission);
+    }
+  }
+  return iree_hal_amdgpu_host_queue_op_submission_end(&submission, status);
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_copy_buffer(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t wait_semaphore_list,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
+    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+    iree_device_size_t length, iree_hal_copy_flags_t flags,
+    iree_hal_profile_queue_event_type_t profile_event_type) {
+  iree_hal_amdgpu_host_queue_op_submission_t submission;
+  iree_hal_amdgpu_host_queue_op_submission_begin(queue, wait_semaphore_list,
+                                                 &submission);
+  iree_status_t status = iree_ok_status();
+  if (submission.resolution.needs_deferral) {
+    status = iree_hal_amdgpu_host_queue_defer_copy(
+        queue, &wait_semaphore_list, &signal_semaphore_list, source_buffer,
+        source_offset, target_buffer, target_offset, length, flags,
+        profile_event_type, &submission.deferred_op);
+  } else {
+    status = iree_hal_amdgpu_host_queue_submit_copy(
+        queue, &submission.resolution, signal_semaphore_list, source_buffer,
+        source_offset, target_buffer, target_offset, length, flags,
+        profile_event_type,
+        IREE_HAL_AMDGPU_HOST_QUEUE_SUBMISSION_FLAG_RETAIN_RESOURCES,
+        &submission.ready);
+    if (iree_status_is_ok(status) && !submission.ready) {
+      status = iree_hal_amdgpu_host_queue_defer_copy(
+          queue, &wait_semaphore_list, &signal_semaphore_list, source_buffer,
+          source_offset, target_buffer, target_offset, length, flags,
+          profile_event_type, &submission.deferred_op);
+      iree_hal_amdgpu_host_queue_op_submission_defer_for_capacity(&submission);
+    }
+  }
+  return iree_hal_amdgpu_host_queue_op_submission_end(&submission, status);
+}
+
+// Queue copy entry point. The shared copy path is also used by file read/write
+// staging so all copy-shaped operations use the same wait/backpressure path.
+static iree_status_t iree_hal_amdgpu_host_queue_copy(
+    iree_hal_amdgpu_virtual_queue_t* base_queue,
+    const iree_hal_semaphore_list_t wait_semaphore_list,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
+    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+    iree_device_size_t length, iree_hal_copy_flags_t flags) {
+  return iree_hal_amdgpu_host_queue_copy_buffer(
+      (iree_hal_amdgpu_host_queue_t*)base_queue, wait_semaphore_list,
+      signal_semaphore_list, source_buffer, source_offset, target_buffer,
+      target_offset, length, flags, IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_COPY);
+}
+
+// Queue update entry point. Immediate updates copy into queue-owned kernarg
+// memory; deferred updates copy into the pending-op arena.
+static iree_status_t iree_hal_amdgpu_host_queue_update(
+    iree_hal_amdgpu_virtual_queue_t* base_queue,
+    const iree_hal_semaphore_list_t wait_semaphore_list,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    const void* source_buffer, iree_host_size_t source_offset,
+    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+    iree_device_size_t length, iree_hal_update_flags_t flags) {
+  iree_hal_amdgpu_host_queue_t* queue =
+      (iree_hal_amdgpu_host_queue_t*)base_queue;
+
+  iree_hal_amdgpu_host_queue_op_submission_t submission;
+  iree_hal_amdgpu_host_queue_op_submission_begin(queue, wait_semaphore_list,
+                                                 &submission);
+  iree_status_t status = iree_ok_status();
+  if (submission.resolution.needs_deferral) {
+    status = iree_hal_amdgpu_host_queue_defer_update(
+        queue, &wait_semaphore_list, &signal_semaphore_list, source_buffer,
+        source_offset, target_buffer, target_offset, length, flags,
+        &submission.deferred_op);
+  } else {
+    status = iree_hal_amdgpu_host_queue_submit_update(
+        queue, &submission.resolution, signal_semaphore_list, source_buffer,
+        source_offset, target_buffer, target_offset, length, flags,
+        IREE_HAL_AMDGPU_HOST_QUEUE_SUBMISSION_FLAG_RETAIN_RESOURCES,
+        &submission.ready);
+    if (iree_status_is_ok(status) && !submission.ready) {
+      status = iree_hal_amdgpu_host_queue_defer_update(
+          queue, &wait_semaphore_list, &signal_semaphore_list, source_buffer,
+          source_offset, target_buffer, target_offset, length, flags,
+          &submission.deferred_op);
+      iree_hal_amdgpu_host_queue_op_submission_defer_for_capacity(&submission);
+    }
+  }
+  return iree_hal_amdgpu_host_queue_op_submission_end(&submission, status);
+}
+
+static bool iree_hal_amdgpu_host_queue_is_noop_dispatch(
+    const iree_hal_dispatch_config_t config, iree_hal_dispatch_flags_t flags) {
+  return !iree_hal_dispatch_uses_indirect_parameters(flags) &&
+         (config.workgroup_count[0] | config.workgroup_count[1] |
+          config.workgroup_count[2]) == 0;
+}
+
+// Queue dispatch entry point. Empty direct dispatches route through the barrier
+// path so they still signal semaphores and profile as dispatch submissions.
+static iree_status_t iree_hal_amdgpu_host_queue_dispatch(
+    iree_hal_amdgpu_virtual_queue_t* base_queue,
+    const iree_hal_semaphore_list_t wait_semaphore_list,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_executable_t* executable,
+    iree_hal_executable_export_ordinal_t export_ordinal,
+    const iree_hal_dispatch_config_t config, iree_const_byte_span_t constants,
+    const iree_hal_buffer_ref_list_t bindings,
+    iree_hal_dispatch_flags_t flags) {
+  iree_hal_amdgpu_host_queue_t* queue =
+      (iree_hal_amdgpu_host_queue_t*)base_queue;
+  const bool is_noop_dispatch =
+      iree_hal_amdgpu_host_queue_is_noop_dispatch(config, flags);
+
+  iree_hal_amdgpu_host_queue_op_submission_t submission;
+  iree_hal_amdgpu_host_queue_op_submission_begin(queue, wait_semaphore_list,
+                                                 &submission);
+  iree_status_t status = iree_ok_status();
+  if (submission.resolution.needs_deferral) {
+    if (is_noop_dispatch) {
+      status = iree_hal_amdgpu_host_queue_defer_execute(
+          queue, &wait_semaphore_list, &signal_semaphore_list,
+          /*command_buffer=*/NULL, iree_hal_buffer_binding_table_empty(),
+          IREE_HAL_EXECUTE_FLAG_NONE, &submission.deferred_op);
+    } else {
+      status = iree_hal_amdgpu_host_queue_defer_dispatch(
+          queue, &wait_semaphore_list, &signal_semaphore_list, executable,
+          export_ordinal, config, constants, bindings, flags,
+          &submission.deferred_op);
+    }
+  } else if (is_noop_dispatch) {
+    uint64_t submission_id = 0;
+    iree_hal_amdgpu_host_queue_profile_event_info_t profile_event_info = {
+        .type = IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_DISPATCH,
+        .operation_count = 0,
+    };
+    status = iree_hal_amdgpu_host_queue_try_submit_barrier(
+        queue, &submission.resolution, signal_semaphore_list,
+        (iree_hal_amdgpu_reclaim_action_t){0},
+        /*operation_resources=*/NULL,
+        /*operation_resource_count=*/0, &profile_event_info,
+        iree_hal_amdgpu_host_queue_post_commit_callback_null(),
+        /*resource_set=*/NULL,
+        IREE_HAL_AMDGPU_HOST_QUEUE_SUBMISSION_FLAG_RETAIN_RESOURCES,
+        &submission.ready, &submission_id);
+    if (iree_status_is_ok(status) && submission.ready) {
+      profile_event_info.submission_id = submission_id;
+      iree_hal_amdgpu_host_queue_record_profile_queue_event(
+          queue, &submission.resolution, signal_semaphore_list,
+          &profile_event_info);
+    }
+    if (iree_status_is_ok(status) && !submission.ready) {
+      status = iree_hal_amdgpu_host_queue_defer_execute(
+          queue, &wait_semaphore_list, &signal_semaphore_list,
+          /*command_buffer=*/NULL, iree_hal_buffer_binding_table_empty(),
+          IREE_HAL_EXECUTE_FLAG_NONE, &submission.deferred_op);
+      iree_hal_amdgpu_host_queue_op_submission_defer_for_capacity(&submission);
+    }
+  } else {
+    status = iree_hal_amdgpu_host_queue_submit_dispatch(
+        queue, &submission.resolution, signal_semaphore_list, executable,
+        export_ordinal, config, constants, bindings, flags,
+        IREE_HAL_AMDGPU_HOST_QUEUE_SUBMISSION_FLAG_RETAIN_RESOURCES,
+        &submission.ready);
+    if (iree_status_is_ok(status) && !submission.ready) {
+      status = iree_hal_amdgpu_host_queue_defer_dispatch(
+          queue, &wait_semaphore_list, &signal_semaphore_list, executable,
+          export_ordinal, config, constants, bindings, flags,
+          &submission.deferred_op);
+      iree_hal_amdgpu_host_queue_op_submission_defer_for_capacity(&submission);
+    }
+  }
+  return iree_hal_amdgpu_host_queue_op_submission_end(&submission, status);
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_read(
+    iree_hal_amdgpu_virtual_queue_t* base_queue,
+    const iree_hal_semaphore_list_t wait_semaphore_list,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_file_t* source_file, uint64_t source_offset,
+    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+    iree_device_size_t length, iree_hal_read_flags_t flags) {
+  return iree_hal_amdgpu_host_queue_read_file(
+      base_queue, wait_semaphore_list, signal_semaphore_list, source_file,
+      source_offset, target_buffer, target_offset, length, flags);
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_write(
+    iree_hal_amdgpu_virtual_queue_t* base_queue,
+    const iree_hal_semaphore_list_t wait_semaphore_list,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
+    iree_hal_file_t* target_file, uint64_t target_offset,
+    iree_device_size_t length, iree_hal_write_flags_t flags) {
+  return iree_hal_amdgpu_host_queue_write_file(
+      base_queue, wait_semaphore_list, signal_semaphore_list, source_buffer,
+      source_offset, target_file, target_offset, length, flags);
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_enqueue_host_action(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t wait_semaphore_list,
+    iree_hal_amdgpu_reclaim_action_t action,
+    iree_hal_resource_t* const* operation_resources,
+    iree_host_size_t operation_resource_count) {
+  if (IREE_UNLIKELY(!action.fn)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "host action callback must be non-null");
+  }
+  if (IREE_UNLIKELY(operation_resource_count > 0 && !operation_resources)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "host action resources must be non-null");
+  }
+
+  iree_hal_amdgpu_host_queue_op_submission_t submission;
+  iree_hal_amdgpu_host_queue_op_submission_begin(queue, wait_semaphore_list,
+                                                 &submission);
+  // Host actions execute on CPU threads and must observe device-produced
+  // host-visible memory even when a semaphore edge itself is device-local.
+  submission.resolution.inline_acquire_scope =
+      iree_hal_amdgpu_host_queue_max_fence_scope(
+          submission.resolution.inline_acquire_scope,
+          IREE_HSA_FENCE_SCOPE_SYSTEM);
+  submission.resolution.barrier_acquire_scope =
+      iree_hal_amdgpu_host_queue_max_fence_scope(
+          submission.resolution.barrier_acquire_scope,
+          IREE_HSA_FENCE_SCOPE_SYSTEM);
+  iree_status_t status = iree_ok_status();
+  if (submission.resolution.needs_deferral) {
+    status = iree_hal_amdgpu_host_queue_defer_host_action(
+        queue, &wait_semaphore_list, action, operation_resources,
+        operation_resource_count, &submission.deferred_op);
+  } else {
+    status = iree_hal_amdgpu_host_queue_try_submit_barrier(
+        queue, &submission.resolution, iree_hal_semaphore_list_empty(), action,
+        operation_resources, operation_resource_count,
+        /*profile_event_info=*/NULL,
+        iree_hal_amdgpu_host_queue_post_commit_callback_null(),
+        /*resource_set=*/NULL,
+        IREE_HAL_AMDGPU_HOST_QUEUE_SUBMISSION_FLAG_RETAIN_RESOURCES,
+        &submission.ready, /*out_submission_id=*/NULL);
+    if (iree_status_is_ok(status) && !submission.ready) {
+      status = iree_hal_amdgpu_host_queue_defer_host_action(
+          queue, &wait_semaphore_list, action, operation_resources,
+          operation_resource_count, &submission.deferred_op);
+      iree_hal_amdgpu_host_queue_op_submission_defer_for_capacity(&submission);
+    }
+  }
+  return iree_hal_amdgpu_host_queue_op_submission_end(&submission, status);
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_host_call(
+    iree_hal_amdgpu_virtual_queue_t* base_queue,
+    const iree_hal_semaphore_list_t wait_semaphore_list,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_host_call_t call, const uint64_t args[4],
+    iree_hal_host_call_flags_t flags) {
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_host_queue_validate_host_call(call, args, flags));
+
+  iree_hal_amdgpu_host_queue_t* queue =
+      (iree_hal_amdgpu_host_queue_t*)base_queue;
+
+  iree_hal_amdgpu_host_queue_op_submission_t submission;
+  iree_hal_amdgpu_host_queue_op_submission_begin(queue, wait_semaphore_list,
+                                                 &submission);
+  iree_status_t status = iree_ok_status();
+  if (submission.resolution.needs_deferral) {
+    status = iree_hal_amdgpu_host_queue_defer_host_call(
+        queue, &wait_semaphore_list, &signal_semaphore_list, call, args, flags,
+        &submission.deferred_op);
+  } else {
+    status = iree_hal_amdgpu_host_queue_submit_host_call(
+        queue, &submission.resolution, signal_semaphore_list, call, args, flags,
+        &submission.ready);
+    if (iree_status_is_ok(status) && !submission.ready) {
+      status = iree_hal_amdgpu_host_queue_defer_host_call(
+          queue, &wait_semaphore_list, &signal_semaphore_list, call, args,
+          flags, &submission.deferred_op);
+      iree_hal_amdgpu_host_queue_op_submission_defer_for_capacity(&submission);
+    }
+  }
+  return iree_hal_amdgpu_host_queue_op_submission_end(&submission, status);
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_flush(
+    iree_hal_amdgpu_virtual_queue_t* base_queue) {
+  return iree_ok_status();
+}
+
+//===----------------------------------------------------------------------===//
+// Virtual queue vtable
+//===----------------------------------------------------------------------===//
+
+static void iree_hal_amdgpu_host_queue_deinitialize_vtable(
+    iree_hal_amdgpu_virtual_queue_t* base_queue) {
+  iree_hal_amdgpu_host_queue_deinitialize(
+      (iree_hal_amdgpu_host_queue_t*)base_queue);
+}
+
 static const iree_hal_amdgpu_virtual_queue_vtable_t
     iree_hal_amdgpu_host_queue_vtable = {
-        .deinitialize = iree_hal_amdgpu_host_queue_deinitialize,
+        .deinitialize = iree_hal_amdgpu_host_queue_deinitialize_vtable,
         .trim = iree_hal_amdgpu_host_queue_trim,
         .alloca = iree_hal_amdgpu_host_queue_alloca,
         .dealloca = iree_hal_amdgpu_host_queue_dealloca,
@@ -264,6 +1311,8 @@
         .copy = iree_hal_amdgpu_host_queue_copy,
         .read = iree_hal_amdgpu_host_queue_read,
         .write = iree_hal_amdgpu_host_queue_write,
+        .host_call = iree_hal_amdgpu_host_queue_host_call,
+        .dispatch = iree_hal_amdgpu_host_queue_dispatch,
         .execute = iree_hal_amdgpu_host_queue_execute,
         .flush = iree_hal_amdgpu_host_queue_flush,
 };
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue.h b/runtime/src/iree/hal/drivers/amdgpu/host_queue.h
index 92d0402..69d7052 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/host_queue.h
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue.h
@@ -1,4 +1,4 @@
-// Copyright 2025 The IREE Authors
+// Copyright 2026 The IREE Authors
 //
 // Licensed under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -7,35 +7,641 @@
 #ifndef IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_H_
 #define IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_H_
 
-#include "iree/hal/drivers/amdgpu/util/error_callback.h"
+#include "iree/async/frontier.h"
+#include "iree/async/proactor.h"
+#include "iree/async/semaphore.h"
+#include "iree/base/api.h"
+#include "iree/base/internal/arena.h"
+#include "iree/base/threading/thread.h"
+#include "iree/hal/api.h"
+#include "iree/hal/drivers/amdgpu/abi/profile.h"
+#include "iree/hal/drivers/amdgpu/abi/signal.h"
+#include "iree/hal/drivers/amdgpu/device/blit.h"
+#include "iree/hal/drivers/amdgpu/util/aql_ring.h"
+#include "iree/hal/drivers/amdgpu/util/block_pool.h"
+#include "iree/hal/drivers/amdgpu/util/epoch_signal_table.h"
+#include "iree/hal/drivers/amdgpu/util/kernarg_ring.h"
+#include "iree/hal/drivers/amdgpu/util/libhsa.h"
+#include "iree/hal/drivers/amdgpu/util/notification_ring.h"
+#include "iree/hal/drivers/amdgpu/util/pm4_capabilities.h"
+#include "iree/hal/drivers/amdgpu/util/queue_upload_ring.h"
 #include "iree/hal/drivers/amdgpu/virtual_queue.h"
+#include "iree/hal/pool.h"
+#include "iree/hal/profile_schema.h"
+#include "iree/hal/profile_sink.h"
 
-typedef struct iree_arena_block_pool_t iree_arena_block_pool_t;
-typedef struct iree_hal_amdgpu_block_allocators_t
-    iree_hal_amdgpu_block_allocators_t;
-typedef struct iree_hal_amdgpu_buffer_pool_t iree_hal_amdgpu_buffer_pool_t;
-typedef struct iree_hal_amdgpu_host_service_t iree_hal_amdgpu_host_service_t;
-typedef struct iree_hal_amdgpu_system_t iree_hal_amdgpu_system_t;
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+typedef struct iree_hal_amdgpu_pending_op_t iree_hal_amdgpu_pending_op_t;
+typedef struct iree_hal_amdgpu_pm4_ib_slot_t iree_hal_amdgpu_pm4_ib_slot_t;
+typedef struct iree_hal_amdgpu_host_queue_command_buffer_scratch_t
+    iree_hal_amdgpu_host_queue_command_buffer_scratch_t;
+typedef struct iree_hal_amdgpu_profile_counter_sample_slot_t
+    iree_hal_amdgpu_profile_counter_sample_slot_t;
+typedef struct iree_hal_amdgpu_profile_counter_range_slot_t
+    iree_hal_amdgpu_profile_counter_range_slot_t;
+typedef struct iree_hal_amdgpu_profile_counter_session_t
+    iree_hal_amdgpu_profile_counter_session_t;
+typedef struct iree_hal_amdgpu_profile_trace_session_t
+    iree_hal_amdgpu_profile_trace_session_t;
+typedef struct iree_hal_amdgpu_profile_trace_slot_t
+    iree_hal_amdgpu_profile_trace_slot_t;
+typedef struct iree_hal_amdgpu_staging_pool_t iree_hal_amdgpu_staging_pool_t;
+typedef struct iree_hal_amdgpu_transient_buffer_pool_t
+    iree_hal_amdgpu_transient_buffer_pool_t;
+typedef struct iree_async_frontier_tracker_t iree_async_frontier_tracker_t;
+
+// Queue-local reservation of dispatch profiling event records.
+typedef struct iree_hal_amdgpu_profile_dispatch_event_reservation_t {
+  // Logical ring position of the first reserved dispatch event.
+  uint64_t first_event_position;
+  // Number of reserved dispatch events.
+  uint32_t event_count;
+  // Reserved padding.
+  uint32_t reserved0;
+} iree_hal_amdgpu_profile_dispatch_event_reservation_t;
+
+// Queue-local reservation of device-timestamped queue operation records.
+typedef struct iree_hal_amdgpu_profile_queue_device_event_reservation_t {
+  // Logical ring position of the first reserved queue device event.
+  uint64_t first_event_position;
+  // Number of reserved queue device events.
+  uint32_t event_count;
+  // Reserved padding.
+  uint32_t reserved0;
+} iree_hal_amdgpu_profile_queue_device_event_reservation_t;
+
+typedef struct iree_hal_amdgpu_host_queue_post_drain_action_t
+    iree_hal_amdgpu_host_queue_post_drain_action_t;
+
+// Callback run by the completion thread after notification-ring drain has
+// published completed entries and reclaimed queue-owned ring state.
+typedef void(IREE_API_PTR* iree_hal_amdgpu_host_queue_post_drain_fn_t)(
+    void* user_data);
+
+// Intrusive completion-thread continuation queued by pre-signal reclaim
+// actions.
+//
+// Pre-signal actions run while notification-ring drain is still publishing a
+// completion entry. Work that may submit additional AQL packets must instead
+// queue one of these actions so it runs after drain has released all completed
+// notification/kernarg state.
+struct iree_hal_amdgpu_host_queue_post_drain_action_t {
+  // Next action in the queue-owned pending list.
+  iree_hal_amdgpu_host_queue_post_drain_action_t* next;
+  // Callback invoked exactly once after the action is dequeued.
+  iree_hal_amdgpu_host_queue_post_drain_fn_t fn;
+  // User data passed to |fn|.
+  void* user_data;
+};
 
 //===----------------------------------------------------------------------===//
 // iree_hal_amdgpu_host_queue_t
 //===----------------------------------------------------------------------===//
 
-// Calculates the size in bytes of the storage required for a queue
-// implementation based on the provided |options|.
-iree_host_size_t iree_hal_amdgpu_host_queue_calculate_size(
-    const iree_hal_amdgpu_queue_options_t* options);
+// Maximum number of frontier entries the queue's accumulated frontier can
+// track. Each entry is one (axis, epoch) pair representing a causal
+// dependency on another queue or device. 64 entries covers rack-scale
+// systems (8 machines x 8 GPUs x 4 queues = 256 theoretical axes, but a
+// single queue only waits on its collective peers — typically 8-16 axes).
+// Overflow is handled gracefully (frontier merge returns false, wait elision
+// degrades but correctness is preserved).
+//
+// Transition snapshots serialize this frontier verbatim, so the queue's
+// frontier capacity is tied to the notification ring's snapshot entry limit.
+#define IREE_HAL_AMDGPU_QUEUE_FRONTIER_CAPACITY \
+  IREE_HAL_AMDGPU_MAX_FRONTIER_SNAPSHOT_ENTRY_COUNT
 
-// Initializes |out_queue| in-place based on |options|.
+IREE_ASYNC_FIXED_FRONTIER_TYPE(iree_hal_amdgpu_host_queue_frontier_t,
+                               IREE_HAL_AMDGPU_QUEUE_FRONTIER_CAPACITY);
+
+// Maximum number of direct buffer bindings accepted by queue_dispatch.
+//
+// Command buffers support large binding tables through their own lifetime
+// tracking path. Direct dispatch keeps the submission path bounded and uses
+// queue-local scratch storage under submission_mutex.
+#define IREE_HAL_AMDGPU_HOST_QUEUE_DISPATCH_SCRATCH_BINDING_CAPACITY 256u
+
+// Maximum number of operation resources retained by one direct queue_dispatch:
+// the executable, one optional indirect-parameter buffer, plus one resource per
+// direct buffer binding.
+#define IREE_HAL_AMDGPU_HOST_QUEUE_DISPATCH_SCRATCH_RESOURCE_CAPACITY \
+  (2u + IREE_HAL_AMDGPU_HOST_QUEUE_DISPATCH_SCRATCH_BINDING_CAPACITY)
+
+// Host-driven queue with per-queue epoch signal and wait-backed
+// notification ring. Embeds iree_hal_amdgpu_virtual_queue_t at offset 0.
+//
+// The epoch signal (owned by the notification ring) is a single hsa_signal_t
+// set as completion_signal on each submission's last AQL packet. The CP
+// decrements it by 1 on completion. The notification ring maps epochs to
+// semaphore signals that the queue's completion thread drains when the epoch
+// advances.
+//
+// All queue operations enter through the virtual_queue vtable. There are no
+// public methods beyond initialize/deinitialize.
+typedef struct iree_hal_amdgpu_host_queue_t {
+  // Virtual queue vtable at offset 0.
+  iree_hal_amdgpu_virtual_queue_t base;
+
+  // HSA API handle for queue operations. Not retained.
+  const iree_hal_amdgpu_libhsa_t* libhsa;
+  // Logical device owning this queue. Not retained.
+  iree_hal_device_t* logical_device;
+  // Proactor used to arm async semaphore/timepoint waits. Borrowed from the
+  // logical device.
+  iree_async_proactor_t* proactor;
+  // Shared frontier tracker for this queue's axis. Borrowed from the logical
+  // device.
+  iree_async_frontier_tracker_t* frontier_tracker;
+  // Allocator used for host-side queue resources.
+  iree_allocator_t host_allocator;
+
+  // Sticky error status from the HSA queue error callback. Non-zero indicates
+  // an unrecoverable GPU fault (page fault, invalid packet, ECC error).
+  // First-error-wins CAS from the HSA runtime thread; acquire-loaded by the
+  // completion thread to fail pending semaphores instead of signaling.
+  // Owned by the queue (freed in deinit).
+  iree_atomic_intptr_t error_status;
+
+  // Hardware AQL queue created via hsa_queue_create. Owned by this queue.
+  hsa_queue_t* hardware_queue;
+
+  // Cached AQL ring state for zero-indirection packet submission.
+  // Initialized from hardware_queue at init time.
+  iree_hal_amdgpu_aql_ring_t aql_ring;
+
+  // Per-queue kernarg bump allocator backed by HSA kernarg-init memory.
+  iree_hal_amdgpu_kernarg_ring_t kernarg_ring;
+
+  // Per-queue upload ring for device-visible control records.
+  // Submission paths reserve from this only when they have queue-ordered
+  // metadata such as device-side fixup inputs.
+  iree_hal_amdgpu_queue_upload_ring_t queue_upload_ring;
+
+  // Optional per-AQL-slot PM4 IB buffer used by PM4-backed wait, transfer, and
+  // profiling snippets. This is not an independent scheduling ring: each slot
+  // is indexed by the matching AQL packet id and inherits the AQL ring's
+  // lifetime/backpressure.
+  iree_hal_amdgpu_pm4_ib_slot_t* pm4_ib_slots;
+
+  // Epoch-driven notification ring mapping submission completions to
+  // semaphore signals. The completion thread drains this ring.
+  iree_hal_amdgpu_notification_ring_t notification_ring;
+
+  // Completion-thread state for queue epoch drain and teardown/error wakeups.
+  struct {
+    // Host thread blocked on the queue epoch signal and draining completed
+    // notification-ring entries.
+    iree_thread_t* thread;
+    // HSA signal used to wake the completion thread during teardown or after
+    // an unrecoverable HSA queue error. Value 0 means the thread should
+    // continue waiting for completions; any other value requests exit after a
+    // final drain.
+    hsa_signal_t stop_signal;
+  } completion;
+
+  //--- Submission pipeline state -------------------------------------------//
+  //
+  // Threading model: the queue has three execution contexts.
+  //
+  //   Submission (any thread, serialized by submission_mutex):
+  //     AQL slot reservation, packet fill, notification ring push, frontier
+  //     snapshot push, queue frontier mutation, last_signal update, kernarg
+  //     allocation. Multiple threads may submit to the same queue; the mutex
+  //     serializes them. Independent queues do not synchronize.
+  //
+  //   Completion thread (single queue-owned host thread):
+  //     Waits on the notification ring epoch signal with
+  //     hsa_amd_signal_wait_any, drains completed entries, checks error_status,
+  //     and reclaims kernargs. Reads the notification ring (SPSC consumer) and
+  //     the atomic error_status. Never writes to submission-path fields.
+  //
+  //   HSA error callback (HSA runtime thread):
+  //     Writes error_status via atomic CAS. Signals
+  //     completion.stop_signal so the completion thread wakes and fails
+  //     outstanding notifications.
+  //
+  // Wait-resolution fast-path contract:
+  //   - Same-queue signal-before-wait is elided directly from the semaphore's
+  //     last_signal cache when the cached producer axis matches queue->axis
+  //     under this strategy's current all-BARRIER AQL policy.
+  //   - Local cross-queue waits use one producer epoch barrier when the
+  //     semaphore cache marks that producer frontier as exact and this queue's
+  //     frontier does not already dominate that producer axis/epoch.
+  //   - The full semaphore-frontier mutex/copy path is reserved for unresolved
+  //     waits whose cached producer frontier is not exact (for example, true
+  //     multi-producer fan-in) or for conservative fallback after
+  //     cache/frontier overflow.
+  //   - Wait-before-signal, remote/non-queue-domain axes, and queue teardown
+  //     use software deferral.
+  //
+  // Signal-commit fast-path contract:
+  //   - Each successful AQL submission advances this queue's epoch, merges this
+  //     queue axis into queue->frontier, reserves one queue-private reclaim
+  //     slot, pushes one notification-ring entry per user-visible signal
+  //     semaphore, and records enough signal metadata for completion drain.
+  //     Zero-signal submissions still consume one queue epoch and reclaim slot
+  //     so kernel resources retire through the same mechanism.
+  //   - Public/multi-producer semaphores publish queue->frontier under the
+  //     semaphore mutex so later waits can prove transitive dependencies.
+  //   - Private single-producer AMDGPU stream semaphores skip that mutex/copy
+  //     path and publish only the producer queue axis/epoch/value to the
+  //     seqlock-protected last_signal cache. Waiting on that producer epoch is
+  //     sufficient because all transitive waits are encoded before the
+  //     producer queue epoch can complete.
+
+  // Queue-local locks. Keep the 4-byte slim mutexes packed before pointer-sized
+  // continuation state.
+  struct {
+    // Serializes the submission path. All queue operations (dispatch, copy,
+    // fill, execute, etc.) acquire this before touching submission state and
+    // release after signal commit. The proactor thread does not acquire this.
+    iree_slim_mutex_t submission_mutex;
+    // Serializes the post-drain continuation list.
+    iree_slim_mutex_t post_drain_mutex;
+  } locks;
+
+  // Post-drain continuation queue for work that cannot run while notification
+  // drain is still publishing or reclaiming ring state.
+  struct {
+    // First queued post-drain continuation.
+    iree_hal_amdgpu_host_queue_post_drain_action_t* head;
+    // Tail pointer for appending post-drain continuations.
+    iree_hal_amdgpu_host_queue_post_drain_action_t* tail;
+  } post_drain;
+
+  // Queue-local scratch used by queue_dispatch under submission_mutex.
+  struct {
+    // Operation resources copied into the notification reclaim entry.
+    iree_hal_resource_t* operation_resources
+        [IREE_HAL_AMDGPU_HOST_QUEUE_DISPATCH_SCRATCH_RESOURCE_CAPACITY];
+    // Resolved device pointers written into final dispatch kernargs.
+    uint64_t binding_ptrs
+        [IREE_HAL_AMDGPU_HOST_QUEUE_DISPATCH_SCRATCH_BINDING_CAPACITY];
+  } dispatch_scratch;
+
+  // Lazily allocated queue_execute scratch storage. Kept out of the host queue
+  // object so direct-dispatch hot state does not carry command-buffer sideband
+  // arrays.
+  iree_hal_amdgpu_host_queue_command_buffer_scratch_t* command_buffer_scratch;
+
+  // Set under submission_mutex when queue teardown begins. Deferred ops whose
+  // waits race to completion after this point are failed with CANCELLED instead
+  // of issuing new AQL packets.
+  bool is_shutting_down;
+
+  // Profiling data-family state for this queue. Mutated only by device
+  // profiling begin/end while the profiling API's idle-device precondition is
+  // held.
+  struct {
+    // True when ROCR should populate dispatch completion signal timestamps.
+    uint32_t hsa_queue_timestamps_enabled : 1;
+    // True when host-side queue operation events should be recorded.
+    uint32_t queue_events_enabled : 1;
+    // True when device-timestamped queue operation events should be recorded.
+    uint32_t queue_device_events_enabled : 1;
+    // True when selected dispatches may receive profile packet augmentation.
+    uint32_t dispatch_profiling_enabled : 1;
+    // Serializes profile event ring mutation and flush.
+    iree_slim_mutex_t event_mutex;
+    // Raw completion-signal storage paired with dispatch event slots.
+    struct {
+      // Borrowed fine-grained GPU-agent block pool backing raw signal storage.
+      iree_hal_amdgpu_block_pool_t* block_pool;
+      // Host-side table of queue-owned GPU-agent raw signal blocks.
+      iree_hal_amdgpu_block_t** blocks;
+      // Number of entries in |blocks|.
+      uint32_t block_count;
+      // Number of iree_amd_signal_t records in each block.
+      uint32_t signals_per_block;
+    } signals;
+    // Shared device-visible allocation backing queue-local event rings.
+    struct {
+      // Allocation base returned by HSA memory pool allocation.
+      void* base;
+      // Byte length of |base|.
+      iree_host_size_t size;
+    } event_allocation;
+    // Device-visible dispatch event ring waiting for sink flush.
+    struct {
+      // Dispatch event record storage in the shared event allocation.
+      iree_hal_amdgpu_profile_dispatch_event_t* values;
+      // Power-of-two capacity of |values| in records.
+      uint32_t capacity;
+      // Capacity minus one, for mapping logical positions to physical slots.
+      uint32_t mask;
+      // Logical ring position of the next event to write to the sink.
+      uint64_t read_position;
+      // Logical ring position one past the last event ready to write.
+      uint64_t ready_position;
+      // Logical ring position one past the last reserved event.
+      uint64_t write_position;
+      // Next queue-local dispatch event id assigned during submission.
+      uint64_t next_event_id;
+    } dispatch_events;
+    // Device-visible queue operation event ring waiting for sink flush.
+    struct {
+      // Queue device event record storage in the shared event allocation.
+      iree_hal_amdgpu_profile_queue_device_event_t* values;
+      // Power-of-two capacity of |values| in records.
+      uint32_t capacity;
+      // Capacity minus one, for mapping logical positions to physical slots.
+      uint32_t mask;
+      // Logical ring position of the next event to write to the sink.
+      uint64_t read_position;
+      // Logical ring position one past the last event ready to write.
+      uint64_t ready_position;
+      // Logical ring position one past the last reserved event.
+      uint64_t write_position;
+      // Next queue-local queue-device event id assigned during submission.
+      uint64_t next_event_id;
+    } queue_device_events;
+    // Queue-local hardware counter profile resources.
+    struct {
+      // Borrowed hardware counter session active for this queue, or NULL.
+      iree_hal_amdgpu_profile_counter_session_t* session;
+      // Number of selected counter sets in |session|.
+      uint32_t set_count;
+      // Dispatch-attributed counter sample storage.
+      struct {
+        // Host-side slot table pairing dispatch event slots with aqlprofile
+        // handles.
+        iree_hal_amdgpu_profile_counter_sample_slot_t* slots;
+      } dispatch_samples;
+      // Queue-range counter sample storage.
+      struct {
+        // Host-side slot table pairing range banks with aqlprofile handles.
+        iree_hal_amdgpu_profile_counter_range_slot_t* slots;
+        // Device-visible timing records for each range bank.
+        uint64_t* ticks;
+        // Byte length of |ticks|.
+        iree_host_size_t tick_storage_size;
+        // Bank currently capturing queue work.
+        uint32_t active_bank;
+        // Number of reusable range banks in |slots| and |ticks|.
+        uint32_t bank_count;
+        // True when a range bank has been started and must be stopped.
+        bool is_active;
+      } ranges;
+    } counters;
+    // Queue-local executable trace profile resources.
+    struct {
+      // Borrowed executable trace session active for this queue, or NULL.
+      iree_hal_amdgpu_profile_trace_session_t* session;
+      // Host-side slot table pairing dispatch event slots with ATT handles.
+      iree_hal_amdgpu_profile_trace_slot_t* slots;
+    } traces;
+  } profiling;
+
+  // False once this queue's accumulated frontier overflows while merging waited
+  // axes. After that, the frontier remains a safe lower bound for resolving
+  // this queue's own waits, but it is no longer a conservative summary that can
+  // be published to public/multi-producer signal semaphores. Those signal
+  // commits therefore clear last_signal, skip semaphore-frontier merges, and
+  // stop pushing transition snapshots, forcing downstream not-yet-complete
+  // waits onto the software path instead of under-barriering.
+  bool can_publish_frontier;
+
+  // This queue's axis in the causal graph. Constructed from the system's
+  // session epoch + machine index and this queue's device/queue ordinals.
+  // Used to identify this queue in frontier entries and epoch signal lookups.
+  // Immutable after initialization.
+  iree_async_axis_t axis;
+
+  // Device-side wait strategy selected once from the GPU ISA at initialization.
+  iree_hal_amdgpu_wait_barrier_strategy_t wait_barrier_strategy;
+
+  // AMD vendor-packet capabilities selected from the GPU ISA.
+  iree_hal_amdgpu_vendor_packet_capability_flags_t vendor_packet_capabilities;
+
+  // One-bit logical queue affinity identifying this queue in HAL buffer
+  // placements. queue_alloca uses this as the transient wrapper's origin so
+  // PREFER_ORIGIN dealloca routes back to the same queue.
+  iree_hal_queue_affinity_t queue_affinity;
+
+  // Shared epoch signal table for cross-queue barrier emission (tier 2 wait
+  // resolution). Maps (device_index, queue_index) to hsa_signal_t for each
+  // queue's epoch signal. Used to look up peer queues' epoch signals when
+  // emitting AQL barrier-value packets for multi-axis dependencies (e.g.,
+  // TP collective joins needing barriers on 7 peer queues).
+  //
+  // Borrowed from the device/system — valid for the lifetime of the queue.
+  // This queue's own epoch signal is registered at init and deregistered at
+  // deinit. Read-only during normal operation.
+  iree_hal_amdgpu_epoch_signal_table_t* epoch_table;
+
+  // Last semaphore pushed to the notification ring and its epoch. Used to
+  // detect semaphore transitions for frontier snapshot recording: when a
+  // push targets a different semaphore than last_signal.semaphore, the
+  // signal commit path writes a frontier snapshot at last_signal.epoch
+  // before starting the new span.
+  //
+  // Protected by submission_mutex (submission-context-only).
+  //
+  // ABA safety: the semaphore pointer is only compared for identity (not
+  // dereferenced) during transition detection. ABA can occur if a semaphore
+  // is released and a new one is allocated at the same address between two
+  // submissions. This is benign:
+  //   - The old semaphore's notification entries must have been drained
+  //     before release (notification ring lifetime contract), so no
+  //     undrained entries for the old semaphore remain.
+  //   - A missed transition causes the new semaphore's entries to be
+  //     coalesced with a span that has no pending entries — the drain
+  //     produces the correct signal for the new semaphore.
+  //   - The frontier snapshot at the end of the coalesced span (when the
+  //     next transition occurs, or the fallback frontier at drain end)
+  //     captures the queue's accumulated frontier, which is an upper bound
+  //     on the actual causal context. Over-attribution (conservative), never
+  //     under-attribution (unsafe).
+  struct {
+    // Most recent semaphore pushed to the notification ring.
+    iree_async_semaphore_t* semaphore;
+    // Queue epoch associated with the most recent semaphore push.
+    uint64_t epoch;
+    // True when the current same-semaphore span requires a frontier snapshot
+    // if the next signal targets a different semaphore.
+    bool needs_frontier_snapshot;
+    // Reserved padding for stable layout.
+    uint8_t reserved[7];
+  } last_signal;
+
+  // Block pool for arena-allocating deferred operations. NUMA-pinned to the
+  // physical device. Borrowed from the physical device; valid for the
+  // lifetime of the queue.
+  iree_arena_block_pool_t* block_pool;
+
+  // Ordinal of this queue's physical device within the topology. Used to look
+  // up device-specific kernel_args from executables via
+  // iree_hal_amdgpu_executable_lookup_kernel_args_for_device.
+  iree_host_size_t device_ordinal;
+
+  // Builtin blit kernel table for this queue's physical device. Borrowed from
+  // the physical device and immutable for the queue's lifetime.
+  const iree_hal_amdgpu_device_buffer_transfer_context_t* transfer_context;
+
+  // Borrowed default pool set for this queue's physical device.
+  const iree_hal_pool_set_t* default_pool_set;
+
+  // Borrowed TLSF default pool for this queue's physical device.
+  iree_hal_pool_t* default_pool;
+
+  // Borrowed transient wrapper pool for queue_alloca results.
+  iree_hal_amdgpu_transient_buffer_pool_t* transient_buffer_pool;
+
+  // Borrowed fixed-size staging pool used by queue_read/queue_write for
+  // non-mappable file transfers.
+  iree_hal_amdgpu_staging_pool_t* staging_pool;
+
+  // Intrusive singly-linked list of pending (deferred) operations. Used for
+  // cleanup on shutdown and GPU fault propagation. Operations add themselves
+  // on deferral and remove themselves on issue/fail/cancel. Protected by
+  // submission_mutex.
+  iree_hal_amdgpu_pending_op_t* pending_head;
+
+  // Accumulated frontier. Advances on each AQL submission: the queue's own
+  // axis entry is set to the current epoch, and cross-queue wait dependencies
+  // are merged in. Used for:
+  //   - Queue-order wait elision (tier 1): queue->frontier dominates the wait
+  //     semaphore's frontier → no additional barrier packet is needed.
+  //   - Submission-time causal merge: merged into signal semaphores' frontiers
+  //     at AQL submission time so same-queue and already-dominated cross-queue
+  //     waits can resolve before GPU completion under the current all-barrier
+  //     AQL queue policy.
+  //   - Frontier snapshot recording: snapshotted to the notification ring's
+  //     frontier byte ring at semaphore transitions.
+  //
+  // Fixed-capacity storage for the accumulated frontier.
+  iree_hal_amdgpu_host_queue_frontier_t frontier;
+} iree_hal_amdgpu_host_queue_t;
+
+// Returns a pointer to the queue's accumulated frontier. The returned pointer
+// is layout-compatible with iree_async_frontier_t and valid for all frontier
+// APIs (compare, merge, etc.). Valid for the lifetime of the queue.
+static inline iree_async_frontier_t* iree_hal_amdgpu_host_queue_frontier(
+    iree_hal_amdgpu_host_queue_t* queue) {
+  return iree_async_fixed_frontier_as_frontier(&queue->frontier);
+}
+
+// Returns a const pointer to the queue's accumulated frontier.
+static inline const iree_async_frontier_t*
+iree_hal_amdgpu_host_queue_const_frontier(
+    const iree_hal_amdgpu_host_queue_t* queue) {
+  return iree_async_fixed_frontier_as_const_frontier(&queue->frontier);
+}
+
+// Submits a buffer-copy payload through the queue with the requested queue
+// profiling event type.
+iree_status_t iree_hal_amdgpu_host_queue_copy_buffer(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t wait_semaphore_list,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
+    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+    iree_device_size_t length, iree_hal_copy_flags_t flags,
+    iree_hal_profile_queue_event_type_t profile_event_type);
+
+// Enqueues a driver-owned host action ordered after |wait_semaphore_list|.
+// |action| uses the reclaim-action status ownership contract: OK means the
+// ordering barrier completed, while non-OK is a borrowed queue/device failure
+// status that must be cloned before any async propagation.
+// |operation_resources| are retained before this returns and released after the
+// action has executed or failed; callers keep ownership of their references.
+iree_status_t iree_hal_amdgpu_host_queue_enqueue_host_action(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t wait_semaphore_list,
+    iree_hal_amdgpu_reclaim_action_t action,
+    iree_hal_resource_t* const* operation_resources,
+    iree_host_size_t operation_resource_count);
+
+// Enqueues |action| to run on the queue completion thread after the current or
+// next notification-ring drain has fully published completed entries. The
+// action storage must remain valid until |action->fn| is invoked.
+void iree_hal_amdgpu_host_queue_enqueue_post_drain_action(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_host_queue_post_drain_action_t* action,
+    iree_hal_amdgpu_host_queue_post_drain_fn_t fn, void* user_data);
+
+// Initializes a host queue in caller-provided memory.
+// The caller must allocate at least sizeof(iree_hal_amdgpu_host_queue_t).
+//
+// Creates an HSA hardware queue on |gpu_agent|, initializes the AQL ring from
+// it, allocates a kernarg ring from |kernarg_memory|, creates the epoch signal
+// and notification ring, and starts the completion thread.
+//
+// |axis| is this queue's identity in the causal graph, constructed by the
+// caller from the system's session/machine identifiers and this queue's
+// device/queue ordinals via iree_async_axis_make_queue().
+//
+// |epoch_table| is the shared epoch signal table for cross-queue barrier
+// emission. This queue registers its epoch signal in the table at init and
+// deregisters at deinit. The table must outlive the queue.
+//
+// |completion_thread_affinity| pins the completion thread near the host CPU
+// agent associated with the GPU. The platform may ignore the request, but on
+// NUMA-aware systems this keeps blocked-wait wakeups and notification-ring
+// drains close to the GPU's nearest CPU node.
+//
+// |aql_queue_capacity| is the power-of-two hardware AQL queue size in packets.
+// |notification_capacity| is the power-of-two notification ring size.
+// |kernarg_capacity_in_blocks| is the power-of-two kernarg ring size in
+// 64-byte blocks, at least 2x |aql_queue_capacity| to cover one tail-padding
+// gap at wrap. Submission admission proves space in both the AQL and kernarg
+// rings before publishing packets.
+// |upload_capacity| is the byte capacity of the device-visible control upload
+// ring used for queue-ordered submission metadata. Zero disables the optional
+// upload ring; non-zero values must be powers of two.
+//
+// |vendor_packet_capabilities| describes the AQL/PM4 vendor-packet support
+// selected from the physical device ISA. Queues allocate dynamic PM4 IB slots
+// when AQL_PM4_IB is available so BARRIER_VALUE-based CDNA queues can still use
+// PM4 snippets for profiling or tiny operations.
+//
+// |profiling_signal_block_pool| provides fine-grained GPU-agent memory used for
+// raw iree_amd_signal_t records. The host initializes these records once when
+// timestamp profiling begins; packets only use them for CP-written profiling
+// timestamps and never for host HSA waits or interrupts.
 iree_status_t iree_hal_amdgpu_host_queue_initialize(
-    iree_hal_amdgpu_system_t* system, iree_hal_amdgpu_queue_options_t options,
-    hsa_agent_t device_agent, iree_host_size_t device_ordinal,
-    iree_hal_amdgpu_host_service_t* host_service,
-    iree_arena_block_pool_t* host_block_pool,
-    iree_hal_amdgpu_block_allocators_t* block_allocators,
-    iree_hal_amdgpu_buffer_pool_t* buffer_pool,
-    iree_hal_amdgpu_error_callback_t error_callback,
-    hsa_signal_t initialization_signal, iree_allocator_t host_allocator,
-    iree_hal_amdgpu_virtual_queue_t* out_queue);
+    const iree_hal_amdgpu_libhsa_t* libhsa, iree_hal_device_t* logical_device,
+    iree_async_proactor_t* proactor, hsa_agent_t gpu_agent,
+    const iree_hal_amdgpu_kernarg_ring_memory_t* kernarg_memory,
+    hsa_amd_memory_pool_t pm4_ib_pool,
+    iree_async_frontier_tracker_t* frontier_tracker, iree_async_axis_t axis,
+    iree_hal_queue_affinity_t queue_affinity,
+    iree_thread_affinity_t completion_thread_affinity,
+    iree_hal_amdgpu_wait_barrier_strategy_t wait_barrier_strategy,
+    iree_hal_amdgpu_vendor_packet_capability_flags_t vendor_packet_capabilities,
+    iree_hal_amdgpu_epoch_signal_table_t* epoch_table,
+    iree_arena_block_pool_t* block_pool,
+    iree_hal_amdgpu_block_pool_t* profiling_signal_block_pool,
+    const iree_hal_amdgpu_device_buffer_transfer_context_t* transfer_context,
+    const iree_hal_pool_set_t* default_pool_set, iree_hal_pool_t* default_pool,
+    iree_hal_amdgpu_transient_buffer_pool_t* transient_buffer_pool,
+    iree_hal_amdgpu_staging_pool_t* staging_pool,
+    iree_host_size_t device_ordinal, uint32_t aql_queue_capacity,
+    uint32_t notification_capacity, uint32_t kernarg_capacity_in_blocks,
+    uint32_t upload_capacity, iree_allocator_t host_allocator,
+    iree_hal_amdgpu_host_queue_t* out_queue);
+
+// Deinitializes the queue. Destroys all owned resources and stops the
+// completion thread.
+//
+// All in-flight work must have completed and been drained before calling.
+// The caller must ensure no concurrent access to the queue during deinit.
+void iree_hal_amdgpu_host_queue_deinitialize(
+    iree_hal_amdgpu_host_queue_t* queue);
+
+// Enables or disables HSA dispatch timestamp population for this queue.
+//
+// This toggles the ROCR queue profiler bit. It is a cold profiling-session
+// operation and must only be called while the device is idle, matching the HAL
+// profiling API contract.
+iree_status_t iree_hal_amdgpu_host_queue_set_hsa_profiling_enabled(
+    iree_hal_amdgpu_host_queue_t* queue, bool enabled);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
 
 #endif  // IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_blit.c b/runtime/src/iree/hal/drivers/amdgpu/host_queue_blit.c
new file mode 100644
index 0000000..65aedb2
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_blit.c
@@ -0,0 +1,846 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/host_queue_blit.h"
+
+#include <string.h>
+
+#include "iree/base/alignment.h"
+#include "iree/hal/drivers/amdgpu/buffer.h"
+#include "iree/hal/drivers/amdgpu/device/blit.h"
+#include "iree/hal/drivers/amdgpu/host_queue_profile.h"
+#include "iree/hal/drivers/amdgpu/util/pm4_emitter.h"
+
+static_assert(IREE_HAL_AMDGPU_DEVICE_BUFFER_FILL_KERNARG_SIZE <=
+                  sizeof(iree_hal_amdgpu_kernarg_block_t),
+              "fill kernargs must fit in one kernarg ring block");
+
+static iree_hal_amdgpu_host_queue_profile_event_info_t
+iree_hal_amdgpu_host_queue_make_blit_profile_event_info(
+    iree_hal_profile_queue_event_type_t type, uint64_t payload_length) {
+  iree_hal_amdgpu_host_queue_profile_event_info_t info = {
+      .type = type,
+      .payload_length = payload_length,
+      .operation_count = 1,
+  };
+  return info;
+}
+
+static void iree_hal_amdgpu_host_queue_record_submitted_blit_profile_event(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    uint64_t submission_id,
+    iree_hal_amdgpu_host_queue_profile_event_info_t* info) {
+  info->submission_id = submission_id;
+  iree_hal_amdgpu_host_queue_record_profile_queue_event(
+      queue, resolution, signal_semaphore_list, info);
+}
+
+// PM4 WRITE_DATA payload for tiny queue_fill/queue_update operations.
+typedef struct iree_hal_amdgpu_host_queue_pm4_write_data_t {
+  // Device-visible target pointer written by PM4 WRITE_DATA.
+  void* target_device_ptr;
+  // Immediate value written to |target_device_ptr|.
+  uint64_t value;
+  // Byte length written from |value|; currently 4 or 8.
+  uint8_t length;
+} iree_hal_amdgpu_host_queue_pm4_write_data_t;
+
+// PM4 COPY_DATA payload for tiny queue_copy operations.
+typedef struct iree_hal_amdgpu_host_queue_pm4_copy_data_t {
+  // Device-visible source pointer read by PM4 COPY_DATA.
+  const void* source_device_ptr;
+  // Device-visible target pointer written by PM4 COPY_DATA.
+  void* target_device_ptr;
+  // Byte length copied from source to target; currently 4 or 8.
+  uint8_t length;
+} iree_hal_amdgpu_host_queue_pm4_copy_data_t;
+
+// Validates a queue_fill target and resolves the target device pointer.
+static iree_status_t iree_hal_amdgpu_host_queue_prepare_fill_target(
+    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+    iree_device_size_t length, iree_host_size_t pattern_length,
+    iree_hal_fill_flags_t flags, uint8_t** out_target_device_ptr) {
+  *out_target_device_ptr = NULL;
+
+  if (IREE_UNLIKELY(!target_buffer)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "target buffer must be non-null");
+  }
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_usage(
+      iree_hal_buffer_allowed_usage(target_buffer),
+      IREE_HAL_BUFFER_USAGE_TRANSFER_TARGET));
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_access(
+      iree_hal_buffer_allowed_access(target_buffer),
+      IREE_HAL_MEMORY_ACCESS_WRITE));
+  IREE_RETURN_IF_ERROR(
+      iree_hal_buffer_validate_range(target_buffer, target_offset, length));
+
+  if (IREE_UNLIKELY(pattern_length != 1 && pattern_length != 2 &&
+                    pattern_length != 4)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "fill patterns must be 1, 2, or 4 bytes (got %" PRIhsz ")",
+        pattern_length);
+  }
+  if (IREE_UNLIKELY(flags != IREE_HAL_FILL_FLAG_NONE)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "unsupported fill flags: 0x%" PRIx64, flags);
+  }
+
+  iree_hal_buffer_t* allocated_target_buffer =
+      iree_hal_buffer_allocated_buffer(target_buffer);
+  uint8_t* target_device_ptr =
+      (uint8_t*)iree_hal_amdgpu_buffer_device_pointer(allocated_target_buffer);
+  if (IREE_UNLIKELY(!target_device_ptr)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "target buffer must be backed by an AMDGPU allocation");
+  }
+  target_device_ptr +=
+      iree_hal_buffer_byte_offset(target_buffer) + target_offset;
+
+  *out_target_device_ptr = target_device_ptr;
+  return iree_ok_status();
+}
+
+// Returns the low-byte fill pattern extended to a full 64-bit repetition.
+static uint64_t iree_hal_amdgpu_host_queue_extend_fill_pattern_x8(
+    uint64_t pattern_bits, iree_host_size_t pattern_length) {
+  switch (pattern_length) {
+    case 1: {
+      const uint64_t pattern = pattern_bits & 0xFFu;
+      return pattern * 0x0101010101010101ull;
+    }
+    case 2: {
+      const uint64_t pattern = pattern_bits & 0xFFFFu;
+      return pattern | (pattern << 16) | (pattern << 32) | (pattern << 48);
+    }
+    default: {
+      const uint64_t pattern = pattern_bits & 0xFFFFFFFFull;
+      return pattern | (pattern << 32);
+    }
+  }
+}
+
+static bool iree_hal_amdgpu_host_queue_can_use_pm4_write_data(
+    const iree_hal_amdgpu_host_queue_t* queue, const void* target_device_ptr,
+    iree_host_size_t length) {
+  return queue->pm4_ib_slots &&
+         iree_hal_amdgpu_vendor_packet_capabilities_support_pm4_memory_write_data(
+             queue->vendor_packet_capabilities) &&
+         (length == 4 || length == 8) &&
+         iree_host_ptr_has_alignment(target_device_ptr, sizeof(uint32_t));
+}
+
+static bool iree_hal_amdgpu_host_queue_prepare_pm4_fill_write_data(
+    const iree_hal_amdgpu_host_queue_t* queue, void* target_device_ptr,
+    iree_device_size_t length, uint64_t pattern_bits,
+    iree_host_size_t pattern_length,
+    iree_hal_amdgpu_host_queue_pm4_write_data_t* out_write_data) {
+  if (length != 4 && length != 8) {
+    return false;
+  }
+  if (!iree_hal_amdgpu_host_queue_can_use_pm4_write_data(
+          queue, target_device_ptr, (iree_host_size_t)length)) {
+    return false;
+  }
+  if (!iree_host_ptr_has_alignment(target_device_ptr, pattern_length) ||
+      !iree_device_size_has_alignment(length, pattern_length)) {
+    return false;
+  }
+
+  out_write_data->target_device_ptr = target_device_ptr;
+  out_write_data->value = iree_hal_amdgpu_host_queue_extend_fill_pattern_x8(
+      pattern_bits, pattern_length);
+  out_write_data->length = (uint8_t)length;
+  return true;
+}
+
+static bool iree_hal_amdgpu_host_queue_prepare_pm4_update_write_data(
+    const iree_hal_amdgpu_host_queue_t* queue, const uint8_t* source_bytes,
+    iree_host_size_t source_length, void* target_device_ptr,
+    iree_hal_amdgpu_host_queue_pm4_write_data_t* out_write_data) {
+  if (!iree_hal_amdgpu_host_queue_can_use_pm4_write_data(
+          queue, target_device_ptr, source_length)) {
+    return false;
+  }
+
+  out_write_data->target_device_ptr = target_device_ptr;
+  out_write_data->value = 0;
+  out_write_data->length = (uint8_t)source_length;
+  memcpy(&out_write_data->value, source_bytes, source_length);
+  return true;
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_submit_pm4_write_data(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_buffer_t* target_buffer,
+    const iree_hal_amdgpu_host_queue_pm4_write_data_t* write_data,
+    const iree_hal_amdgpu_host_queue_profile_event_info_t* profile_event_info,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    bool* out_ready, uint64_t* out_submission_id) {
+  if (out_submission_id) *out_submission_id = 0;
+  iree_hal_resource_t* operation_resources[1] = {
+      (iree_hal_resource_t*)target_buffer,
+  };
+
+  iree_hal_amdgpu_host_queue_pm4_ib_submission_t submission;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_try_begin_pm4_ib_submission(
+      queue, resolution, signal_semaphore_list,
+      IREE_ARRAYSIZE(operation_resources), profile_event_info, out_ready,
+      &submission));
+  if (!*out_ready) return iree_ok_status();
+
+  if (write_data->length == 4) {
+    uint32_t value = 0;
+    memcpy(&value, &write_data->value, sizeof(value));
+    submission.ib_dword_count = iree_hal_amdgpu_pm4_emit_write_data32(
+        submission.pm4_ib_slot, write_data->target_device_ptr, value);
+  } else {
+    submission.ib_dword_count = iree_hal_amdgpu_pm4_emit_write_data64(
+        submission.pm4_ib_slot, write_data->target_device_ptr,
+        write_data->value);
+  }
+  const uint64_t submission_epoch =
+      iree_hal_amdgpu_host_queue_finish_pm4_ib_submission(
+          queue, resolution, signal_semaphore_list, operation_resources,
+          IREE_ARRAYSIZE(operation_resources), profile_event_info,
+          submission_flags, &submission);
+  if (out_submission_id) *out_submission_id = submission_epoch;
+  return iree_ok_status();
+}
+
+static bool iree_hal_amdgpu_host_queue_prepare_pm4_copy_data(
+    const iree_hal_amdgpu_host_queue_t* queue, const void* source_device_ptr,
+    void* target_device_ptr, iree_device_size_t length,
+    iree_hal_amdgpu_host_queue_pm4_copy_data_t* out_copy_data) {
+  if (!queue->pm4_ib_slots ||
+      !iree_hal_amdgpu_vendor_packet_capabilities_support_pm4_memory_copy_data(
+          queue->vendor_packet_capabilities)) {
+    return false;
+  }
+  switch (length) {
+    case 4:
+      if (!iree_host_ptr_has_alignment(source_device_ptr, sizeof(uint32_t)) ||
+          !iree_host_ptr_has_alignment(target_device_ptr, sizeof(uint32_t))) {
+        return false;
+      }
+      break;
+    case 8:
+      if (!iree_host_ptr_has_alignment(source_device_ptr, sizeof(uint64_t)) ||
+          !iree_host_ptr_has_alignment(target_device_ptr, sizeof(uint64_t))) {
+        return false;
+      }
+      break;
+    default:
+      return false;
+  }
+
+  out_copy_data->source_device_ptr = source_device_ptr;
+  out_copy_data->target_device_ptr = target_device_ptr;
+  out_copy_data->length = (uint8_t)length;
+  return true;
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_submit_pm4_copy_data(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_buffer_t* source_buffer, iree_hal_buffer_t* target_buffer,
+    const iree_hal_amdgpu_host_queue_pm4_copy_data_t* copy_data,
+    const iree_hal_amdgpu_host_queue_profile_event_info_t* profile_event_info,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    bool* out_ready, uint64_t* out_submission_id) {
+  if (out_submission_id) *out_submission_id = 0;
+  iree_hal_resource_t* operation_resources[2] = {
+      (iree_hal_resource_t*)source_buffer,
+      (iree_hal_resource_t*)target_buffer,
+  };
+
+  iree_hal_amdgpu_host_queue_pm4_ib_submission_t submission;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_try_begin_pm4_ib_submission(
+      queue, resolution, signal_semaphore_list,
+      IREE_ARRAYSIZE(operation_resources), profile_event_info, out_ready,
+      &submission));
+  if (!*out_ready) return iree_ok_status();
+
+  if (copy_data->length == 4) {
+    submission.ib_dword_count = iree_hal_amdgpu_pm4_emit_copy_data32(
+        submission.pm4_ib_slot, copy_data->source_device_ptr,
+        copy_data->target_device_ptr);
+  } else {
+    submission.ib_dword_count = iree_hal_amdgpu_pm4_emit_copy_data64(
+        submission.pm4_ib_slot, copy_data->source_device_ptr,
+        copy_data->target_device_ptr);
+  }
+  const uint64_t submission_epoch =
+      iree_hal_amdgpu_host_queue_finish_pm4_ib_submission(
+          queue, resolution, signal_semaphore_list, operation_resources,
+          IREE_ARRAYSIZE(operation_resources), profile_event_info,
+          submission_flags, &submission);
+  if (out_submission_id) *out_submission_id = submission_epoch;
+  return iree_ok_status();
+}
+
+// Prepares a fill dispatch packet and kernargs in stack-local storage without
+// touching queue rings. All user-input validation must happen before this so
+// the caller can avoid reserving AQL slots before the packet shape is known.
+static iree_status_t iree_hal_amdgpu_host_queue_prepare_fill_dispatch(
+    const iree_hal_amdgpu_host_queue_t* queue, uint8_t* target_device_ptr,
+    iree_device_size_t length, uint64_t pattern_bits,
+    iree_host_size_t pattern_length,
+    iree_hsa_kernel_dispatch_packet_t* out_dispatch_packet,
+    iree_hal_amdgpu_device_buffer_fill_kernargs_t* out_kernargs) {
+  iree_hsa_kernel_dispatch_packet_t dispatch_packet;
+  memset(&dispatch_packet, 0, sizeof(dispatch_packet));
+  iree_hal_amdgpu_device_buffer_fill_kernargs_t kernargs;
+  memset(&kernargs, 0, sizeof(kernargs));
+  if (IREE_UNLIKELY(!iree_hal_amdgpu_device_buffer_fill_emplace(
+          queue->transfer_context, &dispatch_packet, target_device_ptr, length,
+          pattern_bits, (uint8_t)pattern_length, &kernargs))) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "unsupported fill dispatch shape (length=%" PRIdsz
+                            ", pattern_length=%" PRIhsz ")",
+                            length, pattern_length);
+  }
+  dispatch_packet.kernarg_address = NULL;
+
+  *out_dispatch_packet = dispatch_packet;
+  *out_kernargs = kernargs;
+  return iree_ok_status();
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_submit_fill(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+    iree_device_size_t length, uint64_t pattern_bits,
+    iree_host_size_t pattern_length, iree_hal_fill_flags_t flags,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    bool* out_ready) {
+  IREE_ASSERT_ARGUMENT(out_ready);
+  *out_ready = false;
+  if (IREE_UNLIKELY(queue->is_shutting_down)) {
+    return iree_make_status(IREE_STATUS_CANCELLED, "queue shutting down");
+  }
+
+  uint8_t* target_device_ptr = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_prepare_fill_target(
+      target_buffer, target_offset, length, pattern_length, flags,
+      &target_device_ptr));
+
+  iree_hal_amdgpu_host_queue_pm4_write_data_t pm4_write_data;
+  if (iree_hal_amdgpu_host_queue_prepare_pm4_fill_write_data(
+          queue, target_device_ptr, length, pattern_bits, pattern_length,
+          &pm4_write_data)) {
+    iree_hal_amdgpu_host_queue_profile_event_info_t profile_event_info =
+        iree_hal_amdgpu_host_queue_make_blit_profile_event_info(
+            IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_FILL, length);
+    uint64_t submission_id = 0;
+    iree_status_t status = iree_hal_amdgpu_host_queue_submit_pm4_write_data(
+        queue, resolution, signal_semaphore_list, target_buffer,
+        &pm4_write_data, &profile_event_info, submission_flags, out_ready,
+        &submission_id);
+    if (iree_status_is_ok(status) && *out_ready) {
+      iree_hal_amdgpu_host_queue_record_submitted_blit_profile_event(
+          queue, resolution, signal_semaphore_list, submission_id,
+          &profile_event_info);
+    }
+    return status;
+  }
+
+  iree_hsa_kernel_dispatch_packet_t dispatch_packet;
+  iree_hal_amdgpu_device_buffer_fill_kernargs_t kernargs;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_prepare_fill_dispatch(
+      queue, target_device_ptr, length, pattern_bits, pattern_length,
+      &dispatch_packet, &kernargs));
+
+  iree_hal_resource_t* operation_resources[1] = {
+      (iree_hal_resource_t*)target_buffer,
+  };
+  uint64_t submission_id = 0;
+  iree_hal_amdgpu_host_queue_profile_event_info_t profile_event_info =
+      iree_hal_amdgpu_host_queue_make_blit_profile_event_info(
+          IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_FILL, length);
+  iree_status_t status = iree_hal_amdgpu_host_queue_submit_dispatch_packet(
+      queue, resolution, signal_semaphore_list, &dispatch_packet, &kernargs,
+      sizeof(kernargs), operation_resources,
+      IREE_ARRAYSIZE(operation_resources), &profile_event_info,
+      submission_flags, out_ready, &submission_id);
+  if (iree_status_is_ok(status) && *out_ready) {
+    iree_hal_amdgpu_host_queue_record_submitted_blit_profile_event(
+        queue, resolution, signal_semaphore_list, submission_id,
+        &profile_event_info);
+  }
+  return status;
+}
+
+static_assert(IREE_HAL_AMDGPU_DEVICE_BUFFER_COPY_KERNARG_SIZE <=
+                  sizeof(iree_hal_amdgpu_kernarg_block_t),
+              "copy kernargs must fit in one kernarg ring block");
+
+// Validates a queue_copy request and resolves device pointers. Overlapping
+// ranges within the same buffer are rejected here because both PM4 COPY_DATA
+// and the builtin copy kernels implement memcpy semantics, not memmove.
+static iree_status_t iree_hal_amdgpu_host_queue_prepare_copy_ranges(
+    iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
+    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+    iree_device_size_t length, iree_hal_copy_flags_t flags,
+    const uint8_t** out_source_device_ptr, uint8_t** out_target_device_ptr) {
+  *out_source_device_ptr = NULL;
+  *out_target_device_ptr = NULL;
+
+  if (IREE_UNLIKELY(!source_buffer)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "source buffer must be non-null");
+  }
+  if (IREE_UNLIKELY(!target_buffer)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "target buffer must be non-null");
+  }
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_usage(
+      iree_hal_buffer_allowed_usage(source_buffer),
+      IREE_HAL_BUFFER_USAGE_TRANSFER_SOURCE));
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_access(
+      iree_hal_buffer_allowed_access(source_buffer),
+      IREE_HAL_MEMORY_ACCESS_READ));
+  IREE_RETURN_IF_ERROR(
+      iree_hal_buffer_validate_range(source_buffer, source_offset, length));
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_usage(
+      iree_hal_buffer_allowed_usage(target_buffer),
+      IREE_HAL_BUFFER_USAGE_TRANSFER_TARGET));
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_access(
+      iree_hal_buffer_allowed_access(target_buffer),
+      IREE_HAL_MEMORY_ACCESS_WRITE));
+  IREE_RETURN_IF_ERROR(
+      iree_hal_buffer_validate_range(target_buffer, target_offset, length));
+
+  if (IREE_UNLIKELY(flags != IREE_HAL_COPY_FLAG_NONE)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "unsupported copy flags: 0x%" PRIx64, flags);
+  }
+  if (IREE_UNLIKELY(iree_hal_buffer_test_overlap(source_buffer, source_offset,
+                                                 length, target_buffer,
+                                                 target_offset, length) !=
+                    IREE_HAL_BUFFER_OVERLAP_DISJOINT)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "source and target ranges must not overlap within the same buffer");
+  }
+
+  iree_hal_buffer_t* allocated_source_buffer =
+      iree_hal_buffer_allocated_buffer(source_buffer);
+  const uint8_t* source_device_ptr =
+      (const uint8_t*)iree_hal_amdgpu_buffer_device_pointer(
+          allocated_source_buffer);
+  if (IREE_UNLIKELY(!source_device_ptr)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "source buffer must be backed by an AMDGPU allocation");
+  }
+  source_device_ptr +=
+      iree_hal_buffer_byte_offset(source_buffer) + source_offset;
+
+  iree_hal_buffer_t* allocated_target_buffer =
+      iree_hal_buffer_allocated_buffer(target_buffer);
+  uint8_t* target_device_ptr =
+      (uint8_t*)iree_hal_amdgpu_buffer_device_pointer(allocated_target_buffer);
+  if (IREE_UNLIKELY(!target_device_ptr)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "target buffer must be backed by an AMDGPU allocation");
+  }
+  target_device_ptr +=
+      iree_hal_buffer_byte_offset(target_buffer) + target_offset;
+
+  *out_source_device_ptr = source_device_ptr;
+  *out_target_device_ptr = target_device_ptr;
+  return iree_ok_status();
+}
+
+// Prepares a copy dispatch packet and kernargs in stack-local storage without
+// touching queue rings. All user-input validation must happen before this so
+// the caller can avoid reserving AQL slots before the packet shape is known.
+static iree_status_t iree_hal_amdgpu_host_queue_prepare_copy_dispatch(
+    const iree_hal_amdgpu_host_queue_t* queue, const uint8_t* source_device_ptr,
+    iree_device_size_t source_offset, uint8_t* target_device_ptr,
+    iree_device_size_t target_offset, iree_device_size_t length,
+    iree_hsa_kernel_dispatch_packet_t* out_dispatch_packet,
+    iree_hal_amdgpu_device_buffer_copy_kernargs_t* out_kernargs) {
+  iree_hsa_kernel_dispatch_packet_t dispatch_packet;
+  memset(&dispatch_packet, 0, sizeof(dispatch_packet));
+  iree_hal_amdgpu_device_buffer_copy_kernargs_t kernargs;
+  memset(&kernargs, 0, sizeof(kernargs));
+  if (IREE_UNLIKELY(!iree_hal_amdgpu_device_buffer_copy_emplace(
+          queue->transfer_context, &dispatch_packet, source_device_ptr,
+          target_device_ptr, length, &kernargs))) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "unsupported copy dispatch shape (source_offset=%" PRIdsz
+        ", target_offset=%" PRIdsz ", length=%" PRIdsz ")",
+        source_offset, target_offset, length);
+  }
+  dispatch_packet.kernarg_address = NULL;
+
+  *out_dispatch_packet = dispatch_packet;
+  *out_kernargs = kernargs;
+  return iree_ok_status();
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_submit_copy(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
+    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+    iree_device_size_t length, iree_hal_copy_flags_t flags,
+    iree_hal_profile_queue_event_type_t profile_event_type,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    bool* out_ready) {
+  IREE_ASSERT_ARGUMENT(out_ready);
+  *out_ready = false;
+  if (IREE_UNLIKELY(queue->is_shutting_down)) {
+    return iree_make_status(IREE_STATUS_CANCELLED, "queue shutting down");
+  }
+
+  const uint8_t* source_device_ptr = NULL;
+  uint8_t* target_device_ptr = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_prepare_copy_ranges(
+      source_buffer, source_offset, target_buffer, target_offset, length, flags,
+      &source_device_ptr, &target_device_ptr));
+
+  iree_hal_amdgpu_host_queue_pm4_copy_data_t pm4_copy_data;
+  if (iree_hal_amdgpu_host_queue_prepare_pm4_copy_data(
+          queue, source_device_ptr, target_device_ptr, length,
+          &pm4_copy_data)) {
+    iree_hal_amdgpu_host_queue_profile_event_info_t profile_event_info =
+        iree_hal_amdgpu_host_queue_make_blit_profile_event_info(
+            profile_event_type, length);
+    uint64_t submission_id = 0;
+    iree_status_t status = iree_hal_amdgpu_host_queue_submit_pm4_copy_data(
+        queue, resolution, signal_semaphore_list, source_buffer, target_buffer,
+        &pm4_copy_data, &profile_event_info, submission_flags, out_ready,
+        &submission_id);
+    if (iree_status_is_ok(status) && *out_ready) {
+      iree_hal_amdgpu_host_queue_record_submitted_blit_profile_event(
+          queue, resolution, signal_semaphore_list, submission_id,
+          &profile_event_info);
+    }
+    return status;
+  }
+
+  iree_hsa_kernel_dispatch_packet_t dispatch_packet;
+  iree_hal_amdgpu_device_buffer_copy_kernargs_t kernargs;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_prepare_copy_dispatch(
+      queue, source_device_ptr, source_offset, target_device_ptr, target_offset,
+      length, &dispatch_packet, &kernargs));
+
+  iree_hal_resource_t* operation_resources[2] = {
+      (iree_hal_resource_t*)source_buffer,
+      (iree_hal_resource_t*)target_buffer,
+  };
+  uint64_t submission_id = 0;
+  iree_hal_amdgpu_host_queue_profile_event_info_t profile_event_info =
+      iree_hal_amdgpu_host_queue_make_blit_profile_event_info(
+          profile_event_type, length);
+  iree_status_t status = iree_hal_amdgpu_host_queue_submit_dispatch_packet(
+      queue, resolution, signal_semaphore_list, &dispatch_packet, &kernargs,
+      sizeof(kernargs), operation_resources,
+      IREE_ARRAYSIZE(operation_resources), &profile_event_info,
+      submission_flags, out_ready, &submission_id);
+  if (iree_status_is_ok(status) && *out_ready) {
+    iree_hal_amdgpu_host_queue_record_submitted_blit_profile_event(
+        queue, resolution, signal_semaphore_list, submission_id,
+        &profile_event_info);
+  }
+  return status;
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_submit_copy_with_action(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
+    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+    iree_device_size_t length, iree_hal_copy_flags_t flags,
+    iree_hsa_fence_scope_t minimum_acquire_scope,
+    iree_hsa_fence_scope_t minimum_release_scope,
+    iree_hal_amdgpu_reclaim_action_t pre_signal_action,
+    iree_hal_resource_t* const* extra_operation_resources,
+    iree_host_size_t extra_operation_resource_count,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    bool* out_ready) {
+  IREE_ASSERT_ARGUMENT(out_ready);
+  *out_ready = false;
+  if (IREE_UNLIKELY(queue->is_shutting_down)) {
+    return iree_make_status(IREE_STATUS_CANCELLED, "queue shutting down");
+  }
+  if (IREE_UNLIKELY(extra_operation_resource_count > 0 &&
+                    !extra_operation_resources)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "extra operation resources must be non-null");
+  }
+
+  const uint8_t* source_device_ptr = NULL;
+  uint8_t* target_device_ptr = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_prepare_copy_ranges(
+      source_buffer, source_offset, target_buffer, target_offset, length, flags,
+      &source_device_ptr, &target_device_ptr));
+
+  iree_hsa_kernel_dispatch_packet_t dispatch_packet;
+  iree_hal_amdgpu_device_buffer_copy_kernargs_t kernargs;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_prepare_copy_dispatch(
+      queue, source_device_ptr, source_offset, target_device_ptr, target_offset,
+      length, &dispatch_packet, &kernargs));
+
+  iree_host_size_t operation_resource_count = 0;
+  if (!iree_host_size_checked_add(2, extra_operation_resource_count,
+                                  &operation_resource_count)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "copy operation resource count overflows");
+  }
+  iree_host_size_t operation_resources_size = 0;
+  if (!iree_host_size_checked_mul(operation_resource_count,
+                                  sizeof(iree_hal_resource_t*),
+                                  &operation_resources_size)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "copy operation resource table size overflows");
+  }
+  iree_hal_resource_t** operation_resources =
+      (iree_hal_resource_t**)iree_alloca(operation_resources_size);
+  operation_resources[0] = (iree_hal_resource_t*)source_buffer;
+  operation_resources[1] = (iree_hal_resource_t*)target_buffer;
+  for (iree_host_size_t i = 0; i < extra_operation_resource_count; ++i) {
+    operation_resources[2 + i] = extra_operation_resources[i];
+  }
+
+  iree_hal_amdgpu_host_queue_dispatch_submission_t submission;
+  iree_hal_amdgpu_host_queue_profile_event_info_t profile_event_info =
+      iree_hal_amdgpu_host_queue_make_blit_profile_event_info(
+          IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_COPY, length);
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_try_begin_dispatch_submission(
+      queue, resolution, signal_semaphore_list, operation_resource_count,
+      /*kernarg_block_count=*/1,
+      (iree_hal_amdgpu_profile_dispatch_event_reservation_t){0},
+      &profile_event_info, out_ready, &submission));
+  if (!*out_ready) return iree_ok_status();
+
+  memcpy(submission.kernel.kernargs.blocks->data, &kernargs, sizeof(kernargs));
+  submission.dispatch_setup =
+      iree_hal_amdgpu_host_queue_write_dispatch_packet_body(
+          &submission.dispatch_slot->dispatch, &dispatch_packet,
+          submission.kernel.kernargs.blocks->data,
+          submission.dispatch_completion_signal);
+  submission.minimum_acquire_scope = minimum_acquire_scope;
+  submission.minimum_release_scope = minimum_release_scope;
+  submission.kernel.pre_signal_action = pre_signal_action;
+  const uint64_t submission_id =
+      iree_hal_amdgpu_host_queue_finish_dispatch_submission(
+          queue, resolution, signal_semaphore_list, operation_resources,
+          operation_resource_count, &profile_event_info, submission_flags,
+          &submission);
+  iree_hal_amdgpu_host_queue_record_submitted_blit_profile_event(
+      queue, resolution, signal_semaphore_list, submission_id,
+      &profile_event_info);
+  return iree_ok_status();
+}
+
+// Validates a queue_update request and resolves the source host span and target
+// device pointer. The source host pointer is captured by the caller either into
+// the pending-op arena or into the queue-owned kernarg ring.
+iree_status_t iree_hal_amdgpu_host_queue_prepare_update_copy(
+    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+    const void* source_buffer, iree_host_size_t source_offset,
+    iree_device_size_t length, iree_hal_update_flags_t flags,
+    const uint8_t** out_source_bytes, iree_host_size_t* out_source_length,
+    uint8_t** out_target_device_ptr) {
+  *out_source_bytes = NULL;
+  *out_source_length = 0;
+  *out_target_device_ptr = NULL;
+
+  if (IREE_UNLIKELY(!source_buffer)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "source buffer must be non-null");
+  }
+  if (IREE_UNLIKELY(!target_buffer)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "target buffer must be non-null");
+  }
+
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_usage(
+      iree_hal_buffer_allowed_usage(target_buffer),
+      IREE_HAL_BUFFER_USAGE_TRANSFER_TARGET));
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_access(
+      iree_hal_buffer_allowed_access(target_buffer),
+      IREE_HAL_MEMORY_ACCESS_WRITE));
+  IREE_RETURN_IF_ERROR(
+      iree_hal_buffer_validate_range(target_buffer, target_offset, length));
+
+  if (IREE_UNLIKELY(flags != IREE_HAL_UPDATE_FLAG_NONE)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "unsupported update flags: 0x%" PRIx64, flags);
+  }
+  if (IREE_UNLIKELY(length > IREE_HOST_SIZE_MAX)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "update length %" PRIdsz
+                            " exceeds host addressable size %" PRIhsz,
+                            length, IREE_HOST_SIZE_MAX);
+  }
+  const iree_host_size_t source_length = (iree_host_size_t)length;
+  iree_host_size_t source_end = 0;
+  if (IREE_UNLIKELY(!iree_host_size_checked_add(source_offset, source_length,
+                                                &source_end))) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "update source span overflows host size (offset=%" PRIhsz
+        ", length=%" PRIhsz ")",
+        source_offset, source_length);
+  }
+  (void)source_end;
+
+  iree_hal_buffer_t* allocated_target_buffer =
+      iree_hal_buffer_allocated_buffer(target_buffer);
+  uint8_t* target_device_ptr =
+      (uint8_t*)iree_hal_amdgpu_buffer_device_pointer(allocated_target_buffer);
+  if (IREE_UNLIKELY(!target_device_ptr)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "target buffer must be backed by an AMDGPU allocation");
+  }
+  target_device_ptr +=
+      iree_hal_buffer_byte_offset(target_buffer) + target_offset;
+
+  *out_source_bytes = (const uint8_t*)source_buffer + source_offset;
+  *out_source_length = source_length;
+  *out_target_device_ptr = target_device_ptr;
+  return iree_ok_status();
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_submit_update(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    const void* source_buffer, iree_host_size_t source_offset,
+    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+    iree_device_size_t length, iree_hal_update_flags_t flags,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    bool* out_ready) {
+  IREE_ASSERT_ARGUMENT(out_ready);
+  *out_ready = false;
+  if (IREE_UNLIKELY(queue->is_shutting_down)) {
+    return iree_make_status(IREE_STATUS_CANCELLED, "queue shutting down");
+  }
+
+  const uint8_t* source_bytes = NULL;
+  iree_host_size_t source_length = 0;
+  uint8_t* target_device_ptr = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_prepare_update_copy(
+      target_buffer, target_offset, source_buffer, source_offset, length, flags,
+      &source_bytes, &source_length, &target_device_ptr));
+
+  iree_hal_amdgpu_host_queue_pm4_write_data_t pm4_write_data;
+  if (iree_hal_amdgpu_host_queue_prepare_pm4_update_write_data(
+          queue, source_bytes, source_length, target_device_ptr,
+          &pm4_write_data)) {
+    iree_hal_amdgpu_host_queue_profile_event_info_t profile_event_info =
+        iree_hal_amdgpu_host_queue_make_blit_profile_event_info(
+            IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_UPDATE, length);
+    uint64_t submission_id = 0;
+    iree_status_t status = iree_hal_amdgpu_host_queue_submit_pm4_write_data(
+        queue, resolution, signal_semaphore_list, target_buffer,
+        &pm4_write_data, &profile_event_info, submission_flags, out_ready,
+        &submission_id);
+    if (iree_status_is_ok(status) && *out_ready) {
+      iree_hal_amdgpu_host_queue_record_submitted_blit_profile_event(
+          queue, resolution, signal_semaphore_list, submission_id,
+          &profile_event_info);
+    }
+    return status;
+  }
+
+  const iree_host_size_t source_payload_offset =
+      IREE_HAL_AMDGPU_DEVICE_BUFFER_COPY_STAGED_SOURCE_OFFSET;
+  iree_host_size_t kernarg_length = 0;
+  if (IREE_UNLIKELY(!iree_host_size_checked_add(
+          source_payload_offset, source_length, &kernarg_length))) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "update staging payload overflows host size (offset=%" PRIhsz
+        ", source_length=%" PRIhsz ")",
+        source_payload_offset, source_length);
+  }
+  const iree_host_size_t kernarg_block_count = iree_host_size_ceil_div(
+      kernarg_length, sizeof(iree_hal_amdgpu_kernarg_block_t));
+  if (IREE_UNLIKELY(kernarg_block_count > UINT32_MAX)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "update staging payload requires too many kernarg blocks (%" PRIhsz
+        ", max=%u)",
+        kernarg_block_count, UINT32_MAX);
+  }
+
+  iree_hsa_kernel_dispatch_packet_t dispatch_packet;
+  memset(&dispatch_packet, 0, sizeof(dispatch_packet));
+  iree_hal_amdgpu_device_buffer_copy_kernargs_t kernargs;
+  memset(&kernargs, 0, sizeof(kernargs));
+  // The eventual staged source pointer is 16-byte aligned by construction. Use
+  // a synthetic aligned pointer for pre-reservation packet-shape selection,
+  // then patch source_ptr to the real ring address after allocation succeeds.
+  if (IREE_UNLIKELY(!iree_hal_amdgpu_device_buffer_copy_emplace(
+          queue->transfer_context, &dispatch_packet,
+          (const void*)(uintptr_t)
+              IREE_HAL_AMDGPU_DEVICE_BUFFER_COPY_STAGED_SOURCE_ALIGNMENT,
+          target_device_ptr, length, &kernargs))) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "unsupported update dispatch shape (target_offset=%" PRIdsz
+        ", length=%" PRIdsz ", source_payload_alignment=%d)",
+        target_offset, length,
+        IREE_HAL_AMDGPU_DEVICE_BUFFER_COPY_STAGED_SOURCE_ALIGNMENT);
+  }
+
+  iree_hal_amdgpu_host_queue_dispatch_submission_t submission;
+  iree_hal_amdgpu_host_queue_profile_event_info_t profile_event_info =
+      iree_hal_amdgpu_host_queue_make_blit_profile_event_info(
+          IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_UPDATE, length);
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_try_begin_dispatch_submission(
+      queue, resolution, signal_semaphore_list,
+      /*operation_resource_count=*/1, (uint32_t)kernarg_block_count,
+      (iree_hal_amdgpu_profile_dispatch_event_reservation_t){0},
+      &profile_event_info, out_ready, &submission));
+  if (!*out_ready) return iree_ok_status();
+
+  uint8_t* staged_source_bytes =
+      (uint8_t*)submission.kernel.kernargs.blocks + source_payload_offset;
+  memcpy(submission.kernel.kernargs.blocks->data, &kernargs, sizeof(kernargs));
+  ((iree_hal_amdgpu_device_buffer_copy_kernargs_t*)
+       submission.kernel.kernargs.blocks->data)
+      ->source_ptr = staged_source_bytes;
+  memcpy(staged_source_bytes, source_bytes, source_length);
+  submission.dispatch_setup =
+      iree_hal_amdgpu_host_queue_write_dispatch_packet_body(
+          &submission.dispatch_slot->dispatch, &dispatch_packet,
+          submission.kernel.kernargs.blocks->data,
+          submission.dispatch_completion_signal);
+
+  iree_hal_resource_t* operation_resources[1] = {
+      (iree_hal_resource_t*)target_buffer,
+  };
+  const uint64_t submission_id =
+      iree_hal_amdgpu_host_queue_finish_dispatch_submission(
+          queue, resolution, signal_semaphore_list, operation_resources,
+          IREE_ARRAYSIZE(operation_resources), &profile_event_info,
+          submission_flags, &submission);
+  iree_hal_amdgpu_host_queue_record_submitted_blit_profile_event(
+      queue, resolution, signal_semaphore_list, submission_id,
+      &profile_event_info);
+  return iree_ok_status();
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_blit.h b/runtime/src/iree/hal/drivers/amdgpu/host_queue_blit.h
new file mode 100644
index 0000000..8f942db
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_blit.h
@@ -0,0 +1,81 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_BLIT_H_
+#define IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_BLIT_H_
+
+#include "iree/hal/drivers/amdgpu/host_queue_submission.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Validates a queue_update request and resolves the source host span and target
+// device pointer. The source host pointer is captured by the caller either into
+// the pending-op arena or into the queue-owned kernarg ring.
+iree_status_t iree_hal_amdgpu_host_queue_prepare_update_copy(
+    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+    const void* source_buffer, iree_host_size_t source_offset,
+    iree_device_size_t length, iree_hal_update_flags_t flags,
+    const uint8_t** out_source_bytes, iree_host_size_t* out_source_length,
+    uint8_t** out_target_device_ptr);
+
+// Emits a fill blit kernel submission. Caller must hold submission_mutex.
+iree_status_t iree_hal_amdgpu_host_queue_submit_fill(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+    iree_device_size_t length, uint64_t pattern_bits,
+    iree_host_size_t pattern_length, iree_hal_fill_flags_t flags,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    bool* out_ready);
+
+// Emits a copy blit kernel submission. Caller must hold submission_mutex.
+iree_status_t iree_hal_amdgpu_host_queue_submit_copy(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
+    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+    iree_device_size_t length, iree_hal_copy_flags_t flags,
+    iree_hal_profile_queue_event_type_t profile_event_type,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    bool* out_ready);
+
+// Emits a copy blit kernel submission with additional completion behavior.
+// Caller must hold submission_mutex.
+iree_status_t iree_hal_amdgpu_host_queue_submit_copy_with_action(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
+    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+    iree_device_size_t length, iree_hal_copy_flags_t flags,
+    iree_hsa_fence_scope_t minimum_acquire_scope,
+    iree_hsa_fence_scope_t minimum_release_scope,
+    iree_hal_amdgpu_reclaim_action_t pre_signal_action,
+    iree_hal_resource_t* const* extra_operation_resources,
+    iree_host_size_t extra_operation_resource_count,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    bool* out_ready);
+
+// Emits an update blit kernel submission. Caller must hold submission_mutex.
+iree_status_t iree_hal_amdgpu_host_queue_submit_update(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    const void* source_buffer, iree_host_size_t source_offset,
+    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+    iree_device_size_t length, iree_hal_update_flags_t flags,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    bool* out_ready);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_BLIT_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer.c b/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer.c
new file mode 100644
index 0000000..704e536
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer.c
@@ -0,0 +1,163 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/host_queue_command_buffer.h"
+
+#include "iree/hal/drivers/amdgpu/aql_command_buffer.h"
+#include "iree/hal/drivers/amdgpu/aql_program_validation.h"
+#include "iree/hal/drivers/amdgpu/host_queue_command_buffer_block.h"
+#include "iree/hal/drivers/amdgpu/host_queue_command_buffer_replay.h"
+#include "iree/hal/utils/resource_set.h"
+
+iree_status_t iree_hal_amdgpu_host_queue_validate_execute_flags(
+    iree_hal_execute_flags_t flags) {
+  const iree_hal_execute_flags_t supported_flags =
+      IREE_HAL_EXECUTE_FLAG_BORROW_BINDING_TABLE_LIFETIME;
+  if (IREE_UNLIKELY(iree_any_bit_set(flags, ~supported_flags))) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "unsupported execute flags: 0x%" PRIx64, flags);
+  }
+  return iree_ok_status();
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_create_binding_table_resource_set(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_command_buffer_t* command_buffer,
+    iree_hal_buffer_binding_table_t binding_table,
+    iree_hal_execute_flags_t execute_flags,
+    iree_hal_resource_set_t** out_resource_set) {
+  *out_resource_set = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_host_queue_validate_execute_flags(execute_flags));
+  if (!command_buffer || command_buffer->binding_count == 0) {
+    return iree_ok_status();
+  }
+  if (IREE_UNLIKELY(binding_table.count == 0)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "indirect command buffer requires at least %u "
+                            "bindings but no binding table was provided",
+                            command_buffer->binding_count);
+  }
+  if (IREE_UNLIKELY(binding_table.count < command_buffer->binding_count)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "indirect command buffer requires at least %u bindings but only "
+        "%" PRIhsz " were provided",
+        command_buffer->binding_count, binding_table.count);
+  }
+  if (IREE_UNLIKELY(!binding_table.bindings)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "indirect command buffer binding table storage is "
+                            "NULL for %" PRIhsz " bindings",
+                            binding_table.count);
+  }
+  if (iree_any_bit_set(execute_flags,
+                       IREE_HAL_EXECUTE_FLAG_BORROW_BINDING_TABLE_LIFETIME)) {
+    return iree_ok_status();
+  }
+
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, command_buffer->binding_count);
+  iree_hal_resource_set_t* resource_set = NULL;
+  iree_status_t status =
+      iree_hal_resource_set_allocate(queue->block_pool, &resource_set);
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_resource_set_insert_strided(
+        resource_set, command_buffer->binding_count, binding_table.bindings,
+        offsetof(iree_hal_buffer_binding_t, buffer),
+        sizeof(iree_hal_buffer_binding_t));
+  }
+  if (iree_status_is_ok(status)) {
+    iree_hal_resource_set_freeze(resource_set);
+    *out_resource_set = resource_set;
+  } else {
+    iree_hal_resource_set_free(resource_set);
+  }
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_submit_command_buffer(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_command_buffer_t* command_buffer,
+    iree_hal_buffer_binding_table_t binding_table,
+    iree_hal_execute_flags_t execute_flags,
+    iree_hal_resource_set_t** inout_binding_resource_set, bool* out_ready) {
+  IREE_ASSERT_ARGUMENT(queue);
+  IREE_ASSERT_ARGUMENT(resolution);
+  IREE_ASSERT_ARGUMENT(out_ready);
+  *out_ready = false;
+
+  if (IREE_UNLIKELY(queue->is_shutting_down)) {
+    return iree_make_status(IREE_STATUS_CANCELLED, "queue shutting down");
+  }
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_host_queue_validate_execute_flags(execute_flags));
+  if (IREE_UNLIKELY(!command_buffer)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "command buffer is required");
+  }
+  if (IREE_UNLIKELY(!iree_hal_amdgpu_aql_command_buffer_isa(command_buffer))) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "command buffer is not an AMDGPU AQL command "
+                            "buffer");
+  }
+  const iree_host_size_t command_buffer_device_ordinal =
+      iree_hal_amdgpu_aql_command_buffer_device_ordinal(command_buffer);
+  if (IREE_UNLIKELY(command_buffer_device_ordinal != queue->device_ordinal)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "command buffer recorded for physical device %" PRIhsz
+        " cannot execute on physical device %" PRIhsz,
+        command_buffer_device_ordinal, queue->device_ordinal);
+  }
+
+  const iree_hal_amdgpu_aql_program_t* program =
+      iree_hal_amdgpu_aql_command_buffer_program(command_buffer);
+  if (IREE_UNLIKELY(!program->first_block)) {
+    return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                            "command buffer has not been finalized");
+  }
+
+  const bool requires_replay =
+      program->max_block_aql_packet_count == 0 || program->block_count != 1;
+  if (requires_replay && program->max_block_aql_packet_count == 0) {
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_aql_program_validate_metadata_only(program));
+  }
+  if (requires_replay) {
+    iree_status_t status =
+        iree_hal_amdgpu_command_buffer_replay_start_under_lock(
+            queue, resolution, signal_semaphore_list, command_buffer,
+            binding_table, execute_flags, inout_binding_resource_set);
+    if (iree_status_is_ok(status)) *out_ready = true;
+    return status;
+  }
+  if (!*inout_binding_resource_set) {
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_host_queue_create_binding_table_resource_set(
+            queue, command_buffer, binding_table, execute_flags,
+            inout_binding_resource_set));
+  }
+  iree_hal_resource_t* command_buffer_resource =
+      (iree_hal_resource_t*)command_buffer;
+  bool ready = false;
+  iree_status_t status = iree_hal_amdgpu_host_queue_submit_command_buffer_block(
+      queue, resolution, signal_semaphore_list, command_buffer, binding_table,
+      /*binding_ptrs=*/NULL, program->first_block, inout_binding_resource_set,
+      (iree_hal_amdgpu_reclaim_action_t){0}, &command_buffer_resource,
+      /*operation_resource_count=*/1,
+      IREE_HAL_AMDGPU_HOST_QUEUE_SUBMISSION_FLAG_RETAIN_RESOURCES, &ready);
+  if (!iree_status_is_ok(status)) {
+    iree_hal_resource_set_free(*inout_binding_resource_set);
+    *inout_binding_resource_set = NULL;
+  } else {
+    *out_ready = ready;
+  }
+  return status;
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer.h b/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer.h
new file mode 100644
index 0000000..927400f
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer.h
@@ -0,0 +1,45 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_H_
+#define IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_H_
+
+#include "iree/hal/drivers/amdgpu/abi/command_buffer.h"
+#include "iree/hal/drivers/amdgpu/host_queue_submission.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Validates queue_execute flags supported by the AMDGPU host queue.
+iree_status_t iree_hal_amdgpu_host_queue_validate_execute_flags(
+    iree_hal_execute_flags_t flags);
+
+// Creates a resource set retaining the binding table prefix required by
+// |command_buffer| unless |execute_flags| explicitly borrows buffer lifetimes.
+iree_status_t iree_hal_amdgpu_host_queue_create_binding_table_resource_set(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_command_buffer_t* command_buffer,
+    iree_hal_buffer_binding_table_t binding_table,
+    iree_hal_execute_flags_t execute_flags,
+    iree_hal_resource_set_t** out_resource_set);
+
+// Replays an AMDGPU AQL command buffer program onto the host queue.
+// Caller must hold submission_mutex.
+iree_status_t iree_hal_amdgpu_host_queue_submit_command_buffer(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_command_buffer_t* command_buffer,
+    iree_hal_buffer_binding_table_t binding_table,
+    iree_hal_execute_flags_t execute_flags,
+    iree_hal_resource_set_t** inout_binding_resource_set, bool* out_ready);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer_block.c b/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer_block.c
new file mode 100644
index 0000000..c12d465
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer_block.c
@@ -0,0 +1,816 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/host_queue_command_buffer_block.h"
+
+#include <string.h>
+
+#include "iree/hal/drivers/amdgpu/aql_block_processor.h"
+#include "iree/hal/drivers/amdgpu/aql_block_processor_profile.h"
+#include "iree/hal/drivers/amdgpu/aql_command_buffer.h"
+#include "iree/hal/drivers/amdgpu/buffer.h"
+#include "iree/hal/drivers/amdgpu/device/timestamp.h"
+#include "iree/hal/drivers/amdgpu/host_queue_command_buffer_packet.h"
+#include "iree/hal/drivers/amdgpu/host_queue_command_buffer_profile.h"
+#include "iree/hal/drivers/amdgpu/host_queue_command_buffer_scratch.h"
+#include "iree/hal/drivers/amdgpu/host_queue_policy.h"
+#include "iree/hal/drivers/amdgpu/host_queue_profile.h"
+#include "iree/hal/drivers/amdgpu/host_queue_timestamp.h"
+#include "iree/hal/drivers/amdgpu/profile_counters.h"
+#include "iree/hal/drivers/amdgpu/profile_traces.h"
+#include "iree/hal/drivers/amdgpu/util/aql_emitter.h"
+#include "iree/hal/utils/resource_set.h"
+
+static iree_status_t iree_hal_amdgpu_host_queue_ensure_command_buffer_scratch(
+    iree_hal_amdgpu_host_queue_t* queue) {
+  if (queue->command_buffer_scratch) return iree_ok_status();
+  iree_hal_amdgpu_host_queue_command_buffer_scratch_t* scratch = NULL;
+  IREE_RETURN_IF_ERROR(iree_allocator_malloc(
+      queue->host_allocator, sizeof(*scratch), (void**)&scratch));
+  memset(scratch, 0, sizeof(*scratch));
+  queue->command_buffer_scratch = scratch;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_resolve_binding_base_ptr(
+    const iree_hal_buffer_binding_t* binding, uint64_t* out_binding_ptr) {
+  *out_binding_ptr = 0;
+  if (IREE_UNLIKELY(!binding->buffer)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "dispatch binding table entry is NULL");
+  }
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_memory_type(
+      iree_hal_buffer_memory_type(binding->buffer),
+      IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE));
+  const iree_device_size_t binding_length =
+      binding->length == IREE_HAL_WHOLE_BUFFER ? 0 : binding->length;
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_range(
+      binding->buffer, binding->offset, binding_length));
+
+  iree_hal_buffer_t* allocated_buffer =
+      iree_hal_buffer_allocated_buffer(binding->buffer);
+  void* device_ptr = iree_hal_amdgpu_buffer_device_pointer(allocated_buffer);
+  if (IREE_UNLIKELY(!device_ptr)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "dispatch binding table entry must be backed by an AMDGPU allocation");
+  }
+  iree_device_size_t device_offset = 0;
+  if (IREE_UNLIKELY(!iree_device_size_checked_add(
+          iree_hal_buffer_byte_offset(binding->buffer), binding->offset,
+          &device_offset))) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "dispatch binding table device pointer offset overflows device size");
+  }
+  if (IREE_UNLIKELY(device_offset > UINTPTR_MAX)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "dispatch binding table device pointer offset exceeds host pointer "
+        "size");
+  }
+  *out_binding_ptr = (uint64_t)((uintptr_t)device_ptr + device_offset);
+  return iree_ok_status();
+}
+
+static iree_status_t
+iree_hal_amdgpu_host_queue_prepare_command_buffer_binding_ptrs(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_command_buffer_t* command_buffer,
+    iree_hal_buffer_binding_table_t binding_table,
+    iree_arena_allocator_t* overflow_arena, const uint64_t** out_binding_ptrs) {
+  *out_binding_ptrs = NULL;
+  const uint32_t binding_count = command_buffer->binding_count;
+  if (binding_count == 0) return iree_ok_status();
+  uint64_t* binding_ptrs = NULL;
+  if (binding_count <=
+      IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_BINDING_SCRATCH_CAPACITY) {
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_host_queue_ensure_command_buffer_scratch(queue));
+    binding_ptrs = queue->command_buffer_scratch->bindings.ptrs;
+  } else {
+    iree_host_size_t binding_ptr_bytes = 0;
+    IREE_RETURN_IF_ERROR(
+        IREE_STRUCT_LAYOUT(0, &binding_ptr_bytes,
+                           IREE_STRUCT_FIELD(binding_count, uint64_t, NULL)));
+    IREE_TRACE_ZONE_BEGIN(z0);
+    IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, binding_count);
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_arena_allocate(overflow_arena, binding_ptr_bytes,
+                                (void**)&binding_ptrs));
+    IREE_TRACE_ZONE_END(z0);
+  }
+
+  iree_status_t status =
+      iree_hal_amdgpu_host_queue_resolve_command_buffer_binding_ptrs(
+          command_buffer, binding_table, binding_ptrs);
+  if (iree_status_is_ok(status)) {
+    *out_binding_ptrs = binding_ptrs;
+  }
+  return status;
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_resolve_command_buffer_binding_ptrs(
+    iree_hal_command_buffer_t* command_buffer,
+    iree_hal_buffer_binding_table_t binding_table, uint64_t* out_binding_ptrs) {
+  const uint32_t binding_count = command_buffer->binding_count;
+  if (binding_count == 0) return iree_ok_status();
+  if (IREE_UNLIKELY(binding_table.count < binding_count)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "queue_execute binding table count %" PRIhsz
+                            " is smaller than command-buffer binding count %u",
+                            binding_table.count, binding_count);
+  }
+  if (IREE_UNLIKELY(!binding_table.bindings)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "queue_execute binding table storage is NULL for %u bindings",
+        binding_count);
+  }
+
+  iree_status_t status = iree_ok_status();
+  for (uint32_t i = 0; i < binding_count && iree_status_is_ok(status); ++i) {
+    status = iree_hal_amdgpu_host_queue_resolve_binding_base_ptr(
+        &binding_table.bindings[i], &out_binding_ptrs[i]);
+    if (!iree_status_is_ok(status)) {
+      status = iree_status_annotate_f(status, "binding_table[%" PRIu32 "]", i);
+    }
+  }
+  return status;
+}
+
+static iree_status_t
+iree_hal_amdgpu_host_queue_prepare_command_buffer_packet_metadata(
+    iree_hal_amdgpu_host_queue_t* queue, uint32_t packet_count,
+    iree_arena_allocator_t* scratch_arena, uint16_t** out_packet_headers,
+    uint16_t** out_packet_setups) {
+  *out_packet_headers = NULL;
+  *out_packet_setups = NULL;
+
+  uint16_t* packet_headers = NULL;
+  uint16_t* packet_setups = NULL;
+  if (packet_count <=
+      IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_SCRATCH_CAPACITY) {
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_host_queue_ensure_command_buffer_scratch(queue));
+    packet_headers = queue->command_buffer_scratch->packets.headers;
+    packet_setups = queue->command_buffer_scratch->packets.setups;
+  } else {
+    iree_host_size_t packet_metadata_bytes = 0;
+    IREE_RETURN_IF_ERROR(
+        IREE_STRUCT_LAYOUT(0, &packet_metadata_bytes,
+                           IREE_STRUCT_FIELD(packet_count, uint16_t, NULL)));
+    IREE_TRACE_ZONE_BEGIN(z0);
+    IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, packet_count);
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_arena_allocate(scratch_arena, packet_metadata_bytes,
+                                (void**)&packet_headers));
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_arena_allocate(scratch_arena, packet_metadata_bytes,
+                                (void**)&packet_setups));
+    IREE_TRACE_ZONE_END(z0);
+  }
+
+  // The block processors overwrite every reserved metadata slot before
+  // reporting success; avoid pre-zeroing this hot submission span.
+  *out_packet_headers = packet_headers;
+  *out_packet_setups = packet_setups;
+  return iree_ok_status();
+}
+
+static bool iree_hal_amdgpu_host_queue_command_buffer_packet_has_barrier(
+    const iree_hal_amdgpu_wait_resolution_t* resolution, uint32_t packet_index,
+    iree_hal_amdgpu_host_queue_command_buffer_packet_flags_t packet_flags) {
+  return iree_any_bit_set(
+             packet_flags,
+             IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_FLAG_EXECUTION_BARRIER) ||
+         iree_any_bit_set(
+             packet_flags,
+             IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_FLAG_FINAL) ||
+         (packet_index == 0 && resolution->barrier_count > 0) ||
+         (packet_index == 0 &&
+          resolution->inline_acquire_scope != IREE_HSA_FENCE_SCOPE_NONE);
+}
+
+static iree_hsa_fence_scope_t
+iree_hal_amdgpu_host_queue_command_buffer_block_payload_acquire_scope(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_command_buffer_block_header_t* block) {
+  if (block->kernarg_length == 0) return IREE_HSA_FENCE_SCOPE_NONE;
+  return iree_hal_amdgpu_host_queue_kernarg_acquire_scope(
+      queue, IREE_HSA_FENCE_SCOPE_NONE);
+}
+
+static uint32_t
+iree_hal_amdgpu_host_queue_command_buffer_block_payload_acquire_packet_count(
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_amdgpu_command_buffer_block_header_t* block,
+    uint32_t packet_index_base, iree_hsa_fence_scope_t payload_acquire_scope) {
+  if (payload_acquire_scope == IREE_HSA_FENCE_SCOPE_NONE) return 0;
+  if (block->aql_packet_count == 0) return 0;
+  if (iree_hal_amdgpu_host_queue_command_buffer_packet_has_barrier(
+          resolution, packet_index_base,
+          block->aql_packet_count == 1
+              ? IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_FLAG_FINAL
+              : IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_FLAG_NONE)) {
+    return 1;
+  }
+  return block->initial_barrier_packet_count;
+}
+
+static uint32_t iree_hal_amdgpu_host_queue_aql_packet_header_field(
+    uint16_t header, uint32_t field_shift, uint32_t field_width) {
+  return (header >> field_shift) & ((1u << field_width) - 1u);
+}
+
+static iree_hsa_packet_type_t iree_hal_amdgpu_host_queue_aql_packet_header_type(
+    uint16_t header) {
+  return (iree_hsa_packet_type_t)
+      iree_hal_amdgpu_host_queue_aql_packet_header_field(
+          header, IREE_HSA_PACKET_HEADER_TYPE,
+          IREE_HSA_PACKET_HEADER_WIDTH_TYPE);
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_write_command_buffer_block(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_command_buffer_t* command_buffer,
+    iree_hal_buffer_binding_table_t binding_table,
+    const iree_hal_amdgpu_command_buffer_block_header_t* block,
+    const uint64_t* binding_ptrs, uint64_t first_payload_packet_id,
+    uint32_t packet_index_base, iree_hal_amdgpu_kernarg_block_t* kernarg_blocks,
+    uint16_t* packet_headers, uint16_t* packet_setups,
+    iree_hal_amdgpu_profile_dispatch_event_reservation_t profile_events,
+    uint32_t emitted_packet_count, uint32_t profile_counter_set_count,
+    uint32_t profile_trace_packet_count,
+    iree_hal_amdgpu_profile_dispatch_harvest_source_t* profile_harvest_sources,
+    iree_hal_amdgpu_aql_block_processor_profile_dispatch_list_t
+        profile_dispatches) {
+  const iree_hsa_fence_scope_t payload_acquire_scope =
+      iree_hal_amdgpu_host_queue_command_buffer_block_payload_acquire_scope(
+          queue, block);
+  const bool use_base_processor = profile_events.event_count == 0 &&
+                                  profile_counter_set_count == 0 &&
+                                  profile_trace_packet_count == 0;
+  if (use_base_processor) {
+    iree_hal_amdgpu_aql_block_processor_t processor;
+    const iree_hal_amdgpu_aql_block_processor_t processor_params = {
+        .transfer_context = queue->transfer_context,
+        .command_buffer = command_buffer,
+        .bindings =
+            {
+                .table = binding_table,
+                .ptrs = binding_ptrs,
+            },
+        .packets =
+            {
+                .ring = &queue->aql_ring,
+                .first_id = first_payload_packet_id,
+                .index_base = packet_index_base,
+                .count = emitted_packet_count,
+                .headers = packet_headers,
+                .setups = packet_setups,
+            },
+        .kernargs =
+            {
+                .blocks = kernarg_blocks,
+                .count = (uint32_t)iree_host_size_ceil_div(
+                    block->kernarg_length,
+                    sizeof(iree_hal_amdgpu_kernarg_block_t)),
+            },
+        .submission =
+            {
+                .wait_barrier_count = resolution->barrier_count,
+                .inline_acquire_scope = resolution->inline_acquire_scope,
+                .signal_release_scope =
+                    iree_hal_amdgpu_host_queue_signal_list_release_scope(
+                        queue, signal_semaphore_list),
+            },
+        .payload =
+            {
+                .acquire_scope = payload_acquire_scope,
+            },
+        .flags =
+            packet_index_base == 0
+                ? IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_FLAG_FINAL_PAYLOAD_PACKET
+                : IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_FLAG_NONE,
+    };
+    iree_hal_amdgpu_aql_block_processor_initialize(&processor_params,
+                                                   &processor);
+    iree_hal_amdgpu_aql_block_processor_result_t result;
+    iree_status_t status =
+        iree_hal_amdgpu_aql_block_processor_invoke(&processor, block, &result);
+    iree_hal_amdgpu_aql_block_processor_deinitialize(&processor);
+    return status;
+  }
+  // Per-dispatch counter and trace packets are emitted before the recorded
+  // payload packet they wrap. Do not let a submit-time barrier on logical
+  // packet 0 shrink the recorded payload acquire span when that first logical
+  // packet is profiling metadata instead of the recorded payload stream.
+  const uint32_t first_recorded_packet_index_base =
+      profile_counter_set_count == 0 && profile_trace_packet_count == 0
+          ? packet_index_base
+          : 1u;
+  const uint32_t payload_acquire_packet_count =
+      iree_hal_amdgpu_host_queue_command_buffer_block_payload_acquire_packet_count(
+          resolution, block, first_recorded_packet_index_base,
+          payload_acquire_scope);
+  iree_hal_amdgpu_aql_block_processor_profile_flags_t profile_flags =
+      IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PROFILE_FLAG_NONE;
+  if (profile_events.event_count != 0) {
+    profile_flags |=
+        IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PROFILE_FLAG_DISPATCH_PACKETS;
+  }
+  if (packet_index_base != 0) {
+    profile_flags |=
+        IREE_HAL_AMDGPU_AQL_BLOCK_PROCESSOR_PROFILE_FLAG_QUEUE_DEVICE_EVENT;
+  }
+  iree_hal_amdgpu_aql_block_processor_profile_t processor;
+  const iree_hal_amdgpu_aql_block_processor_profile_t processor_params = {
+      .queue = queue,
+      .command_buffer = command_buffer,
+      .block = block,
+      .submission =
+          {
+              .resolution = resolution,
+              .signal_semaphore_list = signal_semaphore_list,
+          },
+      .bindings =
+          {
+              .table = binding_table,
+              .ptrs = binding_ptrs,
+          },
+      .packets =
+          {
+              .first_payload_id = first_payload_packet_id,
+              .index_base = packet_index_base,
+              .count = emitted_packet_count,
+              .headers = packet_headers,
+              .setups = packet_setups,
+          },
+      .kernargs =
+          {
+              .blocks = kernarg_blocks,
+              .count = (uint32_t)iree_host_size_ceil_div(
+                  block->kernarg_length,
+                  sizeof(iree_hal_amdgpu_kernarg_block_t)),
+          },
+      .payload =
+          {
+              .acquire_scope = payload_acquire_scope,
+              .acquire_packet_count = payload_acquire_packet_count,
+          },
+      .profile =
+          {
+              .dispatches = profile_dispatches,
+              .dispatch_events = profile_events,
+              .harvest_sources = profile_harvest_sources,
+              .command_buffer_id =
+                  iree_hal_amdgpu_aql_command_buffer_profile_id(command_buffer),
+              .counter_set_count = profile_counter_set_count,
+              .trace_packet_count = profile_trace_packet_count,
+          },
+      .flags = profile_flags,
+  };
+  iree_hal_amdgpu_aql_block_processor_profile_initialize(&processor_params,
+                                                         &processor);
+  iree_hal_amdgpu_aql_block_processor_profile_result_t result;
+  iree_status_t status =
+      iree_hal_amdgpu_aql_block_processor_profile_invoke(&processor, &result);
+  iree_hal_amdgpu_aql_block_processor_profile_deinitialize(&processor);
+  return status;
+}
+
+typedef uint32_t
+    iree_hal_amdgpu_host_queue_command_buffer_profile_submission_flags_t;
+enum iree_hal_amdgpu_host_queue_command_buffer_profile_submission_flag_bits_t {
+  IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PROFILE_SUBMISSION_FLAG_NONE = 0u,
+  // The non-profiling path needs a trailing barrier packet to own queue
+  // completion.
+  IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PROFILE_SUBMISSION_FLAG_TRAILING_COMPLETION_PACKET =
+      1u << 0,
+};
+
+typedef struct iree_hal_amdgpu_host_queue_command_buffer_profile_submission_t {
+  // Reserved dispatch timestamp events for profiled commands in this block.
+  iree_hal_amdgpu_profile_dispatch_event_reservation_t dispatch_events;
+  // Reserved whole-block queue-device timestamp event for this execute.
+  iree_hal_amdgpu_profile_queue_device_event_reservation_t queue_device_events;
+  // Host queue event metadata shared by host and device queue-event records.
+  iree_hal_amdgpu_host_queue_profile_event_info_t queue_event_info;
+  // Optional harvest dispatch emitted after profiled dispatch payloads.
+  struct {
+    // Dispatch packet for harvesting dispatch timestamp records.
+    iree_hal_amdgpu_aql_packet_t* packet;
+    // Setup bits for |packet| when present.
+    uint16_t setup;
+  } harvest;
+  // Flags from
+  // iree_hal_amdgpu_host_queue_command_buffer_profile_submission_flag_bits_t.
+  iree_hal_amdgpu_host_queue_command_buffer_profile_submission_flags_t flags;
+} iree_hal_amdgpu_host_queue_command_buffer_profile_submission_t;
+
+// Publishes the non-profiling terminal barrier packet for a replayed
+// command-buffer block. Payload packets keep their recorded final-payload
+// barriers, but queue completion is signaled from this trailing packet so
+// software observes block completion only after the CP reaches the end of the
+// replay span.
+static void iree_hal_amdgpu_host_queue_commit_command_buffer_completion_packet(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list, uint64_t packet_id,
+    uint32_t packet_index) {
+  iree_hal_amdgpu_aql_packet_t* packet =
+      iree_hal_amdgpu_aql_ring_packet(&queue->aql_ring, packet_id);
+  const uint16_t header = iree_hal_amdgpu_aql_emit_nop(
+      &packet->barrier_and,
+      iree_hal_amdgpu_host_queue_command_buffer_packet_control(
+          queue, resolution, signal_semaphore_list, packet_index,
+          IREE_HSA_FENCE_SCOPE_NONE,
+          IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_FLAG_FINAL),
+      iree_hal_amdgpu_notification_ring_epoch_signal(
+          &queue->notification_ring));
+  iree_hal_amdgpu_aql_ring_commit(packet, header, /*setup=*/0);
+}
+
+static uint64_t iree_hal_amdgpu_host_queue_finish_command_buffer_block(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    uint32_t emitted_packet_count,
+    iree_hal_resource_set_t** inout_binding_resource_set,
+    iree_hal_amdgpu_reclaim_action_t pre_signal_action,
+    iree_hal_resource_t* const* operation_resources,
+    iree_host_size_t operation_resource_count,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    iree_hal_amdgpu_host_queue_kernel_submission_t* submission,
+    const uint16_t* packet_headers, const uint16_t* packet_setups,
+    iree_hal_amdgpu_host_queue_command_buffer_profile_submission_t* profile) {
+  submission->pre_signal_action = pre_signal_action;
+  iree_hal_amdgpu_host_queue_emit_kernel_submission_prefix(queue, resolution,
+                                                           submission);
+  const uint64_t submission_epoch =
+      iree_hal_amdgpu_host_queue_finish_kernel_submission(
+          queue, resolution, signal_semaphore_list, operation_resources,
+          operation_resource_count, inout_binding_resource_set,
+          submission_flags, submission);
+
+  iree_hal_amdgpu_profile_queue_device_event_t* queue_device_event =
+      iree_hal_amdgpu_host_queue_initialize_profile_queue_device_event(
+          queue, profile->queue_device_events, &profile->queue_event_info);
+  if (queue_device_event) {
+    submission->reclaim_entry->queue_device_event_first_position =
+        profile->queue_device_events.first_event_position;
+    submission->reclaim_entry->queue_device_event_count =
+        profile->queue_device_events.event_count;
+    queue_device_event->submission_id = submission_epoch;
+  }
+
+  uint16_t profile_harvest_header = 0;
+  if (profile->dispatch_events.event_count != 0) {
+    submission->reclaim_entry->profile_event_first_position =
+        profile->dispatch_events.first_event_position;
+    submission->reclaim_entry->profile_event_count =
+        profile->dispatch_events.event_count;
+    for (uint32_t i = 0; i < profile->dispatch_events.event_count; ++i) {
+      iree_hal_amdgpu_profile_dispatch_event_t* event =
+          iree_hal_amdgpu_host_queue_profile_dispatch_event_at(
+              queue, profile->dispatch_events.first_event_position + i);
+      event->submission_id = submission_epoch;
+    }
+    profile->harvest.packet->dispatch.completion_signal =
+        queue_device_event ? iree_hsa_signal_null()
+                           : iree_hal_amdgpu_notification_ring_epoch_signal(
+                                 &queue->notification_ring);
+    const iree_hsa_fence_scope_t profile_harvest_acquire_scope =
+        iree_hal_amdgpu_host_queue_kernarg_acquire_scope(
+            queue, IREE_HSA_FENCE_SCOPE_AGENT);
+    profile_harvest_header = iree_hal_amdgpu_aql_make_header(
+        IREE_HSA_PACKET_TYPE_KERNEL_DISPATCH,
+        iree_hal_amdgpu_aql_packet_control_barrier(
+            iree_hal_amdgpu_host_queue_max_fence_scope(
+                profile_harvest_acquire_scope,
+                resolution->inline_acquire_scope),
+            IREE_HSA_FENCE_SCOPE_SYSTEM));
+  }
+
+  const uint32_t profile_queue_device_prefix_packet_count =
+      queue_device_event ? 1u : 0u;
+  const uint64_t first_payload_packet_id =
+      submission->first_packet_id + resolution->barrier_count +
+      profile_queue_device_prefix_packet_count;
+  iree_hal_amdgpu_host_queue_publish_submission_kernargs(queue, submission);
+  if (queue_device_event) {
+    const uint64_t start_packet_id =
+        submission->first_packet_id + resolution->barrier_count;
+    iree_hal_amdgpu_host_queue_commit_timestamp_start(
+        queue, start_packet_id,
+        iree_hal_amdgpu_host_queue_command_buffer_packet_control(
+            queue, resolution, signal_semaphore_list, /*packet_index=*/0,
+            IREE_HSA_FENCE_SCOPE_NONE,
+            IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_FLAG_NONE),
+        &queue_device_event->start_tick);
+  }
+
+  for (uint32_t i = 0; i < emitted_packet_count; ++i) {
+    iree_hal_amdgpu_aql_packet_t* packet = iree_hal_amdgpu_aql_ring_packet(
+        &queue->aql_ring, first_payload_packet_id + i);
+    if (iree_hal_amdgpu_host_queue_aql_packet_header_type(packet_headers[i]) !=
+        IREE_HSA_PACKET_TYPE_INVALID) {
+      iree_hal_amdgpu_aql_ring_commit(packet, packet_headers[i],
+                                      packet_setups[i]);
+    }
+  }
+  if (profile->dispatch_events.event_count != 0) {
+    iree_hal_amdgpu_aql_ring_commit(profile->harvest.packet,
+                                    profile_harvest_header,
+                                    profile->harvest.setup);
+  }
+  if (iree_any_bit_set(
+          profile->flags,
+          IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PROFILE_SUBMISSION_FLAG_TRAILING_COMPLETION_PACKET)) {
+    const uint32_t profile_harvest_packet_count =
+        profile->dispatch_events.event_count != 0 ? 1u : 0u;
+    const uint64_t completion_packet_id = first_payload_packet_id +
+                                          emitted_packet_count +
+                                          profile_harvest_packet_count;
+    const uint32_t completion_packet_index =
+        profile_queue_device_prefix_packet_count + emitted_packet_count +
+        profile_harvest_packet_count;
+    iree_hal_amdgpu_host_queue_commit_command_buffer_completion_packet(
+        queue, resolution, signal_semaphore_list, completion_packet_id,
+        completion_packet_index);
+  }
+  if (queue_device_event) {
+    const uint64_t end_packet_id =
+        first_payload_packet_id + emitted_packet_count +
+        (profile->dispatch_events.event_count != 0 ? 1u : 0u);
+    const uint32_t end_packet_index =
+        profile_queue_device_prefix_packet_count + emitted_packet_count +
+        (profile->dispatch_events.event_count != 0 ? 1u : 0u);
+    iree_hal_amdgpu_host_queue_commit_timestamp_end(
+        queue, end_packet_id,
+        iree_hal_amdgpu_host_queue_command_buffer_packet_control(
+            queue, resolution, signal_semaphore_list, end_packet_index,
+            IREE_HSA_FENCE_SCOPE_NONE,
+            IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_FLAG_FINAL),
+        iree_hal_amdgpu_notification_ring_epoch_signal(
+            &queue->notification_ring),
+        &queue_device_event->end_tick);
+  }
+  iree_hal_amdgpu_aql_ring_doorbell(
+      &queue->aql_ring,
+      submission->first_packet_id + submission->packet_count - 1);
+  profile->queue_event_info.submission_id = submission_epoch;
+  memset(submission, 0, sizeof(*submission));
+  return submission_epoch;
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_submit_command_buffer_block(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_command_buffer_t* command_buffer,
+    iree_hal_buffer_binding_table_t binding_table, const uint64_t* binding_ptrs,
+    const iree_hal_amdgpu_command_buffer_block_header_t* block,
+    iree_hal_resource_set_t** inout_binding_resource_set,
+    iree_hal_amdgpu_reclaim_action_t pre_signal_action,
+    iree_hal_resource_t* const* operation_resources,
+    iree_host_size_t operation_resource_count,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    bool* out_ready) {
+  *out_ready = false;
+  const uint64_t command_buffer_id =
+      iree_hal_amdgpu_aql_command_buffer_profile_id(command_buffer);
+  const uint32_t kernarg_block_count = (uint32_t)iree_host_size_ceil_div(
+      block->kernarg_length, sizeof(iree_hal_amdgpu_kernarg_block_t));
+  iree_arena_allocator_t scratch_arena;
+  iree_arena_initialize(queue->block_pool, &scratch_arena);
+  iree_hal_amdgpu_aql_block_processor_profile_dispatch_list_t
+      profile_dispatches = {0};
+  iree_status_t status = iree_ok_status();
+  if (command_buffer_id != 0 && queue->profiling.dispatch_profiling_enabled) {
+    status =
+        iree_hal_amdgpu_host_queue_select_command_buffer_profile_dispatches(
+            queue, command_buffer, block, &scratch_arena, &profile_dispatches);
+    if (!iree_status_is_ok(status)) {
+      iree_arena_deinitialize(&scratch_arena);
+      return status;
+    }
+  }
+  const uint32_t profile_dispatch_event_count = profile_dispatches.count;
+  iree_hal_amdgpu_profile_dispatch_event_reservation_t profile_events = {0};
+  if (profile_dispatch_event_count != 0) {
+    status = iree_hal_amdgpu_host_queue_reserve_profile_dispatch_events(
+        queue, profile_dispatch_event_count, &profile_events);
+    if (!iree_status_is_ok(status)) {
+      iree_arena_deinitialize(&scratch_arena);
+      return status;
+    }
+  }
+  const uint32_t profile_counter_set_count =
+      profile_events.event_count != 0
+          ? iree_hal_amdgpu_host_queue_profile_counter_set_count(queue,
+                                                                 profile_events)
+          : 0u;
+  const uint32_t profile_counter_packet_count =
+      profile_events.event_count != 0
+          ? iree_hal_amdgpu_host_queue_profile_counter_packet_count(
+                queue, profile_events)
+          : 0u;
+  const uint32_t profile_trace_packet_count =
+      profile_events.event_count != 0
+          ? iree_hal_amdgpu_host_queue_profile_trace_packet_count(
+                queue, profile_events)
+          : 0u;
+  if (IREE_UNLIKELY(profile_trace_packet_count >
+                    UINT32_MAX - profile_counter_packet_count)) {
+    iree_hal_amdgpu_host_queue_cancel_profile_dispatch_events(queue,
+                                                              profile_events);
+    iree_arena_deinitialize(&scratch_arena);
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "profiled command-buffer block packet count overflow");
+  }
+  const uint32_t extra_profile_packet_count =
+      profile_counter_packet_count + profile_trace_packet_count;
+  if (IREE_UNLIKELY(block->aql_packet_count >
+                    UINT32_MAX - extra_profile_packet_count)) {
+    iree_hal_amdgpu_host_queue_cancel_profile_dispatch_events(queue,
+                                                              profile_events);
+    iree_arena_deinitialize(&scratch_arena);
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "profiled command-buffer block packet count overflow");
+  }
+  const uint32_t emitted_packet_count =
+      block->aql_packet_count + extra_profile_packet_count;
+  iree_hal_amdgpu_profile_queue_device_event_reservation_t
+      profile_queue_device_events = {0};
+  if (iree_hal_amdgpu_host_queue_should_profile_queue_device_events(queue)) {
+    iree_status_t reserve_status =
+        iree_hal_amdgpu_host_queue_reserve_profile_queue_device_events(
+            queue, /*event_count=*/1, &profile_queue_device_events);
+    if (!iree_status_is_ok(reserve_status)) {
+      iree_hal_amdgpu_host_queue_cancel_profile_dispatch_events(queue,
+                                                                profile_events);
+      iree_arena_deinitialize(&scratch_arena);
+      return reserve_status;
+    }
+  }
+  const uint32_t profile_harvest_packet_count =
+      profile_events.event_count != 0 ? 1u : 0u;
+  const uint32_t profile_queue_device_packet_count =
+      profile_queue_device_events.event_count != 0 ? 2u : 0u;
+  iree_hal_amdgpu_host_queue_command_buffer_profile_submission_flags_t
+      profile_submission_flags =
+          IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PROFILE_SUBMISSION_FLAG_NONE;
+  if (profile_events.event_count == 0 &&
+      profile_queue_device_packet_count == 0) {
+    profile_submission_flags |=
+        IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PROFILE_SUBMISSION_FLAG_TRAILING_COMPLETION_PACKET;
+  }
+  const uint32_t trailing_completion_packet_count =
+      iree_any_bit_set(
+          profile_submission_flags,
+          IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PROFILE_SUBMISSION_FLAG_TRAILING_COMPLETION_PACKET)
+          ? 1u
+          : 0u;
+  if (IREE_UNLIKELY(emitted_packet_count >
+                        UINT32_MAX - profile_harvest_packet_count ||
+                    emitted_packet_count + profile_harvest_packet_count >
+                        UINT32_MAX - profile_queue_device_packet_count ||
+                    emitted_packet_count + profile_harvest_packet_count +
+                            profile_queue_device_packet_count >
+                        UINT32_MAX - trailing_completion_packet_count)) {
+    iree_hal_amdgpu_host_queue_cancel_profile_dispatch_events(queue,
+                                                              profile_events);
+    iree_hal_amdgpu_host_queue_cancel_profile_queue_device_events(
+        queue, profile_queue_device_events);
+    iree_arena_deinitialize(&scratch_arena);
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "profiled command-buffer block packet count overflow");
+  }
+  const uint32_t payload_packet_count =
+      emitted_packet_count + profile_harvest_packet_count +
+      profile_queue_device_packet_count + trailing_completion_packet_count;
+  const uint32_t profile_harvest_kernarg_block_count =
+      profile_events.event_count != 0
+          ? (uint32_t)iree_host_size_ceil_div(
+                iree_hal_amdgpu_device_timestamp_dispatch_harvest_kernarg_length(
+                    profile_events.event_count),
+                sizeof(iree_hal_amdgpu_kernarg_block_t))
+          : 0u;
+
+  const uint64_t* block_binding_ptrs = binding_ptrs;
+  if (!block_binding_ptrs) {
+    status = iree_hal_amdgpu_host_queue_prepare_command_buffer_binding_ptrs(
+        queue, command_buffer, binding_table, &scratch_arena,
+        &block_binding_ptrs);
+  }
+  if (iree_status_is_ok(status) && profile_counter_packet_count != 0) {
+    status = iree_hal_amdgpu_host_queue_prepare_profile_counter_samples(
+        queue, profile_events);
+  }
+  if (iree_status_is_ok(status) && profile_trace_packet_count != 0) {
+    status = iree_hal_amdgpu_host_queue_prepare_profile_traces(queue,
+                                                               profile_events);
+  }
+  if (iree_status_is_ok(status) && profile_trace_packet_count != 0) {
+    status =
+        iree_hal_amdgpu_host_queue_prepare_command_buffer_profile_trace_code_objects(
+            queue, profile_dispatches, profile_events);
+  }
+
+  iree_hal_amdgpu_host_queue_kernel_submission_t submission;
+  memset(&submission, 0, sizeof(submission));
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_host_queue_try_begin_kernel_submission(
+        queue, resolution, signal_semaphore_list, operation_resource_count,
+        payload_packet_count,
+        kernarg_block_count + profile_harvest_kernarg_block_count, out_ready,
+        &submission);
+  }
+  if (!iree_status_is_ok(status) || !*out_ready) {
+    iree_hal_amdgpu_host_queue_cancel_profile_dispatch_events(queue,
+                                                              profile_events);
+    iree_hal_amdgpu_host_queue_cancel_profile_queue_device_events(
+        queue, profile_queue_device_events);
+  }
+  if (iree_status_is_ok(status) && *out_ready) {
+    iree_hal_amdgpu_aql_packet_t* profile_harvest_packet = NULL;
+    iree_hal_amdgpu_profile_dispatch_harvest_source_t* profile_harvest_sources =
+        NULL;
+    uint16_t profile_harvest_setup = 0;
+    const uint32_t profile_queue_device_prefix_packet_count =
+        profile_queue_device_events.event_count != 0 ? 1u : 0u;
+    const uint64_t first_payload_packet_id =
+        submission.first_packet_id + resolution->barrier_count +
+        profile_queue_device_prefix_packet_count;
+    if (profile_events.event_count != 0) {
+      profile_harvest_packet = iree_hal_amdgpu_aql_ring_packet(
+          &queue->aql_ring, first_payload_packet_id + emitted_packet_count);
+      profile_harvest_sources =
+          iree_hal_amdgpu_device_timestamp_emplace_dispatch_harvest(
+              &queue->transfer_context->kernels
+                   ->iree_hal_amdgpu_device_timestamp_harvest_dispatch_records,
+              profile_events.event_count, &profile_harvest_packet->dispatch,
+              submission.kernargs.blocks[kernarg_block_count].data);
+      profile_harvest_setup = profile_harvest_packet->dispatch.setup;
+    }
+    uint16_t* packet_headers = NULL;
+    uint16_t* packet_setups = NULL;
+    status = iree_hal_amdgpu_host_queue_prepare_command_buffer_packet_metadata(
+        queue, emitted_packet_count, &scratch_arena, &packet_headers,
+        &packet_setups);
+    if (iree_status_is_ok(status)) {
+      status = iree_hal_amdgpu_host_queue_write_command_buffer_block(
+          queue, resolution, signal_semaphore_list, command_buffer,
+          binding_table, block, block_binding_ptrs, first_payload_packet_id,
+          profile_queue_device_prefix_packet_count, submission.kernargs.blocks,
+          packet_headers, packet_setups, profile_events, emitted_packet_count,
+          profile_counter_set_count, profile_trace_packet_count,
+          profile_harvest_sources, profile_dispatches);
+    }
+    if (iree_status_is_ok(status)) {
+      iree_hal_amdgpu_host_queue_command_buffer_profile_submission_t
+          profile_submission = {
+              .dispatch_events = profile_events,
+              .queue_device_events = profile_queue_device_events,
+              .queue_event_info =
+                  {
+                      .type = IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_EXECUTE,
+                      .command_buffer_id = command_buffer_id,
+                      .operation_count = block->command_count,
+                  },
+              .harvest =
+                  {
+                      .packet = profile_harvest_packet,
+                      .setup = profile_harvest_setup,
+                  },
+              .flags = profile_submission_flags,
+          };
+      iree_hal_amdgpu_host_queue_finish_command_buffer_block(
+          queue, resolution, signal_semaphore_list, emitted_packet_count,
+          inout_binding_resource_set, pre_signal_action, operation_resources,
+          operation_resource_count, submission_flags, &submission,
+          packet_headers, packet_setups, &profile_submission);
+      iree_hal_amdgpu_host_queue_record_profile_queue_event(
+          queue, resolution, signal_semaphore_list,
+          &profile_submission.queue_event_info);
+    } else {
+      iree_hal_amdgpu_host_queue_fail_kernel_submission(queue, &submission);
+      iree_hal_amdgpu_host_queue_cancel_profile_dispatch_events(queue,
+                                                                profile_events);
+      iree_hal_amdgpu_host_queue_cancel_profile_queue_device_events(
+          queue, profile_queue_device_events);
+    }
+  }
+  iree_arena_deinitialize(&scratch_arena);
+  return status;
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer_block.h b/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer_block.h
new file mode 100644
index 0000000..5f005f1
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer_block.h
@@ -0,0 +1,42 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_BLOCK_H_
+#define IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_BLOCK_H_
+
+#include "iree/hal/drivers/amdgpu/host_queue_command_buffer.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Submits one finalized command-buffer block to the queue. Caller must hold
+// |queue->locks.submission_mutex|.
+iree_status_t iree_hal_amdgpu_host_queue_submit_command_buffer_block(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_command_buffer_t* command_buffer,
+    iree_hal_buffer_binding_table_t binding_table, const uint64_t* binding_ptrs,
+    const iree_hal_amdgpu_command_buffer_block_header_t* block,
+    iree_hal_resource_set_t** inout_binding_resource_set,
+    iree_hal_amdgpu_reclaim_action_t pre_signal_action,
+    iree_hal_resource_t* const* operation_resources,
+    iree_host_size_t operation_resource_count,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    bool* out_ready);
+
+// Resolves queue_execute binding table entries into raw device base pointers
+// indexed by their original binding table slot.
+iree_status_t iree_hal_amdgpu_host_queue_resolve_command_buffer_binding_ptrs(
+    iree_hal_command_buffer_t* command_buffer,
+    iree_hal_buffer_binding_table_t binding_table, uint64_t* out_binding_ptrs);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_BLOCK_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer_packet.c b/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer_packet.c
new file mode 100644
index 0000000..5c3d69c
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer_packet.c
@@ -0,0 +1,59 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/host_queue_command_buffer_packet.h"
+
+#include "iree/hal/drivers/amdgpu/host_queue_policy.h"
+
+static bool iree_hal_amdgpu_host_queue_command_buffer_packet_has_barrier(
+    const iree_hal_amdgpu_wait_resolution_t* resolution, uint32_t packet_index,
+    iree_hal_amdgpu_host_queue_command_buffer_packet_flags_t packet_flags) {
+  return iree_any_bit_set(
+             packet_flags,
+             IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_FLAG_EXECUTION_BARRIER) ||
+         iree_any_bit_set(
+             packet_flags,
+             IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_FLAG_FINAL) ||
+         (packet_index == 0 && resolution->barrier_count > 0) ||
+         (packet_index == 0 &&
+          resolution->inline_acquire_scope != IREE_HSA_FENCE_SCOPE_NONE);
+}
+
+iree_hal_amdgpu_aql_packet_control_t
+iree_hal_amdgpu_host_queue_command_buffer_packet_control(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    uint32_t packet_index, iree_hsa_fence_scope_t minimum_acquire_scope,
+    iree_hal_amdgpu_host_queue_command_buffer_packet_flags_t packet_flags) {
+  const bool has_barrier =
+      iree_hal_amdgpu_host_queue_command_buffer_packet_has_barrier(
+          resolution, packet_index, packet_flags);
+  const iree_hsa_fence_scope_t execution_acquire_scope =
+      iree_hal_amdgpu_host_queue_command_buffer_packet_flags_acquire_scope(
+          packet_flags);
+  const iree_hsa_fence_scope_t execution_release_scope =
+      iree_hal_amdgpu_host_queue_command_buffer_packet_flags_release_scope(
+          packet_flags);
+  const iree_hsa_fence_scope_t acquire_scope =
+      packet_index == 0
+          ? iree_hal_amdgpu_host_queue_max_fence_scope(
+                execution_acquire_scope, resolution->inline_acquire_scope)
+          : execution_acquire_scope;
+  const iree_hsa_fence_scope_t effective_acquire_scope =
+      iree_hal_amdgpu_host_queue_max_fence_scope(acquire_scope,
+                                                 minimum_acquire_scope);
+  iree_hsa_fence_scope_t release_scope = execution_release_scope;
+  if (iree_any_bit_set(
+          packet_flags,
+          IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_FLAG_FINAL)) {
+    release_scope = iree_hal_amdgpu_host_queue_max_fence_scope(
+        release_scope, iree_hal_amdgpu_host_queue_signal_list_release_scope(
+                           queue, signal_semaphore_list));
+  }
+  return iree_hal_amdgpu_aql_packet_control(
+      has_barrier, effective_acquire_scope, release_scope);
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer_packet.h b/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer_packet.h
new file mode 100644
index 0000000..54031b2
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer_packet.h
@@ -0,0 +1,97 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_H_
+#define IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_H_
+
+#include "iree/base/api.h"
+#include "iree/hal/api.h"
+#include "iree/hal/drivers/amdgpu/host_queue.h"
+#include "iree/hal/drivers/amdgpu/host_queue_submission.h"
+#include "iree/hal/drivers/amdgpu/util/aql_emitter.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+typedef uint32_t iree_hal_amdgpu_host_queue_command_buffer_packet_flags_t;
+enum iree_hal_amdgpu_host_queue_command_buffer_packet_flag_bits_t {
+  IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_FLAG_NONE = 0u,
+  // Packet must participate in the command-buffer execution dependency chain.
+  IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_FLAG_EXECUTION_BARRIER =
+      1u << 0,
+  // Packet owns queue completion and releases user-visible signal semaphores.
+  IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_FLAG_FINAL = 1u << 1,
+  // First bit of the two-bit acquire fence scope field in packet flags.
+  IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_ACQUIRE_SCOPE_SHIFT = 2,
+  // Bit mask of the two-bit acquire fence scope field in packet flags.
+  IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_ACQUIRE_SCOPE_MASK = 0x0Cu,
+  // First bit of the two-bit release fence scope field in packet flags.
+  IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_RELEASE_SCOPE_SHIFT = 4,
+  // Bit mask of the two-bit release fence scope field in packet flags.
+  IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_RELEASE_SCOPE_MASK = 0x30u,
+};
+
+// Returns |flags| with its encoded fence scope fields replaced.
+static inline iree_hal_amdgpu_host_queue_command_buffer_packet_flags_t
+iree_hal_amdgpu_host_queue_command_buffer_packet_flags_set_fence_scopes(
+    iree_hal_amdgpu_host_queue_command_buffer_packet_flags_t flags,
+    iree_hsa_fence_scope_t acquire_scope,
+    iree_hsa_fence_scope_t release_scope) {
+  flags &=
+      ~(IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_ACQUIRE_SCOPE_MASK |
+        IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_RELEASE_SCOPE_MASK);
+  flags |=
+      ((uint32_t)acquire_scope & 0x3u)
+      << IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_ACQUIRE_SCOPE_SHIFT;
+  flags |=
+      ((uint32_t)release_scope & 0x3u)
+      << IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_RELEASE_SCOPE_SHIFT;
+  return flags;
+}
+
+// Returns one encoded fence scope field from packet flags.
+static inline iree_hsa_fence_scope_t
+iree_hal_amdgpu_host_queue_command_buffer_packet_flags_fence_scope(
+    iree_hal_amdgpu_host_queue_command_buffer_packet_flags_t flags,
+    uint32_t mask, uint32_t shift) {
+  return (iree_hsa_fence_scope_t)((flags & mask) >> shift);
+}
+
+// Returns the acquire fence scope encoded in packet flags.
+static inline iree_hsa_fence_scope_t
+iree_hal_amdgpu_host_queue_command_buffer_packet_flags_acquire_scope(
+    iree_hal_amdgpu_host_queue_command_buffer_packet_flags_t flags) {
+  return iree_hal_amdgpu_host_queue_command_buffer_packet_flags_fence_scope(
+      flags,
+      IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_ACQUIRE_SCOPE_MASK,
+      IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_ACQUIRE_SCOPE_SHIFT);
+}
+
+// Returns the release fence scope encoded in packet flags.
+static inline iree_hsa_fence_scope_t
+iree_hal_amdgpu_host_queue_command_buffer_packet_flags_release_scope(
+    iree_hal_amdgpu_host_queue_command_buffer_packet_flags_t flags) {
+  return iree_hal_amdgpu_host_queue_command_buffer_packet_flags_fence_scope(
+      flags,
+      IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_RELEASE_SCOPE_MASK,
+      IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_RELEASE_SCOPE_SHIFT);
+}
+
+// Computes AQL packet control for one replayed command-buffer packet.
+iree_hal_amdgpu_aql_packet_control_t
+iree_hal_amdgpu_host_queue_command_buffer_packet_control(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    uint32_t packet_index, iree_hsa_fence_scope_t minimum_acquire_scope,
+    iree_hal_amdgpu_host_queue_command_buffer_packet_flags_t packet_flags);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer_profile.c b/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer_profile.c
new file mode 100644
index 0000000..51e6c3c
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer_profile.c
@@ -0,0 +1,222 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/host_queue_command_buffer_profile.h"
+
+#include <string.h>
+
+#include "iree/hal/drivers/amdgpu/logical_device.h"
+#include "iree/hal/drivers/amdgpu/profile_traces.h"
+
+static bool iree_hal_amdgpu_host_queue_profiles_command_buffer_dispatches(
+    const iree_hal_amdgpu_host_queue_t* queue) {
+  return queue->profiling.dispatch_profiling_enabled;
+}
+
+static bool
+iree_hal_amdgpu_host_queue_should_profile_all_command_buffer_dispatches(
+    const iree_hal_amdgpu_host_queue_t* queue, uint64_t command_buffer_id) {
+  if (command_buffer_id == 0) return false;
+  if (!queue->profiling.hsa_queue_timestamps_enabled) return false;
+
+  iree_hal_amdgpu_logical_device_t* logical_device =
+      (iree_hal_amdgpu_logical_device_t*)queue->logical_device;
+  const iree_hal_profile_capture_filter_t* filter =
+      &logical_device->profiling.options.capture_filter;
+  if (iree_any_bit_set(
+          filter->flags,
+          IREE_HAL_PROFILE_CAPTURE_FILTER_FLAG_COMMAND_INDEX |
+              IREE_HAL_PROFILE_CAPTURE_FILTER_FLAG_EXECUTABLE_EXPORT_PATTERN)) {
+    return false;
+  }
+
+  const uint32_t physical_device_ordinal = queue->device_ordinal <= UINT32_MAX
+                                               ? (uint32_t)queue->device_ordinal
+                                               : UINT32_MAX;
+  const uint32_t queue_ordinal = iree_async_axis_queue_index(queue->axis);
+  return iree_hal_profile_capture_filter_matches_location(
+      filter, command_buffer_id, /*command_index=*/0, physical_device_ordinal,
+      queue_ordinal);
+}
+
+static bool
+iree_hal_amdgpu_host_queue_should_profile_command_buffer_dispatch_summary(
+    const iree_hal_amdgpu_host_queue_t* queue, uint64_t command_buffer_id,
+    const iree_hal_amdgpu_aql_command_buffer_dispatch_summary_t* summary) {
+  if (command_buffer_id == 0) return false;
+  if (!queue->profiling.hsa_queue_timestamps_enabled) return false;
+  const uint32_t physical_device_ordinal = queue->device_ordinal <= UINT32_MAX
+                                               ? (uint32_t)queue->device_ordinal
+                                               : UINT32_MAX;
+  const uint32_t queue_ordinal = iree_async_axis_queue_index(queue->axis);
+  iree_hal_amdgpu_logical_device_t* logical_device =
+      (iree_hal_amdgpu_logical_device_t*)queue->logical_device;
+  return iree_hal_amdgpu_logical_device_should_profile_dispatch(
+      logical_device, summary->metadata.executable_id,
+      summary->metadata.export_ordinal, command_buffer_id,
+      summary->metadata.command_index, physical_device_ordinal, queue_ordinal);
+}
+
+iree_status_t
+iree_hal_amdgpu_host_queue_select_command_buffer_profile_dispatches(
+    const iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_command_buffer_t* command_buffer,
+    const iree_hal_amdgpu_command_buffer_block_header_t* block,
+    iree_arena_allocator_t* scratch_arena,
+    iree_hal_amdgpu_aql_block_processor_profile_dispatch_list_t*
+        out_dispatches) {
+  *out_dispatches =
+      (iree_hal_amdgpu_aql_block_processor_profile_dispatch_list_t){0};
+  const uint64_t command_buffer_id =
+      iree_hal_amdgpu_aql_command_buffer_profile_id(command_buffer);
+  if (command_buffer_id == 0) return iree_ok_status();
+  if (!queue->profiling.hsa_queue_timestamps_enabled) return iree_ok_status();
+  if (!iree_hal_amdgpu_host_queue_profiles_command_buffer_dispatches(queue)) {
+    return iree_ok_status();
+  }
+  const bool profile_all_dispatches =
+      iree_hal_amdgpu_host_queue_should_profile_all_command_buffer_dispatches(
+          queue, command_buffer_id);
+
+  uint32_t summary_count = 0;
+  const iree_hal_amdgpu_aql_command_buffer_dispatch_summary_t* summary =
+      iree_hal_amdgpu_aql_command_buffer_dispatch_summaries(
+          command_buffer, block, &summary_count);
+  if (IREE_UNLIKELY(summary_count != block->dispatch_count)) {
+    return iree_make_status(
+        IREE_STATUS_INTERNAL,
+        "retained dispatch summary count mismatch: expected %u but got %u",
+        block->dispatch_count, summary_count);
+  }
+  if (summary_count == 0) return iree_ok_status();
+
+  iree_host_size_t dispatch_storage_size = 0;
+  IREE_RETURN_IF_ERROR(IREE_STRUCT_LAYOUT(
+      0, &dispatch_storage_size,
+      IREE_STRUCT_FIELD(summary_count,
+                        iree_hal_amdgpu_aql_block_processor_profile_dispatch_t,
+                        NULL)));
+  iree_hal_amdgpu_aql_block_processor_profile_dispatch_t* dispatches = NULL;
+  IREE_RETURN_IF_ERROR(iree_arena_allocate(scratch_arena, dispatch_storage_size,
+                                           (void**)&dispatches));
+
+  uint32_t selected_count = 0;
+  for (uint32_t summary_ordinal = 0; summary_ordinal < summary_count;
+       ++summary_ordinal) {
+    if (IREE_UNLIKELY(!summary)) {
+      return iree_make_status(
+          IREE_STATUS_INTERNAL,
+          "retained dispatch summary list ended after %u of %u entries",
+          summary_ordinal, summary_count);
+    }
+    if (profile_all_dispatches ||
+        iree_hal_amdgpu_host_queue_should_profile_command_buffer_dispatch_summary(
+            queue, command_buffer_id, summary)) {
+      dispatches[selected_count++].summary = summary;
+    }
+    summary = summary->next;
+  }
+  out_dispatches->values = selected_count ? dispatches : NULL;
+  out_dispatches->count = selected_count;
+  return iree_ok_status();
+}
+
+static bool iree_hal_amdgpu_command_buffer_dispatch_summary_uses_indirect(
+    const iree_hal_amdgpu_aql_command_buffer_dispatch_summary_t* summary) {
+  return iree_any_bit_set(
+      summary->metadata.dispatch_flags,
+      IREE_HAL_AMDGPU_COMMAND_BUFFER_DISPATCH_FLAG_INDIRECT_PARAMETERS);
+}
+
+static void
+iree_hal_amdgpu_host_queue_initialize_command_buffer_dispatch_summary_event(
+    iree_hal_amdgpu_profile_dispatch_event_t* event, uint64_t command_buffer_id,
+    const iree_hal_amdgpu_aql_command_buffer_dispatch_summary_t* summary) {
+  const uint64_t event_id = event->event_id;
+  memset(event, 0, sizeof(*event));
+  event->record_length = sizeof(*event);
+  event->flags = IREE_HAL_AMDGPU_PROFILE_DISPATCH_EVENT_FLAG_COMMAND_BUFFER;
+  event->event_id = event_id;
+  event->command_buffer_id = command_buffer_id;
+  event->executable_id = summary->metadata.executable_id;
+  if (iree_hal_amdgpu_command_buffer_dispatch_summary_uses_indirect(summary)) {
+    event->flags |=
+        IREE_HAL_AMDGPU_PROFILE_DISPATCH_EVENT_FLAG_INDIRECT_PARAMETERS;
+  }
+  event->command_index = summary->metadata.command_index;
+  event->export_ordinal = summary->metadata.export_ordinal;
+  for (iree_host_size_t dimension_ordinal = 0;
+       dimension_ordinal < IREE_ARRAYSIZE(event->workgroup_size);
+       ++dimension_ordinal) {
+    event->workgroup_size[dimension_ordinal] =
+        summary->workgroup.size[dimension_ordinal];
+    if (!iree_hal_amdgpu_command_buffer_dispatch_summary_uses_indirect(
+            summary) &&
+        summary->workgroup.size[dimension_ordinal] != 0) {
+      event->workgroup_count[dimension_ordinal] =
+          summary->workgroup.count[dimension_ordinal];
+    }
+  }
+}
+
+void iree_hal_amdgpu_host_queue_record_command_buffer_profile_dispatch_source(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t command_buffer_id,
+    const iree_hal_amdgpu_aql_block_processor_profile_dispatch_t* dispatch,
+    iree_hal_amdgpu_profile_dispatch_event_reservation_t profile_events,
+    iree_hal_amdgpu_profile_dispatch_harvest_source_t* profile_harvest_sources,
+    uint32_t* inout_profile_event_index) {
+  const uint32_t profile_event_index = *inout_profile_event_index;
+  const uint64_t profile_event_position =
+      profile_events.first_event_position + profile_event_index;
+  iree_hal_amdgpu_profile_dispatch_event_t* event =
+      iree_hal_amdgpu_host_queue_profile_dispatch_event_at(
+          queue, profile_event_position);
+  iree_hal_amdgpu_host_queue_initialize_command_buffer_dispatch_summary_event(
+      event, command_buffer_id, dispatch->summary);
+  profile_harvest_sources[profile_event_index].completion_signal =
+      iree_hal_amdgpu_host_queue_profiling_completion_signal_ptr(
+          queue, profile_event_position);
+  profile_harvest_sources[profile_event_index].ticks =
+      iree_hal_amdgpu_profile_dispatch_event_ticks(event);
+  *inout_profile_event_index = profile_event_index + 1;
+}
+
+iree_status_t
+iree_hal_amdgpu_host_queue_prepare_command_buffer_profile_trace_code_objects(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_aql_block_processor_profile_dispatch_list_t dispatches,
+    iree_hal_amdgpu_profile_dispatch_event_reservation_t profile_events) {
+  if (profile_events.event_count == 0 || !queue->profiling.traces.session) {
+    return iree_ok_status();
+  }
+  if (IREE_UNLIKELY(dispatches.count != profile_events.event_count)) {
+    return iree_make_status(
+        IREE_STATUS_INTERNAL,
+        "profile command-buffer dispatch selection count %u does not match "
+        "reserved event count %u",
+        dispatches.count, profile_events.event_count);
+  }
+
+  iree_status_t status = iree_ok_status();
+  for (uint32_t dispatch_ordinal = 0;
+       dispatch_ordinal < dispatches.count && iree_status_is_ok(status);
+       ++dispatch_ordinal) {
+    const iree_hal_amdgpu_aql_block_processor_profile_dispatch_t* dispatch =
+        &dispatches.values[dispatch_ordinal];
+    if (IREE_UNLIKELY(!dispatch->summary)) {
+      status = iree_make_status(
+          IREE_STATUS_INTERNAL,
+          "profile command-buffer dispatch selection %u has no summary",
+          dispatch_ordinal);
+      break;
+    }
+    const uint64_t event_position =
+        profile_events.first_event_position + dispatch_ordinal;
+    status = iree_hal_amdgpu_host_queue_prepare_profile_trace_code_object(
+        queue, event_position, dispatch->summary->metadata.executable_id);
+  }
+  return status;
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer_profile.h b/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer_profile.h
new file mode 100644
index 0000000..6c878dc
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer_profile.h
@@ -0,0 +1,53 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PROFILE_H_
+#define IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PROFILE_H_
+
+#include "iree/base/api.h"
+#include "iree/hal/api.h"
+#include "iree/hal/drivers/amdgpu/abi/command_buffer.h"
+#include "iree/hal/drivers/amdgpu/aql_block_processor_profile.h"
+#include "iree/hal/drivers/amdgpu/aql_command_buffer.h"
+#include "iree/hal/drivers/amdgpu/host_queue.h"
+#include "iree/hal/drivers/amdgpu/host_queue_profile_events.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Selects dispatch commands in |block| matched by the active capture filter.
+iree_status_t
+iree_hal_amdgpu_host_queue_select_command_buffer_profile_dispatches(
+    const iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_command_buffer_t* command_buffer,
+    const iree_hal_amdgpu_command_buffer_block_header_t* block,
+    iree_arena_allocator_t* scratch_arena,
+    iree_hal_amdgpu_aql_block_processor_profile_dispatch_list_t*
+        out_dispatches);
+
+// Records the queue-owned dispatch event and harvest source paired with one
+// command-buffer dispatch packet.
+void iree_hal_amdgpu_host_queue_record_command_buffer_profile_dispatch_source(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t command_buffer_id,
+    const iree_hal_amdgpu_aql_block_processor_profile_dispatch_t* dispatch,
+    iree_hal_amdgpu_profile_dispatch_event_reservation_t profile_events,
+    iree_hal_amdgpu_profile_dispatch_harvest_source_t* profile_harvest_sources,
+    uint32_t* inout_profile_event_index);
+
+// Prepares code-object side data needed by trace captures for the selected
+// dispatch events in |block|.
+iree_status_t
+iree_hal_amdgpu_host_queue_prepare_command_buffer_profile_trace_code_objects(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_aql_block_processor_profile_dispatch_list_t dispatches,
+    iree_hal_amdgpu_profile_dispatch_event_reservation_t profile_events);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PROFILE_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer_replay.c b/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer_replay.c
new file mode 100644
index 0000000..5611ae4
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer_replay.c
@@ -0,0 +1,406 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/host_queue_command_buffer_replay.h"
+
+#include <string.h>
+
+#include "iree/base/alignment.h"
+#include "iree/hal/drivers/amdgpu/aql_command_buffer.h"
+#include "iree/hal/drivers/amdgpu/aql_program_validation.h"
+#include "iree/hal/drivers/amdgpu/host_queue_command_buffer_block.h"
+#include "iree/hal/drivers/amdgpu/host_queue_command_buffer_packet.h"
+#include "iree/hal/drivers/amdgpu/host_queue_policy.h"
+#include "iree/hal/drivers/amdgpu/host_queue_profile.h"
+#include "iree/hal/drivers/amdgpu/host_queue_profile_events.h"
+#include "iree/hal/drivers/amdgpu/host_queue_timestamp.h"
+#include "iree/hal/drivers/amdgpu/util/aql_emitter.h"
+#include "iree/hal/utils/resource_set.h"
+
+typedef struct iree_hal_amdgpu_command_buffer_replay_t {
+  // Reference-counted replay continuation resource.
+  iree_hal_resource_t resource;
+  // Host queue borrowed for the replay lifetime.
+  iree_hal_amdgpu_host_queue_t* queue;
+  // Host allocator used for this continuation allocation.
+  iree_allocator_t host_allocator;
+  // Command buffer retained for the replay lifetime.
+  iree_hal_command_buffer_t* command_buffer;
+  // Final user-visible signal semaphore list retained for the replay lifetime.
+  iree_hal_semaphore_list_t signal_semaphore_list;
+  // Binding table snapshot used after queue_execute returns.
+  iree_hal_buffer_binding_table_t binding_table;
+  // Resolved binding table base pointers indexed by original binding slot.
+  const uint64_t* binding_ptrs;
+  // Resource set retaining binding-table buffers for the replay lifetime.
+  iree_hal_resource_set_t* binding_resource_set;
+  // Immutable recorded AQL program borrowed from |command_buffer|.
+  const iree_hal_amdgpu_aql_program_t* program;
+  // Next command-buffer block to evaluate or submit.
+  const iree_hal_amdgpu_command_buffer_block_header_t* current_block;
+  // Wait resolution that prefixes the next packet submission.
+  iree_hal_amdgpu_wait_resolution_t wait_resolution;
+  // Intrusive continuation used to retry replay after notification drain.
+  iree_hal_amdgpu_host_queue_post_drain_action_t post_drain_action;
+} iree_hal_amdgpu_command_buffer_replay_t;
+
+static void iree_hal_amdgpu_command_buffer_replay_consume_wait_resolution(
+    iree_hal_amdgpu_command_buffer_replay_t* replay) {
+  memset(&replay->wait_resolution, 0, sizeof(replay->wait_resolution));
+}
+
+static void iree_hal_amdgpu_command_buffer_replay_destroy(
+    iree_hal_resource_t* resource) {
+  iree_hal_amdgpu_command_buffer_replay_t* replay =
+      (iree_hal_amdgpu_command_buffer_replay_t*)resource;
+  IREE_TRACE_ZONE_BEGIN(z0);
+  iree_hal_resource_set_free(replay->binding_resource_set);
+  iree_hal_semaphore_list_release(replay->signal_semaphore_list);
+  iree_hal_command_buffer_release(replay->command_buffer);
+  iree_allocator_free(replay->host_allocator, replay);
+  IREE_TRACE_ZONE_END(z0);
+}
+
+static const iree_hal_resource_vtable_t
+    iree_hal_amdgpu_command_buffer_replay_vtable = {
+        .destroy = iree_hal_amdgpu_command_buffer_replay_destroy,
+};
+
+static iree_status_t iree_hal_amdgpu_command_buffer_replay_create(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_command_buffer_t* command_buffer,
+    iree_hal_buffer_binding_table_t binding_table,
+    iree_hal_execute_flags_t execute_flags,
+    iree_hal_resource_set_t** inout_binding_resource_set,
+    iree_hal_amdgpu_command_buffer_replay_t** out_replay) {
+  *out_replay = NULL;
+
+  if (!*inout_binding_resource_set) {
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_host_queue_create_binding_table_resource_set(
+            queue, command_buffer, binding_table, execute_flags,
+            inout_binding_resource_set));
+  }
+
+  const iree_host_size_t signal_count = signal_semaphore_list.count;
+  const iree_host_size_t binding_count = command_buffer->binding_count;
+
+  iree_host_size_t semaphore_offset = 0;
+  iree_host_size_t payload_offset = 0;
+  iree_host_size_t binding_offset = 0;
+  iree_host_size_t binding_ptr_offset = 0;
+  iree_host_size_t total_size = 0;
+  IREE_RETURN_IF_ERROR(IREE_STRUCT_LAYOUT(
+      sizeof(iree_hal_amdgpu_command_buffer_replay_t), &total_size,
+      IREE_STRUCT_FIELD_ALIGNED(signal_count, iree_hal_semaphore_t*, 1,
+                                &semaphore_offset),
+      IREE_STRUCT_FIELD_ALIGNED(signal_count, uint64_t, 1, &payload_offset),
+      IREE_STRUCT_FIELD_ALIGNED(binding_count, iree_hal_buffer_binding_t, 1,
+                                &binding_offset),
+      IREE_STRUCT_FIELD_ALIGNED(binding_count, uint64_t, 1,
+                                &binding_ptr_offset)));
+
+  iree_hal_amdgpu_command_buffer_replay_t* replay = NULL;
+  IREE_RETURN_IF_ERROR(iree_allocator_malloc(queue->host_allocator, total_size,
+                                             (void**)&replay));
+  memset(replay, 0, total_size);
+  iree_hal_resource_initialize(&iree_hal_amdgpu_command_buffer_replay_vtable,
+                               &replay->resource);
+  replay->queue = queue;
+  replay->host_allocator = queue->host_allocator;
+  replay->command_buffer = command_buffer;
+  replay->binding_resource_set = *inout_binding_resource_set;
+  *inout_binding_resource_set = NULL;
+  replay->program = iree_hal_amdgpu_aql_command_buffer_program(command_buffer);
+  replay->current_block = replay->program->first_block;
+  replay->wait_resolution = *resolution;
+  iree_hal_command_buffer_retain(command_buffer);
+
+  uint8_t* storage = (uint8_t*)replay;
+  if (signal_count > 0) {
+    replay->signal_semaphore_list.count = signal_count;
+    replay->signal_semaphore_list.semaphores =
+        (iree_hal_semaphore_t**)(storage + semaphore_offset);
+    replay->signal_semaphore_list.payload_values =
+        (uint64_t*)(storage + payload_offset);
+    memcpy(replay->signal_semaphore_list.semaphores,
+           signal_semaphore_list.semaphores,
+           signal_count * sizeof(*signal_semaphore_list.semaphores));
+    memcpy(replay->signal_semaphore_list.payload_values,
+           signal_semaphore_list.payload_values,
+           signal_count * sizeof(*signal_semaphore_list.payload_values));
+  }
+  iree_hal_semaphore_list_retain(replay->signal_semaphore_list);
+  if (binding_count > 0) {
+    iree_hal_buffer_binding_t* binding_storage =
+        (iree_hal_buffer_binding_t*)(storage + binding_offset);
+    replay->binding_table.count = binding_count;
+    replay->binding_table.bindings = binding_storage;
+    memcpy(binding_storage, binding_table.bindings,
+           binding_count * sizeof(*binding_table.bindings));
+    uint64_t* binding_ptrs = (uint64_t*)(storage + binding_ptr_offset);
+    iree_status_t status =
+        iree_hal_amdgpu_host_queue_resolve_command_buffer_binding_ptrs(
+            command_buffer, replay->binding_table, binding_ptrs);
+    if (!iree_status_is_ok(status)) {
+      iree_hal_resource_release(&replay->resource);
+      return status;
+    }
+    replay->binding_ptrs = binding_ptrs;
+  }
+
+  *out_replay = replay;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_command_buffer_replay_clone_queue_error(
+    iree_hal_amdgpu_command_buffer_replay_t* replay) {
+  if (IREE_UNLIKELY(replay->queue->is_shutting_down)) {
+    return iree_make_status(IREE_STATUS_CANCELLED, "queue shutting down");
+  }
+  iree_status_t error = (iree_status_t)iree_atomic_load(
+      &replay->queue->error_status, iree_memory_order_acquire);
+  return iree_status_is_ok(error) ? iree_ok_status() : iree_status_clone(error);
+}
+
+static void iree_hal_amdgpu_command_buffer_replay_fail_signals(
+    iree_hal_amdgpu_command_buffer_replay_t* replay, iree_status_t status) {
+  if (iree_status_is_ok(status)) return;
+  if (iree_hal_semaphore_list_is_empty(replay->signal_semaphore_list)) {
+    iree_status_free(status);
+    return;
+  }
+  iree_hal_semaphore_list_fail(replay->signal_semaphore_list, status);
+}
+
+static iree_status_t iree_hal_amdgpu_command_buffer_replay_park(
+    iree_hal_amdgpu_command_buffer_replay_t* replay,
+    iree_hal_amdgpu_host_queue_post_drain_fn_t post_drain_fn) {
+  iree_hal_resource_retain(&replay->resource);
+  iree_hal_amdgpu_host_queue_enqueue_post_drain_action(
+      replay->queue, &replay->post_drain_action, post_drain_fn, replay);
+  return iree_ok_status();
+}
+
+static iree_status_t
+iree_hal_amdgpu_command_buffer_replay_submit_completion_packet(
+    iree_hal_amdgpu_command_buffer_replay_t* replay,
+    const iree_hal_amdgpu_wait_resolution_t* resolution, bool* out_ready) {
+  *out_ready = false;
+
+  iree_hal_amdgpu_host_queue_profile_event_info_t profile_event_info = {
+      .type = IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_EXECUTE,
+      .command_buffer_id =
+          iree_hal_amdgpu_aql_command_buffer_profile_id(replay->command_buffer),
+      .operation_count = 0,
+  };
+  iree_hal_amdgpu_profile_queue_device_event_reservation_t
+      profile_queue_device_events = {0};
+  if (iree_hal_amdgpu_host_queue_should_profile_queue_device_events(
+          replay->queue)) {
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_host_queue_reserve_profile_queue_device_events(
+            replay->queue, /*event_count=*/1, &profile_queue_device_events));
+  }
+
+  iree_hal_amdgpu_host_queue_kernel_submission_t submission;
+  iree_status_t status = iree_hal_amdgpu_host_queue_try_begin_kernel_submission(
+      replay->queue, resolution, replay->signal_semaphore_list,
+      /*operation_resource_count=*/1, /*payload_packet_count=*/1,
+      /*kernarg_block_count=*/0, out_ready, &submission);
+  if (!iree_status_is_ok(status) || !*out_ready) {
+    iree_hal_amdgpu_host_queue_cancel_profile_queue_device_events(
+        replay->queue, profile_queue_device_events);
+  }
+  if (iree_status_is_ok(status) && *out_ready) {
+    iree_hal_amdgpu_host_queue_emit_kernel_submission_prefix(
+        replay->queue, resolution, &submission);
+    iree_hal_resource_t* replay_resource = &replay->resource;
+    const uint64_t submission_id =
+        iree_hal_amdgpu_host_queue_finish_kernel_submission(
+            replay->queue, resolution, replay->signal_semaphore_list,
+            &replay_resource, /*operation_resource_count=*/1,
+            /*inout_resource_set=*/NULL,
+            IREE_HAL_AMDGPU_HOST_QUEUE_SUBMISSION_FLAG_RETAIN_RESOURCES,
+            &submission);
+
+    iree_hal_amdgpu_profile_queue_device_event_t* queue_device_event =
+        iree_hal_amdgpu_host_queue_initialize_profile_queue_device_event(
+            replay->queue, profile_queue_device_events, &profile_event_info);
+    if (queue_device_event) {
+      submission.reclaim_entry->queue_device_event_first_position =
+          profile_queue_device_events.first_event_position;
+      submission.reclaim_entry->queue_device_event_count =
+          profile_queue_device_events.event_count;
+      queue_device_event->submission_id = submission_id;
+    }
+
+    if (queue_device_event) {
+      const uint64_t timestamp_packet_id =
+          submission.first_packet_id + submission.packet_count - 1;
+      iree_hal_amdgpu_host_queue_commit_timestamp_range(
+          replay->queue, timestamp_packet_id,
+          iree_hal_amdgpu_host_queue_command_buffer_packet_control(
+              replay->queue, resolution, replay->signal_semaphore_list,
+              /*packet_index=*/0, IREE_HSA_FENCE_SCOPE_NONE,
+              IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_FLAG_FINAL),
+          iree_hal_amdgpu_notification_ring_epoch_signal(
+              &replay->queue->notification_ring),
+          &queue_device_event->start_tick, &queue_device_event->end_tick);
+    } else {
+      iree_hal_amdgpu_aql_packet_t* packet = iree_hal_amdgpu_aql_ring_packet(
+          &replay->queue->aql_ring,
+          submission.first_packet_id + submission.packet_count - 1);
+      const uint16_t header = iree_hal_amdgpu_aql_emit_nop(
+          &packet->barrier_and,
+          iree_hal_amdgpu_aql_packet_control_barrier(
+              resolution->inline_acquire_scope,
+              iree_hal_amdgpu_host_queue_signal_list_release_scope(
+                  replay->queue, replay->signal_semaphore_list)),
+          iree_hal_amdgpu_notification_ring_epoch_signal(
+              &replay->queue->notification_ring));
+      iree_hal_amdgpu_aql_ring_commit(packet, header, /*setup=*/0);
+    }
+    iree_hal_amdgpu_aql_ring_doorbell(
+        &replay->queue->aql_ring,
+        submission.first_packet_id + submission.packet_count - 1);
+    profile_event_info.submission_id = submission_id;
+    iree_hal_amdgpu_host_queue_record_profile_queue_event(
+        replay->queue, resolution, replay->signal_semaphore_list,
+        &profile_event_info);
+    memset(&submission, 0, sizeof(submission));
+    replay->current_block = NULL;
+    iree_hal_amdgpu_command_buffer_replay_consume_wait_resolution(replay);
+  }
+  return status;
+}
+
+static iree_status_t iree_hal_amdgpu_command_buffer_replay_resume_under_lock(
+    iree_hal_amdgpu_command_buffer_replay_t* replay,
+    iree_hal_amdgpu_host_queue_post_drain_fn_t post_drain_fn) {
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_command_buffer_replay_clone_queue_error(replay));
+
+  iree_status_t status = iree_ok_status();
+  while (iree_status_is_ok(status) && replay->current_block) {
+    status = iree_hal_amdgpu_aql_program_validate_block_terminator(
+        replay->current_block);
+    if (!iree_status_is_ok(status)) break;
+    const uint8_t terminator_opcode = replay->current_block->terminator_opcode;
+
+    if (replay->current_block->aql_packet_count == 0) {
+      if (terminator_opcode == IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_RETURN) {
+        const iree_hal_amdgpu_wait_resolution_t* current_resolution =
+            &replay->wait_resolution;
+        bool ready = false;
+        status = iree_hal_amdgpu_command_buffer_replay_submit_completion_packet(
+            replay, current_resolution, &ready);
+        if (iree_status_is_ok(status) && !ready) {
+          status =
+              iree_hal_amdgpu_command_buffer_replay_park(replay, post_drain_fn);
+        }
+        break;
+      }
+
+      const iree_hal_amdgpu_command_buffer_block_header_t* next_block = NULL;
+      status = iree_hal_amdgpu_aql_program_next_linear_block(
+          replay->program, replay->current_block,
+          replay->current_block->terminator_target_block_ordinal, &next_block);
+      if (iree_status_is_ok(status)) {
+        replay->current_block = next_block;
+      }
+      continue;
+    }
+
+    const iree_hal_amdgpu_wait_resolution_t* current_resolution =
+        &replay->wait_resolution;
+
+    if (terminator_opcode == IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_RETURN) {
+      iree_hal_resource_t* replay_resource = &replay->resource;
+      bool ready = false;
+      status = iree_hal_amdgpu_host_queue_submit_command_buffer_block(
+          replay->queue, current_resolution, replay->signal_semaphore_list,
+          replay->command_buffer, replay->binding_table, replay->binding_ptrs,
+          replay->current_block, /*inout_binding_resource_set=*/NULL,
+          (iree_hal_amdgpu_reclaim_action_t){0}, &replay_resource,
+          /*operation_resource_count=*/1,
+          IREE_HAL_AMDGPU_HOST_QUEUE_SUBMISSION_FLAG_RETAIN_RESOURCES, &ready);
+      if (iree_status_is_ok(status) && ready) {
+        replay->current_block = NULL;
+        iree_hal_amdgpu_command_buffer_replay_consume_wait_resolution(replay);
+      } else if (iree_status_is_ok(status)) {
+        status =
+            iree_hal_amdgpu_command_buffer_replay_park(replay, post_drain_fn);
+      }
+      break;
+    }
+
+    const iree_hal_amdgpu_command_buffer_block_header_t* next_block = NULL;
+    status = iree_hal_amdgpu_aql_program_next_linear_block(
+        replay->program, replay->current_block,
+        replay->current_block->terminator_target_block_ordinal, &next_block);
+    if (!iree_status_is_ok(status)) break;
+
+    iree_hal_resource_t* replay_resource = &replay->resource;
+    bool ready = false;
+    status = iree_hal_amdgpu_host_queue_submit_command_buffer_block(
+        replay->queue, current_resolution, iree_hal_semaphore_list_empty(),
+        replay->command_buffer, replay->binding_table, replay->binding_ptrs,
+        replay->current_block, /*inout_binding_resource_set=*/NULL,
+        (iree_hal_amdgpu_reclaim_action_t){0}, &replay_resource,
+        /*operation_resource_count=*/1,
+        IREE_HAL_AMDGPU_HOST_QUEUE_SUBMISSION_FLAG_RETAIN_RESOURCES, &ready);
+    if (iree_status_is_ok(status) && ready) {
+      replay->current_block = next_block;
+      iree_hal_amdgpu_command_buffer_replay_consume_wait_resolution(replay);
+      continue;
+    } else if (iree_status_is_ok(status)) {
+      status =
+          iree_hal_amdgpu_command_buffer_replay_park(replay, post_drain_fn);
+    }
+    break;
+  }
+  return status;
+}
+
+static void iree_hal_amdgpu_command_buffer_replay_post_drain(void* user_data) {
+  iree_hal_amdgpu_command_buffer_replay_t* replay =
+      (iree_hal_amdgpu_command_buffer_replay_t*)user_data;
+  iree_slim_mutex_lock(&replay->queue->locks.submission_mutex);
+  iree_status_t status =
+      iree_hal_amdgpu_command_buffer_replay_resume_under_lock(
+          replay, iree_hal_amdgpu_command_buffer_replay_post_drain);
+  iree_slim_mutex_unlock(&replay->queue->locks.submission_mutex);
+  if (!iree_status_is_ok(status)) {
+    iree_hal_amdgpu_command_buffer_replay_fail_signals(replay, status);
+  }
+  iree_hal_resource_release(&replay->resource);
+}
+
+iree_status_t iree_hal_amdgpu_command_buffer_replay_start_under_lock(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_command_buffer_t* command_buffer,
+    iree_hal_buffer_binding_table_t binding_table,
+    iree_hal_execute_flags_t execute_flags,
+    iree_hal_resource_set_t** inout_binding_resource_set) {
+  iree_hal_amdgpu_command_buffer_replay_t* replay = NULL;
+  iree_status_t status = iree_hal_amdgpu_command_buffer_replay_create(
+      queue, resolution, signal_semaphore_list, command_buffer, binding_table,
+      execute_flags, inout_binding_resource_set, &replay);
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_command_buffer_replay_resume_under_lock(
+        replay, iree_hal_amdgpu_command_buffer_replay_post_drain);
+    iree_hal_resource_release(&replay->resource);
+  } else {
+    iree_hal_resource_set_free(*inout_binding_resource_set);
+    *inout_binding_resource_set = NULL;
+  }
+  return status;
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer_replay.h b/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer_replay.h
new file mode 100644
index 0000000..7076272
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer_replay.h
@@ -0,0 +1,31 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_REPLAY_H_
+#define IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_REPLAY_H_
+
+#include "iree/hal/drivers/amdgpu/host_queue_command_buffer.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Starts multi-block command-buffer replay. Caller must hold
+// |queue->locks.submission_mutex|.
+iree_status_t iree_hal_amdgpu_command_buffer_replay_start_under_lock(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_command_buffer_t* command_buffer,
+    iree_hal_buffer_binding_table_t binding_table,
+    iree_hal_execute_flags_t execute_flags,
+    iree_hal_resource_set_t** inout_binding_resource_set);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_REPLAY_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer_scratch.h b/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer_scratch.h
new file mode 100644
index 0000000..f3c950f
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer_scratch.h
@@ -0,0 +1,41 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_SCRATCH_H_
+#define IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_SCRATCH_H_
+
+#include "iree/base/api.h"
+
+// Queue_execute binding table entries cached as raw device pointers under
+// submission_mutex while replaying an AQL command buffer. Larger binding tables
+// use temporary arena storage for the current submission.
+#define IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_BINDING_SCRATCH_CAPACITY 4096u
+
+// Queue_execute packet metadata cached under submission_mutex while replaying
+// an AQL command buffer. Larger packet-bearing blocks use a temporary arena
+// block for the current submission.
+#define IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_SCRATCH_CAPACITY 512u
+
+// Lazily allocated host queue scratch used only by queue_execute.
+typedef struct iree_hal_amdgpu_host_queue_command_buffer_scratch_t {
+  // Resolved queue_execute binding-table device pointers.
+  struct {
+    // Raw device pointers indexed by queue_execute binding table slot.
+    uint64_t ptrs
+        [IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_BINDING_SCRATCH_CAPACITY];
+  } bindings;
+  // Packet sidebands populated by block processors before AQL publication.
+  struct {
+    // AQL packet header words indexed by emitted packet ordinal.
+    uint16_t headers
+        [IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_SCRATCH_CAPACITY];
+    // AQL packet setup words indexed by emitted packet ordinal.
+    uint16_t setups
+        [IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_SCRATCH_CAPACITY];
+  } packets;
+} iree_hal_amdgpu_host_queue_command_buffer_scratch_t;
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_SCRATCH_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer_test.cc b/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer_test.cc
new file mode 100644
index 0000000..546b1ed
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_command_buffer_test.cc
@@ -0,0 +1,3867 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/host_queue_command_buffer.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <string>
+#include <vector>
+
+#include "iree/hal/api.h"
+#include "iree/hal/cts/util/test_base.h"
+#include "iree/hal/drivers/amdgpu/aql_command_buffer.h"
+#include "iree/hal/drivers/amdgpu/executable.h"
+#include "iree/hal/drivers/amdgpu/host_queue.h"
+#include "iree/hal/drivers/amdgpu/host_queue_command_buffer_packet.h"
+#include "iree/hal/drivers/amdgpu/host_queue_profile_events.h"
+#include "iree/hal/drivers/amdgpu/logical_device.h"
+#include "iree/hal/drivers/amdgpu/physical_device.h"
+#include "iree/hal/drivers/amdgpu/queue_affinity.h"
+#include "iree/hal/drivers/amdgpu/util/aql_emitter.h"
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+#include "runtime/src/iree/hal/drivers/amdgpu/cts/testdata_amdgpu.h"
+
+namespace iree::hal::amdgpu {
+namespace {
+
+using iree::hal::cts::Ref;
+
+class HostQueueCommandBufferTest : public ::testing::Test {
+ protected:
+  static void SetUpTestSuite() {
+    host_allocator_ = iree_allocator_system();
+    iree_status_t status = iree_hal_amdgpu_libhsa_initialize(
+        IREE_HAL_AMDGPU_LIBHSA_FLAG_NONE, iree_string_view_list_empty(),
+        host_allocator_, &libhsa_);
+    if (!iree_status_is_ok(status)) {
+      iree_status_fprint(stderr, status);
+      iree_status_free(status);
+      GTEST_SKIP() << "HSA not available, skipping tests";
+    }
+    IREE_ASSERT_OK(iree_hal_amdgpu_topology_initialize_with_defaults(
+        &libhsa_, &topology_));
+    if (topology_.gpu_agent_count == 0) {
+      GTEST_SKIP() << "no GPU devices available, skipping tests";
+    }
+  }
+
+  static void TearDownTestSuite() {
+    iree_hal_amdgpu_topology_deinitialize(&topology_);
+    iree_hal_amdgpu_libhsa_deinitialize(&libhsa_);
+  }
+
+  static iree_allocator_t host_allocator_;
+  static iree_hal_amdgpu_libhsa_t libhsa_;
+  static iree_hal_amdgpu_topology_t topology_;
+};
+
+iree_allocator_t HostQueueCommandBufferTest::host_allocator_;
+iree_hal_amdgpu_libhsa_t HostQueueCommandBufferTest::libhsa_;
+iree_hal_amdgpu_topology_t HostQueueCommandBufferTest::topology_;
+
+class TestLogicalDevice {
+ public:
+  ~TestLogicalDevice() {
+    iree_hal_device_release(base_device_);
+    iree_hal_device_group_release(device_group_);
+  }
+
+  iree_status_t Initialize(
+      const iree_hal_amdgpu_logical_device_options_t* options,
+      const iree_hal_amdgpu_libhsa_t* libhsa,
+      const iree_hal_amdgpu_topology_t* topology,
+      iree_allocator_t host_allocator) {
+    IREE_RETURN_IF_ERROR(create_context_.Initialize(host_allocator));
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_logical_device_create(
+        IREE_SV("amdgpu"), options, libhsa, topology, create_context_.params(),
+        host_allocator, &base_device_));
+    return iree_hal_device_group_create_from_device(
+        base_device_, create_context_.frontier_tracker(), host_allocator,
+        &device_group_);
+  }
+
+  iree_hal_device_t* base_device() const { return base_device_; }
+
+  iree_hal_allocator_t* allocator() const {
+    return iree_hal_device_allocator(base_device_);
+  }
+
+  iree_hal_amdgpu_logical_device_t* logical_device() const {
+    return (iree_hal_amdgpu_logical_device_t*)base_device_;
+  }
+
+  iree_hal_amdgpu_host_queue_t* first_host_queue() const {
+    iree_hal_amdgpu_logical_device_t* logical_device = this->logical_device();
+    if (logical_device->physical_device_count == 0) return NULL;
+    iree_hal_amdgpu_physical_device_t* physical_device =
+        logical_device->physical_devices[0];
+    if (physical_device->host_queue_count == 0) return NULL;
+    return &physical_device->host_queues[0];
+  }
+
+ private:
+  // Creation context supplying the proactor pool and frontier tracker.
+  iree::hal::cts::DeviceCreateContext create_context_;
+
+  // Test-owned device reference released before the topology-owning group.
+  iree_hal_device_t* base_device_ = NULL;
+
+  // Device group that owns the topology assigned to |base_device_|.
+  iree_hal_device_group_t* device_group_ = NULL;
+};
+
+static iree_status_t QueueAffinityForPhysicalDevice(
+    const TestLogicalDevice& test_device,
+    iree_host_size_t physical_device_ordinal,
+    iree_hal_queue_affinity_t* out_queue_affinity) {
+  iree_hal_amdgpu_logical_device_t* logical_device =
+      test_device.logical_device();
+  const iree_hal_amdgpu_queue_affinity_domain_t domain = {
+      .supported_affinity = logical_device->queue_affinity_mask,
+      .physical_device_count = logical_device->physical_device_count,
+      .queue_count_per_physical_device =
+          logical_device->system->topology.gpu_agent_queue_count,
+  };
+  return iree_hal_amdgpu_queue_affinity_for_physical_device(
+      domain, physical_device_ordinal, out_queue_affinity);
+}
+
+static iree_status_t CreateHostVisibleTransferBuffer(
+    iree_hal_allocator_t* allocator, iree_device_size_t buffer_size,
+    iree_hal_buffer_t** out_buffer) {
+  iree_hal_buffer_params_t params = {0};
+  params.type =
+      IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL | IREE_HAL_MEMORY_TYPE_HOST_VISIBLE;
+  params.usage = IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING;
+  return iree_hal_allocator_allocate_buffer(allocator, params, buffer_size,
+                                            out_buffer);
+}
+
+static iree_status_t CreateHostVisibleDispatchBuffer(
+    iree_hal_allocator_t* allocator, iree_device_size_t buffer_size,
+    iree_hal_buffer_t** out_buffer) {
+  iree_hal_buffer_params_t params = {0};
+  params.type =
+      IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL | IREE_HAL_MEMORY_TYPE_HOST_VISIBLE;
+  params.access = IREE_HAL_MEMORY_ACCESS_ALL;
+  params.usage = IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE |
+                 IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING;
+  return iree_hal_allocator_allocate_buffer(allocator, params, buffer_size,
+                                            out_buffer);
+}
+
+static iree_status_t CreateHostVisibleIndirectParameterBuffer(
+    iree_hal_allocator_t* allocator, iree_device_size_t buffer_size,
+    iree_hal_buffer_t** out_buffer) {
+  iree_hal_buffer_params_t params = {0};
+  params.type =
+      IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL | IREE_HAL_MEMORY_TYPE_HOST_VISIBLE;
+  params.access = IREE_HAL_MEMORY_ACCESS_ALL;
+  params.usage = IREE_HAL_BUFFER_USAGE_DISPATCH_INDIRECT_PARAMETERS |
+                 IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING;
+  return iree_hal_allocator_allocate_buffer(allocator, params, buffer_size,
+                                            out_buffer);
+}
+
+static iree_const_byte_span_t FindCtsExecutableData(
+    iree_string_view_t file_name) {
+  const iree_file_toc_t* toc = iree_cts_testdata_amdgpu_create();
+  for (iree_host_size_t i = 0; toc[i].name != nullptr; ++i) {
+    if (iree_string_view_equal(file_name,
+                               iree_make_cstring_view(toc[i].name))) {
+      return iree_make_const_byte_span(
+          reinterpret_cast<const uint8_t*>(toc[i].data), toc[i].size);
+    }
+  }
+  return iree_const_byte_span_empty();
+}
+
+static iree_status_t LoadCtsExecutable(
+    iree_hal_device_t* device, iree_string_view_t file_name,
+    iree_hal_executable_cache_t** out_executable_cache,
+    iree_hal_executable_t** out_executable) {
+  *out_executable_cache = NULL;
+  *out_executable = NULL;
+
+  iree_const_byte_span_t executable_data = FindCtsExecutableData(file_name);
+  if (IREE_UNLIKELY(executable_data.data_length == 0)) {
+    return iree_make_status(IREE_STATUS_NOT_FOUND,
+                            "AMDGPU CTS executable not found");
+  }
+
+  iree_hal_executable_cache_t* executable_cache = NULL;
+  iree_hal_executable_t* executable = NULL;
+  iree_status_t status = iree_hal_executable_cache_create(
+      device, iree_make_cstring_view("default"), &executable_cache);
+
+  char executable_format[128] = {0};
+  iree_host_size_t inferred_size = 0;
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_executable_cache_infer_format(
+        executable_cache, IREE_HAL_EXECUTABLE_CACHING_MODE_ALIAS_PROVIDED_DATA,
+        executable_data, IREE_ARRAYSIZE(executable_format), executable_format,
+        &inferred_size);
+  }
+  (void)inferred_size;
+
+  if (iree_status_is_ok(status)) {
+    iree_hal_executable_params_t executable_params;
+    iree_hal_executable_params_initialize(&executable_params);
+    executable_params.caching_mode =
+        IREE_HAL_EXECUTABLE_CACHING_MODE_ALIAS_PROVIDED_DATA;
+    executable_params.executable_format =
+        iree_make_cstring_view(executable_format);
+    executable_params.executable_data = executable_data;
+    status = iree_hal_executable_cache_prepare_executable(
+        executable_cache, &executable_params, &executable);
+  }
+
+  if (iree_status_is_ok(status)) {
+    *out_executable_cache = executable_cache;
+    *out_executable = executable;
+  } else {
+    iree_hal_executable_release(executable);
+    iree_hal_executable_cache_release(executable_cache);
+  }
+  return status;
+}
+
+static iree_status_t QueueTransientTransferBuffer(
+    iree_hal_device_t* device, const iree_hal_semaphore_list_t signal_list,
+    iree_device_size_t buffer_size, iree_hal_buffer_t** out_buffer) {
+  iree_hal_buffer_params_t params = {0};
+  params.type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE;
+  params.access = IREE_HAL_MEMORY_ACCESS_ALL;
+  params.usage = IREE_HAL_BUFFER_USAGE_TRANSFER;
+  return iree_hal_device_queue_alloca(device, IREE_HAL_QUEUE_AFFINITY_ANY,
+                                      iree_hal_semaphore_list_empty(),
+                                      signal_list,
+                                      /*pool=*/NULL, params, buffer_size,
+                                      IREE_HAL_ALLOCA_FLAG_NONE, out_buffer);
+}
+
+static iree_status_t EnqueueRawBlockingBarrier(
+    iree_hal_amdgpu_host_queue_t* queue, hsa_signal_t blocker_signal) {
+  const uint64_t packet_id =
+      iree_hal_amdgpu_aql_ring_reserve(&queue->aql_ring, /*count=*/1);
+  iree_hal_amdgpu_aql_packet_t* packet =
+      iree_hal_amdgpu_aql_ring_packet(&queue->aql_ring, packet_id);
+  const hsa_signal_t dep_signals[1] = {blocker_signal};
+  const uint16_t header = iree_hal_amdgpu_aql_emit_barrier_and(
+      &packet->barrier_and, dep_signals, IREE_ARRAYSIZE(dep_signals),
+      iree_hal_amdgpu_aql_packet_control_barrier_system(),
+      iree_hsa_signal_null());
+  iree_hal_amdgpu_aql_ring_commit(packet, header, /*setup=*/0);
+  iree_hal_amdgpu_aql_ring_doorbell(&queue->aql_ring, packet_id);
+  return iree_ok_status();
+}
+
+static bool HostQueueHasPostDrainAction(iree_hal_amdgpu_host_queue_t* queue) {
+  iree_slim_mutex_lock(&queue->locks.post_drain_mutex);
+  const bool has_action = queue->post_drain.head != NULL;
+  iree_slim_mutex_unlock(&queue->locks.post_drain_mutex);
+  return has_action;
+}
+
+static iree_status_t CreateSemaphore(iree_hal_device_t* device,
+                                     iree_hal_semaphore_t** out_semaphore) {
+  return iree_hal_semaphore_create(
+      device, IREE_HAL_QUEUE_AFFINITY_ANY,
+      /*initial_value=*/0, IREE_HAL_SEMAPHORE_FLAG_DEFAULT, out_semaphore);
+}
+
+static iree_status_t SubmitProfiledQueueFill(TestLogicalDevice* test_device) {
+  Ref<iree_hal_buffer_t> target_buffer;
+  IREE_RETURN_IF_ERROR(CreateHostVisibleTransferBuffer(
+      test_device->allocator(), sizeof(uint32_t), target_buffer.out()));
+
+  Ref<iree_hal_semaphore_t> signal;
+  IREE_RETURN_IF_ERROR(
+      CreateSemaphore(test_device->base_device(), signal.out()));
+  uint64_t signal_value = 1;
+  iree_hal_semaphore_t* signal_ptr = signal.get();
+  const iree_hal_semaphore_list_t signal_list = {
+      /*count=*/1,
+      /*semaphores=*/&signal_ptr,
+      /*payload_values=*/&signal_value,
+  };
+  const uint32_t pattern = 0xA11CA7E5u;
+  IREE_RETURN_IF_ERROR(iree_hal_device_queue_fill(
+      test_device->base_device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+      iree_hal_semaphore_list_empty(), signal_list, target_buffer,
+      /*target_offset=*/0, sizeof(pattern), &pattern, sizeof(pattern),
+      IREE_HAL_FILL_FLAG_NONE));
+  return iree_hal_semaphore_wait(signal, signal_value, iree_infinite_timeout(),
+                                 IREE_ASYNC_WAIT_FLAG_NONE);
+}
+
+class DeviceProfilingScope {
+ public:
+  explicit DeviceProfilingScope(iree_hal_device_t* device) : device_(device) {}
+
+  ~DeviceProfilingScope() {
+    if (is_active_) {
+      IREE_EXPECT_OK(iree_hal_device_profiling_end(device_));
+    }
+  }
+
+  iree_status_t Begin(iree_hal_device_profiling_data_families_t data_families,
+                      iree_hal_profile_sink_t* sink) {
+    iree_hal_device_profiling_options_t options = {0};
+    options.data_families = data_families;
+    options.sink = sink;
+    return Begin(&options);
+  }
+
+  iree_status_t Begin(const iree_hal_device_profiling_options_t* options) {
+    iree_status_t status = iree_hal_device_profiling_begin(device_, options);
+    if (iree_status_is_ok(status)) {
+      is_active_ = true;
+    }
+    return status;
+  }
+
+  iree_status_t End() {
+    if (!is_active_) return iree_ok_status();
+    is_active_ = false;
+    return iree_hal_device_profiling_end(device_);
+  }
+
+ private:
+  // Device whose profiling session is active.
+  iree_hal_device_t* device_ = nullptr;
+
+  // True when |device_| has an active profiling session owned by this scope.
+  bool is_active_ = false;
+};
+
+struct CommandBufferProfileSink {
+  // HAL resource header for the profile sink.
+  iree_hal_resource_t resource;
+
+  // Number of session begin notifications observed.
+  int begin_count = 0;
+
+  // Number of session end notifications observed.
+  int end_count = 0;
+
+  // Number of device metadata chunks observed.
+  int device_metadata_count = 0;
+
+  // Number of queue metadata chunks observed.
+  int queue_metadata_count = 0;
+
+  // Number of executable metadata chunks observed.
+  int executable_metadata_count = 0;
+
+  // Number of executable export metadata chunks observed.
+  int executable_export_metadata_count = 0;
+
+  // Number of command-buffer metadata chunks observed.
+  int command_buffer_metadata_count = 0;
+
+  // Number of command-operation metadata chunks observed.
+  int command_operation_metadata_count = 0;
+
+  // Number of clock correlation chunks observed.
+  int clock_correlation_count = 0;
+
+  // Number of host queue event chunks observed.
+  int queue_event_count = 0;
+
+  // Number of device queue event chunks observed.
+  int queue_device_event_count = 0;
+
+  // Number of memory event chunks observed.
+  int memory_event_count = 0;
+
+  // Number of event relationship chunks observed.
+  int relationship_count = 0;
+
+  // Number of counter set metadata chunks observed.
+  int counter_set_metadata_count = 0;
+
+  // Number of counter metadata chunks observed.
+  int counter_metadata_count = 0;
+
+  // Number of counter sample chunks observed.
+  int counter_sample_count = 0;
+
+  // Number of chunks marked truncated by the producer.
+  int truncated_chunk_count = 0;
+
+  // Total dropped records reported by truncated chunks.
+  uint64_t dropped_record_count = 0;
+
+  // Dropped queue event records reported by QUEUE_EVENTS chunks.
+  uint64_t queue_event_dropped_record_count = 0;
+
+  // Dropped memory event records reported by MEMORY_EVENTS chunks.
+  uint64_t memory_event_dropped_record_count = 0;
+
+  // Device metadata records copied from DEVICES chunks.
+  std::vector<iree_hal_profile_device_record_t> device_records;
+
+  // Queue metadata records copied from QUEUES chunks.
+  std::vector<iree_hal_profile_queue_record_t> queue_records;
+
+  // Executable identifiers copied from EXECUTABLES chunks.
+  std::vector<uint64_t> executable_ids;
+
+  // Executable identifiers copied from EXECUTABLE_EXPORTS chunks.
+  std::vector<uint64_t> executable_export_ids;
+
+  // Command-buffer identifiers copied from COMMAND_BUFFERS chunks.
+  std::vector<uint64_t> command_buffer_ids;
+
+  // Command operations copied from COMMAND_OPERATIONS chunks.
+  std::vector<iree_hal_profile_command_operation_record_t> command_operations;
+
+  // Clock correlation records copied from CLOCK_CORRELATIONS chunks.
+  std::vector<iree_hal_profile_clock_correlation_record_t> clock_correlations;
+
+  // Host queue events copied from QUEUE_EVENTS chunks.
+  std::vector<iree_hal_profile_queue_event_t> queue_events;
+
+  // Device queue events copied from QUEUE_DEVICE_EVENTS chunks.
+  std::vector<iree_hal_profile_queue_device_event_t> queue_device_events;
+
+  // Memory events copied from MEMORY_EVENTS chunks.
+  std::vector<iree_hal_profile_memory_event_t> memory_events;
+
+  // Event relationships copied from EVENT_RELATIONSHIPS chunks.
+  std::vector<iree_hal_profile_event_relationship_record_t> event_relationships;
+
+  // Dispatch events copied from DISPATCH_EVENTS chunks.
+  std::vector<iree_hal_profile_dispatch_event_t> dispatch_events;
+
+  // Counter sample records copied from COUNTER_SAMPLES chunks.
+  std::vector<iree_hal_profile_counter_sample_record_t> counter_samples;
+
+  // Counter sample values copied from COUNTER_SAMPLES chunks.
+  std::vector<uint64_t> counter_sample_values;
+
+  // Counter set metadata records copied from COUNTER_SETS chunks.
+  std::vector<iree_hal_profile_counter_set_record_t> counter_set_records;
+
+  // Counter metadata records copied from COUNTERS chunks.
+  std::vector<iree_hal_profile_counter_record_t> counter_records;
+
+  // Physical device ordinals for entries in |dispatch_events|.
+  std::vector<uint32_t> dispatch_event_physical_device_ordinals;
+
+  // Session identifier observed at begin and expected on later callbacks.
+  uint64_t session_id = 0;
+
+  // True if the backend writes after ending the profiling session.
+  bool write_after_end = false;
+
+  // Status code returned from begin_session, or OK for success.
+  iree_status_code_t fail_begin_session_status_code = IREE_STATUS_OK;
+
+  // Content type whose write callback should fail, or empty when disabled.
+  iree_string_view_t fail_write_content_type = {nullptr, 0};
+
+  // Number of matching write callbacks that should fail.
+  int fail_write_remaining = 0;
+
+  // Status code returned from matching write callbacks.
+  iree_status_code_t fail_write_status_code = IREE_STATUS_OK;
+
+  // Expected session status code passed to end_session.
+  iree_status_code_t expected_end_session_status_code = IREE_STATUS_OK;
+
+  // Status code observed by the most recent end_session callback.
+  iree_status_code_t observed_end_session_status_code = IREE_STATUS_OK;
+
+  // Status code returned from end_session, or OK for success.
+  iree_status_code_t fail_end_session_status_code = IREE_STATUS_OK;
+};
+
+static CommandBufferProfileSink* CommandBufferProfileSinkCast(
+    iree_hal_profile_sink_t* sink) {
+  return reinterpret_cast<CommandBufferProfileSink*>(sink);
+}
+
+static void CommandBufferProfileSinkDestroy(iree_hal_profile_sink_t* sink) {
+  (void)sink;
+}
+
+static iree_status_t CommandBufferProfileSinkBeginSession(
+    iree_hal_profile_sink_t* sink,
+    const iree_hal_profile_chunk_metadata_t* metadata) {
+  CommandBufferProfileSink* test_sink = CommandBufferProfileSinkCast(sink);
+  EXPECT_EQ(0, test_sink->begin_count);
+  EXPECT_EQ(0, test_sink->end_count);
+  EXPECT_TRUE(iree_string_view_equal(metadata->content_type,
+                                     IREE_HAL_PROFILE_CONTENT_TYPE_SESSION));
+  ++test_sink->begin_count;
+  if (test_sink->fail_begin_session_status_code != IREE_STATUS_OK) {
+    return iree_make_status(test_sink->fail_begin_session_status_code,
+                            "injected profile sink begin_session failure");
+  }
+  test_sink->session_id = metadata->session_id;
+  return iree_ok_status();
+}
+
+static iree_status_t CommandBufferProfileSinkWrite(
+    iree_hal_profile_sink_t* sink,
+    const iree_hal_profile_chunk_metadata_t* metadata,
+    iree_host_size_t iovec_count, const iree_const_byte_span_t* iovecs) {
+  CommandBufferProfileSink* test_sink = CommandBufferProfileSinkCast(sink);
+  EXPECT_EQ(1, test_sink->begin_count);
+  EXPECT_EQ(0, test_sink->end_count);
+  if (test_sink->end_count != 0) test_sink->write_after_end = true;
+  EXPECT_EQ(test_sink->session_id, metadata->session_id);
+  if (test_sink->fail_write_remaining != 0 &&
+      iree_string_view_equal(metadata->content_type,
+                             test_sink->fail_write_content_type)) {
+    --test_sink->fail_write_remaining;
+    return iree_make_status(test_sink->fail_write_status_code,
+                            "injected profile sink write failure");
+  }
+  const bool is_truncated =
+      iree_any_bit_set(metadata->flags, IREE_HAL_PROFILE_CHUNK_FLAG_TRUNCATED);
+  if (is_truncated) {
+    ++test_sink->truncated_chunk_count;
+    test_sink->dropped_record_count += metadata->dropped_record_count;
+  }
+  if (iovec_count == 0) {
+    if (iree_string_view_equal(metadata->content_type,
+                               IREE_HAL_PROFILE_CONTENT_TYPE_QUEUE_EVENTS)) {
+      test_sink->queue_event_dropped_record_count +=
+          metadata->dropped_record_count;
+      ++test_sink->queue_event_count;
+    } else if (iree_string_view_equal(
+                   metadata->content_type,
+                   IREE_HAL_PROFILE_CONTENT_TYPE_MEMORY_EVENTS)) {
+      test_sink->memory_event_dropped_record_count +=
+          metadata->dropped_record_count;
+      ++test_sink->memory_event_count;
+    }
+    return iree_ok_status();
+  }
+  if (iovec_count != 1) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "expected exactly one profile chunk iovec");
+  }
+
+  if (iree_string_view_equal(metadata->content_type,
+                             IREE_HAL_PROFILE_CONTENT_TYPE_DEVICES)) {
+    EXPECT_EQ(0u,
+              iovecs[0].data_length % sizeof(iree_hal_profile_device_record_t));
+    const auto* records =
+        reinterpret_cast<const iree_hal_profile_device_record_t*>(
+            iovecs[0].data);
+    const iree_host_size_t record_count =
+        iovecs[0].data_length / sizeof(iree_hal_profile_device_record_t);
+    EXPECT_GT(record_count, 0u);
+    for (iree_host_size_t i = 0; i < record_count; ++i) {
+      EXPECT_EQ(sizeof(iree_hal_profile_device_record_t),
+                records[i].record_length);
+      EXPECT_NE(UINT32_MAX, records[i].physical_device_ordinal);
+      EXPECT_GT(records[i].queue_count, 0u);
+    }
+    test_sink->device_records.insert(test_sink->device_records.end(), records,
+                                     records + record_count);
+    ++test_sink->device_metadata_count;
+  } else if (iree_string_view_equal(metadata->content_type,
+                                    IREE_HAL_PROFILE_CONTENT_TYPE_QUEUES)) {
+    EXPECT_EQ(0u,
+              iovecs[0].data_length % sizeof(iree_hal_profile_queue_record_t));
+    const auto* records =
+        reinterpret_cast<const iree_hal_profile_queue_record_t*>(
+            iovecs[0].data);
+    const iree_host_size_t record_count =
+        iovecs[0].data_length / sizeof(iree_hal_profile_queue_record_t);
+    EXPECT_GT(record_count, 0u);
+    for (iree_host_size_t i = 0; i < record_count; ++i) {
+      EXPECT_EQ(sizeof(iree_hal_profile_queue_record_t),
+                records[i].record_length);
+      EXPECT_NE(UINT32_MAX, records[i].physical_device_ordinal);
+      EXPECT_NE(UINT32_MAX, records[i].queue_ordinal);
+    }
+    test_sink->queue_records.insert(test_sink->queue_records.end(), records,
+                                    records + record_count);
+    ++test_sink->queue_metadata_count;
+  } else if (iree_string_view_equal(
+                 metadata->content_type,
+                 IREE_HAL_PROFILE_CONTENT_TYPE_EXECUTABLES)) {
+    EXPECT_EQ(0u, iovecs[0].data_length %
+                      sizeof(iree_hal_profile_executable_record_t));
+    const auto* records =
+        reinterpret_cast<const iree_hal_profile_executable_record_t*>(
+            iovecs[0].data);
+    const iree_host_size_t record_count =
+        iovecs[0].data_length / sizeof(iree_hal_profile_executable_record_t);
+    EXPECT_GT(record_count, 0u);
+    for (iree_host_size_t i = 0; i < record_count; ++i) {
+      EXPECT_EQ(sizeof(iree_hal_profile_executable_record_t),
+                records[i].record_length);
+      EXPECT_NE(0u, records[i].executable_id);
+      EXPECT_GT(records[i].export_count, 0u);
+      EXPECT_NE(0u, records[i].flags &
+                        IREE_HAL_PROFILE_EXECUTABLE_FLAG_CODE_OBJECT_HASH);
+      EXPECT_NE(
+          0u, records[i].code_object_hash[0] | records[i].code_object_hash[1]);
+      test_sink->executable_ids.push_back(records[i].executable_id);
+    }
+    ++test_sink->executable_metadata_count;
+  } else if (iree_string_view_equal(
+                 metadata->content_type,
+                 IREE_HAL_PROFILE_CONTENT_TYPE_EXECUTABLE_EXPORTS)) {
+    iree_host_size_t payload_offset = 0;
+    while (payload_offset < iovecs[0].data_length) {
+      if (iovecs[0].data_length - payload_offset <
+          sizeof(iree_hal_profile_executable_export_record_t)) {
+        return iree_make_status(IREE_STATUS_DATA_LOSS,
+                                "truncated executable export profile record");
+      }
+      iree_hal_profile_executable_export_record_t record;
+      memcpy(&record, iovecs[0].data + payload_offset, sizeof(record));
+      if (record.record_length < sizeof(record) ||
+          record.record_length > iovecs[0].data_length - payload_offset) {
+        return iree_make_status(IREE_STATUS_DATA_LOSS,
+                                "invalid executable export profile record");
+      }
+      EXPECT_NE(0u, record.executable_id);
+      EXPECT_NE(UINT32_MAX, record.export_ordinal);
+      EXPECT_NE(0u, record.flags &
+                        IREE_HAL_PROFILE_EXECUTABLE_EXPORT_FLAG_PIPELINE_HASH);
+      EXPECT_NE(0u, record.pipeline_hash[0] | record.pipeline_hash[1]);
+      EXPECT_EQ(record.name_length,
+                record.record_length - (uint32_t)sizeof(record));
+      test_sink->executable_export_ids.push_back(record.executable_id);
+      payload_offset += record.record_length;
+    }
+    ++test_sink->executable_export_metadata_count;
+  } else if (iree_string_view_equal(
+                 metadata->content_type,
+                 IREE_HAL_PROFILE_CONTENT_TYPE_COMMAND_BUFFERS)) {
+    EXPECT_EQ(0u, iovecs[0].data_length %
+                      sizeof(iree_hal_profile_command_buffer_record_t));
+    const auto* records =
+        reinterpret_cast<const iree_hal_profile_command_buffer_record_t*>(
+            iovecs[0].data);
+    const iree_host_size_t record_count =
+        iovecs[0].data_length /
+        sizeof(iree_hal_profile_command_buffer_record_t);
+    EXPECT_GT(record_count, 0u);
+    for (iree_host_size_t i = 0; i < record_count; ++i) {
+      EXPECT_EQ(sizeof(iree_hal_profile_command_buffer_record_t),
+                records[i].record_length);
+      EXPECT_NE(0u, records[i].command_buffer_id);
+      EXPECT_NE(UINT32_MAX, records[i].physical_device_ordinal);
+      test_sink->command_buffer_ids.push_back(records[i].command_buffer_id);
+    }
+    ++test_sink->command_buffer_metadata_count;
+  } else if (iree_string_view_equal(
+                 metadata->content_type,
+                 IREE_HAL_PROFILE_CONTENT_TYPE_COMMAND_OPERATIONS)) {
+    EXPECT_EQ(0u, iovecs[0].data_length %
+                      sizeof(iree_hal_profile_command_operation_record_t));
+    const auto* records =
+        reinterpret_cast<const iree_hal_profile_command_operation_record_t*>(
+            iovecs[0].data);
+    const iree_host_size_t record_count =
+        iovecs[0].data_length /
+        sizeof(iree_hal_profile_command_operation_record_t);
+    EXPECT_GT(record_count, 0u);
+    for (iree_host_size_t i = 0; i < record_count; ++i) {
+      EXPECT_EQ(sizeof(iree_hal_profile_command_operation_record_t),
+                records[i].record_length);
+      EXPECT_NE(IREE_HAL_PROFILE_COMMAND_OPERATION_TYPE_NONE, records[i].type);
+      EXPECT_NE(UINT32_MAX, records[i].command_index);
+      EXPECT_NE(0u, records[i].command_buffer_id);
+      EXPECT_TRUE(
+          iree_hal_profile_command_operation_has_block_structure(&records[i]));
+      EXPECT_NE(UINT32_MAX, records[i].block_ordinal);
+      EXPECT_NE(UINT32_MAX, records[i].block_command_ordinal);
+      if (records[i].type == IREE_HAL_PROFILE_COMMAND_OPERATION_TYPE_DISPATCH) {
+        EXPECT_NE(0u, records[i].executable_id);
+        EXPECT_NE(UINT32_MAX, records[i].export_ordinal);
+        EXPECT_NE(0u, records[i].binding_count);
+        EXPECT_NE(0u, records[i].workgroup_size[0]);
+      }
+    }
+    test_sink->command_operations.insert(test_sink->command_operations.end(),
+                                         records, records + record_count);
+    ++test_sink->command_operation_metadata_count;
+  } else if (iree_string_view_equal(
+                 metadata->content_type,
+                 IREE_HAL_PROFILE_CONTENT_TYPE_CLOCK_CORRELATIONS)) {
+    EXPECT_EQ(0u, iovecs[0].data_length %
+                      sizeof(iree_hal_profile_clock_correlation_record_t));
+    const auto* records =
+        reinterpret_cast<const iree_hal_profile_clock_correlation_record_t*>(
+            iovecs[0].data);
+    const iree_host_size_t record_count =
+        iovecs[0].data_length /
+        sizeof(iree_hal_profile_clock_correlation_record_t);
+    EXPECT_GT(record_count, 0u);
+    for (iree_host_size_t i = 0; i < record_count; ++i) {
+      EXPECT_EQ(sizeof(iree_hal_profile_clock_correlation_record_t),
+                records[i].record_length);
+      EXPECT_NE(UINT32_MAX, records[i].physical_device_ordinal);
+      EXPECT_NE(0u, records[i].sample_id);
+      EXPECT_TRUE(iree_all_bits_set(
+          records[i].flags,
+          IREE_HAL_PROFILE_CLOCK_CORRELATION_FLAG_DEVICE_TICK |
+              IREE_HAL_PROFILE_CLOCK_CORRELATION_FLAG_HOST_CPU_TIMESTAMP |
+              IREE_HAL_PROFILE_CLOCK_CORRELATION_FLAG_HOST_SYSTEM_TIMESTAMP |
+              IREE_HAL_PROFILE_CLOCK_CORRELATION_FLAG_HOST_TIME_BRACKET));
+      EXPECT_NE(0u, records[i].device_tick);
+      EXPECT_NE(0u, records[i].host_cpu_timestamp_ns);
+      EXPECT_NE(0u, records[i].host_system_timestamp);
+      EXPECT_NE(0u, records[i].host_system_frequency_hz);
+      EXPECT_LE(records[i].host_time_begin_ns, records[i].host_time_end_ns);
+    }
+    test_sink->clock_correlations.insert(test_sink->clock_correlations.end(),
+                                         records, records + record_count);
+    ++test_sink->clock_correlation_count;
+  } else if (iree_string_view_equal(
+                 metadata->content_type,
+                 IREE_HAL_PROFILE_CONTENT_TYPE_COUNTER_SETS)) {
+    iree_host_size_t payload_offset = 0;
+    while (payload_offset < iovecs[0].data_length) {
+      if (iovecs[0].data_length - payload_offset <
+          sizeof(iree_hal_profile_counter_set_record_t)) {
+        return iree_make_status(IREE_STATUS_DATA_LOSS,
+                                "truncated counter set profile record");
+      }
+      iree_hal_profile_counter_set_record_t record;
+      memcpy(&record, iovecs[0].data + payload_offset, sizeof(record));
+      if (record.record_length < sizeof(record) ||
+          record.record_length > iovecs[0].data_length - payload_offset) {
+        return iree_make_status(IREE_STATUS_DATA_LOSS,
+                                "invalid counter set profile record");
+      }
+      EXPECT_NE(0u, record.counter_set_id);
+      EXPECT_GT(record.counter_count, 0u);
+      EXPECT_GT(record.sample_value_count, 0u);
+      EXPECT_EQ(record.name_length,
+                record.record_length - (uint32_t)sizeof(record));
+      test_sink->counter_set_records.push_back(record);
+      payload_offset += record.record_length;
+    }
+    ++test_sink->counter_set_metadata_count;
+  } else if (iree_string_view_equal(metadata->content_type,
+                                    IREE_HAL_PROFILE_CONTENT_TYPE_COUNTERS)) {
+    iree_host_size_t payload_offset = 0;
+    while (payload_offset < iovecs[0].data_length) {
+      if (iovecs[0].data_length - payload_offset <
+          sizeof(iree_hal_profile_counter_record_t)) {
+        return iree_make_status(IREE_STATUS_DATA_LOSS,
+                                "truncated counter profile record");
+      }
+      iree_hal_profile_counter_record_t record;
+      memcpy(&record, iovecs[0].data + payload_offset, sizeof(record));
+      if (record.record_length < sizeof(record) ||
+          record.record_length > iovecs[0].data_length - payload_offset) {
+        return iree_make_status(IREE_STATUS_DATA_LOSS,
+                                "invalid counter profile record");
+      }
+      EXPECT_NE(0u, record.counter_set_id);
+      EXPECT_GT(record.sample_value_count, 0u);
+      const uint32_t string_length = record.block_name_length +
+                                     record.name_length +
+                                     record.description_length;
+      EXPECT_EQ(string_length, record.record_length - (uint32_t)sizeof(record));
+      test_sink->counter_records.push_back(record);
+      payload_offset += record.record_length;
+    }
+    ++test_sink->counter_metadata_count;
+  } else if (iree_string_view_equal(
+                 metadata->content_type,
+                 IREE_HAL_PROFILE_CONTENT_TYPE_DISPATCH_EVENTS)) {
+    EXPECT_NE(UINT32_MAX, metadata->physical_device_ordinal);
+    EXPECT_NE(UINT32_MAX, metadata->queue_ordinal);
+    EXPECT_EQ(
+        0u, iovecs[0].data_length % sizeof(iree_hal_profile_dispatch_event_t));
+    const auto* records =
+        reinterpret_cast<const iree_hal_profile_dispatch_event_t*>(
+            iovecs[0].data);
+    const iree_host_size_t record_count =
+        iovecs[0].data_length / sizeof(iree_hal_profile_dispatch_event_t);
+    EXPECT_GT(record_count, 0u);
+    test_sink->dispatch_events.insert(test_sink->dispatch_events.end(), records,
+                                      records + record_count);
+    test_sink->dispatch_event_physical_device_ordinals.insert(
+        test_sink->dispatch_event_physical_device_ordinals.end(), record_count,
+        metadata->physical_device_ordinal);
+  } else if (iree_string_view_equal(
+                 metadata->content_type,
+                 IREE_HAL_PROFILE_CONTENT_TYPE_QUEUE_EVENTS)) {
+    EXPECT_EQ(0u,
+              iovecs[0].data_length % sizeof(iree_hal_profile_queue_event_t));
+    const auto* records =
+        reinterpret_cast<const iree_hal_profile_queue_event_t*>(iovecs[0].data);
+    const iree_host_size_t record_count =
+        iovecs[0].data_length / sizeof(iree_hal_profile_queue_event_t);
+    EXPECT_GT(record_count, 0u);
+    for (iree_host_size_t i = 0; i < record_count; ++i) {
+      EXPECT_EQ(sizeof(iree_hal_profile_queue_event_t),
+                records[i].record_length);
+      EXPECT_NE(IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_NONE, records[i].type);
+      EXPECT_NE(0u, records[i].event_id);
+      EXPECT_NE(0, records[i].host_time_ns);
+      EXPECT_NE(UINT32_MAX, records[i].physical_device_ordinal);
+      EXPECT_NE(UINT32_MAX, records[i].queue_ordinal);
+    }
+    test_sink->queue_events.insert(test_sink->queue_events.end(), records,
+                                   records + record_count);
+    test_sink->queue_event_dropped_record_count +=
+        metadata->dropped_record_count;
+    ++test_sink->queue_event_count;
+  } else if (iree_string_view_equal(
+                 metadata->content_type,
+                 IREE_HAL_PROFILE_CONTENT_TYPE_QUEUE_DEVICE_EVENTS)) {
+    EXPECT_NE(UINT32_MAX, metadata->physical_device_ordinal);
+    EXPECT_NE(UINT32_MAX, metadata->queue_ordinal);
+    EXPECT_EQ(0u, iovecs[0].data_length %
+                      sizeof(iree_hal_profile_queue_device_event_t));
+    const auto* records =
+        reinterpret_cast<const iree_hal_profile_queue_device_event_t*>(
+            iovecs[0].data);
+    const iree_host_size_t record_count =
+        iovecs[0].data_length / sizeof(iree_hal_profile_queue_device_event_t);
+    EXPECT_GT(record_count, 0u);
+    for (iree_host_size_t i = 0; i < record_count; ++i) {
+      EXPECT_EQ(sizeof(iree_hal_profile_queue_device_event_t),
+                records[i].record_length);
+      EXPECT_NE(IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_NONE, records[i].type);
+      EXPECT_NE(0u, records[i].event_id);
+      EXPECT_NE(0u, records[i].submission_id);
+      EXPECT_NE(UINT32_MAX, records[i].physical_device_ordinal);
+      EXPECT_NE(UINT32_MAX, records[i].queue_ordinal);
+      EXPECT_NE(0u, records[i].start_tick);
+      EXPECT_NE(0u, records[i].end_tick);
+      EXPECT_GE(records[i].end_tick, records[i].start_tick);
+    }
+    test_sink->queue_device_events.insert(test_sink->queue_device_events.end(),
+                                          records, records + record_count);
+    ++test_sink->queue_device_event_count;
+  } else if (iree_string_view_equal(
+                 metadata->content_type,
+                 IREE_HAL_PROFILE_CONTENT_TYPE_MEMORY_EVENTS)) {
+    EXPECT_EQ(0u,
+              iovecs[0].data_length % sizeof(iree_hal_profile_memory_event_t));
+    const auto* records =
+        reinterpret_cast<const iree_hal_profile_memory_event_t*>(
+            iovecs[0].data);
+    const iree_host_size_t record_count =
+        iovecs[0].data_length / sizeof(iree_hal_profile_memory_event_t);
+    EXPECT_GT(record_count, 0u);
+    for (iree_host_size_t i = 0; i < record_count; ++i) {
+      EXPECT_EQ(sizeof(iree_hal_profile_memory_event_t),
+                records[i].record_length);
+      EXPECT_NE(IREE_HAL_PROFILE_MEMORY_EVENT_TYPE_NONE, records[i].type);
+      EXPECT_NE(0u, records[i].event_id);
+      EXPECT_NE(0, records[i].host_time_ns);
+      EXPECT_NE(0u, records[i].allocation_id);
+    }
+    test_sink->memory_events.insert(test_sink->memory_events.end(), records,
+                                    records + record_count);
+    test_sink->memory_event_dropped_record_count +=
+        metadata->dropped_record_count;
+    ++test_sink->memory_event_count;
+  } else if (iree_string_view_equal(
+                 metadata->content_type,
+                 IREE_HAL_PROFILE_CONTENT_TYPE_EVENT_RELATIONSHIPS)) {
+    EXPECT_NE(UINT32_MAX, metadata->physical_device_ordinal);
+    EXPECT_NE(UINT32_MAX, metadata->queue_ordinal);
+    EXPECT_EQ(0u, iovecs[0].data_length %
+                      sizeof(iree_hal_profile_event_relationship_record_t));
+    const auto* records =
+        reinterpret_cast<const iree_hal_profile_event_relationship_record_t*>(
+            iovecs[0].data);
+    const iree_host_size_t record_count =
+        iovecs[0].data_length /
+        sizeof(iree_hal_profile_event_relationship_record_t);
+    EXPECT_GT(record_count, 0u);
+    for (iree_host_size_t i = 0; i < record_count; ++i) {
+      EXPECT_EQ(sizeof(iree_hal_profile_event_relationship_record_t),
+                records[i].record_length);
+      EXPECT_NE(IREE_HAL_PROFILE_EVENT_RELATIONSHIP_TYPE_NONE, records[i].type);
+      EXPECT_NE(0u, records[i].relationship_id);
+      EXPECT_NE(IREE_HAL_PROFILE_EVENT_ENDPOINT_TYPE_NONE,
+                records[i].source_type);
+      EXPECT_NE(IREE_HAL_PROFILE_EVENT_ENDPOINT_TYPE_NONE,
+                records[i].target_type);
+      EXPECT_NE(UINT32_MAX, records[i].physical_device_ordinal);
+      EXPECT_NE(UINT32_MAX, records[i].queue_ordinal);
+      EXPECT_NE(0u, records[i].source_id);
+      EXPECT_NE(0u, records[i].target_id);
+    }
+    test_sink->event_relationships.insert(test_sink->event_relationships.end(),
+                                          records, records + record_count);
+    ++test_sink->relationship_count;
+  } else if (iree_string_view_equal(
+                 metadata->content_type,
+                 IREE_HAL_PROFILE_CONTENT_TYPE_COUNTER_SAMPLES)) {
+    iree_host_size_t payload_offset = 0;
+    while (payload_offset < iovecs[0].data_length) {
+      if (iovecs[0].data_length - payload_offset <
+          sizeof(iree_hal_profile_counter_sample_record_t)) {
+        return iree_make_status(IREE_STATUS_DATA_LOSS,
+                                "truncated counter sample profile record");
+      }
+      iree_hal_profile_counter_sample_record_t record;
+      memcpy(&record, iovecs[0].data + payload_offset, sizeof(record));
+      if (record.record_length < sizeof(record) ||
+          record.record_length > iovecs[0].data_length - payload_offset) {
+        return iree_make_status(IREE_STATUS_DATA_LOSS,
+                                "invalid counter sample profile record");
+      }
+      EXPECT_NE(0u, record.sample_id);
+      EXPECT_NE(0u, record.counter_set_id);
+      switch (record.scope) {
+        case IREE_HAL_PROFILE_COUNTER_SAMPLE_SCOPE_DISPATCH:
+          EXPECT_NE(0u, record.dispatch_event_id);
+          break;
+        case IREE_HAL_PROFILE_COUNTER_SAMPLE_SCOPE_DEVICE_TIME_RANGE:
+          EXPECT_EQ(0u, record.dispatch_event_id);
+          EXPECT_TRUE(iree_any_bit_set(
+              record.flags,
+              IREE_HAL_PROFILE_COUNTER_SAMPLE_FLAG_DEVICE_TICK_RANGE));
+          break;
+        default:
+          ADD_FAILURE() << "unexpected counter sample scope " << record.scope;
+          break;
+      }
+      EXPECT_GT(record.sample_value_count, 0u);
+      EXPECT_EQ(record.record_length,
+                sizeof(record) +
+                    record.sample_value_count * (uint32_t)sizeof(uint64_t));
+      const auto* values = reinterpret_cast<const uint64_t*>(
+          iovecs[0].data + payload_offset + sizeof(record));
+      test_sink->counter_sample_values.insert(
+          test_sink->counter_sample_values.end(), values,
+          values + record.sample_value_count);
+      test_sink->counter_samples.push_back(record);
+      payload_offset += record.record_length;
+    }
+    ++test_sink->counter_sample_count;
+  }
+
+  return iree_ok_status();
+}
+
+static iree_status_t CommandBufferProfileSinkEndSession(
+    iree_hal_profile_sink_t* sink,
+    const iree_hal_profile_chunk_metadata_t* metadata,
+    iree_status_code_t session_status_code) {
+  CommandBufferProfileSink* test_sink = CommandBufferProfileSinkCast(sink);
+  EXPECT_EQ(1, test_sink->begin_count);
+  EXPECT_EQ(0, test_sink->end_count);
+  EXPECT_TRUE(iree_string_view_equal(metadata->content_type,
+                                     IREE_HAL_PROFILE_CONTENT_TYPE_SESSION));
+  EXPECT_EQ(test_sink->session_id, metadata->session_id);
+  EXPECT_EQ(test_sink->expected_end_session_status_code, session_status_code);
+  test_sink->observed_end_session_status_code = session_status_code;
+  test_sink->end_count = 1;
+  if (test_sink->fail_end_session_status_code != IREE_STATUS_OK) {
+    return iree_make_status(test_sink->fail_end_session_status_code,
+                            "injected profile sink end_session failure");
+  }
+  return iree_ok_status();
+}
+
+static const iree_hal_profile_sink_vtable_t kCommandBufferProfileSinkVTable = {
+    /*.destroy=*/CommandBufferProfileSinkDestroy,
+    /*.begin_session=*/CommandBufferProfileSinkBeginSession,
+    /*.write=*/CommandBufferProfileSinkWrite,
+    /*.end_session=*/CommandBufferProfileSinkEndSession,
+};
+
+static void CommandBufferProfileSinkInitialize(CommandBufferProfileSink* sink) {
+  iree_hal_resource_initialize(&kCommandBufferProfileSinkVTable,
+                               &sink->resource);
+}
+
+static iree_hal_profile_sink_t* CommandBufferProfileSinkAsBase(
+    CommandBufferProfileSink* sink) {
+  return reinterpret_cast<iree_hal_profile_sink_t*>(sink);
+}
+
+static void ExpectQueueEventProfilingCanBeginAndEnd(
+    TestLogicalDevice* test_device) {
+  CommandBufferProfileSink sink = {};
+  CommandBufferProfileSinkInitialize(&sink);
+  DeviceProfilingScope profiling(test_device->base_device());
+  IREE_ASSERT_OK(profiling.Begin(IREE_HAL_DEVICE_PROFILING_DATA_QUEUE_EVENTS,
+                                 CommandBufferProfileSinkAsBase(&sink)));
+  IREE_ASSERT_OK(profiling.End());
+  EXPECT_EQ(1, sink.begin_count);
+  EXPECT_EQ(1, sink.end_count);
+}
+
+static void ExpectDispatchEventsWithinClockCorrelationRange(
+    const CommandBufferProfileSink& sink) {
+  ASSERT_GE(sink.clock_correlations.size(), 2u);
+  ASSERT_EQ(sink.dispatch_events.size(),
+            sink.dispatch_event_physical_device_ordinals.size());
+  for (iree_host_size_t event_index = 0;
+       event_index < sink.dispatch_events.size(); ++event_index) {
+    const uint32_t physical_device_ordinal =
+        sink.dispatch_event_physical_device_ordinals[event_index];
+    uint64_t min_device_tick = UINT64_MAX;
+    uint64_t max_device_tick = 0;
+    for (const iree_hal_profile_clock_correlation_record_t& correlation :
+         sink.clock_correlations) {
+      if (correlation.physical_device_ordinal != physical_device_ordinal ||
+          !iree_any_bit_set(
+              correlation.flags,
+              IREE_HAL_PROFILE_CLOCK_CORRELATION_FLAG_DEVICE_TICK)) {
+        continue;
+      }
+      min_device_tick = std::min(min_device_tick, correlation.device_tick);
+      max_device_tick = std::max(max_device_tick, correlation.device_tick);
+    }
+    ASSERT_NE(UINT64_MAX, min_device_tick);
+    ASSERT_NE(0u, max_device_tick);
+    ASSERT_LT(min_device_tick, max_device_tick);
+    EXPECT_GE(sink.dispatch_events[event_index].start_tick, min_device_tick);
+    EXPECT_LE(sink.dispatch_events[event_index].end_tick, max_device_tick);
+  }
+}
+
+static const iree_hal_profile_command_operation_record_t* FindCommandOperation(
+    const CommandBufferProfileSink& sink, uint64_t command_buffer_id,
+    uint32_t command_index) {
+  for (const auto& operation : sink.command_operations) {
+    if (operation.command_buffer_id == command_buffer_id &&
+        operation.command_index == command_index) {
+      return &operation;
+    }
+  }
+  return nullptr;
+}
+
+static const iree_hal_profile_event_relationship_record_t*
+FindEventRelationship(const CommandBufferProfileSink& sink,
+                      iree_hal_profile_event_relationship_type_t type,
+                      iree_hal_profile_event_endpoint_type_t source_type,
+                      uint64_t source_id,
+                      iree_hal_profile_event_endpoint_type_t target_type,
+                      uint64_t target_id) {
+  for (const auto& relationship : sink.event_relationships) {
+    if (relationship.type == type && relationship.source_type == source_type &&
+        relationship.source_id == source_id &&
+        relationship.target_type == target_type &&
+        relationship.target_id == target_id) {
+      return &relationship;
+    }
+  }
+  return nullptr;
+}
+
+static const iree_hal_profile_queue_event_t* FindUniqueQueueEvent(
+    const CommandBufferProfileSink& sink,
+    iree_hal_profile_queue_event_type_t type) {
+  const iree_hal_profile_queue_event_t* result = nullptr;
+  for (const auto& event : sink.queue_events) {
+    if (event.type != type) continue;
+    EXPECT_EQ(nullptr, result);
+    result = &event;
+  }
+  return result;
+}
+
+static iree_host_size_t CountQueueEvents(
+    const CommandBufferProfileSink& sink,
+    iree_hal_profile_queue_event_type_t type) {
+  iree_host_size_t count = 0;
+  for (const auto& event : sink.queue_events) {
+    if (event.type == type) ++count;
+  }
+  return count;
+}
+
+static uint32_t SumQueueEventOperationCounts(
+    const CommandBufferProfileSink& sink,
+    iree_hal_profile_queue_event_type_t type) {
+  uint32_t operation_count = 0;
+  for (const auto& event : sink.queue_events) {
+    if (event.type == type) {
+      operation_count += event.operation_count;
+    }
+  }
+  return operation_count;
+}
+
+static bool IsProfilingUnsupported(iree_status_t status) {
+  return iree_status_is_unimplemented(status) ||
+         iree_status_is_invalid_argument(status);
+}
+
+static bool IsHardwareCounterProfilingUnavailable(iree_status_t status) {
+  return IsProfilingUnsupported(status) || iree_status_is_not_found(status) ||
+         iree_status_is_failed_precondition(status);
+}
+
+static bool IsQueueDeviceProfilingUnavailable(iree_status_t status) {
+  return IsProfilingUnsupported(status) ||
+         iree_status_is_failed_precondition(status);
+}
+
+static iree_status_t BeginHardwareCounterProfiling(
+    DeviceProfilingScope* profiling, CommandBufferProfileSink* sink,
+    iree_host_size_t counter_name_count, iree_string_view_t* counter_names) {
+  iree_hal_profile_counter_set_selection_t counter_set = {
+      /*.flags=*/IREE_HAL_PROFILE_COUNTER_SET_SELECTION_FLAG_NONE,
+      /*.name=*/IREE_SV("smoke"),
+      /*.counter_name_count=*/counter_name_count,
+      /*.counter_names=*/counter_names,
+  };
+  iree_hal_device_profiling_options_t profiling_options = {0};
+  profiling_options.data_families =
+      IREE_HAL_DEVICE_PROFILING_DATA_DISPATCH_EVENTS |
+      IREE_HAL_DEVICE_PROFILING_DATA_COUNTER_SAMPLES;
+  profiling_options.sink = CommandBufferProfileSinkAsBase(sink);
+  profiling_options.counter_set_count = 1;
+  profiling_options.counter_sets = &counter_set;
+  return profiling->Begin(&profiling_options);
+}
+
+static iree_status_t BeginSqWavesProfiling(DeviceProfilingScope* profiling,
+                                           CommandBufferProfileSink* sink) {
+  iree_string_view_t counter_names[] = {
+      IREE_SV("SQ_WAVES"),
+  };
+  return BeginHardwareCounterProfiling(
+      profiling, sink, IREE_ARRAYSIZE(counter_names), counter_names);
+}
+
+static iree_status_t BeginSqWavesCounterRangeProfiling(
+    DeviceProfilingScope* profiling, CommandBufferProfileSink* sink) {
+  iree_string_view_t counter_names[] = {
+      IREE_SV("SQ_WAVES"),
+  };
+  iree_hal_profile_counter_set_selection_t counter_set = {
+      /*.flags=*/IREE_HAL_PROFILE_COUNTER_SET_SELECTION_FLAG_NONE,
+      /*.name=*/IREE_SV("smoke"),
+      /*.counter_name_count=*/IREE_ARRAYSIZE(counter_names),
+      /*.counter_names=*/counter_names,
+  };
+  iree_hal_device_profiling_options_t profiling_options = {0};
+  profiling_options.data_families =
+      IREE_HAL_DEVICE_PROFILING_DATA_COUNTER_RANGES;
+  profiling_options.sink = CommandBufferProfileSinkAsBase(sink);
+  profiling_options.counter_set_count = 1;
+  profiling_options.counter_sets = &counter_set;
+  return profiling->Begin(&profiling_options);
+}
+
+static iree_status_t BeginSqWaveWidthProfiling(DeviceProfilingScope* profiling,
+                                               CommandBufferProfileSink* sink) {
+  iree_string_view_t counter_names[] = {
+      IREE_SV("SQ_WAVES"),
+      IREE_SV("SQ_WAVES_32"),
+      IREE_SV("SQ_WAVES_64"),
+      IREE_SV("SQ_BUSY_CYCLES"),
+  };
+  return BeginHardwareCounterProfiling(
+      profiling, sink, IREE_ARRAYSIZE(counter_names), counter_names);
+}
+
+TEST_F(HostQueueCommandBufferTest,
+       ExplicitHardwareCounterSelectionEmitsMetadataWhenAvailable) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+
+  CommandBufferProfileSink sink = {};
+  CommandBufferProfileSinkInitialize(&sink);
+  DeviceProfilingScope profiling(test_device.base_device());
+  iree_status_t profiling_status = BeginSqWavesProfiling(&profiling, &sink);
+  if (IsQueueDeviceProfilingUnavailable(profiling_status)) {
+    iree_status_free(profiling_status);
+    GTEST_SKIP() << "AMDGPU hardware counter profiling unavailable";
+  }
+  IREE_ASSERT_OK(profiling_status);
+  IREE_ASSERT_OK(profiling.End());
+
+  EXPECT_EQ(1, sink.begin_count);
+  EXPECT_EQ(1, sink.end_count);
+  EXPECT_EQ(1, sink.counter_set_metadata_count);
+  EXPECT_EQ(1, sink.counter_metadata_count);
+}
+
+TEST_F(HostQueueCommandBufferTest,
+       MultipleHardwareCounterSelectionEmitsLayoutWhenAvailable) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+
+  CommandBufferProfileSink sink = {};
+  CommandBufferProfileSinkInitialize(&sink);
+  DeviceProfilingScope profiling(test_device.base_device());
+  iree_status_t profiling_status = BeginSqWaveWidthProfiling(&profiling, &sink);
+  if (IsHardwareCounterProfilingUnavailable(profiling_status)) {
+    iree_status_free(profiling_status);
+    GTEST_SKIP() << "AMDGPU hardware counter profiling unavailable";
+  }
+  IREE_ASSERT_OK(profiling_status);
+  IREE_ASSERT_OK(profiling.End());
+
+  EXPECT_EQ(1, sink.begin_count);
+  EXPECT_EQ(1, sink.end_count);
+  EXPECT_EQ(1, sink.counter_set_metadata_count);
+  EXPECT_EQ(1, sink.counter_metadata_count);
+  ASSERT_FALSE(sink.counter_set_records.empty());
+  ASSERT_EQ(sink.counter_set_records.size() * 4u, sink.counter_records.size());
+  const iree_hal_profile_counter_unit_t expected_units[] = {
+      IREE_HAL_PROFILE_COUNTER_UNIT_COUNT,
+      IREE_HAL_PROFILE_COUNTER_UNIT_COUNT,
+      IREE_HAL_PROFILE_COUNTER_UNIT_COUNT,
+      IREE_HAL_PROFILE_COUNTER_UNIT_CYCLES,
+  };
+  iree_host_size_t counter_record_index = 0;
+  for (const auto& counter_set_record : sink.counter_set_records) {
+    ASSERT_EQ(4u, counter_set_record.counter_count);
+    uint32_t sample_value_count = 0;
+    for (uint32_t i = 0; i < counter_set_record.counter_count; ++i) {
+      const auto& counter_record = sink.counter_records[counter_record_index++];
+      EXPECT_EQ(counter_set_record.counter_set_id,
+                counter_record.counter_set_id);
+      EXPECT_EQ(sample_value_count, counter_record.sample_value_offset);
+      EXPECT_GT(counter_record.sample_value_count, 0u);
+      EXPECT_EQ(expected_units[i], counter_record.unit);
+      sample_value_count += counter_record.sample_value_count;
+    }
+    EXPECT_EQ(sample_value_count, counter_set_record.sample_value_count);
+  }
+}
+
+static iree_status_t AppendConstantsBindingsDispatch(
+    iree_hal_command_buffer_t* command_buffer,
+    iree_hal_executable_t* executable, iree_hal_buffer_ref_list_t bindings) {
+  const uint32_t constant_values[2] = {3, 10};
+  iree_const_byte_span_t constants =
+      iree_make_const_byte_span(constant_values, sizeof(constant_values));
+  return iree_hal_command_buffer_dispatch(
+      command_buffer, executable, /*entry_point=*/0,
+      iree_hal_make_static_dispatch_config(1, 1, 1), constants, bindings,
+      IREE_HAL_DISPATCH_FLAG_NONE);
+}
+
+struct TwoDispatchCommandBuffer {
+  ~TwoDispatchCommandBuffer() {
+    iree_hal_executable_release(executable);
+    iree_hal_executable_cache_release(executable_cache);
+  }
+
+  // Executable cache owning |executable|.
+  iree_hal_executable_cache_t* executable_cache = NULL;
+
+  // CTS executable containing the constants+bindings dispatch entry point.
+  iree_hal_executable_t* executable = NULL;
+
+  // Host-visible input buffer shared by both dispatches.
+  Ref<iree_hal_buffer_t> input_buffer;
+
+  // Host-visible output buffer written by command index 0.
+  Ref<iree_hal_buffer_t> output_buffer0;
+
+  // Host-visible output buffer written by command index 1.
+  Ref<iree_hal_buffer_t> output_buffer1;
+
+  // Command buffer containing two equivalent dispatch operations.
+  Ref<iree_hal_command_buffer_t> command_buffer;
+};
+
+static iree_status_t CreateTwoDispatchCommandBuffer(
+    TestLogicalDevice* test_device, TwoDispatchCommandBuffer* out_fixture,
+    iree_hal_command_buffer_mode_t mode =
+        IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT |
+        IREE_HAL_COMMAND_BUFFER_MODE_RETAIN_PROFILE_METADATA) {
+  IREE_RETURN_IF_ERROR(LoadCtsExecutable(
+      test_device->base_device(),
+      iree_make_cstring_view("command_buffer_dispatch_constants_bindings_test."
+                             "bin"),
+      &out_fixture->executable_cache, &out_fixture->executable));
+
+  IREE_RETURN_IF_ERROR(CreateHostVisibleDispatchBuffer(
+      test_device->allocator(), /*buffer_size=*/4 * sizeof(uint32_t),
+      out_fixture->input_buffer.out()));
+  const uint32_t input_values[4] = {1, 2, 3, 4};
+  IREE_RETURN_IF_ERROR(
+      iree_hal_buffer_map_write(out_fixture->input_buffer, /*target_offset=*/0,
+                                input_values, sizeof(input_values)));
+
+  IREE_RETURN_IF_ERROR(CreateHostVisibleDispatchBuffer(
+      test_device->allocator(), /*buffer_size=*/4 * sizeof(uint32_t),
+      out_fixture->output_buffer0.out()));
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_map_zero(
+      out_fixture->output_buffer0, /*offset=*/0, IREE_HAL_WHOLE_BUFFER));
+
+  IREE_RETURN_IF_ERROR(CreateHostVisibleDispatchBuffer(
+      test_device->allocator(), /*buffer_size=*/4 * sizeof(uint32_t),
+      out_fixture->output_buffer1.out()));
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_map_zero(
+      out_fixture->output_buffer1, /*offset=*/0, IREE_HAL_WHOLE_BUFFER));
+
+  IREE_RETURN_IF_ERROR(iree_hal_command_buffer_create(
+      test_device->base_device(), mode, IREE_HAL_COMMAND_CATEGORY_DISPATCH,
+      IREE_HAL_QUEUE_AFFINITY_ANY, /*binding_capacity=*/0,
+      out_fixture->command_buffer.out()));
+  IREE_RETURN_IF_ERROR(
+      iree_hal_command_buffer_begin(out_fixture->command_buffer));
+  iree_hal_buffer_ref_t binding_refs0[2] = {
+      iree_hal_make_buffer_ref(
+          out_fixture->input_buffer, /*offset=*/0,
+          iree_hal_buffer_byte_length(out_fixture->input_buffer)),
+      iree_hal_make_buffer_ref(
+          out_fixture->output_buffer0, /*offset=*/0,
+          iree_hal_buffer_byte_length(out_fixture->output_buffer0)),
+  };
+  const iree_hal_buffer_ref_list_t bindings0 = {
+      /*count=*/IREE_ARRAYSIZE(binding_refs0),
+      /*values=*/binding_refs0,
+  };
+  IREE_RETURN_IF_ERROR(AppendConstantsBindingsDispatch(
+      out_fixture->command_buffer, out_fixture->executable, bindings0));
+  iree_hal_buffer_ref_t binding_refs1[2] = {
+      iree_hal_make_buffer_ref(
+          out_fixture->input_buffer, /*offset=*/0,
+          iree_hal_buffer_byte_length(out_fixture->input_buffer)),
+      iree_hal_make_buffer_ref(
+          out_fixture->output_buffer1, /*offset=*/0,
+          iree_hal_buffer_byte_length(out_fixture->output_buffer1)),
+  };
+  const iree_hal_buffer_ref_list_t bindings1 = {
+      /*count=*/IREE_ARRAYSIZE(binding_refs1),
+      /*values=*/binding_refs1,
+  };
+  IREE_RETURN_IF_ERROR(AppendConstantsBindingsDispatch(
+      out_fixture->command_buffer, out_fixture->executable, bindings1));
+  return iree_hal_command_buffer_end(out_fixture->command_buffer);
+}
+
+static void ExpectTwoDispatchOutputs(const TwoDispatchCommandBuffer& fixture) {
+  const uint32_t expected_values[4] = {13, 16, 19, 22};
+  uint32_t output_values0[4] = {0, 0, 0, 0};
+  IREE_ASSERT_OK(iree_hal_buffer_map_read(fixture.output_buffer0, /*offset=*/0,
+                                          output_values0,
+                                          sizeof(output_values0)));
+  EXPECT_EQ(0,
+            memcmp(output_values0, expected_values, sizeof(expected_values)));
+  uint32_t output_values1[4] = {0, 0, 0, 0};
+  IREE_ASSERT_OK(iree_hal_buffer_map_read(fixture.output_buffer1, /*offset=*/0,
+                                          output_values1,
+                                          sizeof(output_values1)));
+  EXPECT_EQ(0,
+            memcmp(output_values1, expected_values, sizeof(expected_values)));
+}
+
+TEST_F(HostQueueCommandBufferTest, DispatchSummariesRetainPacketOrdinals) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+
+  TwoDispatchCommandBuffer fixture;
+  IREE_ASSERT_OK(CreateTwoDispatchCommandBuffer(
+      &test_device, &fixture,
+      IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT |
+          IREE_HAL_COMMAND_BUFFER_MODE_RETAIN_DISPATCH_METADATA));
+  EXPECT_EQ(
+      iree_hal_amdgpu_aql_command_buffer_profile_id(fixture.command_buffer),
+      0u);
+
+  const iree_hal_amdgpu_aql_program_t* program =
+      iree_hal_amdgpu_aql_command_buffer_program(fixture.command_buffer);
+  ASSERT_NE(program, nullptr);
+  ASSERT_NE(program->first_block, nullptr);
+  EXPECT_EQ(program->first_block->dispatch_count, 2u);
+  EXPECT_EQ(program->first_block->aql_packet_count, 2u);
+
+  uint32_t summary_count = 0;
+  const iree_hal_amdgpu_aql_command_buffer_dispatch_summary_t* summary =
+      iree_hal_amdgpu_aql_command_buffer_dispatch_summaries(
+          fixture.command_buffer, program->first_block, &summary_count);
+  ASSERT_NE(summary, nullptr);
+  EXPECT_EQ(summary_count, 2u);
+
+  EXPECT_EQ(summary->packets.first_ordinal, 0u);
+  EXPECT_EQ(summary->packets.dispatch_ordinal, 0u);
+  EXPECT_EQ(summary->metadata.command_index, 0u);
+  EXPECT_EQ(summary->metadata.export_ordinal, 0u);
+  EXPECT_EQ(summary->metadata.dispatch_flags,
+            IREE_HAL_AMDGPU_COMMAND_BUFFER_DISPATCH_FLAG_NONE);
+
+  ASSERT_NE(summary->next, nullptr);
+  summary = summary->next;
+  EXPECT_EQ(summary->packets.first_ordinal, 1u);
+  EXPECT_EQ(summary->packets.dispatch_ordinal, 1u);
+  EXPECT_EQ(summary->metadata.command_index, 1u);
+  EXPECT_EQ(summary->metadata.export_ordinal, 0u);
+  EXPECT_EQ(summary->metadata.dispatch_flags,
+            IREE_HAL_AMDGPU_COMMAND_BUFFER_DISPATCH_FLAG_NONE);
+  EXPECT_EQ(summary->next, nullptr);
+}
+
+TEST_F(HostQueueCommandBufferTest,
+       PacketControlBarriersFirstPayloadPacketForInlineWait) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+  iree_hal_amdgpu_host_queue_t* queue = test_device.first_host_queue();
+  ASSERT_NE(queue, nullptr);
+
+  iree_hal_amdgpu_wait_resolution_t resolution = {0};
+  resolution.inline_acquire_scope = IREE_HSA_FENCE_SCOPE_AGENT;
+  iree_hal_amdgpu_aql_packet_control_t control =
+      iree_hal_amdgpu_host_queue_command_buffer_packet_control(
+          queue, &resolution, iree_hal_semaphore_list_empty(),
+          /*packet_index=*/0, IREE_HSA_FENCE_SCOPE_NONE,
+          IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_FLAG_NONE);
+  EXPECT_TRUE(control.has_barrier);
+  EXPECT_EQ(control.acquire_fence_scope, IREE_HSA_FENCE_SCOPE_AGENT);
+  EXPECT_EQ(control.release_fence_scope, IREE_HSA_FENCE_SCOPE_NONE);
+
+  control = iree_hal_amdgpu_host_queue_command_buffer_packet_control(
+      queue, &resolution, iree_hal_semaphore_list_empty(), /*packet_index=*/1,
+      IREE_HSA_FENCE_SCOPE_NONE,
+      IREE_HAL_AMDGPU_HOST_QUEUE_COMMAND_BUFFER_PACKET_FLAG_NONE);
+  EXPECT_FALSE(control.has_barrier);
+  EXPECT_EQ(control.acquire_fence_scope, IREE_HSA_FENCE_SCOPE_NONE);
+  EXPECT_EQ(control.release_fence_scope, IREE_HSA_FENCE_SCOPE_NONE);
+}
+
+TEST_F(HostQueueCommandBufferTest,
+       KernargRingUsesRecordedCpuVisibleCoarseCapability) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+
+  iree_hal_amdgpu_logical_device_t* logical_device =
+      test_device.logical_device();
+  ASSERT_GT(logical_device->physical_device_count, 0u);
+  iree_hal_amdgpu_physical_device_t* physical_device =
+      logical_device->physical_devices[0];
+  iree_hal_amdgpu_host_queue_t* queue = test_device.first_host_queue();
+  ASSERT_NE(queue, nullptr);
+
+  const iree_hal_amdgpu_cpu_visible_device_coarse_memory_t* capability =
+      &physical_device->cpu_visible_device_coarse_memory;
+  const bool uses_cpu_visible_device_coarse = iree_any_bit_set(
+      capability->flags,
+      IREE_HAL_AMDGPU_CPU_VISIBLE_DEVICE_COARSE_MEMORY_FLAG_AVAILABLE);
+  if (uses_cpu_visible_device_coarse) {
+    EXPECT_EQ(queue->kernarg_ring.publication.mode,
+              capability->host_write_publication.mode);
+    EXPECT_EQ(queue->kernarg_ring.publication.hdp_mem_flush_control,
+              capability->host_write_publication.hdp_mem_flush_control);
+  } else {
+    EXPECT_EQ(queue->kernarg_ring.publication.mode,
+              IREE_HAL_AMDGPU_KERNARG_RING_PUBLICATION_MODE_NONE);
+  }
+}
+
+TEST_F(HostQueueCommandBufferTest,
+       PrepublishedKernargsUseRecordedDeviceFineStorage) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+
+  iree_hal_amdgpu_logical_device_t* logical_device =
+      test_device.logical_device();
+  ASSERT_GT(logical_device->physical_device_count, 0u);
+  const iree_hal_amdgpu_aql_prepublished_kernarg_storage_t* storage =
+      &logical_device->physical_devices[0]->prepublished_kernarg_storage;
+
+  EXPECT_EQ(
+      storage->strategy,
+      IREE_HAL_AMDGPU_AQL_PREPUBLISHED_KERNARG_STORAGE_STRATEGY_DEVICE_FINE_HOST_COHERENT);
+  EXPECT_TRUE(iree_all_bits_set(storage->buffer_params.type,
+                                IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL |
+                                    IREE_HAL_MEMORY_TYPE_HOST_VISIBLE |
+                                    IREE_HAL_MEMORY_TYPE_HOST_COHERENT));
+  EXPECT_TRUE(iree_all_bits_set(storage->buffer_params.access,
+                                IREE_HAL_MEMORY_ACCESS_ALL));
+  EXPECT_TRUE(iree_all_bits_set(storage->buffer_params.usage,
+                                IREE_HAL_BUFFER_USAGE_DISPATCH_UNIFORM_READ |
+                                    IREE_HAL_BUFFER_USAGE_MAPPING));
+}
+
+TEST_F(HostQueueCommandBufferTest, DirectDispatchUsesPrepublishedKernargs) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+
+  iree_hal_executable_cache_t* executable_cache = NULL;
+  iree_hal_executable_t* executable = NULL;
+  IREE_ASSERT_OK(LoadCtsExecutable(
+      test_device.base_device(),
+      iree_make_cstring_view("command_buffer_dispatch_constants_bindings_test."
+                             "bin"),
+      &executable_cache, &executable));
+
+  Ref<iree_hal_buffer_t> input_buffer;
+  IREE_ASSERT_OK(CreateHostVisibleDispatchBuffer(
+      test_device.allocator(), /*buffer_size=*/4 * sizeof(uint32_t),
+      input_buffer.out()));
+  const uint32_t input_values[4] = {1, 2, 3, 4};
+  IREE_ASSERT_OK(iree_hal_buffer_map_write(input_buffer, /*target_offset=*/0,
+                                           input_values, sizeof(input_values)));
+
+  Ref<iree_hal_buffer_t> output_buffer;
+  IREE_ASSERT_OK(CreateHostVisibleDispatchBuffer(
+      test_device.allocator(), /*buffer_size=*/4 * sizeof(uint32_t),
+      output_buffer.out()));
+  IREE_ASSERT_OK(iree_hal_buffer_map_zero(output_buffer, /*offset=*/0,
+                                          IREE_HAL_WHOLE_BUFFER));
+
+  iree_hal_buffer_ref_t binding_refs[2] = {
+      iree_hal_make_buffer_ref(input_buffer, /*offset=*/0,
+                               iree_hal_buffer_byte_length(input_buffer)),
+      iree_hal_make_buffer_ref(output_buffer, /*offset=*/0,
+                               iree_hal_buffer_byte_length(output_buffer)),
+  };
+  const iree_hal_buffer_ref_list_t bindings = {
+      /*count=*/IREE_ARRAYSIZE(binding_refs),
+      /*values=*/binding_refs,
+  };
+  const uint32_t constant_values[2] = {3, 10};
+  iree_const_byte_span_t constants =
+      iree_make_const_byte_span(constant_values, sizeof(constant_values));
+
+  Ref<iree_hal_command_buffer_t> command_buffer;
+  IREE_ASSERT_OK(iree_hal_command_buffer_create(
+      test_device.base_device(), IREE_HAL_COMMAND_BUFFER_MODE_DEFAULT,
+      IREE_HAL_COMMAND_CATEGORY_DISPATCH, IREE_HAL_QUEUE_AFFINITY_ANY,
+      /*binding_capacity=*/0, command_buffer.out()));
+  IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer));
+  IREE_ASSERT_OK(iree_hal_command_buffer_dispatch(
+      command_buffer, executable, /*entry_point=*/0,
+      iree_hal_make_static_dispatch_config(1, 1, 1), constants, bindings,
+      IREE_HAL_DISPATCH_FLAG_NONE));
+  IREE_ASSERT_OK(iree_hal_command_buffer_dispatch(
+      command_buffer, executable, /*entry_point=*/0,
+      iree_hal_make_static_dispatch_config(1, 1, 1), constants, bindings,
+      IREE_HAL_DISPATCH_FLAG_NONE));
+  IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer));
+
+  const iree_hal_amdgpu_aql_program_t* program =
+      iree_hal_amdgpu_aql_command_buffer_program(command_buffer);
+  ASSERT_NE(program->first_block, nullptr);
+  EXPECT_EQ(program->max_block_kernarg_length, 0u);
+  const iree_hal_amdgpu_command_buffer_command_header_t* command =
+      iree_hal_amdgpu_command_buffer_block_commands_const(program->first_block);
+  ASSERT_EQ(command->opcode, IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_DISPATCH);
+  const iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command =
+      (const iree_hal_amdgpu_command_buffer_dispatch_command_t*)command;
+  EXPECT_EQ(dispatch_command->kernarg_strategy,
+            IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_PREPUBLISHED);
+  const uint32_t kernarg_length =
+      (uint32_t)dispatch_command->kernarg_length_qwords * 8u;
+  EXPECT_EQ(dispatch_command->payload_reference, 0u);
+  EXPECT_NE(
+      iree_hal_amdgpu_aql_command_buffer_prepublished_kernarg(
+          command_buffer, dispatch_command->payload_reference, kernarg_length),
+      nullptr);
+  const iree_hal_amdgpu_command_buffer_command_header_t* second_command =
+      (const iree_hal_amdgpu_command_buffer_command_header_t*)((const uint8_t*)
+                                                                   command +
+                                                               command->length_qwords *
+                                                                   8u);
+  ASSERT_EQ(second_command->opcode,
+            IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_DISPATCH);
+  const iree_hal_amdgpu_command_buffer_dispatch_command_t*
+      second_dispatch_command =
+          (const iree_hal_amdgpu_command_buffer_dispatch_command_t*)
+              second_command;
+  EXPECT_EQ(second_dispatch_command->kernarg_strategy,
+            IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_PREPUBLISHED);
+  EXPECT_GT(second_dispatch_command->payload_reference, 1u);
+  EXPECT_NE(iree_hal_amdgpu_aql_command_buffer_prepublished_kernarg(
+                command_buffer, second_dispatch_command->payload_reference,
+                (uint32_t)second_dispatch_command->kernarg_length_qwords * 8u),
+            nullptr);
+
+  Ref<iree_hal_semaphore_t> command_buffer_signal;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), command_buffer_signal.out()));
+  uint64_t command_buffer_signal_value = 1;
+  iree_hal_semaphore_t* command_buffer_signal_ptr = command_buffer_signal.get();
+  const iree_hal_semaphore_list_t command_buffer_signal_list = {
+      /*count=*/1,
+      /*semaphores=*/&command_buffer_signal_ptr,
+      /*payload_values=*/&command_buffer_signal_value,
+  };
+  IREE_ASSERT_OK(iree_hal_device_queue_execute(
+      test_device.base_device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+      iree_hal_semaphore_list_empty(), command_buffer_signal_list,
+      command_buffer, iree_hal_buffer_binding_table_empty(),
+      IREE_HAL_EXECUTE_FLAG_NONE));
+  IREE_ASSERT_OK(iree_hal_semaphore_wait(
+      command_buffer_signal, command_buffer_signal_value,
+      iree_infinite_timeout(), IREE_ASYNC_WAIT_FLAG_NONE));
+
+  uint32_t output_values[4] = {0, 0, 0, 0};
+  IREE_ASSERT_OK(iree_hal_buffer_map_read(
+      output_buffer, /*offset=*/0, output_values, sizeof(output_values)));
+  const uint32_t expected_values[4] = {13, 16, 19, 22};
+  EXPECT_EQ(0, memcmp(output_values, expected_values, sizeof(expected_values)));
+
+  IREE_ASSERT_OK(iree_hal_buffer_map_zero(output_buffer, /*offset=*/0,
+                                          IREE_HAL_WHOLE_BUFFER));
+
+  Ref<iree_hal_command_buffer_t> one_shot_command_buffer;
+  IREE_ASSERT_OK(iree_hal_command_buffer_create(
+      test_device.base_device(), IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT,
+      IREE_HAL_COMMAND_CATEGORY_DISPATCH, IREE_HAL_QUEUE_AFFINITY_ANY,
+      /*binding_capacity=*/0, one_shot_command_buffer.out()));
+  IREE_ASSERT_OK(iree_hal_command_buffer_begin(one_shot_command_buffer));
+  IREE_ASSERT_OK(iree_hal_command_buffer_dispatch(
+      one_shot_command_buffer, executable, /*entry_point=*/0,
+      iree_hal_make_static_dispatch_config(1, 1, 1), constants, bindings,
+      IREE_HAL_DISPATCH_FLAG_NONE));
+  IREE_ASSERT_OK(iree_hal_command_buffer_end(one_shot_command_buffer));
+
+  const iree_hal_amdgpu_aql_program_t* one_shot_program =
+      iree_hal_amdgpu_aql_command_buffer_program(one_shot_command_buffer);
+  ASSERT_NE(one_shot_program->first_block, nullptr);
+  EXPECT_GT(one_shot_program->max_block_kernarg_length, 0u);
+  const iree_hal_amdgpu_command_buffer_command_header_t* one_shot_command =
+      iree_hal_amdgpu_command_buffer_block_commands_const(
+          one_shot_program->first_block);
+  ASSERT_EQ(one_shot_command->opcode,
+            IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_DISPATCH);
+  const iree_hal_amdgpu_command_buffer_dispatch_command_t*
+      one_shot_dispatch_command =
+          (const iree_hal_amdgpu_command_buffer_dispatch_command_t*)
+              one_shot_command;
+  EXPECT_EQ(one_shot_dispatch_command->kernarg_strategy,
+            IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_HAL);
+  EXPECT_EQ(
+      iree_hal_amdgpu_aql_command_buffer_prepublished_kernarg(
+          one_shot_command_buffer, one_shot_dispatch_command->payload_reference,
+          (uint32_t)one_shot_dispatch_command->kernarg_length_qwords * 8u),
+      nullptr);
+
+  command_buffer_signal_value = 2;
+  IREE_ASSERT_OK(iree_hal_device_queue_execute(
+      test_device.base_device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+      iree_hal_semaphore_list_empty(), command_buffer_signal_list,
+      one_shot_command_buffer, iree_hal_buffer_binding_table_empty(),
+      IREE_HAL_EXECUTE_FLAG_NONE));
+  IREE_ASSERT_OK(iree_hal_semaphore_wait(
+      command_buffer_signal, command_buffer_signal_value,
+      iree_infinite_timeout(), IREE_ASYNC_WAIT_FLAG_NONE));
+
+  memset(output_values, 0, sizeof(output_values));
+  IREE_ASSERT_OK(iree_hal_buffer_map_read(
+      output_buffer, /*offset=*/0, output_values, sizeof(output_values)));
+  EXPECT_EQ(0, memcmp(output_values, expected_values, sizeof(expected_values)));
+
+  iree_hal_executable_release(executable);
+  iree_hal_executable_cache_release(executable_cache);
+}
+
+TEST_F(HostQueueCommandBufferTest,
+       MixedDynamicDispatchUsesPatchedKernargTemplate) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+
+  iree_hal_executable_cache_t* executable_cache = NULL;
+  iree_hal_executable_t* executable = NULL;
+  IREE_ASSERT_OK(LoadCtsExecutable(
+      test_device.base_device(),
+      iree_make_cstring_view("command_buffer_dispatch_constants_bindings_test."
+                             "bin"),
+      &executable_cache, &executable));
+
+  Ref<iree_hal_buffer_t> input_buffer;
+  IREE_ASSERT_OK(CreateHostVisibleDispatchBuffer(
+      test_device.allocator(), /*buffer_size=*/4 * sizeof(uint32_t),
+      input_buffer.out()));
+  const uint32_t input_values[4] = {1, 2, 3, 4};
+  IREE_ASSERT_OK(iree_hal_buffer_map_write(input_buffer, /*target_offset=*/0,
+                                           input_values, sizeof(input_values)));
+
+  Ref<iree_hal_buffer_t> output_buffer;
+  IREE_ASSERT_OK(CreateHostVisibleDispatchBuffer(
+      test_device.allocator(), /*buffer_size=*/4 * sizeof(uint32_t),
+      output_buffer.out()));
+  IREE_ASSERT_OK(iree_hal_buffer_map_zero(output_buffer, /*offset=*/0,
+                                          IREE_HAL_WHOLE_BUFFER));
+
+  iree_hal_buffer_ref_t binding_refs[2] = {
+      iree_hal_make_buffer_ref(input_buffer, /*offset=*/0,
+                               iree_hal_buffer_byte_length(input_buffer)),
+      iree_hal_make_indirect_buffer_ref(
+          /*buffer_slot=*/3, /*offset=*/0,
+          iree_hal_buffer_byte_length(output_buffer)),
+  };
+  const iree_hal_buffer_ref_list_t dispatch_bindings = {
+      /*count=*/IREE_ARRAYSIZE(binding_refs),
+      /*values=*/binding_refs,
+  };
+  const uint32_t constant_values[2] = {3, 10};
+  iree_const_byte_span_t constants =
+      iree_make_const_byte_span(constant_values, sizeof(constant_values));
+
+  Ref<iree_hal_command_buffer_t> command_buffer;
+  IREE_ASSERT_OK(iree_hal_command_buffer_create(
+      test_device.base_device(),
+      IREE_HAL_COMMAND_BUFFER_MODE_DEFAULT |
+          IREE_HAL_COMMAND_BUFFER_MODE_RETAIN_PROFILE_METADATA,
+      IREE_HAL_COMMAND_CATEGORY_DISPATCH, IREE_HAL_QUEUE_AFFINITY_ANY,
+      /*binding_capacity=*/4, command_buffer.out()));
+  IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer));
+  IREE_ASSERT_OK(iree_hal_command_buffer_dispatch(
+      command_buffer, executable, /*entry_point=*/0,
+      iree_hal_make_static_dispatch_config(1, 1, 1), constants,
+      dispatch_bindings, IREE_HAL_DISPATCH_FLAG_NONE));
+  IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer));
+
+  const iree_hal_amdgpu_aql_program_t* program =
+      iree_hal_amdgpu_aql_command_buffer_program(command_buffer);
+  ASSERT_NE(program->first_block, nullptr);
+  EXPECT_GT(program->max_block_kernarg_length, 0u);
+  ASSERT_EQ(program->first_block->binding_source_count, 1u);
+  const iree_hal_amdgpu_command_buffer_binding_source_t* binding_source =
+      iree_hal_amdgpu_command_buffer_block_binding_sources_const(
+          program->first_block);
+  EXPECT_EQ(binding_source->flags,
+            IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_DYNAMIC);
+  EXPECT_EQ(binding_source->slot, 3u);
+  EXPECT_EQ(binding_source->target_binding_ordinal, 1u);
+
+  const iree_hal_amdgpu_command_buffer_command_header_t* command =
+      iree_hal_amdgpu_command_buffer_block_commands_const(program->first_block);
+  ASSERT_EQ(command->opcode, IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_DISPATCH);
+  const iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command =
+      (const iree_hal_amdgpu_command_buffer_dispatch_command_t*)command;
+  EXPECT_EQ(dispatch_command->kernarg_strategy,
+            IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_PATCHED_TEMPLATE);
+  EXPECT_EQ(dispatch_command->payload.patch_source_count, 1u);
+  const iree_hal_amdgpu_profile_metadata_registry_t& profile_metadata =
+      test_device.logical_device()->profile_metadata;
+  ASSERT_EQ(profile_metadata.command_operation_record_count, 2u);
+  const iree_hal_profile_command_operation_record_t* dispatch_operation =
+      nullptr;
+  for (iree_host_size_t i = 0;
+       i < profile_metadata.command_operation_record_count; ++i) {
+    const iree_hal_profile_command_operation_record_t& operation =
+        profile_metadata.command_operation_records[i];
+    if (operation.type == IREE_HAL_PROFILE_COMMAND_OPERATION_TYPE_DISPATCH) {
+      dispatch_operation = &operation;
+      break;
+    }
+  }
+  ASSERT_NE(dispatch_operation, nullptr);
+  EXPECT_EQ(dispatch_operation->binding_count, 2u);
+  EXPECT_NE(dispatch_operation->flags &
+                IREE_HAL_PROFILE_COMMAND_OPERATION_FLAG_STATIC_BINDINGS,
+            0u);
+  EXPECT_NE(dispatch_operation->flags &
+                IREE_HAL_PROFILE_COMMAND_OPERATION_FLAG_DYNAMIC_BINDINGS,
+            0u);
+  const uint32_t kernarg_length =
+      (uint32_t)dispatch_command->kernarg_length_qwords * 8u;
+  EXPECT_NE(
+      iree_hal_amdgpu_aql_command_buffer_rodata(
+          command_buffer, dispatch_command->payload_reference, kernarg_length),
+      nullptr);
+
+  Ref<iree_hal_semaphore_t> command_buffer_signal;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), command_buffer_signal.out()));
+  uint64_t command_buffer_signal_value = 1;
+  iree_hal_semaphore_t* command_buffer_signal_ptr = command_buffer_signal.get();
+  const iree_hal_semaphore_list_t command_buffer_signal_list = {
+      /*count=*/1,
+      /*semaphores=*/&command_buffer_signal_ptr,
+      /*payload_values=*/&command_buffer_signal_value,
+  };
+  iree_hal_buffer_binding_t bindings[4] = {
+      {
+          /*buffer=*/input_buffer.get(),
+          /*offset=*/0,
+          /*length=*/IREE_HAL_WHOLE_BUFFER,
+      },
+      {
+          /*buffer=*/input_buffer.get(),
+          /*offset=*/0,
+          /*length=*/IREE_HAL_WHOLE_BUFFER,
+      },
+      {
+          /*buffer=*/input_buffer.get(),
+          /*offset=*/0,
+          /*length=*/IREE_HAL_WHOLE_BUFFER,
+      },
+      {
+          /*buffer=*/output_buffer.get(),
+          /*offset=*/0,
+          /*length=*/IREE_HAL_WHOLE_BUFFER,
+      },
+  };
+  const iree_hal_buffer_binding_table_t binding_table = {
+      /*count=*/IREE_ARRAYSIZE(bindings),
+      /*bindings=*/bindings,
+  };
+  IREE_ASSERT_OK(iree_hal_device_queue_execute(
+      test_device.base_device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+      iree_hal_semaphore_list_empty(), command_buffer_signal_list,
+      command_buffer, binding_table, IREE_HAL_EXECUTE_FLAG_NONE));
+  IREE_ASSERT_OK(iree_hal_semaphore_wait(
+      command_buffer_signal, command_buffer_signal_value,
+      iree_infinite_timeout(), IREE_ASYNC_WAIT_FLAG_NONE));
+
+  uint32_t output_values[4] = {0, 0, 0, 0};
+  IREE_ASSERT_OK(iree_hal_buffer_map_read(
+      output_buffer, /*offset=*/0, output_values, sizeof(output_values)));
+  const uint32_t expected_values[4] = {13, 16, 19, 22};
+  EXPECT_EQ(0, memcmp(output_values, expected_values, sizeof(expected_values)));
+
+  iree_hal_executable_release(executable);
+  iree_hal_executable_cache_release(executable_cache);
+}
+
+TEST_F(HostQueueCommandBufferTest, DynamicDispatchUsesBindingTableSlots) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+
+  iree_hal_executable_cache_t* executable_cache = NULL;
+  iree_hal_executable_t* executable = NULL;
+  IREE_ASSERT_OK(LoadCtsExecutable(
+      test_device.base_device(),
+      iree_make_cstring_view("command_buffer_dispatch_constants_bindings_test."
+                             "bin"),
+      &executable_cache, &executable));
+
+  Ref<iree_hal_buffer_t> input_buffer;
+  IREE_ASSERT_OK(CreateHostVisibleDispatchBuffer(
+      test_device.allocator(), /*buffer_size=*/4 * sizeof(uint32_t),
+      input_buffer.out()));
+  const uint32_t input_values[4] = {1, 2, 3, 4};
+  IREE_ASSERT_OK(iree_hal_buffer_map_write(input_buffer, /*target_offset=*/0,
+                                           input_values, sizeof(input_values)));
+
+  Ref<iree_hal_buffer_t> output_buffer;
+  IREE_ASSERT_OK(CreateHostVisibleDispatchBuffer(
+      test_device.allocator(), /*buffer_size=*/4 * sizeof(uint32_t),
+      output_buffer.out()));
+  IREE_ASSERT_OK(iree_hal_buffer_map_zero(output_buffer, /*offset=*/0,
+                                          IREE_HAL_WHOLE_BUFFER));
+
+  iree_hal_buffer_ref_t binding_refs[2] = {
+      iree_hal_make_indirect_buffer_ref(
+          /*buffer_slot=*/1, /*offset=*/0,
+          iree_hal_buffer_byte_length(input_buffer)),
+      iree_hal_make_indirect_buffer_ref(
+          /*buffer_slot=*/3, /*offset=*/0,
+          iree_hal_buffer_byte_length(output_buffer)),
+  };
+  const iree_hal_buffer_ref_list_t dispatch_bindings = {
+      /*count=*/IREE_ARRAYSIZE(binding_refs),
+      /*values=*/binding_refs,
+  };
+  const uint32_t constant_values[2] = {3, 10};
+  iree_const_byte_span_t constants =
+      iree_make_const_byte_span(constant_values, sizeof(constant_values));
+
+  Ref<iree_hal_command_buffer_t> command_buffer;
+  IREE_ASSERT_OK(iree_hal_command_buffer_create(
+      test_device.base_device(), IREE_HAL_COMMAND_BUFFER_MODE_DEFAULT,
+      IREE_HAL_COMMAND_CATEGORY_DISPATCH, IREE_HAL_QUEUE_AFFINITY_ANY,
+      /*binding_capacity=*/4, command_buffer.out()));
+  IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer));
+  IREE_ASSERT_OK(iree_hal_command_buffer_dispatch(
+      command_buffer, executable, /*entry_point=*/0,
+      iree_hal_make_static_dispatch_config(1, 1, 1), constants,
+      dispatch_bindings, IREE_HAL_DISPATCH_FLAG_NONE));
+  IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer));
+
+  const iree_hal_amdgpu_aql_program_t* program =
+      iree_hal_amdgpu_aql_command_buffer_program(command_buffer);
+  ASSERT_NE(program->first_block, nullptr);
+  ASSERT_EQ(program->first_block->binding_source_count, 2u);
+  const iree_hal_amdgpu_command_buffer_binding_source_t* binding_sources =
+      iree_hal_amdgpu_command_buffer_block_binding_sources_const(
+          program->first_block);
+  EXPECT_EQ(binding_sources[0].flags,
+            IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_DYNAMIC);
+  EXPECT_EQ(binding_sources[0].slot, 1u);
+  EXPECT_EQ(binding_sources[0].target_binding_ordinal, 0u);
+  EXPECT_EQ(binding_sources[1].flags,
+            IREE_HAL_AMDGPU_COMMAND_BUFFER_BINDING_SOURCE_FLAG_DYNAMIC);
+  EXPECT_EQ(binding_sources[1].slot, 3u);
+  EXPECT_EQ(binding_sources[1].target_binding_ordinal, 1u);
+
+  const iree_hal_amdgpu_command_buffer_command_header_t* command =
+      iree_hal_amdgpu_command_buffer_block_commands_const(program->first_block);
+  ASSERT_EQ(command->opcode, IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_DISPATCH);
+  const iree_hal_amdgpu_command_buffer_dispatch_command_t* dispatch_command =
+      (const iree_hal_amdgpu_command_buffer_dispatch_command_t*)command;
+  EXPECT_EQ(dispatch_command->kernarg_strategy,
+            IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_DYNAMIC_BINDINGS);
+
+  Ref<iree_hal_semaphore_t> command_buffer_signal;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), command_buffer_signal.out()));
+  uint64_t command_buffer_signal_value = 1;
+  iree_hal_semaphore_t* command_buffer_signal_ptr = command_buffer_signal.get();
+  const iree_hal_semaphore_list_t command_buffer_signal_list = {
+      /*count=*/1,
+      /*semaphores=*/&command_buffer_signal_ptr,
+      /*payload_values=*/&command_buffer_signal_value,
+  };
+  iree_hal_buffer_binding_t bindings[4] = {
+      {
+          /*buffer=*/input_buffer.get(),
+          /*offset=*/0,
+          /*length=*/IREE_HAL_WHOLE_BUFFER,
+      },
+      {
+          /*buffer=*/input_buffer.get(),
+          /*offset=*/0,
+          /*length=*/IREE_HAL_WHOLE_BUFFER,
+      },
+      {
+          /*buffer=*/input_buffer.get(),
+          /*offset=*/0,
+          /*length=*/IREE_HAL_WHOLE_BUFFER,
+      },
+      {
+          /*buffer=*/output_buffer.get(),
+          /*offset=*/0,
+          /*length=*/IREE_HAL_WHOLE_BUFFER,
+      },
+  };
+  const iree_hal_buffer_binding_table_t binding_table = {
+      /*count=*/IREE_ARRAYSIZE(bindings),
+      /*bindings=*/bindings,
+  };
+  IREE_ASSERT_OK(iree_hal_device_queue_execute(
+      test_device.base_device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+      iree_hal_semaphore_list_empty(), command_buffer_signal_list,
+      command_buffer, binding_table, IREE_HAL_EXECUTE_FLAG_NONE));
+  IREE_ASSERT_OK(iree_hal_semaphore_wait(
+      command_buffer_signal, command_buffer_signal_value,
+      iree_infinite_timeout(), IREE_ASYNC_WAIT_FLAG_NONE));
+
+  uint32_t output_values[4] = {0, 0, 0, 0};
+  IREE_ASSERT_OK(iree_hal_buffer_map_read(
+      output_buffer, /*offset=*/0, output_values, sizeof(output_values)));
+  const uint32_t expected_values[4] = {13, 16, 19, 22};
+  EXPECT_EQ(0, memcmp(output_values, expected_values, sizeof(expected_values)));
+
+  iree_hal_executable_release(executable);
+  iree_hal_executable_cache_release(executable_cache);
+}
+
+TEST_F(HostQueueCommandBufferTest,
+       CommandBufferRejectsCrossPhysicalDeviceQueue) {
+  if (topology_.gpu_agent_count < 2) {
+    GTEST_SKIP() << "fewer than two compatible GPU agents";
+    return;
+  }
+
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+  ASSERT_GE(test_device.logical_device()->physical_device_count, 2u);
+
+  iree_hal_queue_affinity_t device0_affinity = 0;
+  IREE_ASSERT_OK(
+      QueueAffinityForPhysicalDevice(test_device, 0, &device0_affinity));
+  iree_hal_queue_affinity_t device1_affinity = 0;
+  IREE_ASSERT_OK(
+      QueueAffinityForPhysicalDevice(test_device, 1, &device1_affinity));
+
+  Ref<iree_hal_command_buffer_t> command_buffer;
+  IREE_ASSERT_OK(iree_hal_command_buffer_create(
+      test_device.base_device(), IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT,
+      IREE_HAL_COMMAND_CATEGORY_DISPATCH, device0_affinity,
+      /*binding_capacity=*/0, command_buffer.out()));
+  IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer));
+  IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer));
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_INVALID_ARGUMENT,
+      iree_hal_device_queue_execute(
+          test_device.base_device(), device1_affinity,
+          iree_hal_semaphore_list_empty(), iree_hal_semaphore_list_empty(),
+          command_buffer, iree_hal_buffer_binding_table_empty(),
+          IREE_HAL_EXECUTE_FLAG_NONE));
+}
+
+TEST_F(HostQueueCommandBufferTest,
+       ProfilingMetadataCoversMultiplePhysicalDevices) {
+  if (topology_.gpu_agent_count < 2) {
+    GTEST_SKIP() << "fewer than two compatible GPU agents";
+    return;
+  }
+
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+  iree_hal_amdgpu_logical_device_t* logical_device =
+      test_device.logical_device();
+  ASSERT_GE(logical_device->physical_device_count, 2u);
+  ASSERT_GT(logical_device->system->topology.gpu_agent_queue_count, 0u);
+
+  CommandBufferProfileSink sink = {};
+  CommandBufferProfileSinkInitialize(&sink);
+  DeviceProfilingScope profiling(test_device.base_device());
+  IREE_ASSERT_OK(profiling.Begin(IREE_HAL_DEVICE_PROFILING_DATA_QUEUE_EVENTS,
+                                 CommandBufferProfileSinkAsBase(&sink)));
+  IREE_ASSERT_OK(profiling.End());
+
+  const iree_host_size_t physical_device_count =
+      logical_device->physical_device_count;
+  const iree_host_size_t queue_count_per_physical_device =
+      logical_device->system->topology.gpu_agent_queue_count;
+  EXPECT_EQ(1, sink.device_metadata_count);
+  EXPECT_EQ(1, sink.queue_metadata_count);
+  ASSERT_EQ(physical_device_count, sink.device_records.size());
+  ASSERT_EQ(physical_device_count * queue_count_per_physical_device,
+            sink.queue_records.size());
+  EXPECT_TRUE(sink.clock_correlations.empty());
+
+  for (const auto& device_record : sink.device_records) {
+    EXPECT_LT(device_record.physical_device_ordinal, physical_device_count);
+    EXPECT_EQ(queue_count_per_physical_device, device_record.queue_count);
+  }
+  for (const auto& queue_record : sink.queue_records) {
+    EXPECT_LT(queue_record.physical_device_ordinal, physical_device_count);
+    EXPECT_LT(queue_record.queue_ordinal, queue_count_per_physical_device);
+  }
+}
+
+TEST_F(HostQueueCommandBufferTest, SinklessProfilingBeginFails) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+
+  iree_hal_device_profiling_options_t profiling_options = {0};
+  profiling_options.data_families =
+      IREE_HAL_DEVICE_PROFILING_DATA_DISPATCH_EVENTS;
+  IREE_ASSERT_STATUS_IS(IREE_STATUS_INVALID_ARGUMENT,
+                        iree_hal_device_profiling_begin(
+                            test_device.base_device(), &profiling_options));
+  IREE_EXPECT_OK(iree_hal_device_profiling_flush(test_device.base_device()));
+  IREE_EXPECT_OK(iree_hal_device_profiling_end(test_device.base_device()));
+}
+
+TEST_F(HostQueueCommandBufferTest, NestedProfilingBeginFails) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+
+  CommandBufferProfileSink first_sink = {};
+  CommandBufferProfileSinkInitialize(&first_sink);
+  DeviceProfilingScope profiling(test_device.base_device());
+  IREE_ASSERT_OK(profiling.Begin(IREE_HAL_DEVICE_PROFILING_DATA_QUEUE_EVENTS,
+                                 CommandBufferProfileSinkAsBase(&first_sink)));
+
+  CommandBufferProfileSink nested_sink = {};
+  CommandBufferProfileSinkInitialize(&nested_sink);
+  iree_hal_device_profiling_options_t nested_options = {0};
+  nested_options.data_families = IREE_HAL_DEVICE_PROFILING_DATA_QUEUE_EVENTS;
+  nested_options.sink = CommandBufferProfileSinkAsBase(&nested_sink);
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_FAILED_PRECONDITION,
+                        iree_hal_device_profiling_begin(
+                            test_device.base_device(), &nested_options));
+  EXPECT_EQ(0, nested_sink.begin_count);
+  EXPECT_EQ(0, nested_sink.end_count);
+
+  IREE_ASSERT_OK(profiling.End());
+  EXPECT_EQ(1, first_sink.begin_count);
+  EXPECT_EQ(1, first_sink.end_count);
+  ExpectQueueEventProfilingCanBeginAndEnd(&test_device);
+}
+
+TEST_F(HostQueueCommandBufferTest, ProfilingBeginSinkBeginFailureAllowsRetry) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+
+  CommandBufferProfileSink sink = {};
+  CommandBufferProfileSinkInitialize(&sink);
+  sink.fail_begin_session_status_code = IREE_STATUS_RESOURCE_EXHAUSTED;
+  DeviceProfilingScope profiling(test_device.base_device());
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_RESOURCE_EXHAUSTED,
+      profiling.Begin(IREE_HAL_DEVICE_PROFILING_DATA_QUEUE_EVENTS,
+                      CommandBufferProfileSinkAsBase(&sink)));
+  EXPECT_EQ(1, sink.begin_count);
+  EXPECT_EQ(0, sink.end_count);
+
+  ExpectQueueEventProfilingCanBeginAndEnd(&test_device);
+}
+
+TEST_F(HostQueueCommandBufferTest,
+       ProfilingBeginMetadataWriteFailureEndsSessionAndAllowsRetry) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+
+  CommandBufferProfileSink sink = {};
+  CommandBufferProfileSinkInitialize(&sink);
+  sink.fail_write_content_type = IREE_HAL_PROFILE_CONTENT_TYPE_DEVICES;
+  sink.fail_write_remaining = 1;
+  sink.fail_write_status_code = IREE_STATUS_DATA_LOSS;
+  sink.expected_end_session_status_code = IREE_STATUS_DATA_LOSS;
+  DeviceProfilingScope profiling(test_device.base_device());
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_DATA_LOSS,
+      profiling.Begin(IREE_HAL_DEVICE_PROFILING_DATA_QUEUE_EVENTS,
+                      CommandBufferProfileSinkAsBase(&sink)));
+  EXPECT_EQ(1, sink.begin_count);
+  EXPECT_EQ(1, sink.end_count);
+  EXPECT_EQ(IREE_STATUS_DATA_LOSS, sink.observed_end_session_status_code);
+  EXPECT_EQ(0, sink.device_metadata_count);
+
+  ExpectQueueEventProfilingCanBeginAndEnd(&test_device);
+}
+
+TEST_F(HostQueueCommandBufferTest,
+       ProfilingFlushWriteFailurePreservesQueueEventsForRetry) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+
+  CommandBufferProfileSink sink = {};
+  CommandBufferProfileSinkInitialize(&sink);
+  DeviceProfilingScope profiling(test_device.base_device());
+  IREE_ASSERT_OK(profiling.Begin(IREE_HAL_DEVICE_PROFILING_DATA_QUEUE_EVENTS,
+                                 CommandBufferProfileSinkAsBase(&sink)));
+
+  IREE_ASSERT_OK(SubmitProfiledQueueFill(&test_device));
+  sink.fail_write_content_type = IREE_HAL_PROFILE_CONTENT_TYPE_QUEUE_EVENTS;
+  sink.fail_write_remaining = 1;
+  sink.fail_write_status_code = IREE_STATUS_UNAVAILABLE;
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_UNAVAILABLE,
+      iree_hal_device_profiling_flush(test_device.base_device()));
+  EXPECT_EQ(0, sink.queue_event_count);
+  EXPECT_TRUE(sink.queue_events.empty());
+
+  IREE_ASSERT_OK(iree_hal_device_profiling_flush(test_device.base_device()));
+  IREE_ASSERT_OK(profiling.End());
+  EXPECT_EQ(1, sink.queue_event_count);
+  ASSERT_EQ(1u, sink.queue_events.size());
+  EXPECT_EQ(IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_FILL, sink.queue_events[0].type);
+}
+
+TEST_F(HostQueueCommandBufferTest,
+       ProfiledQueueEventsReportDroppedRecordsWhenRingFull) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+
+  CommandBufferProfileSink sink = {};
+  CommandBufferProfileSinkInitialize(&sink);
+  DeviceProfilingScope profiling(test_device.base_device());
+  IREE_ASSERT_OK(profiling.Begin(IREE_HAL_DEVICE_PROFILING_DATA_QUEUE_EVENTS,
+                                 CommandBufferProfileSinkAsBase(&sink)));
+
+  const iree_host_size_t event_capacity =
+      test_device.logical_device()
+          ->profiling.event_streams.queue.stream.ring.capacity;
+  ASSERT_GT(event_capacity, 0u);
+  for (iree_host_size_t i = 0; i <= event_capacity; ++i) {
+    iree_hal_profile_queue_event_t event =
+        iree_hal_profile_queue_event_default();
+    event.type = IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_FILL;
+    event.physical_device_ordinal = 0;
+    event.queue_ordinal = 0;
+    event.operation_count = 1;
+    iree_hal_amdgpu_logical_device_record_profile_queue_event(
+        test_device.base_device(), &event);
+  }
+
+  IREE_ASSERT_OK(profiling.End());
+  EXPECT_EQ(1, sink.queue_event_count);
+  EXPECT_EQ(event_capacity, sink.queue_events.size());
+  EXPECT_EQ(1u, sink.queue_event_dropped_record_count);
+  EXPECT_EQ(1u, sink.dropped_record_count);
+  EXPECT_EQ(1, sink.truncated_chunk_count);
+}
+
+TEST_F(HostQueueCommandBufferTest,
+       ProfiledMemoryEventsReportDroppedRecordsWhenRingFull) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+
+  CommandBufferProfileSink sink = {};
+  CommandBufferProfileSinkInitialize(&sink);
+  DeviceProfilingScope profiling(test_device.base_device());
+  IREE_ASSERT_OK(profiling.Begin(IREE_HAL_DEVICE_PROFILING_DATA_MEMORY_EVENTS,
+                                 CommandBufferProfileSinkAsBase(&sink)));
+
+  const iree_host_size_t event_capacity =
+      test_device.logical_device()
+          ->profiling.event_streams.memory.stream.ring.capacity;
+  ASSERT_GT(event_capacity, 0u);
+  iree_host_size_t recorded_count = 0;
+  for (iree_host_size_t i = 0; i <= event_capacity; ++i) {
+    iree_hal_profile_memory_event_t event =
+        iree_hal_profile_memory_event_default();
+    event.type = IREE_HAL_PROFILE_MEMORY_EVENT_TYPE_BUFFER_ALLOCATE;
+    event.allocation_id = i + 1;
+    event.physical_device_ordinal = 0;
+    event.queue_ordinal = 0;
+    event.length = sizeof(uint32_t);
+    if (iree_hal_amdgpu_logical_device_record_profile_memory_event(
+            test_device.base_device(), &event)) {
+      ++recorded_count;
+    }
+  }
+
+  IREE_ASSERT_OK(profiling.End());
+  EXPECT_EQ(event_capacity, recorded_count);
+  EXPECT_EQ(1, sink.memory_event_count);
+  EXPECT_EQ(event_capacity, sink.memory_events.size());
+  EXPECT_EQ(1u, sink.memory_event_dropped_record_count);
+  EXPECT_EQ(1u, sink.dropped_record_count);
+  EXPECT_EQ(1, sink.truncated_chunk_count);
+}
+
+TEST_F(HostQueueCommandBufferTest,
+       ProfilingFlushMetadataWriteFailurePreservesCursorForRetry) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+
+  CommandBufferProfileSink sink = {};
+  CommandBufferProfileSinkInitialize(&sink);
+  DeviceProfilingScope profiling(test_device.base_device());
+  IREE_ASSERT_OK(profiling.Begin(IREE_HAL_DEVICE_PROFILING_DATA_QUEUE_EVENTS,
+                                 CommandBufferProfileSinkAsBase(&sink)));
+
+  Ref<iree_hal_command_buffer_t> command_buffer;
+  IREE_ASSERT_OK(iree_hal_command_buffer_create(
+      test_device.base_device(),
+      IREE_HAL_COMMAND_BUFFER_MODE_RETAIN_PROFILE_METADATA,
+      IREE_HAL_COMMAND_CATEGORY_TRANSFER, IREE_HAL_QUEUE_AFFINITY_ANY,
+      /*binding_capacity=*/0, command_buffer.out()));
+  IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer));
+  IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer));
+
+  Ref<iree_hal_semaphore_t> signal;
+  IREE_ASSERT_OK(CreateSemaphore(test_device.base_device(), signal.out()));
+  uint64_t signal_value = 1;
+  iree_hal_semaphore_t* signal_ptr = signal.get();
+  const iree_hal_semaphore_list_t signal_list = {
+      /*count=*/1,
+      /*semaphores=*/&signal_ptr,
+      /*payload_values=*/&signal_value,
+  };
+  IREE_ASSERT_OK(iree_hal_device_queue_execute(
+      test_device.base_device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+      iree_hal_semaphore_list_empty(), signal_list, command_buffer,
+      iree_hal_buffer_binding_table_empty(), IREE_HAL_EXECUTE_FLAG_NONE));
+  IREE_ASSERT_OK(iree_hal_semaphore_wait(signal, signal_value,
+                                         iree_infinite_timeout(),
+                                         IREE_ASYNC_WAIT_FLAG_NONE));
+
+  sink.fail_write_content_type = IREE_HAL_PROFILE_CONTENT_TYPE_COMMAND_BUFFERS;
+  sink.fail_write_remaining = 1;
+  sink.fail_write_status_code = IREE_STATUS_UNAVAILABLE;
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_UNAVAILABLE,
+      iree_hal_device_profiling_flush(test_device.base_device()));
+  EXPECT_EQ(0, sink.command_buffer_metadata_count);
+  EXPECT_TRUE(sink.command_buffer_ids.empty());
+
+  IREE_ASSERT_OK(iree_hal_device_profiling_flush(test_device.base_device()));
+  IREE_ASSERT_OK(profiling.End());
+  EXPECT_EQ(1, sink.command_buffer_metadata_count);
+  ASSERT_EQ(1u, sink.command_buffer_ids.size());
+}
+
+TEST_F(HostQueueCommandBufferTest,
+       ProfilingEndWriteFailureReportsSessionStatusAndAllowsRetry) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+
+  CommandBufferProfileSink sink = {};
+  CommandBufferProfileSinkInitialize(&sink);
+  DeviceProfilingScope profiling(test_device.base_device());
+  IREE_ASSERT_OK(profiling.Begin(IREE_HAL_DEVICE_PROFILING_DATA_QUEUE_EVENTS,
+                                 CommandBufferProfileSinkAsBase(&sink)));
+
+  IREE_ASSERT_OK(SubmitProfiledQueueFill(&test_device));
+  sink.fail_write_content_type = IREE_HAL_PROFILE_CONTENT_TYPE_QUEUE_EVENTS;
+  sink.fail_write_remaining = 1;
+  sink.fail_write_status_code = IREE_STATUS_DATA_LOSS;
+  sink.expected_end_session_status_code = IREE_STATUS_DATA_LOSS;
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_DATA_LOSS, profiling.End());
+  EXPECT_EQ(1, sink.end_count);
+  EXPECT_EQ(IREE_STATUS_DATA_LOSS, sink.observed_end_session_status_code);
+  EXPECT_EQ(0, sink.queue_event_count);
+
+  ExpectQueueEventProfilingCanBeginAndEnd(&test_device);
+}
+
+TEST_F(HostQueueCommandBufferTest,
+       ProfilingEndSessionFailureClearsStateAndAllowsRetry) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+
+  CommandBufferProfileSink sink = {};
+  CommandBufferProfileSinkInitialize(&sink);
+  sink.fail_end_session_status_code = IREE_STATUS_ABORTED;
+  DeviceProfilingScope profiling(test_device.base_device());
+  IREE_ASSERT_OK(profiling.Begin(IREE_HAL_DEVICE_PROFILING_DATA_QUEUE_EVENTS,
+                                 CommandBufferProfileSinkAsBase(&sink)));
+
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_ABORTED, profiling.End());
+  EXPECT_EQ(1, sink.end_count);
+  EXPECT_EQ(IREE_STATUS_OK, sink.observed_end_session_status_code);
+
+  ExpectQueueEventProfilingCanBeginAndEnd(&test_device);
+}
+
+TEST_F(HostQueueCommandBufferTest, CommandBufferDispatchesEmitProfileEvents) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+
+  CommandBufferProfileSink sink = {};
+  CommandBufferProfileSinkInitialize(&sink);
+  DeviceProfilingScope profiling(test_device.base_device());
+  iree_status_t profiling_status =
+      profiling.Begin(IREE_HAL_DEVICE_PROFILING_DATA_DISPATCH_EVENTS,
+                      CommandBufferProfileSinkAsBase(&sink));
+  if (IsProfilingUnsupported(profiling_status)) {
+    iree_status_free(profiling_status);
+    GTEST_SKIP() << "device profiling data family unsupported by backend";
+  }
+  IREE_ASSERT_OK(profiling_status);
+
+  TwoDispatchCommandBuffer fixture;
+  IREE_ASSERT_OK(CreateTwoDispatchCommandBuffer(&test_device, &fixture));
+
+  Ref<iree_hal_semaphore_t> command_buffer_signal;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), command_buffer_signal.out()));
+  uint64_t command_buffer_signal_value = 1;
+  iree_hal_semaphore_t* command_buffer_signal_ptr = command_buffer_signal.get();
+  const iree_hal_semaphore_list_t command_buffer_signal_list = {
+      /*count=*/1,
+      /*semaphores=*/&command_buffer_signal_ptr,
+      /*payload_values=*/&command_buffer_signal_value,
+  };
+  IREE_ASSERT_OK(iree_hal_device_queue_execute(
+      test_device.base_device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+      iree_hal_semaphore_list_empty(), command_buffer_signal_list,
+      fixture.command_buffer, iree_hal_buffer_binding_table_empty(),
+      IREE_HAL_EXECUTE_FLAG_NONE));
+  IREE_ASSERT_OK(iree_hal_semaphore_wait(
+      command_buffer_signal, command_buffer_signal_value,
+      iree_infinite_timeout(), IREE_ASYNC_WAIT_FLAG_NONE));
+
+  IREE_ASSERT_OK(iree_hal_device_profiling_flush(test_device.base_device()));
+  IREE_ASSERT_OK(iree_hal_device_profiling_flush(test_device.base_device()));
+  IREE_ASSERT_OK(profiling.End());
+
+  ExpectTwoDispatchOutputs(fixture);
+
+  EXPECT_EQ(1, sink.begin_count);
+  EXPECT_EQ(1, sink.end_count);
+  EXPECT_EQ(1, sink.device_metadata_count);
+  EXPECT_EQ(1, sink.queue_metadata_count);
+  EXPECT_EQ(1, sink.executable_metadata_count);
+  EXPECT_EQ(1, sink.executable_export_metadata_count);
+  EXPECT_EQ(1, sink.command_buffer_metadata_count);
+  EXPECT_EQ(1, sink.command_operation_metadata_count);
+  EXPECT_GE(sink.clock_correlation_count, 2);
+  EXPECT_FALSE(sink.write_after_end);
+  ASSERT_EQ(3u, sink.command_operations.size());
+  uint32_t dispatch_operation_count = 0;
+  uint32_t return_operation_count = 0;
+  for (const auto& operation : sink.command_operations) {
+    EXPECT_EQ(sink.command_buffer_ids[0], operation.command_buffer_id);
+    switch (operation.type) {
+      case IREE_HAL_PROFILE_COMMAND_OPERATION_TYPE_DISPATCH:
+        EXPECT_TRUE(
+            iree_hal_profile_command_operation_has_block_structure(&operation));
+        EXPECT_EQ(dispatch_operation_count, operation.command_index);
+        EXPECT_NE(0u, operation.executable_id);
+        EXPECT_NE(
+            sink.executable_ids.end(),
+            std::find(sink.executable_ids.begin(), sink.executable_ids.end(),
+                      operation.executable_id));
+        EXPECT_EQ(0u, operation.export_ordinal);
+        EXPECT_EQ(2u, operation.binding_count);
+        EXPECT_EQ(1u, operation.workgroup_count[0]);
+        EXPECT_EQ(1u, operation.workgroup_count[1]);
+        EXPECT_EQ(1u, operation.workgroup_count[2]);
+        EXPECT_NE(0u, operation.workgroup_size[0]);
+        EXPECT_NE(0u,
+                  operation.flags &
+                      IREE_HAL_PROFILE_COMMAND_OPERATION_FLAG_STATIC_BINDINGS);
+        ++dispatch_operation_count;
+        break;
+      case IREE_HAL_PROFILE_COMMAND_OPERATION_TYPE_RETURN:
+        EXPECT_TRUE(
+            iree_hal_profile_command_operation_has_block_structure(&operation));
+        EXPECT_EQ(2u, operation.command_index);
+        EXPECT_NE(0u, operation.flags &
+                          IREE_HAL_PROFILE_COMMAND_OPERATION_FLAG_CONTROL_FLOW);
+        ++return_operation_count;
+        break;
+      default:
+        FAIL() << "unexpected command operation type " << operation.type;
+    }
+  }
+  EXPECT_EQ(2u, dispatch_operation_count);
+  EXPECT_EQ(1u, return_operation_count);
+  ASSERT_EQ(2u, sink.dispatch_events.size());
+  for (iree_host_size_t i = 0; i < sink.dispatch_events.size(); ++i) {
+    const iree_hal_profile_dispatch_event_t& event = sink.dispatch_events[i];
+    EXPECT_EQ(sizeof(iree_hal_profile_dispatch_event_t), event.record_length);
+    EXPECT_EQ(IREE_HAL_PROFILE_DISPATCH_EVENT_FLAG_COMMAND_BUFFER, event.flags);
+    EXPECT_NE(0u, event.event_id);
+    EXPECT_NE(0u, event.submission_id);
+    EXPECT_NE(0u, event.command_buffer_id);
+    EXPECT_NE(
+        sink.command_buffer_ids.end(),
+        std::find(sink.command_buffer_ids.begin(),
+                  sink.command_buffer_ids.end(), event.command_buffer_id));
+    EXPECT_NE(0u, event.executable_id);
+    EXPECT_NE(sink.executable_ids.end(),
+              std::find(sink.executable_ids.begin(), sink.executable_ids.end(),
+                        event.executable_id));
+    EXPECT_NE(sink.executable_export_ids.end(),
+              std::find(sink.executable_export_ids.begin(),
+                        sink.executable_export_ids.end(), event.executable_id));
+    EXPECT_EQ((uint32_t)i, event.command_index);
+    const iree_hal_profile_command_operation_record_t* operation =
+        FindCommandOperation(sink, event.command_buffer_id,
+                             event.command_index);
+    ASSERT_NE(nullptr, operation);
+    EXPECT_EQ(IREE_HAL_PROFILE_COMMAND_OPERATION_TYPE_DISPATCH,
+              operation->type);
+    EXPECT_EQ(event.executable_id, operation->executable_id);
+    EXPECT_EQ(event.export_ordinal, operation->export_ordinal);
+    EXPECT_EQ(0u, event.export_ordinal);
+    EXPECT_EQ(1u, event.workgroup_count[0]);
+    EXPECT_EQ(1u, event.workgroup_count[1]);
+    EXPECT_EQ(1u, event.workgroup_count[2]);
+    EXPECT_NE(0u, event.workgroup_size[0]);
+    EXPECT_NE(0u, event.start_tick);
+    EXPECT_NE(0u, event.end_tick);
+    EXPECT_GE(event.end_tick, event.start_tick);
+  }
+  ExpectDispatchEventsWithinClockCorrelationRange(sink);
+}
+
+TEST_F(HostQueueCommandBufferTest,
+       CommandBufferExecuteEmitsQueueDeviceSpansAndRelationships) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+
+  CommandBufferProfileSink sink = {};
+  CommandBufferProfileSinkInitialize(&sink);
+  DeviceProfilingScope profiling(test_device.base_device());
+  iree_status_t profiling_status =
+      profiling.Begin(IREE_HAL_DEVICE_PROFILING_DATA_QUEUE_EVENTS |
+                          IREE_HAL_DEVICE_PROFILING_DATA_DEVICE_QUEUE_EVENTS |
+                          IREE_HAL_DEVICE_PROFILING_DATA_DISPATCH_EVENTS,
+                      CommandBufferProfileSinkAsBase(&sink));
+  if (IsQueueDeviceProfilingUnavailable(profiling_status)) {
+    iree_status_free(profiling_status);
+    GTEST_SKIP() << "queue-device profiling data family unsupported by backend";
+  }
+  IREE_ASSERT_OK(profiling_status);
+
+  TwoDispatchCommandBuffer fixture;
+  IREE_ASSERT_OK(CreateTwoDispatchCommandBuffer(&test_device, &fixture));
+
+  Ref<iree_hal_semaphore_t> command_buffer_signal;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), command_buffer_signal.out()));
+  uint64_t command_buffer_signal_value = 1;
+  iree_hal_semaphore_t* command_buffer_signal_ptr = command_buffer_signal.get();
+  const iree_hal_semaphore_list_t command_buffer_signal_list = {
+      /*count=*/1,
+      /*semaphores=*/&command_buffer_signal_ptr,
+      /*payload_values=*/&command_buffer_signal_value,
+  };
+  IREE_ASSERT_OK(iree_hal_device_queue_execute(
+      test_device.base_device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+      iree_hal_semaphore_list_empty(), command_buffer_signal_list,
+      fixture.command_buffer, iree_hal_buffer_binding_table_empty(),
+      IREE_HAL_EXECUTE_FLAG_NONE));
+  IREE_ASSERT_OK(iree_hal_semaphore_wait(
+      command_buffer_signal, command_buffer_signal_value,
+      iree_infinite_timeout(), IREE_ASYNC_WAIT_FLAG_NONE));
+
+  IREE_ASSERT_OK(iree_hal_device_profiling_flush(test_device.base_device()));
+  IREE_ASSERT_OK(iree_hal_device_profiling_flush(test_device.base_device()));
+  IREE_ASSERT_OK(profiling.End());
+
+  ExpectTwoDispatchOutputs(fixture);
+
+  EXPECT_EQ(1, sink.begin_count);
+  EXPECT_EQ(1, sink.end_count);
+  EXPECT_FALSE(sink.write_after_end);
+  EXPECT_EQ(1, sink.queue_event_count);
+  EXPECT_EQ(1, sink.queue_device_event_count);
+  EXPECT_EQ(1, sink.relationship_count);
+  ASSERT_EQ(1u, sink.queue_events.size());
+  ASSERT_EQ(1u, sink.queue_device_events.size());
+  ASSERT_EQ(2u, sink.dispatch_events.size());
+  ASSERT_EQ(3u, sink.event_relationships.size());
+
+  const iree_hal_profile_queue_event_t& queue_event = sink.queue_events[0];
+  EXPECT_EQ(IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_EXECUTE, queue_event.type);
+  EXPECT_EQ(IREE_HAL_PROFILE_QUEUE_DEPENDENCY_STRATEGY_NONE,
+            queue_event.dependency_strategy);
+  EXPECT_NE(0u, queue_event.submission_id);
+  EXPECT_NE(0u, queue_event.command_buffer_id);
+  EXPECT_EQ(0u, queue_event.wait_count);
+  EXPECT_EQ(1u, queue_event.signal_count);
+  EXPECT_EQ(0u, queue_event.barrier_count);
+  EXPECT_EQ(3u, queue_event.operation_count);
+  EXPECT_NE(
+      sink.command_buffer_ids.end(),
+      std::find(sink.command_buffer_ids.begin(), sink.command_buffer_ids.end(),
+                queue_event.command_buffer_id));
+
+  const iree_hal_profile_queue_device_event_t& device_event =
+      sink.queue_device_events[0];
+  EXPECT_EQ(IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_EXECUTE, device_event.type);
+  EXPECT_EQ(queue_event.submission_id, device_event.submission_id);
+  EXPECT_EQ(queue_event.command_buffer_id, device_event.command_buffer_id);
+  EXPECT_EQ(queue_event.stream_id, device_event.stream_id);
+  EXPECT_EQ(queue_event.physical_device_ordinal,
+            device_event.physical_device_ordinal);
+  EXPECT_EQ(queue_event.queue_ordinal, device_event.queue_ordinal);
+  EXPECT_EQ(queue_event.operation_count, device_event.operation_count);
+
+  for (const auto& dispatch_event : sink.dispatch_events) {
+    EXPECT_EQ(queue_event.submission_id, dispatch_event.submission_id);
+    EXPECT_EQ(queue_event.command_buffer_id, dispatch_event.command_buffer_id);
+    EXPECT_NE(
+        nullptr,
+        FindEventRelationship(
+            sink,
+            IREE_HAL_PROFILE_EVENT_RELATIONSHIP_TYPE_QUEUE_SUBMISSION_DISPATCH,
+            IREE_HAL_PROFILE_EVENT_ENDPOINT_TYPE_QUEUE_SUBMISSION,
+            dispatch_event.submission_id,
+            IREE_HAL_PROFILE_EVENT_ENDPOINT_TYPE_DISPATCH_EVENT,
+            dispatch_event.event_id));
+  }
+  EXPECT_NE(
+      nullptr,
+      FindEventRelationship(
+          sink,
+          IREE_HAL_PROFILE_EVENT_RELATIONSHIP_TYPE_QUEUE_SUBMISSION_QUEUE_DEVICE_EVENT,
+          IREE_HAL_PROFILE_EVENT_ENDPOINT_TYPE_QUEUE_SUBMISSION,
+          device_event.submission_id,
+          IREE_HAL_PROFILE_EVENT_ENDPOINT_TYPE_QUEUE_DEVICE_EVENT,
+          device_event.event_id));
+  ExpectDispatchEventsWithinClockCorrelationRange(sink);
+}
+
+TEST_F(HostQueueCommandBufferTest,
+       CommandBufferDispatchesEmitHardwareCounterSamples) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+
+  CommandBufferProfileSink sink = {};
+  CommandBufferProfileSinkInitialize(&sink);
+  DeviceProfilingScope profiling(test_device.base_device());
+  iree_status_t profiling_status = BeginSqWavesProfiling(&profiling, &sink);
+  if (IsHardwareCounterProfilingUnavailable(profiling_status)) {
+    iree_status_free(profiling_status);
+    GTEST_SKIP() << "AMDGPU hardware counter profiling unavailable";
+  }
+  IREE_ASSERT_OK(profiling_status);
+
+  TwoDispatchCommandBuffer fixture;
+  IREE_ASSERT_OK(CreateTwoDispatchCommandBuffer(&test_device, &fixture));
+
+  Ref<iree_hal_semaphore_t> command_buffer_signal;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), command_buffer_signal.out()));
+  uint64_t command_buffer_signal_value = 1;
+  iree_hal_semaphore_t* command_buffer_signal_ptr = command_buffer_signal.get();
+  const iree_hal_semaphore_list_t command_buffer_signal_list = {
+      /*count=*/1,
+      /*semaphores=*/&command_buffer_signal_ptr,
+      /*payload_values=*/&command_buffer_signal_value,
+  };
+  IREE_ASSERT_OK(iree_hal_device_queue_execute(
+      test_device.base_device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+      iree_hal_semaphore_list_empty(), command_buffer_signal_list,
+      fixture.command_buffer, iree_hal_buffer_binding_table_empty(),
+      IREE_HAL_EXECUTE_FLAG_NONE));
+  IREE_ASSERT_OK(iree_hal_semaphore_wait(
+      command_buffer_signal, command_buffer_signal_value,
+      iree_infinite_timeout(), IREE_ASYNC_WAIT_FLAG_NONE));
+
+  IREE_ASSERT_OK(iree_hal_device_profiling_flush(test_device.base_device()));
+  IREE_ASSERT_OK(iree_hal_device_profiling_flush(test_device.base_device()));
+  IREE_ASSERT_OK(profiling.End());
+
+  ExpectTwoDispatchOutputs(fixture);
+
+  EXPECT_EQ(1, sink.begin_count);
+  EXPECT_EQ(1, sink.end_count);
+  EXPECT_EQ(1, sink.counter_set_metadata_count);
+  EXPECT_EQ(1, sink.counter_metadata_count);
+  EXPECT_GE(sink.counter_sample_count, 1);
+  ASSERT_EQ(2u, sink.dispatch_events.size());
+  ASSERT_EQ(sink.dispatch_events.size(), sink.counter_samples.size());
+  iree_host_size_t sample_value_count = 0;
+  for (iree_host_size_t i = 0; i < sink.counter_samples.size(); ++i) {
+    const iree_hal_profile_dispatch_event_t& event = sink.dispatch_events[i];
+    const iree_hal_profile_counter_sample_record_t& sample =
+        sink.counter_samples[i];
+    EXPECT_TRUE(iree_all_bits_set(
+        sample.flags,
+        IREE_HAL_PROFILE_COUNTER_SAMPLE_FLAG_DISPATCH_EVENT |
+            IREE_HAL_PROFILE_COUNTER_SAMPLE_FLAG_COMMAND_OPERATION |
+            IREE_HAL_PROFILE_COUNTER_SAMPLE_FLAG_DEVICE_TICK_RANGE));
+    EXPECT_EQ(IREE_HAL_PROFILE_COUNTER_SAMPLE_SCOPE_DISPATCH, sample.scope);
+    EXPECT_EQ(sample.dispatch_event_id, event.event_id);
+    EXPECT_EQ(sample.submission_id, event.submission_id);
+    EXPECT_EQ(sample.command_buffer_id, event.command_buffer_id);
+    EXPECT_EQ(sample.executable_id, event.executable_id);
+    EXPECT_EQ(sample.start_tick, event.start_tick);
+    EXPECT_EQ(sample.end_tick, event.end_tick);
+    EXPECT_EQ(sample.command_index, event.command_index);
+    EXPECT_EQ(sample.export_ordinal, event.export_ordinal);
+    sample_value_count += sample.sample_value_count;
+  }
+  ASSERT_EQ(sample_value_count, sink.counter_sample_values.size());
+  EXPECT_NE(sink.counter_sample_values.end(),
+            std::find_if(sink.counter_sample_values.begin(),
+                         sink.counter_sample_values.end(),
+                         [](uint64_t value) { return value != 0; }));
+}
+
+TEST_F(HostQueueCommandBufferTest,
+       CounterRangeSamplesUseProfileQueueWhenAvailable) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+  iree_hal_amdgpu_logical_device_t* logical_device =
+      test_device.logical_device();
+  ASSERT_GT(logical_device->physical_device_count, 0u);
+  for (iree_host_size_t i = 0; i < logical_device->physical_device_count; ++i) {
+    ASSERT_GT(logical_device->physical_devices[i]->host_queue_count, 0u);
+  }
+
+  CommandBufferProfileSink sink = {};
+  CommandBufferProfileSinkInitialize(&sink);
+  DeviceProfilingScope profiling(test_device.base_device());
+  iree_status_t profiling_status =
+      BeginSqWavesCounterRangeProfiling(&profiling, &sink);
+  if (IsHardwareCounterProfilingUnavailable(profiling_status)) {
+    iree_status_free(profiling_status);
+    GTEST_SKIP() << "AMDGPU hardware counter range profiling unavailable";
+  }
+  IREE_ASSERT_OK(profiling_status);
+
+  IREE_ASSERT_OK(iree_hal_device_profiling_flush(test_device.base_device()));
+  IREE_ASSERT_OK(profiling.End());
+
+  EXPECT_EQ(1, sink.begin_count);
+  EXPECT_EQ(1, sink.end_count);
+  EXPECT_EQ(1, sink.counter_set_metadata_count);
+  EXPECT_EQ(1, sink.counter_metadata_count);
+  EXPECT_GE(sink.counter_sample_count, 1);
+  ASSERT_FALSE(sink.counter_samples.empty());
+  iree_host_size_t sample_value_count = 0;
+  for (const iree_hal_profile_counter_sample_record_t& sample :
+       sink.counter_samples) {
+    EXPECT_TRUE(iree_all_bits_set(
+        sample.flags, IREE_HAL_PROFILE_COUNTER_SAMPLE_FLAG_DEVICE_TICK_RANGE));
+    EXPECT_EQ(IREE_HAL_PROFILE_COUNTER_SAMPLE_SCOPE_DEVICE_TIME_RANGE,
+              sample.scope);
+    EXPECT_EQ(0u, sample.dispatch_event_id);
+    EXPECT_EQ(0u, sample.submission_id);
+    EXPECT_EQ(0u, sample.command_buffer_id);
+    EXPECT_EQ(0u, sample.executable_id);
+    EXPECT_EQ(UINT32_MAX, sample.command_index);
+    EXPECT_EQ(UINT32_MAX, sample.export_ordinal);
+    ASSERT_LT(sample.physical_device_ordinal,
+              logical_device->physical_device_count);
+    const iree_hal_amdgpu_physical_device_t* physical_device =
+        logical_device->physical_devices[sample.physical_device_ordinal];
+    const uint32_t expected_queue_ordinal =
+        (uint32_t)(physical_device->host_queue_count - 1);
+    EXPECT_EQ(sample.physical_device_ordinal, physical_device->device_ordinal);
+    EXPECT_EQ(expected_queue_ordinal, sample.queue_ordinal);
+    EXPECT_LT(sample.start_tick, sample.end_tick);
+    sample_value_count += sample.sample_value_count;
+  }
+  ASSERT_EQ(sample_value_count, sink.counter_sample_values.size());
+}
+
+TEST_F(HostQueueCommandBufferTest,
+       ProfilingFlushCounterSampleWriteFailurePreservesSamplesForRetry) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+
+  CommandBufferProfileSink sink = {};
+  CommandBufferProfileSinkInitialize(&sink);
+  DeviceProfilingScope profiling(test_device.base_device());
+  iree_status_t profiling_status = BeginSqWavesProfiling(&profiling, &sink);
+  if (IsHardwareCounterProfilingUnavailable(profiling_status)) {
+    iree_status_free(profiling_status);
+    GTEST_SKIP() << "AMDGPU hardware counter profiling unavailable";
+  }
+  IREE_ASSERT_OK(profiling_status);
+
+  TwoDispatchCommandBuffer fixture;
+  IREE_ASSERT_OK(CreateTwoDispatchCommandBuffer(&test_device, &fixture));
+
+  Ref<iree_hal_semaphore_t> command_buffer_signal;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), command_buffer_signal.out()));
+  uint64_t command_buffer_signal_value = 1;
+  iree_hal_semaphore_t* command_buffer_signal_ptr = command_buffer_signal.get();
+  const iree_hal_semaphore_list_t command_buffer_signal_list = {
+      /*count=*/1,
+      /*semaphores=*/&command_buffer_signal_ptr,
+      /*payload_values=*/&command_buffer_signal_value,
+  };
+  IREE_ASSERT_OK(iree_hal_device_queue_execute(
+      test_device.base_device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+      iree_hal_semaphore_list_empty(), command_buffer_signal_list,
+      fixture.command_buffer, iree_hal_buffer_binding_table_empty(),
+      IREE_HAL_EXECUTE_FLAG_NONE));
+  IREE_ASSERT_OK(iree_hal_semaphore_wait(
+      command_buffer_signal, command_buffer_signal_value,
+      iree_infinite_timeout(), IREE_ASYNC_WAIT_FLAG_NONE));
+
+  sink.fail_write_content_type = IREE_HAL_PROFILE_CONTENT_TYPE_COUNTER_SAMPLES;
+  sink.fail_write_remaining = 1;
+  sink.fail_write_status_code = IREE_STATUS_UNAVAILABLE;
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_UNAVAILABLE,
+      iree_hal_device_profiling_flush(test_device.base_device()));
+  EXPECT_EQ(0, sink.counter_sample_count);
+  EXPECT_TRUE(sink.counter_samples.empty());
+
+  IREE_ASSERT_OK(iree_hal_device_profiling_flush(test_device.base_device()));
+  IREE_ASSERT_OK(profiling.End());
+
+  ExpectTwoDispatchOutputs(fixture);
+  EXPECT_GE(sink.counter_sample_count, 1);
+  ASSERT_EQ(2u, sink.counter_samples.size());
+  EXPECT_GE(sink.dispatch_events.size(), sink.counter_samples.size());
+}
+
+TEST_F(HostQueueCommandBufferTest,
+       CommandBufferDispatchProfileFilterSelectsCommandIndex) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+
+  CommandBufferProfileSink sink = {};
+  CommandBufferProfileSinkInitialize(&sink);
+  iree_hal_device_profiling_options_t profiling_options = {0};
+  profiling_options.data_families =
+      IREE_HAL_DEVICE_PROFILING_DATA_DISPATCH_EVENTS;
+  profiling_options.sink = CommandBufferProfileSinkAsBase(&sink);
+  profiling_options.capture_filter.flags =
+      IREE_HAL_PROFILE_CAPTURE_FILTER_FLAG_COMMAND_INDEX;
+  profiling_options.capture_filter.command_index = 1;
+  DeviceProfilingScope profiling(test_device.base_device());
+  iree_status_t profiling_status = profiling.Begin(&profiling_options);
+  if (IsProfilingUnsupported(profiling_status)) {
+    iree_status_free(profiling_status);
+    GTEST_SKIP() << "device profiling data family unsupported by backend";
+  }
+  IREE_ASSERT_OK(profiling_status);
+
+  TwoDispatchCommandBuffer fixture;
+  IREE_ASSERT_OK(CreateTwoDispatchCommandBuffer(&test_device, &fixture));
+
+  Ref<iree_hal_semaphore_t> command_buffer_signal;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), command_buffer_signal.out()));
+  uint64_t command_buffer_signal_value = 1;
+  iree_hal_semaphore_t* command_buffer_signal_ptr = command_buffer_signal.get();
+  const iree_hal_semaphore_list_t command_buffer_signal_list = {
+      /*count=*/1,
+      /*semaphores=*/&command_buffer_signal_ptr,
+      /*payload_values=*/&command_buffer_signal_value,
+  };
+  IREE_ASSERT_OK(iree_hal_device_queue_execute(
+      test_device.base_device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+      iree_hal_semaphore_list_empty(), command_buffer_signal_list,
+      fixture.command_buffer, iree_hal_buffer_binding_table_empty(),
+      IREE_HAL_EXECUTE_FLAG_NONE));
+  IREE_ASSERT_OK(iree_hal_semaphore_wait(
+      command_buffer_signal, command_buffer_signal_value,
+      iree_infinite_timeout(), IREE_ASYNC_WAIT_FLAG_NONE));
+
+  IREE_ASSERT_OK(iree_hal_device_profiling_flush(test_device.base_device()));
+  IREE_ASSERT_OK(profiling.End());
+
+  ExpectTwoDispatchOutputs(fixture);
+
+  ASSERT_EQ(1u, sink.dispatch_events.size());
+  const iree_hal_profile_dispatch_event_t& event = sink.dispatch_events[0];
+  EXPECT_EQ(IREE_HAL_PROFILE_DISPATCH_EVENT_FLAG_COMMAND_BUFFER, event.flags);
+  EXPECT_NE(0u, event.event_id);
+  EXPECT_NE(0u, event.submission_id);
+  EXPECT_NE(0u, event.command_buffer_id);
+  EXPECT_EQ(1u, event.command_index);
+  EXPECT_NE(0u, event.start_tick);
+  EXPECT_NE(0u, event.end_tick);
+  EXPECT_GE(event.end_tick, event.start_tick);
+}
+
+TEST_F(HostQueueCommandBufferTest,
+       DispatchProfileFilterCopiesExecutableExportPattern) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+
+  CommandBufferProfileSink sink = {};
+  CommandBufferProfileSinkInitialize(&sink);
+  std::string export_pattern = "scale_*";
+  iree_hal_device_profiling_options_t profiling_options = {0};
+  profiling_options.data_families =
+      IREE_HAL_DEVICE_PROFILING_DATA_DISPATCH_EVENTS;
+  profiling_options.sink = CommandBufferProfileSinkAsBase(&sink);
+  profiling_options.capture_filter.flags =
+      IREE_HAL_PROFILE_CAPTURE_FILTER_FLAG_EXECUTABLE_EXPORT_PATTERN;
+  profiling_options.capture_filter.executable_export_pattern =
+      iree_make_string_view(export_pattern.data(), export_pattern.size());
+  DeviceProfilingScope profiling(test_device.base_device());
+  iree_status_t profiling_status = profiling.Begin(&profiling_options);
+  if (IsProfilingUnsupported(profiling_status)) {
+    iree_status_free(profiling_status);
+    GTEST_SKIP() << "device profiling data family unsupported by backend";
+  }
+  IREE_ASSERT_OK(profiling_status);
+
+  export_pattern.assign("nomatch");
+
+  iree_hal_executable_cache_t* executable_cache = NULL;
+  iree_hal_executable_t* executable = NULL;
+  IREE_ASSERT_OK(LoadCtsExecutable(
+      test_device.base_device(),
+      iree_make_cstring_view("command_buffer_dispatch_constants_bindings_test."
+                             "bin"),
+      &executable_cache, &executable));
+
+  const uint64_t executable_id =
+      iree_hal_amdgpu_executable_profile_id(executable);
+  EXPECT_TRUE(iree_hal_amdgpu_logical_device_should_profile_dispatch(
+      test_device.logical_device(), executable_id, /*export_ordinal=*/0,
+      /*command_buffer_id=*/0, /*command_index=*/0,
+      /*physical_device_ordinal=*/0, /*queue_ordinal=*/0));
+
+  IREE_ASSERT_OK(profiling.End());
+  iree_hal_executable_release(executable);
+  iree_hal_executable_cache_release(executable_cache);
+}
+
+TEST_F(HostQueueCommandBufferTest,
+       ProfiledDispatchReservationFailsWhenRingFull) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+  iree_hal_amdgpu_host_queue_t* queue = test_device.first_host_queue();
+  ASSERT_NE(nullptr, queue);
+
+  CommandBufferProfileSink sink = {};
+  CommandBufferProfileSinkInitialize(&sink);
+  DeviceProfilingScope profiling(test_device.base_device());
+  iree_status_t profiling_status =
+      profiling.Begin(IREE_HAL_DEVICE_PROFILING_DATA_DISPATCH_EVENTS,
+                      CommandBufferProfileSinkAsBase(&sink));
+  if (IsProfilingUnsupported(profiling_status)) {
+    iree_status_free(profiling_status);
+    GTEST_SKIP() << "device profiling data family unsupported by backend";
+  }
+  IREE_ASSERT_OK(profiling_status);
+
+  iree_hal_amdgpu_profile_dispatch_event_reservation_t reservation = {0};
+  iree_hal_amdgpu_profile_dispatch_event_reservation_t exhausted_reservation = {
+      0};
+  const uint32_t dispatch_event_capacity =
+      iree_hal_amdgpu_host_queue_profile_dispatch_event_capacity(queue);
+  iree_slim_mutex_lock(&queue->locks.submission_mutex);
+  iree_status_t status =
+      iree_hal_amdgpu_host_queue_reserve_profile_dispatch_events(
+          queue, dispatch_event_capacity, &reservation);
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_host_queue_reserve_profile_dispatch_events(
+        queue, 1, &exhausted_reservation);
+  }
+  iree_hal_amdgpu_host_queue_cancel_profile_dispatch_events(queue, reservation);
+  iree_slim_mutex_unlock(&queue->locks.submission_mutex);
+  IREE_ASSERT_STATUS_IS(IREE_STATUS_RESOURCE_EXHAUSTED, status);
+  EXPECT_EQ(dispatch_event_capacity, reservation.event_count);
+  EXPECT_EQ(0u, exhausted_reservation.event_count);
+
+  IREE_ASSERT_OK(profiling.End());
+}
+
+TEST_F(HostQueueCommandBufferTest,
+       ProfiledQueueDeviceReservationFailsWhenRingFull) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+  iree_hal_amdgpu_host_queue_t* queue = test_device.first_host_queue();
+  ASSERT_NE(nullptr, queue);
+
+  CommandBufferProfileSink sink = {};
+  CommandBufferProfileSinkInitialize(&sink);
+  DeviceProfilingScope profiling(test_device.base_device());
+  iree_status_t profiling_status =
+      profiling.Begin(IREE_HAL_DEVICE_PROFILING_DATA_DEVICE_QUEUE_EVENTS,
+                      CommandBufferProfileSinkAsBase(&sink));
+  if (IsQueueDeviceProfilingUnavailable(profiling_status)) {
+    iree_status_free(profiling_status);
+    GTEST_SKIP() << "queue-device profiling data family unsupported by backend";
+  }
+  IREE_ASSERT_OK(profiling_status);
+
+  iree_hal_amdgpu_profile_queue_device_event_reservation_t reservation = {0};
+  iree_hal_amdgpu_profile_queue_device_event_reservation_t
+      exhausted_reservation = {0};
+  const uint32_t queue_device_event_capacity =
+      queue->profiling.queue_device_events.capacity;
+  ASSERT_GT(queue_device_event_capacity, 0u);
+  iree_slim_mutex_lock(&queue->locks.submission_mutex);
+  iree_status_t status =
+      iree_hal_amdgpu_host_queue_reserve_profile_queue_device_events(
+          queue, queue_device_event_capacity, &reservation);
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_host_queue_reserve_profile_queue_device_events(
+        queue, 1, &exhausted_reservation);
+  }
+  iree_hal_amdgpu_host_queue_cancel_profile_queue_device_events(queue,
+                                                                reservation);
+  iree_slim_mutex_unlock(&queue->locks.submission_mutex);
+  IREE_ASSERT_STATUS_IS(IREE_STATUS_RESOURCE_EXHAUSTED, status);
+  EXPECT_EQ(queue_device_event_capacity, reservation.event_count);
+  EXPECT_EQ(0u, exhausted_reservation.event_count);
+
+  IREE_ASSERT_OK(profiling.End());
+}
+
+TEST_F(HostQueueCommandBufferTest,
+       ProfiledCommandBufferDispatchSignalsSurviveAqlSlotReuse) {
+  static constexpr uint32_t kAqlCapacity = 64;
+  static constexpr uint32_t kDispatchCount = kAqlCapacity + 32;
+
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.host_block_pools.command_buffer.usable_block_size =
+      IREE_HAL_AMDGPU_AQL_PROGRAM_MIN_BLOCK_SIZE;
+  options.host_queues.aql_capacity = kAqlCapacity;
+  options.host_queues.kernarg_capacity = 2 * kAqlCapacity;
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+
+  CommandBufferProfileSink sink = {};
+  CommandBufferProfileSinkInitialize(&sink);
+  DeviceProfilingScope profiling(test_device.base_device());
+  iree_status_t profiling_status =
+      profiling.Begin(IREE_HAL_DEVICE_PROFILING_DATA_DISPATCH_EVENTS,
+                      CommandBufferProfileSinkAsBase(&sink));
+  if (IsProfilingUnsupported(profiling_status)) {
+    iree_status_free(profiling_status);
+    GTEST_SKIP() << "device profiling data family unsupported by backend";
+  }
+  IREE_ASSERT_OK(profiling_status);
+
+  iree_hal_executable_cache_t* executable_cache = NULL;
+  iree_hal_executable_t* executable = NULL;
+  IREE_ASSERT_OK(LoadCtsExecutable(
+      test_device.base_device(),
+      iree_make_cstring_view("command_buffer_dispatch_constants_bindings_test."
+                             "bin"),
+      &executable_cache, &executable));
+
+  Ref<iree_hal_buffer_t> input_buffer;
+  const uint32_t input_values[4] = {1, 2, 3, 4};
+  IREE_ASSERT_OK(CreateHostVisibleDispatchBuffer(
+      test_device.allocator(), /*buffer_size=*/4 * sizeof(uint32_t),
+      input_buffer.out()));
+  IREE_ASSERT_OK(iree_hal_buffer_map_write(input_buffer, /*target_offset=*/0,
+                                           input_values, sizeof(input_values)));
+
+  Ref<iree_hal_buffer_t> output_buffer;
+  IREE_ASSERT_OK(CreateHostVisibleDispatchBuffer(
+      test_device.allocator(), /*buffer_size=*/4 * sizeof(uint32_t),
+      output_buffer.out()));
+  IREE_ASSERT_OK(iree_hal_buffer_map_zero(output_buffer, /*offset=*/0,
+                                          IREE_HAL_WHOLE_BUFFER));
+
+  Ref<iree_hal_command_buffer_t> command_buffer;
+  IREE_ASSERT_OK(iree_hal_command_buffer_create(
+      test_device.base_device(),
+      IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT |
+          IREE_HAL_COMMAND_BUFFER_MODE_RETAIN_PROFILE_METADATA,
+      IREE_HAL_COMMAND_CATEGORY_DISPATCH, IREE_HAL_QUEUE_AFFINITY_ANY,
+      /*binding_capacity=*/0, command_buffer.out()));
+  IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer));
+  iree_hal_buffer_ref_t binding_refs[2] = {
+      iree_hal_make_buffer_ref(input_buffer, /*offset=*/0,
+                               iree_hal_buffer_byte_length(input_buffer)),
+      iree_hal_make_buffer_ref(output_buffer, /*offset=*/0,
+                               iree_hal_buffer_byte_length(output_buffer)),
+  };
+  const iree_hal_buffer_ref_list_t bindings = {
+      /*count=*/IREE_ARRAYSIZE(binding_refs),
+      /*values=*/binding_refs,
+  };
+  for (uint32_t i = 0; i < kDispatchCount; ++i) {
+    IREE_ASSERT_OK(
+        AppendConstantsBindingsDispatch(command_buffer, executable, bindings));
+  }
+  IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer));
+
+  Ref<iree_hal_semaphore_t> command_buffer_signal;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), command_buffer_signal.out()));
+  uint64_t command_buffer_signal_value = 1;
+  iree_hal_semaphore_t* command_buffer_signal_ptr = command_buffer_signal.get();
+  const iree_hal_semaphore_list_t command_buffer_signal_list = {
+      /*count=*/1,
+      /*semaphores=*/&command_buffer_signal_ptr,
+      /*payload_values=*/&command_buffer_signal_value,
+  };
+  IREE_ASSERT_OK(iree_hal_device_queue_execute(
+      test_device.base_device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+      iree_hal_semaphore_list_empty(), command_buffer_signal_list,
+      command_buffer, iree_hal_buffer_binding_table_empty(),
+      IREE_HAL_EXECUTE_FLAG_NONE));
+  IREE_ASSERT_OK(iree_hal_semaphore_wait(
+      command_buffer_signal, command_buffer_signal_value,
+      iree_infinite_timeout(), IREE_ASYNC_WAIT_FLAG_NONE));
+
+  IREE_ASSERT_OK(iree_hal_device_profiling_flush(test_device.base_device()));
+  IREE_ASSERT_OK(profiling.End());
+
+  const uint32_t expected_values[4] = {13, 16, 19, 22};
+  uint32_t output_values[4] = {0, 0, 0, 0};
+  IREE_ASSERT_OK(iree_hal_buffer_map_read(
+      output_buffer, /*offset=*/0, output_values, sizeof(output_values)));
+  EXPECT_EQ(0, memcmp(output_values, expected_values, sizeof(expected_values)));
+
+  EXPECT_EQ(1, sink.begin_count);
+  EXPECT_EQ(1, sink.end_count);
+  EXPECT_FALSE(sink.write_after_end);
+  ASSERT_EQ(kDispatchCount, sink.dispatch_events.size());
+  for (iree_host_size_t i = 0; i < sink.dispatch_events.size(); ++i) {
+    const iree_hal_profile_dispatch_event_t& event = sink.dispatch_events[i];
+    EXPECT_EQ(sizeof(iree_hal_profile_dispatch_event_t), event.record_length);
+    EXPECT_EQ(IREE_HAL_PROFILE_DISPATCH_EVENT_FLAG_COMMAND_BUFFER, event.flags);
+    EXPECT_NE(0u, event.event_id);
+    EXPECT_NE(0u, event.submission_id);
+    EXPECT_NE(0u, event.start_tick);
+    EXPECT_NE(0u, event.end_tick);
+    EXPECT_GE(event.end_tick, event.start_tick);
+  }
+
+  iree_hal_executable_release(executable);
+  iree_hal_executable_cache_release(executable_cache);
+}
+
+TEST_F(HostQueueCommandBufferTest,
+       SingleBlockCommandBufferParksAndResumesUnderNotificationPressure) {
+  static constexpr uint32_t kAqlCapacity = 64;
+  static constexpr uint32_t kNotificationCapacity = 1;
+  static constexpr uint32_t kKernargCapacity = 2 * kAqlCapacity;
+
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.host_block_pools.command_buffer.usable_block_size =
+      IREE_HAL_AMDGPU_AQL_PROGRAM_MIN_BLOCK_SIZE;
+  options.host_queues.aql_capacity = kAqlCapacity;
+  options.host_queues.notification_capacity = kNotificationCapacity;
+  options.host_queues.kernarg_capacity = kKernargCapacity;
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+  iree_hal_amdgpu_host_queue_t* queue = test_device.first_host_queue();
+  ASSERT_NE(queue, nullptr);
+
+  Ref<iree_hal_buffer_t> pressure_buffer;
+  IREE_ASSERT_OK(CreateHostVisibleTransferBuffer(
+      test_device.allocator(), sizeof(uint32_t), pressure_buffer.out()));
+
+  Ref<iree_hal_buffer_t> target_buffer;
+  IREE_ASSERT_OK(CreateHostVisibleTransferBuffer(
+      test_device.allocator(), sizeof(uint32_t), target_buffer.out()));
+  IREE_ASSERT_OK(iree_hal_buffer_map_zero(target_buffer, /*offset=*/0,
+                                          IREE_HAL_WHOLE_BUFFER));
+
+  Ref<iree_hal_command_buffer_t> command_buffer;
+  IREE_ASSERT_OK(iree_hal_command_buffer_create(
+      test_device.base_device(), IREE_HAL_COMMAND_BUFFER_MODE_DEFAULT,
+      IREE_HAL_COMMAND_CATEGORY_TRANSFER, IREE_HAL_QUEUE_AFFINITY_ANY,
+      /*binding_capacity=*/1, command_buffer.out()));
+  IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer));
+  const uint32_t expected = 0xBD3A0001u;
+  IREE_ASSERT_OK(iree_hal_command_buffer_fill_buffer(
+      command_buffer,
+      iree_hal_make_indirect_buffer_ref(/*binding=*/0, /*offset=*/0,
+                                        sizeof(expected)),
+      &expected, sizeof(expected), IREE_HAL_FILL_FLAG_NONE));
+  IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer));
+
+  const iree_hal_amdgpu_aql_program_t* program =
+      iree_hal_amdgpu_aql_command_buffer_program(command_buffer);
+  ASSERT_EQ(program->block_count, 1u);
+  ASSERT_GT(program->max_block_aql_packet_count, 0u);
+
+  Ref<iree_hal_semaphore_t> pressure_signal;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), pressure_signal.out()));
+  Ref<iree_hal_semaphore_t> command_buffer_signal;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), command_buffer_signal.out()));
+
+  hsa_signal_t blocker_signal = iree_hsa_signal_null();
+  IREE_ASSERT_OK(iree_hsa_amd_signal_create(
+      IREE_LIBHSA(&libhsa_), /*initial_value=*/1, /*num_consumers=*/0,
+      /*consumers=*/NULL, /*attributes=*/0, &blocker_signal));
+  IREE_ASSERT_OK(EnqueueRawBlockingBarrier(queue, blocker_signal));
+
+  uint64_t pressure_signal_value = 1;
+  iree_hal_semaphore_t* pressure_signal_ptr = pressure_signal.get();
+  iree_hal_semaphore_list_t pressure_signal_list = {
+      /*count=*/1,
+      /*semaphores=*/&pressure_signal_ptr,
+      /*payload_values=*/&pressure_signal_value,
+  };
+  const uint32_t pressure_pattern = 0xABCD1234u;
+  iree_status_t status = iree_hal_device_queue_fill(
+      test_device.base_device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+      iree_hal_semaphore_list_empty(), pressure_signal_list, pressure_buffer,
+      /*target_offset=*/0, sizeof(pressure_pattern), &pressure_pattern,
+      sizeof(pressure_pattern), IREE_HAL_FILL_FLAG_NONE);
+
+  uint64_t command_buffer_signal_value = 1;
+  iree_hal_semaphore_t* command_buffer_signal_ptr = command_buffer_signal.get();
+  iree_hal_semaphore_list_t command_buffer_signal_list = {
+      /*count=*/1,
+      /*semaphores=*/&command_buffer_signal_ptr,
+      /*payload_values=*/&command_buffer_signal_value,
+  };
+  iree_hal_buffer_binding_t binding = {
+      /*buffer=*/target_buffer.get(),
+      /*offset=*/0,
+      /*length=*/IREE_HAL_WHOLE_BUFFER,
+  };
+  const iree_hal_buffer_binding_table_t binding_table = {
+      /*count=*/1,
+      /*bindings=*/&binding,
+  };
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_device_queue_execute(
+        test_device.base_device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+        iree_hal_semaphore_list_empty(), command_buffer_signal_list,
+        command_buffer, binding_table, IREE_HAL_EXECUTE_FLAG_NONE);
+  }
+  const bool replay_parked =
+      iree_status_is_ok(status) && HostQueueHasPostDrainAction(queue);
+
+  iree_hsa_signal_store_screlease(IREE_LIBHSA(&libhsa_), blocker_signal, 0);
+
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_semaphore_wait(
+        command_buffer_signal, command_buffer_signal_value,
+        iree_infinite_timeout(), IREE_ASYNC_WAIT_FLAG_NONE);
+  }
+  IREE_EXPECT_OK(
+      iree_hsa_signal_destroy(IREE_LIBHSA(&libhsa_), blocker_signal));
+
+  IREE_ASSERT_OK(status);
+  EXPECT_TRUE(replay_parked);
+
+  uint32_t actual = 0;
+  IREE_ASSERT_OK(iree_hal_buffer_map_read(target_buffer, /*offset=*/0, &actual,
+                                          sizeof(actual)));
+  EXPECT_EQ(actual, expected);
+}
+
+TEST_F(HostQueueCommandBufferTest,
+       MetadataOnlyCommandBufferParksAndResumesUnderNotificationPressure) {
+  static constexpr uint32_t kAqlCapacity = 64;
+  static constexpr uint32_t kNotificationCapacity = 1;
+  static constexpr uint32_t kKernargCapacity = 2 * kAqlCapacity;
+
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.host_block_pools.command_buffer.usable_block_size =
+      IREE_HAL_AMDGPU_AQL_PROGRAM_MIN_BLOCK_SIZE;
+  options.host_queues.aql_capacity = kAqlCapacity;
+  options.host_queues.notification_capacity = kNotificationCapacity;
+  options.host_queues.kernarg_capacity = kKernargCapacity;
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+  iree_hal_amdgpu_host_queue_t* queue = test_device.first_host_queue();
+  ASSERT_NE(queue, nullptr);
+
+  CommandBufferProfileSink sink = {};
+  CommandBufferProfileSinkInitialize(&sink);
+  DeviceProfilingScope profiling(test_device.base_device());
+  IREE_ASSERT_OK(profiling.Begin(IREE_HAL_DEVICE_PROFILING_DATA_QUEUE_EVENTS,
+                                 CommandBufferProfileSinkAsBase(&sink)));
+
+  Ref<iree_hal_buffer_t> pressure_buffer;
+  IREE_ASSERT_OK(CreateHostVisibleTransferBuffer(
+      test_device.allocator(), sizeof(uint32_t), pressure_buffer.out()));
+
+  Ref<iree_hal_command_buffer_t> command_buffer;
+  IREE_ASSERT_OK(iree_hal_command_buffer_create(
+      test_device.base_device(),
+      IREE_HAL_COMMAND_BUFFER_MODE_RETAIN_PROFILE_METADATA,
+      IREE_HAL_COMMAND_CATEGORY_TRANSFER, IREE_HAL_QUEUE_AFFINITY_ANY,
+      /*binding_capacity=*/0, command_buffer.out()));
+  IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer));
+  IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer));
+
+  const iree_hal_amdgpu_aql_program_t* program =
+      iree_hal_amdgpu_aql_command_buffer_program(command_buffer);
+  ASSERT_EQ(program->max_block_aql_packet_count, 0u);
+
+  Ref<iree_hal_semaphore_t> pressure_signal;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), pressure_signal.out()));
+  Ref<iree_hal_semaphore_t> command_buffer_signal;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), command_buffer_signal.out()));
+
+  hsa_signal_t blocker_signal = iree_hsa_signal_null();
+  IREE_ASSERT_OK(iree_hsa_amd_signal_create(
+      IREE_LIBHSA(&libhsa_), /*initial_value=*/1, /*num_consumers=*/0,
+      /*consumers=*/NULL, /*attributes=*/0, &blocker_signal));
+  IREE_ASSERT_OK(EnqueueRawBlockingBarrier(queue, blocker_signal));
+
+  uint64_t pressure_signal_value = 1;
+  iree_hal_semaphore_t* pressure_signal_ptr = pressure_signal.get();
+  iree_hal_semaphore_list_t pressure_signal_list = {
+      /*count=*/1,
+      /*semaphores=*/&pressure_signal_ptr,
+      /*payload_values=*/&pressure_signal_value,
+  };
+  const uint32_t pressure_pattern = 0xABCD1234u;
+  iree_status_t status = iree_hal_device_queue_fill(
+      test_device.base_device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+      iree_hal_semaphore_list_empty(), pressure_signal_list, pressure_buffer,
+      /*target_offset=*/0, sizeof(pressure_pattern), &pressure_pattern,
+      sizeof(pressure_pattern), IREE_HAL_FILL_FLAG_NONE);
+
+  uint64_t command_buffer_signal_value = 1;
+  iree_hal_semaphore_t* command_buffer_signal_ptr = command_buffer_signal.get();
+  iree_hal_semaphore_list_t command_buffer_signal_list = {
+      /*count=*/1,
+      /*semaphores=*/&command_buffer_signal_ptr,
+      /*payload_values=*/&command_buffer_signal_value,
+  };
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_device_queue_execute(
+        test_device.base_device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+        iree_hal_semaphore_list_empty(), command_buffer_signal_list,
+        command_buffer, iree_hal_buffer_binding_table_empty(),
+        IREE_HAL_EXECUTE_FLAG_NONE);
+  }
+  const bool replay_parked =
+      iree_status_is_ok(status) && HostQueueHasPostDrainAction(queue);
+
+  iree_hsa_signal_store_screlease(IREE_LIBHSA(&libhsa_), blocker_signal, 0);
+
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_semaphore_wait(
+        command_buffer_signal, command_buffer_signal_value,
+        iree_infinite_timeout(), IREE_ASYNC_WAIT_FLAG_NONE);
+  }
+  IREE_EXPECT_OK(
+      iree_hsa_signal_destroy(IREE_LIBHSA(&libhsa_), blocker_signal));
+
+  IREE_ASSERT_OK(status);
+  IREE_ASSERT_OK(iree_hal_device_profiling_flush(test_device.base_device()));
+  IREE_ASSERT_OK(profiling.End());
+  EXPECT_TRUE(replay_parked);
+
+  EXPECT_EQ(1, sink.begin_count);
+  EXPECT_EQ(1, sink.end_count);
+  EXPECT_EQ(1, sink.command_buffer_metadata_count);
+  EXPECT_EQ(1, sink.queue_event_count);
+  ASSERT_EQ(1u, sink.command_buffer_ids.size());
+  EXPECT_EQ(1u, CountQueueEvents(sink, IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_FILL));
+  EXPECT_EQ(1u,
+            CountQueueEvents(sink, IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_EXECUTE));
+  const iree_hal_profile_queue_event_t* execute_event =
+      FindUniqueQueueEvent(sink, IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_EXECUTE);
+  ASSERT_NE(nullptr, execute_event);
+  EXPECT_EQ(sink.command_buffer_ids[0], execute_event->command_buffer_id);
+  EXPECT_EQ(0u, execute_event->operation_count);
+  EXPECT_EQ(1u, execute_event->signal_count);
+  EXPECT_LE(sink.command_operations.size(), 1u);
+  for (const auto& operation : sink.command_operations) {
+    EXPECT_EQ(sink.command_buffer_ids[0], operation.command_buffer_id);
+    EXPECT_EQ(IREE_HAL_PROFILE_COMMAND_OPERATION_TYPE_RETURN, operation.type);
+  }
+}
+
+TEST_F(HostQueueCommandBufferTest,
+       DeferredTransientBindingSurvivesQueuedDealloca) {
+  static constexpr uint32_t kAqlCapacity = 64;
+  static constexpr uint32_t kNotificationCapacity = 1;
+  static constexpr uint32_t kKernargCapacity = 2 * kAqlCapacity;
+
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.host_block_pools.command_buffer.usable_block_size =
+      IREE_HAL_AMDGPU_AQL_PROGRAM_MIN_BLOCK_SIZE;
+  options.host_queues.aql_capacity = kAqlCapacity;
+  options.host_queues.notification_capacity = kNotificationCapacity;
+  options.host_queues.kernarg_capacity = kKernargCapacity;
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+  iree_hal_amdgpu_host_queue_t* queue = test_device.first_host_queue();
+  ASSERT_NE(queue, nullptr);
+
+  Ref<iree_hal_buffer_t> pressure_buffer;
+  IREE_ASSERT_OK(CreateHostVisibleTransferBuffer(
+      test_device.allocator(), sizeof(uint32_t), pressure_buffer.out()));
+
+  Ref<iree_hal_buffer_t> output_buffer;
+  IREE_ASSERT_OK(CreateHostVisibleTransferBuffer(
+      test_device.allocator(), sizeof(uint32_t), output_buffer.out()));
+  IREE_ASSERT_OK(iree_hal_buffer_map_zero(output_buffer, /*offset=*/0,
+                                          IREE_HAL_WHOLE_BUFFER));
+
+  Ref<iree_hal_semaphore_t> alloca_signal;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), alloca_signal.out()));
+  uint64_t alloca_signal_value = 1;
+  iree_hal_semaphore_t* alloca_signal_ptr = alloca_signal.get();
+  iree_hal_semaphore_list_t alloca_signal_list = {
+      /*count=*/1,
+      /*semaphores=*/&alloca_signal_ptr,
+      /*payload_values=*/&alloca_signal_value,
+  };
+  iree_hal_buffer_t* transient_raw = NULL;
+  IREE_ASSERT_OK(QueueTransientTransferBuffer(
+      test_device.base_device(), alloca_signal_list, sizeof(uint32_t),
+      &transient_raw));
+  Ref<iree_hal_buffer_t> transient_buffer(transient_raw);
+  IREE_ASSERT_OK(iree_hal_semaphore_wait(alloca_signal, alloca_signal_value,
+                                         iree_infinite_timeout(),
+                                         IREE_ASYNC_WAIT_FLAG_NONE));
+
+  Ref<iree_hal_command_buffer_t> command_buffer;
+  IREE_ASSERT_OK(iree_hal_command_buffer_create(
+      test_device.base_device(), IREE_HAL_COMMAND_BUFFER_MODE_DEFAULT,
+      IREE_HAL_COMMAND_CATEGORY_TRANSFER, IREE_HAL_QUEUE_AFFINITY_ANY,
+      /*binding_capacity=*/2, command_buffer.out()));
+  IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer));
+  const uint32_t expected = 0xBD3A0002u;
+  IREE_ASSERT_OK(iree_hal_command_buffer_fill_buffer(
+      command_buffer,
+      iree_hal_make_indirect_buffer_ref(/*binding=*/0, /*offset=*/0,
+                                        sizeof(expected)),
+      &expected, sizeof(expected), IREE_HAL_FILL_FLAG_NONE));
+  IREE_ASSERT_OK(iree_hal_command_buffer_copy_buffer(
+      command_buffer,
+      iree_hal_make_indirect_buffer_ref(/*binding=*/0, /*offset=*/0,
+                                        sizeof(expected)),
+      iree_hal_make_indirect_buffer_ref(/*binding=*/1, /*offset=*/0,
+                                        sizeof(expected)),
+      IREE_HAL_COPY_FLAG_NONE));
+  IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer));
+
+  Ref<iree_hal_semaphore_t> pressure_signal;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), pressure_signal.out()));
+  Ref<iree_hal_semaphore_t> command_buffer_signal;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), command_buffer_signal.out()));
+  Ref<iree_hal_semaphore_t> dealloca_signal;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), dealloca_signal.out()));
+
+  hsa_signal_t blocker_signal = iree_hsa_signal_null();
+  IREE_ASSERT_OK(iree_hsa_amd_signal_create(
+      IREE_LIBHSA(&libhsa_), /*initial_value=*/1, /*num_consumers=*/0,
+      /*consumers=*/NULL, /*attributes=*/0, &blocker_signal));
+  IREE_ASSERT_OK(EnqueueRawBlockingBarrier(queue, blocker_signal));
+
+  uint64_t pressure_signal_value = 1;
+  iree_hal_semaphore_t* pressure_signal_ptr = pressure_signal.get();
+  iree_hal_semaphore_list_t pressure_signal_list = {
+      /*count=*/1,
+      /*semaphores=*/&pressure_signal_ptr,
+      /*payload_values=*/&pressure_signal_value,
+  };
+  const uint32_t pressure_pattern = 0xABCD1234u;
+  iree_status_t status = iree_hal_device_queue_fill(
+      test_device.base_device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+      iree_hal_semaphore_list_empty(), pressure_signal_list, pressure_buffer,
+      /*target_offset=*/0, sizeof(pressure_pattern), &pressure_pattern,
+      sizeof(pressure_pattern), IREE_HAL_FILL_FLAG_NONE);
+
+  uint64_t command_buffer_signal_value = 1;
+  iree_hal_semaphore_t* command_buffer_signal_ptr = command_buffer_signal.get();
+  iree_hal_semaphore_list_t command_buffer_signal_list = {
+      /*count=*/1,
+      /*semaphores=*/&command_buffer_signal_ptr,
+      /*payload_values=*/&command_buffer_signal_value,
+  };
+  iree_hal_buffer_binding_t bindings[2] = {
+      {
+          /*buffer=*/transient_buffer.get(),
+          /*offset=*/0,
+          /*length=*/IREE_HAL_WHOLE_BUFFER,
+      },
+      {
+          /*buffer=*/output_buffer.get(),
+          /*offset=*/0,
+          /*length=*/IREE_HAL_WHOLE_BUFFER,
+      },
+  };
+  const iree_hal_buffer_binding_table_t binding_table = {
+      /*count=*/IREE_ARRAYSIZE(bindings),
+      /*bindings=*/bindings,
+  };
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_device_queue_execute(
+        test_device.base_device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+        iree_hal_semaphore_list_empty(), command_buffer_signal_list,
+        command_buffer, binding_table, IREE_HAL_EXECUTE_FLAG_NONE);
+  }
+  const bool replay_parked =
+      iree_status_is_ok(status) && HostQueueHasPostDrainAction(queue);
+
+  uint64_t dealloca_signal_value = 1;
+  iree_hal_semaphore_t* dealloca_signal_ptr = dealloca_signal.get();
+  iree_hal_semaphore_list_t dealloca_wait_list = {
+      /*count=*/1,
+      /*semaphores=*/&command_buffer_signal_ptr,
+      /*payload_values=*/&command_buffer_signal_value,
+  };
+  iree_hal_semaphore_list_t dealloca_signal_list = {
+      /*count=*/1,
+      /*semaphores=*/&dealloca_signal_ptr,
+      /*payload_values=*/&dealloca_signal_value,
+  };
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_device_queue_dealloca(
+        test_device.base_device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+        dealloca_wait_list, dealloca_signal_list, transient_buffer,
+        IREE_HAL_DEALLOCA_FLAG_NONE);
+  }
+
+  iree_hsa_signal_store_screlease(IREE_LIBHSA(&libhsa_), blocker_signal, 0);
+
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_semaphore_wait(dealloca_signal, dealloca_signal_value,
+                                     iree_infinite_timeout(),
+                                     IREE_ASYNC_WAIT_FLAG_NONE);
+  }
+  IREE_EXPECT_OK(
+      iree_hsa_signal_destroy(IREE_LIBHSA(&libhsa_), blocker_signal));
+
+  IREE_ASSERT_OK(status);
+  EXPECT_TRUE(replay_parked);
+
+  uint32_t actual = 0;
+  IREE_ASSERT_OK(iree_hal_buffer_map_read(output_buffer, /*offset=*/0, &actual,
+                                          sizeof(actual)));
+  EXPECT_EQ(actual, expected);
+}
+
+TEST_F(HostQueueCommandBufferTest,
+       OneShotStaticTransientBindingRecordsBeforeAllocaCommit) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+
+  Ref<iree_hal_buffer_t> output_buffer;
+  IREE_ASSERT_OK(CreateHostVisibleTransferBuffer(
+      test_device.allocator(), sizeof(uint32_t), output_buffer.out()));
+  IREE_ASSERT_OK(iree_hal_buffer_map_zero(output_buffer, /*offset=*/0,
+                                          IREE_HAL_WHOLE_BUFFER));
+
+  Ref<iree_hal_semaphore_t> alloca_signal;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), alloca_signal.out()));
+  uint64_t alloca_signal_value = 1;
+  iree_hal_semaphore_t* alloca_signal_ptr = alloca_signal.get();
+  iree_hal_semaphore_list_t alloca_signal_list = {
+      /*count=*/1,
+      /*semaphores=*/&alloca_signal_ptr,
+      /*payload_values=*/&alloca_signal_value,
+  };
+  iree_hal_buffer_t* transient_raw = NULL;
+  IREE_ASSERT_OK(QueueTransientTransferBuffer(
+      test_device.base_device(), alloca_signal_list, sizeof(uint32_t),
+      &transient_raw));
+  Ref<iree_hal_buffer_t> transient_buffer(transient_raw);
+
+  Ref<iree_hal_command_buffer_t> command_buffer;
+  IREE_ASSERT_OK(iree_hal_command_buffer_create(
+      test_device.base_device(), IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT,
+      IREE_HAL_COMMAND_CATEGORY_TRANSFER, IREE_HAL_QUEUE_AFFINITY_ANY,
+      /*binding_capacity=*/0, command_buffer.out()));
+  IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer));
+  const uint32_t expected = 0xBD3A0003u;
+  IREE_ASSERT_OK(iree_hal_command_buffer_fill_buffer(
+      command_buffer,
+      iree_hal_make_buffer_ref(transient_buffer.get(), /*offset=*/0,
+                               sizeof(expected)),
+      &expected, sizeof(expected), IREE_HAL_FILL_FLAG_NONE));
+  IREE_ASSERT_OK(iree_hal_command_buffer_copy_buffer(
+      command_buffer,
+      iree_hal_make_buffer_ref(transient_buffer.get(), /*offset=*/0,
+                               sizeof(expected)),
+      iree_hal_make_buffer_ref(output_buffer.get(), /*offset=*/0,
+                               sizeof(expected)),
+      IREE_HAL_COPY_FLAG_NONE));
+  IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer));
+
+  Ref<iree_hal_semaphore_t> command_buffer_signal;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), command_buffer_signal.out()));
+  uint64_t command_buffer_signal_value = 1;
+  iree_hal_semaphore_t* command_buffer_signal_ptr = command_buffer_signal.get();
+  iree_hal_semaphore_list_t command_buffer_signal_list = {
+      /*count=*/1,
+      /*semaphores=*/&command_buffer_signal_ptr,
+      /*payload_values=*/&command_buffer_signal_value,
+  };
+  IREE_ASSERT_OK(iree_hal_device_queue_execute(
+      test_device.base_device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+      alloca_signal_list, command_buffer_signal_list, command_buffer,
+      iree_hal_buffer_binding_table_empty(), IREE_HAL_EXECUTE_FLAG_NONE));
+
+  Ref<iree_hal_semaphore_t> dealloca_signal;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), dealloca_signal.out()));
+  uint64_t dealloca_signal_value = 1;
+  iree_hal_semaphore_t* dealloca_signal_ptr = dealloca_signal.get();
+  iree_hal_semaphore_list_t dealloca_signal_list = {
+      /*count=*/1,
+      /*semaphores=*/&dealloca_signal_ptr,
+      /*payload_values=*/&dealloca_signal_value,
+  };
+  IREE_ASSERT_OK(iree_hal_device_queue_dealloca(
+      test_device.base_device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+      command_buffer_signal_list, dealloca_signal_list, transient_buffer,
+      IREE_HAL_DEALLOCA_FLAG_NONE));
+  IREE_ASSERT_OK(iree_hal_semaphore_wait(dealloca_signal, dealloca_signal_value,
+                                         iree_infinite_timeout(),
+                                         IREE_ASYNC_WAIT_FLAG_NONE));
+
+  uint32_t actual = 0;
+  IREE_ASSERT_OK(iree_hal_buffer_map_read(output_buffer, /*offset=*/0, &actual,
+                                          sizeof(actual)));
+  EXPECT_EQ(actual, expected);
+}
+
+TEST_F(HostQueueCommandBufferTest,
+       LargeCommandBufferParksAndResumesUnderNotificationPressure) {
+  static constexpr uint32_t kFillCount = 2048;
+  static constexpr uint32_t kAqlCapacity = 64;
+  static constexpr uint32_t kNotificationCapacity = 1;
+  static constexpr uint32_t kKernargCapacity = 2 * kAqlCapacity;
+
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.host_block_pools.command_buffer.usable_block_size =
+      IREE_HAL_AMDGPU_AQL_PROGRAM_MIN_BLOCK_SIZE;
+  options.host_queues.aql_capacity = kAqlCapacity;
+  options.host_queues.notification_capacity = kNotificationCapacity;
+  options.host_queues.kernarg_capacity = kKernargCapacity;
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+  iree_hal_amdgpu_host_queue_t* queue = test_device.first_host_queue();
+  ASSERT_NE(queue, nullptr);
+
+  CommandBufferProfileSink sink = {};
+  CommandBufferProfileSinkInitialize(&sink);
+  DeviceProfilingScope profiling(test_device.base_device());
+  IREE_ASSERT_OK(profiling.Begin(IREE_HAL_DEVICE_PROFILING_DATA_QUEUE_EVENTS,
+                                 CommandBufferProfileSinkAsBase(&sink)));
+
+  Ref<iree_hal_buffer_t> pressure_buffer;
+  IREE_ASSERT_OK(CreateHostVisibleTransferBuffer(
+      test_device.allocator(), sizeof(uint32_t), pressure_buffer.out()));
+
+  const iree_device_size_t target_buffer_size = kFillCount * sizeof(uint32_t);
+  Ref<iree_hal_buffer_t> target_buffer;
+  IREE_ASSERT_OK(CreateHostVisibleTransferBuffer(
+      test_device.allocator(), target_buffer_size, target_buffer.out()));
+  IREE_ASSERT_OK(iree_hal_buffer_map_zero(target_buffer, /*offset=*/0,
+                                          IREE_HAL_WHOLE_BUFFER));
+
+  Ref<iree_hal_command_buffer_t> command_buffer;
+  IREE_ASSERT_OK(iree_hal_command_buffer_create(
+      test_device.base_device(),
+      IREE_HAL_COMMAND_BUFFER_MODE_RETAIN_PROFILE_METADATA,
+      IREE_HAL_COMMAND_CATEGORY_TRANSFER, IREE_HAL_QUEUE_AFFINITY_ANY,
+      /*binding_capacity=*/1, command_buffer.out()));
+  IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer));
+  std::vector<uint32_t> expected(kFillCount);
+  for (uint32_t i = 0; i < kFillCount; ++i) {
+    expected[i] = 0xBD3A0000u | i;
+    IREE_ASSERT_OK(iree_hal_command_buffer_fill_buffer(
+        command_buffer,
+        iree_hal_make_indirect_buffer_ref(/*binding=*/0, i * sizeof(uint32_t),
+                                          sizeof(uint32_t)),
+        &expected[i], sizeof(expected[i]), IREE_HAL_FILL_FLAG_NONE));
+  }
+  IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer));
+
+  const iree_hal_amdgpu_aql_program_t* program =
+      iree_hal_amdgpu_aql_command_buffer_program(command_buffer);
+  ASSERT_GT(program->block_count, 1u);
+  ASSERT_GT(kFillCount, kAqlCapacity);
+  ASSERT_LE(program->max_block_aql_packet_count, kAqlCapacity);
+
+  Ref<iree_hal_semaphore_t> pressure_signal;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), pressure_signal.out()));
+  Ref<iree_hal_semaphore_t> command_buffer_signal;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), command_buffer_signal.out()));
+
+  hsa_signal_t blocker_signal = iree_hsa_signal_null();
+  IREE_ASSERT_OK(iree_hsa_amd_signal_create(
+      IREE_LIBHSA(&libhsa_), /*initial_value=*/1, /*num_consumers=*/0,
+      /*consumers=*/NULL, /*attributes=*/0, &blocker_signal));
+  IREE_ASSERT_OK(EnqueueRawBlockingBarrier(queue, blocker_signal));
+
+  uint64_t pressure_signal_value = 1;
+  iree_hal_semaphore_t* pressure_signal_ptr = pressure_signal.get();
+  iree_hal_semaphore_list_t pressure_signal_list = {
+      /*count=*/1,
+      /*semaphores=*/&pressure_signal_ptr,
+      /*payload_values=*/&pressure_signal_value,
+  };
+  const uint32_t pressure_pattern = 0xABCD1234u;
+  iree_status_t status = iree_hal_device_queue_fill(
+      test_device.base_device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+      iree_hal_semaphore_list_empty(), pressure_signal_list, pressure_buffer,
+      /*target_offset=*/0, sizeof(pressure_pattern), &pressure_pattern,
+      sizeof(pressure_pattern), IREE_HAL_FILL_FLAG_NONE);
+
+  uint64_t command_buffer_signal_value = 1;
+  iree_hal_semaphore_t* command_buffer_signal_ptr = command_buffer_signal.get();
+  iree_hal_semaphore_list_t command_buffer_signal_list = {
+      /*count=*/1,
+      /*semaphores=*/&command_buffer_signal_ptr,
+      /*payload_values=*/&command_buffer_signal_value,
+  };
+  iree_hal_buffer_binding_t binding = {
+      /*buffer=*/target_buffer.get(),
+      /*offset=*/0,
+      /*length=*/IREE_HAL_WHOLE_BUFFER,
+  };
+  const iree_hal_buffer_binding_table_t binding_table = {
+      /*count=*/1,
+      /*bindings=*/&binding,
+  };
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_device_queue_execute(
+        test_device.base_device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+        iree_hal_semaphore_list_empty(), command_buffer_signal_list,
+        command_buffer, binding_table, IREE_HAL_EXECUTE_FLAG_NONE);
+  }
+  const bool replay_parked =
+      iree_status_is_ok(status) && HostQueueHasPostDrainAction(queue);
+
+  iree_hsa_signal_store_screlease(IREE_LIBHSA(&libhsa_), blocker_signal, 0);
+
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_semaphore_wait(
+        command_buffer_signal, command_buffer_signal_value,
+        iree_infinite_timeout(), IREE_ASYNC_WAIT_FLAG_NONE);
+  }
+  IREE_EXPECT_OK(
+      iree_hsa_signal_destroy(IREE_LIBHSA(&libhsa_), blocker_signal));
+
+  IREE_ASSERT_OK(status);
+  IREE_ASSERT_OK(iree_hal_device_profiling_flush(test_device.base_device()));
+  IREE_ASSERT_OK(profiling.End());
+  EXPECT_TRUE(replay_parked);
+
+  EXPECT_EQ(1, sink.begin_count);
+  EXPECT_EQ(1, sink.end_count);
+  EXPECT_EQ(1, sink.command_buffer_metadata_count);
+  EXPECT_EQ(1, sink.queue_event_count);
+  ASSERT_EQ(1u, sink.command_buffer_ids.size());
+  EXPECT_EQ(1u, CountQueueEvents(sink, IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_FILL));
+  EXPECT_EQ(program->block_count,
+            CountQueueEvents(sink, IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_EXECUTE));
+  EXPECT_EQ(program->command_count,
+            SumQueueEventOperationCounts(
+                sink, IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_EXECUTE));
+  uint32_t execute_signal_count = 0;
+  for (const auto& event : sink.queue_events) {
+    if (event.type != IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_EXECUTE) continue;
+    EXPECT_EQ(sink.command_buffer_ids[0], event.command_buffer_id);
+    execute_signal_count += event.signal_count;
+  }
+  EXPECT_EQ(1u, execute_signal_count);
+
+  std::vector<uint32_t> actual(kFillCount);
+  IREE_ASSERT_OK(iree_hal_buffer_map_read(target_buffer, /*offset=*/0,
+                                          actual.data(), target_buffer_size));
+  EXPECT_EQ(actual, expected);
+}
+
+}  // namespace
+}  // namespace iree::hal::amdgpu
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_dispatch.c b/runtime/src/iree/hal/drivers/amdgpu/host_queue_dispatch.c
new file mode 100644
index 0000000..e922014
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_dispatch.c
@@ -0,0 +1,962 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/host_queue_dispatch.h"
+
+#include <string.h>
+
+#include "iree/hal/drivers/amdgpu/buffer.h"
+#include "iree/hal/drivers/amdgpu/device/dispatch.h"
+#include "iree/hal/drivers/amdgpu/device/timestamp.h"
+#include "iree/hal/drivers/amdgpu/executable.h"
+#include "iree/hal/drivers/amdgpu/host_queue_policy.h"
+#include "iree/hal/drivers/amdgpu/host_queue_profile.h"
+#include "iree/hal/drivers/amdgpu/host_queue_profile_events.h"
+#include "iree/hal/drivers/amdgpu/logical_device.h"
+#include "iree/hal/drivers/amdgpu/profile_counters.h"
+#include "iree/hal/drivers/amdgpu/profile_traces.h"
+#include "iree/hal/drivers/amdgpu/util/aql_emitter.h"
+
+typedef struct iree_hal_amdgpu_host_queue_dispatch_plan_t {
+  // Executable dispatch descriptor selected for the queue's physical device.
+  const iree_hal_amdgpu_executable_dispatch_descriptor_t* descriptor;
+  // Kernel arguments used for packet and kernarg emission.
+  const iree_hal_amdgpu_device_kernel_args_t* kernel_args;
+  // Storage for a workgroup-size override when dispatch config provides one.
+  iree_hal_amdgpu_device_kernel_args_t override_kernel_args;
+  // Device ABI layout describing the kernarg bytes to emit.
+  const iree_hal_amdgpu_device_dispatch_kernarg_layout_t* layout;
+  // Number of queue-owned kernarg blocks required for the dispatch.
+  uint32_t kernarg_block_count;
+  // Number of operation resources retained until dispatch completion.
+  iree_host_size_t operation_resource_count;
+  // True when workgroup counts are read from a device buffer before dispatch.
+  bool uses_indirect_parameters;
+} iree_hal_amdgpu_host_queue_dispatch_plan_t;
+
+static iree_status_t iree_hal_amdgpu_host_queue_validate_dispatch_flags(
+    iree_hal_dispatch_flags_t flags) {
+  if (iree_hal_dispatch_uses_indirect_arguments(flags)) {
+    return iree_make_status(
+        IREE_STATUS_UNIMPLEMENTED,
+        "indirect dispatch arguments are not supported by AMDGPU "
+        "queue_dispatch yet");
+  }
+
+  const iree_hal_dispatch_flags_t supported_flags =
+      IREE_HAL_DISPATCH_FLAG_DYNAMIC_INDIRECT_PARAMETERS |
+      IREE_HAL_DISPATCH_FLAG_STATIC_INDIRECT_PARAMETERS |
+      IREE_HAL_DISPATCH_FLAG_CUSTOM_DIRECT_ARGUMENTS |
+      IREE_HAL_DISPATCH_FLAG_ALLOW_INLINE_EXECUTION;
+  if (IREE_UNLIKELY(iree_any_bit_set(flags, ~supported_flags))) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "unsupported dispatch flags: 0x%" PRIx64, flags);
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_lookup_dispatch_descriptor(
+    const iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_executable_t* executable,
+    iree_hal_executable_export_ordinal_t export_ordinal,
+    const iree_hal_amdgpu_executable_dispatch_descriptor_t** out_descriptor) {
+  return iree_hal_amdgpu_executable_lookup_dispatch_descriptor_for_device(
+      executable, export_ordinal, queue->device_ordinal, out_descriptor);
+}
+
+static bool iree_hal_amdgpu_dispatch_config_has_workgroup_size_override(
+    const iree_hal_dispatch_config_t config) {
+  return config.workgroup_size[0] || config.workgroup_size[1] ||
+         config.workgroup_size[2];
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_select_dispatch_kernel_args(
+    const iree_hal_amdgpu_executable_dispatch_descriptor_t* descriptor,
+    const iree_hal_dispatch_config_t config,
+    iree_hal_amdgpu_device_kernel_args_t* override_kernel_args,
+    const iree_hal_amdgpu_device_kernel_args_t** out_kernel_args) {
+  *out_kernel_args = &descriptor->kernel_args;
+  if (!iree_hal_amdgpu_dispatch_config_has_workgroup_size_override(config)) {
+    return iree_ok_status();
+  }
+
+  *override_kernel_args = descriptor->kernel_args;
+  for (iree_host_size_t i = 0; i < 3; ++i) {
+    if (IREE_UNLIKELY(!config.workgroup_size[i])) {
+      return iree_make_status(
+          IREE_STATUS_INVALID_ARGUMENT,
+          "dispatch workgroup size override must specify all dimensions");
+    }
+    if (IREE_UNLIKELY(config.workgroup_size[i] > UINT16_MAX)) {
+      return iree_make_status(
+          IREE_STATUS_OUT_OF_RANGE,
+          "dispatch workgroup size override dimension %" PRIhsz
+          " value %u exceeds %u",
+          i, config.workgroup_size[i], UINT16_MAX);
+    }
+    override_kernel_args->workgroup_size[i] =
+        (uint16_t)config.workgroup_size[i];
+  }
+
+  *out_kernel_args = override_kernel_args;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_validate_dispatch_shape(
+    const iree_hal_amdgpu_executable_dispatch_descriptor_t* descriptor,
+    const iree_hal_amdgpu_device_kernel_args_t* kernel_args,
+    const iree_hal_dispatch_config_t config, iree_hal_dispatch_flags_t flags) {
+  const bool uses_indirect_parameters =
+      iree_hal_dispatch_uses_indirect_parameters(flags);
+  if (iree_hal_amdgpu_dispatch_config_has_workgroup_size_override(config)) {
+    for (iree_host_size_t i = 0; i < 3; ++i) {
+      const uint64_t grid_size =
+          (uint64_t)config.workgroup_count[i] * kernel_args->workgroup_size[i];
+      if (!uses_indirect_parameters && IREE_UNLIKELY(grid_size > UINT32_MAX)) {
+        return iree_make_status(
+            IREE_STATUS_OUT_OF_RANGE,
+            "dispatch grid dimension %" PRIhsz
+            " overflows uint32_t (workgroup_count=%u, workgroup_size=%u)",
+            i, config.workgroup_count[i], kernel_args->workgroup_size[i]);
+      }
+    }
+  } else if (!uses_indirect_parameters) {
+    for (iree_host_size_t i = 0; i < 3; ++i) {
+      if (IREE_UNLIKELY(config.workgroup_count[i] >
+                        descriptor->max_workgroup_count[i])) {
+        return iree_make_status(
+            IREE_STATUS_OUT_OF_RANGE,
+            "dispatch grid dimension %" PRIhsz
+            " overflows uint32_t (workgroup_count=%u, workgroup_size=%u)",
+            i, config.workgroup_count[i], kernel_args->workgroup_size[i]);
+      }
+    }
+  }
+  if (IREE_UNLIKELY(config.dynamic_workgroup_local_memory >
+                    descriptor->max_dynamic_workgroup_local_memory)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "dispatch group segment size overflows uint32_t "
+                            "(static=%u, dynamic=%u)",
+                            kernel_args->group_segment_size,
+                            config.dynamic_workgroup_local_memory);
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_validate_dispatch_binding(
+    const iree_hal_buffer_ref_t* binding) {
+  if (IREE_UNLIKELY(binding->reserved != 0 || binding->buffer_slot != 0 ||
+                    !binding->buffer)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "queue_dispatch bindings must be direct non-null buffer references");
+  }
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_memory_type(
+      iree_hal_buffer_memory_type(binding->buffer),
+      IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE));
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_usage(
+      iree_hal_buffer_allowed_usage(binding->buffer),
+      IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE));
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_access(
+      iree_hal_buffer_allowed_access(binding->buffer),
+      IREE_HAL_MEMORY_ACCESS_ANY));
+  return iree_hal_buffer_validate_range(binding->buffer, binding->offset,
+                                        binding->length);
+}
+
+static iree_status_t
+iree_hal_amdgpu_host_queue_validate_dispatch_indirect_parameters(
+    const iree_hal_buffer_ref_t* workgroup_count_ref) {
+  const iree_device_size_t workgroup_count_length = sizeof(uint32_t[3]);
+  if (IREE_UNLIKELY(workgroup_count_ref->reserved != 0 ||
+                    workgroup_count_ref->buffer_slot != 0 ||
+                    !workgroup_count_ref->buffer)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "queue_dispatch indirect workgroup parameters must use a direct "
+        "non-null buffer reference");
+  }
+  if (IREE_UNLIKELY((workgroup_count_ref->offset % sizeof(uint32_t)) != 0)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "queue_dispatch indirect workgroup parameter offset must be 4-byte "
+        "aligned");
+  }
+  if (IREE_UNLIKELY(workgroup_count_ref->length != IREE_HAL_WHOLE_BUFFER &&
+                    workgroup_count_ref->length < workgroup_count_length)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "queue_dispatch indirect workgroup parameter buffer must contain at "
+        "least uint32_t[3]");
+  }
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_memory_type(
+      iree_hal_buffer_memory_type(workgroup_count_ref->buffer),
+      IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE));
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_usage(
+      iree_hal_buffer_allowed_usage(workgroup_count_ref->buffer),
+      IREE_HAL_BUFFER_USAGE_DISPATCH_INDIRECT_PARAMETERS));
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_access(
+      iree_hal_buffer_allowed_access(workgroup_count_ref->buffer),
+      IREE_HAL_MEMORY_ACCESS_READ));
+  return iree_hal_buffer_validate_range(workgroup_count_ref->buffer,
+                                        workgroup_count_ref->offset,
+                                        workgroup_count_length);
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_validate_dispatch_kernargs(
+    const iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_executable_dispatch_descriptor_t* descriptor,
+    iree_const_byte_span_t constants, const iree_hal_buffer_ref_list_t bindings,
+    iree_hal_dispatch_flags_t flags,
+    const iree_hal_amdgpu_device_dispatch_kernarg_layout_t** out_layout,
+    uint32_t* out_kernarg_block_count,
+    iree_host_size_t* out_operation_resource_count) {
+  *out_layout = NULL;
+  *out_kernarg_block_count = 0;
+  *out_operation_resource_count = 0;
+
+  iree_host_size_t operation_resource_count = 1;
+  if (iree_any_bit_set(flags, IREE_HAL_DISPATCH_FLAG_CUSTOM_DIRECT_ARGUMENTS)) {
+    if (IREE_UNLIKELY(constants.data_length !=
+                      descriptor->kernel_args.kernarg_size)) {
+      return iree_make_status(
+          IREE_STATUS_INVALID_ARGUMENT,
+          "custom dispatch argument length mismatch; expected %u but got "
+          "%" PRIhsz,
+          descriptor->kernel_args.kernarg_size, constants.data_length);
+    }
+    if (IREE_UNLIKELY(constants.data_length > 0 && !constants.data)) {
+      return iree_make_status(
+          IREE_STATUS_INVALID_ARGUMENT,
+          "custom dispatch argument data must be non-null when length is "
+          "non-zero");
+    }
+    *out_layout = &descriptor->custom_kernarg_layout;
+    *out_kernarg_block_count =
+        iree_max(1u, descriptor->custom_kernarg_block_count);
+  } else {
+    if (IREE_UNLIKELY(constants.data_length > 0 && !constants.data)) {
+      return iree_make_status(
+          IREE_STATUS_INVALID_ARGUMENT,
+          "dispatch constant data must be non-null when length is non-zero");
+    }
+    const iree_host_size_t expected_constant_length =
+        (iree_host_size_t)descriptor->kernel_args.constant_count *
+        sizeof(uint32_t);
+    if (IREE_UNLIKELY(constants.data_length != expected_constant_length)) {
+      return iree_make_status(
+          IREE_STATUS_INVALID_ARGUMENT,
+          "dispatch constant count mismatch; expected %u but got %" PRIhsz,
+          (uint32_t)descriptor->kernel_args.constant_count,
+          constants.data_length / sizeof(uint32_t));
+    }
+    if (IREE_UNLIKELY(bindings.count !=
+                      descriptor->kernel_args.binding_count)) {
+      return iree_make_status(
+          IREE_STATUS_INVALID_ARGUMENT,
+          "dispatch binding count mismatch; expected %u but got %" PRIhsz,
+          (uint32_t)descriptor->kernel_args.binding_count, bindings.count);
+    }
+    if (IREE_UNLIKELY(
+            bindings.count >
+            IREE_HAL_AMDGPU_HOST_QUEUE_DISPATCH_SCRATCH_BINDING_CAPACITY)) {
+      return iree_make_status(
+          IREE_STATUS_OUT_OF_RANGE,
+          "queue_dispatch supports at most %u direct buffer bindings but got "
+          "%" PRIhsz,
+          IREE_HAL_AMDGPU_HOST_QUEUE_DISPATCH_SCRATCH_BINDING_CAPACITY,
+          bindings.count);
+    }
+    if (IREE_UNLIKELY(bindings.count > 0 && !bindings.values)) {
+      return iree_make_status(
+          IREE_STATUS_INVALID_ARGUMENT,
+          "dispatch bindings must be non-null when count is non-zero");
+    }
+    operation_resource_count = 1 + bindings.count;
+    *out_layout = &descriptor->hal_kernarg_layout;
+    *out_kernarg_block_count =
+        iree_max(1u, descriptor->hal_kernarg_block_count);
+  }
+
+  if (iree_hal_dispatch_uses_indirect_parameters(flags)) {
+    ++operation_resource_count;
+  }
+
+  if (IREE_UNLIKELY(*out_kernarg_block_count > queue->kernarg_ring.capacity)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "dispatch kernargs require %u"
+                            " blocks but the queue kernarg ring capacity is %u",
+                            *out_kernarg_block_count,
+                            queue->kernarg_ring.capacity);
+  }
+
+  *out_operation_resource_count = operation_resource_count;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_prepare_dispatch_plan(
+    const iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_executable_t* executable,
+    iree_hal_executable_export_ordinal_t export_ordinal,
+    const iree_hal_dispatch_config_t config, iree_const_byte_span_t constants,
+    const iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags,
+    iree_hal_amdgpu_host_queue_dispatch_plan_t* out_plan) {
+  memset(out_plan, 0, sizeof(*out_plan));
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_host_queue_validate_dispatch_flags(flags));
+
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_lookup_dispatch_descriptor(
+      queue, executable, export_ordinal, &out_plan->descriptor));
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_select_dispatch_kernel_args(
+      out_plan->descriptor, config, &out_plan->override_kernel_args,
+      &out_plan->kernel_args));
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_validate_dispatch_shape(
+      out_plan->descriptor, out_plan->kernel_args, config, flags));
+  out_plan->uses_indirect_parameters =
+      iree_hal_dispatch_uses_indirect_parameters(flags);
+
+  return iree_hal_amdgpu_host_queue_validate_dispatch_kernargs(
+      queue, out_plan->descriptor, constants, bindings, flags,
+      &out_plan->layout, &out_plan->kernarg_block_count,
+      &out_plan->operation_resource_count);
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_resolve_validated_binding_ptr(
+    const iree_hal_buffer_ref_t* binding, uint64_t* out_binding_ptr) {
+  *out_binding_ptr = 0;
+  iree_hal_buffer_t* allocated_buffer =
+      iree_hal_buffer_allocated_buffer(binding->buffer);
+  void* device_ptr = iree_hal_amdgpu_buffer_device_pointer(allocated_buffer);
+  if (IREE_UNLIKELY(!device_ptr)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "dispatch binding buffer must be backed by an AMDGPU allocation");
+  }
+
+  iree_device_size_t device_offset = 0;
+  if (IREE_UNLIKELY(!iree_device_size_checked_add(
+          iree_hal_buffer_byte_offset(binding->buffer), binding->offset,
+          &device_offset))) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "dispatch binding device pointer offset overflows device size");
+  }
+  if (IREE_UNLIKELY(device_offset > UINTPTR_MAX)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "dispatch binding device pointer offset exceeds host pointer size");
+  }
+  *out_binding_ptr = (uint64_t)((uintptr_t)device_ptr + device_offset);
+  return iree_ok_status();
+}
+
+// Validates direct dispatch bindings and optionally fills the submit-time
+// channels used to retain resources and write final kernargs. The arrays are
+// caller-owned scratch; passing NULL runs only the corresponding validation.
+static iree_status_t iree_hal_amdgpu_host_queue_prepare_dispatch_bindings(
+    const iree_hal_buffer_ref_list_t bindings,
+    iree_hal_resource_t** operation_resources, uint64_t* binding_ptrs) {
+  iree_status_t status = iree_ok_status();
+  for (iree_host_size_t i = 0; i < bindings.count && iree_status_is_ok(status);
+       ++i) {
+    const iree_hal_buffer_ref_t* binding = &bindings.values[i];
+    status = iree_hal_amdgpu_host_queue_validate_dispatch_binding(binding);
+    if (iree_status_is_ok(status) && binding_ptrs) {
+      status = iree_hal_amdgpu_host_queue_resolve_validated_binding_ptr(
+          binding, &binding_ptrs[i]);
+    }
+    if (iree_status_is_ok(status) && operation_resources) {
+      operation_resources[i + 1] = (iree_hal_resource_t*)binding->buffer;
+    }
+    if (!iree_status_is_ok(status)) {
+      status = iree_status_annotate_f(status, "binding[%" PRIhsz "]", i);
+    }
+  }
+  return status;
+}
+
+static iree_status_t
+iree_hal_amdgpu_host_queue_prepare_dispatch_indirect_parameters(
+    const iree_hal_dispatch_config_t config,
+    iree_hal_resource_t** operation_resources,
+    iree_host_size_t operation_resource_index,
+    uint64_t* out_workgroup_count_ptr) {
+  *out_workgroup_count_ptr = 0;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_host_queue_validate_dispatch_indirect_parameters(
+          &config.workgroup_count_ref));
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_resolve_validated_binding_ptr(
+      &config.workgroup_count_ref, out_workgroup_count_ptr));
+  operation_resources[operation_resource_index] =
+      (iree_hal_resource_t*)config.workgroup_count_ref.buffer;
+  return iree_ok_status();
+}
+
+static void iree_hal_amdgpu_host_queue_initialize_dispatch_event(
+    iree_hal_amdgpu_profile_dispatch_event_t* event,
+    const iree_hal_amdgpu_host_queue_dispatch_plan_t* plan,
+    iree_hal_executable_export_ordinal_t export_ordinal, uint64_t executable_id,
+    const iree_hal_dispatch_config_t config,
+    iree_hal_amdgpu_profile_dispatch_event_flags_t flags) {
+  const uint64_t event_id = event->event_id;
+  memset(event, 0, sizeof(*event));
+  event->record_length = sizeof(*event);
+  event->flags = flags;
+  event->event_id = event_id;
+  event->command_index = UINT32_MAX;
+  event->export_ordinal = export_ordinal;
+  event->executable_id = executable_id;
+  if (!iree_any_bit_set(
+          flags,
+          IREE_HAL_AMDGPU_PROFILE_DISPATCH_EVENT_FLAG_INDIRECT_PARAMETERS)) {
+    memcpy(event->workgroup_count, config.workgroup_count,
+           sizeof(event->workgroup_count));
+  }
+  event->workgroup_size[0] = plan->kernel_args->workgroup_size[0];
+  event->workgroup_size[1] = plan->kernel_args->workgroup_size[1];
+  event->workgroup_size[2] = plan->kernel_args->workgroup_size[2];
+}
+
+static bool iree_hal_amdgpu_host_queue_should_profile_dispatch(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t executable_id,
+    iree_hal_executable_export_ordinal_t export_ordinal) {
+  if (!queue->profiling.dispatch_profiling_enabled) return false;
+  if (!queue->profiling.hsa_queue_timestamps_enabled) return false;
+  iree_hal_amdgpu_logical_device_t* logical_device =
+      (iree_hal_amdgpu_logical_device_t*)queue->logical_device;
+  const uint32_t physical_device_ordinal = queue->device_ordinal <= UINT32_MAX
+                                               ? (uint32_t)queue->device_ordinal
+                                               : UINT32_MAX;
+  const uint32_t queue_ordinal = iree_async_axis_queue_index(queue->axis);
+  return iree_hal_amdgpu_logical_device_should_profile_dispatch(
+      logical_device, executable_id, export_ordinal,
+      /*command_buffer_id=*/0, /*command_index=*/UINT32_MAX,
+      physical_device_ordinal, queue_ordinal);
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_validate_dispatch(
+    const iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_executable_t* executable,
+    iree_hal_executable_export_ordinal_t export_ordinal,
+    const iree_hal_dispatch_config_t config, iree_const_byte_span_t constants,
+    const iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags,
+    iree_host_size_t* out_operation_resource_count) {
+  *out_operation_resource_count = 0;
+  iree_hal_amdgpu_host_queue_dispatch_plan_t plan;
+  iree_status_t status = iree_hal_amdgpu_host_queue_prepare_dispatch_plan(
+      queue, executable, export_ordinal, config, constants, bindings, flags,
+      &plan);
+  if (iree_status_is_ok(status) &&
+      !iree_any_bit_set(flags,
+                        IREE_HAL_DISPATCH_FLAG_CUSTOM_DIRECT_ARGUMENTS)) {
+    status = iree_hal_amdgpu_host_queue_prepare_dispatch_bindings(
+        bindings, /*operation_resources=*/NULL, /*binding_ptrs=*/NULL);
+  }
+  if (iree_status_is_ok(status) && plan.uses_indirect_parameters) {
+    status = iree_hal_amdgpu_host_queue_validate_dispatch_indirect_parameters(
+        &config.workgroup_count_ref);
+  }
+  if (iree_status_is_ok(status)) {
+    *out_operation_resource_count = plan.operation_resource_count;
+  }
+  return status;
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_submit_direct_dispatch(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    const iree_hal_amdgpu_host_queue_dispatch_plan_t* plan,
+    iree_hal_executable_t* executable,
+    iree_hal_executable_export_ordinal_t export_ordinal,
+    const iree_hal_dispatch_config_t config, iree_const_byte_span_t constants,
+    const uint64_t* binding_ptrs,
+    iree_hal_resource_t* const* operation_resources,
+    bool uses_custom_direct_arguments,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    bool* out_ready) {
+  uint64_t executable_id = 0;
+  bool should_profile_dispatch = false;
+  if (queue->profiling.dispatch_profiling_enabled) {
+    executable_id = iree_hal_amdgpu_executable_profile_id(executable);
+    should_profile_dispatch =
+        iree_hal_amdgpu_host_queue_should_profile_dispatch(queue, executable_id,
+                                                           export_ordinal);
+  }
+  iree_hal_amdgpu_profile_dispatch_event_reservation_t profile_events = {0};
+  iree_status_t status = iree_ok_status();
+  if (should_profile_dispatch) {
+    status = iree_hal_amdgpu_host_queue_reserve_profile_dispatch_events(
+        queue, /*event_count=*/1, &profile_events);
+  }
+  if (iree_status_is_ok(status) && profile_events.event_count != 0) {
+    status = iree_hal_amdgpu_host_queue_prepare_profile_counter_samples(
+        queue, profile_events);
+  }
+  if (iree_status_is_ok(status) && profile_events.event_count != 0) {
+    status = iree_hal_amdgpu_host_queue_prepare_profile_traces(queue,
+                                                               profile_events);
+  }
+  if (iree_status_is_ok(status) && profile_events.event_count != 0) {
+    status = iree_hal_amdgpu_host_queue_prepare_profile_trace_code_object(
+        queue, profile_events.first_event_position, executable_id);
+  }
+  if (!iree_status_is_ok(status)) {
+    iree_hal_amdgpu_host_queue_cancel_profile_dispatch_events(queue,
+                                                              profile_events);
+    return status;
+  }
+  iree_hal_amdgpu_host_queue_profile_event_info_t profile_queue_event_info = {
+      .type = IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_DISPATCH,
+      .operation_count = 1,
+  };
+  iree_hal_amdgpu_host_queue_dispatch_submission_t submission;
+  status = iree_hal_amdgpu_host_queue_try_begin_dispatch_submission(
+      queue, resolution, signal_semaphore_list, plan->operation_resource_count,
+      plan->kernarg_block_count, profile_events, &profile_queue_event_info,
+      out_ready, &submission);
+  if (!iree_status_is_ok(status) || !*out_ready) {
+    iree_hal_amdgpu_host_queue_cancel_profile_dispatch_events(queue,
+                                                              profile_events);
+    return status;
+  }
+
+  if (uses_custom_direct_arguments) {
+    iree_hal_amdgpu_device_dispatch_emplace_custom_kernargs(
+        plan->layout, constants.data, submission.kernel.kernargs.blocks->data);
+  } else {
+    iree_hal_amdgpu_device_dispatch_emplace_hal_kernargs(
+        plan->kernel_args, config.workgroup_count,
+        config.dynamic_workgroup_local_memory, plan->layout, binding_ptrs,
+        (const uint32_t*)constants.data,
+        submission.kernel.kernargs.blocks->data);
+  }
+  iree_hal_amdgpu_device_dispatch_emplace_packet(
+      plan->kernel_args, config.workgroup_count,
+      config.dynamic_workgroup_local_memory,
+      &submission.dispatch_slot->dispatch,
+      submission.kernel.kernargs.blocks->data);
+  submission.dispatch_slot->dispatch.completion_signal =
+      submission.dispatch_completion_signal;
+  submission.dispatch_setup = submission.dispatch_slot->dispatch.setup;
+  if (submission.profile_harvest_slot) {
+    iree_hal_amdgpu_profile_dispatch_event_t* event =
+        iree_hal_amdgpu_host_queue_profile_dispatch_event_at(
+            queue, profile_events.first_event_position);
+    iree_hal_amdgpu_host_queue_initialize_dispatch_event(
+        event, plan, export_ordinal, executable_id, config,
+        IREE_HAL_AMDGPU_PROFILE_DISPATCH_EVENT_FLAG_NONE);
+    iree_hal_amdgpu_profile_dispatch_harvest_source_t* sources =
+        iree_hal_amdgpu_device_timestamp_emplace_dispatch_harvest(
+            &queue->transfer_context->kernels
+                 ->iree_hal_amdgpu_device_timestamp_harvest_dispatch_records,
+            profile_events.event_count,
+            &submission.profile_harvest_slot->dispatch,
+            submission.profile_harvest_kernarg_blocks->data);
+    sources[0].completion_signal =
+        iree_hal_amdgpu_host_queue_profiling_completion_signal_ptr(
+            queue, profile_events.first_event_position);
+    sources[0].ticks = iree_hal_amdgpu_profile_dispatch_event_ticks(event);
+    submission.profile_harvest_setup =
+        submission.profile_harvest_slot->dispatch.setup;
+  }
+  const uint64_t submission_epoch =
+      iree_hal_amdgpu_host_queue_finish_dispatch_submission(
+          queue, resolution, signal_semaphore_list, operation_resources,
+          plan->operation_resource_count, &profile_queue_event_info,
+          submission_flags, &submission);
+  profile_queue_event_info.submission_id = submission_epoch;
+  iree_hal_amdgpu_host_queue_record_profile_queue_event(
+      queue, resolution, signal_semaphore_list, &profile_queue_event_info);
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_submit_indirect_dispatch(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    const iree_hal_amdgpu_host_queue_dispatch_plan_t* plan,
+    iree_hal_executable_t* executable,
+    iree_hal_executable_export_ordinal_t export_ordinal,
+    const iree_hal_dispatch_config_t config, iree_const_byte_span_t constants,
+    const uint64_t* binding_ptrs, uint64_t workgroup_count_ptr,
+    iree_hal_resource_t* const* operation_resources,
+    bool uses_custom_direct_arguments,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    bool* out_ready) {
+  uint64_t executable_id = 0;
+  bool should_profile_dispatch = false;
+  if (queue->profiling.dispatch_profiling_enabled) {
+    executable_id = iree_hal_amdgpu_executable_profile_id(executable);
+    should_profile_dispatch =
+        iree_hal_amdgpu_host_queue_should_profile_dispatch(queue, executable_id,
+                                                           export_ordinal);
+  }
+  const uint32_t target_kernarg_block_count = plan->kernarg_block_count;
+  const uint32_t patch_kernarg_block_count = 1;
+  iree_hal_amdgpu_profile_dispatch_event_reservation_t profile_events = {0};
+  iree_status_t status = iree_ok_status();
+  if (should_profile_dispatch) {
+    status = iree_hal_amdgpu_host_queue_reserve_profile_dispatch_events(
+        queue, /*event_count=*/1, &profile_events);
+  }
+  if (iree_status_is_ok(status) && profile_events.event_count != 0) {
+    status = iree_hal_amdgpu_host_queue_prepare_profile_counter_samples(
+        queue, profile_events);
+  }
+  if (iree_status_is_ok(status) && profile_events.event_count != 0) {
+    status = iree_hal_amdgpu_host_queue_prepare_profile_traces(queue,
+                                                               profile_events);
+  }
+  if (iree_status_is_ok(status) && profile_events.event_count != 0) {
+    status = iree_hal_amdgpu_host_queue_prepare_profile_trace_code_object(
+        queue, profile_events.first_event_position, executable_id);
+  }
+  if (!iree_status_is_ok(status)) {
+    iree_hal_amdgpu_host_queue_cancel_profile_dispatch_events(queue,
+                                                              profile_events);
+    return status;
+  }
+  iree_hal_amdgpu_host_queue_profile_event_info_t profile_queue_event_info = {
+      .type = IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_DISPATCH,
+      .operation_count = 1,
+  };
+  iree_hal_amdgpu_profile_queue_device_event_reservation_t
+      profile_queue_device_events = {0};
+  if (iree_hal_amdgpu_host_queue_should_profile_queue_device_events(queue)) {
+    status = iree_hal_amdgpu_host_queue_reserve_profile_queue_device_events(
+        queue, /*event_count=*/1, &profile_queue_device_events);
+  }
+  if (!iree_status_is_ok(status)) {
+    iree_hal_amdgpu_host_queue_cancel_profile_dispatch_events(queue,
+                                                              profile_events);
+    return status;
+  }
+  const bool profile_dispatch_packet = profile_events.event_count != 0;
+  const bool profile_queue_device_event =
+      profile_queue_device_events.event_count != 0;
+  uint32_t profile_counter_set_count = 0;
+  uint32_t profile_counter_packet_count = 0;
+  uint32_t profile_trace_packet_count = 0;
+  uint32_t profile_trace_start_packet_count = 0;
+  if (profile_dispatch_packet) {
+    profile_counter_set_count =
+        iree_hal_amdgpu_host_queue_profile_counter_set_count(queue,
+                                                             profile_events);
+    profile_counter_packet_count =
+        iree_hal_amdgpu_host_queue_profile_counter_packet_count(queue,
+                                                                profile_events);
+    profile_trace_packet_count =
+        iree_hal_amdgpu_host_queue_profile_trace_packet_count(queue,
+                                                              profile_events);
+    profile_trace_start_packet_count =
+        iree_hal_amdgpu_host_queue_profile_trace_start_packet_count(
+            queue, profile_events);
+  }
+  const uint32_t profile_trace_stop_packet_count =
+      profile_trace_packet_count - profile_trace_start_packet_count;
+  const uint32_t profile_queue_device_prefix_packet_count =
+      profile_queue_device_event ? 1u : 0u;
+  const uint32_t profile_queue_device_suffix_packet_count =
+      profile_queue_device_event ? 1u : 0u;
+  const uint32_t profile_queue_device_packet_count =
+      profile_queue_device_prefix_packet_count +
+      profile_queue_device_suffix_packet_count;
+  const uint32_t payload_packet_count =
+      profile_queue_device_packet_count + 2u + profile_counter_packet_count +
+      profile_trace_packet_count + (profile_dispatch_packet ? 1u : 0u);
+  const uint32_t profile_harvest_kernarg_block_count =
+      profile_dispatch_packet
+          ? (uint32_t)iree_host_size_ceil_div(
+                iree_hal_amdgpu_device_timestamp_dispatch_harvest_kernarg_length(
+                    profile_events.event_count),
+                sizeof(iree_hal_amdgpu_kernarg_block_t))
+          : 0u;
+  iree_hal_amdgpu_host_queue_kernel_submission_t submission;
+  status = iree_hal_amdgpu_host_queue_try_begin_kernel_submission(
+      queue, resolution, signal_semaphore_list, plan->operation_resource_count,
+      payload_packet_count,
+      patch_kernarg_block_count + target_kernarg_block_count +
+          profile_harvest_kernarg_block_count,
+      out_ready, &submission);
+  if (!iree_status_is_ok(status) || !*out_ready) {
+    iree_hal_amdgpu_host_queue_cancel_profile_dispatch_events(queue,
+                                                              profile_events);
+    iree_hal_amdgpu_host_queue_cancel_profile_queue_device_events(
+        queue, profile_queue_device_events);
+    return status;
+  }
+
+  const uint64_t patch_packet_id = submission.first_packet_id +
+                                   resolution->barrier_count +
+                                   profile_queue_device_prefix_packet_count;
+  const uint64_t dispatch_packet_id = patch_packet_id + 1u +
+                                      profile_counter_set_count +
+                                      profile_trace_start_packet_count;
+  const uint64_t profile_harvest_packet_id =
+      submission.first_packet_id + resolution->barrier_count +
+      payload_packet_count - 1u - profile_queue_device_suffix_packet_count;
+  iree_hal_amdgpu_aql_packet_t* patch_packet =
+      iree_hal_amdgpu_aql_ring_packet(&queue->aql_ring, patch_packet_id);
+  iree_hal_amdgpu_aql_packet_t* dispatch_packet =
+      iree_hal_amdgpu_aql_ring_packet(&queue->aql_ring, dispatch_packet_id);
+  iree_hal_amdgpu_aql_packet_t* profile_harvest_packet = NULL;
+  if (profile_dispatch_packet) {
+    profile_harvest_packet = iree_hal_amdgpu_aql_ring_packet(
+        &queue->aql_ring, profile_harvest_packet_id);
+  }
+  iree_hal_amdgpu_kernarg_block_t* kernarg_blocks = submission.kernargs.blocks;
+  uint8_t* patch_kernarg_data = kernarg_blocks[0].data;
+  uint8_t* dispatch_kernarg_data = kernarg_blocks[1].data;
+  uint8_t* profile_harvest_kernarg_data = NULL;
+  if (profile_dispatch_packet) {
+    iree_hal_amdgpu_kernarg_block_t* profile_harvest_kernarg_blocks =
+        &kernarg_blocks[patch_kernarg_block_count + target_kernarg_block_count];
+    profile_harvest_kernarg_data = profile_harvest_kernarg_blocks->data;
+  }
+  const uint32_t placeholder_workgroup_count[3] = {0, 0, 0};
+  if (uses_custom_direct_arguments) {
+    iree_hal_amdgpu_device_dispatch_emplace_custom_kernargs(
+        plan->layout, constants.data, dispatch_kernarg_data);
+  } else {
+    iree_hal_amdgpu_device_dispatch_emplace_hal_kernargs(
+        plan->kernel_args, placeholder_workgroup_count,
+        config.dynamic_workgroup_local_memory, plan->layout, binding_ptrs,
+        (const uint32_t*)constants.data, dispatch_kernarg_data);
+  }
+  iree_hal_amdgpu_device_dispatch_emplace_packet(
+      plan->kernel_args, placeholder_workgroup_count,
+      config.dynamic_workgroup_local_memory, &dispatch_packet->dispatch,
+      dispatch_kernarg_data);
+  iree_hsa_signal_t dispatch_completion_signal =
+      profile_queue_device_event
+          ? iree_hsa_signal_null()
+          : iree_hal_amdgpu_notification_ring_epoch_signal(
+                &queue->notification_ring);
+  if (profile_dispatch_packet) {
+    dispatch_completion_signal =
+        iree_hal_amdgpu_host_queue_profiling_completion_signal(
+            queue, profile_events.first_event_position);
+  }
+  dispatch_packet->dispatch.completion_signal = dispatch_completion_signal;
+
+  iree_amdgpu_kernel_implicit_args_t* implicit_args =
+      plan->layout->has_implicit_args
+          ? (iree_amdgpu_kernel_implicit_args_t*)(dispatch_kernarg_data +
+                                                  plan->layout
+                                                      ->implicit_args_offset)
+          : NULL;
+  const uint16_t dispatch_setup = dispatch_packet->dispatch.setup;
+  const iree_hsa_fence_scope_t dispatch_acquire_scope =
+      iree_hal_amdgpu_host_queue_kernarg_acquire_scope(
+          queue, IREE_HSA_FENCE_SCOPE_AGENT);
+  const iree_hal_amdgpu_aql_packet_control_t dispatch_packet_control =
+      (profile_dispatch_packet || profile_queue_device_event)
+          ? iree_hal_amdgpu_aql_packet_control_barrier(
+                iree_hal_amdgpu_host_queue_max_fence_scope(
+                    dispatch_acquire_scope, resolution->inline_acquire_scope),
+                profile_dispatch_packet ? IREE_HSA_FENCE_SCOPE_AGENT
+                                        : IREE_HSA_FENCE_SCOPE_NONE)
+          : iree_hal_amdgpu_aql_packet_control_barrier(
+                iree_hal_amdgpu_host_queue_max_fence_scope(
+                    dispatch_acquire_scope, resolution->inline_acquire_scope),
+                iree_hal_amdgpu_host_queue_signal_list_release_scope(
+                    queue, signal_semaphore_list));
+  const uint16_t dispatch_header = iree_hal_amdgpu_aql_make_header(
+      IREE_HSA_PACKET_TYPE_KERNEL_DISPATCH, dispatch_packet_control);
+  iree_hal_amdgpu_device_dispatch_emplace_indirect_params_patch(
+      &queue->transfer_context->kernels
+           ->iree_hal_amdgpu_device_dispatch_patch_indirect_params,
+      (const uint32_t*)(uintptr_t)workgroup_count_ptr,
+      &dispatch_packet->dispatch, dispatch_header, dispatch_setup,
+      implicit_args, &patch_packet->dispatch, patch_kernarg_data);
+  if (profile_dispatch_packet) {
+    iree_hal_amdgpu_profile_dispatch_event_t* event =
+        iree_hal_amdgpu_host_queue_profile_dispatch_event_at(
+            queue, profile_events.first_event_position);
+    iree_hal_amdgpu_host_queue_initialize_dispatch_event(
+        event, plan, export_ordinal, executable_id, config,
+        IREE_HAL_AMDGPU_PROFILE_DISPATCH_EVENT_FLAG_INDIRECT_PARAMETERS);
+    iree_hal_amdgpu_profile_dispatch_harvest_source_t* sources =
+        iree_hal_amdgpu_device_timestamp_emplace_dispatch_harvest(
+            &queue->transfer_context->kernels
+                 ->iree_hal_amdgpu_device_timestamp_harvest_dispatch_records,
+            profile_events.event_count, &profile_harvest_packet->dispatch,
+            profile_harvest_kernarg_data);
+    sources[0].completion_signal =
+        iree_hal_amdgpu_host_queue_profiling_completion_signal_ptr(
+            queue, profile_events.first_event_position);
+    sources[0].ticks = iree_hal_amdgpu_profile_dispatch_event_ticks(event);
+  }
+
+  iree_hal_amdgpu_host_queue_emit_kernel_submission_prefix(queue, resolution,
+                                                           &submission);
+  const uint64_t submission_epoch =
+      iree_hal_amdgpu_host_queue_finish_kernel_submission(
+          queue, resolution, signal_semaphore_list, operation_resources,
+          plan->operation_resource_count, /*inout_resource_set=*/NULL,
+          submission_flags, &submission);
+  profile_queue_event_info.submission_id = submission_epoch;
+  iree_hal_amdgpu_profile_queue_device_event_t* queue_device_event =
+      iree_hal_amdgpu_host_queue_initialize_profile_queue_device_event(
+          queue, profile_queue_device_events, &profile_queue_event_info);
+  if (queue_device_event) {
+    submission.reclaim_entry->queue_device_event_first_position =
+        profile_queue_device_events.first_event_position;
+    submission.reclaim_entry->queue_device_event_count =
+        profile_queue_device_events.event_count;
+    queue_device_event->submission_id = submission_epoch;
+  }
+  uint16_t profile_harvest_header = 0;
+  if (profile_dispatch_packet) {
+    submission.reclaim_entry->profile_event_first_position =
+        profile_events.first_event_position;
+    submission.reclaim_entry->profile_event_count = profile_events.event_count;
+    iree_hal_amdgpu_profile_dispatch_event_t* event =
+        iree_hal_amdgpu_host_queue_profile_dispatch_event_at(
+            queue, profile_events.first_event_position);
+    event->submission_id = submission_epoch;
+    profile_harvest_packet->dispatch.completion_signal =
+        queue_device_event ? iree_hsa_signal_null()
+                           : iree_hal_amdgpu_notification_ring_epoch_signal(
+                                 &queue->notification_ring);
+    const iree_hsa_fence_scope_t profile_harvest_acquire_scope =
+        iree_hal_amdgpu_host_queue_kernarg_acquire_scope(
+            queue, IREE_HSA_FENCE_SCOPE_AGENT);
+    profile_harvest_header = iree_hal_amdgpu_aql_make_header(
+        IREE_HSA_PACKET_TYPE_KERNEL_DISPATCH,
+        iree_hal_amdgpu_aql_packet_control_barrier(
+            iree_hal_amdgpu_host_queue_max_fence_scope(
+                profile_harvest_acquire_scope,
+                resolution->inline_acquire_scope),
+            IREE_HSA_FENCE_SCOPE_SYSTEM));
+  }
+  iree_hal_amdgpu_host_queue_publish_submission_kernargs(queue, &submission);
+  if (queue_device_event) {
+    iree_hal_amdgpu_host_queue_commit_queue_device_start_packet(
+        queue, resolution,
+        submission.first_packet_id + resolution->barrier_count,
+        queue_device_event);
+  }
+  if (profile_counter_set_count != 0) {
+    iree_hal_amdgpu_host_queue_commit_profile_counter_start_packets(
+        queue, profile_events.first_event_position, profile_counter_set_count,
+        patch_packet_id + 1u,
+        iree_hal_amdgpu_aql_packet_control_barrier(
+            iree_hal_amdgpu_host_queue_max_fence_scope(
+                IREE_HSA_FENCE_SCOPE_AGENT, resolution->inline_acquire_scope),
+            IREE_HSA_FENCE_SCOPE_AGENT));
+  }
+  if (profile_trace_packet_count != 0) {
+    iree_hal_amdgpu_host_queue_commit_profile_trace_start_packet(
+        queue, profile_events.first_event_position,
+        patch_packet_id + 1u + profile_counter_set_count,
+        iree_hal_amdgpu_aql_packet_control_barrier(
+            iree_hal_amdgpu_host_queue_max_fence_scope(
+                IREE_HSA_FENCE_SCOPE_AGENT, resolution->inline_acquire_scope),
+            IREE_HSA_FENCE_SCOPE_AGENT));
+    iree_hal_amdgpu_host_queue_commit_profile_trace_code_object_packet(
+        queue, profile_events.first_event_position,
+        patch_packet_id + 1u + profile_counter_set_count + 1u,
+        iree_hal_amdgpu_aql_packet_control_barrier(IREE_HSA_FENCE_SCOPE_AGENT,
+                                                   IREE_HSA_FENCE_SCOPE_AGENT));
+    iree_hal_amdgpu_host_queue_commit_profile_trace_stop_packet(
+        queue, profile_events.first_event_position, dispatch_packet_id + 1u,
+        iree_hal_amdgpu_aql_packet_control_barrier(IREE_HSA_FENCE_SCOPE_AGENT,
+                                                   IREE_HSA_FENCE_SCOPE_AGENT));
+  }
+  if (profile_counter_set_count != 0) {
+    iree_hal_amdgpu_host_queue_commit_profile_counter_read_stop_packets(
+        queue, profile_events.first_event_position, profile_counter_set_count,
+        dispatch_packet_id + 1u + profile_trace_stop_packet_count,
+        iree_hal_amdgpu_aql_packet_control_barrier(IREE_HSA_FENCE_SCOPE_AGENT,
+                                                   IREE_HSA_FENCE_SCOPE_AGENT));
+  }
+  if (profile_dispatch_packet) {
+    iree_hal_amdgpu_aql_ring_commit(profile_harvest_packet,
+                                    profile_harvest_header,
+                                    profile_harvest_packet->dispatch.setup);
+  }
+  if (queue_device_event) {
+    iree_hal_amdgpu_host_queue_commit_queue_device_end_packet(
+        queue, resolution, signal_semaphore_list,
+        submission.first_packet_id + submission.packet_count - 1,
+        queue_device_event);
+  }
+  const uint16_t patch_setup = patch_packet->dispatch.setup;
+  const uint16_t patch_header = iree_hal_amdgpu_aql_make_header(
+      IREE_HSA_PACKET_TYPE_KERNEL_DISPATCH,
+      iree_hal_amdgpu_aql_packet_control_barrier(
+          iree_hal_amdgpu_host_queue_max_fence_scope(
+              IREE_HSA_FENCE_SCOPE_AGENT, resolution->inline_acquire_scope),
+          IREE_HSA_FENCE_SCOPE_AGENT));
+  iree_hal_amdgpu_aql_ring_commit(patch_packet, patch_header, patch_setup);
+  iree_hal_amdgpu_aql_ring_doorbell(
+      &queue->aql_ring,
+      submission.first_packet_id + submission.packet_count - 1);
+  iree_hal_amdgpu_host_queue_record_profile_queue_event(
+      queue, resolution, signal_semaphore_list, &profile_queue_event_info);
+  return iree_ok_status();
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_submit_dispatch(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_executable_t* executable,
+    iree_hal_executable_export_ordinal_t export_ordinal,
+    const iree_hal_dispatch_config_t config, iree_const_byte_span_t constants,
+    const iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    bool* out_ready) {
+  IREE_ASSERT_ARGUMENT(out_ready);
+  *out_ready = false;
+  if (IREE_UNLIKELY(queue->is_shutting_down)) {
+    return iree_make_status(IREE_STATUS_CANCELLED, "queue shutting down");
+  }
+  iree_hal_amdgpu_host_queue_dispatch_plan_t plan;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_prepare_dispatch_plan(
+      queue, executable, export_ordinal, config, constants, bindings, flags,
+      &plan));
+
+  iree_hal_resource_t** operation_resources =
+      queue->dispatch_scratch.operation_resources;
+  uint64_t* binding_ptrs = queue->dispatch_scratch.binding_ptrs;
+
+  const bool uses_custom_direct_arguments =
+      iree_any_bit_set(flags, IREE_HAL_DISPATCH_FLAG_CUSTOM_DIRECT_ARGUMENTS);
+  operation_resources[0] = (iree_hal_resource_t*)executable;
+
+  iree_status_t status = iree_ok_status();
+  if (!uses_custom_direct_arguments) {
+    status = iree_hal_amdgpu_host_queue_prepare_dispatch_bindings(
+        bindings, operation_resources, binding_ptrs);
+  }
+  uint64_t workgroup_count_ptr = 0;
+  if (iree_status_is_ok(status) && plan.uses_indirect_parameters) {
+    const iree_host_size_t resource_index =
+        uses_custom_direct_arguments ? 1 : 1 + bindings.count;
+    status = iree_hal_amdgpu_host_queue_prepare_dispatch_indirect_parameters(
+        config, operation_resources, resource_index, &workgroup_count_ptr);
+  }
+
+  if (iree_status_is_ok(status)) {
+    if (plan.uses_indirect_parameters) {
+      status = iree_hal_amdgpu_host_queue_submit_indirect_dispatch(
+          queue, resolution, signal_semaphore_list, &plan, executable,
+          export_ordinal, config, constants, binding_ptrs, workgroup_count_ptr,
+          operation_resources, uses_custom_direct_arguments, submission_flags,
+          out_ready);
+    } else {
+      status = iree_hal_amdgpu_host_queue_submit_direct_dispatch(
+          queue, resolution, signal_semaphore_list, &plan, executable,
+          export_ordinal, config, constants, binding_ptrs, operation_resources,
+          uses_custom_direct_arguments, submission_flags, out_ready);
+    }
+  }
+
+  return status;
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_dispatch.h b/runtime/src/iree/hal/drivers/amdgpu/host_queue_dispatch.h
new file mode 100644
index 0000000..672b20c
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_dispatch.h
@@ -0,0 +1,42 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_DISPATCH_H_
+#define IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_DISPATCH_H_
+
+#include "iree/hal/drivers/amdgpu/host_queue_submission.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Validates a queue_dispatch request without requiring transient bindings to
+// have committed backing storage yet.
+iree_status_t iree_hal_amdgpu_host_queue_validate_dispatch(
+    const iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_executable_t* executable,
+    iree_hal_executable_export_ordinal_t export_ordinal,
+    const iree_hal_dispatch_config_t config, iree_const_byte_span_t constants,
+    const iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags,
+    iree_host_size_t* out_operation_resource_count);
+
+// Emits an executable kernel dispatch. Caller must hold submission_mutex.
+iree_status_t iree_hal_amdgpu_host_queue_submit_dispatch(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_executable_t* executable,
+    iree_hal_executable_export_ordinal_t export_ordinal,
+    const iree_hal_dispatch_config_t config, iree_const_byte_span_t constants,
+    const iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    bool* out_ready);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_DISPATCH_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_file.c b/runtime/src/iree/hal/drivers/amdgpu/host_queue_file.c
new file mode 100644
index 0000000..31c6b75
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_file.c
@@ -0,0 +1,630 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/host_queue_file.h"
+
+#include <string.h>
+
+#include "iree/async/operations/file.h"
+#include "iree/hal/drivers/amdgpu/host_queue.h"
+#include "iree/hal/drivers/amdgpu/host_queue_profile.h"
+#include "iree/hal/drivers/amdgpu/host_queue_staging.h"
+#include "iree/hal/drivers/amdgpu/host_queue_submission.h"
+
+typedef enum iree_hal_amdgpu_file_action_kind_e {
+  IREE_HAL_AMDGPU_FILE_ACTION_READ,
+  IREE_HAL_AMDGPU_FILE_ACTION_WRITE,
+} iree_hal_amdgpu_file_action_kind_t;
+
+static iree_hal_profile_queue_event_type_t
+iree_hal_amdgpu_file_action_profile_event_type(
+    iree_hal_amdgpu_file_action_kind_t kind) {
+  return kind == IREE_HAL_AMDGPU_FILE_ACTION_READ
+             ? IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_READ
+             : IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_WRITE;
+}
+
+typedef struct iree_hal_amdgpu_file_action_state_t {
+  // Resource header retained by the queue reclaim entry and async completion.
+  iree_hal_resource_t resource;
+
+  // Host allocator used for this state and cloned semaphore-list storage.
+  iree_allocator_t host_allocator;
+
+  // Proactor used for async file I/O. Borrowed from the logical device.
+  iree_async_proactor_t* proactor;
+
+  // Queue used to publish final signal semaphores after async file I/O.
+  iree_hal_amdgpu_host_queue_t* queue;
+
+  // Logical device retained while async file I/O is pending so |queue| storage
+  // and its physical-device resources remain live.
+  iree_hal_device_t* logical_device;
+
+  // File being read or written. Retained while the action is pending.
+  iree_hal_file_t* file;
+
+  // Async file handle borrowed from |file|.
+  iree_async_file_t* async_file;
+
+  // Buffer being read into or written from. Retained while the action is
+  // pending.
+  iree_hal_buffer_t* buffer;
+
+  // File byte offset for the next transfer.
+  uint64_t file_offset;
+
+  // Buffer byte offset for the mapped range.
+  iree_device_size_t buffer_offset;
+
+  // Total requested transfer length.
+  iree_host_size_t requested_length;
+
+  // Number of wait semaphores supplied to the queue_read/write operation.
+  uint32_t profile_wait_count;
+
+  // Total bytes transferred by completed async file operations.
+  iree_host_size_t completed_length;
+
+  // Direction of the file action.
+  iree_hal_amdgpu_file_action_kind_t kind;
+
+  // Scoped mapping of |buffer| used by async file operations.
+  iree_hal_buffer_mapping_t mapping;
+
+  // Cloned signal list published by a final queue barrier after file I/O.
+  iree_hal_semaphore_list_t signal_semaphore_list;
+
+  // Completion-thread retry queued when the final signal barrier is blocked by
+  // temporary queue capacity pressure.
+  iree_hal_amdgpu_host_queue_post_drain_action_t signal_capacity_retry;
+
+  // Async read operation reused across partial completions.
+  iree_async_file_read_operation_t read_op;
+
+  // Async write operation reused across partial completions.
+  iree_async_file_write_operation_t write_op;
+} iree_hal_amdgpu_file_action_state_t;
+
+static void iree_hal_amdgpu_file_action_state_destroy(
+    iree_hal_resource_t* resource) {
+  iree_hal_amdgpu_file_action_state_t* state =
+      (iree_hal_amdgpu_file_action_state_t*)resource;
+  IREE_TRACE_ZONE_BEGIN(z0);
+  if (!iree_hal_semaphore_list_is_empty(state->signal_semaphore_list)) {
+    iree_hal_semaphore_list_free(state->signal_semaphore_list,
+                                 state->host_allocator);
+  }
+  iree_hal_buffer_release(state->buffer);
+  iree_hal_file_release(state->file);
+  iree_hal_device_release(state->logical_device);
+  iree_allocator_free(state->host_allocator, state);
+  IREE_TRACE_ZONE_END(z0);
+}
+
+static const iree_hal_resource_vtable_t
+    iree_hal_amdgpu_file_action_state_vtable = {
+        .destroy = iree_hal_amdgpu_file_action_state_destroy,
+};
+
+static iree_status_t iree_hal_amdgpu_host_queue_file_barrier(
+    iree_hal_amdgpu_virtual_queue_t* queue,
+    const iree_hal_semaphore_list_t wait_semaphore_list,
+    const iree_hal_semaphore_list_t signal_semaphore_list) {
+  return queue->vtable->execute(
+      queue, wait_semaphore_list, signal_semaphore_list,
+      /*command_buffer=*/NULL, iree_hal_buffer_binding_table_empty(),
+      IREE_HAL_EXECUTE_FLAG_NONE);
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_cast_file_offset(
+    uint64_t file_offset, iree_device_size_t length,
+    iree_device_size_t* out_device_offset) {
+  *out_device_offset = 0;
+  if (IREE_UNLIKELY(file_offset > IREE_DEVICE_SIZE_MAX)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "file offset %" PRIu64
+                            " exceeds device size max %" PRIdsz,
+                            file_offset, IREE_DEVICE_SIZE_MAX);
+  }
+  iree_device_size_t device_offset = (iree_device_size_t)file_offset;
+  iree_device_size_t device_end = 0;
+  if (IREE_UNLIKELY(
+          !iree_device_size_checked_add(device_offset, length, &device_end))) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "file range overflows device size (offset=%" PRIdsz
+                            ", length=%" PRIdsz ")",
+                            device_offset, length);
+  }
+  *out_device_offset = device_offset;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_validate_file_range(
+    iree_hal_file_t* file, const char* operation_name, uint64_t file_offset,
+    iree_device_size_t length, iree_device_size_t* out_device_offset) {
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_cast_file_offset(
+      file_offset, length, out_device_offset));
+  const uint64_t file_length = iree_hal_file_length(file);
+  if (IREE_UNLIKELY(file_offset > file_length ||
+                    (uint64_t)length > file_length - file_offset)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "%s range [%" PRIu64 ", %" PRIu64
+                            ") exceeds file length %" PRIu64,
+                            operation_name, file_offset,
+                            file_offset + (uint64_t)length, file_length);
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_validate_direct_file_buffer(
+    iree_hal_buffer_t* buffer, const char* operation_name,
+    iree_device_size_t buffer_offset, iree_device_size_t length) {
+  IREE_RETURN_IF_ERROR(
+      iree_hal_buffer_validate_range(buffer, buffer_offset, length));
+  if (IREE_UNLIKELY(length > (iree_device_size_t)IREE_HOST_SIZE_MAX)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "%s length %" PRIdsz
+                            " exceeds host addressable size %" PRIhsz,
+                            operation_name, length, IREE_HOST_SIZE_MAX);
+  }
+  if (IREE_UNLIKELY(!iree_all_bits_set(iree_hal_buffer_memory_type(buffer),
+                                       IREE_HAL_MEMORY_TYPE_HOST_VISIBLE))) {
+    return iree_make_status(
+        IREE_STATUS_UNIMPLEMENTED,
+        "AMDGPU queue_%s for non-host-visible buffers requires bounded "
+        "chunked staging",
+        operation_name);
+  }
+  if (IREE_UNLIKELY(!iree_all_bits_set(iree_hal_buffer_allowed_usage(buffer),
+                                       IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED))) {
+    return iree_make_status(
+        IREE_STATUS_UNIMPLEMENTED,
+        "AMDGPU queue_%s for non-mappable buffers requires bounded chunked "
+        "staging",
+        operation_name);
+  }
+  return iree_ok_status();
+}
+
+static bool iree_hal_amdgpu_host_queue_file_buffer_supports_direct_io(
+    iree_hal_buffer_t* buffer) {
+  return iree_all_bits_set(iree_hal_buffer_memory_type(buffer),
+                           IREE_HAL_MEMORY_TYPE_HOST_VISIBLE) &&
+         iree_all_bits_set(iree_hal_buffer_allowed_usage(buffer),
+                           IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED);
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_validate_direct_file_handle(
+    iree_hal_file_t* file, const char* operation_name) {
+  if (IREE_UNLIKELY(!iree_hal_file_async_handle(file))) {
+    return iree_make_status(
+        IREE_STATUS_UNIMPLEMENTED,
+        "AMDGPU queue_%s for non-memory files requires a proactor-backed "
+        "async file handle",
+        operation_name);
+  }
+  return iree_ok_status();
+}
+
+static void iree_hal_amdgpu_file_action_fail_with_borrowed_status(
+    iree_hal_amdgpu_file_action_state_t* state, iree_status_t status) {
+  if (iree_hal_semaphore_list_is_empty(state->signal_semaphore_list)) {
+    return;
+  }
+  iree_hal_semaphore_list_fail(state->signal_semaphore_list,
+                               iree_status_clone(status));
+}
+
+static iree_status_t iree_hal_amdgpu_file_action_clone_queue_error(
+    iree_hal_amdgpu_file_action_state_t* state) {
+  iree_status_t error = (iree_status_t)iree_atomic_load(
+      &state->queue->error_status, iree_memory_order_acquire);
+  return iree_status_is_ok(error) ? iree_ok_status() : iree_status_clone(error);
+}
+
+static void iree_hal_amdgpu_file_action_signal_capacity_post_drain(
+    void* user_data);
+
+static iree_status_t iree_hal_amdgpu_file_action_submit_signal_barrier(
+    iree_hal_amdgpu_file_action_state_t* state) {
+  if (iree_hal_semaphore_list_is_empty(state->signal_semaphore_list)) {
+    return iree_ok_status();
+  }
+
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_file_action_clone_queue_error(state));
+
+  iree_hal_amdgpu_wait_resolution_t resolution;
+  memset(&resolution, 0, sizeof(resolution));
+  resolution.inline_acquire_scope = IREE_HSA_FENCE_SCOPE_SYSTEM;
+  resolution.barrier_acquire_scope = IREE_HSA_FENCE_SCOPE_SYSTEM;
+
+  iree_slim_mutex_lock(&state->queue->locks.submission_mutex);
+  bool ready = false;
+  uint64_t submission_id = 0;
+  iree_hal_amdgpu_host_queue_profile_event_info_t profile_event_info = {
+      .type = iree_hal_amdgpu_file_action_profile_event_type(state->kind),
+      .payload_length = state->requested_length,
+      .operation_count = 1,
+  };
+  iree_status_t status = iree_hal_amdgpu_host_queue_try_submit_barrier(
+      state->queue, &resolution, state->signal_semaphore_list,
+      (iree_hal_amdgpu_reclaim_action_t){0},
+      /*operation_resources=*/NULL, /*operation_resource_count=*/0,
+      &profile_event_info,
+      iree_hal_amdgpu_host_queue_post_commit_callback_null(),
+      /*resource_set=*/NULL,
+      IREE_HAL_AMDGPU_HOST_QUEUE_SUBMISSION_FLAG_RETAIN_RESOURCES, &ready,
+      &submission_id);
+  if (iree_status_is_ok(status) && ready) {
+    iree_hal_amdgpu_wait_resolution_t profile_resolution = resolution;
+    profile_resolution.wait_count = state->profile_wait_count;
+    profile_event_info.submission_id = submission_id;
+    iree_hal_amdgpu_host_queue_record_profile_queue_event(
+        state->queue, &profile_resolution, state->signal_semaphore_list,
+        &profile_event_info);
+  }
+  if (iree_status_is_ok(status) && !ready) {
+    iree_hal_resource_retain(&state->resource);
+    iree_hal_amdgpu_host_queue_enqueue_post_drain_action(
+        state->queue, &state->signal_capacity_retry,
+        iree_hal_amdgpu_file_action_signal_capacity_post_drain, state);
+  }
+  iree_slim_mutex_unlock(&state->queue->locks.submission_mutex);
+  return status;
+}
+
+static void iree_hal_amdgpu_file_action_signal_capacity_post_drain(
+    void* user_data) {
+  iree_hal_amdgpu_file_action_state_t* state =
+      (iree_hal_amdgpu_file_action_state_t*)user_data;
+  iree_status_t status =
+      iree_hal_amdgpu_file_action_submit_signal_barrier(state);
+  if (!iree_status_is_ok(status)) {
+    iree_hal_semaphore_list_fail(state->signal_semaphore_list, status);
+  }
+  iree_hal_resource_release(&state->resource);
+}
+
+static void iree_hal_amdgpu_file_action_complete(
+    iree_hal_amdgpu_file_action_state_t* state, iree_status_t status) {
+  if (iree_status_is_ok(status) &&
+      state->kind == IREE_HAL_AMDGPU_FILE_ACTION_READ &&
+      !iree_all_bits_set(iree_hal_buffer_memory_type(state->mapping.buffer),
+                         IREE_HAL_MEMORY_TYPE_HOST_COHERENT)) {
+    status = iree_status_join(
+        status, iree_hal_buffer_mapping_flush_range(&state->mapping, 0,
+                                                    state->requested_length));
+  }
+  if (state->mapping.buffer) {
+    status =
+        iree_status_join(status, iree_hal_buffer_unmap_range(&state->mapping));
+  }
+
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_file_action_submit_signal_barrier(state);
+  }
+  if (!iree_status_is_ok(status)) {
+    iree_hal_semaphore_list_fail(state->signal_semaphore_list, status);
+  }
+  iree_hal_resource_release(&state->resource);
+}
+
+static iree_status_t iree_hal_amdgpu_file_action_submit_next_read(
+    iree_hal_amdgpu_file_action_state_t* state);
+
+static iree_status_t iree_hal_amdgpu_file_action_submit_next_write(
+    iree_hal_amdgpu_file_action_state_t* state);
+
+static void iree_hal_amdgpu_file_action_read_complete(
+    void* user_data, iree_async_operation_t* base_operation,
+    iree_status_t status, iree_async_completion_flags_t flags) {
+  (void)base_operation;
+  (void)flags;
+  iree_hal_amdgpu_file_action_state_t* state =
+      (iree_hal_amdgpu_file_action_state_t*)user_data;
+
+  bool should_complete = true;
+  if (iree_status_is_ok(status) && state->read_op.bytes_read > 0) {
+    state->completed_length += state->read_op.bytes_read;
+    if (state->completed_length < state->requested_length) {
+      status = iree_hal_amdgpu_file_action_submit_next_read(state);
+      should_complete = !iree_status_is_ok(status);
+    }
+  } else if (iree_status_is_ok(status) &&
+             state->completed_length < state->requested_length) {
+    status = iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                              "short read: requested %" PRIhsz
+                              " bytes, got %" PRIhsz,
+                              state->requested_length, state->completed_length);
+  }
+
+  if (should_complete) {
+    iree_hal_amdgpu_file_action_complete(state, status);
+  }
+}
+
+static void iree_hal_amdgpu_file_action_write_complete(
+    void* user_data, iree_async_operation_t* base_operation,
+    iree_status_t status, iree_async_completion_flags_t flags) {
+  (void)base_operation;
+  (void)flags;
+  iree_hal_amdgpu_file_action_state_t* state =
+      (iree_hal_amdgpu_file_action_state_t*)user_data;
+
+  bool should_complete = true;
+  if (iree_status_is_ok(status) && state->write_op.bytes_written > 0) {
+    state->completed_length += state->write_op.bytes_written;
+    if (state->completed_length < state->requested_length) {
+      status = iree_hal_amdgpu_file_action_submit_next_write(state);
+      should_complete = !iree_status_is_ok(status);
+    }
+  } else if (iree_status_is_ok(status) &&
+             state->completed_length < state->requested_length) {
+    status = iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                              "short write: requested %" PRIhsz
+                              " bytes, wrote %" PRIhsz,
+                              state->requested_length, state->completed_length);
+  }
+
+  if (should_complete) {
+    iree_hal_amdgpu_file_action_complete(state, status);
+  }
+}
+
+static iree_status_t iree_hal_amdgpu_file_action_submit_next_read(
+    iree_hal_amdgpu_file_action_state_t* state) {
+  const iree_host_size_t remaining_length =
+      state->requested_length - state->completed_length;
+  iree_async_operation_zero(&state->read_op.base, sizeof(state->read_op));
+  iree_async_operation_initialize(
+      &state->read_op.base, IREE_ASYNC_OPERATION_TYPE_FILE_READ,
+      IREE_ASYNC_OPERATION_FLAG_NONE, iree_hal_amdgpu_file_action_read_complete,
+      state);
+  state->read_op.file = state->async_file;
+  state->read_op.offset = state->file_offset + state->completed_length;
+  state->read_op.buffer = iree_async_span_from_ptr(
+      state->mapping.contents.data + state->completed_length, remaining_length);
+  return iree_async_proactor_submit_one(state->proactor, &state->read_op.base);
+}
+
+static iree_status_t iree_hal_amdgpu_file_action_submit_next_write(
+    iree_hal_amdgpu_file_action_state_t* state) {
+  const iree_host_size_t remaining_length =
+      state->requested_length - state->completed_length;
+  iree_async_operation_zero(&state->write_op.base, sizeof(state->write_op));
+  iree_async_operation_initialize(
+      &state->write_op.base, IREE_ASYNC_OPERATION_TYPE_FILE_WRITE,
+      IREE_ASYNC_OPERATION_FLAG_NONE,
+      iree_hal_amdgpu_file_action_write_complete, state);
+  state->write_op.file = state->async_file;
+  state->write_op.offset = state->file_offset + state->completed_length;
+  state->write_op.buffer = iree_async_span_from_ptr(
+      state->mapping.contents.data + state->completed_length, remaining_length);
+  return iree_async_proactor_submit_one(state->proactor, &state->write_op.base);
+}
+
+static iree_status_t iree_hal_amdgpu_file_action_start_async(
+    iree_hal_amdgpu_file_action_state_t* state) {
+  iree_hal_memory_access_t mapping_access = IREE_HAL_MEMORY_ACCESS_READ;
+  if (state->kind == IREE_HAL_AMDGPU_FILE_ACTION_READ) {
+    mapping_access = IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE;
+  }
+  iree_status_t status = iree_hal_buffer_map_range(
+      state->buffer, IREE_HAL_MAPPING_MODE_SCOPED, mapping_access,
+      state->buffer_offset, state->requested_length, &state->mapping);
+
+  if (iree_status_is_ok(status) &&
+      state->kind == IREE_HAL_AMDGPU_FILE_ACTION_WRITE &&
+      !iree_all_bits_set(iree_hal_buffer_memory_type(state->mapping.buffer),
+                         IREE_HAL_MEMORY_TYPE_HOST_COHERENT)) {
+    status = iree_hal_buffer_mapping_invalidate_range(&state->mapping, 0,
+                                                      state->requested_length);
+  }
+
+  if (iree_status_is_ok(status)) {
+    iree_hal_resource_retain(&state->resource);
+    if (state->kind == IREE_HAL_AMDGPU_FILE_ACTION_READ) {
+      status = iree_hal_amdgpu_file_action_submit_next_read(state);
+    } else {
+      status = iree_hal_amdgpu_file_action_submit_next_write(state);
+    }
+    if (iree_status_is_ok(status)) {
+      return iree_ok_status();
+    }
+    iree_hal_resource_release(&state->resource);
+  }
+
+  if (state->mapping.buffer) {
+    status =
+        iree_status_join(status, iree_hal_buffer_unmap_range(&state->mapping));
+  }
+  return status;
+}
+
+static void iree_hal_amdgpu_file_action_execute(
+    iree_hal_amdgpu_reclaim_entry_t* entry, void* user_data,
+    iree_status_t status) {
+  (void)entry;
+  iree_hal_amdgpu_file_action_state_t* state =
+      (iree_hal_amdgpu_file_action_state_t*)user_data;
+
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_file_action_start_async(state);
+    if (!iree_status_is_ok(status)) {
+      iree_hal_semaphore_list_fail(state->signal_semaphore_list, status);
+    }
+  } else {
+    iree_hal_amdgpu_file_action_fail_with_borrowed_status(state, status);
+  }
+}
+
+static iree_status_t iree_hal_amdgpu_file_action_state_create(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_file_action_kind_t kind, iree_hal_file_t* file,
+    uint64_t file_offset, iree_hal_buffer_t* buffer,
+    iree_device_size_t buffer_offset, iree_device_size_t length,
+    uint32_t profile_wait_count,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_amdgpu_file_action_state_t** out_state) {
+  *out_state = NULL;
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, length);
+
+  iree_hal_amdgpu_file_action_state_t* state = NULL;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_allocator_malloc(queue->host_allocator, sizeof(*state),
+                                (void**)&state));
+  memset(state, 0, sizeof(*state));
+  iree_hal_resource_initialize(&iree_hal_amdgpu_file_action_state_vtable,
+                               &state->resource);
+  state->host_allocator = queue->host_allocator;
+  state->proactor = queue->proactor;
+  state->queue = queue;
+  state->logical_device = queue->logical_device;
+  iree_hal_device_retain(state->logical_device);
+  state->file = file;
+  iree_hal_file_retain(state->file);
+  state->async_file = iree_hal_file_async_handle(file);
+  state->buffer = buffer;
+  iree_hal_buffer_retain(state->buffer);
+  state->file_offset = file_offset;
+  state->buffer_offset = buffer_offset;
+  state->requested_length = (iree_host_size_t)length;
+  state->profile_wait_count = profile_wait_count;
+  state->kind = kind;
+
+  iree_status_t status = iree_hal_semaphore_list_clone(
+      &signal_semaphore_list, state->host_allocator,
+      &state->signal_semaphore_list);
+  if (iree_status_is_ok(status)) {
+    *out_state = state;
+  } else {
+    iree_hal_resource_release(&state->resource);
+  }
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_submit_direct_file_action(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t wait_semaphore_list,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_amdgpu_file_action_kind_t kind, iree_hal_file_t* file,
+    uint64_t file_offset, iree_hal_buffer_t* buffer,
+    iree_device_size_t buffer_offset, iree_device_size_t length) {
+  iree_hal_amdgpu_file_action_state_t* state = NULL;
+  const uint32_t profile_wait_count =
+      iree_hal_amdgpu_host_queue_profile_semaphore_count(wait_semaphore_list);
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_file_action_state_create(
+      queue, kind, file, file_offset, buffer, buffer_offset, length,
+      profile_wait_count, signal_semaphore_list, &state));
+
+  iree_hal_resource_t* resources[1] = {&state->resource};
+  iree_status_t status = iree_hal_amdgpu_host_queue_enqueue_host_action(
+      queue, wait_semaphore_list,
+      (iree_hal_amdgpu_reclaim_action_t){
+          .fn = iree_hal_amdgpu_file_action_execute,
+          .user_data = state,
+      },
+      resources, IREE_ARRAYSIZE(resources));
+  iree_hal_resource_release(&state->resource);
+  return status;
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_read_file(
+    iree_hal_amdgpu_virtual_queue_t* queue,
+    const iree_hal_semaphore_list_t wait_semaphore_list,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_file_t* source_file, uint64_t source_offset,
+    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+    iree_device_size_t length, iree_hal_read_flags_t flags) {
+  IREE_RETURN_IF_ERROR(
+      iree_hal_file_validate_access(source_file, IREE_HAL_MEMORY_ACCESS_READ));
+  if (IREE_UNLIKELY(flags != IREE_HAL_READ_FLAG_NONE)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "unsupported read flags: 0x%" PRIx64, flags);
+  }
+  if (length == 0) {
+    return iree_hal_amdgpu_host_queue_file_barrier(queue, wait_semaphore_list,
+                                                   signal_semaphore_list);
+  }
+
+  iree_hal_buffer_t* storage_buffer = iree_hal_file_storage_buffer(source_file);
+  iree_device_size_t source_device_offset = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_validate_file_range(
+      source_file, "read", source_offset, length, &source_device_offset));
+  if (!storage_buffer) {
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_validate_direct_file_handle(
+        source_file, "read"));
+    if (iree_hal_amdgpu_host_queue_file_buffer_supports_direct_io(
+            target_buffer)) {
+      IREE_RETURN_IF_ERROR(
+          iree_hal_amdgpu_host_queue_validate_direct_file_buffer(
+              target_buffer, "read", target_offset, length));
+      return iree_hal_amdgpu_host_queue_submit_direct_file_action(
+          (iree_hal_amdgpu_host_queue_t*)queue, wait_semaphore_list,
+          signal_semaphore_list, IREE_HAL_AMDGPU_FILE_ACTION_READ, source_file,
+          source_offset, target_buffer, target_offset, length);
+    }
+    return iree_hal_amdgpu_host_queue_submit_staged_read(
+        (iree_hal_amdgpu_host_queue_t*)queue, wait_semaphore_list,
+        signal_semaphore_list, source_file, source_offset, target_buffer,
+        target_offset, length);
+  }
+  return iree_hal_amdgpu_host_queue_copy_buffer(
+      (iree_hal_amdgpu_host_queue_t*)queue, wait_semaphore_list,
+      signal_semaphore_list, storage_buffer, source_device_offset,
+      target_buffer, target_offset, length, IREE_HAL_COPY_FLAG_NONE,
+      IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_READ);
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_write_file(
+    iree_hal_amdgpu_virtual_queue_t* queue,
+    const iree_hal_semaphore_list_t wait_semaphore_list,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
+    iree_hal_file_t* target_file, uint64_t target_offset,
+    iree_device_size_t length, iree_hal_write_flags_t flags) {
+  IREE_RETURN_IF_ERROR(
+      iree_hal_file_validate_access(target_file, IREE_HAL_MEMORY_ACCESS_WRITE));
+  if (IREE_UNLIKELY(flags != IREE_HAL_WRITE_FLAG_NONE)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "unsupported write flags: 0x%" PRIx64, flags);
+  }
+  if (length == 0) {
+    return iree_hal_amdgpu_host_queue_file_barrier(queue, wait_semaphore_list,
+                                                   signal_semaphore_list);
+  }
+
+  iree_hal_buffer_t* storage_buffer = iree_hal_file_storage_buffer(target_file);
+  iree_device_size_t target_device_offset = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_validate_file_range(
+      target_file, "write", target_offset, length, &target_device_offset));
+  if (!storage_buffer) {
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_validate_direct_file_handle(
+        target_file, "write"));
+    if (iree_hal_amdgpu_host_queue_file_buffer_supports_direct_io(
+            source_buffer)) {
+      IREE_RETURN_IF_ERROR(
+          iree_hal_amdgpu_host_queue_validate_direct_file_buffer(
+              source_buffer, "write", source_offset, length));
+      return iree_hal_amdgpu_host_queue_submit_direct_file_action(
+          (iree_hal_amdgpu_host_queue_t*)queue, wait_semaphore_list,
+          signal_semaphore_list, IREE_HAL_AMDGPU_FILE_ACTION_WRITE, target_file,
+          target_offset, source_buffer, source_offset, length);
+    }
+    return iree_hal_amdgpu_host_queue_submit_staged_write(
+        (iree_hal_amdgpu_host_queue_t*)queue, wait_semaphore_list,
+        signal_semaphore_list, source_buffer, source_offset, target_file,
+        target_offset, length);
+  }
+  return iree_hal_amdgpu_host_queue_copy_buffer(
+      (iree_hal_amdgpu_host_queue_t*)queue, wait_semaphore_list,
+      signal_semaphore_list, source_buffer, source_offset, storage_buffer,
+      target_device_offset, length, IREE_HAL_COPY_FLAG_NONE,
+      IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_WRITE);
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_file.h b/runtime/src/iree/hal/drivers/amdgpu/host_queue_file.h
new file mode 100644
index 0000000..d0d25e1
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_file.h
@@ -0,0 +1,40 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_FILE_H_
+#define IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_FILE_H_
+
+#include "iree/hal/drivers/amdgpu/virtual_queue.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Implements queue_read for memory-file, direct mappable, and staged fd-backed
+// file transfers.
+iree_status_t iree_hal_amdgpu_host_queue_read_file(
+    iree_hal_amdgpu_virtual_queue_t* queue,
+    const iree_hal_semaphore_list_t wait_semaphore_list,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_file_t* source_file, uint64_t source_offset,
+    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+    iree_device_size_t length, iree_hal_read_flags_t flags);
+
+// Implements queue_write for memory-file, direct mappable, and staged fd-backed
+// file transfers.
+iree_status_t iree_hal_amdgpu_host_queue_write_file(
+    iree_hal_amdgpu_virtual_queue_t* queue,
+    const iree_hal_semaphore_list_t wait_semaphore_list,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
+    iree_hal_file_t* target_file, uint64_t target_offset,
+    iree_device_size_t length, iree_hal_write_flags_t flags);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_FILE_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_host_call.c b/runtime/src/iree/hal/drivers/amdgpu/host_queue_host_call.c
new file mode 100644
index 0000000..8d55edc
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_host_call.c
@@ -0,0 +1,252 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/host_queue_host_call.h"
+
+#include <string.h>
+
+#include "iree/hal/drivers/amdgpu/host_queue_policy.h"
+#include "iree/hal/drivers/amdgpu/host_queue_profile.h"
+
+typedef struct iree_hal_amdgpu_host_call_state_t {
+  // Resource header so existing reclaim cleanup owns this cold payload.
+  iree_hal_resource_t resource;
+
+  // Host allocator used for this state and cloned semaphore-list storage.
+  iree_allocator_t host_allocator;
+
+  // Device reported to the host-call callback. Borrowed from the queue.
+  iree_hal_device_t* device;
+
+  // Queue affinity reported to the host-call callback.
+  iree_hal_queue_affinity_t queue_affinity;
+
+  // User callback and user data captured at queue_host_call submission.
+  iree_hal_host_call_t call;
+
+  // User arguments copied at queue_host_call submission.
+  uint64_t args[4];
+
+  // Host-call flags captured at queue_host_call submission.
+  iree_hal_host_call_flags_t flags;
+
+  // Cloned signal list retained until the reclaim action runs.
+  iree_hal_semaphore_list_t signal_semaphore_list;
+} iree_hal_amdgpu_host_call_state_t;
+
+static void iree_hal_amdgpu_host_call_state_destroy(
+    iree_hal_resource_t* resource) {
+  iree_hal_amdgpu_host_call_state_t* state =
+      (iree_hal_amdgpu_host_call_state_t*)resource;
+  IREE_TRACE_ZONE_BEGIN(z0);
+  if (!iree_hal_semaphore_list_is_empty(state->signal_semaphore_list)) {
+    iree_hal_semaphore_list_free(state->signal_semaphore_list,
+                                 state->host_allocator);
+  }
+  iree_allocator_free(state->host_allocator, state);
+  IREE_TRACE_ZONE_END(z0);
+}
+
+static const iree_hal_resource_vtable_t iree_hal_amdgpu_host_call_state_vtable =
+    {
+        .destroy = iree_hal_amdgpu_host_call_state_destroy,
+};
+
+iree_status_t iree_hal_amdgpu_host_queue_validate_host_call(
+    iree_hal_host_call_t call, const uint64_t args[4],
+    iree_hal_host_call_flags_t flags) {
+  const iree_hal_host_call_flags_t known_flags =
+      IREE_HAL_HOST_CALL_FLAG_NON_BLOCKING |
+      IREE_HAL_HOST_CALL_FLAG_WAIT_ACTIVE | IREE_HAL_HOST_CALL_FLAG_RELAXED;
+  if (IREE_UNLIKELY(!call.fn)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "host_call callback must be non-null");
+  }
+  if (IREE_UNLIKELY(!args)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "host_call args must be non-null");
+  }
+  if (IREE_UNLIKELY(iree_any_bit_set(flags, ~known_flags))) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "unsupported host_call flags: 0x%" PRIx64, flags);
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_host_call_state_create(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_host_call_t call, const uint64_t args[4],
+    iree_hal_host_call_flags_t flags,
+    iree_hal_amdgpu_host_call_state_t** out_state) {
+  *out_state = NULL;
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_hal_amdgpu_host_queue_validate_host_call(call, args, flags));
+  iree_hal_amdgpu_host_call_state_t* state = NULL;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_allocator_malloc(queue->host_allocator, sizeof(*state),
+                                (void**)&state));
+  memset(state, 0, sizeof(*state));
+  iree_hal_resource_initialize(&iree_hal_amdgpu_host_call_state_vtable,
+                               &state->resource);
+  state->host_allocator = queue->host_allocator;
+  state->device = queue->logical_device;
+  state->queue_affinity = queue->queue_affinity;
+  state->call = call;
+  memcpy(state->args, args, sizeof(state->args));
+  state->flags = flags;
+
+  iree_status_t status = iree_hal_semaphore_list_clone(
+      &signal_semaphore_list, state->host_allocator,
+      &state->signal_semaphore_list);
+  if (iree_status_is_ok(status)) {
+    *out_state = state;
+  } else {
+    iree_hal_resource_release(&state->resource);
+  }
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+static void iree_hal_amdgpu_host_call_fail_with_borrowed_status(
+    iree_hal_semaphore_list_t signal_semaphore_list, iree_status_t status) {
+  if (signal_semaphore_list.count == 0) {
+    return;
+  }
+  iree_hal_semaphore_list_fail(signal_semaphore_list,
+                               iree_status_clone(status));
+}
+
+static void iree_hal_amdgpu_host_call_signal_or_fail(
+    iree_hal_semaphore_list_t signal_semaphore_list) {
+  iree_status_t signal_status =
+      iree_hal_semaphore_list_signal(signal_semaphore_list, /*frontier=*/NULL);
+  if (!iree_status_is_ok(signal_status)) {
+    iree_hal_semaphore_list_fail(signal_semaphore_list, signal_status);
+  }
+}
+
+// Consumes a callback status whose result is intentionally unobservable by the
+// host-call API contract. NON_BLOCKING callbacks are fire-and-forget after the
+// queue has signaled, and DEFERRED callbacks transfer completion ownership to
+// the callback's cloned signal list.
+static void iree_hal_amdgpu_host_call_consume_unobservable_status(
+    iree_status_t status) {
+  iree_status_free(status);
+}
+
+static void iree_hal_amdgpu_host_call_execute(
+    iree_hal_amdgpu_reclaim_entry_t* entry, void* user_data,
+    iree_status_t status) {
+  (void)entry;
+  iree_hal_amdgpu_host_call_state_t* state =
+      (iree_hal_amdgpu_host_call_state_t*)user_data;
+
+  if (!iree_status_is_ok(status)) {
+    iree_hal_amdgpu_host_call_fail_with_borrowed_status(
+        state->signal_semaphore_list, status);
+    return;
+  }
+
+  const bool is_nonblocking =
+      iree_any_bit_set(state->flags, IREE_HAL_HOST_CALL_FLAG_NON_BLOCKING);
+  if (is_nonblocking) {
+    iree_status_t signal_status = iree_hal_semaphore_list_signal(
+        state->signal_semaphore_list, /*frontier=*/NULL);
+    if (!iree_status_is_ok(signal_status)) {
+      iree_hal_semaphore_list_fail(state->signal_semaphore_list, signal_status);
+      return;
+    }
+  }
+
+  iree_hal_host_call_context_t context = {
+      .device = state->device,
+      .queue_affinity = state->queue_affinity,
+      .signal_semaphore_list = is_nonblocking ? iree_hal_semaphore_list_empty()
+                                              : state->signal_semaphore_list,
+  };
+  iree_status_t call_status =
+      state->call.fn(state->call.user_data, state->args, &context);
+
+  if (is_nonblocking) {
+    iree_hal_amdgpu_host_call_consume_unobservable_status(call_status);
+  } else if (iree_status_is_deferred(call_status)) {
+    iree_hal_amdgpu_host_call_consume_unobservable_status(call_status);
+  } else if (iree_status_is_ok(call_status)) {
+    iree_hal_amdgpu_host_call_signal_or_fail(state->signal_semaphore_list);
+  } else {
+    iree_hal_semaphore_list_fail(state->signal_semaphore_list, call_status);
+  }
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_submit_host_call(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_host_call_t call, const uint64_t args[4],
+    iree_hal_host_call_flags_t flags, bool* out_ready) {
+  IREE_ASSERT_ARGUMENT(out_ready);
+  *out_ready = false;
+  if (IREE_UNLIKELY(queue->is_shutting_down)) {
+    return iree_make_status(IREE_STATUS_CANCELLED, "queue shutting down");
+  }
+
+  iree_hal_amdgpu_host_call_state_t* state = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_call_state_create(
+      queue, signal_semaphore_list, call, args, flags, &state));
+
+  // Host callbacks observe host memory by default. RELAXED opts out when the
+  // callback is known not to consume device-produced host-visible data.
+  iree_hal_amdgpu_wait_resolution_t host_call_resolution = *resolution;
+  if (!iree_any_bit_set(flags, IREE_HAL_HOST_CALL_FLAG_RELAXED)) {
+    host_call_resolution.inline_acquire_scope =
+        iree_hal_amdgpu_host_queue_max_fence_scope(
+            host_call_resolution.inline_acquire_scope,
+            IREE_HSA_FENCE_SCOPE_SYSTEM);
+    host_call_resolution.barrier_acquire_scope =
+        iree_hal_amdgpu_host_queue_max_fence_scope(
+            host_call_resolution.barrier_acquire_scope,
+            IREE_HSA_FENCE_SCOPE_SYSTEM);
+  }
+
+  iree_hal_resource_t* operation_resources[1] = {
+      &state->resource,
+  };
+  iree_hal_amdgpu_host_queue_profile_event_info_t profile_event_info = {
+      .type = IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_HOST_CALL,
+      .operation_count = 1,
+  };
+  iree_hal_amdgpu_host_queue_barrier_submission_t submission;
+  iree_status_t status =
+      iree_hal_amdgpu_host_queue_try_begin_barrier_submission(
+          queue, &host_call_resolution, iree_hal_semaphore_list_empty(),
+          IREE_ARRAYSIZE(operation_resources), &profile_event_info, out_ready,
+          &submission);
+  if (iree_status_is_ok(status) && *out_ready) {
+    const uint64_t submission_id =
+        iree_hal_amdgpu_host_queue_finish_barrier_submission(
+            queue, &host_call_resolution, iree_hal_semaphore_list_empty(),
+            (iree_hal_amdgpu_reclaim_action_t){
+                .fn = iree_hal_amdgpu_host_call_execute,
+                .user_data = state,
+            },
+            operation_resources, IREE_ARRAYSIZE(operation_resources),
+            &profile_event_info,
+            iree_hal_amdgpu_host_queue_post_commit_callback_null(),
+            /*resource_set=*/NULL,
+            IREE_HAL_AMDGPU_HOST_QUEUE_SUBMISSION_FLAG_NONE, &submission);
+    profile_event_info.submission_id = submission_id;
+    iree_hal_amdgpu_host_queue_record_profile_queue_event(
+        queue, &host_call_resolution, signal_semaphore_list,
+        &profile_event_info);
+  }
+  if (!iree_status_is_ok(status) || !*out_ready) {
+    iree_hal_resource_release(&state->resource);
+  }
+  return status;
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_host_call.h b/runtime/src/iree/hal/drivers/amdgpu/host_queue_host_call.h
new file mode 100644
index 0000000..20655b4
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_host_call.h
@@ -0,0 +1,33 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_HOST_CALL_H_
+#define IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_HOST_CALL_H_
+
+#include "iree/hal/drivers/amdgpu/host_queue_submission.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Validates host-call parameters before capture or submission.
+iree_status_t iree_hal_amdgpu_host_queue_validate_host_call(
+    iree_hal_host_call_t call, const uint64_t args[4],
+    iree_hal_host_call_flags_t flags);
+
+// Emits a host-call barrier epoch. Caller must hold submission_mutex.
+iree_status_t iree_hal_amdgpu_host_queue_submit_host_call(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_host_call_t call, const uint64_t args[4],
+    iree_hal_host_call_flags_t flags, bool* out_ready);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_HOST_CALL_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_memory.c b/runtime/src/iree/hal/drivers/amdgpu/host_queue_memory.c
new file mode 100644
index 0000000..e616c97
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_memory.c
@@ -0,0 +1,702 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/host_queue_memory.h"
+
+#include <string.h>
+
+#include "iree/hal/drivers/amdgpu/host_queue_profile.h"
+#include "iree/hal/drivers/amdgpu/logical_device.h"
+#include "iree/hal/drivers/amdgpu/transient_buffer.h"
+
+static void iree_hal_amdgpu_host_queue_populate_memory_event_pool_stats(
+    iree_hal_pool_t* pool, iree_hal_profile_memory_event_t* event) {
+  if (!pool) return;
+  iree_hal_pool_stats_t stats;
+  iree_hal_pool_query_stats(pool, &stats);
+  event->flags |= IREE_HAL_PROFILE_MEMORY_EVENT_FLAG_POOL_STATS;
+  event->pool_bytes_reserved = stats.bytes_reserved;
+  event->pool_bytes_free = stats.bytes_free;
+  event->pool_bytes_committed = stats.bytes_committed;
+  event->pool_budget_limit = stats.budget_limit;
+  event->pool_reservation_count = stats.reservation_count;
+  event->pool_slab_count = stats.slab_count;
+}
+
+static uint64_t iree_hal_amdgpu_host_queue_memory_profile_allocation_id(
+    iree_hal_buffer_t* buffer) {
+  return iree_hal_amdgpu_transient_buffer_profile_allocation_id(buffer);
+}
+
+static uint64_t iree_hal_amdgpu_host_queue_memory_profile_session_id(
+    iree_hal_buffer_t* buffer) {
+  return iree_hal_amdgpu_transient_buffer_profile_session_id(buffer);
+}
+
+static void iree_hal_amdgpu_host_queue_record_memory_event(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_profile_memory_event_type_t type,
+    iree_hal_profile_memory_event_flags_t flags, uint32_t result,
+    iree_hal_pool_t* pool, iree_hal_buffer_params_t params,
+    iree_hal_buffer_t* buffer, const iree_hal_pool_reservation_t* reservation,
+    iree_device_size_t length, uint64_t submission_id,
+    uint32_t frontier_entry_count) {
+  if (!iree_hal_amdgpu_logical_device_should_record_profile_memory_events(
+          queue->logical_device)) {
+    return;
+  }
+
+  iree_hal_profile_memory_event_t event =
+      iree_hal_profile_memory_event_default();
+  event.type = type;
+  event.flags = flags;
+  event.result = result;
+  event.allocation_id =
+      iree_hal_amdgpu_host_queue_memory_profile_allocation_id(buffer);
+  event.pool_id = (uint64_t)(uintptr_t)pool;
+  event.submission_id = submission_id;
+  event.physical_device_ordinal =
+      iree_hal_amdgpu_host_queue_profile_device_ordinal(queue);
+  event.queue_ordinal = iree_hal_amdgpu_host_queue_profile_queue_ordinal(queue);
+  event.frontier_entry_count = frontier_entry_count;
+  event.memory_type = params.type;
+  event.buffer_usage = params.usage;
+  event.length = length;
+  event.alignment = params.min_alignment ? params.min_alignment : 1;
+  if (reservation) {
+    event.flags |= IREE_HAL_PROFILE_MEMORY_EVENT_FLAG_POOL_RESERVATION;
+    event.backing_id = reservation->block_handle;
+    event.offset = reservation->offset;
+    if (type != IREE_HAL_PROFILE_MEMORY_EVENT_TYPE_QUEUE_ALLOCA &&
+        type != IREE_HAL_PROFILE_MEMORY_EVENT_TYPE_QUEUE_DEALLOCA) {
+      event.length = reservation->length;
+    }
+  }
+  iree_hal_amdgpu_host_queue_populate_memory_event_pool_stats(pool, &event);
+  iree_hal_amdgpu_logical_device_record_profile_memory_event_for_session(
+      queue->logical_device,
+      iree_hal_amdgpu_host_queue_memory_profile_session_id(buffer), &event);
+}
+
+typedef struct iree_hal_amdgpu_host_queue_release_reservation_state_t {
+  // Queue whose frontier owns the release.
+  iree_hal_amdgpu_host_queue_t* queue;
+  // Transient buffer whose reservation is released.
+  iree_hal_buffer_t* buffer;
+} iree_hal_amdgpu_host_queue_release_reservation_state_t;
+
+static void iree_hal_amdgpu_host_queue_commit_transient_buffer(
+    iree_hal_amdgpu_reclaim_entry_t* entry, void* user_data,
+    iree_status_t status) {
+  (void)entry;
+  if (!iree_status_is_ok(status)) return;
+  iree_hal_amdgpu_transient_buffer_commit((iree_hal_buffer_t*)user_data);
+}
+
+static void iree_hal_amdgpu_host_queue_decommit_transient_buffer(
+    iree_hal_amdgpu_reclaim_entry_t* entry, void* user_data,
+    iree_status_t status) {
+  (void)entry;
+  if (!iree_status_is_ok(status)) return;
+  iree_hal_amdgpu_transient_buffer_decommit((iree_hal_buffer_t*)user_data);
+}
+
+static void iree_hal_amdgpu_host_queue_release_transient_buffer_reservation(
+    void* user_data, const iree_async_frontier_t* queue_frontier,
+    uint64_t submission_id) {
+  iree_hal_amdgpu_host_queue_release_reservation_state_t* state =
+      (iree_hal_amdgpu_host_queue_release_reservation_state_t*)user_data;
+  iree_hal_pool_t* pool = NULL;
+  iree_hal_pool_reservation_t reservation;
+  const bool has_reservation =
+      iree_hal_amdgpu_transient_buffer_query_reservation(state->buffer, &pool,
+                                                         &reservation);
+  iree_hal_buffer_params_t params = {0};
+  iree_device_size_t length = 0;
+  if (has_reservation) {
+    params = (iree_hal_buffer_params_t){
+        .type = iree_hal_buffer_memory_type(state->buffer),
+        .access = iree_hal_buffer_allowed_access(state->buffer),
+        .usage = iree_hal_buffer_allowed_usage(state->buffer),
+    };
+    length = iree_hal_buffer_byte_length(state->buffer);
+  }
+  iree_hal_amdgpu_transient_buffer_release_reservation(state->buffer,
+                                                       queue_frontier);
+  if (has_reservation) {
+    iree_hal_amdgpu_host_queue_record_memory_event(
+        state->queue, IREE_HAL_PROFILE_MEMORY_EVENT_TYPE_POOL_RELEASE,
+        IREE_HAL_PROFILE_MEMORY_EVENT_FLAG_QUEUE_OPERATION, UINT32_MAX, pool,
+        params, state->buffer, &reservation, length, submission_id,
+        queue_frontier ? queue_frontier->entry_count : 0);
+  }
+}
+
+static void iree_hal_amdgpu_host_queue_apply_pool_optimal_memory_type(
+    const iree_hal_pool_capabilities_t* capabilities,
+    iree_hal_buffer_params_t* params) {
+  if (iree_any_bit_set(params->type, IREE_HAL_MEMORY_TYPE_OPTIMAL)) {
+    params->type &= ~IREE_HAL_MEMORY_TYPE_OPTIMAL;
+    params->type |= capabilities->memory_type;
+  }
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_select_alloca_pool(
+    iree_hal_amdgpu_host_queue_t* queue, iree_hal_pool_t* explicit_pool,
+    iree_hal_buffer_params_t* params, iree_device_size_t allocation_size,
+    iree_hal_pool_t** out_allocation_pool,
+    iree_hal_pool_capabilities_t* out_capabilities) {
+  if (explicit_pool) {
+    iree_hal_pool_query_capabilities(explicit_pool, out_capabilities);
+    iree_hal_amdgpu_host_queue_apply_pool_optimal_memory_type(out_capabilities,
+                                                              params);
+    *out_allocation_pool = explicit_pool;
+    return iree_ok_status();
+  }
+
+  iree_hal_pool_t* selected_pool = iree_hal_pool_set_select(
+      queue->default_pool_set, *params, allocation_size);
+  if (IREE_UNLIKELY(!selected_pool)) {
+    return iree_make_status(
+        IREE_STATUS_NOT_FOUND,
+        "no default AMDGPU pool can satisfy queue_alloca of %" PRIdsz " bytes",
+        allocation_size);
+  }
+  iree_hal_pool_query_capabilities(selected_pool, out_capabilities);
+  iree_hal_amdgpu_host_queue_apply_pool_optimal_memory_type(out_capabilities,
+                                                            params);
+  *out_allocation_pool = selected_pool;
+  return iree_ok_status();
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_prepare_alloca_wrapper(
+    iree_hal_amdgpu_host_queue_t* queue, iree_hal_pool_t* pool,
+    iree_hal_buffer_params_t* params, iree_device_size_t allocation_size,
+    iree_hal_alloca_flags_t flags, iree_hal_pool_t** out_allocation_pool,
+    iree_hal_buffer_t** out_buffer) {
+  IREE_ASSERT_ARGUMENT(queue);
+  IREE_ASSERT_ARGUMENT(params);
+  IREE_ASSERT_ARGUMENT(out_allocation_pool);
+  IREE_ASSERT_ARGUMENT(out_buffer);
+  *out_allocation_pool = NULL;
+  *out_buffer = NULL;
+
+  if (IREE_UNLIKELY(iree_any_bit_set(
+          flags, ~(IREE_HAL_ALLOCA_FLAG_NONE |
+                   IREE_HAL_ALLOCA_FLAG_INDETERMINATE_LIFETIME |
+                   IREE_HAL_ALLOCA_FLAG_ALLOW_POOL_WAIT_FRONTIER)))) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "unsupported alloca flags: 0x%" PRIx64, flags);
+  }
+  if (IREE_UNLIKELY(allocation_size == 0)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "queue_alloca allocation_size must be non-zero");
+  }
+
+  iree_hal_buffer_params_canonicalize(params);
+  iree_hal_pool_capabilities_t capabilities;
+  iree_hal_pool_t* allocation_pool = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_select_alloca_pool(
+      queue, pool, params, allocation_size, &allocation_pool, &capabilities));
+  if (IREE_UNLIKELY(
+          !iree_all_bits_set(capabilities.memory_type, params->type))) {
+    iree_bitfield_string_temp_t requested_type_string;
+    iree_bitfield_string_temp_t pool_type_string;
+    iree_string_view_t requested_type =
+        iree_hal_memory_type_format(params->type, &requested_type_string);
+    iree_string_view_t pool_type = iree_hal_memory_type_format(
+        capabilities.memory_type, &pool_type_string);
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "allocation pool does not support requested memory type %.*s "
+        "(pool memory type %.*s)",
+        (int)requested_type.size, requested_type.data, (int)pool_type.size,
+        pool_type.data);
+  }
+  if (IREE_UNLIKELY(
+          !iree_all_bits_set(capabilities.supported_usage, params->usage))) {
+    iree_bitfield_string_temp_t requested_usage_string;
+    iree_bitfield_string_temp_t pool_usage_string;
+    iree_string_view_t requested_usage =
+        iree_hal_buffer_usage_format(params->usage, &requested_usage_string);
+    iree_string_view_t pool_usage = iree_hal_buffer_usage_format(
+        capabilities.supported_usage, &pool_usage_string);
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "allocation pool does not support requested buffer usage %.*s "
+        "(pool usage %.*s)",
+        (int)requested_usage.size, requested_usage.data, (int)pool_usage.size,
+        pool_usage.data);
+  }
+  if (IREE_UNLIKELY(capabilities.max_allocation_size != 0 &&
+                    allocation_size > capabilities.max_allocation_size)) {
+    return iree_make_status(
+        IREE_STATUS_RESOURCE_EXHAUSTED,
+        "queue_alloca allocation_size %" PRIdsz
+        " exceeds allocation pool max_allocation_size %" PRIdsz,
+        allocation_size, capabilities.max_allocation_size);
+  }
+
+  iree_hal_buffer_placement_t placement = {
+      .device = queue->logical_device,
+      .queue_affinity = queue->queue_affinity,
+      .flags = IREE_HAL_BUFFER_PLACEMENT_FLAG_ASYNCHRONOUS,
+  };
+  if (iree_all_bits_set(flags, IREE_HAL_ALLOCA_FLAG_INDETERMINATE_LIFETIME)) {
+    placement.flags |= IREE_HAL_BUFFER_PLACEMENT_FLAG_INDETERMINATE_LIFETIME;
+  }
+
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_transient_buffer_create(
+      placement, *params, allocation_size, allocation_size,
+      queue->transient_buffer_pool, out_buffer));
+  uint64_t session_id = 0;
+  const uint64_t allocation_id =
+      iree_hal_amdgpu_logical_device_allocate_profile_memory_allocation_id(
+          queue->logical_device, &session_id);
+  if (allocation_id != 0) {
+    iree_hal_amdgpu_transient_buffer_set_profile_allocation(
+        *out_buffer, session_id, allocation_id);
+  }
+  *out_allocation_pool = allocation_pool;
+  return iree_ok_status();
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_acquire_alloca_reservation(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    iree_hal_pool_t* allocation_pool, iree_hal_buffer_params_t params,
+    iree_device_size_t allocation_size, iree_hal_alloca_flags_t flags,
+    iree_hal_pool_reserve_flags_t reserve_flags, iree_hal_buffer_t* buffer,
+    iree_hal_amdgpu_alloca_reservation_t* out_reservation) {
+  IREE_ASSERT_ARGUMENT(out_reservation);
+  memset(out_reservation, 0, sizeof(*out_reservation));
+  out_reservation->readiness = IREE_HAL_AMDGPU_ALLOCA_RESERVATION_READY;
+  out_reservation->acquire_result = IREE_HAL_POOL_ACQUIRE_EXHAUSTED;
+  out_reservation->wait_resolution = *resolution;
+
+  iree_hal_amdgpu_fixed_frontier_t requester_frontier_storage;
+  const iree_async_frontier_t* requester_frontier =
+      iree_hal_amdgpu_host_queue_pool_requester_frontier(
+          queue, resolution, &requester_frontier_storage);
+
+  IREE_RETURN_IF_ERROR(iree_hal_pool_acquire_reservation(
+      allocation_pool, allocation_size,
+      params.min_alignment ? params.min_alignment : 1, requester_frontier,
+      reserve_flags, &out_reservation->reservation,
+      &out_reservation->acquire_info, &out_reservation->acquire_result));
+
+  iree_hal_profile_memory_event_flags_t reserve_event_flags =
+      IREE_HAL_PROFILE_MEMORY_EVENT_FLAG_QUEUE_OPERATION;
+  if (out_reservation->acquire_result == IREE_HAL_POOL_ACQUIRE_OK_NEEDS_WAIT) {
+    reserve_event_flags |= IREE_HAL_PROFILE_MEMORY_EVENT_FLAG_WAIT_FRONTIER;
+  } else if (out_reservation->acquire_result ==
+                 IREE_HAL_POOL_ACQUIRE_EXHAUSTED ||
+             out_reservation->acquire_result ==
+                 IREE_HAL_POOL_ACQUIRE_OVER_BUDGET) {
+    reserve_event_flags |= IREE_HAL_PROFILE_MEMORY_EVENT_FLAG_WAIT_NOTIFICATION;
+  }
+  iree_hal_amdgpu_host_queue_record_memory_event(
+      queue, IREE_HAL_PROFILE_MEMORY_EVENT_TYPE_POOL_RESERVE,
+      reserve_event_flags, out_reservation->acquire_result, allocation_pool,
+      params, buffer,
+      out_reservation->acquire_result == IREE_HAL_POOL_ACQUIRE_EXHAUSTED ||
+              out_reservation->acquire_result ==
+                  IREE_HAL_POOL_ACQUIRE_OVER_BUDGET
+          ? NULL
+          : &out_reservation->reservation,
+      allocation_size, /*submission_id=*/0,
+      out_reservation->acquire_info.wait_frontier
+          ? out_reservation->acquire_info.wait_frontier->entry_count
+          : 0);
+
+  switch (out_reservation->acquire_result) {
+    case IREE_HAL_POOL_ACQUIRE_OK:
+    case IREE_HAL_POOL_ACQUIRE_OK_FRESH:
+      return iree_ok_status();
+    case IREE_HAL_POOL_ACQUIRE_OK_NEEDS_WAIT:
+      if (!iree_all_bits_set(flags,
+                             IREE_HAL_ALLOCA_FLAG_ALLOW_POOL_WAIT_FRONTIER)) {
+        iree_hal_pool_release_reservation(
+            allocation_pool, &out_reservation->reservation,
+            out_reservation->acquire_info.wait_frontier);
+        iree_hal_amdgpu_host_queue_record_memory_event(
+            queue, IREE_HAL_PROFILE_MEMORY_EVENT_TYPE_POOL_RELEASE,
+            IREE_HAL_PROFILE_MEMORY_EVENT_FLAG_QUEUE_OPERATION |
+                IREE_HAL_PROFILE_MEMORY_EVENT_FLAG_WAIT_FRONTIER,
+            out_reservation->acquire_result, allocation_pool, params, buffer,
+            &out_reservation->reservation, allocation_size,
+            /*submission_id=*/0,
+            out_reservation->acquire_info.wait_frontier
+                ? out_reservation->acquire_info.wait_frontier->entry_count
+                : 0);
+        return iree_make_status(
+            IREE_STATUS_RESOURCE_EXHAUSTED,
+            "queue_alloca recycled pool memory requires "
+            "IREE_HAL_ALLOCA_FLAG_ALLOW_POOL_WAIT_FRONTIER");
+      }
+      // A waitable pool reservation is legal whenever the HAL alloca flag
+      // permits one. Appending device-side barriers is only one representation;
+      // non-local, over-capacity, or forced-DEFER frontiers must route to the
+      // cold host-gated memory-readiness path.
+      if (iree_hal_amdgpu_host_queue_append_pool_wait_frontier_barriers(
+              queue, requester_frontier,
+              out_reservation->acquire_info.wait_frontier,
+              &out_reservation->wait_resolution)) {
+        out_reservation->readiness = IREE_HAL_AMDGPU_ALLOCA_RESERVATION_READY;
+      } else {
+        out_reservation->readiness =
+            IREE_HAL_AMDGPU_ALLOCA_RESERVATION_NEEDS_FRONTIER_WAIT;
+        iree_hal_amdgpu_host_queue_record_memory_event(
+            queue, IREE_HAL_PROFILE_MEMORY_EVENT_TYPE_POOL_WAIT,
+            IREE_HAL_PROFILE_MEMORY_EVENT_FLAG_QUEUE_OPERATION |
+                IREE_HAL_PROFILE_MEMORY_EVENT_FLAG_WAIT_FRONTIER,
+            out_reservation->acquire_result, allocation_pool, params, buffer,
+            &out_reservation->reservation, allocation_size,
+            /*submission_id=*/0,
+            out_reservation->acquire_info.wait_frontier
+                ? out_reservation->acquire_info.wait_frontier->entry_count
+                : 0);
+      }
+      return iree_ok_status();
+    case IREE_HAL_POOL_ACQUIRE_EXHAUSTED:
+    case IREE_HAL_POOL_ACQUIRE_OVER_BUDGET:
+      if (out_reservation->acquire_result == IREE_HAL_POOL_ACQUIRE_EXHAUSTED &&
+          iree_all_bits_set(out_reservation->acquire_info.flags,
+                            IREE_HAL_POOL_ACQUIRE_FLAG_GROWTH_REQUIRED)) {
+        out_reservation->readiness =
+            IREE_HAL_AMDGPU_ALLOCA_RESERVATION_NEEDS_POOL_GROWTH;
+      } else {
+        out_reservation->readiness =
+            IREE_HAL_AMDGPU_ALLOCA_RESERVATION_NEEDS_POOL_NOTIFICATION;
+      }
+      iree_hal_amdgpu_host_queue_record_memory_event(
+          queue, IREE_HAL_PROFILE_MEMORY_EVENT_TYPE_POOL_WAIT,
+          IREE_HAL_PROFILE_MEMORY_EVENT_FLAG_QUEUE_OPERATION |
+              IREE_HAL_PROFILE_MEMORY_EVENT_FLAG_WAIT_NOTIFICATION,
+          out_reservation->acquire_result, allocation_pool, params, buffer,
+          /*reservation=*/NULL, allocation_size, /*submission_id=*/0,
+          /*frontier_entry_count=*/0);
+      return iree_ok_status();
+    default:
+      return iree_make_status(IREE_STATUS_INTERNAL,
+                              "unrecognized pool acquire result %u",
+                              out_reservation->acquire_result);
+  }
+}
+
+static const iree_async_frontier_t*
+iree_hal_amdgpu_alloca_reservation_failure_frontier(
+    const iree_hal_amdgpu_alloca_reservation_t* alloca_reservation) {
+  return alloca_reservation->acquire_result ==
+                 IREE_HAL_POOL_ACQUIRE_OK_NEEDS_WAIT
+             ? alloca_reservation->acquire_info.wait_frontier
+             : NULL;
+}
+
+static iree_hal_amdgpu_host_queue_profile_event_info_t
+iree_hal_amdgpu_host_queue_alloca_profile_event_info(
+    iree_hal_buffer_t* buffer) {
+  return (iree_hal_amdgpu_host_queue_profile_event_info_t){
+      .type = IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_ALLOCA,
+      .allocation_id =
+          iree_hal_amdgpu_host_queue_memory_profile_allocation_id(buffer),
+      .payload_length = iree_hal_buffer_byte_length(buffer),
+      .operation_count = 1,
+  };
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_materialize_alloca_reservation(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_alloca_reservation_t* alloca_reservation,
+    iree_hal_pool_t* allocation_pool, iree_hal_buffer_params_t params,
+    iree_hal_buffer_t* buffer,
+    iree_hal_amdgpu_alloca_materialization_t* out_materialization) {
+  IREE_ASSERT_ARGUMENT(out_materialization);
+  memset(out_materialization, 0, sizeof(*out_materialization));
+  out_materialization->reservation.acquire_result =
+      IREE_HAL_POOL_ACQUIRE_EXHAUSTED;
+
+  iree_hal_buffer_t* backing_buffer = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_pool_materialize_reservation(
+      allocation_pool, params, &alloca_reservation->reservation,
+      IREE_HAL_POOL_MATERIALIZE_FLAG_NONE, &backing_buffer));
+
+  out_materialization->reservation = *alloca_reservation;
+  out_materialization->backing_buffer = backing_buffer;
+  iree_hal_amdgpu_host_queue_record_memory_event(
+      queue, IREE_HAL_PROFILE_MEMORY_EVENT_TYPE_POOL_MATERIALIZE,
+      IREE_HAL_PROFILE_MEMORY_EVENT_FLAG_QUEUE_OPERATION,
+      alloca_reservation->acquire_result, allocation_pool, params, buffer,
+      &alloca_reservation->reservation, iree_hal_buffer_byte_length(buffer),
+      /*submission_id=*/0,
+      alloca_reservation->acquire_info.wait_frontier
+          ? alloca_reservation->acquire_info.wait_frontier->entry_count
+          : 0);
+  return iree_ok_status();
+}
+
+void iree_hal_amdgpu_host_queue_release_alloca_materialization(
+    iree_hal_pool_t* allocation_pool,
+    iree_hal_amdgpu_alloca_materialization_t* materialization) {
+  if (!materialization) return;
+  iree_hal_buffer_release(materialization->backing_buffer);
+  materialization->backing_buffer = NULL;
+  switch (materialization->reservation.acquire_result) {
+    case IREE_HAL_POOL_ACQUIRE_OK:
+    case IREE_HAL_POOL_ACQUIRE_OK_FRESH:
+    case IREE_HAL_POOL_ACQUIRE_OK_NEEDS_WAIT:
+      iree_hal_pool_release_reservation(
+          allocation_pool, &materialization->reservation.reservation,
+          iree_hal_amdgpu_alloca_reservation_failure_frontier(
+              &materialization->reservation));
+      materialization->reservation.acquire_result =
+          IREE_HAL_POOL_ACQUIRE_EXHAUSTED;
+      break;
+    default:
+      break;
+  }
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_submit_alloca_materialization(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_alloca_materialization_t* materialization,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_pool_t* allocation_pool, iree_hal_buffer_params_t params,
+    iree_hal_buffer_t* buffer,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    bool* out_ready) {
+  IREE_ASSERT_ARGUMENT(materialization);
+  IREE_ASSERT_ARGUMENT(out_ready);
+  *out_ready = false;
+
+  iree_hal_resource_t* operation_resources[1] = {
+      (iree_hal_resource_t*)buffer,
+  };
+  iree_hal_amdgpu_host_queue_profile_event_info_t profile_event_info =
+      iree_hal_amdgpu_host_queue_alloca_profile_event_info(buffer);
+  iree_hal_amdgpu_host_queue_barrier_submission_t submission;
+  iree_status_t status =
+      iree_hal_amdgpu_host_queue_try_begin_barrier_submission(
+          queue, &materialization->reservation.wait_resolution,
+          signal_semaphore_list, IREE_ARRAYSIZE(operation_resources),
+          &profile_event_info, out_ready, &submission);
+  if (!iree_status_is_ok(status) || !*out_ready) {
+    iree_hal_amdgpu_host_queue_record_memory_event(
+        queue, IREE_HAL_PROFILE_MEMORY_EVENT_TYPE_POOL_RELEASE,
+        IREE_HAL_PROFILE_MEMORY_EVENT_FLAG_QUEUE_OPERATION,
+        materialization->reservation.acquire_result, allocation_pool, params,
+        buffer, &materialization->reservation.reservation,
+        iree_hal_buffer_byte_length(buffer), /*submission_id=*/0,
+        materialization->reservation.acquire_info.wait_frontier
+            ? materialization->reservation.acquire_info.wait_frontier
+                  ->entry_count
+            : 0);
+    iree_hal_amdgpu_host_queue_release_alloca_materialization(allocation_pool,
+                                                              materialization);
+    return status;
+  }
+
+  iree_hal_amdgpu_transient_buffer_attach_reservation(
+      buffer, allocation_pool, &materialization->reservation.reservation);
+  iree_hal_amdgpu_transient_buffer_stage_backing(
+      buffer, materialization->backing_buffer);
+  materialization->backing_buffer = NULL;
+  const uint64_t submission_epoch =
+      iree_hal_amdgpu_host_queue_finish_barrier_submission(
+          queue, &materialization->reservation.wait_resolution,
+          signal_semaphore_list,
+          (iree_hal_amdgpu_reclaim_action_t){
+              .fn = iree_hal_amdgpu_host_queue_commit_transient_buffer,
+              .user_data = buffer,
+          },
+          operation_resources, IREE_ARRAYSIZE(operation_resources),
+          &profile_event_info,
+          iree_hal_amdgpu_host_queue_post_commit_callback_null(),
+          /*resource_set=*/NULL, submission_flags, &submission);
+  profile_event_info.submission_id = submission_epoch;
+  iree_hal_amdgpu_host_queue_record_memory_event(
+      queue, IREE_HAL_PROFILE_MEMORY_EVENT_TYPE_QUEUE_ALLOCA,
+      IREE_HAL_PROFILE_MEMORY_EVENT_FLAG_QUEUE_OPERATION,
+      materialization->reservation.acquire_result, allocation_pool, params,
+      buffer, &materialization->reservation.reservation,
+      iree_hal_buffer_byte_length(buffer), submission_epoch,
+      materialization->reservation.acquire_info.wait_frontier
+          ? materialization->reservation.acquire_info.wait_frontier->entry_count
+          : 0);
+  iree_hal_amdgpu_host_queue_record_profile_queue_event(
+      queue, &materialization->reservation.wait_resolution,
+      signal_semaphore_list, &profile_event_info);
+  materialization->reservation.acquire_result = IREE_HAL_POOL_ACQUIRE_EXHAUSTED;
+  return status;
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_submit_alloca_reservation(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_alloca_reservation_t* alloca_reservation,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_pool_t* allocation_pool, iree_hal_buffer_params_t params,
+    iree_hal_buffer_t* buffer,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    bool* out_ready) {
+  IREE_ASSERT_ARGUMENT(out_ready);
+  *out_ready = false;
+  const iree_async_frontier_t* reservation_failure_frontier =
+      iree_hal_amdgpu_alloca_reservation_failure_frontier(alloca_reservation);
+
+  iree_hal_resource_t* operation_resources[1] = {
+      (iree_hal_resource_t*)buffer,
+  };
+  iree_hal_amdgpu_host_queue_profile_event_info_t profile_event_info =
+      iree_hal_amdgpu_host_queue_alloca_profile_event_info(buffer);
+  iree_hal_amdgpu_host_queue_barrier_submission_t submission;
+  iree_status_t status =
+      iree_hal_amdgpu_host_queue_try_begin_barrier_submission(
+          queue, &alloca_reservation->wait_resolution, signal_semaphore_list,
+          IREE_ARRAYSIZE(operation_resources), &profile_event_info, out_ready,
+          &submission);
+  if (!iree_status_is_ok(status) || !*out_ready) {
+    iree_hal_pool_release_reservation(allocation_pool,
+                                      &alloca_reservation->reservation,
+                                      reservation_failure_frontier);
+    iree_hal_amdgpu_host_queue_record_memory_event(
+        queue, IREE_HAL_PROFILE_MEMORY_EVENT_TYPE_POOL_RELEASE,
+        IREE_HAL_PROFILE_MEMORY_EVENT_FLAG_QUEUE_OPERATION,
+        alloca_reservation->acquire_result, allocation_pool, params, buffer,
+        &alloca_reservation->reservation, iree_hal_buffer_byte_length(buffer),
+        /*submission_id=*/0,
+        alloca_reservation->acquire_info.wait_frontier
+            ? alloca_reservation->acquire_info.wait_frontier->entry_count
+            : 0);
+    return status;
+  }
+
+  iree_hal_buffer_t* backing_buffer = NULL;
+  status = iree_hal_pool_materialize_reservation(
+      allocation_pool, params, &alloca_reservation->reservation,
+      IREE_HAL_POOL_MATERIALIZE_FLAG_NONE, &backing_buffer);
+  if (iree_status_is_ok(status)) {
+    iree_hal_amdgpu_transient_buffer_attach_reservation(
+        buffer, allocation_pool, &alloca_reservation->reservation);
+    iree_hal_amdgpu_transient_buffer_stage_backing(buffer, backing_buffer);
+    iree_hal_amdgpu_host_queue_record_memory_event(
+        queue, IREE_HAL_PROFILE_MEMORY_EVENT_TYPE_POOL_MATERIALIZE,
+        IREE_HAL_PROFILE_MEMORY_EVENT_FLAG_QUEUE_OPERATION,
+        alloca_reservation->acquire_result, allocation_pool, params, buffer,
+        &alloca_reservation->reservation, iree_hal_buffer_byte_length(buffer),
+        /*submission_id=*/0,
+        alloca_reservation->acquire_info.wait_frontier
+            ? alloca_reservation->acquire_info.wait_frontier->entry_count
+            : 0);
+    backing_buffer = NULL;
+  }
+  iree_hal_buffer_release(backing_buffer);
+
+  uint64_t submission_epoch = 0;
+  if (iree_status_is_ok(status)) {
+    submission_epoch = iree_hal_amdgpu_host_queue_finish_barrier_submission(
+        queue, &alloca_reservation->wait_resolution, signal_semaphore_list,
+        (iree_hal_amdgpu_reclaim_action_t){
+            .fn = iree_hal_amdgpu_host_queue_commit_transient_buffer,
+            .user_data = buffer,
+        },
+        operation_resources, IREE_ARRAYSIZE(operation_resources),
+        &profile_event_info,
+        iree_hal_amdgpu_host_queue_post_commit_callback_null(),
+        /*resource_set=*/NULL, submission_flags, &submission);
+    profile_event_info.submission_id = submission_epoch;
+    iree_hal_amdgpu_host_queue_record_memory_event(
+        queue, IREE_HAL_PROFILE_MEMORY_EVENT_TYPE_QUEUE_ALLOCA,
+        IREE_HAL_PROFILE_MEMORY_EVENT_FLAG_QUEUE_OPERATION,
+        alloca_reservation->acquire_result, allocation_pool, params, buffer,
+        &alloca_reservation->reservation, iree_hal_buffer_byte_length(buffer),
+        submission_epoch,
+        alloca_reservation->acquire_info.wait_frontier
+            ? alloca_reservation->acquire_info.wait_frontier->entry_count
+            : 0);
+    iree_hal_amdgpu_host_queue_record_profile_queue_event(
+        queue, &alloca_reservation->wait_resolution, signal_semaphore_list,
+        &profile_event_info);
+  }
+
+  if (!iree_status_is_ok(status)) {
+    iree_hal_amdgpu_host_queue_fail_barrier_submission(queue, &submission);
+    iree_hal_pool_release_reservation(allocation_pool,
+                                      &alloca_reservation->reservation,
+                                      reservation_failure_frontier);
+    iree_hal_amdgpu_host_queue_record_memory_event(
+        queue, IREE_HAL_PROFILE_MEMORY_EVENT_TYPE_POOL_RELEASE,
+        IREE_HAL_PROFILE_MEMORY_EVENT_FLAG_QUEUE_OPERATION,
+        alloca_reservation->acquire_result, allocation_pool, params, buffer,
+        &alloca_reservation->reservation, iree_hal_buffer_byte_length(buffer),
+        /*submission_id=*/0,
+        alloca_reservation->acquire_info.wait_frontier
+            ? alloca_reservation->acquire_info.wait_frontier->entry_count
+            : 0);
+  }
+  return status;
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_submit_dealloca(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_buffer_t* buffer,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    bool* out_ready) {
+  IREE_ASSERT_ARGUMENT(out_ready);
+  *out_ready = false;
+  iree_hal_resource_t* operation_resources[1] = {
+      (iree_hal_resource_t*)buffer,
+  };
+  iree_hal_amdgpu_host_queue_profile_event_info_t profile_event_info = {
+      .type = IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_DEALLOCA,
+      .allocation_id =
+          iree_hal_amdgpu_host_queue_memory_profile_allocation_id(buffer),
+      .payload_length = iree_hal_buffer_byte_length(buffer),
+      .operation_count = 1,
+  };
+  iree_hal_amdgpu_host_queue_barrier_submission_t submission;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_try_begin_barrier_submission(
+      queue, resolution, signal_semaphore_list,
+      IREE_ARRAYSIZE(operation_resources), &profile_event_info, out_ready,
+      &submission));
+  if (!*out_ready) return iree_ok_status();
+
+  iree_hal_amdgpu_host_queue_release_reservation_state_t release_state = {
+      .queue = queue,
+      .buffer = buffer,
+  };
+  const uint64_t submission_epoch =
+      iree_hal_amdgpu_host_queue_finish_barrier_submission(
+          queue, resolution, signal_semaphore_list,
+          (iree_hal_amdgpu_reclaim_action_t){
+              .fn = iree_hal_amdgpu_host_queue_decommit_transient_buffer,
+              .user_data = buffer,
+          },
+          operation_resources, IREE_ARRAYSIZE(operation_resources),
+          &profile_event_info,
+          (iree_hal_amdgpu_host_queue_post_commit_callback_t){
+              .fn =
+                  iree_hal_amdgpu_host_queue_release_transient_buffer_reservation,
+              .user_data = &release_state,
+          },
+          /*resource_set=*/NULL, submission_flags, &submission);
+  profile_event_info.submission_id = submission_epoch;
+  iree_hal_buffer_params_t params = {
+      .type = iree_hal_buffer_memory_type(buffer),
+      .access = iree_hal_buffer_allowed_access(buffer),
+      .usage = iree_hal_buffer_allowed_usage(buffer),
+  };
+  iree_hal_amdgpu_host_queue_record_memory_event(
+      queue, IREE_HAL_PROFILE_MEMORY_EVENT_TYPE_QUEUE_DEALLOCA,
+      IREE_HAL_PROFILE_MEMORY_EVENT_FLAG_QUEUE_OPERATION, UINT32_MAX,
+      /*pool=*/NULL, params, buffer, /*reservation=*/NULL,
+      iree_hal_buffer_byte_length(buffer), submission_epoch,
+      /*frontier_entry_count=*/0);
+  iree_hal_amdgpu_host_queue_record_profile_queue_event(
+      queue, resolution, signal_semaphore_list, &profile_event_info);
+  return iree_ok_status();
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_memory.h b/runtime/src/iree/hal/drivers/amdgpu/host_queue_memory.h
new file mode 100644
index 0000000..2f28d07
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_memory.h
@@ -0,0 +1,117 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_MEMORY_H_
+#define IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_MEMORY_H_
+
+#include "iree/hal/drivers/amdgpu/host_queue_submission.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+typedef enum iree_hal_amdgpu_alloca_reservation_readiness_e {
+  // The reservation can be materialized and submitted immediately.
+  IREE_HAL_AMDGPU_ALLOCA_RESERVATION_READY = 0,
+  // The reservation must wait for a pool death frontier before materialization.
+  IREE_HAL_AMDGPU_ALLOCA_RESERVATION_NEEDS_FRONTIER_WAIT = 1,
+  // The pool needs cold backing growth before another queue-locked reservation
+  // attempt can succeed.
+  IREE_HAL_AMDGPU_ALLOCA_RESERVATION_NEEDS_POOL_GROWTH = 2,
+  // The pool is exhausted or over budget and needs a release notification
+  // retry.
+  IREE_HAL_AMDGPU_ALLOCA_RESERVATION_NEEDS_POOL_NOTIFICATION = 3,
+} iree_hal_amdgpu_alloca_reservation_readiness_t;
+
+typedef struct iree_hal_amdgpu_alloca_reservation_t {
+  // Scheduler action required before the reservation can be submitted.
+  iree_hal_amdgpu_alloca_reservation_readiness_t readiness;
+  // Pool acquisition result that produced |reservation|.
+  iree_hal_pool_acquire_result_t acquire_result;
+  // Pool-owned byte range reserved for this alloca operation.
+  iree_hal_pool_reservation_t reservation;
+  // Borrowed metadata returned with |reservation|.
+  iree_hal_pool_acquire_info_t acquire_info;
+  // Queue wait resolution to use when publishing the alloca signal.
+  iree_hal_amdgpu_wait_resolution_t wait_resolution;
+} iree_hal_amdgpu_alloca_reservation_t;
+
+typedef struct iree_hal_amdgpu_alloca_materialization_t {
+  // Ready pool reservation that produced |backing_buffer|.
+  iree_hal_amdgpu_alloca_reservation_t reservation;
+  // Pool-backed buffer wrapper to stage into the transient alloca buffer.
+  iree_hal_buffer_t* backing_buffer;
+} iree_hal_amdgpu_alloca_materialization_t;
+
+// Resolves the allocation pool, validates/canonicalizes the request, and
+// creates the transient wrapper returned from queue_alloca.
+iree_status_t iree_hal_amdgpu_host_queue_prepare_alloca_wrapper(
+    iree_hal_amdgpu_host_queue_t* queue, iree_hal_pool_t* pool,
+    iree_hal_buffer_params_t* params, iree_device_size_t allocation_size,
+    iree_hal_alloca_flags_t flags, iree_hal_pool_t** out_allocation_pool,
+    iree_hal_buffer_t** out_buffer);
+
+// Attempts to reserve bytes from |allocation_pool| and classifies the result
+// as immediate, death-frontier-waitable, or notification-retry-required.
+iree_status_t iree_hal_amdgpu_host_queue_acquire_alloca_reservation(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    iree_hal_pool_t* allocation_pool, iree_hal_buffer_params_t params,
+    iree_device_size_t allocation_size, iree_hal_alloca_flags_t flags,
+    iree_hal_pool_reserve_flags_t reserve_flags, iree_hal_buffer_t* buffer,
+    iree_hal_amdgpu_alloca_reservation_t* out_reservation);
+
+// Materializes a ready reservation. Does not require submission_mutex.
+iree_status_t iree_hal_amdgpu_host_queue_materialize_alloca_reservation(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_alloca_reservation_t* alloca_reservation,
+    iree_hal_pool_t* allocation_pool, iree_hal_buffer_params_t params,
+    iree_hal_buffer_t* buffer,
+    iree_hal_amdgpu_alloca_materialization_t* out_materialization);
+
+// Releases a materialized reservation that was not submitted.
+void iree_hal_amdgpu_host_queue_release_alloca_materialization(
+    iree_hal_pool_t* allocation_pool,
+    iree_hal_amdgpu_alloca_materialization_t* materialization);
+
+// Stages a materialized reservation on |buffer| and submits the queue barrier
+// that commits the transient buffer on completion. Caller must hold
+// submission_mutex.
+iree_status_t iree_hal_amdgpu_host_queue_submit_alloca_materialization(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_alloca_materialization_t* materialization,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_pool_t* allocation_pool, iree_hal_buffer_params_t params,
+    iree_hal_buffer_t* buffer,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    bool* out_ready);
+
+// Materializes a ready reservation, stages it on |buffer|, and submits the
+// queue barrier that commits the transient buffer on completion. Caller must
+// hold submission_mutex.
+iree_status_t iree_hal_amdgpu_host_queue_submit_alloca_reservation(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_alloca_reservation_t* alloca_reservation,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_pool_t* allocation_pool, iree_hal_buffer_params_t params,
+    iree_hal_buffer_t* buffer,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    bool* out_ready);
+
+// Submits the queue barrier that decommits a transient buffer on completion.
+iree_status_t iree_hal_amdgpu_host_queue_submit_dealloca(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_buffer_t* buffer,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    bool* out_ready);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_MEMORY_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_pending.c b/runtime/src/iree/hal/drivers/amdgpu/host_queue_pending.c
new file mode 100644
index 0000000..c86b56a
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_pending.c
@@ -0,0 +1,1391 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <string.h>
+
+#include "iree/async/frontier_tracker.h"
+#include "iree/async/notification.h"
+#include "iree/async/operations/scheduling.h"
+#include "iree/base/threading/notification.h"
+#include "iree/hal/drivers/amdgpu/host_queue_memory.h"
+#include "iree/hal/drivers/amdgpu/host_queue_pending_operation.h"
+#include "iree/hal/drivers/amdgpu/host_queue_waits.h"
+#include "iree/hal/drivers/amdgpu/semaphore.h"
+#include "iree/hal/drivers/amdgpu/transient_buffer.h"
+
+//===----------------------------------------------------------------------===//
+// Pending operations (deferred submission)
+//===----------------------------------------------------------------------===//
+//
+// LOCKING PROTOCOL
+//
+// submission_mutex protects all submission-path state: AQL ring reservation,
+// kernarg allocation, packet emission, commit_signals, frontier mutation,
+// notification ring push, and the pending list (link/unlink).
+//
+// The completion thread (drain, error check) does NOT acquire
+// submission_mutex. It reads the notification ring (SPSC consumer) and the
+// atomic error_status.
+//
+// Deferred operations use a two-phase protocol:
+//
+//   Phase 1 (under submission_mutex): resolve_waits, allocate pending_op,
+//     capture operation parameters, link to pending list.
+//
+//   Phase 2 (WITHOUT submission_mutex): register timepoints via enqueue_waits.
+//     Timepoint callbacks may fire synchronously during acquire_timepoint
+//     (when the semaphore value is already reached or the semaphore is already
+//     failed). The last callback to fire calls pending_op_issue or
+//     pending_op_fail, both of which acquire submission_mutex internally.
+//     This is safe because Phase 1 released the mutex before Phase 2 began.
+//
+// pending_op_issue: acquires submission_mutex to emit AQL packets, transfer
+//   retained resources to the reclaim ring, commit signals, and unlink.
+//
+// pending_op_fail: acquires submission_mutex to unlink. Semaphore failure
+//   and resource release happen outside the lock.
+//
+// pending_op_destroy_under_lock: for capture-time failures (arena allocation
+//   errors after pending_op_allocate). Caller already holds submission_mutex.
+//   Does NOT re-acquire; unlinks and cleans up directly.
+
+// Per-wait timepoint entry, arena-allocated one per unsatisfied wait. The
+// timepoint callback decrements the operation's atomic wait counter; the last
+// callback to fire issues or fails the operation.
+struct iree_hal_amdgpu_wait_entry_t {
+  // Async semaphore timepoint registration owned by this wait entry.
+  iree_async_semaphore_timepoint_t timepoint;
+  // Pending operation whose wait_count is decremented by this callback.
+  iree_hal_amdgpu_pending_op_t* operation;
+  // Set to 1 after the callback's final access to this entry/op completes.
+  // Queue shutdown spins on this for callbacks that were already detached from
+  // the semaphore before cancel_timepoint() ran.
+  iree_atomic_int32_t callback_complete;
+};
+
+typedef enum iree_hal_amdgpu_alloca_memory_wait_kind_e {
+  // No active memory wait or held reservation.
+  IREE_HAL_AMDGPU_ALLOCA_MEMORY_WAIT_NONE = 0,
+  // Waiting for a copied pool death frontier while holding a reservation.
+  IREE_HAL_AMDGPU_ALLOCA_MEMORY_WAIT_FRONTIER = 1,
+  // Performing cold pool backing growth before retrying reservation.
+  IREE_HAL_AMDGPU_ALLOCA_MEMORY_WAIT_POOL_GROWTH = 2,
+  // Waiting for a pool release notification before retrying reservation.
+  IREE_HAL_AMDGPU_ALLOCA_MEMORY_WAIT_POOL_NOTIFICATION = 3,
+} iree_hal_amdgpu_alloca_memory_wait_kind_t;
+
+// Cold-path alloca memory-readiness wait. Allocated inside a pending op's arena
+// only after user semaphore waits have resolved and the pool cannot produce
+// immediately-usable bytes.
+struct iree_hal_amdgpu_alloca_memory_wait_t {
+  // Active wait source.
+  iree_hal_amdgpu_alloca_memory_wait_kind_t kind;
+
+  // Set to 1 after the callback's final access to this wait/op completes.
+  iree_atomic_int32_t callback_complete;
+
+  // Held-reservation wait state blocked on a pool death frontier.
+  struct {
+    // Queue-owned reservation held while waiting for its death frontier.
+    iree_hal_pool_reservation_t reservation;
+
+    // Arena-owned copy of the pool-owned death frontier.
+    iree_async_frontier_t* wait_frontier;
+
+    // Tracker waiter storage for |wait_frontier|.
+    iree_async_frontier_waiter_t waiter;
+  } frontier;
+
+  // Cold pool backing-growth retry state for reservation attempts.
+  struct {
+    // Queue-order frontier snapshot used for cold reservation pre-growth.
+    iree_hal_amdgpu_fixed_frontier_t requester_frontier;
+
+    // Materialized pool reservation prepared before queue submission retry.
+    iree_hal_amdgpu_alloca_materialization_t materialization;
+  } pool_growth;
+
+  // Pool notification retry state for reservation attempts.
+  struct {
+    // Borrowed notification returned by the pool.
+    iree_async_notification_t* notification;
+
+    // Notification epoch observed before the reservation retry.
+    uint32_t wait_token;
+
+    // Whether the pre-submit observation scope is still held. Once submit
+    // returns, the submitted wait operation owns its own observation scope and
+    // this bridge scope is released.
+    bool pre_submit_observation_held;
+
+    // Wait operations rotated so a callback can arm a retry before returning.
+    iree_async_notification_wait_operation_t wait_ops[2];
+
+    // Index of the active wait operation in |wait_ops|.
+    uint8_t wait_slot;
+  } pool_notification;
+};
+
+static void iree_hal_amdgpu_pending_op_issue(iree_hal_amdgpu_pending_op_t* op);
+static void iree_hal_amdgpu_pending_op_capacity_post_drain(void* user_data);
+static void iree_hal_amdgpu_pending_op_fail(iree_hal_amdgpu_pending_op_t* op,
+                                            iree_status_t status);
+// Links a pending op into the queue's pending list. Caller must hold
+// submission_mutex.
+static void iree_hal_amdgpu_pending_op_link(iree_hal_amdgpu_pending_op_t* op) {
+  iree_hal_amdgpu_host_queue_t* queue = op->queue;
+  op->next = queue->pending_head;
+  op->prev_next = &queue->pending_head;
+  if (queue->pending_head) {
+    queue->pending_head->prev_next = &op->next;
+  }
+  queue->pending_head = op;
+}
+
+// Unlinks a pending op from the queue's pending list. Caller must hold
+// submission_mutex.
+static void iree_hal_amdgpu_pending_op_unlink(
+    iree_hal_amdgpu_pending_op_t* op) {
+  *op->prev_next = op->next;
+  if (op->next) {
+    op->next->prev_next = op->prev_next;
+  }
+  op->next = NULL;
+  op->prev_next = NULL;
+}
+
+// Retains a resource and appends it to the pending op's retained_resources
+// array. The caller must have allocated sufficient capacity in the array
+// via the max_resource_count parameter to pending_op_allocate.
+void iree_hal_amdgpu_pending_op_retain(iree_hal_amdgpu_pending_op_t* op,
+                                       iree_hal_resource_t* resource) {
+  if (IREE_LIKELY(resource)) {
+    iree_hal_resource_retain(resource);
+    op->retained_resources[op->retained_resource_count++] = resource;
+  }
+}
+
+// Releases all retained HAL resources in the flat array. Used on failure,
+// cancellation, and success paths where the submit helper retained the
+// resources it needs instead of consuming this pending op's refs.
+void iree_hal_amdgpu_pending_op_release_retained(
+    iree_hal_amdgpu_pending_op_t* op) {
+  for (uint16_t i = 0; i < op->retained_resource_count; ++i) {
+    iree_hal_resource_release(op->retained_resources[i]);
+  }
+  op->retained_resource_count = 0;
+}
+
+static void iree_hal_amdgpu_pending_op_release_execute_binding_resource_set(
+    iree_hal_amdgpu_pending_op_t* op) {
+  if (op->type == IREE_HAL_AMDGPU_PENDING_OP_EXECUTE) {
+    iree_hal_resource_set_free(op->execute.binding_resource_set);
+    op->execute.binding_resource_set = NULL;
+  }
+}
+
+static void iree_hal_amdgpu_pending_op_fail_host_action(
+    iree_hal_amdgpu_pending_op_t* op, iree_status_t status) {
+  if (op->type != IREE_HAL_AMDGPU_PENDING_OP_HOST_ACTION ||
+      !op->host_action.action.fn) {
+    return;
+  }
+  op->host_action.action.fn(/*entry=*/NULL, op->host_action.action.user_data,
+                            status);
+  op->host_action.action.fn = NULL;
+  op->host_action.action.user_data = NULL;
+}
+
+// Releases any queue-owned alloca memory-readiness reservation. This runs only
+// on failure/cancellation paths or after ownership has not transferred into the
+// transient buffer.
+static void iree_hal_amdgpu_pending_op_release_alloca_memory_wait(
+    iree_hal_amdgpu_pending_op_t* op) {
+  if (op->type != IREE_HAL_AMDGPU_PENDING_OP_ALLOCA) return;
+  iree_hal_amdgpu_alloca_memory_wait_t* wait = op->alloca_op.memory_wait;
+  if (!wait) return;
+
+  switch (wait->kind) {
+    case IREE_HAL_AMDGPU_ALLOCA_MEMORY_WAIT_FRONTIER:
+      iree_hal_pool_release_reservation(op->alloca_op.pool,
+                                        &wait->frontier.reservation,
+                                        wait->frontier.wait_frontier);
+      wait->kind = IREE_HAL_AMDGPU_ALLOCA_MEMORY_WAIT_NONE;
+      break;
+    case IREE_HAL_AMDGPU_ALLOCA_MEMORY_WAIT_POOL_GROWTH:
+      iree_hal_amdgpu_host_queue_release_alloca_materialization(
+          op->alloca_op.pool, &wait->pool_growth.materialization);
+      wait->kind = IREE_HAL_AMDGPU_ALLOCA_MEMORY_WAIT_NONE;
+      break;
+    case IREE_HAL_AMDGPU_ALLOCA_MEMORY_WAIT_POOL_NOTIFICATION:
+      wait->kind = IREE_HAL_AMDGPU_ALLOCA_MEMORY_WAIT_NONE;
+      break;
+    case IREE_HAL_AMDGPU_ALLOCA_MEMORY_WAIT_NONE:
+      break;
+  }
+}
+
+// Clears the queued marker for a deferred dealloca that never published a
+// completion epoch. Successful deallocas transfer ownership to the reclaim ring
+// and must not call this.
+static void iree_hal_amdgpu_pending_op_abort_unsubmitted_dealloca(
+    iree_hal_amdgpu_pending_op_t* op) {
+  if (op->type != IREE_HAL_AMDGPU_PENDING_OP_DEALLOCA) return;
+  iree_hal_amdgpu_transient_buffer_abort_dealloca(op->dealloca.buffer);
+}
+
+static bool iree_hal_amdgpu_alloca_memory_wait_callback_is_complete(
+    void* user_data) {
+  iree_hal_amdgpu_alloca_memory_wait_t* wait =
+      (iree_hal_amdgpu_alloca_memory_wait_t*)user_data;
+  return iree_atomic_load(&wait->callback_complete,
+                          iree_memory_order_acquire) != 0;
+}
+
+static void iree_hal_amdgpu_alloca_memory_wait_publish_callback_complete(
+    iree_hal_amdgpu_pending_op_t* op) {
+  iree_hal_amdgpu_alloca_memory_wait_t* wait = op->alloca_op.memory_wait;
+  iree_atomic_store(&wait->callback_complete, 1, iree_memory_order_release);
+  iree_notification_post(&op->callback_notification, IREE_ALL_WAITERS);
+}
+
+// Publishes a prepared memory-readiness wait as ARMING. The release store on
+// lifecycle_state makes the initialized sidecar fields visible to the callback
+// or cancellation path that observes the state transition.
+static void iree_hal_amdgpu_pending_op_begin_alloca_memory_wait_arming(
+    iree_hal_amdgpu_pending_op_t* op,
+    iree_hal_amdgpu_alloca_memory_wait_t* wait,
+    iree_hal_amdgpu_alloca_memory_wait_kind_t kind) {
+  wait->kind = kind;
+  iree_atomic_store(&wait->callback_complete, 1, iree_memory_order_relaxed);
+  iree_atomic_store(&op->lifecycle_state,
+                    IREE_HAL_AMDGPU_PENDING_OP_LIFECYCLE_ARMING_MEMORY_WAIT,
+                    iree_memory_order_release);
+}
+
+static iree_status_t iree_hal_amdgpu_pending_op_ensure_alloca_memory_wait(
+    iree_hal_amdgpu_pending_op_t* op,
+    iree_hal_amdgpu_alloca_memory_wait_t** out_wait) {
+  iree_hal_amdgpu_alloca_memory_wait_t* wait = op->alloca_op.memory_wait;
+  if (!wait) {
+    IREE_TRACE_ZONE_BEGIN(z0);
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_arena_allocate(&op->arena, sizeof(*wait), (void**)&wait));
+    memset(wait, 0, sizeof(*wait));
+    iree_atomic_store(&wait->callback_complete, 1, iree_memory_order_relaxed);
+    op->alloca_op.memory_wait = wait;
+    IREE_TRACE_ZONE_END(z0);
+  }
+  *out_wait = wait;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_pending_op_prepare_alloca_frontier_wait(
+    iree_hal_amdgpu_pending_op_t* op,
+    const iree_hal_amdgpu_alloca_reservation_t* alloca_reservation) {
+  const iree_async_frontier_t* wait_frontier =
+      alloca_reservation->acquire_info.wait_frontier;
+  if (IREE_UNLIKELY(!wait_frontier)) {
+    return iree_make_status(
+        IREE_STATUS_INTERNAL,
+        "queue_alloca waitable pool reservation did not provide a frontier");
+  }
+
+  iree_host_size_t wait_frontier_size = 0;
+  IREE_RETURN_IF_ERROR(iree_async_frontier_size(wait_frontier->entry_count,
+                                                &wait_frontier_size));
+  iree_hal_amdgpu_alloca_memory_wait_t* wait = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_pending_op_ensure_alloca_memory_wait(op, &wait));
+  iree_async_frontier_t* wait_frontier_copy = NULL;
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, wait_frontier->entry_count);
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_arena_allocate(&op->arena, wait_frontier_size,
+                              (void**)&wait_frontier_copy));
+
+  memcpy(wait_frontier_copy, wait_frontier, wait_frontier_size);
+  wait->frontier.reservation = alloca_reservation->reservation;
+  wait->frontier.wait_frontier = wait_frontier_copy;
+  iree_hal_amdgpu_pending_op_begin_alloca_memory_wait_arming(
+      op, wait, IREE_HAL_AMDGPU_ALLOCA_MEMORY_WAIT_FRONTIER);
+  IREE_TRACE_ZONE_END(z0);
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_pending_op_prepare_alloca_pool_growth(
+    iree_hal_amdgpu_pending_op_t* op,
+    const iree_async_frontier_t* requester_frontier) {
+  iree_hal_amdgpu_alloca_memory_wait_t* wait = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_pending_op_ensure_alloca_memory_wait(op, &wait));
+  iree_async_frontier_t* growth_frontier =
+      iree_hal_amdgpu_fixed_frontier_as_frontier(
+          &wait->pool_growth.requester_frontier);
+  iree_async_frontier_initialize(growth_frontier,
+                                 requester_frontier->entry_count);
+  memcpy(
+      growth_frontier->entries, requester_frontier->entries,
+      requester_frontier->entry_count * sizeof(requester_frontier->entries[0]));
+  memset(&wait->pool_growth.materialization, 0,
+         sizeof(wait->pool_growth.materialization));
+  wait->pool_growth.materialization.reservation.acquire_result =
+      IREE_HAL_POOL_ACQUIRE_EXHAUSTED;
+  iree_hal_amdgpu_pending_op_begin_alloca_memory_wait_arming(
+      op, wait, IREE_HAL_AMDGPU_ALLOCA_MEMORY_WAIT_POOL_GROWTH);
+  return iree_ok_status();
+}
+
+static iree_status_t
+iree_hal_amdgpu_pending_op_prepare_alloca_pool_notification_wait(
+    iree_hal_amdgpu_pending_op_t* op, iree_async_notification_t* notification,
+    uint32_t wait_token) {
+  iree_hal_amdgpu_alloca_memory_wait_t* wait = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_pending_op_ensure_alloca_memory_wait(op, &wait));
+  wait->pool_notification.notification = notification;
+  wait->pool_notification.wait_token = wait_token;
+  wait->pool_notification.pre_submit_observation_held = true;
+  wait->pool_notification.wait_slot =
+      (uint8_t)((wait->pool_notification.wait_slot + 1u) & 1u);
+  iree_hal_amdgpu_pending_op_begin_alloca_memory_wait_arming(
+      op, wait, IREE_HAL_AMDGPU_ALLOCA_MEMORY_WAIT_POOL_NOTIFICATION);
+  return iree_ok_status();
+}
+
+static void iree_hal_amdgpu_alloca_pool_notification_end_observe(
+    iree_hal_amdgpu_alloca_memory_wait_t* wait) {
+  if (wait->pool_notification.pre_submit_observation_held) {
+    wait->pool_notification.pre_submit_observation_held = false;
+    iree_async_notification_end_observe(wait->pool_notification.notification);
+  }
+}
+
+// Cancels any active alloca memory-readiness wait before destroying the op.
+static void iree_hal_amdgpu_pending_op_cancel_alloca_memory_wait(
+    iree_hal_amdgpu_pending_op_t* op) {
+  if (op->type != IREE_HAL_AMDGPU_PENDING_OP_ALLOCA) return;
+  iree_hal_amdgpu_alloca_memory_wait_t* wait = op->alloca_op.memory_wait;
+  if (!wait) return;
+
+  switch (wait->kind) {
+    case IREE_HAL_AMDGPU_ALLOCA_MEMORY_WAIT_FRONTIER: {
+      const bool cancelled = iree_async_frontier_tracker_cancel_wait(
+          op->queue->frontier_tracker, &wait->frontier.waiter);
+      if (!cancelled) {
+        iree_notification_await(
+            &op->callback_notification,
+            iree_hal_amdgpu_alloca_memory_wait_callback_is_complete, wait,
+            iree_infinite_timeout());
+      }
+      break;
+    }
+    case IREE_HAL_AMDGPU_ALLOCA_MEMORY_WAIT_POOL_NOTIFICATION: {
+      iree_hal_amdgpu_alloca_pool_notification_end_observe(wait);
+      // Shutdown is allowed to prod the pool notification: it is a broad wake,
+      // but prevents teardown from depending on a future dealloca. The callback
+      // observes the CANCELLING lifecycle state and only publishes completion.
+      iree_async_notification_signal(wait->pool_notification.notification,
+                                     INT32_MAX);
+      iree_notification_await(
+          &op->callback_notification,
+          iree_hal_amdgpu_alloca_memory_wait_callback_is_complete, wait,
+          iree_infinite_timeout());
+      break;
+    }
+    case IREE_HAL_AMDGPU_ALLOCA_MEMORY_WAIT_POOL_GROWTH:
+      iree_hal_amdgpu_host_queue_release_alloca_materialization(
+          op->alloca_op.pool, &wait->pool_growth.materialization);
+      wait->kind = IREE_HAL_AMDGPU_ALLOCA_MEMORY_WAIT_NONE;
+      break;
+    case IREE_HAL_AMDGPU_ALLOCA_MEMORY_WAIT_NONE:
+      break;
+  }
+}
+
+static bool iree_hal_amdgpu_wait_entry_callback_is_complete(void* user_data) {
+  iree_hal_amdgpu_wait_entry_t* entry =
+      (iree_hal_amdgpu_wait_entry_t*)user_data;
+  return iree_atomic_load(&entry->callback_complete,
+                          iree_memory_order_acquire) != 0;
+}
+
+static void iree_hal_amdgpu_wait_entry_publish_callback_complete(
+    iree_hal_amdgpu_wait_entry_t* entry) {
+  iree_atomic_store(&entry->callback_complete, 1, iree_memory_order_release);
+  iree_notification_post(&entry->operation->callback_notification,
+                         IREE_ALL_WAITERS);
+}
+
+static bool iree_hal_amdgpu_pending_op_wait_callbacks_are_complete(
+    void* user_data) {
+  iree_hal_amdgpu_pending_op_t* op = (iree_hal_amdgpu_pending_op_t*)user_data;
+  for (iree_host_size_t i = 0; i < op->wait_semaphore_list.count; ++i) {
+    iree_hal_amdgpu_wait_entry_t* entry = &op->wait_entries[i];
+    if (!iree_hal_amdgpu_wait_entry_callback_is_complete(entry)) {
+      return false;
+    }
+  }
+  return true;
+}
+
+// Records the first asynchronous wait failure. Takes ownership of |status|,
+// storing it for the completion owner or dropping it if another failure won.
+static void iree_hal_amdgpu_pending_op_record_error_status(
+    iree_hal_amdgpu_pending_op_t* op, iree_status_t status) {
+  if (iree_status_is_ok(status)) return;
+  intptr_t expected = 0;
+  if (!iree_atomic_compare_exchange_strong(
+          &op->error_status, &expected, (intptr_t)status,
+          iree_memory_order_acq_rel, iree_memory_order_relaxed)) {
+    iree_status_free(status);
+  }
+}
+
+static bool iree_hal_amdgpu_pending_op_mark_waits_resolved(
+    iree_hal_amdgpu_pending_op_t* op) {
+  int32_t expected_state = IREE_HAL_AMDGPU_PENDING_OP_LIFECYCLE_PENDING;
+  return iree_atomic_compare_exchange_strong(
+      &op->lifecycle_state, &expected_state,
+      IREE_HAL_AMDGPU_PENDING_OP_LIFECYCLE_COMPLETING,
+      iree_memory_order_acq_rel, iree_memory_order_acquire);
+}
+
+static void iree_hal_amdgpu_pending_op_complete_resolved_waits(
+    iree_hal_amdgpu_pending_op_t* op) {
+  iree_notification_await(
+      &op->callback_notification,
+      iree_hal_amdgpu_pending_op_wait_callbacks_are_complete, op,
+      iree_infinite_timeout());
+  iree_status_t error = (iree_status_t)iree_atomic_exchange(
+      &op->error_status, 0, iree_memory_order_acquire);
+  if (!iree_status_is_ok(error)) {
+    iree_hal_amdgpu_pending_op_fail(op, error);
+  } else {
+    iree_hal_amdgpu_pending_op_issue(op);
+  }
+}
+
+// Destroys a pending operation that failed during capture (arena allocation
+// error after pending_op_allocate but before enqueue_waits). Caller MUST hold
+// submission_mutex; the op is linked to the pending list by allocate and
+// needs the mutex for unlinking.
+//
+// Unlike pending_op_fail (which acquires submission_mutex internally), this
+// function assumes the caller already holds it. This is necessary because the
+// capture phase runs under the mutex (Phase 1 of the two-phase protocol).
+void iree_hal_amdgpu_pending_op_destroy_under_lock(
+    iree_hal_amdgpu_pending_op_t* op, iree_status_t status) {
+  iree_hal_amdgpu_pending_op_fail_host_action(op, status);
+  // Fail signal semaphores so downstream waiters get the error.
+  iree_hal_semaphore_list_fail(op->signal_semaphore_list, status);
+  iree_hal_amdgpu_pending_op_abort_unsubmitted_dealloca(op);
+  // Release any queue-owned memory reservation before releasing op resources.
+  iree_hal_amdgpu_pending_op_release_alloca_memory_wait(op);
+  iree_hal_amdgpu_pending_op_release_execute_binding_resource_set(op);
+  // Release all retained resources (signal semaphores + op resources).
+  iree_hal_amdgpu_pending_op_release_retained(op);
+  // Release wait semaphores (separately retained by the clone).
+  iree_hal_semaphore_list_release(op->wait_semaphore_list);
+  // Unlink from the pending list (caller holds submission_mutex).
+  iree_hal_amdgpu_pending_op_unlink(op);
+  // Tear down callback wake state before returning arena blocks to the pool.
+  iree_notification_deinitialize(&op->callback_notification);
+  // Return arena blocks to the pool.
+  iree_arena_deinitialize(&op->arena);
+}
+
+// Timepoint callback fired when a wait semaphore reaches its target value or
+// fails. The last resolved wait claims completion, waits for all callbacks to
+// finish touching arena-owned entries, and then issues or fails the operation.
+static void iree_hal_amdgpu_wait_entry_resolved(
+    void* user_data, iree_async_semaphore_timepoint_t* timepoint,
+    iree_status_t status) {
+  iree_hal_amdgpu_wait_entry_t* entry =
+      (iree_hal_amdgpu_wait_entry_t*)user_data;
+  iree_hal_amdgpu_pending_op_t* op = entry->operation;
+
+  iree_hal_amdgpu_pending_op_record_error_status(op, status);
+
+  int32_t previous_count =
+      iree_atomic_fetch_sub(&op->wait_count, 1, iree_memory_order_acq_rel);
+  bool owns_completion = false;
+  if (previous_count == 1) {
+    owns_completion = iree_hal_amdgpu_pending_op_mark_waits_resolved(op);
+  }
+
+  iree_hal_amdgpu_wait_entry_publish_callback_complete(entry);
+  if (owns_completion) {
+    iree_hal_amdgpu_pending_op_complete_resolved_waits(op);
+  }
+}
+
+// Registers timepoints for all waits in the operation's wait semaphore list.
+// Sets wait_count and registers one timepoint per wait; callbacks may fire
+// synchronously during registration. When all waits are satisfied (the last
+// callback fires), the operation is issued or failed.
+//
+// The wait_semaphore_list on the op (cloned into the arena by allocate)
+// retains all semaphores for the lifetime of the op. Wait entries do not
+// independently retain semaphores.
+static iree_status_t iree_hal_amdgpu_pending_op_enqueue_waits(
+    iree_hal_amdgpu_pending_op_t* op) {
+  iree_hal_semaphore_list_t wait_semaphores = op->wait_semaphore_list;
+  if (wait_semaphores.count == 0) {
+    iree_hal_amdgpu_pending_op_issue(op);
+    return iree_ok_status();
+  }
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, wait_semaphores.count);
+
+  iree_host_size_t wait_entry_bytes = 0;
+  iree_status_t status =
+      IREE_STRUCT_LAYOUT(0, &wait_entry_bytes,
+                         IREE_STRUCT_FIELD(wait_semaphores.count,
+                                           iree_hal_amdgpu_wait_entry_t, NULL));
+  if (!iree_status_is_ok(status)) {
+    iree_hal_amdgpu_pending_op_fail(op, status);
+    IREE_TRACE_ZONE_END(z0);
+    return iree_ok_status();
+  }
+  status = iree_arena_allocate(&op->arena, wait_entry_bytes,
+                               (void**)&op->wait_entries);
+  if (!iree_status_is_ok(status)) {
+    iree_hal_amdgpu_pending_op_fail(op, status);
+    IREE_TRACE_ZONE_END(z0);
+    return iree_ok_status();
+  }
+  memset(op->wait_entries, 0, wait_entry_bytes);
+  // Unregistered entries never receive callbacks, so they start complete.
+  // Active registrations flip their entry incomplete until the callback exits.
+  for (iree_host_size_t i = 0; i < wait_semaphores.count; ++i) {
+    iree_atomic_store(&op->wait_entries[i].callback_complete, 1,
+                      iree_memory_order_relaxed);
+  }
+
+  // Set wait_count before registering any timepoints. A timepoint callback
+  // may fire synchronously during acquire_timepoint.
+  iree_atomic_store(&op->wait_count, (int32_t)wait_semaphores.count,
+                    iree_memory_order_release);
+
+  for (iree_host_size_t i = 0; i < wait_semaphores.count; ++i) {
+    iree_hal_amdgpu_wait_entry_t* entry = &op->wait_entries[i];
+    entry->operation = op;
+    iree_atomic_store(&entry->callback_complete, 0, iree_memory_order_relaxed);
+    entry->timepoint.callback = iree_hal_amdgpu_wait_entry_resolved;
+    entry->timepoint.user_data = entry;
+    status = iree_async_semaphore_acquire_timepoint(
+        (iree_async_semaphore_t*)wait_semaphores.semaphores[i],
+        wait_semaphores.payload_values[i], &entry->timepoint);
+
+    if (!iree_status_is_ok(status)) {
+      // Registration failed at index i. Timepoints 0..i-1 are already
+      // registered and their callbacks will fire asynchronously; we cannot
+      // destroy the op here. Record the error and subtract the unregistered
+      // count so the existing callbacks drain and destroy the op.
+      iree_hal_amdgpu_pending_op_record_error_status(op, status);
+      int32_t unregistered = (int32_t)(wait_semaphores.count - i);
+      iree_atomic_store(&entry->callback_complete, 1,
+                        iree_memory_order_release);
+      int32_t previous_count = iree_atomic_fetch_sub(
+          &op->wait_count, unregistered, iree_memory_order_acq_rel);
+      if (previous_count == unregistered) {
+        if (iree_hal_amdgpu_pending_op_mark_waits_resolved(op)) {
+          iree_hal_amdgpu_pending_op_complete_resolved_waits(op);
+        }
+      }
+      IREE_TRACE_ZONE_END(z0);
+      return iree_ok_status();
+    }
+  }
+
+  IREE_TRACE_ZONE_END(z0);
+  return iree_ok_status();
+}
+
+static void iree_hal_amdgpu_alloca_memory_wait_resolved(
+    iree_hal_amdgpu_pending_op_t* op, iree_status_t status) {
+  iree_hal_amdgpu_alloca_memory_wait_t* wait = op->alloca_op.memory_wait;
+  if (iree_status_is_ok(status) &&
+      wait->kind == IREE_HAL_AMDGPU_ALLOCA_MEMORY_WAIT_POOL_NOTIFICATION) {
+    wait->kind = IREE_HAL_AMDGPU_ALLOCA_MEMORY_WAIT_NONE;
+  }
+
+  int32_t expected_state =
+      IREE_HAL_AMDGPU_PENDING_OP_LIFECYCLE_ARMING_MEMORY_WAIT;
+  if (iree_atomic_compare_exchange_strong(
+          &op->lifecycle_state, &expected_state,
+          IREE_HAL_AMDGPU_PENDING_OP_LIFECYCLE_COMPLETING,
+          iree_memory_order_acq_rel, iree_memory_order_acquire)) {
+    iree_hal_amdgpu_pending_op_record_error_status(op, status);
+    iree_hal_amdgpu_alloca_memory_wait_publish_callback_complete(op);
+    return;
+  }
+
+  expected_state = IREE_HAL_AMDGPU_PENDING_OP_LIFECYCLE_PENDING;
+  if (iree_atomic_compare_exchange_strong(
+          &op->lifecycle_state, &expected_state,
+          IREE_HAL_AMDGPU_PENDING_OP_LIFECYCLE_COMPLETING,
+          iree_memory_order_acq_rel, iree_memory_order_acquire)) {
+    iree_hal_amdgpu_alloca_memory_wait_publish_callback_complete(op);
+    if (!iree_status_is_ok(status)) {
+      iree_hal_amdgpu_pending_op_fail(op, status);
+    } else {
+      iree_hal_amdgpu_pending_op_issue(op);
+    }
+    return;
+  }
+
+  iree_hal_amdgpu_pending_op_record_error_status(op, status);
+  iree_hal_amdgpu_alloca_memory_wait_publish_callback_complete(op);
+}
+
+static void iree_hal_amdgpu_alloca_frontier_wait_resolved(
+    void* user_data, iree_status_t status) {
+  iree_hal_amdgpu_alloca_memory_wait_resolved(
+      (iree_hal_amdgpu_pending_op_t*)user_data, status);
+}
+
+static void iree_hal_amdgpu_alloca_pool_notification_wait_resolved(
+    void* user_data, iree_async_operation_t* operation, iree_status_t status,
+    iree_async_completion_flags_t flags) {
+  (void)operation;
+  (void)flags;
+  iree_hal_amdgpu_alloca_memory_wait_resolved(
+      (iree_hal_amdgpu_pending_op_t*)user_data, status);
+}
+
+static void iree_hal_amdgpu_pending_op_finish_alloca_memory_wait_enqueue(
+    iree_hal_amdgpu_pending_op_t* op, iree_status_t status) {
+  int32_t expected_state =
+      IREE_HAL_AMDGPU_PENDING_OP_LIFECYCLE_ARMING_MEMORY_WAIT;
+  if (iree_status_is_ok(status)) {
+    if (iree_atomic_compare_exchange_strong(
+            &op->lifecycle_state, &expected_state,
+            IREE_HAL_AMDGPU_PENDING_OP_LIFECYCLE_PENDING,
+            iree_memory_order_acq_rel, iree_memory_order_acquire)) {
+      return;
+    }
+    if (expected_state == IREE_HAL_AMDGPU_PENDING_OP_LIFECYCLE_COMPLETING) {
+      iree_status_t error = (iree_status_t)iree_atomic_exchange(
+          &op->error_status, 0, iree_memory_order_acquire);
+      if (!iree_status_is_ok(error)) {
+        iree_hal_amdgpu_pending_op_fail(op, error);
+      } else {
+        iree_hal_amdgpu_pending_op_issue(op);
+      }
+    }
+    return;
+  }
+
+  if (iree_atomic_compare_exchange_strong(
+          &op->lifecycle_state, &expected_state,
+          IREE_HAL_AMDGPU_PENDING_OP_LIFECYCLE_COMPLETING,
+          iree_memory_order_acq_rel, iree_memory_order_acquire)) {
+    iree_hal_amdgpu_pending_op_fail(op, status);
+    return;
+  }
+  iree_hal_amdgpu_pending_op_record_error_status(op, status);
+}
+
+static void iree_hal_amdgpu_pending_op_enqueue_alloca_frontier_wait(
+    iree_hal_amdgpu_pending_op_t* op) {
+  iree_hal_amdgpu_host_queue_t* queue = op->queue;
+  iree_hal_amdgpu_alloca_memory_wait_t* wait = op->alloca_op.memory_wait;
+  iree_atomic_store(&wait->callback_complete, 0, iree_memory_order_relaxed);
+  iree_status_t status = iree_async_frontier_tracker_wait(
+      queue->frontier_tracker, wait->frontier.wait_frontier,
+      iree_hal_amdgpu_alloca_frontier_wait_resolved, op,
+      &wait->frontier.waiter);
+  iree_hal_amdgpu_pending_op_finish_alloca_memory_wait_enqueue(op, status);
+}
+
+static void iree_hal_amdgpu_pending_op_enqueue_alloca_pool_notification_wait(
+    iree_hal_amdgpu_pending_op_t* op) {
+  iree_hal_amdgpu_alloca_memory_wait_t* wait = op->alloca_op.memory_wait;
+  iree_async_notification_wait_operation_t* wait_op =
+      &wait->pool_notification.wait_ops[wait->pool_notification.wait_slot];
+  iree_async_operation_zero(&wait_op->base, sizeof(*wait_op));
+  iree_async_operation_initialize(
+      &wait_op->base, IREE_ASYNC_OPERATION_TYPE_NOTIFICATION_WAIT,
+      IREE_ASYNC_OPERATION_FLAG_NONE,
+      iree_hal_amdgpu_alloca_pool_notification_wait_resolved, op);
+  wait_op->notification = wait->pool_notification.notification;
+  wait_op->wait_flags = IREE_ASYNC_NOTIFICATION_WAIT_FLAG_USE_WAIT_TOKEN;
+  wait_op->wait_token = wait->pool_notification.wait_token;
+
+  iree_atomic_store(&wait->callback_complete, 0, iree_memory_order_relaxed);
+  iree_status_t status =
+      iree_async_proactor_submit_one(op->queue->proactor, &wait_op->base);
+  iree_hal_amdgpu_alloca_pool_notification_end_observe(wait);
+  iree_hal_amdgpu_pending_op_finish_alloca_memory_wait_enqueue(op, status);
+}
+
+static iree_status_t iree_hal_amdgpu_pending_op_grow_alloca_pool(
+    iree_hal_amdgpu_pending_op_t* op) {
+  iree_hal_amdgpu_alloca_memory_wait_t* wait = op->alloca_op.memory_wait;
+  const iree_async_frontier_t* requester_frontier =
+      iree_hal_amdgpu_fixed_frontier_as_frontier(
+          &wait->pool_growth.requester_frontier);
+  const iree_hal_pool_reserve_flags_t reserve_flags =
+      op->alloca_op.reserve_flags & ~IREE_HAL_POOL_RESERVE_FLAG_DISALLOW_GROWTH;
+
+  iree_hal_pool_reservation_t reservation;
+  iree_hal_pool_acquire_info_t acquire_info;
+  iree_hal_pool_acquire_result_t acquire_result =
+      IREE_HAL_POOL_ACQUIRE_EXHAUSTED;
+  IREE_RETURN_IF_ERROR(iree_hal_pool_acquire_reservation(
+      op->alloca_op.pool, op->alloca_op.allocation_size,
+      op->alloca_op.params.min_alignment ? op->alloca_op.params.min_alignment
+                                         : 1,
+      requester_frontier, reserve_flags, &reservation, &acquire_info,
+      &acquire_result));
+
+  switch (acquire_result) {
+    case IREE_HAL_POOL_ACQUIRE_OK:
+    case IREE_HAL_POOL_ACQUIRE_OK_FRESH: {
+      iree_hal_amdgpu_alloca_reservation_t alloca_reservation = {
+          .readiness = IREE_HAL_AMDGPU_ALLOCA_RESERVATION_READY,
+          .acquire_result = acquire_result,
+          .reservation = reservation,
+          .acquire_info = acquire_info,
+      };
+      iree_status_t status =
+          iree_hal_amdgpu_host_queue_materialize_alloca_reservation(
+              op->queue, &alloca_reservation, op->alloca_op.pool,
+              op->alloca_op.params, op->alloca_op.buffer,
+              &wait->pool_growth.materialization);
+      if (!iree_status_is_ok(status)) {
+        iree_hal_pool_release_reservation(op->alloca_op.pool, &reservation,
+                                          /*death_frontier=*/NULL);
+      }
+      return status;
+    }
+    case IREE_HAL_POOL_ACQUIRE_OK_NEEDS_WAIT:
+      iree_hal_pool_release_reservation(op->alloca_op.pool, &reservation,
+                                        acquire_info.wait_frontier);
+      wait->kind = IREE_HAL_AMDGPU_ALLOCA_MEMORY_WAIT_NONE;
+      return iree_ok_status();
+    case IREE_HAL_POOL_ACQUIRE_EXHAUSTED:
+    case IREE_HAL_POOL_ACQUIRE_OVER_BUDGET:
+      return iree_make_status(
+          IREE_STATUS_RESOURCE_EXHAUSTED,
+          "queue_alloca cold pool growth did not produce a reservation "
+          "(result=%u)",
+          acquire_result);
+    default:
+      return iree_make_status(IREE_STATUS_INTERNAL,
+                              "unrecognized pool acquire result %u",
+                              acquire_result);
+  }
+}
+
+static void iree_hal_amdgpu_pending_op_enqueue_alloca_pool_growth(
+    iree_hal_amdgpu_pending_op_t* op) {
+  iree_hal_amdgpu_alloca_memory_wait_t* wait = op->alloca_op.memory_wait;
+  iree_atomic_store(&wait->callback_complete, 0, iree_memory_order_relaxed);
+  iree_status_t status = iree_hal_amdgpu_pending_op_grow_alloca_pool(op);
+  iree_hal_amdgpu_alloca_memory_wait_resolved(op, status);
+  iree_hal_amdgpu_pending_op_finish_alloca_memory_wait_enqueue(
+      op, iree_ok_status());
+}
+
+void iree_hal_amdgpu_pending_op_enqueue_alloca_memory_wait(
+    iree_hal_amdgpu_pending_op_t* op) {
+  iree_hal_amdgpu_alloca_memory_wait_t* wait = op->alloca_op.memory_wait;
+  switch (wait->kind) {
+    case IREE_HAL_AMDGPU_ALLOCA_MEMORY_WAIT_FRONTIER:
+      iree_hal_amdgpu_pending_op_enqueue_alloca_frontier_wait(op);
+      break;
+    case IREE_HAL_AMDGPU_ALLOCA_MEMORY_WAIT_POOL_GROWTH:
+      iree_hal_amdgpu_pending_op_enqueue_alloca_pool_growth(op);
+      break;
+    case IREE_HAL_AMDGPU_ALLOCA_MEMORY_WAIT_POOL_NOTIFICATION:
+      iree_hal_amdgpu_pending_op_enqueue_alloca_pool_notification_wait(op);
+      break;
+    case IREE_HAL_AMDGPU_ALLOCA_MEMORY_WAIT_NONE:
+      iree_hal_amdgpu_pending_op_fail(
+          op, iree_make_status(IREE_STATUS_INTERNAL,
+                               "pending alloca has no memory wait to enqueue"));
+      break;
+  }
+}
+
+static void iree_hal_amdgpu_pending_op_enqueue_capacity_retry(
+    iree_hal_amdgpu_pending_op_t* op) {
+  iree_atomic_store(&op->lifecycle_state,
+                    IREE_HAL_AMDGPU_PENDING_OP_LIFECYCLE_COMPLETING,
+                    iree_memory_order_release);
+  iree_hal_amdgpu_host_queue_enqueue_post_drain_action(
+      op->queue, &op->capacity_retry,
+      iree_hal_amdgpu_pending_op_capacity_post_drain, op);
+}
+
+static void iree_hal_amdgpu_pending_op_capacity_post_drain(void* user_data) {
+  iree_hal_amdgpu_pending_op_issue((iree_hal_amdgpu_pending_op_t*)user_data);
+}
+
+iree_status_t iree_hal_amdgpu_pending_op_start(iree_hal_amdgpu_pending_op_t* op,
+                                               bool wait_for_capacity) {
+  if (wait_for_capacity) {
+    iree_hal_amdgpu_pending_op_enqueue_capacity_retry(op);
+    return iree_ok_status();
+  }
+  return iree_hal_amdgpu_pending_op_enqueue_waits(op);
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_clone_error_status(
+    iree_hal_amdgpu_host_queue_t* queue) {
+  iree_status_t error = (iree_status_t)iree_atomic_load(
+      &queue->error_status, iree_memory_order_acquire);
+  return iree_status_is_ok(error) ? iree_ok_status() : iree_status_clone(error);
+}
+
+// Allocates and initializes a pending operation from a fresh arena.
+// Clones the wait semaphore list. Allocates the retained_resources array
+// with |max_resource_count| capacity and populates the first entries with
+// signal semaphores (retained). The signal_semaphore_list.semaphores pointer
+// aliases into retained_resources so that commit_signals and
+// semaphore_list_fail can use it directly.
+//
+// The caller must push operation-specific resources into retained_resources
+// (via the returned op) before calling enqueue_waits.
+//
+// On failure, the arena is cleaned up and *out_op is set to NULL.
+iree_status_t iree_hal_amdgpu_pending_op_allocate(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t* wait_semaphore_list,
+    const iree_hal_semaphore_list_t* signal_semaphore_list,
+    iree_hal_amdgpu_pending_op_type_t type, uint16_t max_resource_count,
+    iree_hal_amdgpu_pending_op_t** out_op) {
+  IREE_ASSERT_ARGUMENT(out_op);
+  *out_op = NULL;
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, type);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, max_resource_count);
+
+  iree_arena_allocator_t arena;
+  iree_arena_initialize(queue->block_pool, &arena);
+
+  iree_hal_amdgpu_pending_op_t* op = NULL;
+  iree_status_t status = iree_arena_allocate(&arena, sizeof(*op), (void**)&op);
+  if (!iree_status_is_ok(status)) {
+    iree_arena_deinitialize(&arena);
+    IREE_TRACE_ZONE_END(z0);
+    return status;
+  }
+
+  memset(op, 0, sizeof(*op));
+  memcpy(&op->arena, &arena, sizeof(arena));
+  op->queue = queue;
+  op->type = type;
+  iree_atomic_store(&op->wait_count, 0, iree_memory_order_relaxed);
+  iree_atomic_store(&op->lifecycle_state,
+                    IREE_HAL_AMDGPU_PENDING_OP_LIFECYCLE_PENDING,
+                    iree_memory_order_relaxed);
+  iree_atomic_store(&op->error_status, 0, iree_memory_order_relaxed);
+  iree_notification_initialize(&op->callback_notification);
+
+  iree_allocator_t arena_allocator = iree_arena_allocator(&op->arena);
+
+  // Clone the wait semaphore list (retains each wait semaphore).
+  status = iree_hal_semaphore_list_clone(wait_semaphore_list, arena_allocator,
+                                         &op->wait_semaphore_list);
+
+  // Allocate the retained resources array and the signal payload values.
+  if (iree_status_is_ok(status) && max_resource_count > 0) {
+    iree_host_size_t retained_resource_size = 0;
+    status = IREE_STRUCT_LAYOUT(
+        0, &retained_resource_size,
+        IREE_STRUCT_FIELD(max_resource_count, iree_hal_resource_t*, NULL));
+    if (iree_status_is_ok(status)) {
+      status = iree_arena_allocate(&op->arena, retained_resource_size,
+                                   (void**)&op->retained_resources);
+    }
+  }
+  uint64_t* signal_payload_values = NULL;
+  if (iree_status_is_ok(status) && signal_semaphore_list->count > 0) {
+    iree_host_size_t signal_payload_size = 0;
+    status = IREE_STRUCT_LAYOUT(
+        0, &signal_payload_size,
+        IREE_STRUCT_FIELD(signal_semaphore_list->count, uint64_t, NULL));
+    if (iree_status_is_ok(status)) {
+      status = iree_arena_allocate(&op->arena, signal_payload_size,
+                                   (void**)&signal_payload_values);
+    }
+  }
+
+  if (iree_status_is_ok(status)) {
+    // Signal semaphores occupy the first entries of retained_resources.
+    // The signal_semaphore_list.semaphores pointer aliases this region.
+    for (iree_host_size_t i = 0; i < signal_semaphore_list->count; ++i) {
+      op->retained_resources[i] =
+          (iree_hal_resource_t*)signal_semaphore_list->semaphores[i];
+      iree_hal_resource_retain(op->retained_resources[i]);
+      signal_payload_values[i] = signal_semaphore_list->payload_values[i];
+    }
+    op->retained_resource_count = (uint16_t)signal_semaphore_list->count;
+    op->signal_semaphore_list.count = signal_semaphore_list->count;
+    op->signal_semaphore_list.semaphores =
+        (iree_hal_semaphore_t**)op->retained_resources;
+    op->signal_semaphore_list.payload_values = signal_payload_values;
+
+    iree_hal_amdgpu_pending_op_link(op);
+    *out_op = op;
+  } else {
+    iree_hal_semaphore_list_release(op->wait_semaphore_list);
+    iree_notification_deinitialize(&op->callback_notification);
+    iree_arena_deinitialize(&op->arena);
+  }
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+// Issues a deferred operation after all waits are satisfied. All waits are
+// tier 0 (timeline_value >= waited_value); the GPU work producing those
+// values has completed. No barriers are needed.
+//
+// Called from the last wait_entry callback (any thread). Acquires
+// submission_mutex to emit AQL packets and commit signals.
+static void iree_hal_amdgpu_pending_op_issue(iree_hal_amdgpu_pending_op_t* op) {
+  iree_hal_amdgpu_host_queue_t* queue = op->queue;
+  iree_slim_mutex_lock(&queue->locks.submission_mutex);
+
+  iree_status_t status = iree_ok_status();
+  iree_hal_amdgpu_pending_op_payload_issue_t issue = {
+      .ready = true,
+      .memory_wait_op = NULL,
+  };
+  if (queue->is_shutting_down) {
+    status = iree_make_status(IREE_STATUS_CANCELLED, "queue shutting down");
+  } else {
+    status = iree_hal_amdgpu_host_queue_clone_error_status(queue);
+  }
+  if (iree_status_is_ok(status)) {
+    // All waits are tier 0; emit operation packets with no dependency
+    // barriers.
+    iree_hal_amdgpu_wait_resolution_t resolution;
+    resolution.barrier_count = 0;
+    resolution.needs_deferral = false;
+    memset(resolution.reserved, 0, sizeof(resolution.reserved));
+    resolution.wait_count = op->wait_semaphore_list.count > UINT32_MAX
+                                ? UINT32_MAX
+                                : (uint32_t)op->wait_semaphore_list.count;
+    resolution.profile_event_flags =
+        IREE_HAL_PROFILE_QUEUE_EVENT_FLAG_SOFTWARE_DEFERRED;
+    resolution.inline_acquire_scope = op->wait_semaphore_list.count > 0
+                                          ? IREE_HSA_FENCE_SCOPE_SYSTEM
+                                          : IREE_HSA_FENCE_SCOPE_NONE;
+    resolution.barrier_acquire_scope = IREE_HSA_FENCE_SCOPE_NONE;
+    status = iree_hal_amdgpu_pending_op_issue_payload(op, &resolution, &issue);
+    if (iree_status_is_ok(status) && issue.memory_wait_op) {
+      iree_slim_mutex_unlock(&queue->locks.submission_mutex);
+      iree_hal_amdgpu_pending_op_enqueue_alloca_memory_wait(
+          issue.memory_wait_op);
+      return;
+    }
+  }
+
+  if (iree_status_is_ok(status) && !issue.ready) {
+    iree_hal_amdgpu_pending_op_enqueue_capacity_retry(op);
+    iree_slim_mutex_unlock(&queue->locks.submission_mutex);
+    return;
+  }
+
+  if (!iree_status_is_ok(status)) {
+    iree_hal_amdgpu_pending_op_fail_host_action(op, status);
+    iree_hal_semaphore_list_fail(op->signal_semaphore_list, status);
+    iree_hal_amdgpu_pending_op_abort_unsubmitted_dealloca(op);
+    iree_hal_amdgpu_pending_op_release_alloca_memory_wait(op);
+    iree_hal_amdgpu_pending_op_release_execute_binding_resource_set(op);
+    iree_hal_amdgpu_pending_op_release_retained(op);
+  }
+
+  // Clean up the pending op. Wait semaphore list is released (the clone holds
+  // separate retains). Remaining retained_resources entries are either
+  // transferred to reclaim or were released by the success path above.
+  iree_hal_semaphore_list_release(op->wait_semaphore_list);
+  iree_hal_amdgpu_pending_op_unlink(op);
+  iree_notification_deinitialize(&op->callback_notification);
+  iree_arena_deinitialize(&op->arena);
+
+  iree_slim_mutex_unlock(&queue->locks.submission_mutex);
+}
+
+// Fails a deferred operation. Propagates the error to all signal semaphores
+// so downstream waiters receive the failure instead of hanging. Takes
+// ownership of |status|.
+static void iree_hal_amdgpu_pending_op_fail(iree_hal_amdgpu_pending_op_t* op,
+                                            iree_status_t status) {
+  iree_hal_amdgpu_host_queue_t* queue = op->queue;
+  iree_hal_amdgpu_pending_op_fail_host_action(op, status);
+  // Fail signal semaphores (records error, does not release our retains).
+  iree_hal_semaphore_list_fail(op->signal_semaphore_list, status);
+  iree_hal_amdgpu_pending_op_abort_unsubmitted_dealloca(op);
+  // Release any queue-owned memory reservation before releasing op resources.
+  iree_hal_amdgpu_pending_op_release_alloca_memory_wait(op);
+  iree_hal_amdgpu_pending_op_release_execute_binding_resource_set(op);
+  // Release all retained resources (signal semaphores + op resources).
+  iree_hal_amdgpu_pending_op_release_retained(op);
+  // Release wait semaphores (separately retained by the clone).
+  iree_hal_semaphore_list_release(op->wait_semaphore_list);
+  iree_slim_mutex_lock(&queue->locks.submission_mutex);
+  iree_hal_amdgpu_pending_op_unlink(op);
+  iree_slim_mutex_unlock(&queue->locks.submission_mutex);
+  iree_notification_deinitialize(&op->callback_notification);
+  iree_arena_deinitialize(&op->arena);
+}
+
+// Cancels all pending operations on a queue with the given failure details.
+// Creates a status only for operations that do not already carry a wait error.
+// Called during deinitialize or on unrecoverable GPU fault.
+// Caller must ensure no concurrent submissions (shutdown path).
+void iree_hal_amdgpu_host_queue_cancel_pending(
+    iree_hal_amdgpu_host_queue_t* queue, iree_status_code_t status_code,
+    const char* status_message) {
+  iree_slim_mutex_lock(&queue->locks.submission_mutex);
+  queue->is_shutting_down = true;
+  iree_slim_mutex_unlock(&queue->locks.submission_mutex);
+
+  for (;;) {
+    iree_hal_amdgpu_pending_op_t* op = NULL;
+    iree_slim_mutex_lock(&queue->locks.submission_mutex);
+    for (iree_hal_amdgpu_pending_op_t* candidate = queue->pending_head;
+         candidate != NULL; candidate = candidate->next) {
+      int32_t expected_state = IREE_HAL_AMDGPU_PENDING_OP_LIFECYCLE_PENDING;
+      if (iree_atomic_compare_exchange_strong(
+              &candidate->lifecycle_state, &expected_state,
+              IREE_HAL_AMDGPU_PENDING_OP_LIFECYCLE_CANCELLING,
+              iree_memory_order_acq_rel, iree_memory_order_acquire)) {
+        iree_hal_amdgpu_pending_op_unlink(candidate);
+        op = candidate;
+        break;
+      }
+    }
+    bool has_pending_ops = queue->pending_head != NULL;
+    iree_slim_mutex_unlock(&queue->locks.submission_mutex);
+
+    if (op == NULL) {
+      if (!has_pending_ops) break;
+      iree_thread_yield();
+      continue;
+    }
+
+    for (iree_host_size_t i = 0; i < op->wait_semaphore_list.count; ++i) {
+      iree_hal_amdgpu_wait_entry_t* entry = &op->wait_entries[i];
+      if (iree_hal_amdgpu_wait_entry_callback_is_complete(entry)) continue;
+      if (iree_async_semaphore_cancel_timepoint(entry->timepoint.semaphore,
+                                                &entry->timepoint)) {
+        continue;
+      }
+      iree_notification_await(&op->callback_notification,
+                              iree_hal_amdgpu_wait_entry_callback_is_complete,
+                              entry, iree_infinite_timeout());
+    }
+    iree_hal_amdgpu_pending_op_cancel_alloca_memory_wait(op);
+
+    iree_status_t op_status = (iree_status_t)iree_atomic_exchange(
+        &op->error_status, 0, iree_memory_order_acquire);
+    if (iree_status_is_ok(op_status)) {
+      op_status = iree_make_status(status_code, "%s", status_message);
+    }
+    iree_hal_semaphore_list_fail(op->signal_semaphore_list, op_status);
+    iree_hal_amdgpu_pending_op_abort_unsubmitted_dealloca(op);
+    iree_hal_amdgpu_pending_op_release_alloca_memory_wait(op);
+    iree_hal_amdgpu_pending_op_release_execute_binding_resource_set(op);
+    iree_hal_amdgpu_pending_op_release_retained(op);
+    iree_hal_semaphore_list_release(op->wait_semaphore_list);
+    iree_notification_deinitialize(&op->callback_notification);
+    iree_arena_deinitialize(&op->arena);
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// Alloca memory-readiness waits
+//===----------------------------------------------------------------------===//
+
+static iree_status_t
+iree_hal_amdgpu_host_queue_submit_alloca_held_frontier_wait(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_pool_t* allocation_pool, iree_hal_buffer_params_t params,
+    iree_hal_buffer_t* buffer,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    iree_hal_amdgpu_alloca_memory_wait_t* memory_wait, bool* out_ready) {
+  iree_hal_amdgpu_alloca_reservation_t alloca_reservation = {
+      .readiness = IREE_HAL_AMDGPU_ALLOCA_RESERVATION_READY,
+      .acquire_result = IREE_HAL_POOL_ACQUIRE_OK_NEEDS_WAIT,
+      .reservation = memory_wait->frontier.reservation,
+      .acquire_info =
+          {
+              .wait_frontier = memory_wait->frontier.wait_frontier,
+          },
+      .wait_resolution = *resolution,
+  };
+  memory_wait->kind = IREE_HAL_AMDGPU_ALLOCA_MEMORY_WAIT_NONE;
+  return iree_hal_amdgpu_host_queue_submit_alloca_reservation(
+      queue, &alloca_reservation, signal_semaphore_list, allocation_pool,
+      params, buffer, submission_flags, out_ready);
+}
+
+static iree_status_t
+iree_hal_amdgpu_host_queue_submit_alloca_held_growth_materialization(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_pool_t* allocation_pool, iree_hal_buffer_params_t params,
+    iree_hal_buffer_t* buffer,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    iree_hal_amdgpu_alloca_memory_wait_t* memory_wait, bool* out_ready) {
+  iree_hal_amdgpu_alloca_reservation_t alloca_reservation = {
+      .readiness = IREE_HAL_AMDGPU_ALLOCA_RESERVATION_READY,
+      .acquire_result =
+          memory_wait->pool_growth.materialization.reservation.acquire_result,
+      .reservation =
+          memory_wait->pool_growth.materialization.reservation.reservation,
+      .acquire_info =
+          memory_wait->pool_growth.materialization.reservation.acquire_info,
+      .wait_resolution = *resolution,
+  };
+  memory_wait->pool_growth.materialization.reservation = alloca_reservation;
+  memory_wait->kind = IREE_HAL_AMDGPU_ALLOCA_MEMORY_WAIT_NONE;
+  return iree_hal_amdgpu_host_queue_submit_alloca_materialization(
+      queue, &memory_wait->pool_growth.materialization, signal_semaphore_list,
+      allocation_pool, params, buffer, submission_flags, out_ready);
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_get_alloca_memory_wait_op(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_pool_t* allocation_pool, iree_hal_buffer_params_t params,
+    iree_device_size_t allocation_size, iree_hal_alloca_flags_t flags,
+    iree_hal_pool_reserve_flags_t reserve_flags, iree_hal_buffer_t* buffer,
+    iree_hal_amdgpu_pending_op_t* pending_op,
+    iree_hal_amdgpu_pending_op_t** out_memory_wait_op) {
+  *out_memory_wait_op = NULL;
+  if (pending_op) {
+    *out_memory_wait_op = pending_op;
+    return iree_ok_status();
+  }
+
+  iree_hal_semaphore_list_t empty_wait_list = iree_hal_semaphore_list_empty();
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_defer_alloca(
+      queue, &empty_wait_list, &signal_semaphore_list, allocation_pool, params,
+      allocation_size, flags, reserve_flags, buffer, out_memory_wait_op));
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_defer_alloca_frontier_wait(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_pool_t* allocation_pool, iree_hal_buffer_params_t params,
+    iree_device_size_t allocation_size, iree_hal_alloca_flags_t flags,
+    iree_hal_pool_reserve_flags_t reserve_flags, iree_hal_buffer_t* buffer,
+    const iree_hal_amdgpu_alloca_reservation_t* alloca_reservation,
+    iree_hal_amdgpu_pending_op_t* pending_op,
+    iree_hal_amdgpu_pending_op_t** out_memory_wait_op) {
+  iree_hal_amdgpu_pending_op_t* memory_wait_op = pending_op;
+  iree_status_t status = iree_hal_amdgpu_host_queue_get_alloca_memory_wait_op(
+      queue, signal_semaphore_list, allocation_pool, params, allocation_size,
+      flags, reserve_flags, buffer, pending_op, &memory_wait_op);
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_pending_op_prepare_alloca_frontier_wait(
+        memory_wait_op, alloca_reservation);
+  }
+  if (iree_status_is_ok(status)) {
+    *out_memory_wait_op = memory_wait_op;
+  } else {
+    iree_hal_pool_release_reservation(
+        allocation_pool, &alloca_reservation->reservation,
+        alloca_reservation->acquire_info.wait_frontier);
+    if (!pending_op && memory_wait_op) {
+      iree_hal_amdgpu_pending_op_destroy_under_lock(memory_wait_op,
+                                                    iree_status_clone(status));
+    }
+  }
+
+  return status;
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_defer_alloca_pool_growth(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_pool_t* allocation_pool, iree_hal_buffer_params_t params,
+    iree_device_size_t allocation_size, iree_hal_alloca_flags_t flags,
+    iree_hal_pool_reserve_flags_t reserve_flags, iree_hal_buffer_t* buffer,
+    iree_hal_amdgpu_pending_op_t* pending_op,
+    iree_hal_amdgpu_pending_op_t** out_memory_wait_op) {
+  iree_hal_amdgpu_pending_op_t* memory_wait_op = pending_op;
+  iree_status_t status = iree_hal_amdgpu_host_queue_get_alloca_memory_wait_op(
+      queue, signal_semaphore_list, allocation_pool, params, allocation_size,
+      flags, reserve_flags, buffer, pending_op, &memory_wait_op);
+  if (iree_status_is_ok(status)) {
+    iree_hal_amdgpu_fixed_frontier_t requester_frontier_storage;
+    const iree_async_frontier_t* requester_frontier =
+        iree_hal_amdgpu_host_queue_pool_requester_frontier(
+            queue, resolution, &requester_frontier_storage);
+    status = iree_hal_amdgpu_pending_op_prepare_alloca_pool_growth(
+        memory_wait_op, requester_frontier);
+  }
+  if (iree_status_is_ok(status)) {
+    *out_memory_wait_op = memory_wait_op;
+  } else if (!pending_op && memory_wait_op) {
+    iree_hal_amdgpu_pending_op_destroy_under_lock(memory_wait_op,
+                                                  iree_status_clone(status));
+  }
+  return status;
+}
+
+static iree_status_t
+iree_hal_amdgpu_host_queue_defer_alloca_pool_notification_wait(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_pool_t* allocation_pool, iree_hal_buffer_params_t params,
+    iree_device_size_t allocation_size, iree_hal_alloca_flags_t flags,
+    iree_hal_pool_reserve_flags_t reserve_flags, iree_hal_buffer_t* buffer,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    iree_hal_amdgpu_pending_op_t* pending_op,
+    iree_hal_amdgpu_pending_op_t** out_memory_wait_op, bool* out_ready) {
+  iree_async_notification_t* notification =
+      iree_hal_pool_notification(allocation_pool);
+  if (IREE_UNLIKELY(!notification)) {
+    return iree_make_status(
+        IREE_STATUS_INTERNAL,
+        "queue_alloca exhausted pool did not provide a notification");
+  }
+
+  const uint32_t wait_token =
+      iree_async_notification_begin_observe(notification);
+  iree_hal_amdgpu_alloca_reservation_t alloca_reservation;
+  iree_status_t status = iree_hal_amdgpu_host_queue_acquire_alloca_reservation(
+      queue, resolution, allocation_pool, params, allocation_size, flags,
+      reserve_flags, buffer, &alloca_reservation);
+
+  bool observation_transferred = false;
+  if (iree_status_is_ok(status)) {
+    switch (alloca_reservation.readiness) {
+      case IREE_HAL_AMDGPU_ALLOCA_RESERVATION_READY:
+        status = iree_hal_amdgpu_host_queue_submit_alloca_reservation(
+            queue, &alloca_reservation, signal_semaphore_list, allocation_pool,
+            params, buffer, submission_flags, out_ready);
+        break;
+      case IREE_HAL_AMDGPU_ALLOCA_RESERVATION_NEEDS_FRONTIER_WAIT:
+        status = iree_hal_amdgpu_host_queue_defer_alloca_frontier_wait(
+            queue, signal_semaphore_list, allocation_pool, params,
+            allocation_size, flags, reserve_flags, buffer, &alloca_reservation,
+            pending_op, out_memory_wait_op);
+        break;
+      case IREE_HAL_AMDGPU_ALLOCA_RESERVATION_NEEDS_POOL_GROWTH:
+        status = iree_hal_amdgpu_host_queue_defer_alloca_pool_growth(
+            queue, resolution, signal_semaphore_list, allocation_pool, params,
+            allocation_size, flags, reserve_flags, buffer, pending_op,
+            out_memory_wait_op);
+        break;
+      case IREE_HAL_AMDGPU_ALLOCA_RESERVATION_NEEDS_POOL_NOTIFICATION:
+        break;
+      default:
+        status =
+            iree_make_status(IREE_STATUS_INTERNAL,
+                             "unrecognized alloca reservation readiness %u",
+                             alloca_reservation.readiness);
+        break;
+    }
+  }
+
+  iree_hal_amdgpu_pending_op_t* memory_wait_op = pending_op;
+  if (iree_status_is_ok(status) &&
+      alloca_reservation.readiness ==
+          IREE_HAL_AMDGPU_ALLOCA_RESERVATION_NEEDS_POOL_NOTIFICATION) {
+    status = iree_hal_amdgpu_host_queue_get_alloca_memory_wait_op(
+        queue, signal_semaphore_list, allocation_pool, params, allocation_size,
+        flags, reserve_flags, buffer, pending_op, &memory_wait_op);
+    if (iree_status_is_ok(status)) {
+      status = iree_hal_amdgpu_pending_op_prepare_alloca_pool_notification_wait(
+          memory_wait_op, notification, wait_token);
+      observation_transferred = iree_status_is_ok(status);
+    }
+    if (iree_status_is_ok(status)) {
+      *out_memory_wait_op = memory_wait_op;
+    } else if (!pending_op && memory_wait_op) {
+      iree_hal_amdgpu_pending_op_destroy_under_lock(memory_wait_op,
+                                                    iree_status_clone(status));
+    }
+  }
+  if (!observation_transferred) {
+    iree_async_notification_end_observe(notification);
+  }
+  return status;
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_submit_alloca(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_pool_t* allocation_pool, iree_hal_buffer_params_t params,
+    iree_device_size_t allocation_size, iree_hal_alloca_flags_t flags,
+    iree_hal_pool_reserve_flags_t reserve_flags, iree_hal_buffer_t* buffer,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    iree_hal_amdgpu_pending_op_t* pending_op,
+    iree_hal_amdgpu_pending_op_t** out_memory_wait_op, bool* out_ready) {
+  IREE_ASSERT_ARGUMENT(out_ready);
+  *out_ready = false;
+  *out_memory_wait_op = NULL;
+  if (IREE_UNLIKELY(queue->is_shutting_down)) {
+    return iree_make_status(IREE_STATUS_CANCELLED, "queue shutting down");
+  }
+
+  iree_hal_amdgpu_alloca_memory_wait_t* memory_wait =
+      pending_op ? pending_op->alloca_op.memory_wait : NULL;
+  if (memory_wait &&
+      memory_wait->kind == IREE_HAL_AMDGPU_ALLOCA_MEMORY_WAIT_FRONTIER) {
+    return iree_hal_amdgpu_host_queue_submit_alloca_held_frontier_wait(
+        queue, resolution, signal_semaphore_list, allocation_pool, params,
+        buffer, submission_flags, memory_wait, out_ready);
+  }
+  if (memory_wait &&
+      memory_wait->kind == IREE_HAL_AMDGPU_ALLOCA_MEMORY_WAIT_POOL_GROWTH &&
+      (memory_wait->pool_growth.materialization.reservation.acquire_result ==
+           IREE_HAL_POOL_ACQUIRE_OK ||
+       memory_wait->pool_growth.materialization.reservation.acquire_result ==
+           IREE_HAL_POOL_ACQUIRE_OK_FRESH)) {
+    return iree_hal_amdgpu_host_queue_submit_alloca_held_growth_materialization(
+        queue, resolution, signal_semaphore_list, allocation_pool, params,
+        buffer, submission_flags, memory_wait, out_ready);
+  }
+
+  iree_hal_amdgpu_alloca_reservation_t alloca_reservation;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_acquire_alloca_reservation(
+      queue, resolution, allocation_pool, params, allocation_size, flags,
+      reserve_flags, buffer, &alloca_reservation));
+  switch (alloca_reservation.readiness) {
+    case IREE_HAL_AMDGPU_ALLOCA_RESERVATION_READY:
+      return iree_hal_amdgpu_host_queue_submit_alloca_reservation(
+          queue, &alloca_reservation, signal_semaphore_list, allocation_pool,
+          params, buffer, submission_flags, out_ready);
+    case IREE_HAL_AMDGPU_ALLOCA_RESERVATION_NEEDS_FRONTIER_WAIT:
+      return iree_hal_amdgpu_host_queue_defer_alloca_frontier_wait(
+          queue, signal_semaphore_list, allocation_pool, params,
+          allocation_size, flags, reserve_flags, buffer, &alloca_reservation,
+          pending_op, out_memory_wait_op);
+    case IREE_HAL_AMDGPU_ALLOCA_RESERVATION_NEEDS_POOL_GROWTH:
+      return iree_hal_amdgpu_host_queue_defer_alloca_pool_growth(
+          queue, resolution, signal_semaphore_list, allocation_pool, params,
+          allocation_size, flags, reserve_flags, buffer, pending_op,
+          out_memory_wait_op);
+    case IREE_HAL_AMDGPU_ALLOCA_RESERVATION_NEEDS_POOL_NOTIFICATION:
+      return iree_hal_amdgpu_host_queue_defer_alloca_pool_notification_wait(
+          queue, resolution, signal_semaphore_list, allocation_pool, params,
+          allocation_size, flags, reserve_flags, buffer, submission_flags,
+          pending_op, out_memory_wait_op, out_ready);
+    default:
+      return iree_make_status(IREE_STATUS_INTERNAL,
+                              "unrecognized alloca reservation readiness %u",
+                              alloca_reservation.readiness);
+  }
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_pending.h b/runtime/src/iree/hal/drivers/amdgpu/host_queue_pending.h
new file mode 100644
index 0000000..0d0b50d
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_pending.h
@@ -0,0 +1,143 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_PENDING_H_
+#define IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_PENDING_H_
+
+#include "iree/hal/drivers/amdgpu/host_queue.h"
+#include "iree/hal/drivers/amdgpu/host_queue_submission.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Starts wait registration or capacity retry for a captured operation.
+iree_status_t iree_hal_amdgpu_pending_op_start(iree_hal_amdgpu_pending_op_t* op,
+                                               bool wait_for_capacity);
+
+// Enqueues the cold alloca memory-readiness wait prepared on |op|.
+void iree_hal_amdgpu_pending_op_enqueue_alloca_memory_wait(
+    iree_hal_amdgpu_pending_op_t* op);
+
+// Cancels all queue pending operations during shutdown or fatal queue failure.
+void iree_hal_amdgpu_host_queue_cancel_pending(
+    iree_hal_amdgpu_host_queue_t* queue, iree_status_code_t status_code,
+    const char* status_message);
+
+// Captures a queue_alloca operation for later issue. Caller must hold
+// queue->locks.submission_mutex.
+iree_status_t iree_hal_amdgpu_host_queue_defer_alloca(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t* wait_semaphore_list,
+    const iree_hal_semaphore_list_t* signal_semaphore_list,
+    iree_hal_pool_t* pool, iree_hal_buffer_params_t params,
+    iree_device_size_t allocation_size, iree_hal_alloca_flags_t flags,
+    iree_hal_pool_reserve_flags_t reserve_flags, iree_hal_buffer_t* buffer,
+    iree_hal_amdgpu_pending_op_t** out_op);
+
+// Submits an alloca operation after wait resolution. Caller must hold
+// queue->locks.submission_mutex. If memory readiness must wait,
+// |out_memory_wait_op| receives the operation that owns the prepared wait
+// sidecar.
+iree_status_t iree_hal_amdgpu_host_queue_submit_alloca(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_pool_t* pool, iree_hal_buffer_params_t params,
+    iree_device_size_t allocation_size, iree_hal_alloca_flags_t flags,
+    iree_hal_pool_reserve_flags_t reserve_flags, iree_hal_buffer_t* buffer,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    iree_hal_amdgpu_pending_op_t* pending_op,
+    iree_hal_amdgpu_pending_op_t** out_memory_wait_op, bool* out_ready);
+
+// Captures a queue_dealloca operation for later issue. Caller must hold
+// queue->locks.submission_mutex.
+iree_status_t iree_hal_amdgpu_host_queue_defer_dealloca(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t* wait_semaphore_list,
+    const iree_hal_semaphore_list_t* signal_semaphore_list,
+    iree_hal_buffer_t* buffer, iree_hal_amdgpu_pending_op_t** out_op);
+
+// Captures a queue_fill operation for later issue. Caller must hold
+// queue->locks.submission_mutex.
+iree_status_t iree_hal_amdgpu_host_queue_defer_fill(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t* wait_semaphore_list,
+    const iree_hal_semaphore_list_t* signal_semaphore_list,
+    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+    iree_device_size_t length, uint64_t pattern_bits,
+    iree_host_size_t pattern_length, iree_hal_fill_flags_t flags,
+    iree_hal_amdgpu_pending_op_t** out_op);
+
+// Captures a queue_copy/read/write operation for later issue. Caller must hold
+// queue->locks.submission_mutex.
+iree_status_t iree_hal_amdgpu_host_queue_defer_copy(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t* wait_semaphore_list,
+    const iree_hal_semaphore_list_t* signal_semaphore_list,
+    iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
+    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+    iree_device_size_t length, iree_hal_copy_flags_t flags,
+    iree_hal_profile_queue_event_type_t profile_event_type,
+    iree_hal_amdgpu_pending_op_t** out_op);
+
+// Captures a queue_update operation for later issue. Caller must hold
+// queue->locks.submission_mutex.
+iree_status_t iree_hal_amdgpu_host_queue_defer_update(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t* wait_semaphore_list,
+    const iree_hal_semaphore_list_t* signal_semaphore_list,
+    const void* source_buffer, iree_host_size_t source_offset,
+    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+    iree_device_size_t length, iree_hal_update_flags_t flags,
+    iree_hal_amdgpu_pending_op_t** out_op);
+
+// Captures a queue_execute operation for later issue. Caller must hold
+// queue->locks.submission_mutex.
+iree_status_t iree_hal_amdgpu_host_queue_defer_execute(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t* wait_semaphore_list,
+    const iree_hal_semaphore_list_t* signal_semaphore_list,
+    iree_hal_command_buffer_t* command_buffer,
+    iree_hal_buffer_binding_table_t binding_table,
+    iree_hal_execute_flags_t flags, iree_hal_amdgpu_pending_op_t** out_op);
+
+// Captures a queue_dispatch operation for later issue. Caller must hold
+// queue->locks.submission_mutex.
+iree_status_t iree_hal_amdgpu_host_queue_defer_dispatch(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t* wait_semaphore_list,
+    const iree_hal_semaphore_list_t* signal_semaphore_list,
+    iree_hal_executable_t* executable,
+    iree_hal_executable_export_ordinal_t export_ordinal,
+    const iree_hal_dispatch_config_t config, iree_const_byte_span_t constants,
+    const iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags,
+    iree_hal_amdgpu_pending_op_t** out_op);
+
+// Captures a driver host action for later issue. Caller must hold
+// queue->locks.submission_mutex.
+iree_status_t iree_hal_amdgpu_host_queue_defer_host_action(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t* wait_semaphore_list,
+    iree_hal_amdgpu_reclaim_action_t action,
+    iree_hal_resource_t* const* operation_resources,
+    iree_host_size_t operation_resource_count,
+    iree_hal_amdgpu_pending_op_t** out_op);
+
+// Captures a queue_host_call operation for later issue. Caller must hold
+// queue->locks.submission_mutex.
+iree_status_t iree_hal_amdgpu_host_queue_defer_host_call(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t* wait_semaphore_list,
+    const iree_hal_semaphore_list_t* signal_semaphore_list,
+    iree_hal_host_call_t call, const uint64_t args[4],
+    iree_hal_host_call_flags_t flags, iree_hal_amdgpu_pending_op_t** out_op);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_PENDING_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_pending_operation.h b/runtime/src/iree/hal/drivers/amdgpu/host_queue_pending_operation.h
new file mode 100644
index 0000000..f2bc4ea
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_pending_operation.h
@@ -0,0 +1,262 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_PENDING_OPERATION_H_
+#define IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_PENDING_OPERATION_H_
+
+#include "iree/base/threading/notification.h"
+#include "iree/hal/drivers/amdgpu/host_queue_pending.h"
+#include "iree/hal/utils/resource_set.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+typedef struct iree_hal_amdgpu_alloca_memory_wait_t
+    iree_hal_amdgpu_alloca_memory_wait_t;
+typedef struct iree_hal_amdgpu_wait_entry_t iree_hal_amdgpu_wait_entry_t;
+
+// Operation types corresponding to virtual queue vtable entries. Each type has
+// a per-operation capture struct in the pending_op_t union.
+typedef enum iree_hal_amdgpu_pending_op_type_e {
+  IREE_HAL_AMDGPU_PENDING_OP_FILL,
+  IREE_HAL_AMDGPU_PENDING_OP_COPY,
+  IREE_HAL_AMDGPU_PENDING_OP_UPDATE,
+  IREE_HAL_AMDGPU_PENDING_OP_DISPATCH,
+  IREE_HAL_AMDGPU_PENDING_OP_EXECUTE,
+  IREE_HAL_AMDGPU_PENDING_OP_ALLOCA,
+  IREE_HAL_AMDGPU_PENDING_OP_DEALLOCA,
+  IREE_HAL_AMDGPU_PENDING_OP_HOST_CALL,
+  IREE_HAL_AMDGPU_PENDING_OP_HOST_ACTION,
+} iree_hal_amdgpu_pending_op_type_t;
+
+// Completion ownership for a deferred operation.
+typedef enum iree_hal_amdgpu_pending_op_lifecycle_e {
+  // Waiting callbacks may still resolve the op.
+  IREE_HAL_AMDGPU_PENDING_OP_LIFECYCLE_PENDING = 0,
+  // Queue shutdown claimed cancellation ownership.
+  IREE_HAL_AMDGPU_PENDING_OP_LIFECYCLE_CANCELLING = 1,
+  // The last wait callback claimed completion ownership.
+  IREE_HAL_AMDGPU_PENDING_OP_LIFECYCLE_COMPLETING = 2,
+  // The issuing thread is registering a cold alloca memory-readiness wait.
+  // Cancellation only claims PENDING ops; the arming thread publishes PENDING
+  // after registration or observes a synchronous callback as COMPLETING.
+  IREE_HAL_AMDGPU_PENDING_OP_LIFECYCLE_ARMING_MEMORY_WAIT = 3,
+} iree_hal_amdgpu_pending_op_lifecycle_t;
+
+// A deferred queue operation waiting for its waits to become satisfiable.
+// Arena-allocated from the queue's block pool. All variable-size captured data
+// lives in the arena alongside this struct.
+struct iree_hal_amdgpu_pending_op_t {
+  // Arena backing this operation and all captured data.
+  iree_arena_allocator_t arena;
+
+  // Owning queue used to emit work when all waits are satisfied.
+  iree_hal_amdgpu_host_queue_t* queue;
+
+  // Next operation in the queue's pending list.
+  iree_hal_amdgpu_pending_op_t* next;
+
+  // Back-pointer to the previous link field for O(1) unlink.
+  iree_hal_amdgpu_pending_op_t** prev_next;
+
+  // Completion-thread retry queued when submission capacity is unavailable.
+  iree_hal_amdgpu_host_queue_post_drain_action_t capacity_retry;
+
+  // Number of outstanding wait timepoints.
+  iree_atomic_int32_t wait_count;
+
+  // Completion/cancellation owner.
+  iree_atomic_int32_t lifecycle_state;
+
+  // First error from a failed wait. CAS from 0; the winner owns the status.
+  iree_atomic_intptr_t error_status;
+
+  // Wakes cancellation when a detached wait callback finishes touching the op.
+  iree_notification_t callback_notification;
+
+  // Arena-owned clone of the wait semaphore list.
+  iree_hal_semaphore_list_t wait_semaphore_list;
+
+  // Arena-owned clone of signal payload values; semaphores alias the first
+  // entries of retained_resources.
+  iree_hal_semaphore_list_t signal_semaphore_list;
+
+  // Wait entries registered with the wait semaphores.
+  iree_hal_amdgpu_wait_entry_t* wait_entries;
+
+  // Flat array of all retained HAL resources.
+  iree_hal_resource_t** retained_resources;
+
+  // Number of entries currently owned in |retained_resources|.
+  uint16_t retained_resource_count;
+
+  // Operation payload selector.
+  iree_hal_amdgpu_pending_op_type_t type;
+
+  union {
+    // Captured queue_fill payload.
+    struct {
+      // Target buffer retained until the deferred fill operation issues.
+      iree_hal_buffer_t* target_buffer;
+      // Target byte offset captured from queue_fill.
+      iree_device_size_t target_offset;
+      // Number of bytes filled by this queue operation.
+      iree_device_size_t length;
+      // Fill pattern bits captured in the low bytes.
+      uint64_t pattern_bits;
+      // Fill pattern length in bytes.
+      iree_host_size_t pattern_length;
+      // HAL fill flags captured from queue_fill.
+      iree_hal_fill_flags_t flags;
+    } fill;
+
+    // Captured queue_copy/read/write payload.
+    struct {
+      // Source buffer retained until the deferred copy operation issues.
+      iree_hal_buffer_t* source_buffer;
+      // Source byte offset captured from queue_copy/read/write.
+      iree_device_size_t source_offset;
+      // Target buffer retained until the deferred copy operation issues.
+      iree_hal_buffer_t* target_buffer;
+      // Target byte offset captured from queue_copy/read/write.
+      iree_device_size_t target_offset;
+      // Number of bytes copied by this queue operation.
+      iree_device_size_t length;
+      // HAL copy flags captured from queue_copy/read/write.
+      iree_hal_copy_flags_t flags;
+      // Queue profiling event type used when the copy submission issues.
+      iree_hal_profile_queue_event_type_t profile_event_type;
+    } copy;
+
+    // Captured queue_update payload.
+    struct {
+      // Source data is copied into the arena.
+      const void* source_data;
+      // Target buffer retained until the deferred update operation issues.
+      iree_hal_buffer_t* target_buffer;
+      // Target byte offset captured from queue_update.
+      iree_device_size_t target_offset;
+      // Number of bytes copied from |source_data|.
+      iree_device_size_t length;
+      // HAL update flags captured from queue_update.
+      iree_hal_update_flags_t flags;
+    } update;
+
+    // Captured queue_dispatch payload.
+    struct {
+      // Executable retained until the deferred dispatch operation issues.
+      iree_hal_executable_t* executable;
+      // Export ordinal captured from queue_dispatch.
+      iree_hal_executable_export_ordinal_t export_ordinal;
+      // Dispatch workgroup configuration captured from queue_dispatch.
+      iree_hal_dispatch_config_t config;
+      // Arena-owned copy of dispatch constants.
+      iree_const_byte_span_t constants;
+      // Arena-owned copy of dispatch buffer references.
+      iree_hal_buffer_ref_list_t bindings;
+      // HAL dispatch flags captured from queue_dispatch.
+      iree_hal_dispatch_flags_t flags;
+    } dispatch;
+
+    // Captured queue_execute payload.
+    struct {
+      // Command buffer retained until the deferred execute operation issues.
+      iree_hal_command_buffer_t* command_buffer;
+      // Arena-owned copy of the binding table prefix used by command_buffer.
+      iree_hal_buffer_binding_table_t binding_table;
+      // Binding resources captured until the deferred execute operation issues.
+      iree_hal_resource_set_t* binding_resource_set;
+      // HAL execute flags captured from queue_execute.
+      iree_hal_execute_flags_t flags;
+    } execute;
+
+    // Captured queue_alloca payload.
+    struct {
+      // Borrowed pool resolved during queue_alloca capture.
+      iree_hal_pool_t* pool;
+      // Buffer parameters captured from queue_alloca.
+      iree_hal_buffer_params_t params;
+      // Requested allocation size in bytes.
+      iree_device_size_t allocation_size;
+      // HAL allocation flags captured from queue_alloca.
+      iree_hal_alloca_flags_t flags;
+      // Pool reservation flags used when probing the selected pool.
+      iree_hal_pool_reserve_flags_t reserve_flags;
+      // Transient buffer returned to the caller and committed on success.
+      iree_hal_buffer_t* buffer;
+      // Cold memory-readiness sidecar allocated only after user waits resolve.
+      iree_hal_amdgpu_alloca_memory_wait_t* memory_wait;
+    } alloca_op;
+
+    // Captured queue_dealloca payload.
+    struct {
+      // Transient buffer retained until the deferred dealloca operation issues.
+      iree_hal_buffer_t* buffer;
+    } dealloca;
+
+    // Captured queue_host_call payload.
+    struct {
+      // Host callback and user data captured from queue_host_call.
+      iree_hal_host_call_t call;
+      // Host call arguments copied by value.
+      uint64_t args[4];
+      // HAL host-call flags captured from queue_host_call.
+      iree_hal_host_call_flags_t flags;
+    } host_call;
+
+    // Captured driver host-action payload.
+    struct {
+      // Driver-owned completion-thread action ordered by queue semaphores.
+      iree_hal_amdgpu_reclaim_action_t action;
+    } host_action;
+  };
+};
+
+// Result of trying to issue an operation payload under the queue submission
+// lock.
+typedef struct iree_hal_amdgpu_pending_op_payload_issue_t {
+  // Whether queue admission found enough capacity for this payload.
+  bool ready;
+
+  // Pending alloca operation that owns a prepared cold memory-readiness wait.
+  iree_hal_amdgpu_pending_op_t* memory_wait_op;
+} iree_hal_amdgpu_pending_op_payload_issue_t;
+
+// Allocates and links a pending operation. Caller must hold
+// queue->locks.submission_mutex.
+iree_status_t iree_hal_amdgpu_pending_op_allocate(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t* wait_semaphore_list,
+    const iree_hal_semaphore_list_t* signal_semaphore_list,
+    iree_hal_amdgpu_pending_op_type_t type, uint16_t max_resource_count,
+    iree_hal_amdgpu_pending_op_t** out_op);
+
+// Retains |resource| in |op|'s preallocated retained resource table.
+void iree_hal_amdgpu_pending_op_retain(iree_hal_amdgpu_pending_op_t* op,
+                                       iree_hal_resource_t* resource);
+
+// Releases all resources retained by |op|.
+void iree_hal_amdgpu_pending_op_release_retained(
+    iree_hal_amdgpu_pending_op_t* op);
+
+// Destroys a capture-time failed operation. Caller must hold
+// queue->locks.submission_mutex.
+void iree_hal_amdgpu_pending_op_destroy_under_lock(
+    iree_hal_amdgpu_pending_op_t* op, iree_status_t status);
+
+// Issues the operation-family payload. Caller must hold
+// queue->locks.submission_mutex.
+iree_status_t iree_hal_amdgpu_pending_op_issue_payload(
+    iree_hal_amdgpu_pending_op_t* op,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    iree_hal_amdgpu_pending_op_payload_issue_t* issue);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_PENDING_OPERATION_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_pending_payload.c b/runtime/src/iree/hal/drivers/amdgpu/host_queue_pending_payload.c
new file mode 100644
index 0000000..c3d91bb
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_pending_payload.c
@@ -0,0 +1,541 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <string.h>
+
+#include "iree/hal/drivers/amdgpu/host_queue_blit.h"
+#include "iree/hal/drivers/amdgpu/host_queue_command_buffer.h"
+#include "iree/hal/drivers/amdgpu/host_queue_dispatch.h"
+#include "iree/hal/drivers/amdgpu/host_queue_host_call.h"
+#include "iree/hal/drivers/amdgpu/host_queue_memory.h"
+#include "iree/hal/drivers/amdgpu/host_queue_pending_operation.h"
+
+//===----------------------------------------------------------------------===//
+// Pending operation payload issue
+//===----------------------------------------------------------------------===//
+
+static iree_status_t iree_hal_amdgpu_pending_op_issue_fill(
+    iree_hal_amdgpu_pending_op_t* op,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    iree_hal_amdgpu_pending_op_payload_issue_t* issue) {
+  iree_status_t status = iree_hal_amdgpu_host_queue_submit_fill(
+      op->queue, resolution, op->signal_semaphore_list, op->fill.target_buffer,
+      op->fill.target_offset, op->fill.length, op->fill.pattern_bits,
+      op->fill.pattern_length, op->fill.flags,
+      IREE_HAL_AMDGPU_HOST_QUEUE_SUBMISSION_FLAG_NONE, &issue->ready);
+  if (iree_status_is_ok(status) && issue->ready) {
+    op->retained_resource_count = 0;
+  }
+  return status;
+}
+
+static iree_status_t iree_hal_amdgpu_pending_op_issue_copy(
+    iree_hal_amdgpu_pending_op_t* op,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    iree_hal_amdgpu_pending_op_payload_issue_t* issue) {
+  iree_status_t status = iree_hal_amdgpu_host_queue_submit_copy(
+      op->queue, resolution, op->signal_semaphore_list, op->copy.source_buffer,
+      op->copy.source_offset, op->copy.target_buffer, op->copy.target_offset,
+      op->copy.length, op->copy.flags, op->copy.profile_event_type,
+      IREE_HAL_AMDGPU_HOST_QUEUE_SUBMISSION_FLAG_NONE, &issue->ready);
+  if (iree_status_is_ok(status) && issue->ready) {
+    op->retained_resource_count = 0;
+  }
+  return status;
+}
+
+static iree_status_t iree_hal_amdgpu_pending_op_issue_update(
+    iree_hal_amdgpu_pending_op_t* op,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    iree_hal_amdgpu_pending_op_payload_issue_t* issue) {
+  iree_status_t status = iree_hal_amdgpu_host_queue_submit_update(
+      op->queue, resolution, op->signal_semaphore_list, op->update.source_data,
+      /*source_offset=*/0, op->update.target_buffer, op->update.target_offset,
+      op->update.length, op->update.flags,
+      IREE_HAL_AMDGPU_HOST_QUEUE_SUBMISSION_FLAG_NONE, &issue->ready);
+  if (iree_status_is_ok(status) && issue->ready) {
+    op->retained_resource_count = 0;
+  }
+  return status;
+}
+
+static iree_status_t iree_hal_amdgpu_pending_op_issue_dispatch(
+    iree_hal_amdgpu_pending_op_t* op,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    iree_hal_amdgpu_pending_op_payload_issue_t* issue) {
+  iree_status_t status = iree_hal_amdgpu_host_queue_submit_dispatch(
+      op->queue, resolution, op->signal_semaphore_list, op->dispatch.executable,
+      op->dispatch.export_ordinal, op->dispatch.config, op->dispatch.constants,
+      op->dispatch.bindings, op->dispatch.flags,
+      IREE_HAL_AMDGPU_HOST_QUEUE_SUBMISSION_FLAG_NONE, &issue->ready);
+  if (iree_status_is_ok(status) && issue->ready) {
+    op->retained_resource_count = 0;
+  }
+  return status;
+}
+
+static iree_status_t iree_hal_amdgpu_pending_op_issue_execute(
+    iree_hal_amdgpu_pending_op_t* op,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    iree_hal_amdgpu_pending_op_payload_issue_t* issue) {
+  if (op->execute.command_buffer) {
+    iree_status_t status = iree_hal_amdgpu_host_queue_submit_command_buffer(
+        op->queue, resolution, op->signal_semaphore_list,
+        op->execute.command_buffer, op->execute.binding_table,
+        op->execute.flags, &op->execute.binding_resource_set, &issue->ready);
+    if (iree_status_is_ok(status) && issue->ready) {
+      iree_hal_amdgpu_pending_op_release_retained(op);
+    }
+    return status;
+  }
+
+  iree_status_t status = iree_hal_amdgpu_host_queue_try_submit_barrier(
+      op->queue, resolution, op->signal_semaphore_list,
+      (iree_hal_amdgpu_reclaim_action_t){0},
+      /*operation_resources=*/NULL,
+      /*operation_resource_count=*/0,
+      /*profile_event_info=*/NULL,
+      iree_hal_amdgpu_host_queue_post_commit_callback_null(),
+      /*resource_set=*/NULL, IREE_HAL_AMDGPU_HOST_QUEUE_SUBMISSION_FLAG_NONE,
+      &issue->ready,
+      /*out_submission_id=*/NULL);
+  if (iree_status_is_ok(status) && issue->ready) {
+    op->retained_resource_count = 0;
+  }
+  return status;
+}
+
+static iree_status_t iree_hal_amdgpu_pending_op_issue_alloca(
+    iree_hal_amdgpu_pending_op_t* op,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    iree_hal_amdgpu_pending_op_payload_issue_t* issue) {
+  iree_status_t status = iree_hal_amdgpu_host_queue_submit_alloca(
+      op->queue, resolution, op->signal_semaphore_list, op->alloca_op.pool,
+      op->alloca_op.params, op->alloca_op.allocation_size, op->alloca_op.flags,
+      op->alloca_op.reserve_flags, op->alloca_op.buffer,
+      IREE_HAL_AMDGPU_HOST_QUEUE_SUBMISSION_FLAG_NONE, op,
+      &issue->memory_wait_op, &issue->ready);
+  if (iree_status_is_ok(status) && issue->ready) {
+    op->retained_resource_count = 0;
+  }
+  return status;
+}
+
+static iree_status_t iree_hal_amdgpu_pending_op_issue_dealloca(
+    iree_hal_amdgpu_pending_op_t* op,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    iree_hal_amdgpu_pending_op_payload_issue_t* issue) {
+  iree_status_t status = iree_hal_amdgpu_host_queue_submit_dealloca(
+      op->queue, resolution, op->signal_semaphore_list, op->dealloca.buffer,
+      IREE_HAL_AMDGPU_HOST_QUEUE_SUBMISSION_FLAG_NONE, &issue->ready);
+  if (iree_status_is_ok(status) && issue->ready) {
+    op->retained_resource_count = 0;
+  }
+  return status;
+}
+
+static iree_status_t iree_hal_amdgpu_pending_op_issue_host_call(
+    iree_hal_amdgpu_pending_op_t* op,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    iree_hal_amdgpu_pending_op_payload_issue_t* issue) {
+  iree_status_t status = iree_hal_amdgpu_host_queue_submit_host_call(
+      op->queue, resolution, op->signal_semaphore_list, op->host_call.call,
+      op->host_call.args, op->host_call.flags, &issue->ready);
+  if (iree_status_is_ok(status) && issue->ready) {
+    iree_hal_amdgpu_pending_op_release_retained(op);
+  }
+  return status;
+}
+
+static iree_status_t iree_hal_amdgpu_pending_op_issue_host_action(
+    iree_hal_amdgpu_pending_op_t* op,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    iree_hal_amdgpu_pending_op_payload_issue_t* issue) {
+  iree_status_t status = iree_hal_amdgpu_host_queue_try_submit_barrier(
+      op->queue, resolution, iree_hal_semaphore_list_empty(),
+      op->host_action.action, op->retained_resources,
+      op->retained_resource_count,
+      /*profile_event_info=*/NULL,
+      iree_hal_amdgpu_host_queue_post_commit_callback_null(),
+      /*resource_set=*/NULL, IREE_HAL_AMDGPU_HOST_QUEUE_SUBMISSION_FLAG_NONE,
+      &issue->ready,
+      /*out_submission_id=*/NULL);
+  if (iree_status_is_ok(status) && issue->ready) {
+    op->retained_resource_count = 0;
+  }
+  return status;
+}
+
+iree_status_t iree_hal_amdgpu_pending_op_issue_payload(
+    iree_hal_amdgpu_pending_op_t* op,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    iree_hal_amdgpu_pending_op_payload_issue_t* issue) {
+  switch (op->type) {
+    case IREE_HAL_AMDGPU_PENDING_OP_FILL:
+      return iree_hal_amdgpu_pending_op_issue_fill(op, resolution, issue);
+    case IREE_HAL_AMDGPU_PENDING_OP_COPY:
+      return iree_hal_amdgpu_pending_op_issue_copy(op, resolution, issue);
+    case IREE_HAL_AMDGPU_PENDING_OP_UPDATE:
+      return iree_hal_amdgpu_pending_op_issue_update(op, resolution, issue);
+    case IREE_HAL_AMDGPU_PENDING_OP_DISPATCH:
+      return iree_hal_amdgpu_pending_op_issue_dispatch(op, resolution, issue);
+    case IREE_HAL_AMDGPU_PENDING_OP_EXECUTE:
+      return iree_hal_amdgpu_pending_op_issue_execute(op, resolution, issue);
+    case IREE_HAL_AMDGPU_PENDING_OP_ALLOCA:
+      return iree_hal_amdgpu_pending_op_issue_alloca(op, resolution, issue);
+    case IREE_HAL_AMDGPU_PENDING_OP_DEALLOCA:
+      return iree_hal_amdgpu_pending_op_issue_dealloca(op, resolution, issue);
+    case IREE_HAL_AMDGPU_PENDING_OP_HOST_CALL:
+      return iree_hal_amdgpu_pending_op_issue_host_call(op, resolution, issue);
+    case IREE_HAL_AMDGPU_PENDING_OP_HOST_ACTION:
+      return iree_hal_amdgpu_pending_op_issue_host_action(op, resolution,
+                                                          issue);
+    default:
+      return iree_make_status(IREE_STATUS_INTERNAL,
+                              "unrecognized pending op type %d", op->type);
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// Pending operation payload capture
+//===----------------------------------------------------------------------===//
+
+iree_status_t iree_hal_amdgpu_host_queue_defer_alloca(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t* wait_semaphore_list,
+    const iree_hal_semaphore_list_t* signal_semaphore_list,
+    iree_hal_pool_t* pool, iree_hal_buffer_params_t params,
+    iree_device_size_t allocation_size, iree_hal_alloca_flags_t flags,
+    iree_hal_pool_reserve_flags_t reserve_flags, iree_hal_buffer_t* buffer,
+    iree_hal_amdgpu_pending_op_t** out_op) {
+  uint16_t max_resources = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_count_reclaim_resources(
+      signal_semaphore_list->count,
+      /*operation_resource_count=*/1, &max_resources));
+  iree_hal_amdgpu_pending_op_t* op = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_pending_op_allocate(
+      queue, wait_semaphore_list, signal_semaphore_list,
+      IREE_HAL_AMDGPU_PENDING_OP_ALLOCA, max_resources, &op));
+  iree_hal_amdgpu_pending_op_retain(op, (iree_hal_resource_t*)buffer);
+  op->alloca_op.pool = pool;
+  op->alloca_op.params = params;
+  op->alloca_op.allocation_size = allocation_size;
+  op->alloca_op.flags = flags;
+  op->alloca_op.reserve_flags = reserve_flags;
+  op->alloca_op.buffer = buffer;
+  *out_op = op;
+  return iree_ok_status();
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_defer_dealloca(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t* wait_semaphore_list,
+    const iree_hal_semaphore_list_t* signal_semaphore_list,
+    iree_hal_buffer_t* buffer, iree_hal_amdgpu_pending_op_t** out_op) {
+  uint16_t max_resources = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_count_reclaim_resources(
+      signal_semaphore_list->count,
+      /*operation_resource_count=*/1, &max_resources));
+  iree_hal_amdgpu_pending_op_t* op = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_pending_op_allocate(
+      queue, wait_semaphore_list, signal_semaphore_list,
+      IREE_HAL_AMDGPU_PENDING_OP_DEALLOCA, max_resources, &op));
+  iree_hal_amdgpu_pending_op_retain(op, (iree_hal_resource_t*)buffer);
+  op->dealloca.buffer = buffer;
+  *out_op = op;
+  return iree_ok_status();
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_defer_fill(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t* wait_semaphore_list,
+    const iree_hal_semaphore_list_t* signal_semaphore_list,
+    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+    iree_device_size_t length, uint64_t pattern_bits,
+    iree_host_size_t pattern_length, iree_hal_fill_flags_t flags,
+    iree_hal_amdgpu_pending_op_t** out_op) {
+  uint16_t max_resources = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_count_reclaim_resources(
+      signal_semaphore_list->count,
+      /*operation_resource_count=*/1, &max_resources));
+  iree_hal_amdgpu_pending_op_t* op = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_pending_op_allocate(
+      queue, wait_semaphore_list, signal_semaphore_list,
+      IREE_HAL_AMDGPU_PENDING_OP_FILL, max_resources, &op));
+  iree_hal_amdgpu_pending_op_retain(op, (iree_hal_resource_t*)target_buffer);
+  op->fill.target_buffer = target_buffer;
+  op->fill.target_offset = target_offset;
+  op->fill.length = length;
+  op->fill.pattern_bits = pattern_bits;
+  op->fill.pattern_length = pattern_length;
+  op->fill.flags = flags;
+  *out_op = op;
+  return iree_ok_status();
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_defer_copy(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t* wait_semaphore_list,
+    const iree_hal_semaphore_list_t* signal_semaphore_list,
+    iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
+    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+    iree_device_size_t length, iree_hal_copy_flags_t flags,
+    iree_hal_profile_queue_event_type_t profile_event_type,
+    iree_hal_amdgpu_pending_op_t** out_op) {
+  uint16_t max_resources = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_count_reclaim_resources(
+      signal_semaphore_list->count,
+      /*operation_resource_count=*/2, &max_resources));
+  iree_hal_amdgpu_pending_op_t* op = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_pending_op_allocate(
+      queue, wait_semaphore_list, signal_semaphore_list,
+      IREE_HAL_AMDGPU_PENDING_OP_COPY, max_resources, &op));
+  iree_hal_amdgpu_pending_op_retain(op, (iree_hal_resource_t*)source_buffer);
+  iree_hal_amdgpu_pending_op_retain(op, (iree_hal_resource_t*)target_buffer);
+  op->copy.source_buffer = source_buffer;
+  op->copy.source_offset = source_offset;
+  op->copy.target_buffer = target_buffer;
+  op->copy.target_offset = target_offset;
+  op->copy.length = length;
+  op->copy.flags = flags;
+  op->copy.profile_event_type = profile_event_type;
+  *out_op = op;
+  return iree_ok_status();
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_defer_update(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t* wait_semaphore_list,
+    const iree_hal_semaphore_list_t* signal_semaphore_list,
+    const void* source_buffer, iree_host_size_t source_offset,
+    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+    iree_device_size_t length, iree_hal_update_flags_t flags,
+    iree_hal_amdgpu_pending_op_t** out_op) {
+  const uint8_t* source_bytes = NULL;
+  iree_host_size_t source_length = 0;
+  uint8_t* target_device_ptr = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_prepare_update_copy(
+      target_buffer, target_offset, source_buffer, source_offset, length, flags,
+      &source_bytes, &source_length, &target_device_ptr));
+  (void)target_device_ptr;
+
+  uint16_t max_resources = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_count_reclaim_resources(
+      signal_semaphore_list->count,
+      /*operation_resource_count=*/1, &max_resources));
+  iree_hal_amdgpu_pending_op_t* op = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_pending_op_allocate(
+      queue, wait_semaphore_list, signal_semaphore_list,
+      IREE_HAL_AMDGPU_PENDING_OP_UPDATE, max_resources, &op));
+  iree_hal_amdgpu_pending_op_retain(op, (iree_hal_resource_t*)target_buffer);
+
+  void* source_copy = NULL;
+  iree_status_t status =
+      iree_arena_allocate(&op->arena, source_length, &source_copy);
+  if (!iree_status_is_ok(status)) {
+    iree_hal_amdgpu_pending_op_destroy_under_lock(op, status);
+    return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
+                            "arena allocation failed during defer_update");
+  }
+  memcpy(source_copy, source_bytes, source_length);
+  op->update.source_data = source_copy;
+  op->update.target_buffer = target_buffer;
+  op->update.target_offset = target_offset;
+  op->update.length = length;
+  op->update.flags = flags;
+  *out_op = op;
+  return iree_ok_status();
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_defer_execute(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t* wait_semaphore_list,
+    const iree_hal_semaphore_list_t* signal_semaphore_list,
+    iree_hal_command_buffer_t* command_buffer,
+    iree_hal_buffer_binding_table_t binding_table,
+    iree_hal_execute_flags_t flags, iree_hal_amdgpu_pending_op_t** out_op) {
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_host_queue_validate_execute_flags(flags));
+  if (IREE_UNLIKELY(!command_buffer && binding_table.count != 0)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "barrier-only queue_execute must not provide a binding table "
+        "(count=%" PRIhsz ")",
+        binding_table.count);
+  }
+  const iree_host_size_t binding_count =
+      command_buffer ? command_buffer->binding_count : 0;
+  if (command_buffer && command_buffer->binding_count == 0) {
+    binding_table = iree_hal_buffer_binding_table_empty();
+  }
+
+  iree_hal_resource_set_t* binding_resource_set = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_host_queue_create_binding_table_resource_set(
+          queue, command_buffer, binding_table, flags, &binding_resource_set));
+
+  const iree_host_size_t operation_resource_count = command_buffer ? 1 : 0;
+  uint16_t max_resources = 0;
+  iree_hal_amdgpu_pending_op_t* op = NULL;
+  iree_status_t status = iree_hal_amdgpu_host_queue_count_reclaim_resources(
+      signal_semaphore_list->count, operation_resource_count, &max_resources);
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_pending_op_allocate(
+        queue, wait_semaphore_list, signal_semaphore_list,
+        IREE_HAL_AMDGPU_PENDING_OP_EXECUTE, max_resources, &op);
+  }
+  if (iree_status_is_ok(status)) {
+    iree_hal_amdgpu_pending_op_retain(op, (iree_hal_resource_t*)command_buffer);
+    op->execute.command_buffer = command_buffer;
+    op->execute.binding_resource_set = binding_resource_set;
+    binding_resource_set = NULL;
+    op->execute.flags = flags;
+  }
+
+  if (iree_status_is_ok(status) && binding_count > 0) {
+    iree_hal_buffer_binding_t* bindings_copy = NULL;
+    iree_host_size_t binding_table_size = 0;
+    status = IREE_STRUCT_LAYOUT(
+        0, &binding_table_size,
+        IREE_STRUCT_FIELD(binding_count, iree_hal_buffer_binding_t, NULL));
+    if (iree_status_is_ok(status)) {
+      status = iree_arena_allocate(&op->arena, binding_table_size,
+                                   (void**)&bindings_copy);
+    }
+    if (iree_status_is_ok(status)) {
+      memcpy(bindings_copy, binding_table.bindings, binding_table_size);
+      op->execute.binding_table.count = binding_count;
+      op->execute.binding_table.bindings = bindings_copy;
+    }
+  }
+
+  if (iree_status_is_ok(status)) {
+    *out_op = op;
+  } else {
+    iree_hal_resource_set_free(binding_resource_set);
+    if (op) {
+      iree_hal_amdgpu_pending_op_destroy_under_lock(op,
+                                                    iree_status_clone(status));
+    }
+  }
+  return status;
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_defer_dispatch(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t* wait_semaphore_list,
+    const iree_hal_semaphore_list_t* signal_semaphore_list,
+    iree_hal_executable_t* executable,
+    iree_hal_executable_export_ordinal_t export_ordinal,
+    const iree_hal_dispatch_config_t config, iree_const_byte_span_t constants,
+    const iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags,
+    iree_hal_amdgpu_pending_op_t** out_op) {
+  iree_host_size_t operation_resource_count = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_validate_dispatch(
+      queue, executable, export_ordinal, config, constants, bindings, flags,
+      &operation_resource_count));
+  uint16_t max_resources = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_count_reclaim_resources(
+      signal_semaphore_list->count, operation_resource_count, &max_resources));
+  iree_hal_amdgpu_pending_op_t* op = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_pending_op_allocate(
+      queue, wait_semaphore_list, signal_semaphore_list,
+      IREE_HAL_AMDGPU_PENDING_OP_DISPATCH, max_resources, &op));
+  iree_hal_amdgpu_pending_op_retain(op, (iree_hal_resource_t*)executable);
+  op->dispatch.executable = executable;
+  op->dispatch.export_ordinal = export_ordinal;
+  op->dispatch.config = config;
+  op->dispatch.flags = flags;
+
+  iree_status_t status = iree_ok_status();
+  if (constants.data_length > 0) {
+    void* constants_copy = NULL;
+    status =
+        iree_arena_allocate(&op->arena, constants.data_length, &constants_copy);
+    if (iree_status_is_ok(status)) {
+      memcpy(constants_copy, constants.data, constants.data_length);
+      op->dispatch.constants.data = (const uint8_t*)constants_copy;
+      op->dispatch.constants.data_length = constants.data_length;
+    }
+  }
+
+  if (iree_status_is_ok(status) && bindings.count > 0 &&
+      !iree_any_bit_set(flags,
+                        IREE_HAL_DISPATCH_FLAG_CUSTOM_DIRECT_ARGUMENTS)) {
+    iree_hal_buffer_ref_t* bindings_copy = NULL;
+    iree_host_size_t binding_ref_size = 0;
+    status = IREE_STRUCT_LAYOUT(
+        0, &binding_ref_size,
+        IREE_STRUCT_FIELD(bindings.count, iree_hal_buffer_ref_t, NULL));
+    if (iree_status_is_ok(status)) {
+      status = iree_arena_allocate(&op->arena, binding_ref_size,
+                                   (void**)&bindings_copy);
+    }
+    if (iree_status_is_ok(status)) {
+      memcpy(bindings_copy, bindings.values, binding_ref_size);
+      for (iree_host_size_t i = 0; i < bindings.count; ++i) {
+        iree_hal_amdgpu_pending_op_retain(
+            op, (iree_hal_resource_t*)bindings_copy[i].buffer);
+      }
+      op->dispatch.bindings.count = bindings.count;
+      op->dispatch.bindings.values = bindings_copy;
+    }
+  }
+
+  if (iree_status_is_ok(status)) {
+    *out_op = op;
+  } else {
+    iree_hal_amdgpu_pending_op_destroy_under_lock(op,
+                                                  iree_status_clone(status));
+  }
+  return status;
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_defer_host_action(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t* wait_semaphore_list,
+    iree_hal_amdgpu_reclaim_action_t action,
+    iree_hal_resource_t* const* operation_resources,
+    iree_host_size_t operation_resource_count,
+    iree_hal_amdgpu_pending_op_t** out_op) {
+  uint16_t max_resources = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_count_reclaim_resources(
+      /*signal_semaphore_count=*/0,
+      /*operation_resource_count=*/operation_resource_count, &max_resources));
+  iree_hal_amdgpu_pending_op_t* op = NULL;
+  const iree_hal_semaphore_list_t empty_signal_list =
+      iree_hal_semaphore_list_empty();
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_pending_op_allocate(
+      queue, wait_semaphore_list, &empty_signal_list,
+      IREE_HAL_AMDGPU_PENDING_OP_HOST_ACTION, max_resources, &op));
+  for (iree_host_size_t i = 0; i < operation_resource_count; ++i) {
+    iree_hal_amdgpu_pending_op_retain(op, operation_resources[i]);
+  }
+  op->host_action.action = action;
+  *out_op = op;
+  return iree_ok_status();
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_defer_host_call(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t* wait_semaphore_list,
+    const iree_hal_semaphore_list_t* signal_semaphore_list,
+    iree_hal_host_call_t call, const uint64_t args[4],
+    iree_hal_host_call_flags_t flags, iree_hal_amdgpu_pending_op_t** out_op) {
+  uint16_t max_resources = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_count_reclaim_resources(
+      signal_semaphore_list->count,
+      /*operation_resource_count=*/0, &max_resources));
+  iree_hal_amdgpu_pending_op_t* op = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_pending_op_allocate(
+      queue, wait_semaphore_list, signal_semaphore_list,
+      IREE_HAL_AMDGPU_PENDING_OP_HOST_CALL, max_resources, &op));
+  op->host_call.call = call;
+  memcpy(op->host_call.args, args, sizeof(op->host_call.args));
+  op->host_call.flags = flags;
+  *out_op = op;
+  return iree_ok_status();
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_pending_test.cc b/runtime/src/iree/hal/drivers/amdgpu/host_queue_pending_test.cc
new file mode 100644
index 0000000..e037e5f
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_pending_test.cc
@@ -0,0 +1,788 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/host_queue_pending.h"
+
+#include <cstdint>
+#include <cstring>
+
+#include "iree/async/frontier.h"
+#include "iree/base/internal/atomics.h"
+#include "iree/base/threading/notification.h"
+#include "iree/hal/api.h"
+#include "iree/hal/cts/util/test_base.h"
+#include "iree/hal/drivers/amdgpu/host_queue.h"
+#include "iree/hal/drivers/amdgpu/logical_device.h"
+#include "iree/hal/drivers/amdgpu/physical_device.h"
+#include "iree/hal/drivers/amdgpu/util/aql_emitter.h"
+#include "iree/hal/memory/fixed_block_pool.h"
+#include "iree/hal/memory/tlsf_pool.h"
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace iree::hal::amdgpu {
+namespace {
+
+using iree::hal::cts::Ref;
+
+constexpr iree_hal_queue_affinity_t kQueueAffinity0 =
+    ((iree_hal_queue_affinity_t)1ull) << 0;
+
+class HostQueuePendingTest : public ::testing::Test {
+ protected:
+  static void SetUpTestSuite() {
+    host_allocator_ = iree_allocator_system();
+    iree_status_t status = iree_hal_amdgpu_libhsa_initialize(
+        IREE_HAL_AMDGPU_LIBHSA_FLAG_NONE, iree_string_view_list_empty(),
+        host_allocator_, &libhsa_);
+    if (!iree_status_is_ok(status)) {
+      iree_status_fprint(stderr, status);
+      iree_status_free(status);
+      GTEST_SKIP() << "HSA not available, skipping tests";
+    }
+    IREE_ASSERT_OK(iree_hal_amdgpu_topology_initialize_with_defaults(
+        &libhsa_, &topology_));
+    if (topology_.gpu_agent_count == 0) {
+      GTEST_SKIP() << "no GPU devices available, skipping tests";
+    }
+  }
+
+  static void TearDownTestSuite() {
+    iree_hal_amdgpu_topology_deinitialize(&topology_);
+    iree_hal_amdgpu_libhsa_deinitialize(&libhsa_);
+  }
+
+  static iree_allocator_t host_allocator_;
+  static iree_hal_amdgpu_libhsa_t libhsa_;
+  static iree_hal_amdgpu_topology_t topology_;
+};
+
+iree_allocator_t HostQueuePendingTest::host_allocator_;
+iree_hal_amdgpu_libhsa_t HostQueuePendingTest::libhsa_;
+iree_hal_amdgpu_topology_t HostQueuePendingTest::topology_;
+
+class TestLogicalDevice {
+ public:
+  ~TestLogicalDevice() {
+    iree_hal_device_release(base_device_);
+    iree_hal_device_group_release(device_group_);
+  }
+
+  iree_status_t Initialize(
+      const iree_hal_amdgpu_logical_device_options_t* options,
+      const iree_hal_amdgpu_libhsa_t* libhsa,
+      const iree_hal_amdgpu_topology_t* topology,
+      iree_allocator_t host_allocator) {
+    IREE_RETURN_IF_ERROR(create_context_.Initialize(host_allocator));
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_logical_device_create(
+        IREE_SV("amdgpu"), options, libhsa, topology, create_context_.params(),
+        host_allocator, &base_device_));
+    return iree_hal_device_group_create_from_device(
+        base_device_, create_context_.frontier_tracker(), host_allocator,
+        &device_group_);
+  }
+
+  iree_hal_device_t* base_device() const { return base_device_; }
+
+  iree_hal_allocator_t* allocator() const {
+    return iree_hal_device_allocator(base_device_);
+  }
+
+  iree_hal_amdgpu_logical_device_t* logical_device() const {
+    return (iree_hal_amdgpu_logical_device_t*)base_device_;
+  }
+
+  iree_hal_amdgpu_host_queue_t* first_host_queue() const {
+    iree_hal_amdgpu_logical_device_t* logical_device = this->logical_device();
+    if (logical_device->physical_device_count == 0) return NULL;
+    iree_hal_amdgpu_physical_device_t* physical_device =
+        logical_device->physical_devices[0];
+    if (physical_device->host_queue_count == 0) return NULL;
+    return &physical_device->host_queues[0];
+  }
+
+ private:
+  // Creation context supplying the proactor pool and frontier tracker.
+  iree::hal::cts::DeviceCreateContext create_context_;
+
+  // Test-owned device reference released before the topology-owning group.
+  iree_hal_device_t* base_device_ = NULL;
+
+  // Device group that owns the topology assigned to |base_device_|.
+  iree_hal_device_group_t* device_group_ = NULL;
+};
+
+static iree_hal_buffer_params_t MakeTransientBufferParams() {
+  iree_hal_buffer_params_t params = {0};
+  params.type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE;
+  params.access = IREE_HAL_MEMORY_ACCESS_ALL;
+  params.usage =
+      IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE;
+  return params;
+}
+
+static iree_hal_buffer_params_t MakeHostLocalMappedTransientBufferParams(
+    iree_hal_memory_type_t extra_memory_type) {
+  iree_hal_buffer_params_t params = {0};
+  params.type = IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
+                IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE | extra_memory_type;
+  params.access = IREE_HAL_MEMORY_ACCESS_ALL;
+  params.usage = IREE_HAL_BUFFER_USAGE_TRANSFER |
+                 IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE |
+                 IREE_HAL_BUFFER_USAGE_MAPPING;
+  return params;
+}
+
+static iree_status_t CreateHostVisibleTransferBuffer(
+    iree_hal_allocator_t* allocator, iree_device_size_t buffer_size,
+    iree_hal_buffer_t** out_buffer) {
+  iree_hal_buffer_params_t params = {0};
+  params.type =
+      IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL | IREE_HAL_MEMORY_TYPE_HOST_VISIBLE;
+  params.usage = IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING;
+  return iree_hal_allocator_allocate_buffer(allocator, params, buffer_size,
+                                            out_buffer);
+}
+
+static iree_status_t CreateSemaphore(iree_hal_device_t* device,
+                                     iree_hal_semaphore_t** out_semaphore) {
+  return iree_hal_semaphore_create(
+      device, IREE_HAL_QUEUE_AFFINITY_ANY,
+      /*initial_value=*/0, IREE_HAL_SEMAPHORE_FLAG_DEFAULT, out_semaphore);
+}
+
+static iree_hal_semaphore_list_t MakeSemaphoreList(
+    iree_hal_semaphore_t** semaphore, uint64_t* payload_value) {
+  return iree_hal_semaphore_list_t{
+      /*count=*/1,
+      /*semaphores=*/semaphore,
+      /*payload_values=*/payload_value,
+  };
+}
+
+static void RunDefaultPoolServesHostLocalMappedAlloca(
+    const iree_hal_amdgpu_libhsa_t* libhsa,
+    const iree_hal_amdgpu_topology_t* topology, iree_allocator_t host_allocator,
+    iree_hal_memory_type_t extra_memory_type) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, libhsa, topology, host_allocator));
+
+  Ref<iree_hal_semaphore_t> alloca_signal;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), alloca_signal.out()));
+  uint64_t alloca_signal_value = 1;
+  iree_hal_semaphore_t* alloca_signal_ptr = alloca_signal.get();
+  const iree_hal_semaphore_list_t alloca_signal_list =
+      MakeSemaphoreList(&alloca_signal_ptr, &alloca_signal_value);
+
+  iree_hal_buffer_t* buffer = NULL;
+  IREE_ASSERT_OK(iree_hal_device_queue_alloca(
+      test_device.base_device(), kQueueAffinity0,
+      iree_hal_semaphore_list_empty(), alloca_signal_list, /*pool=*/NULL,
+      MakeHostLocalMappedTransientBufferParams(extra_memory_type),
+      /*allocation_size=*/8, IREE_HAL_ALLOCA_FLAG_NONE, &buffer));
+  ASSERT_NE(buffer, nullptr);
+  IREE_ASSERT_OK(iree_hal_semaphore_wait(alloca_signal, alloca_signal_value,
+                                         iree_infinite_timeout(),
+                                         IREE_ASYNC_WAIT_FLAG_NONE));
+
+  EXPECT_TRUE(iree_all_bits_set(
+      iree_hal_buffer_memory_type(buffer),
+      IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE));
+  EXPECT_TRUE(iree_all_bits_set(iree_hal_buffer_allowed_usage(buffer),
+                                IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED));
+
+  iree_hal_buffer_mapping_t mapping;
+  IREE_ASSERT_OK(iree_hal_buffer_map_range(
+      buffer, IREE_HAL_MAPPING_MODE_SCOPED, IREE_HAL_MEMORY_ACCESS_WRITE,
+      /*byte_offset=*/0, /*byte_length=*/8, &mapping));
+  memset(mapping.contents.data, 0, 8);
+  iree_hal_buffer_unmap_range(&mapping);
+
+  Ref<iree_hal_semaphore_t> dealloca_signal;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), dealloca_signal.out()));
+  uint64_t dealloca_signal_value = 1;
+  iree_hal_semaphore_t* dealloca_signal_ptr = dealloca_signal.get();
+  const iree_hal_semaphore_list_t dealloca_signal_list =
+      MakeSemaphoreList(&dealloca_signal_ptr, &dealloca_signal_value);
+  IREE_ASSERT_OK(iree_hal_device_queue_dealloca(
+      test_device.base_device(), kQueueAffinity0,
+      iree_hal_semaphore_list_empty(), dealloca_signal_list, buffer,
+      IREE_HAL_DEALLOCA_FLAG_NONE));
+  IREE_ASSERT_OK(iree_hal_semaphore_wait(dealloca_signal, dealloca_signal_value,
+                                         iree_infinite_timeout(),
+                                         IREE_ASYNC_WAIT_FLAG_NONE));
+  iree_hal_buffer_release(buffer);
+}
+
+TEST_F(HostQueuePendingTest, DefaultPoolServesHostLocalMappedAlloca) {
+  RunDefaultPoolServesHostLocalMappedAlloca(
+      &libhsa_, &topology_, host_allocator_, IREE_HAL_MEMORY_TYPE_NONE);
+}
+
+TEST_F(HostQueuePendingTest, DefaultPoolServesOptimalHostLocalMappedAlloca) {
+  RunDefaultPoolServesHostLocalMappedAlloca(
+      &libhsa_, &topology_, host_allocator_, IREE_HAL_MEMORY_TYPE_OPTIMAL);
+}
+
+static bool HostQueueHasPendingOps(iree_hal_amdgpu_host_queue_t* queue) {
+  iree_slim_mutex_lock(&queue->locks.submission_mutex);
+  const bool has_pending_ops = queue->pending_head != NULL;
+  iree_slim_mutex_unlock(&queue->locks.submission_mutex);
+  return has_pending_ops;
+}
+
+static bool HostQueueHasPostDrainAction(iree_hal_amdgpu_host_queue_t* queue) {
+  iree_slim_mutex_lock(&queue->locks.post_drain_mutex);
+  const bool has_action = queue->post_drain.head != NULL;
+  iree_slim_mutex_unlock(&queue->locks.post_drain_mutex);
+  return has_action;
+}
+
+static iree_status_t EnqueueRawBlockingBarrier(
+    iree_hal_amdgpu_host_queue_t* queue, hsa_signal_t blocker_signal) {
+  const uint64_t packet_id =
+      iree_hal_amdgpu_aql_ring_reserve(&queue->aql_ring, /*count=*/1);
+  iree_hal_amdgpu_aql_packet_t* packet =
+      iree_hal_amdgpu_aql_ring_packet(&queue->aql_ring, packet_id);
+  const hsa_signal_t dep_signals[1] = {blocker_signal};
+  const uint16_t header = iree_hal_amdgpu_aql_emit_barrier_and(
+      &packet->barrier_and, dep_signals, IREE_ARRAYSIZE(dep_signals),
+      iree_hal_amdgpu_aql_packet_control_barrier_system(),
+      iree_hsa_signal_null());
+  iree_hal_amdgpu_aql_ring_commit(packet, header, /*setup=*/0);
+  iree_hal_amdgpu_aql_ring_doorbell(&queue->aql_ring, packet_id);
+  return iree_ok_status();
+}
+
+static iree_status_t CreateExplicitFixedBlockPool(iree_hal_device_t* device,
+                                                  iree_device_size_t block_size,
+                                                  iree_hal_pool_t** out_pool) {
+  iree_hal_queue_pool_backend_t backend = {0};
+  IREE_RETURN_IF_ERROR(iree_hal_device_query_queue_pool_backend(
+      device, kQueueAffinity0, &backend));
+  if (!backend.slab_provider || !backend.notification) {
+    return iree_make_status(
+        IREE_STATUS_FAILED_PRECONDITION,
+        "queue pool backend query returned an incomplete backend bundle");
+  }
+  iree_hal_fixed_block_pool_options_t options = {};
+  options.block_allocator_options.block_size = block_size;
+  options.block_allocator_options.block_count = 1;
+  options.block_allocator_options.frontier_capacity = 2;
+  return iree_hal_fixed_block_pool_create(
+      options, backend.slab_provider, backend.notification,
+      iree_hal_pool_epoch_query_null(), iree_allocator_system(), out_pool);
+}
+
+static iree_status_t CreateExplicitTlsfPool(iree_hal_device_t* device,
+                                            iree_device_size_t slab_size,
+                                            iree_hal_pool_t** out_pool) {
+  iree_hal_queue_pool_backend_t backend = {0};
+  IREE_RETURN_IF_ERROR(iree_hal_device_query_queue_pool_backend(
+      device, kQueueAffinity0, &backend));
+  if (!backend.slab_provider || !backend.notification) {
+    return iree_make_status(
+        IREE_STATUS_FAILED_PRECONDITION,
+        "queue pool backend query returned an incomplete backend bundle");
+  }
+  iree_hal_tlsf_pool_options_t options = {};
+  options.tlsf_options.range_length = slab_size;
+  options.tlsf_options.alignment = 16;
+  options.tlsf_options.initial_block_capacity = 16;
+  options.tlsf_options.frontier_capacity = 2;
+  return iree_hal_tlsf_pool_create(
+      options, backend.slab_provider, backend.notification,
+      iree_hal_pool_epoch_query_null(), iree_allocator_system(), out_pool);
+}
+
+static iree_status_t SeedWaitableFixedBlockReservation(
+    iree_hal_pool_t* pool, iree_device_size_t allocation_size,
+    iree_async_axis_t death_axis) {
+  iree_hal_pool_reservation_t reservation;
+  iree_hal_pool_acquire_info_t acquire_info;
+  iree_hal_pool_acquire_result_t acquire_result;
+  IREE_RETURN_IF_ERROR(iree_hal_pool_acquire_reservation(
+      pool, allocation_size, /*alignment=*/1, /*requester_frontier=*/NULL,
+      IREE_HAL_POOL_RESERVE_FLAG_NONE, &reservation, &acquire_info,
+      &acquire_result));
+  if (acquire_result != IREE_HAL_POOL_ACQUIRE_OK &&
+      acquire_result != IREE_HAL_POOL_ACQUIRE_OK_FRESH) {
+    return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                            "expected fresh fixed-block reservation");
+  }
+
+  iree_async_single_frontier_t death_frontier;
+  iree_async_single_frontier_initialize(&death_frontier, death_axis, 1);
+  iree_hal_pool_release_reservation(
+      pool, &reservation,
+      iree_async_single_frontier_as_const_frontier(&death_frontier));
+  return iree_ok_status();
+}
+
+typedef struct HostActionState {
+  // Notification posted after the action callback records its result.
+  iree_notification_t notification;
+
+  // Number of times the action callback has run.
+  iree_atomic_int32_t call_count;
+
+  // Last callback status code.
+  iree_atomic_int32_t status_code;
+
+  // Whether the callback received a reclaim entry.
+  iree_atomic_int32_t had_entry;
+} HostActionState;
+
+static void HostActionStateInitialize(HostActionState* state) {
+  iree_notification_initialize(&state->notification);
+  iree_atomic_store(&state->call_count, 0, iree_memory_order_relaxed);
+  iree_atomic_store(&state->status_code, IREE_STATUS_UNKNOWN,
+                    iree_memory_order_relaxed);
+  iree_atomic_store(&state->had_entry, 0, iree_memory_order_relaxed);
+}
+
+static void HostActionStateDeinitialize(HostActionState* state) {
+  iree_notification_deinitialize(&state->notification);
+}
+
+static bool HostActionStateWasCalled(void* user_data) {
+  HostActionState* state = (HostActionState*)user_data;
+  return iree_atomic_load(&state->call_count, iree_memory_order_acquire) != 0;
+}
+
+static void RecordHostAction(iree_hal_amdgpu_reclaim_entry_t* entry,
+                             void* user_data, iree_status_t status) {
+  HostActionState* state = (HostActionState*)user_data;
+  iree_atomic_store(&state->had_entry, entry ? 1 : 0,
+                    iree_memory_order_release);
+  iree_atomic_store(&state->status_code, iree_status_code(status),
+                    iree_memory_order_release);
+  iree_atomic_fetch_add(&state->call_count, 1, iree_memory_order_acq_rel);
+  iree_notification_post(&state->notification, IREE_ALL_WAITERS);
+}
+
+static int32_t HostActionCallCount(HostActionState* state) {
+  return iree_atomic_load(&state->call_count, iree_memory_order_acquire);
+}
+
+static iree_status_code_t HostActionStatusCode(HostActionState* state) {
+  return (iree_status_code_t)iree_atomic_load(&state->status_code,
+                                              iree_memory_order_acquire);
+}
+
+static bool HostActionHadEntry(HostActionState* state) {
+  return iree_atomic_load(&state->had_entry, iree_memory_order_acquire) != 0;
+}
+
+TEST_F(HostQueuePendingTest,
+       DeferredHostActionFailureRunsSynchronousCallbackOnce) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+  iree_hal_amdgpu_host_queue_t* queue = test_device.first_host_queue();
+  ASSERT_NE(queue, nullptr);
+
+  Ref<iree_hal_semaphore_t> wait_semaphore;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), wait_semaphore.out()));
+  iree_hal_semaphore_fail(
+      wait_semaphore,
+      iree_make_status(IREE_STATUS_CANCELLED, "test wait failed"));
+  uint64_t wait_value = 1;
+  iree_hal_semaphore_t* wait_semaphore_ptr = wait_semaphore.get();
+  const iree_hal_semaphore_list_t wait_list =
+      MakeSemaphoreList(&wait_semaphore_ptr, &wait_value);
+
+  HostActionState action_state;
+  HostActionStateInitialize(&action_state);
+  IREE_ASSERT_OK(iree_hal_amdgpu_host_queue_enqueue_host_action(
+      queue, wait_list,
+      iree_hal_amdgpu_reclaim_action_t{
+          .fn = RecordHostAction,
+          .user_data = &action_state,
+      },
+      /*operation_resources=*/NULL, /*operation_resource_count=*/0));
+
+  EXPECT_EQ(HostActionCallCount(&action_state), 1);
+  EXPECT_EQ(HostActionStatusCode(&action_state), IREE_STATUS_CANCELLED);
+  EXPECT_FALSE(HostActionHadEntry(&action_state));
+  EXPECT_FALSE(HostQueueHasPendingOps(queue));
+
+  HostActionStateDeinitialize(&action_state);
+}
+
+TEST_F(HostQueuePendingTest, CancelPendingFillFailsSignalSemaphore) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+  iree_hal_amdgpu_host_queue_t* queue = test_device.first_host_queue();
+  ASSERT_NE(queue, nullptr);
+
+  Ref<iree_hal_buffer_t> target_buffer;
+  IREE_ASSERT_OK(CreateHostVisibleTransferBuffer(
+      test_device.allocator(), sizeof(uint32_t), target_buffer.out()));
+
+  Ref<iree_hal_semaphore_t> wait_semaphore;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), wait_semaphore.out()));
+  uint64_t wait_value = 1;
+  iree_hal_semaphore_t* wait_semaphore_ptr = wait_semaphore.get();
+  const iree_hal_semaphore_list_t wait_list =
+      MakeSemaphoreList(&wait_semaphore_ptr, &wait_value);
+
+  Ref<iree_hal_semaphore_t> signal_semaphore;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), signal_semaphore.out()));
+  uint64_t signal_value = 1;
+  iree_hal_semaphore_t* signal_semaphore_ptr = signal_semaphore.get();
+  const iree_hal_semaphore_list_t signal_list =
+      MakeSemaphoreList(&signal_semaphore_ptr, &signal_value);
+
+  const uint32_t pattern = 0xCACE1100u;
+  IREE_ASSERT_OK(iree_hal_device_queue_fill(
+      test_device.base_device(), IREE_HAL_QUEUE_AFFINITY_ANY, wait_list,
+      signal_list, target_buffer, /*target_offset=*/0, sizeof(pattern),
+      &pattern, sizeof(pattern), IREE_HAL_FILL_FLAG_NONE));
+  ASSERT_TRUE(HostQueueHasPendingOps(queue));
+
+  iree_hal_amdgpu_host_queue_cancel_pending(queue, IREE_STATUS_CANCELLED,
+                                            "test cancellation");
+  EXPECT_FALSE(HostQueueHasPendingOps(queue));
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_CANCELLED,
+                        iree_hal_semaphore_wait(signal_semaphore, signal_value,
+                                                iree_infinite_timeout(),
+                                                IREE_ASYNC_WAIT_FLAG_NONE));
+  IREE_EXPECT_OK(
+      iree_hal_semaphore_signal(wait_semaphore, wait_value, /*frontier=*/NULL));
+}
+
+TEST_F(HostQueuePendingTest, CapacityParkedHostActionRetriesAfterPostDrain) {
+  static constexpr uint32_t kAqlCapacity = 64;
+  static constexpr uint32_t kNotificationCapacity = 1;
+  static constexpr uint32_t kKernargCapacity = 2 * kAqlCapacity;
+
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.host_queues.aql_capacity = kAqlCapacity;
+  options.host_queues.notification_capacity = kNotificationCapacity;
+  options.host_queues.kernarg_capacity = kKernargCapacity;
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+  iree_hal_amdgpu_host_queue_t* queue = test_device.first_host_queue();
+  ASSERT_NE(queue, nullptr);
+
+  Ref<iree_hal_buffer_t> pressure_buffer;
+  IREE_ASSERT_OK(CreateHostVisibleTransferBuffer(
+      test_device.allocator(), sizeof(uint32_t), pressure_buffer.out()));
+
+  hsa_signal_t blocker_signal = iree_hsa_signal_null();
+  IREE_ASSERT_OK(iree_hsa_amd_signal_create(
+      IREE_LIBHSA(&libhsa_), /*initial_value=*/1, /*num_consumers=*/0,
+      /*consumers=*/NULL, /*attributes=*/0, &blocker_signal));
+  IREE_ASSERT_OK(EnqueueRawBlockingBarrier(queue, blocker_signal));
+
+  Ref<iree_hal_semaphore_t> pressure_signal;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), pressure_signal.out()));
+  uint64_t pressure_signal_value = 1;
+  iree_hal_semaphore_t* pressure_signal_ptr = pressure_signal.get();
+  const iree_hal_semaphore_list_t pressure_signal_list =
+      MakeSemaphoreList(&pressure_signal_ptr, &pressure_signal_value);
+  const uint32_t pressure_pattern = 0xABCD1234u;
+  iree_status_t status = iree_hal_device_queue_fill(
+      test_device.base_device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+      iree_hal_semaphore_list_empty(), pressure_signal_list, pressure_buffer,
+      /*target_offset=*/0, sizeof(pressure_pattern), &pressure_pattern,
+      sizeof(pressure_pattern), IREE_HAL_FILL_FLAG_NONE);
+
+  HostActionState action_state;
+  HostActionStateInitialize(&action_state);
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_host_queue_enqueue_host_action(
+        queue, iree_hal_semaphore_list_empty(),
+        iree_hal_amdgpu_reclaim_action_t{
+            .fn = RecordHostAction,
+            .user_data = &action_state,
+        },
+        /*operation_resources=*/NULL, /*operation_resource_count=*/0);
+  }
+  const bool retry_parked =
+      iree_status_is_ok(status) && HostQueueHasPostDrainAction(queue);
+
+  iree_hsa_signal_store_screlease(IREE_LIBHSA(&libhsa_), blocker_signal, 0);
+
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_semaphore_wait(pressure_signal, pressure_signal_value,
+                                     iree_infinite_timeout(),
+                                     IREE_ASYNC_WAIT_FLAG_NONE);
+  }
+  if (iree_status_is_ok(status)) {
+    ASSERT_TRUE(iree_notification_await(&action_state.notification,
+                                        HostActionStateWasCalled, &action_state,
+                                        iree_infinite_timeout()));
+  }
+  IREE_EXPECT_OK(
+      iree_hsa_signal_destroy(IREE_LIBHSA(&libhsa_), blocker_signal));
+
+  IREE_ASSERT_OK(status);
+  EXPECT_TRUE(retry_parked);
+  EXPECT_EQ(HostActionCallCount(&action_state), 1);
+  EXPECT_EQ(HostActionStatusCode(&action_state), IREE_STATUS_OK);
+  EXPECT_TRUE(HostActionHadEntry(&action_state));
+
+  HostActionStateDeinitialize(&action_state);
+}
+
+TEST_F(HostQueuePendingTest, QueueAllocaRejectsWaitableReservationWithoutFlag) {
+  const iree_device_size_t allocation_size = 4096;
+
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+  iree_hal_amdgpu_host_queue_t* queue = test_device.first_host_queue();
+  ASSERT_NE(queue, nullptr);
+
+  Ref<iree_hal_pool_t> pool;
+  IREE_ASSERT_OK(CreateExplicitFixedBlockPool(test_device.base_device(),
+                                              allocation_size, pool.out()));
+  const iree_async_axis_t death_axis =
+      iree_async_axis_make_queue(0xFE, 0xFE, 0xFE, 0xFE);
+  IREE_ASSERT_OK(
+      SeedWaitableFixedBlockReservation(pool, allocation_size, death_axis));
+
+  Ref<iree_hal_semaphore_t> alloca_signal;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), alloca_signal.out()));
+  uint64_t alloca_signal_value = 1;
+  iree_hal_semaphore_t* alloca_signal_ptr = alloca_signal.get();
+  const iree_hal_semaphore_list_t alloca_signal_list =
+      MakeSemaphoreList(&alloca_signal_ptr, &alloca_signal_value);
+
+  iree_hal_buffer_t* buffer = NULL;
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_RESOURCE_EXHAUSTED,
+                        iree_hal_device_queue_alloca(
+                            test_device.base_device(), kQueueAffinity0,
+                            iree_hal_semaphore_list_empty(), alloca_signal_list,
+                            pool, MakeTransientBufferParams(), allocation_size,
+                            IREE_HAL_ALLOCA_FLAG_NONE, &buffer));
+  EXPECT_EQ(buffer, nullptr);
+  EXPECT_FALSE(HostQueueHasPendingOps(queue));
+}
+
+TEST_F(HostQueuePendingTest, QueueAllocaTlsfGrowthRetriesThroughColdPath) {
+  const iree_device_size_t allocation_size = 4096;
+
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+  iree_hal_amdgpu_host_queue_t* queue = test_device.first_host_queue();
+  ASSERT_NE(queue, nullptr);
+
+  Ref<iree_hal_pool_t> pool;
+  IREE_ASSERT_OK(CreateExplicitTlsfPool(test_device.base_device(),
+                                        allocation_size, pool.out()));
+
+  Ref<iree_hal_semaphore_t> alloca0_signal;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), alloca0_signal.out()));
+  uint64_t alloca0_signal_value = 1;
+  iree_hal_semaphore_t* alloca0_signal_ptr = alloca0_signal.get();
+  const iree_hal_semaphore_list_t alloca0_signal_list =
+      MakeSemaphoreList(&alloca0_signal_ptr, &alloca0_signal_value);
+
+  iree_hal_buffer_t* buffer0 = NULL;
+  IREE_ASSERT_OK(iree_hal_device_queue_alloca(
+      test_device.base_device(), kQueueAffinity0,
+      iree_hal_semaphore_list_empty(), alloca0_signal_list, pool,
+      MakeTransientBufferParams(), allocation_size, IREE_HAL_ALLOCA_FLAG_NONE,
+      &buffer0));
+  ASSERT_NE(buffer0, nullptr);
+  IREE_ASSERT_OK(iree_hal_semaphore_wait(alloca0_signal, alloca0_signal_value,
+                                         iree_infinite_timeout(),
+                                         IREE_ASYNC_WAIT_FLAG_NONE));
+
+  Ref<iree_hal_semaphore_t> alloca1_signal;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), alloca1_signal.out()));
+  uint64_t alloca1_signal_value = 1;
+  iree_hal_semaphore_t* alloca1_signal_ptr = alloca1_signal.get();
+  const iree_hal_semaphore_list_t alloca1_signal_list =
+      MakeSemaphoreList(&alloca1_signal_ptr, &alloca1_signal_value);
+
+  iree_hal_buffer_t* buffer1 = NULL;
+  IREE_ASSERT_OK(iree_hal_device_queue_alloca(
+      test_device.base_device(), kQueueAffinity0,
+      iree_hal_semaphore_list_empty(), alloca1_signal_list, pool,
+      MakeTransientBufferParams(), allocation_size, IREE_HAL_ALLOCA_FLAG_NONE,
+      &buffer1));
+  ASSERT_NE(buffer1, nullptr);
+  IREE_ASSERT_OK(iree_hal_semaphore_wait(alloca1_signal, alloca1_signal_value,
+                                         iree_infinite_timeout(),
+                                         IREE_ASYNC_WAIT_FLAG_NONE));
+  EXPECT_FALSE(HostQueueHasPendingOps(queue));
+
+  iree_hal_pool_stats_t stats;
+  iree_hal_pool_query_stats(pool, &stats);
+  EXPECT_GE(stats.slab_count, 2u);
+  EXPECT_GE(stats.exhausted_count, 1u);
+
+  iree_hal_buffer_release(buffer1);
+  iree_hal_buffer_release(buffer0);
+}
+
+TEST_F(HostQueuePendingTest, CancelPendingAllocaFrontierWait) {
+  const iree_device_size_t allocation_size = 4096;
+
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+  iree_hal_amdgpu_host_queue_t* queue = test_device.first_host_queue();
+  ASSERT_NE(queue, nullptr);
+
+  Ref<iree_hal_pool_t> pool;
+  IREE_ASSERT_OK(CreateExplicitFixedBlockPool(test_device.base_device(),
+                                              allocation_size, pool.out()));
+  const iree_async_axis_t death_axis = queue->axis;
+  IREE_ASSERT_OK(
+      SeedWaitableFixedBlockReservation(pool, allocation_size, death_axis));
+  queue->wait_barrier_strategy = IREE_HAL_AMDGPU_WAIT_BARRIER_STRATEGY_DEFER;
+
+  Ref<iree_hal_semaphore_t> alloca_signal;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), alloca_signal.out()));
+  uint64_t alloca_signal_value = 1;
+  iree_hal_semaphore_t* alloca_signal_ptr = alloca_signal.get();
+  const iree_hal_semaphore_list_t alloca_signal_list =
+      MakeSemaphoreList(&alloca_signal_ptr, &alloca_signal_value);
+
+  iree_hal_buffer_t* buffer = NULL;
+  IREE_ASSERT_OK(iree_hal_device_queue_alloca(
+      test_device.base_device(), kQueueAffinity0,
+      iree_hal_semaphore_list_empty(), alloca_signal_list, pool,
+      MakeTransientBufferParams(), allocation_size,
+      IREE_HAL_ALLOCA_FLAG_ALLOW_POOL_WAIT_FRONTIER, &buffer));
+  ASSERT_NE(buffer, nullptr);
+  EXPECT_FALSE(iree_hal_semaphore_list_poll(alloca_signal_list));
+  ASSERT_TRUE(HostQueueHasPendingOps(queue));
+
+  iree_hal_amdgpu_host_queue_cancel_pending(queue, IREE_STATUS_CANCELLED,
+                                            "test cancellation");
+  EXPECT_FALSE(HostQueueHasPendingOps(queue));
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_CANCELLED,
+      iree_hal_semaphore_wait(alloca_signal, alloca_signal_value,
+                              iree_infinite_timeout(),
+                              IREE_ASYNC_WAIT_FLAG_NONE));
+
+  iree_hal_pool_stats_t stats;
+  iree_hal_pool_query_stats(pool, &stats);
+  EXPECT_EQ(stats.reservation_count, 0u);
+  iree_hal_buffer_release(buffer);
+}
+
+TEST_F(HostQueuePendingTest, CancelPendingAllocaPoolNotificationWait) {
+  const iree_device_size_t allocation_size = 4096;
+
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+  iree_hal_amdgpu_host_queue_t* queue = test_device.first_host_queue();
+  ASSERT_NE(queue, nullptr);
+
+  Ref<iree_hal_pool_t> pool;
+  IREE_ASSERT_OK(CreateExplicitFixedBlockPool(test_device.base_device(),
+                                              allocation_size, pool.out()));
+
+  Ref<iree_hal_semaphore_t> alloca0_signal;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), alloca0_signal.out()));
+  uint64_t alloca0_signal_value = 1;
+  iree_hal_semaphore_t* alloca0_signal_ptr = alloca0_signal.get();
+  const iree_hal_semaphore_list_t alloca0_signal_list =
+      MakeSemaphoreList(&alloca0_signal_ptr, &alloca0_signal_value);
+
+  iree_hal_buffer_t* buffer0 = NULL;
+  IREE_ASSERT_OK(iree_hal_device_queue_alloca(
+      test_device.base_device(), kQueueAffinity0,
+      iree_hal_semaphore_list_empty(), alloca0_signal_list, pool,
+      MakeTransientBufferParams(), allocation_size, IREE_HAL_ALLOCA_FLAG_NONE,
+      &buffer0));
+  ASSERT_NE(buffer0, nullptr);
+  IREE_ASSERT_OK(iree_hal_semaphore_wait(alloca0_signal, alloca0_signal_value,
+                                         iree_infinite_timeout(),
+                                         IREE_ASYNC_WAIT_FLAG_NONE));
+
+  Ref<iree_hal_semaphore_t> alloca1_signal;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), alloca1_signal.out()));
+  uint64_t alloca1_signal_value = 1;
+  iree_hal_semaphore_t* alloca1_signal_ptr = alloca1_signal.get();
+  const iree_hal_semaphore_list_t alloca1_signal_list =
+      MakeSemaphoreList(&alloca1_signal_ptr, &alloca1_signal_value);
+
+  iree_hal_buffer_t* buffer1 = NULL;
+  IREE_ASSERT_OK(iree_hal_device_queue_alloca(
+      test_device.base_device(), kQueueAffinity0,
+      iree_hal_semaphore_list_empty(), alloca1_signal_list, pool,
+      MakeTransientBufferParams(), allocation_size, IREE_HAL_ALLOCA_FLAG_NONE,
+      &buffer1));
+  ASSERT_NE(buffer1, nullptr);
+  EXPECT_FALSE(iree_hal_semaphore_list_poll(alloca1_signal_list));
+  ASSERT_TRUE(HostQueueHasPendingOps(queue));
+
+  iree_hal_pool_stats_t stats;
+  iree_hal_pool_query_stats(pool, &stats);
+  EXPECT_GE(stats.exhausted_count, 1u);
+
+  iree_hal_amdgpu_host_queue_cancel_pending(queue, IREE_STATUS_CANCELLED,
+                                            "test cancellation");
+  EXPECT_FALSE(HostQueueHasPendingOps(queue));
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_CANCELLED,
+      iree_hal_semaphore_wait(alloca1_signal, alloca1_signal_value,
+                              iree_infinite_timeout(),
+                              IREE_ASYNC_WAIT_FLAG_NONE));
+
+  iree_hal_buffer_release(buffer1);
+  iree_hal_buffer_release(buffer0);
+}
+
+}  // namespace
+}  // namespace iree::hal::amdgpu
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_policy.c b/runtime/src/iree/hal/drivers/amdgpu/host_queue_policy.c
new file mode 100644
index 0000000..91102f2
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_policy.c
@@ -0,0 +1,93 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/host_queue_policy.h"
+
+#include "iree/hal/drivers/amdgpu/logical_device.h"
+#include "iree/hal/drivers/amdgpu/queue_affinity.h"
+#include "iree/hal/drivers/amdgpu/semaphore.h"
+#include "iree/hal/drivers/amdgpu/system.h"
+
+// Returns true when |queue_affinity| names only queues on |queue|'s physical
+// HSA agent. HSA AGENT scope is not a logical-device-wide guarantee: a
+// multi-GPU logical device still needs SYSTEM scope for cross-physical-agent
+// synchronization.
+static bool iree_hal_amdgpu_host_queue_affinity_is_same_agent(
+    const iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_queue_affinity_t queue_affinity) {
+  const iree_hal_amdgpu_logical_device_t* logical_device =
+      (const iree_hal_amdgpu_logical_device_t*)queue->logical_device;
+  const iree_hal_amdgpu_queue_affinity_domain_t domain = {
+      .supported_affinity = logical_device->queue_affinity_mask,
+      .physical_device_count = logical_device->physical_device_count,
+      .queue_count_per_physical_device =
+          logical_device->system->topology.gpu_agent_queue_count,
+  };
+  return iree_hal_amdgpu_queue_affinity_is_physical_device_local(
+      domain, queue_affinity, queue->device_ordinal);
+}
+
+// Returns true when a semaphore edge can be represented with HSA AGENT scope
+// on |queue|. Public/exportable/host-visible semaphores and cross-agent
+// affinities use SYSTEM scope.
+static bool iree_hal_amdgpu_host_queue_can_use_agent_scope(
+    const iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_semaphore_t* semaphore) {
+  const iree_hal_amdgpu_logical_device_t* logical_device =
+      (const iree_hal_amdgpu_logical_device_t*)queue->logical_device;
+  if (!iree_hal_amdgpu_semaphore_is_local(semaphore, logical_device)) {
+    return false;
+  }
+
+  const iree_hal_semaphore_flags_t flags =
+      iree_hal_amdgpu_semaphore_flags(semaphore);
+  if (!iree_all_bits_set(flags, IREE_HAL_SEMAPHORE_FLAG_DEVICE_LOCAL)) {
+    return false;
+  }
+  const iree_hal_semaphore_flags_t public_flags =
+      IREE_HAL_SEMAPHORE_FLAG_HOST_INTERRUPT |
+      IREE_HAL_SEMAPHORE_FLAG_EXPORTABLE |
+      IREE_HAL_SEMAPHORE_FLAG_EXPORTABLE_TIMEPOINTS;
+  if (iree_any_bit_set(flags, public_flags)) return false;
+
+  return iree_hal_amdgpu_host_queue_affinity_is_same_agent(
+      queue, iree_hal_amdgpu_semaphore_queue_affinity(semaphore));
+}
+
+iree_hsa_fence_scope_t iree_hal_amdgpu_host_queue_wait_acquire_scope(
+    const iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_semaphore_t* semaphore) {
+  return iree_hal_amdgpu_host_queue_can_use_agent_scope(queue, semaphore)
+             ? IREE_HSA_FENCE_SCOPE_AGENT
+             : IREE_HSA_FENCE_SCOPE_SYSTEM;
+}
+
+iree_hsa_fence_scope_t iree_hal_amdgpu_host_queue_axis_acquire_scope(
+    const iree_hal_amdgpu_host_queue_t* queue, iree_async_axis_t axis) {
+  return iree_async_axis_device_index(axis) == queue->device_ordinal
+             ? IREE_HSA_FENCE_SCOPE_AGENT
+             : IREE_HSA_FENCE_SCOPE_SYSTEM;
+}
+
+iree_hsa_fence_scope_t iree_hal_amdgpu_host_queue_signal_release_scope(
+    const iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_semaphore_t* semaphore) {
+  return iree_hal_amdgpu_host_queue_can_use_agent_scope(queue, semaphore)
+             ? IREE_HSA_FENCE_SCOPE_AGENT
+             : IREE_HSA_FENCE_SCOPE_SYSTEM;
+}
+
+iree_hsa_fence_scope_t iree_hal_amdgpu_host_queue_signal_list_release_scope(
+    const iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_semaphore_list_t semaphores) {
+  iree_hsa_fence_scope_t release_scope = IREE_HSA_FENCE_SCOPE_AGENT;
+  for (iree_host_size_t i = 0; i < semaphores.count; ++i) {
+    release_scope = iree_hal_amdgpu_host_queue_max_fence_scope(
+        release_scope, iree_hal_amdgpu_host_queue_signal_release_scope(
+                           queue, semaphores.semaphores[i]));
+  }
+  return release_scope;
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_policy.h b/runtime/src/iree/hal/drivers/amdgpu/host_queue_policy.h
new file mode 100644
index 0000000..3d036bf
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_policy.h
@@ -0,0 +1,49 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_POLICY_H_
+#define IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_POLICY_H_
+
+#include "iree/hal/drivers/amdgpu/host_queue.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Returns the stronger of two HSA fence scopes.
+static inline iree_hsa_fence_scope_t iree_hal_amdgpu_host_queue_max_fence_scope(
+    iree_hsa_fence_scope_t lhs, iree_hsa_fence_scope_t rhs) {
+  return lhs > rhs ? lhs : rhs;
+}
+
+// Returns the acquire scope required when |queue| consumes a wait edge on
+// |semaphore|. This is derived from the semaphore's visibility contract only;
+// operation buffers/bindings are intentionally not inspected.
+iree_hsa_fence_scope_t iree_hal_amdgpu_host_queue_wait_acquire_scope(
+    const iree_hal_amdgpu_host_queue_t* queue, iree_hal_semaphore_t* semaphore);
+
+// Returns the acquire scope required when |queue| waits on a producer
+// frontier axis. This is used for pool/death-frontier waits where there is no
+// semaphore edge to classify.
+iree_hsa_fence_scope_t iree_hal_amdgpu_host_queue_axis_acquire_scope(
+    const iree_hal_amdgpu_host_queue_t* queue, iree_async_axis_t axis);
+
+// Returns the release scope required when |queue| publishes a signal edge on
+// |semaphore|. This is derived from the semaphore's visibility contract only;
+// operation buffers/bindings are intentionally not inspected.
+iree_hsa_fence_scope_t iree_hal_amdgpu_host_queue_signal_release_scope(
+    const iree_hal_amdgpu_host_queue_t* queue, iree_hal_semaphore_t* semaphore);
+
+// Returns the release scope required for all signal semaphores in |semaphores|.
+iree_hsa_fence_scope_t iree_hal_amdgpu_host_queue_signal_list_release_scope(
+    const iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_semaphore_list_t semaphores);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_POLICY_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_profile.c b/runtime/src/iree/hal/drivers/amdgpu/host_queue_profile.c
new file mode 100644
index 0000000..b2545d7
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_profile.c
@@ -0,0 +1,117 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/host_queue_profile.h"
+
+#include "iree/hal/drivers/amdgpu/host_queue_profile_events.h"
+#include "iree/hal/drivers/amdgpu/logical_device.h"
+
+uint32_t iree_hal_amdgpu_host_queue_profile_device_ordinal(
+    const iree_hal_amdgpu_host_queue_t* queue) {
+  return queue->device_ordinal <= UINT32_MAX ? (uint32_t)queue->device_ordinal
+                                             : UINT32_MAX;
+}
+
+uint32_t iree_hal_amdgpu_host_queue_profile_queue_ordinal(
+    const iree_hal_amdgpu_host_queue_t* queue) {
+  return iree_async_axis_queue_index(queue->axis);
+}
+
+uint64_t iree_hal_amdgpu_host_queue_profile_stream_id(
+    const iree_hal_amdgpu_host_queue_t* queue) {
+  const uint32_t physical_device_ordinal =
+      iree_hal_amdgpu_host_queue_profile_device_ordinal(queue);
+  const uint32_t queue_ordinal =
+      iree_hal_amdgpu_host_queue_profile_queue_ordinal(queue);
+  return ((uint64_t)physical_device_ordinal << 32) | (uint64_t)queue_ordinal;
+}
+
+uint32_t iree_hal_amdgpu_host_queue_profile_semaphore_count(
+    const iree_hal_semaphore_list_t semaphore_list) {
+  return semaphore_list.count > UINT32_MAX ? UINT32_MAX
+                                           : (uint32_t)semaphore_list.count;
+}
+
+void iree_hal_amdgpu_host_queue_set_profile_flags(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_host_queue_profile_flags_t flags) {
+  queue->profiling.queue_events_enabled = iree_any_bit_set(
+      flags, IREE_HAL_AMDGPU_HOST_QUEUE_PROFILE_FLAG_QUEUE_EVENTS);
+  queue->profiling.queue_device_events_enabled = iree_any_bit_set(
+      flags, IREE_HAL_AMDGPU_HOST_QUEUE_PROFILE_FLAG_QUEUE_DEVICE_EVENTS);
+  queue->profiling.dispatch_profiling_enabled = iree_any_bit_set(
+      flags, IREE_HAL_AMDGPU_HOST_QUEUE_PROFILE_FLAG_DISPATCHES);
+}
+
+iree_hal_amdgpu_profile_queue_device_event_t*
+iree_hal_amdgpu_host_queue_initialize_profile_queue_device_event(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_queue_device_event_reservation_t reservation,
+    const iree_hal_amdgpu_host_queue_profile_event_info_t* info) {
+  if (reservation.event_count == 0) return NULL;
+  IREE_ASSERT(info != NULL,
+              "queue device event reservation requires profile event info");
+  iree_hal_amdgpu_profile_queue_device_event_t* event =
+      iree_hal_amdgpu_host_queue_profile_queue_device_event_at(
+          queue, reservation.first_event_position);
+  event->type = info->type;
+  event->flags = info->flags;
+  event->command_buffer_id = info->command_buffer_id;
+  event->allocation_id = info->allocation_id;
+  event->payload_length = info->payload_length;
+  event->operation_count = info->operation_count;
+  event->start_tick = 0;
+  event->end_tick = 0;
+  return event;
+}
+
+static iree_hal_profile_queue_dependency_strategy_t
+iree_hal_amdgpu_host_queue_profile_dependency_strategy(
+    const iree_hal_amdgpu_wait_resolution_t* resolution) {
+  if (resolution->needs_deferral) {
+    return IREE_HAL_PROFILE_QUEUE_DEPENDENCY_STRATEGY_SOFTWARE_DEFER;
+  }
+  if (iree_any_bit_set(resolution->profile_event_flags,
+                       IREE_HAL_PROFILE_QUEUE_EVENT_FLAG_SOFTWARE_DEFERRED)) {
+    return IREE_HAL_PROFILE_QUEUE_DEPENDENCY_STRATEGY_SOFTWARE_DEFER;
+  }
+  if (resolution->wait_count == 0) {
+    return IREE_HAL_PROFILE_QUEUE_DEPENDENCY_STRATEGY_NONE;
+  }
+  if (resolution->barrier_count != 0) {
+    return IREE_HAL_PROFILE_QUEUE_DEPENDENCY_STRATEGY_DEVICE_BARRIER;
+  }
+  return IREE_HAL_PROFILE_QUEUE_DEPENDENCY_STRATEGY_INLINE;
+}
+
+void iree_hal_amdgpu_host_queue_record_profile_queue_event(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    const iree_hal_amdgpu_host_queue_profile_event_info_t* info) {
+  if (!queue->profiling.queue_events_enabled) return;
+
+  iree_hal_profile_queue_event_t event = iree_hal_profile_queue_event_default();
+  event.type = info->type;
+  event.flags = info->flags | resolution->profile_event_flags;
+  event.dependency_strategy =
+      iree_hal_amdgpu_host_queue_profile_dependency_strategy(resolution);
+  event.submission_id = info->submission_id;
+  event.command_buffer_id = info->command_buffer_id;
+  event.allocation_id = info->allocation_id;
+  event.stream_id = iree_hal_amdgpu_host_queue_profile_stream_id(queue);
+  event.physical_device_ordinal =
+      iree_hal_amdgpu_host_queue_profile_device_ordinal(queue);
+  event.queue_ordinal = iree_hal_amdgpu_host_queue_profile_queue_ordinal(queue);
+  event.wait_count = resolution->wait_count;
+  event.signal_count =
+      iree_hal_amdgpu_host_queue_profile_semaphore_count(signal_semaphore_list);
+  event.barrier_count = resolution->barrier_count;
+  event.operation_count = info->operation_count;
+  event.payload_length = info->payload_length;
+  iree_hal_amdgpu_logical_device_record_profile_queue_event(
+      queue->logical_device, &event);
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_profile.h b/runtime/src/iree/hal/drivers/amdgpu/host_queue_profile.h
new file mode 100644
index 0000000..be19d99
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_profile.h
@@ -0,0 +1,88 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_PROFILE_H_
+#define IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_PROFILE_H_
+
+#include "iree/hal/drivers/amdgpu/host_queue_waits.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+typedef uint32_t iree_hal_amdgpu_host_queue_profile_flags_t;
+enum iree_hal_amdgpu_host_queue_profile_flag_bits_t {
+  IREE_HAL_AMDGPU_HOST_QUEUE_PROFILE_FLAG_NONE = 0u,
+  // Host-timestamped queue operation events should be recorded.
+  IREE_HAL_AMDGPU_HOST_QUEUE_PROFILE_FLAG_QUEUE_EVENTS = 1u << 0,
+  // Device-timestamped queue operation events should be recorded.
+  IREE_HAL_AMDGPU_HOST_QUEUE_PROFILE_FLAG_QUEUE_DEVICE_EVENTS = 1u << 1,
+  // Per-dispatch profiling augmentation may be applied to selected dispatches.
+  IREE_HAL_AMDGPU_HOST_QUEUE_PROFILE_FLAG_DISPATCHES = 1u << 2,
+};
+
+// Additional details for one queue operation profile event.
+typedef struct iree_hal_amdgpu_host_queue_profile_event_info_t {
+  // Type of queue operation represented by the event.
+  iree_hal_profile_queue_event_type_t type;
+  // Flags describing queue operation properties.
+  iree_hal_profile_queue_event_flags_t flags;
+  // Queue submission epoch assigned by the operation.
+  uint64_t submission_id;
+  // Session-local command-buffer identifier, or 0 when not applicable.
+  uint64_t command_buffer_id;
+  // Producer-defined allocation identifier, or 0 when not applicable.
+  uint64_t allocation_id;
+  // Type-specific payload byte length, or 0 when not applicable.
+  uint64_t payload_length;
+  // Number of encoded payload operations represented by this event.
+  uint32_t operation_count;
+} iree_hal_amdgpu_host_queue_profile_event_info_t;
+
+// Returns the session-local profiling ordinal for |queue|'s physical device.
+uint32_t iree_hal_amdgpu_host_queue_profile_device_ordinal(
+    const iree_hal_amdgpu_host_queue_t* queue);
+
+// Returns the session-local profiling ordinal for |queue| within its device.
+uint32_t iree_hal_amdgpu_host_queue_profile_queue_ordinal(
+    const iree_hal_amdgpu_host_queue_t* queue);
+
+// Returns the stream id used by queue metadata and queue/dispatch events.
+uint64_t iree_hal_amdgpu_host_queue_profile_stream_id(
+    const iree_hal_amdgpu_host_queue_t* queue);
+
+// Returns |semaphore_list.count| saturated to the queue-event field width.
+uint32_t iree_hal_amdgpu_host_queue_profile_semaphore_count(
+    const iree_hal_semaphore_list_t semaphore_list);
+
+// Sets queue-local profile recording flags for an active session.
+void iree_hal_amdgpu_host_queue_set_profile_flags(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_host_queue_profile_flags_t flags);
+
+// Initializes one reserved device-timestamped queue operation event.
+iree_hal_amdgpu_profile_queue_device_event_t*
+iree_hal_amdgpu_host_queue_initialize_profile_queue_device_event(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_queue_device_event_reservation_t reservation,
+    const iree_hal_amdgpu_host_queue_profile_event_info_t* info);
+
+// Records one queue operation event when queue profiling is enabled.
+//
+// This performs a cheap queue-local enabled check before preparing the full
+// event. The sink is never called here; the logical device buffers events for
+// profiling_flush/end.
+void iree_hal_amdgpu_host_queue_record_profile_queue_event(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    const iree_hal_amdgpu_host_queue_profile_event_info_t* info);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_PROFILE_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_profile_events.c b/runtime/src/iree/hal/drivers/amdgpu/host_queue_profile_events.c
new file mode 100644
index 0000000..1167481
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_profile_events.c
@@ -0,0 +1,706 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/host_queue_profile_events.h"
+
+#include <string.h>
+
+#include "iree/hal/drivers/amdgpu/host_queue_profile.h"
+#include "iree/hal/drivers/amdgpu/profile_counters.h"
+#include "iree/hal/drivers/amdgpu/profile_traces.h"
+
+static_assert(sizeof(iree_hal_amdgpu_profile_dispatch_event_t) ==
+                  sizeof(iree_hal_profile_dispatch_event_t),
+              "AMDGPU dispatch events must convert without layout growth");
+static_assert(sizeof(iree_hal_amdgpu_profile_queue_device_event_t) ==
+                  sizeof(iree_hal_profile_queue_device_event_t),
+              "AMDGPU queue device events must convert without layout growth");
+static_assert(
+    (uint32_t)IREE_HAL_AMDGPU_PROFILE_DISPATCH_EVENT_FLAG_COMMAND_BUFFER ==
+        (uint32_t)IREE_HAL_PROFILE_DISPATCH_EVENT_FLAG_COMMAND_BUFFER,
+    "AMDGPU command-buffer dispatch event flag must match HAL");
+static_assert(
+    (uint32_t)IREE_HAL_AMDGPU_PROFILE_DISPATCH_EVENT_FLAG_INDIRECT_PARAMETERS ==
+        (uint32_t)IREE_HAL_PROFILE_DISPATCH_EVENT_FLAG_INDIRECT_PARAMETERS,
+    "AMDGPU indirect dispatch event flag must match HAL");
+
+// Maximum dispatch events buffered per queue between profiling flushes.
+#define IREE_HAL_AMDGPU_HOST_QUEUE_PROFILE_DISPATCH_EVENT_CAPACITY (64 * 1024)
+
+// Maximum queue device events buffered per queue between profiling flushes.
+#define IREE_HAL_AMDGPU_HOST_QUEUE_PROFILE_QUEUE_DEVICE_EVENT_CAPACITY \
+  (64 * 1024)
+
+static void iree_hal_amdgpu_host_queue_initialize_profiling_signal(
+    iree_amd_signal_t* signal) {
+  memset(signal, 0, sizeof(*signal));
+  signal->kind = IREE_AMD_SIGNAL_KIND_USER;
+  // Profiling completion signals are never waited on. Keep the value at
+  // all-bits-set so packet completion decrements never require host/device
+  // reset traffic; consumers read start_ts/end_ts after queue ordering proves
+  // the profiled packet completed.
+  signal->value = (iree_hsa_signal_value_t)-1;
+}
+
+static iree_status_t
+iree_hal_amdgpu_host_queue_allocate_profiling_completion_signals(
+    iree_hal_amdgpu_block_pool_t* signal_block_pool, uint32_t signal_count,
+    iree_allocator_t host_allocator, iree_hal_amdgpu_host_queue_t* out_queue) {
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, signal_count);
+
+  if (IREE_UNLIKELY(signal_block_pool->block_size < sizeof(iree_amd_signal_t) ||
+                    signal_block_pool->block_size % sizeof(iree_amd_signal_t) !=
+                        0)) {
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                             "profiling signal block size %" PRIdsz
+                             " must hold whole iree_amd_signal_t records",
+                             signal_block_pool->block_size));
+  }
+  const iree_host_size_t signals_per_block =
+      signal_block_pool->block_size / sizeof(iree_amd_signal_t);
+  if (IREE_UNLIKELY(signals_per_block == 0 || signals_per_block > UINT32_MAX)) {
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                             "profiling signal block size %" PRIu64
+                             " cannot hold a valid signal count",
+                             (uint64_t)signal_block_pool->block_size));
+  }
+  const iree_host_size_t signal_block_count =
+      iree_host_size_ceil_div(signal_count, signals_per_block);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, signal_block_count);
+
+  iree_host_size_t signal_block_table_size = 0;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0,
+      IREE_STRUCT_LAYOUT(0, &signal_block_table_size,
+                         IREE_STRUCT_FIELD(signal_block_count,
+                                           iree_hal_amdgpu_block_t*, NULL)));
+  iree_hal_amdgpu_block_t** signal_blocks = NULL;
+  iree_status_t status = iree_allocator_malloc(
+      host_allocator, signal_block_table_size, (void**)&signal_blocks);
+  iree_host_size_t acquired_block_count = 0;
+  for (iree_host_size_t block_index = 0;
+       block_index < signal_block_count && iree_status_is_ok(status);
+       ++block_index) {
+    iree_hal_amdgpu_block_t* block = NULL;
+    status = iree_hal_amdgpu_block_pool_acquire(signal_block_pool, &block);
+    if (iree_status_is_ok(status)) {
+      signal_blocks[block_index] = block;
+      ++acquired_block_count;
+      if (IREE_UNLIKELY(
+              (uintptr_t)block->ptr % iree_alignof(iree_amd_signal_t) != 0)) {
+        status = iree_make_status(
+            IREE_STATUS_FAILED_PRECONDITION,
+            "profiling signal block is not aligned to %" PRIhsz " bytes",
+            (iree_host_size_t)iree_alignof(iree_amd_signal_t));
+      } else {
+        for (iree_host_size_t signal_index = 0;
+             signal_index < signals_per_block; ++signal_index) {
+          uint8_t* signal_ptr =
+              (uint8_t*)block->ptr + signal_index * sizeof(iree_amd_signal_t);
+          iree_hal_amdgpu_host_queue_initialize_profiling_signal(
+              (iree_amd_signal_t*)signal_ptr);
+        }
+      }
+    }
+  }
+
+  if (iree_status_is_ok(status)) {
+    out_queue->profiling.signals.block_pool = signal_block_pool;
+    out_queue->profiling.signals.blocks = signal_blocks;
+    out_queue->profiling.signals.block_count = (uint32_t)signal_block_count;
+    out_queue->profiling.signals.signals_per_block =
+        (uint32_t)signals_per_block;
+  } else {
+    for (iree_host_size_t i = 0; i < acquired_block_count; ++i) {
+      iree_hal_amdgpu_block_pool_release(signal_block_pool, signal_blocks[i]);
+    }
+    iree_allocator_free(host_allocator, signal_blocks);
+  }
+
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_ensure_profiling_completion_signals(
+    iree_hal_amdgpu_host_queue_t* queue) {
+  if (queue->profiling.signals.blocks) return iree_ok_status();
+  return iree_hal_amdgpu_host_queue_allocate_profiling_completion_signals(
+      queue->profiling.signals.block_pool,
+      queue->profiling.dispatch_events.capacity, queue->host_allocator, queue);
+}
+
+void iree_hal_amdgpu_host_queue_deallocate_profiling_completion_signals(
+    iree_hal_amdgpu_host_queue_t* queue) {
+  if (!queue->profiling.signals.blocks) {
+    return;
+  }
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  for (uint32_t i = 0; i < queue->profiling.signals.block_count; ++i) {
+    iree_hal_amdgpu_block_pool_release(queue->profiling.signals.block_pool,
+                                       queue->profiling.signals.blocks[i]);
+  }
+  iree_allocator_free(queue->host_allocator, queue->profiling.signals.blocks);
+  queue->profiling.signals.block_pool = NULL;
+  queue->profiling.signals.blocks = NULL;
+  queue->profiling.signals.block_count = 0;
+  queue->profiling.signals.signals_per_block = 0;
+
+  IREE_TRACE_ZONE_END(z0);
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_ensure_profile_event_storage(
+    iree_hal_amdgpu_host_queue_t* queue) {
+  if (queue->profiling.event_allocation.base) return iree_ok_status();
+
+  IREE_TRACE_ZONE_BEGIN(z0);
+  const uint32_t dispatch_event_capacity =
+      IREE_HAL_AMDGPU_HOST_QUEUE_PROFILE_DISPATCH_EVENT_CAPACITY;
+  const uint32_t queue_device_event_capacity =
+      IREE_HAL_AMDGPU_HOST_QUEUE_PROFILE_QUEUE_DEVICE_EVENT_CAPACITY;
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, dispatch_event_capacity);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, queue_device_event_capacity);
+
+  iree_host_size_t dispatch_events_offset = 0;
+  iree_host_size_t queue_device_events_offset = 0;
+  iree_host_size_t total_size = 0;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, IREE_STRUCT_LAYOUT(
+              0, &total_size,
+              IREE_STRUCT_FIELD(dispatch_event_capacity,
+                                iree_hal_amdgpu_profile_dispatch_event_t,
+                                &dispatch_events_offset),
+              IREE_STRUCT_FIELD(queue_device_event_capacity,
+                                iree_hal_amdgpu_profile_queue_device_event_t,
+                                &queue_device_events_offset)));
+  void* event_storage = NULL;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0,
+      iree_hsa_amd_memory_pool_allocate(
+          IREE_LIBHSA(queue->libhsa),
+          queue->profiling.signals.block_pool->memory_pool, total_size,
+          HSA_AMD_MEMORY_POOL_STANDARD_FLAG, &event_storage),
+      "allocating profile event rings of %" PRIhsz " bytes", total_size);
+  memset(event_storage, 0, total_size);
+
+  iree_slim_mutex_lock(&queue->profiling.event_mutex);
+  queue->profiling.event_allocation.base = event_storage;
+  queue->profiling.event_allocation.size = total_size;
+  queue->profiling.dispatch_events.values =
+      (iree_hal_amdgpu_profile_dispatch_event_t*)((uint8_t*)event_storage +
+                                                  dispatch_events_offset);
+  queue->profiling.dispatch_events.capacity = dispatch_event_capacity;
+  queue->profiling.dispatch_events.mask = dispatch_event_capacity - 1;
+  queue->profiling.dispatch_events.read_position = 0;
+  queue->profiling.dispatch_events.ready_position = 0;
+  queue->profiling.dispatch_events.write_position = 0;
+  queue->profiling.dispatch_events.next_event_id = 1;
+  queue->profiling.queue_device_events.values =
+      (iree_hal_amdgpu_profile_queue_device_event_t*)((uint8_t*)event_storage +
+                                                      queue_device_events_offset);
+  queue->profiling.queue_device_events.capacity = queue_device_event_capacity;
+  queue->profiling.queue_device_events.mask = queue_device_event_capacity - 1;
+  queue->profiling.queue_device_events.read_position = 0;
+  queue->profiling.queue_device_events.ready_position = 0;
+  queue->profiling.queue_device_events.write_position = 0;
+  queue->profiling.queue_device_events.next_event_id = 1;
+  iree_slim_mutex_unlock(&queue->profiling.event_mutex);
+
+  IREE_TRACE_ZONE_END(z0);
+  return iree_ok_status();
+}
+
+void iree_hal_amdgpu_host_queue_clear_profile_events(
+    iree_hal_amdgpu_host_queue_t* queue) {
+  iree_slim_mutex_lock(&queue->profiling.event_mutex);
+  queue->profiling.dispatch_events.read_position = 0;
+  queue->profiling.dispatch_events.ready_position = 0;
+  queue->profiling.dispatch_events.write_position = 0;
+  queue->profiling.dispatch_events.next_event_id = 1;
+  queue->profiling.queue_device_events.read_position = 0;
+  queue->profiling.queue_device_events.ready_position = 0;
+  queue->profiling.queue_device_events.write_position = 0;
+  queue->profiling.queue_device_events.next_event_id = 1;
+  if (queue->profiling.event_allocation.base) {
+    memset(queue->profiling.event_allocation.base, 0,
+           queue->profiling.event_allocation.size);
+  }
+  iree_slim_mutex_unlock(&queue->profiling.event_mutex);
+}
+
+void iree_hal_amdgpu_host_queue_deallocate_profile_events(
+    iree_hal_amdgpu_host_queue_t* queue) {
+  if (!queue->profiling.event_allocation.base) return;
+  IREE_TRACE_ZONE_BEGIN(z0);
+  iree_hal_amdgpu_hsa_cleanup_assert_success(iree_hsa_amd_memory_pool_free_raw(
+      queue->libhsa, queue->profiling.event_allocation.base));
+  queue->profiling.event_allocation.base = NULL;
+  queue->profiling.event_allocation.size = 0;
+  queue->profiling.dispatch_events.values = NULL;
+  queue->profiling.dispatch_events.capacity = 0;
+  queue->profiling.dispatch_events.mask = 0;
+  queue->profiling.dispatch_events.read_position = 0;
+  queue->profiling.dispatch_events.ready_position = 0;
+  queue->profiling.dispatch_events.write_position = 0;
+  queue->profiling.dispatch_events.next_event_id = 0;
+  queue->profiling.queue_device_events.values = NULL;
+  queue->profiling.queue_device_events.capacity = 0;
+  queue->profiling.queue_device_events.mask = 0;
+  queue->profiling.queue_device_events.read_position = 0;
+  queue->profiling.queue_device_events.ready_position = 0;
+  queue->profiling.queue_device_events.write_position = 0;
+  queue->profiling.queue_device_events.next_event_id = 0;
+  IREE_TRACE_ZONE_END(z0);
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_reserve_profile_dispatch_events(
+    iree_hal_amdgpu_host_queue_t* queue, uint32_t event_count,
+    iree_hal_amdgpu_profile_dispatch_event_reservation_t* out_reservation) {
+  *out_reservation = (iree_hal_amdgpu_profile_dispatch_event_reservation_t){0};
+  if (event_count == 0 || !queue->profiling.hsa_queue_timestamps_enabled ||
+      !queue->profiling.dispatch_events.values) {
+    return iree_ok_status();
+  }
+  if (IREE_UNLIKELY(event_count > queue->profiling.dispatch_events.capacity)) {
+    return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
+                            "dispatch profiling reservation of %" PRIu32
+                            " events exceeds queue capacity %" PRIu32,
+                            event_count,
+                            queue->profiling.dispatch_events.capacity);
+  }
+
+  bool is_exhausted = false;
+  uint64_t exhausted_available_count = 0;
+  uint64_t exhausted_ready_count = 0;
+  iree_slim_mutex_lock(&queue->profiling.event_mutex);
+  const uint64_t read_position = queue->profiling.dispatch_events.read_position;
+  const uint64_t ready_position =
+      queue->profiling.dispatch_events.ready_position;
+  const uint64_t write_position =
+      queue->profiling.dispatch_events.write_position;
+  const uint64_t occupied_count = write_position - read_position;
+  const uint64_t available_count =
+      queue->profiling.dispatch_events.capacity - occupied_count;
+  if (event_count <= available_count) {
+    out_reservation->first_event_position = write_position;
+    out_reservation->event_count = event_count;
+    queue->profiling.dispatch_events.write_position =
+        write_position + event_count;
+    for (uint32_t i = 0; i < event_count; ++i) {
+      iree_hal_amdgpu_profile_dispatch_event_t* event =
+          iree_hal_amdgpu_host_queue_profile_dispatch_event_at(
+              queue, write_position + i);
+      memset(event, 0, sizeof(*event));
+      event->record_length = sizeof(*event);
+      event->event_id = queue->profiling.dispatch_events.next_event_id++;
+      event->command_index = UINT32_MAX;
+      event->export_ordinal = UINT32_MAX;
+    }
+  } else {
+    is_exhausted = true;
+    exhausted_available_count = available_count;
+    exhausted_ready_count = ready_position - read_position;
+  }
+  iree_slim_mutex_unlock(&queue->profiling.event_mutex);
+  if (IREE_UNLIKELY(is_exhausted)) {
+    return iree_make_status(
+        IREE_STATUS_RESOURCE_EXHAUSTED,
+        "dispatch profiling event ring exhausted: requested %" PRIu32
+        " events, available %" PRIu64 ", ready %" PRIu64 ", capacity %" PRIu32,
+        event_count, exhausted_available_count, exhausted_ready_count,
+        queue->profiling.dispatch_events.capacity);
+  }
+  return iree_ok_status();
+}
+
+void iree_hal_amdgpu_host_queue_cancel_profile_dispatch_events(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_dispatch_event_reservation_t reservation) {
+  if (!reservation.event_count) return;
+  iree_slim_mutex_lock(&queue->profiling.event_mutex);
+  queue->profiling.dispatch_events.write_position =
+      reservation.first_event_position;
+  queue->profiling.dispatch_events.next_event_id -= reservation.event_count;
+  iree_slim_mutex_unlock(&queue->profiling.event_mutex);
+}
+
+iree_hal_amdgpu_profile_dispatch_event_t*
+iree_hal_amdgpu_host_queue_profile_dispatch_event_at(
+    const iree_hal_amdgpu_host_queue_t* queue, uint64_t event_position) {
+  const uint32_t event_index =
+      (uint32_t)(event_position & queue->profiling.dispatch_events.mask);
+  return &queue->profiling.dispatch_events.values[event_index];
+}
+
+void iree_hal_amdgpu_host_queue_retire_profile_dispatch_events(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_dispatch_event_reservation_t reservation) {
+  if (!reservation.event_count) return;
+  iree_slim_mutex_lock(&queue->profiling.event_mutex);
+  queue->profiling.dispatch_events.ready_position =
+      reservation.first_event_position + reservation.event_count;
+  iree_slim_mutex_unlock(&queue->profiling.event_mutex);
+}
+
+bool iree_hal_amdgpu_host_queue_should_profile_queue_device_events(
+    const iree_hal_amdgpu_host_queue_t* queue) {
+  return queue->profiling.queue_device_events_enabled &&
+         queue->profiling.queue_device_events.values;
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_reserve_profile_queue_device_events(
+    iree_hal_amdgpu_host_queue_t* queue, uint32_t event_count,
+    iree_hal_amdgpu_profile_queue_device_event_reservation_t* out_reservation) {
+  *out_reservation =
+      (iree_hal_amdgpu_profile_queue_device_event_reservation_t){0};
+  if (event_count == 0 ||
+      !iree_hal_amdgpu_host_queue_should_profile_queue_device_events(queue)) {
+    return iree_ok_status();
+  }
+  if (IREE_UNLIKELY(!queue->pm4_ib_slots)) {
+    return iree_make_status(
+        IREE_STATUS_FAILED_PRECONDITION,
+        "queue device profiling requires queue-local PM4 IB slots");
+  }
+  if (IREE_UNLIKELY(
+          !iree_hal_amdgpu_vendor_packet_capabilities_support_timestamp_range(
+              queue->vendor_packet_capabilities))) {
+    return iree_make_status(
+        IREE_STATUS_FAILED_PRECONDITION,
+        "queue device profiling requires PM4 timestamp range support");
+  }
+  if (IREE_UNLIKELY(event_count >
+                    queue->profiling.queue_device_events.capacity)) {
+    return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
+                            "queue device profiling reservation of %" PRIu32
+                            " events exceeds queue capacity %" PRIu32,
+                            event_count,
+                            queue->profiling.queue_device_events.capacity);
+  }
+
+  bool is_exhausted = false;
+  uint64_t exhausted_available_count = 0;
+  uint64_t exhausted_ready_count = 0;
+  iree_slim_mutex_lock(&queue->profiling.event_mutex);
+  const uint64_t read_position =
+      queue->profiling.queue_device_events.read_position;
+  const uint64_t ready_position =
+      queue->profiling.queue_device_events.ready_position;
+  const uint64_t write_position =
+      queue->profiling.queue_device_events.write_position;
+  const uint64_t occupied_count = write_position - read_position;
+  const uint64_t available_count =
+      queue->profiling.queue_device_events.capacity - occupied_count;
+  if (event_count <= available_count) {
+    out_reservation->first_event_position = write_position;
+    out_reservation->event_count = event_count;
+    queue->profiling.queue_device_events.write_position =
+        write_position + event_count;
+    for (uint32_t i = 0; i < event_count; ++i) {
+      iree_hal_amdgpu_profile_queue_device_event_t* event =
+          iree_hal_amdgpu_host_queue_profile_queue_device_event_at(
+              queue, write_position + i);
+      memset(event, 0, sizeof(*event));
+      event->record_length = sizeof(*event);
+      event->event_id = queue->profiling.queue_device_events.next_event_id++;
+      event->physical_device_ordinal =
+          iree_hal_amdgpu_host_queue_profile_device_ordinal(queue);
+      event->queue_ordinal =
+          iree_hal_amdgpu_host_queue_profile_queue_ordinal(queue);
+      event->stream_id = iree_hal_amdgpu_host_queue_profile_stream_id(queue);
+    }
+  } else {
+    is_exhausted = true;
+    exhausted_available_count = available_count;
+    exhausted_ready_count = ready_position - read_position;
+  }
+  iree_slim_mutex_unlock(&queue->profiling.event_mutex);
+  if (IREE_UNLIKELY(is_exhausted)) {
+    return iree_make_status(
+        IREE_STATUS_RESOURCE_EXHAUSTED,
+        "queue device profiling event ring exhausted: requested %" PRIu32
+        " events, available %" PRIu64 ", ready %" PRIu64 ", capacity %" PRIu32,
+        event_count, exhausted_available_count, exhausted_ready_count,
+        queue->profiling.queue_device_events.capacity);
+  }
+  return iree_ok_status();
+}
+
+void iree_hal_amdgpu_host_queue_cancel_profile_queue_device_events(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_queue_device_event_reservation_t reservation) {
+  if (!reservation.event_count) return;
+  iree_slim_mutex_lock(&queue->profiling.event_mutex);
+  queue->profiling.queue_device_events.write_position =
+      reservation.first_event_position;
+  queue->profiling.queue_device_events.next_event_id -= reservation.event_count;
+  iree_slim_mutex_unlock(&queue->profiling.event_mutex);
+}
+
+iree_hal_amdgpu_profile_queue_device_event_t*
+iree_hal_amdgpu_host_queue_profile_queue_device_event_at(
+    const iree_hal_amdgpu_host_queue_t* queue, uint64_t event_position) {
+  const uint32_t event_index =
+      (uint32_t)(event_position & queue->profiling.queue_device_events.mask);
+  return &queue->profiling.queue_device_events.values[event_index];
+}
+
+void iree_hal_amdgpu_host_queue_retire_profile_queue_device_events(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_queue_device_event_reservation_t reservation) {
+  if (!reservation.event_count) return;
+  iree_slim_mutex_lock(&queue->profiling.event_mutex);
+  queue->profiling.queue_device_events.ready_position =
+      reservation.first_event_position + reservation.event_count;
+  iree_slim_mutex_unlock(&queue->profiling.event_mutex);
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_copy_dispatch_events(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t read_position,
+    iree_host_size_t event_count, iree_host_size_t* out_storage_size,
+    iree_hal_profile_dispatch_event_t** out_events) {
+  *out_storage_size = 0;
+  *out_events = NULL;
+  if (event_count == 0) return iree_ok_status();
+
+  IREE_RETURN_IF_ERROR(IREE_STRUCT_LAYOUT(
+      0, out_storage_size,
+      IREE_STRUCT_FIELD(event_count, iree_hal_profile_dispatch_event_t, NULL)));
+  IREE_RETURN_IF_ERROR(iree_allocator_malloc(
+      queue->host_allocator, *out_storage_size, (void**)out_events));
+  for (iree_host_size_t i = 0; i < event_count; ++i) {
+    const iree_hal_amdgpu_profile_dispatch_event_t* source =
+        iree_hal_amdgpu_host_queue_profile_dispatch_event_at(queue,
+                                                             read_position + i);
+    iree_hal_profile_dispatch_event_t* target = &(*out_events)[i];
+    target->record_length = source->record_length;
+    target->flags = source->flags;
+    target->event_id = source->event_id;
+    target->submission_id = source->submission_id;
+    target->command_buffer_id = source->command_buffer_id;
+    target->executable_id = source->executable_id;
+    target->command_index = source->command_index;
+    target->export_ordinal = source->export_ordinal;
+    memcpy(target->workgroup_count, source->workgroup_count,
+           sizeof(target->workgroup_count));
+    memcpy(target->workgroup_size, source->workgroup_size,
+           sizeof(target->workgroup_size));
+    target->start_tick = source->start_tick;
+    target->end_tick = source->end_tick;
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_copy_queue_device_events(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t read_position,
+    iree_host_size_t event_count, iree_host_size_t* out_storage_size,
+    iree_hal_profile_queue_device_event_t** out_events) {
+  *out_storage_size = 0;
+  *out_events = NULL;
+  if (event_count == 0) return iree_ok_status();
+
+  IREE_RETURN_IF_ERROR(IREE_STRUCT_LAYOUT(
+      0, out_storage_size,
+      IREE_STRUCT_FIELD(event_count, iree_hal_profile_queue_device_event_t,
+                        NULL)));
+  IREE_RETURN_IF_ERROR(iree_allocator_malloc(
+      queue->host_allocator, *out_storage_size, (void**)out_events));
+  for (iree_host_size_t i = 0; i < event_count; ++i) {
+    const iree_hal_amdgpu_profile_queue_device_event_t* source =
+        iree_hal_amdgpu_host_queue_profile_queue_device_event_at(
+            queue, read_position + i);
+    memcpy(&(*out_events)[i], source, sizeof((*out_events)[i]));
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_build_event_relationships(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_profile_dispatch_event_t* dispatch_events,
+    iree_host_size_t dispatch_event_count,
+    const iree_hal_profile_queue_device_event_t* queue_device_events,
+    iree_host_size_t queue_device_event_count, uint64_t stream_id,
+    uint32_t physical_device_ordinal, uint32_t queue_ordinal,
+    iree_host_size_t* out_storage_size,
+    iree_hal_profile_event_relationship_record_t** out_relationships) {
+  *out_storage_size = 0;
+  *out_relationships = NULL;
+
+  const iree_host_size_t max_relationship_count =
+      dispatch_event_count + queue_device_event_count;
+  if (max_relationship_count == 0) return iree_ok_status();
+
+  IREE_RETURN_IF_ERROR(IREE_STRUCT_LAYOUT(
+      0, out_storage_size,
+      IREE_STRUCT_FIELD(max_relationship_count,
+                        iree_hal_profile_event_relationship_record_t, NULL)));
+  IREE_RETURN_IF_ERROR(iree_allocator_malloc(
+      queue->host_allocator, *out_storage_size, (void**)out_relationships));
+
+  iree_host_size_t relationship_count = 0;
+  for (iree_host_size_t i = 0; i < dispatch_event_count; ++i) {
+    const iree_hal_profile_dispatch_event_t* event = &dispatch_events[i];
+    if (event->submission_id == 0) continue;
+    iree_hal_profile_event_relationship_record_t* relationship =
+        &(*out_relationships)[relationship_count++];
+    *relationship = iree_hal_profile_event_relationship_record_default();
+    relationship->type =
+        IREE_HAL_PROFILE_EVENT_RELATIONSHIP_TYPE_QUEUE_SUBMISSION_DISPATCH;
+    relationship->relationship_id = relationship_count;
+    relationship->source_type =
+        IREE_HAL_PROFILE_EVENT_ENDPOINT_TYPE_QUEUE_SUBMISSION;
+    relationship->target_type =
+        IREE_HAL_PROFILE_EVENT_ENDPOINT_TYPE_DISPATCH_EVENT;
+    relationship->physical_device_ordinal = physical_device_ordinal;
+    relationship->queue_ordinal = queue_ordinal;
+    relationship->stream_id = stream_id;
+    relationship->source_id = event->submission_id;
+    relationship->target_id = event->event_id;
+  }
+
+  for (iree_host_size_t i = 0; i < queue_device_event_count; ++i) {
+    const iree_hal_profile_queue_device_event_t* event =
+        &queue_device_events[i];
+    if (event->submission_id == 0) continue;
+    iree_hal_profile_event_relationship_record_t* relationship =
+        &(*out_relationships)[relationship_count++];
+    *relationship = iree_hal_profile_event_relationship_record_default();
+    relationship->type =
+        IREE_HAL_PROFILE_EVENT_RELATIONSHIP_TYPE_QUEUE_SUBMISSION_QUEUE_DEVICE_EVENT;
+    relationship->relationship_id = relationship_count;
+    relationship->source_type =
+        IREE_HAL_PROFILE_EVENT_ENDPOINT_TYPE_QUEUE_SUBMISSION;
+    relationship->target_type =
+        IREE_HAL_PROFILE_EVENT_ENDPOINT_TYPE_QUEUE_DEVICE_EVENT;
+    relationship->physical_device_ordinal = physical_device_ordinal;
+    relationship->queue_ordinal = queue_ordinal;
+    relationship->stream_id = stream_id;
+    relationship->source_id = event->submission_id;
+    relationship->target_id = event->event_id;
+  }
+
+  *out_storage_size =
+      relationship_count * sizeof(iree_hal_profile_event_relationship_record_t);
+  return iree_ok_status();
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_write_profile_events(
+    iree_hal_amdgpu_host_queue_t* queue, iree_hal_profile_sink_t* sink,
+    uint64_t session_id) {
+  if (!sink) return iree_ok_status();
+  if (IREE_UNLIKELY(queue->device_ordinal > UINT32_MAX)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "profile event physical device ordinal out of "
+                            "range: %" PRIhsz,
+                            queue->device_ordinal);
+  }
+
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  iree_hal_profile_dispatch_event_t* dispatch_events = NULL;
+  iree_host_size_t dispatch_event_storage_size = 0;
+  iree_hal_profile_queue_device_event_t* queue_device_events = NULL;
+  iree_host_size_t queue_device_event_storage_size = 0;
+  iree_hal_profile_event_relationship_record_t* relationships = NULL;
+  iree_host_size_t relationship_storage_size = 0;
+  iree_host_size_t dispatch_event_count = 0;
+  iree_host_size_t queue_device_event_count = 0;
+
+  iree_slim_mutex_lock(&queue->profiling.event_mutex);
+  const uint64_t dispatch_read_position =
+      queue->profiling.dispatch_events.read_position;
+  const uint64_t dispatch_ready_position =
+      queue->profiling.dispatch_events.ready_position;
+  const uint64_t queue_device_read_position =
+      queue->profiling.queue_device_events.read_position;
+  const uint64_t queue_device_ready_position =
+      queue->profiling.queue_device_events.ready_position;
+  dispatch_event_count =
+      (iree_host_size_t)(dispatch_ready_position - dispatch_read_position);
+  queue_device_event_count = (iree_host_size_t)(queue_device_ready_position -
+                                                queue_device_read_position);
+  iree_status_t status = iree_hal_amdgpu_host_queue_copy_dispatch_events(
+      queue, dispatch_read_position, dispatch_event_count,
+      &dispatch_event_storage_size, &dispatch_events);
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_host_queue_copy_queue_device_events(
+        queue, queue_device_read_position, queue_device_event_count,
+        &queue_device_event_storage_size, &queue_device_events);
+  }
+  iree_slim_mutex_unlock(&queue->profiling.event_mutex);
+
+  const bool has_events =
+      dispatch_event_count != 0 || queue_device_event_count != 0;
+  if (iree_status_is_ok(status) && has_events) {
+    iree_hal_profile_chunk_metadata_t metadata =
+        iree_hal_profile_chunk_metadata_default();
+    const uint32_t physical_device_ordinal = (uint32_t)queue->device_ordinal;
+    const uint32_t queue_ordinal = iree_async_axis_queue_index(queue->axis);
+    metadata.content_type = IREE_HAL_PROFILE_CONTENT_TYPE_DISPATCH_EVENTS;
+    metadata.name = iree_make_cstring_view("amdgpu.dispatch");
+    metadata.session_id = session_id;
+    metadata.stream_id =
+        ((uint64_t)physical_device_ordinal << 32) | (uint64_t)queue_ordinal;
+    metadata.physical_device_ordinal = physical_device_ordinal;
+    metadata.queue_ordinal = queue_ordinal;
+
+    status = iree_hal_amdgpu_host_queue_build_event_relationships(
+        queue, dispatch_events, dispatch_event_count, queue_device_events,
+        queue_device_event_count, metadata.stream_id, physical_device_ordinal,
+        queue_ordinal, &relationship_storage_size, &relationships);
+
+    if (iree_status_is_ok(status) && dispatch_event_count != 0) {
+      iree_const_byte_span_t iovec = iree_make_const_byte_span(
+          dispatch_events, dispatch_event_storage_size);
+      status = iree_hal_profile_sink_write(sink, &metadata, 1, &iovec);
+    }
+    if (iree_status_is_ok(status) && queue_device_event_count != 0) {
+      metadata.content_type = IREE_HAL_PROFILE_CONTENT_TYPE_QUEUE_DEVICE_EVENTS;
+      metadata.name = iree_make_cstring_view("amdgpu.queue_device");
+      iree_const_byte_span_t iovec = iree_make_const_byte_span(
+          queue_device_events, queue_device_event_storage_size);
+      status = iree_hal_profile_sink_write(sink, &metadata, 1, &iovec);
+    }
+    if (iree_status_is_ok(status) && relationship_storage_size != 0) {
+      metadata.content_type = IREE_HAL_PROFILE_CONTENT_TYPE_EVENT_RELATIONSHIPS;
+      metadata.name = iree_make_cstring_view("amdgpu.relationships");
+      iree_const_byte_span_t relationship_iovec =
+          iree_make_const_byte_span(relationships, relationship_storage_size);
+      status =
+          iree_hal_profile_sink_write(sink, &metadata, 1, &relationship_iovec);
+    }
+    if (iree_status_is_ok(status)) {
+      status = iree_hal_amdgpu_host_queue_write_profile_counter_samples(
+          queue, sink, session_id, dispatch_read_position, dispatch_event_count,
+          dispatch_events);
+    }
+    if (iree_status_is_ok(status)) {
+      status = iree_hal_amdgpu_host_queue_write_profile_traces(
+          queue, sink, session_id, dispatch_read_position, dispatch_event_count,
+          dispatch_events);
+    }
+  }
+
+  if (iree_status_is_ok(status) && has_events) {
+    iree_hal_amdgpu_host_queue_release_profile_trace_slots(
+        queue, dispatch_read_position, dispatch_event_count);
+    iree_slim_mutex_lock(&queue->profiling.event_mutex);
+    queue->profiling.dispatch_events.read_position =
+        dispatch_read_position + dispatch_event_count;
+    queue->profiling.queue_device_events.read_position =
+        queue_device_read_position + queue_device_event_count;
+    iree_slim_mutex_unlock(&queue->profiling.event_mutex);
+  }
+
+  iree_allocator_free(queue->host_allocator, dispatch_events);
+  iree_allocator_free(queue->host_allocator, queue_device_events);
+  iree_allocator_free(queue->host_allocator, relationships);
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_profile_events.h b/runtime/src/iree/hal/drivers/amdgpu/host_queue_profile_events.h
new file mode 100644
index 0000000..9f7268a
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_profile_events.h
@@ -0,0 +1,158 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_PROFILE_EVENTS_H_
+#define IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_PROFILE_EVENTS_H_
+
+#include "iree/hal/drivers/amdgpu/host_queue.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Allocates queue-local device-visible event rings used by dispatch and
+// queue-device timestamp profiling.
+iree_status_t iree_hal_amdgpu_host_queue_ensure_profile_event_storage(
+    iree_hal_amdgpu_host_queue_t* queue);
+
+// Clears all event-ring cursors and records while preserving allocated storage.
+void iree_hal_amdgpu_host_queue_clear_profile_events(
+    iree_hal_amdgpu_host_queue_t* queue);
+
+// Releases queue-local profile event ring storage.
+void iree_hal_amdgpu_host_queue_deallocate_profile_events(
+    iree_hal_amdgpu_host_queue_t* queue);
+
+// Allocates queue-local completion signals paired with dispatch event slots.
+iree_status_t iree_hal_amdgpu_host_queue_ensure_profiling_completion_signals(
+    iree_hal_amdgpu_host_queue_t* queue);
+
+// Releases queue-local profiling completion signals.
+void iree_hal_amdgpu_host_queue_deallocate_profiling_completion_signals(
+    iree_hal_amdgpu_host_queue_t* queue);
+
+// Returns the configured dispatch event ring capacity.
+static inline uint32_t
+iree_hal_amdgpu_host_queue_profile_dispatch_event_capacity(
+    const iree_hal_amdgpu_host_queue_t* queue) {
+  return queue->profiling.dispatch_events.capacity;
+}
+
+// Returns the dispatch event ring slot index for |event_position|.
+static inline uint32_t iree_hal_amdgpu_host_queue_profile_dispatch_event_index(
+    const iree_hal_amdgpu_host_queue_t* queue, uint64_t event_position) {
+  return (uint32_t)(event_position & queue->profiling.dispatch_events.mask);
+}
+
+// Returns the raw profiling completion signal paired with |event_position|'s
+// dispatch event ring slot. The returned pointer references queue-owned
+// iree_amd_signal_t storage, not a ROCR-created HSA signal, and must never be
+// passed to host signal APIs except as an AQL packet completion_signal handle.
+// Valid only while HSA queue timestamp profiling is enabled.
+static inline iree_amd_signal_t*
+iree_hal_amdgpu_host_queue_profiling_completion_signal_ptr(
+    const iree_hal_amdgpu_host_queue_t* queue, uint64_t event_position) {
+  const uint32_t signal_index =
+      iree_hal_amdgpu_host_queue_profile_dispatch_event_index(queue,
+                                                              event_position);
+  const uint32_t block_index =
+      signal_index / queue->profiling.signals.signals_per_block;
+  const uint32_t block_signal_index =
+      signal_index - block_index * queue->profiling.signals.signals_per_block;
+  uint8_t* block_ptr =
+      (uint8_t*)queue->profiling.signals.blocks[block_index]->ptr;
+  iree_amd_signal_t* signal =
+      (iree_amd_signal_t*)(block_ptr +
+                           block_signal_index * sizeof(iree_amd_signal_t));
+  return signal;
+}
+
+// Returns the raw profiling completion signal handle paired with
+// |event_position|'s dispatch event ring slot.
+static inline iree_hsa_signal_t
+iree_hal_amdgpu_host_queue_profiling_completion_signal(
+    const iree_hal_amdgpu_host_queue_t* queue, uint64_t event_position) {
+  iree_amd_signal_t* signal =
+      iree_hal_amdgpu_host_queue_profiling_completion_signal_ptr(
+          queue, event_position);
+  return (iree_hsa_signal_t){.handle = (uint64_t)(uintptr_t)signal};
+}
+
+// Reserves queue-local dispatch profile event records.
+//
+// Caller must hold submission_mutex. If the ring cannot hold |event_count|
+// records the function fails with RESOURCE_EXHAUSTED. Dispatch events are a
+// precise execution timeline, so AMDGPU does not drop them under pressure.
+// Callers that want to keep long captures exact must drain with
+// iree_hal_device_profiling_flush before the ring fills.
+iree_status_t iree_hal_amdgpu_host_queue_reserve_profile_dispatch_events(
+    iree_hal_amdgpu_host_queue_t* queue, uint32_t event_count,
+    iree_hal_amdgpu_profile_dispatch_event_reservation_t* out_reservation);
+
+// Cancels a tail reservation before its packets have been published.
+//
+// Caller must hold submission_mutex. Only valid for the most recent successful
+// reservation on a path that is failing before AQL publication.
+void iree_hal_amdgpu_host_queue_cancel_profile_dispatch_events(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_dispatch_event_reservation_t reservation);
+
+// Returns the dispatch event record at |event_position|.
+iree_hal_amdgpu_profile_dispatch_event_t*
+iree_hal_amdgpu_host_queue_profile_dispatch_event_at(
+    const iree_hal_amdgpu_host_queue_t* queue, uint64_t event_position);
+
+// Marks a completed event reservation ready for sink flush.
+void iree_hal_amdgpu_host_queue_retire_profile_dispatch_events(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_dispatch_event_reservation_t reservation);
+
+// Returns true when queue device timestamp packets should be emitted.
+bool iree_hal_amdgpu_host_queue_should_profile_queue_device_events(
+    const iree_hal_amdgpu_host_queue_t* queue);
+
+// Reserves queue-local device-timestamped queue operation records.
+//
+// Caller must hold submission_mutex. If the ring cannot hold |event_count|
+// records the function fails with RESOURCE_EXHAUSTED. Queue-device events are a
+// precise execution timeline, so AMDGPU does not drop them under pressure. The
+// returned records live in device-visible memory so PM4 packets can write
+// timestamp fields directly.
+iree_status_t iree_hal_amdgpu_host_queue_reserve_profile_queue_device_events(
+    iree_hal_amdgpu_host_queue_t* queue, uint32_t event_count,
+    iree_hal_amdgpu_profile_queue_device_event_reservation_t* out_reservation);
+
+// Cancels a tail queue-device-event reservation before AQL publication.
+//
+// Caller must hold submission_mutex. Only valid for the most recent successful
+// reservation on a path that is failing before AQL publication.
+void iree_hal_amdgpu_host_queue_cancel_profile_queue_device_events(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_queue_device_event_reservation_t reservation);
+
+// Returns the queue device event record at |event_position|.
+iree_hal_amdgpu_profile_queue_device_event_t*
+iree_hal_amdgpu_host_queue_profile_queue_device_event_at(
+    const iree_hal_amdgpu_host_queue_t* queue, uint64_t event_position);
+
+// Marks a completed queue-device-event reservation ready for sink flush.
+void iree_hal_amdgpu_host_queue_retire_profile_queue_device_events(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_queue_device_event_reservation_t reservation);
+
+// Writes and clears buffered device-side profile events for this queue.
+//
+// Sink writes are cold profiling API operations and may block. The submission
+// and completion paths only append to the queue-local batch.
+iree_status_t iree_hal_amdgpu_host_queue_write_profile_events(
+    iree_hal_amdgpu_host_queue_t* queue, iree_hal_profile_sink_t* sink,
+    uint64_t session_id);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_PROFILE_EVENTS_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_staging.c b/runtime/src/iree/hal/drivers/amdgpu/host_queue_staging.c
new file mode 100644
index 0000000..ae42f5e
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_staging.c
@@ -0,0 +1,1321 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/host_queue_staging.h"
+
+#include <string.h>
+
+#include "iree/async/operations/file.h"
+#include "iree/hal/drivers/amdgpu/access_policy.h"
+#include "iree/hal/drivers/amdgpu/buffer.h"
+#include "iree/hal/drivers/amdgpu/host_queue.h"
+#include "iree/hal/drivers/amdgpu/host_queue_blit.h"
+#include "iree/hal/drivers/amdgpu/host_queue_profile.h"
+#include "iree/hal/drivers/amdgpu/host_queue_submission.h"
+#include "iree/hal/drivers/amdgpu/slab_provider.h"
+
+typedef struct iree_hal_amdgpu_staging_slot_t {
+  // Slot ordinal in the physical-device staging pool.
+  uint32_t ordinal;
+  // Byte offset of the slot inside the staging buffer.
+  iree_device_size_t buffer_offset;
+  // Host-accessible bytes for file I/O.
+  iree_byte_span_t host_span;
+  // HAL buffer wrapping the complete staging allocation. Borrowed from the
+  // pool; queue copy submissions retain it while the GPU owns the slot.
+  iree_hal_buffer_t* buffer;
+} iree_hal_amdgpu_staging_slot_t;
+
+typedef struct iree_hal_amdgpu_staging_allocation_t {
+  // Borrowed HSA API table used to free |allocation_base|.
+  const iree_hal_amdgpu_libhsa_t* libhsa;
+  // Host allocator used to allocate this release state.
+  iree_allocator_t host_allocator;
+  // Original pointer returned by hsa_amd_memory_pool_allocate.
+  void* allocation_base;
+} iree_hal_amdgpu_staging_allocation_t;
+
+typedef uint32_t iree_hal_amdgpu_staging_pool_waiter_flags_t;
+enum iree_hal_amdgpu_staging_pool_waiter_flag_bits_e {
+  IREE_HAL_AMDGPU_STAGING_POOL_WAITER_FLAG_NONE = 0u,
+  IREE_HAL_AMDGPU_STAGING_POOL_WAITER_FLAG_QUEUED = 1u << 0,
+};
+
+// Callback invoked when a staging slot may be available.
+typedef void(IREE_API_PTR* iree_hal_amdgpu_staging_pool_waiter_fn_t)(
+    void* user_data);
+
+// Intrusive waiter used by transfers that cannot acquire a staging slot.
+struct iree_hal_amdgpu_staging_pool_waiter_t {
+  // Next waiter in the pool-owned FIFO list.
+  iree_hal_amdgpu_staging_pool_waiter_t* next;
+  // Callback invoked after the waiter is removed from the FIFO list.
+  iree_hal_amdgpu_staging_pool_waiter_fn_t fn;
+  // User data passed to |fn|.
+  void* user_data;
+  // Slot reserved for this waiter when it is dequeued.
+  iree_hal_amdgpu_staging_slot_t slot;
+  // Waiter lifecycle flags from iree_hal_amdgpu_staging_pool_waiter_flags_t.
+  iree_hal_amdgpu_staging_pool_waiter_flags_t flags;
+};
+
+typedef enum iree_hal_amdgpu_staging_pool_wait_result_e {
+  // The waiter was newly queued and will receive a future callback.
+  IREE_HAL_AMDGPU_STAGING_POOL_WAIT_QUEUED = 0,
+  // The waiter was already queued by an earlier pump attempt.
+  IREE_HAL_AMDGPU_STAGING_POOL_WAIT_ALREADY_QUEUED = 1,
+  // A slot became available before the waiter was queued; retry acquire.
+  IREE_HAL_AMDGPU_STAGING_POOL_WAIT_RETRY = 2,
+} iree_hal_amdgpu_staging_pool_wait_result_t;
+
+void iree_hal_amdgpu_staging_pool_options_initialize(
+    iree_hal_amdgpu_staging_pool_options_t* out_options) {
+  IREE_ASSERT_ARGUMENT(out_options);
+  memset(out_options, 0, sizeof(*out_options));
+  out_options->slot_size = IREE_HAL_AMDGPU_STAGING_SLOT_SIZE_DEFAULT;
+  out_options->slot_count = IREE_HAL_AMDGPU_STAGING_SLOT_COUNT_DEFAULT;
+}
+
+iree_status_t iree_hal_amdgpu_staging_pool_options_verify(
+    const iree_hal_amdgpu_staging_pool_options_t* options) {
+  IREE_ASSERT_ARGUMENT(options);
+  if (options->slot_size == 0 ||
+      !iree_host_size_is_power_of_two(options->slot_size)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "staging slot size must be a non-zero power of two (got %" PRIhsz ")",
+        options->slot_size);
+  }
+  if (options->slot_size < IREE_HAL_AMDGPU_STAGING_SLOT_ALIGNMENT) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "staging slot size must be at least %" PRIhsz
+        " bytes to preserve slot "
+        "alignment (got %" PRIhsz ")",
+        (iree_host_size_t)IREE_HAL_AMDGPU_STAGING_SLOT_ALIGNMENT,
+        options->slot_size);
+  }
+  if (options->slot_count == 0 ||
+      !iree_host_size_is_power_of_two(options->slot_count)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "staging slot count must be a non-zero power of two (got %u)",
+        options->slot_count);
+  }
+  iree_host_size_t total_size = 0;
+  if (!iree_host_size_checked_mul(options->slot_size, options->slot_count,
+                                  &total_size) ||
+      total_size > (iree_host_size_t)IREE_DEVICE_SIZE_MAX) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "staging pool size overflows (slot_size=%" PRIhsz
+                            ", slot_count=%u)",
+                            options->slot_size, options->slot_count);
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_staging_pool_resolve_access_agents(
+    const iree_hal_amdgpu_topology_t* topology,
+    iree_hal_queue_affinity_t queue_affinity_mask,
+    iree_hal_amdgpu_access_agent_list_t* out_agent_list) {
+  const iree_hal_amdgpu_queue_affinity_domain_t domain = {
+      .supported_affinity = queue_affinity_mask,
+      .physical_device_count = topology->gpu_agent_count,
+      .queue_count_per_physical_device = topology->gpu_agent_queue_count,
+  };
+  return iree_hal_amdgpu_access_agent_list_resolve(
+      topology, domain, queue_affinity_mask, out_agent_list);
+}
+
+static void iree_hal_amdgpu_staging_allocation_release(
+    void* user_data, iree_hal_buffer_t* buffer) {
+  (void)buffer;
+  iree_hal_amdgpu_staging_allocation_t* allocation =
+      (iree_hal_amdgpu_staging_allocation_t*)user_data;
+  IREE_TRACE_ZONE_BEGIN(z0);
+  if (allocation->allocation_base) {
+    iree_hal_amdgpu_hsa_cleanup_assert_success(
+        iree_hsa_amd_memory_pool_free_raw(allocation->libhsa,
+                                          allocation->allocation_base));
+  }
+  iree_allocator_free(allocation->host_allocator, allocation);
+  IREE_TRACE_ZONE_END(z0);
+}
+
+static iree_hal_amdgpu_staging_pool_waiter_t*
+iree_hal_amdgpu_staging_pool_pop_waiter_locked(
+    iree_hal_amdgpu_staging_pool_t* pool,
+    const iree_hal_amdgpu_staging_slot_t* slot) {
+  iree_hal_amdgpu_staging_pool_waiter_t* waiter = pool->waiter_head;
+  if (waiter) {
+    pool->waiter_head = waiter->next;
+    if (!pool->waiter_head) {
+      pool->waiter_tail = NULL;
+    }
+    waiter->next = NULL;
+    waiter->slot = *slot;
+    waiter->flags &= ~IREE_HAL_AMDGPU_STAGING_POOL_WAITER_FLAG_QUEUED;
+  }
+  return waiter;
+}
+
+static bool iree_hal_amdgpu_staging_pool_try_acquire(
+    iree_hal_amdgpu_staging_pool_t* pool,
+    iree_hal_amdgpu_staging_slot_t* out_slot) {
+  bool did_acquire = false;
+  iree_slim_mutex_lock(&pool->mutex);
+  if (pool->available_count > 0) {
+    const uint32_t slot_ordinal =
+        pool->free_slots[pool->free_read++ & pool->slot_mask];
+    --pool->available_count;
+    out_slot->ordinal = slot_ordinal;
+    out_slot->buffer_offset =
+        (iree_device_size_t)slot_ordinal * pool->slot_size;
+    out_slot->host_span = iree_make_byte_span(
+        pool->host_base + (iree_host_size_t)out_slot->buffer_offset,
+        pool->slot_size);
+    out_slot->buffer = pool->buffer;
+    did_acquire = true;
+  }
+  iree_slim_mutex_unlock(&pool->mutex);
+  return did_acquire;
+}
+
+static bool iree_hal_amdgpu_staging_pool_take_waiter_slot(
+    iree_hal_amdgpu_staging_pool_t* pool,
+    iree_hal_amdgpu_staging_pool_waiter_t* waiter,
+    iree_hal_amdgpu_staging_slot_t* out_slot) {
+  bool did_take = false;
+  iree_slim_mutex_lock(&pool->mutex);
+  if (waiter->slot.buffer) {
+    *out_slot = waiter->slot;
+    memset(&waiter->slot, 0, sizeof(waiter->slot));
+    did_take = true;
+  }
+  iree_slim_mutex_unlock(&pool->mutex);
+  return did_take;
+}
+
+static iree_hal_amdgpu_staging_pool_wait_result_t
+iree_hal_amdgpu_staging_pool_queue_waiter(
+    iree_hal_amdgpu_staging_pool_t* pool,
+    iree_hal_amdgpu_staging_pool_waiter_t* waiter,
+    iree_hal_amdgpu_staging_pool_waiter_fn_t fn, void* user_data) {
+  iree_hal_amdgpu_staging_pool_wait_result_t result =
+      IREE_HAL_AMDGPU_STAGING_POOL_WAIT_QUEUED;
+  iree_slim_mutex_lock(&pool->mutex);
+  if (iree_any_bit_set(waiter->flags,
+                       IREE_HAL_AMDGPU_STAGING_POOL_WAITER_FLAG_QUEUED)) {
+    result = IREE_HAL_AMDGPU_STAGING_POOL_WAIT_ALREADY_QUEUED;
+  } else if (pool->available_count > 0) {
+    result = IREE_HAL_AMDGPU_STAGING_POOL_WAIT_RETRY;
+  } else {
+    waiter->next = NULL;
+    waiter->fn = fn;
+    waiter->user_data = user_data;
+    waiter->flags |= IREE_HAL_AMDGPU_STAGING_POOL_WAITER_FLAG_QUEUED;
+    if (pool->waiter_tail) {
+      pool->waiter_tail->next = waiter;
+    } else {
+      pool->waiter_head = waiter;
+    }
+    pool->waiter_tail = waiter;
+  }
+  iree_slim_mutex_unlock(&pool->mutex);
+  return result;
+}
+
+static bool iree_hal_amdgpu_staging_pool_cancel_waiter(
+    iree_hal_amdgpu_staging_pool_t* pool,
+    iree_hal_amdgpu_staging_pool_waiter_t* waiter) {
+  bool did_cancel = false;
+  iree_slim_mutex_lock(&pool->mutex);
+  iree_hal_amdgpu_staging_pool_waiter_t* previous = NULL;
+  for (iree_hal_amdgpu_staging_pool_waiter_t* current = pool->waiter_head;
+       current != NULL; current = current->next) {
+    if (current == waiter) {
+      if (previous) {
+        previous->next = current->next;
+      } else {
+        pool->waiter_head = current->next;
+      }
+      if (pool->waiter_tail == current) {
+        pool->waiter_tail = previous;
+      }
+      waiter->next = NULL;
+      waiter->flags &= ~IREE_HAL_AMDGPU_STAGING_POOL_WAITER_FLAG_QUEUED;
+      did_cancel = true;
+      break;
+    }
+    previous = current;
+  }
+  iree_slim_mutex_unlock(&pool->mutex);
+  return did_cancel;
+}
+
+static void iree_hal_amdgpu_staging_pool_release(
+    iree_hal_amdgpu_staging_pool_t* pool, uint32_t slot_ordinal) {
+  iree_hal_amdgpu_staging_slot_t slot = {
+      .ordinal = slot_ordinal,
+      .buffer_offset = (iree_device_size_t)slot_ordinal * pool->slot_size,
+      .host_span = iree_make_byte_span(
+          pool->host_base + (iree_host_size_t)slot_ordinal * pool->slot_size,
+          pool->slot_size),
+      .buffer = pool->buffer,
+  };
+  iree_hal_amdgpu_staging_pool_waiter_t* waiter = NULL;
+  iree_slim_mutex_lock(&pool->mutex);
+  waiter = iree_hal_amdgpu_staging_pool_pop_waiter_locked(pool, &slot);
+  if (!waiter) {
+    pool->free_slots[pool->free_write++ & pool->slot_mask] = slot_ordinal;
+    ++pool->available_count;
+  }
+  iree_slim_mutex_unlock(&pool->mutex);
+  if (waiter) {
+    waiter->fn(waiter->user_data);
+  }
+}
+
+iree_status_t iree_hal_amdgpu_staging_pool_initialize(
+    iree_hal_device_t* logical_device, const iree_hal_amdgpu_libhsa_t* libhsa,
+    const iree_hal_amdgpu_topology_t* topology,
+    const iree_hal_amdgpu_host_memory_pools_t* host_memory_pools,
+    iree_hal_queue_affinity_t queue_affinity_mask,
+    const iree_hal_amdgpu_staging_pool_options_t* options,
+    iree_allocator_t host_allocator, iree_hal_amdgpu_staging_pool_t* out_pool) {
+  IREE_ASSERT_ARGUMENT(logical_device);
+  IREE_ASSERT_ARGUMENT(libhsa);
+  IREE_ASSERT_ARGUMENT(topology);
+  IREE_ASSERT_ARGUMENT(host_memory_pools);
+  IREE_ASSERT_ARGUMENT(options);
+  IREE_ASSERT_ARGUMENT(out_pool);
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, options->slot_size);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, options->slot_count);
+
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_hal_amdgpu_staging_pool_options_verify(options));
+
+  memset(out_pool, 0, sizeof(*out_pool));
+  out_pool->host_allocator = host_allocator;
+  out_pool->slot_size = options->slot_size;
+  out_pool->slot_count = options->slot_count;
+  out_pool->slot_mask = options->slot_count - 1u;
+  iree_slim_mutex_initialize(&out_pool->mutex);
+
+  hsa_amd_memory_pool_t memory_pool = host_memory_pools->coarse_pool;
+  if (options->force_fine_host_memory || !memory_pool.handle) {
+    memory_pool = host_memory_pools->fine_pool;
+  }
+  iree_status_t status = iree_ok_status();
+  if (IREE_UNLIKELY(!memory_pool.handle)) {
+    status = iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                              "AMDGPU staging requires a host memory pool");
+  }
+
+  iree_hal_amdgpu_slab_provider_memory_pool_properties_t memory_pool_properties;
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_slab_provider_query_memory_pool_properties(
+        libhsa, memory_pool, &memory_pool_properties);
+  }
+
+  iree_host_size_t total_size = 0;
+  if (iree_status_is_ok(status) &&
+      !iree_host_size_checked_mul(options->slot_size, options->slot_count,
+                                  &total_size)) {
+    status = iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                              "staging pool size overflows (slot_size=%" PRIhsz
+                              ", slot_count=%u)",
+                              options->slot_size, options->slot_count);
+  }
+
+  iree_host_size_t free_slots_size = 0;
+  if (iree_status_is_ok(status) &&
+      !iree_host_size_checked_mul(options->slot_count, sizeof(uint32_t),
+                                  &free_slots_size)) {
+    status = iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                              "staging free-slot table size overflows");
+  }
+
+  iree_host_size_t allocation_size = total_size;
+  if (iree_status_is_ok(status) && memory_pool_properties.allocation_alignment <
+                                       IREE_HAL_AMDGPU_STAGING_SLOT_ALIGNMENT) {
+    if (!iree_host_size_checked_add(total_size,
+                                    IREE_HAL_AMDGPU_STAGING_SLOT_ALIGNMENT - 1,
+                                    &allocation_size)) {
+      status = iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                                "staging aligned allocation size overflows");
+    }
+  }
+
+  uint32_t* free_slots = NULL;
+  if (iree_status_is_ok(status)) {
+    status = iree_allocator_malloc(host_allocator, free_slots_size,
+                                   (void**)&free_slots);
+  }
+
+  void* allocation_base = NULL;
+  if (iree_status_is_ok(status)) {
+    status = iree_hsa_amd_memory_pool_allocate(
+        IREE_LIBHSA(libhsa), memory_pool, allocation_size,
+        HSA_AMD_MEMORY_POOL_STANDARD_FLAG, &allocation_base);
+  }
+  if (iree_status_is_ok(status)) {
+    iree_hal_amdgpu_access_agent_list_t access_agents;
+    status = iree_hal_amdgpu_staging_pool_resolve_access_agents(
+        topology, queue_affinity_mask, &access_agents);
+    if (iree_status_is_ok(status)) {
+      status = iree_hal_amdgpu_access_allow_agent_list(libhsa, &access_agents,
+                                                       allocation_base);
+    }
+  }
+
+  void* host_ptr = NULL;
+  if (iree_status_is_ok(status)) {
+    const uintptr_t allocation_begin = (uintptr_t)allocation_base;
+    const uintptr_t allocation_end = allocation_begin + allocation_size;
+    iree_host_size_t aligned_host_base = 0;
+    if (!iree_host_size_checked_align((iree_host_size_t)allocation_begin,
+                                      IREE_HAL_AMDGPU_STAGING_SLOT_ALIGNMENT,
+                                      &aligned_host_base)) {
+      status = iree_make_status(
+          IREE_STATUS_OUT_OF_RANGE,
+          "HSA staging allocation base overflowed while aligning to %" PRIhsz
+          " bytes (base=%p)",
+          (iree_host_size_t)IREE_HAL_AMDGPU_STAGING_SLOT_ALIGNMENT,
+          allocation_base);
+    }
+    if (iree_status_is_ok(status)) {
+      const uintptr_t aligned_base = (uintptr_t)aligned_host_base;
+      const uintptr_t aligned_end = aligned_base + total_size;
+      if (allocation_end < allocation_begin ||
+          aligned_base < allocation_begin || aligned_end < aligned_base ||
+          aligned_end > allocation_end ||
+          !iree_host_ptr_has_alignment(
+              (void*)aligned_base, IREE_HAL_AMDGPU_STAGING_SLOT_ALIGNMENT)) {
+        status = iree_make_status(
+            IREE_STATUS_INTERNAL,
+            "HSA staging allocation could not satisfy %" PRIhsz
+            "-byte alignment (base=%p, allocation_size=%" PRIhsz
+            ", total_size=%" PRIhsz ")",
+            (iree_host_size_t)IREE_HAL_AMDGPU_STAGING_SLOT_ALIGNMENT,
+            allocation_base, allocation_size, total_size);
+      } else {
+        host_ptr = (void*)aligned_base;
+      }
+    }
+  }
+
+  iree_hal_amdgpu_staging_allocation_t* release_state = NULL;
+  if (iree_status_is_ok(status)) {
+    status = iree_allocator_malloc(host_allocator, sizeof(*release_state),
+                                   (void**)&release_state);
+  }
+  if (iree_status_is_ok(status)) {
+    release_state->libhsa = libhsa;
+    release_state->host_allocator = host_allocator;
+    release_state->allocation_base = allocation_base;
+  }
+
+  iree_hal_buffer_t* buffer = NULL;
+  if (iree_status_is_ok(status)) {
+    const iree_hal_buffer_placement_t placement = {
+        .device = logical_device,
+        .queue_affinity = queue_affinity_mask,
+        .flags = IREE_HAL_BUFFER_PLACEMENT_FLAG_NONE,
+    };
+    status = iree_hal_amdgpu_buffer_create(
+        libhsa, placement,
+        IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
+        IREE_HAL_MEMORY_ACCESS_ALL, IREE_HAL_BUFFER_USAGE_TRANSFER,
+        (iree_device_size_t)total_size, (iree_device_size_t)total_size,
+        host_ptr,
+        (iree_hal_buffer_release_callback_t){
+            .fn = iree_hal_amdgpu_staging_allocation_release,
+            .user_data = release_state,
+        },
+        host_allocator, &buffer);
+  }
+
+  if (iree_status_is_ok(status)) {
+    for (uint32_t i = 0; i < options->slot_count; ++i) {
+      free_slots[i] = i;
+    }
+    out_pool->buffer = buffer;
+    out_pool->host_base = (uint8_t*)host_ptr;
+    out_pool->available_count = options->slot_count;
+    out_pool->free_slots = free_slots;
+    out_pool->free_write = options->slot_count;
+    release_state = NULL;
+    allocation_base = NULL;
+  } else {
+    iree_hal_buffer_release(buffer);
+    if (allocation_base) {
+      status = iree_status_join(
+          status,
+          iree_hsa_amd_memory_pool_free(IREE_LIBHSA(libhsa), allocation_base));
+    }
+    iree_allocator_free(host_allocator, release_state);
+    iree_allocator_free(host_allocator, free_slots);
+    iree_slim_mutex_deinitialize(&out_pool->mutex);
+    memset(out_pool, 0, sizeof(*out_pool));
+  }
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+void iree_hal_amdgpu_staging_pool_deinitialize(
+    iree_hal_amdgpu_staging_pool_t* pool) {
+  IREE_ASSERT_ARGUMENT(pool);
+  if (!pool->buffer && !pool->free_slots && pool->slot_count == 0) {
+    return;
+  }
+  IREE_TRACE_ZONE_BEGIN(z0);
+  iree_hal_buffer_release(pool->buffer);
+  iree_allocator_free(pool->host_allocator, pool->free_slots);
+  iree_slim_mutex_deinitialize(&pool->mutex);
+  memset(pool, 0, sizeof(*pool));
+  IREE_TRACE_ZONE_END(z0);
+}
+
+typedef enum iree_hal_amdgpu_staging_transfer_kind_e {
+  // File data flows from the proactor into staging and then GPU copy writes the
+  // target buffer.
+  IREE_HAL_AMDGPU_STAGING_TRANSFER_READ = 0,
+  // GPU copy writes staging and then file data flows from staging into the
+  // proactor.
+  IREE_HAL_AMDGPU_STAGING_TRANSFER_WRITE = 1,
+} iree_hal_amdgpu_staging_transfer_kind_t;
+
+static iree_hal_profile_queue_event_type_t
+iree_hal_amdgpu_staging_transfer_profile_event_type(
+    iree_hal_amdgpu_staging_transfer_kind_t kind) {
+  return kind == IREE_HAL_AMDGPU_STAGING_TRANSFER_READ
+             ? IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_READ
+             : IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_WRITE;
+}
+
+typedef uint32_t iree_hal_amdgpu_staging_transfer_flags_t;
+enum iree_hal_amdgpu_staging_transfer_flag_bits_e {
+  IREE_HAL_AMDGPU_STAGING_TRANSFER_FLAG_NONE = 0u,
+  IREE_HAL_AMDGPU_STAGING_TRANSFER_FLAG_FINISHING = 1u << 0,
+};
+
+typedef enum iree_hal_amdgpu_staging_chunk_state_e {
+  // Chunk is available for a new file subrange.
+  IREE_HAL_AMDGPU_STAGING_CHUNK_IDLE = 0,
+  // Chunk has an in-flight async file read into its staging slot.
+  IREE_HAL_AMDGPU_STAGING_CHUNK_READING = 1,
+  // Chunk has an in-flight GPU copy from staging to the user buffer.
+  IREE_HAL_AMDGPU_STAGING_CHUNK_COPYING_TO_DEVICE = 2,
+  // Chunk has an in-flight GPU copy from the user buffer to staging.
+  IREE_HAL_AMDGPU_STAGING_CHUNK_COPYING_TO_HOST = 3,
+  // Chunk has an in-flight async file write from its staging slot.
+  IREE_HAL_AMDGPU_STAGING_CHUNK_WRITING = 4,
+} iree_hal_amdgpu_staging_chunk_state_t;
+
+typedef struct iree_hal_amdgpu_staging_transfer_t
+    iree_hal_amdgpu_staging_transfer_t;
+
+typedef struct iree_hal_amdgpu_staging_chunk_t {
+  // Owning transfer.
+  iree_hal_amdgpu_staging_transfer_t* transfer;
+  // Current lifecycle state.
+  iree_hal_amdgpu_staging_chunk_state_t state;
+  // Staging slot owned by this chunk while |state| is not IDLE.
+  iree_hal_amdgpu_staging_slot_t slot;
+  // Byte offset from the transfer start.
+  iree_device_size_t transfer_offset;
+  // Byte length assigned to this chunk.
+  iree_host_size_t length;
+  // Bytes completed by the current partial file operation.
+  iree_host_size_t file_progress;
+  // Owned status captured by the GPU copy pre-signal action for post-drain use.
+  iree_status_t copy_status;
+  // Post-drain continuation queued by the GPU copy pre-signal action.
+  iree_hal_amdgpu_host_queue_post_drain_action_t post_drain_action;
+  // Async read operation storage.
+  iree_async_file_read_operation_t read_op;
+  // Async write operation storage.
+  iree_async_file_write_operation_t write_op;
+} iree_hal_amdgpu_staging_chunk_t;
+
+struct iree_hal_amdgpu_staging_transfer_t {
+  // Resource header retained by host actions, async file callbacks, and GPU
+  // copy reclaim entries.
+  iree_hal_resource_t resource;
+  // Host allocator used for this transfer and cloned semaphore-list storage.
+  iree_allocator_t host_allocator;
+  // Serializes transfer counters and terminal status ownership.
+  iree_slim_mutex_t mutex;
+  // Queue used for internal GPU copies and final user signal publication.
+  iree_hal_amdgpu_host_queue_t* queue;
+  // Physical-device staging pool used by this transfer.
+  iree_hal_amdgpu_staging_pool_t* pool;
+  // Logical device retained while asynchronous transfer work is pending.
+  iree_hal_device_t* logical_device;
+  // File being read or written.
+  iree_hal_file_t* file;
+  // Async file handle borrowed from |file|.
+  iree_async_file_t* async_file;
+  // User buffer being copied to or from.
+  iree_hal_buffer_t* buffer;
+  // File byte offset for the first requested byte.
+  uint64_t file_offset;
+  // User buffer byte offset for the first requested byte.
+  iree_device_size_t buffer_offset;
+  // Total requested transfer length.
+  iree_device_size_t requested_length;
+  // Number of bytes assigned to chunks.
+  iree_device_size_t submitted_length;
+  // Number of bytes fully transferred through all stages.
+  iree_device_size_t completed_length;
+  // Number of chunks currently owning a staging slot or in-flight operation.
+  uint32_t active_chunk_count;
+  // Number of chunk records in |chunks|.
+  uint32_t chunk_count;
+  // Number of wait semaphores supplied to the queue_read/write operation.
+  uint32_t profile_wait_count;
+  // Direction of this transfer.
+  iree_hal_amdgpu_staging_transfer_kind_t kind;
+  // Transfer lifecycle flags from iree_hal_amdgpu_staging_transfer_flags_t.
+  iree_hal_amdgpu_staging_transfer_flags_t flags;
+  // Owned first failure status, or OK if no failure has occurred.
+  iree_status_t failure_status;
+  // Waiter queued when all staging slots are temporarily unavailable.
+  iree_hal_amdgpu_staging_pool_waiter_t slot_waiter;
+  // Completion-thread retry queued when the final signal barrier is blocked by
+  // temporary queue capacity pressure.
+  iree_hal_amdgpu_host_queue_post_drain_action_t signal_capacity_retry;
+  // Cloned signal list published after the transfer completes.
+  iree_hal_semaphore_list_t signal_semaphore_list;
+  // Chunk records used to pipeline file I/O and GPU copies.
+  iree_hal_amdgpu_staging_chunk_t* chunks;
+};
+
+static void iree_hal_amdgpu_staging_transfer_pump(
+    iree_hal_amdgpu_staging_transfer_t* transfer);
+
+static void iree_hal_amdgpu_staging_copy_post_drain(void* user_data);
+static void iree_hal_amdgpu_staging_copy_capacity_post_drain(void* user_data);
+static void iree_hal_amdgpu_staging_signal_capacity_post_drain(void* user_data);
+
+static void iree_hal_amdgpu_staging_transfer_destroy(
+    iree_hal_resource_t* resource) {
+  iree_hal_amdgpu_staging_transfer_t* transfer =
+      (iree_hal_amdgpu_staging_transfer_t*)resource;
+  IREE_TRACE_ZONE_BEGIN(z0);
+  if (!iree_hal_semaphore_list_is_empty(transfer->signal_semaphore_list)) {
+    iree_hal_semaphore_list_free(transfer->signal_semaphore_list,
+                                 transfer->host_allocator);
+  }
+  iree_hal_buffer_release(transfer->buffer);
+  iree_hal_file_release(transfer->file);
+  iree_hal_device_release(transfer->logical_device);
+  iree_slim_mutex_deinitialize(&transfer->mutex);
+  iree_allocator_free(transfer->host_allocator, transfer);
+  IREE_TRACE_ZONE_END(z0);
+}
+
+static const iree_hal_resource_vtable_t
+    iree_hal_amdgpu_staging_transfer_vtable = {
+        .destroy = iree_hal_amdgpu_staging_transfer_destroy,
+};
+
+static iree_status_t iree_hal_amdgpu_staging_transfer_clone_queue_error(
+    iree_hal_amdgpu_staging_transfer_t* transfer) {
+  iree_status_t error = (iree_status_t)iree_atomic_load(
+      &transfer->queue->error_status, iree_memory_order_acquire);
+  return iree_status_is_ok(error) ? iree_ok_status() : iree_status_clone(error);
+}
+
+static void iree_hal_amdgpu_staging_transfer_record_failure(
+    iree_hal_amdgpu_staging_transfer_t* transfer, iree_status_t status) {
+  if (iree_status_is_ok(status)) return;
+  iree_slim_mutex_lock(&transfer->mutex);
+  if (iree_status_is_ok(transfer->failure_status)) {
+    transfer->failure_status = status;
+    status = iree_ok_status();
+  }
+  iree_slim_mutex_unlock(&transfer->mutex);
+  iree_status_free(status);
+}
+
+static iree_status_t iree_hal_amdgpu_staging_transfer_submit_signal_barrier(
+    iree_hal_amdgpu_staging_transfer_t* transfer) {
+  if (iree_hal_semaphore_list_is_empty(transfer->signal_semaphore_list)) {
+    return iree_ok_status();
+  }
+
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_staging_transfer_clone_queue_error(transfer));
+
+  iree_hal_amdgpu_wait_resolution_t resolution;
+  memset(&resolution, 0, sizeof(resolution));
+  resolution.inline_acquire_scope = IREE_HSA_FENCE_SCOPE_SYSTEM;
+  resolution.barrier_acquire_scope = IREE_HSA_FENCE_SCOPE_SYSTEM;
+
+  iree_slim_mutex_lock(&transfer->queue->locks.submission_mutex);
+  bool ready = false;
+  uint64_t submission_id = 0;
+  iree_hal_amdgpu_host_queue_profile_event_info_t profile_event_info = {
+      .type =
+          iree_hal_amdgpu_staging_transfer_profile_event_type(transfer->kind),
+      .payload_length = transfer->requested_length,
+      .operation_count = 1,
+  };
+  iree_status_t status = iree_hal_amdgpu_host_queue_try_submit_barrier(
+      transfer->queue, &resolution, transfer->signal_semaphore_list,
+      (iree_hal_amdgpu_reclaim_action_t){0},
+      /*operation_resources=*/NULL, /*operation_resource_count=*/0,
+      &profile_event_info,
+      iree_hal_amdgpu_host_queue_post_commit_callback_null(),
+      /*resource_set=*/NULL,
+      IREE_HAL_AMDGPU_HOST_QUEUE_SUBMISSION_FLAG_RETAIN_RESOURCES, &ready,
+      &submission_id);
+  if (iree_status_is_ok(status) && ready) {
+    iree_hal_amdgpu_wait_resolution_t profile_resolution = resolution;
+    profile_resolution.wait_count = transfer->profile_wait_count;
+    profile_event_info.submission_id = submission_id;
+    iree_hal_amdgpu_host_queue_record_profile_queue_event(
+        transfer->queue, &profile_resolution, transfer->signal_semaphore_list,
+        &profile_event_info);
+  }
+  if (iree_status_is_ok(status) && !ready) {
+    iree_hal_resource_retain(&transfer->resource);
+    iree_hal_amdgpu_host_queue_enqueue_post_drain_action(
+        transfer->queue, &transfer->signal_capacity_retry,
+        iree_hal_amdgpu_staging_signal_capacity_post_drain, transfer);
+  }
+  iree_slim_mutex_unlock(&transfer->queue->locks.submission_mutex);
+  return status;
+}
+
+static void iree_hal_amdgpu_staging_transfer_fail_signals(
+    iree_hal_amdgpu_staging_transfer_t* transfer, iree_status_t status) {
+  if (iree_status_is_ok(status)) return;
+  if (iree_hal_semaphore_list_is_empty(transfer->signal_semaphore_list)) {
+    iree_status_free(status);
+    return;
+  }
+  iree_hal_semaphore_list_fail(transfer->signal_semaphore_list, status);
+}
+
+static void iree_hal_amdgpu_staging_transfer_fail_signals_with_borrowed_status(
+    iree_hal_amdgpu_staging_transfer_t* transfer, iree_status_t status) {
+  if (iree_status_is_ok(status) ||
+      iree_hal_semaphore_list_is_empty(transfer->signal_semaphore_list)) {
+    return;
+  }
+  iree_hal_semaphore_list_fail(transfer->signal_semaphore_list,
+                               iree_status_clone(status));
+}
+
+static void iree_hal_amdgpu_staging_transfer_complete(
+    iree_hal_amdgpu_staging_transfer_t* transfer, iree_status_t status) {
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_staging_transfer_submit_signal_barrier(transfer);
+  }
+  iree_hal_amdgpu_staging_transfer_fail_signals(transfer, status);
+  iree_hal_resource_release(&transfer->resource);
+}
+
+static void iree_hal_amdgpu_staging_signal_capacity_post_drain(
+    void* user_data) {
+  iree_hal_amdgpu_staging_transfer_complete(
+      (iree_hal_amdgpu_staging_transfer_t*)user_data, iree_ok_status());
+}
+
+static void iree_hal_amdgpu_staging_transfer_try_finish(
+    iree_hal_amdgpu_staging_transfer_t* transfer) {
+  bool should_complete = false;
+  bool should_release_waiter_ref = false;
+  iree_status_t status = iree_ok_status();
+
+  iree_slim_mutex_lock(&transfer->mutex);
+  const bool has_failure = !iree_status_is_ok(transfer->failure_status);
+  const bool is_complete =
+      transfer->completed_length == transfer->requested_length;
+  if (!iree_any_bit_set(transfer->flags,
+                        IREE_HAL_AMDGPU_STAGING_TRANSFER_FLAG_FINISHING) &&
+      transfer->active_chunk_count == 0 && (has_failure || is_complete)) {
+    transfer->flags |= IREE_HAL_AMDGPU_STAGING_TRANSFER_FLAG_FINISHING;
+    status = transfer->failure_status;
+    transfer->failure_status = iree_ok_status();
+    should_complete = true;
+  }
+  iree_slim_mutex_unlock(&transfer->mutex);
+
+  if (should_complete && iree_hal_amdgpu_staging_pool_cancel_waiter(
+                             transfer->pool, &transfer->slot_waiter)) {
+    should_release_waiter_ref = true;
+  }
+  if (should_release_waiter_ref) {
+    iree_hal_resource_release(&transfer->resource);
+  }
+  if (should_complete) {
+    iree_hal_amdgpu_staging_transfer_complete(transfer, status);
+  }
+}
+
+static void iree_hal_amdgpu_staging_chunk_return_slot(
+    iree_hal_amdgpu_staging_chunk_t* chunk) {
+  iree_hal_amdgpu_staging_pool_t* pool = chunk->transfer->pool;
+  const uint32_t slot_ordinal = chunk->slot.ordinal;
+  memset(&chunk->slot, 0, sizeof(chunk->slot));
+  iree_hal_amdgpu_staging_pool_release(pool, slot_ordinal);
+}
+
+static void iree_hal_amdgpu_staging_chunk_finish(
+    iree_hal_amdgpu_staging_chunk_t* chunk, bool did_transfer_bytes) {
+  iree_hal_amdgpu_staging_transfer_t* transfer = chunk->transfer;
+  iree_slim_mutex_lock(&transfer->mutex);
+  if (did_transfer_bytes) {
+    transfer->completed_length += chunk->length;
+  }
+  chunk->state = IREE_HAL_AMDGPU_STAGING_CHUNK_IDLE;
+  chunk->length = 0;
+  chunk->file_progress = 0;
+  --transfer->active_chunk_count;
+  iree_slim_mutex_unlock(&transfer->mutex);
+  iree_hal_amdgpu_staging_chunk_return_slot(chunk);
+  iree_hal_amdgpu_staging_transfer_pump(transfer);
+  iree_hal_amdgpu_staging_transfer_try_finish(transfer);
+}
+
+static void iree_hal_amdgpu_staging_chunk_fail(
+    iree_hal_amdgpu_staging_chunk_t* chunk, iree_status_t status) {
+  iree_hal_amdgpu_staging_transfer_record_failure(chunk->transfer, status);
+  iree_hal_amdgpu_staging_chunk_finish(chunk, /*did_transfer_bytes=*/false);
+}
+
+static iree_status_t iree_hal_amdgpu_staging_chunk_submit_read(
+    iree_hal_amdgpu_staging_chunk_t* chunk);
+
+static iree_status_t iree_hal_amdgpu_staging_chunk_submit_write(
+    iree_hal_amdgpu_staging_chunk_t* chunk);
+
+static void iree_hal_amdgpu_staging_copy_pre_signal(
+    iree_hal_amdgpu_reclaim_entry_t* entry, void* user_data,
+    iree_status_t status) {
+  (void)entry;
+  iree_hal_amdgpu_staging_chunk_t* chunk =
+      (iree_hal_amdgpu_staging_chunk_t*)user_data;
+  chunk->copy_status =
+      iree_status_is_ok(status) ? iree_ok_status() : iree_status_clone(status);
+  iree_hal_resource_retain(&chunk->transfer->resource);
+  iree_hal_amdgpu_host_queue_enqueue_post_drain_action(
+      chunk->transfer->queue, &chunk->post_drain_action,
+      iree_hal_amdgpu_staging_copy_post_drain, chunk);
+}
+
+static iree_status_t iree_hal_amdgpu_staging_chunk_submit_copy(
+    iree_hal_amdgpu_staging_chunk_t* chunk) {
+  iree_hal_amdgpu_staging_transfer_t* transfer = chunk->transfer;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_staging_transfer_clone_queue_error(transfer));
+
+  iree_hal_amdgpu_wait_resolution_t resolution;
+  memset(&resolution, 0, sizeof(resolution));
+  resolution.inline_acquire_scope = IREE_HSA_FENCE_SCOPE_NONE;
+  resolution.barrier_acquire_scope = IREE_HSA_FENCE_SCOPE_NONE;
+
+  iree_hal_buffer_t* source_buffer = NULL;
+  iree_device_size_t source_offset = 0;
+  iree_hal_buffer_t* target_buffer = NULL;
+  iree_device_size_t target_offset = 0;
+  iree_hsa_fence_scope_t minimum_acquire_scope = IREE_HSA_FENCE_SCOPE_NONE;
+  iree_hsa_fence_scope_t minimum_release_scope = IREE_HSA_FENCE_SCOPE_NONE;
+  if (transfer->kind == IREE_HAL_AMDGPU_STAGING_TRANSFER_READ) {
+    source_buffer = chunk->slot.buffer;
+    source_offset = chunk->slot.buffer_offset;
+    target_buffer = transfer->buffer;
+    target_offset = transfer->buffer_offset + chunk->transfer_offset;
+    minimum_acquire_scope = IREE_HSA_FENCE_SCOPE_SYSTEM;
+    chunk->state = IREE_HAL_AMDGPU_STAGING_CHUNK_COPYING_TO_DEVICE;
+  } else {
+    source_buffer = transfer->buffer;
+    source_offset = transfer->buffer_offset + chunk->transfer_offset;
+    target_buffer = chunk->slot.buffer;
+    target_offset = chunk->slot.buffer_offset;
+    minimum_release_scope = IREE_HSA_FENCE_SCOPE_SYSTEM;
+    chunk->state = IREE_HAL_AMDGPU_STAGING_CHUNK_COPYING_TO_HOST;
+  }
+
+  iree_hal_resource_t* extra_resources[1] = {&transfer->resource};
+  iree_slim_mutex_lock(&transfer->queue->locks.submission_mutex);
+  bool ready = false;
+  iree_status_t status = iree_hal_amdgpu_host_queue_submit_copy_with_action(
+      transfer->queue, &resolution, iree_hal_semaphore_list_empty(),
+      source_buffer, source_offset, target_buffer, target_offset, chunk->length,
+      IREE_HAL_COPY_FLAG_NONE, minimum_acquire_scope, minimum_release_scope,
+      (iree_hal_amdgpu_reclaim_action_t){
+          .fn = iree_hal_amdgpu_staging_copy_pre_signal,
+          .user_data = chunk,
+      },
+      extra_resources, IREE_ARRAYSIZE(extra_resources),
+      IREE_HAL_AMDGPU_HOST_QUEUE_SUBMISSION_FLAG_RETAIN_RESOURCES, &ready);
+  if (iree_status_is_ok(status) && !ready) {
+    iree_hal_resource_retain(&transfer->resource);
+    iree_hal_amdgpu_host_queue_enqueue_post_drain_action(
+        transfer->queue, &chunk->post_drain_action,
+        iree_hal_amdgpu_staging_copy_capacity_post_drain, chunk);
+  }
+  iree_slim_mutex_unlock(&transfer->queue->locks.submission_mutex);
+  return status;
+}
+
+static void iree_hal_amdgpu_staging_read_complete(
+    void* user_data, iree_async_operation_t* base_operation,
+    iree_status_t status, iree_async_completion_flags_t flags) {
+  (void)base_operation;
+  (void)flags;
+  iree_hal_amdgpu_staging_chunk_t* chunk =
+      (iree_hal_amdgpu_staging_chunk_t*)user_data;
+
+  if (iree_status_is_ok(status) && chunk->read_op.bytes_read > 0) {
+    chunk->file_progress += chunk->read_op.bytes_read;
+    if (chunk->file_progress < chunk->length) {
+      status = iree_hal_amdgpu_staging_chunk_submit_read(chunk);
+      if (iree_status_is_ok(status)) {
+        iree_hal_resource_release(&chunk->transfer->resource);
+        return;
+      }
+    }
+  } else if (iree_status_is_ok(status) &&
+             chunk->file_progress < chunk->length) {
+    status = iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                              "short read: requested %" PRIhsz
+                              " bytes, got %" PRIhsz,
+                              chunk->length, chunk->file_progress);
+  }
+
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_staging_chunk_submit_copy(chunk);
+  }
+  if (!iree_status_is_ok(status)) {
+    iree_hal_amdgpu_staging_chunk_fail(chunk, status);
+  }
+  iree_hal_resource_release(&chunk->transfer->resource);
+}
+
+static void iree_hal_amdgpu_staging_write_complete(
+    void* user_data, iree_async_operation_t* base_operation,
+    iree_status_t status, iree_async_completion_flags_t flags) {
+  (void)base_operation;
+  (void)flags;
+  iree_hal_amdgpu_staging_chunk_t* chunk =
+      (iree_hal_amdgpu_staging_chunk_t*)user_data;
+
+  if (iree_status_is_ok(status) && chunk->write_op.bytes_written > 0) {
+    chunk->file_progress += chunk->write_op.bytes_written;
+    if (chunk->file_progress < chunk->length) {
+      status = iree_hal_amdgpu_staging_chunk_submit_write(chunk);
+      if (iree_status_is_ok(status)) {
+        iree_hal_resource_release(&chunk->transfer->resource);
+        return;
+      }
+    }
+  } else if (iree_status_is_ok(status) &&
+             chunk->file_progress < chunk->length) {
+    status = iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                              "short write: requested %" PRIhsz
+                              " bytes, wrote %" PRIhsz,
+                              chunk->length, chunk->file_progress);
+  }
+
+  if (!iree_status_is_ok(status)) {
+    iree_hal_amdgpu_staging_chunk_fail(chunk, status);
+  } else {
+    iree_hal_amdgpu_staging_chunk_finish(chunk, /*did_transfer_bytes=*/true);
+  }
+  iree_hal_resource_release(&chunk->transfer->resource);
+}
+
+static iree_status_t iree_hal_amdgpu_staging_chunk_submit_read(
+    iree_hal_amdgpu_staging_chunk_t* chunk) {
+  iree_hal_amdgpu_staging_transfer_t* transfer = chunk->transfer;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_staging_transfer_clone_queue_error(transfer));
+
+  iree_async_operation_zero(&chunk->read_op.base, sizeof(chunk->read_op));
+  iree_async_operation_initialize(&chunk->read_op.base,
+                                  IREE_ASYNC_OPERATION_TYPE_FILE_READ,
+                                  IREE_ASYNC_OPERATION_FLAG_NONE,
+                                  iree_hal_amdgpu_staging_read_complete, chunk);
+  chunk->read_op.file = transfer->async_file;
+  chunk->read_op.offset =
+      transfer->file_offset + chunk->transfer_offset + chunk->file_progress;
+  chunk->read_op.buffer = iree_async_span_from_ptr(
+      chunk->slot.host_span.data + chunk->file_progress,
+      chunk->length - chunk->file_progress);
+  iree_hal_resource_retain(&transfer->resource);
+  iree_status_t status = iree_async_proactor_submit_one(
+      transfer->queue->proactor, &chunk->read_op.base);
+  if (!iree_status_is_ok(status)) {
+    iree_hal_resource_release(&transfer->resource);
+  }
+  return status;
+}
+
+static iree_status_t iree_hal_amdgpu_staging_chunk_submit_write(
+    iree_hal_amdgpu_staging_chunk_t* chunk) {
+  iree_hal_amdgpu_staging_transfer_t* transfer = chunk->transfer;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_staging_transfer_clone_queue_error(transfer));
+
+  iree_async_operation_zero(&chunk->write_op.base, sizeof(chunk->write_op));
+  iree_async_operation_initialize(
+      &chunk->write_op.base, IREE_ASYNC_OPERATION_TYPE_FILE_WRITE,
+      IREE_ASYNC_OPERATION_FLAG_NONE, iree_hal_amdgpu_staging_write_complete,
+      chunk);
+  chunk->write_op.file = transfer->async_file;
+  chunk->write_op.offset =
+      transfer->file_offset + chunk->transfer_offset + chunk->file_progress;
+  chunk->write_op.buffer = iree_async_span_from_ptr(
+      chunk->slot.host_span.data + chunk->file_progress,
+      chunk->length - chunk->file_progress);
+  iree_hal_resource_retain(&transfer->resource);
+  iree_status_t status = iree_async_proactor_submit_one(
+      transfer->queue->proactor, &chunk->write_op.base);
+  if (!iree_status_is_ok(status)) {
+    iree_hal_resource_release(&transfer->resource);
+  }
+  return status;
+}
+
+static void iree_hal_amdgpu_staging_copy_post_drain(void* user_data) {
+  iree_hal_amdgpu_staging_chunk_t* chunk =
+      (iree_hal_amdgpu_staging_chunk_t*)user_data;
+  iree_status_t status = chunk->copy_status;
+  chunk->copy_status = iree_ok_status();
+
+  if (iree_status_is_ok(status) &&
+      chunk->transfer->kind == IREE_HAL_AMDGPU_STAGING_TRANSFER_WRITE) {
+    status = iree_hal_amdgpu_staging_chunk_submit_write(chunk);
+    if (iree_status_is_ok(status)) {
+      iree_hal_resource_release(&chunk->transfer->resource);
+      return;
+    }
+  }
+
+  if (!iree_status_is_ok(status)) {
+    iree_hal_amdgpu_staging_chunk_fail(chunk, status);
+  } else {
+    iree_hal_amdgpu_staging_chunk_finish(chunk, /*did_transfer_bytes=*/true);
+  }
+  iree_hal_resource_release(&chunk->transfer->resource);
+}
+
+static void iree_hal_amdgpu_staging_copy_capacity_post_drain(void* user_data) {
+  iree_hal_amdgpu_staging_chunk_t* chunk =
+      (iree_hal_amdgpu_staging_chunk_t*)user_data;
+  iree_status_t status = iree_hal_amdgpu_staging_chunk_submit_copy(chunk);
+  if (!iree_status_is_ok(status)) {
+    iree_hal_amdgpu_staging_chunk_fail(chunk, status);
+  }
+  iree_hal_resource_release(&chunk->transfer->resource);
+}
+
+static void iree_hal_amdgpu_staging_transfer_slot_available(void* user_data) {
+  iree_hal_amdgpu_staging_transfer_t* transfer =
+      (iree_hal_amdgpu_staging_transfer_t*)user_data;
+  iree_hal_amdgpu_staging_transfer_pump(transfer);
+  iree_hal_amdgpu_staging_transfer_try_finish(transfer);
+  iree_hal_resource_release(&transfer->resource);
+}
+
+static iree_hal_amdgpu_staging_chunk_t*
+iree_hal_amdgpu_staging_transfer_find_idle_chunk(
+    iree_hal_amdgpu_staging_transfer_t* transfer) {
+  for (uint32_t i = 0; i < transfer->chunk_count; ++i) {
+    if (transfer->chunks[i].state == IREE_HAL_AMDGPU_STAGING_CHUNK_IDLE) {
+      return &transfer->chunks[i];
+    }
+  }
+  return NULL;
+}
+
+static void iree_hal_amdgpu_staging_transfer_pump(
+    iree_hal_amdgpu_staging_transfer_t* transfer) {
+  for (;;) {
+    iree_hal_amdgpu_staging_slot_t slot;
+    memset(&slot, 0, sizeof(slot));
+    const bool has_waiter_slot = iree_hal_amdgpu_staging_pool_take_waiter_slot(
+        transfer->pool, &transfer->slot_waiter, &slot);
+
+    iree_slim_mutex_lock(&transfer->mutex);
+    const bool can_submit_more =
+        !iree_any_bit_set(transfer->flags,
+                          IREE_HAL_AMDGPU_STAGING_TRANSFER_FLAG_FINISHING) &&
+        iree_status_is_ok(transfer->failure_status) &&
+        transfer->submitted_length < transfer->requested_length;
+    iree_slim_mutex_unlock(&transfer->mutex);
+    if (!can_submit_more) {
+      if (has_waiter_slot) {
+        iree_hal_amdgpu_staging_pool_release(transfer->pool, slot.ordinal);
+      }
+      return;
+    }
+
+    if (!has_waiter_slot &&
+        !iree_hal_amdgpu_staging_pool_try_acquire(transfer->pool, &slot)) {
+      iree_hal_amdgpu_staging_pool_wait_result_t wait_result =
+          iree_hal_amdgpu_staging_pool_queue_waiter(
+              transfer->pool, &transfer->slot_waiter,
+              iree_hal_amdgpu_staging_transfer_slot_available, transfer);
+      if (wait_result == IREE_HAL_AMDGPU_STAGING_POOL_WAIT_QUEUED) {
+        iree_hal_resource_retain(&transfer->resource);
+      }
+      if (wait_result != IREE_HAL_AMDGPU_STAGING_POOL_WAIT_RETRY) {
+        return;
+      }
+      continue;
+    }
+
+    iree_hal_amdgpu_staging_chunk_t* chunk = NULL;
+    iree_device_size_t chunk_offset = 0;
+    iree_host_size_t chunk_length = 0;
+    bool should_release_slot = false;
+    iree_slim_mutex_lock(&transfer->mutex);
+    const bool has_failure = !iree_status_is_ok(transfer->failure_status);
+    const bool has_more_bytes =
+        transfer->submitted_length < transfer->requested_length;
+    if (!iree_any_bit_set(transfer->flags,
+                          IREE_HAL_AMDGPU_STAGING_TRANSFER_FLAG_FINISHING) &&
+        !has_failure && has_more_bytes) {
+      chunk = iree_hal_amdgpu_staging_transfer_find_idle_chunk(transfer);
+    }
+    if (chunk) {
+      const iree_device_size_t remaining_length =
+          transfer->requested_length - transfer->submitted_length;
+      chunk_length = (iree_host_size_t)iree_min(
+          (iree_device_size_t)transfer->pool->slot_size, remaining_length);
+      chunk_offset = transfer->submitted_length;
+      chunk->state = transfer->kind == IREE_HAL_AMDGPU_STAGING_TRANSFER_READ
+                         ? IREE_HAL_AMDGPU_STAGING_CHUNK_READING
+                         : IREE_HAL_AMDGPU_STAGING_CHUNK_COPYING_TO_HOST;
+      chunk->slot = slot;
+      chunk->transfer_offset = chunk_offset;
+      chunk->length = chunk_length;
+      chunk->file_progress = 0;
+      chunk->copy_status = iree_ok_status();
+      transfer->submitted_length += chunk_length;
+      ++transfer->active_chunk_count;
+    } else {
+      should_release_slot = true;
+    }
+    iree_slim_mutex_unlock(&transfer->mutex);
+
+    if (should_release_slot) {
+      iree_hal_amdgpu_staging_pool_release(transfer->pool, slot.ordinal);
+      return;
+    }
+
+    iree_status_t status =
+        transfer->kind == IREE_HAL_AMDGPU_STAGING_TRANSFER_READ
+            ? iree_hal_amdgpu_staging_chunk_submit_read(chunk)
+            : iree_hal_amdgpu_staging_chunk_submit_copy(chunk);
+    if (!iree_status_is_ok(status)) {
+      iree_hal_amdgpu_staging_chunk_fail(chunk, status);
+      return;
+    }
+  }
+}
+
+static iree_status_t iree_hal_amdgpu_staging_transfer_start(
+    iree_hal_amdgpu_staging_transfer_t* transfer) {
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_staging_transfer_clone_queue_error(transfer));
+  // The transfer buffer may be a queue_alloca result whose backing is only
+  // staged when the operation is submitted. Validate the device pointer here,
+  // after the host action's wait set has been satisfied.
+  iree_hal_buffer_t* allocated_buffer =
+      iree_hal_buffer_allocated_buffer(transfer->buffer);
+  if (IREE_UNLIKELY(!iree_hal_amdgpu_buffer_device_pointer(allocated_buffer))) {
+    return iree_make_status(
+        IREE_STATUS_FAILED_PRECONDITION,
+        "staged AMDGPU file transfer buffer was not backed by an AMDGPU "
+        "allocation after queue waits completed");
+  }
+  // The host action's reclaim entry owns |transfer| only until this callback
+  // returns. Keep a transfer-owned self reference across the async file/GPU
+  // copy pipeline; the terminal completion path releases it.
+  iree_hal_resource_retain(&transfer->resource);
+  iree_hal_amdgpu_staging_transfer_pump(transfer);
+  iree_hal_amdgpu_staging_transfer_try_finish(transfer);
+  return iree_ok_status();
+}
+
+static void iree_hal_amdgpu_staging_transfer_execute(
+    iree_hal_amdgpu_reclaim_entry_t* entry, void* user_data,
+    iree_status_t status) {
+  (void)entry;
+  iree_hal_amdgpu_staging_transfer_t* transfer =
+      (iree_hal_amdgpu_staging_transfer_t*)user_data;
+
+  if (!iree_status_is_ok(status)) {
+    iree_hal_amdgpu_staging_transfer_fail_signals_with_borrowed_status(transfer,
+                                                                       status);
+    return;
+  }
+
+  iree_status_t start_status = iree_hal_amdgpu_staging_transfer_start(transfer);
+  if (!iree_status_is_ok(start_status)) {
+    iree_hal_amdgpu_staging_transfer_fail_signals(transfer, start_status);
+  }
+}
+
+static iree_status_t iree_hal_amdgpu_staging_transfer_validate_buffer(
+    iree_hal_amdgpu_staging_transfer_kind_t kind, iree_hal_buffer_t* buffer,
+    iree_device_size_t buffer_offset, iree_device_size_t length) {
+  IREE_RETURN_IF_ERROR(
+      iree_hal_buffer_validate_range(buffer, buffer_offset, length));
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_usage(
+      iree_hal_buffer_allowed_usage(buffer),
+      kind == IREE_HAL_AMDGPU_STAGING_TRANSFER_READ
+          ? IREE_HAL_BUFFER_USAGE_TRANSFER_TARGET
+          : IREE_HAL_BUFFER_USAGE_TRANSFER_SOURCE));
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_access(
+      iree_hal_buffer_allowed_access(buffer),
+      kind == IREE_HAL_AMDGPU_STAGING_TRANSFER_READ
+          ? IREE_HAL_MEMORY_ACCESS_WRITE
+          : IREE_HAL_MEMORY_ACCESS_READ));
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_staging_transfer_create(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_staging_transfer_kind_t kind, iree_hal_file_t* file,
+    uint64_t file_offset, iree_hal_buffer_t* buffer,
+    iree_device_size_t buffer_offset, iree_device_size_t length,
+    uint32_t profile_wait_count,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_amdgpu_staging_transfer_t** out_transfer) {
+  *out_transfer = NULL;
+  if (IREE_UNLIKELY(!queue->staging_pool || !queue->staging_pool->buffer)) {
+    return iree_make_status(
+        IREE_STATUS_FAILED_PRECONDITION,
+        "AMDGPU queue file staging pool is not initialized");
+  }
+  if (IREE_UNLIKELY(!iree_hal_file_async_handle(file))) {
+    return iree_make_status(
+        IREE_STATUS_UNIMPLEMENTED,
+        "AMDGPU staged queue file transfers require a proactor-backed async "
+        "file handle");
+  }
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_staging_transfer_validate_buffer(
+      kind, buffer, buffer_offset, length));
+
+  iree_host_size_t chunks_size = 0;
+  if (!iree_host_size_checked_mul(queue->staging_pool->slot_count,
+                                  sizeof(iree_hal_amdgpu_staging_chunk_t),
+                                  &chunks_size)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "staging transfer chunk table size overflows");
+  }
+  iree_host_size_t total_size = 0;
+  if (!iree_host_size_checked_add(sizeof(iree_hal_amdgpu_staging_transfer_t),
+                                  chunks_size, &total_size)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "staging transfer allocation size overflows");
+  }
+
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, length);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, queue->staging_pool->slot_count);
+  iree_hal_amdgpu_staging_transfer_t* transfer = NULL;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_allocator_malloc(queue->host_allocator, total_size,
+                                (void**)&transfer));
+  memset(transfer, 0, total_size);
+  iree_hal_resource_initialize(&iree_hal_amdgpu_staging_transfer_vtable,
+                               &transfer->resource);
+  transfer->host_allocator = queue->host_allocator;
+  iree_slim_mutex_initialize(&transfer->mutex);
+  transfer->queue = queue;
+  transfer->pool = queue->staging_pool;
+  transfer->logical_device = queue->logical_device;
+  iree_hal_device_retain(transfer->logical_device);
+  transfer->file = file;
+  iree_hal_file_retain(transfer->file);
+  transfer->async_file = iree_hal_file_async_handle(file);
+  transfer->buffer = buffer;
+  iree_hal_buffer_retain(transfer->buffer);
+  transfer->file_offset = file_offset;
+  transfer->buffer_offset = buffer_offset;
+  transfer->requested_length = length;
+  transfer->chunk_count = queue->staging_pool->slot_count;
+  transfer->profile_wait_count = profile_wait_count;
+  transfer->kind = kind;
+  transfer->chunks = (iree_hal_amdgpu_staging_chunk_t*)(transfer + 1);
+  for (uint32_t i = 0; i < transfer->chunk_count; ++i) {
+    transfer->chunks[i].transfer = transfer;
+  }
+
+  iree_status_t status = iree_hal_semaphore_list_clone(
+      &signal_semaphore_list, transfer->host_allocator,
+      &transfer->signal_semaphore_list);
+  if (iree_status_is_ok(status)) {
+    *out_transfer = transfer;
+  } else {
+    iree_hal_resource_release(&transfer->resource);
+  }
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_submit_staged_transfer(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t wait_semaphore_list,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_amdgpu_staging_transfer_kind_t kind, iree_hal_file_t* file,
+    uint64_t file_offset, iree_hal_buffer_t* buffer,
+    iree_device_size_t buffer_offset, iree_device_size_t length) {
+  iree_hal_amdgpu_staging_transfer_t* transfer = NULL;
+  const uint32_t profile_wait_count =
+      iree_hal_amdgpu_host_queue_profile_semaphore_count(wait_semaphore_list);
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_staging_transfer_create(
+      queue, kind, file, file_offset, buffer, buffer_offset, length,
+      profile_wait_count, signal_semaphore_list, &transfer));
+
+  iree_hal_resource_t* resources[1] = {&transfer->resource};
+  iree_status_t status = iree_hal_amdgpu_host_queue_enqueue_host_action(
+      queue, wait_semaphore_list,
+      (iree_hal_amdgpu_reclaim_action_t){
+          .fn = iree_hal_amdgpu_staging_transfer_execute,
+          .user_data = transfer,
+      },
+      resources, IREE_ARRAYSIZE(resources));
+  iree_hal_resource_release(&transfer->resource);
+  return status;
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_submit_staged_read(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t wait_semaphore_list,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_file_t* source_file, uint64_t source_offset,
+    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+    iree_device_size_t length) {
+  return iree_hal_amdgpu_host_queue_submit_staged_transfer(
+      queue, wait_semaphore_list, signal_semaphore_list,
+      IREE_HAL_AMDGPU_STAGING_TRANSFER_READ, source_file, source_offset,
+      target_buffer, target_offset, length);
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_submit_staged_write(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t wait_semaphore_list,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
+    iree_hal_file_t* target_file, uint64_t target_offset,
+    iree_device_size_t length) {
+  return iree_hal_amdgpu_host_queue_submit_staged_transfer(
+      queue, wait_semaphore_list, signal_semaphore_list,
+      IREE_HAL_AMDGPU_STAGING_TRANSFER_WRITE, target_file, target_offset,
+      source_buffer, source_offset, length);
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_staging.h b/runtime/src/iree/hal/drivers/amdgpu/host_queue_staging.h
new file mode 100644
index 0000000..9b876ef
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_staging.h
@@ -0,0 +1,122 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_STAGING_H_
+#define IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_STAGING_H_
+
+#include "iree/base/api.h"
+#include "iree/base/threading/mutex.h"
+#include "iree/hal/api.h"
+#include "iree/hal/drivers/amdgpu/system.h"
+#include "iree/hal/drivers/amdgpu/util/libhsa.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Default byte length of one file staging slot.
+#define IREE_HAL_AMDGPU_STAGING_SLOT_SIZE_DEFAULT (16 * 1024 * 1024)
+
+// Required byte alignment for the staging allocation base and every slot.
+#define IREE_HAL_AMDGPU_STAGING_SLOT_ALIGNMENT (2 * 1024 * 1024)
+
+// Default number of file staging slots per physical device.
+#define IREE_HAL_AMDGPU_STAGING_SLOT_COUNT_DEFAULT 4u
+
+typedef struct iree_hal_amdgpu_host_queue_t iree_hal_amdgpu_host_queue_t;
+typedef struct iree_hal_amdgpu_staging_pool_waiter_t
+    iree_hal_amdgpu_staging_pool_waiter_t;
+
+// Options controlling the per-physical-device queue_read/queue_write staging
+// pool.
+typedef struct iree_hal_amdgpu_staging_pool_options_t {
+  // Byte length of each staging slot. Must be a non-zero power of two and at
+  // least IREE_HAL_AMDGPU_STAGING_SLOT_ALIGNMENT so every slot begins on a
+  // large-page-friendly boundary.
+  iree_host_size_t slot_size;
+  // Number of staging slots. Must be non-zero and a power of two.
+  uint32_t slot_count;
+  // Forces the staging allocation to use the fine-grained host pool instead of
+  // the preferred coarse-grained host pool.
+  uint32_t force_fine_host_memory : 1;
+  // Reserved for future staging allocation policy bits.
+  uint32_t reserved : 31;
+} iree_hal_amdgpu_staging_pool_options_t;
+
+// Fixed-size host/device-visible staging pool shared by one physical device.
+typedef struct iree_hal_amdgpu_staging_pool_t {
+  // Host allocator used for the free-slot ring.
+  iree_allocator_t host_allocator;
+  // HAL buffer wrapping the whole staging allocation.
+  iree_hal_buffer_t* buffer;
+  // Host pointer to the first staging byte.
+  uint8_t* host_base;
+  // Byte length of each staging slot.
+  iree_host_size_t slot_size;
+  // Number of staging slots in |buffer|.
+  uint32_t slot_count;
+  // Mask used to wrap free-slot ring indices.
+  uint32_t slot_mask;
+  // Serializes free-slot and waiter FIFO state.
+  iree_slim_mutex_t mutex;
+  // Number of entries currently available in |free_slots|.
+  uint32_t available_count;
+  // Read index into |free_slots|.
+  uint32_t free_read;
+  // Write index into |free_slots|.
+  uint32_t free_write;
+  // Ring of available slot ordinals.
+  uint32_t* free_slots;
+  // First waiter blocked on slot availability.
+  iree_hal_amdgpu_staging_pool_waiter_t* waiter_head;
+  // Last waiter blocked on slot availability.
+  iree_hal_amdgpu_staging_pool_waiter_t* waiter_tail;
+} iree_hal_amdgpu_staging_pool_t;
+
+// Initializes |out_options| to its default values.
+void iree_hal_amdgpu_staging_pool_options_initialize(
+    iree_hal_amdgpu_staging_pool_options_t* out_options);
+
+// Verifies |options| for use by a staging pool.
+iree_status_t iree_hal_amdgpu_staging_pool_options_verify(
+    const iree_hal_amdgpu_staging_pool_options_t* options);
+
+// Initializes a fixed-size staging pool in caller-owned storage.
+iree_status_t iree_hal_amdgpu_staging_pool_initialize(
+    iree_hal_device_t* logical_device, const iree_hal_amdgpu_libhsa_t* libhsa,
+    const iree_hal_amdgpu_topology_t* topology,
+    const iree_hal_amdgpu_host_memory_pools_t* host_memory_pools,
+    iree_hal_queue_affinity_t queue_affinity_mask,
+    const iree_hal_amdgpu_staging_pool_options_t* options,
+    iree_allocator_t host_allocator, iree_hal_amdgpu_staging_pool_t* out_pool);
+
+// Deinitializes |pool| and releases its fixed staging allocation.
+void iree_hal_amdgpu_staging_pool_deinitialize(
+    iree_hal_amdgpu_staging_pool_t* pool);
+
+// Submits a chunked fd-backed queue_read through the staging pool.
+iree_status_t iree_hal_amdgpu_host_queue_submit_staged_read(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t wait_semaphore_list,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_file_t* source_file, uint64_t source_offset,
+    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+    iree_device_size_t length);
+
+// Submits a chunked fd-backed queue_write through the staging pool.
+iree_status_t iree_hal_amdgpu_host_queue_submit_staged_write(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t wait_semaphore_list,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
+    iree_hal_file_t* target_file, uint64_t target_offset,
+    iree_device_size_t length);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_STAGING_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_staging_test.cc b/runtime/src/iree/hal/drivers/amdgpu/host_queue_staging_test.cc
new file mode 100644
index 0000000..98f6972
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_staging_test.cc
@@ -0,0 +1,694 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/host_queue_staging.h"
+
+#include <cerrno>
+#include <cstdint>
+#include <cstring>
+#include <fstream>
+#include <string>
+#include <vector>
+
+#include "iree/hal/api.h"
+#include "iree/hal/cts/util/test_base.h"
+#include "iree/hal/drivers/amdgpu/host_queue.h"
+#include "iree/hal/drivers/amdgpu/logical_device.h"
+#include "iree/hal/drivers/amdgpu/physical_device.h"
+#include "iree/hal/drivers/amdgpu/queue_affinity.h"
+#include "iree/hal/drivers/amdgpu/system.h"
+#include "iree/hal/drivers/amdgpu/util/aql_emitter.h"
+#include "iree/io/file_handle.h"
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+#if IREE_FILE_IO_ENABLE
+#include <unistd.h>
+#endif  // IREE_FILE_IO_ENABLE
+
+namespace iree::hal::amdgpu {
+namespace {
+
+using iree::hal::cts::Ref;
+
+constexpr iree_hal_queue_affinity_t kQueueAffinity0 =
+    ((iree_hal_queue_affinity_t)1ull) << 0;
+
+constexpr iree_device_size_t kStagingSlotSize =
+    IREE_HAL_AMDGPU_STAGING_SLOT_ALIGNMENT;
+constexpr iree_device_size_t kMultiSlotTransferSize =
+    kStagingSlotSize * 2 + 4096;
+
+std::vector<uint8_t> MakePatternData(size_t size) {
+  std::vector<uint8_t> data(size);
+  for (size_t i = 0; i < size; ++i) {
+    data[i] = static_cast<uint8_t>((i * 131 + (i >> 7) * 17 + 0x5A) & 0xFF);
+  }
+  return data;
+}
+
+static iree_status_t CreateSemaphore(iree_hal_device_t* device,
+                                     iree_hal_semaphore_t** out_semaphore) {
+  return iree_hal_semaphore_create(
+      device, IREE_HAL_QUEUE_AFFINITY_ANY,
+      /*initial_value=*/0, IREE_HAL_SEMAPHORE_FLAG_DEFAULT, out_semaphore);
+}
+
+static iree_hal_semaphore_list_t MakeSemaphoreList(
+    iree_hal_semaphore_t** semaphore, uint64_t* payload_value) {
+  return iree_hal_semaphore_list_t{
+      /*count=*/1,
+      /*semaphores=*/semaphore,
+      /*payload_values=*/payload_value,
+  };
+}
+
+static bool HostQueueHasPostDrainAction(iree_hal_amdgpu_host_queue_t* queue) {
+  iree_slim_mutex_lock(&queue->locks.post_drain_mutex);
+  const bool has_action = queue->post_drain.head != NULL;
+  iree_slim_mutex_unlock(&queue->locks.post_drain_mutex);
+  return has_action;
+}
+
+static iree_status_t EnqueueRawBlockingBarrier(
+    iree_hal_amdgpu_host_queue_t* queue, hsa_signal_t blocker_signal) {
+  const uint64_t packet_id =
+      iree_hal_amdgpu_aql_ring_reserve(&queue->aql_ring, /*count=*/1);
+  iree_hal_amdgpu_aql_packet_t* packet =
+      iree_hal_amdgpu_aql_ring_packet(&queue->aql_ring, packet_id);
+  const hsa_signal_t dep_signals[1] = {blocker_signal};
+  const uint16_t header = iree_hal_amdgpu_aql_emit_barrier_and(
+      &packet->barrier_and, dep_signals, IREE_ARRAYSIZE(dep_signals),
+      iree_hal_amdgpu_aql_packet_control_barrier_system(),
+      iree_hsa_signal_null());
+  iree_hal_amdgpu_aql_ring_commit(packet, header, /*setup=*/0);
+  iree_hal_amdgpu_aql_ring_doorbell(&queue->aql_ring, packet_id);
+  return iree_ok_status();
+}
+
+static iree_status_t WriteAll(int fd, const uint8_t* data, size_t length) {
+#if IREE_FILE_IO_ENABLE
+  size_t total_written = 0;
+  while (total_written < length) {
+    const ssize_t written =
+        write(fd, data + total_written, length - total_written);
+    if (written < 0 && errno == EINTR) continue;
+    if (written < 0) {
+      return iree_make_status(iree_status_code_from_errno(errno),
+                              "write failed after %" PRIhsz " of %" PRIhsz
+                              " bytes: %s",
+                              total_written, length, strerror(errno));
+    }
+    if (written == 0) {
+      return iree_make_status(IREE_STATUS_UNAVAILABLE,
+                              "write made no progress after %" PRIhsz
+                              " of %" PRIhsz " bytes",
+                              total_written, length);
+    }
+    total_written += (size_t)written;
+  }
+  return iree_ok_status();
+#else
+  (void)fd;
+  (void)data;
+  (void)length;
+  return iree_make_status(IREE_STATUS_UNAVAILABLE, "file I/O is disabled");
+#endif  // IREE_FILE_IO_ENABLE
+}
+
+static void ExpectByteRangeRepeated(const std::vector<uint8_t>& data,
+                                    uint8_t pattern) {
+  for (size_t i = 0; i < data.size(); ++i) {
+    if (data[i] != pattern) {
+      ADD_FAILURE() << "byte mismatch at offset " << i << ": expected 0x"
+                    << std::hex << static_cast<int>(pattern) << ", got 0x"
+                    << static_cast<int>(data[i]);
+      return;
+    }
+  }
+}
+
+static void ExpectByteRangeMatches(const std::vector<uint8_t>& actual,
+                                   const std::vector<uint8_t>& expected) {
+  ASSERT_EQ(actual.size(), expected.size());
+  if (!actual.empty()) {
+    EXPECT_EQ(std::memcmp(actual.data(), expected.data(), actual.size()), 0);
+  }
+}
+
+class HostQueueStagingTest : public ::testing::Test {
+ protected:
+  static void SetUpTestSuite() {
+    host_allocator_ = iree_allocator_system();
+    iree_status_t status = iree_hal_amdgpu_libhsa_initialize(
+        IREE_HAL_AMDGPU_LIBHSA_FLAG_NONE, iree_string_view_list_empty(),
+        host_allocator_, &libhsa_);
+    if (!iree_status_is_ok(status)) {
+      iree_status_fprint(stderr, status);
+      iree_status_free(status);
+      GTEST_SKIP() << "HSA not available, skipping tests";
+    }
+    IREE_ASSERT_OK(iree_hal_amdgpu_topology_initialize_with_defaults(
+        &libhsa_, &topology_));
+    if (topology_.gpu_agent_count == 0) {
+      GTEST_SKIP() << "no GPU devices available, skipping tests";
+    }
+  }
+
+  static void TearDownTestSuite() {
+    iree_hal_amdgpu_topology_deinitialize(&topology_);
+    iree_hal_amdgpu_libhsa_deinitialize(&libhsa_);
+  }
+
+  void TearDown() override {
+#if IREE_FILE_IO_ENABLE
+    for (const auto& path : temp_paths_) {
+      unlink(path.c_str());
+    }
+#endif  // IREE_FILE_IO_ENABLE
+  }
+
+  class TestLogicalDevice {
+   public:
+    ~TestLogicalDevice() {
+      iree_hal_device_release(base_device_);
+      iree_hal_device_group_release(device_group_);
+    }
+
+    iree_status_t Initialize(
+        const iree_hal_amdgpu_logical_device_options_t* options,
+        const iree_hal_amdgpu_libhsa_t* libhsa,
+        const iree_hal_amdgpu_topology_t* topology,
+        iree_allocator_t host_allocator) {
+      IREE_RETURN_IF_ERROR(create_context_.Initialize(host_allocator));
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_logical_device_create(
+          IREE_SV("amdgpu"), options, libhsa, topology,
+          create_context_.params(), host_allocator, &base_device_));
+      return iree_hal_device_group_create_from_device(
+          base_device_, create_context_.frontier_tracker(), host_allocator,
+          &device_group_);
+    }
+
+    iree_status_t ReinitializeFileStagingPool(
+        const iree_hal_amdgpu_staging_pool_options_t* options,
+        iree_allocator_t host_allocator) {
+      iree_hal_amdgpu_logical_device_t* logical_device = this->logical_device();
+      iree_hal_amdgpu_physical_device_t* physical_device =
+          this->first_physical_device();
+      if (!physical_device) {
+        return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                                "test device has no physical devices");
+      }
+
+      iree_hal_queue_affinity_t queue_affinity_mask = 0;
+      const iree_hal_amdgpu_queue_affinity_domain_t domain = {
+          .supported_affinity = logical_device->queue_affinity_mask,
+          .physical_device_count = logical_device->physical_device_count,
+          .queue_count_per_physical_device =
+              physical_device->host_queue_capacity,
+      };
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_queue_affinity_for_physical_device(
+          domain, physical_device->device_ordinal, &queue_affinity_mask));
+
+      iree_hal_amdgpu_staging_pool_deinitialize(
+          &physical_device->file_staging_pool);
+      return iree_hal_amdgpu_staging_pool_initialize(
+          base_device_, &logical_device->system->libhsa,
+          &logical_device->system->topology,
+          &physical_device->host_memory_pools, queue_affinity_mask, options,
+          host_allocator, &physical_device->file_staging_pool);
+    }
+
+    iree_hal_device_t* base_device() const { return base_device_; }
+
+    iree_hal_allocator_t* allocator() const {
+      return iree_hal_device_allocator(base_device_);
+    }
+
+    iree_hal_amdgpu_logical_device_t* logical_device() const {
+      return (iree_hal_amdgpu_logical_device_t*)base_device_;
+    }
+
+    iree_hal_amdgpu_physical_device_t* first_physical_device() const {
+      iree_hal_amdgpu_logical_device_t* logical_device = this->logical_device();
+      if (logical_device->physical_device_count == 0) return NULL;
+      return logical_device->physical_devices[0];
+    }
+
+    iree_hal_amdgpu_host_queue_t* first_host_queue() const {
+      iree_hal_amdgpu_physical_device_t* physical_device =
+          this->first_physical_device();
+      if (!physical_device || physical_device->host_queue_count == 0) {
+        return NULL;
+      }
+      return &physical_device->host_queues[0];
+    }
+
+   private:
+    // Creation context supplying the proactor pool and frontier tracker.
+    iree::hal::cts::DeviceCreateContext create_context_;
+
+    // Test-owned device reference released before the topology-owning group.
+    iree_hal_device_t* base_device_ = NULL;
+
+    // Device group that owns the topology assigned to |base_device_|.
+    iree_hal_device_group_t* device_group_ = NULL;
+  };
+
+  iree_hal_amdgpu_staging_pool_options_t OneSlotStagingOptions() {
+    iree_hal_amdgpu_staging_pool_options_t options;
+    iree_hal_amdgpu_staging_pool_options_initialize(&options);
+    options.slot_size = kStagingSlotSize;
+    options.slot_count = 1;
+    return options;
+  }
+
+  iree_status_t CreateTestDevice(
+      const iree_hal_amdgpu_logical_device_options_t* options,
+      TestLogicalDevice* out_device) {
+    IREE_RETURN_IF_ERROR(
+        out_device->Initialize(options, &libhsa_, &topology_, host_allocator_));
+    iree_hal_amdgpu_staging_pool_options_t staging_options =
+        OneSlotStagingOptions();
+    return out_device->ReinitializeFileStagingPool(&staging_options,
+                                                   host_allocator_);
+  }
+
+  iree_status_t CreateTempFileWithContents(const std::vector<uint8_t>& data,
+                                           std::string* out_path) {
+#if IREE_FILE_IO_ENABLE
+    *out_path = std::string();
+    char temp_path[] = "/tmp/iree_hal_amdgpu_staging_XXXXXX";
+    int fd = mkstemp(temp_path);
+    if (fd < 0) {
+      return iree_make_status(iree_status_code_from_errno(errno),
+                              "mkstemp failed: %s", strerror(errno));
+    }
+    temp_paths_.push_back(temp_path);
+    iree_status_t status = WriteAll(fd, data.data(), data.size());
+    if (close(fd) != 0) {
+      status = iree_status_join(
+          status, iree_make_status(iree_status_code_from_errno(errno),
+                                   "close failed: %s", strerror(errno)));
+    }
+    if (iree_status_is_ok(status)) {
+      *out_path = temp_path;
+    }
+    return status;
+#else
+    (void)data;
+    *out_path = std::string();
+    return iree_make_status(IREE_STATUS_UNAVAILABLE, "file I/O is disabled");
+#endif  // IREE_FILE_IO_ENABLE
+  }
+
+  iree_status_t CreatePatternedTempFile(size_t size, uint8_t pattern,
+                                        std::string* out_path) {
+    std::vector<uint8_t> data(size, pattern);
+    return CreateTempFileWithContents(data, out_path);
+  }
+
+  iree_status_t TruncateTempFile(const std::string& path, size_t length) {
+#if IREE_FILE_IO_ENABLE
+    if (truncate(path.c_str(), length) != 0) {
+      return iree_make_status(iree_status_code_from_errno(errno),
+                              "truncate failed: %s", strerror(errno));
+    }
+    return iree_ok_status();
+#else
+    (void)path;
+    (void)length;
+    return iree_make_status(IREE_STATUS_UNAVAILABLE, "file I/O is disabled");
+#endif  // IREE_FILE_IO_ENABLE
+  }
+
+  std::vector<uint8_t> ReadTempFileContents(const std::string& path,
+                                            size_t length) {
+    std::vector<uint8_t> data(length);
+    std::ifstream file(path, std::ios::binary);
+    EXPECT_TRUE(file.good());
+    if (file.good()) {
+      file.read(reinterpret_cast<char*>(data.data()),
+                static_cast<std::streamsize>(data.size()));
+      EXPECT_EQ(file.gcount(), static_cast<std::streamsize>(data.size()));
+    }
+    return data;
+  }
+
+  iree_status_t ImportFdFile(iree_hal_device_t* device, const std::string& path,
+                             iree_hal_memory_access_t access,
+                             iree_hal_file_t** out_file) {
+    iree_io_file_mode_t mode = IREE_IO_FILE_MODE_READ;
+    if (iree_all_bits_set(access, IREE_HAL_MEMORY_ACCESS_WRITE)) {
+      mode |= IREE_IO_FILE_MODE_WRITE;
+    }
+    iree_io_file_handle_t* handle = NULL;
+    IREE_RETURN_IF_ERROR(
+        iree_io_file_handle_open(mode, iree_make_cstring_view(path.c_str()),
+                                 iree_allocator_system(), &handle));
+    iree_status_t status = iree_hal_file_import(
+        device, IREE_HAL_QUEUE_AFFINITY_ANY, access, handle,
+        IREE_HAL_EXTERNAL_FILE_FLAG_NONE, out_file);
+    iree_io_file_handle_release(handle);
+    return status;
+  }
+
+  iree_status_t CreatePatternedDeviceBuffer(iree_hal_allocator_t* allocator,
+                                            iree_hal_device_t* device,
+                                            iree_device_size_t buffer_size,
+                                            uint8_t pattern,
+                                            iree_hal_buffer_t** out_buffer) {
+    iree_hal_buffer_params_t params = {0};
+    params.type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE;
+    params.access = IREE_HAL_MEMORY_ACCESS_ALL;
+    params.usage =
+        IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE;
+    IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer(
+        allocator, params, buffer_size, out_buffer));
+    return FillDeviceBufferRange(device, *out_buffer, /*offset=*/0, buffer_size,
+                                 pattern);
+  }
+
+  iree_status_t FillDeviceBufferRange(iree_hal_device_t* device,
+                                      iree_hal_buffer_t* buffer,
+                                      iree_device_size_t offset,
+                                      iree_device_size_t length,
+                                      uint8_t pattern) {
+    if (length == 0) return iree_ok_status();
+    Ref<iree_hal_semaphore_t> signal_semaphore;
+    IREE_RETURN_IF_ERROR(CreateSemaphore(device, signal_semaphore.out()));
+    uint64_t signal_value = 1;
+    iree_hal_semaphore_t* signal_semaphore_ptr = signal_semaphore.get();
+    iree_hal_semaphore_list_t signal_list =
+        MakeSemaphoreList(&signal_semaphore_ptr, &signal_value);
+    IREE_RETURN_IF_ERROR(iree_hal_device_queue_fill(
+        device, IREE_HAL_QUEUE_AFFINITY_ANY, iree_hal_semaphore_list_empty(),
+        signal_list, buffer, offset, length, &pattern, sizeof(pattern),
+        IREE_HAL_FILL_FLAG_NONE));
+    return iree_hal_semaphore_wait(signal_semaphore, signal_value,
+                                   iree_infinite_timeout(),
+                                   IREE_ASYNC_WAIT_FLAG_NONE);
+  }
+
+  iree_status_t QueueReadAndWait(iree_hal_device_t* device,
+                                 iree_hal_file_t* source_file,
+                                 uint64_t source_offset,
+                                 iree_hal_buffer_t* target_buffer,
+                                 iree_device_size_t target_offset,
+                                 iree_device_size_t length) {
+    Ref<iree_hal_semaphore_t> signal_semaphore;
+    IREE_RETURN_IF_ERROR(CreateSemaphore(device, signal_semaphore.out()));
+    uint64_t signal_value = 1;
+    iree_hal_semaphore_t* signal_semaphore_ptr = signal_semaphore.get();
+    iree_hal_semaphore_list_t signal_list =
+        MakeSemaphoreList(&signal_semaphore_ptr, &signal_value);
+    IREE_RETURN_IF_ERROR(iree_hal_device_queue_read(
+        device, IREE_HAL_QUEUE_AFFINITY_ANY, iree_hal_semaphore_list_empty(),
+        signal_list, source_file, source_offset, target_buffer, target_offset,
+        length, IREE_HAL_READ_FLAG_NONE));
+    return iree_hal_semaphore_wait(signal_semaphore, signal_value,
+                                   iree_infinite_timeout(),
+                                   IREE_ASYNC_WAIT_FLAG_NONE);
+  }
+
+  iree_status_t QueueWriteAndWait(iree_hal_device_t* device,
+                                  iree_hal_buffer_t* source_buffer,
+                                  iree_device_size_t source_offset,
+                                  iree_hal_file_t* target_file,
+                                  uint64_t target_offset,
+                                  iree_device_size_t length) {
+    Ref<iree_hal_semaphore_t> signal_semaphore;
+    IREE_RETURN_IF_ERROR(CreateSemaphore(device, signal_semaphore.out()));
+    uint64_t signal_value = 1;
+    iree_hal_semaphore_t* signal_semaphore_ptr = signal_semaphore.get();
+    iree_hal_semaphore_list_t signal_list =
+        MakeSemaphoreList(&signal_semaphore_ptr, &signal_value);
+    IREE_RETURN_IF_ERROR(iree_hal_device_queue_write(
+        device, IREE_HAL_QUEUE_AFFINITY_ANY, iree_hal_semaphore_list_empty(),
+        signal_list, source_buffer, source_offset, target_file, target_offset,
+        length, IREE_HAL_WRITE_FLAG_NONE));
+    return iree_hal_semaphore_wait(signal_semaphore, signal_value,
+                                   iree_infinite_timeout(),
+                                   IREE_ASYNC_WAIT_FLAG_NONE);
+  }
+
+  iree_status_t ReadBufferContents(iree_hal_device_t* device,
+                                   iree_hal_buffer_t* buffer,
+                                   iree_device_size_t offset,
+                                   iree_device_size_t length,
+                                   std::vector<uint8_t>* out_data) {
+    out_data->assign((size_t)length, 0);
+    iree_io_file_handle_t* handle = NULL;
+    IREE_RETURN_IF_ERROR(iree_io_file_handle_wrap_host_allocation(
+        IREE_IO_FILE_ACCESS_READ | IREE_IO_FILE_ACCESS_WRITE,
+        iree_make_byte_span(out_data->data(), out_data->size()),
+        iree_io_file_handle_release_callback_null(), iree_allocator_system(),
+        &handle));
+    Ref<iree_hal_file_t> file;
+    iree_status_t status = iree_hal_file_import(
+        device, IREE_HAL_QUEUE_AFFINITY_ANY, IREE_HAL_MEMORY_ACCESS_WRITE,
+        handle, IREE_HAL_EXTERNAL_FILE_FLAG_NONE, file.out());
+    iree_io_file_handle_release(handle);
+    if (iree_status_is_ok(status)) {
+      status = QueueWriteAndWait(device, buffer, offset, file,
+                                 /*target_offset=*/0, length);
+    }
+    return status;
+  }
+
+  static iree_allocator_t host_allocator_;
+  static iree_hal_amdgpu_libhsa_t libhsa_;
+  static iree_hal_amdgpu_topology_t topology_;
+
+  std::vector<std::string> temp_paths_;
+};
+
+iree_allocator_t HostQueueStagingTest::host_allocator_;
+iree_hal_amdgpu_libhsa_t HostQueueStagingTest::libhsa_;
+iree_hal_amdgpu_topology_t HostQueueStagingTest::topology_;
+
+#if IREE_FILE_IO_ENABLE
+
+TEST_F(HostQueueStagingTest, OneSlotLargeReadCompletesThroughSlotWaiter) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(CreateTestDevice(&options, &test_device));
+  iree_hal_amdgpu_physical_device_t* physical_device =
+      test_device.first_physical_device();
+  ASSERT_NE(physical_device, nullptr);
+  ASSERT_EQ(physical_device->file_staging_pool.slot_count, 1u);
+
+  std::vector<uint8_t> file_data = MakePatternData(kMultiSlotTransferSize);
+  std::string path;
+  IREE_ASSERT_OK(CreateTempFileWithContents(file_data, &path));
+
+  Ref<iree_hal_file_t> file;
+  IREE_ASSERT_OK(ImportFdFile(test_device.base_device(), path,
+                              IREE_HAL_MEMORY_ACCESS_READ, file.out()));
+  Ref<iree_hal_buffer_t> buffer;
+  IREE_ASSERT_OK(CreatePatternedDeviceBuffer(
+      test_device.allocator(), test_device.base_device(),
+      kMultiSlotTransferSize, 0x00, buffer.out()));
+
+  IREE_ASSERT_OK(QueueReadAndWait(test_device.base_device(), file,
+                                  /*source_offset=*/0, buffer,
+                                  /*target_offset=*/0, kMultiSlotTransferSize));
+
+  std::vector<uint8_t> contents;
+  IREE_ASSERT_OK(ReadBufferContents(test_device.base_device(), buffer,
+                                    /*offset=*/0, kMultiSlotTransferSize,
+                                    &contents));
+  ExpectByteRangeMatches(contents, file_data);
+}
+
+TEST_F(HostQueueStagingTest, OneSlotLargeWriteCompletesThroughSlotWaiter) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(CreateTestDevice(&options, &test_device));
+  iree_hal_amdgpu_physical_device_t* physical_device =
+      test_device.first_physical_device();
+  ASSERT_NE(physical_device, nullptr);
+  ASSERT_EQ(physical_device->file_staging_pool.slot_count, 1u);
+
+  std::string path;
+  IREE_ASSERT_OK(CreatePatternedTempFile(kMultiSlotTransferSize, 0x00, &path));
+
+  Ref<iree_hal_file_t> file;
+  IREE_ASSERT_OK(ImportFdFile(
+      test_device.base_device(), path,
+      IREE_HAL_MEMORY_ACCESS_READ | IREE_HAL_MEMORY_ACCESS_WRITE, file.out()));
+  Ref<iree_hal_buffer_t> buffer;
+  IREE_ASSERT_OK(CreatePatternedDeviceBuffer(
+      test_device.allocator(), test_device.base_device(),
+      kMultiSlotTransferSize, 0xA7, buffer.out()));
+
+  IREE_ASSERT_OK(QueueWriteAndWait(test_device.base_device(), buffer,
+                                   /*source_offset=*/0, file,
+                                   /*target_offset=*/0,
+                                   kMultiSlotTransferSize));
+
+  std::vector<uint8_t> contents =
+      ReadTempFileContents(path, kMultiSlotTransferSize);
+  ASSERT_EQ(contents.size(), kMultiSlotTransferSize);
+  ExpectByteRangeRepeated(contents, 0xA7);
+}
+
+TEST_F(HostQueueStagingTest, CapacityParkedStagedWriteRetriesAfterPostDrain) {
+  static constexpr uint32_t kAqlCapacity = 64;
+  static constexpr uint32_t kNotificationCapacity = 1;
+  static constexpr uint32_t kKernargCapacity = 2 * kAqlCapacity;
+
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.host_queues.aql_capacity = kAqlCapacity;
+  options.host_queues.notification_capacity = kNotificationCapacity;
+  options.host_queues.kernarg_capacity = kKernargCapacity;
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(CreateTestDevice(&options, &test_device));
+  iree_hal_amdgpu_host_queue_t* queue = test_device.first_host_queue();
+  ASSERT_NE(queue, nullptr);
+
+  std::string path;
+  IREE_ASSERT_OK(CreatePatternedTempFile(kMultiSlotTransferSize, 0x00, &path));
+  Ref<iree_hal_file_t> file;
+  IREE_ASSERT_OK(ImportFdFile(
+      test_device.base_device(), path,
+      IREE_HAL_MEMORY_ACCESS_READ | IREE_HAL_MEMORY_ACCESS_WRITE, file.out()));
+
+  Ref<iree_hal_buffer_t> source_buffer;
+  IREE_ASSERT_OK(CreatePatternedDeviceBuffer(
+      test_device.allocator(), test_device.base_device(),
+      kMultiSlotTransferSize, 0x3C, source_buffer.out()));
+  Ref<iree_hal_buffer_t> pressure_buffer;
+  IREE_ASSERT_OK(CreatePatternedDeviceBuffer(
+      test_device.allocator(), test_device.base_device(), sizeof(uint32_t),
+      0x00, pressure_buffer.out()));
+
+  hsa_signal_t blocker_signal = iree_hsa_signal_null();
+  IREE_ASSERT_OK(iree_hsa_amd_signal_create(
+      IREE_LIBHSA(&libhsa_), /*initial_value=*/1, /*num_consumers=*/0,
+      /*consumers=*/NULL, /*attributes=*/0, &blocker_signal));
+  IREE_ASSERT_OK(EnqueueRawBlockingBarrier(queue, blocker_signal));
+
+  Ref<iree_hal_semaphore_t> pressure_signal;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), pressure_signal.out()));
+  uint64_t pressure_signal_value = 1;
+  iree_hal_semaphore_t* pressure_signal_ptr = pressure_signal.get();
+  iree_hal_semaphore_list_t pressure_signal_list =
+      MakeSemaphoreList(&pressure_signal_ptr, &pressure_signal_value);
+  const uint32_t pressure_pattern = 0xABCD1234u;
+  iree_status_t status = iree_hal_device_queue_fill(
+      test_device.base_device(), kQueueAffinity0,
+      iree_hal_semaphore_list_empty(), pressure_signal_list, pressure_buffer,
+      /*target_offset=*/0, sizeof(pressure_pattern), &pressure_pattern,
+      sizeof(pressure_pattern), IREE_HAL_FILL_FLAG_NONE);
+
+  Ref<iree_hal_semaphore_t> write_signal;
+  if (iree_status_is_ok(status)) {
+    status = CreateSemaphore(test_device.base_device(), write_signal.out());
+  }
+  uint64_t write_signal_value = 1;
+  iree_hal_semaphore_t* write_signal_ptr = write_signal.get();
+  iree_hal_semaphore_list_t write_signal_list =
+      MakeSemaphoreList(&write_signal_ptr, &write_signal_value);
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_device_queue_write(
+        test_device.base_device(), kQueueAffinity0,
+        iree_hal_semaphore_list_empty(), write_signal_list, source_buffer,
+        /*source_offset=*/0, file, /*target_offset=*/0, kMultiSlotTransferSize,
+        IREE_HAL_WRITE_FLAG_NONE);
+  }
+  const bool retry_parked =
+      iree_status_is_ok(status) && HostQueueHasPostDrainAction(queue);
+
+  iree_hsa_signal_store_screlease(IREE_LIBHSA(&libhsa_), blocker_signal, 0);
+
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_semaphore_wait(write_signal, write_signal_value,
+                                     iree_infinite_timeout(),
+                                     IREE_ASYNC_WAIT_FLAG_NONE);
+  }
+  IREE_EXPECT_OK(
+      iree_hsa_signal_destroy(IREE_LIBHSA(&libhsa_), blocker_signal));
+
+  IREE_ASSERT_OK(status);
+  EXPECT_TRUE(retry_parked);
+
+  std::vector<uint8_t> contents =
+      ReadTempFileContents(path, kMultiSlotTransferSize);
+  ASSERT_EQ(contents.size(), kMultiSlotTransferSize);
+  ExpectByteRangeRepeated(contents, 0x3C);
+}
+
+TEST_F(HostQueueStagingTest, ShortReadFailsTerminalSignal) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.preallocate_pools = 0;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(CreateTestDevice(&options, &test_device));
+
+  std::vector<uint8_t> file_data = MakePatternData(kStagingSlotSize);
+  std::string path;
+  IREE_ASSERT_OK(CreateTempFileWithContents(file_data, &path));
+
+  Ref<iree_hal_file_t> file;
+  IREE_ASSERT_OK(ImportFdFile(test_device.base_device(), path,
+                              IREE_HAL_MEMORY_ACCESS_READ, file.out()));
+  Ref<iree_hal_buffer_t> buffer;
+  IREE_ASSERT_OK(CreatePatternedDeviceBuffer(
+      test_device.allocator(), test_device.base_device(), kStagingSlotSize,
+      0x00, buffer.out()));
+
+  Ref<iree_hal_semaphore_t> wait_semaphore;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), wait_semaphore.out()));
+  uint64_t wait_value = 1;
+  iree_hal_semaphore_t* wait_semaphore_ptr = wait_semaphore.get();
+  iree_hal_semaphore_list_t wait_list =
+      MakeSemaphoreList(&wait_semaphore_ptr, &wait_value);
+
+  Ref<iree_hal_semaphore_t> signal_semaphore;
+  IREE_ASSERT_OK(
+      CreateSemaphore(test_device.base_device(), signal_semaphore.out()));
+  uint64_t signal_value = 1;
+  iree_hal_semaphore_t* signal_semaphore_ptr = signal_semaphore.get();
+  iree_hal_semaphore_list_t signal_list =
+      MakeSemaphoreList(&signal_semaphore_ptr, &signal_value);
+
+  IREE_ASSERT_OK(iree_hal_device_queue_read(
+      test_device.base_device(), kQueueAffinity0, wait_list, signal_list, file,
+      /*source_offset=*/0, buffer, /*target_offset=*/0, kStagingSlotSize,
+      IREE_HAL_READ_FLAG_NONE));
+  IREE_ASSERT_OK(TruncateTempFile(path, /*length=*/0));
+  IREE_ASSERT_OK(
+      iree_hal_semaphore_signal(wait_semaphore, wait_value, /*frontier=*/NULL));
+
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_OUT_OF_RANGE,
+                        iree_hal_semaphore_wait(signal_semaphore, signal_value,
+                                                iree_infinite_timeout(),
+                                                IREE_ASYNC_WAIT_FLAG_NONE));
+}
+
+#else
+
+TEST_F(HostQueueStagingTest, FileIoDisabled) {
+  GTEST_SKIP() << "file I/O is disabled";
+}
+
+#endif  // IREE_FILE_IO_ENABLE
+
+}  // namespace
+}  // namespace iree::hal::amdgpu
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_submission.c b/runtime/src/iree/hal/drivers/amdgpu/host_queue_submission.c
new file mode 100644
index 0000000..2107680
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_submission.c
@@ -0,0 +1,1313 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/host_queue_submission.h"
+
+#include <string.h>
+
+#include "iree/hal/drivers/amdgpu/device/timestamp.h"
+#include "iree/hal/drivers/amdgpu/host_queue_policy.h"
+#include "iree/hal/drivers/amdgpu/host_queue_profile.h"
+#include "iree/hal/drivers/amdgpu/host_queue_profile_events.h"
+#include "iree/hal/drivers/amdgpu/profile_counters.h"
+#include "iree/hal/drivers/amdgpu/profile_traces.h"
+#include "iree/hal/drivers/amdgpu/semaphore.h"
+#include "iree/hal/drivers/amdgpu/util/aql_emitter.h"
+#include "iree/hal/drivers/amdgpu/util/pm4_emitter.h"
+
+// Returns true if |semaphore| has the strict private stream contract that lets
+// the signal path publish only a producer queue epoch instead of accumulating a
+// full multi-producer semaphore frontier.
+static bool iree_hal_amdgpu_host_queue_is_private_stream_signal(
+    const iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_semaphore_t* semaphore) {
+  return iree_hal_amdgpu_semaphore_has_private_stream_semantics(
+      semaphore,
+      (const iree_hal_amdgpu_logical_device_t*)queue->logical_device);
+}
+
+static uint64_t iree_hal_amdgpu_host_queue_last_drained_epoch(
+    const iree_hal_amdgpu_host_queue_t* queue) {
+  return (uint64_t)iree_atomic_load(
+      &queue->notification_ring.epoch.last_drained, iree_memory_order_acquire);
+}
+
+static bool iree_hal_amdgpu_host_queue_should_push_frontier_snapshot(
+    const iree_hal_amdgpu_host_queue_t* queue,
+    bool span_needs_frontier_snapshot, uint64_t span_epoch,
+    uint64_t last_drained_epoch) {
+  return queue->can_publish_frontier && span_needs_frontier_snapshot &&
+         span_epoch > last_drained_epoch;
+}
+
+// Returns a conservative upper bound on the number of frontier snapshots that
+// commit_signals will push for |signal_semaphore_list|.
+//
+// Caller must hold submission_mutex.
+static iree_host_size_t iree_hal_amdgpu_host_queue_count_frontier_snapshots(
+    const iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t signal_semaphore_list) {
+  const uint64_t last_drained_epoch =
+      iree_hal_amdgpu_host_queue_last_drained_epoch(queue);
+  iree_host_size_t snapshot_count = 0;
+  iree_async_semaphore_t* last_semaphore = queue->last_signal.semaphore;
+  uint64_t last_semaphore_epoch = queue->last_signal.epoch;
+  bool last_needs_frontier_snapshot =
+      queue->last_signal.needs_frontier_snapshot;
+  for (iree_host_size_t i = 0; i < signal_semaphore_list.count; ++i) {
+    iree_hal_semaphore_t* hal_semaphore = signal_semaphore_list.semaphores[i];
+    iree_async_semaphore_t* semaphore = (iree_async_semaphore_t*)hal_semaphore;
+    const bool needs_frontier_snapshot =
+        queue->can_publish_frontier &&
+        !iree_hal_amdgpu_host_queue_is_private_stream_signal(queue,
+                                                             hal_semaphore);
+    if (semaphore != last_semaphore) {
+      if (last_semaphore != NULL &&
+          iree_hal_amdgpu_host_queue_should_push_frontier_snapshot(
+              queue, last_needs_frontier_snapshot, last_semaphore_epoch,
+              last_drained_epoch)) {
+        ++snapshot_count;
+      }
+      last_semaphore = semaphore;
+      last_needs_frontier_snapshot = needs_frontier_snapshot;
+    }
+    // Any later transition within this submission is necessarily pending, even
+    // if the previous same-semaphore span had already drained.
+    last_semaphore_epoch = UINT64_MAX;
+  }
+  return snapshot_count;
+}
+
+static void iree_hal_amdgpu_host_queue_push_frontier_snapshot_if_pending(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_async_frontier_t* queue_frontier) {
+  if (queue->last_signal.semaphore == NULL) return;
+  if (!iree_hal_amdgpu_host_queue_should_push_frontier_snapshot(
+          queue, queue->last_signal.needs_frontier_snapshot,
+          queue->last_signal.epoch,
+          iree_hal_amdgpu_host_queue_last_drained_epoch(queue))) {
+    return;
+  }
+  iree_hal_amdgpu_notification_ring_push_frontier_snapshot(
+      &queue->notification_ring, queue->last_signal.epoch, queue_frontier);
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_count_reclaim_resources(
+    iree_host_size_t signal_semaphore_count,
+    iree_host_size_t operation_resource_count,
+    uint16_t* out_reclaim_resource_count) {
+  IREE_ASSERT_ARGUMENT(out_reclaim_resource_count);
+  if (signal_semaphore_count > UINT16_MAX ||
+      operation_resource_count > UINT16_MAX ||
+      signal_semaphore_count > UINT16_MAX - operation_resource_count) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "submission retains too many resources (signals=%" PRIhsz
+        ", operation_resources=%" PRIhsz ", max=%u)",
+        signal_semaphore_count, operation_resource_count, UINT16_MAX);
+  }
+  *out_reclaim_resource_count =
+      (uint16_t)(signal_semaphore_count + operation_resource_count);
+  return iree_ok_status();
+}
+
+// Writes |packet_count| no-op barrier packets into already-reserved AQL slots.
+// The caller controls doorbell timing so these packets can plug failure-path
+// reservations without leaving INVALID packets visible to the CP.
+static void iree_hal_amdgpu_host_queue_fill_noop_packets(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t first_packet_id,
+    uint32_t packet_count) {
+  for (uint32_t i = 0; i < packet_count; ++i) {
+    iree_hal_amdgpu_aql_packet_t* packet =
+        iree_hal_amdgpu_aql_ring_packet(&queue->aql_ring, first_packet_id + i);
+    uint16_t header = iree_hal_amdgpu_aql_emit_nop(
+        &packet->barrier_and,
+        iree_hal_amdgpu_aql_packet_control_barrier(IREE_HSA_FENCE_SCOPE_NONE,
+                                                   IREE_HSA_FENCE_SCOPE_NONE),
+        iree_hsa_signal_null());
+    iree_hal_amdgpu_aql_ring_commit(packet, header, /*setup=*/0);
+  }
+}
+
+// Emits |packet_count| no-op barrier packets and rings the doorbell. Used only
+// to plug already-reserved AQL slots on an internal failure path so the CP
+// never stalls on an INVALID header.
+static void iree_hal_amdgpu_host_queue_emit_noop_packets(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t first_packet_id,
+    uint32_t packet_count) {
+  IREE_ASSERT(packet_count > 0, "must plug at least one reserved AQL packet");
+  iree_hal_amdgpu_host_queue_fill_noop_packets(queue, first_packet_id,
+                                               packet_count);
+  iree_hal_amdgpu_aql_ring_doorbell(&queue->aql_ring,
+                                    first_packet_id + packet_count - 1);
+}
+
+// Publishes an internal completion epoch for a failed submission that already
+// reserved kernarg space. User-visible semaphores are not signaled, but the
+// normal notification drain can reclaim the kernarg ring in queue order.
+static void iree_hal_amdgpu_host_queue_emit_reclaim_noop_packets(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_reclaim_entry_t* reclaim_entry, uint64_t first_packet_id,
+    uint32_t packet_count, uint64_t kernarg_write_position,
+    uint64_t queue_upload_write_position) {
+  reclaim_entry->kernarg_write_position = kernarg_write_position;
+  reclaim_entry->queue_upload_write_position = queue_upload_write_position;
+  reclaim_entry->count = 0;
+  iree_hal_amdgpu_notification_ring_advance_epoch(&queue->notification_ring);
+  for (uint32_t i = 0; i < packet_count; ++i) {
+    iree_hal_amdgpu_aql_packet_t* packet =
+        iree_hal_amdgpu_aql_ring_packet(&queue->aql_ring, first_packet_id + i);
+    const bool is_final_packet = i + 1 == packet_count;
+    uint16_t header = iree_hal_amdgpu_aql_emit_nop(
+        &packet->barrier_and,
+        iree_hal_amdgpu_aql_packet_control_barrier(IREE_HSA_FENCE_SCOPE_NONE,
+                                                   IREE_HSA_FENCE_SCOPE_NONE),
+        is_final_packet ? iree_hal_amdgpu_notification_ring_epoch_signal(
+                              &queue->notification_ring)
+                        : iree_hsa_signal_null());
+    iree_hal_amdgpu_aql_ring_commit(packet, header, /*setup=*/0);
+  }
+  iree_hal_amdgpu_aql_ring_doorbell(&queue->aql_ring,
+                                    first_packet_id + packet_count - 1);
+}
+
+// Returns the packet control for the final dispatch packet in a submission.
+// Direct host-queue submissions keep BARRIER set so the queue epoch remains an
+// ordered prefix-completion clock. Dispatches carry at least AGENT acquire so
+// device-side packet execution observes host-populated kernargs.
+static iree_hal_amdgpu_aql_packet_control_t
+iree_hal_amdgpu_host_queue_final_dispatch_packet_control(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hsa_fence_scope_t minimum_acquire_scope,
+    iree_hsa_fence_scope_t minimum_release_scope) {
+  return iree_hal_amdgpu_aql_packet_control_barrier(
+      iree_hal_amdgpu_host_queue_max_fence_scope(
+          iree_hal_amdgpu_host_queue_max_fence_scope(
+              IREE_HSA_FENCE_SCOPE_AGENT, resolution->inline_acquire_scope),
+          minimum_acquire_scope),
+      iree_hal_amdgpu_host_queue_max_fence_scope(
+          iree_hal_amdgpu_host_queue_signal_list_release_scope(
+              queue, signal_semaphore_list),
+          minimum_release_scope));
+}
+
+// Returns the packet control for a dispatch packet followed by a trailing
+// queue-completion packet. The dispatch no longer signals the queue/user epoch,
+// so it keeps only operation-local visibility requirements; the trailing packet
+// owns signal-list release visibility.
+static iree_hal_amdgpu_aql_packet_control_t
+iree_hal_amdgpu_host_queue_payload_dispatch_packet_control(
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    iree_hsa_fence_scope_t minimum_acquire_scope,
+    iree_hsa_fence_scope_t minimum_release_scope) {
+  return iree_hal_amdgpu_aql_packet_control_barrier(
+      iree_hal_amdgpu_host_queue_max_fence_scope(
+          iree_hal_amdgpu_host_queue_max_fence_scope(
+              IREE_HSA_FENCE_SCOPE_AGENT, resolution->inline_acquire_scope),
+          minimum_acquire_scope),
+      minimum_release_scope);
+}
+
+// Returns the packet control for a final no-op/barrier completion packet.
+static iree_hal_amdgpu_aql_packet_control_t
+iree_hal_amdgpu_host_queue_final_barrier_packet_control(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list) {
+  return iree_hal_amdgpu_aql_packet_control_barrier(
+      resolution->inline_acquire_scope,
+      iree_hal_amdgpu_host_queue_signal_list_release_scope(
+          queue, signal_semaphore_list));
+}
+
+// Returns the packet control for a final PM4-IB payload packet. PM4 IB payloads
+// are host-populated memory consumed by the CP, so they carry the same minimum
+// AGENT acquire as dispatch packets.
+static iree_hal_amdgpu_aql_packet_control_t
+iree_hal_amdgpu_host_queue_final_pm4_ib_packet_control(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list) {
+  return iree_hal_amdgpu_aql_packet_control_barrier(
+      iree_hal_amdgpu_host_queue_max_fence_scope(
+          IREE_HSA_FENCE_SCOPE_AGENT, resolution->inline_acquire_scope),
+      iree_hal_amdgpu_host_queue_signal_list_release_scope(
+          queue, signal_semaphore_list));
+}
+
+// Returns the packet control for a non-final PM4-IB packet in a larger
+// submission. The final packet owns user-visible release and queue completion.
+static iree_hal_amdgpu_aql_packet_control_t
+iree_hal_amdgpu_host_queue_payload_pm4_ib_packet_control(
+    const iree_hal_amdgpu_wait_resolution_t* resolution) {
+  return iree_hal_amdgpu_aql_packet_control_barrier(
+      iree_hal_amdgpu_host_queue_max_fence_scope(
+          IREE_HSA_FENCE_SCOPE_AGENT, resolution->inline_acquire_scope),
+      IREE_HSA_FENCE_SCOPE_NONE);
+}
+
+void iree_hal_amdgpu_host_queue_commit_queue_device_start_packet(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution, uint64_t packet_id,
+    iree_hal_amdgpu_profile_queue_device_event_t* queue_device_event) {
+  iree_hal_amdgpu_aql_packet_t* packet =
+      iree_hal_amdgpu_aql_ring_packet(&queue->aql_ring, packet_id);
+  iree_hal_amdgpu_pm4_ib_slot_t* pm4_ib_slot =
+      &queue->pm4_ib_slots[packet_id & queue->aql_ring.mask];
+  iree_hal_amdgpu_pm4_ib_builder_t builder;
+  iree_hal_amdgpu_pm4_ib_builder_initialize(pm4_ib_slot, &builder);
+  const bool did_emit =
+      iree_hal_amdgpu_pm4_ib_builder_emit_copy_timestamp_to_memory(
+          &builder, &queue_device_event->start_tick);
+  IREE_ASSERT(did_emit, "PM4 start timestamp must fit profiling IB slot");
+  (void)did_emit;
+  uint16_t setup = 0;
+  const uint16_t header = iree_hal_amdgpu_aql_emit_pm4_ib(
+      &packet->pm4_ib, pm4_ib_slot,
+      iree_hal_amdgpu_pm4_ib_builder_dword_count(&builder),
+      iree_hal_amdgpu_host_queue_payload_pm4_ib_packet_control(resolution),
+      iree_hsa_signal_null(), &setup);
+  iree_hal_amdgpu_aql_ring_commit(packet, header, setup);
+}
+
+void iree_hal_amdgpu_host_queue_commit_queue_device_end_packet(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list, uint64_t packet_id,
+    iree_hal_amdgpu_profile_queue_device_event_t* queue_device_event) {
+  iree_hal_amdgpu_aql_packet_t* packet =
+      iree_hal_amdgpu_aql_ring_packet(&queue->aql_ring, packet_id);
+  iree_hal_amdgpu_pm4_ib_slot_t* pm4_ib_slot =
+      &queue->pm4_ib_slots[packet_id & queue->aql_ring.mask];
+  iree_hal_amdgpu_pm4_ib_builder_t builder;
+  iree_hal_amdgpu_pm4_ib_builder_initialize(pm4_ib_slot, &builder);
+  const bool did_emit =
+      iree_hal_amdgpu_pm4_ib_builder_emit_release_mem_timestamp_to_memory(
+          &builder, &queue_device_event->end_tick);
+  IREE_ASSERT(did_emit, "PM4 end timestamp must fit profiling IB slot");
+  (void)did_emit;
+  uint16_t setup = 0;
+  const uint16_t header = iree_hal_amdgpu_aql_emit_pm4_ib(
+      &packet->pm4_ib, pm4_ib_slot,
+      iree_hal_amdgpu_pm4_ib_builder_dword_count(&builder),
+      iree_hal_amdgpu_host_queue_final_pm4_ib_packet_control(
+          queue, resolution, signal_semaphore_list),
+      iree_hal_amdgpu_notification_ring_epoch_signal(&queue->notification_ring),
+      &setup);
+  iree_hal_amdgpu_aql_ring_commit(packet, header, setup);
+}
+
+uint16_t iree_hal_amdgpu_host_queue_write_dispatch_packet_body(
+    iree_hsa_kernel_dispatch_packet_t* IREE_RESTRICT dispatch_packet,
+    const iree_hsa_kernel_dispatch_packet_t* IREE_RESTRICT
+        dispatch_packet_template,
+    void* kernarg_address, iree_hsa_signal_t completion_signal) {
+  dispatch_packet->workgroup_size[0] =
+      dispatch_packet_template->workgroup_size[0];
+  dispatch_packet->workgroup_size[1] =
+      dispatch_packet_template->workgroup_size[1];
+  dispatch_packet->workgroup_size[2] =
+      dispatch_packet_template->workgroup_size[2];
+  dispatch_packet->reserved0 = dispatch_packet_template->reserved0;
+  dispatch_packet->grid_size[0] = dispatch_packet_template->grid_size[0];
+  dispatch_packet->grid_size[1] = dispatch_packet_template->grid_size[1];
+  dispatch_packet->grid_size[2] = dispatch_packet_template->grid_size[2];
+  dispatch_packet->private_segment_size =
+      dispatch_packet_template->private_segment_size;
+  dispatch_packet->group_segment_size =
+      dispatch_packet_template->group_segment_size;
+  dispatch_packet->kernel_object = dispatch_packet_template->kernel_object;
+  dispatch_packet->kernarg_address = kernarg_address;
+  dispatch_packet->reserved2 = dispatch_packet_template->reserved2;
+  dispatch_packet->completion_signal = completion_signal;
+  return dispatch_packet_template->setup;
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_try_begin_kernel_submission(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_host_size_t operation_resource_count, uint32_t payload_packet_count,
+    uint32_t kernarg_block_count, bool* out_ready,
+    iree_hal_amdgpu_host_queue_kernel_submission_t* out_submission) {
+  IREE_ASSERT_ARGUMENT(queue);
+  IREE_ASSERT_ARGUMENT(resolution);
+  IREE_ASSERT_ARGUMENT(out_ready);
+  IREE_ASSERT_ARGUMENT(out_submission);
+  *out_ready = false;
+  memset(out_submission, 0, sizeof(*out_submission));
+
+  if (IREE_UNLIKELY(payload_packet_count == 0)) {
+    return iree_make_status(IREE_STATUS_INTERNAL,
+                            "kernel submission requires at least one payload "
+                            "packet");
+  }
+
+  const uint64_t packet_count =
+      (uint64_t)resolution->barrier_count + payload_packet_count;
+  const uint64_t aql_queue_capacity = (uint64_t)queue->aql_ring.mask + 1;
+  if (IREE_UNLIKELY(packet_count > aql_queue_capacity ||
+                    packet_count > UINT32_MAX)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "kernel submission requires %" PRIu64
+        " AQL packets (%u barriers + %u payload packets) but queue capacity is "
+        "%" PRIu64,
+        packet_count, resolution->barrier_count, payload_packet_count,
+        aql_queue_capacity);
+  }
+  if (IREE_UNLIKELY(kernarg_block_count > queue->kernarg_ring.capacity)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "kernel submission requires %u kernarg blocks but ring capacity is %u",
+        kernarg_block_count, queue->kernarg_ring.capacity);
+  }
+  if (IREE_UNLIKELY(signal_semaphore_list.count >
+                    queue->notification_ring.capacity)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "kernel submission requires %" PRIhsz
+                            " notification entries but ring capacity is %u",
+                            signal_semaphore_list.count,
+                            queue->notification_ring.capacity);
+  }
+
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_count_reclaim_resources(
+      signal_semaphore_list.count, operation_resource_count,
+      &out_submission->reclaim_resource_count));
+
+  const iree_host_size_t frontier_snapshot_count =
+      iree_hal_amdgpu_host_queue_count_frontier_snapshots(
+          queue, signal_semaphore_list);
+  if (!iree_hal_amdgpu_notification_ring_can_reserve(
+          &queue->notification_ring, signal_semaphore_list.count,
+          frontier_snapshot_count)) {
+    return iree_ok_status();
+  }
+  if (kernarg_block_count > 0 &&
+      !iree_hal_amdgpu_kernarg_ring_can_allocate(&queue->kernarg_ring,
+                                                 kernarg_block_count)) {
+    return iree_ok_status();
+  }
+
+  uint64_t first_packet_id = 0;
+  if (!iree_hal_amdgpu_aql_ring_try_reserve(
+          &queue->aql_ring, (uint32_t)packet_count, &first_packet_id)) {
+    return iree_ok_status();
+  }
+
+  out_submission->reclaim_entry =
+      iree_hal_amdgpu_notification_ring_reclaim_entry(
+          &queue->notification_ring);
+  iree_status_t status = iree_hal_amdgpu_reclaim_entry_prepare(
+      out_submission->reclaim_entry, queue->block_pool,
+      out_submission->reclaim_resource_count,
+      &out_submission->reclaim_resources);
+  if (iree_status_is_ok(status)) {
+    out_submission->packet_count = (uint32_t)packet_count;
+    out_submission->first_packet_id = first_packet_id;
+    if (kernarg_block_count > 0) {
+      out_submission->kernargs.blocks = iree_hal_amdgpu_kernarg_ring_allocate(
+          &queue->kernarg_ring, kernarg_block_count,
+          &out_submission->kernargs.write_position);
+      if (IREE_UNLIKELY(!out_submission->kernargs.blocks)) {
+        iree_hal_amdgpu_host_queue_emit_noop_packets(
+            queue, out_submission->first_packet_id,
+            out_submission->packet_count);
+        iree_hal_amdgpu_reclaim_entry_release(out_submission->reclaim_entry,
+                                              queue->block_pool);
+        status = iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                                  "kernarg ring allocation failed after AQL "
+                                  "reservation; queue sizing invariant was "
+                                  "violated");
+      }
+    } else {
+      out_submission->kernargs.write_position = (uint64_t)iree_atomic_load(
+          &queue->kernarg_ring.write_position, iree_memory_order_relaxed);
+    }
+  } else {
+    iree_hal_amdgpu_host_queue_emit_noop_packets(queue, first_packet_id,
+                                                 (uint32_t)packet_count);
+  }
+  if (iree_status_is_ok(status)) {
+    *out_ready = true;
+  } else {
+    memset(out_submission, 0, sizeof(*out_submission));
+  }
+  return status;
+}
+
+static uint32_t iree_hal_amdgpu_host_queue_barrier_packet_count(
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    bool profile_queue_device_event) {
+  if (profile_queue_device_event) {
+    return (uint32_t)resolution->barrier_count + 1u;
+  }
+  return resolution->barrier_count > 0 ? resolution->barrier_count : 1u;
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_try_begin_barrier_submission(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_host_size_t operation_resource_count,
+    const iree_hal_amdgpu_host_queue_profile_event_info_t* profile_event_info,
+    bool* out_ready,
+    iree_hal_amdgpu_host_queue_barrier_submission_t* out_submission) {
+  IREE_ASSERT_ARGUMENT(queue);
+  IREE_ASSERT_ARGUMENT(resolution);
+  IREE_ASSERT_ARGUMENT(out_ready);
+  IREE_ASSERT_ARGUMENT(out_submission);
+  *out_ready = false;
+  memset(out_submission, 0, sizeof(*out_submission));
+
+  if (IREE_UNLIKELY(queue->is_shutting_down)) {
+    return iree_make_status(IREE_STATUS_CANCELLED, "queue shutting down");
+  }
+
+  const bool profile_queue_device_event =
+      profile_event_info &&
+      iree_hal_amdgpu_host_queue_should_profile_queue_device_events(queue);
+  const uint64_t packet_count = iree_hal_amdgpu_host_queue_barrier_packet_count(
+      resolution, profile_queue_device_event);
+  const uint64_t aql_queue_capacity = (uint64_t)queue->aql_ring.mask + 1;
+  if (IREE_UNLIKELY(packet_count > aql_queue_capacity ||
+                    packet_count > UINT32_MAX)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "barrier submission requires %" PRIu64
+        " AQL packets (%u wait barriers) but queue capacity is %" PRIu64,
+        packet_count, resolution->barrier_count, aql_queue_capacity);
+  }
+  if (IREE_UNLIKELY(signal_semaphore_list.count >
+                    queue->notification_ring.capacity)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "barrier submission requires %" PRIhsz
+                            " notification entries but ring capacity is %u",
+                            signal_semaphore_list.count,
+                            queue->notification_ring.capacity);
+  }
+
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_count_reclaim_resources(
+      signal_semaphore_list.count, operation_resource_count,
+      &out_submission->reclaim_resource_count));
+
+  iree_status_t status =
+      iree_hal_amdgpu_host_queue_reserve_profile_queue_device_events(
+          queue, profile_queue_device_event ? 1u : 0u,
+          &out_submission->profile_queue_device_events);
+  if (!iree_status_is_ok(status)) return status;
+
+  const iree_host_size_t frontier_snapshot_count =
+      iree_hal_amdgpu_host_queue_count_frontier_snapshots(
+          queue, signal_semaphore_list);
+  if (!iree_hal_amdgpu_notification_ring_can_reserve(
+          &queue->notification_ring, signal_semaphore_list.count,
+          frontier_snapshot_count)) {
+    iree_hal_amdgpu_host_queue_cancel_profile_queue_device_events(
+        queue, out_submission->profile_queue_device_events);
+    memset(out_submission, 0, sizeof(*out_submission));
+    return iree_ok_status();
+  }
+
+  uint64_t first_packet_id = 0;
+  if (!iree_hal_amdgpu_aql_ring_try_reserve(
+          &queue->aql_ring, (uint32_t)packet_count, &first_packet_id)) {
+    iree_hal_amdgpu_host_queue_cancel_profile_queue_device_events(
+        queue, out_submission->profile_queue_device_events);
+    memset(out_submission, 0, sizeof(*out_submission));
+    return iree_ok_status();
+  }
+
+  out_submission->reclaim_entry =
+      iree_hal_amdgpu_notification_ring_reclaim_entry(
+          &queue->notification_ring);
+  status = iree_hal_amdgpu_reclaim_entry_prepare(
+      out_submission->reclaim_entry, queue->block_pool,
+      out_submission->reclaim_resource_count,
+      &out_submission->reclaim_resources);
+  if (iree_status_is_ok(status)) {
+    out_submission->packet_count = (uint32_t)packet_count;
+    out_submission->first_packet_id = first_packet_id;
+    *out_ready = true;
+  } else {
+    iree_hal_amdgpu_host_queue_emit_noop_packets(queue, first_packet_id,
+                                                 (uint32_t)packet_count);
+    iree_hal_amdgpu_host_queue_cancel_profile_queue_device_events(
+        queue, out_submission->profile_queue_device_events);
+    memset(out_submission, 0, sizeof(*out_submission));
+  }
+  return status;
+}
+
+void iree_hal_amdgpu_host_queue_fail_kernel_submission(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_host_queue_kernel_submission_t* submission) {
+  iree_hal_amdgpu_host_queue_emit_reclaim_noop_packets(
+      queue, submission->reclaim_entry, submission->first_packet_id,
+      submission->packet_count, submission->kernargs.write_position,
+      submission->queue_upload.write_position);
+  memset(submission, 0, sizeof(*submission));
+}
+
+void iree_hal_amdgpu_host_queue_emit_kernel_submission_prefix(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_amdgpu_host_queue_kernel_submission_t* submission) {
+  iree_hal_amdgpu_host_queue_emit_barriers(queue, resolution,
+                                           submission->first_packet_id);
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_try_begin_dispatch_submission(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_host_size_t operation_resource_count, uint32_t kernarg_block_count,
+    iree_hal_amdgpu_profile_dispatch_event_reservation_t profile_events,
+    const iree_hal_amdgpu_host_queue_profile_event_info_t*
+        profile_queue_event_info,
+    bool* out_ready,
+    iree_hal_amdgpu_host_queue_dispatch_submission_t* out_submission) {
+  IREE_ASSERT_ARGUMENT(out_ready);
+  IREE_ASSERT_ARGUMENT(out_submission);
+  *out_ready = false;
+  memset(out_submission, 0, sizeof(*out_submission));
+
+  const bool use_profiling_completion_signal = profile_events.event_count != 0;
+  const bool profile_queue_device_event =
+      profile_queue_event_info &&
+      iree_hal_amdgpu_host_queue_should_profile_queue_device_events(queue);
+  iree_hal_amdgpu_profile_queue_device_event_reservation_t
+      profile_queue_device_events = {0};
+  if (profile_queue_device_event) {
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_host_queue_reserve_profile_queue_device_events(
+            queue, /*event_count=*/1, &profile_queue_device_events));
+  }
+  uint32_t profile_counter_set_count = 0;
+  uint32_t profile_counter_packet_count = 0;
+  uint32_t profile_trace_packet_count = 0;
+  uint32_t profile_trace_start_packet_count = 0;
+  if (use_profiling_completion_signal) {
+    profile_counter_set_count =
+        iree_hal_amdgpu_host_queue_profile_counter_set_count(queue,
+                                                             profile_events);
+    profile_counter_packet_count =
+        iree_hal_amdgpu_host_queue_profile_counter_packet_count(queue,
+                                                                profile_events);
+    profile_trace_packet_count =
+        iree_hal_amdgpu_host_queue_profile_trace_packet_count(queue,
+                                                              profile_events);
+    profile_trace_start_packet_count =
+        iree_hal_amdgpu_host_queue_profile_trace_start_packet_count(
+            queue, profile_events);
+  }
+  const uint32_t profile_queue_device_packet_count =
+      profile_queue_device_events.event_count != 0 ? 2u : 0u;
+  const uint32_t payload_packet_count =
+      1u + profile_counter_packet_count + profile_trace_packet_count +
+      (use_profiling_completion_signal ? 1u : 0u) +
+      profile_queue_device_packet_count;
+  const uint32_t profile_harvest_kernarg_block_count =
+      use_profiling_completion_signal
+          ? (uint32_t)iree_host_size_ceil_div(
+                iree_hal_amdgpu_device_timestamp_dispatch_harvest_kernarg_length(
+                    profile_events.event_count),
+                sizeof(iree_hal_amdgpu_kernarg_block_t))
+          : 0u;
+  iree_status_t status = iree_hal_amdgpu_host_queue_try_begin_kernel_submission(
+      queue, resolution, signal_semaphore_list, operation_resource_count,
+      payload_packet_count,
+      kernarg_block_count + profile_harvest_kernarg_block_count, out_ready,
+      &out_submission->kernel);
+  if (!iree_status_is_ok(status) || !*out_ready) {
+    iree_hal_amdgpu_host_queue_cancel_profile_queue_device_events(
+        queue, profile_queue_device_events);
+  }
+  if (iree_status_is_ok(status) && *out_ready) {
+    out_submission->profile_queue_device_events = profile_queue_device_events;
+    const uint32_t profile_queue_device_prefix_packet_count =
+        profile_queue_device_events.event_count != 0 ? 1u : 0u;
+    const uint32_t profile_queue_device_suffix_packet_count =
+        profile_queue_device_events.event_count != 0 ? 1u : 0u;
+    const uint64_t dispatch_packet_id =
+        out_submission->kernel.first_packet_id + resolution->barrier_count +
+        profile_queue_device_prefix_packet_count + profile_counter_set_count +
+        profile_trace_start_packet_count;
+    out_submission->dispatch_packet_id = dispatch_packet_id;
+    out_submission->dispatch_slot =
+        iree_hal_amdgpu_aql_ring_packet(&queue->aql_ring, dispatch_packet_id);
+    out_submission->dispatch_completion_signal =
+        profile_queue_device_events.event_count != 0
+            ? iree_hsa_signal_null()
+            : iree_hal_amdgpu_notification_ring_epoch_signal(
+                  &queue->notification_ring);
+    if (use_profiling_completion_signal) {
+      out_submission->profile_events = profile_events;
+      out_submission->profile_counter_set_count = profile_counter_set_count;
+      out_submission->profile_trace_start_packet_count =
+          profile_trace_start_packet_count;
+      out_submission->dispatch_completion_signal =
+          iree_hal_amdgpu_host_queue_profiling_completion_signal(
+              queue, profile_events.first_event_position);
+      out_submission->profile_harvest_slot = iree_hal_amdgpu_aql_ring_packet(
+          &queue->aql_ring, out_submission->kernel.first_packet_id +
+                                out_submission->kernel.packet_count - 1 -
+                                profile_queue_device_suffix_packet_count);
+      out_submission->profile_harvest_kernarg_blocks =
+          &out_submission->kernel.kernargs.blocks[kernarg_block_count];
+      out_submission->minimum_release_scope =
+          iree_hal_amdgpu_host_queue_max_fence_scope(
+              out_submission->minimum_release_scope,
+              IREE_HSA_FENCE_SCOPE_AGENT);
+    }
+  }
+  return status;
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_try_begin_pm4_ib_submission(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_host_size_t operation_resource_count,
+    const iree_hal_amdgpu_host_queue_profile_event_info_t*
+        profile_queue_event_info,
+    bool* out_ready,
+    iree_hal_amdgpu_host_queue_pm4_ib_submission_t* out_submission) {
+  IREE_ASSERT_ARGUMENT(out_ready);
+  IREE_ASSERT_ARGUMENT(out_submission);
+  *out_ready = false;
+  memset(out_submission, 0, sizeof(*out_submission));
+
+  if (IREE_UNLIKELY(!queue->pm4_ib_slots)) {
+    return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                            "PM4 IB slots are not available");
+  }
+
+  const bool profile_queue_device_event =
+      profile_queue_event_info &&
+      iree_hal_amdgpu_host_queue_should_profile_queue_device_events(queue);
+  iree_hal_amdgpu_profile_queue_device_event_reservation_t
+      profile_queue_device_events = {0};
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_host_queue_reserve_profile_queue_device_events(
+          queue, profile_queue_device_event ? 1u : 0u,
+          &profile_queue_device_events));
+  const uint32_t profile_queue_device_packet_count =
+      profile_queue_device_events.event_count != 0 ? 2u : 0u;
+  iree_status_t status = iree_hal_amdgpu_host_queue_try_begin_kernel_submission(
+      queue, resolution, signal_semaphore_list, operation_resource_count,
+      /*payload_packet_count=*/1 + profile_queue_device_packet_count,
+      /*kernarg_block_count=*/0, out_ready, &out_submission->kernel);
+  if (!iree_status_is_ok(status) || !*out_ready) {
+    iree_hal_amdgpu_host_queue_cancel_profile_queue_device_events(
+        queue, profile_queue_device_events);
+  }
+  if (iree_status_is_ok(status) && *out_ready) {
+    out_submission->profile_queue_device_events = profile_queue_device_events;
+    const uint32_t profile_queue_device_prefix_packet_count =
+        profile_queue_device_events.event_count != 0 ? 1u : 0u;
+    const uint64_t packet_id = out_submission->kernel.first_packet_id +
+                               resolution->barrier_count +
+                               profile_queue_device_prefix_packet_count;
+    out_submission->pm4_ib_packet_slot =
+        iree_hal_amdgpu_aql_ring_packet(&queue->aql_ring, packet_id);
+    out_submission->pm4_ib_slot =
+        &queue->pm4_ib_slots[packet_id & queue->aql_ring.mask];
+  }
+  return status;
+}
+
+// Commits the signal/frontier side of an AQL submission. Called after all
+// dispatch packet fields, kernargs, and any prefix barrier headers are written,
+// but before the final dispatch header is committed and the doorbell is rung.
+// The caller must reserve notification-ring space and prepare the next reclaim
+// entry before this call. Caller must hold submission_mutex.
+static uint64_t iree_hal_amdgpu_host_queue_commit_signals(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t signal_semaphore_list) {
+  // Advance epoch and merge this queue's axis into the accumulated frontier.
+  uint64_t epoch = iree_hal_amdgpu_notification_ring_advance_epoch(
+      &queue->notification_ring);
+  iree_async_single_frontier_t self_frontier;
+  iree_async_single_frontier_initialize(&self_frontier, queue->axis, epoch);
+  if (IREE_UNLIKELY(!iree_async_frontier_merge(
+          iree_hal_amdgpu_host_queue_frontier(queue),
+          IREE_HAL_AMDGPU_QUEUE_FRONTIER_CAPACITY,
+          iree_async_single_frontier_as_const_frontier(&self_frontier)))) {
+    // The queue frontier was full of foreign axes and did not contain this
+    // queue's own axis. Collapse to the current self axis as a safe lower bound
+    // and permanently disable frontier publication so later waits defer instead
+    // of observing under-attributed dependencies.
+    iree_async_frontier_initialize(iree_hal_amdgpu_host_queue_frontier(queue),
+                                   /*entry_count=*/1);
+    queue->frontier.entries[0] = self_frontier.entries[0];
+    queue->can_publish_frontier = false;
+  }
+
+  const iree_async_frontier_t* queue_frontier =
+      iree_hal_amdgpu_host_queue_const_frontier(queue);
+
+  // A submission with no user-visible signal semaphores still consumes one
+  // queue-private epoch and reclaim entry. Leave last_signal unchanged so a
+  // later signaled submission can still flush the previous same-semaphore span;
+  // any intervening zero-signal epochs are conservatively included in that
+  // frontier snapshot.
+  if (signal_semaphore_list.count == 0) {
+    return epoch;
+  }
+
+  for (iree_host_size_t i = 0; i < signal_semaphore_list.count; ++i) {
+    iree_hal_semaphore_t* hal_semaphore = signal_semaphore_list.semaphores[i];
+    uint64_t value = signal_semaphore_list.payload_values[i];
+    iree_async_semaphore_t* async_semaphore =
+        (iree_async_semaphore_t*)hal_semaphore;
+    bool is_amdgpu_semaphore = iree_hal_amdgpu_semaphore_isa(hal_semaphore);
+    const bool is_private_stream_signal =
+        iree_hal_amdgpu_host_queue_is_private_stream_signal(queue,
+                                                            hal_semaphore);
+
+    // Detect semaphore transition for frontier snapshot recording.
+    if (async_semaphore != queue->last_signal.semaphore) {
+      iree_hal_amdgpu_host_queue_push_frontier_snapshot_if_pending(
+          queue, queue_frontier);
+      queue->last_signal.semaphore = async_semaphore;
+      queue->last_signal.needs_frontier_snapshot =
+          queue->can_publish_frontier && !is_private_stream_signal;
+    }
+
+    // Push notification entry for drain -> signal_untainted on completion.
+    const iree_hal_amdgpu_notification_entry_flags_t notification_flags =
+        (is_private_stream_signal || !queue->can_publish_frontier)
+            ? IREE_HAL_AMDGPU_NOTIFICATION_ENTRY_FLAG_OMIT_FRONTIER_SNAPSHOT
+            : IREE_HAL_AMDGPU_NOTIFICATION_ENTRY_FLAG_NONE;
+    iree_hal_amdgpu_notification_ring_push(&queue->notification_ring, epoch,
+                                           async_semaphore, value,
+                                           notification_flags);
+    queue->last_signal.epoch = epoch;
+
+    if (is_private_stream_signal) {
+      iree_hal_amdgpu_semaphore_publish_private_stream_signal(
+          hal_semaphore, queue->axis, epoch, value);
+      continue;
+    }
+
+    // Submission-time causal marker: merge queue's frontier into the
+    // semaphore's frontier so same-queue and already-dominated cross-queue
+    // waits can resolve before GPU completion under the current submission
+    // boundary barrier policy.
+    bool did_publish_frontier = queue->can_publish_frontier;
+    if (did_publish_frontier) {
+      if (is_amdgpu_semaphore) {
+        did_publish_frontier = iree_hal_amdgpu_semaphore_publish_signal(
+            hal_semaphore, queue->axis, queue_frontier, epoch, value);
+      } else {
+        did_publish_frontier = iree_async_semaphore_merge_frontier(
+            async_semaphore, queue_frontier);
+      }
+    }
+    if (!did_publish_frontier) {
+      // The semaphore's frontier storage overflowed, so its frontier is no
+      // longer a conservative summary of this signal's causal dependencies.
+      // Clear the last-signal cache to force future waits down the software
+      // deferral path instead of unsafely eliding or under-barriering them.
+      if (is_amdgpu_semaphore) {
+        iree_hal_amdgpu_semaphore_clear_last_signal(hal_semaphore);
+      }
+      continue;
+    }
+  }
+  return epoch;
+}
+
+uint64_t iree_hal_amdgpu_host_queue_finish_kernel_submission(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_resource_t* const* operation_resources,
+    iree_host_size_t operation_resource_count,
+    iree_hal_resource_set_t** inout_resource_set,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    iree_hal_amdgpu_host_queue_kernel_submission_t* submission) {
+  const bool retain_submission_resources = iree_any_bit_set(
+      submission_flags,
+      IREE_HAL_AMDGPU_HOST_QUEUE_SUBMISSION_FLAG_RETAIN_RESOURCES);
+
+  for (iree_host_size_t i = 0; i < signal_semaphore_list.count; ++i) {
+    submission->reclaim_resources[i] =
+        (iree_hal_resource_t*)signal_semaphore_list.semaphores[i];
+    if (retain_submission_resources) {
+      iree_hal_resource_retain(submission->reclaim_resources[i]);
+    }
+  }
+  for (iree_host_size_t i = 0; i < operation_resource_count; ++i) {
+    iree_hal_resource_t* resource = operation_resources[i];
+    submission->reclaim_resources[signal_semaphore_list.count + i] = resource;
+    if (retain_submission_resources) {
+      iree_hal_resource_retain(resource);
+    }
+  }
+  submission->reclaim_entry->resource_set =
+      inout_resource_set ? *inout_resource_set : NULL;
+  if (inout_resource_set) {
+    *inout_resource_set = NULL;
+  }
+  submission->reclaim_entry->kernarg_write_position =
+      submission->kernargs.write_position;
+  submission->reclaim_entry->queue_upload_write_position =
+      submission->queue_upload.write_position;
+  submission->reclaim_entry->count = submission->reclaim_resource_count;
+  submission->reclaim_entry->pre_signal_action = submission->pre_signal_action;
+  iree_hal_amdgpu_host_queue_merge_barrier_axes(queue, resolution);
+  return iree_hal_amdgpu_host_queue_commit_signals(queue,
+                                                   signal_semaphore_list);
+}
+
+uint64_t iree_hal_amdgpu_host_queue_finish_dispatch_submission(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_resource_t* const* operation_resources,
+    iree_host_size_t operation_resource_count,
+    const iree_hal_amdgpu_host_queue_profile_event_info_t*
+        profile_queue_event_info,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    iree_hal_amdgpu_host_queue_dispatch_submission_t* submission) {
+  iree_hal_amdgpu_host_queue_emit_kernel_submission_prefix(queue, resolution,
+                                                           &submission->kernel);
+  const uint64_t submission_epoch =
+      iree_hal_amdgpu_host_queue_finish_kernel_submission(
+          queue, resolution, signal_semaphore_list, operation_resources,
+          operation_resource_count, /*inout_resource_set=*/NULL,
+          submission_flags, &submission->kernel);
+
+  iree_hal_amdgpu_profile_queue_device_event_t* queue_device_event =
+      iree_hal_amdgpu_host_queue_initialize_profile_queue_device_event(
+          queue, submission->profile_queue_device_events,
+          profile_queue_event_info);
+  if (queue_device_event) {
+    submission->kernel.reclaim_entry->queue_device_event_first_position =
+        submission->profile_queue_device_events.first_event_position;
+    submission->kernel.reclaim_entry->queue_device_event_count =
+        submission->profile_queue_device_events.event_count;
+    queue_device_event->submission_id = submission_epoch;
+  }
+
+  uint16_t profile_harvest_header = 0;
+  const iree_hsa_fence_scope_t dispatch_minimum_acquire_scope =
+      submission->kernel.kernargs.blocks
+          ? iree_hal_amdgpu_host_queue_kernarg_acquire_scope(
+                queue, submission->minimum_acquire_scope)
+          : submission->minimum_acquire_scope;
+  iree_hal_amdgpu_aql_packet_control_t dispatch_packet_control =
+      iree_hal_amdgpu_host_queue_final_dispatch_packet_control(
+          queue, resolution, signal_semaphore_list,
+          dispatch_minimum_acquire_scope, submission->minimum_release_scope);
+  if (queue_device_event || submission->profile_harvest_slot) {
+    dispatch_packet_control =
+        iree_hal_amdgpu_host_queue_payload_dispatch_packet_control(
+            resolution, dispatch_minimum_acquire_scope,
+            submission->minimum_release_scope);
+  }
+  if (submission->profile_harvest_slot) {
+    submission->kernel.reclaim_entry->profile_event_first_position =
+        submission->profile_events.first_event_position;
+    submission->kernel.reclaim_entry->profile_event_count =
+        submission->profile_events.event_count;
+    for (uint32_t i = 0; i < submission->profile_events.event_count; ++i) {
+      iree_hal_amdgpu_profile_dispatch_event_t* event =
+          iree_hal_amdgpu_host_queue_profile_dispatch_event_at(
+              queue, submission->profile_events.first_event_position + i);
+      event->submission_id = submission_epoch;
+    }
+    submission->profile_harvest_slot->dispatch.completion_signal =
+        queue_device_event ? iree_hsa_signal_null()
+                           : iree_hal_amdgpu_notification_ring_epoch_signal(
+                                 &queue->notification_ring);
+    const iree_hsa_fence_scope_t profile_harvest_acquire_scope =
+        iree_hal_amdgpu_host_queue_kernarg_acquire_scope(
+            queue, IREE_HSA_FENCE_SCOPE_AGENT);
+    profile_harvest_header = iree_hal_amdgpu_aql_make_header(
+        IREE_HSA_PACKET_TYPE_KERNEL_DISPATCH,
+        queue_device_event
+            ? iree_hal_amdgpu_aql_packet_control_barrier(
+                  iree_hal_amdgpu_host_queue_max_fence_scope(
+                      profile_harvest_acquire_scope,
+                      resolution->inline_acquire_scope),
+                  IREE_HSA_FENCE_SCOPE_SYSTEM)
+            : iree_hal_amdgpu_host_queue_final_dispatch_packet_control(
+                  queue, resolution, signal_semaphore_list,
+                  profile_harvest_acquire_scope, IREE_HSA_FENCE_SCOPE_SYSTEM));
+  }
+  const uint16_t dispatch_header = iree_hal_amdgpu_aql_make_header(
+      IREE_HSA_PACKET_TYPE_KERNEL_DISPATCH, dispatch_packet_control);
+  const uint32_t profile_queue_device_prefix_packet_count =
+      queue_device_event ? 1u : 0u;
+  iree_hal_amdgpu_host_queue_publish_submission_kernargs(queue,
+                                                         &submission->kernel);
+  if (queue_device_event) {
+    iree_hal_amdgpu_host_queue_commit_queue_device_start_packet(
+        queue, resolution,
+        submission->kernel.first_packet_id + resolution->barrier_count,
+        queue_device_event);
+  }
+  if (submission->profile_counter_set_count != 0) {
+    iree_hal_amdgpu_host_queue_commit_profile_counter_start_packets(
+        queue, submission->profile_events.first_event_position,
+        submission->profile_counter_set_count,
+        submission->kernel.first_packet_id + resolution->barrier_count +
+            profile_queue_device_prefix_packet_count,
+        iree_hal_amdgpu_aql_packet_control_barrier(
+            iree_hal_amdgpu_host_queue_max_fence_scope(
+                IREE_HSA_FENCE_SCOPE_AGENT, resolution->inline_acquire_scope),
+            IREE_HSA_FENCE_SCOPE_AGENT));
+  }
+  if (submission->profile_trace_start_packet_count != 0) {
+    iree_hal_amdgpu_host_queue_commit_profile_trace_start_packet(
+        queue, submission->profile_events.first_event_position,
+        submission->kernel.first_packet_id + resolution->barrier_count +
+            profile_queue_device_prefix_packet_count +
+            submission->profile_counter_set_count,
+        iree_hal_amdgpu_aql_packet_control_barrier(
+            iree_hal_amdgpu_host_queue_max_fence_scope(
+                IREE_HSA_FENCE_SCOPE_AGENT, resolution->inline_acquire_scope),
+            IREE_HSA_FENCE_SCOPE_AGENT));
+    iree_hal_amdgpu_host_queue_commit_profile_trace_code_object_packet(
+        queue, submission->profile_events.first_event_position,
+        submission->kernel.first_packet_id + resolution->barrier_count +
+            profile_queue_device_prefix_packet_count +
+            submission->profile_counter_set_count + 1u,
+        iree_hal_amdgpu_aql_packet_control_barrier(IREE_HSA_FENCE_SCOPE_AGENT,
+                                                   IREE_HSA_FENCE_SCOPE_AGENT));
+  }
+  iree_hal_amdgpu_aql_ring_commit(submission->dispatch_slot, dispatch_header,
+                                  submission->dispatch_setup);
+  if (submission->profile_trace_start_packet_count != 0) {
+    iree_hal_amdgpu_host_queue_commit_profile_trace_stop_packet(
+        queue, submission->profile_events.first_event_position,
+        submission->dispatch_packet_id + 1,
+        iree_hal_amdgpu_aql_packet_control_barrier(IREE_HSA_FENCE_SCOPE_AGENT,
+                                                   IREE_HSA_FENCE_SCOPE_AGENT));
+  }
+  if (submission->profile_counter_set_count != 0) {
+    const uint32_t profile_trace_stop_packet_count =
+        submission->profile_trace_start_packet_count != 0
+            ? submission->profile_events.event_count
+            : 0u;
+    iree_hal_amdgpu_host_queue_commit_profile_counter_read_stop_packets(
+        queue, submission->profile_events.first_event_position,
+        submission->profile_counter_set_count,
+        submission->dispatch_packet_id + 1 + profile_trace_stop_packet_count,
+        iree_hal_amdgpu_aql_packet_control_barrier(IREE_HSA_FENCE_SCOPE_AGENT,
+                                                   IREE_HSA_FENCE_SCOPE_AGENT));
+  }
+  if (submission->profile_harvest_slot) {
+    iree_hal_amdgpu_aql_ring_commit(submission->profile_harvest_slot,
+                                    profile_harvest_header,
+                                    submission->profile_harvest_setup);
+  }
+  if (queue_device_event) {
+    iree_hal_amdgpu_host_queue_commit_queue_device_end_packet(
+        queue, resolution, signal_semaphore_list,
+        submission->kernel.first_packet_id + submission->kernel.packet_count -
+            1,
+        queue_device_event);
+  }
+  iree_hal_amdgpu_aql_ring_doorbell(
+      &queue->aql_ring,
+      submission->kernel.first_packet_id + submission->kernel.packet_count - 1);
+  memset(submission, 0, sizeof(*submission));
+  return submission_epoch;
+}
+
+uint64_t iree_hal_amdgpu_host_queue_finish_pm4_ib_submission(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_resource_t* const* operation_resources,
+    iree_host_size_t operation_resource_count,
+    const iree_hal_amdgpu_host_queue_profile_event_info_t*
+        profile_queue_event_info,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    iree_hal_amdgpu_host_queue_pm4_ib_submission_t* submission) {
+  iree_hal_amdgpu_host_queue_emit_kernel_submission_prefix(queue, resolution,
+                                                           &submission->kernel);
+  const uint64_t submission_epoch =
+      iree_hal_amdgpu_host_queue_finish_kernel_submission(
+          queue, resolution, signal_semaphore_list, operation_resources,
+          operation_resource_count, /*inout_resource_set=*/NULL,
+          submission_flags, &submission->kernel);
+
+  iree_hal_amdgpu_profile_queue_device_event_t* queue_device_event =
+      iree_hal_amdgpu_host_queue_initialize_profile_queue_device_event(
+          queue, submission->profile_queue_device_events,
+          profile_queue_event_info);
+  if (queue_device_event) {
+    submission->kernel.reclaim_entry->queue_device_event_first_position =
+        submission->profile_queue_device_events.first_event_position;
+    submission->kernel.reclaim_entry->queue_device_event_count =
+        submission->profile_queue_device_events.event_count;
+    queue_device_event->submission_id = submission_epoch;
+    iree_hal_amdgpu_host_queue_commit_queue_device_start_packet(
+        queue, resolution,
+        submission->kernel.first_packet_id + resolution->barrier_count,
+        queue_device_event);
+  }
+
+  uint16_t pm4_ib_setup = 0;
+  uint16_t pm4_ib_header = iree_hal_amdgpu_aql_emit_pm4_ib(
+      &submission->pm4_ib_packet_slot->pm4_ib, submission->pm4_ib_slot,
+      submission->ib_dword_count,
+      queue_device_event
+          ? iree_hal_amdgpu_host_queue_payload_pm4_ib_packet_control(resolution)
+          : iree_hal_amdgpu_host_queue_final_pm4_ib_packet_control(
+                queue, resolution, signal_semaphore_list),
+      queue_device_event ? iree_hsa_signal_null()
+                         : iree_hal_amdgpu_notification_ring_epoch_signal(
+                               &queue->notification_ring),
+      &pm4_ib_setup);
+  iree_hal_amdgpu_aql_ring_commit(submission->pm4_ib_packet_slot, pm4_ib_header,
+                                  pm4_ib_setup);
+  if (queue_device_event) {
+    iree_hal_amdgpu_host_queue_commit_queue_device_end_packet(
+        queue, resolution, signal_semaphore_list,
+        submission->kernel.first_packet_id + submission->kernel.packet_count -
+            1,
+        queue_device_event);
+  }
+  iree_hal_amdgpu_aql_ring_doorbell(
+      &queue->aql_ring,
+      submission->kernel.first_packet_id + submission->kernel.packet_count - 1);
+  memset(submission, 0, sizeof(*submission));
+  return submission_epoch;
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_submit_dispatch_packet(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    const iree_hsa_kernel_dispatch_packet_t* dispatch_packet_template,
+    const void* kernargs, iree_host_size_t kernarg_length,
+    iree_hal_resource_t* const* operation_resources,
+    iree_host_size_t operation_resource_count,
+    const iree_hal_amdgpu_host_queue_profile_event_info_t*
+        profile_queue_event_info,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    bool* out_ready, uint64_t* out_submission_id) {
+  IREE_ASSERT_ARGUMENT(queue);
+  IREE_ASSERT_ARGUMENT(resolution);
+  IREE_ASSERT_ARGUMENT(dispatch_packet_template);
+  IREE_ASSERT_ARGUMENT(kernargs);
+  IREE_ASSERT_ARGUMENT(out_ready);
+  IREE_ASSERT_LE(kernarg_length, sizeof(iree_hal_amdgpu_kernarg_block_t));
+  *out_ready = false;
+  if (out_submission_id) *out_submission_id = 0;
+
+  iree_hal_amdgpu_host_queue_dispatch_submission_t submission;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_try_begin_dispatch_submission(
+      queue, resolution, signal_semaphore_list, operation_resource_count,
+      /*kernarg_block_count=*/1,
+      (iree_hal_amdgpu_profile_dispatch_event_reservation_t){0},
+      profile_queue_event_info, out_ready, &submission));
+  if (!*out_ready) return iree_ok_status();
+
+  memcpy(submission.kernel.kernargs.blocks->data, kernargs, kernarg_length);
+  submission.dispatch_setup =
+      iree_hal_amdgpu_host_queue_write_dispatch_packet_body(
+          &submission.dispatch_slot->dispatch, dispatch_packet_template,
+          submission.kernel.kernargs.blocks->data,
+          submission.dispatch_completion_signal);
+  const uint64_t submission_epoch =
+      iree_hal_amdgpu_host_queue_finish_dispatch_submission(
+          queue, resolution, signal_semaphore_list, operation_resources,
+          operation_resource_count, profile_queue_event_info, submission_flags,
+          &submission);
+  if (out_submission_id) *out_submission_id = submission_epoch;
+  return iree_ok_status();
+}
+
+uint64_t iree_hal_amdgpu_host_queue_finish_barrier_submission(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_amdgpu_reclaim_action_t pre_signal_action,
+    iree_hal_resource_t* const* operation_resources,
+    iree_host_size_t operation_resource_count,
+    const iree_hal_amdgpu_host_queue_profile_event_info_t* profile_event_info,
+    iree_hal_amdgpu_host_queue_post_commit_callback_t post_commit_callback,
+    iree_hal_resource_set_t* resource_set,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    iree_hal_amdgpu_host_queue_barrier_submission_t* submission) {
+  IREE_ASSERT_ARGUMENT(queue);
+  IREE_ASSERT_ARGUMENT(resolution);
+  IREE_ASSERT_ARGUMENT(submission);
+  const bool retain_submission_resources = iree_any_bit_set(
+      submission_flags,
+      IREE_HAL_AMDGPU_HOST_QUEUE_SUBMISSION_FLAG_RETAIN_RESOURCES);
+
+  const bool complete_with_queue_device_event =
+      submission->profile_queue_device_events.event_count != 0;
+  const bool complete_with_wait_barrier =
+      resolution->barrier_count > 0 && !complete_with_queue_device_event;
+  iree_hal_amdgpu_reclaim_entry_t* reclaim_entry = submission->reclaim_entry;
+  iree_hal_resource_t** reclaim_resources = submission->reclaim_resources;
+  reclaim_entry->pre_signal_action = pre_signal_action;
+  reclaim_entry->resource_set = resource_set;
+
+  const uint32_t aql_packet_count = submission->packet_count;
+  const uint64_t first_packet_id = submission->first_packet_id;
+
+  uint16_t completion_header = 0;
+  uint16_t completion_setup = 0;
+  iree_hal_amdgpu_aql_packet_t* completion_slot =
+      iree_hal_amdgpu_aql_ring_packet(&queue->aql_ring,
+                                      first_packet_id + aql_packet_count - 1);
+  if (complete_with_queue_device_event) {
+    iree_hal_amdgpu_host_queue_emit_barriers(queue, resolution,
+                                             first_packet_id);
+  } else if (complete_with_wait_barrier) {
+    for (uint8_t i = 0; i + 1 < resolution->barrier_count; ++i) {
+      iree_hal_amdgpu_aql_packet_t* barrier_packet =
+          iree_hal_amdgpu_aql_ring_packet(&queue->aql_ring,
+                                          first_packet_id + i);
+      uint16_t barrier_setup = 0;
+      uint16_t barrier_header =
+          iree_hal_amdgpu_host_queue_write_wait_barrier_packet_body(
+              queue, &resolution->barriers[i], first_packet_id + i,
+              iree_hsa_signal_null(), resolution->barrier_acquire_scope,
+              IREE_HSA_FENCE_SCOPE_NONE, barrier_packet, &barrier_setup);
+      iree_hal_amdgpu_aql_ring_commit(barrier_packet, barrier_header,
+                                      barrier_setup);
+    }
+  }
+  iree_hal_amdgpu_profile_queue_device_event_t* queue_device_event =
+      iree_hal_amdgpu_host_queue_initialize_profile_queue_device_event(
+          queue, submission->profile_queue_device_events, profile_event_info);
+
+  for (iree_host_size_t i = 0; i < signal_semaphore_list.count; ++i) {
+    reclaim_resources[i] =
+        (iree_hal_resource_t*)signal_semaphore_list.semaphores[i];
+    if (retain_submission_resources) {
+      iree_hal_resource_retain(reclaim_resources[i]);
+    }
+  }
+  for (iree_host_size_t i = 0; i < operation_resource_count; ++i) {
+    iree_hal_resource_t* resource = operation_resources[i];
+    reclaim_resources[signal_semaphore_list.count + i] = resource;
+    if (retain_submission_resources) {
+      iree_hal_resource_retain(resource);
+    }
+  }
+  reclaim_entry->kernarg_write_position = (uint64_t)iree_atomic_load(
+      &queue->kernarg_ring.write_position, iree_memory_order_relaxed);
+  reclaim_entry->count = submission->reclaim_resource_count;
+
+  iree_hal_amdgpu_host_queue_merge_barrier_axes(queue, resolution);
+  const uint64_t submission_epoch =
+      iree_hal_amdgpu_host_queue_commit_signals(queue, signal_semaphore_list);
+  if (queue_device_event) {
+    reclaim_entry->queue_device_event_first_position =
+        submission->profile_queue_device_events.first_event_position;
+    reclaim_entry->queue_device_event_count =
+        submission->profile_queue_device_events.event_count;
+    queue_device_event->submission_id = submission_epoch;
+  }
+  if (post_commit_callback.fn) {
+    post_commit_callback.fn(post_commit_callback.user_data,
+                            iree_hal_amdgpu_host_queue_const_frontier(queue),
+                            submission_epoch);
+  }
+  if (complete_with_queue_device_event) {
+    iree_hal_amdgpu_pm4_ib_slot_t* pm4_ib_slot =
+        &queue->pm4_ib_slots[(first_packet_id + aql_packet_count - 1) &
+                             queue->aql_ring.mask];
+    iree_hal_amdgpu_pm4_ib_builder_t builder;
+    iree_hal_amdgpu_pm4_ib_builder_initialize(pm4_ib_slot, &builder);
+    const bool did_emit =
+        iree_hal_amdgpu_pm4_ib_builder_emit_timestamp_range_to_memory(
+            &builder, &queue_device_event->start_tick,
+            &queue_device_event->end_tick);
+    IREE_ASSERT(did_emit, "PM4 timestamp range must fit profiling IB slot");
+    (void)did_emit;
+    completion_header = iree_hal_amdgpu_aql_emit_pm4_ib(
+        &completion_slot->pm4_ib, pm4_ib_slot,
+        iree_hal_amdgpu_pm4_ib_builder_dword_count(&builder),
+        iree_hal_amdgpu_host_queue_final_pm4_ib_packet_control(
+            queue, resolution, signal_semaphore_list),
+        iree_hal_amdgpu_notification_ring_epoch_signal(
+            &queue->notification_ring),
+        &completion_setup);
+  } else if (complete_with_wait_barrier) {
+    const iree_hsa_fence_scope_t release_scope =
+        iree_hal_amdgpu_host_queue_signal_list_release_scope(
+            queue, signal_semaphore_list);
+    completion_header =
+        iree_hal_amdgpu_host_queue_write_wait_barrier_packet_body(
+            queue, &resolution->barriers[resolution->barrier_count - 1],
+            first_packet_id + aql_packet_count - 1,
+            iree_hal_amdgpu_notification_ring_epoch_signal(
+                &queue->notification_ring),
+            resolution->barrier_acquire_scope, release_scope, completion_slot,
+            &completion_setup);
+  } else {
+    completion_header = iree_hal_amdgpu_aql_emit_nop(
+        &completion_slot->barrier_and,
+        iree_hal_amdgpu_host_queue_final_barrier_packet_control(
+            queue, resolution, signal_semaphore_list),
+        iree_hal_amdgpu_notification_ring_epoch_signal(
+            &queue->notification_ring));
+  }
+  iree_hal_amdgpu_aql_ring_commit(completion_slot, completion_header,
+                                  completion_setup);
+  iree_hal_amdgpu_aql_ring_doorbell(&queue->aql_ring,
+                                    first_packet_id + aql_packet_count - 1);
+  memset(submission, 0, sizeof(*submission));
+  return submission_epoch;
+}
+
+void iree_hal_amdgpu_host_queue_fail_barrier_submission(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_host_queue_barrier_submission_t* submission) {
+  iree_hal_amdgpu_host_queue_emit_noop_packets(
+      queue, submission->first_packet_id, submission->packet_count);
+  iree_hal_amdgpu_host_queue_cancel_profile_queue_device_events(
+      queue, submission->profile_queue_device_events);
+  iree_hal_amdgpu_reclaim_entry_release(submission->reclaim_entry,
+                                        queue->block_pool);
+  memset(submission, 0, sizeof(*submission));
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_try_submit_barrier(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_amdgpu_reclaim_action_t pre_signal_action,
+    iree_hal_resource_t* const* operation_resources,
+    iree_host_size_t operation_resource_count,
+    const iree_hal_amdgpu_host_queue_profile_event_info_t* profile_event_info,
+    iree_hal_amdgpu_host_queue_post_commit_callback_t post_commit_callback,
+    iree_hal_resource_set_t* resource_set,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    bool* out_ready, uint64_t* out_submission_id) {
+  IREE_ASSERT_ARGUMENT(out_ready);
+  *out_ready = false;
+  if (out_submission_id) *out_submission_id = 0;
+
+  iree_hal_amdgpu_host_queue_barrier_submission_t submission;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_try_begin_barrier_submission(
+      queue, resolution, signal_semaphore_list, operation_resource_count,
+      profile_event_info, out_ready, &submission));
+  if (!*out_ready) return iree_ok_status();
+
+  const uint64_t submission_epoch =
+      iree_hal_amdgpu_host_queue_finish_barrier_submission(
+          queue, resolution, signal_semaphore_list, pre_signal_action,
+          operation_resources, operation_resource_count, profile_event_info,
+          post_commit_callback, resource_set, submission_flags, &submission);
+  if (out_submission_id) *out_submission_id = submission_epoch;
+  return iree_ok_status();
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_submission.h b/runtime/src/iree/hal/drivers/amdgpu/host_queue_submission.h
new file mode 100644
index 0000000..62c3959
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_submission.h
@@ -0,0 +1,392 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_SUBMISSION_H_
+#define IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_SUBMISSION_H_
+
+#include "iree/hal/drivers/amdgpu/host_queue_policy.h"
+#include "iree/hal/drivers/amdgpu/host_queue_waits.h"
+#include "iree/hal/utils/resource_set.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+typedef struct iree_hal_amdgpu_host_queue_profile_event_info_t
+    iree_hal_amdgpu_host_queue_profile_event_info_t;
+
+typedef void(IREE_API_PTR* iree_hal_amdgpu_host_queue_post_commit_fn_t)(
+    void* user_data, const iree_async_frontier_t* queue_frontier,
+    uint64_t submission_id);
+
+// Optional callback invoked after queue frontier state has advanced and before
+// the completion packet is published.
+typedef struct iree_hal_amdgpu_host_queue_post_commit_callback_t {
+  // Function invoked with the queue frontier and submission id visible after
+  // commit.
+  iree_hal_amdgpu_host_queue_post_commit_fn_t fn;
+
+  // Opaque user data passed to |fn|.
+  void* user_data;
+} iree_hal_amdgpu_host_queue_post_commit_callback_t;
+
+// Returns a null post-commit callback.
+static inline iree_hal_amdgpu_host_queue_post_commit_callback_t
+iree_hal_amdgpu_host_queue_post_commit_callback_null(void) {
+  iree_hal_amdgpu_host_queue_post_commit_callback_t callback = {
+      .fn = NULL,
+      .user_data = NULL,
+  };
+  return callback;
+}
+
+// Flags controlling submission helper ownership transfers.
+typedef uint32_t iree_hal_amdgpu_host_queue_submission_flags_t;
+enum iree_hal_amdgpu_host_queue_submission_flag_bits_t {
+  IREE_HAL_AMDGPU_HOST_QUEUE_SUBMISSION_FLAG_NONE = 0u,
+  // Retains signal semaphores and operation resources into the reclaim entry.
+  // When omitted, the helper transfers one existing retain for each resource
+  // from the caller into the reclaim entry on success.
+  IREE_HAL_AMDGPU_HOST_QUEUE_SUBMISSION_FLAG_RETAIN_RESOURCES = 1u << 0,
+};
+
+// One in-flight kernel-shaped packet submission assembled under
+// submission_mutex. Owns the generic notification/reclaim, AQL reservation, and
+// kernarg reservation state shared by direct queue dispatches and
+// command-buffer replay.
+typedef struct iree_hal_amdgpu_host_queue_kernel_submission_t {
+  // Reclaim entry reserved from the notification ring for this submission.
+  iree_hal_amdgpu_reclaim_entry_t* reclaim_entry;
+  // Reclaim resource slots owned by |reclaim_entry|.
+  iree_hal_resource_t** reclaim_resources;
+  // Queue-owned kernarg reservation for this submission.
+  struct {
+    // First kernarg block reserved for this submission, or NULL when unused.
+    iree_hal_amdgpu_kernarg_block_t* blocks;
+    // Kernarg ring write position to reclaim after this submission completes.
+    uint64_t write_position;
+  } kernargs;
+  // First AQL packet id reserved for this submission.
+  uint64_t first_packet_id;
+  // Queue-control upload reservation for this submission.
+  struct {
+    // Upload ring write position to reclaim after this submission completes.
+    uint64_t write_position;
+  } queue_upload;
+  // Number of AQL packets reserved starting at |first_packet_id|.
+  uint32_t packet_count;
+  // Number of valid entries in |reclaim_resources|.
+  uint16_t reclaim_resource_count;
+  // Optional action executed before user signals are published when this
+  // submission completes.
+  iree_hal_amdgpu_reclaim_action_t pre_signal_action;
+} iree_hal_amdgpu_host_queue_kernel_submission_t;
+
+// One in-flight barrier-shaped packet submission assembled under
+// submission_mutex. Owns the generic notification/reclaim and AQL reservation
+// state for submissions that complete with a barrier or no-op packet and do not
+// require queue-owned kernarg storage.
+typedef struct iree_hal_amdgpu_host_queue_barrier_submission_t {
+  // Reclaim entry reserved from the notification ring for this submission.
+  iree_hal_amdgpu_reclaim_entry_t* reclaim_entry;
+  // Reclaim resource slots owned by |reclaim_entry|.
+  iree_hal_resource_t** reclaim_resources;
+  // Queue device profile event reservation for this submission.
+  iree_hal_amdgpu_profile_queue_device_event_reservation_t
+      profile_queue_device_events;
+  // First AQL packet id reserved for this submission.
+  uint64_t first_packet_id;
+  // Number of AQL packets reserved starting at |first_packet_id|.
+  uint32_t packet_count;
+  // Number of valid entries in |reclaim_resources|.
+  uint16_t reclaim_resource_count;
+} iree_hal_amdgpu_host_queue_barrier_submission_t;
+
+// One in-flight single-dispatch submission assembled under submission_mutex.
+// Operation implementations populate the dispatch packet and kernargs directly
+// while generic ownership and publication stay in |kernel|.
+typedef struct iree_hal_amdgpu_host_queue_dispatch_submission_t {
+  // Generic kernel-shaped submission state.
+  iree_hal_amdgpu_host_queue_kernel_submission_t kernel;
+  // Queue device profile event reservation for this submission.
+  iree_hal_amdgpu_profile_queue_device_event_reservation_t
+      profile_queue_device_events;
+  // Packet id of |dispatch_slot|.
+  uint64_t dispatch_packet_id;
+  // Uncommitted dispatch payload AQL slot.
+  iree_hal_amdgpu_aql_packet_t* dispatch_slot;
+  // Optional trailing harvest slot used when |dispatch_slot| completes with a
+  // profiling-owned signal instead of the queue epoch signal.
+  iree_hal_amdgpu_aql_packet_t* profile_harvest_slot;
+  // Queue-owned kernarg blocks reserved for |profile_harvest_slot|.
+  iree_hal_amdgpu_kernarg_block_t* profile_harvest_kernarg_blocks;
+  // Dispatch profile event reservation harvested by |profile_harvest_slot|.
+  iree_hal_amdgpu_profile_dispatch_event_reservation_t profile_events;
+  // Number of counter sets captured around |dispatch_slot|.
+  uint32_t profile_counter_set_count;
+  // Number of executable trace start packets before |dispatch_slot|.
+  uint32_t profile_trace_start_packet_count;
+  // Completion signal to write into |dispatch_slot|.
+  iree_hsa_signal_t dispatch_completion_signal;
+  // Setup bits published with |dispatch_slot|'s final header.
+  uint16_t dispatch_setup;
+  // Setup bits published with |profile_harvest_slot|'s final header.
+  uint16_t profile_harvest_setup;
+  // Minimum acquire fence scope required by operation-local data visibility.
+  iree_hsa_fence_scope_t minimum_acquire_scope;
+  // Minimum release fence scope required by operation-local data visibility.
+  iree_hsa_fence_scope_t minimum_release_scope;
+} iree_hal_amdgpu_host_queue_dispatch_submission_t;
+
+// One in-flight PM4-IB payload submission assembled under submission_mutex.
+// Operation implementations populate |pm4_ib_slot| directly while generic
+// ownership and publication stay in |kernel|.
+typedef struct iree_hal_amdgpu_host_queue_pm4_ib_submission_t {
+  // Generic payload-shaped submission state.
+  iree_hal_amdgpu_host_queue_kernel_submission_t kernel;
+  // Queue device profile event reservation for this submission.
+  iree_hal_amdgpu_profile_queue_device_event_reservation_t
+      profile_queue_device_events;
+  // Uncommitted PM4-IB payload AQL slot.
+  iree_hal_amdgpu_aql_packet_t* pm4_ib_packet_slot;
+  // Queue-owned PM4 IB storage referenced by |pm4_ib_packet_slot|.
+  iree_hal_amdgpu_pm4_ib_slot_t* pm4_ib_slot;
+  // Number of PM4 dwords populated in |pm4_ib_slot|.
+  uint32_t ib_dword_count;
+} iree_hal_amdgpu_host_queue_pm4_ib_submission_t;
+
+// Returns the number of retained resources required for a submission with
+// |signal_semaphore_count| user-visible signal semaphores and
+// |operation_resource_count| additional operation-owned resources.
+iree_status_t iree_hal_amdgpu_host_queue_count_reclaim_resources(
+    iree_host_size_t signal_semaphore_count,
+    iree_host_size_t operation_resource_count,
+    uint16_t* out_reclaim_resource_count);
+
+// Attempts to begin one kernel-shaped packet submission without waiting for
+// ring capacity. If temporary AQL/notification capacity is unavailable then
+// |out_ready| is set to false, no queue state is mutated, and OK is returned.
+// Any non-OK status is a real structural failure rather than retry state.
+//
+// Caller must hold submission_mutex.
+iree_status_t iree_hal_amdgpu_host_queue_try_begin_kernel_submission(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_host_size_t operation_resource_count, uint32_t payload_packet_count,
+    uint32_t kernarg_block_count, bool* out_ready,
+    iree_hal_amdgpu_host_queue_kernel_submission_t* out_submission);
+
+// Publishes host-populated queue-owned kernargs before committing packet
+// headers that reference them. Callers must have already written all kernarg
+// bytes for |submission|.
+static inline void iree_hal_amdgpu_host_queue_publish_submission_kernargs(
+    const iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_host_queue_kernel_submission_t* submission) {
+  if (submission->kernargs.blocks) {
+    iree_hal_amdgpu_kernarg_ring_publish_host_writes(&queue->kernarg_ring);
+  }
+}
+
+// Returns the acquire scope required for device execution to observe
+// host-populated queue-owned kernargs. Device-local rings need SYSTEM acquire
+// here because publish_submission_kernargs() only drains host writes before
+// packet publication; shader-visible memory may otherwise retain stale
+// contents across ring-slot reuse.
+static inline iree_hsa_fence_scope_t
+iree_hal_amdgpu_host_queue_kernarg_acquire_scope(
+    const iree_hal_amdgpu_host_queue_t* queue,
+    iree_hsa_fence_scope_t minimum_acquire_scope) {
+  if (iree_hal_amdgpu_kernarg_ring_requires_host_write_publication(
+          &queue->kernarg_ring)) {
+    return iree_hal_amdgpu_host_queue_max_fence_scope(
+        minimum_acquire_scope, IREE_HSA_FENCE_SCOPE_SYSTEM);
+  }
+  return minimum_acquire_scope;
+}
+
+// Attempts to begin one barrier-shaped packet submission without waiting for
+// ring capacity. If temporary AQL/notification capacity is unavailable then
+// |out_ready| is set to false, no queue state is mutated, and OK is returned.
+// Any non-OK status is a structural failure rather than retry state.
+//
+// Caller must hold submission_mutex.
+iree_status_t iree_hal_amdgpu_host_queue_try_begin_barrier_submission(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_host_size_t operation_resource_count,
+    const iree_hal_amdgpu_host_queue_profile_event_info_t* profile_event_info,
+    bool* out_ready,
+    iree_hal_amdgpu_host_queue_barrier_submission_t* out_submission);
+
+// Publishes a barrier-shaped packet submission and returns its queue submission
+// epoch. Caller must hold submission_mutex.
+uint64_t iree_hal_amdgpu_host_queue_finish_barrier_submission(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_amdgpu_reclaim_action_t pre_signal_action,
+    iree_hal_resource_t* const* operation_resources,
+    iree_host_size_t operation_resource_count,
+    const iree_hal_amdgpu_host_queue_profile_event_info_t* profile_event_info,
+    iree_hal_amdgpu_host_queue_post_commit_callback_t post_commit_callback,
+    iree_hal_resource_set_t* resource_set,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    iree_hal_amdgpu_host_queue_barrier_submission_t* submission);
+
+// Emits no-op packets for a barrier-shaped submission whose AQL slots were
+// reserved but whose payload could not be published. User signal semaphores are
+// not signaled.
+void iree_hal_amdgpu_host_queue_fail_barrier_submission(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_host_queue_barrier_submission_t* submission);
+
+// Publishes the software side of a kernel-shaped packet submission: transfers
+// operation resources and an optional resource set to the reclaim entry,
+// advances queue/frontier state, and records user-visible signal metadata.
+// Payload packet headers remain uncommitted when this returns. Caller must hold
+// submission_mutex.
+uint64_t iree_hal_amdgpu_host_queue_finish_kernel_submission(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_resource_t* const* operation_resources,
+    iree_host_size_t operation_resource_count,
+    iree_hal_resource_set_t** inout_resource_set,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    iree_hal_amdgpu_host_queue_kernel_submission_t* submission);
+
+// Emits reclaim-only no-op packets for a kernel-shaped submission whose AQL and
+// kernarg slots were reserved but whose payload could not be published. User
+// signal semaphores are not signaled; only queue-private reclaim can advance.
+void iree_hal_amdgpu_host_queue_fail_kernel_submission(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_host_queue_kernel_submission_t* submission);
+
+// Publishes wait-barrier packets for a successful kernel-shaped submission.
+// Caller must have already populated payload packet bodies but not committed
+// payload packet headers.
+void iree_hal_amdgpu_host_queue_emit_kernel_submission_prefix(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_amdgpu_host_queue_kernel_submission_t* submission);
+
+// Commits a non-final queue-device start timestamp packet. The packet must be
+// reserved in |submission| and precede the queue operation payload.
+void iree_hal_amdgpu_host_queue_commit_queue_device_start_packet(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution, uint64_t packet_id,
+    iree_hal_amdgpu_profile_queue_device_event_t* queue_device_event);
+
+// Commits the final queue-device end timestamp packet and queue completion.
+// The packet must be reserved in |submission| and follow the queue operation
+// payload.
+void iree_hal_amdgpu_host_queue_commit_queue_device_end_packet(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list, uint64_t packet_id,
+    iree_hal_amdgpu_profile_queue_device_event_t* queue_device_event);
+
+// Attempts to begin one kernel-dispatch submission without waiting for ring
+// capacity. Caller must hold submission_mutex.
+iree_status_t iree_hal_amdgpu_host_queue_try_begin_dispatch_submission(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_host_size_t operation_resource_count, uint32_t kernarg_block_count,
+    iree_hal_amdgpu_profile_dispatch_event_reservation_t profile_events,
+    const iree_hal_amdgpu_host_queue_profile_event_info_t*
+        profile_queue_event_info,
+    bool* out_ready,
+    iree_hal_amdgpu_host_queue_dispatch_submission_t* out_submission);
+
+// Attempts to begin one PM4-IB payload submission without waiting for ring
+// capacity. Caller must hold submission_mutex.
+iree_status_t iree_hal_amdgpu_host_queue_try_begin_pm4_ib_submission(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_host_size_t operation_resource_count,
+    const iree_hal_amdgpu_host_queue_profile_event_info_t*
+        profile_queue_event_info,
+    bool* out_ready,
+    iree_hal_amdgpu_host_queue_pm4_ib_submission_t* out_submission);
+
+// Writes one final dispatch packet body into an AQL slot in forward field
+// order and returns the setup bits that must be published with the header.
+uint16_t iree_hal_amdgpu_host_queue_write_dispatch_packet_body(
+    iree_hsa_kernel_dispatch_packet_t* IREE_RESTRICT dispatch_packet,
+    const iree_hsa_kernel_dispatch_packet_t* IREE_RESTRICT
+        dispatch_packet_template,
+    void* kernarg_address, iree_hsa_signal_t completion_signal);
+
+// Finishes a submission by transferring retained resources to the reclaim
+// entry, publishing queue/semaphore frontier state, committing the final
+// dispatch header, ringing the doorbell, and returning the assigned queue
+// submission epoch. Caller must hold submission_mutex.
+uint64_t iree_hal_amdgpu_host_queue_finish_dispatch_submission(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_resource_t* const* operation_resources,
+    iree_host_size_t operation_resource_count,
+    const iree_hal_amdgpu_host_queue_profile_event_info_t*
+        profile_queue_event_info,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    iree_hal_amdgpu_host_queue_dispatch_submission_t* submission);
+
+// Finishes a PM4-IB payload submission by transferring retained resources to
+// the reclaim entry, publishing queue/semaphore frontier state, committing the
+// final PM4-IB packet header, ringing the doorbell, and returning the assigned
+// queue submission epoch. Caller must hold submission_mutex.
+uint64_t iree_hal_amdgpu_host_queue_finish_pm4_ib_submission(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_resource_t* const* operation_resources,
+    iree_host_size_t operation_resource_count,
+    const iree_hal_amdgpu_host_queue_profile_event_info_t*
+        profile_queue_event_info,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    iree_hal_amdgpu_host_queue_pm4_ib_submission_t* submission);
+
+// Emits one kernel-dispatch submission using an already-prepared packet shape
+// and kernargs blob. Caller must hold submission_mutex.
+iree_status_t iree_hal_amdgpu_host_queue_submit_dispatch_packet(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    const iree_hsa_kernel_dispatch_packet_t* dispatch_packet_template,
+    const void* kernargs, iree_host_size_t kernarg_length,
+    iree_hal_resource_t* const* operation_resources,
+    iree_host_size_t operation_resource_count,
+    const iree_hal_amdgpu_host_queue_profile_event_info_t*
+        profile_queue_event_info,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    bool* out_ready, uint64_t* out_submission_id);
+
+// Attempts to submit a barrier-only operation without waiting for temporary
+// ring capacity. On not-ready, no queue state or ownership is mutated.
+// Caller must hold submission_mutex.
+iree_status_t iree_hal_amdgpu_host_queue_try_submit_barrier(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_amdgpu_reclaim_action_t pre_signal_action,
+    iree_hal_resource_t* const* operation_resources,
+    iree_host_size_t operation_resource_count,
+    const iree_hal_amdgpu_host_queue_profile_event_info_t* profile_event_info,
+    iree_hal_amdgpu_host_queue_post_commit_callback_t post_commit_callback,
+    iree_hal_resource_set_t* resource_set,
+    iree_hal_amdgpu_host_queue_submission_flags_t submission_flags,
+    bool* out_ready, uint64_t* out_submission_id);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_SUBMISSION_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_submission_test.cc b/runtime/src/iree/hal/drivers/amdgpu/host_queue_submission_test.cc
new file mode 100644
index 0000000..6560a93
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_submission_test.cc
@@ -0,0 +1,407 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/host_queue_submission.h"
+
+#include <cstdint>
+
+#include "iree/hal/api.h"
+#include "iree/hal/cts/util/test_base.h"
+#include "iree/hal/drivers/amdgpu/host_queue.h"
+#include "iree/hal/drivers/amdgpu/host_queue_profile.h"
+#include "iree/hal/drivers/amdgpu/host_queue_profile_events.h"
+#include "iree/hal/drivers/amdgpu/host_queue_waits.h"
+#include "iree/hal/drivers/amdgpu/logical_device.h"
+#include "iree/hal/drivers/amdgpu/physical_device.h"
+#include "iree/hal/drivers/amdgpu/util/pm4_capabilities.h"
+#include "iree/hal/drivers/amdgpu/util/pm4_emitter.h"
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace iree::hal::amdgpu {
+namespace {
+
+constexpr uint32_t kNoHarvestPacketOffset = UINT32_MAX;
+
+class HostQueueSubmissionTest : public ::testing::Test {
+ protected:
+  static void SetUpTestSuite() {
+    host_allocator_ = iree_allocator_system();
+    iree_status_t status = iree_hal_amdgpu_libhsa_initialize(
+        IREE_HAL_AMDGPU_LIBHSA_FLAG_NONE, iree_string_view_list_empty(),
+        host_allocator_, &libhsa_);
+    if (!iree_status_is_ok(status)) {
+      iree_status_fprint(stderr, status);
+      iree_status_free(status);
+      GTEST_SKIP() << "HSA not available, skipping tests";
+    }
+    IREE_ASSERT_OK(iree_hal_amdgpu_topology_initialize_with_defaults(
+        &libhsa_, &topology_));
+    if (topology_.gpu_agent_count == 0) {
+      GTEST_SKIP() << "no GPU devices available, skipping tests";
+    }
+  }
+
+  static void TearDownTestSuite() {
+    iree_hal_amdgpu_topology_deinitialize(&topology_);
+    iree_hal_amdgpu_libhsa_deinitialize(&libhsa_);
+  }
+
+  static iree_allocator_t host_allocator_;
+  static iree_hal_amdgpu_libhsa_t libhsa_;
+  static iree_hal_amdgpu_topology_t topology_;
+};
+
+iree_allocator_t HostQueueSubmissionTest::host_allocator_;
+iree_hal_amdgpu_libhsa_t HostQueueSubmissionTest::libhsa_;
+iree_hal_amdgpu_topology_t HostQueueSubmissionTest::topology_;
+
+class TestLogicalDevice {
+ public:
+  ~TestLogicalDevice() {
+    iree_hal_device_release(base_device_);
+    iree_hal_device_group_release(device_group_);
+  }
+
+  iree_status_t Initialize(
+      const iree_hal_amdgpu_logical_device_options_t* options,
+      const iree_hal_amdgpu_libhsa_t* libhsa,
+      const iree_hal_amdgpu_topology_t* topology,
+      iree_allocator_t host_allocator) {
+    IREE_RETURN_IF_ERROR(create_context_.Initialize(host_allocator));
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_logical_device_create(
+        IREE_SV("amdgpu"), options, libhsa, topology, create_context_.params(),
+        host_allocator, &base_device_));
+    return iree_hal_device_group_create_from_device(
+        base_device_, create_context_.frontier_tracker(), host_allocator,
+        &device_group_);
+  }
+
+  iree_hal_amdgpu_host_queue_t* first_host_queue() const {
+    iree_hal_amdgpu_logical_device_t* logical_device =
+        (iree_hal_amdgpu_logical_device_t*)base_device_;
+    if (logical_device->physical_device_count == 0) return NULL;
+    iree_hal_amdgpu_physical_device_t* physical_device =
+        logical_device->physical_devices[0];
+    if (physical_device->host_queue_count == 0) return NULL;
+    return &physical_device->host_queues[0];
+  }
+
+ private:
+  // Creation context supplying the proactor pool and frontier tracker.
+  iree::hal::cts::DeviceCreateContext create_context_;
+
+  // Test-owned device reference released before the topology-owning group.
+  iree_hal_device_t* base_device_ = NULL;
+
+  // Device group that owns the topology assigned to |base_device_|.
+  iree_hal_device_group_t* device_group_ = NULL;
+};
+
+class HostQueueHsaProfilingScope {
+ public:
+  explicit HostQueueHsaProfilingScope(iree_hal_amdgpu_host_queue_t* queue)
+      : queue_(queue) {}
+
+  ~HostQueueHsaProfilingScope() {
+    if (is_enabled_) {
+      IREE_EXPECT_OK(
+          iree_hal_amdgpu_host_queue_set_hsa_profiling_enabled(queue_, false));
+    }
+  }
+
+  iree_status_t Enable() {
+    iree_status_t status =
+        iree_hal_amdgpu_host_queue_set_hsa_profiling_enabled(queue_, true);
+    if (iree_status_is_ok(status)) {
+      is_enabled_ = true;
+    }
+    return status;
+  }
+
+ private:
+  // Host queue whose HSA profiling mode is enabled for the current test.
+  iree_hal_amdgpu_host_queue_t* queue_;
+
+  // True once |queue_| HSA profiling has been enabled and must be disabled.
+  bool is_enabled_ = false;
+};
+
+typedef struct DispatchSubmissionPlanCase {
+  // Number of wait-barrier packets preceding the dispatch payload.
+  uint8_t barrier_count;
+  // Whether a dispatch profiling event is reserved for the submission.
+  bool reserve_dispatch_event;
+  // Whether a queue-device event is reserved for the submission.
+  bool reserve_queue_device_event;
+  // Expected total AQL packets reserved for the submission.
+  uint32_t expected_packet_count;
+  // Expected dispatch packet offset from the first reserved packet.
+  uint32_t expected_dispatch_packet_offset;
+  // Expected harvest packet offset, or kNoHarvestPacketOffset when absent.
+  uint32_t expected_harvest_packet_offset;
+  // True when the dispatch packet should signal queue completion directly.
+  bool expect_dispatch_completion_signal;
+} DispatchSubmissionPlanCase;
+
+static void ExpectDispatchSubmissionPlan(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const DispatchSubmissionPlanCase& plan_case) {
+  iree_hal_amdgpu_wait_resolution_t resolution = {0};
+  resolution.barrier_count = plan_case.barrier_count;
+  const iree_hal_semaphore_list_t empty_signal_list = {0};
+  iree_hal_amdgpu_profile_dispatch_event_reservation_t profile_events = {0};
+  iree_hal_amdgpu_host_queue_profile_event_info_t profile_queue_event_info = {
+      .type = IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_DISPATCH,
+      .operation_count = 1,
+  };
+  iree_hal_amdgpu_host_queue_set_profile_flags(
+      queue, plan_case.reserve_queue_device_event
+                 ? IREE_HAL_AMDGPU_HOST_QUEUE_PROFILE_FLAG_QUEUE_DEVICE_EVENTS
+                 : IREE_HAL_AMDGPU_HOST_QUEUE_PROFILE_FLAG_NONE);
+
+  bool is_ready = false;
+  iree_hal_amdgpu_host_queue_dispatch_submission_t submission = {};
+  iree_slim_mutex_lock(&queue->locks.submission_mutex);
+  iree_status_t status = iree_ok_status();
+  if (plan_case.reserve_dispatch_event) {
+    status = iree_hal_amdgpu_host_queue_reserve_profile_dispatch_events(
+        queue, /*event_count=*/1, &profile_events);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_host_queue_try_begin_dispatch_submission(
+        queue, &resolution, empty_signal_list,
+        /*operation_resource_count=*/0, /*kernarg_block_count=*/1,
+        profile_events, &profile_queue_event_info, &is_ready, &submission);
+  }
+  if (iree_status_is_ok(status) && is_ready) {
+    EXPECT_EQ(plan_case.expected_packet_count, submission.kernel.packet_count);
+    EXPECT_EQ(submission.kernel.first_packet_id +
+                  plan_case.expected_dispatch_packet_offset,
+              submission.dispatch_packet_id);
+    EXPECT_EQ(iree_hal_amdgpu_aql_ring_packet(&queue->aql_ring,
+                                              submission.dispatch_packet_id),
+              submission.dispatch_slot);
+
+    const bool has_harvest_packet =
+        plan_case.expected_harvest_packet_offset != kNoHarvestPacketOffset;
+    if (has_harvest_packet) {
+      EXPECT_EQ(
+          iree_hal_amdgpu_aql_ring_packet(
+              &queue->aql_ring, submission.kernel.first_packet_id +
+                                    plan_case.expected_harvest_packet_offset),
+          submission.profile_harvest_slot);
+    } else {
+      EXPECT_EQ(NULL, submission.profile_harvest_slot);
+    }
+
+    const bool has_dispatch_completion_signal =
+        submission.dispatch_completion_signal.handle != 0;
+    EXPECT_EQ(plan_case.expect_dispatch_completion_signal,
+              has_dispatch_completion_signal);
+
+    iree_hal_amdgpu_host_queue_cancel_profile_queue_device_events(
+        queue, submission.profile_queue_device_events);
+    iree_hal_amdgpu_host_queue_fail_kernel_submission(queue,
+                                                      &submission.kernel);
+  }
+  if (profile_events.event_count != 0) {
+    iree_hal_amdgpu_host_queue_cancel_profile_dispatch_events(queue,
+                                                              profile_events);
+  }
+  iree_slim_mutex_unlock(&queue->locks.submission_mutex);
+
+  IREE_EXPECT_OK(status);
+  EXPECT_TRUE(is_ready);
+}
+
+typedef struct Pm4IbSubmissionPlanCase {
+  // Number of wait-barrier packets preceding the PM4-IB payload.
+  uint8_t barrier_count;
+  // Whether a queue-device event is reserved for the submission.
+  bool reserve_queue_device_event;
+  // Expected total AQL packets reserved for the submission.
+  uint32_t expected_packet_count;
+  // Expected PM4-IB packet offset from the first reserved packet.
+  uint32_t expected_pm4_ib_packet_offset;
+} Pm4IbSubmissionPlanCase;
+
+static void ExpectPm4IbSubmissionPlan(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const Pm4IbSubmissionPlanCase& plan_case) {
+  iree_hal_amdgpu_wait_resolution_t resolution = {0};
+  resolution.barrier_count = plan_case.barrier_count;
+  const iree_hal_semaphore_list_t empty_signal_list = {0};
+  iree_hal_amdgpu_host_queue_profile_event_info_t profile_queue_event_info = {
+      .type = IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_UPDATE,
+      .operation_count = 1,
+  };
+  iree_hal_amdgpu_host_queue_set_profile_flags(
+      queue, plan_case.reserve_queue_device_event
+                 ? IREE_HAL_AMDGPU_HOST_QUEUE_PROFILE_FLAG_QUEUE_DEVICE_EVENTS
+                 : IREE_HAL_AMDGPU_HOST_QUEUE_PROFILE_FLAG_NONE);
+
+  bool is_ready = false;
+  iree_hal_amdgpu_host_queue_pm4_ib_submission_t submission = {};
+  iree_slim_mutex_lock(&queue->locks.submission_mutex);
+  iree_status_t status = iree_hal_amdgpu_host_queue_try_begin_pm4_ib_submission(
+      queue, &resolution, empty_signal_list,
+      /*operation_resource_count=*/0, &profile_queue_event_info, &is_ready,
+      &submission);
+  if (iree_status_is_ok(status) && is_ready) {
+    EXPECT_EQ(plan_case.expected_packet_count, submission.kernel.packet_count);
+    EXPECT_EQ(
+        iree_hal_amdgpu_aql_ring_packet(
+            &queue->aql_ring, submission.kernel.first_packet_id +
+                                  plan_case.expected_pm4_ib_packet_offset),
+        submission.pm4_ib_packet_slot);
+    EXPECT_EQ(&queue->pm4_ib_slots[(submission.kernel.first_packet_id +
+                                    plan_case.expected_pm4_ib_packet_offset) &
+                                   queue->aql_ring.mask],
+              submission.pm4_ib_slot);
+
+    iree_hal_amdgpu_host_queue_cancel_profile_queue_device_events(
+        queue, submission.profile_queue_device_events);
+    iree_hal_amdgpu_host_queue_fail_kernel_submission(queue,
+                                                      &submission.kernel);
+  }
+  iree_slim_mutex_unlock(&queue->locks.submission_mutex);
+
+  IREE_EXPECT_OK(status);
+  EXPECT_TRUE(is_ready);
+}
+
+static bool HostQueueSupportsQueueDeviceProfiling(
+    const iree_hal_amdgpu_host_queue_t* queue) {
+  return iree_hal_amdgpu_vendor_packet_capabilities_support_timestamp_range(
+      queue->vendor_packet_capabilities);
+}
+
+TEST_F(HostQueueSubmissionTest, DispatchPacketAccountingCombinations) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+  iree_hal_amdgpu_host_queue_t* queue = test_device.first_host_queue();
+  ASSERT_NE(queue, nullptr);
+
+  HostQueueHsaProfilingScope profiling_scope(queue);
+  IREE_ASSERT_OK(profiling_scope.Enable());
+
+  const DispatchSubmissionPlanCase cases[] = {
+      {
+          /*barrier_count=*/0,
+          /*reserve_dispatch_event=*/false,
+          /*reserve_queue_device_event=*/false,
+          /*expected_packet_count=*/1,
+          /*expected_dispatch_packet_offset=*/0,
+          /*expected_harvest_packet_offset=*/kNoHarvestPacketOffset,
+          /*expect_dispatch_completion_signal=*/true,
+      },
+      {
+          /*barrier_count=*/2,
+          /*reserve_dispatch_event=*/false,
+          /*reserve_queue_device_event=*/false,
+          /*expected_packet_count=*/3,
+          /*expected_dispatch_packet_offset=*/2,
+          /*expected_harvest_packet_offset=*/kNoHarvestPacketOffset,
+          /*expect_dispatch_completion_signal=*/true,
+      },
+      {
+          /*barrier_count=*/0,
+          /*reserve_dispatch_event=*/false,
+          /*reserve_queue_device_event=*/true,
+          /*expected_packet_count=*/3,
+          /*expected_dispatch_packet_offset=*/1,
+          /*expected_harvest_packet_offset=*/kNoHarvestPacketOffset,
+          /*expect_dispatch_completion_signal=*/false,
+      },
+      {
+          /*barrier_count=*/0,
+          /*reserve_dispatch_event=*/true,
+          /*reserve_queue_device_event=*/false,
+          /*expected_packet_count=*/2,
+          /*expected_dispatch_packet_offset=*/0,
+          /*expected_harvest_packet_offset=*/1,
+          /*expect_dispatch_completion_signal=*/true,
+      },
+      {
+          /*barrier_count=*/2,
+          /*reserve_dispatch_event=*/true,
+          /*reserve_queue_device_event=*/false,
+          /*expected_packet_count=*/4,
+          /*expected_dispatch_packet_offset=*/2,
+          /*expected_harvest_packet_offset=*/3,
+          /*expect_dispatch_completion_signal=*/true,
+      },
+      {
+          /*barrier_count=*/0,
+          /*reserve_dispatch_event=*/true,
+          /*reserve_queue_device_event=*/true,
+          /*expected_packet_count=*/4,
+          /*expected_dispatch_packet_offset=*/1,
+          /*expected_harvest_packet_offset=*/2,
+          /*expect_dispatch_completion_signal=*/true,
+      },
+  };
+  for (const DispatchSubmissionPlanCase& plan_case : cases) {
+    if (plan_case.reserve_queue_device_event &&
+        !HostQueueSupportsQueueDeviceProfiling(queue)) {
+      GTEST_SKIP() << "queue device profiling is not supported";
+    }
+    ExpectDispatchSubmissionPlan(queue, plan_case);
+  }
+}
+
+TEST_F(HostQueueSubmissionTest, Pm4IbPacketAccountingCombinations) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+  iree_hal_amdgpu_host_queue_t* queue = test_device.first_host_queue();
+  ASSERT_NE(queue, nullptr);
+  if (!queue->pm4_ib_slots) {
+    GTEST_SKIP() << "PM4 IB slots are not available";
+  }
+
+  HostQueueHsaProfilingScope profiling_scope(queue);
+  IREE_ASSERT_OK(profiling_scope.Enable());
+
+  const Pm4IbSubmissionPlanCase cases[] = {
+      {
+          /*barrier_count=*/0,
+          /*reserve_queue_device_event=*/false,
+          /*expected_packet_count=*/1,
+          /*expected_pm4_ib_packet_offset=*/0,
+      },
+      {
+          /*barrier_count=*/2,
+          /*reserve_queue_device_event=*/false,
+          /*expected_packet_count=*/3,
+          /*expected_pm4_ib_packet_offset=*/2,
+      },
+      {
+          /*barrier_count=*/0,
+          /*reserve_queue_device_event=*/true,
+          /*expected_packet_count=*/3,
+          /*expected_pm4_ib_packet_offset=*/1,
+      },
+  };
+  for (const Pm4IbSubmissionPlanCase& plan_case : cases) {
+    if (plan_case.reserve_queue_device_event &&
+        !HostQueueSupportsQueueDeviceProfiling(queue)) {
+      GTEST_SKIP() << "queue device profiling is not supported";
+    }
+    ExpectPm4IbSubmissionPlan(queue, plan_case);
+  }
+}
+
+}  // namespace
+}  // namespace iree::hal::amdgpu
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_timestamp.c b/runtime/src/iree/hal/drivers/amdgpu/host_queue_timestamp.c
new file mode 100644
index 0000000..cb262f7
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_timestamp.c
@@ -0,0 +1,59 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/host_queue_timestamp.h"
+
+#include "iree/hal/drivers/amdgpu/util/aql_ring.h"
+#include "iree/hal/drivers/amdgpu/util/pm4_emitter.h"
+
+static iree_hal_amdgpu_pm4_ib_slot_t* iree_hal_amdgpu_host_queue_pm4_ib_slot(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t packet_id) {
+  return &queue->pm4_ib_slots[packet_id & queue->aql_ring.mask];
+}
+
+void iree_hal_amdgpu_host_queue_commit_timestamp_start(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t packet_id,
+    iree_hal_amdgpu_aql_packet_control_t packet_control, uint64_t* start_tick) {
+  iree_hal_amdgpu_aql_packet_t* packet =
+      iree_hal_amdgpu_aql_ring_packet(&queue->aql_ring, packet_id);
+  iree_hal_amdgpu_pm4_ib_slot_t* pm4_ib_slot =
+      iree_hal_amdgpu_host_queue_pm4_ib_slot(queue, packet_id);
+  uint16_t setup = 0;
+  const uint16_t header = iree_hal_amdgpu_aql_emit_timestamp_start(
+      &packet->pm4_ib, pm4_ib_slot, packet_control, start_tick, &setup);
+  iree_hal_amdgpu_aql_ring_commit(packet, header, setup);
+}
+
+void iree_hal_amdgpu_host_queue_commit_timestamp_end(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t packet_id,
+    iree_hal_amdgpu_aql_packet_control_t packet_control,
+    iree_hsa_signal_t completion_signal, uint64_t* end_tick) {
+  iree_hal_amdgpu_aql_packet_t* packet =
+      iree_hal_amdgpu_aql_ring_packet(&queue->aql_ring, packet_id);
+  iree_hal_amdgpu_pm4_ib_slot_t* pm4_ib_slot =
+      iree_hal_amdgpu_host_queue_pm4_ib_slot(queue, packet_id);
+  uint16_t setup = 0;
+  const uint16_t header = iree_hal_amdgpu_aql_emit_timestamp_end(
+      &packet->pm4_ib, pm4_ib_slot, packet_control, completion_signal, end_tick,
+      &setup);
+  iree_hal_amdgpu_aql_ring_commit(packet, header, setup);
+}
+
+void iree_hal_amdgpu_host_queue_commit_timestamp_range(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t packet_id,
+    iree_hal_amdgpu_aql_packet_control_t packet_control,
+    iree_hsa_signal_t completion_signal, uint64_t* start_tick,
+    uint64_t* end_tick) {
+  iree_hal_amdgpu_aql_packet_t* packet =
+      iree_hal_amdgpu_aql_ring_packet(&queue->aql_ring, packet_id);
+  iree_hal_amdgpu_pm4_ib_slot_t* pm4_ib_slot =
+      iree_hal_amdgpu_host_queue_pm4_ib_slot(queue, packet_id);
+  uint16_t setup = 0;
+  const uint16_t header = iree_hal_amdgpu_aql_emit_timestamp_range(
+      &packet->pm4_ib, pm4_ib_slot, packet_control, completion_signal,
+      start_tick, end_tick, &setup);
+  iree_hal_amdgpu_aql_ring_commit(packet, header, setup);
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_timestamp.h b/runtime/src/iree/hal/drivers/amdgpu/host_queue_timestamp.h
new file mode 100644
index 0000000..6280460
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_timestamp.h
@@ -0,0 +1,40 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_TIMESTAMP_H_
+#define IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_TIMESTAMP_H_
+
+#include "iree/hal/drivers/amdgpu/host_queue.h"
+#include "iree/hal/drivers/amdgpu/util/aql_emitter.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Publishes a PM4 timestamp packet that writes a top-of-pipe tick.
+void iree_hal_amdgpu_host_queue_commit_timestamp_start(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t packet_id,
+    iree_hal_amdgpu_aql_packet_control_t packet_control, uint64_t* start_tick);
+
+// Publishes a PM4 timestamp packet that writes a bottom-of-pipe tick.
+void iree_hal_amdgpu_host_queue_commit_timestamp_end(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t packet_id,
+    iree_hal_amdgpu_aql_packet_control_t packet_control,
+    iree_hsa_signal_t completion_signal, uint64_t* end_tick);
+
+// Publishes one PM4 timestamp packet that writes both top-of-pipe and
+// bottom-of-pipe ticks.
+void iree_hal_amdgpu_host_queue_commit_timestamp_range(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t packet_id,
+    iree_hal_amdgpu_aql_packet_control_t packet_control,
+    iree_hsa_signal_t completion_signal, uint64_t* start_tick,
+    uint64_t* end_tick);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_TIMESTAMP_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_waits.c b/runtime/src/iree/hal/drivers/amdgpu/host_queue_waits.c
new file mode 100644
index 0000000..5d0017a
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_waits.c
@@ -0,0 +1,389 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/host_queue_waits.h"
+
+#include <string.h>
+
+#include "iree/hal/drivers/amdgpu/host_queue_policy.h"
+#include "iree/hal/drivers/amdgpu/semaphore.h"
+#include "iree/hal/drivers/amdgpu/util/aql_emitter.h"
+#include "iree/hal/drivers/amdgpu/util/pm4_emitter.h"
+
+// Returns true if |frontier| contains |axis| at an epoch >= |target_epoch|.
+// Frontier entries are sorted by axis, so the scan can stop as soon as the
+// target axis has been passed.
+static bool iree_hal_amdgpu_frontier_dominates_axis(
+    const iree_async_frontier_t* frontier, iree_async_axis_t axis,
+    uint64_t target_epoch) {
+  for (uint8_t i = 0; i < frontier->entry_count; ++i) {
+    const iree_async_frontier_entry_t* entry = &frontier->entries[i];
+    if (entry->axis < axis) continue;
+    return entry->axis == axis && entry->epoch >= target_epoch;
+  }
+  return false;
+}
+
+// Appends a tier-2 barrier to |resolution| while preserving ascending axis
+// order and deduplicating repeated axes across multiple waits. If the barrier
+// budget is exhausted, returns false so the caller can fall back to software
+// deferral instead of relying on a debug-only assert.
+//
+// Caller must hold submission_mutex.
+static bool iree_hal_amdgpu_host_queue_append_wait_barrier(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_wait_resolution_t* resolution, iree_async_axis_t axis,
+    hsa_signal_t epoch_signal, uint64_t target_epoch,
+    iree_hsa_fence_scope_t acquire_scope) {
+  if (queue->wait_barrier_strategy ==
+      IREE_HAL_AMDGPU_WAIT_BARRIER_STRATEGY_DEFER) {
+    return false;
+  }
+
+  resolution->barrier_acquire_scope =
+      iree_hal_amdgpu_host_queue_max_fence_scope(
+          resolution->barrier_acquire_scope, acquire_scope);
+
+  uint8_t insert_ordinal = 0;
+  while (insert_ordinal < resolution->barrier_count &&
+         resolution->barriers[insert_ordinal].axis < axis) {
+    ++insert_ordinal;
+  }
+
+  if (insert_ordinal < resolution->barrier_count &&
+      resolution->barriers[insert_ordinal].axis == axis) {
+    if (target_epoch > resolution->barriers[insert_ordinal].target_epoch) {
+      resolution->barriers[insert_ordinal].target_epoch = target_epoch;
+      resolution->barriers[insert_ordinal].epoch_signal = epoch_signal;
+    }
+    return true;
+  }
+
+  if (resolution->barrier_count >= IREE_HAL_AMDGPU_QUEUE_FRONTIER_CAPACITY) {
+    return false;
+  }
+
+  for (uint8_t i = resolution->barrier_count; i > insert_ordinal; --i) {
+    resolution->barriers[i] = resolution->barriers[i - 1];
+  }
+
+  iree_hal_amdgpu_wait_barrier_t* barrier =
+      &resolution->barriers[insert_ordinal];
+  barrier->axis = axis;
+  barrier->epoch_signal = epoch_signal;
+  barrier->target_epoch = target_epoch;
+  ++resolution->barrier_count;
+  return true;
+}
+
+// Resolves a single (semaphore, value) wait. Appends tier 2 barriers to
+// |resolution| if needed. Returns true if the wait is resolved (satisfied or
+// barriers appended), false if deferral is needed.
+//
+// Tier 0: timeline_value >= value -> already completed.
+// Tier 1a: signal submitted by this queue -> elide directly from last_signal
+//   under this strategy's current all-barrier AQL policy, no semaphore-frontier
+//   mutex/copy.
+// Tier 1b: signal submitted by a producer epoch that exactly covers the
+//   semaphore frontier, and this queue already dominates that producer epoch
+//   -> elide directly from last_signal.
+// Tier 1c: signal submitted + queue frontier dominates -> no barrier needed.
+// Tier 2a: signal submitted by a local producer epoch that exactly covers the
+//   semaphore frontier -> append one barrier directly from last_signal.
+// Tier 2b: signal submitted + local queue axes from semaphore frontier ->
+//   barriers appended from the undominated frontier entries.
+// Tier 3: anything else -> deferral.
+//
+// Caller must hold submission_mutex.
+static bool iree_hal_amdgpu_host_queue_resolve_wait(
+    iree_hal_amdgpu_host_queue_t* queue, iree_hal_semaphore_t* semaphore,
+    uint64_t value, iree_hal_amdgpu_wait_resolution_t* resolution) {
+  iree_async_semaphore_t* async_semaphore = (iree_async_semaphore_t*)semaphore;
+  iree_hsa_fence_scope_t acquire_scope =
+      iree_hal_amdgpu_host_queue_wait_acquire_scope(queue, semaphore);
+
+  // A failed semaphore must take the software deferral path so the timepoint
+  // callback propagates the failure to this op's signal semaphores.
+  if (iree_async_semaphore_query_status(async_semaphore) != IREE_STATUS_OK) {
+    return false;
+  }
+
+  // Tier 0: already completed. Cheapest check (one atomic load).
+  uint64_t current_value = (uint64_t)iree_atomic_load(
+      &async_semaphore->timeline_value, iree_memory_order_acquire);
+  if (current_value >= value) {
+    resolution->inline_acquire_scope =
+        iree_hal_amdgpu_host_queue_max_fence_scope(
+            resolution->inline_acquire_scope, acquire_scope);
+    return true;
+  }
+
+  // Not completed. Must be an AMDGPU semaphore for device-side resolution.
+  if (!iree_hal_amdgpu_semaphore_isa(semaphore)) return false;
+
+  // Has the signal for |value| been submitted? The last_signal cache records
+  // the most recent signal's value. If it hasn't reached |value|, the signal
+  // hasn't been submitted yet (wait-before-signal) and the frontier does not
+  // reflect the signal's causal context - frontier dominance would be a
+  // false positive.
+  iree_hal_amdgpu_last_signal_flags_t signal_flags =
+      IREE_HAL_AMDGPU_LAST_SIGNAL_FLAG_NONE;
+  iree_async_axis_t signal_axis = 0;
+  uint64_t signal_epoch = 0;
+  uint64_t signal_value = 0;
+  if (!iree_hal_amdgpu_last_signal_load(
+          iree_hal_amdgpu_semaphore_last_signal(semaphore), &signal_flags,
+          &signal_axis, &signal_epoch, &signal_value) ||
+      signal_value < value) {
+    return false;
+  }
+
+  // Tier 1a: same-queue elision from the last_signal cache alone.
+  // HAL queues are not FIFO: user-visible order comes from semaphore edges.
+  // This shortcut is valid only because AMDGPU queue submissions represent
+  // inline waits on the first payload packet with AQL BARRIER set, so
+  // submission order under submission_mutex creates a single in-queue
+  // dependency chain. If that policy is relaxed for independent HIP streams,
+  // this branch must emit an explicit same-queue dependency edge instead of
+  // returning purely from producer axis identity.
+  if (signal_axis == queue->axis) {
+    resolution->inline_acquire_scope =
+        iree_hal_amdgpu_host_queue_max_fence_scope(
+            resolution->inline_acquire_scope, acquire_scope);
+    return true;
+  }
+
+  // Tier 1b/2a: when the semaphore cache says the producer queue's epoch
+  // exactly covers the unresolved semaphore frontier, resolve directly from
+  // that producer axis/epoch snapshot. This avoids the semaphore-frontier
+  // mutex/copy on common cross-queue handoffs while still refusing to guess
+  // on TP fan-in semaphores with independent producers.
+  if (signal_flags & IREE_HAL_AMDGPU_LAST_SIGNAL_FLAG_PRODUCER_FRONTIER_EXACT) {
+    if (iree_hal_amdgpu_frontier_dominates_axis(
+            iree_hal_amdgpu_host_queue_const_frontier(queue), signal_axis,
+            signal_epoch)) {
+      resolution->inline_acquire_scope =
+          iree_hal_amdgpu_host_queue_max_fence_scope(
+              resolution->inline_acquire_scope, acquire_scope);
+      return true;
+    }
+    hsa_signal_t peer_signal;
+    if (!iree_hal_amdgpu_epoch_signal_table_lookup(queue->epoch_table,
+                                                   signal_axis, &peer_signal)) {
+      return false;
+    }
+    return iree_hal_amdgpu_host_queue_append_wait_barrier(
+        queue, resolution, signal_axis, peer_signal, signal_epoch,
+        acquire_scope);
+  }
+
+  // Signal submitted, not completed. Copy the semaphore's frontier and find
+  // axes that our queue doesn't dominate.
+  iree_hal_amdgpu_fixed_frontier_t semaphore_frontier;
+  iree_async_semaphore_query_frontier(
+      async_semaphore,
+      iree_hal_amdgpu_fixed_frontier_as_frontier(&semaphore_frontier),
+      IREE_HAL_AMDGPU_QUEUE_FRONTIER_CAPACITY);
+
+  iree_async_frontier_entry_t
+      undominated[IREE_HAL_AMDGPU_QUEUE_FRONTIER_CAPACITY];
+  uint8_t undominated_count = iree_async_frontier_find_undominated(
+      iree_hal_amdgpu_host_queue_const_frontier(queue),
+      iree_hal_amdgpu_fixed_frontier_as_frontier(&semaphore_frontier),
+      IREE_HAL_AMDGPU_QUEUE_FRONTIER_CAPACITY, undominated);
+
+  // Tier 1c: all axes dominated -> no additional barrier needed.
+  if (undominated_count == 0) {
+    resolution->inline_acquire_scope =
+        iree_hal_amdgpu_host_queue_max_fence_scope(
+            resolution->inline_acquire_scope, acquire_scope);
+    return true;
+  }
+
+  // Tier 2b: look up each undominated axis in the epoch signal table.
+  // If any axis is not a local queue (remote, collective, host), defer.
+  for (uint8_t i = 0; i < undominated_count; ++i) {
+    hsa_signal_t peer_signal;
+    if (!iree_hal_amdgpu_epoch_signal_table_lookup(
+            queue->epoch_table, undominated[i].axis, &peer_signal)) {
+      return false;
+    }
+    if (!iree_hal_amdgpu_host_queue_append_wait_barrier(
+            queue, resolution, undominated[i].axis, peer_signal,
+            undominated[i].epoch, acquire_scope)) {
+      return false;
+    }
+  }
+  return true;
+}
+
+void iree_hal_amdgpu_host_queue_resolve_waits(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t wait_semaphore_list,
+    iree_hal_amdgpu_wait_resolution_t* out_resolution) {
+  out_resolution->barrier_count = 0;
+  out_resolution->needs_deferral = false;
+  memset(out_resolution->reserved, 0, sizeof(out_resolution->reserved));
+  out_resolution->wait_count = wait_semaphore_list.count > UINT32_MAX
+                                   ? UINT32_MAX
+                                   : (uint32_t)wait_semaphore_list.count;
+  out_resolution->profile_event_flags = IREE_HAL_PROFILE_QUEUE_EVENT_FLAG_NONE;
+  out_resolution->inline_acquire_scope = IREE_HSA_FENCE_SCOPE_NONE;
+  out_resolution->barrier_acquire_scope = IREE_HSA_FENCE_SCOPE_NONE;
+
+  for (iree_host_size_t i = 0; i < wait_semaphore_list.count; ++i) {
+    if (!iree_hal_amdgpu_host_queue_resolve_wait(
+            queue, wait_semaphore_list.semaphores[i],
+            wait_semaphore_list.payload_values[i], out_resolution)) {
+      out_resolution->needs_deferral = true;
+      out_resolution->barrier_count = 0;
+      return;
+    }
+  }
+}
+
+uint16_t iree_hal_amdgpu_host_queue_write_wait_barrier_packet_body(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_barrier_t* barrier, uint64_t packet_id,
+    hsa_signal_t completion_signal, iree_hsa_fence_scope_t acquire_scope,
+    iree_hsa_fence_scope_t release_scope, iree_hal_amdgpu_aql_packet_t* packet,
+    uint16_t* out_setup) {
+  // The epoch signal starts at INITIAL_VALUE and is decremented by 1 per
+  // completion. A submission at one-based epoch N is complete when the queue's
+  // current epoch has reached N, so the barrier fires when:
+  //   signal_load(s) <= INITIAL_VALUE - target_epoch
+  // BARRIER_VALUE only supports LT, so encode <= as:
+  //   signal_load(s) < INITIAL_VALUE - target_epoch + 1
+  //
+  // Epochs are one-based by construction (see notification_ring_advance_epoch).
+  // Plugging target_epoch == 0 into the formula collapses to
+  // "signal < INITIAL_VALUE + 1", which is trivially true and would let a wait
+  // for "no submission yet" fire immediately. Reserving zero for that state
+  // keeps the wait formula safe without any special-case branches here.
+  iree_hsa_signal_value_t compare_value =
+      (iree_hsa_signal_value_t)(IREE_HAL_AMDGPU_EPOCH_INITIAL_VALUE -
+                                barrier->target_epoch + 1);
+
+  switch (queue->wait_barrier_strategy) {
+    case IREE_HAL_AMDGPU_WAIT_BARRIER_STRATEGY_AQL_BARRIER_VALUE:
+      return iree_hal_amdgpu_aql_emit_barrier_value(
+          &packet->barrier_value,
+          (iree_hsa_signal_t){.handle = barrier->epoch_signal.handle},
+          IREE_HSA_SIGNAL_CONDITION_LT, compare_value,
+          (iree_hsa_signal_value_t)INT64_MAX,
+          iree_hal_amdgpu_aql_packet_control_barrier(acquire_scope,
+                                                     release_scope),
+          completion_signal, out_setup);
+    case IREE_HAL_AMDGPU_WAIT_BARRIER_STRATEGY_PM4_WAIT_REG_MEM64: {
+      IREE_ASSERT(queue->pm4_ib_slots != NULL);
+      iree_hal_amdgpu_pm4_ib_slot_t* ib_slot =
+          &queue->pm4_ib_slots[packet_id & queue->aql_ring.mask];
+      uint32_t ib_dword_count = iree_hal_amdgpu_pm4_emit_wait_reg_mem64(
+          ib_slot, (iree_hsa_signal_t){.handle = barrier->epoch_signal.handle},
+          compare_value, (iree_hsa_signal_value_t)INT64_MAX);
+      return iree_hal_amdgpu_aql_emit_pm4_ib(
+          &packet->pm4_ib, ib_slot, ib_dword_count,
+          iree_hal_amdgpu_aql_packet_control_barrier(acquire_scope,
+                                                     release_scope),
+          completion_signal, out_setup);
+    }
+    case IREE_HAL_AMDGPU_WAIT_BARRIER_STRATEGY_DEFER:
+    default:
+      IREE_ASSERT(false,
+                  "resolved wait barriers require a device-side strategy");
+      *out_setup = 0;
+      return 0;
+  }
+}
+
+void iree_hal_amdgpu_host_queue_emit_barriers(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    uint64_t first_packet_id) {
+  for (uint8_t i = 0; i < resolution->barrier_count; ++i) {
+    iree_hal_amdgpu_aql_packet_t* packet =
+        iree_hal_amdgpu_aql_ring_packet(&queue->aql_ring, first_packet_id + i);
+    uint16_t setup = 0;
+    uint16_t header = iree_hal_amdgpu_host_queue_write_wait_barrier_packet_body(
+        queue, &resolution->barriers[i], first_packet_id + i,
+        iree_hsa_signal_null(), resolution->barrier_acquire_scope,
+        IREE_HSA_FENCE_SCOPE_NONE, packet, &setup);
+    iree_hal_amdgpu_aql_ring_commit(packet, header, setup);
+  }
+}
+
+void iree_hal_amdgpu_host_queue_merge_barrier_axes(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution) {
+  if (resolution->barrier_count == 0 || !queue->can_publish_frontier) return;
+  iree_hal_amdgpu_fixed_frontier_t barrier_frontier;
+  iree_async_frontier_initialize(
+      iree_hal_amdgpu_fixed_frontier_as_frontier(&barrier_frontier),
+      resolution->barrier_count);
+  for (uint8_t i = 0; i < resolution->barrier_count; ++i) {
+    barrier_frontier.entries[i].axis = resolution->barriers[i].axis;
+    barrier_frontier.entries[i].epoch = resolution->barriers[i].target_epoch;
+  }
+  if (!iree_async_frontier_merge(
+          iree_hal_amdgpu_host_queue_frontier(queue),
+          IREE_HAL_AMDGPU_QUEUE_FRONTIER_CAPACITY,
+          iree_hal_amdgpu_fixed_frontier_as_frontier(&barrier_frontier))) {
+    queue->can_publish_frontier = false;
+  }
+}
+
+const iree_async_frontier_t* iree_hal_amdgpu_host_queue_pool_requester_frontier(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    iree_hal_amdgpu_fixed_frontier_t* storage) {
+  const iree_async_frontier_t* queue_frontier =
+      iree_hal_amdgpu_host_queue_const_frontier(queue);
+  if (resolution->barrier_count == 0) return queue_frontier;
+
+  memcpy(storage, &queue->frontier, sizeof(*storage));
+  iree_hal_amdgpu_fixed_frontier_t barrier_frontier;
+  iree_async_frontier_initialize(
+      iree_hal_amdgpu_fixed_frontier_as_frontier(&barrier_frontier),
+      resolution->barrier_count);
+  for (uint8_t i = 0; i < resolution->barrier_count; ++i) {
+    barrier_frontier.entries[i].axis = resolution->barriers[i].axis;
+    barrier_frontier.entries[i].epoch = resolution->barriers[i].target_epoch;
+  }
+  if (!iree_async_frontier_merge(
+          iree_hal_amdgpu_fixed_frontier_as_frontier(storage),
+          IREE_HAL_AMDGPU_QUEUE_FRONTIER_CAPACITY,
+          iree_hal_amdgpu_fixed_frontier_as_frontier(&barrier_frontier))) {
+    return queue_frontier;
+  }
+  return iree_hal_amdgpu_fixed_frontier_as_frontier(storage);
+}
+
+bool iree_hal_amdgpu_host_queue_append_pool_wait_frontier_barriers(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_async_frontier_t* requester_frontier,
+    const iree_async_frontier_t* wait_frontier,
+    iree_hal_amdgpu_wait_resolution_t* resolution) {
+  if (!wait_frontier) return false;
+  for (uint8_t i = 0; i < wait_frontier->entry_count; ++i) {
+    const iree_async_frontier_entry_t* entry = &wait_frontier->entries[i];
+    if (iree_hal_amdgpu_frontier_dominates_axis(requester_frontier, entry->axis,
+                                                entry->epoch)) {
+      continue;
+    }
+    hsa_signal_t peer_signal;
+    if (!iree_hal_amdgpu_epoch_signal_table_lookup(queue->epoch_table,
+                                                   entry->axis, &peer_signal)) {
+      return false;
+    }
+    if (!iree_hal_amdgpu_host_queue_append_wait_barrier(
+            queue, resolution, entry->axis, peer_signal, entry->epoch,
+            iree_hal_amdgpu_host_queue_axis_acquire_scope(queue,
+                                                          entry->axis))) {
+      return false;
+    }
+  }
+  return true;
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_queue_waits.h b/runtime/src/iree/hal/drivers/amdgpu/host_queue_waits.h
new file mode 100644
index 0000000..746bd5b
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/host_queue_waits.h
@@ -0,0 +1,113 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_WAITS_H_
+#define IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_WAITS_H_
+
+#include "iree/hal/drivers/amdgpu/host_queue.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Stack-allocable frontier storage sized to the host queue frontier capacity.
+typedef iree_hal_amdgpu_host_queue_frontier_t iree_hal_amdgpu_fixed_frontier_t;
+
+static inline iree_async_frontier_t* iree_hal_amdgpu_fixed_frontier_as_frontier(
+    iree_hal_amdgpu_fixed_frontier_t* storage) {
+  return iree_async_fixed_frontier_as_frontier(storage);
+}
+
+//===----------------------------------------------------------------------===//
+// Wait resolution
+//===----------------------------------------------------------------------===//
+
+// A single device-side wait barrier emitted for each undominated local queue
+// axis that the current submission must wait on before its own packets run.
+typedef struct iree_hal_amdgpu_wait_barrier_t {
+  // Producer queue axis to wait on.
+  iree_async_axis_t axis;
+  // Producer queue epoch signal consumed by the barrier packet.
+  hsa_signal_t epoch_signal;
+  // Producer queue epoch that must be complete before the barrier releases.
+  uint64_t target_epoch;
+} iree_hal_amdgpu_wait_barrier_t;
+
+// Result of resolving a wait_semaphore_list. Either all waits are resolved
+// with |barriers[0..barrier_count]|, or software deferral is required.
+typedef struct iree_hal_amdgpu_wait_resolution_t {
+  // Number of valid device-side barriers in |barriers|.
+  uint8_t barrier_count;
+  // True if at least one wait requires software deferral.
+  bool needs_deferral;
+  // Padding reserved to keep the fence scopes aligned.
+  uint8_t reserved[2];
+  // Number of wait semaphore edges represented by this resolution.
+  uint32_t wait_count;
+  // Queue profiling flags describing how this resolution was reached.
+  iree_hal_profile_queue_event_flags_t profile_event_flags;
+  // Acquire scope required on the final operation packet for waits resolved
+  // without dedicated wait-barrier packets.
+  iree_hsa_fence_scope_t inline_acquire_scope;
+  // Acquire scope required on dedicated wait-barrier packets.
+  iree_hsa_fence_scope_t barrier_acquire_scope;
+  // Device-side wait barriers sorted by ascending producer axis.
+  iree_hal_amdgpu_wait_barrier_t
+      barriers[IREE_HAL_AMDGPU_QUEUE_FRONTIER_CAPACITY];
+} iree_hal_amdgpu_wait_resolution_t;
+
+// Resolves a wait_semaphore_list into device-side barriers or software
+// deferral. Caller must hold submission_mutex.
+void iree_hal_amdgpu_host_queue_resolve_waits(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_semaphore_list_t wait_semaphore_list,
+    iree_hal_amdgpu_wait_resolution_t* out_resolution);
+
+// Writes one device-side wait barrier packet body and returns the header/setup
+// bits that will publish it. Caller must commit the packet header after this
+// returns. Caller must hold submission_mutex.
+uint16_t iree_hal_amdgpu_host_queue_write_wait_barrier_packet_body(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_barrier_t* barrier, uint64_t packet_id,
+    hsa_signal_t completion_signal, iree_hsa_fence_scope_t acquire_scope,
+    iree_hsa_fence_scope_t release_scope, iree_hal_amdgpu_aql_packet_t* packet,
+    uint16_t* out_setup);
+
+// Emits device-side wait barrier packets for a resolved wait list. Caller must
+// have reserved |resolution->barrier_count| consecutive AQL slots starting at
+// |first_packet_id|. Caller must hold submission_mutex.
+void iree_hal_amdgpu_host_queue_emit_barriers(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    uint64_t first_packet_id);
+
+// Merges the resolved wait barrier axes into the queue's accumulated frontier
+// after successful submission publication. Caller must hold submission_mutex.
+void iree_hal_amdgpu_host_queue_merge_barrier_axes(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution);
+
+// Returns the queue-order frontier to use for pool acquire_reservation() after
+// accounting for dependency barriers in |resolution|.
+const iree_async_frontier_t* iree_hal_amdgpu_host_queue_pool_requester_frontier(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_wait_resolution_t* resolution,
+    iree_hal_amdgpu_fixed_frontier_t* storage);
+
+// Imports a pool-owned death frontier into the queue's AQL dependency list.
+// Entries already dominated by |requester_frontier| are skipped; remaining
+// local-queue axes become device-side wait barriers in |resolution|.
+bool iree_hal_amdgpu_host_queue_append_pool_wait_frontier_barriers(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_async_frontier_t* requester_frontier,
+    const iree_async_frontier_t* wait_frontier,
+    iree_hal_amdgpu_wait_resolution_t* resolution);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_HOST_QUEUE_WAITS_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_service.c b/runtime/src/iree/hal/drivers/amdgpu/host_service.c
deleted file mode 100644
index e393529..0000000
--- a/runtime/src/iree/hal/drivers/amdgpu/host_service.c
+++ /dev/null
@@ -1,620 +0,0 @@
-// Copyright 2025 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/hal/drivers/amdgpu/host_service.h"
-
-#include "iree/hal/api.h"
-#include "iree/hal/drivers/amdgpu/device/host_client.h"
-
-// A signal payload that is unlikely to ever be hit during real execution.
-// This is used to indicate a default state that we want to detect changes on.
-#define IREE_HAL_AMDGPU_INVALID_SIGNAL_VALUE ((hsa_signal_value_t) - 2)
-
-static void iree_hal_amdgpu_host_service_fail(
-    iree_hal_amdgpu_host_service_t* service, iree_status_t status);
-
-//===----------------------------------------------------------------------===//
-// HSA_PACKET_TYPE_BARRIER_AND
-//===----------------------------------------------------------------------===//
-
-// Issues an `HSA_PACKET_TYPE_BARRIER_AND` on the host worker.
-// Signals completion when all dependencies are resolved.
-static iree_status_t iree_hal_amdgpu_host_service_issue_barrier_and(
-    iree_hal_amdgpu_host_service_t* service,
-    const iree_hal_amdgpu_libhsa_t* libhsa,
-    const hsa_barrier_and_packet_t* packet) {
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  // NOTE: this will wake if all signals ever passes through 0 - it's possible
-  // for it to be non-zero upon return if something else modifies it (but we
-  // should never be doing that).
-  //
-  // NOTE: hsa_amd_signal_wait_any has relaxed memory semantics and to have the
-  // proper acquire behavior we need to load the signal value ourselves.
-  static_assert(IREE_ARRAYSIZE(packet->dep_signal) == 5, "expecting 5 signals");
-  hsa_signal_condition_t conds[5] = {
-      HSA_SIGNAL_CONDITION_EQ, HSA_SIGNAL_CONDITION_EQ, HSA_SIGNAL_CONDITION_EQ,
-      HSA_SIGNAL_CONDITION_EQ, HSA_SIGNAL_CONDITION_EQ,
-  };
-  hsa_signal_value_t values[5] = {
-      0, 0, 0, 0, 0,
-  };
-  const uint32_t satisfying_index = iree_hsa_amd_signal_wait_all(
-      IREE_LIBHSA(libhsa), IREE_ARRAYSIZE(packet->dep_signal),
-      (hsa_signal_t*)packet->dep_signal, conds, values, UINT64_MAX,
-      HSA_WAIT_STATE_BLOCKED,
-      /*satisfying_values=*/NULL);
-  iree_status_t status = iree_ok_status();
-  if (IREE_LIKELY(satisfying_index != UINT32_MAX)) {
-    // NOTE: if all signals are null then wait_all will return index 0 as having
-    // satisfied the wait even if it's invalid. We assume here there's at least
-    // one valid signal (as otherwise why would the wait have been requested?)
-    // and try to find the first, using the satisfying_index as a starting point
-    // ideally pointing at a valid signal and we break from the loop
-    // immediately.
-    for (uint32_t i = satisfying_index; i < IREE_ARRAYSIZE(packet->dep_signal);
-         ++i) {
-      if (packet->dep_signal[i].handle) {
-        iree_hsa_signal_load_scacquire(IREE_LIBHSA(libhsa),
-                                       packet->dep_signal[i]);
-        break;
-      }
-    }
-  } else {
-    status = iree_make_status(IREE_STATUS_INTERNAL,
-                              "hsa_amd_signal_wait_all failed");
-  }
-
-  // TODO(benvanik): figure out the expected behavior of completion signals on
-  // failed barriers (it may be in the HSA spec). For now we don't signal and
-  // rely on a global device loss to alert the user.
-  if (iree_status_is_ok(status) && packet->completion_signal.handle != 0) {
-    iree_hsa_signal_subtract_screlease(IREE_LIBHSA(libhsa),
-                                       packet->completion_signal, 1);
-  }
-
-  IREE_TRACE_ZONE_END(z0);
-  return status;
-}
-
-//===----------------------------------------------------------------------===//
-// HSA_PACKET_TYPE_BARRIER_OR
-//===----------------------------------------------------------------------===//
-
-// Issues an `HSA_PACKET_TYPE_BARRIER_OR` on the host worker.
-// Signals completion when any dependency is resolved.
-static iree_status_t iree_hal_amdgpu_host_service_issue_barrier_or(
-    iree_hal_amdgpu_host_service_t* service,
-    const iree_hal_amdgpu_libhsa_t* libhsa,
-    const hsa_barrier_or_packet_t* packet) {
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  // NOTE: this will wake if the signal ever passes through 0 - it's possible
-  // for it to be non-zero upon return if something else modifies it (but we
-  // should never be doing that).
-  //
-  // NOTE: hsa_amd_signal_wait_any has relaxed memory semantics and to have the
-  // proper acquire behavior we need to load the signal value ourselves.
-  static_assert(IREE_ARRAYSIZE(packet->dep_signal) == 5, "expecting 5 signals");
-  hsa_signal_condition_t conds[5] = {
-      HSA_SIGNAL_CONDITION_EQ, HSA_SIGNAL_CONDITION_EQ, HSA_SIGNAL_CONDITION_EQ,
-      HSA_SIGNAL_CONDITION_EQ, HSA_SIGNAL_CONDITION_EQ,
-  };
-  hsa_signal_value_t values[5] = {
-      0, 0, 0, 0, 0,
-  };
-  hsa_signal_value_t satisfying_value = 0;
-  const uint32_t satisfying_index = iree_hsa_amd_signal_wait_any(
-      IREE_LIBHSA(libhsa), IREE_ARRAYSIZE(packet->dep_signal),
-      (hsa_signal_t*)packet->dep_signal, conds, values, UINT64_MAX,
-      HSA_WAIT_STATE_BLOCKED, &satisfying_value);
-  iree_status_t status = iree_ok_status();
-  if (IREE_LIKELY(satisfying_index != UINT32_MAX)) {
-    iree_hsa_signal_load_scacquire(IREE_LIBHSA(libhsa),
-                                   packet->dep_signal[satisfying_index]);
-  } else {
-    status = iree_make_status(IREE_STATUS_INTERNAL,
-                              "hsa_amd_signal_wait_any failed");
-  }
-
-  // TODO(benvanik): figure out the expected behavior of completion signals on
-  // failed barriers (it may be in the HSA spec). For now we don't signal and
-  // rely on a global device loss to alert the user.
-  if (iree_status_is_ok(status) && packet->completion_signal.handle != 0) {
-    iree_hsa_signal_subtract_screlease(IREE_LIBHSA(libhsa),
-                                       packet->completion_signal, 1);
-  }
-
-  IREE_TRACE_ZONE_END(z0);
-  return status;
-}
-
-//===----------------------------------------------------------------------===//
-// HSA_AMD_PACKET_TYPE_BARRIER_VALUE
-//===----------------------------------------------------------------------===//
-
-// Returns true if `|current_value| |condition| |desired_value|` is true.
-static bool iree_hsa_condition_is_met(hsa_signal_condition32_t condition,
-                                      hsa_signal_value_t current_value,
-                                      hsa_signal_value_t desired_value) {
-  switch (condition) {
-    default:
-    case HSA_SIGNAL_CONDITION_EQ:
-      return current_value == desired_value;
-    case HSA_SIGNAL_CONDITION_NE:
-      return current_value != desired_value;
-    case HSA_SIGNAL_CONDITION_LT:
-      return current_value < desired_value;
-    case HSA_SIGNAL_CONDITION_GTE:
-      return current_value >= desired_value;
-  }
-}
-
-// Issues an `HSA_AMD_PACKET_TYPE_BARRIER_VALUE` on the host worker.
-// Signals completion when the value condition is met.
-static iree_status_t iree_hal_amdgpu_host_service_issue_barrier_value(
-    iree_hal_amdgpu_host_service_t* service,
-    const iree_hal_amdgpu_libhsa_t* libhsa,
-    const hsa_amd_barrier_value_packet_t* packet) {
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  // TODO(benvanik): propagate failures according to spec?
-
-  // HSA signal wait doesn't take a mask so if the mask is used we have to
-  // emulate it. Hopefully most cases are not using a mask.
-  if (IREE_LIKELY(packet->mask == UINT64_MAX)) {
-    // NOTE: this will wake if the signal ever meets the condition - it's
-    // possible for it to be unsatisfied upon return if something else modifies
-    // it (but we should never be doing that).
-    iree_hsa_signal_wait_scacquire(IREE_LIBHSA(libhsa), packet->signal,
-                                   packet->cond, packet->value, UINT64_MAX,
-                                   HSA_WAIT_STATE_BLOCKED);
-  } else {
-    // Emulate a wait that takes a mask. This will wake each time the value
-    // changes until the condition is met.
-    IREE_TRACE_ZONE_APPEND_TEXT(z0, "emulated mask");
-    hsa_signal_value_t value =
-        iree_hsa_signal_load_scacquire(IREE_LIBHSA(libhsa), packet->signal);
-    while (!iree_hsa_condition_is_met(packet->cond, value & packet->mask,
-                                      packet->value)) {
-      value = iree_hsa_signal_wait_scacquire(
-          IREE_LIBHSA(libhsa), packet->signal, HSA_SIGNAL_CONDITION_NE, value,
-          UINT64_MAX, HSA_WAIT_STATE_BLOCKED);
-    }
-  }
-
-  if (packet->completion_signal.handle != 0) {
-    iree_hsa_signal_subtract_screlease(IREE_LIBHSA(libhsa),
-                                       packet->completion_signal, 1);
-  }
-
-  IREE_TRACE_ZONE_END(z0);
-  return iree_ok_status();
-}
-
-//===----------------------------------------------------------------------===//
-// HSA_PACKET_TYPE_AGENT_DISPATCH
-//===----------------------------------------------------------------------===//
-
-// IREE_HAL_AMDGPU_DEVICE_HOST_CALL_POST_SIGNAL
-static iree_status_t iree_hal_amdgpu_host_post_signal(
-    iree_hal_amdgpu_host_service_t* service, iree_hal_semaphore_t* semaphore,
-    uint64_t payload) {
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  // Notify the (likely) external semaphore of its new value. It may make
-  // platform calls or do other bookkeeping.
-  iree_status_t status =
-      iree_hal_semaphore_signal(semaphore, payload, /*frontier=*/NULL);
-
-  IREE_TRACE_ZONE_END(z0);
-  return status;
-}
-
-// IREE_HAL_AMDGPU_DEVICE_HOST_CALL_POST_RELEASE
-static iree_status_t iree_hal_amdgpu_host_post_release(
-    iree_hal_amdgpu_host_service_t* service, iree_host_size_t resource_count,
-    iree_hal_resource_t* resources[]) {
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  // Release each resource. Some entries may be NULL.
-  for (iree_host_size_t i = 0; i < resource_count; ++i) {
-    iree_hal_resource_release(resources[i]);
-  }
-
-  IREE_TRACE_ZONE_END(z0);
-  return iree_ok_status();
-}
-
-// Issues an `HSA_PACKET_TYPE_AGENT_DISPATCH` on the host service.
-// Signals completion if the requested operation completes synchronously and
-// otherwise may signal asynchronously after this function returns.
-//
-// If the operation is asynchronous then the service->outstanding_signal must be
-// incremented and decremented when the operation completes along with the
-// packet completion signal.
-static iree_status_t iree_hal_amdgpu_host_service_issue_agent_dispatch(
-    iree_hal_amdgpu_host_service_t* service,
-    const iree_hal_amdgpu_libhsa_t* libhsa,
-    const hsa_agent_dispatch_packet_t* packet) {
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  iree_status_t status = iree_ok_status();
-  switch (packet->type) {
-    case IREE_HAL_AMDGPU_DEVICE_HOST_CALL_POST_SIGNAL:
-      status = iree_hal_amdgpu_host_post_signal(
-          service, (iree_hal_semaphore_t*)packet->arg[0], packet->arg[1]);
-      break;
-    case IREE_HAL_AMDGPU_DEVICE_HOST_CALL_POST_RELEASE:
-      status = iree_hal_amdgpu_host_post_release(
-          service, IREE_ARRAYSIZE(packet->arg),
-          (iree_hal_resource_t**)packet->arg);
-      break;
-    default:
-      status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "unknown type %u",
-                                packet->type);
-      break;
-  }
-
-  if (iree_status_is_ok(status) && packet->completion_signal.handle != 0) {
-    iree_hsa_signal_subtract_screlease(IREE_LIBHSA(libhsa),
-                                       packet->completion_signal, 1);
-  }
-
-  IREE_TRACE_ZONE_END(z0);
-  return status;
-}
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_host_service_t
-//===----------------------------------------------------------------------===//
-
-static iree_status_t iree_hal_amdgpu_host_service_barrier(
-    iree_hal_amdgpu_host_service_t* service,
-    const iree_hal_amdgpu_libhsa_t* libhsa) {
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  // Wait until all asynchronous operations complete (signal reaches <= 0).
-  // Note that this may fail and return a value greater than that for various
-  // reasons.
-  hsa_signal_value_t value = iree_hsa_signal_wait_scacquire(
-      IREE_LIBHSA(libhsa), service->outstanding_signal, HSA_SIGNAL_CONDITION_LT,
-      1ull, UINT64_MAX, HSA_WAIT_STATE_BLOCKED);
-  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, value);
-
-  // TODO(benvanik): a way to get errors back? we could reserve the high bit and
-  // make all values with it set interpret as a status. A high bit set would
-  // make the hsa_signal_value_t negative and allow the wait to be satisfied.
-  iree_status_t status =
-      value == 0ull ? iree_ok_status()
-                    : iree_make_status(
-                          IREE_STATUS_ABORTED,
-                          "asynchronous work failed and the queue is invalid");
-
-  IREE_TRACE_ZONE_END(z0);
-  return status;
-}
-
-static int iree_hal_amdgpu_host_service_main(void* entry_arg) {
-  iree_hal_amdgpu_host_service_t* service =
-      (iree_hal_amdgpu_host_service_t*)entry_arg;
-  const iree_hal_amdgpu_libhsa_t* libhsa = service->libhsa;
-
-  // Main loop.
-  const uint64_t queue_mask = service->queue->size - 1;
-  uint64_t last_packet_id = IREE_HAL_AMDGPU_INVALID_SIGNAL_VALUE;
-  uint64_t read_index = 0;
-  while (true) {
-    // Since we are in MULTI mode we just check that the packet ID changed but
-    // don't trust it as an indication of what we should process as it may be
-    // set out of order from multiple producers.
-    const uint64_t new_packet_id = (uint64_t)iree_hsa_signal_wait_scacquire(
-        IREE_LIBHSA(libhsa), service->doorbell, HSA_SIGNAL_CONDITION_NE,
-        last_packet_id, UINT64_MAX, HSA_WAIT_STATE_BLOCKED);
-    last_packet_id = new_packet_id;
-    if (new_packet_id == UINT64_MAX) {
-      // Exit signal.
-      break;
-    }
-
-    // Drain all packets. Note that this may block and we may get new packets
-    // enqueued while processing.
-    while (read_index != iree_hsa_queue_load_write_index_scacquire(
-                             IREE_LIBHSA(libhsa), service->queue)) {
-      IREE_TRACE_ZONE_BEGIN_NAMED(
-          z_packet, "iree_hal_amdgpu_host_service_process_packet");
-      IREE_TRACE_ZONE_APPEND_VALUE_I64(z_packet, read_index);
-
-      // Reference packet in queue memory. Note that we want to get it out of
-      // there ASAP to free up the space in the queue.
-      // NOTE: we cast to an agent packet here but don't yet know the type.
-      // We only use the struct to parse the header bits common to all packets.
-      hsa_agent_dispatch_packet_t* packet_ptr =
-          (hsa_agent_dispatch_packet_t*)service->queue->base_address +
-          (read_index & queue_mask);
-
-      // Spin until the packet is populated.
-      // In AQL it's valid to bump the write index prior to the packet header
-      // being updated and the queue must stall until it's no longer INVALID.
-      //
-      // Note that because this is expected to be cycles away we don't yield and
-      // risk an OS context switch. If we hit cases where the write index and
-      // packet header stores are split in time we'll want to do something
-      // smarter like a backoff.
-      uint32_t packet_type = HSA_PACKET_TYPE_INVALID;
-      do {
-        const uint32_t packet_header = iree_atomic_load(
-            (iree_atomic_uint32_t*)packet_ptr, iree_memory_order_acquire);
-        packet_type = (packet_header >> HSA_PACKET_HEADER_TYPE) &
-                      ((1 << HSA_PACKET_HEADER_WIDTH_TYPE) - 1);
-      } while (packet_type == HSA_PACKET_TYPE_INVALID);
-
-      // Copy the packet locally and swap the packet back to INVALID so that
-      // producers can overwrite it immediately. By bumping the read index
-      // producers will be able to reserve it if they are waiting for capacity
-      // to be available.
-      uint8_t packet_data[64];
-      memcpy(packet_data, packet_ptr, sizeof(packet_data));
-      iree_atomic_store((iree_atomic_uint32_t*)packet_ptr,
-                        (HSA_PACKET_TYPE_INVALID << HSA_PACKET_HEADER_TYPE),
-                        iree_memory_order_relaxed);
-      iree_hsa_queue_store_read_index_screlease(IREE_LIBHSA(libhsa),
-                                                service->queue, ++read_index);
-
-      // If the packet has a barrier bit set then we need to block until all
-      // prior queue operations have completed. Most of our operations are
-      // synchronous but it's possible to have async operations outstanding and
-      // we need to wait for them.
-      const uint16_t packet_header = *(const uint16_t*)packet_data;
-      if (packet_header & (1u << HSA_PACKET_HEADER_BARRIER)) {
-        iree_status_t barrier_status =
-            iree_hal_amdgpu_host_service_barrier(service, libhsa);
-        if (!iree_status_is_ok(barrier_status)) {
-          IREE_TRACE_ZONE_APPEND_TEXT(
-              z_packet,
-              iree_status_code_string(iree_status_code(barrier_status)));
-          iree_hal_amdgpu_host_service_fail(service, barrier_status);
-          break;
-        }
-      }
-
-      // Switch on packet type and issue.
-      iree_status_t status = iree_ok_status();
-      switch (packet_type) {
-        case HSA_PACKET_TYPE_BARRIER_AND:
-          status = iree_hal_amdgpu_host_service_issue_barrier_and(
-              service, libhsa, (const hsa_barrier_and_packet_t*)packet_data);
-          break;
-        case HSA_PACKET_TYPE_BARRIER_OR:
-          status = iree_hal_amdgpu_host_service_issue_barrier_or(
-              service, libhsa, (const hsa_barrier_or_packet_t*)packet_data);
-          break;
-        case HSA_PACKET_TYPE_VENDOR_SPECIFIC: {
-          const hsa_amd_vendor_packet_header_t vendor_header =
-              *(const hsa_amd_vendor_packet_header_t*)packet_data;
-          switch (vendor_header.AmdFormat) {
-            case HSA_AMD_PACKET_TYPE_BARRIER_VALUE:
-              status = iree_hal_amdgpu_host_service_issue_barrier_value(
-                  service, libhsa,
-                  (const hsa_amd_barrier_value_packet_t*)packet_data);
-              break;
-            default:
-              status = iree_make_status(IREE_STATUS_INTERNAL,
-                                        "invalid vendor packet type %u",
-                                        vendor_header.AmdFormat);
-              break;
-          }
-          break;
-        }
-        case HSA_PACKET_TYPE_AGENT_DISPATCH:
-          status = iree_hal_amdgpu_host_service_issue_agent_dispatch(
-              service, libhsa, (const hsa_agent_dispatch_packet_t*)packet_data);
-          break;
-        default:
-          status = iree_make_status(IREE_STATUS_INTERNAL,
-                                    "invalid packet type %u", packet_type);
-          break;
-      }
-      if (!iree_status_is_ok(status)) {
-        IREE_TRACE_ZONE_APPEND_TEXT(
-            z_packet, iree_status_code_string(iree_status_code(status)));
-        iree_hal_amdgpu_host_service_fail(service, status);
-      }
-
-      IREE_TRACE_ZONE_END(z_packet);
-    }
-  }
-
-  // Wait for any outstanding asynchronous operations to complete.
-  // This ensures that we don't free memory that may be in use by them.
-  // Note that only this worker is allowed to wait on the signal so we have to
-  // do it here.
-  IREE_IGNORE_ERROR(iree_hal_amdgpu_host_service_barrier(service, libhsa));
-
-  return 0;
-}
-
-iree_status_t iree_hal_amdgpu_host_service_initialize(
-    const iree_hal_amdgpu_libhsa_t* libhsa, iree_host_size_t host_ordinal,
-    hsa_agent_t host_agent, hsa_region_t host_fine_region,
-    iree_host_size_t device_ordinal,
-    iree_hal_amdgpu_error_callback_t error_callback,
-    iree_allocator_t host_allocator,
-    iree_hal_amdgpu_host_service_t* out_service) {
-  IREE_ASSERT_ARGUMENT(out_service);
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  memset(out_service, 0, sizeof(*out_service));
-  out_service->libhsa = libhsa;
-  out_service->error_callback = error_callback;
-  out_service->failure_code = IREE_ATOMIC_VAR_INIT(0);
-
-  // NUMA node.
-  uint32_t host_agent_node = 0;
-  IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0, iree_hsa_agent_get_info(IREE_LIBHSA(libhsa), host_agent,
-                                  HSA_AGENT_INFO_NODE, &host_agent_node));
-
-  // Pin the thread to the NUMA node specified.
-  // We don't care which core but do want it to be one of those associated with
-  // the devices this worker is servicing.
-  iree_thread_affinity_t thread_affinity = {0};
-  iree_thread_affinity_set_group_any(host_agent_node, &thread_affinity);
-
-  // Create a semaphore for tracking outstanding asynchronous operations. It's
-  // marked as only being "consumed" by the host agent (waited on). Other agents
-  // and threads can signal it.
-  iree_status_t status = iree_hsa_amd_signal_create(
-      IREE_LIBHSA(libhsa), 0ull, 1, &host_agent,
-      /*attributes=*/0, &out_service->outstanding_signal);
-
-  // Create the doorbell for the soft queue. It's marked as only being
-  // "consumed" by the host agent (waited on). Other agents can signal it.
-  if (iree_status_is_ok(status)) {
-    status = iree_hsa_amd_signal_create(
-        IREE_LIBHSA(libhsa), IREE_HAL_AMDGPU_INVALID_SIGNAL_VALUE, 1,
-        &host_agent,
-        /*attributes=*/0, &out_service->doorbell);
-  }
-
-  // Create and allocate the soft queue.
-  // We cannot change the queue size after it is created and pick something
-  // large enough to be able to reasonably satisfy all requests. Must be a power
-  // of two. We allocate the queue in the host pool closest to the devices that
-  // will be producing for it and where this service will be consuming it.
-  if (iree_status_is_ok(status)) {
-    status = iree_hsa_soft_queue_create(
-        IREE_LIBHSA(libhsa), host_fine_region,
-        IREE_HAL_AMDGPU_HOST_SERVICE_QUEUE_CAPACITY, HSA_QUEUE_TYPE_MULTI,
-        HSA_QUEUE_FEATURE_AGENT_DISPATCH, out_service->doorbell,
-        &out_service->queue);
-  }
-
-  // Create the worker thread for handling device library requests.
-  // The worker may start immediately and use the queue/doorbell.
-  if (iree_status_is_ok(status)) {
-    char thread_name[32];
-    iree_snprintf(thread_name, IREE_ARRAYSIZE(thread_name),
-                  "iree-amdgpu-host-%" PRIhsz "-%" PRIhsz, host_ordinal,
-                  device_ordinal);
-    const iree_thread_create_params_t thread_params = {
-        .name = iree_make_cstring_view(thread_name),
-        .stack_size = 0,  // default
-        .create_suspended = false,
-        .priority_class = IREE_THREAD_PRIORITY_CLASS_HIGH,
-        .initial_affinity = thread_affinity,
-    };
-    status =
-        iree_thread_create(iree_hal_amdgpu_host_service_main, out_service,
-                           thread_params, host_allocator, &out_service->thread);
-  }
-
-  if (!iree_status_is_ok(status)) {
-    iree_hal_amdgpu_host_service_deinitialize(out_service);
-  }
-  IREE_TRACE_ZONE_END(z0);
-  return status;
-}
-
-void iree_hal_amdgpu_host_service_deinitialize(
-    iree_hal_amdgpu_host_service_t* service) {
-  IREE_ASSERT_ARGUMENT(service);
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  const iree_hal_amdgpu_libhsa_t* libhsa = service->libhsa;
-
-  // Mark the queue as inactive. This is likely a no-op for our soft queue from
-  // the API perspective but can help tooling see our intent.
-  if (service->queue) {
-    IREE_IGNORE_ERROR(
-        iree_hsa_queue_inactivate(IREE_LIBHSA(libhsa), service->queue));
-  }
-
-  // Signal the doorbell to the termination value. This should wake the worker
-  // and have it exit.
-  if (service->doorbell.handle) {
-    iree_hsa_signal_store_screlease(IREE_LIBHSA(libhsa), service->doorbell,
-                                    UINT64_MAX);
-  }
-
-  // Join thread after it has shut down.
-  if (service->thread) {
-    iree_thread_join(service->thread);
-    iree_thread_release(service->thread);
-    service->thread = NULL;
-  }
-
-  // Tear down HSA resources.
-  if (service->queue) {
-    IREE_IGNORE_ERROR(
-        iree_hsa_queue_destroy(IREE_LIBHSA(libhsa), service->queue));
-  }
-  if (service->doorbell.handle) {
-    IREE_IGNORE_ERROR(
-        iree_hsa_signal_destroy(IREE_LIBHSA(libhsa), service->doorbell));
-  }
-  if (service->outstanding_signal.handle) {
-    IREE_IGNORE_ERROR(iree_hsa_signal_destroy(IREE_LIBHSA(libhsa),
-                                              service->outstanding_signal));
-  }
-
-  IREE_TRACE_ZONE_END(z0);
-}
-
-static void iree_hal_amdgpu_host_service_fail(
-    iree_hal_amdgpu_host_service_t* service, iree_status_t status) {
-  IREE_ASSERT_ARGUMENT(service);
-  IREE_ASSERT(!iree_status_is_ok(status));
-  IREE_TRACE_ZONE_BEGIN(z0);
-  IREE_TRACE_ZONE_APPEND_TEXT(
-      z0, iree_status_code_string(iree_status_code(status)));
-
-  // Try to set our local status - we only preserve the first failure so only
-  // do this if we are going from a valid status to a failed one.
-  uint64_t old_status_code = 0;
-  uint64_t new_status_code = (uint64_t)iree_status_code(status);
-  const bool first_failure = iree_atomic_compare_exchange_strong(
-      &service->failure_code, &old_status_code, new_status_code,
-      iree_memory_order_acq_rel,
-      iree_memory_order_relaxed /* old_status is unused */);
-  if (first_failure && service->error_callback.fn) {
-    // Notify user-provided function; ownership of the status is transferred to
-    // the callee.
-    service->error_callback.fn(service->error_callback.user_data, status);
-  } else {
-    // No callback or callback already issued prior, drop the error.
-    IREE_IGNORE_ERROR(status);
-  }
-
-  // Force the worker to exit (soon).
-  if (service->doorbell.handle) {
-    iree_hsa_signal_store_screlease(IREE_LIBHSA(service->libhsa),
-                                    service->doorbell, UINT64_MAX);
-  }
-
-  IREE_TRACE_ZONE_END(z0);
-}
-
-void iree_hal_amdgpu_host_service_notify_completion(
-    iree_hal_amdgpu_host_async_token_t token, iree_status_t status) {
-  IREE_ASSERT_ARGUMENT(token.service);
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  // TODO(benvanik): we could track queue depth here or use the token to track
-  // the async operation across threads.
-
-  // If the async operation failed we need to set the sticky failure status.
-  if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
-    // Enter the failure state; ownership of the status transferred.
-    iree_hal_amdgpu_host_service_fail(token.service, status);
-  }
-
-  iree_hal_amdgpu_host_service_t* service = token.service;
-  iree_hsa_signal_subtract_screlease(IREE_LIBHSA(service->libhsa),
-                                     service->outstanding_signal, 1);
-
-  IREE_TRACE_ZONE_END(z0);
-}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_service.h b/runtime/src/iree/hal/drivers/amdgpu/host_service.h
deleted file mode 100644
index 5fa1461..0000000
--- a/runtime/src/iree/hal/drivers/amdgpu/host_service.h
+++ /dev/null
@@ -1,109 +0,0 @@
-// Copyright 2025 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_HAL_DRIVERS_AMDGPU_HOST_SERVICE_H_
-#define IREE_HAL_DRIVERS_AMDGPU_HOST_SERVICE_H_
-
-#include "iree/base/api.h"
-#include "iree/base/internal/atomics.h"
-#include "iree/base/threading/thread.h"
-#include "iree/hal/drivers/amdgpu/util/error_callback.h"
-#include "iree/hal/drivers/amdgpu/util/libhsa.h"
-
-#ifdef __cplusplus
-extern "C" {
-#endif  // __cplusplus
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_host_service_t
-//===----------------------------------------------------------------------===//
-
-// Capacity in entries of the host service queue.
-// This should not need to be too large as each queue is only able to have ~64
-// outstanding operations but each operation may require several host calls.
-// Effectively this is the maximum queue count sharing a single service * the
-// maximum concurrency on that queue * the maximum pipeline depth of any
-// pipeline that may use host services.
-#define IREE_HAL_AMDGPU_HOST_SERVICE_QUEUE_CAPACITY (1 * 1024)
-
-// A host service managing requests from one or more device queues.
-// Multiple physical devices or queues on a single physical device can share the
-// same host service. Multiple service workers can be used to reduce latency on
-// high core-count systems or locate the worker closer to the devices it manages
-// in NUMA systems.
-//
-// Thread-safe.
-typedef struct iree_hal_amdgpu_host_service_t {
-  // HSA library handle. Unowned.
-  const iree_hal_amdgpu_libhsa_t* libhsa;
-
-  // Optional callback issued when the failure status is first set.
-  iree_hal_amdgpu_error_callback_t error_callback;
-
-  // If the service has received a fatal error from the device it will be stored
-  // here as a status code to prevent duplicate error callbacks.
-  iree_atomic_uint64_t failure_code;
-
-  // OS handle to the worker thread.
-  iree_thread_t* thread;
-
-  // A semaphore signal used to indicate the number of outstanding asynchronous
-  // operations. 0 in the idle state, incremented for each new asynchronous
-  // operation, and decremented when an asynchronous operation completes.
-  // Used to implement the packet barrier bit.
-  hsa_signal_t outstanding_signal;
-
-  // HSA soft queue for incoming requests from devices.
-  hsa_queue_t* queue;
-  // HSA doorbell indicating the queue has been updated.
-  hsa_signal_t doorbell;
-} iree_hal_amdgpu_host_service_t;
-
-// Initializes the service state and launches the worker thread.
-// |libhsa| must remain valid for the lifetime of the service.
-//
-// The |host_ordinal| and |device_ordinal| are used for naming the service
-// worker thread. The worker thread will be pinned to the CPU |host_agent|
-// affinity and have its underlying HSA queue allocated from |host_fine_region|.
-//
-// An optional |error_callback| can be provided to receive notification of the
-// service entering the failure state. The callback may be issued from driver
-// threads and must not re-enter the host service API or make any stateful HSA
-// calls.
-//
-// TODO(benvanik): change device_ordinal to some other disambiguator if we
-// decide to share host workers across devices. Today it is only used for
-// thread/trace naming.
-iree_status_t iree_hal_amdgpu_host_service_initialize(
-    const iree_hal_amdgpu_libhsa_t* libhsa, iree_host_size_t host_ordinal,
-    hsa_agent_t host_agent, hsa_region_t host_fine_region,
-    iree_host_size_t device_ordinal,
-    iree_hal_amdgpu_error_callback_t error_callback,
-    iree_allocator_t host_allocator,
-    iree_hal_amdgpu_host_service_t* out_service);
-
-// Deinitializes the service and terminates the worker thread.
-void iree_hal_amdgpu_host_service_deinitialize(
-    iree_hal_amdgpu_host_service_t* service);
-
-// An asynchronous host operation token.
-typedef struct iree_hal_amdgpu_host_async_token_t {
-  iree_hal_amdgpu_host_service_t* service;
-} iree_hal_amdgpu_host_async_token_t;
-
-// Notifies the service that |async_token| originated from that an asynchronous
-// operation has completed. May be called from any thread. Must only be called
-// once per asynchronous operation. The worker may be immediately deallocated
-// after exiting and this should almost always be a tail call. If the operation
-// failed a |status| can be provided and will be consumed by the call.
-void iree_hal_amdgpu_host_service_notify_completion(
-    iree_hal_amdgpu_host_async_token_t async_token, iree_status_t status);
-
-#ifdef __cplusplus
-}  // extern "C"
-#endif  // __cplusplus
-
-#endif  // IREE_HAL_DRIVERS_AMDGPU_HOST_SERVICE_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/host_service_test.cc b/runtime/src/iree/hal/drivers/amdgpu/host_service_test.cc
deleted file mode 100644
index 1f3d67d..0000000
--- a/runtime/src/iree/hal/drivers/amdgpu/host_service_test.cc
+++ /dev/null
@@ -1,508 +0,0 @@
-// Copyright 2025 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/hal/drivers/amdgpu/host_service.h"
-
-#include <thread>
-
-#include "iree/base/api.h"
-#include "iree/hal/api.h"
-#include "iree/hal/drivers/amdgpu/device/host_client.h"
-#include "iree/hal/drivers/amdgpu/util/topology.h"
-#include "iree/testing/gtest.h"
-#include "iree/testing/status_matchers.h"
-
-namespace iree::hal::amdgpu {
-namespace {
-
-using iree::testing::status::StatusIs;
-
-// Returns an error callback that will assign the
-static iree_hal_amdgpu_error_callback_t MakeErrorCallback(
-    iree_atomic_intptr_t* status_bind) {
-  iree_hal_amdgpu_error_callback_t callback;
-  callback.fn = +[](void* user_data, iree_status_t status) {
-    IREE_TRACE_SCOPE();
-    iree_atomic_store((iree_atomic_intptr_t*)user_data, (intptr_t)status,
-                      iree_memory_order_seq_cst);
-  };
-  callback.user_data = (void*)status_bind;
-  return callback;
-}
-
-// Returns the first fine-grained global region of the |host_agent|.
-static hsa_region_t GetHostGlobalFineRegion(
-    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_agent_t host_agent) {
-  IREE_TRACE_SCOPE();
-  typedef struct iree_hal_amdgpu_hsa_region_list_t {
-    iree_host_size_t count;
-    hsa_region_t values[32];
-  } iree_hal_amdgpu_hsa_region_list_t;
-  iree_hal_amdgpu_hsa_region_list_t all_regions = {
-      .count = 0,
-  };
-  IREE_CHECK_OK(iree_hsa_agent_iterate_regions(
-      IREE_LIBHSA(libhsa), host_agent,
-      +[](hsa_region_t region, void* user_data) -> hsa_status_t {
-        auto* pool_list = (iree_hal_amdgpu_hsa_region_list_t*)user_data;
-        if (pool_list->count + 1 >= IREE_ARRAYSIZE(pool_list->values)) {
-          return HSA_STATUS_ERROR_OUT_OF_RESOURCES;
-        }
-        pool_list->values[pool_list->count++] = region;
-        return HSA_STATUS_SUCCESS;
-      },
-      &all_regions));
-  for (iree_host_size_t i = 0; i < all_regions.count; ++i) {
-    hsa_region_t region = all_regions.values[i];
-    hsa_region_segment_t segment = (hsa_region_segment_t)0;
-    IREE_CHECK_OK(iree_hsa_region_get_info(IREE_LIBHSA(libhsa), region,
-                                           HSA_REGION_INFO_SEGMENT, &segment));
-    if (segment != HSA_REGION_SEGMENT_GLOBAL) continue;
-    bool alloc_allowed = false;
-    IREE_CHECK_OK(iree_hsa_region_get_info(
-        IREE_LIBHSA(libhsa), region, HSA_REGION_INFO_RUNTIME_ALLOC_ALLOWED,
-        &alloc_allowed));
-    if (!alloc_allowed) continue;
-    hsa_region_global_flag_t global_flag = (hsa_region_global_flag_t)0;
-    IREE_CHECK_OK(iree_hsa_region_get_info(IREE_LIBHSA(libhsa), region,
-                                           HSA_REGION_INFO_GLOBAL_FLAGS,
-                                           &global_flag));
-    if (global_flag & HSA_REGION_GLOBAL_FLAG_FINE_GRAINED) {
-      return region;
-    }
-  }
-  return {0};
-}
-
-struct HsaSignal {
-  HsaSignal(const iree_hal_amdgpu_libhsa_t* libhsa,
-            hsa_signal_value_t initial_value = 0)
-      : libhsa(libhsa) {
-    IREE_CHECK_OK(
-        iree_hsa_amd_signal_create(IREE_LIBHSA(libhsa), initial_value, 0, NULL,
-                                   (hsa_amd_signal_attribute_t)0, &value));
-  }
-  ~HsaSignal() {
-    if (value.handle) {
-      iree_hsa_signal_destroy(IREE_LIBHSA(libhsa), value);
-    }
-  }
-  operator hsa_signal_t() const noexcept { return value; }
-  const iree_hal_amdgpu_libhsa_t* libhsa = nullptr;
-  hsa_signal_t value = {0};
-};
-
-// Enqueues a HSA_PACKET_TYPE_BARRIER_AND or HSA_PACKET_TYPE_BARRIER_OR
-// packet to the service queue.
-void EnqueueBarrier(iree_hal_amdgpu_host_service_t* service,
-                    hsa_packet_type_t packet_type,
-                    std::array<hsa_signal_t, 5> dep_signals,
-                    hsa_signal_t completion_signal) {
-  IREE_TRACE_SCOPE();
-
-  const uint64_t packet_id = iree_hsa_queue_add_write_index_relaxed(
-      IREE_LIBHSA(service->libhsa), service->queue, 1u);
-  while (packet_id - iree_hsa_queue_load_read_index_scacquire(
-                         IREE_LIBHSA(service->libhsa), service->queue) >=
-         service->queue->size) {
-    iree_thread_yield();  // spinning
-  }
-  const uint64_t queue_mask = service->queue->size - 1;  // power of two
-  hsa_barrier_or_packet_t* packet =
-      (hsa_barrier_or_packet_t*)((uint8_t*)service->queue->base_address +
-                                 (packet_id & queue_mask) * 64);
-
-  packet->reserved1 = 0;
-  memcpy(&packet->dep_signal[0], &dep_signals[0], sizeof(packet->dep_signal));
-  packet->reserved2 = 0;
-  packet->completion_signal = completion_signal;
-
-  // NOTE: high uint16_t is reserved0.
-  uint32_t header = packet_type << HSA_PACKET_HEADER_TYPE;
-  header |= 1 << HSA_PACKET_HEADER_BARRIER;
-  header |= HSA_FENCE_SCOPE_SYSTEM << HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE;
-  header |= HSA_FENCE_SCOPE_SYSTEM << HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE;
-  iree_atomic_store((iree_atomic_uint32_t*)packet, header,
-                    iree_memory_order_release);
-
-  iree_hsa_signal_store_relaxed(IREE_LIBHSA(service->libhsa), service->doorbell,
-                                packet_id);
-}
-
-// Enqueues a HSA_AMD_PACKET_TYPE_BARRIER_VALUE packet to the service queue.
-void EnqueueBarrierValue(iree_hal_amdgpu_host_service_t* service,
-                         hsa_signal_t dep_signal,
-                         hsa_signal_condition32_t condition,
-                         hsa_signal_value_t condition_value,
-                         hsa_signal_value_t condition_mask,
-                         hsa_signal_t completion_signal) {
-  IREE_TRACE_SCOPE();
-
-  const uint64_t packet_id = iree_hsa_queue_add_write_index_relaxed(
-      IREE_LIBHSA(service->libhsa), service->queue, 1u);
-  while (packet_id - iree_hsa_queue_load_read_index_scacquire(
-                         IREE_LIBHSA(service->libhsa), service->queue) >=
-         service->queue->size) {
-    iree_thread_yield();  // spinning
-  }
-  const uint64_t queue_mask = service->queue->size - 1;  // power of two
-  hsa_amd_barrier_value_packet_t* packet =
-      (hsa_amd_barrier_value_packet_t*)((uint8_t*)service->queue->base_address +
-                                        (packet_id & queue_mask) * 64);
-
-  packet->reserved0 = 0;
-  packet->signal = dep_signal;
-  packet->value = condition_value;
-  packet->mask = condition_mask;
-  packet->cond = condition;
-  packet->reserved1 = 0;
-  packet->reserved2 = 0;
-  packet->reserved3 = 0;
-  packet->completion_signal = completion_signal;
-
-  union {
-    hsa_amd_vendor_packet_header_t vendor_header;
-    uint32_t vendor_header_bits;
-  };
-  vendor_header.header = HSA_PACKET_TYPE_VENDOR_SPECIFIC
-                         << HSA_PACKET_HEADER_TYPE;
-  vendor_header.header |= 1 << HSA_PACKET_HEADER_BARRIER;
-  vendor_header.header |= HSA_FENCE_SCOPE_SYSTEM
-                          << HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE;
-  vendor_header.header |= HSA_FENCE_SCOPE_SYSTEM
-                          << HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE;
-  vendor_header.AmdFormat = HSA_AMD_PACKET_TYPE_BARRIER_VALUE;
-  vendor_header.reserved = 0;
-  iree_atomic_store((iree_atomic_uint32_t*)packet, vendor_header_bits,
-                    iree_memory_order_release);
-
-  iree_hsa_signal_store_relaxed(IREE_LIBHSA(service->libhsa), service->doorbell,
-                                packet_id);
-}
-
-// Enqueues a unidirectional agent packet to the service queue.
-void EnqueuePost(iree_hal_amdgpu_host_service_t* service, uint16_t type,
-                 uint64_t return_address, uint64_t arg0, uint64_t arg1,
-                 uint64_t arg2, uint64_t arg3, hsa_signal_t completion_signal) {
-  IREE_TRACE_SCOPE();
-
-  const uint64_t packet_id = iree_hsa_queue_add_write_index_relaxed(
-      IREE_LIBHSA(service->libhsa), service->queue, 1u);
-  while (packet_id - iree_hsa_queue_load_read_index_scacquire(
-                         IREE_LIBHSA(service->libhsa), service->queue) >=
-         service->queue->size) {
-    iree_thread_yield();  // spinning
-  }
-  const uint64_t queue_mask = service->queue->size - 1;  // power of two
-  hsa_agent_dispatch_packet_t* agent_packet =
-      (hsa_agent_dispatch_packet_t*)((uint8_t*)service->queue->base_address +
-                                     (packet_id & queue_mask) * 64);
-
-  agent_packet->reserved0 = 0;
-  agent_packet->return_address = (void*)return_address;
-  agent_packet->arg[0] = arg0;
-  agent_packet->arg[1] = arg1;
-  agent_packet->arg[2] = arg2;
-  agent_packet->arg[3] = arg3;
-  agent_packet->reserved2 = 0;
-  agent_packet->completion_signal = completion_signal;
-
-  uint16_t header = HSA_PACKET_TYPE_AGENT_DISPATCH << HSA_PACKET_HEADER_TYPE;
-  header |= 1 << HSA_PACKET_HEADER_BARRIER;
-  header |= HSA_FENCE_SCOPE_SYSTEM << HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE;
-  header |= HSA_FENCE_SCOPE_SYSTEM << HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE;
-  const uint32_t header_type = header | (uint32_t)(type << 16);
-  iree_atomic_store((iree_atomic_uint32_t*)agent_packet, header_type,
-                    iree_memory_order_release);
-
-  iree_hsa_signal_store_relaxed(IREE_LIBHSA(service->libhsa), service->doorbell,
-                                packet_id);
-}
-
-struct HostServiceTest : public ::testing::Test {
-  static iree_allocator_t host_allocator;
-  static iree_hal_amdgpu_libhsa_t libhsa;
-  static iree_hal_amdgpu_topology_t topology;
-  static hsa_region_t host_fine_region;
-
-  static void SetUpTestSuite() {
-    IREE_TRACE_SCOPE();
-    host_allocator = iree_allocator_system();
-    iree_status_t status = iree_hal_amdgpu_libhsa_initialize(
-        IREE_HAL_AMDGPU_LIBHSA_FLAG_NONE, iree_string_view_list_empty(),
-        host_allocator, &libhsa);
-    if (!iree_status_is_ok(status)) {
-      iree_status_fprint(stderr, status);
-      iree_status_ignore(status);
-      GTEST_SKIP() << "HSA not available, skipping tests";
-    }
-    IREE_ASSERT_OK(
-        iree_hal_amdgpu_topology_initialize_with_defaults(&libhsa, &topology));
-    if (topology.gpu_agent_count == 0) {
-      GTEST_SKIP() << "no GPU devices available, skipping tests";
-    }
-    host_fine_region = GetHostGlobalFineRegion(&libhsa, topology.cpu_agents[0]);
-  }
-
-  static void TearDownTestSuite() {
-    IREE_TRACE_SCOPE();
-    iree_hal_amdgpu_topology_deinitialize(&topology);
-    iree_hal_amdgpu_libhsa_deinitialize(&libhsa);
-  }
-};
-iree_allocator_t HostServiceTest::host_allocator;
-iree_hal_amdgpu_libhsa_t HostServiceTest::libhsa;
-iree_hal_amdgpu_topology_t HostServiceTest::topology;
-hsa_region_t HostServiceTest::host_fine_region;
-
-// Tests that the host service can be initialized/deinitialized immediately.
-TEST_F(HostServiceTest, Lifetime) {
-  IREE_TRACE_SCOPE();
-
-  iree_atomic_intptr_t service_status = IREE_ATOMIC_VAR_INIT(0);
-  iree_hal_amdgpu_host_service_t service = {0};
-  IREE_ASSERT_OK(iree_hal_amdgpu_host_service_initialize(
-      &libhsa, /*host_ordinal=*/0, topology.cpu_agents[0], host_fine_region,
-      /*device_ordinal=*/0, MakeErrorCallback(&service_status), host_allocator,
-      &service));
-
-  iree_hal_amdgpu_host_service_deinitialize(&service);
-  IREE_EXPECT_OK((iree_status_t)iree_atomic_load(&service_status,
-                                                 iree_memory_order_seq_cst));
-}
-
-// Tests handling of the HSA_PACKET_TYPE_BARRIER_AND packet type.
-TEST_F(HostServiceTest, BarrierAnd) {
-  IREE_TRACE_SCOPE();
-
-  iree_atomic_intptr_t service_status = IREE_ATOMIC_VAR_INIT(0);
-  iree_hal_amdgpu_host_service_t service = {0};
-  IREE_ASSERT_OK(iree_hal_amdgpu_host_service_initialize(
-      &libhsa, /*host_ordinal=*/0, topology.cpu_agents[0], host_fine_region,
-      /*device_ordinal=*/0, MakeErrorCallback(&service_status), host_allocator,
-      &service));
-
-  HsaSignal signal0(&libhsa, /*initial_value=*/1);
-  HsaSignal signal1(&libhsa, /*initial_value=*/1);
-
-  HsaSignal completion_signal(&libhsa, /*initial_value=*/1);
-  EnqueueBarrier(&service, HSA_PACKET_TYPE_BARRIER_AND, {signal0, signal1},
-                 completion_signal);
-
-  std::thread thread0([&]() {
-    std::this_thread::sleep_for(std::chrono::milliseconds(25));
-    iree_hsa_signal_subtract_screlease(IREE_LIBHSA(&libhsa), signal0, 1);
-  });
-  std::thread thread1([&]() {
-    std::this_thread::sleep_for(std::chrono::milliseconds(25));
-    iree_hsa_signal_subtract_screlease(IREE_LIBHSA(&libhsa), signal1, 1);
-  });
-
-  EXPECT_EQ(
-      0, iree_hsa_signal_wait_scacquire(IREE_LIBHSA(&libhsa), completion_signal,
-                                        HSA_SIGNAL_CONDITION_EQ, 0, UINT64_MAX,
-                                        HSA_WAIT_STATE_BLOCKED));
-
-  thread0.join();
-  thread1.join();
-
-  iree_hal_amdgpu_host_service_deinitialize(&service);
-  IREE_EXPECT_OK((iree_status_t)iree_atomic_load(&service_status,
-                                                 iree_memory_order_seq_cst));
-}
-
-// Tests handling of the HSA_PACKET_TYPE_BARRIER_OR packet type.
-TEST_F(HostServiceTest, BarrierOr) {
-  IREE_TRACE_SCOPE();
-
-  iree_atomic_intptr_t service_status = IREE_ATOMIC_VAR_INIT(0);
-  iree_hal_amdgpu_host_service_t service = {0};
-  IREE_ASSERT_OK(iree_hal_amdgpu_host_service_initialize(
-      &libhsa, /*host_ordinal=*/0, topology.cpu_agents[0], host_fine_region,
-      /*device_ordinal=*/0, MakeErrorCallback(&service_status), host_allocator,
-      &service));
-
-  HsaSignal signal0(&libhsa, /*initial_value=*/1);
-  HsaSignal signal1(&libhsa, /*initial_value=*/1);
-
-  HsaSignal completion_signal(&libhsa, /*initial_value=*/1);
-  EnqueueBarrier(&service, HSA_PACKET_TYPE_BARRIER_OR, {signal0, signal1},
-                 completion_signal);
-
-  // NOTE: we only resolve one signal; the other is never resolved so we can
-  // test the OR behavior.
-  std::thread thread0([&]() {
-    std::this_thread::sleep_for(std::chrono::milliseconds(25));
-    iree_hsa_signal_subtract_screlease(IREE_LIBHSA(&libhsa), signal0, 1);
-  });
-
-  EXPECT_EQ(
-      0, iree_hsa_signal_wait_scacquire(IREE_LIBHSA(&libhsa), completion_signal,
-                                        HSA_SIGNAL_CONDITION_EQ, 0, UINT64_MAX,
-                                        HSA_WAIT_STATE_BLOCKED));
-
-  thread0.join();
-
-  iree_hal_amdgpu_host_service_deinitialize(&service);
-  IREE_EXPECT_OK((iree_status_t)iree_atomic_load(&service_status,
-                                                 iree_memory_order_seq_cst));
-}
-
-// Tests handling of the HSA_AMD_PACKET_TYPE_BARRIER_VALUE packet type.
-TEST_F(HostServiceTest, BarrierValue) {
-  IREE_TRACE_SCOPE();
-
-  iree_atomic_intptr_t service_status = IREE_ATOMIC_VAR_INIT(0);
-  iree_hal_amdgpu_host_service_t service = {0};
-  IREE_ASSERT_OK(iree_hal_amdgpu_host_service_initialize(
-      &libhsa, /*host_ordinal=*/0, topology.cpu_agents[0], host_fine_region,
-      /*device_ordinal=*/0, MakeErrorCallback(&service_status), host_allocator,
-      &service));
-
-  HsaSignal signal(&libhsa, /*initial_value=*/0);
-
-  HsaSignal completion_signal(&libhsa, /*initial_value=*/1);
-  EnqueueBarrierValue(&service, signal, HSA_SIGNAL_CONDITION_GTE,
-                      /*condition_value=*/10, /*condition_mask=*/UINT64_MAX,
-                      completion_signal);
-
-  std::thread thread([&]() {
-    std::this_thread::sleep_for(std::chrono::milliseconds(25));
-    iree_hsa_signal_add_screlease(IREE_LIBHSA(&libhsa), signal, 11);
-  });
-
-  EXPECT_EQ(
-      0, iree_hsa_signal_wait_scacquire(IREE_LIBHSA(&libhsa), completion_signal,
-                                        HSA_SIGNAL_CONDITION_EQ, 0, UINT64_MAX,
-                                        HSA_WAIT_STATE_BLOCKED));
-
-  thread.join();
-
-  iree_hal_amdgpu_host_service_deinitialize(&service);
-  IREE_EXPECT_OK((iree_status_t)iree_atomic_load(&service_status,
-                                                 iree_memory_order_seq_cst));
-}
-
-// Tests that service worker errors are propagated to the error callback.
-TEST_F(HostServiceTest, FailureState) {
-  IREE_TRACE_SCOPE();
-
-  iree_atomic_intptr_t service_status = IREE_ATOMIC_VAR_INIT(0);
-  iree_hal_amdgpu_host_service_t service = {0};
-  IREE_ASSERT_OK(iree_hal_amdgpu_host_service_initialize(
-      &libhsa, /*host_ordinal=*/0, topology.cpu_agents[0], host_fine_region,
-      /*device_ordinal=*/0, MakeErrorCallback(&service_status), host_allocator,
-      &service));
-
-  HsaSignal completion_signal(&libhsa, /*initial_value=*/1);
-  EnqueuePost(&service, /*type=*/UINT16_MAX, 0, 0, 0, 0, 0, completion_signal);
-
-  // NOTE: it's not required that the service ever signal completion - it is
-  // allowed to immediately stop processing packets during launch if it desires.
-  // We delay a bit to give the worker time to process the packet and hope that
-  // it reaches its error state. Real usage of the service should never touch
-  // the internal data structures.
-  for (int i = 0; i < 1000; ++i) {
-    if (iree_atomic_load(&service_status, iree_memory_order_seq_cst) != 0) {
-      break;
-    }
-    std::this_thread::sleep_for(std::chrono::milliseconds(5));
-  }
-
-  iree_hal_amdgpu_host_service_deinitialize(&service);
-  EXPECT_THAT((iree_status_t)iree_atomic_load(&service_status,
-                                              iree_memory_order_seq_cst),
-              StatusIs(StatusCode::kInvalidArgument));
-}
-
-typedef struct iree_hal_test_resource_t {
-  iree_hal_resource_t resource;
-  iree_allocator_t host_allocator;
-  iree_atomic_uint32_t* live_count;
-} iree_hal_test_resource_t;
-typedef struct iree_hal_test_resource_vtable_t {
-  void(IREE_API_PTR* destroy)(iree_hal_test_resource_t* resource);
-} iree_hal_test_resource_vtable_t;
-IREE_HAL_ASSERT_VTABLE_LAYOUT(iree_hal_test_resource_vtable_t);
-static const iree_hal_test_resource_vtable_t iree_hal_test_resource_vtable = {
-    /*.destroy=*/+[](iree_hal_test_resource_t* resource) {
-      iree_hal_test_resource_t* test_resource =
-          (iree_hal_test_resource_t*)resource;
-      iree_allocator_t host_allocator = test_resource->host_allocator;
-      iree_atomic_fetch_sub(test_resource->live_count, 1u,
-                            iree_memory_order_seq_cst);
-      iree_allocator_free(host_allocator, test_resource);
-    },
-};
-static iree_status_t iree_hal_test_resource_create(
-    iree_atomic_uint32_t* live_count, iree_allocator_t host_allocator,
-    iree_hal_resource_t** out_resource) {
-  iree_hal_test_resource_t* test_resource = NULL;
-  IREE_RETURN_IF_ERROR(iree_allocator_malloc(
-      host_allocator, sizeof(*test_resource), (void**)&test_resource));
-  iree_hal_resource_initialize(&iree_hal_test_resource_vtable,
-                               &test_resource->resource);
-  test_resource->host_allocator = host_allocator;
-  test_resource->live_count = live_count;
-  iree_atomic_fetch_add(test_resource->live_count, 1u,
-                        iree_memory_order_seq_cst);
-  *out_resource = (iree_hal_resource_t*)test_resource;
-  return iree_ok_status();
-}
-
-// Tests IREE_HAL_AMDGPU_DEVICE_HOST_CALL_POST_RELEASE agent dispatch.
-// This is primarily a leak test as the only thing that will release the
-// resources is the service.
-TEST_F(HostServiceTest, PostRelease) {
-  IREE_TRACE_SCOPE();
-
-  iree_atomic_intptr_t service_status = IREE_ATOMIC_VAR_INIT(0);
-  iree_hal_amdgpu_host_service_t service = {0};
-  IREE_ASSERT_OK(iree_hal_amdgpu_host_service_initialize(
-      &libhsa, /*host_ordinal=*/0, topology.cpu_agents[0], host_fine_region,
-      /*device_ordinal=*/0, MakeErrorCallback(&service_status), host_allocator,
-      &service));
-
-  // Create some resources to test with. Each will +1 the live_count.
-  iree_atomic_uint32_t live_count = IREE_ATOMIC_VAR_INIT(0);
-  iree_hal_resource_t* resources[5] = {NULL};
-  for (iree_host_size_t i = 0; i < IREE_ARRAYSIZE(resources); ++i) {
-    IREE_ASSERT_OK(iree_hal_test_resource_create(&live_count, host_allocator,
-                                                 &resources[i]));
-  }
-
-  // Release all resources in two batches.
-  // Batch 0: only resource[0] in arg2. This tests that NULL values are ignored.
-  // Batch 1: remaining 4 resources[1]/[2]/[3]/[4].
-  HsaSignal completion_signal(&libhsa, /*initial_value=*/2);
-  EnqueuePost(&service, IREE_HAL_AMDGPU_DEVICE_HOST_CALL_POST_RELEASE,
-              /*return_address=*/0ull, /*arg0=*/0ull,
-              /*arg1=*/0ull, /*arg2=*/(uint64_t)resources[0], /*arg3=*/0ull,
-              completion_signal);
-  EnqueuePost(&service, IREE_HAL_AMDGPU_DEVICE_HOST_CALL_POST_RELEASE,
-              /*return_address=*/0ull, /*arg0=*/(uint64_t)resources[1],
-              /*arg1=*/(uint64_t)resources[2], /*arg2=*/(uint64_t)resources[3],
-              /*arg3=*/(uint64_t)resources[4], completion_signal);
-
-  // Await releases to complete.
-  EXPECT_EQ(
-      0, iree_hsa_signal_wait_scacquire(IREE_LIBHSA(&libhsa), completion_signal,
-                                        HSA_SIGNAL_CONDITION_EQ, 0, UINT64_MAX,
-                                        HSA_WAIT_STATE_BLOCKED));
-
-  // All resources should have been released.
-  EXPECT_EQ(0, iree_atomic_load(&live_count, iree_memory_order_seq_cst));
-
-  iree_hal_amdgpu_host_service_deinitialize(&service);
-  IREE_EXPECT_OK((iree_status_t)iree_atomic_load(&service_status,
-                                                 iree_memory_order_seq_cst));
-}
-
-// TODO(benvanik): async iree_hal_amdgpu_host_service_notify_completion when
-// there is a command using it. Today all are unidirectional post-only.
-
-}  // namespace
-}  // namespace iree::hal::amdgpu
diff --git a/runtime/src/iree/hal/drivers/amdgpu/logical_device.c b/runtime/src/iree/hal/drivers/amdgpu/logical_device.c
index c1d4be8..5accac0 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/logical_device.c
+++ b/runtime/src/iree/hal/drivers/amdgpu/logical_device.c
@@ -9,38 +9,75 @@
 #include "iree/async/frontier.h"
 #include "iree/async/frontier_tracker.h"
 #include "iree/async/util/proactor_pool.h"
+#include "iree/hal/drivers/amdgpu/abi/signal.h"
 #include "iree/hal/drivers/amdgpu/allocator.h"
 #include "iree/hal/drivers/amdgpu/api.h"
-#include "iree/hal/drivers/amdgpu/channel.h"
-#include "iree/hal/drivers/amdgpu/command_buffer.h"
-#include "iree/hal/drivers/amdgpu/event.h"
+#include "iree/hal/drivers/amdgpu/aql_command_buffer.h"
+#include "iree/hal/drivers/amdgpu/aql_program_builder.h"
 #include "iree/hal/drivers/amdgpu/executable.h"
 #include "iree/hal/drivers/amdgpu/executable_cache.h"
+#include "iree/hal/drivers/amdgpu/host_queue_profile.h"
+#include "iree/hal/drivers/amdgpu/host_queue_profile_events.h"
 #include "iree/hal/drivers/amdgpu/physical_device.h"
+#include "iree/hal/drivers/amdgpu/profile_counters.h"
+#include "iree/hal/drivers/amdgpu/profile_device_metrics.h"
+#include "iree/hal/drivers/amdgpu/profile_traces.h"
+#include "iree/hal/drivers/amdgpu/queue_affinity.h"
 #include "iree/hal/drivers/amdgpu/semaphore.h"
 #include "iree/hal/drivers/amdgpu/system.h"
-#include "iree/hal/drivers/amdgpu/util/affinity.h"
+#include "iree/hal/drivers/amdgpu/util/epoch_signal_table.h"
+#include "iree/hal/drivers/amdgpu/util/kfd.h"
+#include "iree/hal/drivers/amdgpu/util/notification_ring.h"
 #include "iree/hal/drivers/amdgpu/util/topology.h"
-#include "iree/hal/drivers/amdgpu/virtual_queue.h"
+#include "iree/hal/drivers/amdgpu/util/vmem.h"
 #include "iree/hal/utils/file_registry.h"
-#include "iree/hal/utils/file_transfer.h"
 
 //===----------------------------------------------------------------------===//
 // Utilities
 //===----------------------------------------------------------------------===//
 
-static iree_hal_amdgpu_device_affinity_t
-iree_hal_amdgpu_device_affinity_from_queue_affinity(
-    iree_hal_queue_affinity_t queue_affinity,
-    iree_host_size_t per_device_queue_count) {
-  iree_hal_amdgpu_device_affinity_t device_affinity = 0;
-  IREE_HAL_FOR_QUEUE_AFFINITY(queue_affinity) {
-    const int physical_device_ordinal = queue_ordinal / per_device_queue_count;
-    iree_hal_amdgpu_device_affinity_or_into(device_affinity,
-                                            1ull << physical_device_ordinal);
-  }
-  return device_affinity;
+static iree_hal_amdgpu_queue_affinity_domain_t
+iree_hal_amdgpu_logical_device_queue_affinity_domain(
+    const iree_hal_amdgpu_logical_device_t* logical_device) {
+  return (iree_hal_amdgpu_queue_affinity_domain_t){
+      .supported_affinity = logical_device->queue_affinity_mask,
+      .physical_device_count = logical_device->physical_device_count,
+      .queue_count_per_physical_device =
+          logical_device->system->topology.gpu_agent_queue_count,
+  };
 }
+
+// Returns the queue for a flattened logical queue ordinal.
+static iree_status_t iree_hal_amdgpu_logical_device_queue_from_ordinal(
+    iree_hal_amdgpu_logical_device_t* logical_device,
+    iree_host_size_t queue_ordinal,
+    iree_hal_amdgpu_virtual_queue_t** out_queue) {
+  IREE_ASSERT_ARGUMENT(logical_device);
+  IREE_ASSERT_ARGUMENT(out_queue);
+  *out_queue = NULL;
+
+  iree_hal_amdgpu_queue_affinity_resolved_t resolved;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_queue_affinity_resolve_ordinal(
+      iree_hal_amdgpu_logical_device_queue_affinity_domain(logical_device),
+      queue_ordinal, &resolved));
+
+  iree_hal_amdgpu_physical_device_t* physical_device =
+      logical_device->physical_devices[resolved.physical_device_ordinal];
+  if (IREE_UNLIKELY(resolved.physical_queue_ordinal >=
+                    physical_device->host_queue_count)) {
+    return iree_make_status(IREE_STATUS_INTERNAL,
+                            "queue affinity ordinal %" PRIhsz
+                            " maps to invalid host queue ordinal "
+                            "%" PRIhsz " on physical device %" PRIhsz,
+                            queue_ordinal, resolved.physical_queue_ordinal,
+                            resolved.physical_device_ordinal);
+  }
+
+  *out_queue =
+      &physical_device->host_queues[resolved.physical_queue_ordinal].base;
+  return iree_ok_status();
+}
+
 //===----------------------------------------------------------------------===//
 // iree_hal_amdgpu_logical_device_options_t
 //===----------------------------------------------------------------------===//
@@ -54,7 +91,7 @@
 #define IREE_HAL_AMDGPU_LOGICAL_DEVICE_MIN_SMALL_HOST_BLOCK_SIZE (4 * 1024)
 
 // Power-of-two size for the shared host large block pool in bytes.
-// Used for resource tracking and command buffer recording.
+// Used for resource tracking and other larger host-side transients.
 #define IREE_HAL_AMDGPU_LOGICAL_DEVICE_DEFAULT_LARGE_HOST_BLOCK_SIZE (64 * 1024)
 
 // Minimum size of a large host block (some structures require at least this
@@ -75,6 +112,8 @@
       IREE_HAL_AMDGPU_LOGICAL_DEVICE_DEFAULT_SMALL_HOST_BLOCK_SIZE;
   out_options->host_block_pools.large.block_size =
       IREE_HAL_AMDGPU_LOGICAL_DEVICE_DEFAULT_LARGE_HOST_BLOCK_SIZE;
+  out_options->host_block_pools.command_buffer.usable_block_size =
+      IREE_HAL_AMDGPU_AQL_PROGRAM_DEFAULT_BLOCK_SIZE;
 
   out_options->device_block_pools.small.block_size =
       IREE_HAL_AMDGPU_PHYSICAL_DEVICE_SMALL_DEVICE_BLOCK_SIZE_DEFAULT;
@@ -85,7 +124,22 @@
   out_options->device_block_pools.large.initial_capacity =
       IREE_HAL_AMDGPU_PHYSICAL_DEVICE_LARGE_DEVICE_BLOCK_INITIAL_CAPACITY_DEFAULT;
 
+  out_options->default_pool.range_length =
+      IREE_HAL_AMDGPU_PHYSICAL_DEVICE_DEFAULT_POOL_RANGE_LENGTH_DEFAULT;
+  out_options->default_pool.alignment =
+      IREE_HAL_AMDGPU_PHYSICAL_DEVICE_DEFAULT_POOL_ALIGNMENT_DEFAULT;
+  out_options->default_pool.frontier_capacity =
+      IREE_HAL_AMDGPU_PHYSICAL_DEVICE_DEFAULT_POOL_FRONTIER_CAPACITY_DEFAULT;
+
   out_options->queue_placement = IREE_HAL_AMDGPU_QUEUE_PLACEMENT_ANY;
+  out_options->host_queues.aql_capacity =
+      IREE_HAL_AMDGPU_PHYSICAL_DEVICE_DEFAULT_HOST_QUEUE_AQL_CAPACITY;
+  out_options->host_queues.notification_capacity =
+      IREE_HAL_AMDGPU_PHYSICAL_DEVICE_DEFAULT_HOST_QUEUE_NOTIFICATION_CAPACITY;
+  out_options->host_queues.kernarg_capacity =
+      IREE_HAL_AMDGPU_PHYSICAL_DEVICE_DEFAULT_HOST_QUEUE_KERNARG_CAPACITY;
+  out_options->host_queues.upload_capacity =
+      IREE_HAL_AMDGPU_PHYSICAL_DEVICE_DEFAULT_HOST_QUEUE_UPLOAD_CAPACITY;
 
   out_options->preallocate_pools = 1;
 }
@@ -94,12 +148,51 @@
     iree_hal_amdgpu_logical_device_options_t* options,
     iree_string_pair_list_t params) {
   IREE_ASSERT_ARGUMENT(options);
-  if (!params.count) return iree_ok_status();  // no-op
+  if (!params.count) return iree_ok_status();
   IREE_TRACE_ZONE_BEGIN(z0);
 
-  // TODO(benvanik): parameters.
+  const iree_string_pair_t* first_param = &params.pairs[0];
+  iree_status_t status = iree_make_status(
+      IREE_STATUS_INVALID_ARGUMENT,
+      "AMDGPU logical device options do not support key/value parameter '%.*s'",
+      (int)first_param->key.size, first_param->key.data);
 
   IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+iree_status_t iree_hal_amdgpu_logical_device_options_verify_supported_features(
+    const iree_hal_amdgpu_logical_device_options_t* options) {
+  IREE_ASSERT_ARGUMENT(options);
+  switch (options->queue_placement) {
+    case IREE_HAL_AMDGPU_QUEUE_PLACEMENT_ANY:
+    case IREE_HAL_AMDGPU_QUEUE_PLACEMENT_HOST:
+      break;
+    case IREE_HAL_AMDGPU_QUEUE_PLACEMENT_DEVICE:
+      return iree_make_status(
+          IREE_STATUS_UNIMPLEMENTED,
+          "AMDGPU device queue placement is not implemented; use "
+          "queue_placement=any or queue_placement=host");
+    default:
+      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                              "invalid AMDGPU queue placement value %u",
+                              (uint32_t)options->queue_placement);
+  }
+  if (options->exclusive_execution) {
+    return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                            "AMDGPU exclusive_execution is not implemented");
+  }
+  if (options->wait_active_for_ns < 0) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "AMDGPU wait_active_for_ns must be non-negative (got %" PRId64 ")",
+        options->wait_active_for_ns);
+  }
+  if (options->wait_active_for_ns != 0) {
+    return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                            "AMDGPU wait_active_for_ns is not implemented; "
+                            "use 0");
+  }
   return iree_ok_status();
 }
 
@@ -111,30 +204,97 @@
   IREE_ASSERT_ARGUMENT(topology);
   IREE_TRACE_ZONE_BEGIN(z0);
 
-  // TODO(benvanik): verify that the parameters are within expected ranges and
-  // any requested features are supported.
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_hal_amdgpu_logical_device_options_verify_supported_features(
+              options));
 
   if (options->host_block_pools.small.block_size <
           IREE_HAL_AMDGPU_LOGICAL_DEVICE_MIN_SMALL_HOST_BLOCK_SIZE ||
       !iree_host_size_is_power_of_two(
           options->host_block_pools.small.block_size)) {
-    return iree_make_status(
-        IREE_STATUS_OUT_OF_RANGE,
-        "small host block pool size invalid, expected a "
-        "power-of-two greater than %d and got %" PRIhsz,
-        IREE_HAL_AMDGPU_LOGICAL_DEVICE_MIN_SMALL_HOST_BLOCK_SIZE,
-        options->host_block_pools.small.block_size);
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_make_status(
+                IREE_STATUS_OUT_OF_RANGE,
+                "small host block pool size invalid, expected a "
+                "power-of-two greater than %d and got %" PRIhsz,
+                IREE_HAL_AMDGPU_LOGICAL_DEVICE_MIN_SMALL_HOST_BLOCK_SIZE,
+                options->host_block_pools.small.block_size));
   }
   if (options->host_block_pools.large.block_size <
           IREE_HAL_AMDGPU_LOGICAL_DEVICE_MIN_LARGE_HOST_BLOCK_SIZE ||
       !iree_host_size_is_power_of_two(
           options->host_block_pools.large.block_size)) {
-    return iree_make_status(
-        IREE_STATUS_OUT_OF_RANGE,
-        "large host block pool size invalid, expected a "
-        "power-of-two greater than %d and got %" PRIhsz,
-        IREE_HAL_AMDGPU_LOGICAL_DEVICE_MIN_LARGE_HOST_BLOCK_SIZE,
-        options->host_block_pools.large.block_size);
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_make_status(
+                IREE_STATUS_OUT_OF_RANGE,
+                "large host block pool size invalid, expected a "
+                "power-of-two greater than %d and got %" PRIhsz,
+                IREE_HAL_AMDGPU_LOGICAL_DEVICE_MIN_LARGE_HOST_BLOCK_SIZE,
+                options->host_block_pools.large.block_size));
+  }
+  if (options->host_block_pools.command_buffer.usable_block_size <
+          IREE_HAL_AMDGPU_AQL_PROGRAM_MIN_BLOCK_SIZE ||
+      options->host_block_pools.command_buffer.usable_block_size > UINT32_MAX ||
+      !iree_host_size_is_power_of_two(
+          options->host_block_pools.command_buffer.usable_block_size)) {
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_make_status(
+                IREE_STATUS_OUT_OF_RANGE,
+                "command-buffer host block pool usable size invalid, expected "
+                "a power-of-two between %u and %u and got %" PRIhsz,
+                IREE_HAL_AMDGPU_AQL_PROGRAM_MIN_BLOCK_SIZE, UINT32_MAX,
+                options->host_block_pools.command_buffer.usable_block_size));
+  }
+
+  if (topology->gpu_agent_queue_count > UINT8_MAX) {
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                             "gpu_agent_queue_count=%" PRIhsz
+                             " exceeds the queue-axis encoding limit (%u)",
+                             topology->gpu_agent_queue_count, UINT8_MAX));
+  }
+  iree_host_size_t total_queue_count = 0;
+  if (!iree_host_size_checked_mul(topology->gpu_agent_count,
+                                  topology->gpu_agent_queue_count,
+                                  &total_queue_count) ||
+      total_queue_count > IREE_HAL_MAX_QUEUES) {
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0,
+        iree_make_status(
+            IREE_STATUS_OUT_OF_RANGE,
+            "topology queue space does not fit in iree_hal_queue_affinity_t "
+            "(gpu_agent_count=%" PRIhsz ", gpu_agent_queue_count=%" PRIhsz
+            ", max_total_queues=%" PRIhsz ")",
+            topology->gpu_agent_count, topology->gpu_agent_queue_count,
+            (iree_host_size_t)IREE_HAL_MAX_QUEUES));
+  }
+  if (!iree_host_size_is_power_of_two(options->host_queues.aql_capacity) ||
+      !iree_host_size_is_power_of_two(
+          options->host_queues.notification_capacity) ||
+      !iree_host_size_is_power_of_two(options->host_queues.kernarg_capacity) ||
+      (options->host_queues.upload_capacity != 0 &&
+       !iree_host_size_is_power_of_two(options->host_queues.upload_capacity))) {
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                             "host queue AQL, notification, kernarg, and "
+                             "upload capacities must all be powers of two, "
+                             "with zero allowed for disabled upload capacity "
+                             "(got aql=%u, notification=%u, kernarg_blocks=%u, "
+                             "upload_bytes=%u)",
+                             options->host_queues.aql_capacity,
+                             options->host_queues.notification_capacity,
+                             options->host_queues.kernarg_capacity,
+                             options->host_queues.upload_capacity));
+  }
+  if (options->host_queues.kernarg_capacity / 2u <
+      options->host_queues.aql_capacity) {
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_make_status(
+                IREE_STATUS_OUT_OF_RANGE,
+                "host queue kernarg capacity must be at least 2x the AQL queue "
+                "capacity (got kernarg_blocks=%u, aql_packets=%u)",
+                options->host_queues.kernarg_capacity,
+                options->host_queues.aql_capacity));
   }
 
   IREE_TRACE_ZONE_END(z0);
@@ -153,6 +313,887 @@
   return (iree_hal_amdgpu_logical_device_t*)base_value;
 }
 
+static bool iree_hal_amdgpu_logical_device_profiling_needs_hsa_timestamps(
+    iree_hal_device_profiling_data_families_t data_families) {
+  return iree_any_bit_set(data_families,
+                          IREE_HAL_DEVICE_PROFILING_DATA_DEVICE_QUEUE_EVENTS |
+                              IREE_HAL_DEVICE_PROFILING_DATA_DISPATCH_EVENTS |
+                              IREE_HAL_DEVICE_PROFILING_DATA_COUNTER_SAMPLES |
+                              IREE_HAL_DEVICE_PROFILING_DATA_EXECUTABLE_TRACES);
+}
+
+static bool iree_hal_amdgpu_logical_device_profiling_needs_clock_correlations(
+    iree_hal_device_profiling_data_families_t data_families) {
+  return iree_hal_amdgpu_logical_device_profiling_needs_hsa_timestamps(
+             data_families) ||
+         iree_any_bit_set(data_families,
+                          IREE_HAL_DEVICE_PROFILING_DATA_COUNTER_RANGES);
+}
+
+static iree_hal_device_profiling_data_families_t
+iree_hal_amdgpu_logical_device_lightweight_statistics_data_families(void) {
+  return IREE_HAL_DEVICE_PROFILING_DATA_EXECUTABLE_METADATA |
+         IREE_HAL_DEVICE_PROFILING_DATA_DEVICE_QUEUE_EVENTS |
+         IREE_HAL_DEVICE_PROFILING_DATA_DISPATCH_EVENTS;
+}
+
+static iree_hal_device_profiling_options_t
+iree_hal_amdgpu_logical_device_resolve_profiling_options(
+    const iree_hal_device_profiling_options_t* options) {
+  iree_hal_device_profiling_options_t resolved_options = *options;
+  if (resolved_options.data_families == IREE_HAL_DEVICE_PROFILING_DATA_NONE &&
+      iree_hal_device_profiling_options_requests_lightweight_statistics(
+          options)) {
+    resolved_options.data_families =
+        iree_hal_amdgpu_logical_device_lightweight_statistics_data_families();
+  }
+  resolved_options.flags &=
+      ~IREE_HAL_DEVICE_PROFILING_FLAG_LIGHTWEIGHT_STATISTICS;
+  return resolved_options;
+}
+
+// Power-of-two capacity for logical-device memory lifecycle event buffering.
+#define IREE_HAL_AMDGPU_LOGICAL_DEVICE_PROFILE_MEMORY_EVENT_CAPACITY (64 * 1024)
+
+// Power-of-two capacity for logical-device queue operation event buffering.
+#define IREE_HAL_AMDGPU_LOGICAL_DEVICE_PROFILE_QUEUE_EVENT_CAPACITY (64 * 1024)
+
+static iree_hal_profile_chunk_metadata_t
+iree_hal_amdgpu_logical_device_profile_session_metadata(
+    iree_hal_amdgpu_logical_device_t* logical_device, uint64_t session_id) {
+  iree_hal_profile_chunk_metadata_t metadata =
+      iree_hal_profile_chunk_metadata_default();
+  metadata.content_type = IREE_HAL_PROFILE_CONTENT_TYPE_SESSION;
+  metadata.name = logical_device->identifier;
+  metadata.session_id = session_id;
+  return metadata;
+}
+
+static uint64_t iree_hal_amdgpu_logical_device_profile_queue_stream_id(
+    uint32_t physical_device_ordinal, uint32_t queue_ordinal) {
+  return ((uint64_t)physical_device_ordinal << 32) | (uint64_t)queue_ordinal;
+}
+
+static bool iree_hal_amdgpu_logical_device_profile_memory_events_requested(
+    const iree_hal_amdgpu_logical_device_t* logical_device) {
+  return iree_hal_device_profiling_options_requests_data(
+             &logical_device->profiling.options,
+             IREE_HAL_DEVICE_PROFILING_DATA_MEMORY_EVENTS) &&
+         logical_device->profiling.options.sink &&
+         iree_hal_amdgpu_profile_event_streams_has_memory_storage(
+             &logical_device->profiling.event_streams);
+}
+
+bool iree_hal_amdgpu_logical_device_should_record_profile_memory_events(
+    iree_hal_device_t* base_device) {
+  iree_hal_amdgpu_logical_device_t* logical_device =
+      iree_hal_amdgpu_logical_device_cast(base_device);
+  return iree_hal_amdgpu_logical_device_profile_memory_events_requested(
+      logical_device);
+}
+
+static void iree_hal_amdgpu_logical_device_reset_profile_options(
+    iree_hal_amdgpu_logical_device_t* logical_device) {
+  iree_hal_device_profiling_options_storage_free(
+      logical_device->profiling.options_storage,
+      logical_device->host_allocator);
+  logical_device->profiling.options_storage = NULL;
+  logical_device->profiling.options = (iree_hal_device_profiling_options_t){0};
+}
+
+bool iree_hal_amdgpu_logical_device_should_profile_dispatch(
+    iree_hal_amdgpu_logical_device_t* logical_device, uint64_t executable_id,
+    uint32_t export_ordinal, uint64_t command_buffer_id, uint32_t command_index,
+    uint32_t physical_device_ordinal, uint32_t queue_ordinal) {
+  if (!iree_any_bit_set(logical_device->profiling.options.data_families,
+                        IREE_HAL_DEVICE_PROFILING_DATA_DISPATCH_EVENTS |
+                            IREE_HAL_DEVICE_PROFILING_DATA_COUNTER_SAMPLES |
+                            IREE_HAL_DEVICE_PROFILING_DATA_EXECUTABLE_TRACES)) {
+    return false;
+  }
+
+  const iree_hal_profile_capture_filter_t* filter =
+      &logical_device->profiling.options.capture_filter;
+  if (!iree_hal_profile_capture_filter_matches_location(
+          filter, command_buffer_id, command_index, physical_device_ordinal,
+          queue_ordinal)) {
+    return false;
+  }
+  if (iree_any_bit_set(
+          filter->flags,
+          IREE_HAL_PROFILE_CAPTURE_FILTER_FLAG_EXECUTABLE_EXPORT_PATTERN)) {
+    return iree_hal_amdgpu_profile_metadata_export_matches(
+        &logical_device->profile_metadata, executable_id, export_ordinal,
+        filter->executable_export_pattern);
+  }
+  return true;
+}
+
+uint64_t iree_hal_amdgpu_logical_device_allocate_profile_memory_allocation_id(
+    iree_hal_device_t* base_device, uint64_t* out_session_id) {
+  iree_hal_amdgpu_logical_device_t* logical_device =
+      iree_hal_amdgpu_logical_device_cast(base_device);
+  *out_session_id = 0;
+  if (!iree_hal_amdgpu_logical_device_profile_memory_events_requested(
+          logical_device)) {
+    return 0;
+  }
+
+  return iree_hal_amdgpu_profile_event_streams_allocate_memory_allocation_id(
+      &logical_device->profiling.event_streams,
+      logical_device->profiling.session_id, out_session_id);
+}
+
+bool iree_hal_amdgpu_logical_device_record_profile_memory_event_for_session(
+    iree_hal_device_t* base_device, uint64_t session_id,
+    const iree_hal_profile_memory_event_t* event) {
+  iree_hal_amdgpu_logical_device_t* logical_device =
+      iree_hal_amdgpu_logical_device_cast(base_device);
+  if (!iree_hal_amdgpu_logical_device_profile_memory_events_requested(
+          logical_device)) {
+    return false;
+  }
+
+  return iree_hal_amdgpu_profile_event_streams_record_memory_event(
+      &logical_device->profiling.event_streams,
+      logical_device->profiling.session_id, session_id, event);
+}
+
+bool iree_hal_amdgpu_logical_device_record_profile_memory_event(
+    iree_hal_device_t* base_device,
+    const iree_hal_profile_memory_event_t* event) {
+  return iree_hal_amdgpu_logical_device_record_profile_memory_event_for_session(
+      base_device, /*session_id=*/0, event);
+}
+
+static bool iree_hal_amdgpu_logical_device_profile_queue_events_requested(
+    const iree_hal_amdgpu_logical_device_t* logical_device) {
+  return iree_hal_device_profiling_options_requests_data(
+             &logical_device->profiling.options,
+             IREE_HAL_DEVICE_PROFILING_DATA_QUEUE_EVENTS) &&
+         logical_device->profiling.options.sink &&
+         iree_hal_amdgpu_profile_event_streams_has_queue_storage(
+             &logical_device->profiling.event_streams);
+}
+
+void iree_hal_amdgpu_logical_device_record_profile_queue_event(
+    iree_hal_device_t* base_device,
+    const iree_hal_profile_queue_event_t* event) {
+  iree_hal_amdgpu_logical_device_t* logical_device =
+      iree_hal_amdgpu_logical_device_cast(base_device);
+  if (!iree_hal_amdgpu_logical_device_profile_queue_events_requested(
+          logical_device)) {
+    return;
+  }
+
+  iree_hal_amdgpu_profile_event_streams_record_queue_event(
+      &logical_device->profiling.event_streams, event);
+}
+
+static iree_status_t
+iree_hal_amdgpu_logical_device_sample_profile_clock_correlation(
+    iree_hal_amdgpu_logical_device_t* logical_device,
+    iree_hal_amdgpu_physical_device_t* physical_device,
+    iree_hal_profile_clock_correlation_record_t* out_record) {
+  if (IREE_UNLIKELY(physical_device->device_ordinal > UINT32_MAX)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "profile clock correlation physical device ordinal out of range: "
+        "%" PRIhsz,
+        physical_device->device_ordinal);
+  }
+
+  iree_hal_amdgpu_device_clock_counters_t counters = {0};
+  const iree_time_t host_time_begin_ns = iree_time_now();
+  iree_status_t status = iree_hal_amdgpu_device_clock_source_sample(
+      &logical_device->system->device_clock_source, physical_device->driver_uid,
+      &counters);
+  const iree_time_t host_time_end_ns = iree_time_now();
+
+  if (iree_status_is_ok(status)) {
+    *out_record = iree_hal_profile_clock_correlation_record_default();
+    out_record->flags =
+        IREE_HAL_PROFILE_CLOCK_CORRELATION_FLAG_DEVICE_TICK |
+        IREE_HAL_PROFILE_CLOCK_CORRELATION_FLAG_HOST_CPU_TIMESTAMP |
+        IREE_HAL_PROFILE_CLOCK_CORRELATION_FLAG_HOST_SYSTEM_TIMESTAMP |
+        IREE_HAL_PROFILE_CLOCK_CORRELATION_FLAG_HOST_TIME_BRACKET;
+    out_record->physical_device_ordinal =
+        (uint32_t)physical_device->device_ordinal;
+    out_record->sample_id =
+        logical_device->profiling.next_clock_correlation_sample_id++;
+    out_record->device_tick = counters.device_clock_counter;
+    out_record->host_cpu_timestamp_ns = counters.host_cpu_timestamp_ns;
+    out_record->host_system_timestamp = counters.host_system_timestamp;
+    out_record->host_system_frequency_hz = counters.host_system_frequency_hz;
+    out_record->host_time_begin_ns = host_time_begin_ns;
+    out_record->host_time_end_ns = host_time_end_ns;
+  } else {
+    status = iree_status_annotate_f(
+        status,
+        "sampling profile clock correlation for physical_device_ordinal=%zu "
+        "driver_uid=%" PRIu32,
+        physical_device->device_ordinal, physical_device->driver_uid);
+  }
+  return status;
+}
+
+static iree_status_t iree_hal_amdgpu_logical_device_write_profile_devices(
+    iree_hal_amdgpu_logical_device_t* logical_device,
+    iree_hal_profile_sink_t* sink, uint64_t session_id) {
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  const iree_host_size_t record_count = logical_device->physical_device_count;
+  if (record_count == 0) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_make_status(
+        IREE_STATUS_INTERNAL,
+        "logical device has no physical devices (initialization incomplete)");
+  }
+
+  iree_host_size_t records_size = 0;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, IREE_STRUCT_LAYOUT(
+              0, &records_size,
+              IREE_STRUCT_FIELD(record_count, iree_hal_profile_device_record_t,
+                                NULL)));
+  iree_hal_profile_device_record_t* records = NULL;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_allocator_malloc(logical_device->host_allocator, records_size,
+                                (void**)&records));
+
+  iree_status_t status = iree_ok_status();
+  for (iree_host_size_t i = 0; i < record_count && iree_status_is_ok(status);
+       ++i) {
+    iree_hal_amdgpu_physical_device_t* physical_device =
+        logical_device->physical_devices[i];
+    if (IREE_UNLIKELY(physical_device->device_ordinal > UINT32_MAX ||
+                      physical_device->host_queue_count > UINT32_MAX)) {
+      status = iree_make_status(
+          IREE_STATUS_OUT_OF_RANGE,
+          "profile device metadata ordinals out of range: device=%" PRIhsz
+          ", queue_count=%" PRIhsz,
+          physical_device->device_ordinal, physical_device->host_queue_count);
+      break;
+    }
+
+    records[i] = iree_hal_profile_device_record_default();
+    records[i].physical_device_ordinal =
+        (uint32_t)physical_device->device_ordinal;
+    records[i].queue_count = (uint32_t)physical_device->host_queue_count;
+    if (physical_device->has_physical_device_uuid) {
+      records[i].flags |= IREE_HAL_PROFILE_DEVICE_FLAG_PHYSICAL_DEVICE_UUID;
+      memcpy(records[i].physical_device_uuid,
+             physical_device->physical_device_uuid,
+             sizeof(records[i].physical_device_uuid));
+    }
+  }
+
+  if (iree_status_is_ok(status)) {
+    iree_hal_profile_chunk_metadata_t metadata =
+        iree_hal_profile_chunk_metadata_default();
+    metadata.content_type = IREE_HAL_PROFILE_CONTENT_TYPE_DEVICES;
+    metadata.name = logical_device->identifier;
+    metadata.session_id = session_id;
+    iree_const_byte_span_t iovec =
+        iree_make_const_byte_span(records, records_size);
+    status = iree_hal_profile_sink_write(sink, &metadata, 1, &iovec);
+  }
+
+  iree_allocator_free(logical_device->host_allocator, records);
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+static iree_status_t iree_hal_amdgpu_logical_device_write_profile_queues(
+    iree_hal_amdgpu_logical_device_t* logical_device,
+    iree_hal_profile_sink_t* sink, uint64_t session_id) {
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  iree_host_size_t record_count = 0;
+  for (iree_host_size_t i = 0; i < logical_device->physical_device_count; ++i) {
+    iree_hal_amdgpu_physical_device_t* physical_device =
+        logical_device->physical_devices[i];
+    if (IREE_UNLIKELY(!iree_host_size_checked_add(
+            record_count, physical_device->host_queue_count, &record_count))) {
+      IREE_TRACE_ZONE_END(z0);
+      return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                              "profile queue metadata count overflow");
+    }
+  }
+  if (record_count == 0) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_make_status(
+        IREE_STATUS_INTERNAL,
+        "logical device has no host queues (initialization incomplete)");
+  }
+
+  iree_host_size_t records_size = 0;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, IREE_STRUCT_LAYOUT(
+              0, &records_size,
+              IREE_STRUCT_FIELD(record_count, iree_hal_profile_queue_record_t,
+                                NULL)));
+  iree_hal_profile_queue_record_t* records = NULL;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_allocator_malloc(logical_device->host_allocator, records_size,
+                                (void**)&records));
+
+  iree_status_t status = iree_ok_status();
+  iree_host_size_t record_ordinal = 0;
+  for (iree_host_size_t i = 0;
+       i < logical_device->physical_device_count && iree_status_is_ok(status);
+       ++i) {
+    iree_hal_amdgpu_physical_device_t* physical_device =
+        logical_device->physical_devices[i];
+    if (IREE_UNLIKELY(physical_device->device_ordinal > UINT32_MAX)) {
+      status = iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                                "profile queue metadata physical device "
+                                "ordinal out of range: %" PRIhsz,
+                                physical_device->device_ordinal);
+      break;
+    }
+    const uint32_t physical_device_ordinal =
+        (uint32_t)physical_device->device_ordinal;
+    for (iree_host_size_t j = 0;
+         j < physical_device->host_queue_count && iree_status_is_ok(status);
+         ++j) {
+      if (IREE_UNLIKELY(j > UINT32_MAX)) {
+        status = iree_make_status(
+            IREE_STATUS_OUT_OF_RANGE,
+            "profile queue metadata queue ordinal out of range: %" PRIhsz, j);
+        break;
+      }
+      const uint32_t queue_ordinal = (uint32_t)j;
+      records[record_ordinal] = iree_hal_profile_queue_record_default();
+      records[record_ordinal].physical_device_ordinal = physical_device_ordinal;
+      records[record_ordinal].queue_ordinal = queue_ordinal;
+      records[record_ordinal].stream_id =
+          iree_hal_amdgpu_logical_device_profile_queue_stream_id(
+              physical_device_ordinal, queue_ordinal);
+      ++record_ordinal;
+    }
+  }
+
+  if (iree_status_is_ok(status)) {
+    iree_hal_profile_chunk_metadata_t metadata =
+        iree_hal_profile_chunk_metadata_default();
+    metadata.content_type = IREE_HAL_PROFILE_CONTENT_TYPE_QUEUES;
+    metadata.name = logical_device->identifier;
+    metadata.session_id = session_id;
+    iree_const_byte_span_t iovec =
+        iree_make_const_byte_span(records, records_size);
+    status = iree_hal_profile_sink_write(sink, &metadata, 1, &iovec);
+  }
+
+  iree_allocator_free(logical_device->host_allocator, records);
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+static iree_status_t
+iree_hal_amdgpu_logical_device_write_profile_clock_correlations(
+    iree_hal_amdgpu_logical_device_t* logical_device,
+    iree_hal_profile_sink_t* sink, uint64_t session_id) {
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  const iree_host_size_t record_count = logical_device->physical_device_count;
+  if (record_count == 0) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_make_status(
+        IREE_STATUS_INTERNAL,
+        "logical device has no physical devices (initialization incomplete)");
+  }
+
+  iree_host_size_t records_size = 0;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, IREE_STRUCT_LAYOUT(
+              0, &records_size,
+              IREE_STRUCT_FIELD(record_count,
+                                iree_hal_profile_clock_correlation_record_t,
+                                NULL)));
+  iree_hal_profile_clock_correlation_record_t* records = NULL;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_allocator_malloc(logical_device->host_allocator, records_size,
+                                (void**)&records));
+
+  iree_status_t status = iree_ok_status();
+  for (iree_host_size_t i = 0; i < record_count && iree_status_is_ok(status);
+       ++i) {
+    status = iree_hal_amdgpu_logical_device_sample_profile_clock_correlation(
+        logical_device, logical_device->physical_devices[i], &records[i]);
+  }
+
+  if (iree_status_is_ok(status)) {
+    iree_hal_profile_chunk_metadata_t metadata =
+        iree_hal_profile_chunk_metadata_default();
+    metadata.content_type = IREE_HAL_PROFILE_CONTENT_TYPE_CLOCK_CORRELATIONS;
+    metadata.name = logical_device->identifier;
+    metadata.session_id = session_id;
+    iree_const_byte_span_t iovec =
+        iree_make_const_byte_span(records, records_size);
+    status = iree_hal_profile_sink_write(sink, &metadata, 1, &iovec);
+  }
+
+  iree_allocator_free(logical_device->host_allocator, records);
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+static bool iree_hal_amdgpu_logical_device_profile_needs_executable_artifacts(
+    iree_hal_device_profiling_data_families_t data_families) {
+  return iree_any_bit_set(data_families,
+                          IREE_HAL_DEVICE_PROFILING_DATA_EXECUTABLE_METADATA |
+                              IREE_HAL_DEVICE_PROFILING_DATA_EXECUTABLE_TRACES);
+}
+
+static iree_status_t iree_hal_amdgpu_logical_device_write_profile_metadata(
+    iree_hal_amdgpu_logical_device_t* logical_device,
+    iree_hal_profile_sink_t* sink, uint64_t session_id,
+    iree_hal_device_profiling_data_families_t data_families) {
+  const bool emit_executable_artifacts =
+      iree_hal_amdgpu_logical_device_profile_needs_executable_artifacts(
+          data_families);
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_logical_device_write_profile_devices(
+      logical_device, sink, session_id));
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_logical_device_write_profile_queues(
+      logical_device, sink, session_id));
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_profile_metadata_write(
+      &logical_device->profile_metadata, sink, session_id,
+      logical_device->identifier, emit_executable_artifacts,
+      &logical_device->profiling.metadata_cursor));
+  if (iree_hal_amdgpu_logical_device_profiling_needs_clock_correlations(
+          data_families)) {
+    return iree_hal_amdgpu_logical_device_write_profile_clock_correlations(
+        logical_device, sink, session_id);
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_logical_device_write_profile_events(
+    iree_hal_amdgpu_logical_device_t* logical_device,
+    iree_hal_profile_sink_t* sink, uint64_t session_id) {
+  IREE_TRACE_ZONE_BEGIN(z0);
+  iree_status_t status = iree_hal_amdgpu_profile_event_streams_write_queue(
+      &logical_device->profiling.event_streams, sink, session_id,
+      logical_device->host_allocator);
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_profile_event_streams_write_memory(
+        &logical_device->profiling.event_streams, sink, session_id,
+        logical_device->host_allocator);
+  }
+  for (iree_host_size_t i = 0;
+       i < logical_device->physical_device_count && iree_status_is_ok(status);
+       ++i) {
+    iree_hal_amdgpu_physical_device_t* physical_device =
+        logical_device->physical_devices[i];
+    for (iree_host_size_t j = 0;
+         j < physical_device->host_queue_count && iree_status_is_ok(status);
+         ++j) {
+      status = iree_hal_amdgpu_host_queue_write_profile_events(
+          &physical_device->host_queues[j], sink, session_id);
+    }
+  }
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+static iree_hal_amdgpu_host_queue_profile_flags_t
+iree_hal_amdgpu_logical_device_queue_profile_flags(
+    const iree_hal_device_profiling_options_t* options) {
+  iree_hal_amdgpu_host_queue_profile_flags_t flags =
+      IREE_HAL_AMDGPU_HOST_QUEUE_PROFILE_FLAG_NONE;
+  if (iree_hal_device_profiling_options_requests_data(
+          options, IREE_HAL_DEVICE_PROFILING_DATA_QUEUE_EVENTS)) {
+    flags |= IREE_HAL_AMDGPU_HOST_QUEUE_PROFILE_FLAG_QUEUE_EVENTS;
+  }
+  if (iree_hal_device_profiling_options_requests_data(
+          options, IREE_HAL_DEVICE_PROFILING_DATA_DEVICE_QUEUE_EVENTS)) {
+    flags |= IREE_HAL_AMDGPU_HOST_QUEUE_PROFILE_FLAG_QUEUE_DEVICE_EVENTS;
+  }
+  if (iree_any_bit_set(options->data_families,
+                       IREE_HAL_DEVICE_PROFILING_DATA_DISPATCH_EVENTS |
+                           IREE_HAL_DEVICE_PROFILING_DATA_COUNTER_SAMPLES |
+                           IREE_HAL_DEVICE_PROFILING_DATA_EXECUTABLE_TRACES)) {
+    flags |= IREE_HAL_AMDGPU_HOST_QUEUE_PROFILE_FLAG_DISPATCHES;
+  }
+  return flags;
+}
+
+static void iree_hal_amdgpu_logical_device_set_queue_profiling_enabled(
+    iree_hal_amdgpu_logical_device_t* logical_device,
+    iree_hal_amdgpu_host_queue_profile_flags_t flags) {
+  for (iree_host_size_t i = 0; i < logical_device->physical_device_count; ++i) {
+    iree_hal_amdgpu_physical_device_t* physical_device =
+        logical_device->physical_devices[i];
+    for (iree_host_size_t j = 0; j < physical_device->host_queue_count; ++j) {
+      iree_hal_amdgpu_host_queue_set_profile_flags(
+          &physical_device->host_queues[j], flags);
+    }
+  }
+}
+
+static iree_status_t iree_hal_amdgpu_logical_device_set_hsa_profiling_enabled(
+    iree_hal_amdgpu_logical_device_t* logical_device, bool enabled) {
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, enabled ? 1 : 0);
+
+  iree_status_t status = iree_ok_status();
+  iree_host_size_t changed_count = 0;
+  for (iree_host_size_t i = 0;
+       i < logical_device->physical_device_count && iree_status_is_ok(status);
+       ++i) {
+    status = iree_hal_amdgpu_physical_device_set_hsa_profiling_enabled(
+        logical_device->physical_devices[i], enabled);
+    if (iree_status_is_ok(status)) {
+      ++changed_count;
+    }
+  }
+
+  if (!iree_status_is_ok(status) && enabled) {
+    for (iree_host_size_t i = 0; i < changed_count; ++i) {
+      status = iree_status_join(
+          status, iree_hal_amdgpu_physical_device_set_hsa_profiling_enabled(
+                      logical_device->physical_devices[i], false));
+    }
+  } else if (!enabled) {
+    for (iree_host_size_t i = changed_count;
+         i < logical_device->physical_device_count; ++i) {
+      status = iree_status_join(
+          status, iree_hal_amdgpu_physical_device_set_hsa_profiling_enabled(
+                      logical_device->physical_devices[i], false));
+    }
+  }
+
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+// Returns true when |queue_ordinal| is the physical device's counter range
+// sampling queue.
+//
+// Queue affinity ANY resolves to queue 0 for ordinary submissions, so using
+// the final queue gives the sampler the best chance to run independently while
+// the default queue is saturated. When only one queue exists we fall back to
+// that queue and sampling is necessarily ordered behind user work.
+static bool iree_hal_amdgpu_logical_device_is_profile_counter_range_queue(
+    const iree_hal_amdgpu_physical_device_t* physical_device,
+    iree_host_size_t queue_ordinal) {
+  return queue_ordinal + 1 == physical_device->host_queue_count;
+}
+
+static iree_hal_amdgpu_host_queue_t*
+iree_hal_amdgpu_logical_device_select_profile_counter_range_queue(
+    iree_hal_amdgpu_physical_device_t* physical_device) {
+  if (physical_device->host_queue_count == 0) return NULL;
+  return &physical_device->host_queues[physical_device->host_queue_count - 1];
+}
+
+static iree_status_t
+iree_hal_amdgpu_logical_device_set_counter_profiling_enabled(
+    iree_hal_amdgpu_logical_device_t* logical_device,
+    iree_hal_amdgpu_profile_counter_session_t* counter_session, bool enabled) {
+  if (!iree_hal_amdgpu_profile_counter_session_is_active(counter_session)) {
+    return iree_ok_status();
+  }
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, enabled ? 1 : 0);
+
+  iree_status_t status = iree_ok_status();
+  const bool capture_dispatch_samples =
+      iree_hal_amdgpu_profile_counter_session_captures_dispatch_samples(
+          counter_session);
+  const bool capture_queue_ranges =
+      iree_hal_amdgpu_profile_counter_session_captures_queue_ranges(
+          counter_session);
+  for (iree_host_size_t i = 0;
+       i < logical_device->physical_device_count && iree_status_is_ok(status);
+       ++i) {
+    iree_hal_amdgpu_physical_device_t* physical_device =
+        logical_device->physical_devices[i];
+    for (iree_host_size_t j = 0;
+         j < physical_device->host_queue_count && iree_status_is_ok(status);
+         ++j) {
+      iree_hal_amdgpu_host_queue_t* queue = &physical_device->host_queues[j];
+      if (enabled) {
+        iree_hal_amdgpu_profile_counter_enable_flags_t flags =
+            IREE_HAL_AMDGPU_PROFILE_COUNTER_ENABLE_FLAG_NONE;
+        if (capture_dispatch_samples) {
+          flags |= IREE_HAL_AMDGPU_PROFILE_COUNTER_ENABLE_FLAG_DISPATCH_SAMPLES;
+        }
+        if (capture_queue_ranges &&
+            iree_hal_amdgpu_logical_device_is_profile_counter_range_queue(
+                physical_device, j)) {
+          flags |= IREE_HAL_AMDGPU_PROFILE_COUNTER_ENABLE_FLAG_QUEUE_RANGES;
+        }
+        status = iree_hal_amdgpu_host_queue_enable_profile_counters(
+            queue, counter_session, flags);
+      } else {
+        iree_hal_amdgpu_host_queue_disable_profile_counters(queue);
+      }
+    }
+  }
+
+  if (!iree_status_is_ok(status) && enabled) {
+    for (iree_host_size_t i = 0; i < logical_device->physical_device_count;
+         ++i) {
+      iree_hal_amdgpu_physical_device_t* physical_device =
+          logical_device->physical_devices[i];
+      for (iree_host_size_t j = 0; j < physical_device->host_queue_count; ++j) {
+        iree_hal_amdgpu_host_queue_disable_profile_counters(
+            &physical_device->host_queues[j]);
+      }
+    }
+  }
+
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+static iree_status_t
+iree_hal_amdgpu_logical_device_start_profile_counter_ranges(
+    iree_hal_amdgpu_logical_device_t* logical_device,
+    iree_hal_amdgpu_profile_counter_session_t* counter_session) {
+  if (!iree_hal_amdgpu_profile_counter_session_captures_queue_ranges(
+          counter_session)) {
+    return iree_ok_status();
+  }
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  iree_status_t status = iree_ok_status();
+  iree_host_size_t started_device_count = 0;
+  for (iree_host_size_t i = 0;
+       i < logical_device->physical_device_count && iree_status_is_ok(status);
+       ++i) {
+    iree_hal_amdgpu_physical_device_t* physical_device =
+        logical_device->physical_devices[i];
+    if (IREE_UNLIKELY(physical_device->host_queue_count == 0)) {
+      status = iree_make_status(IREE_STATUS_INTERNAL,
+                                "logical device physical device has no host "
+                                "queues (initialization incomplete)");
+    } else {
+      iree_hal_amdgpu_host_queue_t* queue =
+          iree_hal_amdgpu_logical_device_select_profile_counter_range_queue(
+              physical_device);
+      status = iree_hal_amdgpu_host_queue_start_profile_counter_ranges(queue);
+      if (iree_status_is_ok(status)) {
+        ++started_device_count;
+      }
+    }
+  }
+
+  if (!iree_status_is_ok(status)) {
+    for (iree_host_size_t i = 0; i < started_device_count; ++i) {
+      iree_hal_amdgpu_physical_device_t* physical_device =
+          logical_device->physical_devices[i];
+      iree_hal_amdgpu_host_queue_t* queue =
+          iree_hal_amdgpu_logical_device_select_profile_counter_range_queue(
+              physical_device);
+      status = iree_status_join(
+          status, iree_hal_amdgpu_host_queue_flush_profile_counter_ranges(
+                      queue, /*sink=*/NULL, /*session_id=*/0,
+                      IREE_HAL_AMDGPU_PROFILE_COUNTER_RANGE_FLUSH_FLAG_NONE));
+    }
+  }
+
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+static iree_status_t
+iree_hal_amdgpu_logical_device_flush_profile_counter_ranges(
+    iree_hal_amdgpu_logical_device_t* logical_device,
+    iree_hal_amdgpu_profile_counter_session_t* counter_session,
+    iree_hal_profile_sink_t* sink, uint64_t session_id,
+    iree_hal_amdgpu_profile_counter_range_flush_flags_t flags) {
+  if (!iree_hal_amdgpu_profile_counter_session_captures_queue_ranges(
+          counter_session)) {
+    return iree_ok_status();
+  }
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  iree_status_t status = iree_ok_status();
+  for (iree_host_size_t i = 0;
+       i < logical_device->physical_device_count && iree_status_is_ok(status);
+       ++i) {
+    iree_hal_amdgpu_physical_device_t* physical_device =
+        logical_device->physical_devices[i];
+    if (IREE_UNLIKELY(physical_device->host_queue_count == 0)) {
+      status = iree_make_status(IREE_STATUS_INTERNAL,
+                                "logical device physical device has no host "
+                                "queues (initialization incomplete)");
+    } else {
+      iree_hal_amdgpu_host_queue_t* queue =
+          iree_hal_amdgpu_logical_device_select_profile_counter_range_queue(
+              physical_device);
+      status = iree_hal_amdgpu_host_queue_flush_profile_counter_ranges(
+          queue, sink, session_id, flags);
+    }
+  }
+
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+static iree_status_t iree_hal_amdgpu_logical_device_set_trace_profiling_enabled(
+    iree_hal_amdgpu_logical_device_t* logical_device,
+    iree_hal_amdgpu_profile_trace_session_t* trace_session, bool enabled) {
+  if (!iree_hal_amdgpu_profile_trace_session_is_active(trace_session)) {
+    return iree_ok_status();
+  }
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, enabled ? 1 : 0);
+
+  iree_status_t status = iree_ok_status();
+  iree_host_size_t changed_queue_count = 0;
+  for (iree_host_size_t i = 0;
+       i < logical_device->physical_device_count && iree_status_is_ok(status);
+       ++i) {
+    iree_hal_amdgpu_physical_device_t* physical_device =
+        logical_device->physical_devices[i];
+    for (iree_host_size_t j = 0;
+         j < physical_device->host_queue_count && iree_status_is_ok(status);
+         ++j) {
+      iree_hal_amdgpu_host_queue_t* queue = &physical_device->host_queues[j];
+      if (enabled) {
+        status = iree_hal_amdgpu_host_queue_enable_profile_traces(
+            queue, trace_session);
+        if (iree_status_is_ok(status)) {
+          ++changed_queue_count;
+        }
+      } else {
+        iree_hal_amdgpu_host_queue_disable_profile_traces(queue);
+      }
+    }
+  }
+
+  if (!iree_status_is_ok(status) && enabled) {
+    for (iree_host_size_t i = 0, seen_queue_count = 0;
+         i < logical_device->physical_device_count &&
+         seen_queue_count < changed_queue_count;
+         ++i) {
+      iree_hal_amdgpu_physical_device_t* physical_device =
+          logical_device->physical_devices[i];
+      for (iree_host_size_t j = 0; j < physical_device->host_queue_count &&
+                                   seen_queue_count < changed_queue_count;
+           ++j, ++seen_queue_count) {
+        iree_hal_amdgpu_host_queue_disable_profile_traces(
+            &physical_device->host_queues[j]);
+      }
+    }
+  }
+
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+// Selects one host queue from |queue_affinity| after intersecting with this
+// logical device's supported queues. The current policy is deterministic
+// first-set-bit selection, which is enough to honor explicit HIP stream
+// affinities and keeps the CTS path stable. A multi-bit affinity therefore acts
+// as "any of these queues"; queue_flush handles multi-bit masks by iterating
+// all selected queues instead.
+static iree_status_t iree_hal_amdgpu_logical_device_select_host_queue(
+    iree_hal_amdgpu_logical_device_t* logical_device,
+    iree_hal_queue_affinity_t queue_affinity,
+    iree_hal_amdgpu_virtual_queue_t** out_queue) {
+  IREE_ASSERT_ARGUMENT(logical_device);
+  IREE_ASSERT_ARGUMENT(out_queue);
+  *out_queue = NULL;
+
+  iree_hal_amdgpu_queue_affinity_resolved_t resolved;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_queue_affinity_resolve(
+      iree_hal_amdgpu_logical_device_queue_affinity_domain(logical_device),
+      queue_affinity, &resolved));
+  return iree_hal_amdgpu_logical_device_queue_from_ordinal(
+      logical_device, resolved.queue_ordinal, out_queue);
+}
+
+// Selects the physical device backing |queue_affinity| for pool creation.
+//
+// Queue pools are scoped to one physical memory domain, but |queue_affinity|
+// still has the usual "any queue in this mask" meaning. This helper therefore
+// collapses multi-bit masks with the same deterministic first-set-bit policy as
+// host queue submission. In practice IREE_HAL_QUEUE_AFFINITY_ANY usually
+// selects queue 0 after intersecting with this device's supported queue mask.
+static iree_status_t
+iree_hal_amdgpu_logical_device_select_queue_pool_physical_device(
+    iree_hal_amdgpu_logical_device_t* logical_device,
+    iree_hal_queue_affinity_t queue_affinity,
+    iree_hal_amdgpu_physical_device_t** out_physical_device) {
+  IREE_ASSERT_ARGUMENT(logical_device);
+  IREE_ASSERT_ARGUMENT(out_physical_device);
+  *out_physical_device = NULL;
+
+  iree_hal_amdgpu_queue_affinity_resolved_t resolved;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_queue_affinity_resolve(
+      iree_hal_amdgpu_logical_device_queue_affinity_domain(logical_device),
+      queue_affinity, &resolved));
+  *out_physical_device =
+      logical_device->physical_devices[resolved.physical_device_ordinal];
+  return iree_ok_status();
+}
+
+// Normalizes command-buffer queue affinity to queues on one physical device and
+// returns the physical device ordinal whose executable kernel objects may be
+// baked into the recorded command stream.
+static iree_status_t
+iree_hal_amdgpu_logical_device_normalize_command_buffer_affinity(
+    iree_hal_amdgpu_logical_device_t* logical_device,
+    iree_hal_queue_affinity_t queue_affinity,
+    iree_hal_queue_affinity_t* out_queue_affinity,
+    iree_host_size_t* out_device_ordinal) {
+  *out_queue_affinity = 0;
+  *out_device_ordinal = 0;
+
+  return iree_hal_amdgpu_queue_affinity_normalize_for_physical_device(
+      iree_hal_amdgpu_logical_device_queue_affinity_domain(logical_device),
+      queue_affinity, out_queue_affinity, out_device_ordinal);
+}
+
+static bool iree_hal_amdgpu_logical_device_query_pool_epoch(
+    void* user_data, iree_async_axis_t axis, uint64_t epoch) {
+  iree_hal_amdgpu_logical_device_t* logical_device =
+      (iree_hal_amdgpu_logical_device_t*)user_data;
+  hsa_signal_t epoch_signal = {0};
+  if (!iree_hal_amdgpu_epoch_signal_table_lookup(
+          logical_device->host_queue_epoch_table, axis, &epoch_signal)) {
+    return false;
+  }
+  iree_amd_signal_t* signal =
+      (iree_amd_signal_t*)(uintptr_t)epoch_signal.handle;
+  const iree_hsa_signal_value_t current_value = iree_atomic_load(
+      (iree_atomic_int64_t*)&signal->value, iree_memory_order_acquire);
+  if (IREE_UNLIKELY(current_value < 0 ||
+                    current_value > IREE_HAL_AMDGPU_EPOCH_INITIAL_VALUE)) {
+    return false;
+  }
+  const uint64_t current_epoch =
+      (uint64_t)IREE_HAL_AMDGPU_EPOCH_INITIAL_VALUE - (uint64_t)current_value;
+  return current_epoch >= epoch;
+}
+
+static void iree_hal_amdgpu_logical_device_deassign_frontier(
+    iree_hal_amdgpu_logical_device_t* logical_device) {
+  IREE_TRACE_ZONE_BEGIN(z0);
+  for (iree_host_size_t i = 0; i < logical_device->physical_device_count; ++i) {
+    iree_hal_amdgpu_physical_device_deassign_frontier(
+        logical_device->physical_devices[i]);
+  }
+
+  iree_async_frontier_tracker_release(logical_device->frontier_tracker);
+  logical_device->frontier_tracker = NULL;
+  logical_device->axis = 0;
+  memset(&logical_device->topology_info, 0,
+         sizeof(logical_device->topology_info));
+
+  if (logical_device->host_queue_epoch_table) {
+    iree_allocator_free(logical_device->host_allocator,
+                        logical_device->host_queue_epoch_table);
+    logical_device->host_queue_epoch_table = NULL;
+  }
+  IREE_TRACE_ZONE_END(z0);
+}
+
 static void iree_hal_amdgpu_logical_device_error_handler(void* user_data,
                                                          iree_status_t status) {
   iree_hal_amdgpu_logical_device_t* logical_device =
@@ -173,13 +1214,191 @@
   if (!iree_atomic_compare_exchange_strong(
           &logical_device->failure_status, &current_value, (intptr_t)status,
           iree_memory_order_acq_rel, iree_memory_order_relaxed)) {
-    // Previous status was not OK; drop our new status.
-    IREE_IGNORE_ERROR(status);
+    // Previous status was not OK; the sticky slot owns only the first failure.
+    iree_status_free(status);
   }
 
   IREE_TRACE_ZONE_END(z0);
 }
 
+static void iree_hal_amdgpu_logical_device_translate_physical_options(
+    const iree_hal_amdgpu_logical_device_options_t* options,
+    const iree_hal_amdgpu_topology_t* topology,
+    iree_hal_amdgpu_physical_device_options_t* out_options) {
+  iree_hal_amdgpu_physical_device_options_initialize(out_options);
+  out_options->device_block_pools.small.block_size =
+      options->device_block_pools.small.block_size;
+  out_options->device_block_pools.small.initial_capacity =
+      options->device_block_pools.small.initial_capacity;
+  out_options->device_block_pools.large.block_size =
+      options->device_block_pools.large.block_size;
+  out_options->device_block_pools.large.initial_capacity =
+      options->device_block_pools.large.initial_capacity;
+  out_options->default_pool.range_length = options->default_pool.range_length;
+  out_options->default_pool.alignment = options->default_pool.alignment;
+  out_options->default_pool.frontier_capacity =
+      options->default_pool.frontier_capacity;
+  out_options->host_block_pool_initial_capacity =
+      options->preallocate_pools ? 16 : 0;
+  out_options->host_queue_count = topology->gpu_agent_queue_count;
+  out_options->host_queue_aql_capacity = options->host_queues.aql_capacity;
+  out_options->host_queue_notification_capacity =
+      options->host_queues.notification_capacity;
+  out_options->host_queue_kernarg_capacity =
+      options->host_queues.kernarg_capacity;
+  out_options->host_queue_upload_capacity =
+      options->host_queues.upload_capacity;
+  out_options->force_wait_barrier_defer = options->force_wait_barrier_defer;
+}
+
+static iree_status_t iree_hal_amdgpu_logical_device_verify_physical_options(
+    const iree_hal_amdgpu_physical_device_options_t* options,
+    const iree_hal_amdgpu_libhsa_t* libhsa,
+    const iree_hal_amdgpu_topology_t* topology) {
+  for (iree_host_size_t i = 0; i < topology->gpu_agent_count; ++i) {
+    hsa_agent_t gpu_agent = topology->gpu_agents[i];
+    hsa_agent_t cpu_agent = topology->cpu_agents[topology->gpu_cpu_map[i]];
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_physical_device_options_verify(options, libhsa,
+                                                       cpu_agent, gpu_agent),
+        "verifying GPU agent %" PRIhsz " meets required options", i);
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_logical_device_allocate_storage(
+    iree_string_view_t identifier, const iree_hal_amdgpu_topology_t* topology,
+    iree_host_size_t physical_device_size, iree_allocator_t host_allocator,
+    iree_hal_amdgpu_logical_device_t** out_logical_device) {
+  *out_logical_device = NULL;
+
+  iree_hal_amdgpu_logical_device_t* logical_device = NULL;
+  iree_host_size_t physical_device_data_offset = 0;
+  iree_host_size_t identifier_offset = 0;
+  iree_host_size_t total_size = 0;
+  IREE_RETURN_IF_ERROR(IREE_STRUCT_LAYOUT(
+      sizeof(*logical_device), &total_size,
+      IREE_STRUCT_FIELD(topology->gpu_agent_count,
+                        iree_hal_amdgpu_physical_device_t*, NULL),
+      IREE_STRUCT_ARRAY_FIELD_ALIGNED(
+          topology->gpu_agent_count, physical_device_size, uint8_t,
+          iree_max_align_t, &physical_device_data_offset),
+      IREE_STRUCT_FIELD(identifier.size, char, &identifier_offset)));
+
+  const iree_hal_amdgpu_queue_affinity_domain_t queue_affinity_domain = {
+      .supported_affinity = IREE_HAL_QUEUE_AFFINITY_ANY,
+      .physical_device_count = topology->gpu_agent_count,
+      .queue_count_per_physical_device = topology->gpu_agent_queue_count,
+  };
+  iree_hal_queue_affinity_t logical_queue_affinity_mask = 0;
+  for (iree_host_size_t i = 0; i < topology->gpu_agent_count; ++i) {
+    iree_hal_queue_affinity_t physical_device_affinity = 0;
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_queue_affinity_for_physical_device(
+        queue_affinity_domain, i, &physical_device_affinity));
+    iree_hal_queue_affinity_or_into(logical_queue_affinity_mask,
+                                    physical_device_affinity);
+  }
+
+  IREE_RETURN_IF_ERROR(iree_allocator_malloc(host_allocator, total_size,
+                                             (void**)&logical_device));
+  memset(logical_device, 0, total_size);
+  iree_hal_resource_initialize(&iree_hal_amdgpu_logical_device_vtable,
+                               &logical_device->resource);
+  iree_string_view_append_to_buffer(identifier, &logical_device->identifier,
+                                    (char*)logical_device + identifier_offset);
+  logical_device->host_allocator = host_allocator;
+  logical_device->failure_status = IREE_ATOMIC_VAR_INIT(0);
+  iree_atomic_store(&logical_device->epoch, 0, iree_memory_order_relaxed);
+  logical_device->next_profile_session_id = 1;
+  iree_hal_amdgpu_profile_metadata_initialize(
+      host_allocator, &logical_device->profile_metadata);
+  iree_hal_amdgpu_profile_event_streams_initialize(
+      &logical_device->profiling.event_streams);
+
+  // Setup physical device table first so failure cleanup has a valid table.
+  logical_device->physical_device_count = topology->gpu_agent_count;
+  logical_device->queue_affinity_mask = logical_queue_affinity_mask;
+  uint8_t* physical_device_base =
+      (uint8_t*)logical_device + physical_device_data_offset;
+  for (iree_host_size_t i = 0; i < logical_device->physical_device_count; ++i) {
+    logical_device->physical_devices[i] =
+        (iree_hal_amdgpu_physical_device_t*)physical_device_base;
+    physical_device_base += physical_device_size;
+  }
+
+  *out_logical_device = logical_device;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_logical_device_initialize_host_resources(
+    iree_hal_amdgpu_logical_device_t* logical_device,
+    const iree_hal_amdgpu_logical_device_options_t* options,
+    iree_async_proactor_pool_t* proactor_pool,
+    iree_allocator_t host_allocator) {
+  logical_device->proactor_pool = proactor_pool;
+  iree_async_proactor_pool_retain(logical_device->proactor_pool);
+
+  iree_arena_block_pool_initialize(options->host_block_pools.small.block_size,
+                                   host_allocator,
+                                   &logical_device->host_block_pools.small);
+  iree_arena_block_pool_initialize(options->host_block_pools.large.block_size,
+                                   host_allocator,
+                                   &logical_device->host_block_pools.large);
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_aql_program_block_pool_initialize(
+      options->host_block_pools.command_buffer.usable_block_size,
+      host_allocator, &logical_device->host_block_pools.command_buffer));
+  return iree_async_proactor_pool_get(logical_device->proactor_pool, 0,
+                                      &logical_device->proactor);
+}
+
+static iree_status_t
+iree_hal_amdgpu_logical_device_initialize_system_and_allocator(
+    iree_hal_amdgpu_logical_device_t* logical_device,
+    const iree_hal_amdgpu_logical_device_options_t* options,
+    const iree_hal_amdgpu_libhsa_t* libhsa,
+    const iree_hal_amdgpu_topology_t* topology,
+    iree_allocator_t host_allocator) {
+  iree_hal_amdgpu_system_options_t system_options = {
+      .exclusive_execution = options->exclusive_execution,
+  };
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_system_allocate(libhsa, topology, system_options,
+                                      host_allocator, &logical_device->system));
+  return iree_hal_amdgpu_allocator_create(
+      logical_device, &logical_device->system->libhsa,
+      &logical_device->system->topology, host_allocator,
+      &logical_device->device_allocator);
+}
+
+static iree_status_t iree_hal_amdgpu_logical_device_initialize_physical_devices(
+    iree_hal_amdgpu_logical_device_t* logical_device,
+    const iree_hal_amdgpu_topology_t* topology,
+    const iree_hal_amdgpu_physical_device_options_t* options,
+    iree_allocator_t host_allocator) {
+  for (iree_host_size_t device_ordinal = 0;
+       device_ordinal < logical_device->physical_device_count;
+       ++device_ordinal) {
+    const iree_host_size_t host_ordinal = topology->gpu_cpu_map[device_ordinal];
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_physical_device_initialize(
+        (iree_hal_device_t*)logical_device, logical_device->system, options,
+        logical_device->proactor, host_ordinal,
+        &logical_device->system->host_memory_pools[host_ordinal],
+        device_ordinal, host_allocator,
+        logical_device->physical_devices[device_ordinal]));
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_logical_device_warmup_host_pools(
+    iree_hal_amdgpu_logical_device_t* logical_device) {
+  IREE_RETURN_IF_ERROR(iree_arena_block_pool_preallocate(
+      &logical_device->host_block_pools.small, 16));
+  IREE_RETURN_IF_ERROR(iree_arena_block_pool_preallocate(
+      &logical_device->host_block_pools.large, 16));
+  return iree_arena_block_pool_preallocate(
+      &logical_device->host_block_pools.command_buffer, 16);
+}
+
 iree_status_t iree_hal_amdgpu_logical_device_create(
     iree_string_view_t identifier,
     const iree_hal_amdgpu_logical_device_options_t* options,
@@ -207,266 +1426,45 @@
       iree_hal_amdgpu_logical_device_options_verify(options, libhsa, topology),
       "verifying logical device options");
 
-  // Copy options relevant during construction.
-  //
-  // TODO(benvanik): maybe expose these on the public API? feels like too much
-  // churn for too little benefit - option parsing is still possible, though.
   iree_hal_amdgpu_physical_device_options_t physical_device_options = {0};
-  iree_hal_amdgpu_physical_device_options_initialize(&physical_device_options);
-  physical_device_options.device_block_pools.small.block_size =
-      options->device_block_pools.small.block_size;
-  physical_device_options.device_block_pools.small.initial_capacity =
-      options->device_block_pools.small.initial_capacity;
-  physical_device_options.device_block_pools.large.block_size =
-      options->device_block_pools.large.block_size;
-  physical_device_options.device_block_pools.large.initial_capacity =
-      options->device_block_pools.large.initial_capacity;
-  physical_device_options.host_block_pool_initial_capacity =
-      options->preallocate_pools ? 16 : 0;
-  physical_device_options.queue_count = topology->gpu_agent_queue_count;
-  if (options->trace_execution) {
-    physical_device_options.queue_options.flags |=
-        IREE_HAL_AMDGPU_QUEUE_FLAG_TRACE_EXECUTION;
-  }
-  if (options->exclusive_execution) {
-    physical_device_options.queue_options.mode |=
-        IREE_HAL_AMDGPU_QUEUE_SCHEDULING_MODE_EXCLUSIVE;
-  } else {
-    physical_device_options.queue_options.mode |=
-        IREE_HAL_AMDGPU_QUEUE_SCHEDULING_MODE_WORK_CONSERVING;
-  }
+  iree_hal_amdgpu_logical_device_translate_physical_options(
+      options, topology, &physical_device_options);
 
-  // Heterogeneous GPU agents may end up with different queue placement
-  // strategies. For any not explicitly specified as part of the options we
-  // infer the optimal placement and cache the result for subsequent use during
-  // initialization. Note that heterogeneous requires extra coordination during
-  // peer communication as though a CPU<->GPU may be able to communicate OK a
-  // GPU<->GPU pair may not be able to.
-  IREE_ASSERT_LE(topology->gpu_agent_count, 128);
-  iree_hal_amdgpu_queue_placement_t* gpu_agent_queue_placement =
-      (iree_hal_amdgpu_queue_placement_t*)iree_alloca(
-          topology->gpu_agent_count *
-          sizeof(iree_hal_amdgpu_queue_placement_t));
-  for (iree_host_size_t i = 0; i < topology->gpu_agent_count; ++i) {
-    switch (options->queue_placement) {
-      case IREE_HAL_AMDGPU_QUEUE_PLACEMENT_ANY: {
-        hsa_agent_t cpu_agent = topology->cpu_agents[topology->gpu_cpu_map[i]];
-        hsa_agent_t gpu_agent = topology->gpu_agents[i];
-        IREE_RETURN_AND_END_ZONE_IF_ERROR(
-            z0,
-            iree_hal_amdgpu_queue_infer_placement(
-                libhsa, cpu_agent, gpu_agent, &gpu_agent_queue_placement[i]),
-            "inferring optimal queue placement for GPU agent %" PRIhsz, i);
-        break;
-      }
-      default: {
-        gpu_agent_queue_placement[i] = options->queue_placement;
-        break;
-      }
-    }
-  }
-
-  // Verify all GPU agents meet the required physical device options.
-  // If they verify OK we are able to compute their total size (as each may
-  // differ) used to allocate the logical device that embeds their data
-  // structures.
-  iree_host_size_t total_physical_device_size = 0;
-  for (iree_host_size_t i = 0; i < topology->gpu_agent_count; ++i) {
-    hsa_agent_t gpu_agent = topology->gpu_agents[i];
-    hsa_agent_t cpu_agent = topology->cpu_agents[topology->gpu_cpu_map[i]];
-    physical_device_options.queue_options.placement =
-        gpu_agent_queue_placement[i];
-    IREE_RETURN_AND_END_ZONE_IF_ERROR(
-        z0,
-        iree_hal_amdgpu_physical_device_options_verify(
-            &physical_device_options, libhsa, cpu_agent, gpu_agent),
-        "verifying GPU agent %" PRIhsz " meets required options", i);
-    total_physical_device_size +=
-        iree_hal_amdgpu_physical_device_calculate_size(
-            &physical_device_options);
-  }
+  // Verify all GPU agents meet the required physical device options. Each
+  // embedded physical device has the same layout because all physical devices
+  // in one logical device share the same host-queue options.
+  const iree_host_size_t physical_device_size =
+      iree_hal_amdgpu_physical_device_calculate_size(&physical_device_options);
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0,
+      iree_hal_amdgpu_logical_device_verify_physical_options(
+          &physical_device_options, libhsa, topology),
+      "verifying physical device options");
 
   // Allocate the logical device and all nested physical device data structures.
   iree_hal_amdgpu_logical_device_t* logical_device = NULL;
-  const iree_host_size_t total_size =
-      sizeof(*logical_device) +
-      iree_host_align(sizeof(logical_device->physical_devices[0]) *
-                          topology->gpu_agent_count,
-                      iree_max_align_t) +
-      total_physical_device_size + identifier.size;
   IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0, iree_allocator_malloc(host_allocator, total_size,
-                                (void**)&logical_device));
-  iree_hal_resource_initialize(&iree_hal_amdgpu_logical_device_vtable,
-                               &logical_device->resource);
-  iree_string_view_append_to_buffer(
-      identifier, &logical_device->identifier,
-      (char*)logical_device + total_size - identifier.size);
-  logical_device->host_allocator = host_allocator;
-  logical_device->failure_status = IREE_ATOMIC_VAR_INIT(0);
-
-  // Retain the proactor pool and acquire a proactor for this device.
-  logical_device->proactor_pool = create_params->proactor_pool;
-  iree_async_proactor_pool_retain(logical_device->proactor_pool);
-  logical_device->frontier_tracker = NULL;
-  logical_device->axis = 0;
-  iree_atomic_store(&logical_device->epoch, 0, iree_memory_order_relaxed);
-  iree_status_t status = iree_async_proactor_pool_get(
-      logical_device->proactor_pool, 0, &logical_device->proactor);
-
-  // Setup physical device table.
-  // This extra indirection is unfortunate but allows us to have dynamic queue
-  // counts based on options.
-  // We need to initialize this first so that any failure cleanup has a valid
-  // table.
-  logical_device->physical_device_count = topology->gpu_agent_count;
-  uint8_t* physical_device_base =
-      (uint8_t*)logical_device + sizeof(*logical_device) +
-      iree_host_align(sizeof(logical_device->physical_devices[0]) *
-                          topology->gpu_agent_count,
-                      iree_max_align_t);
-  for (iree_host_size_t i = 0, queue_index = 0;
-       i < logical_device->physical_device_count; ++i) {
-    logical_device->physical_devices[i] =
-        (iree_hal_amdgpu_physical_device_t*)physical_device_base;
-    physical_device_options.queue_options.placement =
-        gpu_agent_queue_placement[i];
-    physical_device_base += iree_hal_amdgpu_physical_device_calculate_size(
-        &physical_device_options);
-    for (iree_host_size_t j = 0; j < topology->gpu_agent_queue_count;
-         ++j, ++queue_index) {
-      iree_hal_queue_affinity_or_into(logical_device->queue_affinity_mask,
-                                      1ull << queue_index);
-    }
-  }
-
-  // Block pools used by subsequent data structures.
-  iree_arena_block_pool_initialize(options->host_block_pools.small.block_size,
-                                   host_allocator,
-                                   &logical_device->host_block_pools.small);
-  iree_arena_block_pool_initialize(options->host_block_pools.large.block_size,
-                                   host_allocator,
-                                   &logical_device->host_block_pools.large);
-
-  // Instantiate system container for agents used by the logical device. Loads
-  // fixed per-agent resources like the device library.
-  iree_hal_amdgpu_system_options_t system_options = {
-      .trace_execution = options->trace_execution,
-      .exclusive_execution = options->exclusive_execution,
-  };
+      z0, iree_hal_amdgpu_logical_device_allocate_storage(
+              identifier, topology, physical_device_size, host_allocator,
+              &logical_device));
+  iree_status_t status =
+      iree_hal_amdgpu_logical_device_initialize_host_resources(
+          logical_device, options, create_params->proactor_pool,
+          host_allocator);
   if (iree_status_is_ok(status)) {
-    status = iree_hal_amdgpu_system_allocate(libhsa, topology, system_options,
-                                             host_allocator,
-                                             &logical_device->system);
+    status = iree_hal_amdgpu_logical_device_initialize_system_and_allocator(
+        logical_device, options, libhsa, topology, host_allocator);
   }
-  iree_hal_amdgpu_system_t* system = logical_device->system;
-
-  // Signal used for asynchronous initialization.
-  // Incremented by each initialization dispatch issued on any agent and
-  // decremented as they complete. When this reaches 0 all physical devices have
-  // completed initialization.
-  hsa_signal_t initialization_signal = {0};
   if (iree_status_is_ok(status)) {
-    status = iree_hsa_amd_signal_create(
-        IREE_LIBHSA(&system->libhsa), 0ull, topology->all_agent_count,
-        topology->all_agents, 0, &initialization_signal);
-  }
-
-  // TODO(benvanik): pass device handles and pool configuration to the
-  // allocator. Some implementations may share allocators across multiple
-  // devices created from the same driver.
-  // if (iree_status_is_ok(status)) {
-  //   status = iree_hal_amdgpu_allocator_create(
-  //       host_allocator, &logical_device->device_allocator);
-  // }
-
-  // Initialize a pool for all internal semaphores across all agents.
-  if (iree_status_is_ok(status)) {
-    iree_hal_amdgpu_semaphore_options_t semaphore_options = {
-        .wait_active_for_ns = options->wait_active_for_ns,
-    };
-    status = iree_hal_amdgpu_semaphore_pool_initialize(
-        &system->libhsa, &system->topology,
-        IREE_HAL_AMDGPU_SEMAPHORE_POOL_DEFAULT_BLOCK_CAPACITY,
-        semaphore_options, IREE_HAL_SEMAPHORE_FLAG_DEFAULT, host_allocator,
-        system->host_memory_pools[0].fine_pool,
-        &logical_device->semaphore_pool);
-  }
-
-  // Initialize a pool for all transient buffer handles across all agents.
-  //
-  // TODO(benvanik): possibly make this per NUMA node/CPU agent. Devices will
-  // be accessing the allocation handles and we don't want that to make multiple
-  // hops if we can avoid it.
-  if (iree_status_is_ok(status)) {
-    const iree_hal_buffer_placement_t placement = {
-        .device = (iree_hal_device_t*)logical_device,
-        .queue_affinity = IREE_HAL_QUEUE_AFFINITY_ANY,
-    };
-    status = iree_hal_amdgpu_buffer_pool_initialize(
-        &system->libhsa, &system->topology, placement,
-        IREE_HAL_AMDGPU_BUFFER_POOL_DEFAULT_BLOCK_CAPACITY, host_allocator,
-        system->host_memory_pools[0].fine_pool, &logical_device->buffer_pool);
-  }
-
-  // Route asynchronous errors back to the logical device so we have a single
-  // place to hold on to the error and service it back to users.
-  iree_hal_amdgpu_error_callback_t error_callback = {
-      .fn = iree_hal_amdgpu_logical_device_error_handler,
-      .user_data = logical_device,
-  };
-
-  // Initialize physical devices for each GPU agent in the topology.
-  // Their order matches the original but each may represent more than one
-  // logical queue affinity bit.
-  if (iree_status_is_ok(status)) {
-    for (iree_host_size_t device_ordinal = 0;
-         device_ordinal < logical_device->physical_device_count;
-         ++device_ordinal) {
-      physical_device_options.queue_options.placement =
-          gpu_agent_queue_placement[device_ordinal];
-      const iree_host_size_t host_ordinal =
-          topology->gpu_cpu_map[device_ordinal];
-      status = iree_hal_amdgpu_physical_device_initialize(
-          system, &physical_device_options, host_ordinal,
-          &system->host_memory_pools[host_ordinal], device_ordinal,
-          &logical_device->buffer_pool, error_callback, initialization_signal,
-          host_allocator, logical_device->physical_devices[device_ordinal]);
-      if (!iree_status_is_ok(status)) break;
-    }
+    status = iree_hal_amdgpu_logical_device_initialize_physical_devices(
+        logical_device, topology, &physical_device_options, host_allocator);
   }
 
   // If requested then warmup pools that we expect to grow on the first usage of
   // the backend. The first use may need more than the warmup provides here but
   // that's ok - users can warmup if they want.
-  if (options->preallocate_pools) {
-    if (iree_status_is_ok(status)) {
-      status = iree_arena_block_pool_preallocate(
-          &logical_device->host_block_pools.small, 16);
-    }
-    if (iree_status_is_ok(status)) {
-      status = iree_arena_block_pool_preallocate(
-          &logical_device->host_block_pools.large, 16);
-    }
-    if (iree_status_is_ok(status)) {
-      status = iree_hal_amdgpu_semaphore_pool_preallocate(
-          &logical_device->semaphore_pool, 256);
-    }
-    if (iree_status_is_ok(status)) {
-      status = iree_hal_amdgpu_buffer_pool_preallocate(
-          &logical_device->buffer_pool, 256);
-    }
-  }
-
-  // Wait for all initialization that may still be in progress to complete.
-  // This ensures we don't tear down data structures that may still be in use
-  // on a device doing asynchronous initialization.
-  if (initialization_signal.handle) {
-    iree_hsa_signal_wait_scacquire(IREE_LIBHSA(libhsa), initialization_signal,
-                                   HSA_SIGNAL_CONDITION_LT, 1u, UINT64_MAX,
-                                   HSA_WAIT_STATE_BLOCKED);
-    IREE_IGNORE_ERROR(
-        iree_hsa_signal_destroy(IREE_LIBHSA(libhsa), initialization_signal));
+  if (iree_status_is_ok(status) && options->preallocate_pools) {
+    status = iree_hal_amdgpu_logical_device_warmup_host_pools(logical_device);
   }
 
   if (iree_status_is_ok(status)) {
@@ -485,28 +1483,64 @@
   iree_allocator_t host_allocator = iree_hal_device_host_allocator(base_device);
   IREE_TRACE_ZONE_BEGIN(z0);
 
+  iree_hal_amdgpu_profile_counter_session_t* counter_session =
+      logical_device->profiling.counter_session;
+  iree_hal_amdgpu_profile_trace_session_t* trace_session =
+      logical_device->profiling.trace_session;
+  if (trace_session) {
+    for (iree_host_size_t i = 0; i < logical_device->physical_device_count;
+         ++i) {
+      iree_hal_amdgpu_physical_device_t* physical_device =
+          logical_device->physical_devices[i];
+      for (iree_host_size_t j = 0; j < physical_device->host_queue_count; ++j) {
+        iree_hal_amdgpu_host_queue_disable_profile_traces(
+            &physical_device->host_queues[j]);
+      }
+    }
+    logical_device->profiling.trace_session = NULL;
+    iree_hal_amdgpu_profile_trace_session_free(trace_session);
+  }
+  if (counter_session) {
+    for (iree_host_size_t i = 0; i < logical_device->physical_device_count;
+         ++i) {
+      iree_hal_amdgpu_physical_device_t* physical_device =
+          logical_device->physical_devices[i];
+      for (iree_host_size_t j = 0; j < physical_device->host_queue_count; ++j) {
+        iree_hal_amdgpu_host_queue_disable_profile_counters(
+            &physical_device->host_queues[j]);
+      }
+    }
+    logical_device->profiling.counter_session = NULL;
+    iree_hal_amdgpu_profile_counter_session_free(counter_session);
+  }
+  iree_hal_amdgpu_logical_device_reset_profile_options(logical_device);
+  logical_device->profiling.session_id = 0;
+  iree_hal_amdgpu_profile_event_streams_deinitialize(
+      &logical_device->profiling.event_streams, logical_device->host_allocator);
+
+  iree_hal_amdgpu_logical_device_deassign_frontier(logical_device);
+
   // Devices may hold allocations and need to be cleaned up first.
   for (iree_host_size_t i = 0; i < logical_device->physical_device_count; ++i) {
     iree_hal_amdgpu_physical_device_deinitialize(
         logical_device->physical_devices[i]);
   }
 
-  // All buffers must have been returned to the HAL device by this point.
-  iree_hal_amdgpu_buffer_pool_deinitialize(&logical_device->buffer_pool);
-
-  // All semaphores must have been returned to the HAL device by this point.
-  iree_hal_amdgpu_semaphore_pool_deinitialize(&logical_device->semaphore_pool);
-
   iree_hal_allocator_release(logical_device->device_allocator);
   iree_hal_channel_provider_release(logical_device->channel_provider);
 
   // This may unload HSA; must come after all resources are released.
   iree_hal_amdgpu_system_free(logical_device->system);
 
+  iree_hal_amdgpu_profile_metadata_deinitialize(
+      &logical_device->profile_metadata);
+
   // Note that these may be used by other child data types and must be freed
   // last.
   iree_arena_block_pool_deinitialize(&logical_device->host_block_pools.small);
   iree_arena_block_pool_deinitialize(&logical_device->host_block_pools.large);
+  iree_arena_block_pool_deinitialize(
+      &logical_device->host_block_pools.command_buffer);
 
   iree_async_proactor_pool_release(logical_device->proactor_pool);
 
@@ -562,14 +1596,10 @@
   // Release pooled resources from each physical device. These may return items
   // back to the parent logical device pools.
   for (iree_host_size_t i = 0; i < logical_device->physical_device_count; ++i) {
-    iree_hal_amdgpu_physical_device_trim(logical_device->physical_devices[i]);
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_physical_device_trim(
+        logical_device->physical_devices[i]));
   }
 
-  // Release pooled resources that aren't required for any currently live uses.
-  // May release device memory.
-  iree_hal_amdgpu_buffer_pool_trim(&logical_device->buffer_pool);
-  iree_hal_amdgpu_semaphore_pool_trim(&logical_device->semaphore_pool);
-
   // Trim the allocator pools, if any.
   IREE_RETURN_IF_ERROR(
       iree_hal_allocator_trim(logical_device->device_allocator));
@@ -577,6 +1607,7 @@
   // Trim host pools.
   iree_arena_block_pool_trim(&logical_device->host_block_pools.small);
   iree_arena_block_pool_trim(&logical_device->host_block_pools.large);
+  iree_arena_block_pool_trim(&logical_device->host_block_pools.command_buffer);
 
   return iree_ok_status();
 }
@@ -638,35 +1669,45 @@
       iree_hal_amdgpu_logical_device_cast(base_device);
   memset(out_capabilities, 0, sizeof(*out_capabilities));
 
-  // For single-GPU logical devices, query the first physical device.
-  // TODO(multi-gpu): for multi-GPU logical devices, aggregate capabilities from
-  // all physical devices (take intersection of supported features, lowest
-  // common denominator for limits, etc.).
   if (logical_device->physical_device_count == 0) {
     return iree_make_status(
         IREE_STATUS_INTERNAL,
         "logical device has no physical devices (initialization incomplete)");
   }
 
+  // A multi-GPU logical device is a composite HAL device. Generic HAL topology
+  // has only one node for it, so do not expose a physical-device-0 identity as
+  // though it represented the entire composite. Exact internal physical device
+  // identity is reported through AMDGPU profile/device metadata and queue
+  // affinity records.
+  const bool is_composite_device = logical_device->physical_device_count > 1;
   iree_hal_amdgpu_physical_device_t* physical_device =
       logical_device->physical_devices[0];
-  hsa_agent_t gpu_agent = physical_device->device_agent;
-  const iree_hal_amdgpu_libhsa_t* libhsa = &logical_device->system->libhsa;
 
-  // Query device UUID (32-byte from HSA, truncate to 16 for HAL).
-  char uuid_buffer[32];
-  memset(uuid_buffer, 0, sizeof(uuid_buffer));
-  IREE_RETURN_IF_ERROR(iree_hsa_agent_get_info(
-      IREE_LIBHSA(libhsa), gpu_agent, (hsa_agent_info_t)HSA_AMD_AGENT_INFO_UUID,
-      uuid_buffer));
-  memcpy(out_capabilities->physical_device_uuid, uuid_buffer, 16);
-  out_capabilities->has_physical_device_uuid = true;
+  memset(out_capabilities->physical_device_uuid, 0,
+         sizeof(out_capabilities->physical_device_uuid));
+  if (!is_composite_device && physical_device->has_physical_device_uuid) {
+    memcpy(out_capabilities->physical_device_uuid,
+           physical_device->physical_device_uuid,
+           sizeof(out_capabilities->physical_device_uuid));
+    out_capabilities->has_physical_device_uuid = true;
+  }
 
-  // Query NUMA node from HSA.
-  uint32_t numa_node;
-  IREE_RETURN_IF_ERROR(iree_hsa_agent_get_info(
-      IREE_LIBHSA(libhsa), gpu_agent, HSA_AGENT_INFO_NODE, &numa_node));
-  out_capabilities->numa_node = (uint8_t)numa_node;
+  // Report a NUMA affinity only when the composite has a single nearest host
+  // node that fits the generic HAL uint8_t representation. Mixed-NUMA
+  // composites intentionally leave the default 0 because generic topology
+  // cannot express one logical device spanning multiple CPU NUMA nodes.
+  uint32_t host_numa_node = physical_device->host_numa_node;
+  bool has_representative_numa_node = host_numa_node <= UINT8_MAX;
+  for (iree_host_size_t i = 1; i < logical_device->physical_device_count &&
+                               has_representative_numa_node;
+       ++i) {
+    has_representative_numa_node =
+        logical_device->physical_devices[i]->host_numa_node == host_numa_node;
+  }
+  if (has_representative_numa_node) {
+    out_capabilities->numa_node = (uint8_t)host_numa_node;
+  }
 
   // External handle types (DMA-BUF support from system info).
   if (logical_device->system->info.dmabuf_supported) {
@@ -676,13 +1717,46 @@
         IREE_HAL_TOPOLOGY_HANDLE_TYPE_DMA_BUF;
   }
 
-  // Capability flags.
-  if (logical_device->system->info.svm_accessible_by_default) {
-    out_capabilities->flags |= IREE_HAL_DEVICE_CAPABILITY_UNIFIED_MEMORY;
+  // Memory-system capability flags are the intersection across the physical
+  // devices in this logical device. SVM/pageable-memory facts are distinct
+  // from peer-pool addressability; refine_topology_edge owns the latter.
+  iree_hal_device_capability_bits_t memory_system_flags =
+      iree_hal_amdgpu_select_memory_system_device_capability_flags(
+          &physical_device->memory_system);
+  for (iree_host_size_t i = 1; i < logical_device->physical_device_count; ++i) {
+    memory_system_flags &=
+        iree_hal_amdgpu_select_memory_system_device_capability_flags(
+            &logical_device->physical_devices[i]->memory_system);
   }
+  out_capabilities->flags |= memory_system_flags;
 
-  // Driver handle (HSA agent handle for same-driver refinement).
-  out_capabilities->driver_device_handle = (uintptr_t)gpu_agent.handle;
+  // AMDGPU semaphores are native async timeline semaphores (not binary
+  // emulation).
+  out_capabilities->flags |= IREE_HAL_DEVICE_CAPABILITY_TIMELINE_SEMAPHORES;
+
+  // Fine-grained memory provides host coherency without explicit flushes.
+  // Coarse-grained memory requires fences, but the driver manages that
+  // transparently.
+  out_capabilities->flags |= IREE_HAL_DEVICE_CAPABILITY_HOST_COHERENT;
+
+  // All AMDGPU devices support device-scope atomics. System-scope atomics are
+  // supported on fine-grained memory when callers explicitly opt into
+  // host-visible placement.
+  out_capabilities->flags |= IREE_HAL_DEVICE_CAPABILITY_ATOMIC_SCOPE_DEVICE;
+  out_capabilities->flags |= IREE_HAL_DEVICE_CAPABILITY_ATOMIC_SCOPE_SYSTEM;
+
+  // All AMD GPUs support peer-to-peer DMA (through XGMI or PCIe). The actual
+  // access mode for a specific GPU pair is determined by
+  // refine_topology_edge — here we declare the capability in principle.
+  out_capabilities->flags |= IREE_HAL_DEVICE_CAPABILITY_P2P_COPY;
+
+  // Driver handle (HSA agent handle for same-driver refinement). Composite
+  // devices intentionally leave this unset: a single HSA agent handle would
+  // make generic topology alias detection treat a composite as one GPU.
+  if (!is_composite_device) {
+    out_capabilities->driver_device_handle =
+        (uintptr_t)physical_device->device_agent.handle;
+  }
 
   return iree_ok_status();
 }
@@ -694,9 +1768,248 @@
   return &logical_device->topology_info;
 }
 
+// Maximum number of HSA memory-pool link hops we will stack-allocate.
+#define IREE_HAL_AMDGPU_MAX_TOPOLOGY_LINK_HOPS 16
+
+typedef struct iree_hal_amdgpu_topology_edge_aggregate_t {
+  // Physical capability facts produced by cross-pair aggregation.
+  struct {
+    // Positive capabilities conservatively intersected across every pair.
+    iree_hal_topology_capability_t guaranteed;
+    // Requirement bits unioned across pairs because any pair can constrain use.
+    iree_hal_topology_capability_t required;
+  } physical_capabilities;
+  // Worst non-coherent read mode across all physical pairs.
+  iree_hal_topology_interop_mode_t noncoherent_read_mode;
+  // Worst non-coherent write mode across all physical pairs.
+  iree_hal_topology_interop_mode_t noncoherent_write_mode;
+  // Worst coherent read mode across all physical pairs.
+  iree_hal_topology_interop_mode_t coherent_read_mode;
+  // Worst coherent write mode across all physical pairs.
+  iree_hal_topology_interop_mode_t coherent_write_mode;
+  // Worst link class across all physical pairs.
+  iree_hal_topology_link_class_t link_class;
+  // Worst copy-cost class across all physical pairs.
+  uint8_t copy_cost;
+  // Worst latency class across all physical pairs.
+  uint8_t latency_class;
+  // Worst normalized NUMA distance across all physical pairs.
+  uint8_t numa_distance;
+} iree_hal_amdgpu_topology_edge_aggregate_t;
+
+static iree_status_t iree_hal_amdgpu_query_physical_topology_edge(
+    const iree_hal_amdgpu_libhsa_t* libhsa,
+    const iree_hal_amdgpu_physical_device_t* source_physical_device,
+    const iree_hal_amdgpu_physical_device_t* destination_physical_device,
+    iree_hal_amdgpu_physical_topology_edge_t* out_physical_edge) {
+  hsa_agent_t source_agent = source_physical_device->device_agent;
+  hsa_agent_t destination_agent = destination_physical_device->device_agent;
+
+  // Find both memory pool types on the destination agent. Not all devices
+  // expose both pool types; missing pools are treated as NEVER_ALLOWED for that
+  // pool kind, but an agent with no global pool at all is not a usable topology
+  // node.
+  hsa_amd_memory_pool_t dst_coarse_pool = {0};
+  bool has_coarse_pool = iree_hal_amdgpu_try_find_coarse_global_memory_pool(
+      libhsa, destination_agent, &dst_coarse_pool);
+  hsa_amd_memory_pool_t dst_fine_pool = {0};
+  bool has_fine_pool = iree_hal_amdgpu_try_find_fine_global_memory_pool(
+      libhsa, destination_agent, &dst_fine_pool);
+  if (!has_coarse_pool && !has_fine_pool) {
+    return iree_make_status(
+        IREE_STATUS_UNAVAILABLE,
+        "destination agent has neither coarse nor fine global memory pool");
+  }
+
+  iree_hal_amdgpu_physical_topology_edge_selection_t selection = {
+      .memory_access =
+          {
+              .coarse = HSA_AMD_MEMORY_POOL_ACCESS_NEVER_ALLOWED,
+              .fine = HSA_AMD_MEMORY_POOL_ACCESS_NEVER_ALLOWED,
+          },
+  };
+  if (has_coarse_pool) {
+    IREE_RETURN_IF_ERROR(iree_hsa_amd_agent_memory_pool_get_info(
+        IREE_LIBHSA(libhsa), source_agent, dst_coarse_pool,
+        HSA_AMD_AGENT_MEMORY_POOL_INFO_ACCESS,
+        &selection.memory_access.coarse));
+  }
+  if (has_fine_pool) {
+    IREE_RETURN_IF_ERROR(iree_hsa_amd_agent_memory_pool_get_info(
+        IREE_LIBHSA(libhsa), source_agent, dst_fine_pool,
+        HSA_AMD_AGENT_MEMORY_POOL_INFO_ACCESS, &selection.memory_access.fine));
+  }
+
+  // Query link hop count and topology. The link topology describes the
+  // interconnect between agents and is the same regardless of pool granularity;
+  // use whichever pool is present, preferring coarse-grained memory.
+  hsa_amd_memory_pool_t link_query_pool =
+      has_coarse_pool ? dst_coarse_pool : dst_fine_pool;
+  uint32_t hop_count = 0;
+  IREE_RETURN_IF_ERROR(iree_hsa_amd_agent_memory_pool_get_info(
+      IREE_LIBHSA(libhsa), source_agent, link_query_pool,
+      HSA_AMD_AGENT_MEMORY_POOL_INFO_NUM_LINK_HOPS, &hop_count));
+  if (hop_count > IREE_HAL_AMDGPU_MAX_TOPOLOGY_LINK_HOPS) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "HSA reports %" PRIu32 " link hops between GPU agents (max %" PRIhsz
+        ")",
+        hop_count, (iree_host_size_t)IREE_HAL_AMDGPU_MAX_TOPOLOGY_LINK_HOPS);
+  }
+
+  hsa_amd_memory_pool_link_info_t
+      link_hops[IREE_HAL_AMDGPU_MAX_TOPOLOGY_LINK_HOPS];
+  memset(link_hops, 0, sizeof(link_hops[0]) * hop_count);
+  if (hop_count > 0) {
+    // The LINK_INFO query writes exactly hop_count entries into the caller's
+    // buffer with no separate size parameter.
+    IREE_RETURN_IF_ERROR(iree_hsa_amd_agent_memory_pool_get_info(
+        IREE_LIBHSA(libhsa), source_agent, link_query_pool,
+        HSA_AMD_AGENT_MEMORY_POOL_INFO_LINK_INFO, link_hops));
+  }
+
+  selection.link.hops = link_hops;
+  selection.link.count = hop_count;
+  return iree_hal_amdgpu_select_physical_topology_edge(&selection,
+                                                       out_physical_edge);
+}
+
+static void iree_hal_amdgpu_topology_edge_aggregate_initialize(
+    iree_hal_topology_edge_t edge,
+    iree_hal_amdgpu_topology_edge_aggregate_t* out_aggregate) {
+  // Start physical facts at their best value so the aggregate can both upgrade
+  // an imprecise base edge and then monotonically worsen with each pair.
+  // Per-pair DISALLOWED_BY_DEFAULT access remains copy-only until an allocation
+  // policy proves that direct access was explicitly granted.
+  out_aggregate->physical_capabilities.guaranteed =
+      IREE_HAL_TOPOLOGY_CAPABILITY_P2P_COPY |
+      IREE_HAL_TOPOLOGY_CAPABILITY_PEER_COHERENT |
+      IREE_HAL_TOPOLOGY_CAPABILITY_ATOMIC_DEVICE |
+      IREE_HAL_TOPOLOGY_CAPABILITY_ATOMIC_SYSTEM;
+  out_aggregate->physical_capabilities.required =
+      IREE_HAL_TOPOLOGY_CAPABILITY_NONE;
+  out_aggregate->noncoherent_read_mode = IREE_HAL_TOPOLOGY_INTEROP_MODE_NATIVE;
+  out_aggregate->noncoherent_write_mode = IREE_HAL_TOPOLOGY_INTEROP_MODE_NATIVE;
+  out_aggregate->coherent_read_mode = IREE_HAL_TOPOLOGY_INTEROP_MODE_NATIVE;
+  out_aggregate->coherent_write_mode = IREE_HAL_TOPOLOGY_INTEROP_MODE_NATIVE;
+  out_aggregate->link_class = IREE_HAL_TOPOLOGY_LINK_CLASS_SAME_DIE;
+  out_aggregate->copy_cost = 0;
+  out_aggregate->latency_class = 0;
+  out_aggregate->numa_distance = iree_hal_topology_edge_numa_distance(edge.lo);
+}
+
+static void iree_hal_amdgpu_topology_edge_aggregate_include(
+    const iree_hal_amdgpu_physical_topology_edge_t* physical_edge,
+    iree_hal_amdgpu_topology_edge_aggregate_t* aggregate) {
+  aggregate->physical_capabilities.guaranteed &=
+      physical_edge->capabilities.guaranteed;
+  aggregate->physical_capabilities.required |=
+      physical_edge->capabilities.required;
+
+  aggregate->noncoherent_read_mode = iree_max(
+      aggregate->noncoherent_read_mode, physical_edge->modes.noncoherent_read);
+  aggregate->noncoherent_write_mode =
+      iree_max(aggregate->noncoherent_write_mode,
+               physical_edge->modes.noncoherent_write);
+  aggregate->coherent_read_mode = iree_max(aggregate->coherent_read_mode,
+                                           physical_edge->modes.coherent_read);
+  aggregate->coherent_write_mode = iree_max(
+      aggregate->coherent_write_mode, physical_edge->modes.coherent_write);
+
+  if (physical_edge->link.link_class > aggregate->link_class) {
+    aggregate->link_class = physical_edge->link.link_class;
+  }
+  if (physical_edge->link.copy_cost > aggregate->copy_cost) {
+    aggregate->copy_cost = physical_edge->link.copy_cost;
+  }
+  if (physical_edge->link.latency_class > aggregate->latency_class) {
+    aggregate->latency_class = physical_edge->link.latency_class;
+  }
+  if (physical_edge->link.numa_distance > aggregate->numa_distance) {
+    aggregate->numa_distance = physical_edge->link.numa_distance;
+  }
+}
+
+static void iree_hal_amdgpu_topology_edge_apply_aggregate(
+    const iree_hal_amdgpu_topology_edge_aggregate_t* aggregate,
+    iree_hal_topology_edge_t* edge) {
+  edge->lo = iree_hal_topology_edge_set_buffer_read_mode_noncoherent(
+      edge->lo, aggregate->noncoherent_read_mode);
+  edge->lo = iree_hal_topology_edge_set_buffer_write_mode_noncoherent(
+      edge->lo, aggregate->noncoherent_write_mode);
+  edge->lo = iree_hal_topology_edge_set_buffer_read_mode_coherent(
+      edge->lo, aggregate->coherent_read_mode);
+  edge->lo = iree_hal_topology_edge_set_buffer_write_mode_coherent(
+      edge->lo, aggregate->coherent_write_mode);
+
+  edge->lo =
+      iree_hal_topology_edge_set_link_class(edge->lo, aggregate->link_class);
+  edge->lo =
+      iree_hal_topology_edge_set_copy_cost(edge->lo, aggregate->copy_cost);
+  edge->lo = iree_hal_topology_edge_set_latency_class(edge->lo,
+                                                      aggregate->latency_class);
+  edge->lo = iree_hal_topology_edge_set_numa_distance(edge->lo,
+                                                      aggregate->numa_distance);
+
+  iree_hal_topology_capability_t capabilities =
+      iree_hal_topology_edge_capability_flags(edge->lo);
+  const iree_hal_topology_capability_t physical_guaranteed_capability_mask =
+      IREE_HAL_TOPOLOGY_CAPABILITY_P2P_COPY |
+      IREE_HAL_TOPOLOGY_CAPABILITY_PEER_COHERENT |
+      IREE_HAL_TOPOLOGY_CAPABILITY_ATOMIC_DEVICE |
+      IREE_HAL_TOPOLOGY_CAPABILITY_ATOMIC_SYSTEM;
+  const iree_hal_topology_capability_t physical_required_capability_mask =
+      IREE_HAL_TOPOLOGY_CAPABILITY_PEER_ACCESS_REQUIRES_GRANT;
+  capabilities &= ~(physical_guaranteed_capability_mask |
+                    physical_required_capability_mask);
+  capabilities |= aggregate->physical_capabilities.guaranteed &
+                  physical_guaranteed_capability_mask;
+  capabilities |= aggregate->physical_capabilities.required &
+                  physical_required_capability_mask;
+  edge->lo =
+      iree_hal_topology_edge_set_capability_flags(edge->lo, capabilities);
+}
+
 static iree_status_t iree_hal_amdgpu_logical_device_refine_topology_edge(
     iree_hal_device_t* src_device, iree_hal_device_t* dst_device,
     iree_hal_topology_edge_t* edge) {
+  iree_hal_amdgpu_logical_device_t* src_logical =
+      iree_hal_amdgpu_logical_device_cast(src_device);
+  iree_hal_amdgpu_logical_device_t* dst_logical =
+      iree_hal_amdgpu_logical_device_cast(dst_device);
+  const iree_hal_amdgpu_libhsa_t* libhsa = &src_logical->system->libhsa;
+  if (src_logical->physical_device_count == 0 ||
+      dst_logical->physical_device_count == 0) {
+    return iree_make_status(
+        IREE_STATUS_INTERNAL,
+        "cannot refine AMDGPU topology edge with an empty physical device set");
+  }
+
+  iree_hal_amdgpu_topology_edge_aggregate_t aggregate;
+  iree_hal_amdgpu_topology_edge_aggregate_initialize(*edge, &aggregate);
+
+  // A composite logical device has one generic HAL topology node but several
+  // physical HSA agents. The generic edge must be valid for any source/dest
+  // physical pair because the scheduler cannot encode a subset-specific edge.
+  for (iree_host_size_t source_index = 0;
+       source_index < src_logical->physical_device_count; ++source_index) {
+    const iree_hal_amdgpu_physical_device_t* source_physical_device =
+        src_logical->physical_devices[source_index];
+    for (iree_host_size_t destination_index = 0;
+         destination_index < dst_logical->physical_device_count;
+         ++destination_index) {
+      const iree_hal_amdgpu_physical_device_t* destination_physical_device =
+          dst_logical->physical_devices[destination_index];
+      iree_hal_amdgpu_physical_topology_edge_t physical_edge;
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_query_physical_topology_edge(
+          libhsa, source_physical_device, destination_physical_device,
+          &physical_edge));
+      iree_hal_amdgpu_topology_edge_aggregate_include(&physical_edge,
+                                                      &aggregate);
+    }
+  }
+
+  iree_hal_amdgpu_topology_edge_apply_aggregate(&aggregate, edge);
   return iree_ok_status();
 }
 
@@ -705,53 +2018,61 @@
     const iree_hal_device_topology_info_t* topology_info) {
   iree_hal_amdgpu_logical_device_t* logical_device =
       iree_hal_amdgpu_logical_device_cast(base_device);
-  if (logical_device->frontier_tracker) {
-    iree_async_frontier_tracker_retire_axis(
-        logical_device->frontier_tracker, logical_device->axis,
-        iree_make_status(IREE_STATUS_CANCELLED,
-                         "AMDGPU logical device topology assignment reset"));
-    logical_device->frontier_tracker = NULL;
-    logical_device->axis = 0;
+  if (!topology_info) {
+    iree_hal_amdgpu_logical_device_deassign_frontier(logical_device);
+    return iree_ok_status();
   }
-  memset(&logical_device->topology_info, 0,
-         sizeof(logical_device->topology_info));
-  if (!topology_info) return iree_ok_status();
+  IREE_TRACE_ZONE_BEGIN(z0);
+  iree_hal_amdgpu_system_t* system = logical_device->system;
 
-  iree_async_frontier_tracker_t* frontier_tracker =
-      topology_info->frontier.tracker;
-  const iree_async_axis_t axis = topology_info->frontier.base_axis;
-  if (frontier_tracker && axis != 0) {
-    IREE_RETURN_IF_ERROR(iree_async_frontier_tracker_register_axis(
-        frontier_tracker, axis, /*semaphore=*/NULL));
+  const uint8_t device_count = (uint8_t)system->topology.gpu_agent_count;
+  const uint8_t queue_stride = (uint8_t)system->topology.gpu_agent_queue_count;
+  const iree_host_size_t table_size =
+      iree_hal_amdgpu_epoch_signal_table_size(device_count, queue_stride);
+  iree_status_t status =
+      iree_allocator_malloc(logical_device->host_allocator, table_size,
+                            (void**)&logical_device->host_queue_epoch_table);
+  if (iree_status_is_ok(status)) {
+    iree_hal_amdgpu_epoch_signal_table_initialize(
+        logical_device->host_queue_epoch_table,
+        iree_async_axis_session(topology_info->frontier.base_axis),
+        iree_async_axis_machine(topology_info->frontier.base_axis),
+        device_count, queue_stride);
   }
-  logical_device->topology_info = *topology_info;
-  logical_device->frontier_tracker = frontier_tracker;
-  logical_device->axis = axis;
-  iree_atomic_store(&logical_device->epoch, 0, iree_memory_order_relaxed);
-  return iree_ok_status();
+
+  for (iree_host_size_t device_ordinal = 0;
+       device_ordinal < logical_device->physical_device_count &&
+       iree_status_is_ok(status);
+       ++device_ordinal) {
+    const iree_host_size_t host_ordinal =
+        system->topology.gpu_cpu_map[device_ordinal];
+    status = iree_hal_amdgpu_physical_device_assign_frontier(
+        base_device, system, logical_device->proactor,
+        topology_info->frontier.tracker, topology_info->frontier.base_axis,
+        logical_device->host_queue_epoch_table,
+        &system->host_memory_pools[host_ordinal],
+        logical_device->host_allocator,
+        logical_device->physical_devices[device_ordinal]);
+  }
+
+  if (iree_status_is_ok(status)) {
+    logical_device->topology_info = *topology_info;
+    logical_device->frontier_tracker = topology_info->frontier.tracker;
+    logical_device->axis = topology_info->frontier.base_axis;
+    iree_async_frontier_tracker_retain(logical_device->frontier_tracker);
+  } else {
+    iree_hal_amdgpu_logical_device_deassign_frontier(logical_device);
+  }
+
+  IREE_TRACE_ZONE_END(z0);
+  return status;
 }
 
 static iree_status_t iree_hal_amdgpu_logical_device_create_channel(
     iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
     iree_hal_channel_params_t params, iree_hal_channel_t** out_channel) {
-  iree_hal_amdgpu_logical_device_t* logical_device =
-      iree_hal_amdgpu_logical_device_cast(base_device);
-
-  // Mask the user-provided queue affinity to only those we have.
-  iree_hal_queue_affinity_and_into(queue_affinity,
-                                   logical_device->queue_affinity_mask);
-  if (iree_hal_queue_affinity_is_empty(queue_affinity)) {
-    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                            "no valid queue affinity bits specified");
-  }
-
-  // TODO(benvanik): pass any additional resources required to create the
-  // channel. The device->channel_provider can be used to get default
-  // rank/count, exchange IDs, etc as needed.
-  (void)logical_device;
-
-  return iree_hal_amdgpu_channel_create(
-      params, iree_hal_device_host_allocator(base_device), out_channel);
+  return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                          "AMDGPU collective channels not yet implemented");
 }
 
 static iree_status_t iree_hal_amdgpu_logical_device_create_command_buffer(
@@ -761,77 +2082,29 @@
     iree_hal_command_buffer_t** out_command_buffer) {
   iree_hal_amdgpu_logical_device_t* logical_device =
       iree_hal_amdgpu_logical_device_cast(base_device);
-
-  // Mask the user-provided queue affinity to only those we have.
-  iree_hal_queue_affinity_and_into(queue_affinity,
-                                   logical_device->queue_affinity_mask);
-  if (iree_hal_queue_affinity_is_empty(queue_affinity)) {
-    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                            "no valid queue affinity bits specified");
-  }
-
-  // Determine the physical devices the command buffer will be uploaded to based
-  // on the queues that it is declared as being executable on. A single physical
-  // device may have multiple queues.
-  const iree_hal_amdgpu_device_affinity_t device_affinity =
-      iree_hal_amdgpu_device_affinity_from_queue_affinity(
-          queue_affinity,
-          logical_device->system->topology.gpu_agent_queue_count);
-  IREE_ASSERT_GT(iree_hal_amdgpu_device_affinity_count(device_affinity), 0,
-                 "must have at least one device");
-
-  iree_hal_amdgpu_command_buffer_options_t options;
-  iree_hal_amdgpu_command_buffer_options_initialize(
+  iree_hal_queue_affinity_t effective_queue_affinity = 0;
+  iree_host_size_t device_ordinal = 0;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_logical_device_normalize_command_buffer_affinity(
+          logical_device, queue_affinity, &effective_queue_affinity,
+          &device_ordinal));
+  const iree_hal_amdgpu_physical_device_t* physical_device =
+      logical_device->physical_devices[device_ordinal];
+  return iree_hal_amdgpu_aql_command_buffer_create(
       iree_hal_device_allocator(base_device), mode, command_categories,
-      queue_affinity, binding_capacity, &options);
-
-  // TODO(benvanik): assign based on device options controlling behavior.
-  options.recording_flags = IREE_HAL_AMDGPU_COMMAND_BUFFER_RECORDING_FLAG_NONE;
-
-  // Host block pool is shared across all command buffers to hopefully allow us
-  // to reuse allocations when command buffers are short-lived.
-  options.host_block_pools = &logical_device->host_block_pools;
-
-  // Gather block pools for all of the physical devices - the command buffer
-  // doesn't care about the device and only needs to know how to allocate blocks
-  // for storing the command buffer data.
-  const int device_count =
-      iree_hal_amdgpu_device_affinity_count(device_affinity);
-  iree_hal_amdgpu_block_pools_t** device_block_pools =
-      iree_alloca(device_count * sizeof(iree_hal_amdgpu_block_pools_t*));
-  IREE_HAL_AMDGPU_FOR_PHYSICAL_DEVICE(device_affinity) {
-    iree_hal_amdgpu_physical_device_t* physical_device =
-        logical_device->physical_devices[device_ordinal];
-    device_block_pools[device_index] = &physical_device->coarse_block_pools;
-  }
-  options.device_affinity = device_affinity;
-  options.device_block_pools = device_block_pools;
-
-  return iree_hal_amdgpu_command_buffer_create(
-      &options, logical_device->host_allocator, out_command_buffer);
+      effective_queue_affinity, binding_capacity, device_ordinal,
+      physical_device->prepublished_kernarg_storage,
+      &logical_device->profile_metadata,
+      &logical_device->host_block_pools.command_buffer,
+      &logical_device->host_block_pools.small, logical_device->host_allocator,
+      out_command_buffer);
 }
 
 static iree_status_t iree_hal_amdgpu_logical_device_create_event(
     iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
     iree_hal_event_flags_t flags, iree_hal_event_t** out_event) {
-  iree_hal_amdgpu_logical_device_t* logical_device =
-      iree_hal_amdgpu_logical_device_cast(base_device);
-
-  // Mask the user-provided queue affinity to only those we have.
-  iree_hal_queue_affinity_and_into(queue_affinity,
-                                   logical_device->queue_affinity_mask);
-  if (iree_hal_queue_affinity_is_empty(queue_affinity)) {
-    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                            "no valid queue affinity bits specified");
-  }
-
-  // TODO(benvanik): pass any additional resources required to create the event.
-  // The implementation could pool events here.
-  (void)logical_device;
-
-  return iree_hal_amdgpu_event_create(
-      queue_affinity, flags, iree_hal_device_host_allocator(base_device),
-      out_event);
+  return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                          "AMDGPU events not yet implemented");
 }
 
 static iree_status_t iree_hal_amdgpu_logical_device_create_executable_cache(
@@ -841,8 +2114,8 @@
       iree_hal_amdgpu_logical_device_cast(base_device);
   return iree_hal_amdgpu_executable_cache_create(
       &logical_device->system->libhsa, &logical_device->system->topology,
-      identifier, iree_hal_device_host_allocator(base_device),
-      out_executable_cache);
+      &logical_device->profile_metadata, identifier,
+      iree_hal_device_host_allocator(base_device), out_executable_cache);
 }
 
 static iree_status_t iree_hal_amdgpu_logical_device_import_file(
@@ -852,17 +2125,13 @@
   iree_hal_amdgpu_logical_device_t* logical_device =
       iree_hal_amdgpu_logical_device_cast(base_device);
 
-  // Mask the user-provided queue affinity to only those we have.
-  iree_hal_queue_affinity_and_into(queue_affinity,
-                                   logical_device->queue_affinity_mask);
-  if (iree_hal_queue_affinity_is_empty(queue_affinity)) {
-    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                            "no valid queue affinity bits specified");
-  }
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_queue_affinity_normalize(
+      logical_device->queue_affinity_mask, queue_affinity, &queue_affinity));
 
   return iree_hal_file_from_handle(
       iree_hal_device_allocator(base_device), queue_affinity, access, handle,
-      /*proactor=*/NULL, iree_hal_device_host_allocator(base_device), out_file);
+      logical_device->proactor, iree_hal_device_host_allocator(base_device),
+      out_file);
 }
 
 static iree_status_t iree_hal_amdgpu_logical_device_create_semaphore(
@@ -871,84 +2140,35 @@
     iree_hal_semaphore_t** out_semaphore) {
   iree_hal_amdgpu_logical_device_t* logical_device =
       iree_hal_amdgpu_logical_device_cast(base_device);
-
-  // TODO(benvanik): support exportable semaphores based on flags. We 99.9% of
-  // the time want our internal ones.
-
-  // Acquire a semaphore from the pool.
-  return iree_hal_amdgpu_semaphore_pool_acquire(
-      &logical_device->semaphore_pool, initial_value, flags, out_semaphore);
+  return iree_hal_amdgpu_semaphore_create(
+      logical_device, logical_device->proactor, queue_affinity, initial_value,
+      flags, logical_device->host_allocator, out_semaphore);
 }
 
 static iree_hal_semaphore_compatibility_t
 iree_hal_amdgpu_logical_device_query_semaphore_compatibility(
     iree_hal_device_t* base_device, iree_hal_semaphore_t* semaphore) {
-  iree_hal_semaphore_compatibility_t compatibility =
-      IREE_HAL_SEMAPHORE_COMPATIBILITY_NONE;
-  if (iree_hal_amdgpu_internal_semaphore_isa(semaphore)) {
-    // Internal semaphores need to be in a pool that allows access by all. We
-    // could fast path for semaphores created from this logical device and then
-    // fall back to querying for pool compatibility. Today all semaphores are
-    // created from HSA_AMD_MEMORY_POOL_INFO_ACCESSIBLE_BY_ALL pools and we just
-    // assume regardless of origin they are compatible.
-    compatibility = IREE_HAL_SEMAPHORE_COMPATIBILITY_ALL;
-  } else {
-    // TODO(benvanik): support external semaphores. We can wrap them and have
-    // the device library post back to the host to signal them in cases where we
-    // can't do so via memory operations.
-    compatibility = IREE_HAL_SEMAPHORE_COMPATIBILITY_NONE;
+  if (iree_hal_amdgpu_semaphore_isa(semaphore)) {
+    return IREE_HAL_SEMAPHORE_COMPATIBILITY_ALL;
   }
-  return compatibility;
+  return IREE_HAL_SEMAPHORE_COMPATIBILITY_HOST_ONLY;
 }
 
 static iree_status_t iree_hal_amdgpu_logical_device_query_queue_pool_backend(
     iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
     iree_hal_queue_pool_backend_t* out_backend) {
-  (void)base_device;
-  (void)queue_affinity;
-  (void)out_backend;
-  return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
-                          "AMDGPU queue pool backends are not implemented in "
-                          "this landing slice");
-}
-
-// Resolves a queue affinity to a particular device queue.
-// If the affinity specifies more than one queue we always go with the first one
-// set today (so 0b110 is the same as 0b010).
-//
-// In the future we could load balance (distribute independently executable
-// work) or chain (ensure dependent executable work ends up on the same queue
-// selected earlier). That gets tricky, though, if we start submitting work
-// across peer devices where we may not be able to (quickly) check such
-// dependencies on the host (or have to check them on peers). For now the
-// first-set will handle most cases the compiler generates beyond the
-// unspecified ANY affinity where everything will end up on device 0/queue 0.
-static iree_status_t iree_hal_amdgpu_logical_device_select_queue(
-    iree_hal_amdgpu_logical_device_t* logical_device,
-    iree_hal_queue_affinity_t queue_affinity,
-    iree_hal_amdgpu_virtual_queue_t** out_queue) {
-  // Mask the user-provided queue affinity to only those we have.
-  iree_hal_queue_affinity_and_into(queue_affinity,
-                                   logical_device->queue_affinity_mask);
-  if (iree_hal_queue_affinity_is_empty(queue_affinity)) {
-    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                            "no valid queue affinity bits specified");
-  }
-
-  // Find the first set bit as our default policy.
-  const int logical_queue_ordinal =
-      iree_hal_queue_affinity_find_first_set(queue_affinity);
-
-  // Map queue ordinal to physical device ordinal and its local queue ordinal.
-  const iree_host_size_t per_queue_count =
-      logical_device->system->topology.gpu_agent_queue_count;
-  const iree_host_size_t physical_device_ordinal =
-      logical_queue_ordinal / per_queue_count;
-  const iree_host_size_t physical_queue_ordinal =
-      logical_queue_ordinal % per_queue_count;
-
-  *out_queue = logical_device->physical_devices[physical_device_ordinal]
-                   ->queues[physical_queue_ordinal];
+  iree_hal_amdgpu_logical_device_t* logical_device =
+      iree_hal_amdgpu_logical_device_cast(base_device);
+  iree_hal_amdgpu_physical_device_t* physical_device = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_logical_device_select_queue_pool_physical_device(
+          logical_device, queue_affinity, &physical_device));
+  out_backend->slab_provider = physical_device->default_slab_provider;
+  out_backend->notification = physical_device->default_pool_notification;
+  out_backend->epoch_query = (iree_hal_pool_epoch_query_t){
+      .fn = iree_hal_amdgpu_logical_device_query_pool_epoch,
+      .user_data = logical_device,
+  };
   return iree_ok_status();
 }
 
@@ -962,7 +2182,7 @@
   iree_hal_amdgpu_logical_device_t* logical_device =
       iree_hal_amdgpu_logical_device_cast(base_device);
   iree_hal_amdgpu_virtual_queue_t* queue = NULL;
-  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_logical_device_select_queue(
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_logical_device_select_host_queue(
       logical_device, queue_affinity, &queue));
   return queue->vtable->alloca(queue, wait_semaphore_list,
                                signal_semaphore_list, pool, params,
@@ -977,7 +2197,7 @@
   iree_hal_amdgpu_logical_device_t* logical_device =
       iree_hal_amdgpu_logical_device_cast(base_device);
   iree_hal_amdgpu_virtual_queue_t* queue = NULL;
-  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_logical_device_select_queue(
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_logical_device_select_host_queue(
       logical_device, queue_affinity, &queue));
   return queue->vtable->dealloca(queue, wait_semaphore_list,
                                  signal_semaphore_list, buffer, flags);
@@ -990,25 +2210,31 @@
     iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
     iree_device_size_t length, const void* pattern,
     iree_host_size_t pattern_length, iree_hal_fill_flags_t flags) {
+  // Match the HAL contract documented on iree_hal_command_buffer_fill_buffer
+  // (1/2/4-byte patterns only) so queue_fill and command_buffer_fill accept
+  // the same inputs across all backends. The device kernel itself supports an
+  // 8-byte pattern path via iree_hal_amdgpu_device_buffer_fill_x8, but we
+  // deliberately do not expose that here — callers writing 8-byte fills would
+  // then be portable only to amdgpu.
+  if (IREE_UNLIKELY(pattern_length != 1 && pattern_length != 2 &&
+                    pattern_length != 4)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "fill patterns must be 1, 2, or 4 bytes (got %" PRIhsz ")",
+        pattern_length);
+  }
+  if (IREE_UNLIKELY(!pattern)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "fill pattern pointer is required");
+  }
+  uint64_t pattern_bits = 0;
+  memcpy(&pattern_bits, pattern, pattern_length);
+
   iree_hal_amdgpu_logical_device_t* logical_device =
       iree_hal_amdgpu_logical_device_cast(base_device);
   iree_hal_amdgpu_virtual_queue_t* queue = NULL;
-  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_logical_device_select_queue(
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_logical_device_select_host_queue(
       logical_device, queue_affinity, &queue));
-  uint64_t pattern_bits = 0;
-  switch (pattern_length) {
-    case 1:
-    case 2:
-    case 4:
-    case 8:
-      memcpy(&pattern_bits, pattern, pattern_length);
-      break;
-    default:
-      return iree_make_status(
-          IREE_STATUS_INVALID_ARGUMENT,
-          "pattern length must be 1, 2, 4, or 8 - got %" PRIhsz,
-          pattern_length);
-  }
   return queue->vtable->fill(queue, wait_semaphore_list, signal_semaphore_list,
                              target_buffer, target_offset, length, pattern_bits,
                              pattern_length, flags);
@@ -1024,7 +2250,7 @@
   iree_hal_amdgpu_logical_device_t* logical_device =
       iree_hal_amdgpu_logical_device_cast(base_device);
   iree_hal_amdgpu_virtual_queue_t* queue = NULL;
-  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_logical_device_select_queue(
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_logical_device_select_host_queue(
       logical_device, queue_affinity, &queue));
   return queue->vtable->update(
       queue, wait_semaphore_list, signal_semaphore_list, source_buffer,
@@ -1041,7 +2267,7 @@
   iree_hal_amdgpu_logical_device_t* logical_device =
       iree_hal_amdgpu_logical_device_cast(base_device);
   iree_hal_amdgpu_virtual_queue_t* queue = NULL;
-  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_logical_device_select_queue(
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_logical_device_select_host_queue(
       logical_device, queue_affinity, &queue));
   return queue->vtable->copy(queue, wait_semaphore_list, signal_semaphore_list,
                              source_buffer, source_offset, target_buffer,
@@ -1057,34 +2283,12 @@
     iree_device_size_t length, iree_hal_read_flags_t flags) {
   iree_hal_amdgpu_logical_device_t* logical_device =
       iree_hal_amdgpu_logical_device_cast(base_device);
-
-  // Route to optimized queue I/O if available.
   iree_hal_amdgpu_virtual_queue_t* queue = NULL;
-  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_logical_device_select_queue(
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_logical_device_select_host_queue(
       logical_device, queue_affinity, &queue));
-  if (queue->vtable->read) {
-    return queue->vtable->read(
-        queue, wait_semaphore_list, signal_semaphore_list, source_file,
-        source_offset, target_buffer, target_offset, length, flags);
-  }
-
-  // Fall back to inefficient emulated I/O.
-  // TODO(benvanik): when all queue implementations support native I/O we should
-  // drop the emulation (it's bad).
-  iree_hal_queue_affinity_and_into(queue_affinity,
-                                   logical_device->queue_affinity_mask);
-  if (iree_hal_queue_affinity_is_empty(queue_affinity)) {
-    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                            "no valid queue affinity bits specified");
-  }
-  iree_hal_file_transfer_options_t options = {
-      .chunk_count = IREE_HAL_FILE_TRANSFER_CHUNK_COUNT_DEFAULT,
-      .chunk_size = IREE_HAL_FILE_TRANSFER_CHUNK_SIZE_DEFAULT,
-  };
-  return iree_hal_device_queue_read_streaming(
-      base_device, queue_affinity, wait_semaphore_list, signal_semaphore_list,
-      source_file, source_offset, target_buffer, target_offset, length, flags,
-      options);
+  return queue->vtable->read(queue, wait_semaphore_list, signal_semaphore_list,
+                             source_file, source_offset, target_buffer,
+                             target_offset, length, flags);
 }
 
 static iree_status_t iree_hal_amdgpu_logical_device_queue_write(
@@ -1096,34 +2300,46 @@
     iree_device_size_t length, iree_hal_write_flags_t flags) {
   iree_hal_amdgpu_logical_device_t* logical_device =
       iree_hal_amdgpu_logical_device_cast(base_device);
-
-  // Route to optimized queue I/O if available.
   iree_hal_amdgpu_virtual_queue_t* queue = NULL;
-  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_logical_device_select_queue(
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_logical_device_select_host_queue(
       logical_device, queue_affinity, &queue));
-  if (queue->vtable->read) {
-    return queue->vtable->write(
-        queue, wait_semaphore_list, signal_semaphore_list, source_buffer,
-        source_offset, target_file, target_offset, length, flags);
-  }
+  return queue->vtable->write(queue, wait_semaphore_list, signal_semaphore_list,
+                              source_buffer, source_offset, target_file,
+                              target_offset, length, flags);
+}
 
-  // Fall back to inefficient emulated I/O.
-  // TODO(benvanik): when all queue implementations support native I/O we should
-  // drop the emulation (it's bad).
-  iree_hal_queue_affinity_and_into(queue_affinity,
-                                   logical_device->queue_affinity_mask);
-  if (iree_hal_queue_affinity_is_empty(queue_affinity)) {
-    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                            "no valid queue affinity bits specified");
-  }
-  iree_hal_file_transfer_options_t options = {
-      .chunk_count = IREE_HAL_FILE_TRANSFER_CHUNK_COUNT_DEFAULT,
-      .chunk_size = IREE_HAL_FILE_TRANSFER_CHUNK_SIZE_DEFAULT,
-  };
-  return iree_hal_device_queue_write_streaming(
-      base_device, queue_affinity, wait_semaphore_list, signal_semaphore_list,
-      source_buffer, source_offset, target_file, target_offset, length, flags,
-      options);
+static iree_status_t iree_hal_amdgpu_logical_device_queue_host_call(
+    iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
+    const iree_hal_semaphore_list_t wait_semaphore_list,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_host_call_t call, const uint64_t args[4],
+    iree_hal_host_call_flags_t flags) {
+  iree_hal_amdgpu_logical_device_t* logical_device =
+      iree_hal_amdgpu_logical_device_cast(base_device);
+  iree_hal_amdgpu_virtual_queue_t* queue = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_logical_device_select_host_queue(
+      logical_device, queue_affinity, &queue));
+  return queue->vtable->host_call(queue, wait_semaphore_list,
+                                  signal_semaphore_list, call, args, flags);
+}
+
+static iree_status_t iree_hal_amdgpu_logical_device_queue_dispatch(
+    iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
+    const iree_hal_semaphore_list_t wait_semaphore_list,
+    const iree_hal_semaphore_list_t signal_semaphore_list,
+    iree_hal_executable_t* executable,
+    iree_hal_executable_export_ordinal_t export_ordinal,
+    const iree_hal_dispatch_config_t config, iree_const_byte_span_t constants,
+    const iree_hal_buffer_ref_list_t bindings,
+    iree_hal_dispatch_flags_t flags) {
+  iree_hal_amdgpu_logical_device_t* logical_device =
+      iree_hal_amdgpu_logical_device_cast(base_device);
+  iree_hal_amdgpu_virtual_queue_t* queue = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_logical_device_select_host_queue(
+      logical_device, queue_affinity, &queue));
+  return queue->vtable->dispatch(
+      queue, wait_semaphore_list, signal_semaphore_list, executable,
+      export_ordinal, config, constants, bindings, flags);
 }
 
 static iree_status_t iree_hal_amdgpu_logical_device_queue_execute(
@@ -1136,7 +2352,7 @@
   iree_hal_amdgpu_logical_device_t* logical_device =
       iree_hal_amdgpu_logical_device_cast(base_device);
   iree_hal_amdgpu_virtual_queue_t* queue = NULL;
-  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_logical_device_select_queue(
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_logical_device_select_host_queue(
       logical_device, queue_affinity, &queue));
   return queue->vtable->execute(queue, wait_semaphore_list,
                                 signal_semaphore_list, command_buffer,
@@ -1147,17 +2363,224 @@
     iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity) {
   iree_hal_amdgpu_logical_device_t* logical_device =
       iree_hal_amdgpu_logical_device_cast(base_device);
-  iree_hal_amdgpu_virtual_queue_t* queue = NULL;
-  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_logical_device_select_queue(
-      logical_device, queue_affinity, &queue));
-  return queue->vtable->flush(queue);
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_queue_affinity_normalize(
+      logical_device->queue_affinity_mask, queue_affinity, &queue_affinity));
+
+  IREE_HAL_FOR_QUEUE_AFFINITY(queue_affinity) {
+    iree_hal_amdgpu_virtual_queue_t* queue = NULL;
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_logical_device_queue_from_ordinal(
+        logical_device, queue_ordinal, &queue));
+    IREE_RETURN_IF_ERROR(queue->vtable->flush(queue));
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t
+iree_hal_amdgpu_logical_device_verify_queue_device_profiling_supported(
+    iree_hal_amdgpu_logical_device_t* logical_device) {
+  for (iree_host_size_t i = 0; i < logical_device->physical_device_count; ++i) {
+    iree_hal_amdgpu_physical_device_t* physical_device =
+        logical_device->physical_devices[i];
+    if (iree_hal_amdgpu_vendor_packet_capabilities_support_timestamp_range(
+            physical_device->vendor_packet_capabilities)) {
+      continue;
+    }
+    return iree_make_status(
+        IREE_STATUS_FAILED_PRECONDITION,
+        "AMDGPU queue operation profiling requires PM4 timestamp range "
+        "support on physical device %" PRIhsz,
+        physical_device->device_ordinal);
+  }
+  return iree_ok_status();
 }
 
 static iree_status_t iree_hal_amdgpu_logical_device_profiling_begin(
     iree_hal_device_t* base_device,
     const iree_hal_device_profiling_options_t* options) {
-  return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
-                          "AMDGPU HAL-native profiling is not implemented");
+  iree_hal_amdgpu_logical_device_t* logical_device =
+      iree_hal_amdgpu_logical_device_cast(base_device);
+  iree_hal_device_profiling_options_t resolved_options =
+      iree_hal_amdgpu_logical_device_resolve_profiling_options(options);
+
+  if (iree_hal_device_profiling_options_requests_data(
+          &resolved_options,
+          IREE_HAL_DEVICE_PROFILING_DATA_HOST_EXECUTION_EVENTS)) {
+    return iree_make_status(
+        IREE_STATUS_UNIMPLEMENTED,
+        "AMDGPU profiling does not produce host execution events");
+  }
+  if (resolved_options.data_families == IREE_HAL_DEVICE_PROFILING_DATA_NONE) {
+    return iree_ok_status();
+  }
+  if (!logical_device->frontier_tracker) {
+    return iree_make_status(
+        IREE_STATUS_FAILED_PRECONDITION,
+        "AMDGPU profiling requires an assigned device topology");
+  }
+  if (logical_device->profiling.options.data_families !=
+      IREE_HAL_DEVICE_PROFILING_DATA_NONE) {
+    return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                            "cannot nest AMDGPU profile captures");
+  }
+  if (iree_hal_device_profiling_options_requests_data(
+          &resolved_options,
+          IREE_HAL_DEVICE_PROFILING_DATA_DEVICE_QUEUE_EVENTS)) {
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_logical_device_verify_queue_device_profiling_supported(
+            logical_device));
+  }
+
+  bool sink_session_begun = false;
+  bool hsa_profiling_enabled = false;
+  bool counter_profiling_enabled = false;
+  bool counter_ranges_started = false;
+  bool trace_profiling_enabled = false;
+  iree_hal_device_profiling_options_t session_options = {0};
+  iree_hal_device_profiling_options_storage_t* options_storage = NULL;
+  iree_hal_amdgpu_profile_counter_session_t* counter_session = NULL;
+  iree_hal_amdgpu_profile_trace_session_t* trace_session = NULL;
+  iree_hal_amdgpu_profile_device_metrics_session_t* device_metrics_session =
+      NULL;
+  iree_status_t status = iree_hal_device_profiling_options_clone(
+      &resolved_options, logical_device->host_allocator, &session_options,
+      &options_storage);
+  iree_hal_profile_sink_t* sink = session_options.sink;
+  uint64_t session_id = 0;
+  iree_hal_profile_chunk_metadata_t metadata = {0};
+  if (iree_status_is_ok(status)) {
+    session_id = logical_device->next_profile_session_id++;
+    metadata = iree_hal_amdgpu_logical_device_profile_session_metadata(
+        logical_device, session_id);
+    logical_device->profiling.next_clock_correlation_sample_id = 1;
+    memset(&logical_device->profiling.metadata_cursor, 0,
+           sizeof(logical_device->profiling.metadata_cursor));
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_profile_counter_session_allocate(
+        logical_device, &session_options, logical_device->host_allocator,
+        &counter_session);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_profile_trace_session_allocate(
+        logical_device, &session_options, logical_device->host_allocator,
+        &trace_session);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_profile_device_metrics_session_allocate(
+        logical_device, &session_options, logical_device->host_allocator,
+        &device_metrics_session);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_profile_sink_begin_session(sink, &metadata);
+    sink_session_begun = iree_status_is_ok(status);
+  }
+  if (iree_status_is_ok(status) &&
+      iree_hal_device_profiling_options_requests_data(
+          &session_options, IREE_HAL_DEVICE_PROFILING_DATA_QUEUE_EVENTS)) {
+    status = iree_hal_amdgpu_profile_event_streams_ensure_queue_storage(
+        &logical_device->profiling.event_streams,
+        IREE_HAL_AMDGPU_LOGICAL_DEVICE_PROFILE_QUEUE_EVENT_CAPACITY,
+        logical_device->host_allocator);
+    if (iree_status_is_ok(status)) {
+      iree_hal_amdgpu_profile_event_streams_clear_queue(
+          &logical_device->profiling.event_streams);
+    }
+  }
+  if (iree_status_is_ok(status) &&
+      iree_hal_device_profiling_options_requests_data(
+          &session_options, IREE_HAL_DEVICE_PROFILING_DATA_MEMORY_EVENTS)) {
+    status = iree_hal_amdgpu_profile_event_streams_ensure_memory_storage(
+        &logical_device->profiling.event_streams,
+        IREE_HAL_AMDGPU_LOGICAL_DEVICE_PROFILE_MEMORY_EVENT_CAPACITY,
+        logical_device->host_allocator);
+    if (iree_status_is_ok(status)) {
+      iree_hal_amdgpu_profile_event_streams_clear_memory(
+          &logical_device->profiling.event_streams);
+    }
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_logical_device_write_profile_metadata(
+        logical_device, sink, session_id, session_options.data_families);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_profile_counter_session_write_metadata(
+        counter_session, sink, session_id, logical_device->identifier);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_profile_device_metrics_session_write_metadata(
+        device_metrics_session, sink, session_id, logical_device->identifier);
+  }
+  if (iree_status_is_ok(status) &&
+      iree_hal_amdgpu_logical_device_profiling_needs_hsa_timestamps(
+          session_options.data_families)) {
+    status = iree_hal_amdgpu_logical_device_set_hsa_profiling_enabled(
+        logical_device, true);
+    hsa_profiling_enabled = iree_status_is_ok(status);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_logical_device_set_counter_profiling_enabled(
+        logical_device, counter_session, true);
+    counter_profiling_enabled = iree_status_is_ok(status);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_logical_device_start_profile_counter_ranges(
+        logical_device, counter_session);
+    counter_ranges_started = iree_status_is_ok(status);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_logical_device_set_trace_profiling_enabled(
+        logical_device, trace_session, true);
+    trace_profiling_enabled = iree_status_is_ok(status);
+  }
+
+  if (iree_status_is_ok(status)) {
+    logical_device->profiling.options = session_options;
+    logical_device->profiling.options_storage = options_storage;
+    logical_device->profiling.session_id = session_id;
+    logical_device->profiling.counter_session = counter_session;
+    logical_device->profiling.trace_session = trace_session;
+    logical_device->profiling.device_metrics_session = device_metrics_session;
+    iree_hal_amdgpu_logical_device_set_queue_profiling_enabled(
+        logical_device,
+        iree_hal_amdgpu_logical_device_queue_profile_flags(&session_options));
+  } else {
+    if (trace_profiling_enabled) {
+      status = iree_status_join(
+          status, iree_hal_amdgpu_logical_device_set_trace_profiling_enabled(
+                      logical_device, trace_session, false));
+    }
+    if (counter_ranges_started) {
+      status = iree_status_join(
+          status,
+          iree_hal_amdgpu_logical_device_flush_profile_counter_ranges(
+              logical_device, counter_session, /*sink=*/NULL, /*session_id=*/0,
+              IREE_HAL_AMDGPU_PROFILE_COUNTER_RANGE_FLUSH_FLAG_NONE));
+    }
+    if (counter_profiling_enabled) {
+      status = iree_status_join(
+          status, iree_hal_amdgpu_logical_device_set_counter_profiling_enabled(
+                      logical_device, counter_session, false));
+    }
+    if (hsa_profiling_enabled) {
+      status = iree_status_join(
+          status, iree_hal_amdgpu_logical_device_set_hsa_profiling_enabled(
+                      logical_device, false));
+    }
+    if (sink_session_begun) {
+      status = iree_status_join(
+          status, iree_hal_profile_sink_end_session(sink, &metadata,
+                                                    iree_status_code(status)));
+    }
+    logical_device->profiling.next_clock_correlation_sample_id = 0;
+    memset(&logical_device->profiling.metadata_cursor, 0,
+           sizeof(logical_device->profiling.metadata_cursor));
+    iree_hal_device_profiling_options_storage_free(
+        options_storage, logical_device->host_allocator);
+    iree_hal_amdgpu_profile_counter_session_free(counter_session);
+    iree_hal_amdgpu_profile_trace_session_free(trace_session);
+    iree_hal_amdgpu_profile_device_metrics_session_free(device_metrics_session);
+  }
+  return status;
 }
 
 static iree_status_t iree_hal_amdgpu_logical_device_profiling_flush(
@@ -1165,11 +2588,35 @@
   iree_hal_amdgpu_logical_device_t* logical_device =
       iree_hal_amdgpu_logical_device_cast(base_device);
 
-  // TODO(benvanik): figure out if there's any AMD tooling calls we can make.
-  (void)logical_device;
-  iree_status_t status = iree_ok_status();
-
-  return status;
+  const iree_hal_device_profiling_options_t* options =
+      &logical_device->profiling.options;
+  if (options->data_families == IREE_HAL_DEVICE_PROFILING_DATA_NONE) {
+    return iree_ok_status();
+  }
+  iree_hal_profile_sink_t* sink = options->sink;
+  const bool emit_executable_artifacts =
+      iree_hal_amdgpu_logical_device_profile_needs_executable_artifacts(
+          options->data_families);
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_logical_device_flush_profile_counter_ranges(
+          logical_device, logical_device->profiling.counter_session, sink,
+          logical_device->profiling.session_id,
+          IREE_HAL_AMDGPU_PROFILE_COUNTER_RANGE_FLUSH_FLAG_RESTART));
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_profile_metadata_write(
+      &logical_device->profile_metadata, sink,
+      logical_device->profiling.session_id, logical_device->identifier,
+      emit_executable_artifacts, &logical_device->profiling.metadata_cursor));
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_logical_device_write_profile_events(
+      logical_device, sink, logical_device->profiling.session_id));
+  if (iree_hal_amdgpu_logical_device_profiling_needs_clock_correlations(
+          options->data_families)) {
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_logical_device_write_profile_clock_correlations(
+            logical_device, sink, logical_device->profiling.session_id));
+  }
+  return iree_hal_amdgpu_profile_device_metrics_session_sample_and_write(
+      logical_device->profiling.device_metrics_session, sink,
+      logical_device->profiling.session_id, logical_device->identifier);
 }
 
 static iree_status_t iree_hal_amdgpu_logical_device_profiling_end(
@@ -1177,26 +2624,97 @@
   iree_hal_amdgpu_logical_device_t* logical_device =
       iree_hal_amdgpu_logical_device_cast(base_device);
 
-  // TODO(benvanik): figure out if there's any AMD tooling calls we can make.
-  (void)logical_device;
   iree_status_t status = iree_ok_status();
+  const iree_hal_device_profiling_data_families_t data_families =
+      logical_device->profiling.options.data_families;
+  if (data_families == IREE_HAL_DEVICE_PROFILING_DATA_NONE) {
+    return iree_ok_status();
+  }
 
+  iree_hal_profile_sink_t* sink = logical_device->profiling.options.sink;
+  iree_hal_amdgpu_profile_counter_session_t* counter_session =
+      logical_device->profiling.counter_session;
+  iree_hal_amdgpu_profile_trace_session_t* trace_session =
+      logical_device->profiling.trace_session;
+  iree_hal_amdgpu_profile_device_metrics_session_t* device_metrics_session =
+      logical_device->profiling.device_metrics_session;
+  const uint64_t session_id = logical_device->profiling.session_id;
+  iree_hal_profile_chunk_metadata_t metadata =
+      iree_hal_amdgpu_logical_device_profile_session_metadata(logical_device,
+                                                              session_id);
+  const bool emit_executable_artifacts =
+      iree_hal_amdgpu_logical_device_profile_needs_executable_artifacts(
+          data_families);
+
+  status = iree_hal_amdgpu_logical_device_flush_profile_counter_ranges(
+      logical_device, counter_session, sink, session_id,
+      IREE_HAL_AMDGPU_PROFILE_COUNTER_RANGE_FLUSH_FLAG_NONE);
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_profile_metadata_write(
+        &logical_device->profile_metadata, sink, session_id,
+        logical_device->identifier, emit_executable_artifacts,
+        &logical_device->profiling.metadata_cursor);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_logical_device_write_profile_events(
+        logical_device, sink, session_id);
+  }
+  if (iree_status_is_ok(status) &&
+      iree_hal_amdgpu_logical_device_profiling_needs_clock_correlations(
+          data_families)) {
+    status = iree_hal_amdgpu_logical_device_write_profile_clock_correlations(
+        logical_device, sink, session_id);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_profile_device_metrics_session_sample_and_write(
+        device_metrics_session, sink, session_id, logical_device->identifier);
+  }
+  status = iree_status_join(
+      status, iree_hal_amdgpu_logical_device_set_trace_profiling_enabled(
+                  logical_device, trace_session, false));
+  status = iree_status_join(
+      status, iree_hal_amdgpu_logical_device_set_counter_profiling_enabled(
+                  logical_device, counter_session, false));
+  if (iree_hal_amdgpu_logical_device_profiling_needs_hsa_timestamps(
+          data_families)) {
+    status = iree_status_join(
+        status, iree_hal_amdgpu_logical_device_set_hsa_profiling_enabled(
+                    logical_device, false));
+  }
+  status =
+      iree_status_join(status, iree_hal_profile_sink_end_session(
+                                   sink, &metadata, iree_status_code(status)));
+
+  iree_hal_amdgpu_logical_device_reset_profile_options(logical_device);
+  logical_device->profiling.session_id = 0;
+  logical_device->profiling.next_clock_correlation_sample_id = 0;
+  memset(&logical_device->profiling.metadata_cursor, 0,
+         sizeof(logical_device->profiling.metadata_cursor));
+  logical_device->profiling.counter_session = NULL;
+  logical_device->profiling.trace_session = NULL;
+  logical_device->profiling.device_metrics_session = NULL;
+  iree_hal_amdgpu_logical_device_set_queue_profiling_enabled(
+      logical_device, IREE_HAL_AMDGPU_HOST_QUEUE_PROFILE_FLAG_NONE);
+  iree_hal_amdgpu_profile_counter_session_free(counter_session);
+  iree_hal_amdgpu_profile_trace_session_free(trace_session);
+  iree_hal_amdgpu_profile_device_metrics_session_free(device_metrics_session);
   return status;
 }
 
 static iree_status_t iree_hal_amdgpu_logical_device_external_capture_begin(
     iree_hal_device_t* base_device,
     const iree_hal_device_external_capture_options_t* options) {
-  return iree_make_status(
-      IREE_STATUS_UNIMPLEMENTED,
-      "AMDGPU external capture provider '%.*s' is not implemented",
-      (int)options->provider.size, options->provider.data);
+  (void)base_device;
+  (void)options;
+  return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                          "AMDGPU external capture not implemented");
 }
 
 static iree_status_t iree_hal_amdgpu_logical_device_external_capture_end(
     iree_hal_device_t* base_device) {
+  (void)base_device;
   return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
-                          "AMDGPU external capture is not implemented");
+                          "AMDGPU external capture not implemented");
 }
 
 static const iree_hal_device_vtable_t iree_hal_amdgpu_logical_device_vtable = {
@@ -1231,6 +2749,8 @@
     .queue_copy = iree_hal_amdgpu_logical_device_queue_copy,
     .queue_read = iree_hal_amdgpu_logical_device_queue_read,
     .queue_write = iree_hal_amdgpu_logical_device_queue_write,
+    .queue_host_call = iree_hal_amdgpu_logical_device_queue_host_call,
+    .queue_dispatch = iree_hal_amdgpu_logical_device_queue_dispatch,
     .queue_execute = iree_hal_amdgpu_logical_device_queue_execute,
     .queue_flush = iree_hal_amdgpu_logical_device_queue_flush,
     .profiling_begin = iree_hal_amdgpu_logical_device_profiling_begin,
diff --git a/runtime/src/iree/hal/drivers/amdgpu/logical_device.h b/runtime/src/iree/hal/drivers/amdgpu/logical_device.h
index 4ae3454..5ecb711 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/logical_device.h
+++ b/runtime/src/iree/hal/drivers/amdgpu/logical_device.h
@@ -9,20 +9,46 @@
 
 #include "iree/async/frontier.h"
 #include "iree/base/api.h"
+#include "iree/base/internal/arena.h"
 #include "iree/hal/api.h"
 #include "iree/hal/drivers/amdgpu/api.h"
-#include "iree/hal/drivers/amdgpu/buffer_pool.h"
-#include "iree/hal/drivers/amdgpu/command_buffer.h"
-#include "iree/hal/drivers/amdgpu/semaphore_pool.h"
+#include "iree/hal/drivers/amdgpu/profile_events.h"
+#include "iree/hal/drivers/amdgpu/profile_metadata.h"
 #include "iree/hal/drivers/amdgpu/util/libhsa.h"
 
 typedef struct iree_async_proactor_pool_t iree_async_proactor_pool_t;
 typedef struct iree_async_proactor_t iree_async_proactor_t;
 typedef struct iree_hal_amdgpu_physical_device_t
     iree_hal_amdgpu_physical_device_t;
+typedef struct iree_hal_amdgpu_epoch_signal_table_t
+    iree_hal_amdgpu_epoch_signal_table_t;
+typedef struct iree_hal_amdgpu_profile_counter_session_t
+    iree_hal_amdgpu_profile_counter_session_t;
+typedef struct iree_hal_amdgpu_profile_device_metrics_session_t
+    iree_hal_amdgpu_profile_device_metrics_session_t;
+typedef struct iree_hal_amdgpu_profile_trace_session_t
+    iree_hal_amdgpu_profile_trace_session_t;
 typedef struct iree_hal_amdgpu_system_t iree_hal_amdgpu_system_t;
 typedef struct iree_hal_amdgpu_topology_t iree_hal_amdgpu_topology_t;
 
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+//===----------------------------------------------------------------------===//
+// iree_hal_amdgpu_host_block_pools_t
+//===----------------------------------------------------------------------===//
+
+// Block pools for host memory blocks of various sizes.
+typedef struct iree_hal_amdgpu_host_block_pools_t {
+  // Used for small allocations of around 1-4KB.
+  iree_arena_block_pool_t small;
+  // Used for large page-sized allocations of 32-64kB.
+  iree_arena_block_pool_t large;
+  // Used for durable command-buffer recording blocks.
+  iree_arena_block_pool_t command_buffer;
+} iree_hal_amdgpu_host_block_pools_t;
+
 //===----------------------------------------------------------------------===//
 // iree_hal_amdgpu_logical_device_t
 //===----------------------------------------------------------------------===//
@@ -36,7 +62,9 @@
 // implementation does not currently handle taking the minimum capabilities and
 // limits.
 typedef struct iree_hal_amdgpu_logical_device_t {
+  // HAL resource header.
   iree_hal_resource_t resource;
+  // Host allocator used for logical-device-owned host allocations.
   iree_allocator_t host_allocator;
 
   // Proactor pool retained from create_params; provides async I/O proactors.
@@ -44,16 +72,24 @@
   // Proactor borrowed from the pool for this device's async operations.
   iree_async_proactor_t* proactor;
 
-  // Shared frontier tracker for cross-device causal ordering.
-  // Borrowed from the session — valid as long as the session is alive.
-  // NULL if frontier-based fast paths are not enabled.
+  // Shared frontier tracker for cross-device causal ordering. Retained after
+  // topology assignment and released during logical device destruction.
   iree_async_frontier_tracker_t* frontier_tracker;
 
-  // This device's axis and monotonic epoch counter for frontier tracking.
-  // AMDGPU currently has no submit path — these are plumbing-only.
+  // This device's topology-assigned base axis.
   iree_async_axis_t axis;
+
+  // Logical-device epoch counter for frontier tracking.
   iree_atomic_int64_t epoch;
 
+  // Next process-local profile session identifier allocated by this device.
+  uint64_t next_profile_session_id;
+
+  // Durable profiling metadata registered by cold executable/command-buffer
+  // construction paths.
+  iree_hal_amdgpu_profile_metadata_registry_t profile_metadata;
+
+  // Stable device identifier string stored inline after this struct.
   iree_string_view_t identifier;
 
   // Block pools for host memory blocks of various sizes.
@@ -64,6 +100,11 @@
   // the agents available in HSA that are represented as physical devices.
   iree_hal_amdgpu_system_t* system;
 
+  // Shared epoch-signal table used by all host queues on this logical device
+  // for local cross-queue barrier emission. Owned by the logical device and
+  // deregistered by each host queue before this table is freed.
+  iree_hal_amdgpu_epoch_signal_table_t* host_queue_epoch_table;
+
   // Mask indicating which queue affinities are valid.
   iree_hal_queue_affinity_t queue_affinity_mask;
 
@@ -73,19 +114,35 @@
   // Optional provider used for creating/configuring collective channels.
   iree_hal_channel_provider_t* channel_provider;
 
-  // Growable pool of HAL semaphores and their matching device allocations.
-  // Semaphores can be used on any CPU and GPU agent in the system.
-  iree_hal_amdgpu_semaphore_pool_t semaphore_pool;
-
-  // Growable pool of transient buffers and their matching device handles.
-  // Allocation handles can be used on any CPU and GPU agent in the system.
-  iree_hal_amdgpu_buffer_pool_t buffer_pool;
-
   // Sticky logical device-global error flag.
   // Asynchronous errors from subsystems get routed back to this as our "device
   // loss" trigger.
   iree_atomic_intptr_t failure_status;
 
+  // Active profiling session state. Mutated only by the HAL profiling
+  // begin/end API while its idle-device precondition is held.
+  struct {
+    // Owned profiling options for the active session.
+    iree_hal_device_profiling_options_t options;
+    // Opaque storage backing borrowed pointers in |options|, or NULL.
+    iree_hal_device_profiling_options_storage_t* options_storage;
+    // Process-local profiling session identifier assigned at begin.
+    uint64_t session_id;
+    // Next session-local clock-correlation sample identifier.
+    uint64_t next_clock_correlation_sample_id;
+    // Cursor tracking metadata side-table chunks emitted in this session.
+    iree_hal_amdgpu_profile_metadata_cursor_t metadata_cursor;
+    // Hardware counter session active for selected dispatches, or NULL.
+    iree_hal_amdgpu_profile_counter_session_t* counter_session;
+    // Executable trace session active for selected dispatches, or NULL.
+    iree_hal_amdgpu_profile_trace_session_t* trace_session;
+    // Device metrics session sampled on profiling flush/end, or NULL.
+    iree_hal_amdgpu_profile_device_metrics_session_t* device_metrics_session;
+    // Host-side memory and queue profiling event streams.
+    iree_hal_amdgpu_profile_event_streams_t event_streams;
+  } profiling;
+
+  // Topology metadata assigned by the device group after construction.
   iree_hal_device_topology_info_t topology_info;
 
   // Count of physical devices.
@@ -116,4 +173,65 @@
     const iree_hal_device_create_params_t* create_params,
     iree_allocator_t host_allocator, iree_hal_device_t** out_device);
 
+// Verifies option feature knobs that are independent of HSA/topology queries.
+// Driver creation calls this before loading HSA so unsupported default-device
+// options fail without touching ROCR process-global state. Full logical device
+// verification still happens after topology discovery.
+iree_status_t iree_hal_amdgpu_logical_device_options_verify_supported_features(
+    const iree_hal_amdgpu_logical_device_options_t* options);
+
+// Returns true when memory lifecycle profiling records would be retained.
+//
+// Producers may use this to avoid preparing profiling payloads on hot paths.
+// The answer is only meaningful while the device profiling begin/end idle
+// precondition is held by the caller.
+bool iree_hal_amdgpu_logical_device_should_record_profile_memory_events(
+    iree_hal_device_t* base_device);
+
+// Returns true when the active profile capture should emit heavy dispatch
+// artifacts for the given executable export and queue location.
+bool iree_hal_amdgpu_logical_device_should_profile_dispatch(
+    iree_hal_amdgpu_logical_device_t* logical_device, uint64_t executable_id,
+    uint32_t export_ordinal, uint64_t command_buffer_id, uint32_t command_index,
+    uint32_t physical_device_ordinal, uint32_t queue_ordinal);
+
+// Returns a session-local allocation id, or 0 when memory profiling is off.
+//
+// |out_session_id| receives the active profiling session id owning the returned
+// allocation id. Callers that may release after a later profiling session
+// begins must pass the id back to the session-filtered record helper.
+uint64_t iree_hal_amdgpu_logical_device_allocate_profile_memory_allocation_id(
+    iree_hal_device_t* base_device, uint64_t* out_session_id);
+
+// Records one memory lifecycle event into the active profiling stream.
+//
+// This never calls the sink directly. Events are buffered in host memory and
+// emitted by profiling_flush/end, making this safe for submission and pool
+// paths that must not block on file or tool I/O. The stream is an aggregate
+// lossy signal: when its fixed ring is full the event is dropped and the next
+// emitted chunk reports TRUNCATED with a dropped-record count.
+bool iree_hal_amdgpu_logical_device_record_profile_memory_event(
+    iree_hal_device_t* base_device,
+    const iree_hal_profile_memory_event_t* event);
+
+// Records one memory lifecycle event only if |session_id| is still active.
+bool iree_hal_amdgpu_logical_device_record_profile_memory_event_for_session(
+    iree_hal_device_t* base_device, uint64_t session_id,
+    const iree_hal_profile_memory_event_t* event);
+
+// Records one queue operation event into the active profiling stream.
+//
+// This never calls the sink directly. Events are buffered in host memory and
+// emitted by profiling_flush/end, making this safe for submission paths that
+// must not block on file or tool I/O. The stream is an aggregate lossy signal:
+// when its fixed ring is full the event is dropped and the next emitted chunk
+// reports TRUNCATED with a dropped-record count.
+void iree_hal_amdgpu_logical_device_record_profile_queue_event(
+    iree_hal_device_t* base_device,
+    const iree_hal_profile_queue_event_t* event);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
 #endif  // IREE_HAL_DRIVERS_AMDGPU_LOGICAL_DEVICE_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/physical_device.c b/runtime/src/iree/hal/drivers/amdgpu/physical_device.c
index 843f3fa..683c009 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/physical_device.c
+++ b/runtime/src/iree/hal/drivers/amdgpu/physical_device.c
@@ -6,11 +6,19 @@
 
 #include "iree/hal/drivers/amdgpu/physical_device.h"
 
-#include "iree/hal/drivers/amdgpu/device_queue.h"
-#include "iree/hal/drivers/amdgpu/host_queue.h"
+#include <stdio.h>
+
+#include "iree/async/frontier_tracker.h"
+#include "iree/async/notification.h"
+#include "iree/hal/drivers/amdgpu/abi/signal.h"
+#include "iree/hal/drivers/amdgpu/queue_affinity.h"
+#include "iree/hal/drivers/amdgpu/slab_provider.h"
 #include "iree/hal/drivers/amdgpu/system.h"
+#include "iree/hal/drivers/amdgpu/util/epoch_signal_table.h"
 #include "iree/hal/drivers/amdgpu/util/topology.h"
 #include "iree/hal/drivers/amdgpu/util/vmem.h"
+#include "iree/hal/memory/passthrough_pool.h"
+#include "iree/hal/memory/tlsf_pool.h"
 
 //===----------------------------------------------------------------------===//
 // iree_hal_amdgpu_physical_device_options_t
@@ -25,6 +33,212 @@
 #define IREE_HAL_AMDGPU_PHYSICAL_DEVICE_FINE_BLOCK_POOL_SMALL_PAGE_SIZE 128
 #define IREE_HAL_AMDGPU_PHYSICAL_DEVICE_FINE_BLOCK_POOL_LARGE_PAGE_SIZE 4096
 
+// Catch-all priority for direct allocations in the default pool set.
+#define IREE_HAL_AMDGPU_PHYSICAL_DEVICE_DEFAULT_POOL_PRIORITY_OVERSIZED 0
+
+// Preferred priority for pooled allocations in the default pool set.
+#define IREE_HAL_AMDGPU_PHYSICAL_DEVICE_DEFAULT_POOL_PRIORITY_TLSF 10
+
+typedef struct iree_hal_amdgpu_agent_first_isa_t {
+  // Number of ISAs seen during iteration.
+  uint32_t count;
+  // First ISA returned by the HSA agent iterator.
+  hsa_isa_t value;
+} iree_hal_amdgpu_agent_first_isa_t;
+
+static hsa_status_t iree_hal_amdgpu_record_first_isa(hsa_isa_t isa,
+                                                     void* user_data) {
+  iree_hal_amdgpu_agent_first_isa_t* first_isa =
+      (iree_hal_amdgpu_agent_first_isa_t*)user_data;
+  if (first_isa->count++ == 0) {
+    first_isa->value = isa;
+  }
+  return HSA_STATUS_SUCCESS;
+}
+
+static bool iree_hal_amdgpu_parse_hex_digit(char c, uint32_t* out_value) {
+  if (c >= '0' && c <= '9') {
+    *out_value = (uint32_t)(c - '0');
+    return true;
+  } else if (c >= 'a' && c <= 'f') {
+    *out_value = (uint32_t)(c - 'a' + 10);
+    return true;
+  } else if (c >= 'A' && c <= 'F') {
+    *out_value = (uint32_t)(c - 'A' + 10);
+    return true;
+  }
+  return false;
+}
+
+static iree_status_t iree_hal_amdgpu_query_agent_target_id(
+    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_agent_t agent,
+    iree_host_size_t target_id_processor_capacity,
+    char* target_id_processor_storage,
+    iree_hal_amdgpu_target_id_t* out_target_id) {
+  IREE_ASSERT_ARGUMENT(libhsa);
+  IREE_ASSERT_ARGUMENT(target_id_processor_storage);
+  IREE_ASSERT_ARGUMENT(out_target_id);
+
+  iree_hal_amdgpu_agent_first_isa_t first_isa;
+  memset(&first_isa, 0, sizeof(first_isa));
+  IREE_RETURN_IF_ERROR(iree_hsa_agent_iterate_isas(
+      IREE_LIBHSA(libhsa), agent, iree_hal_amdgpu_record_first_isa,
+      &first_isa));
+  if (first_isa.count == 0) {
+    return iree_make_status(IREE_STATUS_NOT_FOUND,
+                            "GPU agent has no reported HSA ISA");
+  }
+
+  char isa_name_buffer[128] = {0};
+  uint32_t isa_name_length = 0;
+  IREE_RETURN_IF_ERROR(
+      iree_hsa_isa_get_info_alt(IREE_LIBHSA(libhsa), first_isa.value,
+                                HSA_ISA_INFO_NAME_LENGTH, &isa_name_length));
+  if (isa_name_length == 0 ||
+      isa_name_length > IREE_ARRAYSIZE(isa_name_buffer)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "ISA name length invalid: %u", isa_name_length);
+  }
+  IREE_RETURN_IF_ERROR(
+      iree_hsa_isa_get_info_alt(IREE_LIBHSA(libhsa), first_isa.value,
+                                HSA_ISA_INFO_NAME, isa_name_buffer));
+  iree_hal_amdgpu_target_id_t target_id;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_target_id_parse_hsa_isa_name(
+      iree_make_string_view(isa_name_buffer, isa_name_length - /*NUL*/ 1),
+      &target_id));
+  if (target_id.processor.size >= target_id_processor_capacity) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "target ID processor storage too small");
+  }
+  memcpy(target_id_processor_storage, target_id.processor.data,
+         target_id.processor.size);
+  target_id_processor_storage[target_id.processor.size] = 0;
+  target_id.processor = iree_make_string_view(target_id_processor_storage,
+                                              target_id.processor.size);
+  *out_target_id = target_id;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_query_agent_uuid(
+    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_agent_t agent,
+    uint8_t out_uuid[16], bool* out_has_uuid) {
+  IREE_ASSERT_ARGUMENT(libhsa);
+  IREE_ASSERT_ARGUMENT(out_uuid);
+  IREE_ASSERT_ARGUMENT(out_has_uuid);
+
+  memset(out_uuid, 0, 16);
+  *out_has_uuid = false;
+
+  // HSA returns a prefixed ASCII string such as "GPU-4939e1d93d24ff77".
+  // Unsupported devices may return fallback strings such as "GPU-XX"; those
+  // are valid HSA responses but not stable identifiers for profile records.
+  char uuid_string[64] = {0};
+  IREE_RETURN_IF_ERROR(iree_hsa_agent_get_info(
+      IREE_LIBHSA(libhsa), agent, (hsa_agent_info_t)HSA_AMD_AGENT_INFO_UUID,
+      uuid_string));
+
+  iree_string_view_t uuid_hex = iree_make_cstring_view(uuid_string);
+  if (!iree_string_view_consume_prefix(&uuid_hex, IREE_SV("GPU-")) &&
+      !iree_string_view_consume_prefix(&uuid_hex, IREE_SV("CPU-")) &&
+      !iree_string_view_consume_prefix(&uuid_hex, IREE_SV("DSP-")) &&
+      !iree_string_view_consume_prefix(&uuid_hex, IREE_SV("AIE-"))) {
+    return iree_ok_status();
+  }
+
+  iree_host_size_t parsed_length = 0;
+  if (uuid_hex.size == 16) {
+    parsed_length = 8;
+  } else if (uuid_hex.size == 32) {
+    parsed_length = 16;
+  } else {
+    return iree_ok_status();
+  }
+  for (iree_host_size_t i = 0; i < uuid_hex.size; ++i) {
+    uint32_t value = 0;
+    if (!iree_hal_amdgpu_parse_hex_digit(uuid_hex.data[i], &value)) {
+      return iree_ok_status();
+    }
+  }
+
+  if (!iree_string_view_parse_hex_bytes(uuid_hex, parsed_length, out_uuid)) {
+    return iree_make_status(
+        IREE_STATUS_INTERNAL,
+        "HSA device UUID was prevalidated but failed to parse: %.*s",
+        (int)uuid_hex.size, uuid_hex.data);
+  }
+  *out_has_uuid = true;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_query_agent_pci_identity(
+    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_agent_t agent,
+    iree_hal_amdgpu_physical_device_t* out_physical_device) {
+  uint32_t pci_domain = 0;
+  IREE_RETURN_IF_ERROR(iree_hsa_agent_get_info(
+      IREE_LIBHSA(libhsa), agent, (hsa_agent_info_t)HSA_AMD_AGENT_INFO_DOMAIN,
+      &pci_domain));
+  uint32_t bdfid = 0;
+  IREE_RETURN_IF_ERROR(iree_hsa_agent_get_info(
+      IREE_LIBHSA(libhsa), agent, (hsa_agent_info_t)HSA_AMD_AGENT_INFO_BDFID,
+      &bdfid));
+
+  out_physical_device->pci_domain = pci_domain;
+  out_physical_device->pci_bus = (bdfid >> 8) & 0xFFu;
+  out_physical_device->pci_device = (bdfid >> 3) & 0x1Fu;
+  out_physical_device->pci_function = bdfid & 0x7u;
+  out_physical_device->has_pci_identity = 1u;
+  return iree_ok_status();
+}
+
+static bool iree_hal_amdgpu_physical_device_query_pool_epoch(
+    void* user_data, iree_async_axis_t axis, uint64_t epoch) {
+  iree_hal_amdgpu_epoch_signal_table_t* epoch_signal_table =
+      (iree_hal_amdgpu_epoch_signal_table_t*)user_data;
+  hsa_signal_t epoch_signal = {0};
+  if (!iree_hal_amdgpu_epoch_signal_table_lookup(epoch_signal_table, axis,
+                                                 &epoch_signal)) {
+    return false;
+  }
+  iree_amd_signal_t* signal =
+      (iree_amd_signal_t*)(uintptr_t)epoch_signal.handle;
+  const iree_hsa_signal_value_t current_value = iree_atomic_load(
+      (iree_atomic_int64_t*)&signal->value, iree_memory_order_acquire);
+  if (IREE_UNLIKELY(current_value < 0 ||
+                    current_value > IREE_HAL_AMDGPU_EPOCH_INITIAL_VALUE)) {
+    return false;
+  }
+  const uint64_t current_epoch =
+      (uint64_t)IREE_HAL_AMDGPU_EPOCH_INITIAL_VALUE - (uint64_t)current_value;
+  return current_epoch >= epoch;
+}
+
+static iree_string_view_t iree_hal_amdgpu_format_pool_trace_name(
+    char* buffer, iree_host_size_t buffer_capacity, const char* pool_name,
+    iree_host_size_t device_ordinal) {
+  if (IREE_UNLIKELY(buffer_capacity == 0)) return iree_string_view_empty();
+  const int name_length =
+      snprintf(buffer, buffer_capacity, "iree-hal-amdgpu-l0p%" PRIhsz "-%s",
+               device_ordinal, pool_name);
+  if (IREE_UNLIKELY(name_length < 0)) return iree_string_view_empty();
+  iree_host_size_t safe_length = (iree_host_size_t)name_length;
+  if (safe_length >= buffer_capacity) safe_length = buffer_capacity - 1;
+  return iree_make_string_view(buffer, safe_length);
+}
+
+static iree_hal_buffer_usage_t
+iree_hal_amdgpu_physical_device_mappable_pool_supported_usage(void) {
+  const iree_hal_buffer_usage_t sharing_usage =
+      IREE_HAL_BUFFER_USAGE_SHARING_REPLICATE |
+      IREE_HAL_BUFFER_USAGE_SHARING_CONCURRENT |
+      IREE_HAL_BUFFER_USAGE_SHARING_IMMUTABLE;
+  return IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_DISPATCH |
+         sharing_usage | IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED |
+         IREE_HAL_BUFFER_USAGE_MAPPING_PERSISTENT |
+         IREE_HAL_BUFFER_USAGE_MAPPING_OPTIONAL |
+         IREE_HAL_BUFFER_USAGE_MAPPING_ACCESS_RANDOM |
+         IREE_HAL_BUFFER_USAGE_MAPPING_ACCESS_SEQUENTIAL_WRITE;
+}
+
 void iree_hal_amdgpu_physical_device_options_initialize(
     iree_hal_amdgpu_physical_device_options_t* out_options) {
   IREE_ASSERT_ARGUMENT(out_options);
@@ -32,8 +246,8 @@
 
   out_options->device_block_pools.small.block_size =
       IREE_HAL_AMDGPU_PHYSICAL_DEVICE_SMALL_DEVICE_BLOCK_SIZE_DEFAULT;
-  out_options->device_block_pools.large.min_blocks_per_allocation =
-      IREE_HAL_AMDGPU_PHYSICAL_DEVICE_LARGE_DEVICE_BLOCKS_PER_ALLOCATION_DEFAULT;
+  out_options->device_block_pools.small.min_blocks_per_allocation =
+      IREE_HAL_AMDGPU_PHYSICAL_DEVICE_SMALL_DEVICE_BLOCKS_PER_ALLOCATION_DEFAULT;
   out_options->device_block_pools.small.initial_capacity =
       IREE_HAL_AMDGPU_PHYSICAL_DEVICE_SMALL_DEVICE_BLOCK_INITIAL_CAPACITY_DEFAULT;
 
@@ -47,9 +261,25 @@
   out_options->host_block_pool_size =
       IREE_HAL_AMDGPU_PHYSICAL_DEVICE_HOST_BLOCK_SIZE_DEFAULT;
 
-  out_options->queue_count =
+  out_options->host_queue_count =
       IREE_HAL_AMDGPU_PHYSICAL_DEVICE_DEFAULT_QUEUE_COUNT;
-  iree_hal_amdgpu_queue_options_initialize(&out_options->queue_options);
+  out_options->host_queue_aql_capacity =
+      IREE_HAL_AMDGPU_PHYSICAL_DEVICE_DEFAULT_HOST_QUEUE_AQL_CAPACITY;
+  out_options->host_queue_notification_capacity =
+      IREE_HAL_AMDGPU_PHYSICAL_DEVICE_DEFAULT_HOST_QUEUE_NOTIFICATION_CAPACITY;
+  out_options->host_queue_kernarg_capacity =
+      IREE_HAL_AMDGPU_PHYSICAL_DEVICE_DEFAULT_HOST_QUEUE_KERNARG_CAPACITY;
+  out_options->host_queue_upload_capacity =
+      IREE_HAL_AMDGPU_PHYSICAL_DEVICE_DEFAULT_HOST_QUEUE_UPLOAD_CAPACITY;
+
+  out_options->default_pool.range_length =
+      IREE_HAL_AMDGPU_PHYSICAL_DEVICE_DEFAULT_POOL_RANGE_LENGTH_DEFAULT;
+  out_options->default_pool.alignment =
+      IREE_HAL_AMDGPU_PHYSICAL_DEVICE_DEFAULT_POOL_ALIGNMENT_DEFAULT;
+  out_options->default_pool.frontier_capacity =
+      IREE_HAL_AMDGPU_PHYSICAL_DEVICE_DEFAULT_POOL_FRONTIER_CAPACITY_DEFAULT;
+
+  iree_hal_amdgpu_staging_pool_options_initialize(&out_options->file_staging);
 }
 
 iree_status_t iree_hal_amdgpu_physical_device_options_verify(
@@ -79,38 +309,70 @@
         IREE_STATUS_OUT_OF_RANGE,
         "large device block pool size invalid, expected a "
         "power-of-two greater than %d and got %" PRIhsz,
-        IREE_HAL_AMDGPU_PHYSICAL_DEVICE_LARGE_DEVICE_BLOCK_SIZE_DEFAULT,
+        IREE_HAL_AMDGPU_PHYSICAL_DEVICE_MIN_LARGE_DEVICE_BLOCK_SIZE,
         options->device_block_pools.large.block_size);
   }
 
-  // Verify each queue - if used - is valid.
-  if (options->queue_count > 0 && options->queue_count <= 64) {
-    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
-                            "a physical device may only have 1-64 HAL queues");
-  }
-  IREE_RETURN_IF_ERROR(
-      iree_hal_amdgpu_queue_options_verify(&options->queue_options, libhsa,
-                                           cpu_agent, gpu_agent),
-      "verifying queue options");
-
-  // Verify that the total hardware queues required is less than the total
-  // available on the device. Some of those queues may be in use at the time we
-  // are running and so we may fail to allocate them all even if this reports
-  // OK.
-  uint32_t max_queue_count = 0;
-  IREE_RETURN_IF_ERROR(
-      iree_hsa_agent_get_info(IREE_LIBHSA(libhsa), gpu_agent,
-                              HSA_AGENT_INFO_QUEUES_MAX, &max_queue_count),
-      "querying HSA_AGENT_INFO_QUEUES_MAX");
-  const iree_host_size_t device_queue_count =
-      options->queue_count * options->queue_options.execution_queue_count;
-  if (device_queue_count > max_queue_count) {
+  if (options->host_queue_count == 0 || options->host_queue_count > UINT8_MAX) {
     return iree_make_status(
         IREE_STATUS_OUT_OF_RANGE,
-        "maximum hardware queue count exceeded; device reports %u available "
-        "queues (at maximum) and %" PRIhsz " were requested",
-        max_queue_count, device_queue_count);
+        "host queue count must be in [1, %u] to fit the queue-axis encoding "
+        "(got %" PRIhsz ")",
+        UINT8_MAX, options->host_queue_count);
   }
+  if (!iree_host_size_is_power_of_two(options->host_queue_aql_capacity) ||
+      !iree_host_size_is_power_of_two(
+          options->host_queue_notification_capacity) ||
+      !iree_host_size_is_power_of_two(options->host_queue_kernarg_capacity) ||
+      (options->host_queue_upload_capacity != 0 &&
+       !iree_host_size_is_power_of_two(options->host_queue_upload_capacity))) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "host queue AQL, notification, kernarg, and upload capacities must all "
+        "be powers of two, with zero allowed for disabled upload capacity (got "
+        "aql=%u, notification=%u, kernarg_blocks=%u, upload_bytes=%u)",
+        options->host_queue_aql_capacity,
+        options->host_queue_notification_capacity,
+        options->host_queue_kernarg_capacity,
+        options->host_queue_upload_capacity);
+  }
+  if (options->host_queue_kernarg_capacity / 2u <
+      options->host_queue_aql_capacity) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "host queue kernarg capacity must be at least 2x the AQL queue "
+        "capacity to cover one tail-padding gap at wrap (got "
+        "kernarg_blocks=%u, "
+        "aql_packets=%u)",
+        options->host_queue_kernarg_capacity, options->host_queue_aql_capacity);
+  }
+
+  if (options->default_pool.range_length == 0) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "default pool range_length must be non-zero");
+  }
+  if (options->default_pool.alignment < IREE_HAL_MEMORY_TLSF_MIN_ALIGNMENT ||
+      !iree_device_size_is_power_of_two(options->default_pool.alignment)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "default pool alignment must be a power of two >= %" PRIu64
+        " (got %" PRIu64 ")",
+        (uint64_t)IREE_HAL_MEMORY_TLSF_MIN_ALIGNMENT,
+        (uint64_t)options->default_pool.alignment);
+  }
+  if (options->default_pool.range_length < options->default_pool.alignment) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "default pool range_length (%" PRIu64
+                            ") must be >= alignment (%" PRIu64 ")",
+                            (uint64_t)options->default_pool.range_length,
+                            (uint64_t)options->default_pool.alignment);
+  }
+  if (options->default_pool.frontier_capacity == 0) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "default pool frontier_capacity must be non-zero");
+  }
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_staging_pool_options_verify(&options->file_staging));
 
   return iree_ok_status();
 }
@@ -119,208 +381,846 @@
 // iree_hal_amdgpu_physical_device_t
 //===----------------------------------------------------------------------===//
 
-// Calculates the size in bytes of the storage required for a queue
-// implementation based on the provided |options|.
-static iree_host_size_t iree_hal_amdgpu_queue_calculate_size(
-    const iree_hal_amdgpu_queue_options_t* options) {
-  switch (options->placement) {
-    case IREE_HAL_AMDGPU_QUEUE_PLACEMENT_HOST:
-      return iree_hal_amdgpu_host_queue_calculate_size(options);
-    case IREE_HAL_AMDGPU_QUEUE_PLACEMENT_DEVICE:
-      return iree_hal_amdgpu_device_queue_calculate_size(options);
-    default:
-      IREE_ASSERT_UNREACHABLE("queue placement should have resolved earlier");
-      return 0;  // invalid handled elsewhere
-  }
-}
-
-// Initializes |out_queue| in-place based on |options|.
-static iree_status_t iree_hal_amdgpu_queue_initialize(
-    iree_hal_amdgpu_system_t* system, iree_hal_amdgpu_queue_options_t options,
-    hsa_agent_t device_agent, iree_host_size_t device_ordinal,
-    iree_hal_amdgpu_host_service_t* host_service,
-    iree_arena_block_pool_t* host_block_pool,
-    iree_hal_amdgpu_block_allocators_t* block_allocators,
-    iree_hal_amdgpu_buffer_pool_t* buffer_pool,
-    iree_hal_amdgpu_error_callback_t error_callback,
-    hsa_signal_t initialization_signal, iree_allocator_t host_allocator,
-    iree_hal_amdgpu_virtual_queue_t* out_queue) {
-  // Note that today these mostly take the same arguments but we may need to
-  // provide more for device queues (such as peer information).
-  switch (options.placement) {
-    case IREE_HAL_AMDGPU_QUEUE_PLACEMENT_HOST:
-      return iree_hal_amdgpu_host_queue_initialize(
-          system, options, device_agent, device_ordinal, host_service,
-          host_block_pool, block_allocators, buffer_pool, error_callback,
-          initialization_signal, host_allocator, out_queue);
-    case IREE_HAL_AMDGPU_QUEUE_PLACEMENT_DEVICE:
-      return iree_hal_amdgpu_device_queue_initialize(
-          system, options, device_agent, device_ordinal, host_service,
-          host_block_pool, block_allocators, buffer_pool, error_callback,
-          initialization_signal, host_allocator, out_queue);
-    default:
-      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                              "invalid queue placement %d",
-                              (int)options.placement);
-  }
-}
-
 iree_host_size_t iree_hal_amdgpu_physical_device_calculate_size(
     const iree_hal_amdgpu_physical_device_options_t* options) {
-  const iree_host_size_t base_size =
+  IREE_ASSERT_ARGUMENT(options);
+  return iree_host_align(
       sizeof(iree_hal_amdgpu_physical_device_t) +
-      iree_host_align(
-          options->queue_count * sizeof(iree_hal_amdgpu_virtual_queue_t*),
-          iree_max_align_t);
-  const iree_host_size_t queue_size =
-      iree_hal_amdgpu_queue_calculate_size(&options->queue_options);
-  return iree_host_align(base_size + options->queue_count * queue_size,
-                         iree_max_align_t);
+          sizeof(iree_hal_amdgpu_host_queue_t) * options->host_queue_count,
+      iree_max_align_t);
 }
 
-iree_status_t iree_hal_amdgpu_physical_device_initialize(
+static iree_status_t iree_hal_amdgpu_physical_device_initialize_identity(
     iree_hal_amdgpu_system_t* system,
     const iree_hal_amdgpu_physical_device_options_t* options,
     iree_host_size_t host_ordinal,
     const iree_hal_amdgpu_host_memory_pools_t* host_memory_pools,
-    iree_host_size_t device_ordinal, iree_hal_amdgpu_buffer_pool_t* buffer_pool,
-    iree_hal_amdgpu_error_callback_t error_callback,
-    hsa_signal_t initialization_signal, iree_allocator_t host_allocator,
+    iree_host_size_t device_ordinal,
     iree_hal_amdgpu_physical_device_t* out_physical_device) {
-  IREE_ASSERT_ARGUMENT(out_physical_device);
-  IREE_TRACE_ZONE_BEGIN(z0);
-
   iree_hal_amdgpu_libhsa_t* libhsa = &system->libhsa;
-
-  hsa_agent_t host_agent = system->topology.cpu_agents[host_ordinal];
   hsa_agent_t device_agent = system->topology.gpu_agents[device_ordinal];
+  hsa_agent_t host_agent = system->topology.cpu_agents[host_ordinal];
 
-  // Zeroing allows for deinitialization to happen midway through initialization
-  // if something fails.
+  // Zeroing allows deinitialization to run after any partial initialization
+  // failure below.
   memset(out_physical_device, 0, sizeof(*out_physical_device));
-
   out_physical_device->device_agent = device_agent;
   out_physical_device->device_ordinal = device_ordinal;
-  out_physical_device->queue_count = options->queue_count;
+  out_physical_device->host_memory_pools = *host_memory_pools;
+  out_physical_device->host_queue_capacity = options->host_queue_count;
+  out_physical_device->host_queue_aql_capacity =
+      options->host_queue_aql_capacity;
+  out_physical_device->host_queue_notification_capacity =
+      options->host_queue_notification_capacity;
+  out_physical_device->host_queue_kernarg_capacity =
+      options->host_queue_kernarg_capacity;
+  out_physical_device->host_queue_upload_capacity =
+      options->host_queue_upload_capacity;
 
-  // Setup queue pointers before anything else so that when we deinitialize they
-  // are valid.
-  const iree_host_size_t total_queue_size =
-      iree_hal_amdgpu_queue_calculate_size(&options->queue_options);
-  uint8_t* queue_base_ptr =
-      (uint8_t*)out_physical_device + sizeof(*out_physical_device) +
-      iree_host_align(
-          options->queue_count * sizeof(out_physical_device->queues[0]),
-          iree_max_align_t);
-  for (iree_host_size_t i = 0; i < options->queue_count; ++i) {
-    out_physical_device->queues[i] =
-        (iree_hal_amdgpu_virtual_queue_t*)(queue_base_ptr +
-                                           i * total_queue_size);
-  }
+  IREE_RETURN_IF_ERROR(
+      iree_hsa_agent_get_info(IREE_LIBHSA(libhsa), device_agent,
+                              (hsa_agent_info_t)HSA_AMD_AGENT_INFO_DRIVER_UID,
+                              &out_physical_device->driver_uid));
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_query_agent_pci_identity(
+      libhsa, device_agent, out_physical_device));
+  bool has_physical_device_uuid = false;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_query_agent_uuid(
+      libhsa, device_agent, out_physical_device->physical_device_uuid,
+      &has_physical_device_uuid));
+  out_physical_device->has_physical_device_uuid =
+      has_physical_device_uuid ? 1u : 0u;
+  uint32_t host_numa_node = UINT32_MAX;
+  IREE_RETURN_IF_ERROR(iree_hsa_agent_get_info(
+      IREE_LIBHSA(libhsa), host_agent, HSA_AGENT_INFO_NODE, &host_numa_node));
+  out_physical_device->host_numa_node = host_numa_node;
+  return iree_ok_status();
+}
 
-  // Find the device pools used for blocks.
-  hsa_amd_memory_pool_t coarse_block_memory_pool = {0};
-  IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0, iree_hal_amdgpu_find_coarse_global_memory_pool(
-              libhsa, device_agent, &coarse_block_memory_pool));
-  hsa_amd_memory_pool_t fine_block_memory_pool = {0};
-  IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0, iree_hal_amdgpu_find_fine_global_memory_pool(
-              libhsa, device_agent, &fine_block_memory_pool));
-
-  // Initialize the per-device block pool.
+static iree_status_t iree_hal_amdgpu_physical_device_initialize_host_pools(
+    const iree_hal_amdgpu_physical_device_options_t* options,
+    iree_allocator_t host_allocator,
+    iree_hal_amdgpu_physical_device_t* out_physical_device) {
   // This should be pinned to the host NUMA node associated with the devices but
   // today we rely on the OS to migrate pages as needed.
   iree_arena_block_pool_initialize(options->host_block_pool_size,
                                    host_allocator,
                                    &out_physical_device->fine_host_block_pool);
-  if (options->host_block_pool_initial_capacity) {
-    IREE_RETURN_AND_END_ZONE_IF_ERROR(
-        z0, iree_arena_block_pool_preallocate(
-                &out_physical_device->fine_host_block_pool,
-                options->host_block_pool_initial_capacity));
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_transient_buffer_pool_initialize(
+      &out_physical_device->fine_host_block_pool,
+      &out_physical_device->transient_buffer_pool));
+  return iree_hal_amdgpu_buffer_pool_initialize(
+      &out_physical_device->fine_host_block_pool,
+      &out_physical_device->materialized_buffer_pool);
+}
+
+static iree_status_t iree_hal_amdgpu_physical_device_query_global_memory_pools(
+    iree_hal_amdgpu_libhsa_t* libhsa, hsa_agent_t device_agent,
+    hsa_amd_memory_pool_t* out_coarse_block_memory_pool,
+    hsa_amd_memory_pool_t* out_fine_block_memory_pool) {
+  iree_status_t status = iree_hal_amdgpu_find_coarse_global_memory_pool(
+      libhsa, device_agent, out_coarse_block_memory_pool);
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_find_fine_global_memory_pool(
+        libhsa, device_agent, out_fine_block_memory_pool);
+  }
+  if (!iree_status_is_ok(status)) {
+    status = iree_status_annotate(
+        status, IREE_SV("AMDGPU physical device requires coarse and fine "
+                        "device-local global memory pools"));
+  }
+  return status;
+}
+
+typedef struct iree_hal_amdgpu_physical_device_kernarg_ring_memory_t {
+  // Descriptor consumed by host queue initialization.
+  iree_hal_amdgpu_kernarg_ring_memory_t descriptor;
+  // Host fallback access-agent list referenced by |descriptor|.
+  hsa_agent_t host_access_agents[1];
+} iree_hal_amdgpu_physical_device_kernarg_ring_memory_t;
+
+static iree_status_t iree_hal_amdgpu_physical_device_query_memory_pool_access(
+    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_agent_t agent,
+    hsa_amd_memory_pool_t memory_pool,
+    hsa_amd_memory_pool_access_t* out_access) {
+  *out_access = HSA_AMD_MEMORY_POOL_ACCESS_NEVER_ALLOWED;
+  return iree_hsa_amd_agent_memory_pool_get_info(
+      IREE_LIBHSA(libhsa), agent, memory_pool,
+      HSA_AMD_AGENT_MEMORY_POOL_INFO_ACCESS, out_access);
+}
+
+static void iree_hal_amdgpu_physical_device_use_host_kernarg_memory(
+    const iree_hal_amdgpu_host_memory_pools_t* host_memory_pools,
+    hsa_agent_t device_agent,
+    iree_hal_amdgpu_physical_device_kernarg_ring_memory_t* out_memory) {
+  memset(out_memory, 0, sizeof(*out_memory));
+  out_memory->host_access_agents[0] = device_agent;
+  out_memory->descriptor = (iree_hal_amdgpu_kernarg_ring_memory_t){
+      .memory_pool = host_memory_pools->kernarg_pool,
+      .access_agents = out_memory->host_access_agents,
+      .access_agent_count = 1,
+  };
+}
+
+static void iree_hal_amdgpu_physical_device_use_cpu_visible_kernarg_memory(
+    const iree_hal_amdgpu_cpu_visible_device_coarse_memory_t* capability,
+    iree_hal_amdgpu_physical_device_kernarg_ring_memory_t* out_memory) {
+  memset(out_memory, 0, sizeof(*out_memory));
+  out_memory->descriptor = (iree_hal_amdgpu_kernarg_ring_memory_t){
+      .memory_pool = capability->memory_pool,
+      .access_agents = capability->access_agents,
+      .access_agent_count = capability->access_agent_count,
+      .publication = capability->host_write_publication,
+  };
+}
+
+static hsa_amd_hdp_flush_t
+iree_hal_amdgpu_physical_device_query_hdp_flush_registers(
+    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_agent_t device_agent) {
+  hsa_amd_hdp_flush_t hdp_flush = {0};
+  const hsa_status_t hsa_status = iree_hsa_agent_get_info_raw(
+      libhsa, device_agent, (hsa_agent_info_t)HSA_AMD_AGENT_INFO_HDP_FLUSH,
+      &hdp_flush);
+  if (hsa_status != HSA_STATUS_SUCCESS) {
+    memset(&hdp_flush, 0, sizeof(hdp_flush));
+  }
+  return hdp_flush;
+}
+
+static iree_status_t
+iree_hal_amdgpu_physical_device_query_svm_direct_host_access(
+    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_agent_t device_agent,
+    bool* out_direct_host_access) {
+  *out_direct_host_access = false;
+  return iree_hsa_agent_get_info(
+      IREE_LIBHSA(libhsa), device_agent,
+      (hsa_agent_info_t)HSA_AMD_AGENT_INFO_SVM_DIRECT_HOST_ACCESS,
+      out_direct_host_access);
+}
+
+static iree_status_t
+iree_hal_amdgpu_physical_device_initialize_cpu_visible_device_coarse_memory(
+    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_agent_t device_agent,
+    hsa_amd_memory_pool_t device_coarse_memory_pool,
+    iree_hal_amdgpu_gfxip_version_t gfxip_version,
+    const iree_hal_amdgpu_topology_t* topology,
+    iree_hal_amdgpu_cpu_visible_device_coarse_memory_t* out_memory) {
+  memset(out_memory, 0, sizeof(*out_memory));
+  if (!device_coarse_memory_pool.handle || topology->cpu_agent_count == 0) {
+    return iree_ok_status();
+  }
+  if (IREE_UNLIKELY(topology->cpu_agent_count >
+                    IREE_HAL_AMDGPU_MAX_CPU_AGENT)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "AMDGPU topology has %" PRIhsz
+        " CPU agents but CPU-visible coarse memory tracks at most %d",
+        topology->cpu_agent_count, IREE_HAL_AMDGPU_MAX_CPU_AGENT);
+  }
+  if (!iree_hal_amdgpu_kernarg_ring_supports_host_write_publication()) {
+    return iree_ok_status();
+  }
+  if (!iree_hal_amdgpu_gfxip_allows_hdp_kernarg_publication(gfxip_version)) {
+    return iree_ok_status();
   }
 
-  // Create block pools and allocators used to back device-side resources.
-  // Shared amongst all queues on the device.
+  const hsa_amd_hdp_flush_t hdp_flush =
+      iree_hal_amdgpu_physical_device_query_hdp_flush_registers(libhsa,
+                                                                device_agent);
+  if (!hdp_flush.HDP_MEM_FLUSH_CNTL || !hdp_flush.HDP_REG_FLUSH_CNTL) {
+    return iree_ok_status();
+  }
+
+  hsa_amd_memory_pool_access_t cpu_access[IREE_HAL_AMDGPU_MAX_CPU_AGENT] = {0};
+  for (iree_host_size_t i = 0; i < topology->cpu_agent_count; ++i) {
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_physical_device_query_memory_pool_access(
+            libhsa, topology->cpu_agents[i], device_coarse_memory_pool,
+            &cpu_access[i]));
+  }
+
+  const iree_hal_amdgpu_cpu_visible_device_coarse_memory_selection_t selection = {
+      .device_agent = device_agent,
+      .memory_pool = device_coarse_memory_pool,
+      .gfxip_version = gfxip_version,
+      .cpu =
+          {
+              .agents = topology->cpu_agents,
+              .access = cpu_access,
+              .count = topology->cpu_agent_count,
+          },
+      .hdp =
+          {
+              .registers = hdp_flush,
+          },
+      .flags =
+          IREE_HAL_AMDGPU_CPU_VISIBLE_DEVICE_COARSE_MEMORY_SELECTION_FLAG_HOST_WRITE_PUBLICATION_SUPPORTED,
+  };
+  return iree_hal_amdgpu_select_cpu_visible_device_coarse_memory(&selection,
+                                                                 out_memory);
+}
+
+static iree_status_t
+iree_hal_amdgpu_physical_device_initialize_memory_system_capabilities(
+    const iree_hal_amdgpu_libhsa_t* libhsa,
+    const iree_hal_amdgpu_system_info_t* system_info, hsa_agent_t device_agent,
+    hsa_amd_memory_pool_t fine_block_memory_pool,
+    const iree_hal_amdgpu_cpu_visible_device_coarse_memory_t*
+        cpu_visible_device_coarse_memory,
+    iree_hal_amdgpu_memory_system_capabilities_t* out_capabilities) {
+  bool svm_direct_host_access = false;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_physical_device_query_svm_direct_host_access(
+          libhsa, device_agent, &svm_direct_host_access));
+
+  const iree_hal_amdgpu_memory_system_capabilities_selection_t selection = {
+      .svm =
+          {
+              .supported = system_info->svm.supported,
+              .accessible_by_default = system_info->svm.accessible_by_default,
+              .xnack_enabled = system_info->svm.xnack_enabled,
+              .direct_host_access = svm_direct_host_access ? 1u : 0u,
+          },
+      .device_local =
+          {
+              .fine_memory_pool = fine_block_memory_pool,
+              .coarse_cpu_visible_memory = cpu_visible_device_coarse_memory,
+          },
+  };
+  iree_hal_amdgpu_select_memory_system_capabilities(&selection,
+                                                    out_capabilities);
+  return iree_ok_status();
+}
+
+static void iree_hal_amdgpu_physical_device_select_kernarg_ring_memory(
+    const iree_hal_amdgpu_physical_device_t* physical_device,
+    const iree_hal_amdgpu_host_memory_pools_t* host_memory_pools,
+    iree_hal_amdgpu_physical_device_kernarg_ring_memory_t* out_memory) {
+  iree_hal_amdgpu_physical_device_use_host_kernarg_memory(
+      host_memory_pools, physical_device->device_agent, out_memory);
+  if (!iree_hal_amdgpu_cpu_visible_device_coarse_memory_is_available(
+          &physical_device->cpu_visible_device_coarse_memory)) {
+    return;
+  }
+  iree_hal_amdgpu_physical_device_use_cpu_visible_kernarg_memory(
+      &physical_device->cpu_visible_device_coarse_memory, out_memory);
+}
+
+static iree_status_t iree_hal_amdgpu_physical_device_initialize_block_pool(
+    iree_hal_amdgpu_libhsa_t* libhsa,
+    iree_hal_amdgpu_block_pool_options_t pool_options, hsa_agent_t device_agent,
+    hsa_amd_memory_pool_t memory_pool, const char* trace_name_prefix,
+    iree_host_size_t device_ordinal, iree_allocator_t host_allocator,
+    iree_hal_amdgpu_block_pool_t* out_block_pool) {
+  char trace_name[64] = {0};
+  pool_options.trace_name = iree_hal_amdgpu_format_pool_trace_name(
+      trace_name, IREE_ARRAYSIZE(trace_name), trace_name_prefix,
+      device_ordinal);
+  return iree_hal_amdgpu_block_pool_initialize(libhsa, pool_options,
+                                               device_agent, memory_pool,
+                                               host_allocator, out_block_pool);
+}
+
+static iree_status_t
+iree_hal_amdgpu_physical_device_initialize_device_block_pools_and_allocators(
+    iree_hal_amdgpu_libhsa_t* libhsa,
+    const iree_hal_amdgpu_physical_device_options_t* options,
+    hsa_agent_t device_agent, iree_host_size_t device_ordinal,
+    hsa_amd_memory_pool_t coarse_block_memory_pool,
+    hsa_amd_memory_pool_t fine_block_memory_pool,
+    iree_allocator_t host_allocator,
+    iree_hal_amdgpu_physical_device_t* out_physical_device) {
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_physical_device_initialize_block_pool(
+      libhsa, options->device_block_pools.small, device_agent,
+      coarse_block_memory_pool, "coarse-small-block", device_ordinal,
+      host_allocator, &out_physical_device->coarse_block_pools.small));
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_physical_device_initialize_block_pool(
+      libhsa, options->device_block_pools.large, device_agent,
+      coarse_block_memory_pool, "coarse-large-block", device_ordinal,
+      host_allocator, &out_physical_device->coarse_block_pools.large));
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_physical_device_initialize_block_pool(
+      libhsa, options->device_block_pools.small, device_agent,
+      fine_block_memory_pool, "fine-small-block", device_ordinal,
+      host_allocator, &out_physical_device->fine_block_pools.small));
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_physical_device_initialize_block_pool(
+      libhsa, options->device_block_pools.large, device_agent,
+      fine_block_memory_pool, "fine-large-block", device_ordinal,
+      host_allocator, &out_physical_device->fine_block_pools.large));
+
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_block_allocator_initialize(
+      &out_physical_device->coarse_block_pools.small,
+      IREE_HAL_AMDGPU_PHYSICAL_DEVICE_COARSE_BLOCK_POOL_SMALL_PAGE_SIZE,
+      &out_physical_device->coarse_block_allocators.small));
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_block_allocator_initialize(
+      &out_physical_device->coarse_block_pools.large,
+      IREE_HAL_AMDGPU_PHYSICAL_DEVICE_COARSE_BLOCK_POOL_LARGE_PAGE_SIZE,
+      &out_physical_device->coarse_block_allocators.large));
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_block_allocator_initialize(
+      &out_physical_device->fine_block_pools.small,
+      IREE_HAL_AMDGPU_PHYSICAL_DEVICE_FINE_BLOCK_POOL_SMALL_PAGE_SIZE,
+      &out_physical_device->fine_block_allocators.small));
+  return iree_hal_amdgpu_block_allocator_initialize(
+      &out_physical_device->fine_block_pools.large,
+      IREE_HAL_AMDGPU_PHYSICAL_DEVICE_FINE_BLOCK_POOL_LARGE_PAGE_SIZE,
+      &out_physical_device->fine_block_allocators.large);
+}
+
+static iree_status_t iree_hal_amdgpu_physical_device_preallocate_host_pool(
+    const iree_hal_amdgpu_physical_device_options_t* options,
+    iree_hal_amdgpu_physical_device_t* out_physical_device) {
+  if (!options->host_block_pool_initial_capacity) return iree_ok_status();
+  return iree_arena_block_pool_preallocate(
+      &out_physical_device->fine_host_block_pool,
+      options->host_block_pool_initial_capacity);
+}
+
+static iree_status_t
+iree_hal_amdgpu_physical_device_initialize_default_pool_resources(
+    iree_hal_device_t* logical_device, iree_hal_amdgpu_system_t* system,
+    const iree_hal_amdgpu_physical_device_options_t* options,
+    iree_async_proactor_t* proactor, iree_host_size_t device_ordinal,
+    hsa_amd_memory_pool_t coarse_block_memory_pool,
+    iree_hal_queue_affinity_t queue_affinity_mask,
+    iree_allocator_t host_allocator,
+    iree_hal_amdgpu_physical_device_t* out_physical_device) {
+  iree_hal_amdgpu_libhsa_t* libhsa = &system->libhsa;
+
+  IREE_RETURN_IF_ERROR(iree_async_notification_create(
+      proactor, IREE_ASYNC_NOTIFICATION_FLAG_NONE,
+      &out_physical_device->default_pool_notification));
+
+  char trace_name[64] = {0};
+  iree_string_view_t slab_trace_name = iree_hal_amdgpu_format_pool_trace_name(
+      trace_name, IREE_ARRAYSIZE(trace_name), "default-slab", device_ordinal);
+  iree_hal_amdgpu_slab_provider_memory_pool_properties_t properties;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_slab_provider_query_memory_pool_properties(
+          libhsa, coarse_block_memory_pool, &properties));
+  const iree_hal_amdgpu_slab_provider_options_t default_slab_options = {
+      .memory_pool = coarse_block_memory_pool,
+      .memory_type = properties.memory_type,
+      .supported_usage = properties.supported_usage,
+  };
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_slab_provider_create(
+      logical_device, libhsa, &system->topology, default_slab_options,
+      device_ordinal, queue_affinity_mask,
+      &out_physical_device->materialized_buffer_pool, slab_trace_name,
+      host_allocator, &out_physical_device->default_slab_provider));
+
+  if (properties.allocation_alignment < options->default_pool.alignment) {
+    return iree_make_status(
+        IREE_STATUS_FAILED_PRECONDITION,
+        "default pool alignment %" PRIu64
+        " exceeds HSA memory pool allocation alignment %" PRIu64,
+        (uint64_t)options->default_pool.alignment,
+        (uint64_t)properties.allocation_alignment);
+  }
+  iree_device_size_t range_length = options->default_pool.range_length;
+  if (!iree_device_size_checked_align(
+          range_length, properties.allocation_granule, &range_length)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "default pool range_length %" PRIu64
+        " overflows while aligning to HSA allocation granule %" PRIu64,
+        (uint64_t)options->default_pool.range_length,
+        (uint64_t)properties.allocation_granule);
+  }
+  out_physical_device->default_pool_options = (iree_hal_tlsf_pool_options_t){
+      .tlsf_options =
+          {
+              .range_length = range_length,
+              .alignment = options->default_pool.alignment,
+              .frontier_capacity = options->default_pool.frontier_capacity,
+          },
+      .budget_limit = 0,
+  };
+
+  iree_hal_amdgpu_slab_provider_memory_pool_properties_t host_properties;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_slab_provider_query_memory_pool_properties(
+          libhsa, out_physical_device->host_memory_pools.fine_pool,
+          &host_properties));
+  if (host_properties.allocation_alignment < options->default_pool.alignment) {
+    return iree_make_status(
+        IREE_STATUS_FAILED_PRECONDITION,
+        "host queue allocation pool alignment %" PRIu64
+        " exceeds HSA host memory pool allocation alignment %" PRIu64,
+        (uint64_t)options->default_pool.alignment,
+        (uint64_t)host_properties.allocation_alignment);
+  }
+  char host_trace_name[64] = {0};
+  iree_string_view_t host_slab_trace_name =
+      iree_hal_amdgpu_format_pool_trace_name(host_trace_name,
+                                             IREE_ARRAYSIZE(host_trace_name),
+                                             "host-slab", device_ordinal);
+  const iree_hal_amdgpu_slab_provider_options_t host_slab_options = {
+      .memory_pool = out_physical_device->host_memory_pools.fine_pool,
+      .memory_type =
+          IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
+      .supported_usage =
+          iree_hal_amdgpu_physical_device_mappable_pool_supported_usage(),
+  };
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_slab_provider_create(
+      logical_device, libhsa, &system->topology, host_slab_options,
+      device_ordinal, queue_affinity_mask,
+      &out_physical_device->materialized_buffer_pool, host_slab_trace_name,
+      host_allocator, &out_physical_device->default_host_slab_provider));
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_physical_device_initialize_staging(
+    iree_hal_device_t* logical_device, iree_hal_amdgpu_system_t* system,
+    const iree_hal_amdgpu_physical_device_options_t* options,
+    const iree_hal_amdgpu_host_memory_pools_t* host_memory_pools,
+    iree_hal_queue_affinity_t queue_affinity_mask,
+    iree_allocator_t host_allocator,
+    iree_hal_amdgpu_physical_device_t* out_physical_device) {
+  return iree_hal_amdgpu_staging_pool_initialize(
+      logical_device, &system->libhsa, &system->topology, host_memory_pools,
+      queue_affinity_mask, &options->file_staging, host_allocator,
+      &out_physical_device->file_staging_pool);
+}
+
+static iree_status_t iree_hal_amdgpu_physical_device_initialize_signal_pool(
+    iree_hal_amdgpu_libhsa_t* libhsa, iree_allocator_t host_allocator,
+    iree_hal_amdgpu_physical_device_t* out_physical_device) {
+  return iree_hal_amdgpu_host_signal_pool_initialize(
+      libhsa,
+      /*initial_capacity=*/IREE_HAL_AMDGPU_HOST_SIGNAL_POOL_BATCH_SIZE_DEFAULT,
+      /*batch_size=*/0, host_allocator, &out_physical_device->host_signal_pool);
+}
+
+static iree_status_t
+iree_hal_amdgpu_physical_device_initialize_device_library_and_blit_context(
+    iree_hal_amdgpu_system_t* system, hsa_agent_t device_agent,
+    iree_host_size_t device_ordinal,
+    iree_hal_amdgpu_physical_device_t* out_physical_device) {
+  iree_hal_amdgpu_libhsa_t* libhsa = &system->libhsa;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_device_library_populate_agent_kernels(
+      &system->device_library, device_agent,
+      &out_physical_device->device_kernels));
+
+  uint32_t compute_unit_count = 0;
+  IREE_RETURN_IF_ERROR(iree_hsa_agent_get_info(
+      IREE_LIBHSA(libhsa), device_agent,
+      (hsa_agent_info_t)HSA_AMD_AGENT_INFO_COMPUTE_UNIT_COUNT,
+      &compute_unit_count));
+  uint32_t wavefront_size = 0;
+  IREE_RETURN_IF_ERROR(
+      iree_hsa_agent_get_info(IREE_LIBHSA(libhsa), device_agent,
+                              HSA_AGENT_INFO_WAVEFRONT_SIZE, &wavefront_size));
+
+  // Validate launch metadata before passing it to the blit context. A broken
+  // HSA bring-up that returns garbage here must fail loud with a clear message
+  // rather than letting the blit path silently dispatch with wrong geometry.
+  if (compute_unit_count == 0) {
+    return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                            "HSA reported 0 compute units for device agent "
+                            "ordinal %" PRIhsz,
+                            device_ordinal);
+  }
+  if (wavefront_size != 32 && wavefront_size != 64) {
+    return iree_make_status(
+        IREE_STATUS_FAILED_PRECONDITION,
+        "HSA reported unsupported wavefront size %u for device agent ordinal "
+        "%" PRIhsz " (expected 32 or 64)",
+        wavefront_size, device_ordinal);
+  }
+  iree_hal_amdgpu_device_buffer_transfer_context_initialize(
+      &out_physical_device->device_kernels, compute_unit_count, wavefront_size,
+      &out_physical_device->buffer_transfer_context);
+  return iree_ok_status();
+}
+
+static iree_status_t
+iree_hal_amdgpu_physical_device_initialize_vendor_packet_strategy(
+    iree_hal_amdgpu_libhsa_t* libhsa,
+    const iree_hal_amdgpu_physical_device_options_t* options,
+    hsa_agent_t device_agent,
+    iree_hal_amdgpu_physical_device_t* out_physical_device) {
+  iree_hal_amdgpu_target_id_t target_id;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_query_agent_target_id(
+      libhsa, device_agent,
+      sizeof(out_physical_device->isa.target_id_processor),
+      out_physical_device->isa.target_id_processor, &target_id));
+  iree_hal_amdgpu_gfxip_version_t gfxip_version = target_id.version;
+
+  iree_hal_amdgpu_vendor_packet_capability_flags_t vendor_packet_capabilities =
+      iree_hal_amdgpu_select_vendor_packet_capabilities(gfxip_version);
+  iree_hal_amdgpu_wait_barrier_strategy_t wait_barrier_strategy =
+      IREE_HAL_AMDGPU_WAIT_BARRIER_STRATEGY_DEFER;
+  if (!options->force_wait_barrier_defer) {
+    wait_barrier_strategy = iree_hal_amdgpu_select_wait_barrier_strategy(
+        vendor_packet_capabilities);
+  }
+  out_physical_device->isa.target_id = target_id;
+  out_physical_device->vendor_packet_capabilities = vendor_packet_capabilities;
+  out_physical_device->wait_barrier_strategy = wait_barrier_strategy;
+  return iree_ok_status();
+}
+
+iree_status_t iree_hal_amdgpu_physical_device_initialize(
+    iree_hal_device_t* logical_device, iree_hal_amdgpu_system_t* system,
+    const iree_hal_amdgpu_physical_device_options_t* options,
+    iree_async_proactor_t* proactor, iree_host_size_t host_ordinal,
+    const iree_hal_amdgpu_host_memory_pools_t* host_memory_pools,
+    iree_host_size_t device_ordinal, iree_allocator_t host_allocator,
+    iree_hal_amdgpu_physical_device_t* out_physical_device) {
+  IREE_ASSERT_ARGUMENT(logical_device);
+  IREE_ASSERT_ARGUMENT(system);
+  IREE_ASSERT_ARGUMENT(options);
+  IREE_ASSERT_ARGUMENT(proactor);
+  IREE_ASSERT_ARGUMENT(host_memory_pools);
+  IREE_ASSERT_ARGUMENT(out_physical_device);
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  iree_hal_amdgpu_libhsa_t* libhsa = &system->libhsa;
+  hsa_agent_t device_agent = system->topology.gpu_agents[device_ordinal];
+
+  iree_status_t status = iree_hal_amdgpu_physical_device_initialize_identity(
+      system, options, host_ordinal, host_memory_pools, device_ordinal,
+      out_physical_device);
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_physical_device_initialize_host_pools(
+        options, host_allocator, out_physical_device);
+  }
+
+  // Find the device memory pools and create block pools/allocators.
+  hsa_amd_memory_pool_t coarse_block_memory_pool = {0};
+  hsa_amd_memory_pool_t fine_block_memory_pool = {0};
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_physical_device_query_global_memory_pools(
+        libhsa, device_agent, &coarse_block_memory_pool,
+        &fine_block_memory_pool);
+  }
+  if (iree_status_is_ok(status)) {
+    out_physical_device->prepublished_kernarg_storage =
+        iree_hal_amdgpu_select_prepublished_kernarg_storage(
+            fine_block_memory_pool);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_physical_device_preallocate_host_pool(
+        options, out_physical_device);
+  }
+  if (iree_status_is_ok(status)) {
+    status =
+        iree_hal_amdgpu_physical_device_initialize_device_block_pools_and_allocators(
+            libhsa, options, device_agent, device_ordinal,
+            coarse_block_memory_pool, fine_block_memory_pool, host_allocator,
+            out_physical_device);
+  }
+
+  // Create the default queue-allocation slab provider over the device
+  // coarse-grained HSA pool and derive the TLSF policy used once topology
+  // assignment provides an epoch query.
+  iree_hal_queue_affinity_t queue_affinity_mask = 0;
+  const iree_hal_amdgpu_queue_affinity_domain_t queue_affinity_domain = {
+      .supported_affinity = IREE_HAL_QUEUE_AFFINITY_ANY,
+      .physical_device_count = system->topology.gpu_agent_count,
+      .queue_count_per_physical_device = options->host_queue_count,
+  };
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_queue_affinity_for_physical_device(
+        queue_affinity_domain, device_ordinal, &queue_affinity_mask);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_physical_device_initialize_default_pool_resources(
+        logical_device, system, options, proactor, device_ordinal,
+        coarse_block_memory_pool, queue_affinity_mask, host_allocator,
+        out_physical_device);
+  }
+
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_physical_device_initialize_staging(
+        logical_device, system, options, host_memory_pools, queue_affinity_mask,
+        host_allocator, out_physical_device);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_physical_device_initialize_signal_pool(
+        libhsa, host_allocator, out_physical_device);
+  }
+  if (iree_status_is_ok(status)) {
+    status =
+        iree_hal_amdgpu_physical_device_initialize_device_library_and_blit_context(
+            system, device_agent, device_ordinal, out_physical_device);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_physical_device_initialize_vendor_packet_strategy(
+        libhsa, options, device_agent, out_physical_device);
+  }
+  if (iree_status_is_ok(status)) {
+    status =
+        iree_hal_amdgpu_physical_device_initialize_cpu_visible_device_coarse_memory(
+            libhsa, device_agent, coarse_block_memory_pool,
+            out_physical_device->isa.target_id.version, &system->topology,
+            &out_physical_device->cpu_visible_device_coarse_memory);
+  }
+  if (iree_status_is_ok(status)) {
+    status =
+        iree_hal_amdgpu_physical_device_initialize_memory_system_capabilities(
+            libhsa, &system->info, device_agent, fine_block_memory_pool,
+            &out_physical_device->cpu_visible_device_coarse_memory,
+            &out_physical_device->memory_system);
+  }
+
+  if (!iree_status_is_ok(status)) {
+    iree_hal_amdgpu_physical_device_deinitialize(out_physical_device);
+  }
+
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+static iree_status_t iree_hal_amdgpu_physical_device_create_pool_pair(
+    iree_hal_amdgpu_physical_device_t* physical_device,
+    iree_hal_amdgpu_epoch_signal_table_t* epoch_signal_table,
+    iree_hal_slab_provider_t* slab_provider,
+    iree_hal_tlsf_pool_options_t pool_options, const char* pool_name,
+    const char* oversized_pool_name, iree_allocator_t host_allocator,
+    iree_hal_pool_t** out_pool, iree_hal_pool_t** out_oversized_pool) {
+  char pool_trace_name[64] = {0};
+  pool_options.trace_name = iree_hal_amdgpu_format_pool_trace_name(
+      pool_trace_name, IREE_ARRAYSIZE(pool_trace_name), pool_name,
+      physical_device->device_ordinal);
+  iree_status_t status = iree_hal_tlsf_pool_create(
+      pool_options, slab_provider, physical_device->default_pool_notification,
+      (iree_hal_pool_epoch_query_t){
+          .fn = iree_hal_amdgpu_physical_device_query_pool_epoch,
+          .user_data = epoch_signal_table,
+      },
+      host_allocator, out_pool);
+
+  char oversized_pool_trace_name[64] = {0};
+  if (iree_status_is_ok(status)) {
+    iree_hal_passthrough_pool_options_t oversized_pool_options = {
+        .trace_name = iree_hal_amdgpu_format_pool_trace_name(
+            oversized_pool_trace_name,
+            IREE_ARRAYSIZE(oversized_pool_trace_name), oversized_pool_name,
+            physical_device->device_ordinal),
+    };
+    status = iree_hal_passthrough_pool_create(
+        oversized_pool_options, slab_provider,
+        physical_device->default_pool_notification, host_allocator,
+        out_oversized_pool);
+  }
+  return status;
+}
+
+static iree_status_t iree_hal_amdgpu_physical_device_create_default_pools(
+    iree_hal_amdgpu_physical_device_t* physical_device,
+    iree_hal_amdgpu_epoch_signal_table_t* epoch_signal_table,
+    iree_allocator_t host_allocator) {
+  IREE_RETURN_IF_ERROR(iree_hal_pool_set_initialize(
+      /*initial_capacity=*/4, host_allocator,
+      &physical_device->default_pool_set));
+
+  iree_status_t status = iree_hal_amdgpu_physical_device_create_pool_pair(
+      physical_device, epoch_signal_table,
+      physical_device->default_slab_provider,
+      physical_device->default_pool_options, "tlsf", "oversized",
+      host_allocator, &physical_device->default_pool,
+      &physical_device->default_oversized_pool);
+  iree_hal_tlsf_pool_options_t host_pool_options =
+      physical_device->default_pool_options;
+  host_pool_options.tlsf_options.range_length =
+      IREE_HAL_AMDGPU_PHYSICAL_DEVICE_HOST_POOL_RANGE_LENGTH_DEFAULT;
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_physical_device_create_pool_pair(
+        physical_device, epoch_signal_table,
+        physical_device->default_host_slab_provider, host_pool_options,
+        "host-tlsf", "host-oversized", host_allocator,
+        &physical_device->default_host_pool,
+        &physical_device->default_host_oversized_pool);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_pool_set_register(
+        &physical_device->default_pool_set,
+        IREE_HAL_AMDGPU_PHYSICAL_DEVICE_DEFAULT_POOL_PRIORITY_OVERSIZED,
+        physical_device->default_oversized_pool);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_pool_set_register(
+        &physical_device->default_pool_set,
+        IREE_HAL_AMDGPU_PHYSICAL_DEVICE_DEFAULT_POOL_PRIORITY_OVERSIZED,
+        physical_device->default_host_oversized_pool);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_pool_set_register(
+        &physical_device->default_pool_set,
+        IREE_HAL_AMDGPU_PHYSICAL_DEVICE_DEFAULT_POOL_PRIORITY_TLSF,
+        physical_device->default_pool);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_pool_set_register(
+        &physical_device->default_pool_set,
+        IREE_HAL_AMDGPU_PHYSICAL_DEVICE_DEFAULT_POOL_PRIORITY_TLSF,
+        physical_device->default_host_pool);
+  }
+  return status;
+}
+
+iree_status_t iree_hal_amdgpu_physical_device_assign_frontier(
+    iree_hal_device_t* logical_device, iree_hal_amdgpu_system_t* system,
+    iree_async_proactor_t* proactor,
+    iree_async_frontier_tracker_t* frontier_tracker,
+    iree_async_axis_t base_axis,
+    iree_hal_amdgpu_epoch_signal_table_t* epoch_signal_table,
+    const iree_hal_amdgpu_host_memory_pools_t* host_memory_pools,
+    iree_allocator_t host_allocator,
+    iree_hal_amdgpu_physical_device_t* physical_device) {
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  iree_hal_amdgpu_libhsa_t* libhsa = &system->libhsa;
+  const uint8_t session_epoch = iree_async_axis_session(base_axis);
+  const uint8_t machine_index = iree_async_axis_machine(base_axis);
+  iree_status_t status = iree_hal_amdgpu_physical_device_create_default_pools(
+      physical_device, epoch_signal_table, host_allocator);
+  const iree_hal_amdgpu_queue_affinity_domain_t queue_affinity_domain = {
+      .supported_affinity = IREE_HAL_QUEUE_AFFINITY_ANY,
+      .physical_device_count = system->topology.gpu_agent_count,
+      .queue_count_per_physical_device = physical_device->host_queue_capacity,
+  };
+  iree_hal_amdgpu_physical_device_kernarg_ring_memory_t kernarg_ring_memory;
+  iree_hal_amdgpu_physical_device_select_kernarg_ring_memory(
+      physical_device, host_memory_pools, &kernarg_ring_memory);
+  for (iree_host_size_t queue_ordinal = 0;
+       queue_ordinal < physical_device->host_queue_capacity &&
+       iree_status_is_ok(status);
+       ++queue_ordinal) {
+    const iree_host_size_t logical_queue_ordinal =
+        physical_device->device_ordinal * physical_device->host_queue_capacity +
+        queue_ordinal;
+    iree_hal_amdgpu_queue_affinity_resolved_t resolved;
+    status = iree_hal_amdgpu_queue_affinity_resolve_ordinal(
+        queue_affinity_domain, logical_queue_ordinal, &resolved);
+    if (!iree_status_is_ok(status)) break;
+    iree_async_axis_t queue_axis = iree_async_axis_make_queue(
+        session_epoch, machine_index, (uint8_t)physical_device->device_ordinal,
+        (uint8_t)queue_ordinal);
+    iree_thread_affinity_t completion_thread_affinity;
+    iree_thread_affinity_set_group_any(physical_device->host_numa_node,
+                                       &completion_thread_affinity);
+    status = iree_hal_amdgpu_host_queue_initialize(
+        libhsa, logical_device, proactor, physical_device->device_agent,
+        &kernarg_ring_memory.descriptor, host_memory_pools->fine_pool,
+        frontier_tracker, queue_axis, resolved.queue_affinity,
+        completion_thread_affinity, physical_device->wait_barrier_strategy,
+        physical_device->vendor_packet_capabilities, epoch_signal_table,
+        &physical_device->fine_host_block_pool,
+        &physical_device->fine_block_pools.small,
+        &physical_device->buffer_transfer_context,
+        &physical_device->default_pool_set, physical_device->default_pool,
+        &physical_device->transient_buffer_pool,
+        &physical_device->file_staging_pool, physical_device->device_ordinal,
+        physical_device->host_queue_aql_capacity,
+        physical_device->host_queue_notification_capacity,
+        physical_device->host_queue_kernarg_capacity,
+        physical_device->host_queue_upload_capacity, host_allocator,
+        &physical_device->host_queues[queue_ordinal]);
+    if (iree_status_is_ok(status)) {
+      physical_device->host_queue_count = queue_ordinal + 1;
+    }
+  }
+
+  if (!iree_status_is_ok(status)) {
+    iree_hal_amdgpu_physical_device_deassign_frontier(physical_device);
+  }
+
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+void iree_hal_amdgpu_physical_device_deassign_frontier(
+    iree_hal_amdgpu_physical_device_t* physical_device) {
+  IREE_TRACE_ZONE_BEGIN(z0);
+  for (iree_host_size_t i = 0; i < physical_device->host_queue_count; ++i) {
+    iree_hal_amdgpu_host_queue_deinitialize(&physical_device->host_queues[i]);
+  }
+  physical_device->host_queue_count = 0;
+  if (physical_device->default_pool_set.entries) {
+    iree_hal_pool_set_deinitialize(&physical_device->default_pool_set);
+  }
+  iree_hal_pool_release(physical_device->default_host_oversized_pool);
+  physical_device->default_host_oversized_pool = NULL;
+  iree_hal_pool_release(physical_device->default_host_pool);
+  physical_device->default_host_pool = NULL;
+  iree_hal_pool_release(physical_device->default_oversized_pool);
+  physical_device->default_oversized_pool = NULL;
+  iree_hal_pool_release(physical_device->default_pool);
+  physical_device->default_pool = NULL;
+  IREE_TRACE_ZONE_END(z0);
+}
+
+iree_status_t iree_hal_amdgpu_physical_device_set_hsa_profiling_enabled(
+    iree_hal_amdgpu_physical_device_t* physical_device, bool enabled) {
+  IREE_ASSERT_ARGUMENT(physical_device);
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, enabled ? 1 : 0);
+
   iree_status_t status = iree_ok_status();
-  if (iree_status_is_ok(status)) {
-    status = iree_hal_amdgpu_block_pool_initialize(
-        libhsa, options->device_block_pools.small, device_agent,
-        coarse_block_memory_pool, host_allocator,
-        &out_physical_device->coarse_block_pools.small);
-  }
-  if (iree_status_is_ok(status)) {
-    status = iree_hal_amdgpu_block_pool_initialize(
-        libhsa, options->device_block_pools.large, device_agent,
-        coarse_block_memory_pool, host_allocator,
-        &out_physical_device->coarse_block_pools.large);
-  }
-  if (iree_status_is_ok(status)) {
-    status = iree_hal_amdgpu_block_pool_initialize(
-        libhsa, options->device_block_pools.small, device_agent,
-        fine_block_memory_pool, host_allocator,
-        &out_physical_device->fine_block_pools.small);
-  }
-  if (iree_status_is_ok(status)) {
-    status = iree_hal_amdgpu_block_pool_initialize(
-        libhsa, options->device_block_pools.large, device_agent,
-        fine_block_memory_pool, host_allocator,
-        &out_physical_device->fine_block_pools.large);
-  }
-  // TODO(benvanik): expose block allocator min_page_size values as options.
-  if (iree_status_is_ok(status)) {
-    status = iree_hal_amdgpu_block_allocator_initialize(
-        &out_physical_device->coarse_block_pools.small,
-        IREE_HAL_AMDGPU_PHYSICAL_DEVICE_COARSE_BLOCK_POOL_SMALL_PAGE_SIZE,
-        &out_physical_device->coarse_block_allocators.small);
-  }
-  if (iree_status_is_ok(status)) {
-    status = iree_hal_amdgpu_block_allocator_initialize(
-        &out_physical_device->coarse_block_pools.large,
-        IREE_HAL_AMDGPU_PHYSICAL_DEVICE_COARSE_BLOCK_POOL_LARGE_PAGE_SIZE,
-        &out_physical_device->coarse_block_allocators.large);
-  }
-  if (iree_status_is_ok(status)) {
-    status = iree_hal_amdgpu_block_allocator_initialize(
-        &out_physical_device->fine_block_pools.small,
-        IREE_HAL_AMDGPU_PHYSICAL_DEVICE_FINE_BLOCK_POOL_SMALL_PAGE_SIZE,
-        &out_physical_device->fine_block_allocators.small);
-  }
-  if (iree_status_is_ok(status)) {
-    status = iree_hal_amdgpu_block_allocator_initialize(
-        &out_physical_device->fine_block_pools.large,
-        IREE_HAL_AMDGPU_PHYSICAL_DEVICE_FINE_BLOCK_POOL_LARGE_PAGE_SIZE,
-        &out_physical_device->fine_block_allocators.large);
-  }
-
-  // Create the host worker thread that will handle scheduler requests.
-  // Each queue on this physical device will share the same worker today but we
-  // could change that if we become host-bound. In general we should not be
-  // using the host during our latency-critical operations but it's possible if
-  // memory pool growth/trims take awhile that we end up serializing multiple
-  // device queues.
-  if (iree_status_is_ok(status)) {
-    status = iree_hal_amdgpu_host_service_initialize(
-        libhsa, host_ordinal, host_agent, host_memory_pools->fine_region,
-        device_ordinal, error_callback, host_allocator,
-        &out_physical_device->host_service);
-  }
-
-  // Initialize each queue and its device-side scheduler.
-  // Note that initialization may happen asynchronously with the
-  // initialization_signal. Queues may require host services to initialize and
-  // this must happen after all other physical device state has completed
-  // initialization.
+  iree_host_size_t changed_count = 0;
   for (iree_host_size_t i = 0;
-       iree_status_is_ok(status) && i < options->queue_count; ++i) {
-    status = iree_hal_amdgpu_queue_initialize(
-        system, options->queue_options, device_agent, device_ordinal,
-        &out_physical_device->host_service,
-        &out_physical_device->fine_host_block_pool,
-        &out_physical_device->fine_block_allocators, buffer_pool,
-        error_callback, initialization_signal, host_allocator,
-        out_physical_device->queues[i]);
+       i < physical_device->host_queue_count && iree_status_is_ok(status);
+       ++i) {
+    status = iree_hal_amdgpu_host_queue_set_hsa_profiling_enabled(
+        &physical_device->host_queues[i], enabled);
+    if (iree_status_is_ok(status)) {
+      ++changed_count;
+    }
+  }
+
+  if (!iree_status_is_ok(status) && enabled) {
+    for (iree_host_size_t i = 0; i < changed_count; ++i) {
+      status = iree_status_join(
+          status, iree_hal_amdgpu_host_queue_set_hsa_profiling_enabled(
+                      &physical_device->host_queues[i], false));
+    }
+  } else if (!enabled) {
+    for (iree_host_size_t i = changed_count;
+         i < physical_device->host_queue_count; ++i) {
+      status = iree_status_join(
+          status, iree_hal_amdgpu_host_queue_set_hsa_profiling_enabled(
+                      &physical_device->host_queues[i], false));
+    }
   }
 
   IREE_TRACE_ZONE_END(z0);
@@ -332,24 +1232,25 @@
   IREE_ASSERT_ARGUMENT(physical_device);
   IREE_TRACE_ZONE_BEGIN(z0);
 
-  // Deinitialize all queues and their device-side schedulers before releasing
-  // any resources that may be used by them (such as the host worker).
-  for (iree_host_size_t i = 0; i < physical_device->queue_count; ++i) {
-    iree_hal_amdgpu_virtual_queue_t* queue = physical_device->queues[i];
-    if (queue) {
-      queue->vtable->deinitialize(queue);
-    }
-  }
+  iree_hal_amdgpu_physical_device_deassign_frontier(physical_device);
 
-  // Deinitialize the host service only after all queues have fully terminated.
-  iree_hal_amdgpu_host_service_deinitialize(&physical_device->host_service);
+  iree_hal_amdgpu_host_signal_pool_deinitialize(
+      &physical_device->host_signal_pool);
 
-  // Note that the host service may be using allocations from the host block
-  // pool and that this must happen after it has fully terminated.
+  iree_hal_slab_provider_release(physical_device->default_slab_provider);
+  iree_hal_slab_provider_release(physical_device->default_host_slab_provider);
+  iree_async_notification_release(physical_device->default_pool_notification);
+
+  iree_hal_amdgpu_staging_pool_deinitialize(
+      &physical_device->file_staging_pool);
+
+  iree_hal_amdgpu_transient_buffer_pool_deinitialize(
+      &physical_device->transient_buffer_pool);
+  iree_hal_amdgpu_buffer_pool_deinitialize(
+      &physical_device->materialized_buffer_pool);
+
   iree_arena_block_pool_deinitialize(&physical_device->fine_host_block_pool);
 
-  // Note that other per-device data structures may be using blocks until they
-  // are deinitialized and this must be deinitialized last.
   iree_hal_amdgpu_block_allocator_deinitialize(
       &physical_device->coarse_block_allocators.small);
   iree_hal_amdgpu_block_allocator_deinitialize(
@@ -372,15 +1273,14 @@
   IREE_TRACE_ZONE_END(z0);
 }
 
-void iree_hal_amdgpu_physical_device_trim(
+iree_status_t iree_hal_amdgpu_physical_device_trim(
     iree_hal_amdgpu_physical_device_t* physical_device) {
   IREE_ASSERT_ARGUMENT(physical_device);
   IREE_TRACE_ZONE_BEGIN(z0);
 
-  // Trim queues first to release resources back to block pools.
-  for (iree_host_size_t i = 0; i < physical_device->queue_count; ++i) {
-    iree_hal_amdgpu_virtual_queue_t* queue = physical_device->queues[i];
-    queue->vtable->trim(queue);
+  for (iree_host_size_t i = 0; i < physical_device->host_queue_count; ++i) {
+    physical_device->host_queues[i].base.vtable->trim(
+        &physical_device->host_queues[i].base);
   }
 
   iree_hal_amdgpu_block_pool_trim(&physical_device->coarse_block_pools.small);
@@ -390,5 +1290,21 @@
 
   iree_arena_block_pool_trim(&physical_device->fine_host_block_pool);
 
+  iree_status_t status = iree_ok_status();
+  if (physical_device->default_pool) {
+    status = iree_hal_pool_trim(physical_device->default_pool);
+  }
+  if (iree_status_is_ok(status) && physical_device->default_oversized_pool) {
+    status = iree_hal_pool_trim(physical_device->default_oversized_pool);
+  }
+  if (iree_status_is_ok(status) && physical_device->default_host_pool) {
+    status = iree_hal_pool_trim(physical_device->default_host_pool);
+  }
+  if (iree_status_is_ok(status) &&
+      physical_device->default_host_oversized_pool) {
+    status = iree_hal_pool_trim(physical_device->default_host_oversized_pool);
+  }
+
   IREE_TRACE_ZONE_END(z0);
+  return status;
 }
diff --git a/runtime/src/iree/hal/drivers/amdgpu/physical_device.h b/runtime/src/iree/hal/drivers/amdgpu/physical_device.h
index 548193f..867f799 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/physical_device.h
+++ b/runtime/src/iree/hal/drivers/amdgpu/physical_device.h
@@ -9,13 +9,21 @@
 
 #include "iree/base/api.h"
 #include "iree/base/internal/arena.h"
-#include "iree/hal/drivers/amdgpu/host_service.h"
+#include "iree/hal/drivers/amdgpu/buffer.h"
+#include "iree/hal/drivers/amdgpu/host_queue.h"
+#include "iree/hal/drivers/amdgpu/host_queue_staging.h"
+#include "iree/hal/drivers/amdgpu/physical_device_capabilities.h"
 #include "iree/hal/drivers/amdgpu/system.h"
+#include "iree/hal/drivers/amdgpu/transient_buffer.h"
 #include "iree/hal/drivers/amdgpu/util/block_pool.h"
 #include "iree/hal/drivers/amdgpu/util/libhsa.h"
-#include "iree/hal/drivers/amdgpu/virtual_queue.h"
+#include "iree/hal/drivers/amdgpu/util/signal_pool.h"
+#include "iree/hal/drivers/amdgpu/util/target_id.h"
+#include "iree/hal/memory/slab_provider.h"
+#include "iree/hal/memory/tlsf_pool.h"
+#include "iree/hal/pool.h"
+#include "iree/hal/pool_set.h"
 
-typedef struct iree_hal_amdgpu_buffer_pool_t iree_hal_amdgpu_buffer_pool_t;
 typedef struct iree_hal_amdgpu_host_memory_pools_t
     iree_hal_amdgpu_host_memory_pools_t;
 
@@ -61,8 +69,38 @@
 // not be large.
 #define IREE_HAL_AMDGPU_PHYSICAL_DEVICE_HOST_BLOCK_SIZE_DEFAULT (8 * 1024)
 
+// Logical byte length for the default per-device queue-allocation pool.
+#define IREE_HAL_AMDGPU_PHYSICAL_DEVICE_DEFAULT_POOL_RANGE_LENGTH_DEFAULT \
+  (64 * 1024 * 1024)
+
+// Logical byte length for host-visible default queue-allocation pool slabs.
+#define IREE_HAL_AMDGPU_PHYSICAL_DEVICE_HOST_POOL_RANGE_LENGTH_DEFAULT \
+  (64 * 1024)
+
+// Minimum byte alignment for default-pool suballocations.
+#define IREE_HAL_AMDGPU_PHYSICAL_DEVICE_DEFAULT_POOL_ALIGNMENT_DEFAULT 256
+
+// Maximum death-frontier entries stored per free default-pool block.
+#define IREE_HAL_AMDGPU_PHYSICAL_DEVICE_DEFAULT_POOL_FRONTIER_CAPACITY_DEFAULT \
+  IREE_HAL_MEMORY_TLSF_DEFAULT_FRONTIER_CAPACITY
+
 // Total number of HAL queues on the physical device.
-#define IREE_HAL_AMDGPU_PHYSICAL_DEVICE_DEFAULT_QUEUE_COUNT (1)
+#define IREE_HAL_AMDGPU_PHYSICAL_DEVICE_DEFAULT_QUEUE_COUNT \
+  IREE_HAL_AMDGPU_DEFAULT_GPU_AGENT_QUEUE_COUNT
+
+// Default per-queue hardware AQL ring capacity in packets.
+#define IREE_HAL_AMDGPU_PHYSICAL_DEVICE_DEFAULT_HOST_QUEUE_AQL_CAPACITY \
+  IREE_HAL_AMDGPU_DEFAULT_EXECUTION_QUEUE_CAPACITY
+
+// Default per-queue completion/reclaim ring capacity in epochs and hot entries.
+#define IREE_HAL_AMDGPU_PHYSICAL_DEVICE_DEFAULT_HOST_QUEUE_NOTIFICATION_CAPACITY \
+  IREE_HAL_AMDGPU_DEFAULT_NOTIFICATION_CAPACITY
+
+// Default per-queue kernarg ring capacity in 64-byte blocks.
+#define IREE_HAL_AMDGPU_PHYSICAL_DEVICE_DEFAULT_HOST_QUEUE_KERNARG_CAPACITY \
+  ((uint32_t)(IREE_HAL_AMDGPU_DEFAULT_KERNARG_RINGBUFFER_CAPACITY /         \
+              sizeof(iree_hal_amdgpu_kernarg_block_t)))
+#define IREE_HAL_AMDGPU_PHYSICAL_DEVICE_DEFAULT_HOST_QUEUE_UPLOAD_CAPACITY 0
 
 // Options controlling how a physical device is initialized.
 typedef struct iree_hal_amdgpu_physical_device_options_t {
@@ -85,13 +123,36 @@
   // Initial block count preallocated for the host block pool.
   iree_host_size_t host_block_pool_initial_capacity;
 
-  // Total number of HAL queues on the physical device.
-  iree_host_size_t queue_count;
-  // Options used to initialize each queue.
-  // Currently we assume queues are homogeneous but we may want to expose
-  // bucketed types (e.g. host-side or device-side queues) to allow for tuning
-  // each independently.
-  iree_hal_amdgpu_queue_options_t queue_options;
+  // Number of host queues created for this physical device.
+  iree_host_size_t host_queue_count;
+  // Per-host-queue HSA AQL ring capacity in packets.
+  uint32_t host_queue_aql_capacity;
+  // Per-host-queue completion/reclaim ring capacity.
+  uint32_t host_queue_notification_capacity;
+  // Per-host-queue kernarg ring capacity in 64-byte blocks.
+  uint32_t host_queue_kernarg_capacity;
+  // Per-host-queue device-visible control upload ring capacity in bytes. Zero
+  // disables the optional upload ring.
+  uint32_t host_queue_upload_capacity;
+
+  // Default queue-allocation pool policy.
+  struct {
+    // Logical byte length of the default TLSF pool range.
+    iree_device_size_t range_length;
+
+    // Minimum byte alignment for every default-pool reservation.
+    iree_device_size_t alignment;
+
+    // Maximum death-frontier entry count stored per free TLSF block.
+    uint8_t frontier_capacity;
+  } default_pool;
+
+  // Fixed-size queue_read/queue_write staging policy.
+  iree_hal_amdgpu_staging_pool_options_t file_staging;
+
+  // Forces cross-queue wait barriers to use software deferral instead of the
+  // optimal device-side strategy for the GPU ISA.
+  uint32_t force_wait_barrier_defer : 1;
 } iree_hal_amdgpu_physical_device_options_t;
 
 // Initializes |out_options| to its default values.
@@ -115,6 +176,41 @@
   hsa_agent_t device_agent;
   // Ordinal of the GPU agent within the topology.
   iree_host_size_t device_ordinal;
+  // HSA driver identifier used when querying per-device clock counters.
+  uint32_t driver_uid;
+  // PCI domain from HSA_AMD_AGENT_INFO_DOMAIN.
+  uint32_t pci_domain;
+  // PCI bus decoded from HSA_AMD_AGENT_INFO_BDFID.
+  uint32_t pci_bus;
+  // PCI device decoded from HSA_AMD_AGENT_INFO_BDFID.
+  uint32_t pci_device;
+  // PCI function decoded from HSA_AMD_AGENT_INFO_BDFID.
+  uint32_t pci_function;
+  // True when the PCI identity fields contain HSA-provided values.
+  uint32_t has_pci_identity : 1;
+  // HSA ISA identity selected for this GPU agent.
+  struct {
+    // Storage backing |target_id.processor|.
+    char target_id_processor[64];
+    // Parsed target identity, including XNACK/SRAMECC support and mode.
+    iree_hal_amdgpu_target_id_t target_id;
+  } isa;
+  // Stable physical device UUID bytes reported by HSA when available.
+  uint8_t physical_device_uuid[16];
+  // True when |physical_device_uuid| contains a stable HSA device identifier.
+  uint32_t has_physical_device_uuid : 1;
+  // NUMA node of the CPU agent nearest to |device_agent|.
+  uint32_t host_numa_node;
+  // Host memory pools for the CPU agent nearest to |device_agent|.
+  iree_hal_amdgpu_host_memory_pools_t host_memory_pools;
+  // Cold memory-system facts used to derive conservative topology flags.
+  iree_hal_amdgpu_memory_system_capabilities_t memory_system;
+  // CPU-visible coarse-grained device-memory capability for this GPU.
+  iree_hal_amdgpu_cpu_visible_device_coarse_memory_t
+      cpu_visible_device_coarse_memory;
+  // Prepublished command-buffer kernarg storage capability for this GPU.
+  iree_hal_amdgpu_aql_prepublished_kernarg_storage_t
+      prepublished_kernarg_storage;
 
   // Fine-grained block pools for device memory blocks of various sizes.
   iree_hal_amdgpu_block_pools_t fine_block_pools;
@@ -132,16 +228,62 @@
   // the common case will be that the blocks are touched by the same device.
   iree_arena_block_pool_t fine_host_block_pool;
 
-  // Host-side service worker for supporting device library requests.
-  // Today we have one per physical device but could share them or even have
-  // one per queue.
-  iree_hal_amdgpu_host_service_t host_service;
+  // Per-device pool of user-visible queue_alloca transient buffer wrappers.
+  iree_hal_amdgpu_transient_buffer_pool_t transient_buffer_pool;
 
-  // HAL queues with associated device-side schedulers.
-  iree_host_size_t queue_count;
-  iree_hal_amdgpu_virtual_queue_t* queues[/*queue_count*/];
+  // Per-device pool of materialized slab-backed HAL buffer view wrappers.
+  iree_hal_amdgpu_buffer_pool_t materialized_buffer_pool;
 
-  // + queue storage; queues may be of mixed types and have different sizes
+  // Pool of HSA signals for host-waited semaphores and proactor integration.
+  iree_hal_amdgpu_host_signal_pool_t host_signal_pool;
+
+  // Default queue-allocation pool notification for this physical device.
+  iree_async_notification_t* default_pool_notification;
+  // Slab provider backing default and caller-created pools for this domain.
+  iree_hal_slab_provider_t* default_slab_provider;
+  // Host-local slab provider for mappable queue allocation transients.
+  iree_hal_slab_provider_t* default_host_slab_provider;
+  // TLSF options derived from device options and HSA memory-pool properties.
+  iree_hal_tlsf_pool_options_t default_pool_options;
+  // Routes default queue allocations to the best compatible memory pool.
+  iree_hal_pool_set_t default_pool_set;
+  // Frontier-aware suballocating pool used up to the TLSF slab length.
+  iree_hal_pool_t* default_pool;
+  // Direct per-allocation pool used for requests larger than one TLSF slab.
+  iree_hal_pool_t* default_oversized_pool;
+  // Frontier-aware suballocating pool for host-visible queue allocations.
+  iree_hal_pool_t* default_host_pool;
+  // Direct host-visible pool used for requests larger than one host TLSF slab.
+  iree_hal_pool_t* default_host_oversized_pool;
+
+  // Fixed-size staging pool for non-mappable queue_read/queue_write transfers.
+  iree_hal_amdgpu_staging_pool_t file_staging_pool;
+
+  // Builtin kernel table for this GPU agent.
+  iree_hal_amdgpu_device_kernels_t device_kernels;
+  // Host/device-neutral transfer context that points into |device_kernels|.
+  iree_hal_amdgpu_device_buffer_transfer_context_t buffer_transfer_context;
+
+  // Total number of host queue slots allocated in |host_queues|.
+  iree_host_size_t host_queue_capacity;
+  // Per-host-queue HSA AQL ring capacity in packets.
+  uint32_t host_queue_aql_capacity;
+  // Per-host-queue completion/reclaim ring capacity.
+  uint32_t host_queue_notification_capacity;
+  // Per-host-queue kernarg ring capacity in 64-byte blocks.
+  uint32_t host_queue_kernarg_capacity;
+  // Per-host-queue device-visible control upload ring capacity in bytes. Zero
+  // disables the optional upload ring.
+  uint32_t host_queue_upload_capacity;
+  // AMD vendor-packet capabilities selected from this GPU agent's ISA.
+  iree_hal_amdgpu_vendor_packet_capability_flags_t vendor_packet_capabilities;
+  // Hardware strategy selected for cross-queue epoch waits on this GPU agent.
+  iree_hal_amdgpu_wait_barrier_strategy_t wait_barrier_strategy;
+
+  // Number of live host queues initialized in |host_queues|.
+  iree_host_size_t host_queue_count;
+  // One or more host queues mapped to HSA queues on this physical device.
+  iree_hal_amdgpu_host_queue_t host_queues[/*host_queue_count*/];
 } iree_hal_amdgpu_physical_device_t;
 
 // Returns the aligned heap size in bytes required to store the physical device
@@ -149,37 +291,50 @@
 iree_host_size_t iree_hal_amdgpu_physical_device_calculate_size(
     const iree_hal_amdgpu_physical_device_options_t* options);
 
-// Initializes a physical device with one or more HAL queues.
+// Initializes a physical device.
 // Requires that the |options| have been verified.
 //
-// |initialization_signal| will be incremented as asynchronous initialization
-// operations are enqueued and decremented as they complete. Callers must wait
-// for the completion signal to reach 0 prior to deinitializing the device even
-// if initialization fails.
-//
-// NOTE: if initialization fails callers must call
-// iree_hal_amdgpu_physical_device_deinitialize after |initialization_signal| is
-// reached.
-//
 // |out_physical_device| must reference at least
 // iree_hal_amdgpu_physical_device_calculate_size of valid host memory.
 iree_status_t iree_hal_amdgpu_physical_device_initialize(
-    iree_hal_amdgpu_system_t* system,
+    iree_hal_device_t* logical_device, iree_hal_amdgpu_system_t* system,
     const iree_hal_amdgpu_physical_device_options_t* options,
-    iree_host_size_t host_ordinal,
+    iree_async_proactor_t* proactor, iree_host_size_t host_ordinal,
     const iree_hal_amdgpu_host_memory_pools_t* host_memory_pools,
-    iree_host_size_t device_ordinal, iree_hal_amdgpu_buffer_pool_t* buffer_pool,
-    iree_hal_amdgpu_error_callback_t error_callback,
-    hsa_signal_t initialization_signal, iree_allocator_t host_allocator,
+    iree_host_size_t device_ordinal, iree_allocator_t host_allocator,
     iree_hal_amdgpu_physical_device_t* out_physical_device);
 
+// Binds and initializes this physical device's host queues after the logical
+// device has been assigned a topology/frontier.
+iree_status_t iree_hal_amdgpu_physical_device_assign_frontier(
+    iree_hal_device_t* logical_device, iree_hal_amdgpu_system_t* system,
+    iree_async_proactor_t* proactor,
+    iree_async_frontier_tracker_t* frontier_tracker,
+    iree_async_axis_t base_axis,
+    iree_hal_amdgpu_epoch_signal_table_t* epoch_signal_table,
+    const iree_hal_amdgpu_host_memory_pools_t* host_memory_pools,
+    iree_allocator_t host_allocator,
+    iree_hal_amdgpu_physical_device_t* physical_device);
+
+// Deinitializes any host queues initialized by assign_frontier.
+void iree_hal_amdgpu_physical_device_deassign_frontier(
+    iree_hal_amdgpu_physical_device_t* physical_device);
+
+// Enables or disables HSA dispatch timestamp population on all live queues.
+//
+// On enable failure, queues successfully enabled by this call are disabled
+// before the status is returned. On disable failure, the function attempts all
+// queues and joins failures.
+iree_status_t iree_hal_amdgpu_physical_device_set_hsa_profiling_enabled(
+    iree_hal_amdgpu_physical_device_t* physical_device, bool enabled);
+
 // Deinitializes a physical device and deallocates all device-specific
 // resources.
 void iree_hal_amdgpu_physical_device_deinitialize(
     iree_hal_amdgpu_physical_device_t* physical_device);
 
 // Releases any unused pooled resources.
-void iree_hal_amdgpu_physical_device_trim(
+iree_status_t iree_hal_amdgpu_physical_device_trim(
     iree_hal_amdgpu_physical_device_t* physical_device);
 
 #endif  // IREE_HAL_DRIVERS_AMDGPU_PHYSICAL_DEVICE_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/physical_device_capabilities.c b/runtime/src/iree/hal/drivers/amdgpu/physical_device_capabilities.c
new file mode 100644
index 0000000..e528715
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/physical_device_capabilities.c
@@ -0,0 +1,546 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/physical_device_capabilities.h"
+
+#include <stdint.h>
+#include <string.h>
+
+// Inclusive unsigned 16-bit range used for gfx IP table matching.
+typedef struct iree_hal_amdgpu_uint16_range_t {
+  // Inclusive lower bound.
+  uint16_t min;
+  // Inclusive upper bound.
+  uint16_t max;
+} iree_hal_amdgpu_uint16_range_t;
+
+// Gfx IP version range matched by a capability table row.
+typedef struct iree_hal_amdgpu_gfxip_version_range_t {
+  // Accepted major version range.
+  iree_hal_amdgpu_uint16_range_t major;
+  // Accepted minor version range.
+  iree_hal_amdgpu_uint16_range_t minor;
+  // Accepted stepping range.
+  iree_hal_amdgpu_uint16_range_t stepping;
+} iree_hal_amdgpu_gfxip_version_range_t;
+
+// AMD vendor-packet capability table row.
+typedef struct iree_hal_amdgpu_vendor_packet_capability_row_t {
+  // Gfx IP version range matched by this row.
+  iree_hal_amdgpu_gfxip_version_range_t version;
+  // Vendor-packet and PM4 packet-family capabilities enabled by this row.
+  iree_hal_amdgpu_vendor_packet_capability_flags_t capabilities;
+} iree_hal_amdgpu_vendor_packet_capability_row_t;
+
+enum {
+  // Packet families validated on the local gfx1100 bring-up system.
+  IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_GFX1100_VALIDATED =
+      IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_AQL_PM4_IB |
+      IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_WAIT_REG_MEM64 |
+      IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_WRITE_DATA_MEMORY |
+      IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_COPY_DATA_MEMORY |
+      IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_COPY_TIMESTAMP |
+      IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_RELEASE_MEM_TIMESTAMP |
+      IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_EVENT_WRITE |
+      IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_SET_SH_REG |
+      IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_SET_UCONFIG_REG |
+      IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_REGISTER_READBACK |
+      IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_PERFCOUNTER_READBACK |
+      IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_IMMEDIATE_WRITE,
+};
+
+static bool iree_hal_amdgpu_uint16_range_contains(
+    iree_hal_amdgpu_uint16_range_t range, uint16_t value) {
+  return value >= range.min && value <= range.max;
+}
+
+static bool iree_hal_amdgpu_gfxip_version_range_contains(
+    iree_hal_amdgpu_gfxip_version_range_t range,
+    iree_hal_amdgpu_gfxip_version_t version) {
+  return iree_hal_amdgpu_uint16_range_contains(range.major, version.major) &&
+         iree_hal_amdgpu_uint16_range_contains(range.minor, version.minor) &&
+         iree_hal_amdgpu_uint16_range_contains(range.stepping,
+                                               version.stepping);
+}
+
+bool iree_hal_amdgpu_cpu_visible_device_coarse_memory_is_available(
+    const iree_hal_amdgpu_cpu_visible_device_coarse_memory_t* memory) {
+  return iree_any_bit_set(
+      memory->flags,
+      IREE_HAL_AMDGPU_CPU_VISIBLE_DEVICE_COARSE_MEMORY_FLAG_AVAILABLE);
+}
+
+bool iree_hal_amdgpu_memory_pool_access_is_valid(
+    hsa_amd_memory_pool_access_t access) {
+  switch (access) {
+    case HSA_AMD_MEMORY_POOL_ACCESS_NEVER_ALLOWED:
+    case HSA_AMD_MEMORY_POOL_ACCESS_ALLOWED_BY_DEFAULT:
+    case HSA_AMD_MEMORY_POOL_ACCESS_DISALLOWED_BY_DEFAULT:
+      return true;
+    default:
+      return false;
+  }
+}
+
+iree_hal_topology_interop_mode_t
+iree_hal_amdgpu_memory_pool_access_topology_mode(
+    hsa_amd_memory_pool_access_t access) {
+  IREE_ASSERT(iree_hal_amdgpu_memory_pool_access_is_valid(access),
+              "invalid HSA memory-pool access mode");
+  switch (access) {
+    case HSA_AMD_MEMORY_POOL_ACCESS_ALLOWED_BY_DEFAULT:
+      return IREE_HAL_TOPOLOGY_INTEROP_MODE_NATIVE;
+    case HSA_AMD_MEMORY_POOL_ACCESS_DISALLOWED_BY_DEFAULT:
+      return IREE_HAL_TOPOLOGY_INTEROP_MODE_COPY;
+    case HSA_AMD_MEMORY_POOL_ACCESS_NEVER_ALLOWED:
+    default:
+      return IREE_HAL_TOPOLOGY_INTEROP_MODE_COPY;
+  }
+}
+
+iree_hal_topology_capability_t
+iree_hal_amdgpu_memory_pool_access_topology_capabilities(
+    hsa_amd_memory_pool_access_t access) {
+  IREE_ASSERT(iree_hal_amdgpu_memory_pool_access_is_valid(access),
+              "invalid HSA memory-pool access mode");
+  if (access == HSA_AMD_MEMORY_POOL_ACCESS_DISALLOWED_BY_DEFAULT) {
+    return IREE_HAL_TOPOLOGY_CAPABILITY_PEER_ACCESS_REQUIRES_GRANT;
+  }
+  return IREE_HAL_TOPOLOGY_CAPABILITY_NONE;
+}
+
+// Maps an HSA link type to a HAL topology link class.
+//
+// For multi-hop links, callers should take the worst/highest class.
+static iree_hal_topology_link_class_t iree_hal_amdgpu_link_type_to_link_class(
+    hsa_amd_link_info_type_t link_type) {
+  switch (link_type) {
+    case HSA_AMD_LINK_INFO_TYPE_XGMI:
+      return IREE_HAL_TOPOLOGY_LINK_CLASS_NVLINK_IF;
+    case HSA_AMD_LINK_INFO_TYPE_PCIE:
+      return IREE_HAL_TOPOLOGY_LINK_CLASS_PCIE_SAME_ROOT;
+    case HSA_AMD_LINK_INFO_TYPE_QPI:
+    case HSA_AMD_LINK_INFO_TYPE_HYPERTRANSPORT:
+      // Cross-socket interconnects: treat as cross-root PCIe.
+      return IREE_HAL_TOPOLOGY_LINK_CLASS_PCIE_CROSS_ROOT;
+    case HSA_AMD_LINK_INFO_TYPE_INFINBAND:
+      return IREE_HAL_TOPOLOGY_LINK_CLASS_FABRIC;
+    default:
+      return IREE_HAL_TOPOLOGY_LINK_CLASS_OTHER;
+  }
+}
+
+static iree_hal_amdgpu_physical_topology_link_flags_t
+iree_hal_amdgpu_link_type_to_physical_topology_link_flags(
+    hsa_amd_link_info_type_t link_type) {
+  switch (link_type) {
+    case HSA_AMD_LINK_INFO_TYPE_PCIE:
+      return IREE_HAL_AMDGPU_PHYSICAL_TOPOLOGY_LINK_FLAG_PCIE;
+    case HSA_AMD_LINK_INFO_TYPE_XGMI:
+      return IREE_HAL_AMDGPU_PHYSICAL_TOPOLOGY_LINK_FLAG_XGMI;
+    case HSA_AMD_LINK_INFO_TYPE_HYPERTRANSPORT:
+      return IREE_HAL_AMDGPU_PHYSICAL_TOPOLOGY_LINK_FLAG_HYPERTRANSPORT;
+    case HSA_AMD_LINK_INFO_TYPE_QPI:
+      return IREE_HAL_AMDGPU_PHYSICAL_TOPOLOGY_LINK_FLAG_QPI;
+    case HSA_AMD_LINK_INFO_TYPE_INFINBAND:
+      return IREE_HAL_AMDGPU_PHYSICAL_TOPOLOGY_LINK_FLAG_INFINIBAND;
+    default:
+      return IREE_HAL_AMDGPU_PHYSICAL_TOPOLOGY_LINK_FLAG_OTHER;
+  }
+}
+
+static void iree_hal_amdgpu_topology_costs_from_link_class(
+    iree_hal_topology_link_class_t link_class, uint8_t* out_copy_cost,
+    uint8_t* out_latency_class) {
+  switch (link_class) {
+    case IREE_HAL_TOPOLOGY_LINK_CLASS_SAME_DIE:
+      *out_copy_cost = 0;
+      *out_latency_class = 0;
+      break;
+    case IREE_HAL_TOPOLOGY_LINK_CLASS_NVLINK_IF:
+      *out_copy_cost = 3;
+      *out_latency_class = 3;
+      break;
+    case IREE_HAL_TOPOLOGY_LINK_CLASS_PCIE_SAME_ROOT:
+      *out_copy_cost = 7;
+      *out_latency_class = 7;
+      break;
+    case IREE_HAL_TOPOLOGY_LINK_CLASS_PCIE_CROSS_ROOT:
+      *out_copy_cost = 9;
+      *out_latency_class = 9;
+      break;
+    case IREE_HAL_TOPOLOGY_LINK_CLASS_HOST_STAGED:
+      *out_copy_cost = 13;
+      *out_latency_class = 11;
+      break;
+    case IREE_HAL_TOPOLOGY_LINK_CLASS_FABRIC:
+      *out_copy_cost = 15;
+      *out_latency_class = 14;
+      break;
+    case IREE_HAL_TOPOLOGY_LINK_CLASS_ISOLATED:
+      *out_copy_cost = 15;
+      *out_latency_class = 15;
+      break;
+    default:
+      *out_copy_cost = 11;
+      *out_latency_class = 10;
+      break;
+  }
+}
+
+static uint8_t iree_hal_amdgpu_topology_scale_hsa_numa_distance(
+    uint32_t hsa_numa_distance) {
+  if (hsa_numa_distance == 0) return 0;
+  uint32_t scaled = hsa_numa_distance > 10 ? (hsa_numa_distance - 10) / 2 : 0;
+  return (uint8_t)iree_min(scaled, 15u);
+}
+
+static iree_status_t iree_hal_amdgpu_validate_physical_topology_edge_access(
+    hsa_amd_memory_pool_access_t access, const char* pool_kind) {
+  if (IREE_LIKELY(iree_hal_amdgpu_memory_pool_access_is_valid(access))) {
+    return iree_ok_status();
+  }
+  return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                          "HSA reported unknown %s memory pool access mode %u",
+                          pool_kind, (uint32_t)access);
+}
+
+static void iree_hal_amdgpu_physical_topology_edge_initialize(
+    iree_hal_amdgpu_physical_topology_edge_t* out_edge) {
+  memset(out_edge, 0, sizeof(*out_edge));
+  out_edge->memory_access.coarse = HSA_AMD_MEMORY_POOL_ACCESS_NEVER_ALLOWED;
+  out_edge->memory_access.fine = HSA_AMD_MEMORY_POOL_ACCESS_NEVER_ALLOWED;
+  out_edge->coherency.all_hops_coherent = 1;
+  out_edge->atomics.all_hops_32bit = 1;
+  out_edge->atomics.all_hops_64bit = 1;
+  out_edge->link.link_class = IREE_HAL_TOPOLOGY_LINK_CLASS_SAME_DIE;
+  out_edge->modes.noncoherent_read = IREE_HAL_TOPOLOGY_INTEROP_MODE_COPY;
+  out_edge->modes.noncoherent_write = IREE_HAL_TOPOLOGY_INTEROP_MODE_COPY;
+  out_edge->modes.coherent_read = IREE_HAL_TOPOLOGY_INTEROP_MODE_COPY;
+  out_edge->modes.coherent_write = IREE_HAL_TOPOLOGY_INTEROP_MODE_COPY;
+}
+
+static iree_hal_topology_capability_t
+iree_hal_amdgpu_physical_topology_guaranteed_capabilities(
+    const iree_hal_amdgpu_physical_topology_edge_t* edge) {
+  iree_hal_topology_capability_t capabilities =
+      IREE_HAL_TOPOLOGY_CAPABILITY_NONE;
+  if (!edge->memory_access.coarse_accessible &&
+      !edge->memory_access.fine_accessible) {
+    return capabilities;
+  }
+  capabilities |= IREE_HAL_TOPOLOGY_CAPABILITY_P2P_COPY;
+  if (edge->coherency.all_hops_coherent) {
+    capabilities |= IREE_HAL_TOPOLOGY_CAPABILITY_PEER_COHERENT;
+  }
+  if (edge->atomics.all_hops_32bit) {
+    capabilities |= IREE_HAL_TOPOLOGY_CAPABILITY_ATOMIC_DEVICE;
+  }
+  if (edge->atomics.all_hops_64bit) {
+    capabilities |= IREE_HAL_TOPOLOGY_CAPABILITY_ATOMIC_SYSTEM;
+  }
+  return capabilities;
+}
+
+static iree_hal_topology_capability_t
+iree_hal_amdgpu_physical_topology_required_capabilities(
+    const iree_hal_amdgpu_physical_topology_edge_t* edge) {
+  iree_hal_topology_capability_t capabilities =
+      IREE_HAL_TOPOLOGY_CAPABILITY_NONE;
+  capabilities |= iree_hal_amdgpu_memory_pool_access_topology_capabilities(
+      edge->memory_access.coarse);
+  capabilities |= iree_hal_amdgpu_memory_pool_access_topology_capabilities(
+      edge->memory_access.fine);
+  return capabilities;
+}
+
+iree_status_t iree_hal_amdgpu_select_physical_topology_edge(
+    const iree_hal_amdgpu_physical_topology_edge_selection_t* selection,
+    iree_hal_amdgpu_physical_topology_edge_t* out_edge) {
+  IREE_ASSERT_ARGUMENT(selection);
+  IREE_ASSERT_ARGUMENT(out_edge);
+  iree_hal_amdgpu_physical_topology_edge_initialize(out_edge);
+
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_validate_physical_topology_edge_access(
+      selection->memory_access.coarse, "coarse"));
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_validate_physical_topology_edge_access(
+      selection->memory_access.fine, "fine"));
+  if (IREE_UNLIKELY(selection->link.count && !selection->link.hops)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "AMDGPU physical topology edge selection requires link hops when "
+        "link count is nonzero");
+  }
+
+  out_edge->memory_access.coarse = selection->memory_access.coarse;
+  out_edge->memory_access.fine = selection->memory_access.fine;
+  out_edge->memory_access.coarse_accessible =
+      selection->memory_access.coarse !=
+      HSA_AMD_MEMORY_POOL_ACCESS_NEVER_ALLOWED;
+  out_edge->memory_access.fine_accessible =
+      selection->memory_access.fine != HSA_AMD_MEMORY_POOL_ACCESS_NEVER_ALLOWED;
+
+  for (iree_host_size_t i = 0; i < selection->link.count; ++i) {
+    const hsa_amd_memory_pool_link_info_t* link_hop = &selection->link.hops[i];
+    iree_hal_topology_link_class_t link_class =
+        iree_hal_amdgpu_link_type_to_link_class(link_hop->link_type);
+    if (link_class > out_edge->link.link_class) {
+      out_edge->link.link_class = link_class;
+    }
+    out_edge->link.flags |=
+        iree_hal_amdgpu_link_type_to_physical_topology_link_flags(
+            link_hop->link_type);
+    uint8_t numa_distance = iree_hal_amdgpu_topology_scale_hsa_numa_distance(
+        link_hop->numa_distance);
+    if (numa_distance > out_edge->link.numa_distance) {
+      out_edge->link.numa_distance = numa_distance;
+    }
+    if (!link_hop->coherent_support) {
+      out_edge->coherency.all_hops_coherent = 0;
+    }
+    if (!link_hop->atomic_support_32bit) {
+      out_edge->atomics.all_hops_32bit = 0;
+    }
+    if (!link_hop->atomic_support_64bit) {
+      out_edge->atomics.all_hops_64bit = 0;
+    }
+  }
+
+  if (!out_edge->memory_access.coarse_accessible &&
+      !out_edge->memory_access.fine_accessible) {
+    out_edge->link.link_class = IREE_HAL_TOPOLOGY_LINK_CLASS_HOST_STAGED;
+    out_edge->coherency.all_hops_coherent = 0;
+    out_edge->atomics.all_hops_32bit = 0;
+    out_edge->atomics.all_hops_64bit = 0;
+  }
+
+  iree_hal_amdgpu_topology_costs_from_link_class(out_edge->link.link_class,
+                                                 &out_edge->link.copy_cost,
+                                                 &out_edge->link.latency_class);
+  out_edge->capabilities.guaranteed =
+      iree_hal_amdgpu_physical_topology_guaranteed_capabilities(out_edge);
+  out_edge->capabilities.required =
+      iree_hal_amdgpu_physical_topology_required_capabilities(out_edge);
+  out_edge->modes.noncoherent_read =
+      iree_hal_amdgpu_memory_pool_access_topology_mode(
+          out_edge->memory_access.coarse);
+  out_edge->modes.noncoherent_write = out_edge->modes.noncoherent_read;
+  out_edge->modes.coherent_read =
+      iree_hal_amdgpu_memory_pool_access_topology_mode(
+          out_edge->memory_access.fine);
+  out_edge->modes.coherent_write = out_edge->modes.coherent_read;
+  return iree_ok_status();
+}
+
+static bool iree_hal_amdgpu_gfxip_is_pre_gfx908(
+    iree_hal_amdgpu_gfxip_version_t version) {
+  return version.major < 9 ||
+         (version.major == 9 && version.minor == 0 && version.stepping < 8);
+}
+
+static bool iree_hal_amdgpu_gfxip_is_gfx101x(
+    iree_hal_amdgpu_gfxip_version_t version) {
+  return version.major == 10 && (version.minor == 0 || version.minor == 1);
+}
+
+bool iree_hal_amdgpu_gfxip_allows_hdp_kernarg_publication(
+    iree_hal_amdgpu_gfxip_version_t version) {
+  // Matches the HDP workaround eligibility in CLR's setKernelArgImpl. Devices
+  // outside this set stay on host kernarg memory unless we add a first-class
+  // readback publication mode.
+  return !iree_hal_amdgpu_gfxip_is_pre_gfx908(version) &&
+         !iree_hal_amdgpu_gfxip_is_gfx101x(version);
+}
+
+iree_status_t iree_hal_amdgpu_select_cpu_visible_device_coarse_memory(
+    const iree_hal_amdgpu_cpu_visible_device_coarse_memory_selection_t*
+        selection,
+    iree_hal_amdgpu_cpu_visible_device_coarse_memory_t* out_memory) {
+  IREE_ASSERT_ARGUMENT(selection);
+  IREE_ASSERT_ARGUMENT(out_memory);
+  memset(out_memory, 0, sizeof(*out_memory));
+
+  if (!selection->memory_pool.handle || selection->cpu.count == 0) {
+    return iree_ok_status();
+  }
+  if (IREE_UNLIKELY(selection->cpu.count > IREE_HAL_AMDGPU_MAX_CPU_AGENT)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "AMDGPU topology has %" PRIhsz
+        " CPU agents but CPU-visible coarse memory tracks at most %d",
+        selection->cpu.count, IREE_HAL_AMDGPU_MAX_CPU_AGENT);
+  }
+  if (!iree_any_bit_set(
+          selection->flags,
+          IREE_HAL_AMDGPU_CPU_VISIBLE_DEVICE_COARSE_MEMORY_SELECTION_FLAG_HOST_WRITE_PUBLICATION_SUPPORTED)) {
+    return iree_ok_status();
+  }
+  if (!iree_hal_amdgpu_gfxip_allows_hdp_kernarg_publication(
+          selection->gfxip_version)) {
+    return iree_ok_status();
+  }
+  if (!selection->hdp.registers.HDP_MEM_FLUSH_CNTL ||
+      !selection->hdp.registers.HDP_REG_FLUSH_CNTL) {
+    return iree_ok_status();
+  }
+  if (IREE_UNLIKELY(!selection->cpu.agents || !selection->cpu.access)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "CPU-visible device-coarse memory selection requires CPU agents and "
+        "access modes");
+  }
+
+  for (iree_host_size_t i = 0; i < selection->cpu.count; ++i) {
+    const hsa_amd_memory_pool_access_t access = selection->cpu.access[i];
+    if (IREE_UNLIKELY(!iree_hal_amdgpu_memory_pool_access_is_valid(access))) {
+      return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                              "HSA reported unknown memory pool access mode %u",
+                              (uint32_t)access);
+    }
+    if (access == HSA_AMD_MEMORY_POOL_ACCESS_NEVER_ALLOWED) {
+      return iree_ok_status();
+    }
+  }
+
+  iree_host_size_t access_agent_count = 0;
+  for (iree_host_size_t i = 0; i < selection->cpu.count; ++i) {
+    out_memory->access_agents[access_agent_count++] = selection->cpu.agents[i];
+  }
+  out_memory->access_agents[access_agent_count++] = selection->device_agent;
+  out_memory->memory_pool = selection->memory_pool;
+  out_memory->access_agent_count = access_agent_count;
+  out_memory->host_write_publication =
+      (iree_hal_amdgpu_kernarg_ring_publication_t){
+          .mode = IREE_HAL_AMDGPU_KERNARG_RING_PUBLICATION_MODE_HDP_FLUSH,
+          .hdp_mem_flush_control = selection->hdp.registers.HDP_MEM_FLUSH_CNTL,
+      };
+  out_memory->flags =
+      IREE_HAL_AMDGPU_CPU_VISIBLE_DEVICE_COARSE_MEMORY_FLAG_AVAILABLE |
+      IREE_HAL_AMDGPU_CPU_VISIBLE_DEVICE_COARSE_MEMORY_FLAG_HDP_FLUSH;
+  return iree_ok_status();
+}
+
+void iree_hal_amdgpu_select_memory_system_capabilities(
+    const iree_hal_amdgpu_memory_system_capabilities_selection_t* selection,
+    iree_hal_amdgpu_memory_system_capabilities_t* out_capabilities) {
+  IREE_ASSERT_ARGUMENT(selection);
+  IREE_ASSERT_ARGUMENT(out_capabilities);
+  memset(out_capabilities, 0, sizeof(*out_capabilities));
+
+  out_capabilities->svm.supported = selection->svm.supported ? 1u : 0u;
+  out_capabilities->svm.accessible_by_default =
+      selection->svm.accessible_by_default ? 1u : 0u;
+  out_capabilities->svm.xnack_enabled = selection->svm.xnack_enabled ? 1u : 0u;
+  out_capabilities->svm.direct_host_access =
+      selection->svm.direct_host_access ? 1u : 0u;
+  out_capabilities->device_local.fine_host_visible =
+      selection->device_local.fine_memory_pool.handle ? 1u : 0u;
+  out_capabilities->device_local.coarse_cpu_visible =
+      selection->device_local.coarse_cpu_visible_memory &&
+              iree_hal_amdgpu_cpu_visible_device_coarse_memory_is_available(
+                  selection->device_local.coarse_cpu_visible_memory)
+          ? 1u
+          : 0u;
+}
+
+iree_hal_device_capability_bits_t
+iree_hal_amdgpu_select_memory_system_device_capability_flags(
+    const iree_hal_amdgpu_memory_system_capabilities_t* capabilities) {
+  IREE_ASSERT_ARGUMENT(capabilities);
+  iree_hal_device_capability_bits_t flags = IREE_HAL_DEVICE_CAPABILITY_NONE;
+  if (capabilities->svm.supported) {
+    flags |= IREE_HAL_DEVICE_CAPABILITY_SHARED_VIRTUAL_ADDRESS;
+  }
+  if (capabilities->svm.accessible_by_default) {
+    flags |= IREE_HAL_DEVICE_CAPABILITY_UNIFIED_MEMORY;
+  }
+  return flags;
+}
+
+bool iree_hal_amdgpu_memory_system_requires_svm_access_attributes(
+    const iree_hal_amdgpu_memory_system_capabilities_t* capabilities) {
+  IREE_ASSERT_ARGUMENT(capabilities);
+  return capabilities->svm.supported &&
+         !capabilities->svm.accessible_by_default;
+}
+
+iree_hal_amdgpu_aql_prepublished_kernarg_storage_t
+iree_hal_amdgpu_select_prepublished_kernarg_storage(
+    hsa_amd_memory_pool_t fine_block_memory_pool) {
+  if (!fine_block_memory_pool.handle) {
+    return iree_hal_amdgpu_aql_prepublished_kernarg_storage_disabled();
+  }
+  return iree_hal_amdgpu_aql_prepublished_kernarg_storage_device_fine_host_coherent();
+}
+
+iree_hal_amdgpu_vendor_packet_capability_flags_t
+iree_hal_amdgpu_select_vendor_packet_capabilities(
+    iree_hal_amdgpu_gfxip_version_t version) {
+  // The CDNA BARRIER_VALUE rows match CLR's barrier_value_packet_ gate:
+  // gfx9.0.10 or gfx9.[minor >= 4].[stepping 0..2].
+  //
+  // The gfx1100 row is the currently validated RDNA3 PM4 path. Nearby gfx10,
+  // gfx11, and gfx12 parts stay on the base AQL path until each packet-family
+  // contract has hardware evidence or an explicit probe.
+  static const iree_hal_amdgpu_vendor_packet_capability_row_t kRows[] = {
+      {
+          .version =
+              {
+                  .major = {9, 9},
+                  .minor = {0, 0},
+                  .stepping = {10, 10},
+              },
+          .capabilities =
+              IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_AQL_PM4_IB |
+              IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_AQL_BARRIER_VALUE,
+      },
+      {
+          .version =
+              {
+                  .major = {9, 9},
+                  .minor = {4, UINT16_MAX},
+                  .stepping = {0, 2},
+              },
+          .capabilities =
+              IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_AQL_PM4_IB |
+              IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_AQL_BARRIER_VALUE,
+      },
+      {
+          .version =
+              {
+                  .major = {11, 11},
+                  .minor = {0, 0},
+                  .stepping = {0, 0},
+              },
+          .capabilities =
+              IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_GFX1100_VALIDATED,
+      },
+  };
+
+  iree_hal_amdgpu_vendor_packet_capability_flags_t capabilities = 0;
+  for (iree_host_size_t i = 0; i < IREE_ARRAYSIZE(kRows); ++i) {
+    if (iree_hal_amdgpu_gfxip_version_range_contains(kRows[i].version,
+                                                     version)) {
+      capabilities |= kRows[i].capabilities;
+    }
+  }
+  return capabilities;
+}
+
+iree_hal_amdgpu_wait_barrier_strategy_t
+iree_hal_amdgpu_select_wait_barrier_strategy(
+    iree_hal_amdgpu_vendor_packet_capability_flags_t
+        vendor_packet_capabilities) {
+  if (vendor_packet_capabilities &
+      IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_AQL_BARRIER_VALUE) {
+    return IREE_HAL_AMDGPU_WAIT_BARRIER_STRATEGY_AQL_BARRIER_VALUE;
+  }
+  if (vendor_packet_capabilities &
+      IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_WAIT_REG_MEM64) {
+    return IREE_HAL_AMDGPU_WAIT_BARRIER_STRATEGY_PM4_WAIT_REG_MEM64;
+  }
+  return IREE_HAL_AMDGPU_WAIT_BARRIER_STRATEGY_DEFER;
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/physical_device_capabilities.h b/runtime/src/iree/hal/drivers/amdgpu/physical_device_capabilities.h
new file mode 100644
index 0000000..40951bf
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/physical_device_capabilities.h
@@ -0,0 +1,305 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_PHYSICAL_DEVICE_CAPABILITIES_H_
+#define IREE_HAL_DRIVERS_AMDGPU_PHYSICAL_DEVICE_CAPABILITIES_H_
+
+#include "iree/base/api.h"
+#include "iree/hal/device.h"
+#include "iree/hal/drivers/amdgpu/aql_prepublished_kernarg_storage.h"
+#include "iree/hal/drivers/amdgpu/util/kernarg_ring.h"
+#include "iree/hal/drivers/amdgpu/util/libhsa.h"
+#include "iree/hal/drivers/amdgpu/util/pm4_capabilities.h"
+#include "iree/hal/drivers/amdgpu/util/target_id.h"
+#include "iree/hal/drivers/amdgpu/util/topology.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+typedef enum iree_hal_amdgpu_cpu_visible_device_coarse_memory_flag_bits_e {
+  IREE_HAL_AMDGPU_CPU_VISIBLE_DEVICE_COARSE_MEMORY_FLAG_NONE = 0u,
+  // All CPU agents can access the GPU coarse-grained memory pool and the
+  // driver knows how to publish CPU writes before GPU consumption.
+  IREE_HAL_AMDGPU_CPU_VISIBLE_DEVICE_COARSE_MEMORY_FLAG_AVAILABLE = 1u << 0,
+  // CPU writes require an HDP flush before the GPU consumes the memory.
+  IREE_HAL_AMDGPU_CPU_VISIBLE_DEVICE_COARSE_MEMORY_FLAG_HDP_FLUSH = 1u << 1,
+} iree_hal_amdgpu_cpu_visible_device_coarse_memory_flag_bits_t;
+
+typedef uint32_t iree_hal_amdgpu_cpu_visible_device_coarse_memory_flags_t;
+
+// Physical-device capability for CPU-visible GPU coarse-grained memory.
+typedef struct iree_hal_amdgpu_cpu_visible_device_coarse_memory_t {
+  // GPU coarse-grained HSA memory pool CPU agents can access.
+  hsa_amd_memory_pool_t memory_pool;
+  // Agents granted access for allocations that use |memory_pool|.
+  hsa_agent_t access_agents[IREE_HAL_AMDGPU_MAX_CPU_AGENT + 1];
+  // Number of valid entries in |access_agents|.
+  iree_host_size_t access_agent_count;
+  // Publication required after CPU writes and before GPU consumption.
+  iree_hal_amdgpu_kernarg_ring_publication_t host_write_publication;
+  // Capability flags from
+  // iree_hal_amdgpu_cpu_visible_device_coarse_memory_flag_bits_t.
+  iree_hal_amdgpu_cpu_visible_device_coarse_memory_flags_t flags;
+} iree_hal_amdgpu_cpu_visible_device_coarse_memory_t;
+
+typedef enum iree_hal_amdgpu_cpu_visible_device_coarse_memory_selection_flag_bits_e {
+  IREE_HAL_AMDGPU_CPU_VISIBLE_DEVICE_COARSE_MEMORY_SELECTION_FLAG_NONE = 0u,
+  // Host writes can be published for CPU-visible device-coarse memory.
+  IREE_HAL_AMDGPU_CPU_VISIBLE_DEVICE_COARSE_MEMORY_SELECTION_FLAG_HOST_WRITE_PUBLICATION_SUPPORTED =
+      1u << 0,
+} iree_hal_amdgpu_cpu_visible_device_coarse_memory_selection_flag_bits_t;
+
+typedef uint32_t
+    iree_hal_amdgpu_cpu_visible_device_coarse_memory_selection_flags_t;
+
+// Queried facts used to select CPU-visible device-coarse memory capability.
+typedef struct iree_hal_amdgpu_cpu_visible_device_coarse_memory_selection_t {
+  // GPU agent that owns |memory_pool|.
+  hsa_agent_t device_agent;
+  // GPU coarse-grained memory pool being considered.
+  hsa_amd_memory_pool_t memory_pool;
+  // Parsed gfx IP version for HDP publication eligibility.
+  iree_hal_amdgpu_gfxip_version_t gfxip_version;
+  // CPU agents and their access to |memory_pool|.
+  struct {
+    // CPU agents that may write the memory.
+    const hsa_agent_t* agents;
+    // Per-CPU-agent access mode for |memory_pool|.
+    const hsa_amd_memory_pool_access_t* access;
+    // Number of entries in |agents| and |access|.
+    iree_host_size_t count;
+  } cpu;
+  // HDP publication registers reported by HSA.
+  struct {
+    // Raw HSA HDP flush register descriptor.
+    hsa_amd_hdp_flush_t registers;
+  } hdp;
+  // Selection flags from
+  // iree_hal_amdgpu_cpu_visible_device_coarse_memory_selection_flag_bits_t.
+  iree_hal_amdgpu_cpu_visible_device_coarse_memory_selection_flags_t flags;
+} iree_hal_amdgpu_cpu_visible_device_coarse_memory_selection_t;
+
+// Returns true if CPU-visible device-coarse memory is available.
+bool iree_hal_amdgpu_cpu_visible_device_coarse_memory_is_available(
+    const iree_hal_amdgpu_cpu_visible_device_coarse_memory_t* memory);
+
+// Returns true if |access| is a known HSA memory-pool access mode.
+bool iree_hal_amdgpu_memory_pool_access_is_valid(
+    hsa_amd_memory_pool_access_t access);
+
+// Maps an HSA memory-pool access mode to the safe default topology buffer mode.
+iree_hal_topology_interop_mode_t
+iree_hal_amdgpu_memory_pool_access_topology_mode(
+    hsa_amd_memory_pool_access_t access);
+
+// Maps an HSA memory-pool access mode to additional topology capabilities.
+iree_hal_topology_capability_t
+iree_hal_amdgpu_memory_pool_access_topology_capabilities(
+    hsa_amd_memory_pool_access_t access);
+
+typedef enum iree_hal_amdgpu_physical_topology_link_flag_bits_e {
+  IREE_HAL_AMDGPU_PHYSICAL_TOPOLOGY_LINK_FLAG_NONE = 0u,
+  // At least one HSA-reported hop uses PCIe.
+  IREE_HAL_AMDGPU_PHYSICAL_TOPOLOGY_LINK_FLAG_PCIE = 1u << 0,
+  // At least one HSA-reported hop uses xGMI.
+  IREE_HAL_AMDGPU_PHYSICAL_TOPOLOGY_LINK_FLAG_XGMI = 1u << 1,
+  // At least one HSA-reported hop uses HyperTransport.
+  IREE_HAL_AMDGPU_PHYSICAL_TOPOLOGY_LINK_FLAG_HYPERTRANSPORT = 1u << 2,
+  // At least one HSA-reported hop uses QPI.
+  IREE_HAL_AMDGPU_PHYSICAL_TOPOLOGY_LINK_FLAG_QPI = 1u << 3,
+  // At least one HSA-reported hop uses InfiniBand.
+  IREE_HAL_AMDGPU_PHYSICAL_TOPOLOGY_LINK_FLAG_INFINIBAND = 1u << 4,
+  // At least one HSA-reported hop uses an unknown link type.
+  IREE_HAL_AMDGPU_PHYSICAL_TOPOLOGY_LINK_FLAG_OTHER = 1u << 5,
+} iree_hal_amdgpu_physical_topology_link_flag_bits_t;
+
+typedef uint32_t iree_hal_amdgpu_physical_topology_link_flags_t;
+
+// Physical source->destination topology edge selected from already-queried HSA
+// memory-pool access and link-hop facts.
+typedef struct iree_hal_amdgpu_physical_topology_edge_t {
+  // Source-agent access to the destination memory pools.
+  struct {
+    // Source-agent access to the destination coarse-grained memory pool.
+    hsa_amd_memory_pool_access_t coarse;
+    // Source-agent access to the destination fine-grained memory pool.
+    hsa_amd_memory_pool_access_t fine;
+    // True when |coarse| permits some direct device access.
+    uint32_t coarse_accessible : 1;
+    // True when |fine| permits some direct device access.
+    uint32_t fine_accessible : 1;
+  } memory_access;
+
+  // HSA link-hop facts collapsed into strategy-friendly topology values.
+  struct {
+    // Worst physical link class across HSA-reported link hops.
+    iree_hal_topology_link_class_t link_class;
+    // Conservative copy-cost class derived from |link_class|.
+    uint8_t copy_cost;
+    // Conservative latency class derived from |link_class|.
+    uint8_t latency_class;
+    // Worst normalized NUMA distance reported by HSA link hops.
+    uint8_t numa_distance;
+    // Link flags from iree_hal_amdgpu_physical_topology_link_flag_bits_t.
+    iree_hal_amdgpu_physical_topology_link_flags_t flags;
+  } link;
+
+  // Link coherency facts.
+  struct {
+    // True when every HSA-reported link hop supports coherent transactions.
+    uint32_t all_hops_coherent : 1;
+  } coherency;
+
+  // Link atomic-transaction facts.
+  struct {
+    // True when every HSA-reported link hop supports 32-bit atomics.
+    uint32_t all_hops_32bit : 1;
+    // True when every HSA-reported link hop supports 64-bit atomics.
+    uint32_t all_hops_64bit : 1;
+  } atomics;
+
+  // Generic HAL topology capabilities implied by the physical edge.
+  struct {
+    // Positive capabilities guaranteed by this physical pair.
+    iree_hal_topology_capability_t guaranteed;
+    // Requirement bits imposed by this physical pair.
+    iree_hal_topology_capability_t required;
+  } capabilities;
+
+  // Generic HAL buffer interop modes implied by memory-pool access.
+  struct {
+    // Noncoherent read mode derived from coarse-grained pool access.
+    iree_hal_topology_interop_mode_t noncoherent_read;
+    // Noncoherent write mode derived from coarse-grained pool access.
+    iree_hal_topology_interop_mode_t noncoherent_write;
+    // Coherent read mode derived from fine-grained pool access.
+    iree_hal_topology_interop_mode_t coherent_read;
+    // Coherent write mode derived from fine-grained pool access.
+    iree_hal_topology_interop_mode_t coherent_write;
+  } modes;
+} iree_hal_amdgpu_physical_topology_edge_t;
+
+// Already-queried HSA facts used to select a physical topology edge.
+typedef struct iree_hal_amdgpu_physical_topology_edge_selection_t {
+  // Source-agent access to the destination memory pools.
+  struct {
+    // Source-agent access to the destination coarse-grained memory pool.
+    hsa_amd_memory_pool_access_t coarse;
+    // Source-agent access to the destination fine-grained memory pool.
+    hsa_amd_memory_pool_access_t fine;
+  } memory_access;
+
+  // HSA link-hop facts for the source->destination memory path.
+  struct {
+    // HSA-reported link-hop records.
+    const hsa_amd_memory_pool_link_info_t* hops;
+    // Number of entries in |hops|.
+    iree_host_size_t count;
+  } link;
+} iree_hal_amdgpu_physical_topology_edge_selection_t;
+
+// Selects a physical topology edge from already-queried HSA facts.
+iree_status_t iree_hal_amdgpu_select_physical_topology_edge(
+    const iree_hal_amdgpu_physical_topology_edge_selection_t* selection,
+    iree_hal_amdgpu_physical_topology_edge_t* out_edge);
+
+// Returns true if the gfx IP family permits HDP kernarg publication.
+bool iree_hal_amdgpu_gfxip_allows_hdp_kernarg_publication(
+    iree_hal_amdgpu_gfxip_version_t version);
+
+// Selects CPU-visible device-coarse memory from already-queried topology facts.
+iree_status_t iree_hal_amdgpu_select_cpu_visible_device_coarse_memory(
+    const iree_hal_amdgpu_cpu_visible_device_coarse_memory_selection_t*
+        selection,
+    iree_hal_amdgpu_cpu_visible_device_coarse_memory_t* out_memory);
+
+// AMDGPU memory-system facts used to derive conservative HAL topology flags.
+typedef struct iree_hal_amdgpu_memory_system_capabilities_t {
+  // HSA SVM/HMM process and agent facts.
+  struct {
+    // HSA SVM attribute and prefetch APIs are available.
+    uint32_t supported : 1;
+    // System allocations are accessible by GPU agents without per-range grants.
+    uint32_t accessible_by_default : 1;
+    // The process is bound to XNACK-enabled execution.
+    uint32_t xnack_enabled : 1;
+    // The host can directly access SVM pages resident in this GPU local memory.
+    uint32_t direct_host_access : 1;
+  } svm;
+
+  // Device-local memory placement facts.
+  struct {
+    // A host-coherent fine-grained device memory pool is available.
+    uint32_t fine_host_visible : 1;
+    // A CPU-visible coarse-grained device memory pool is usable by the driver.
+    uint32_t coarse_cpu_visible : 1;
+  } device_local;
+} iree_hal_amdgpu_memory_system_capabilities_t;
+
+// Already-queried facts used to select memory-system capabilities.
+typedef struct iree_hal_amdgpu_memory_system_capabilities_selection_t {
+  // HSA SVM/HMM process and agent facts.
+  struct {
+    // Whether HSA SVM APIs are available in this process.
+    uint32_t supported : 1;
+    // Whether pageable/system memory is GPU-accessible without SVM attributes.
+    uint32_t accessible_by_default : 1;
+    // Whether the process is bound to XNACK-enabled execution.
+    uint32_t xnack_enabled : 1;
+    // Whether this GPU reports direct host access to resident SVM pages.
+    uint32_t direct_host_access : 1;
+  } svm;
+
+  // Device-local memory placement facts.
+  struct {
+    // Fine-grained global memory pool considered for host-visible device data.
+    hsa_amd_memory_pool_t fine_memory_pool;
+    // Selected CPU-visible coarse-grained device-memory capability.
+    const iree_hal_amdgpu_cpu_visible_device_coarse_memory_t*
+        coarse_cpu_visible_memory;
+  } device_local;
+} iree_hal_amdgpu_memory_system_capabilities_selection_t;
+
+// Selects memory-system capabilities from already-queried facts.
+void iree_hal_amdgpu_select_memory_system_capabilities(
+    const iree_hal_amdgpu_memory_system_capabilities_selection_t* selection,
+    iree_hal_amdgpu_memory_system_capabilities_t* out_capabilities);
+
+// Returns HAL device capability flags implied by AMDGPU memory-system facts.
+iree_hal_device_capability_bits_t
+iree_hal_amdgpu_select_memory_system_device_capability_flags(
+    const iree_hal_amdgpu_memory_system_capabilities_t* capabilities);
+
+// Returns true when SVM ranges require explicit HSA access attributes before a
+// GPU can safely access them.
+bool iree_hal_amdgpu_memory_system_requires_svm_access_attributes(
+    const iree_hal_amdgpu_memory_system_capabilities_t* capabilities);
+
+// Selects command-buffer prepublished kernarg storage from queried memory
+// pools.
+iree_hal_amdgpu_aql_prepublished_kernarg_storage_t
+iree_hal_amdgpu_select_prepublished_kernarg_storage(
+    hsa_amd_memory_pool_t fine_block_memory_pool);
+
+// Selects AMD vendor AQL packet and PM4 packet-family capabilities from the
+// parsed gfx IP version.
+iree_hal_amdgpu_vendor_packet_capability_flags_t
+iree_hal_amdgpu_select_vendor_packet_capabilities(
+    iree_hal_amdgpu_gfxip_version_t version);
+
+// Selects the cross-queue wait strategy from already-selected vendor packet
+// capabilities.
+iree_hal_amdgpu_wait_barrier_strategy_t
+iree_hal_amdgpu_select_wait_barrier_strategy(
+    iree_hal_amdgpu_vendor_packet_capability_flags_t
+        vendor_packet_capabilities);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_PHYSICAL_DEVICE_CAPABILITIES_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/physical_device_capabilities_test.cc b/runtime/src/iree/hal/drivers/amdgpu/physical_device_capabilities_test.cc
new file mode 100644
index 0000000..9d55ab7
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/physical_device_capabilities_test.cc
@@ -0,0 +1,595 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/physical_device_capabilities.h"
+
+#include <array>
+
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace iree::hal::amdgpu {
+namespace {
+
+static hsa_agent_t Agent(uint64_t handle) {
+  hsa_agent_t agent = {};
+  agent.handle = handle;
+  return agent;
+}
+
+static hsa_amd_memory_pool_t MemoryPool(uint64_t handle) {
+  hsa_amd_memory_pool_t memory_pool = {};
+  memory_pool.handle = handle;
+  return memory_pool;
+}
+
+static hsa_amd_hdp_flush_t HdpFlush(uintptr_t mem_flush_control,
+                                    uintptr_t register_flush_control) {
+  hsa_amd_hdp_flush_t hdp_flush = {};
+  hdp_flush.HDP_MEM_FLUSH_CNTL = reinterpret_cast<uint32_t*>(mem_flush_control);
+  hdp_flush.HDP_REG_FLUSH_CNTL =
+      reinterpret_cast<uint32_t*>(register_flush_control);
+  return hdp_flush;
+}
+
+static iree_hal_amdgpu_gfxip_version_t GfxIp(uint16_t major, uint16_t minor,
+                                             uint16_t stepping) {
+  iree_hal_amdgpu_gfxip_version_t version = {};
+  version.major = major;
+  version.minor = minor;
+  version.stepping = stepping;
+  return version;
+}
+
+static hsa_amd_memory_pool_link_info_t LinkInfo(
+    hsa_amd_link_info_type_t link_type) {
+  hsa_amd_memory_pool_link_info_t link_info = {};
+  link_info.link_type = link_type;
+  link_info.atomic_support_32bit = true;
+  link_info.atomic_support_64bit = true;
+  link_info.coherent_support = true;
+  return link_info;
+}
+
+class PhysicalDeviceCapabilitiesTest : public ::testing::Test {
+ protected:
+  iree_hal_amdgpu_cpu_visible_device_coarse_memory_selection_t
+  MakeCoarseMemorySelection() {
+    iree_hal_amdgpu_cpu_visible_device_coarse_memory_selection_t selection = {};
+    selection.device_agent = Agent(10);
+    selection.memory_pool = MemoryPool(20);
+    selection.gfxip_version = GfxIp(11, 0, 0);
+    selection.cpu.agents = cpu_agents_.data();
+    selection.cpu.access = cpu_access_.data();
+    selection.cpu.count = cpu_agents_.size();
+    selection.hdp.registers = HdpFlush(0xCAFE, 0xBEEF);
+    selection.flags =
+        IREE_HAL_AMDGPU_CPU_VISIBLE_DEVICE_COARSE_MEMORY_SELECTION_FLAG_HOST_WRITE_PUBLICATION_SUPPORTED;
+    return selection;
+  }
+
+  iree_hal_amdgpu_memory_system_capabilities_selection_t
+  MakeMemorySystemSelection() {
+    iree_hal_amdgpu_memory_system_capabilities_selection_t selection = {};
+    selection.svm.supported = 1;
+    selection.svm.accessible_by_default = 0;
+    selection.svm.xnack_enabled = 0;
+    selection.svm.direct_host_access = 0;
+    selection.device_local.fine_memory_pool = MemoryPool(30);
+    selection.device_local.coarse_cpu_visible_memory = nullptr;
+    return selection;
+  }
+
+  iree_hal_amdgpu_physical_topology_edge_selection_t MakeTopologyEdgeSelection(
+      const hsa_amd_memory_pool_link_info_t* link_hops,
+      iree_host_size_t link_hop_count) {
+    iree_hal_amdgpu_physical_topology_edge_selection_t selection = {};
+    selection.memory_access.coarse =
+        HSA_AMD_MEMORY_POOL_ACCESS_ALLOWED_BY_DEFAULT;
+    selection.memory_access.fine =
+        HSA_AMD_MEMORY_POOL_ACCESS_ALLOWED_BY_DEFAULT;
+    selection.link.hops = link_hops;
+    selection.link.count = link_hop_count;
+    return selection;
+  }
+
+  std::array<hsa_agent_t, 2> cpu_agents_ = {Agent(1), Agent(2)};
+  std::array<hsa_amd_memory_pool_access_t, 2> cpu_access_ = {
+      HSA_AMD_MEMORY_POOL_ACCESS_ALLOWED_BY_DEFAULT,
+      HSA_AMD_MEMORY_POOL_ACCESS_DISALLOWED_BY_DEFAULT};
+};
+
+TEST_F(PhysicalDeviceCapabilitiesTest, SelectsAvailableCoarseMemory) {
+  iree_hal_amdgpu_cpu_visible_device_coarse_memory_selection_t selection =
+      MakeCoarseMemorySelection();
+  iree_hal_amdgpu_cpu_visible_device_coarse_memory_t capability;
+  IREE_ASSERT_OK(iree_hal_amdgpu_select_cpu_visible_device_coarse_memory(
+      &selection, &capability));
+
+  EXPECT_TRUE(iree_hal_amdgpu_cpu_visible_device_coarse_memory_is_available(
+      &capability));
+  EXPECT_EQ(capability.memory_pool.handle, selection.memory_pool.handle);
+  ASSERT_EQ(capability.access_agent_count, 3u);
+  EXPECT_EQ(capability.access_agents[0].handle, cpu_agents_[0].handle);
+  EXPECT_EQ(capability.access_agents[1].handle, cpu_agents_[1].handle);
+  EXPECT_EQ(capability.access_agents[2].handle, selection.device_agent.handle);
+  EXPECT_EQ(capability.host_write_publication.mode,
+            IREE_HAL_AMDGPU_KERNARG_RING_PUBLICATION_MODE_HDP_FLUSH);
+  EXPECT_EQ(capability.host_write_publication.hdp_mem_flush_control,
+            selection.hdp.registers.HDP_MEM_FLUSH_CNTL);
+  EXPECT_TRUE(iree_all_bits_set(
+      capability.flags,
+      IREE_HAL_AMDGPU_CPU_VISIBLE_DEVICE_COARSE_MEMORY_FLAG_AVAILABLE |
+          IREE_HAL_AMDGPU_CPU_VISIBLE_DEVICE_COARSE_MEMORY_FLAG_HDP_FLUSH));
+}
+
+TEST_F(PhysicalDeviceCapabilitiesTest, EmptyInputsDisableCoarseMemory) {
+  iree_hal_amdgpu_cpu_visible_device_coarse_memory_selection_t selection =
+      MakeCoarseMemorySelection();
+  iree_hal_amdgpu_cpu_visible_device_coarse_memory_t capability;
+
+  selection.memory_pool = MemoryPool(0);
+  IREE_ASSERT_OK(iree_hal_amdgpu_select_cpu_visible_device_coarse_memory(
+      &selection, &capability));
+  EXPECT_FALSE(iree_hal_amdgpu_cpu_visible_device_coarse_memory_is_available(
+      &capability));
+
+  selection = MakeCoarseMemorySelection();
+  selection.cpu.count = 0;
+  IREE_ASSERT_OK(iree_hal_amdgpu_select_cpu_visible_device_coarse_memory(
+      &selection, &capability));
+  EXPECT_FALSE(iree_hal_amdgpu_cpu_visible_device_coarse_memory_is_available(
+      &capability));
+}
+
+TEST_F(PhysicalDeviceCapabilitiesTest, PublicationGatesDisableCoarseMemory) {
+  iree_hal_amdgpu_cpu_visible_device_coarse_memory_selection_t selection =
+      MakeCoarseMemorySelection();
+  iree_hal_amdgpu_cpu_visible_device_coarse_memory_t capability;
+
+  selection.flags =
+      IREE_HAL_AMDGPU_CPU_VISIBLE_DEVICE_COARSE_MEMORY_SELECTION_FLAG_NONE;
+  IREE_ASSERT_OK(iree_hal_amdgpu_select_cpu_visible_device_coarse_memory(
+      &selection, &capability));
+  EXPECT_FALSE(iree_hal_amdgpu_cpu_visible_device_coarse_memory_is_available(
+      &capability));
+
+  selection = MakeCoarseMemorySelection();
+  selection.hdp.registers.HDP_MEM_FLUSH_CNTL = 0;
+  IREE_ASSERT_OK(iree_hal_amdgpu_select_cpu_visible_device_coarse_memory(
+      &selection, &capability));
+  EXPECT_FALSE(iree_hal_amdgpu_cpu_visible_device_coarse_memory_is_available(
+      &capability));
+
+  selection = MakeCoarseMemorySelection();
+  selection.hdp.registers.HDP_REG_FLUSH_CNTL = 0;
+  IREE_ASSERT_OK(iree_hal_amdgpu_select_cpu_visible_device_coarse_memory(
+      &selection, &capability));
+  EXPECT_FALSE(iree_hal_amdgpu_cpu_visible_device_coarse_memory_is_available(
+      &capability));
+}
+
+TEST_F(PhysicalDeviceCapabilitiesTest, GfxIpGatesHdpPublication) {
+  EXPECT_FALSE(
+      iree_hal_amdgpu_gfxip_allows_hdp_kernarg_publication(GfxIp(9, 0, 7)));
+  EXPECT_TRUE(
+      iree_hal_amdgpu_gfxip_allows_hdp_kernarg_publication(GfxIp(9, 0, 8)));
+  EXPECT_FALSE(
+      iree_hal_amdgpu_gfxip_allows_hdp_kernarg_publication(GfxIp(10, 0, 0)));
+  EXPECT_FALSE(
+      iree_hal_amdgpu_gfxip_allows_hdp_kernarg_publication(GfxIp(10, 1, 0)));
+  EXPECT_TRUE(
+      iree_hal_amdgpu_gfxip_allows_hdp_kernarg_publication(GfxIp(10, 3, 0)));
+  EXPECT_TRUE(
+      iree_hal_amdgpu_gfxip_allows_hdp_kernarg_publication(GfxIp(11, 0, 0)));
+}
+
+TEST_F(PhysicalDeviceCapabilitiesTest, UnsupportedGfxIpDisablesCoarseMemory) {
+  iree_hal_amdgpu_cpu_visible_device_coarse_memory_selection_t selection =
+      MakeCoarseMemorySelection();
+  selection.gfxip_version = GfxIp(10, 1, 0);
+  iree_hal_amdgpu_cpu_visible_device_coarse_memory_t capability;
+  IREE_ASSERT_OK(iree_hal_amdgpu_select_cpu_visible_device_coarse_memory(
+      &selection, &capability));
+  EXPECT_FALSE(iree_hal_amdgpu_cpu_visible_device_coarse_memory_is_available(
+      &capability));
+}
+
+TEST_F(PhysicalDeviceCapabilitiesTest, CpuAccessGatesCoarseMemory) {
+  iree_hal_amdgpu_cpu_visible_device_coarse_memory_selection_t selection =
+      MakeCoarseMemorySelection();
+  iree_hal_amdgpu_cpu_visible_device_coarse_memory_t capability;
+
+  cpu_access_[1] = HSA_AMD_MEMORY_POOL_ACCESS_NEVER_ALLOWED;
+  IREE_ASSERT_OK(iree_hal_amdgpu_select_cpu_visible_device_coarse_memory(
+      &selection, &capability));
+  EXPECT_FALSE(iree_hal_amdgpu_cpu_visible_device_coarse_memory_is_available(
+      &capability));
+
+  cpu_access_[1] = (hsa_amd_memory_pool_access_t)99;
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_OUT_OF_RANGE,
+                        iree_hal_amdgpu_select_cpu_visible_device_coarse_memory(
+                            &selection, &capability));
+}
+
+TEST_F(PhysicalDeviceCapabilitiesTest,
+       MemoryPoolAccessMapsToSafeTopologyModes) {
+  EXPECT_TRUE(iree_hal_amdgpu_memory_pool_access_is_valid(
+      HSA_AMD_MEMORY_POOL_ACCESS_NEVER_ALLOWED));
+  EXPECT_EQ(iree_hal_amdgpu_memory_pool_access_topology_mode(
+                HSA_AMD_MEMORY_POOL_ACCESS_NEVER_ALLOWED),
+            IREE_HAL_TOPOLOGY_INTEROP_MODE_COPY);
+  EXPECT_EQ(iree_hal_amdgpu_memory_pool_access_topology_capabilities(
+                HSA_AMD_MEMORY_POOL_ACCESS_NEVER_ALLOWED),
+            IREE_HAL_TOPOLOGY_CAPABILITY_NONE);
+
+  EXPECT_TRUE(iree_hal_amdgpu_memory_pool_access_is_valid(
+      HSA_AMD_MEMORY_POOL_ACCESS_ALLOWED_BY_DEFAULT));
+  EXPECT_EQ(iree_hal_amdgpu_memory_pool_access_topology_mode(
+                HSA_AMD_MEMORY_POOL_ACCESS_ALLOWED_BY_DEFAULT),
+            IREE_HAL_TOPOLOGY_INTEROP_MODE_NATIVE);
+  EXPECT_EQ(iree_hal_amdgpu_memory_pool_access_topology_capabilities(
+                HSA_AMD_MEMORY_POOL_ACCESS_ALLOWED_BY_DEFAULT),
+            IREE_HAL_TOPOLOGY_CAPABILITY_NONE);
+
+  EXPECT_TRUE(iree_hal_amdgpu_memory_pool_access_is_valid(
+      HSA_AMD_MEMORY_POOL_ACCESS_DISALLOWED_BY_DEFAULT));
+  EXPECT_EQ(iree_hal_amdgpu_memory_pool_access_topology_mode(
+                HSA_AMD_MEMORY_POOL_ACCESS_DISALLOWED_BY_DEFAULT),
+            IREE_HAL_TOPOLOGY_INTEROP_MODE_COPY);
+  EXPECT_EQ(iree_hal_amdgpu_memory_pool_access_topology_capabilities(
+                HSA_AMD_MEMORY_POOL_ACCESS_DISALLOWED_BY_DEFAULT),
+            IREE_HAL_TOPOLOGY_CAPABILITY_PEER_ACCESS_REQUIRES_GRANT);
+
+  EXPECT_FALSE(iree_hal_amdgpu_memory_pool_access_is_valid(
+      (hsa_amd_memory_pool_access_t)99));
+}
+
+TEST_F(PhysicalDeviceCapabilitiesTest, SelectsXgmiPhysicalTopologyEdge) {
+  std::array<hsa_amd_memory_pool_link_info_t, 1> link_hops = {
+      LinkInfo(HSA_AMD_LINK_INFO_TYPE_XGMI)};
+  link_hops[0].numa_distance = 16;
+
+  iree_hal_amdgpu_physical_topology_edge_selection_t selection =
+      MakeTopologyEdgeSelection(link_hops.data(), link_hops.size());
+  iree_hal_amdgpu_physical_topology_edge_t edge;
+  IREE_ASSERT_OK(
+      iree_hal_amdgpu_select_physical_topology_edge(&selection, &edge));
+
+  EXPECT_EQ(edge.memory_access.coarse,
+            HSA_AMD_MEMORY_POOL_ACCESS_ALLOWED_BY_DEFAULT);
+  EXPECT_EQ(edge.memory_access.fine,
+            HSA_AMD_MEMORY_POOL_ACCESS_ALLOWED_BY_DEFAULT);
+  EXPECT_TRUE(edge.memory_access.coarse_accessible);
+  EXPECT_TRUE(edge.memory_access.fine_accessible);
+  EXPECT_TRUE(edge.coherency.all_hops_coherent);
+  EXPECT_TRUE(edge.atomics.all_hops_32bit);
+  EXPECT_TRUE(edge.atomics.all_hops_64bit);
+  EXPECT_TRUE(iree_any_bit_set(
+      edge.link.flags, IREE_HAL_AMDGPU_PHYSICAL_TOPOLOGY_LINK_FLAG_XGMI));
+  EXPECT_EQ(edge.link.link_class, IREE_HAL_TOPOLOGY_LINK_CLASS_NVLINK_IF);
+  EXPECT_EQ(edge.link.copy_cost, 3);
+  EXPECT_EQ(edge.link.latency_class, 3);
+  EXPECT_EQ(edge.link.numa_distance, 3);
+  EXPECT_TRUE(
+      iree_all_bits_set(edge.capabilities.guaranteed,
+                        IREE_HAL_TOPOLOGY_CAPABILITY_P2P_COPY |
+                            IREE_HAL_TOPOLOGY_CAPABILITY_PEER_COHERENT |
+                            IREE_HAL_TOPOLOGY_CAPABILITY_ATOMIC_DEVICE |
+                            IREE_HAL_TOPOLOGY_CAPABILITY_ATOMIC_SYSTEM));
+  EXPECT_EQ(edge.capabilities.required, IREE_HAL_TOPOLOGY_CAPABILITY_NONE);
+  EXPECT_EQ(edge.modes.noncoherent_read, IREE_HAL_TOPOLOGY_INTEROP_MODE_NATIVE);
+  EXPECT_EQ(edge.modes.coherent_read, IREE_HAL_TOPOLOGY_INTEROP_MODE_NATIVE);
+}
+
+TEST_F(PhysicalDeviceCapabilitiesTest,
+       SelectsWorstMultiHopPhysicalTopologyEdge) {
+  std::array<hsa_amd_memory_pool_link_info_t, 2> link_hops = {
+      LinkInfo(HSA_AMD_LINK_INFO_TYPE_XGMI),
+      LinkInfo(HSA_AMD_LINK_INFO_TYPE_HYPERTRANSPORT)};
+  link_hops[0].numa_distance = 12;
+  link_hops[1].numa_distance = 28;
+  link_hops[1].atomic_support_32bit = false;
+  link_hops[1].coherent_support = false;
+
+  iree_hal_amdgpu_physical_topology_edge_selection_t selection =
+      MakeTopologyEdgeSelection(link_hops.data(), link_hops.size());
+  iree_hal_amdgpu_physical_topology_edge_t edge;
+  IREE_ASSERT_OK(
+      iree_hal_amdgpu_select_physical_topology_edge(&selection, &edge));
+
+  EXPECT_FALSE(edge.coherency.all_hops_coherent);
+  EXPECT_FALSE(edge.atomics.all_hops_32bit);
+  EXPECT_TRUE(edge.atomics.all_hops_64bit);
+  EXPECT_TRUE(iree_all_bits_set(
+      edge.link.flags,
+      IREE_HAL_AMDGPU_PHYSICAL_TOPOLOGY_LINK_FLAG_XGMI |
+          IREE_HAL_AMDGPU_PHYSICAL_TOPOLOGY_LINK_FLAG_HYPERTRANSPORT));
+  EXPECT_EQ(edge.link.link_class, IREE_HAL_TOPOLOGY_LINK_CLASS_PCIE_CROSS_ROOT);
+  EXPECT_EQ(edge.link.copy_cost, 9);
+  EXPECT_EQ(edge.link.latency_class, 9);
+  EXPECT_EQ(edge.link.numa_distance, 9);
+}
+
+TEST_F(PhysicalDeviceCapabilitiesTest,
+       SelectsPciePhysicalTopologyEdgeWithoutSystemAtomics) {
+  std::array<hsa_amd_memory_pool_link_info_t, 1> link_hops = {
+      LinkInfo(HSA_AMD_LINK_INFO_TYPE_PCIE)};
+  link_hops[0].atomic_support_64bit = false;
+  link_hops[0].coherent_support = false;
+
+  iree_hal_amdgpu_physical_topology_edge_selection_t selection =
+      MakeTopologyEdgeSelection(link_hops.data(), link_hops.size());
+  selection.memory_access.fine = HSA_AMD_MEMORY_POOL_ACCESS_NEVER_ALLOWED;
+  iree_hal_amdgpu_physical_topology_edge_t edge;
+  IREE_ASSERT_OK(
+      iree_hal_amdgpu_select_physical_topology_edge(&selection, &edge));
+
+  EXPECT_TRUE(edge.memory_access.coarse_accessible);
+  EXPECT_FALSE(edge.memory_access.fine_accessible);
+  EXPECT_FALSE(edge.coherency.all_hops_coherent);
+  EXPECT_TRUE(edge.atomics.all_hops_32bit);
+  EXPECT_FALSE(edge.atomics.all_hops_64bit);
+  EXPECT_TRUE(iree_any_bit_set(
+      edge.link.flags, IREE_HAL_AMDGPU_PHYSICAL_TOPOLOGY_LINK_FLAG_PCIE));
+  EXPECT_EQ(edge.link.link_class, IREE_HAL_TOPOLOGY_LINK_CLASS_PCIE_SAME_ROOT);
+  EXPECT_EQ(edge.link.copy_cost, 7);
+  EXPECT_EQ(edge.link.latency_class, 7);
+  EXPECT_TRUE(iree_any_bit_set(edge.capabilities.guaranteed,
+                               IREE_HAL_TOPOLOGY_CAPABILITY_P2P_COPY));
+  EXPECT_FALSE(iree_any_bit_set(edge.capabilities.guaranteed,
+                                IREE_HAL_TOPOLOGY_CAPABILITY_PEER_COHERENT));
+  EXPECT_TRUE(iree_any_bit_set(edge.capabilities.guaranteed,
+                               IREE_HAL_TOPOLOGY_CAPABILITY_ATOMIC_DEVICE));
+  EXPECT_FALSE(iree_any_bit_set(edge.capabilities.guaranteed,
+                                IREE_HAL_TOPOLOGY_CAPABILITY_ATOMIC_SYSTEM));
+  EXPECT_EQ(edge.modes.noncoherent_read, IREE_HAL_TOPOLOGY_INTEROP_MODE_NATIVE);
+  EXPECT_EQ(edge.modes.coherent_read, IREE_HAL_TOPOLOGY_INTEROP_MODE_COPY);
+}
+
+TEST_F(PhysicalDeviceCapabilitiesTest,
+       GrantablePhysicalTopologyEdgeRequiresGrant) {
+  std::array<hsa_amd_memory_pool_link_info_t, 1> link_hops = {
+      LinkInfo(HSA_AMD_LINK_INFO_TYPE_PCIE)};
+  iree_hal_amdgpu_physical_topology_edge_selection_t selection =
+      MakeTopologyEdgeSelection(link_hops.data(), link_hops.size());
+  selection.memory_access.coarse =
+      HSA_AMD_MEMORY_POOL_ACCESS_DISALLOWED_BY_DEFAULT;
+  selection.memory_access.fine =
+      HSA_AMD_MEMORY_POOL_ACCESS_DISALLOWED_BY_DEFAULT;
+
+  iree_hal_amdgpu_physical_topology_edge_t edge;
+  IREE_ASSERT_OK(
+      iree_hal_amdgpu_select_physical_topology_edge(&selection, &edge));
+
+  EXPECT_TRUE(edge.memory_access.coarse_accessible);
+  EXPECT_TRUE(edge.memory_access.fine_accessible);
+  EXPECT_TRUE(iree_any_bit_set(edge.capabilities.guaranteed,
+                               IREE_HAL_TOPOLOGY_CAPABILITY_P2P_COPY));
+  EXPECT_TRUE(iree_any_bit_set(
+      edge.capabilities.required,
+      IREE_HAL_TOPOLOGY_CAPABILITY_PEER_ACCESS_REQUIRES_GRANT));
+  EXPECT_EQ(edge.modes.noncoherent_read, IREE_HAL_TOPOLOGY_INTEROP_MODE_COPY);
+  EXPECT_EQ(edge.modes.coherent_read, IREE_HAL_TOPOLOGY_INTEROP_MODE_COPY);
+}
+
+TEST_F(PhysicalDeviceCapabilitiesTest,
+       NeverAllowedPhysicalTopologyEdgeIsHostStaged) {
+  std::array<hsa_amd_memory_pool_link_info_t, 1> link_hops = {
+      LinkInfo(HSA_AMD_LINK_INFO_TYPE_XGMI)};
+  iree_hal_amdgpu_physical_topology_edge_selection_t selection =
+      MakeTopologyEdgeSelection(link_hops.data(), link_hops.size());
+  selection.memory_access.coarse = HSA_AMD_MEMORY_POOL_ACCESS_NEVER_ALLOWED;
+  selection.memory_access.fine = HSA_AMD_MEMORY_POOL_ACCESS_NEVER_ALLOWED;
+
+  iree_hal_amdgpu_physical_topology_edge_t edge;
+  IREE_ASSERT_OK(
+      iree_hal_amdgpu_select_physical_topology_edge(&selection, &edge));
+
+  EXPECT_FALSE(edge.memory_access.coarse_accessible);
+  EXPECT_FALSE(edge.memory_access.fine_accessible);
+  EXPECT_FALSE(edge.coherency.all_hops_coherent);
+  EXPECT_FALSE(edge.atomics.all_hops_32bit);
+  EXPECT_FALSE(edge.atomics.all_hops_64bit);
+  EXPECT_EQ(edge.link.link_class, IREE_HAL_TOPOLOGY_LINK_CLASS_HOST_STAGED);
+  EXPECT_EQ(edge.link.copy_cost, 13);
+  EXPECT_EQ(edge.link.latency_class, 11);
+  EXPECT_EQ(edge.capabilities.guaranteed, IREE_HAL_TOPOLOGY_CAPABILITY_NONE);
+}
+
+TEST_F(PhysicalDeviceCapabilitiesTest,
+       InvalidPhysicalTopologyEdgeInputsFailLoud) {
+  iree_hal_amdgpu_physical_topology_edge_selection_t selection =
+      MakeTopologyEdgeSelection(nullptr, 1);
+  iree_hal_amdgpu_physical_topology_edge_t edge;
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_INVALID_ARGUMENT,
+      iree_hal_amdgpu_select_physical_topology_edge(&selection, &edge));
+
+  selection = MakeTopologyEdgeSelection(nullptr, 0);
+  selection.memory_access.coarse = (hsa_amd_memory_pool_access_t)99;
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_OUT_OF_RANGE,
+      iree_hal_amdgpu_select_physical_topology_edge(&selection, &edge));
+}
+
+TEST_F(PhysicalDeviceCapabilitiesTest, CpuAccessInputsAreRequiredWhenNeeded) {
+  iree_hal_amdgpu_cpu_visible_device_coarse_memory_selection_t selection =
+      MakeCoarseMemorySelection();
+  selection.cpu.agents = nullptr;
+  iree_hal_amdgpu_cpu_visible_device_coarse_memory_t capability;
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_INVALID_ARGUMENT,
+                        iree_hal_amdgpu_select_cpu_visible_device_coarse_memory(
+                            &selection, &capability));
+
+  selection = MakeCoarseMemorySelection();
+  selection.cpu.access = nullptr;
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_INVALID_ARGUMENT,
+                        iree_hal_amdgpu_select_cpu_visible_device_coarse_memory(
+                            &selection, &capability));
+}
+
+TEST_F(PhysicalDeviceCapabilitiesTest, TooManyCpuAgentsFails) {
+  iree_hal_amdgpu_cpu_visible_device_coarse_memory_selection_t selection =
+      MakeCoarseMemorySelection();
+  selection.cpu.count = IREE_HAL_AMDGPU_MAX_CPU_AGENT + 1;
+  iree_hal_amdgpu_cpu_visible_device_coarse_memory_t capability;
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_OUT_OF_RANGE,
+                        iree_hal_amdgpu_select_cpu_visible_device_coarse_memory(
+                            &selection, &capability));
+}
+
+TEST_F(PhysicalDeviceCapabilitiesTest, SvmDefaultAccessDoesNotImplyPeerFlags) {
+  iree_hal_amdgpu_memory_system_capabilities_selection_t selection =
+      MakeMemorySystemSelection();
+  selection.svm.accessible_by_default = 1;
+  selection.svm.xnack_enabled = 1;
+
+  iree_hal_amdgpu_memory_system_capabilities_t capability;
+  iree_hal_amdgpu_select_memory_system_capabilities(&selection, &capability);
+
+  EXPECT_TRUE(capability.svm.supported);
+  EXPECT_TRUE(capability.svm.accessible_by_default);
+  EXPECT_TRUE(capability.svm.xnack_enabled);
+  EXPECT_FALSE(capability.svm.direct_host_access);
+  EXPECT_TRUE(capability.device_local.fine_host_visible);
+  EXPECT_FALSE(capability.device_local.coarse_cpu_visible);
+
+  iree_hal_device_capability_bits_t flags =
+      iree_hal_amdgpu_select_memory_system_device_capability_flags(&capability);
+  EXPECT_TRUE(flags & IREE_HAL_DEVICE_CAPABILITY_SHARED_VIRTUAL_ADDRESS);
+  EXPECT_TRUE(flags & IREE_HAL_DEVICE_CAPABILITY_UNIFIED_MEMORY);
+  EXPECT_FALSE(flags & IREE_HAL_DEVICE_CAPABILITY_PEER_ADDRESSABLE);
+  EXPECT_FALSE(flags & IREE_HAL_DEVICE_CAPABILITY_PEER_COHERENT);
+  EXPECT_FALSE(iree_hal_amdgpu_memory_system_requires_svm_access_attributes(
+      &capability));
+}
+
+TEST_F(PhysicalDeviceCapabilitiesTest,
+       LargeBarDoesNotImplyPageableSvmDefaultAccess) {
+  iree_hal_amdgpu_cpu_visible_device_coarse_memory_t coarse_memory = {};
+  coarse_memory.memory_pool = MemoryPool(40);
+  coarse_memory.flags =
+      IREE_HAL_AMDGPU_CPU_VISIBLE_DEVICE_COARSE_MEMORY_FLAG_AVAILABLE |
+      IREE_HAL_AMDGPU_CPU_VISIBLE_DEVICE_COARSE_MEMORY_FLAG_HDP_FLUSH;
+
+  iree_hal_amdgpu_memory_system_capabilities_selection_t selection =
+      MakeMemorySystemSelection();
+  selection.svm.direct_host_access = 1;
+  selection.device_local.coarse_cpu_visible_memory = &coarse_memory;
+
+  iree_hal_amdgpu_memory_system_capabilities_t capability;
+  iree_hal_amdgpu_select_memory_system_capabilities(&selection, &capability);
+
+  EXPECT_TRUE(capability.svm.supported);
+  EXPECT_FALSE(capability.svm.accessible_by_default);
+  EXPECT_FALSE(capability.svm.xnack_enabled);
+  EXPECT_TRUE(capability.svm.direct_host_access);
+  EXPECT_TRUE(capability.device_local.fine_host_visible);
+  EXPECT_TRUE(capability.device_local.coarse_cpu_visible);
+
+  iree_hal_device_capability_bits_t flags =
+      iree_hal_amdgpu_select_memory_system_device_capability_flags(&capability);
+  EXPECT_TRUE(flags & IREE_HAL_DEVICE_CAPABILITY_SHARED_VIRTUAL_ADDRESS);
+  EXPECT_FALSE(flags & IREE_HAL_DEVICE_CAPABILITY_UNIFIED_MEMORY);
+  EXPECT_FALSE(flags & IREE_HAL_DEVICE_CAPABILITY_PEER_ADDRESSABLE);
+  EXPECT_FALSE(flags & IREE_HAL_DEVICE_CAPABILITY_PEER_COHERENT);
+  EXPECT_TRUE(iree_hal_amdgpu_memory_system_requires_svm_access_attributes(
+      &capability));
+}
+
+TEST_F(PhysicalDeviceCapabilitiesTest, SelectsPrepublishedKernargStorage) {
+  iree_hal_amdgpu_aql_prepublished_kernarg_storage_t storage =
+      iree_hal_amdgpu_select_prepublished_kernarg_storage(MemoryPool(0));
+  EXPECT_EQ(storage.strategy,
+            IREE_HAL_AMDGPU_AQL_PREPUBLISHED_KERNARG_STORAGE_STRATEGY_DISABLED);
+
+  storage = iree_hal_amdgpu_select_prepublished_kernarg_storage(MemoryPool(42));
+  EXPECT_EQ(
+      storage.strategy,
+      IREE_HAL_AMDGPU_AQL_PREPUBLISHED_KERNARG_STORAGE_STRATEGY_DEVICE_FINE_HOST_COHERENT);
+  EXPECT_TRUE(iree_all_bits_set(storage.buffer_params.type,
+                                IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL |
+                                    IREE_HAL_MEMORY_TYPE_HOST_VISIBLE |
+                                    IREE_HAL_MEMORY_TYPE_HOST_COHERENT));
+}
+
+TEST_F(PhysicalDeviceCapabilitiesTest, SelectsCdnaBarrierValueCapabilities) {
+  iree_hal_amdgpu_vendor_packet_capability_flags_t capabilities =
+      iree_hal_amdgpu_select_vendor_packet_capabilities(GfxIp(9, 0, 10));
+  EXPECT_TRUE(iree_all_bits_set(
+      capabilities,
+      IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_AQL_PM4_IB |
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_AQL_BARRIER_VALUE));
+  EXPECT_FALSE(iree_any_bit_set(
+      capabilities, IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_WAIT_REG_MEM64));
+
+  capabilities =
+      iree_hal_amdgpu_select_vendor_packet_capabilities(GfxIp(9, 4, 2));
+  EXPECT_TRUE(iree_all_bits_set(
+      capabilities,
+      IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_AQL_PM4_IB |
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_AQL_BARRIER_VALUE));
+
+  capabilities =
+      iree_hal_amdgpu_select_vendor_packet_capabilities(GfxIp(9, 5, 2));
+  EXPECT_TRUE(iree_all_bits_set(
+      capabilities,
+      IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_AQL_PM4_IB |
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_AQL_BARRIER_VALUE));
+
+  capabilities =
+      iree_hal_amdgpu_select_vendor_packet_capabilities(GfxIp(9, 4, 3));
+  EXPECT_EQ(capabilities, 0u);
+
+  capabilities =
+      iree_hal_amdgpu_select_vendor_packet_capabilities(GfxIp(9, 5, 3));
+  EXPECT_EQ(capabilities, 0u);
+}
+
+TEST_F(PhysicalDeviceCapabilitiesTest, SelectsValidatedGfx1100Capabilities) {
+  iree_hal_amdgpu_vendor_packet_capability_flags_t capabilities =
+      iree_hal_amdgpu_select_vendor_packet_capabilities(GfxIp(11, 0, 0));
+  EXPECT_TRUE(iree_all_bits_set(
+      capabilities,
+      IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_AQL_PM4_IB |
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_WAIT_REG_MEM64 |
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_WRITE_DATA_MEMORY |
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_COPY_DATA_MEMORY |
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_COPY_TIMESTAMP |
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_RELEASE_MEM_TIMESTAMP |
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_EVENT_WRITE |
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_SET_SH_REG |
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_SET_UCONFIG_REG |
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_REGISTER_READBACK |
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_PERFCOUNTER_READBACK |
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_IMMEDIATE_WRITE));
+}
+
+TEST_F(PhysicalDeviceCapabilitiesTest,
+       LeavesUnvalidatedGfxFamiliesOnBaseAqlPath) {
+  EXPECT_EQ(iree_hal_amdgpu_select_vendor_packet_capabilities(GfxIp(10, 3, 0)),
+            0u);
+  EXPECT_EQ(iree_hal_amdgpu_select_vendor_packet_capabilities(GfxIp(11, 0, 1)),
+            0u);
+  EXPECT_EQ(iree_hal_amdgpu_select_vendor_packet_capabilities(GfxIp(12, 0, 0)),
+            0u);
+}
+
+TEST_F(PhysicalDeviceCapabilitiesTest, SelectsWaitBarrierStrategy) {
+  EXPECT_EQ(iree_hal_amdgpu_select_wait_barrier_strategy(
+                IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_AQL_BARRIER_VALUE |
+                IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_WAIT_REG_MEM64),
+            IREE_HAL_AMDGPU_WAIT_BARRIER_STRATEGY_AQL_BARRIER_VALUE);
+  EXPECT_EQ(iree_hal_amdgpu_select_wait_barrier_strategy(
+                IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_WAIT_REG_MEM64),
+            IREE_HAL_AMDGPU_WAIT_BARRIER_STRATEGY_PM4_WAIT_REG_MEM64);
+  EXPECT_EQ(iree_hal_amdgpu_select_wait_barrier_strategy(0),
+            IREE_HAL_AMDGPU_WAIT_BARRIER_STRATEGY_DEFER);
+}
+
+}  // namespace
+}  // namespace iree::hal::amdgpu
diff --git a/runtime/src/iree/hal/drivers/amdgpu/profile_aqlprofile.c b/runtime/src/iree/hal/drivers/amdgpu/profile_aqlprofile.c
new file mode 100644
index 0000000..f539303
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/profile_aqlprofile.c
@@ -0,0 +1,183 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/licenses/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/profile_aqlprofile.h"
+
+#include <string.h>
+
+//===----------------------------------------------------------------------===//
+// Raw HSA helpers
+//===----------------------------------------------------------------------===//
+
+static hsa_status_t iree_hal_amdgpu_profile_hsa_memory_pool_allocate(
+    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_amd_memory_pool_t memory_pool,
+    size_t size, uint32_t flags, void** ptr) {
+#if IREE_HAL_AMDGPU_LIBHSA_STATIC
+  (void)libhsa;
+  return hsa_amd_memory_pool_allocate(memory_pool, size, flags, ptr);
+#else
+  return libhsa->hsa_amd_memory_pool_allocate(memory_pool, size, flags, ptr);
+#endif  // IREE_HAL_AMDGPU_LIBHSA_STATIC
+}
+
+static hsa_status_t iree_hal_amdgpu_profile_hsa_memory_pool_free(
+    const iree_hal_amdgpu_libhsa_t* libhsa, void* ptr) {
+#if IREE_HAL_AMDGPU_LIBHSA_STATIC
+  (void)libhsa;
+  return hsa_amd_memory_pool_free(ptr);
+#else
+  return libhsa->hsa_amd_memory_pool_free(ptr);
+#endif  // IREE_HAL_AMDGPU_LIBHSA_STATIC
+}
+
+static hsa_status_t iree_hal_amdgpu_profile_hsa_agents_allow_access(
+    const iree_hal_amdgpu_libhsa_t* libhsa, uint32_t num_agents,
+    const hsa_agent_t* agents, const uint32_t* flags, const void* ptr) {
+#if IREE_HAL_AMDGPU_LIBHSA_STATIC
+  (void)libhsa;
+  return hsa_amd_agents_allow_access(num_agents, agents, flags, ptr);
+#else
+  return libhsa->hsa_amd_agents_allow_access(num_agents, agents, flags, ptr);
+#endif  // IREE_HAL_AMDGPU_LIBHSA_STATIC
+}
+
+static hsa_status_t iree_hal_amdgpu_profile_hsa_memory_copy(
+    const iree_hal_amdgpu_libhsa_t* libhsa, void* target, const void* source,
+    size_t size) {
+#if IREE_HAL_AMDGPU_LIBHSA_STATIC
+  (void)libhsa;
+  return hsa_memory_copy(target, source, size);
+#else
+  return libhsa->hsa_memory_copy(target, source, size);
+#endif  // IREE_HAL_AMDGPU_LIBHSA_STATIC
+}
+
+//===----------------------------------------------------------------------===//
+// aqlprofile support
+//===----------------------------------------------------------------------===//
+
+iree_status_t iree_hal_amdgpu_profile_aqlprofile_register_agent(
+    const iree_hal_amdgpu_libhsa_t* libhsa,
+    const iree_hal_amdgpu_libaqlprofile_t* libaqlprofile,
+    hsa_agent_t device_agent,
+    iree_hal_amdgpu_aqlprofile_agent_handle_t* out_agent_handle) {
+  char agent_name[64] = {0};
+  iree_hal_amdgpu_aqlprofile_agent_info_v1_t agent_info;
+  memset(&agent_info, 0, sizeof(agent_info));
+  IREE_RETURN_IF_ERROR(iree_hsa_agent_get_info(
+      IREE_LIBHSA(libhsa), device_agent, HSA_AGENT_INFO_NAME, agent_name));
+  IREE_RETURN_IF_ERROR(iree_hsa_agent_get_info(
+      IREE_LIBHSA(libhsa), device_agent,
+      (hsa_agent_info_t)HSA_AMD_AGENT_INFO_NUM_XCC, &agent_info.xcc_num));
+  IREE_RETURN_IF_ERROR(iree_hsa_agent_get_info(
+      IREE_LIBHSA(libhsa), device_agent,
+      (hsa_agent_info_t)HSA_AMD_AGENT_INFO_NUM_SHADER_ENGINES,
+      &agent_info.se_num));
+  IREE_RETURN_IF_ERROR(iree_hsa_agent_get_info(
+      IREE_LIBHSA(libhsa), device_agent,
+      (hsa_agent_info_t)HSA_AMD_AGENT_INFO_COMPUTE_UNIT_COUNT,
+      &agent_info.cu_num));
+  IREE_RETURN_IF_ERROR(iree_hsa_agent_get_info(
+      IREE_LIBHSA(libhsa), device_agent,
+      (hsa_agent_info_t)HSA_AMD_AGENT_INFO_NUM_SHADER_ARRAYS_PER_SE,
+      &agent_info.shader_arrays_per_se));
+  IREE_RETURN_IF_ERROR(iree_hsa_agent_get_info(
+      IREE_LIBHSA(libhsa), device_agent,
+      (hsa_agent_info_t)HSA_AMD_AGENT_INFO_DOMAIN, &agent_info.domain));
+  IREE_RETURN_IF_ERROR(iree_hsa_agent_get_info(
+      IREE_LIBHSA(libhsa), device_agent,
+      (hsa_agent_info_t)HSA_AMD_AGENT_INFO_BDFID, &agent_info.location_id));
+  agent_info.agent_gfxip = agent_name;
+  IREE_RETURN_IF_AQLPROFILE_ERROR(
+      libaqlprofile,
+      libaqlprofile->aqlprofile_register_agent_info(
+          out_agent_handle, &agent_info,
+          IREE_HAL_AMDGPU_AQLPROFILE_AGENT_VERSION_V1),
+      "registering AMDGPU profiling agent");
+  return iree_ok_status();
+}
+
+hsa_status_t iree_hal_amdgpu_profile_aqlprofile_memory_alloc(
+    void** ptr, uint64_t size,
+    iree_hal_amdgpu_aqlprofile_buffer_desc_flags_t flags, void* user_data) {
+  iree_hal_amdgpu_profile_aqlprofile_memory_context_t* context =
+      (iree_hal_amdgpu_profile_aqlprofile_memory_context_t*)user_data;
+  *ptr = NULL;
+  if (size == 0) return HSA_STATUS_SUCCESS;
+
+  hsa_amd_memory_pool_t memory_pool = {0};
+  const bool should_clear = flags.host_access;
+  const bool should_allow_device_access = flags.device_access;
+  const bool should_allocate_executable =
+      flags.host_access && flags.device_access &&
+      (flags.memory_hint ==
+           IREE_HAL_AMDGPU_AQLPROFILE_MEMORY_HINT_DEVICE_NONCOHERENT ||
+       flags.memory_hint == IREE_HAL_AMDGPU_AQLPROFILE_MEMORY_HINT_NONE);
+  if (flags.host_access) {
+    memory_pool = context->host_memory_pools->coarse_pool;
+    if (flags.memory_hint ==
+        IREE_HAL_AMDGPU_AQLPROFILE_MEMORY_HINT_DEVICE_UNCACHED) {
+      memory_pool = context->host_memory_pools->kernarg_pool;
+    }
+  } else if (flags.device_access) {
+    memory_pool = context->device_coarse_pool;
+  } else {
+    return HSA_STATUS_ERROR_INVALID_ARGUMENT;
+  }
+  if (!memory_pool.handle) return HSA_STATUS_ERROR_INVALID_ALLOCATION;
+
+  hsa_status_t status = iree_hal_amdgpu_profile_hsa_memory_pool_allocate(
+      context->libhsa, memory_pool, (size_t)size,
+      should_allocate_executable ? HSA_AMD_MEMORY_POOL_EXECUTABLE_FLAG : 0,
+      ptr);
+  if (status != HSA_STATUS_SUCCESS) return status;
+  if (should_clear) memset(*ptr, 0, (size_t)size);
+
+  if (should_allow_device_access) {
+    status = iree_hal_amdgpu_profile_hsa_agents_allow_access(
+        context->libhsa, /*num_agents=*/1, &context->device_agent,
+        /*flags=*/NULL, *ptr);
+    if (status != HSA_STATUS_SUCCESS) {
+      iree_hal_amdgpu_profile_hsa_memory_pool_free(context->libhsa, *ptr);
+      *ptr = NULL;
+    }
+  }
+  return status;
+}
+
+void iree_hal_amdgpu_profile_aqlprofile_memory_dealloc(void* ptr,
+                                                       void* user_data) {
+  if (!ptr) return;
+  iree_hal_amdgpu_profile_aqlprofile_memory_context_t* context =
+      (iree_hal_amdgpu_profile_aqlprofile_memory_context_t*)user_data;
+  iree_hal_amdgpu_profile_hsa_memory_pool_free(context->libhsa, ptr);
+}
+
+hsa_status_t iree_hal_amdgpu_profile_aqlprofile_memory_copy(void* target,
+                                                            const void* source,
+                                                            size_t size,
+                                                            void* user_data) {
+  if (size == 0) return HSA_STATUS_SUCCESS;
+  iree_hal_amdgpu_profile_aqlprofile_memory_context_t* context =
+      (iree_hal_amdgpu_profile_aqlprofile_memory_context_t*)user_data;
+  return iree_hal_amdgpu_profile_hsa_memory_copy(context->libhsa, target,
+                                                 source, size);
+}
+
+void iree_hal_amdgpu_profile_aqlprofile_emplace_pm4_ib_packet(
+    const iree_hsa_amd_aql_pm4_ib_packet_t* source_packet,
+    iree_hal_amdgpu_aql_packet_t* packet,
+    iree_hal_amdgpu_aql_packet_control_t packet_control,
+    iree_hsa_signal_t completion_signal, uint16_t* out_header,
+    uint16_t* out_setup) {
+  memcpy((uint8_t*)&packet->pm4_ib + sizeof(uint32_t),
+         (const uint8_t*)source_packet + sizeof(uint32_t),
+         sizeof(*source_packet) - sizeof(uint32_t));
+  packet->pm4_ib.completion_signal = completion_signal;
+  *out_setup = IREE_HSA_AMD_AQL_FORMAT_PM4_IB;
+  *out_header = iree_hal_amdgpu_aql_make_header(
+      IREE_HSA_PACKET_TYPE_VENDOR_SPECIFIC, packet_control);
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/profile_aqlprofile.h b/runtime/src/iree/hal/drivers/amdgpu/profile_aqlprofile.h
new file mode 100644
index 0000000..b84b948
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/profile_aqlprofile.h
@@ -0,0 +1,69 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/licenses/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_PROFILE_AQLPROFILE_H_
+#define IREE_HAL_DRIVERS_AMDGPU_PROFILE_AQLPROFILE_H_
+
+#include "iree/hal/drivers/amdgpu/system.h"
+#include "iree/hal/drivers/amdgpu/util/aql_emitter.h"
+#include "iree/hal/drivers/amdgpu/util/aql_ring.h"
+#include "iree/hal/drivers/amdgpu/util/libaqlprofile.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Callback context used by aqlprofile memory allocation and copy hooks.
+typedef struct iree_hal_amdgpu_profile_aqlprofile_memory_context_t {
+  // HSA API table used by raw callback-status functions.
+  const iree_hal_amdgpu_libhsa_t* libhsa;
+  // GPU agent that must be able to access allocated aqlprofile memory.
+  hsa_agent_t device_agent;
+  // Host memory pools nearest to |device_agent|.
+  const iree_hal_amdgpu_host_memory_pools_t* host_memory_pools;
+  // Coarse-grained memory pool owned by |device_agent| for device-only trace
+  // output buffers.
+  hsa_amd_memory_pool_t device_coarse_pool;
+} iree_hal_amdgpu_profile_aqlprofile_memory_context_t;
+
+// Registers |device_agent| with aqlprofile and returns an opaque agent handle.
+iree_status_t iree_hal_amdgpu_profile_aqlprofile_register_agent(
+    const iree_hal_amdgpu_libhsa_t* libhsa,
+    const iree_hal_amdgpu_libaqlprofile_t* libaqlprofile,
+    hsa_agent_t device_agent,
+    iree_hal_amdgpu_aqlprofile_agent_handle_t* out_agent_handle);
+
+// Allocates memory requested by aqlprofile packet generation.
+hsa_status_t iree_hal_amdgpu_profile_aqlprofile_memory_alloc(
+    void** ptr, uint64_t size,
+    iree_hal_amdgpu_aqlprofile_buffer_desc_flags_t flags, void* user_data);
+
+// Releases memory previously allocated by
+// iree_hal_amdgpu_profile_aqlprofile_memory_alloc.
+void iree_hal_amdgpu_profile_aqlprofile_memory_dealloc(void* ptr,
+                                                       void* user_data);
+
+// Copies memory for aqlprofile packet generation.
+hsa_status_t iree_hal_amdgpu_profile_aqlprofile_memory_copy(void* target,
+                                                            const void* source,
+                                                            size_t size,
+                                                            void* user_data);
+
+// Copies aqlprofile's PM4-IB AQL packet template into |packet| and returns the
+// header/setup dwords the caller must publish after all packet bodies are
+// populated.
+void iree_hal_amdgpu_profile_aqlprofile_emplace_pm4_ib_packet(
+    const iree_hsa_amd_aql_pm4_ib_packet_t* source_packet,
+    iree_hal_amdgpu_aql_packet_t* packet,
+    iree_hal_amdgpu_aql_packet_control_t packet_control,
+    iree_hsa_signal_t completion_signal, uint16_t* out_header,
+    uint16_t* out_setup);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_PROFILE_AQLPROFILE_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/profile_counters.c b/runtime/src/iree/hal/drivers/amdgpu/profile_counters.c
new file mode 100644
index 0000000..99e4b96
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/profile_counters.c
@@ -0,0 +1,2069 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/profile_counters.h"
+
+#include <string.h>
+
+#include "iree/hal/drivers/amdgpu/host_queue.h"
+#include "iree/hal/drivers/amdgpu/host_queue_profile.h"
+#include "iree/hal/drivers/amdgpu/host_queue_profile_events.h"
+#include "iree/hal/drivers/amdgpu/host_queue_timestamp.h"
+#include "iree/hal/drivers/amdgpu/logical_device.h"
+#include "iree/hal/drivers/amdgpu/physical_device.h"
+#include "iree/hal/drivers/amdgpu/profile_aqlprofile.h"
+#include "iree/hal/drivers/amdgpu/system.h"
+#include "iree/hal/drivers/amdgpu/util/libaqlprofile.h"
+#include "iree/hal/drivers/amdgpu/util/signal_pool.h"
+
+//===----------------------------------------------------------------------===//
+// Counter support tables
+//===----------------------------------------------------------------------===//
+
+static const iree_string_view_t iree_hal_amdgpu_profile_counter_set_name(void) {
+  return IREE_SV("amdgpu.pmc");
+}
+
+enum { iree_hal_amdgpu_profile_counter_packets_per_set = 3u };
+enum { iree_hal_amdgpu_profile_counter_event_unsupported = UINT32_MAX };
+enum { iree_hal_amdgpu_profile_counter_range_bank_count = 2u };
+
+typedef uint32_t iree_hal_amdgpu_profile_counter_session_flags_t;
+enum iree_hal_amdgpu_profile_counter_session_flag_bits_t {
+  IREE_HAL_AMDGPU_PROFILE_COUNTER_SESSION_FLAG_NONE = 0u,
+  // Captures operation-attributed counter samples around selected dispatches.
+  IREE_HAL_AMDGPU_PROFILE_COUNTER_SESSION_FLAG_DISPATCH_SAMPLES = 1u << 0,
+  // Captures queue-level counter ranges over profiling begin/flush/end spans.
+  IREE_HAL_AMDGPU_PROFILE_COUNTER_SESSION_FLAG_QUEUE_RANGES = 1u << 1,
+};
+
+// Static description of one raw hardware counter we can request from
+// aqlprofile by name.
+typedef struct iree_hal_amdgpu_profile_counter_descriptor_t {
+  // User-visible counter name accepted in profiling options and metadata.
+  iree_string_view_t name;
+  // User-visible hardware block name emitted in metadata.
+  iree_string_view_t block_name;
+  // User-visible description emitted in metadata.
+  iree_string_view_t description;
+  // Display unit for values returned by this counter.
+  iree_hal_profile_counter_unit_t unit;
+  // aqlprofile hardware block identifier for this counter.
+  iree_hal_amdgpu_aqlprofile_block_name_t block_name_id;
+  // aqlprofile event id for gfx9 devices, or unsupported.
+  uint32_t gfx9_event_id;
+  // aqlprofile event id for gfx10 devices, or unsupported.
+  uint32_t gfx10_event_id;
+  // aqlprofile event id for gfx11 devices, or unsupported.
+  uint32_t gfx11_event_id;
+  // aqlprofile event id for gfx12 devices, or unsupported.
+  uint32_t gfx12_event_id;
+} iree_hal_amdgpu_profile_counter_descriptor_t;
+
+#define IREE_HAL_AMDGPU_PROFILE_COUNTER_SV(value) {(value), sizeof(value) - 1}
+
+static const iree_hal_amdgpu_profile_counter_descriptor_t
+    iree_hal_amdgpu_profile_counter_descriptors[] = {
+        {
+            .name = IREE_HAL_AMDGPU_PROFILE_COUNTER_SV("SQ_WAVES"),
+            .block_name = IREE_HAL_AMDGPU_PROFILE_COUNTER_SV("SQ"),
+            .description = IREE_HAL_AMDGPU_PROFILE_COUNTER_SV(
+                "Raw SQ_WAVES values returned by aqlprofile."),
+            .unit = IREE_HAL_PROFILE_COUNTER_UNIT_COUNT,
+            .block_name_id = IREE_HAL_AMDGPU_AQLPROFILE_BLOCK_NAME_SQ,
+            .gfx9_event_id = 4,
+            .gfx10_event_id = 4,
+            .gfx11_event_id = 4,
+            .gfx12_event_id = 4,
+        },
+        {
+            .name = IREE_HAL_AMDGPU_PROFILE_COUNTER_SV("SQ_WAVES_32"),
+            .block_name = IREE_HAL_AMDGPU_PROFILE_COUNTER_SV("SQ"),
+            .description = IREE_HAL_AMDGPU_PROFILE_COUNTER_SV(
+                "Raw SQ_WAVES_32 values returned by aqlprofile."),
+            .unit = IREE_HAL_PROFILE_COUNTER_UNIT_COUNT,
+            .block_name_id = IREE_HAL_AMDGPU_AQLPROFILE_BLOCK_NAME_SQ,
+            .gfx9_event_id = iree_hal_amdgpu_profile_counter_event_unsupported,
+            .gfx10_event_id = 5,
+            .gfx11_event_id = 5,
+            .gfx12_event_id = 5,
+        },
+        {
+            .name = IREE_HAL_AMDGPU_PROFILE_COUNTER_SV("SQ_WAVES_64"),
+            .block_name = IREE_HAL_AMDGPU_PROFILE_COUNTER_SV("SQ"),
+            .description = IREE_HAL_AMDGPU_PROFILE_COUNTER_SV(
+                "Raw SQ_WAVES_64 values returned by aqlprofile."),
+            .unit = IREE_HAL_PROFILE_COUNTER_UNIT_COUNT,
+            .block_name_id = IREE_HAL_AMDGPU_AQLPROFILE_BLOCK_NAME_SQ,
+            .gfx9_event_id = iree_hal_amdgpu_profile_counter_event_unsupported,
+            .gfx10_event_id = 6,
+            .gfx11_event_id = 6,
+            .gfx12_event_id = 6,
+        },
+        {
+            .name = IREE_HAL_AMDGPU_PROFILE_COUNTER_SV("SQ_BUSY_CYCLES"),
+            .block_name = IREE_HAL_AMDGPU_PROFILE_COUNTER_SV("SQ"),
+            .description = IREE_HAL_AMDGPU_PROFILE_COUNTER_SV(
+                "Clock cycles with active waves in a shader engine."),
+            .unit = IREE_HAL_PROFILE_COUNTER_UNIT_CYCLES,
+            .block_name_id = IREE_HAL_AMDGPU_AQLPROFILE_BLOCK_NAME_SQ,
+            .gfx9_event_id = 3,
+            .gfx10_event_id = 3,
+            .gfx11_event_id = 3,
+            .gfx12_event_id = 3,
+        },
+};
+
+#undef IREE_HAL_AMDGPU_PROFILE_COUNTER_SV
+
+// One resolved raw counter within a selected counter set.
+typedef struct iree_hal_amdgpu_profile_counter_t {
+  // Static counter descriptor used for metadata and event matching.
+  const iree_hal_amdgpu_profile_counter_descriptor_t* descriptor;
+  // Resolved aqlprofile hardware event request for the physical device.
+  iree_hal_amdgpu_aqlprofile_pmc_event_t event;
+  // Counter ordinal within its owning counter set.
+  uint32_t counter_ordinal;
+  // First uint64_t slot occupied by this counter in emitted sample records.
+  uint32_t sample_value_offset;
+  // Number of uint64_t slots occupied by this counter in emitted samples.
+  uint32_t sample_value_count;
+} iree_hal_amdgpu_profile_counter_t;
+
+// One resolved counter set for one physical device.
+typedef struct iree_hal_amdgpu_profile_counter_set_t {
+  // Session-local counter set id referenced by counter and sample records.
+  uint64_t counter_set_id;
+  // Session-local physical device ordinal owning this counter set.
+  uint32_t physical_device_ordinal;
+  // Number of counters in |counters|.
+  uint32_t counter_count;
+  // Number of uint64_t values in each sample for this counter set.
+  uint32_t sample_value_count;
+  // Registered aqlprofile agent used when creating per-use sample handles.
+  iree_hal_amdgpu_aqlprofile_agent_handle_t agent;
+  // Contiguous aqlprofile event requests for all counters in this set.
+  iree_hal_amdgpu_aqlprofile_pmc_event_t* events;
+  // Resolved counters with emitted sample-value slices.
+  iree_hal_amdgpu_profile_counter_t* counters;
+  // Session-owned human-readable counter set name.
+  iree_string_view_t name;
+} iree_hal_amdgpu_profile_counter_set_t;
+
+// Per-use mutable aqlprofile capture packet state.
+typedef struct iree_hal_amdgpu_profile_counter_packet_set_t {
+  // Callback context retained for the lifetime of |handle|.
+  iree_hal_amdgpu_profile_aqlprofile_memory_context_t memory_context;
+  // aqlprofile handle owning PM4 programs and output storage for this slot.
+  iree_hal_amdgpu_aqlprofile_handle_t handle;
+  // AQL PM4-IB packet templates referencing |handle|'s immutable PM4 programs.
+  iree_hal_amdgpu_aqlprofile_pmc_aql_packets_t packets;
+} iree_hal_amdgpu_profile_counter_packet_set_t;
+
+// Per-queue/per-event-ring-slot mutable aqlprofile capture state.
+struct iree_hal_amdgpu_profile_counter_sample_slot_t {
+  // Reusable aqlprofile packet set for this event-ring slot.
+  iree_hal_amdgpu_profile_counter_packet_set_t packet_set;
+  // Producer-local sample id assigned when this slot is reserved for a
+  // dispatch.
+  uint64_t sample_id;
+};
+
+// Per-queue/per-range-bank mutable aqlprofile capture state.
+struct iree_hal_amdgpu_profile_counter_range_slot_t {
+  // Reusable aqlprofile packet set for this range bank.
+  iree_hal_amdgpu_profile_counter_packet_set_t packet_set;
+  // Producer-local sample id assigned when this range bank starts.
+  uint64_t sample_id;
+};
+
+// Device-visible tick range associated with one queue counter range bank.
+typedef struct iree_hal_amdgpu_profile_counter_range_ticks_t {
+  // Device timestamp captured after the range counter starts.
+  uint64_t start_tick;
+  // Device timestamp captured before the range counter is read and stopped.
+  uint64_t end_tick;
+} iree_hal_amdgpu_profile_counter_range_ticks_t;
+
+// Logical-device profiling session for selected hardware counters.
+struct iree_hal_amdgpu_profile_counter_session_t {
+  // Host allocator used for session and queue slot storage.
+  iree_allocator_t host_allocator;
+  // Counter capture variants requested by profiling options.
+  iree_hal_amdgpu_profile_counter_session_flags_t flags;
+  // Borrowed HSA API table from the logical device.
+  const iree_hal_amdgpu_libhsa_t* libhsa;
+  // Dynamically loaded aqlprofile SDK.
+  iree_hal_amdgpu_libaqlprofile_t libaqlprofile;
+  // Number of physical devices in |agent_handles|.
+  iree_host_size_t physical_device_count;
+  // Number of requested counter sets per physical device.
+  uint32_t counter_set_count;
+  // Registered aqlprofile agent handles indexed by physical device ordinal.
+  iree_hal_amdgpu_aqlprofile_agent_handle_t* agent_handles;
+  // Resolved counter sets indexed by physical device then counter-set ordinal.
+  iree_hal_amdgpu_profile_counter_set_t* counter_sets;
+  // Next nonzero producer-local counter sample id.
+  iree_atomic_int64_t next_sample_id;
+};
+
+// Callback context used to count decoded aqlprofile values.
+typedef struct iree_hal_amdgpu_profile_counter_count_context_t {
+  // Counter set whose per-counter sample widths are being discovered.
+  iree_hal_amdgpu_profile_counter_set_t* counter_set;
+} iree_hal_amdgpu_profile_counter_count_context_t;
+
+// Callback context used to copy decoded aqlprofile values.
+typedef struct iree_hal_amdgpu_profile_counter_collect_context_t {
+  // Counter set describing the emitted sample-value layout.
+  const iree_hal_amdgpu_profile_counter_set_t* counter_set;
+  // Destination value vector.
+  uint64_t* values;
+  // Per-counter value counts already written for the current sample.
+  uint32_t* counter_value_counts;
+} iree_hal_amdgpu_profile_counter_collect_context_t;
+
+static const iree_hal_amdgpu_profile_counter_descriptor_t*
+iree_hal_amdgpu_profile_counter_find_descriptor(iree_string_view_t name) {
+  for (iree_host_size_t i = 0;
+       i < IREE_ARRAYSIZE(iree_hal_amdgpu_profile_counter_descriptors); ++i) {
+    const iree_hal_amdgpu_profile_counter_descriptor_t* descriptor =
+        &iree_hal_amdgpu_profile_counter_descriptors[i];
+    if (iree_string_view_equal(name, descriptor->name)) return descriptor;
+  }
+  return NULL;
+}
+
+static iree_status_t iree_hal_amdgpu_profile_counter_resolve_event(
+    const iree_hal_amdgpu_profile_counter_descriptor_t* descriptor,
+    iree_hal_amdgpu_gfxip_version_t gfxip_version,
+    iree_hal_amdgpu_aqlprofile_pmc_event_t* out_event) {
+  uint32_t event_id = iree_hal_amdgpu_profile_counter_event_unsupported;
+  switch (gfxip_version.major) {
+    case 9:
+      event_id = descriptor->gfx9_event_id;
+      break;
+    case 10:
+      event_id = descriptor->gfx10_event_id;
+      break;
+    case 11:
+      event_id = descriptor->gfx11_event_id;
+      break;
+    case 12:
+      event_id = descriptor->gfx12_event_id;
+      break;
+    default:
+      break;
+  }
+  if (IREE_UNLIKELY(event_id ==
+                    iree_hal_amdgpu_profile_counter_event_unsupported)) {
+    return iree_make_status(
+        IREE_STATUS_UNIMPLEMENTED,
+        "AMDGPU counter '%.*s' is not mapped for gfx%u.%u.%u",
+        (int)descriptor->name.size, descriptor->name.data, gfxip_version.major,
+        gfxip_version.minor, gfxip_version.stepping);
+  }
+
+  *out_event = (iree_hal_amdgpu_aqlprofile_pmc_event_t){
+      .block_index = 0,
+      .event_id = event_id,
+      .block_name = descriptor->block_name_id,
+  };
+  return iree_ok_status();
+}
+
+static bool iree_hal_amdgpu_profile_counter_events_equal(
+    iree_hal_amdgpu_aqlprofile_pmc_event_t lhs,
+    iree_hal_amdgpu_aqlprofile_pmc_event_t rhs) {
+  return lhs.block_index == rhs.block_index && lhs.event_id == rhs.event_id &&
+         lhs.flags.raw == rhs.flags.raw && lhs.block_name == rhs.block_name;
+}
+
+static bool iree_hal_amdgpu_profile_counter_find_index_by_event(
+    const iree_hal_amdgpu_profile_counter_set_t* counter_set,
+    iree_hal_amdgpu_aqlprofile_pmc_event_t event,
+    uint32_t* out_counter_ordinal) {
+  for (uint32_t i = 0; i < counter_set->counter_count; ++i) {
+    const iree_hal_amdgpu_profile_counter_t* counter =
+        &counter_set->counters[i];
+    if (iree_hal_amdgpu_profile_counter_events_equal(counter->event, event)) {
+      *out_counter_ordinal = i;
+      return true;
+    }
+  }
+  return false;
+}
+
+static hsa_status_t iree_hal_amdgpu_profile_counter_count_callback(
+    iree_hal_amdgpu_aqlprofile_pmc_event_t event, uint64_t counter_id,
+    uint64_t counter_value, void* user_data) {
+  (void)counter_id;
+  (void)counter_value;
+  iree_hal_amdgpu_profile_counter_count_context_t* context =
+      (iree_hal_amdgpu_profile_counter_count_context_t*)user_data;
+  uint32_t counter_ordinal = 0;
+  if (!iree_hal_amdgpu_profile_counter_find_index_by_event(
+          context->counter_set, event, &counter_ordinal)) {
+    return HSA_STATUS_ERROR_INVALID_ARGUMENT;
+  }
+  ++context->counter_set->counters[counter_ordinal].sample_value_count;
+  return HSA_STATUS_SUCCESS;
+}
+
+static hsa_status_t iree_hal_amdgpu_profile_counter_collect_callback(
+    iree_hal_amdgpu_aqlprofile_pmc_event_t event, uint64_t counter_id,
+    uint64_t counter_value, void* user_data) {
+  (void)counter_id;
+  iree_hal_amdgpu_profile_counter_collect_context_t* context =
+      (iree_hal_amdgpu_profile_counter_collect_context_t*)user_data;
+  uint32_t counter_ordinal = 0;
+  if (!iree_hal_amdgpu_profile_counter_find_index_by_event(
+          context->counter_set, event, &counter_ordinal)) {
+    return HSA_STATUS_ERROR_INVALID_ARGUMENT;
+  }
+  const iree_hal_amdgpu_profile_counter_t* counter =
+      &context->counter_set->counters[counter_ordinal];
+  uint32_t* counter_value_count =
+      &context->counter_value_counts[counter_ordinal];
+  if (*counter_value_count >= counter->sample_value_count) {
+    return HSA_STATUS_ERROR_OUT_OF_RESOURCES;
+  }
+  context->values[counter->sample_value_offset + *counter_value_count] =
+      counter_value;
+  ++*counter_value_count;
+  return HSA_STATUS_SUCCESS;
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hal_amdgpu_profile_counter_session_t
+//===----------------------------------------------------------------------===//
+
+static iree_status_t iree_hal_amdgpu_profile_counter_initialize_selection(
+    const iree_hal_profile_counter_set_selection_t* selection,
+    iree_hal_amdgpu_gfxip_version_t gfxip_version,
+    iree_hal_amdgpu_profile_counter_set_t* counter_set) {
+  if (IREE_UNLIKELY(selection->flags !=
+                    IREE_HAL_PROFILE_COUNTER_SET_SELECTION_FLAG_NONE)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "unsupported AMDGPU counter set flags 0x%x",
+                            selection->flags);
+  }
+  if (IREE_UNLIKELY(selection->counter_name_count == 0)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "AMDGPU counter profiling requires at least one counter per set");
+  }
+  if (IREE_UNLIKELY(selection->counter_name_count > UINT32_MAX)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "AMDGPU counter count exceeds uint32_t");
+  }
+
+  for (iree_host_size_t i = 0; i < selection->counter_name_count; ++i) {
+    const iree_hal_amdgpu_profile_counter_descriptor_t* descriptor =
+        iree_hal_amdgpu_profile_counter_find_descriptor(
+            selection->counter_names[i]);
+    if (IREE_UNLIKELY(!descriptor)) {
+      return iree_make_status(
+          IREE_STATUS_UNIMPLEMENTED,
+          "unsupported AMDGPU counter '%.*s'; supported counters: "
+          "SQ_WAVES, SQ_WAVES_32, SQ_WAVES_64, SQ_BUSY_CYCLES",
+          (int)selection->counter_names[i].size,
+          selection->counter_names[i].data);
+    }
+    for (iree_host_size_t j = 0; j < i; ++j) {
+      if (iree_string_view_equal(selection->counter_names[i],
+                                 selection->counter_names[j])) {
+        return iree_make_status(
+            IREE_STATUS_INVALID_ARGUMENT,
+            "duplicate AMDGPU counter '%.*s' in one counter set",
+            (int)selection->counter_names[i].size,
+            selection->counter_names[i].data);
+      }
+    }
+
+    const uint32_t counter_ordinal = (uint32_t)i;
+    iree_hal_amdgpu_aqlprofile_pmc_event_t event = {0};
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_profile_counter_resolve_event(
+        descriptor, gfxip_version, &event));
+    counter_set->events[i] = event;
+    counter_set->counters[i] = (iree_hal_amdgpu_profile_counter_t){
+        .descriptor = descriptor,
+        .event = event,
+        .counter_ordinal = counter_ordinal,
+    };
+  }
+  counter_set->counter_count = (uint32_t)selection->counter_name_count;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_profile_counter_count_values(
+    const iree_hal_amdgpu_libaqlprofile_t* libaqlprofile,
+    iree_hal_amdgpu_aqlprofile_handle_t handle,
+    iree_hal_amdgpu_profile_counter_set_t* counter_set) {
+  for (uint32_t i = 0; i < counter_set->counter_count; ++i) {
+    counter_set->counters[i].sample_value_count = 0;
+    counter_set->counters[i].sample_value_offset = 0;
+  }
+
+  iree_hal_amdgpu_profile_counter_count_context_t context = {
+      .counter_set = counter_set,
+  };
+  IREE_RETURN_IF_AQLPROFILE_ERROR(
+      libaqlprofile,
+      libaqlprofile->aqlprofile_pmc_iterate_data(
+          handle, iree_hal_amdgpu_profile_counter_count_callback, &context),
+      "iterating AMDGPU counter metadata");
+
+  iree_host_size_t sample_value_count = 0;
+  for (uint32_t i = 0; i < counter_set->counter_count; ++i) {
+    iree_hal_amdgpu_profile_counter_t* counter = &counter_set->counters[i];
+    if (IREE_UNLIKELY(counter->sample_value_count == 0)) {
+      return iree_make_status(
+          IREE_STATUS_FAILED_PRECONDITION,
+          "aqlprofile reported zero values for selected AMDGPU counter '%.*s'",
+          (int)counter->descriptor->name.size, counter->descriptor->name.data);
+    }
+    if (IREE_UNLIKELY(sample_value_count > UINT32_MAX)) {
+      return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                              "AMDGPU counter sample value count overflow");
+    }
+    counter->sample_value_offset = (uint32_t)sample_value_count;
+    if (IREE_UNLIKELY(!iree_host_size_checked_add(sample_value_count,
+                                                  counter->sample_value_count,
+                                                  &sample_value_count))) {
+      return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                              "AMDGPU counter sample value count overflow");
+    }
+  }
+  if (IREE_UNLIKELY(sample_value_count > UINT32_MAX)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "AMDGPU counter sample value count overflow");
+  }
+  counter_set->sample_value_count = (uint32_t)sample_value_count;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_profile_counter_create_packets(
+    const iree_hal_amdgpu_libaqlprofile_t* libaqlprofile,
+    const iree_hal_amdgpu_profile_counter_set_t* counter_set,
+    const iree_hal_amdgpu_profile_aqlprofile_memory_context_t* memory_context,
+    iree_hal_amdgpu_aqlprofile_handle_t* out_handle,
+    iree_hal_amdgpu_aqlprofile_pmc_aql_packets_t* out_packets) {
+  iree_hal_amdgpu_aqlprofile_pmc_profile_t profile = {
+      .agent = counter_set->agent,
+      .events = counter_set->events,
+      .event_count = counter_set->counter_count,
+  };
+  IREE_RETURN_IF_AQLPROFILE_ERROR(
+      libaqlprofile,
+      libaqlprofile->aqlprofile_pmc_create_packets(
+          out_handle, out_packets, profile,
+          iree_hal_amdgpu_profile_aqlprofile_memory_alloc,
+          iree_hal_amdgpu_profile_aqlprofile_memory_dealloc,
+          iree_hal_amdgpu_profile_aqlprofile_memory_copy,
+          (void*)memory_context),
+      "creating AMDGPU counter PM4 packets");
+  return iree_ok_status();
+}
+
+static void iree_hal_amdgpu_profile_counter_destroy_packets(
+    const iree_hal_amdgpu_libaqlprofile_t* libaqlprofile,
+    iree_hal_amdgpu_aqlprofile_handle_t* handle) {
+  if (!handle->handle) return;
+  libaqlprofile->aqlprofile_pmc_delete_packets(*handle);
+  handle->handle = 0;
+}
+
+static void iree_hal_amdgpu_profile_counter_destroy_packet_set(
+    const iree_hal_amdgpu_libaqlprofile_t* libaqlprofile,
+    iree_hal_amdgpu_profile_counter_packet_set_t* packet_set) {
+  iree_hal_amdgpu_profile_counter_destroy_packets(libaqlprofile,
+                                                  &packet_set->handle);
+  memset(&packet_set->packets, 0, sizeof(packet_set->packets));
+}
+
+static iree_status_t iree_hal_amdgpu_profile_counter_initialize_set(
+    const iree_hal_device_profiling_options_t* options,
+    iree_hal_amdgpu_physical_device_t* physical_device,
+    iree_host_size_t selection_ordinal,
+    const iree_hal_amdgpu_libaqlprofile_t* libaqlprofile,
+    iree_hal_amdgpu_profile_counter_session_t* session,
+    iree_hal_amdgpu_profile_counter_t** inout_counter_storage,
+    iree_hal_amdgpu_aqlprofile_pmc_event_t** inout_event_storage,
+    char** inout_string_storage,
+    iree_hal_amdgpu_profile_counter_set_t* out_counter_set) {
+  const iree_hal_profile_counter_set_selection_t* selection =
+      &options->counter_sets[selection_ordinal];
+  iree_string_view_t set_name = selection->name;
+  if (iree_string_view_is_empty(set_name)) {
+    set_name = iree_hal_amdgpu_profile_counter_set_name();
+  }
+  char* string_storage = *inout_string_storage;
+  memcpy(string_storage, set_name.data, set_name.size);
+  *inout_string_storage = string_storage + set_name.size;
+
+  *out_counter_set = (iree_hal_amdgpu_profile_counter_set_t){
+      .counter_set_id =
+          ((uint64_t)(uint32_t)physical_device->device_ordinal << 32) |
+          (uint64_t)(selection_ordinal + 1),
+      .physical_device_ordinal = (uint32_t)physical_device->device_ordinal,
+      .counter_count = 0,
+      .sample_value_count = 0,
+      .agent = session->agent_handles[physical_device->device_ordinal],
+      .events = *inout_event_storage,
+      .counters = *inout_counter_storage,
+      .name = iree_make_string_view(string_storage, set_name.size),
+  };
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_profile_counter_initialize_selection(
+      selection, physical_device->isa.target_id.version, out_counter_set));
+  *inout_counter_storage += out_counter_set->counter_count;
+  *inout_event_storage += out_counter_set->counter_count;
+
+  for (uint32_t i = 0; i < out_counter_set->counter_count; ++i) {
+    const iree_hal_amdgpu_profile_counter_t* counter =
+        &out_counter_set->counters[i];
+    bool valid = false;
+    IREE_RETURN_IF_AQLPROFILE_ERROR(
+        libaqlprofile,
+        libaqlprofile->aqlprofile_validate_pmc_event(
+            session->agent_handles[physical_device->device_ordinal],
+            &counter->event, &valid),
+        "validating AMDGPU counter event");
+    if (IREE_UNLIKELY(!valid)) {
+      return iree_make_status(
+          IREE_STATUS_FAILED_PRECONDITION,
+          "AMDGPU counter '%.*s' is not valid on physical device %" PRIhsz,
+          (int)counter->descriptor->name.size, counter->descriptor->name.data,
+          physical_device->device_ordinal);
+    }
+  }
+
+  iree_hal_amdgpu_profile_aqlprofile_memory_context_t memory_context = {
+      .libhsa = session->libhsa,
+      .device_agent = physical_device->device_agent,
+      .host_memory_pools = &physical_device->host_memory_pools,
+      .device_coarse_pool =
+          physical_device->coarse_block_pools.large.memory_pool,
+  };
+  iree_hal_amdgpu_aqlprofile_handle_t metadata_handle = {0};
+  iree_hal_amdgpu_aqlprofile_pmc_aql_packets_t metadata_packets;
+  memset(&metadata_packets, 0, sizeof(metadata_packets));
+  iree_status_t status = iree_hal_amdgpu_profile_counter_create_packets(
+      libaqlprofile, out_counter_set, &memory_context, &metadata_handle,
+      &metadata_packets);
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_profile_counter_count_values(
+        libaqlprofile, metadata_handle, out_counter_set);
+  }
+  iree_hal_amdgpu_profile_counter_destroy_packets(libaqlprofile,
+                                                  &metadata_handle);
+  return status;
+}
+
+iree_status_t iree_hal_amdgpu_profile_counter_session_allocate(
+    iree_hal_amdgpu_logical_device_t* logical_device,
+    const iree_hal_device_profiling_options_t* options,
+    iree_allocator_t host_allocator,
+    iree_hal_amdgpu_profile_counter_session_t** out_session) {
+  IREE_TRACE_ZONE_BEGIN(z0);
+  *out_session = NULL;
+
+  const bool capture_dispatch_samples =
+      iree_hal_device_profiling_options_requests_counter_samples(options);
+  const bool capture_queue_ranges =
+      iree_hal_device_profiling_options_requests_counter_ranges(options);
+  if (!capture_dispatch_samples && !capture_queue_ranges) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_ok_status();
+  }
+  if (IREE_UNLIKELY(!options->sink)) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "AMDGPU hardware counter profiling requires a profile sink");
+  }
+  if (IREE_UNLIKELY(options->counter_set_count > UINT32_MAX)) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "AMDGPU counter set count exceeds uint32_t");
+  }
+  if (IREE_UNLIKELY(logical_device->physical_device_count > UINT32_MAX)) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "AMDGPU physical device count exceeds uint32_t");
+  }
+
+  iree_host_size_t counter_set_total_count = 0;
+  if (IREE_UNLIKELY(!iree_host_size_checked_mul(
+          logical_device->physical_device_count, options->counter_set_count,
+          &counter_set_total_count))) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "AMDGPU counter set count overflow");
+  }
+  iree_host_size_t counter_total_count = 0;
+  iree_host_size_t string_storage_length = 0;
+  for (iree_host_size_t i = 0; i < options->counter_set_count; ++i) {
+    if (IREE_UNLIKELY(options->counter_sets[i].counter_name_count == 0)) {
+      IREE_TRACE_ZONE_END(z0);
+      return iree_make_status(
+          IREE_STATUS_INVALID_ARGUMENT,
+          "AMDGPU counter profiling requires at least one counter per set");
+    }
+    iree_host_size_t replicated_counter_count = 0;
+    if (IREE_UNLIKELY(!iree_host_size_checked_mul(
+                          logical_device->physical_device_count,
+                          options->counter_sets[i].counter_name_count,
+                          &replicated_counter_count) ||
+                      !iree_host_size_checked_add(counter_total_count,
+                                                  replicated_counter_count,
+                                                  &counter_total_count))) {
+      IREE_TRACE_ZONE_END(z0);
+      return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                              "AMDGPU counter count overflow");
+    }
+
+    iree_string_view_t set_name = options->counter_sets[i].name;
+    if (iree_string_view_is_empty(set_name)) {
+      set_name = iree_hal_amdgpu_profile_counter_set_name();
+    }
+    iree_host_size_t replicated_length = 0;
+    if (IREE_UNLIKELY(
+            !iree_host_size_checked_mul(logical_device->physical_device_count,
+                                        set_name.size, &replicated_length) ||
+            !iree_host_size_checked_add(string_storage_length,
+                                        replicated_length,
+                                        &string_storage_length))) {
+      IREE_TRACE_ZONE_END(z0);
+      return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                              "AMDGPU counter set name storage overflow");
+    }
+  }
+
+  iree_host_size_t agent_handles_offset = 0;
+  iree_host_size_t counter_sets_offset = 0;
+  iree_host_size_t counters_offset = 0;
+  iree_host_size_t events_offset = 0;
+  iree_host_size_t string_storage_offset = 0;
+  iree_host_size_t session_size = 0;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, IREE_STRUCT_LAYOUT(
+              sizeof(iree_hal_amdgpu_profile_counter_session_t), &session_size,
+              IREE_STRUCT_FIELD(logical_device->physical_device_count,
+                                iree_hal_amdgpu_aqlprofile_agent_handle_t,
+                                &agent_handles_offset),
+              IREE_STRUCT_FIELD(counter_set_total_count,
+                                iree_hal_amdgpu_profile_counter_set_t,
+                                &counter_sets_offset),
+              IREE_STRUCT_FIELD(counter_total_count,
+                                iree_hal_amdgpu_profile_counter_t,
+                                &counters_offset),
+              IREE_STRUCT_FIELD(counter_total_count,
+                                iree_hal_amdgpu_aqlprofile_pmc_event_t,
+                                &events_offset),
+              IREE_STRUCT_FIELD(string_storage_length, char,
+                                &string_storage_offset)));
+
+  iree_hal_amdgpu_profile_counter_session_t* session = NULL;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0,
+      iree_allocator_malloc(host_allocator, session_size, (void**)&session));
+  memset(session, 0, session_size);
+  session->host_allocator = host_allocator;
+  session->flags = IREE_HAL_AMDGPU_PROFILE_COUNTER_SESSION_FLAG_NONE;
+  if (capture_dispatch_samples) {
+    session->flags |=
+        IREE_HAL_AMDGPU_PROFILE_COUNTER_SESSION_FLAG_DISPATCH_SAMPLES;
+  }
+  if (capture_queue_ranges) {
+    session->flags |= IREE_HAL_AMDGPU_PROFILE_COUNTER_SESSION_FLAG_QUEUE_RANGES;
+  }
+  session->libhsa = &logical_device->system->libhsa;
+  session->physical_device_count = logical_device->physical_device_count;
+  session->counter_set_count = (uint32_t)options->counter_set_count;
+  session->agent_handles =
+      (iree_hal_amdgpu_aqlprofile_agent_handle_t*)((uint8_t*)session +
+                                                   agent_handles_offset);
+  session->counter_sets =
+      (iree_hal_amdgpu_profile_counter_set_t*)((uint8_t*)session +
+                                               counter_sets_offset);
+  iree_hal_amdgpu_profile_counter_t* counter_storage =
+      (iree_hal_amdgpu_profile_counter_t*)((uint8_t*)session + counters_offset);
+  iree_hal_amdgpu_aqlprofile_pmc_event_t* event_storage =
+      (iree_hal_amdgpu_aqlprofile_pmc_event_t*)((uint8_t*)session +
+                                                events_offset);
+  iree_atomic_store(&session->next_sample_id, 1, iree_memory_order_relaxed);
+
+  iree_status_t status = iree_hal_amdgpu_libaqlprofile_initialize(
+      session->libhsa, iree_string_view_list_empty(), host_allocator,
+      &session->libaqlprofile);
+  for (iree_host_size_t i = 0;
+       i < logical_device->physical_device_count && iree_status_is_ok(status);
+       ++i) {
+    iree_hal_amdgpu_physical_device_t* physical_device =
+        logical_device->physical_devices[i];
+    status = iree_hal_amdgpu_profile_aqlprofile_register_agent(
+        session->libhsa, &session->libaqlprofile, physical_device->device_agent,
+        &session->agent_handles[physical_device->device_ordinal]);
+  }
+  char* string_storage = (char*)session + string_storage_offset;
+  for (iree_host_size_t i = 0;
+       i < logical_device->physical_device_count && iree_status_is_ok(status);
+       ++i) {
+    iree_hal_amdgpu_physical_device_t* physical_device =
+        logical_device->physical_devices[i];
+    for (iree_host_size_t j = 0;
+         j < options->counter_set_count && iree_status_is_ok(status); ++j) {
+      const iree_host_size_t counter_set_index =
+          physical_device->device_ordinal * options->counter_set_count + j;
+      status = iree_hal_amdgpu_profile_counter_initialize_set(
+          options, physical_device, j, &session->libaqlprofile, session,
+          &counter_storage, &event_storage, &string_storage,
+          &session->counter_sets[counter_set_index]);
+    }
+  }
+
+  if (iree_status_is_ok(status)) {
+    *out_session = session;
+  } else {
+    iree_hal_amdgpu_libaqlprofile_deinitialize(&session->libaqlprofile);
+    iree_allocator_free(host_allocator, session);
+  }
+
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+void iree_hal_amdgpu_profile_counter_session_free(
+    iree_hal_amdgpu_profile_counter_session_t* session) {
+  if (!session) return;
+  IREE_TRACE_ZONE_BEGIN(z0);
+  iree_allocator_t host_allocator = session->host_allocator;
+  iree_hal_amdgpu_libaqlprofile_deinitialize(&session->libaqlprofile);
+  iree_allocator_free(host_allocator, session);
+  IREE_TRACE_ZONE_END(z0);
+}
+
+bool iree_hal_amdgpu_profile_counter_session_is_active(
+    const iree_hal_amdgpu_profile_counter_session_t* session) {
+  return session &&
+         session->flags != IREE_HAL_AMDGPU_PROFILE_COUNTER_SESSION_FLAG_NONE &&
+         session->counter_set_count != 0;
+}
+
+bool iree_hal_amdgpu_profile_counter_session_captures_dispatch_samples(
+    const iree_hal_amdgpu_profile_counter_session_t* session) {
+  return iree_hal_amdgpu_profile_counter_session_is_active(session) &&
+         iree_any_bit_set(
+             session->flags,
+             IREE_HAL_AMDGPU_PROFILE_COUNTER_SESSION_FLAG_DISPATCH_SAMPLES);
+}
+
+bool iree_hal_amdgpu_profile_counter_session_captures_queue_ranges(
+    const iree_hal_amdgpu_profile_counter_session_t* session) {
+  return iree_hal_amdgpu_profile_counter_session_is_active(session) &&
+         iree_any_bit_set(
+             session->flags,
+             IREE_HAL_AMDGPU_PROFILE_COUNTER_SESSION_FLAG_QUEUE_RANGES);
+}
+
+static iree_status_t iree_hal_amdgpu_profile_counter_set_record_size(
+    const iree_hal_amdgpu_profile_counter_set_t* counter_set,
+    iree_host_size_t* out_record_size) {
+  return IREE_STRUCT_LAYOUT(
+      0, out_record_size,
+      IREE_STRUCT_FIELD(1, iree_hal_profile_counter_set_record_t, NULL),
+      IREE_STRUCT_FIELD(counter_set->name.size, char, NULL));
+}
+
+static iree_status_t iree_hal_amdgpu_profile_counter_record_size(
+    const iree_hal_amdgpu_profile_counter_t* counter,
+    iree_host_size_t* out_record_size) {
+  const iree_hal_amdgpu_profile_counter_descriptor_t* descriptor =
+      counter->descriptor;
+  return IREE_STRUCT_LAYOUT(
+      0, out_record_size,
+      IREE_STRUCT_FIELD(1, iree_hal_profile_counter_record_t, NULL),
+      IREE_STRUCT_FIELD(descriptor->block_name.size + descriptor->name.size +
+                            descriptor->description.size,
+                        char, NULL));
+}
+
+static iree_status_t iree_hal_amdgpu_profile_counter_metadata_size(
+    const iree_hal_amdgpu_profile_counter_session_t* session,
+    iree_host_size_t* out_counter_set_size,
+    iree_host_size_t* out_counter_size) {
+  iree_host_size_t counter_set_size = 0;
+  iree_host_size_t counter_size = 0;
+  const iree_host_size_t counter_set_total_count =
+      session->physical_device_count * session->counter_set_count;
+  for (iree_host_size_t i = 0; i < counter_set_total_count; ++i) {
+    const iree_hal_amdgpu_profile_counter_set_t* counter_set =
+        &session->counter_sets[i];
+    iree_host_size_t record_size = 0;
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_profile_counter_set_record_size(
+        counter_set, &record_size));
+    if (IREE_UNLIKELY(!iree_host_size_checked_add(counter_set_size, record_size,
+                                                  &counter_set_size))) {
+      return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                              "AMDGPU counter set metadata overflow");
+    }
+
+    for (uint32_t j = 0; j < counter_set->counter_count; ++j) {
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_profile_counter_record_size(
+          &counter_set->counters[j], &record_size));
+      if (IREE_UNLIKELY(!iree_host_size_checked_add(counter_size, record_size,
+                                                    &counter_size))) {
+        return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                                "AMDGPU counter metadata overflow");
+      }
+    }
+  }
+  *out_counter_set_size = counter_set_size;
+  *out_counter_size = counter_size;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_profile_counter_pack_counter_set_record(
+    const iree_hal_amdgpu_profile_counter_set_t* counter_set,
+    uint8_t** inout_storage_ptr) {
+  iree_host_size_t record_size = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_profile_counter_set_record_size(
+      counter_set, &record_size));
+  if (IREE_UNLIKELY(record_size > UINT32_MAX)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "AMDGPU counter set metadata record exceeds uint32_t");
+  }
+
+  iree_hal_profile_counter_set_record_t record =
+      iree_hal_profile_counter_set_record_default();
+  record.record_length = (uint32_t)record_size;
+  record.counter_set_id = counter_set->counter_set_id;
+  record.physical_device_ordinal = counter_set->physical_device_ordinal;
+  record.counter_count = counter_set->counter_count;
+  record.sample_value_count = counter_set->sample_value_count;
+  record.name_length = (uint32_t)counter_set->name.size;
+
+  uint8_t* storage_ptr = *inout_storage_ptr;
+  memcpy(storage_ptr, &record, sizeof(record));
+  memcpy(storage_ptr + sizeof(record), counter_set->name.data,
+         counter_set->name.size);
+  *inout_storage_ptr = storage_ptr + record_size;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_profile_counter_pack_counter_record(
+    const iree_hal_amdgpu_profile_counter_set_t* counter_set,
+    const iree_hal_amdgpu_profile_counter_t* counter,
+    uint8_t** inout_storage_ptr) {
+  iree_host_size_t record_size = 0;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_profile_counter_record_size(counter, &record_size));
+  if (IREE_UNLIKELY(record_size > UINT32_MAX)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "AMDGPU counter metadata record exceeds uint32_t");
+  }
+
+  const iree_hal_amdgpu_profile_counter_descriptor_t* descriptor =
+      counter->descriptor;
+  iree_hal_profile_counter_record_t record =
+      iree_hal_profile_counter_record_default();
+  record.record_length = (uint32_t)record_size;
+  record.flags = IREE_HAL_PROFILE_COUNTER_FLAG_RAW;
+  record.unit = descriptor->unit;
+  record.physical_device_ordinal = counter_set->physical_device_ordinal;
+  record.counter_set_id = counter_set->counter_set_id;
+  record.counter_ordinal = counter->counter_ordinal;
+  record.sample_value_offset = counter->sample_value_offset;
+  record.sample_value_count = counter->sample_value_count;
+  record.block_name_length = (uint32_t)descriptor->block_name.size;
+  record.name_length = (uint32_t)descriptor->name.size;
+  record.description_length = (uint32_t)descriptor->description.size;
+
+  uint8_t* storage_ptr = *inout_storage_ptr;
+  memcpy(storage_ptr, &record, sizeof(record));
+  uint8_t* string_ptr = storage_ptr + sizeof(record);
+  memcpy(string_ptr, descriptor->block_name.data, descriptor->block_name.size);
+  string_ptr += descriptor->block_name.size;
+  memcpy(string_ptr, descriptor->name.data, descriptor->name.size);
+  string_ptr += descriptor->name.size;
+  memcpy(string_ptr, descriptor->description.data,
+         descriptor->description.size);
+  *inout_storage_ptr = storage_ptr + record_size;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_profile_counter_write_metadata_chunk(
+    iree_hal_profile_sink_t* sink, uint64_t session_id,
+    iree_string_view_t stream_name, iree_string_view_t content_type,
+    const uint8_t* storage, iree_host_size_t storage_size) {
+  iree_hal_profile_chunk_metadata_t metadata =
+      iree_hal_profile_chunk_metadata_default();
+  metadata.content_type = content_type;
+  metadata.name = stream_name;
+  metadata.session_id = session_id;
+  iree_const_byte_span_t iovec =
+      iree_make_const_byte_span(storage, storage_size);
+  return iree_hal_profile_sink_write(sink, &metadata, 1, &iovec);
+}
+
+iree_status_t iree_hal_amdgpu_profile_counter_session_write_metadata(
+    const iree_hal_amdgpu_profile_counter_session_t* session,
+    iree_hal_profile_sink_t* sink, uint64_t session_id,
+    iree_string_view_t stream_name) {
+  if (!iree_hal_amdgpu_profile_counter_session_is_active(session)) {
+    return iree_ok_status();
+  }
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  iree_host_size_t counter_set_storage_size = 0;
+  iree_host_size_t counter_storage_size = 0;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_hal_amdgpu_profile_counter_metadata_size(
+              session, &counter_set_storage_size, &counter_storage_size));
+
+  uint8_t* counter_set_storage = NULL;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0,
+      iree_allocator_malloc(session->host_allocator, counter_set_storage_size,
+                            (void**)&counter_set_storage));
+  uint8_t* counter_storage = NULL;
+  iree_status_t status = iree_allocator_malloc(
+      session->host_allocator, counter_storage_size, (void**)&counter_storage);
+
+  const iree_host_size_t counter_set_total_count =
+      session->physical_device_count * session->counter_set_count;
+  uint8_t* counter_set_ptr = counter_set_storage;
+  uint8_t* counter_ptr = counter_storage;
+  for (iree_host_size_t i = 0;
+       i < counter_set_total_count && iree_status_is_ok(status); ++i) {
+    const iree_hal_amdgpu_profile_counter_set_t* counter_set =
+        &session->counter_sets[i];
+    status = iree_hal_amdgpu_profile_counter_pack_counter_set_record(
+        counter_set, &counter_set_ptr);
+    for (uint32_t j = 0;
+         j < counter_set->counter_count && iree_status_is_ok(status); ++j) {
+      status = iree_hal_amdgpu_profile_counter_pack_counter_record(
+          counter_set, &counter_set->counters[j], &counter_ptr);
+    }
+  }
+
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_profile_counter_write_metadata_chunk(
+        sink, session_id, stream_name,
+        IREE_HAL_PROFILE_CONTENT_TYPE_COUNTER_SETS, counter_set_storage,
+        counter_set_storage_size);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_profile_counter_write_metadata_chunk(
+        sink, session_id, stream_name, IREE_HAL_PROFILE_CONTENT_TYPE_COUNTERS,
+        counter_storage, counter_storage_size);
+  }
+
+  iree_allocator_free(session->host_allocator, counter_storage);
+  iree_allocator_free(session->host_allocator, counter_set_storage);
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+//===----------------------------------------------------------------------===//
+// Queue-local counter sample slots
+//===----------------------------------------------------------------------===//
+
+static const iree_hal_amdgpu_profile_counter_set_t*
+iree_hal_amdgpu_profile_counter_session_counter_set(
+    const iree_hal_amdgpu_profile_counter_session_t* session,
+    iree_host_size_t device_ordinal, uint32_t counter_set_ordinal) {
+  if (!session || counter_set_ordinal >= session->counter_set_count ||
+      device_ordinal >= session->physical_device_count) {
+    return NULL;
+  }
+  return &session->counter_sets[device_ordinal * session->counter_set_count +
+                                counter_set_ordinal];
+}
+
+static const iree_hal_amdgpu_profile_counter_set_t*
+iree_hal_amdgpu_host_queue_profile_counter_set(
+    const iree_hal_amdgpu_host_queue_t* queue, uint32_t counter_set_ordinal) {
+  const iree_hal_amdgpu_profile_counter_session_t* session =
+      queue->profiling.counters.session;
+  return iree_hal_amdgpu_profile_counter_session_counter_set(
+      session, queue->device_ordinal, counter_set_ordinal);
+}
+
+static iree_hal_amdgpu_physical_device_t*
+iree_hal_amdgpu_host_queue_profile_counter_physical_device(
+    const iree_hal_amdgpu_host_queue_t* queue) {
+  iree_hal_amdgpu_logical_device_t* logical_device =
+      (iree_hal_amdgpu_logical_device_t*)queue->logical_device;
+  return logical_device->physical_devices[queue->device_ordinal];
+}
+
+static void iree_hal_amdgpu_host_queue_initialize_counter_packet_set(
+    const iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_counter_session_t* session,
+    iree_hal_amdgpu_profile_counter_packet_set_t* packet_set) {
+  iree_hal_amdgpu_physical_device_t* physical_device =
+      iree_hal_amdgpu_host_queue_profile_counter_physical_device(queue);
+  packet_set->memory_context =
+      (iree_hal_amdgpu_profile_aqlprofile_memory_context_t){
+          .libhsa = session->libhsa,
+          .device_agent = physical_device->device_agent,
+          .host_memory_pools = &physical_device->host_memory_pools,
+          .device_coarse_pool =
+              physical_device->coarse_block_pools.large.memory_pool,
+      };
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_ensure_counter_packet_set(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_counter_session_t* session,
+    const iree_hal_amdgpu_profile_counter_set_t* counter_set,
+    iree_hal_amdgpu_profile_counter_packet_set_t* packet_set) {
+  if (packet_set->handle.handle) return iree_ok_status();
+  iree_hal_amdgpu_host_queue_initialize_counter_packet_set(queue, session,
+                                                           packet_set);
+  return iree_hal_amdgpu_profile_counter_create_packets(
+      &session->libaqlprofile, counter_set, &packet_set->memory_context,
+      &packet_set->handle, &packet_set->packets);
+}
+
+static iree_hal_amdgpu_profile_counter_sample_slot_t*
+iree_hal_amdgpu_host_queue_profile_counter_slot(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t event_position,
+    uint32_t counter_set_ordinal) {
+  const uint32_t event_index =
+      iree_hal_amdgpu_host_queue_profile_dispatch_event_index(queue,
+                                                              event_position);
+  const iree_host_size_t slot_index =
+      (iree_host_size_t)event_index * queue->profiling.counters.set_count +
+      counter_set_ordinal;
+  return &queue->profiling.counters.dispatch_samples.slots[slot_index];
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_enable_profile_counters(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_counter_session_t* session,
+    iree_hal_amdgpu_profile_counter_enable_flags_t flags) {
+  if (!iree_hal_amdgpu_profile_counter_session_is_active(session)) {
+    return iree_ok_status();
+  }
+  const bool enable_dispatch_samples =
+      iree_any_bit_set(
+          flags,
+          IREE_HAL_AMDGPU_PROFILE_COUNTER_ENABLE_FLAG_DISPATCH_SAMPLES) &&
+      iree_hal_amdgpu_profile_counter_session_captures_dispatch_samples(
+          session);
+  const bool enable_queue_ranges =
+      iree_any_bit_set(
+          flags, IREE_HAL_AMDGPU_PROFILE_COUNTER_ENABLE_FLAG_QUEUE_RANGES) &&
+      iree_hal_amdgpu_profile_counter_session_captures_queue_ranges(session);
+  if (!enable_dispatch_samples && !enable_queue_ranges) {
+    return iree_ok_status();
+  }
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  if (IREE_UNLIKELY(!iree_any_bit_set(
+          queue->vendor_packet_capabilities,
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_AQL_PM4_IB))) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_make_status(
+        IREE_STATUS_FAILED_PRECONDITION,
+        "AMDGPU counter profiling requires AQL PM4-IB packet support");
+  }
+  if (IREE_UNLIKELY(queue->profiling.counters.session)) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                            "AMDGPU counter profiling is already enabled");
+  }
+
+  iree_hal_amdgpu_profile_counter_sample_slot_t* dispatch_sample_slots = NULL;
+  if (enable_dispatch_samples) {
+    const uint32_t dispatch_event_capacity =
+        iree_hal_amdgpu_host_queue_profile_dispatch_event_capacity(queue);
+    if (IREE_UNLIKELY(!dispatch_event_capacity)) {
+      IREE_TRACE_ZONE_END(z0);
+      return iree_make_status(
+          IREE_STATUS_FAILED_PRECONDITION,
+          "AMDGPU counter profiling requires dispatch event storage");
+    }
+
+    iree_host_size_t slot_count = 0;
+    if (IREE_UNLIKELY(!iree_host_size_checked_mul(dispatch_event_capacity,
+                                                  session->counter_set_count,
+                                                  &slot_count))) {
+      IREE_TRACE_ZONE_END(z0);
+      return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                              "AMDGPU counter sample slot count overflow");
+    }
+    iree_host_size_t slot_storage_size = 0;
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, IREE_STRUCT_LAYOUT(
+                0, &slot_storage_size,
+                IREE_STRUCT_FIELD(slot_count,
+                                  iree_hal_amdgpu_profile_counter_sample_slot_t,
+                                  NULL)));
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_allocator_malloc(queue->host_allocator, slot_storage_size,
+                                  (void**)&dispatch_sample_slots));
+    memset(dispatch_sample_slots, 0, slot_storage_size);
+  }
+
+  iree_hal_amdgpu_profile_counter_range_slot_t* range_slots = NULL;
+  uint64_t* range_ticks = NULL;
+  iree_host_size_t range_tick_storage_size = 0;
+  if (enable_queue_ranges) {
+    iree_host_size_t range_slot_count = 0;
+    if (IREE_UNLIKELY(!iree_host_size_checked_mul(
+            iree_hal_amdgpu_profile_counter_range_bank_count,
+            session->counter_set_count, &range_slot_count))) {
+      iree_allocator_free(queue->host_allocator, dispatch_sample_slots);
+      IREE_TRACE_ZONE_END(z0);
+      return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                              "AMDGPU counter range slot count overflow");
+    }
+    iree_host_size_t range_slot_storage_size = 0;
+    iree_status_t status = IREE_STRUCT_LAYOUT(
+        0, &range_slot_storage_size,
+        IREE_STRUCT_FIELD(range_slot_count,
+                          iree_hal_amdgpu_profile_counter_range_slot_t, NULL));
+    if (iree_status_is_ok(status)) {
+      status = iree_allocator_malloc(
+          queue->host_allocator, range_slot_storage_size, (void**)&range_slots);
+    }
+    if (iree_status_is_ok(status)) {
+      memset(range_slots, 0, range_slot_storage_size);
+      status = IREE_STRUCT_LAYOUT(
+          0, &range_tick_storage_size,
+          IREE_STRUCT_FIELD(iree_hal_amdgpu_profile_counter_range_bank_count,
+                            iree_hal_amdgpu_profile_counter_range_ticks_t,
+                            NULL));
+    }
+    if (iree_status_is_ok(status)) {
+      status = iree_hsa_amd_memory_pool_allocate(
+          IREE_LIBHSA(queue->libhsa),
+          queue->profiling.signals.block_pool->memory_pool,
+          range_tick_storage_size, HSA_AMD_MEMORY_POOL_STANDARD_FLAG,
+          (void**)&range_ticks);
+    }
+    if (iree_status_is_ok(status)) {
+      memset(range_ticks, 0, range_tick_storage_size);
+      for (iree_host_size_t i = 0;
+           i < range_slot_count && iree_status_is_ok(status); ++i) {
+        const uint32_t counter_set_ordinal =
+            (uint32_t)(i % session->counter_set_count);
+        const iree_hal_amdgpu_profile_counter_set_t* counter_set =
+            iree_hal_amdgpu_profile_counter_session_counter_set(
+                session, queue->device_ordinal, counter_set_ordinal);
+        if (IREE_UNLIKELY(!counter_set)) {
+          status = iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                                    "AMDGPU counter set is not available");
+        } else {
+          iree_hal_amdgpu_host_queue_initialize_counter_packet_set(
+              queue, session, &range_slots[i].packet_set);
+          status = iree_hal_amdgpu_profile_counter_create_packets(
+              &session->libaqlprofile, counter_set,
+              &range_slots[i].packet_set.memory_context,
+              &range_slots[i].packet_set.handle,
+              &range_slots[i].packet_set.packets);
+        }
+      }
+      if (!iree_status_is_ok(status)) {
+        for (iree_host_size_t i = 0; i < range_slot_count; ++i) {
+          iree_hal_amdgpu_profile_counter_destroy_packet_set(
+              &session->libaqlprofile, &range_slots[i].packet_set);
+        }
+      }
+    } else {
+      status = iree_status_annotate(
+          status, IREE_SV("allocating AMDGPU counter range timestamp storage"));
+    }
+    if (!iree_status_is_ok(status)) {
+      if (range_ticks) {
+        status = iree_status_join(
+            status, iree_hsa_amd_memory_pool_free(IREE_LIBHSA(queue->libhsa),
+                                                  range_ticks));
+      }
+      iree_allocator_free(queue->host_allocator, range_slots);
+      iree_allocator_free(queue->host_allocator, dispatch_sample_slots);
+      IREE_TRACE_ZONE_END(z0);
+      return status;
+    }
+  }
+
+  queue->profiling.counters.session = session;
+  queue->profiling.counters.set_count = session->counter_set_count;
+  queue->profiling.counters.dispatch_samples.slots = dispatch_sample_slots;
+  queue->profiling.counters.ranges.slots = range_slots;
+  queue->profiling.counters.ranges.ticks = range_ticks;
+  queue->profiling.counters.ranges.tick_storage_size = range_tick_storage_size;
+  queue->profiling.counters.ranges.active_bank = 0;
+  queue->profiling.counters.ranges.bank_count =
+      range_slots ? iree_hal_amdgpu_profile_counter_range_bank_count : 0;
+  queue->profiling.counters.ranges.is_active = false;
+
+  IREE_TRACE_ZONE_END(z0);
+  return iree_ok_status();
+}
+
+void iree_hal_amdgpu_host_queue_disable_profile_counters(
+    iree_hal_amdgpu_host_queue_t* queue) {
+  if (!queue->profiling.counters.session) return;
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  iree_hal_amdgpu_profile_counter_session_t* session =
+      queue->profiling.counters.session;
+  if (queue->profiling.counters.dispatch_samples.slots) {
+    const iree_host_size_t slot_count =
+        (iree_host_size_t)
+            iree_hal_amdgpu_host_queue_profile_dispatch_event_capacity(queue) *
+        queue->profiling.counters.set_count;
+    for (iree_host_size_t i = 0; i < slot_count; ++i) {
+      iree_hal_amdgpu_profile_counter_destroy_packet_set(
+          &session->libaqlprofile,
+          &queue->profiling.counters.dispatch_samples.slots[i].packet_set);
+    }
+    iree_allocator_free(queue->host_allocator,
+                        queue->profiling.counters.dispatch_samples.slots);
+  }
+  if (queue->profiling.counters.ranges.slots) {
+    const iree_host_size_t slot_count =
+        (iree_host_size_t)queue->profiling.counters.ranges.bank_count *
+        queue->profiling.counters.set_count;
+    for (iree_host_size_t i = 0; i < slot_count; ++i) {
+      iree_hal_amdgpu_profile_counter_destroy_packet_set(
+          &session->libaqlprofile,
+          &queue->profiling.counters.ranges.slots[i].packet_set);
+    }
+    iree_allocator_free(queue->host_allocator,
+                        queue->profiling.counters.ranges.slots);
+  }
+  if (queue->profiling.counters.ranges.ticks) {
+    iree_hal_amdgpu_hsa_cleanup_assert_success(
+        iree_hsa_amd_memory_pool_free_raw(
+            queue->libhsa, queue->profiling.counters.ranges.ticks));
+  }
+  queue->profiling.counters.session = NULL;
+  queue->profiling.counters.set_count = 0;
+  queue->profiling.counters.dispatch_samples.slots = NULL;
+  queue->profiling.counters.ranges.slots = NULL;
+  queue->profiling.counters.ranges.ticks = NULL;
+  queue->profiling.counters.ranges.tick_storage_size = 0;
+  queue->profiling.counters.ranges.active_bank = 0;
+  queue->profiling.counters.ranges.bank_count = 0;
+  queue->profiling.counters.ranges.is_active = false;
+
+  IREE_TRACE_ZONE_END(z0);
+}
+
+uint32_t iree_hal_amdgpu_host_queue_profile_counter_packet_count(
+    const iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_dispatch_event_reservation_t reservation) {
+  if (!reservation.event_count ||
+      !queue->profiling.counters.dispatch_samples.slots) {
+    return 0;
+  }
+  return reservation.event_count * queue->profiling.counters.set_count *
+         iree_hal_amdgpu_profile_counter_packets_per_set;
+}
+
+uint32_t iree_hal_amdgpu_host_queue_profile_counter_set_count(
+    const iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_dispatch_event_reservation_t reservation) {
+  if (!reservation.event_count ||
+      !queue->profiling.counters.dispatch_samples.slots) {
+    return 0;
+  }
+  return queue->profiling.counters.set_count;
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_prepare_profile_counter_samples(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_dispatch_event_reservation_t reservation) {
+  iree_hal_amdgpu_profile_counter_session_t* session =
+      queue->profiling.counters.session;
+  if (!reservation.event_count ||
+      !queue->profiling.counters.dispatch_samples.slots) {
+    return iree_ok_status();
+  }
+
+  for (uint32_t event_ordinal = 0; event_ordinal < reservation.event_count;
+       ++event_ordinal) {
+    const uint64_t event_position =
+        reservation.first_event_position + event_ordinal;
+    for (uint32_t counter_set_ordinal = 0;
+         counter_set_ordinal < queue->profiling.counters.set_count;
+         ++counter_set_ordinal) {
+      const iree_hal_amdgpu_profile_counter_set_t* counter_set =
+          iree_hal_amdgpu_host_queue_profile_counter_set(queue,
+                                                         counter_set_ordinal);
+      if (IREE_UNLIKELY(!counter_set)) {
+        return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                                "AMDGPU counter set is not available");
+      }
+      iree_hal_amdgpu_profile_counter_sample_slot_t* slot =
+          iree_hal_amdgpu_host_queue_profile_counter_slot(queue, event_position,
+                                                          counter_set_ordinal);
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_queue_ensure_counter_packet_set(
+          queue, session, counter_set, &slot->packet_set));
+      slot->sample_id = (uint64_t)iree_atomic_fetch_add(
+          &session->next_sample_id, 1, iree_memory_order_relaxed);
+    }
+  }
+  return iree_ok_status();
+}
+
+static void iree_hal_amdgpu_host_queue_emplace_profile_counter_packet(
+    const iree_hsa_amd_aql_pm4_ib_packet_t* source_packet,
+    iree_hal_amdgpu_aql_packet_t* packet,
+    iree_hal_amdgpu_aql_packet_control_t packet_control,
+    iree_hsa_signal_t completion_signal, uint16_t* out_header,
+    uint16_t* out_setup) {
+  iree_hal_amdgpu_profile_aqlprofile_emplace_pm4_ib_packet(
+      source_packet, packet, packet_control, completion_signal, out_header,
+      out_setup);
+}
+
+static void iree_hal_amdgpu_host_queue_emplace_profile_counter_packet_at(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hsa_amd_aql_pm4_ib_packet_t* source_packet,
+    uint64_t first_packet_id, uint32_t packet_index,
+    iree_hal_amdgpu_aql_packet_control_t packet_control,
+    uint16_t* packet_headers, uint16_t* packet_setups) {
+  iree_hal_amdgpu_aql_packet_t* packet = iree_hal_amdgpu_aql_ring_packet(
+      &queue->aql_ring, first_packet_id + packet_index);
+  iree_hal_amdgpu_host_queue_emplace_profile_counter_packet(
+      source_packet, packet, packet_control, iree_hsa_signal_null(),
+      &packet_headers[packet_index], &packet_setups[packet_index]);
+}
+
+void iree_hal_amdgpu_host_queue_emplace_profile_counter_start_packets(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t event_position,
+    uint32_t counter_set_count, uint64_t first_packet_id,
+    uint32_t first_packet_index,
+    iree_hal_amdgpu_aql_packet_control_t packet_control,
+    uint16_t* packet_headers, uint16_t* packet_setups) {
+  for (uint32_t i = 0; i < counter_set_count; ++i) {
+    iree_hal_amdgpu_profile_counter_sample_slot_t* slot =
+        iree_hal_amdgpu_host_queue_profile_counter_slot(queue, event_position,
+                                                        i);
+    iree_hal_amdgpu_host_queue_emplace_profile_counter_packet_at(
+        queue, &slot->packet_set.packets.start_packet, first_packet_id,
+        first_packet_index + i, packet_control, packet_headers, packet_setups);
+  }
+}
+
+void iree_hal_amdgpu_host_queue_emplace_profile_counter_read_stop_packets(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t event_position,
+    uint32_t counter_set_count, uint64_t first_packet_id,
+    uint32_t first_packet_index,
+    iree_hal_amdgpu_aql_packet_control_t packet_control,
+    uint16_t* packet_headers, uint16_t* packet_setups) {
+  for (uint32_t i = 0; i < counter_set_count; ++i) {
+    iree_hal_amdgpu_profile_counter_sample_slot_t* slot =
+        iree_hal_amdgpu_host_queue_profile_counter_slot(queue, event_position,
+                                                        i);
+    iree_hal_amdgpu_host_queue_emplace_profile_counter_packet_at(
+        queue, &slot->packet_set.packets.read_packet, first_packet_id,
+        first_packet_index + i * 2u, packet_control, packet_headers,
+        packet_setups);
+    iree_hal_amdgpu_host_queue_emplace_profile_counter_packet_at(
+        queue, &slot->packet_set.packets.stop_packet, first_packet_id,
+        first_packet_index + i * 2u + 1u, packet_control, packet_headers,
+        packet_setups);
+  }
+}
+
+static void iree_hal_amdgpu_host_queue_commit_profile_counter_packet(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hsa_amd_aql_pm4_ib_packet_t* source_packet, uint64_t packet_id,
+    iree_hal_amdgpu_aql_packet_control_t packet_control,
+    iree_hsa_signal_t completion_signal) {
+  iree_hal_amdgpu_aql_packet_t* packet =
+      iree_hal_amdgpu_aql_ring_packet(&queue->aql_ring, packet_id);
+  uint16_t header = 0;
+  uint16_t setup = 0;
+  iree_hal_amdgpu_host_queue_emplace_profile_counter_packet(
+      source_packet, packet, packet_control, completion_signal, &header,
+      &setup);
+  iree_hal_amdgpu_aql_ring_commit(packet, header, setup);
+}
+
+void iree_hal_amdgpu_host_queue_commit_profile_counter_start_packets(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t event_position,
+    uint32_t counter_set_count, uint64_t first_packet_id,
+    iree_hal_amdgpu_aql_packet_control_t packet_control) {
+  for (uint32_t i = 0; i < counter_set_count; ++i) {
+    iree_hal_amdgpu_profile_counter_sample_slot_t* slot =
+        iree_hal_amdgpu_host_queue_profile_counter_slot(queue, event_position,
+                                                        i);
+    iree_hal_amdgpu_host_queue_commit_profile_counter_packet(
+        queue, &slot->packet_set.packets.start_packet, first_packet_id + i,
+        packet_control, iree_hsa_signal_null());
+  }
+}
+
+void iree_hal_amdgpu_host_queue_commit_profile_counter_read_stop_packets(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t event_position,
+    uint32_t counter_set_count, uint64_t first_packet_id,
+    iree_hal_amdgpu_aql_packet_control_t packet_control) {
+  for (uint32_t i = 0; i < counter_set_count; ++i) {
+    iree_hal_amdgpu_profile_counter_sample_slot_t* slot =
+        iree_hal_amdgpu_host_queue_profile_counter_slot(queue, event_position,
+                                                        i);
+    iree_hal_amdgpu_host_queue_commit_profile_counter_packet(
+        queue, &slot->packet_set.packets.read_packet, first_packet_id + i * 2u,
+        packet_control, iree_hsa_signal_null());
+    iree_hal_amdgpu_host_queue_commit_profile_counter_packet(
+        queue, &slot->packet_set.packets.stop_packet,
+        first_packet_id + i * 2u + 1u, packet_control, iree_hsa_signal_null());
+  }
+}
+
+static iree_hal_amdgpu_profile_counter_range_slot_t*
+iree_hal_amdgpu_host_queue_profile_counter_range_slot(
+    iree_hal_amdgpu_host_queue_t* queue, uint32_t bank,
+    uint32_t counter_set_ordinal) {
+  const iree_host_size_t slot_index =
+      (iree_host_size_t)bank * queue->profiling.counters.set_count +
+      counter_set_ordinal;
+  return &queue->profiling.counters.ranges.slots[slot_index];
+}
+
+static iree_hal_amdgpu_profile_counter_range_ticks_t*
+iree_hal_amdgpu_host_queue_profile_counter_range_ticks(
+    iree_hal_amdgpu_host_queue_t* queue, uint32_t bank) {
+  iree_hal_amdgpu_profile_counter_range_ticks_t* ticks =
+      (iree_hal_amdgpu_profile_counter_range_ticks_t*)
+          queue->profiling.counters.ranges.ticks;
+  return &ticks[bank];
+}
+
+static iree_status_t
+iree_hal_amdgpu_host_queue_profile_counter_range_start_packet_count(
+    const iree_hal_amdgpu_host_queue_t* queue, uint32_t* out_packet_count) {
+  if (IREE_UNLIKELY(queue->profiling.counters.set_count == UINT32_MAX)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "AMDGPU counter range packet count overflow");
+  }
+  *out_packet_count = queue->profiling.counters.set_count + 1u;
+  return iree_ok_status();
+}
+
+static iree_status_t
+iree_hal_amdgpu_host_queue_profile_counter_range_flush_packet_count(
+    const iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_counter_range_flush_flags_t flags,
+    uint32_t* out_packet_count) {
+  const uint32_t set_count = queue->profiling.counters.set_count;
+  if (IREE_UNLIKELY(set_count > (UINT32_MAX - 1u) / 2u)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "AMDGPU counter range packet count overflow");
+  }
+  uint32_t packet_count = 1u + set_count * 2u;
+  if (iree_any_bit_set(
+          flags, IREE_HAL_AMDGPU_PROFILE_COUNTER_RANGE_FLUSH_FLAG_RESTART)) {
+    uint32_t start_packet_count = 0;
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_host_queue_profile_counter_range_start_packet_count(
+            queue, &start_packet_count));
+    if (IREE_UNLIKELY(packet_count > UINT32_MAX - start_packet_count)) {
+      return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                              "AMDGPU counter range packet count overflow");
+    }
+    packet_count += start_packet_count;
+  }
+  *out_packet_count = packet_count;
+  return iree_ok_status();
+}
+
+static void iree_hal_amdgpu_host_queue_commit_profile_counter_range_start(
+    iree_hal_amdgpu_host_queue_t* queue, uint32_t bank,
+    uint64_t first_packet_id) {
+  const iree_hal_amdgpu_aql_packet_control_t packet_control =
+      iree_hal_amdgpu_aql_packet_control_barrier(IREE_HSA_FENCE_SCOPE_AGENT,
+                                                 IREE_HSA_FENCE_SCOPE_AGENT);
+  for (uint32_t i = 0; i < queue->profiling.counters.set_count; ++i) {
+    iree_hal_amdgpu_profile_counter_range_slot_t* slot =
+        iree_hal_amdgpu_host_queue_profile_counter_range_slot(queue, bank, i);
+    slot->sample_id = (uint64_t)iree_atomic_fetch_add(
+        &queue->profiling.counters.session->next_sample_id, 1,
+        iree_memory_order_relaxed);
+    iree_hal_amdgpu_host_queue_commit_profile_counter_packet(
+        queue, &slot->packet_set.packets.start_packet, first_packet_id + i,
+        packet_control, iree_hsa_signal_null());
+  }
+  iree_hal_amdgpu_profile_counter_range_ticks_t* ticks =
+      iree_hal_amdgpu_host_queue_profile_counter_range_ticks(queue, bank);
+  ticks->start_tick = 0;
+  ticks->end_tick = 0;
+  iree_hal_amdgpu_host_queue_commit_timestamp_start(
+      queue, first_packet_id + queue->profiling.counters.set_count,
+      packet_control, &ticks->start_tick);
+}
+
+static void iree_hal_amdgpu_host_queue_commit_profile_counter_range_flush(
+    iree_hal_amdgpu_host_queue_t* queue, uint32_t bank,
+    iree_hsa_signal_t completion_signal, uint64_t first_packet_id) {
+  const iree_hal_amdgpu_aql_packet_control_t packet_control =
+      iree_hal_amdgpu_aql_packet_control_barrier(IREE_HSA_FENCE_SCOPE_AGENT,
+                                                 IREE_HSA_FENCE_SCOPE_AGENT);
+  const iree_hal_amdgpu_aql_packet_control_t final_packet_control =
+      iree_hal_amdgpu_aql_packet_control_barrier(IREE_HSA_FENCE_SCOPE_AGENT,
+                                                 IREE_HSA_FENCE_SCOPE_SYSTEM);
+  iree_hal_amdgpu_profile_counter_range_ticks_t* ticks =
+      iree_hal_amdgpu_host_queue_profile_counter_range_ticks(queue, bank);
+  uint64_t packet_id = first_packet_id;
+  iree_hal_amdgpu_host_queue_commit_timestamp_end(
+      queue, packet_id++, packet_control, iree_hsa_signal_null(),
+      &ticks->end_tick);
+  for (uint32_t i = 0; i < queue->profiling.counters.set_count; ++i) {
+    iree_hal_amdgpu_profile_counter_range_slot_t* slot =
+        iree_hal_amdgpu_host_queue_profile_counter_range_slot(queue, bank, i);
+    iree_hal_amdgpu_host_queue_commit_profile_counter_packet(
+        queue, &slot->packet_set.packets.read_packet, packet_id++,
+        packet_control, iree_hsa_signal_null());
+    const bool is_final_stop = i + 1u == queue->profiling.counters.set_count;
+    iree_hal_amdgpu_host_queue_commit_profile_counter_packet(
+        queue, &slot->packet_set.packets.stop_packet, packet_id++,
+        is_final_stop ? final_packet_control : packet_control,
+        is_final_stop ? completion_signal : iree_hsa_signal_null());
+  }
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_wait_profile_counter_range(
+    iree_hal_amdgpu_host_queue_t* queue, iree_hsa_signal_t signal) {
+  iree_hal_amdgpu_logical_device_t* logical_device =
+      (iree_hal_amdgpu_logical_device_t*)queue->logical_device;
+  uint64_t wait_timeout_hint =
+      logical_device->system->info.timestamp_frequency / 1000;
+  if (wait_timeout_hint == 0) wait_timeout_hint = 1;
+
+  for (;;) {
+    const hsa_signal_value_t signal_value = iree_hsa_signal_wait_scacquire(
+        IREE_LIBHSA(queue->libhsa), signal, HSA_SIGNAL_CONDITION_EQ, 0,
+        wait_timeout_hint, HSA_WAIT_STATE_BLOCKED);
+    if (signal_value == 0) return iree_ok_status();
+
+    iree_status_t queue_error = (iree_status_t)iree_atomic_load(
+        &queue->error_status, iree_memory_order_acquire);
+    if (IREE_UNLIKELY(!iree_status_is_ok(queue_error))) {
+      return iree_status_clone(queue_error);
+    }
+  }
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_start_profile_counter_ranges(
+    iree_hal_amdgpu_host_queue_t* queue) {
+  if (!queue->profiling.counters.ranges.slots ||
+      queue->profiling.counters.ranges.is_active) {
+    return iree_ok_status();
+  }
+
+  uint32_t packet_count = 0;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_host_queue_profile_counter_range_start_packet_count(
+          queue, &packet_count));
+  iree_slim_mutex_lock(&queue->locks.submission_mutex);
+  iree_status_t status = iree_ok_status();
+  if (IREE_UNLIKELY(queue->is_shutting_down)) {
+    status = iree_make_status(IREE_STATUS_CANCELLED, "queue shutting down");
+  }
+  if (iree_status_is_ok(status)) {
+    const uint32_t bank = queue->profiling.counters.ranges.active_bank;
+    const uint64_t first_packet_id =
+        iree_hal_amdgpu_aql_ring_reserve(&queue->aql_ring, packet_count);
+    iree_hal_amdgpu_host_queue_commit_profile_counter_range_start(
+        queue, bank, first_packet_id);
+    iree_hal_amdgpu_aql_ring_doorbell(&queue->aql_ring,
+                                      first_packet_id + packet_count - 1u);
+    queue->profiling.counters.ranges.is_active = true;
+  }
+  iree_slim_mutex_unlock(&queue->locks.submission_mutex);
+  return status;
+}
+
+static iree_status_t iree_hal_amdgpu_profile_counter_sample_record_size(
+    uint32_t sample_value_count, iree_host_size_t* out_record_size) {
+  return IREE_STRUCT_LAYOUT(
+      0, out_record_size,
+      IREE_STRUCT_FIELD(1, iree_hal_profile_counter_sample_record_t, NULL),
+      IREE_STRUCT_FIELD(sample_value_count, uint64_t, NULL));
+}
+
+static iree_status_t
+iree_hal_amdgpu_host_queue_profile_counter_sample_storage_size(
+    iree_hal_amdgpu_host_queue_t* queue, iree_host_size_t event_count,
+    iree_host_size_t* out_storage_size) {
+  iree_host_size_t per_event_storage_size = 0;
+  for (uint32_t counter_set_ordinal = 0;
+       counter_set_ordinal < queue->profiling.counters.set_count;
+       ++counter_set_ordinal) {
+    const iree_hal_amdgpu_profile_counter_set_t* counter_set =
+        iree_hal_amdgpu_host_queue_profile_counter_set(queue,
+                                                       counter_set_ordinal);
+    if (IREE_UNLIKELY(!counter_set)) {
+      return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                              "AMDGPU counter set is not available");
+    }
+    iree_host_size_t record_size = 0;
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_profile_counter_sample_record_size(
+        counter_set->sample_value_count, &record_size));
+    if (IREE_UNLIKELY(!iree_host_size_checked_add(
+            per_event_storage_size, record_size, &per_event_storage_size))) {
+      return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                              "AMDGPU counter sample storage overflow");
+    }
+  }
+  if (IREE_UNLIKELY(!iree_host_size_checked_mul(
+          per_event_storage_size, event_count, out_storage_size))) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "AMDGPU counter sample storage overflow");
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t
+iree_hal_amdgpu_host_queue_profile_counter_max_counter_count(
+    iree_hal_amdgpu_host_queue_t* queue, uint32_t* out_counter_count) {
+  uint32_t max_counter_count = 0;
+  for (uint32_t counter_set_ordinal = 0;
+       counter_set_ordinal < queue->profiling.counters.set_count;
+       ++counter_set_ordinal) {
+    const iree_hal_amdgpu_profile_counter_set_t* counter_set =
+        iree_hal_amdgpu_host_queue_profile_counter_set(queue,
+                                                       counter_set_ordinal);
+    if (IREE_UNLIKELY(!counter_set)) {
+      return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                              "AMDGPU counter set is not available");
+    }
+    max_counter_count = iree_max(max_counter_count, counter_set->counter_count);
+  }
+  *out_counter_count = max_counter_count;
+  return iree_ok_status();
+}
+
+static void iree_hal_amdgpu_profile_counter_initialize_sample_record(
+    const iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_profile_dispatch_event_t* event,
+    const iree_hal_amdgpu_profile_counter_set_t* counter_set,
+    const iree_hal_amdgpu_profile_counter_sample_slot_t* slot,
+    iree_hal_profile_counter_sample_record_t* out_record) {
+  *out_record = iree_hal_profile_counter_sample_record_default();
+  out_record->flags = IREE_HAL_PROFILE_COUNTER_SAMPLE_FLAG_DISPATCH_EVENT |
+                      IREE_HAL_PROFILE_COUNTER_SAMPLE_FLAG_DEVICE_TICK_RANGE;
+  if (iree_any_bit_set(event->flags,
+                       IREE_HAL_PROFILE_DISPATCH_EVENT_FLAG_COMMAND_BUFFER)) {
+    out_record->flags |= IREE_HAL_PROFILE_COUNTER_SAMPLE_FLAG_COMMAND_OPERATION;
+  }
+  out_record->scope = IREE_HAL_PROFILE_COUNTER_SAMPLE_SCOPE_DISPATCH;
+  out_record->sample_id = slot->sample_id;
+  out_record->counter_set_id = counter_set->counter_set_id;
+  out_record->dispatch_event_id = event->event_id;
+  out_record->submission_id = event->submission_id;
+  out_record->command_buffer_id = event->command_buffer_id;
+  out_record->executable_id = event->executable_id;
+  out_record->stream_id = iree_hal_amdgpu_host_queue_profile_stream_id(queue);
+  out_record->start_tick = event->start_tick;
+  out_record->end_tick = event->end_tick;
+  out_record->command_index = event->command_index;
+  out_record->export_ordinal = event->export_ordinal;
+  out_record->physical_device_ordinal =
+      iree_hal_amdgpu_host_queue_profile_device_ordinal(queue);
+  out_record->queue_ordinal =
+      iree_hal_amdgpu_host_queue_profile_queue_ordinal(queue);
+  out_record->sample_value_count = counter_set->sample_value_count;
+}
+
+static iree_status_t iree_hal_amdgpu_profile_counter_collect_packet_set_values(
+    iree_hal_amdgpu_profile_counter_session_t* session,
+    const iree_hal_amdgpu_profile_counter_set_t* counter_set,
+    iree_hal_amdgpu_profile_counter_packet_set_t* packet_set,
+    uint32_t* counter_value_counts, uint64_t* values) {
+  memset(counter_value_counts, 0,
+         counter_set->counter_count * sizeof(counter_value_counts[0]));
+  iree_hal_amdgpu_profile_counter_collect_context_t collect_context = {
+      .counter_set = counter_set,
+      .values = values,
+      .counter_value_counts = counter_value_counts,
+  };
+  iree_status_t status = iree_status_from_aqlprofile_status(
+      &session->libaqlprofile, __FILE__, __LINE__,
+      session->libaqlprofile.aqlprofile_pmc_iterate_data(
+          packet_set->handle, iree_hal_amdgpu_profile_counter_collect_callback,
+          &collect_context),
+      "aqlprofile_pmc_iterate_data", "iterating AMDGPU counter sample values");
+  for (uint32_t i = 0;
+       i < counter_set->counter_count && iree_status_is_ok(status); ++i) {
+    const iree_hal_amdgpu_profile_counter_t* counter =
+        &counter_set->counters[i];
+    if (counter_value_counts[i] != counter->sample_value_count) {
+      status = iree_make_status(
+          IREE_STATUS_FAILED_PRECONDITION,
+          "aqlprofile returned %u values for AMDGPU counter '%.*s' but "
+          "metadata declared %u",
+          counter_value_counts[i], (int)counter->descriptor->name.size,
+          counter->descriptor->name.data, counter->sample_value_count);
+    }
+  }
+  return status;
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_pack_profile_counter_sample(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_counter_session_t* session, uint64_t event_position,
+    const iree_hal_profile_dispatch_event_t* event,
+    uint32_t counter_set_ordinal, uint32_t* counter_value_counts,
+    uint8_t* storage, iree_host_size_t* out_record_size) {
+  *out_record_size = 0;
+  const iree_hal_amdgpu_profile_counter_set_t* counter_set =
+      iree_hal_amdgpu_host_queue_profile_counter_set(queue,
+                                                     counter_set_ordinal);
+  if (IREE_UNLIKELY(!counter_set)) {
+    return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                            "AMDGPU counter set is not available");
+  }
+  iree_hal_amdgpu_profile_counter_sample_slot_t* slot =
+      iree_hal_amdgpu_host_queue_profile_counter_slot(queue, event_position,
+                                                      counter_set_ordinal);
+  if (IREE_UNLIKELY(!slot->packet_set.handle.handle || slot->sample_id == 0)) {
+    return iree_make_status(
+        IREE_STATUS_FAILED_PRECONDITION,
+        "AMDGPU counter sample slot was not prepared before flush");
+  }
+
+  iree_hal_profile_counter_sample_record_t sample_record;
+  iree_hal_amdgpu_profile_counter_initialize_sample_record(
+      queue, event, counter_set, slot, &sample_record);
+  iree_host_size_t record_size = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_profile_counter_sample_record_size(
+      sample_record.sample_value_count, &record_size));
+  if (IREE_UNLIKELY(record_size > UINT32_MAX)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "AMDGPU counter sample record exceeds uint32_t");
+  }
+  sample_record.record_length = (uint32_t)record_size;
+
+  memcpy(storage, &sample_record, sizeof(sample_record));
+  uint64_t* values = (uint64_t*)(void*)(storage + sizeof(sample_record));
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_profile_counter_collect_packet_set_values(
+          session, counter_set, &slot->packet_set, counter_value_counts,
+          values));
+  *out_record_size = record_size;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_pack_profile_counter_samples(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_counter_session_t* session,
+    uint64_t event_read_position, iree_host_size_t event_count,
+    const iree_hal_profile_dispatch_event_t* events,
+    uint32_t* counter_value_counts, uint8_t* storage) {
+  uint8_t* sample_ptr = storage;
+  iree_status_t status = iree_ok_status();
+  for (iree_host_size_t event_ordinal = 0;
+       event_ordinal < event_count && iree_status_is_ok(status);
+       ++event_ordinal) {
+    const uint64_t event_position = event_read_position + event_ordinal;
+    const iree_hal_profile_dispatch_event_t* event = &events[event_ordinal];
+    for (uint32_t counter_set_ordinal = 0;
+         counter_set_ordinal < queue->profiling.counters.set_count &&
+         iree_status_is_ok(status);
+         ++counter_set_ordinal) {
+      iree_host_size_t record_size = 0;
+      status = iree_hal_amdgpu_host_queue_pack_profile_counter_sample(
+          queue, session, event_position, event, counter_set_ordinal,
+          counter_value_counts, sample_ptr, &record_size);
+      if (iree_status_is_ok(status)) {
+        sample_ptr += record_size;
+      }
+    }
+  }
+  return status;
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_write_profile_counter_chunk(
+    const iree_hal_amdgpu_host_queue_t* queue, iree_hal_profile_sink_t* sink,
+    uint64_t session_id, iree_string_view_t chunk_name,
+    const uint8_t* sample_storage, iree_host_size_t sample_storage_size) {
+  iree_hal_profile_chunk_metadata_t metadata =
+      iree_hal_profile_chunk_metadata_default();
+  metadata.content_type = IREE_HAL_PROFILE_CONTENT_TYPE_COUNTER_SAMPLES;
+  metadata.name = chunk_name;
+  metadata.session_id = session_id;
+  metadata.stream_id = iree_hal_amdgpu_host_queue_profile_stream_id(queue);
+  metadata.physical_device_ordinal =
+      iree_hal_amdgpu_host_queue_profile_device_ordinal(queue);
+  metadata.queue_ordinal =
+      iree_hal_amdgpu_host_queue_profile_queue_ordinal(queue);
+  iree_const_byte_span_t iovec =
+      iree_make_const_byte_span(sample_storage, sample_storage_size);
+  return iree_hal_profile_sink_write(sink, &metadata, 1, &iovec);
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_write_profile_counter_samples(
+    iree_hal_amdgpu_host_queue_t* queue, iree_hal_profile_sink_t* sink,
+    uint64_t session_id, uint64_t event_read_position,
+    iree_host_size_t event_count,
+    const iree_hal_profile_dispatch_event_t* events) {
+  if (!sink || !event_count ||
+      !queue->profiling.counters.dispatch_samples.slots) {
+    return iree_ok_status();
+  }
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  iree_hal_amdgpu_profile_counter_session_t* session =
+      queue->profiling.counters.session;
+  iree_host_size_t sample_storage_size = 0;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_hal_amdgpu_host_queue_profile_counter_sample_storage_size(
+              queue, event_count, &sample_storage_size));
+
+  uint32_t max_counter_count = 0;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_hal_amdgpu_host_queue_profile_counter_max_counter_count(
+              queue, &max_counter_count));
+  iree_host_size_t counter_value_counts_size = 0;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0,
+      IREE_STRUCT_LAYOUT(0, &counter_value_counts_size,
+                         IREE_STRUCT_FIELD(max_counter_count, uint32_t, NULL)));
+  uint32_t* counter_value_counts = NULL;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0,
+      iree_allocator_malloc(queue->host_allocator, counter_value_counts_size,
+                            (void**)&counter_value_counts));
+
+  uint8_t* sample_storage = NULL;
+  iree_status_t status = iree_allocator_malloc(
+      queue->host_allocator, sample_storage_size, (void**)&sample_storage);
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_host_queue_pack_profile_counter_samples(
+        queue, session, event_read_position, event_count, events,
+        counter_value_counts, sample_storage);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_host_queue_write_profile_counter_chunk(
+        queue, sink, session_id,
+        iree_make_cstring_view("amdgpu.counter-samples"), sample_storage,
+        sample_storage_size);
+  }
+
+  iree_allocator_free(queue->host_allocator, sample_storage);
+  iree_allocator_free(queue->host_allocator, counter_value_counts);
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+static void iree_hal_amdgpu_profile_counter_initialize_range_record(
+    const iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hal_amdgpu_profile_counter_set_t* counter_set,
+    const iree_hal_amdgpu_profile_counter_range_slot_t* slot,
+    const iree_hal_amdgpu_profile_counter_range_ticks_t* ticks,
+    iree_hal_profile_counter_sample_record_t* out_record) {
+  *out_record = iree_hal_profile_counter_sample_record_default();
+  out_record->flags = IREE_HAL_PROFILE_COUNTER_SAMPLE_FLAG_DEVICE_TICK_RANGE;
+  out_record->scope = IREE_HAL_PROFILE_COUNTER_SAMPLE_SCOPE_DEVICE_TIME_RANGE;
+  out_record->sample_id = slot->sample_id;
+  out_record->counter_set_id = counter_set->counter_set_id;
+  out_record->stream_id = iree_hal_amdgpu_host_queue_profile_stream_id(queue);
+  out_record->start_tick = ticks->start_tick;
+  out_record->end_tick = ticks->end_tick;
+  out_record->physical_device_ordinal =
+      iree_hal_amdgpu_host_queue_profile_device_ordinal(queue);
+  out_record->queue_ordinal =
+      iree_hal_amdgpu_host_queue_profile_queue_ordinal(queue);
+  out_record->sample_value_count = counter_set->sample_value_count;
+}
+
+static iree_status_t
+iree_hal_amdgpu_host_queue_pack_profile_counter_range_sample(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_counter_session_t* session, uint32_t bank,
+    uint32_t counter_set_ordinal, uint32_t* counter_value_counts,
+    uint8_t* storage, iree_host_size_t* out_record_size) {
+  *out_record_size = 0;
+  const iree_hal_amdgpu_profile_counter_set_t* counter_set =
+      iree_hal_amdgpu_host_queue_profile_counter_set(queue,
+                                                     counter_set_ordinal);
+  if (IREE_UNLIKELY(!counter_set)) {
+    return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                            "AMDGPU counter set is not available");
+  }
+  iree_hal_amdgpu_profile_counter_range_slot_t* slot =
+      iree_hal_amdgpu_host_queue_profile_counter_range_slot(
+          queue, bank, counter_set_ordinal);
+  if (IREE_UNLIKELY(!slot->packet_set.handle.handle || slot->sample_id == 0)) {
+    return iree_make_status(
+        IREE_STATUS_FAILED_PRECONDITION,
+        "AMDGPU counter range slot was not started before flush");
+  }
+
+  iree_hal_profile_counter_sample_record_t sample_record;
+  iree_hal_amdgpu_profile_counter_initialize_range_record(
+      queue, counter_set, slot,
+      iree_hal_amdgpu_host_queue_profile_counter_range_ticks(queue, bank),
+      &sample_record);
+  iree_host_size_t record_size = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_profile_counter_sample_record_size(
+      sample_record.sample_value_count, &record_size));
+  if (IREE_UNLIKELY(record_size > UINT32_MAX)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "AMDGPU counter sample record exceeds uint32_t");
+  }
+  sample_record.record_length = (uint32_t)record_size;
+
+  memcpy(storage, &sample_record, sizeof(sample_record));
+  uint64_t* values = (uint64_t*)(void*)(storage + sizeof(sample_record));
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_profile_counter_collect_packet_set_values(
+          session, counter_set, &slot->packet_set, counter_value_counts,
+          values));
+  *out_record_size = record_size;
+  return iree_ok_status();
+}
+
+static iree_status_t
+iree_hal_amdgpu_host_queue_pack_profile_counter_range_samples(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_counter_session_t* session, uint32_t bank,
+    uint32_t* counter_value_counts, uint8_t* storage) {
+  uint8_t* sample_ptr = storage;
+  iree_status_t status = iree_ok_status();
+  for (uint32_t counter_set_ordinal = 0;
+       counter_set_ordinal < queue->profiling.counters.set_count &&
+       iree_status_is_ok(status);
+       ++counter_set_ordinal) {
+    iree_host_size_t record_size = 0;
+    status = iree_hal_amdgpu_host_queue_pack_profile_counter_range_sample(
+        queue, session, bank, counter_set_ordinal, counter_value_counts,
+        sample_ptr, &record_size);
+    if (iree_status_is_ok(status)) {
+      sample_ptr += record_size;
+    }
+  }
+  return status;
+}
+
+static iree_status_t
+iree_hal_amdgpu_host_queue_write_profile_counter_range_samples(
+    iree_hal_amdgpu_host_queue_t* queue, iree_hal_profile_sink_t* sink,
+    uint64_t session_id, uint32_t bank) {
+  if (!sink || !queue->profiling.counters.ranges.slots) {
+    return iree_ok_status();
+  }
+
+  iree_hal_amdgpu_profile_counter_session_t* session =
+      queue->profiling.counters.session;
+  iree_host_size_t sample_storage_size = 0;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_host_queue_profile_counter_sample_storage_size(
+          queue, /*event_count=*/1, &sample_storage_size));
+
+  uint32_t max_counter_count = 0;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_host_queue_profile_counter_max_counter_count(
+          queue, &max_counter_count));
+  iree_host_size_t counter_value_counts_size = 0;
+  IREE_RETURN_IF_ERROR(
+      IREE_STRUCT_LAYOUT(0, &counter_value_counts_size,
+                         IREE_STRUCT_FIELD(max_counter_count, uint32_t, NULL)));
+  uint32_t* counter_value_counts = NULL;
+  IREE_RETURN_IF_ERROR(iree_allocator_malloc(queue->host_allocator,
+                                             counter_value_counts_size,
+                                             (void**)&counter_value_counts));
+
+  uint8_t* sample_storage = NULL;
+  iree_status_t status = iree_allocator_malloc(
+      queue->host_allocator, sample_storage_size, (void**)&sample_storage);
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_host_queue_pack_profile_counter_range_samples(
+        queue, session, bank, counter_value_counts, sample_storage);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_host_queue_write_profile_counter_chunk(
+        queue, sink, session_id,
+        iree_make_cstring_view("amdgpu.counter-ranges"), sample_storage,
+        sample_storage_size);
+  }
+
+  iree_allocator_free(queue->host_allocator, sample_storage);
+  iree_allocator_free(queue->host_allocator, counter_value_counts);
+  return status;
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_flush_profile_counter_ranges(
+    iree_hal_amdgpu_host_queue_t* queue, iree_hal_profile_sink_t* sink,
+    uint64_t session_id,
+    iree_hal_amdgpu_profile_counter_range_flush_flags_t flags) {
+  if (!queue->profiling.counters.ranges.slots) {
+    return iree_ok_status();
+  }
+
+  const bool should_restart = iree_any_bit_set(
+      flags, IREE_HAL_AMDGPU_PROFILE_COUNTER_RANGE_FLUSH_FLAG_RESTART);
+  if (!queue->profiling.counters.ranges.is_active) {
+    return should_restart
+               ? iree_hal_amdgpu_host_queue_start_profile_counter_ranges(queue)
+               : iree_ok_status();
+  }
+
+  uint32_t packet_count = 0;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_host_queue_profile_counter_range_flush_packet_count(
+          queue, flags, &packet_count));
+
+  iree_hal_amdgpu_physical_device_t* physical_device =
+      iree_hal_amdgpu_host_queue_profile_counter_physical_device(queue);
+  hsa_signal_t completion_signal = {0};
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_host_signal_pool_acquire(
+      &physical_device->host_signal_pool, /*initial_value=*/1,
+      &completion_signal));
+
+  uint32_t stopped_bank = 0;
+  iree_slim_mutex_lock(&queue->locks.submission_mutex);
+  iree_status_t status = iree_ok_status();
+  if (IREE_UNLIKELY(queue->is_shutting_down)) {
+    status = iree_make_status(IREE_STATUS_CANCELLED, "queue shutting down");
+  }
+  if (iree_status_is_ok(status)) {
+    stopped_bank = queue->profiling.counters.ranges.active_bank;
+    const uint64_t first_packet_id =
+        iree_hal_amdgpu_aql_ring_reserve(&queue->aql_ring, packet_count);
+    iree_hal_amdgpu_host_queue_commit_profile_counter_range_flush(
+        queue, stopped_bank, completion_signal, first_packet_id);
+
+    uint64_t next_packet_id =
+        first_packet_id + 1u + queue->profiling.counters.set_count * 2u;
+    if (should_restart) {
+      const uint32_t next_bank =
+          (stopped_bank + 1u) % queue->profiling.counters.ranges.bank_count;
+      iree_hal_amdgpu_host_queue_commit_profile_counter_range_start(
+          queue, next_bank, next_packet_id);
+      queue->profiling.counters.ranges.active_bank = next_bank;
+      queue->profiling.counters.ranges.is_active = true;
+    } else {
+      queue->profiling.counters.ranges.is_active = false;
+    }
+    iree_hal_amdgpu_aql_ring_doorbell(&queue->aql_ring,
+                                      first_packet_id + packet_count - 1u);
+  }
+  iree_slim_mutex_unlock(&queue->locks.submission_mutex);
+
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_host_queue_wait_profile_counter_range(
+        queue, completion_signal);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_host_queue_write_profile_counter_range_samples(
+        queue, sink, session_id, stopped_bank);
+  }
+  iree_hal_amdgpu_host_signal_pool_release(&physical_device->host_signal_pool,
+                                           completion_signal);
+  return status;
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/profile_counters.h b/runtime/src/iree/hal/drivers/amdgpu/profile_counters.h
new file mode 100644
index 0000000..6c341f3
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/profile_counters.h
@@ -0,0 +1,186 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_PROFILE_COUNTERS_H_
+#define IREE_HAL_DRIVERS_AMDGPU_PROFILE_COUNTERS_H_
+
+#include "iree/hal/device.h"
+#include "iree/hal/drivers/amdgpu/util/aql_emitter.h"
+#include "iree/hal/drivers/amdgpu/util/aql_ring.h"
+#include "iree/hal/profile_schema.h"
+#include "iree/hal/profile_sink.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+typedef struct iree_hal_amdgpu_host_queue_t iree_hal_amdgpu_host_queue_t;
+typedef struct iree_hal_amdgpu_logical_device_t
+    iree_hal_amdgpu_logical_device_t;
+typedef struct iree_hal_amdgpu_profile_counter_sample_slot_t
+    iree_hal_amdgpu_profile_counter_sample_slot_t;
+typedef struct iree_hal_amdgpu_profile_counter_range_slot_t
+    iree_hal_amdgpu_profile_counter_range_slot_t;
+typedef struct iree_hal_amdgpu_profile_counter_session_t
+    iree_hal_amdgpu_profile_counter_session_t;
+typedef struct iree_hal_amdgpu_profile_dispatch_event_reservation_t
+    iree_hal_amdgpu_profile_dispatch_event_reservation_t;
+
+// Flags selecting which counter resources to enable on a host queue.
+typedef uint32_t iree_hal_amdgpu_profile_counter_enable_flags_t;
+enum iree_hal_amdgpu_profile_counter_enable_flag_bits_t {
+  IREE_HAL_AMDGPU_PROFILE_COUNTER_ENABLE_FLAG_NONE = 0u,
+  // Enables dispatch-attributed counter sample resources.
+  IREE_HAL_AMDGPU_PROFILE_COUNTER_ENABLE_FLAG_DISPATCH_SAMPLES = 1u << 0,
+  // Enables queue-carried physical-device range counter resources.
+  IREE_HAL_AMDGPU_PROFILE_COUNTER_ENABLE_FLAG_QUEUE_RANGES = 1u << 1,
+};
+
+// Flags controlling queue-range counter flush behavior.
+typedef uint32_t iree_hal_amdgpu_profile_counter_range_flush_flags_t;
+enum iree_hal_amdgpu_profile_counter_range_flush_flag_bits_t {
+  IREE_HAL_AMDGPU_PROFILE_COUNTER_RANGE_FLUSH_FLAG_NONE = 0u,
+  // Starts a new range on the queue after stopping the current range.
+  IREE_HAL_AMDGPU_PROFILE_COUNTER_RANGE_FLUSH_FLAG_RESTART = 1u << 0,
+};
+
+// Allocates a hardware counter profiling session from |options|.
+//
+// The returned session is immutable after creation except for its monotonically
+// assigned sample identifiers. The logical-device profiling begin path owns the
+// session and publishes a borrowed pointer to queues while profiling is active.
+iree_status_t iree_hal_amdgpu_profile_counter_session_allocate(
+    iree_hal_amdgpu_logical_device_t* logical_device,
+    const iree_hal_device_profiling_options_t* options,
+    iree_allocator_t host_allocator,
+    iree_hal_amdgpu_profile_counter_session_t** out_session);
+
+// Frees |session| and releases its aqlprofile library reference.
+void iree_hal_amdgpu_profile_counter_session_free(
+    iree_hal_amdgpu_profile_counter_session_t* session);
+
+// Returns true when |session| contains counter sets to capture.
+bool iree_hal_amdgpu_profile_counter_session_is_active(
+    const iree_hal_amdgpu_profile_counter_session_t* session);
+
+// Returns true when |session| captures dispatch-attributed samples.
+bool iree_hal_amdgpu_profile_counter_session_captures_dispatch_samples(
+    const iree_hal_amdgpu_profile_counter_session_t* session);
+
+// Returns true when |session| captures queue-level counter ranges.
+bool iree_hal_amdgpu_profile_counter_session_captures_queue_ranges(
+    const iree_hal_amdgpu_profile_counter_session_t* session);
+
+// Writes counter-set and counter metadata chunks for |session|.
+iree_status_t iree_hal_amdgpu_profile_counter_session_write_metadata(
+    const iree_hal_amdgpu_profile_counter_session_t* session,
+    iree_hal_profile_sink_t* sink, uint64_t session_id,
+    iree_string_view_t stream_name);
+
+// Enables host-queue-carried counter sample storage for |queue|.
+//
+// |flags| selects which session capture resources are materialized on this
+// queue. Dispatch slots create aqlprofile handles lazily and retain them until
+// profiling is disabled so steady counter captures reuse packet/output storage
+// after the dispatch event cursor advances past each slot. Range resources are
+// pre-created because range flush/restart is cold but should not allocate while
+// the queue is stopped waiting for samples.
+iree_status_t iree_hal_amdgpu_host_queue_enable_profile_counters(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_counter_session_t* session,
+    iree_hal_amdgpu_profile_counter_enable_flags_t flags);
+
+// Disables queue-local counter sample storage and deletes all slot handles.
+void iree_hal_amdgpu_host_queue_disable_profile_counters(
+    iree_hal_amdgpu_host_queue_t* queue);
+
+// Starts queue-carried physical-device counter ranges for |queue|.
+iree_status_t iree_hal_amdgpu_host_queue_start_profile_counter_ranges(
+    iree_hal_amdgpu_host_queue_t* queue);
+
+// Stops, optionally writes, and optionally restarts physical-device ranges.
+iree_status_t iree_hal_amdgpu_host_queue_flush_profile_counter_ranges(
+    iree_hal_amdgpu_host_queue_t* queue, iree_hal_profile_sink_t* sink,
+    uint64_t session_id,
+    iree_hal_amdgpu_profile_counter_range_flush_flags_t flags);
+
+// Returns the number of additional AQL packets needed for |reservation|.
+uint32_t iree_hal_amdgpu_host_queue_profile_counter_packet_count(
+    const iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_dispatch_event_reservation_t reservation);
+
+// Returns the number of counter sets captured for each profiled dispatch.
+uint32_t iree_hal_amdgpu_host_queue_profile_counter_set_count(
+    const iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_dispatch_event_reservation_t reservation);
+
+// Prepares counter sample slots for |reservation|.
+//
+// Caller must hold queue->locks.submission_mutex and must call this only after
+// the dispatch profile events have been reserved. Handles are created lazily
+// per event-ring slot and then reused only after the dispatch event cursor has
+// advanced past the slot.
+iree_status_t iree_hal_amdgpu_host_queue_prepare_profile_counter_samples(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_dispatch_event_reservation_t reservation);
+
+// Emplaces counter start packets beginning at |first_packet_index|.
+//
+// Packet bodies are populated and commit metadata is written into
+// |packet_headers|/|packet_setups|, but headers are not committed to the AQL
+// ring. Command-buffer replay uses this so all packet bodies can be populated
+// before publishing headers in order.
+void iree_hal_amdgpu_host_queue_emplace_profile_counter_start_packets(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t event_position,
+    uint32_t counter_set_count, uint64_t first_packet_id,
+    uint32_t first_packet_index,
+    iree_hal_amdgpu_aql_packet_control_t packet_control,
+    uint16_t* packet_headers, uint16_t* packet_setups);
+
+// Emplaces counter read/stop packet pairs beginning at |first_packet_index|.
+//
+// Packets are emitted in rocprof/aqlprofile order: read first, then stop.
+void iree_hal_amdgpu_host_queue_emplace_profile_counter_read_stop_packets(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t event_position,
+    uint32_t counter_set_count, uint64_t first_packet_id,
+    uint32_t first_packet_index,
+    iree_hal_amdgpu_aql_packet_control_t packet_control,
+    uint16_t* packet_headers, uint16_t* packet_setups);
+
+// Commits counter start packets beginning at |first_packet_id|.
+//
+// The caller owns all surrounding ordering and doorbell publication. Each
+// packet references the prepared sample slot for |event_position| and has no
+// completion signal; dispatch timestamp harvest remains the queue-visible
+// completion point for the profiled dispatch.
+void iree_hal_amdgpu_host_queue_commit_profile_counter_start_packets(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t event_position,
+    uint32_t counter_set_count, uint64_t first_packet_id,
+    iree_hal_amdgpu_aql_packet_control_t packet_control);
+
+// Commits counter read/stop packet pairs beginning at |first_packet_id|.
+//
+// Packets are emitted in rocprof/aqlprofile order: read first, then stop.
+void iree_hal_amdgpu_host_queue_commit_profile_counter_read_stop_packets(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t event_position,
+    uint32_t counter_set_count, uint64_t first_packet_id,
+    iree_hal_amdgpu_aql_packet_control_t packet_control);
+
+// Writes counter sample chunks for retired dispatch events in |events|.
+//
+// The caller must not advance the dispatch event read cursor until this returns
+// successfully; queue slot reuse is what makes the aqlprofile handles safe.
+iree_status_t iree_hal_amdgpu_host_queue_write_profile_counter_samples(
+    iree_hal_amdgpu_host_queue_t* queue, iree_hal_profile_sink_t* sink,
+    uint64_t session_id, uint64_t event_read_position,
+    iree_host_size_t event_count,
+    const iree_hal_profile_dispatch_event_t* events);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_PROFILE_COUNTERS_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/profile_device_metrics.c b/runtime/src/iree/hal/drivers/amdgpu/profile_device_metrics.c
new file mode 100644
index 0000000..5008051
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/profile_device_metrics.c
@@ -0,0 +1,364 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/profile_device_metrics.h"
+
+#include <inttypes.h>
+#include <string.h>
+
+#include "iree/hal/api.h"
+#include "iree/hal/drivers/amdgpu/logical_device.h"
+#include "iree/hal/drivers/amdgpu/profile_device_metrics_source.h"
+
+//===----------------------------------------------------------------------===//
+// Device metric source session
+//===----------------------------------------------------------------------===//
+
+struct iree_hal_amdgpu_profile_device_metrics_session_t {
+  // Host allocator used for session storage.
+  iree_allocator_t host_allocator;
+  // Number of initialized entries in |sources|.
+  iree_host_size_t source_count;
+  // Per-physical-device metric sources.
+  iree_hal_amdgpu_profile_device_metric_source_t sources[];
+};
+
+//===----------------------------------------------------------------------===//
+// Sample builder
+//===----------------------------------------------------------------------===//
+
+static bool iree_hal_amdgpu_profile_device_metrics_value_is_present(
+    const iree_hal_amdgpu_profile_device_metric_sample_builder_t* builder,
+    uint64_t metric_id) {
+  for (uint32_t i = 0; i < builder->record.value_count; ++i) {
+    if (builder->values[i].metric_id == metric_id) return true;
+  }
+  return false;
+}
+
+void iree_hal_amdgpu_profile_device_metric_sample_builder_append_u64(
+    iree_hal_amdgpu_profile_device_metric_sample_builder_t* builder,
+    uint64_t metric_id, uint64_t value) {
+  if (iree_hal_amdgpu_profile_device_metrics_value_is_present(builder,
+                                                              metric_id)) {
+    return;
+  }
+  if (builder->record.value_count >=
+      IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_MAX_SAMPLE_VALUES) {
+    builder->record.flags |= IREE_HAL_PROFILE_DEVICE_METRIC_SAMPLE_FLAG_PARTIAL;
+    return;
+  }
+  builder->values[builder->record.value_count++] =
+      (iree_hal_profile_device_metric_value_t){
+          .metric_id = metric_id,
+          .value_bits = value,
+      };
+}
+
+void iree_hal_amdgpu_profile_device_metric_sample_builder_append_i64(
+    iree_hal_amdgpu_profile_device_metric_sample_builder_t* builder,
+    uint64_t metric_id, int64_t value) {
+  iree_hal_amdgpu_profile_device_metric_sample_builder_append_u64(
+      builder, metric_id, (uint64_t)value);
+}
+
+//===----------------------------------------------------------------------===//
+// Profile record emission
+//===----------------------------------------------------------------------===//
+
+static iree_status_t iree_hal_amdgpu_profile_device_metrics_write_chunk(
+    iree_hal_profile_sink_t* sink, uint64_t session_id,
+    iree_string_view_t stream_name, iree_string_view_t content_type,
+    const uint8_t* storage, iree_host_size_t storage_size) {
+  iree_hal_profile_chunk_metadata_t metadata =
+      iree_hal_profile_chunk_metadata_default();
+  metadata.content_type = content_type;
+  metadata.name = stream_name;
+  metadata.session_id = session_id;
+  iree_const_byte_span_t iovec =
+      iree_make_const_byte_span(storage, storage_size);
+  return iree_hal_profile_sink_write(sink, &metadata, 1, &iovec);
+}
+
+static iree_status_t iree_hal_amdgpu_profile_device_metrics_source_record_size(
+    const iree_hal_amdgpu_profile_device_metric_source_t* source,
+    iree_host_size_t* out_record_size) {
+  const iree_host_size_t name_length = strlen(source->metadata.name);
+  return IREE_STRUCT_LAYOUT(
+      0, out_record_size,
+      IREE_STRUCT_FIELD(1, iree_hal_profile_device_metric_source_record_t,
+                        NULL),
+      IREE_STRUCT_FIELD(name_length, char, NULL));
+}
+
+static iree_status_t
+iree_hal_amdgpu_profile_device_metrics_descriptor_record_size(
+    const iree_hal_profile_metric_descriptor_t* descriptor,
+    iree_host_size_t* out_record_size) {
+  return IREE_STRUCT_LAYOUT(
+      0, out_record_size,
+      IREE_STRUCT_FIELD(1, iree_hal_profile_device_metric_descriptor_record_t,
+                        NULL),
+      IREE_STRUCT_FIELD(descriptor->name.size + descriptor->description.size,
+                        char, NULL));
+}
+
+static iree_status_t iree_hal_amdgpu_profile_device_metrics_pack_source_record(
+    const iree_hal_amdgpu_profile_device_metric_source_t* source,
+    uint8_t* storage, iree_host_size_t storage_capacity,
+    iree_host_size_t* out_storage_size) {
+  iree_host_size_t record_size = 0;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_profile_device_metrics_source_record_size(source,
+                                                                &record_size));
+  if (record_size > UINT32_MAX || record_size > storage_capacity) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "AMDGPU metric source record exceeds storage");
+  }
+
+  const iree_host_size_t name_length = strlen(source->metadata.name);
+  iree_hal_profile_device_metric_source_record_t record =
+      iree_hal_profile_device_metric_source_record_default();
+  record.record_length = (uint32_t)record_size;
+  record.source_id = source->metadata.id;
+  record.physical_device_ordinal = source->metadata.physical_device_ordinal;
+  record.device_class = IREE_HAL_PROFILE_DEVICE_CLASS_GPU;
+  record.source_kind = source->metadata.kind;
+  record.source_revision = source->metadata.revision;
+  record.metric_count = source->metrics.count;
+  record.name_length = (uint32_t)name_length;
+
+  memcpy(storage, &record, sizeof(record));
+  memcpy(storage + sizeof(record), source->metadata.name, name_length);
+  *out_storage_size = record_size;
+  return iree_ok_status();
+}
+
+static iree_status_t
+iree_hal_amdgpu_profile_device_metrics_pack_descriptor_record(
+    const iree_hal_amdgpu_profile_device_metric_source_t* source,
+    const iree_hal_profile_metric_descriptor_t* descriptor, uint8_t* storage,
+    iree_host_size_t storage_capacity, iree_host_size_t* out_storage_size) {
+  iree_host_size_t record_size = 0;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_profile_device_metrics_descriptor_record_size(
+          descriptor, &record_size));
+  if (record_size > UINT32_MAX || record_size > storage_capacity) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "AMDGPU metric descriptor record exceeds storage");
+  }
+
+  iree_hal_profile_device_metric_descriptor_record_t record =
+      iree_hal_profile_device_metric_descriptor_record_default();
+  record.record_length = (uint32_t)record_size;
+  record.source_id = source->metadata.id;
+  record.metric_id = descriptor->metric_id;
+  record.unit = descriptor->unit;
+  record.value_kind = descriptor->value_kind;
+  record.semantic = descriptor->semantic;
+  record.plot_hint = descriptor->plot_hint;
+  record.name_length = (uint32_t)descriptor->name.size;
+  record.description_length = (uint32_t)descriptor->description.size;
+
+  memcpy(storage, &record, sizeof(record));
+  uint8_t* string_ptr = storage + sizeof(record);
+  memcpy(string_ptr, descriptor->name.data, descriptor->name.size);
+  string_ptr += descriptor->name.size;
+  memcpy(string_ptr, descriptor->description.data,
+         descriptor->description.size);
+  *out_storage_size = record_size;
+  return iree_ok_status();
+}
+
+static iree_status_t
+iree_hal_amdgpu_profile_device_metrics_write_source_metadata(
+    const iree_hal_amdgpu_profile_device_metric_source_t* source,
+    iree_hal_profile_sink_t* sink, uint64_t session_id,
+    iree_string_view_t stream_name) {
+  uint8_t storage[128] = {0};
+  iree_host_size_t storage_size = 0;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_profile_device_metrics_pack_source_record(
+          source, storage, sizeof(storage), &storage_size));
+  return iree_hal_amdgpu_profile_device_metrics_write_chunk(
+      sink, session_id, stream_name,
+      IREE_HAL_PROFILE_CONTENT_TYPE_DEVICE_METRIC_SOURCES, storage,
+      storage_size);
+}
+
+static iree_status_t
+iree_hal_amdgpu_profile_device_metrics_write_descriptor_metadata(
+    const iree_hal_amdgpu_profile_device_metric_source_t* source,
+    iree_hal_profile_sink_t* sink, uint64_t session_id,
+    iree_string_view_t stream_name, uint64_t metric_id) {
+  const iree_hal_profile_metric_descriptor_t* descriptor =
+      iree_hal_profile_builtin_metric_descriptor_lookup(metric_id);
+  if (!descriptor) {
+    return iree_make_status(
+        IREE_STATUS_INTERNAL,
+        "AMDGPU metric id %" PRIu64 " has no built-in descriptor", metric_id);
+  }
+
+  uint8_t storage[256] = {0};
+  iree_host_size_t storage_size = 0;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_profile_device_metrics_pack_descriptor_record(
+          source, descriptor, storage, sizeof(storage), &storage_size));
+  return iree_hal_amdgpu_profile_device_metrics_write_chunk(
+      sink, session_id, stream_name,
+      IREE_HAL_PROFILE_CONTENT_TYPE_DEVICE_METRIC_DESCRIPTORS, storage,
+      storage_size);
+}
+
+static iree_status_t iree_hal_amdgpu_profile_device_metrics_write_sample(
+    const iree_hal_amdgpu_profile_device_metric_sample_builder_t* builder,
+    iree_hal_profile_sink_t* sink, uint64_t session_id,
+    iree_string_view_t stream_name) {
+  iree_host_size_t storage_size = 0;
+  IREE_RETURN_IF_ERROR(IREE_STRUCT_LAYOUT(
+      0, &storage_size,
+      IREE_STRUCT_FIELD(1, iree_hal_profile_device_metric_sample_record_t,
+                        NULL),
+      IREE_STRUCT_FIELD(builder->record.value_count,
+                        iree_hal_profile_device_metric_value_t, NULL)));
+  if (IREE_UNLIKELY(storage_size > UINT32_MAX)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "AMDGPU metric sample exceeds uint32_t");
+  }
+
+  iree_hal_profile_device_metric_sample_record_t record = builder->record;
+  record.record_length = (uint32_t)storage_size;
+  iree_const_byte_span_t iovecs[2] = {
+      iree_make_const_byte_span((const uint8_t*)&record, sizeof(record)),
+      iree_make_const_byte_span(
+          (const uint8_t*)builder->values,
+          builder->record.value_count * sizeof(builder->values[0])),
+  };
+
+  iree_hal_profile_chunk_metadata_t metadata =
+      iree_hal_profile_chunk_metadata_default();
+  metadata.content_type = IREE_HAL_PROFILE_CONTENT_TYPE_DEVICE_METRIC_SAMPLES;
+  metadata.name = stream_name;
+  metadata.session_id = session_id;
+  metadata.stream_id = record.source_id;
+  return iree_hal_profile_sink_write(
+      sink, &metadata, builder->record.value_count ? 2 : 1, iovecs);
+}
+
+//===----------------------------------------------------------------------===//
+// Session lifecycle
+//===----------------------------------------------------------------------===//
+
+iree_status_t iree_hal_amdgpu_profile_device_metrics_session_allocate(
+    iree_hal_amdgpu_logical_device_t* logical_device,
+    const iree_hal_device_profiling_options_t* options,
+    iree_allocator_t host_allocator,
+    iree_hal_amdgpu_profile_device_metrics_session_t** out_session) {
+  *out_session = NULL;
+  if (!iree_hal_device_profiling_options_requests_device_metrics(options)) {
+    return iree_ok_status();
+  }
+  if (IREE_UNLIKELY(!options->sink)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "AMDGPU device metrics profiling requires a profile sink");
+  }
+  if (IREE_UNLIKELY(logical_device->physical_device_count > UINT32_MAX)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "AMDGPU physical device count exceeds uint32_t");
+  }
+
+  iree_host_size_t session_size = 0;
+  IREE_RETURN_IF_ERROR(IREE_STRUCT_LAYOUT(
+      sizeof(iree_hal_amdgpu_profile_device_metrics_session_t), &session_size,
+      IREE_STRUCT_FIELD(logical_device->physical_device_count,
+                        iree_hal_amdgpu_profile_device_metric_source_t, NULL)));
+
+  iree_hal_amdgpu_profile_device_metrics_session_t* session = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_allocator_malloc(host_allocator, session_size, (void**)&session));
+  memset(session, 0, session_size);
+  session->host_allocator = host_allocator;
+
+  iree_status_t status = iree_ok_status();
+  for (iree_host_size_t i = 0;
+       i < logical_device->physical_device_count && iree_status_is_ok(status);
+       ++i) {
+    status = iree_hal_amdgpu_profile_device_metric_source_initialize(
+        logical_device->physical_devices[i], host_allocator,
+        &session->sources[i]);
+    if (iree_status_is_ok(status)) {
+      ++session->source_count;
+    }
+  }
+  if (iree_status_is_ok(status)) {
+    *out_session = session;
+  } else {
+    iree_hal_amdgpu_profile_device_metrics_session_free(session);
+  }
+  return status;
+}
+
+void iree_hal_amdgpu_profile_device_metrics_session_free(
+    iree_hal_amdgpu_profile_device_metrics_session_t* session) {
+  if (!session) return;
+  for (iree_host_size_t i = 0; i < session->source_count; ++i) {
+    iree_hal_amdgpu_profile_device_metric_source_t* source =
+        &session->sources[i];
+    iree_hal_amdgpu_profile_device_metric_source_deinitialize(source);
+  }
+  iree_allocator_free(session->host_allocator, session);
+}
+
+iree_status_t iree_hal_amdgpu_profile_device_metrics_session_write_metadata(
+    const iree_hal_amdgpu_profile_device_metrics_session_t* session,
+    iree_hal_profile_sink_t* sink, uint64_t session_id,
+    iree_string_view_t stream_name) {
+  if (!session) return iree_ok_status();
+  iree_status_t status = iree_ok_status();
+  for (iree_host_size_t i = 0;
+       i < session->source_count && iree_status_is_ok(status); ++i) {
+    const iree_hal_amdgpu_profile_device_metric_source_t* source =
+        &session->sources[i];
+    status = iree_hal_amdgpu_profile_device_metrics_write_source_metadata(
+        source, sink, session_id, stream_name);
+    for (iree_host_size_t j = 0;
+         j < source->metrics.count && iree_status_is_ok(status); ++j) {
+      status = iree_hal_amdgpu_profile_device_metrics_write_descriptor_metadata(
+          source, sink, session_id, stream_name, source->metrics.ids[j]);
+    }
+  }
+  return status;
+}
+
+iree_status_t iree_hal_amdgpu_profile_device_metrics_session_sample_and_write(
+    iree_hal_amdgpu_profile_device_metrics_session_t* session,
+    iree_hal_profile_sink_t* sink, uint64_t session_id,
+    iree_string_view_t stream_name) {
+  if (!session) return iree_ok_status();
+  iree_status_t status = iree_ok_status();
+  for (iree_host_size_t i = 0;
+       i < session->source_count && iree_status_is_ok(status); ++i) {
+    iree_hal_amdgpu_profile_device_metric_source_t* source =
+        &session->sources[i];
+    iree_hal_amdgpu_profile_device_metric_sample_builder_t builder;
+    memset(&builder, 0, sizeof(builder));
+    builder.record = iree_hal_profile_device_metric_sample_record_default();
+    builder.record.sample_id = source->sampling.next_sample_id++;
+    builder.record.source_id = source->metadata.id;
+    builder.record.physical_device_ordinal =
+        source->metadata.physical_device_ordinal;
+    builder.record.host_time_begin_ns = iree_time_now();
+    status =
+        iree_hal_amdgpu_profile_device_metric_source_sample(source, &builder);
+    builder.record.host_time_end_ns = iree_time_now();
+    if (iree_status_is_ok(status)) {
+      status = iree_hal_amdgpu_profile_device_metrics_write_sample(
+          &builder, sink, session_id, stream_name);
+    }
+  }
+  return status;
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/profile_device_metrics.h b/runtime/src/iree/hal/drivers/amdgpu/profile_device_metrics.h
new file mode 100644
index 0000000..2d4031d
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/profile_device_metrics.h
@@ -0,0 +1,52 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_PROFILE_DEVICE_METRICS_H_
+#define IREE_HAL_DRIVERS_AMDGPU_PROFILE_DEVICE_METRICS_H_
+
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+typedef struct iree_hal_amdgpu_logical_device_t
+    iree_hal_amdgpu_logical_device_t;
+typedef struct iree_hal_amdgpu_profile_device_metrics_session_t
+    iree_hal_amdgpu_profile_device_metrics_session_t;
+
+// Allocates a flush-sampled AMDGPU device metrics session from |options|.
+//
+// The session owns only cold-path sampler state: source identities, platform
+// metric handles, and per-source sample ids. Queue submission/completion paths
+// never reference the session.
+iree_status_t iree_hal_amdgpu_profile_device_metrics_session_allocate(
+    iree_hal_amdgpu_logical_device_t* logical_device,
+    const iree_hal_device_profiling_options_t* options,
+    iree_allocator_t host_allocator,
+    iree_hal_amdgpu_profile_device_metrics_session_t** out_session);
+
+// Frees |session| and releases any platform metric handles it owns.
+void iree_hal_amdgpu_profile_device_metrics_session_free(
+    iree_hal_amdgpu_profile_device_metrics_session_t* session);
+
+// Writes source and descriptor metadata chunks for |session|.
+iree_status_t iree_hal_amdgpu_profile_device_metrics_session_write_metadata(
+    const iree_hal_amdgpu_profile_device_metrics_session_t* session,
+    iree_hal_profile_sink_t* sink, uint64_t session_id,
+    iree_string_view_t stream_name);
+
+// Samples all active device metric sources and writes sample chunks.
+iree_status_t iree_hal_amdgpu_profile_device_metrics_session_sample_and_write(
+    iree_hal_amdgpu_profile_device_metrics_session_t* session,
+    iree_hal_profile_sink_t* sink, uint64_t session_id,
+    iree_string_view_t stream_name);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_PROFILE_DEVICE_METRICS_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/profile_device_metrics_linux.c b/runtime/src/iree/hal/drivers/amdgpu/profile_device_metrics_linux.c
new file mode 100644
index 0000000..db4e16f
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/profile_device_metrics_linux.c
@@ -0,0 +1,1026 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <inttypes.h>
+#include <string.h>
+
+#include "iree/hal/drivers/amdgpu/profile_device_metrics_source.h"
+
+#if defined(IREE_PLATFORM_LINUX)
+
+#include <dirent.h>
+#include <errno.h>
+#include <fcntl.h>
+#include <stdarg.h>
+#include <stddef.h>
+#include <stdio.h>
+#include <unistd.h>
+
+#include "iree/base/alignment.h"
+#include "iree/hal/drivers/amdgpu/physical_device.h"
+
+//===----------------------------------------------------------------------===//
+// Linux sysfs/gpu_metrics schema subset
+//===----------------------------------------------------------------------===//
+
+#define IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_GPU_BUFFER_LENGTH 4096
+#define IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_MAX_PATH_LENGTH 256
+#define IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SCALAR_FILE_CAPACITY \
+  10
+#define IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_TEXT_BUFFER_LENGTH 64
+#define IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_SOURCE_KIND_LINUX_SYSFS 1u
+
+typedef char iree_hal_amdgpu_profile_device_metrics_text_buffer_t
+    [IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_TEXT_BUFFER_LENGTH];
+
+// Built-in metrics emitted by the AMDGPU sysfs sampler.
+static const uint64_t iree_hal_amdgpu_profile_device_metrics_metric_ids[] = {
+    IREE_HAL_PROFILE_BUILTIN_METRIC_ID_CLOCK_COMPUTE_CURRENT,
+    IREE_HAL_PROFILE_BUILTIN_METRIC_ID_CLOCK_MEMORY_CURRENT,
+    IREE_HAL_PROFILE_BUILTIN_METRIC_ID_TEMPERATURE_EDGE,
+    IREE_HAL_PROFILE_BUILTIN_METRIC_ID_TEMPERATURE_HOTSPOT,
+    IREE_HAL_PROFILE_BUILTIN_METRIC_ID_TEMPERATURE_MEMORY,
+    IREE_HAL_PROFILE_BUILTIN_METRIC_ID_POWER_SOCKET,
+    IREE_HAL_PROFILE_BUILTIN_METRIC_ID_ACTIVITY_COMPUTE,
+    IREE_HAL_PROFILE_BUILTIN_METRIC_ID_ACTIVITY_MEMORY,
+    IREE_HAL_PROFILE_BUILTIN_METRIC_ID_MEMORY_LOCAL_USED,
+    IREE_HAL_PROFILE_BUILTIN_METRIC_ID_MEMORY_LOCAL_TOTAL,
+    IREE_HAL_PROFILE_BUILTIN_METRIC_ID_THROTTLE_STATUS,
+};
+
+// Linux sysfs scalar metric slots discovered once at profiling begin.
+typedef uint32_t iree_hal_amdgpu_profile_device_metrics_linux_sysfs_slot_t;
+enum iree_hal_amdgpu_profile_device_metrics_linux_sysfs_slot_e {
+  IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SLOT_GPU_BUSY_PERCENT = 0,
+  IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SLOT_MEMORY_BUSY_PERCENT =
+      1,
+  IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SLOT_MEMORY_LOCAL_USED = 2,
+  IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SLOT_MEMORY_LOCAL_TOTAL =
+      3,
+  IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SLOT_CLOCK_COMPUTE_CURRENT =
+      4,
+  IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SLOT_CLOCK_MEMORY_CURRENT =
+      5,
+  IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SLOT_TEMPERATURE_EDGE = 6,
+  IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SLOT_TEMPERATURE_HOTSPOT =
+      7,
+  IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SLOT_TEMPERATURE_MEMORY =
+      8,
+  IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SLOT_POWER_SOCKET = 9,
+  IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SLOT_COUNT =
+      IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SCALAR_FILE_CAPACITY,
+};
+
+typedef struct iree_hal_amdgpu_profile_metrics_table_header_t {
+  // Total byte size of the returned table.
+  uint16_t structure_size;
+
+  // Major gpu_metrics layout family.
+  uint8_t format_revision;
+
+  // Minor gpu_metrics layout revision within |format_revision|.
+  uint8_t content_revision;
+} iree_hal_amdgpu_profile_metrics_table_header_t;
+
+// Naturally aligned subset matching Linux gpu_metrics_v1_1 through v1_3.
+typedef struct iree_hal_amdgpu_profile_gpu_metrics_v1_3_t {
+  // Shared metrics table header.
+  iree_hal_amdgpu_profile_metrics_table_header_t common_header;
+
+  // Edge temperature in degrees Celsius.
+  uint16_t temperature_edge;
+
+  // Hotspot temperature in degrees Celsius.
+  uint16_t temperature_hotspot;
+
+  // Memory temperature in degrees Celsius.
+  uint16_t temperature_mem;
+
+  // VR gfx temperature in degrees Celsius.
+  uint16_t temperature_vrgfx;
+
+  // VR SoC temperature in degrees Celsius.
+  uint16_t temperature_vrsoc;
+
+  // VR memory temperature in degrees Celsius.
+  uint16_t temperature_vrmem;
+
+  // Average gfx activity percentage.
+  uint16_t average_gfx_activity;
+
+  // Average memory-controller activity percentage.
+  uint16_t average_umc_activity;
+
+  // Average media activity percentage.
+  uint16_t average_mm_activity;
+
+  // Average socket power in Watts.
+  uint16_t average_socket_power;
+
+  // Producer-defined energy accumulator.
+  uint64_t energy_accumulator;
+
+  // Driver-attached timestamp in nanoseconds.
+  uint64_t system_clock_counter;
+
+  // Average gfx clock in MHz.
+  uint16_t average_gfxclk_frequency;
+
+  // Average SoC clock in MHz.
+  uint16_t average_socclk_frequency;
+
+  // Average memory clock in MHz.
+  uint16_t average_uclk_frequency;
+
+  // Average VCLK0 clock in MHz.
+  uint16_t average_vclk0_frequency;
+
+  // Average DCLK0 clock in MHz.
+  uint16_t average_dclk0_frequency;
+
+  // Average VCLK1 clock in MHz.
+  uint16_t average_vclk1_frequency;
+
+  // Average DCLK1 clock in MHz.
+  uint16_t average_dclk1_frequency;
+
+  // Current gfx clock in MHz.
+  uint16_t current_gfxclk;
+
+  // Current SoC clock in MHz.
+  uint16_t current_socclk;
+
+  // Current memory clock in MHz.
+  uint16_t current_uclk;
+
+  // Current VCLK0 clock in MHz.
+  uint16_t current_vclk0;
+
+  // Current DCLK0 clock in MHz.
+  uint16_t current_dclk0;
+
+  // Current VCLK1 clock in MHz.
+  uint16_t current_vclk1;
+
+  // Current DCLK1 clock in MHz.
+  uint16_t current_dclk1;
+
+  // ASIC-dependent throttle status bitfield.
+  uint32_t throttle_status;
+
+  // Current fan speed in RPM.
+  uint16_t current_fan_speed;
+
+  // Current PCIe link width.
+  uint16_t pcie_link_width;
+
+  // Current PCIe link speed in tenths of GT/s.
+  uint16_t pcie_link_speed;
+
+  // Padding matching the kernel ABI.
+  uint16_t padding;
+
+  // Accumulated gfx activity.
+  uint32_t gfx_activity_acc;
+
+  // Accumulated memory activity.
+  uint32_t mem_activity_acc;
+
+  // HBM instance temperatures in degrees Celsius.
+  uint16_t temperature_hbm[4];
+
+  // PMFW timestamp with 10ns resolution.
+  uint64_t firmware_timestamp;
+
+  // SoC voltage in mV.
+  uint16_t voltage_soc;
+
+  // Gfx voltage in mV.
+  uint16_t voltage_gfx;
+
+  // Memory voltage in mV.
+  uint16_t voltage_mem;
+
+  // Padding matching the kernel ABI.
+  uint16_t padding1;
+
+  // ASIC-independent throttle status bitfield.
+  uint64_t indep_throttle_status;
+} iree_hal_amdgpu_profile_gpu_metrics_v1_3_t;
+
+typedef struct iree_hal_amdgpu_profile_linux_sysfs_scalar_metric_t {
+  // Built-in metric id emitted for this scalar file.
+  uint64_t metric_id;
+
+  // Multiplier applied to the parsed unsigned sysfs value.
+  uint64_t scale;
+} iree_hal_amdgpu_profile_linux_sysfs_scalar_metric_t;
+
+static iree_hal_amdgpu_profile_linux_sysfs_scalar_metric_t
+iree_hal_amdgpu_profile_device_metrics_linux_sysfs_scalar_metric(
+    iree_hal_amdgpu_profile_device_metrics_linux_sysfs_slot_t slot) {
+  switch (slot) {
+    case IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SLOT_GPU_BUSY_PERCENT:
+      return (iree_hal_amdgpu_profile_linux_sysfs_scalar_metric_t){
+          .metric_id = IREE_HAL_PROFILE_BUILTIN_METRIC_ID_ACTIVITY_COMPUTE,
+          .scale = 1000u,
+      };
+    case IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SLOT_MEMORY_BUSY_PERCENT:
+      return (iree_hal_amdgpu_profile_linux_sysfs_scalar_metric_t){
+          .metric_id = IREE_HAL_PROFILE_BUILTIN_METRIC_ID_ACTIVITY_MEMORY,
+          .scale = 1000u,
+      };
+    case IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SLOT_MEMORY_LOCAL_USED:
+      return (iree_hal_amdgpu_profile_linux_sysfs_scalar_metric_t){
+          .metric_id = IREE_HAL_PROFILE_BUILTIN_METRIC_ID_MEMORY_LOCAL_USED,
+          .scale = 1u,
+      };
+    case IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SLOT_MEMORY_LOCAL_TOTAL:
+      return (iree_hal_amdgpu_profile_linux_sysfs_scalar_metric_t){
+          .metric_id = IREE_HAL_PROFILE_BUILTIN_METRIC_ID_MEMORY_LOCAL_TOTAL,
+          .scale = 1u,
+      };
+    case IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SLOT_CLOCK_COMPUTE_CURRENT:
+      return (iree_hal_amdgpu_profile_linux_sysfs_scalar_metric_t){
+          .metric_id = IREE_HAL_PROFILE_BUILTIN_METRIC_ID_CLOCK_COMPUTE_CURRENT,
+          .scale = 1u,
+      };
+    case IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SLOT_CLOCK_MEMORY_CURRENT:
+      return (iree_hal_amdgpu_profile_linux_sysfs_scalar_metric_t){
+          .metric_id = IREE_HAL_PROFILE_BUILTIN_METRIC_ID_CLOCK_MEMORY_CURRENT,
+          .scale = 1u,
+      };
+    case IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SLOT_TEMPERATURE_EDGE:
+      return (iree_hal_amdgpu_profile_linux_sysfs_scalar_metric_t){
+          .metric_id = IREE_HAL_PROFILE_BUILTIN_METRIC_ID_TEMPERATURE_EDGE,
+          .scale = 1u,
+      };
+    case IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SLOT_TEMPERATURE_HOTSPOT:
+      return (iree_hal_amdgpu_profile_linux_sysfs_scalar_metric_t){
+          .metric_id = IREE_HAL_PROFILE_BUILTIN_METRIC_ID_TEMPERATURE_HOTSPOT,
+          .scale = 1u,
+      };
+    case IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SLOT_TEMPERATURE_MEMORY:
+      return (iree_hal_amdgpu_profile_linux_sysfs_scalar_metric_t){
+          .metric_id = IREE_HAL_PROFILE_BUILTIN_METRIC_ID_TEMPERATURE_MEMORY,
+          .scale = 1u,
+      };
+    case IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SLOT_POWER_SOCKET:
+      return (iree_hal_amdgpu_profile_linux_sysfs_scalar_metric_t){
+          .metric_id = IREE_HAL_PROFILE_BUILTIN_METRIC_ID_POWER_SOCKET,
+          .scale = 1u,
+      };
+    default:
+      return (iree_hal_amdgpu_profile_linux_sysfs_scalar_metric_t){0};
+  }
+}
+
+typedef struct iree_hal_amdgpu_profile_linux_sysfs_source_state_t {
+  // Discovered sysfs source identity.
+  struct {
+    // NUL-terminated sysfs PCI device path.
+    char device_path
+        [IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_MAX_PATH_LENGTH];
+
+    // Number of readable sysfs files discovered for this source.
+    uint32_t readable_file_count;
+  } discovery;
+
+  // Open sysfs file descriptors.
+  struct {
+    // Open gpu_metrics binary sysfs file descriptor, or -1.
+    int gpu_metrics;
+
+    // Open scalar sysfs file descriptors indexed by sysfs metric slot, or -1.
+    int scalars
+        [IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SCALAR_FILE_CAPACITY];
+  } files;
+} iree_hal_amdgpu_profile_linux_sysfs_source_state_t;
+
+//===----------------------------------------------------------------------===//
+// Linux sysfs discovery and sampling
+//===----------------------------------------------------------------------===//
+
+static void iree_hal_amdgpu_profile_device_metrics_linux_sysfs_reset(
+    iree_hal_amdgpu_profile_linux_sysfs_source_state_t* state) {
+  memset(state, 0, sizeof(*state));
+  state->files.gpu_metrics = -1;
+  for (iree_host_size_t i = 0; i < IREE_ARRAYSIZE(state->files.scalars); ++i) {
+    state->files.scalars[i] = -1;
+  }
+}
+
+static bool iree_hal_amdgpu_profile_device_metrics_u16_is_valid(
+    uint16_t value) {
+  return value != UINT16_MAX;
+}
+
+static bool iree_hal_amdgpu_profile_device_metrics_u32_is_valid(
+    uint32_t value) {
+  return value != UINT32_MAX;
+}
+
+static bool iree_hal_amdgpu_profile_device_metrics_u64_is_valid(
+    uint64_t value) {
+  return value != UINT64_MAX;
+}
+
+static bool iree_hal_amdgpu_profile_device_metrics_has_field(
+    iree_host_size_t storage_length, iree_host_size_t field_offset,
+    iree_host_size_t field_size) {
+  return field_offset <= storage_length &&
+         field_size <= storage_length - field_offset;
+}
+
+#define IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_HAS_FIELD(storage_length, type, \
+                                                         field)                \
+  iree_hal_amdgpu_profile_device_metrics_has_field(                            \
+      (storage_length), offsetof(type, field),                                 \
+      sizeof(((const type*)0)->field))
+
+static void iree_hal_amdgpu_profile_device_metrics_append_u16_scaled(
+    iree_hal_amdgpu_profile_device_metric_sample_builder_t* builder,
+    uint64_t metric_id, uint16_t value, uint64_t scale) {
+  if (iree_hal_amdgpu_profile_device_metrics_u16_is_valid(value)) {
+    iree_hal_amdgpu_profile_device_metric_sample_builder_append_u64(
+        builder, metric_id, (uint64_t)value * scale);
+  }
+}
+
+static void iree_hal_amdgpu_profile_device_metrics_append_i16_scaled(
+    iree_hal_amdgpu_profile_device_metric_sample_builder_t* builder,
+    uint64_t metric_id, uint16_t value, int64_t scale) {
+  if (iree_hal_amdgpu_profile_device_metrics_u16_is_valid(value)) {
+    iree_hal_amdgpu_profile_device_metric_sample_builder_append_i64(
+        builder, metric_id, (int64_t)value * scale);
+  }
+}
+
+static iree_status_t iree_hal_amdgpu_profile_device_metrics_format_path(
+    char* buffer, iree_host_size_t buffer_capacity, const char* format, ...) {
+  va_list varargs;
+  va_start(varargs, format);
+  const int result = vsnprintf(buffer, buffer_capacity, format, varargs);
+  va_end(varargs);
+  if (result < 0 || (iree_host_size_t)result >= buffer_capacity) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "AMDGPU metric sysfs path exceeds %" PRIhsz " bytes", buffer_capacity);
+  }
+  return iree_ok_status();
+}
+
+static bool iree_hal_amdgpu_profile_device_metrics_is_optional_file_error(
+    int error_code) {
+  return error_code == ENOENT || error_code == ENOTDIR || error_code == ENODEV;
+}
+
+static bool iree_hal_amdgpu_profile_device_metrics_is_unavailable_read_error(
+    int error_code) {
+  return error_code == EBUSY || error_code == EAGAIN || error_code == ENODATA ||
+         error_code == ENODEV;
+}
+
+static void iree_hal_amdgpu_profile_device_metrics_close_file(
+    int* file_descriptor) {
+  if (*file_descriptor >= 0) {
+    close(*file_descriptor);
+    *file_descriptor = -1;
+  }
+}
+
+static iree_status_t iree_hal_amdgpu_profile_device_metrics_open_optional_file(
+    const char* path, int* out_file_descriptor) {
+  *out_file_descriptor = -1;
+  const int file_descriptor = open(path, O_RDONLY | O_CLOEXEC);
+  if (file_descriptor >= 0) {
+    *out_file_descriptor = file_descriptor;
+    return iree_ok_status();
+  }
+
+  const int error_code = errno;
+  if (iree_hal_amdgpu_profile_device_metrics_is_optional_file_error(
+          error_code)) {
+    return iree_ok_status();
+  }
+  return iree_make_status(iree_status_code_from_errno(error_code),
+                          "failed to open AMDGPU metric sysfs file %s: %s",
+                          path, strerror(error_code));
+}
+
+static iree_status_t iree_hal_amdgpu_profile_device_metrics_read_file(
+    int file_descriptor, const char* name, uint8_t* buffer,
+    iree_host_size_t buffer_capacity, bool* out_available,
+    iree_host_size_t* out_length) {
+  *out_available = false;
+  *out_length = 0;
+  if (file_descriptor < 0) return iree_ok_status();
+
+  ssize_t read_length = 0;
+  do {
+    read_length = pread(file_descriptor, buffer, buffer_capacity, 0);
+  } while (read_length < 0 && errno == EINTR);
+  if (read_length < 0) {
+    const int error_code = errno;
+    if (iree_hal_amdgpu_profile_device_metrics_is_unavailable_read_error(
+            error_code)) {
+      return iree_ok_status();
+    }
+    return iree_make_status(iree_status_code_from_errno(error_code),
+                            "failed to read AMDGPU metric sysfs file %s: %s",
+                            name, strerror(error_code));
+  }
+  if (read_length == 0) return iree_ok_status();
+
+  *out_available = true;
+  *out_length = (iree_host_size_t)read_length;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_profile_device_metrics_read_uint64_file(
+    int file_descriptor, const char* name, bool* out_available,
+    uint64_t* out_value) {
+  *out_available = false;
+  *out_value = 0;
+  iree_hal_amdgpu_profile_device_metrics_text_buffer_t buffer = {0};
+  iree_host_size_t length = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_profile_device_metrics_read_file(
+      file_descriptor, name, buffer, sizeof(buffer) - 1, out_available,
+      &length));
+  if (!*out_available) return iree_ok_status();
+
+  buffer[length] = 0;
+  iree_string_view_t text =
+      iree_string_view_trim(iree_make_string_view((const char*)buffer, length));
+  if (!iree_string_view_atoi_uint64(text, out_value)) {
+    return iree_make_status(IREE_STATUS_DATA_LOSS,
+                            "failed to parse AMDGPU metric sysfs file %s as "
+                            "uint64: %.*s",
+                            name, (int)text.size, text.data);
+  }
+  return iree_ok_status();
+}
+
+static int* iree_hal_amdgpu_profile_device_metrics_scalar_file_descriptor(
+    iree_hal_amdgpu_profile_linux_sysfs_source_state_t* state,
+    iree_hal_amdgpu_profile_device_metrics_linux_sysfs_slot_t slot) {
+  return &state->files.scalars[slot];
+}
+
+static iree_status_t iree_hal_amdgpu_profile_device_metrics_open_device_file(
+    iree_hal_amdgpu_profile_linux_sysfs_source_state_t* state,
+    const char* file_name, int* out_file_descriptor) {
+  char
+      path[IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_MAX_PATH_LENGTH] =
+          {0};
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_profile_device_metrics_format_path(
+      path, sizeof(path), "%s/%s", state->discovery.device_path, file_name));
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_profile_device_metrics_open_optional_file(
+          path, out_file_descriptor));
+  if (*out_file_descriptor >= 0) ++state->discovery.readable_file_count;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_profile_device_metrics_read_label(
+    const char* directory_path, const char* prefix, uint32_t index,
+    char* buffer, iree_host_size_t buffer_capacity, bool* out_available,
+    iree_string_view_t* out_label) {
+  *out_label = iree_string_view_empty();
+  char
+      path[IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_MAX_PATH_LENGTH] =
+          {0};
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_profile_device_metrics_format_path(
+      path, sizeof(path), "%s/%s%" PRIu32 "_label", directory_path, prefix,
+      index));
+  int file_descriptor = -1;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_profile_device_metrics_open_optional_file(
+          path, &file_descriptor));
+  iree_host_size_t length = 0;
+  iree_status_t status = iree_hal_amdgpu_profile_device_metrics_read_file(
+      file_descriptor, path, (uint8_t*)buffer, buffer_capacity - 1,
+      out_available, &length);
+  iree_hal_amdgpu_profile_device_metrics_close_file(&file_descriptor);
+  if (!iree_status_is_ok(status) || !*out_available) return status;
+
+  buffer[length] = 0;
+  *out_label =
+      iree_string_view_trim(iree_make_string_view((const char*)buffer, length));
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_profile_device_metrics_open_hwmon_input(
+    iree_hal_amdgpu_profile_linux_sysfs_source_state_t* state,
+    iree_hal_amdgpu_profile_device_metrics_linux_sysfs_slot_t slot,
+    const char* directory_path, const char* prefix, uint32_t index,
+    const char* suffix) {
+  if (state->files.scalars[slot] >= 0) {
+    return iree_ok_status();
+  }
+  char
+      path[IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_MAX_PATH_LENGTH] =
+          {0};
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_profile_device_metrics_format_path(
+      path, sizeof(path), "%s/%s%" PRIu32 "_%s", directory_path, prefix, index,
+      suffix));
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_profile_device_metrics_open_optional_file(
+          path, &state->files.scalars[slot]));
+  if (state->files.scalars[slot] >= 0) {
+    ++state->discovery.readable_file_count;
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_profile_device_metrics_discover_hwmon_freq(
+    iree_hal_amdgpu_profile_linux_sysfs_source_state_t* state,
+    const char* directory_path) {
+  for (uint32_t i = 1; i <= 8; ++i) {
+    iree_hal_amdgpu_profile_device_metrics_text_buffer_t label_buffer;
+    bool label_available = false;
+    iree_string_view_t label = iree_string_view_empty();
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_profile_device_metrics_read_label(
+        directory_path, "freq", i, label_buffer, sizeof(label_buffer),
+        &label_available, &label));
+    if (!label_available) continue;
+
+    if (iree_string_view_equal(label, IREE_SV("sclk")) ||
+        iree_string_view_equal(label, IREE_SV("gfxclk"))) {
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_profile_device_metrics_open_hwmon_input(
+          state,
+          IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SLOT_CLOCK_COMPUTE_CURRENT,
+          directory_path, "freq", i, "input"));
+    } else if (iree_string_view_equal(label, IREE_SV("mclk")) ||
+               iree_string_view_equal(label, IREE_SV("uclk"))) {
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_profile_device_metrics_open_hwmon_input(
+          state,
+          IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SLOT_CLOCK_MEMORY_CURRENT,
+          directory_path, "freq", i, "input"));
+    }
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t
+iree_hal_amdgpu_profile_device_metrics_discover_hwmon_temperature(
+    iree_hal_amdgpu_profile_linux_sysfs_source_state_t* state,
+    const char* directory_path) {
+  for (uint32_t i = 1; i <= 16; ++i) {
+    iree_hal_amdgpu_profile_device_metrics_text_buffer_t label_buffer;
+    bool label_available = false;
+    iree_string_view_t label = iree_string_view_empty();
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_profile_device_metrics_read_label(
+        directory_path, "temp", i, label_buffer, sizeof(label_buffer),
+        &label_available, &label));
+    if (!label_available) continue;
+
+    if (iree_string_view_equal(label, IREE_SV("edge"))) {
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_profile_device_metrics_open_hwmon_input(
+          state,
+          IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SLOT_TEMPERATURE_EDGE,
+          directory_path, "temp", i, "input"));
+    } else if (iree_string_view_equal(label, IREE_SV("junction")) ||
+               iree_string_view_equal(label, IREE_SV("hotspot"))) {
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_profile_device_metrics_open_hwmon_input(
+          state,
+          IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SLOT_TEMPERATURE_HOTSPOT,
+          directory_path, "temp", i, "input"));
+    } else if (iree_string_view_equal(label, IREE_SV("mem")) ||
+               iree_string_view_equal(label, IREE_SV("memory"))) {
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_profile_device_metrics_open_hwmon_input(
+          state,
+          IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SLOT_TEMPERATURE_MEMORY,
+          directory_path, "temp", i, "input"));
+    }
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t
+iree_hal_amdgpu_profile_device_metrics_discover_hwmon_power(
+    iree_hal_amdgpu_profile_linux_sysfs_source_state_t* state,
+    const char* directory_path) {
+  return iree_hal_amdgpu_profile_device_metrics_open_hwmon_input(
+      state,
+      IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SLOT_POWER_SOCKET,
+      directory_path, "power", 1, "average");
+}
+
+static bool iree_hal_amdgpu_profile_device_metrics_is_hwmon_dirent(
+    const char* name) {
+  if (strncmp(name, "hwmon", 5) != 0) return false;
+  for (const char* cursor = name + 5; *cursor; ++cursor) {
+    if (*cursor < '0' || *cursor > '9') return false;
+  }
+  return true;
+}
+
+static iree_status_t iree_hal_amdgpu_profile_device_metrics_discover_hwmon(
+    iree_hal_amdgpu_profile_linux_sysfs_source_state_t* state) {
+  char directory_path
+      [IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_MAX_PATH_LENGTH] = {
+          0};
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_profile_device_metrics_format_path(
+      directory_path, sizeof(directory_path), "%s/hwmon",
+      state->discovery.device_path));
+
+  DIR* directory = opendir(directory_path);
+  if (!directory) {
+    const int error_code = errno;
+    if (iree_hal_amdgpu_profile_device_metrics_is_optional_file_error(
+            error_code)) {
+      return iree_ok_status();
+    }
+    return iree_make_status(iree_status_code_from_errno(error_code),
+                            "failed to open AMDGPU hwmon directory %s: %s",
+                            directory_path, strerror(error_code));
+  }
+
+  iree_status_t status = iree_ok_status();
+  while (iree_status_is_ok(status)) {
+    errno = 0;
+    struct dirent* entry = readdir(directory);
+    if (!entry) {
+      if (errno != 0) {
+        const int error_code = errno;
+        status = iree_make_status(iree_status_code_from_errno(error_code),
+                                  "failed to enumerate AMDGPU hwmon directory "
+                                  "%s: %s",
+                                  directory_path, strerror(error_code));
+      }
+      break;
+    }
+    if (!iree_hal_amdgpu_profile_device_metrics_is_hwmon_dirent(
+            entry->d_name)) {
+      continue;
+    }
+    char hwmon_path
+        [IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_MAX_PATH_LENGTH] = {
+            0};
+    status = iree_hal_amdgpu_profile_device_metrics_format_path(
+        hwmon_path, sizeof(hwmon_path), "%s/%s", directory_path, entry->d_name);
+    if (iree_status_is_ok(status)) {
+      status = iree_hal_amdgpu_profile_device_metrics_discover_hwmon_freq(
+          state, hwmon_path);
+    }
+    if (iree_status_is_ok(status)) {
+      status =
+          iree_hal_amdgpu_profile_device_metrics_discover_hwmon_temperature(
+              state, hwmon_path);
+    }
+    if (iree_status_is_ok(status)) {
+      status = iree_hal_amdgpu_profile_device_metrics_discover_hwmon_power(
+          state, hwmon_path);
+    }
+  }
+  closedir(directory);
+  return status;
+}
+
+static void iree_hal_amdgpu_profile_device_metrics_parse_gpu_metrics_v1_3(
+    const uint8_t* storage, iree_host_size_t storage_length,
+    iree_hal_amdgpu_profile_device_metric_sample_builder_t* builder) {
+  const iree_hal_amdgpu_profile_gpu_metrics_v1_3_t* metrics =
+      (const iree_hal_amdgpu_profile_gpu_metrics_v1_3_t*)(const void*)storage;
+
+  if (IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_HAS_FIELD(
+          storage_length, iree_hal_amdgpu_profile_gpu_metrics_v1_3_t,
+          temperature_edge)) {
+    iree_hal_amdgpu_profile_device_metrics_append_i16_scaled(
+        builder, IREE_HAL_PROFILE_BUILTIN_METRIC_ID_TEMPERATURE_EDGE,
+        metrics->temperature_edge, 1000);
+  }
+  if (IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_HAS_FIELD(
+          storage_length, iree_hal_amdgpu_profile_gpu_metrics_v1_3_t,
+          temperature_hotspot)) {
+    iree_hal_amdgpu_profile_device_metrics_append_i16_scaled(
+        builder, IREE_HAL_PROFILE_BUILTIN_METRIC_ID_TEMPERATURE_HOTSPOT,
+        metrics->temperature_hotspot, 1000);
+  }
+  if (IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_HAS_FIELD(
+          storage_length, iree_hal_amdgpu_profile_gpu_metrics_v1_3_t,
+          temperature_mem)) {
+    iree_hal_amdgpu_profile_device_metrics_append_i16_scaled(
+        builder, IREE_HAL_PROFILE_BUILTIN_METRIC_ID_TEMPERATURE_MEMORY,
+        metrics->temperature_mem, 1000);
+  }
+  if (IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_HAS_FIELD(
+          storage_length, iree_hal_amdgpu_profile_gpu_metrics_v1_3_t,
+          average_gfx_activity)) {
+    iree_hal_amdgpu_profile_device_metrics_append_u16_scaled(
+        builder, IREE_HAL_PROFILE_BUILTIN_METRIC_ID_ACTIVITY_COMPUTE,
+        metrics->average_gfx_activity, 1000u);
+  }
+  if (IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_HAS_FIELD(
+          storage_length, iree_hal_amdgpu_profile_gpu_metrics_v1_3_t,
+          average_umc_activity)) {
+    iree_hal_amdgpu_profile_device_metrics_append_u16_scaled(
+        builder, IREE_HAL_PROFILE_BUILTIN_METRIC_ID_ACTIVITY_MEMORY,
+        metrics->average_umc_activity, 1000u);
+  }
+  if (IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_HAS_FIELD(
+          storage_length, iree_hal_amdgpu_profile_gpu_metrics_v1_3_t,
+          average_socket_power)) {
+    iree_hal_amdgpu_profile_device_metrics_append_u16_scaled(
+        builder, IREE_HAL_PROFILE_BUILTIN_METRIC_ID_POWER_SOCKET,
+        metrics->average_socket_power, 1000000u);
+  }
+  if (IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_HAS_FIELD(
+          storage_length, iree_hal_amdgpu_profile_gpu_metrics_v1_3_t,
+          current_gfxclk)) {
+    iree_hal_amdgpu_profile_device_metrics_append_u16_scaled(
+        builder, IREE_HAL_PROFILE_BUILTIN_METRIC_ID_CLOCK_COMPUTE_CURRENT,
+        metrics->current_gfxclk, 1000000u);
+  }
+  if (IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_HAS_FIELD(
+          storage_length, iree_hal_amdgpu_profile_gpu_metrics_v1_3_t,
+          current_uclk)) {
+    iree_hal_amdgpu_profile_device_metrics_append_u16_scaled(
+        builder, IREE_HAL_PROFILE_BUILTIN_METRIC_ID_CLOCK_MEMORY_CURRENT,
+        metrics->current_uclk, 1000000u);
+  }
+  if (IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_HAS_FIELD(
+          storage_length, iree_hal_amdgpu_profile_gpu_metrics_v1_3_t,
+          indep_throttle_status) &&
+      iree_hal_amdgpu_profile_device_metrics_u64_is_valid(
+          metrics->indep_throttle_status)) {
+    iree_hal_amdgpu_profile_device_metric_sample_builder_append_u64(
+        builder, IREE_HAL_PROFILE_BUILTIN_METRIC_ID_THROTTLE_STATUS,
+        metrics->indep_throttle_status);
+  } else if (IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_HAS_FIELD(
+                 storage_length, iree_hal_amdgpu_profile_gpu_metrics_v1_3_t,
+                 throttle_status) &&
+             iree_hal_amdgpu_profile_device_metrics_u32_is_valid(
+                 metrics->throttle_status)) {
+    iree_hal_amdgpu_profile_device_metric_sample_builder_append_u64(
+        builder, IREE_HAL_PROFILE_BUILTIN_METRIC_ID_THROTTLE_STATUS,
+        metrics->throttle_status);
+  }
+}
+
+#undef IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_HAS_FIELD
+
+static iree_status_t iree_hal_amdgpu_profile_device_metrics_parse_gpu_metrics(
+    const uint8_t* storage, iree_host_size_t storage_length,
+    iree_hal_amdgpu_profile_device_metric_sample_builder_t* builder) {
+  if (storage_length < sizeof(iree_hal_amdgpu_profile_metrics_table_header_t)) {
+    builder->record.flags |= IREE_HAL_PROFILE_DEVICE_METRIC_SAMPLE_FLAG_PARTIAL;
+    return iree_ok_status();
+  }
+  const iree_hal_amdgpu_profile_metrics_table_header_t* header =
+      (const iree_hal_amdgpu_profile_metrics_table_header_t*)(const void*)
+          storage;
+  if (header->structure_size <
+      sizeof(iree_hal_amdgpu_profile_metrics_table_header_t)) {
+    builder->record.flags |= IREE_HAL_PROFILE_DEVICE_METRIC_SAMPLE_FLAG_PARTIAL;
+    return iree_ok_status();
+  }
+
+  const iree_host_size_t table_length =
+      iree_min(storage_length, (iree_host_size_t)header->structure_size);
+  if (header->format_revision == 1 && header->content_revision >= 1 &&
+      header->content_revision <= 3) {
+    iree_hal_amdgpu_profile_device_metrics_parse_gpu_metrics_v1_3(
+        storage, table_length, builder);
+    return iree_ok_status();
+  }
+
+  builder->record.flags |= IREE_HAL_PROFILE_DEVICE_METRIC_SAMPLE_FLAG_PARTIAL;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_profile_device_metrics_probe_gpu_metrics(
+    iree_hal_amdgpu_profile_linux_sysfs_source_state_t* state,
+    uint32_t* out_source_revision) {
+  iree_alignas(8)
+      uint8_t storage[sizeof(iree_hal_amdgpu_profile_metrics_table_header_t)] =
+          {0};
+  bool available = false;
+  iree_host_size_t storage_length = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_profile_device_metrics_read_file(
+      state->files.gpu_metrics, "gpu_metrics", storage, sizeof(storage),
+      &available, &storage_length));
+  if (!available ||
+      storage_length < sizeof(iree_hal_amdgpu_profile_metrics_table_header_t)) {
+    return iree_ok_status();
+  }
+
+  const iree_hal_amdgpu_profile_metrics_table_header_t* header =
+      (const iree_hal_amdgpu_profile_metrics_table_header_t*)(const void*)
+          storage;
+  *out_source_revision =
+      ((uint32_t)header->format_revision << 8) | header->content_revision;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_profile_device_metrics_sample_gpu_metrics(
+    iree_hal_amdgpu_profile_linux_sysfs_source_state_t* state,
+    iree_hal_amdgpu_profile_device_metric_sample_builder_t* builder) {
+  iree_alignas(8) uint8_t
+      storage[IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_GPU_BUFFER_LENGTH] = {0};
+  bool available = false;
+  iree_host_size_t storage_length = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_profile_device_metrics_read_file(
+      state->files.gpu_metrics, "gpu_metrics", storage, sizeof(storage),
+      &available, &storage_length));
+  if (!available) {
+    if (state->files.gpu_metrics >= 0) {
+      builder->record.flags |=
+          IREE_HAL_PROFILE_DEVICE_METRIC_SAMPLE_FLAG_PARTIAL;
+    }
+    return iree_ok_status();
+  }
+  return iree_hal_amdgpu_profile_device_metrics_parse_gpu_metrics(
+      storage, storage_length, builder);
+}
+
+static iree_status_t iree_hal_amdgpu_profile_device_metrics_sample_scalars(
+    iree_hal_amdgpu_profile_linux_sysfs_source_state_t* state,
+    iree_hal_amdgpu_profile_device_metric_sample_builder_t* builder) {
+  for (iree_host_size_t i = 0;
+       i < IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SLOT_COUNT; ++i) {
+    const int file_descriptor = state->files.scalars[i];
+    if (file_descriptor < 0) continue;
+
+    bool available = false;
+    uint64_t value = 0;
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_profile_device_metrics_read_uint64_file(
+            file_descriptor, "scalar metric", &available, &value));
+    if (!available) {
+      builder->record.flags |=
+          IREE_HAL_PROFILE_DEVICE_METRIC_SAMPLE_FLAG_PARTIAL;
+      continue;
+    }
+    const iree_hal_amdgpu_profile_linux_sysfs_scalar_metric_t scalar_metric =
+        iree_hal_amdgpu_profile_device_metrics_linux_sysfs_scalar_metric(
+            (iree_hal_amdgpu_profile_device_metrics_linux_sysfs_slot_t)i);
+    iree_hal_amdgpu_profile_device_metric_sample_builder_append_u64(
+        builder, scalar_metric.metric_id, value * scalar_metric.scale);
+  }
+  return iree_ok_status();
+}
+
+iree_status_t iree_hal_amdgpu_profile_device_metric_source_initialize(
+    iree_hal_amdgpu_physical_device_t* physical_device,
+    iree_allocator_t host_allocator,
+    iree_hal_amdgpu_profile_device_metric_source_t* out_source) {
+  memset(out_source, 0, sizeof(*out_source));
+  if (IREE_UNLIKELY(physical_device->device_ordinal > UINT32_MAX)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "AMDGPU device metric source physical device ordinal exceeds uint32_t");
+  }
+  if (!physical_device->has_pci_identity) {
+    return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                            "AMDGPU device metrics require PCI identity for "
+                            "physical device %" PRIhsz,
+                            physical_device->device_ordinal);
+  }
+
+  iree_hal_amdgpu_profile_linux_sysfs_source_state_t* state = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_allocator_malloc(host_allocator, sizeof(*state), (void**)&state));
+  iree_hal_amdgpu_profile_device_metrics_linux_sysfs_reset(state);
+  out_source->platform.host_allocator = host_allocator;
+  out_source->platform.state = state;
+  out_source->metadata.id = (uint64_t)physical_device->device_ordinal + 1u;
+  out_source->metadata.kind =
+      IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_SOURCE_KIND_LINUX_SYSFS;
+  out_source->metadata.physical_device_ordinal =
+      (uint32_t)physical_device->device_ordinal;
+  out_source->metrics.count =
+      IREE_ARRAYSIZE(iree_hal_amdgpu_profile_device_metrics_metric_ids);
+  out_source->metrics.ids = iree_hal_amdgpu_profile_device_metrics_metric_ids;
+  out_source->sampling.next_sample_id = 1;
+
+  iree_status_t status = iree_hal_amdgpu_profile_device_metrics_format_path(
+      state->discovery.device_path, sizeof(state->discovery.device_path),
+      "/sys/bus/pci/devices/%04" PRIx32 ":%02" PRIx32 ":%02" PRIx32 ".%" PRIu32,
+      physical_device->pci_domain, physical_device->pci_bus,
+      physical_device->pci_device, physical_device->pci_function);
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_profile_device_metrics_format_path(
+        out_source->metadata.name, sizeof(out_source->metadata.name),
+        "amdgpu.sysfs.%04" PRIx32 ":%02" PRIx32 ":%02" PRIx32 ".%" PRIu32,
+        physical_device->pci_domain, physical_device->pci_bus,
+        physical_device->pci_device, physical_device->pci_function);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_profile_device_metrics_open_device_file(
+        state, "gpu_metrics", &state->files.gpu_metrics);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_profile_device_metrics_probe_gpu_metrics(
+        state, &out_source->metadata.revision);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_profile_device_metrics_open_device_file(
+        state, "gpu_busy_percent",
+        iree_hal_amdgpu_profile_device_metrics_scalar_file_descriptor(
+            state,
+            IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SLOT_GPU_BUSY_PERCENT));
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_profile_device_metrics_open_device_file(
+        state, "mem_busy_percent",
+        iree_hal_amdgpu_profile_device_metrics_scalar_file_descriptor(
+            state,
+            IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SLOT_MEMORY_BUSY_PERCENT));
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_profile_device_metrics_open_device_file(
+        state, "mem_info_vram_used",
+        iree_hal_amdgpu_profile_device_metrics_scalar_file_descriptor(
+            state,
+            IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SLOT_MEMORY_LOCAL_USED));
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_profile_device_metrics_open_device_file(
+        state, "mem_info_vram_total",
+        iree_hal_amdgpu_profile_device_metrics_scalar_file_descriptor(
+            state,
+            IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_LINUX_SYSFS_SLOT_MEMORY_LOCAL_TOTAL));
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_profile_device_metrics_discover_hwmon(state);
+  }
+  if (iree_status_is_ok(status) && state->discovery.readable_file_count == 0) {
+    status = iree_make_status(IREE_STATUS_NOT_FOUND,
+                              "AMDGPU device metrics found no readable sysfs "
+                              "files under %s",
+                              state->discovery.device_path);
+  }
+  if (!iree_status_is_ok(status)) {
+    iree_hal_amdgpu_profile_device_metric_source_deinitialize(out_source);
+  }
+  return status;
+}
+
+void iree_hal_amdgpu_profile_device_metric_source_deinitialize(
+    iree_hal_amdgpu_profile_device_metric_source_t* source) {
+  if (!source) return;
+  iree_hal_amdgpu_profile_linux_sysfs_source_state_t* state =
+      (iree_hal_amdgpu_profile_linux_sysfs_source_state_t*)
+          source->platform.state;
+  if (!state) {
+    memset(source, 0, sizeof(*source));
+    return;
+  }
+  iree_hal_amdgpu_profile_device_metrics_close_file(&state->files.gpu_metrics);
+  for (iree_host_size_t i = 0; i < IREE_ARRAYSIZE(state->files.scalars); ++i) {
+    iree_hal_amdgpu_profile_device_metrics_close_file(&state->files.scalars[i]);
+  }
+  iree_allocator_free(source->platform.host_allocator, state);
+  memset(source, 0, sizeof(*source));
+}
+
+iree_status_t iree_hal_amdgpu_profile_device_metric_source_sample(
+    iree_hal_amdgpu_profile_device_metric_source_t* source,
+    iree_hal_amdgpu_profile_device_metric_sample_builder_t* builder) {
+  iree_hal_amdgpu_profile_linux_sysfs_source_state_t* state =
+      (iree_hal_amdgpu_profile_linux_sysfs_source_state_t*)
+          source->platform.state;
+  if (IREE_UNLIKELY(!state)) {
+    return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                            "AMDGPU device metric source is uninitialized");
+  }
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_profile_device_metrics_sample_gpu_metrics(state,
+                                                                builder));
+  return iree_hal_amdgpu_profile_device_metrics_sample_scalars(state, builder);
+}
+
+#else
+
+iree_status_t iree_hal_amdgpu_profile_device_metric_source_initialize(
+    iree_hal_amdgpu_physical_device_t* physical_device,
+    iree_allocator_t host_allocator,
+    iree_hal_amdgpu_profile_device_metric_source_t* out_source) {
+  (void)physical_device;
+  (void)host_allocator;
+  memset(out_source, 0, sizeof(*out_source));
+  return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                          "AMDGPU device metrics require Linux sysfs support");
+}
+
+void iree_hal_amdgpu_profile_device_metric_source_deinitialize(
+    iree_hal_amdgpu_profile_device_metric_source_t* source) {
+  if (!source) return;
+  memset(source, 0, sizeof(*source));
+}
+
+iree_status_t iree_hal_amdgpu_profile_device_metric_source_sample(
+    iree_hal_amdgpu_profile_device_metric_source_t* source,
+    iree_hal_amdgpu_profile_device_metric_sample_builder_t* builder) {
+  (void)source;
+  (void)builder;
+  return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                          "AMDGPU device metrics require Linux sysfs support");
+}
+
+#endif  // IREE_PLATFORM_LINUX
diff --git a/runtime/src/iree/hal/drivers/amdgpu/profile_device_metrics_source.h b/runtime/src/iree/hal/drivers/amdgpu/profile_device_metrics_source.h
new file mode 100644
index 0000000..a78e4de
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/profile_device_metrics_source.h
@@ -0,0 +1,106 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_PROFILE_DEVICE_METRICS_SOURCE_H_
+#define IREE_HAL_DRIVERS_AMDGPU_PROFILE_DEVICE_METRICS_SOURCE_H_
+
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+typedef struct iree_hal_amdgpu_physical_device_t
+    iree_hal_amdgpu_physical_device_t;
+
+#define IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_MAX_NAME_LENGTH 64
+#define IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_MAX_SAMPLE_VALUES 16
+
+// Per-physical-device source sampled by a device-metrics session.
+typedef struct iree_hal_amdgpu_profile_device_metric_source_t {
+  // Profile metadata describing this source.
+  struct {
+    // Producer-defined source id unique within the profiling session.
+    uint64_t id;
+
+    // Session-local physical device ordinal sampled by this source.
+    uint32_t physical_device_ordinal;
+
+    // Producer-defined source kind written to source metadata.
+    uint32_t kind;
+
+    // Source revision derived from the backing implementation.
+    uint32_t revision;
+
+    // Human-readable source name stored in source metadata.
+    char name[IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_MAX_NAME_LENGTH];
+  } metadata;
+
+  // Built-in metric ids this source can emit.
+  struct {
+    // Number of entries in |ids|.
+    iree_host_size_t count;
+
+    // Pointer to static built-in metric ids emitted by the source.
+    const uint64_t* ids;
+  } metrics;
+
+  // Mutable sampling state.
+  struct {
+    // Next nonzero sample id emitted for this source.
+    uint64_t next_sample_id;
+  } sampling;
+
+  // Source implementation state.
+  struct {
+    // Host allocator used for |state| when owned by the source.
+    iree_allocator_t host_allocator;
+
+    // Opaque implementation-owned state.
+    void* state;
+  } platform;
+} iree_hal_amdgpu_profile_device_metric_source_t;
+
+// Builder for one packed device metric sample.
+typedef struct iree_hal_amdgpu_profile_device_metric_sample_builder_t {
+  // Sample record header being populated.
+  iree_hal_profile_device_metric_sample_record_t record;
+
+  // Fixed value storage written immediately after |record|.
+  iree_hal_profile_device_metric_value_t
+      values[IREE_HAL_AMDGPU_PROFILE_DEVICE_METRICS_MAX_SAMPLE_VALUES];
+} iree_hal_amdgpu_profile_device_metric_sample_builder_t;
+
+// Appends |value| to |builder| unless the metric is already present.
+void iree_hal_amdgpu_profile_device_metric_sample_builder_append_u64(
+    iree_hal_amdgpu_profile_device_metric_sample_builder_t* builder,
+    uint64_t metric_id, uint64_t value);
+
+// Appends signed |value| to |builder| unless the metric is already present.
+void iree_hal_amdgpu_profile_device_metric_sample_builder_append_i64(
+    iree_hal_amdgpu_profile_device_metric_sample_builder_t* builder,
+    uint64_t metric_id, int64_t value);
+
+// Initializes a per-device metric source for |physical_device|.
+iree_status_t iree_hal_amdgpu_profile_device_metric_source_initialize(
+    iree_hal_amdgpu_physical_device_t* physical_device,
+    iree_allocator_t host_allocator,
+    iree_hal_amdgpu_profile_device_metric_source_t* out_source);
+
+// Deinitializes |source| and releases platform handles it owns.
+void iree_hal_amdgpu_profile_device_metric_source_deinitialize(
+    iree_hal_amdgpu_profile_device_metric_source_t* source);
+
+// Samples |source| into |builder|.
+iree_status_t iree_hal_amdgpu_profile_device_metric_source_sample(
+    iree_hal_amdgpu_profile_device_metric_source_t* source,
+    iree_hal_amdgpu_profile_device_metric_sample_builder_t* builder);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_PROFILE_DEVICE_METRICS_SOURCE_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/profile_events.c b/runtime/src/iree/hal/drivers/amdgpu/profile_events.c
new file mode 100644
index 0000000..beb7dc8
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/profile_events.c
@@ -0,0 +1,254 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/profile_events.h"
+
+static void iree_hal_amdgpu_profile_event_stream_initialize(
+    iree_hal_amdgpu_profile_event_stream_t* stream) {
+  iree_slim_mutex_initialize(&stream->mutex);
+}
+
+static void iree_hal_amdgpu_profile_event_stream_deallocate(
+    iree_hal_amdgpu_profile_event_stream_t* stream,
+    iree_allocator_t host_allocator) {
+  iree_allocator_free(host_allocator, stream->ring.records);
+  memset(&stream->ring, 0, sizeof(stream->ring));
+}
+
+static void iree_hal_amdgpu_profile_event_stream_deinitialize(
+    iree_hal_amdgpu_profile_event_stream_t* stream,
+    iree_allocator_t host_allocator) {
+  iree_hal_amdgpu_profile_event_stream_deallocate(stream, host_allocator);
+  iree_slim_mutex_deinitialize(&stream->mutex);
+}
+
+static iree_status_t iree_hal_amdgpu_profile_event_stream_ensure_storage(
+    iree_hal_amdgpu_profile_event_stream_t* stream, iree_host_size_t event_size,
+    iree_host_size_t event_capacity, iree_allocator_t host_allocator) {
+  if (stream->ring.records) return iree_ok_status();
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, event_capacity);
+
+  if (IREE_UNLIKELY(event_capacity == 0 ||
+                    !iree_host_size_is_power_of_two(event_capacity))) {
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                             "AMDGPU profile event stream capacity must be a "
+                             "non-zero power of two (got %" PRIhsz ")",
+                             event_capacity));
+  }
+
+  iree_host_size_t storage_size = 0;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_host_size_checked_mul(event_capacity, event_size, &storage_size)
+              ? iree_ok_status()
+              : iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                                 "AMDGPU profile event stream storage size "
+                                 "overflows"));
+
+  void* events = NULL;
+  iree_status_t status =
+      iree_allocator_malloc(host_allocator, storage_size, &events);
+  if (iree_status_is_ok(status)) {
+    memset(events, 0, storage_size);
+    iree_hal_profile_event_ring_initialize(events, event_size, event_capacity,
+                                           &stream->ring);
+  }
+
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+static void iree_hal_amdgpu_profile_event_stream_clear(
+    iree_hal_amdgpu_profile_event_stream_t* stream) {
+  iree_slim_mutex_lock(&stream->mutex);
+  iree_hal_profile_event_ring_clear(&stream->ring);
+  iree_slim_mutex_unlock(&stream->mutex);
+}
+
+static iree_status_t iree_hal_amdgpu_profile_event_stream_write(
+    iree_hal_amdgpu_profile_event_stream_t* stream,
+    iree_hal_profile_sink_t* sink, iree_hal_profile_chunk_metadata_t metadata,
+    iree_allocator_t host_allocator) {
+  (void)host_allocator;
+  if (!sink || !stream->ring.records) return iree_ok_status();
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  iree_hal_profile_event_ring_snapshot_t snapshot;
+  iree_slim_mutex_lock(&stream->mutex);
+  iree_status_t status =
+      iree_hal_profile_event_ring_snapshot(&stream->ring, &snapshot);
+  iree_slim_mutex_unlock(&stream->mutex);
+
+  if (iree_status_is_ok(status) &&
+      (snapshot.record_count != 0 || snapshot.dropped_record_count != 0)) {
+    if (snapshot.dropped_record_count != 0) {
+      metadata.flags |= IREE_HAL_PROFILE_CHUNK_FLAG_TRUNCATED;
+      metadata.dropped_record_count = snapshot.dropped_record_count;
+    }
+
+    status = iree_hal_profile_sink_write(
+        sink, &metadata, snapshot.record_span_count,
+        snapshot.record_span_count ? snapshot.record_spans : NULL);
+    if (iree_status_is_ok(status)) {
+      iree_slim_mutex_lock(&stream->mutex);
+      iree_hal_profile_event_ring_commit_snapshot(&stream->ring, &snapshot);
+      iree_slim_mutex_unlock(&stream->mutex);
+    }
+  }
+
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+void iree_hal_amdgpu_profile_event_streams_initialize(
+    iree_hal_amdgpu_profile_event_streams_t* streams) {
+  IREE_ASSERT_ARGUMENT(streams);
+  iree_hal_amdgpu_profile_event_stream_initialize(&streams->memory.stream);
+  iree_hal_amdgpu_profile_event_stream_initialize(&streams->queue.stream);
+}
+
+void iree_hal_amdgpu_profile_event_streams_deinitialize(
+    iree_hal_amdgpu_profile_event_streams_t* streams,
+    iree_allocator_t host_allocator) {
+  if (!streams) return;
+  iree_hal_amdgpu_profile_event_stream_deinitialize(&streams->memory.stream,
+                                                    host_allocator);
+  streams->memory.next_allocation_id = 0;
+  iree_hal_amdgpu_profile_event_stream_deinitialize(&streams->queue.stream,
+                                                    host_allocator);
+}
+
+bool iree_hal_amdgpu_profile_event_streams_has_memory_storage(
+    const iree_hal_amdgpu_profile_event_streams_t* streams) {
+  return streams->memory.stream.ring.records != NULL;
+}
+
+bool iree_hal_amdgpu_profile_event_streams_has_queue_storage(
+    const iree_hal_amdgpu_profile_event_streams_t* streams) {
+  return streams->queue.stream.ring.records != NULL;
+}
+
+iree_status_t iree_hal_amdgpu_profile_event_streams_ensure_memory_storage(
+    iree_hal_amdgpu_profile_event_streams_t* streams,
+    iree_host_size_t event_capacity, iree_allocator_t host_allocator) {
+  return iree_hal_amdgpu_profile_event_stream_ensure_storage(
+      &streams->memory.stream, sizeof(iree_hal_profile_memory_event_t),
+      event_capacity, host_allocator);
+}
+
+iree_status_t iree_hal_amdgpu_profile_event_streams_ensure_queue_storage(
+    iree_hal_amdgpu_profile_event_streams_t* streams,
+    iree_host_size_t event_capacity, iree_allocator_t host_allocator) {
+  return iree_hal_amdgpu_profile_event_stream_ensure_storage(
+      &streams->queue.stream, sizeof(iree_hal_profile_queue_event_t),
+      event_capacity, host_allocator);
+}
+
+void iree_hal_amdgpu_profile_event_streams_clear_memory(
+    iree_hal_amdgpu_profile_event_streams_t* streams) {
+  iree_hal_amdgpu_profile_event_stream_clear(&streams->memory.stream);
+  streams->memory.next_allocation_id = 1;
+}
+
+void iree_hal_amdgpu_profile_event_streams_clear_queue(
+    iree_hal_amdgpu_profile_event_streams_t* streams) {
+  iree_hal_amdgpu_profile_event_stream_clear(&streams->queue.stream);
+}
+
+uint64_t iree_hal_amdgpu_profile_event_streams_allocate_memory_allocation_id(
+    iree_hal_amdgpu_profile_event_streams_t* streams,
+    uint64_t active_session_id, uint64_t* out_session_id) {
+  *out_session_id = 0;
+  iree_slim_mutex_lock(&streams->memory.stream.mutex);
+  *out_session_id = active_session_id;
+  const uint64_t allocation_id = streams->memory.next_allocation_id++;
+  iree_slim_mutex_unlock(&streams->memory.stream.mutex);
+  return allocation_id;
+}
+
+bool iree_hal_amdgpu_profile_event_streams_record_memory_event(
+    iree_hal_amdgpu_profile_event_streams_t* streams,
+    uint64_t active_session_id, uint64_t session_id,
+    const iree_hal_profile_memory_event_t* event) {
+  bool recorded = false;
+  iree_hal_amdgpu_profile_event_stream_t* stream = &streams->memory.stream;
+  if (!stream->ring.records) return false;
+  iree_slim_mutex_lock(&stream->mutex);
+  const bool session_matches =
+      session_id == 0 || active_session_id == session_id;
+  if (session_matches) {
+    uint64_t event_position = 0;
+    uint64_t event_id = 0;
+    if (iree_hal_profile_event_ring_try_append(&stream->ring, &event_position,
+                                               &event_id)) {
+      iree_hal_profile_memory_event_t record = *event;
+      record.record_length = sizeof(record);
+      record.event_id = event_id;
+      if (record.host_time_ns == 0) {
+        record.host_time_ns = iree_time_now();
+      }
+      iree_hal_profile_memory_event_t* target =
+          (iree_hal_profile_memory_event_t*)
+              iree_hal_profile_event_ring_record_at(&stream->ring,
+                                                    event_position);
+      *target = record;
+      recorded = true;
+    }
+  }
+  iree_slim_mutex_unlock(&stream->mutex);
+  return recorded;
+}
+
+void iree_hal_amdgpu_profile_event_streams_record_queue_event(
+    iree_hal_amdgpu_profile_event_streams_t* streams,
+    const iree_hal_profile_queue_event_t* event) {
+  iree_hal_amdgpu_profile_event_stream_t* stream = &streams->queue.stream;
+  if (!stream->ring.records) return;
+  iree_slim_mutex_lock(&stream->mutex);
+  uint64_t event_position = 0;
+  uint64_t event_id = 0;
+  if (iree_hal_profile_event_ring_try_append(&stream->ring, &event_position,
+                                             &event_id)) {
+    iree_hal_profile_queue_event_t record = *event;
+    record.record_length = sizeof(record);
+    record.event_id = event_id;
+    if (record.host_time_ns == 0) {
+      record.host_time_ns = iree_time_now();
+    }
+    iree_hal_profile_queue_event_t* target =
+        (iree_hal_profile_queue_event_t*)iree_hal_profile_event_ring_record_at(
+            &stream->ring, event_position);
+    *target = record;
+  }
+  iree_slim_mutex_unlock(&stream->mutex);
+}
+
+iree_status_t iree_hal_amdgpu_profile_event_streams_write_memory(
+    iree_hal_amdgpu_profile_event_streams_t* streams,
+    iree_hal_profile_sink_t* sink, uint64_t session_id,
+    iree_allocator_t host_allocator) {
+  iree_hal_profile_chunk_metadata_t metadata =
+      iree_hal_profile_chunk_metadata_default();
+  metadata.content_type = IREE_HAL_PROFILE_CONTENT_TYPE_MEMORY_EVENTS;
+  metadata.name = iree_make_cstring_view("amdgpu.memory");
+  metadata.session_id = session_id;
+  return iree_hal_amdgpu_profile_event_stream_write(
+      &streams->memory.stream, sink, metadata, host_allocator);
+}
+
+iree_status_t iree_hal_amdgpu_profile_event_streams_write_queue(
+    iree_hal_amdgpu_profile_event_streams_t* streams,
+    iree_hal_profile_sink_t* sink, uint64_t session_id,
+    iree_allocator_t host_allocator) {
+  iree_hal_profile_chunk_metadata_t metadata =
+      iree_hal_profile_chunk_metadata_default();
+  metadata.content_type = IREE_HAL_PROFILE_CONTENT_TYPE_QUEUE_EVENTS;
+  metadata.name = iree_make_cstring_view("amdgpu.queue");
+  metadata.session_id = session_id;
+  return iree_hal_amdgpu_profile_event_stream_write(
+      &streams->queue.stream, sink, metadata, host_allocator);
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/profile_events.h b/runtime/src/iree/hal/drivers/amdgpu/profile_events.h
new file mode 100644
index 0000000..59d5998
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/profile_events.h
@@ -0,0 +1,123 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_PROFILE_EVENTS_H_
+#define IREE_HAL_DRIVERS_AMDGPU_PROFILE_EVENTS_H_
+
+#include "iree/base/api.h"
+#include "iree/base/threading/mutex.h"
+#include "iree/hal/api.h"
+#include "iree/hal/profile_sink.h"
+#include "iree/hal/utils/profile_event_ring.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+//===----------------------------------------------------------------------===//
+// iree_hal_amdgpu_profile_event_streams_t
+//===----------------------------------------------------------------------===//
+
+// Lossy fixed-capacity host-side profiling event stream.
+//
+// Producers append records until the stream reaches capacity. Once full, new
+// records are dropped and accounted until a flush writes retained records to
+// the sink and advances the read position. The stream writes a metadata-only
+// TRUNCATED chunk if only dropped records are available.
+typedef struct iree_hal_amdgpu_profile_event_stream_t {
+  // Mutex protecting positions, dropped counts, and event id allocation.
+  iree_slim_mutex_t mutex;
+
+  // Lossy fixed-capacity host-side event ring.
+  iree_hal_profile_event_ring_t ring;
+} iree_hal_amdgpu_profile_event_stream_t;
+
+// Host-side profiling event streams owned by an AMDGPU logical device.
+typedef struct iree_hal_amdgpu_profile_event_streams_t {
+  // Memory lifecycle event stream and allocation id state.
+  struct {
+    // Lossy ring for iree_hal_profile_memory_event_t records.
+    iree_hal_amdgpu_profile_event_stream_t stream;
+
+    // Next nonzero allocation id assigned to profiled memory objects.
+    uint64_t next_allocation_id;
+  } memory;
+
+  // Queue operation event stream.
+  struct {
+    // Lossy ring for iree_hal_profile_queue_event_t records.
+    iree_hal_amdgpu_profile_event_stream_t stream;
+  } queue;
+} iree_hal_amdgpu_profile_event_streams_t;
+
+// Initializes stream mutexes in caller-owned zeroed storage.
+void iree_hal_amdgpu_profile_event_streams_initialize(
+    iree_hal_amdgpu_profile_event_streams_t* streams);
+
+// Deinitializes streams and releases all event storage.
+void iree_hal_amdgpu_profile_event_streams_deinitialize(
+    iree_hal_amdgpu_profile_event_streams_t* streams,
+    iree_allocator_t host_allocator);
+
+// Returns true when memory event storage is allocated.
+bool iree_hal_amdgpu_profile_event_streams_has_memory_storage(
+    const iree_hal_amdgpu_profile_event_streams_t* streams);
+
+// Returns true when queue event storage is allocated.
+bool iree_hal_amdgpu_profile_event_streams_has_queue_storage(
+    const iree_hal_amdgpu_profile_event_streams_t* streams);
+
+// Allocates memory event storage if not already allocated.
+iree_status_t iree_hal_amdgpu_profile_event_streams_ensure_memory_storage(
+    iree_hal_amdgpu_profile_event_streams_t* streams,
+    iree_host_size_t event_capacity, iree_allocator_t host_allocator);
+
+// Allocates queue event storage if not already allocated.
+iree_status_t iree_hal_amdgpu_profile_event_streams_ensure_queue_storage(
+    iree_hal_amdgpu_profile_event_streams_t* streams,
+    iree_host_size_t event_capacity, iree_allocator_t host_allocator);
+
+// Clears the memory event stream and resets memory event/allocation ids.
+void iree_hal_amdgpu_profile_event_streams_clear_memory(
+    iree_hal_amdgpu_profile_event_streams_t* streams);
+
+// Clears the queue event stream and resets queue event ids.
+void iree_hal_amdgpu_profile_event_streams_clear_queue(
+    iree_hal_amdgpu_profile_event_streams_t* streams);
+
+// Allocates a memory allocation id for |active_session_id|.
+uint64_t iree_hal_amdgpu_profile_event_streams_allocate_memory_allocation_id(
+    iree_hal_amdgpu_profile_event_streams_t* streams,
+    uint64_t active_session_id, uint64_t* out_session_id);
+
+// Records one memory event if |session_id| matches |active_session_id|.
+bool iree_hal_amdgpu_profile_event_streams_record_memory_event(
+    iree_hal_amdgpu_profile_event_streams_t* streams,
+    uint64_t active_session_id, uint64_t session_id,
+    const iree_hal_profile_memory_event_t* event);
+
+// Records one queue event.
+void iree_hal_amdgpu_profile_event_streams_record_queue_event(
+    iree_hal_amdgpu_profile_event_streams_t* streams,
+    const iree_hal_profile_queue_event_t* event);
+
+// Writes pending memory events to |sink|.
+iree_status_t iree_hal_amdgpu_profile_event_streams_write_memory(
+    iree_hal_amdgpu_profile_event_streams_t* streams,
+    iree_hal_profile_sink_t* sink, uint64_t session_id,
+    iree_allocator_t host_allocator);
+
+// Writes pending queue events to |sink|.
+iree_status_t iree_hal_amdgpu_profile_event_streams_write_queue(
+    iree_hal_amdgpu_profile_event_streams_t* streams,
+    iree_hal_profile_sink_t* sink, uint64_t session_id,
+    iree_allocator_t host_allocator);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_PROFILE_EVENTS_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/profile_events_test.cc b/runtime/src/iree/hal/drivers/amdgpu/profile_events_test.cc
new file mode 100644
index 0000000..7c9ea40
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/profile_events_test.cc
@@ -0,0 +1,252 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/profile_events.h"
+
+#include <cstring>
+#include <utility>
+#include <vector>
+
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace iree::hal::amdgpu {
+namespace {
+
+struct CapturedChunk {
+  // Metadata copied from the sink write callback.
+  iree_hal_profile_chunk_metadata_t metadata =
+      iree_hal_profile_chunk_metadata_default();
+
+  // Concatenated iovec payload copied from the sink write callback.
+  std::vector<uint8_t> payload;
+};
+
+struct CapturingProfileSink {
+  // HAL resource header for the sink.
+  iree_hal_resource_t resource;
+
+  // Chunks copied from sink write callbacks.
+  std::vector<CapturedChunk> chunks;
+
+  // Optional stream used to force a queue-event drop during a write callback.
+  iree_hal_amdgpu_profile_event_streams_t* queue_streams_to_record = nullptr;
+
+  // Queue event used when |queue_streams_to_record| is set.
+  iree_hal_profile_queue_event_t queue_event_to_record = {};
+
+  // Number of queue events to append during the next write callback.
+  int queue_event_record_count_on_write = 0;
+};
+
+static CapturingProfileSink* CapturingProfileSinkCast(
+    iree_hal_profile_sink_t* sink) {
+  return reinterpret_cast<CapturingProfileSink*>(sink);
+}
+
+static void CapturingProfileSinkDestroy(iree_hal_profile_sink_t* sink) {
+  (void)sink;
+}
+
+static iree_status_t CapturingProfileSinkBeginSession(
+    iree_hal_profile_sink_t* sink,
+    const iree_hal_profile_chunk_metadata_t* metadata) {
+  (void)sink;
+  (void)metadata;
+  return iree_ok_status();
+}
+
+static iree_status_t CapturingProfileSinkWrite(
+    iree_hal_profile_sink_t* sink,
+    const iree_hal_profile_chunk_metadata_t* metadata,
+    iree_host_size_t iovec_count, const iree_const_byte_span_t* iovecs) {
+  auto* captured_sink = CapturingProfileSinkCast(sink);
+  CapturedChunk chunk;
+  chunk.metadata = *metadata;
+  for (iree_host_size_t i = 0; i < iovec_count; ++i) {
+    const uint8_t* source = iovecs[i].data;
+    chunk.payload.insert(chunk.payload.end(), source,
+                         source + iovecs[i].data_length);
+  }
+  captured_sink->chunks.push_back(std::move(chunk));
+  while (captured_sink->queue_event_record_count_on_write > 0) {
+    --captured_sink->queue_event_record_count_on_write;
+    iree_hal_amdgpu_profile_event_streams_record_queue_event(
+        captured_sink->queue_streams_to_record,
+        &captured_sink->queue_event_to_record);
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t CapturingProfileSinkEndSession(
+    iree_hal_profile_sink_t* sink,
+    const iree_hal_profile_chunk_metadata_t* metadata,
+    iree_status_code_t session_status_code) {
+  (void)sink;
+  (void)metadata;
+  (void)session_status_code;
+  return iree_ok_status();
+}
+
+static const iree_hal_profile_sink_vtable_t kCapturingProfileSinkVTable = {
+    .destroy = CapturingProfileSinkDestroy,
+    .begin_session = CapturingProfileSinkBeginSession,
+    .write = CapturingProfileSinkWrite,
+    .end_session = CapturingProfileSinkEndSession,
+};
+
+static void CapturingProfileSinkInitialize(CapturingProfileSink* sink) {
+  iree_hal_resource_initialize(&kCapturingProfileSinkVTable, &sink->resource);
+}
+
+static iree_hal_profile_sink_t* CapturingProfileSinkAsBase(
+    CapturingProfileSink* sink) {
+  return reinterpret_cast<iree_hal_profile_sink_t*>(sink);
+}
+
+template <typename T>
+static std::vector<T> DecodeRecords(const CapturedChunk& chunk) {
+  EXPECT_EQ(0u, chunk.payload.size() % sizeof(T));
+  std::vector<T> records(chunk.payload.size() / sizeof(T));
+  memcpy(records.data(), chunk.payload.data(), chunk.payload.size());
+  return records;
+}
+
+class ProfileEventsTest : public ::testing::Test {
+ protected:
+  void SetUp() override {
+    iree_hal_amdgpu_profile_event_streams_initialize(&streams_);
+    CapturingProfileSinkInitialize(&sink_);
+  }
+
+  void TearDown() override {
+    iree_hal_amdgpu_profile_event_streams_deinitialize(&streams_,
+                                                       iree_allocator_system());
+  }
+
+  iree_hal_amdgpu_profile_event_streams_t streams_ = {};
+  CapturingProfileSink sink_ = {};
+};
+
+TEST_F(ProfileEventsTest, MemoryEventsPreserveRecordsAndReportDrops) {
+  IREE_ASSERT_OK(iree_hal_amdgpu_profile_event_streams_ensure_memory_storage(
+      &streams_, /*event_capacity=*/2, iree_allocator_system()));
+  iree_hal_amdgpu_profile_event_streams_clear_memory(&streams_);
+
+  uint64_t session_id = 0;
+  const uint64_t allocation_id =
+      iree_hal_amdgpu_profile_event_streams_allocate_memory_allocation_id(
+          &streams_, /*active_session_id=*/7, &session_id);
+  EXPECT_EQ(7u, session_id);
+  EXPECT_EQ(1u, allocation_id);
+
+  iree_hal_profile_memory_event_t event =
+      iree_hal_profile_memory_event_default();
+  event.type = IREE_HAL_PROFILE_MEMORY_EVENT_TYPE_BUFFER_ALLOCATE;
+  event.allocation_id = allocation_id;
+  event.length = 4096;
+
+  EXPECT_TRUE(iree_hal_amdgpu_profile_event_streams_record_memory_event(
+      &streams_, /*active_session_id=*/7, session_id, &event));
+  event.type = IREE_HAL_PROFILE_MEMORY_EVENT_TYPE_BUFFER_FREE;
+  EXPECT_TRUE(iree_hal_amdgpu_profile_event_streams_record_memory_event(
+      &streams_, /*active_session_id=*/7, session_id, &event));
+  EXPECT_FALSE(iree_hal_amdgpu_profile_event_streams_record_memory_event(
+      &streams_, /*active_session_id=*/7, session_id, &event));
+
+  IREE_ASSERT_OK(iree_hal_amdgpu_profile_event_streams_write_memory(
+      &streams_, CapturingProfileSinkAsBase(&sink_), session_id,
+      iree_allocator_system()));
+
+  ASSERT_EQ(1u, sink_.chunks.size());
+  const CapturedChunk& chunk = sink_.chunks.front();
+  EXPECT_TRUE(
+      iree_string_view_equal(IREE_HAL_PROFILE_CONTENT_TYPE_MEMORY_EVENTS,
+                             chunk.metadata.content_type));
+  EXPECT_EQ(7u, chunk.metadata.session_id);
+  EXPECT_TRUE(iree_all_bits_set(chunk.metadata.flags,
+                                IREE_HAL_PROFILE_CHUNK_FLAG_TRUNCATED));
+  EXPECT_EQ(1u, chunk.metadata.dropped_record_count);
+
+  std::vector<iree_hal_profile_memory_event_t> records =
+      DecodeRecords<iree_hal_profile_memory_event_t>(chunk);
+  ASSERT_EQ(2u, records.size());
+  EXPECT_EQ(IREE_HAL_PROFILE_MEMORY_EVENT_TYPE_BUFFER_ALLOCATE,
+            records[0].type);
+  EXPECT_EQ(IREE_HAL_PROFILE_MEMORY_EVENT_TYPE_BUFFER_FREE, records[1].type);
+  EXPECT_EQ(1u, records[0].event_id);
+  EXPECT_EQ(2u, records[1].event_id);
+  EXPECT_NE(0, records[0].host_time_ns);
+}
+
+TEST_F(ProfileEventsTest, MemoryEventsSkipMismatchedSession) {
+  IREE_ASSERT_OK(iree_hal_amdgpu_profile_event_streams_ensure_memory_storage(
+      &streams_, /*event_capacity=*/2, iree_allocator_system()));
+  iree_hal_amdgpu_profile_event_streams_clear_memory(&streams_);
+
+  iree_hal_profile_memory_event_t event =
+      iree_hal_profile_memory_event_default();
+  event.type = IREE_HAL_PROFILE_MEMORY_EVENT_TYPE_BUFFER_ALLOCATE;
+  EXPECT_FALSE(iree_hal_amdgpu_profile_event_streams_record_memory_event(
+      &streams_, /*active_session_id=*/7, /*session_id=*/9, &event));
+
+  IREE_ASSERT_OK(iree_hal_amdgpu_profile_event_streams_write_memory(
+      &streams_, CapturingProfileSinkAsBase(&sink_), /*session_id=*/7,
+      iree_allocator_system()));
+  EXPECT_TRUE(sink_.chunks.empty());
+}
+
+TEST_F(ProfileEventsTest, QueueEventsReportRetainedAndMetadataOnlyDrops) {
+  IREE_ASSERT_OK(iree_hal_amdgpu_profile_event_streams_ensure_queue_storage(
+      &streams_, /*event_capacity=*/1, iree_allocator_system()));
+  iree_hal_amdgpu_profile_event_streams_clear_queue(&streams_);
+
+  iree_hal_profile_queue_event_t event = iree_hal_profile_queue_event_default();
+  event.type = IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_DISPATCH;
+  event.stream_id = 42;
+  iree_hal_amdgpu_profile_event_streams_record_queue_event(&streams_, &event);
+  iree_hal_amdgpu_profile_event_streams_record_queue_event(&streams_, &event);
+
+  sink_.queue_streams_to_record = &streams_;
+  sink_.queue_event_to_record = event;
+  sink_.queue_event_record_count_on_write = 1;
+  IREE_ASSERT_OK(iree_hal_amdgpu_profile_event_streams_write_queue(
+      &streams_, CapturingProfileSinkAsBase(&sink_), /*session_id=*/11,
+      iree_allocator_system()));
+
+  ASSERT_EQ(1u, sink_.chunks.size());
+  const CapturedChunk& chunk = sink_.chunks.front();
+  EXPECT_TRUE(iree_string_view_equal(IREE_HAL_PROFILE_CONTENT_TYPE_QUEUE_EVENTS,
+                                     chunk.metadata.content_type));
+  EXPECT_EQ(11u, chunk.metadata.session_id);
+  EXPECT_TRUE(iree_all_bits_set(chunk.metadata.flags,
+                                IREE_HAL_PROFILE_CHUNK_FLAG_TRUNCATED));
+  EXPECT_EQ(1u, chunk.metadata.dropped_record_count);
+
+  std::vector<iree_hal_profile_queue_event_t> records =
+      DecodeRecords<iree_hal_profile_queue_event_t>(chunk);
+  ASSERT_EQ(1u, records.size());
+  EXPECT_EQ(IREE_HAL_PROFILE_QUEUE_EVENT_TYPE_DISPATCH, records[0].type);
+  EXPECT_EQ(1u, records[0].event_id);
+  EXPECT_EQ(42u, records[0].stream_id);
+
+  IREE_ASSERT_OK(iree_hal_amdgpu_profile_event_streams_write_queue(
+      &streams_, CapturingProfileSinkAsBase(&sink_), /*session_id=*/11,
+      iree_allocator_system()));
+
+  ASSERT_EQ(2u, sink_.chunks.size());
+  const CapturedChunk& metadata_only_chunk = sink_.chunks.back();
+  EXPECT_TRUE(
+      iree_string_view_equal(IREE_HAL_PROFILE_CONTENT_TYPE_QUEUE_EVENTS,
+                             metadata_only_chunk.metadata.content_type));
+  EXPECT_TRUE(iree_all_bits_set(metadata_only_chunk.metadata.flags,
+                                IREE_HAL_PROFILE_CHUNK_FLAG_TRUNCATED));
+  EXPECT_EQ(1u, metadata_only_chunk.metadata.dropped_record_count);
+  EXPECT_TRUE(metadata_only_chunk.payload.empty());
+}
+
+}  // namespace
+}  // namespace iree::hal::amdgpu
diff --git a/runtime/src/iree/hal/drivers/amdgpu/profile_metadata.c b/runtime/src/iree/hal/drivers/amdgpu/profile_metadata.c
new file mode 100644
index 0000000..e9f4fea
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/profile_metadata.c
@@ -0,0 +1,1146 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/profile_metadata.h"
+
+#include <string.h>
+
+#include "iree/base/alignment.h"
+
+typedef struct iree_hal_amdgpu_profile_metadata_snapshot_t {
+  // Host allocator used for snapshot copies.
+  iree_allocator_t host_allocator;
+  // Executable records copied from the registry.
+  iree_hal_profile_executable_record_t* executable_records;
+  // Number of executable records in |executable_records|.
+  iree_host_size_t executable_record_count;
+  // Packed executable code-object records copied from the registry.
+  uint8_t* executable_code_object_record_data;
+  // Byte length of |executable_code_object_record_data|.
+  iree_host_size_t executable_code_object_record_data_length;
+  // Executable code-object load records copied from the registry.
+  iree_hal_profile_executable_code_object_load_record_t*
+      executable_code_object_load_records;
+  // Number of executable code-object load records in
+  // |executable_code_object_load_records|.
+  iree_host_size_t executable_code_object_load_record_count;
+  // Packed executable export records copied from the registry.
+  uint8_t* executable_export_record_data;
+  // Byte length of |executable_export_record_data|.
+  iree_host_size_t executable_export_record_data_length;
+  // Command-buffer records copied from the registry.
+  iree_hal_profile_command_buffer_record_t* command_buffer_records;
+  // Number of command-buffer records in |command_buffer_records|.
+  iree_host_size_t command_buffer_record_count;
+  // Command-operation records copied from the registry.
+  iree_hal_profile_command_operation_record_t* command_operation_records;
+  // Number of command-operation records in |command_operation_records|.
+  iree_host_size_t command_operation_record_count;
+  // Cursor position after this snapshot's copied records.
+  iree_hal_amdgpu_profile_metadata_cursor_t end_cursor;
+} iree_hal_amdgpu_profile_metadata_snapshot_t;
+
+typedef struct iree_hal_amdgpu_profile_hash64_state_t {
+  // First SipHash state word.
+  uint64_t v0;
+  // Second SipHash state word.
+  uint64_t v1;
+  // Third SipHash state word.
+  uint64_t v2;
+  // Fourth SipHash state word.
+  uint64_t v3;
+  // Total number of bytes absorbed by the hash.
+  uint64_t total_length;
+  // Partial input bytes waiting to form the next 64-bit word.
+  uint8_t tail[8];
+  // Number of valid bytes in |tail|.
+  uint8_t tail_length;
+} iree_hal_amdgpu_profile_hash64_state_t;
+
+static uint64_t iree_hal_amdgpu_profile_hash64_rotate_left(uint64_t value,
+                                                           int bit_count) {
+  return (value << bit_count) | (value >> (64 - bit_count));
+}
+
+static void iree_hal_amdgpu_profile_hash64_round(
+    iree_hal_amdgpu_profile_hash64_state_t* state) {
+  state->v0 += state->v1;
+  state->v1 = iree_hal_amdgpu_profile_hash64_rotate_left(state->v1, 13);
+  state->v1 ^= state->v0;
+  state->v0 = iree_hal_amdgpu_profile_hash64_rotate_left(state->v0, 32);
+  state->v2 += state->v3;
+  state->v3 = iree_hal_amdgpu_profile_hash64_rotate_left(state->v3, 16);
+  state->v3 ^= state->v2;
+  state->v0 += state->v3;
+  state->v3 = iree_hal_amdgpu_profile_hash64_rotate_left(state->v3, 21);
+  state->v3 ^= state->v0;
+  state->v2 += state->v1;
+  state->v1 = iree_hal_amdgpu_profile_hash64_rotate_left(state->v1, 17);
+  state->v1 ^= state->v2;
+  state->v2 = iree_hal_amdgpu_profile_hash64_rotate_left(state->v2, 32);
+}
+
+static void iree_hal_amdgpu_profile_hash64_initialize(
+    uint64_t key0, uint64_t key1,
+    iree_hal_amdgpu_profile_hash64_state_t* out_state) {
+  memset(out_state, 0, sizeof(*out_state));
+  out_state->v0 = UINT64_C(0x736f6d6570736575) ^ key0;
+  out_state->v1 = UINT64_C(0x646f72616e646f6d) ^ key1;
+  out_state->v2 = UINT64_C(0x6c7967656e657261) ^ key0;
+  out_state->v3 = UINT64_C(0x7465646279746573) ^ key1;
+}
+
+static void iree_hal_amdgpu_profile_hash64_compress(
+    iree_hal_amdgpu_profile_hash64_state_t* state, uint64_t word) {
+  state->v3 ^= word;
+  for (int i = 0; i < 2; ++i) {
+    iree_hal_amdgpu_profile_hash64_round(state);
+  }
+  state->v0 ^= word;
+}
+
+static void iree_hal_amdgpu_profile_hash64_append(
+    iree_hal_amdgpu_profile_hash64_state_t* state, const void* data,
+    iree_host_size_t data_length) {
+  if (data_length == 0) return;
+
+  const uint8_t* cursor = (const uint8_t*)data;
+  const uint8_t* const end = cursor + data_length;
+  state->total_length += data_length;
+
+  if (state->tail_length != 0) {
+    while (cursor < end && state->tail_length < sizeof(state->tail)) {
+      state->tail[state->tail_length++] = *cursor++;
+    }
+    if (state->tail_length == sizeof(state->tail)) {
+      iree_hal_amdgpu_profile_hash64_compress(
+          state, iree_unaligned_load_le_u64((const uint64_t*)state->tail));
+      state->tail_length = 0;
+    }
+  }
+
+  const iree_host_size_t remaining_length = (iree_host_size_t)(end - cursor);
+  const uint8_t* const word_end =
+      cursor + remaining_length - (remaining_length % sizeof(uint64_t));
+  for (; cursor < word_end; cursor += sizeof(uint64_t)) {
+    iree_hal_amdgpu_profile_hash64_compress(
+        state, iree_unaligned_load_le_u64((const uint64_t*)cursor));
+  }
+
+  while (cursor < end) {
+    state->tail[state->tail_length++] = *cursor++;
+  }
+}
+
+static uint64_t iree_hal_amdgpu_profile_hash64_finalize(
+    iree_hal_amdgpu_profile_hash64_state_t* state) {
+  uint64_t final_word = (state->total_length & 0xFFu) << 56;
+  for (uint8_t i = 0; i < state->tail_length; ++i) {
+    final_word |= ((uint64_t)state->tail[i]) << (i * 8);
+  }
+
+  state->v3 ^= final_word;
+  for (int i = 0; i < 2; ++i) {
+    iree_hal_amdgpu_profile_hash64_round(state);
+  }
+  state->v0 ^= final_word;
+  state->v2 ^= 0xFFu;
+  for (int i = 0; i < 4; ++i) {
+    iree_hal_amdgpu_profile_hash64_round(state);
+  }
+  return state->v0 ^ state->v1 ^ state->v2 ^ state->v3;
+}
+
+static void iree_hal_amdgpu_profile_hash128_initialize(
+    iree_hal_amdgpu_profile_hash64_state_t out_states[2]) {
+  iree_hal_amdgpu_profile_hash64_initialize(UINT64_C(0x0706050403020100),
+                                            UINT64_C(0x0f0e0d0c0b0a0908),
+                                            &out_states[0]);
+  iree_hal_amdgpu_profile_hash64_initialize(UINT64_C(0x1716151413121110),
+                                            UINT64_C(0x1f1e1d1c1b1a1918),
+                                            &out_states[1]);
+}
+
+static void iree_hal_amdgpu_profile_hash128_append(
+    iree_hal_amdgpu_profile_hash64_state_t states[2], const void* data,
+    iree_host_size_t data_length) {
+  iree_hal_amdgpu_profile_hash64_append(&states[0], data, data_length);
+  iree_hal_amdgpu_profile_hash64_append(&states[1], data, data_length);
+}
+
+static void iree_hal_amdgpu_profile_hash128_append_u16(
+    iree_hal_amdgpu_profile_hash64_state_t states[2], uint16_t value) {
+  uint16_t storage = 0;
+  iree_unaligned_store_le_u16(&storage, value);
+  iree_hal_amdgpu_profile_hash128_append(states, &storage, sizeof(storage));
+}
+
+static void iree_hal_amdgpu_profile_hash128_append_u32(
+    iree_hal_amdgpu_profile_hash64_state_t states[2], uint32_t value) {
+  uint32_t storage = 0;
+  iree_unaligned_store_le_u32(&storage, value);
+  iree_hal_amdgpu_profile_hash128_append(states, &storage, sizeof(storage));
+}
+
+static void iree_hal_amdgpu_profile_hash128_append_u64(
+    iree_hal_amdgpu_profile_hash64_state_t states[2], uint64_t value) {
+  uint64_t storage = 0;
+  iree_unaligned_store_le_u64(&storage, value);
+  iree_hal_amdgpu_profile_hash128_append(states, &storage, sizeof(storage));
+}
+
+static void iree_hal_amdgpu_profile_hash128_finalize(
+    iree_hal_amdgpu_profile_hash64_state_t states[2], uint64_t out_hash[2]) {
+  out_hash[0] = iree_hal_amdgpu_profile_hash64_finalize(&states[0]);
+  out_hash[1] = iree_hal_amdgpu_profile_hash64_finalize(&states[1]);
+}
+
+void iree_hal_amdgpu_profile_metadata_initialize(
+    iree_allocator_t host_allocator,
+    iree_hal_amdgpu_profile_metadata_registry_t* out_registry) {
+  memset(out_registry, 0, sizeof(*out_registry));
+  out_registry->host_allocator = host_allocator;
+  out_registry->next_executable_id = 1;
+  out_registry->next_command_buffer_id = 1;
+  iree_slim_mutex_initialize(&out_registry->mutex);
+}
+
+void iree_hal_amdgpu_profile_metadata_deinitialize(
+    iree_hal_amdgpu_profile_metadata_registry_t* registry) {
+  iree_allocator_t host_allocator = registry->host_allocator;
+  iree_allocator_free(host_allocator, registry->command_operation_records);
+  iree_allocator_free(host_allocator, registry->command_buffer_records);
+  iree_allocator_free(host_allocator, registry->executable_export_record_data);
+  iree_allocator_free(host_allocator,
+                      registry->executable_code_object_load_records);
+  iree_allocator_free(host_allocator,
+                      registry->executable_code_object_record_data);
+  iree_allocator_free(host_allocator, registry->executable_records);
+  iree_slim_mutex_deinitialize(&registry->mutex);
+  memset(registry, 0, sizeof(*registry));
+}
+
+void iree_hal_amdgpu_profile_metadata_hash_code_object(
+    iree_const_byte_span_t code_object_data, uint64_t out_hash[2]) {
+  iree_hal_amdgpu_profile_hash64_state_t states[2];
+  iree_hal_amdgpu_profile_hash128_initialize(states);
+  iree_hal_amdgpu_profile_hash128_append(states, code_object_data.data,
+                                         code_object_data.data_length);
+  iree_hal_amdgpu_profile_hash128_finalize(states, out_hash);
+}
+
+static void iree_hal_amdgpu_profile_metadata_hash_pipeline(
+    const uint64_t code_object_hash[2], uint32_t export_ordinal,
+    const iree_hal_amdgpu_device_kernel_args_t* host_kernel_args,
+    iree_string_view_t export_name, uint64_t out_hash[2]) {
+  iree_hal_amdgpu_profile_hash64_state_t states[2];
+  iree_hal_amdgpu_profile_hash128_initialize(states);
+  iree_hal_amdgpu_profile_hash128_append_u64(states, code_object_hash[0]);
+  iree_hal_amdgpu_profile_hash128_append_u64(states, code_object_hash[1]);
+  iree_hal_amdgpu_profile_hash128_append_u32(states, export_ordinal);
+  iree_hal_amdgpu_profile_hash128_append_u16(states,
+                                             host_kernel_args->constant_count);
+  iree_hal_amdgpu_profile_hash128_append_u16(states,
+                                             host_kernel_args->binding_count);
+  iree_hal_amdgpu_profile_hash128_append_u32(
+      states, host_kernel_args->workgroup_size[0]);
+  iree_hal_amdgpu_profile_hash128_append_u32(
+      states, host_kernel_args->workgroup_size[1]);
+  iree_hal_amdgpu_profile_hash128_append_u32(
+      states, host_kernel_args->workgroup_size[2]);
+  const uint64_t export_name_length = export_name.size;
+  iree_hal_amdgpu_profile_hash128_append_u64(states, export_name_length);
+  iree_hal_amdgpu_profile_hash128_append(states, export_name.data,
+                                         export_name.size);
+  iree_hal_amdgpu_profile_hash128_finalize(states, out_hash);
+}
+
+static iree_status_t iree_hal_amdgpu_profile_metadata_export_record_length(
+    iree_string_view_t name, iree_host_size_t* out_record_length) {
+  *out_record_length = 0;
+  if (IREE_UNLIKELY(name.size > UINT32_MAX)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "profile executable export name is too long");
+  }
+  iree_host_size_t record_length = 0;
+  IREE_RETURN_IF_ERROR(IREE_STRUCT_LAYOUT(
+      sizeof(iree_hal_profile_executable_export_record_t), &record_length,
+      IREE_STRUCT_FIELD(name.size, uint8_t, NULL)));
+  if (IREE_UNLIKELY(record_length > UINT32_MAX)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "profile executable export record length exceeds uint32_t");
+  }
+  *out_record_length = record_length;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_profile_metadata_export_data_length(
+    iree_host_size_t export_count,
+    const iree_hal_executable_export_info_t* export_infos,
+    const iree_host_size_t* export_parameter_offsets,
+    iree_host_size_t* out_data_length) {
+  *out_data_length = 0;
+  iree_host_size_t data_length = 0;
+  iree_status_t status = iree_ok_status();
+  for (iree_host_size_t i = 0; i < export_count && iree_status_is_ok(status);
+       ++i) {
+    iree_host_size_t record_length = 0;
+    status = iree_hal_amdgpu_profile_metadata_export_record_length(
+        export_infos[i].name, &record_length);
+    if (iree_status_is_ok(status) &&
+        IREE_UNLIKELY(!iree_host_size_checked_add(data_length, record_length,
+                                                  &data_length))) {
+      status = iree_make_status(
+          IREE_STATUS_OUT_OF_RANGE,
+          "profile executable export metadata length overflow");
+    }
+    const iree_host_size_t parameter_count =
+        export_parameter_offsets[i + 1] - export_parameter_offsets[i];
+    if (iree_status_is_ok(status) &&
+        IREE_UNLIKELY(parameter_count > UINT32_MAX)) {
+      status = iree_make_status(
+          IREE_STATUS_OUT_OF_RANGE,
+          "profile executable export parameter count exceeds uint32_t");
+    }
+  }
+  if (iree_status_is_ok(status)) {
+    *out_data_length = data_length;
+  }
+  return status;
+}
+
+static iree_status_t iree_hal_amdgpu_profile_metadata_append_export_records(
+    uint64_t executable_id, iree_host_size_t export_count,
+    const iree_hal_executable_export_info_t* export_infos,
+    const iree_host_size_t* export_parameter_offsets,
+    const uint64_t code_object_hash[2],
+    const iree_hal_amdgpu_device_kernel_args_t* host_kernel_args,
+    uint8_t* target_data) {
+  uint8_t* cursor = target_data;
+  for (iree_host_size_t i = 0; i < export_count; ++i) {
+    const iree_string_view_t name = export_infos[i].name;
+    iree_host_size_t record_length = 0;
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_profile_metadata_export_record_length(
+        name, &record_length));
+
+    iree_hal_profile_executable_export_record_t record =
+        iree_hal_profile_executable_export_record_default();
+    record.record_length = (uint32_t)record_length;
+    record.executable_id = executable_id;
+    record.export_ordinal = (uint32_t)i;
+    record.constant_count = host_kernel_args[i].constant_count;
+    record.binding_count = host_kernel_args[i].binding_count;
+    record.parameter_count = (uint32_t)(export_parameter_offsets[i + 1] -
+                                        export_parameter_offsets[i]);
+    record.workgroup_size[0] = host_kernel_args[i].workgroup_size[0];
+    record.workgroup_size[1] = host_kernel_args[i].workgroup_size[1];
+    record.workgroup_size[2] = host_kernel_args[i].workgroup_size[2];
+    if (code_object_hash) {
+      record.flags |= IREE_HAL_PROFILE_EXECUTABLE_EXPORT_FLAG_PIPELINE_HASH;
+      iree_hal_amdgpu_profile_metadata_hash_pipeline(
+          code_object_hash, record.export_ordinal, &host_kernel_args[i], name,
+          record.pipeline_hash);
+    }
+    record.name_length = (uint32_t)name.size;
+
+    memcpy(cursor, &record, sizeof(record));
+    if (name.size > 0) {
+      memcpy(cursor + sizeof(record), name.data, name.size);
+    }
+    cursor += record_length;
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t
+iree_hal_amdgpu_profile_metadata_code_object_record_data_length(
+    iree_const_byte_span_t code_object_data,
+    iree_host_size_t* out_data_length) {
+  *out_data_length = 0;
+  if (IREE_UNLIKELY(!code_object_data.data ||
+                    code_object_data.data_length == 0)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "profile executable code-object data is required");
+  }
+  iree_host_size_t data_length = 0;
+  IREE_RETURN_IF_ERROR(IREE_STRUCT_LAYOUT(
+      sizeof(iree_hal_profile_executable_code_object_record_t), &data_length,
+      IREE_STRUCT_FIELD(code_object_data.data_length, uint8_t, NULL)));
+  if (IREE_UNLIKELY(data_length > UINT32_MAX)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "profile executable code-object record length exceeds uint32_t");
+  }
+  *out_data_length = data_length;
+  return iree_ok_status();
+}
+
+static void iree_hal_amdgpu_profile_metadata_append_code_object_record(
+    uint64_t executable_id, iree_const_byte_span_t code_object_data,
+    const uint64_t code_object_hash[2], uint8_t* target_data) {
+  iree_hal_profile_executable_code_object_record_t record =
+      iree_hal_profile_executable_code_object_record_default();
+  record.record_length =
+      (uint32_t)(sizeof(record) + code_object_data.data_length);
+  record.executable_id = executable_id;
+  record.code_object_id = executable_id;
+  record.data_length = code_object_data.data_length;
+  if (code_object_hash) {
+    record.flags |=
+        IREE_HAL_PROFILE_EXECUTABLE_CODE_OBJECT_FLAG_CODE_OBJECT_HASH;
+    record.code_object_hash[0] = code_object_hash[0];
+    record.code_object_hash[1] = code_object_hash[1];
+  }
+  memcpy(target_data, &record, sizeof(record));
+  memcpy(target_data + sizeof(record), code_object_data.data,
+         code_object_data.data_length);
+}
+
+static void iree_hal_amdgpu_profile_metadata_append_code_object_load_records(
+    uint64_t executable_id, iree_host_size_t code_object_load_info_count,
+    const iree_hal_amdgpu_profile_code_object_load_info_t*
+        code_object_load_infos,
+    iree_hal_profile_executable_code_object_load_record_t* target_records) {
+  for (iree_host_size_t i = 0; i < code_object_load_info_count; ++i) {
+    iree_hal_profile_executable_code_object_load_record_t record =
+        iree_hal_profile_executable_code_object_load_record_default();
+    record.physical_device_ordinal =
+        code_object_load_infos[i].physical_device_ordinal;
+    record.executable_id = executable_id;
+    record.code_object_id = executable_id;
+    record.load_delta = code_object_load_infos[i].load_delta;
+    record.load_size = code_object_load_infos[i].load_size;
+    target_records[i] = record;
+  }
+}
+
+static bool iree_hal_amdgpu_profile_metadata_has_executable_locked(
+    const iree_hal_amdgpu_profile_metadata_registry_t* registry,
+    uint64_t executable_id) {
+  for (iree_host_size_t i = 0; i < registry->executable_record_count; ++i) {
+    if (registry->executable_records[i].executable_id == executable_id) {
+      return true;
+    }
+  }
+  return false;
+}
+
+static bool iree_hal_amdgpu_profile_metadata_has_code_object_locked(
+    const iree_hal_amdgpu_profile_metadata_registry_t* registry,
+    uint64_t executable_id) {
+  if (registry->executable_code_object_record_data_length == 0) return false;
+  const uint8_t* current = registry->executable_code_object_record_data;
+  const uint8_t* end =
+      current + registry->executable_code_object_record_data_length;
+  while (current < end) {
+    const iree_hal_profile_executable_code_object_record_t* record =
+        (const iree_hal_profile_executable_code_object_record_t*)current;
+    if (record->executable_id == executable_id) return true;
+    current += record->record_length;
+  }
+  return false;
+}
+
+iree_status_t iree_hal_amdgpu_profile_metadata_register_executable(
+    iree_hal_amdgpu_profile_metadata_registry_t* registry,
+    iree_host_size_t export_count,
+    const iree_hal_executable_export_info_t* export_infos,
+    const iree_host_size_t* export_parameter_offsets,
+    const uint64_t code_object_hash[2],
+    const iree_hal_amdgpu_device_kernel_args_t* host_kernel_args,
+    uint64_t* out_executable_id) {
+  IREE_ASSERT_ARGUMENT(out_executable_id);
+  *out_executable_id = 0;
+  if (IREE_UNLIKELY(
+          export_count > 0 &&
+          (!export_infos || !export_parameter_offsets || !host_kernel_args))) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "profile executable metadata is incomplete");
+  }
+  if (IREE_UNLIKELY(export_count > UINT32_MAX)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "profile executable export count exceeds uint32_t");
+  }
+
+  iree_host_size_t export_data_length = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_profile_metadata_export_data_length(
+      export_count, export_infos, export_parameter_offsets,
+      &export_data_length));
+
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, export_count);
+
+  iree_slim_mutex_lock(&registry->mutex);
+
+  const uint64_t executable_id = registry->next_executable_id;
+  iree_status_t status = iree_ok_status();
+
+  if (iree_status_is_ok(status) && registry->executable_record_count + 1 >
+                                       registry->executable_record_capacity) {
+    status = iree_allocator_grow_array(
+        registry->host_allocator,
+        iree_max((iree_host_size_t)16, registry->executable_record_count + 1),
+        sizeof(registry->executable_records[0]),
+        &registry->executable_record_capacity,
+        (void**)&registry->executable_records);
+  }
+
+  iree_host_size_t new_export_data_length =
+      registry->executable_export_record_data_length;
+  if (iree_status_is_ok(status) &&
+      IREE_UNLIKELY(!iree_host_size_checked_add(new_export_data_length,
+                                                export_data_length,
+                                                &new_export_data_length))) {
+    status =
+        iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                         "profile executable export metadata length overflow");
+  }
+  if (iree_status_is_ok(status) &&
+      new_export_data_length >
+          registry->executable_export_record_data_capacity) {
+    status = iree_allocator_grow_array(
+        registry->host_allocator,
+        iree_max((iree_host_size_t)1024, new_export_data_length),
+        sizeof(registry->executable_export_record_data[0]),
+        &registry->executable_export_record_data_capacity,
+        (void**)&registry->executable_export_record_data);
+  }
+
+  if (iree_status_is_ok(status)) {
+    iree_hal_profile_executable_record_t record =
+        iree_hal_profile_executable_record_default();
+    record.executable_id = executable_id;
+    record.export_count = (uint32_t)export_count;
+    if (code_object_hash) {
+      record.flags |= IREE_HAL_PROFILE_EXECUTABLE_FLAG_CODE_OBJECT_HASH;
+      record.code_object_hash[0] = code_object_hash[0];
+      record.code_object_hash[1] = code_object_hash[1];
+    }
+
+    uint8_t* export_data =
+        export_data_length ? registry->executable_export_record_data +
+                                 registry->executable_export_record_data_length
+                           : NULL;
+    status = iree_hal_amdgpu_profile_metadata_append_export_records(
+        executable_id, export_count, export_infos, export_parameter_offsets,
+        code_object_hash, host_kernel_args, export_data);
+    if (iree_status_is_ok(status)) {
+      registry->executable_records[registry->executable_record_count++] =
+          record;
+      registry->executable_export_record_data_length = new_export_data_length;
+      ++registry->next_executable_id;
+      *out_executable_id = executable_id;
+    }
+  }
+
+  iree_slim_mutex_unlock(&registry->mutex);
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+iree_status_t iree_hal_amdgpu_profile_metadata_register_executable_artifacts(
+    iree_hal_amdgpu_profile_metadata_registry_t* registry,
+    uint64_t executable_id, iree_const_byte_span_t code_object_data,
+    const uint64_t code_object_hash[2],
+    iree_host_size_t code_object_load_info_count,
+    const iree_hal_amdgpu_profile_code_object_load_info_t*
+        code_object_load_infos) {
+  if (IREE_UNLIKELY(executable_id == 0)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "profile executable id is required");
+  }
+  if (IREE_UNLIKELY(code_object_load_info_count > UINT32_MAX)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "profile executable code-object load count exceeds uint32_t");
+  }
+  if (IREE_UNLIKELY(code_object_load_info_count > 0 &&
+                    !code_object_load_infos)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "profile executable code-object load records are required");
+  }
+
+  iree_host_size_t code_object_record_data_length = 0;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_profile_metadata_code_object_record_data_length(
+          code_object_data, &code_object_record_data_length));
+
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, code_object_load_info_count);
+
+  iree_slim_mutex_lock(&registry->mutex);
+
+  iree_status_t status = iree_ok_status();
+  if (IREE_UNLIKELY(!iree_hal_amdgpu_profile_metadata_has_executable_locked(
+          registry, executable_id))) {
+    status = iree_make_status(
+        IREE_STATUS_NOT_FOUND,
+        "profile executable metadata not found for executable %" PRIu64,
+        executable_id);
+  }
+  if (iree_status_is_ok(status) &&
+      IREE_UNLIKELY(iree_hal_amdgpu_profile_metadata_has_code_object_locked(
+          registry, executable_id))) {
+    status = iree_make_status(
+        IREE_STATUS_ALREADY_EXISTS,
+        "profile executable code-object artifacts already registered for "
+        "executable %" PRIu64,
+        executable_id);
+  }
+
+  iree_host_size_t new_code_object_record_data_length =
+      registry->executable_code_object_record_data_length;
+  if (iree_status_is_ok(status) &&
+      IREE_UNLIKELY(!iree_host_size_checked_add(
+          new_code_object_record_data_length, code_object_record_data_length,
+          &new_code_object_record_data_length))) {
+    status = iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "profile executable code-object metadata length overflow");
+  }
+  if (iree_status_is_ok(status) &&
+      new_code_object_record_data_length >
+          registry->executable_code_object_record_data_capacity) {
+    status = iree_allocator_grow_array(
+        registry->host_allocator,
+        iree_max((iree_host_size_t)1024, new_code_object_record_data_length),
+        sizeof(registry->executable_code_object_record_data[0]),
+        &registry->executable_code_object_record_data_capacity,
+        (void**)&registry->executable_code_object_record_data);
+  }
+
+  iree_host_size_t new_code_object_load_record_count =
+      registry->executable_code_object_load_record_count;
+  if (iree_status_is_ok(status) &&
+      IREE_UNLIKELY(!iree_host_size_checked_add(
+          new_code_object_load_record_count, code_object_load_info_count,
+          &new_code_object_load_record_count))) {
+    status = iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "profile executable code-object load metadata count overflow");
+  }
+  if (iree_status_is_ok(status) &&
+      new_code_object_load_record_count >
+          registry->executable_code_object_load_record_capacity) {
+    status = iree_allocator_grow_array(
+        registry->host_allocator,
+        iree_max((iree_host_size_t)16, new_code_object_load_record_count),
+        sizeof(registry->executable_code_object_load_records[0]),
+        &registry->executable_code_object_load_record_capacity,
+        (void**)&registry->executable_code_object_load_records);
+  }
+
+  if (iree_status_is_ok(status)) {
+    uint8_t* code_object_data_target =
+        registry->executable_code_object_record_data +
+        registry->executable_code_object_record_data_length;
+    iree_hal_amdgpu_profile_metadata_append_code_object_record(
+        executable_id, code_object_data, code_object_hash,
+        code_object_data_target);
+
+    iree_hal_profile_executable_code_object_load_record_t*
+        code_object_load_records =
+            code_object_load_info_count
+                ? registry->executable_code_object_load_records +
+                      registry->executable_code_object_load_record_count
+                : NULL;
+    iree_hal_amdgpu_profile_metadata_append_code_object_load_records(
+        executable_id, code_object_load_info_count, code_object_load_infos,
+        code_object_load_records);
+
+    registry->executable_code_object_record_data_length =
+        new_code_object_record_data_length;
+    registry->executable_code_object_load_record_count =
+        new_code_object_load_record_count;
+  }
+
+  iree_slim_mutex_unlock(&registry->mutex);
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+iree_status_t iree_hal_amdgpu_profile_metadata_lookup_code_object_load(
+    iree_hal_amdgpu_profile_metadata_registry_t* registry,
+    uint64_t executable_id, uint32_t physical_device_ordinal,
+    iree_hal_profile_executable_code_object_load_record_t* out_record) {
+  *out_record = iree_hal_profile_executable_code_object_load_record_default();
+
+  iree_slim_mutex_lock(&registry->mutex);
+  bool found = false;
+  for (iree_host_size_t i = 0;
+       i < registry->executable_code_object_load_record_count; ++i) {
+    const iree_hal_profile_executable_code_object_load_record_t* record =
+        &registry->executable_code_object_load_records[i];
+    if (record->executable_id == executable_id &&
+        record->physical_device_ordinal == physical_device_ordinal) {
+      *out_record = *record;
+      found = true;
+      break;
+    }
+  }
+  iree_slim_mutex_unlock(&registry->mutex);
+
+  if (IREE_UNLIKELY(!found)) {
+    return iree_make_status(
+        IREE_STATUS_NOT_FOUND,
+        "profile code-object load metadata not found for executable %" PRIu64
+        " on physical device %" PRIu32,
+        executable_id, physical_device_ordinal);
+  }
+  return iree_ok_status();
+}
+
+iree_status_t iree_hal_amdgpu_profile_metadata_register_command_buffer(
+    iree_hal_amdgpu_profile_metadata_registry_t* registry,
+    iree_hal_command_buffer_mode_t mode,
+    iree_hal_command_category_t command_categories,
+    iree_hal_queue_affinity_t queue_affinity,
+    iree_host_size_t physical_device_ordinal, uint64_t* out_command_buffer_id) {
+  IREE_ASSERT_ARGUMENT(out_command_buffer_id);
+  *out_command_buffer_id = 0;
+  if (IREE_UNLIKELY(physical_device_ordinal > UINT32_MAX)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "profile command-buffer physical device ordinal exceeds uint32_t");
+  }
+
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  iree_slim_mutex_lock(&registry->mutex);
+
+  const uint64_t command_buffer_id = registry->next_command_buffer_id;
+  iree_status_t status = iree_ok_status();
+  if (IREE_UNLIKELY(command_buffer_id == UINT64_MAX)) {
+    status = iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                              "profile command-buffer id space exhausted");
+  }
+  if (iree_status_is_ok(status) &&
+      registry->command_buffer_record_count + 1 >
+          registry->command_buffer_record_capacity) {
+    status = iree_allocator_grow_array(
+        registry->host_allocator,
+        iree_max((iree_host_size_t)16,
+                 registry->command_buffer_record_count + 1),
+        sizeof(registry->command_buffer_records[0]),
+        &registry->command_buffer_record_capacity,
+        (void**)&registry->command_buffer_records);
+  }
+  if (iree_status_is_ok(status)) {
+    iree_hal_profile_command_buffer_record_t record =
+        iree_hal_profile_command_buffer_record_default();
+    record.command_buffer_id = command_buffer_id;
+    record.mode = mode;
+    record.command_categories = command_categories;
+    record.queue_affinity = queue_affinity;
+    record.physical_device_ordinal = (uint32_t)physical_device_ordinal;
+    registry->command_buffer_records[registry->command_buffer_record_count++] =
+        record;
+    ++registry->next_command_buffer_id;
+    *out_command_buffer_id = command_buffer_id;
+  }
+
+  iree_slim_mutex_unlock(&registry->mutex);
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+iree_status_t iree_hal_amdgpu_profile_metadata_register_command_operations(
+    iree_hal_amdgpu_profile_metadata_registry_t* registry,
+    iree_host_size_t operation_count,
+    const iree_hal_profile_command_operation_record_t* operations) {
+  if (operation_count == 0) {
+    return iree_ok_status();
+  }
+  if (IREE_UNLIKELY(!operations)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "profile command-operation records are required");
+  }
+
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, operation_count);
+
+  iree_slim_mutex_lock(&registry->mutex);
+
+  iree_host_size_t new_record_count = 0;
+  iree_host_size_t byte_length = 0;
+  iree_status_t status = iree_ok_status();
+  if (IREE_UNLIKELY(
+          !iree_host_size_checked_add(registry->command_operation_record_count,
+                                      operation_count, &new_record_count))) {
+    status =
+        iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                         "profile command-operation metadata count overflow");
+  }
+  if (iree_status_is_ok(status)) {
+    status = IREE_STRUCT_LAYOUT(
+        0, &byte_length,
+        IREE_STRUCT_FIELD(operation_count,
+                          iree_hal_profile_command_operation_record_t, NULL));
+  }
+  if (iree_status_is_ok(status) &&
+      new_record_count > registry->command_operation_record_capacity) {
+    status = iree_allocator_grow_array(
+        registry->host_allocator,
+        iree_max((iree_host_size_t)64, new_record_count),
+        sizeof(registry->command_operation_records[0]),
+        &registry->command_operation_record_capacity,
+        (void**)&registry->command_operation_records);
+  }
+  if (iree_status_is_ok(status)) {
+    memcpy(registry->command_operation_records +
+               registry->command_operation_record_count,
+           operations, byte_length);
+    registry->command_operation_record_count = new_record_count;
+  }
+
+  iree_slim_mutex_unlock(&registry->mutex);
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+bool iree_hal_amdgpu_profile_metadata_export_matches(
+    iree_hal_amdgpu_profile_metadata_registry_t* registry,
+    uint64_t executable_id, uint32_t export_ordinal,
+    iree_string_view_t pattern) {
+  if (iree_string_view_is_empty(pattern)) return true;
+
+  bool matches = false;
+  iree_slim_mutex_lock(&registry->mutex);
+  iree_host_size_t offset = 0;
+  while (offset + sizeof(iree_hal_profile_executable_export_record_t) <=
+         registry->executable_export_record_data_length) {
+    const uint8_t* record_data =
+        registry->executable_export_record_data + offset;
+    iree_hal_profile_executable_export_record_t record;
+    memcpy(&record, record_data, sizeof(record));
+    if (record.record_length < sizeof(record) ||
+        offset + record.record_length >
+            registry->executable_export_record_data_length ||
+        record.name_length != record.record_length - sizeof(record)) {
+      break;
+    }
+    if (record.executable_id == executable_id &&
+        record.export_ordinal == export_ordinal) {
+      matches = iree_string_view_match_pattern(
+          iree_make_string_view((const char*)record_data + sizeof(record),
+                                record.name_length),
+          pattern);
+      break;
+    }
+    offset += record.record_length;
+  }
+  iree_slim_mutex_unlock(&registry->mutex);
+  return matches;
+}
+
+static void iree_hal_amdgpu_profile_metadata_snapshot_deinitialize(
+    iree_hal_amdgpu_profile_metadata_snapshot_t* snapshot) {
+  iree_allocator_free(snapshot->host_allocator,
+                      snapshot->command_operation_records);
+  iree_allocator_free(snapshot->host_allocator,
+                      snapshot->command_buffer_records);
+  iree_allocator_free(snapshot->host_allocator,
+                      snapshot->executable_export_record_data);
+  iree_allocator_free(snapshot->host_allocator,
+                      snapshot->executable_code_object_load_records);
+  iree_allocator_free(snapshot->host_allocator,
+                      snapshot->executable_code_object_record_data);
+  iree_allocator_free(snapshot->host_allocator, snapshot->executable_records);
+  memset(snapshot, 0, sizeof(*snapshot));
+}
+
+static iree_status_t iree_hal_amdgpu_profile_metadata_snapshot_copy(
+    iree_hal_amdgpu_profile_metadata_registry_t* registry,
+    const iree_hal_amdgpu_profile_metadata_cursor_t* cursor,
+    bool emit_executable_artifacts,
+    iree_hal_amdgpu_profile_metadata_snapshot_t* out_snapshot) {
+  memset(out_snapshot, 0, sizeof(*out_snapshot));
+  out_snapshot->host_allocator = registry->host_allocator;
+
+  iree_slim_mutex_lock(&registry->mutex);
+
+  iree_status_t status = iree_ok_status();
+  if (IREE_UNLIKELY(cursor->executable_record_count >
+                        registry->executable_record_count ||
+                    cursor->executable_code_object_record_data_length >
+                        registry->executable_code_object_record_data_length ||
+                    cursor->executable_code_object_load_record_count >
+                        registry->executable_code_object_load_record_count ||
+                    cursor->executable_export_record_data_length >
+                        registry->executable_export_record_data_length ||
+                    cursor->command_buffer_record_count >
+                        registry->command_buffer_record_count ||
+                    cursor->command_operation_record_count >
+                        registry->command_operation_record_count)) {
+    status =
+        iree_make_status(IREE_STATUS_INTERNAL,
+                         "profile metadata cursor is outside registry bounds");
+  }
+
+  iree_host_size_t executable_record_count = 0;
+  iree_host_size_t executable_code_object_record_data_length = 0;
+  iree_host_size_t executable_code_object_load_record_count = 0;
+  iree_host_size_t executable_export_record_data_length = 0;
+  iree_host_size_t command_buffer_record_count = 0;
+  iree_host_size_t command_operation_record_count = 0;
+  if (iree_status_is_ok(status)) {
+    executable_record_count =
+        registry->executable_record_count - cursor->executable_record_count;
+    executable_code_object_record_data_length =
+        registry->executable_code_object_record_data_length -
+        cursor->executable_code_object_record_data_length;
+    executable_code_object_load_record_count =
+        registry->executable_code_object_load_record_count -
+        cursor->executable_code_object_load_record_count;
+    executable_export_record_data_length =
+        registry->executable_export_record_data_length -
+        cursor->executable_export_record_data_length;
+    command_buffer_record_count = registry->command_buffer_record_count -
+                                  cursor->command_buffer_record_count;
+    command_operation_record_count = registry->command_operation_record_count -
+                                     cursor->command_operation_record_count;
+  }
+
+  if (iree_status_is_ok(status) && executable_record_count > 0) {
+    iree_host_size_t byte_length = 0;
+    status = IREE_STRUCT_LAYOUT(
+        0, &byte_length,
+        IREE_STRUCT_FIELD(executable_record_count,
+                          iree_hal_profile_executable_record_t, NULL));
+    if (iree_status_is_ok(status)) {
+      status = iree_allocator_malloc(registry->host_allocator, byte_length,
+                                     (void**)&out_snapshot->executable_records);
+    }
+    if (iree_status_is_ok(status)) {
+      memcpy(out_snapshot->executable_records,
+             registry->executable_records + cursor->executable_record_count,
+             byte_length);
+      out_snapshot->executable_record_count = executable_record_count;
+    }
+  }
+
+  if (iree_status_is_ok(status) && emit_executable_artifacts &&
+      executable_code_object_record_data_length > 0) {
+    status = iree_allocator_malloc(
+        registry->host_allocator, executable_code_object_record_data_length,
+        (void**)&out_snapshot->executable_code_object_record_data);
+    if (iree_status_is_ok(status)) {
+      memcpy(out_snapshot->executable_code_object_record_data,
+             registry->executable_code_object_record_data +
+                 cursor->executable_code_object_record_data_length,
+             executable_code_object_record_data_length);
+      out_snapshot->executable_code_object_record_data_length =
+          executable_code_object_record_data_length;
+    }
+  }
+
+  if (iree_status_is_ok(status) && emit_executable_artifacts &&
+      executable_code_object_load_record_count > 0) {
+    iree_host_size_t byte_length = 0;
+    status = IREE_STRUCT_LAYOUT(
+        0, &byte_length,
+        IREE_STRUCT_FIELD(executable_code_object_load_record_count,
+                          iree_hal_profile_executable_code_object_load_record_t,
+                          NULL));
+    if (iree_status_is_ok(status)) {
+      status = iree_allocator_malloc(
+          registry->host_allocator, byte_length,
+          (void**)&out_snapshot->executable_code_object_load_records);
+    }
+    if (iree_status_is_ok(status)) {
+      memcpy(out_snapshot->executable_code_object_load_records,
+             registry->executable_code_object_load_records +
+                 cursor->executable_code_object_load_record_count,
+             byte_length);
+      out_snapshot->executable_code_object_load_record_count =
+          executable_code_object_load_record_count;
+    }
+  }
+
+  if (iree_status_is_ok(status) && executable_export_record_data_length > 0) {
+    status = iree_allocator_malloc(
+        registry->host_allocator, executable_export_record_data_length,
+        (void**)&out_snapshot->executable_export_record_data);
+    if (iree_status_is_ok(status)) {
+      memcpy(out_snapshot->executable_export_record_data,
+             registry->executable_export_record_data +
+                 cursor->executable_export_record_data_length,
+             executable_export_record_data_length);
+      out_snapshot->executable_export_record_data_length =
+          executable_export_record_data_length;
+    }
+  }
+
+  if (iree_status_is_ok(status) && command_buffer_record_count > 0) {
+    iree_host_size_t byte_length = 0;
+    status = IREE_STRUCT_LAYOUT(
+        0, &byte_length,
+        IREE_STRUCT_FIELD(command_buffer_record_count,
+                          iree_hal_profile_command_buffer_record_t, NULL));
+    if (iree_status_is_ok(status)) {
+      status =
+          iree_allocator_malloc(registry->host_allocator, byte_length,
+                                (void**)&out_snapshot->command_buffer_records);
+    }
+    if (iree_status_is_ok(status)) {
+      memcpy(out_snapshot->command_buffer_records,
+             registry->command_buffer_records +
+                 cursor->command_buffer_record_count,
+             byte_length);
+      out_snapshot->command_buffer_record_count = command_buffer_record_count;
+    }
+  }
+
+  if (iree_status_is_ok(status) && command_operation_record_count > 0) {
+    iree_host_size_t byte_length = 0;
+    status = IREE_STRUCT_LAYOUT(
+        0, &byte_length,
+        IREE_STRUCT_FIELD(command_operation_record_count,
+                          iree_hal_profile_command_operation_record_t, NULL));
+    if (iree_status_is_ok(status)) {
+      status = iree_allocator_malloc(
+          registry->host_allocator, byte_length,
+          (void**)&out_snapshot->command_operation_records);
+    }
+    if (iree_status_is_ok(status)) {
+      memcpy(out_snapshot->command_operation_records,
+             registry->command_operation_records +
+                 cursor->command_operation_record_count,
+             byte_length);
+      out_snapshot->command_operation_record_count =
+          command_operation_record_count;
+    }
+  }
+
+  if (iree_status_is_ok(status)) {
+    out_snapshot->end_cursor.executable_record_count =
+        registry->executable_record_count;
+    out_snapshot->end_cursor.executable_code_object_record_data_length =
+        registry->executable_code_object_record_data_length;
+    out_snapshot->end_cursor.executable_code_object_load_record_count =
+        registry->executable_code_object_load_record_count;
+    out_snapshot->end_cursor.executable_export_record_data_length =
+        registry->executable_export_record_data_length;
+    out_snapshot->end_cursor.command_buffer_record_count =
+        registry->command_buffer_record_count;
+    out_snapshot->end_cursor.command_operation_record_count =
+        registry->command_operation_record_count;
+  }
+
+  iree_slim_mutex_unlock(&registry->mutex);
+  if (!iree_status_is_ok(status)) {
+    iree_hal_amdgpu_profile_metadata_snapshot_deinitialize(out_snapshot);
+  }
+  return status;
+}
+
+iree_status_t iree_hal_amdgpu_profile_metadata_write(
+    iree_hal_amdgpu_profile_metadata_registry_t* registry,
+    iree_hal_profile_sink_t* sink, uint64_t session_id, iree_string_view_t name,
+    bool emit_executable_artifacts,
+    iree_hal_amdgpu_profile_metadata_cursor_t* cursor) {
+  if (!sink) {
+    return iree_ok_status();
+  }
+
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  iree_hal_amdgpu_profile_metadata_snapshot_t snapshot;
+  iree_status_t status = iree_hal_amdgpu_profile_metadata_snapshot_copy(
+      registry, cursor, emit_executable_artifacts, &snapshot);
+
+  if (iree_status_is_ok(status) && snapshot.executable_record_count > 0) {
+    iree_hal_profile_chunk_metadata_t metadata =
+        iree_hal_profile_chunk_metadata_default();
+    metadata.content_type = IREE_HAL_PROFILE_CONTENT_TYPE_EXECUTABLES;
+    metadata.name = name;
+    metadata.session_id = session_id;
+    iree_const_byte_span_t iovec =
+        iree_make_const_byte_span(snapshot.executable_records,
+                                  snapshot.executable_record_count *
+                                      sizeof(snapshot.executable_records[0]));
+    status = iree_hal_profile_sink_write(sink, &metadata, 1, &iovec);
+  }
+
+  if (iree_status_is_ok(status) &&
+      snapshot.executable_code_object_record_data_length > 0) {
+    iree_hal_profile_chunk_metadata_t metadata =
+        iree_hal_profile_chunk_metadata_default();
+    metadata.content_type =
+        IREE_HAL_PROFILE_CONTENT_TYPE_EXECUTABLE_CODE_OBJECTS;
+    metadata.name = name;
+    metadata.session_id = session_id;
+    iree_const_byte_span_t iovec = iree_make_const_byte_span(
+        snapshot.executable_code_object_record_data,
+        snapshot.executable_code_object_record_data_length);
+    status = iree_hal_profile_sink_write(sink, &metadata, 1, &iovec);
+  }
+
+  if (iree_status_is_ok(status) &&
+      snapshot.executable_code_object_load_record_count > 0) {
+    iree_hal_profile_chunk_metadata_t metadata =
+        iree_hal_profile_chunk_metadata_default();
+    metadata.content_type =
+        IREE_HAL_PROFILE_CONTENT_TYPE_EXECUTABLE_CODE_OBJECT_LOADS;
+    metadata.name = name;
+    metadata.session_id = session_id;
+    iree_const_byte_span_t iovec = iree_make_const_byte_span(
+        snapshot.executable_code_object_load_records,
+        snapshot.executable_code_object_load_record_count *
+            sizeof(snapshot.executable_code_object_load_records[0]));
+    status = iree_hal_profile_sink_write(sink, &metadata, 1, &iovec);
+  }
+
+  if (iree_status_is_ok(status) &&
+      snapshot.executable_export_record_data_length > 0) {
+    iree_hal_profile_chunk_metadata_t metadata =
+        iree_hal_profile_chunk_metadata_default();
+    metadata.content_type = IREE_HAL_PROFILE_CONTENT_TYPE_EXECUTABLE_EXPORTS;
+    metadata.name = name;
+    metadata.session_id = session_id;
+    iree_const_byte_span_t iovec = iree_make_const_byte_span(
+        snapshot.executable_export_record_data,
+        snapshot.executable_export_record_data_length);
+    status = iree_hal_profile_sink_write(sink, &metadata, 1, &iovec);
+  }
+
+  if (iree_status_is_ok(status) && snapshot.command_buffer_record_count > 0) {
+    iree_hal_profile_chunk_metadata_t metadata =
+        iree_hal_profile_chunk_metadata_default();
+    metadata.content_type = IREE_HAL_PROFILE_CONTENT_TYPE_COMMAND_BUFFERS;
+    metadata.name = name;
+    metadata.session_id = session_id;
+    iree_const_byte_span_t iovec = iree_make_const_byte_span(
+        snapshot.command_buffer_records,
+        snapshot.command_buffer_record_count *
+            sizeof(snapshot.command_buffer_records[0]));
+    status = iree_hal_profile_sink_write(sink, &metadata, 1, &iovec);
+  }
+
+  if (iree_status_is_ok(status) &&
+      snapshot.command_operation_record_count > 0) {
+    iree_hal_profile_chunk_metadata_t metadata =
+        iree_hal_profile_chunk_metadata_default();
+    metadata.content_type = IREE_HAL_PROFILE_CONTENT_TYPE_COMMAND_OPERATIONS;
+    metadata.name = name;
+    metadata.session_id = session_id;
+    iree_const_byte_span_t iovec = iree_make_const_byte_span(
+        snapshot.command_operation_records,
+        snapshot.command_operation_record_count *
+            sizeof(snapshot.command_operation_records[0]));
+    status = iree_hal_profile_sink_write(sink, &metadata, 1, &iovec);
+  }
+
+  if (iree_status_is_ok(status)) {
+    *cursor = snapshot.end_cursor;
+  }
+  iree_hal_amdgpu_profile_metadata_snapshot_deinitialize(&snapshot);
+
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/profile_metadata.h b/runtime/src/iree/hal/drivers/amdgpu/profile_metadata.h
new file mode 100644
index 0000000..b1c2362
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/profile_metadata.h
@@ -0,0 +1,199 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_PROFILE_METADATA_H_
+#define IREE_HAL_DRIVERS_AMDGPU_PROFILE_METADATA_H_
+
+#include "iree/base/api.h"
+#include "iree/base/threading/mutex.h"
+#include "iree/hal/api.h"
+#include "iree/hal/drivers/amdgpu/abi/kernel_args.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Cursor tracking the profiling metadata records already emitted to one sink.
+typedef struct iree_hal_amdgpu_profile_metadata_cursor_t {
+  // Number of executable records already emitted.
+  iree_host_size_t executable_record_count;
+  // Byte length of packed executable code-object records already emitted.
+  iree_host_size_t executable_code_object_record_data_length;
+  // Number of executable code-object load records already emitted.
+  iree_host_size_t executable_code_object_load_record_count;
+  // Byte length of packed executable export records already emitted.
+  iree_host_size_t executable_export_record_data_length;
+  // Number of command-buffer records already emitted.
+  iree_host_size_t command_buffer_record_count;
+  // Number of command-operation records already emitted.
+  iree_host_size_t command_operation_record_count;
+} iree_hal_amdgpu_profile_metadata_cursor_t;
+
+// Logical-device-owned registry of durable profiling metadata side tables.
+typedef struct iree_hal_amdgpu_profile_metadata_registry_t {
+  // Host allocator used for registry arrays.
+  iree_allocator_t host_allocator;
+  // Mutex protecting registry growth and snapshot copies.
+  iree_slim_mutex_t mutex;
+  // Next non-zero executable id to assign.
+  uint64_t next_executable_id;
+  // Next non-zero command-buffer id to assign.
+  uint64_t next_command_buffer_id;
+  // Executable records in id assignment order.
+  iree_hal_profile_executable_record_t* executable_records;
+  // Number of valid executable records.
+  iree_host_size_t executable_record_count;
+  // Allocated executable record capacity.
+  iree_host_size_t executable_record_capacity;
+  // Packed executable code-object records in executable id assignment order.
+  uint8_t* executable_code_object_record_data;
+  // Byte length of valid packed executable code-object records.
+  iree_host_size_t executable_code_object_record_data_length;
+  // Allocated byte capacity for packed executable code-object records.
+  iree_host_size_t executable_code_object_record_data_capacity;
+  // Executable code-object load records in executable id assignment order.
+  iree_hal_profile_executable_code_object_load_record_t*
+      executable_code_object_load_records;
+  // Number of valid executable code-object load records.
+  iree_host_size_t executable_code_object_load_record_count;
+  // Allocated executable code-object load record capacity.
+  iree_host_size_t executable_code_object_load_record_capacity;
+  // Packed executable export records in executable id assignment order.
+  uint8_t* executable_export_record_data;
+  // Byte length of valid packed executable export records.
+  iree_host_size_t executable_export_record_data_length;
+  // Allocated byte capacity for packed executable export records.
+  iree_host_size_t executable_export_record_data_capacity;
+  // Command-buffer records in id assignment order.
+  iree_hal_profile_command_buffer_record_t* command_buffer_records;
+  // Number of valid command-buffer records.
+  iree_host_size_t command_buffer_record_count;
+  // Allocated command-buffer record capacity.
+  iree_host_size_t command_buffer_record_capacity;
+  // Command-operation records in command-buffer recording order.
+  iree_hal_profile_command_operation_record_t* command_operation_records;
+  // Number of valid command-operation records.
+  iree_host_size_t command_operation_record_count;
+  // Allocated command-operation record capacity.
+  iree_host_size_t command_operation_record_capacity;
+} iree_hal_amdgpu_profile_metadata_registry_t;
+
+// Loader-reported code-object range for one physical device.
+typedef struct iree_hal_amdgpu_profile_code_object_load_info_t {
+  // Session-local physical device ordinal owning this loaded code object.
+  uint32_t physical_device_ordinal;
+  // Loader-provided code-object load delta used for PC translation.
+  int64_t load_delta;
+  // Byte length of the loaded code-object range on the device.
+  uint64_t load_size;
+} iree_hal_amdgpu_profile_code_object_load_info_t;
+
+// Initializes |out_registry| for logical-device-lifetime metadata.
+void iree_hal_amdgpu_profile_metadata_initialize(
+    iree_allocator_t host_allocator,
+    iree_hal_amdgpu_profile_metadata_registry_t* out_registry);
+
+// Releases all host allocations owned by |registry|.
+void iree_hal_amdgpu_profile_metadata_deinitialize(
+    iree_hal_amdgpu_profile_metadata_registry_t* registry);
+
+// Computes the stable AMDGPU 128-bit code-object content hash.
+//
+// The hash is an IREE-defined identity value, not a security boundary:
+// consumers should use it only for equality/correlation. It is the pair of two
+// SipHash-2-4 outputs with fixed IREE keys over the exact loaded code-object
+// byte sequence.
+void iree_hal_amdgpu_profile_metadata_hash_code_object(
+    iree_const_byte_span_t code_object_data, uint64_t out_hash[2]);
+
+// Registers immutable executable identity metadata and assigns
+// |out_executable_id|.
+//
+// This records the cheap executable/export metadata required to attribute
+// dispatch events and aggregate statistics. Code-object image bytes and loader
+// load ranges are registered separately so normal timing profiles can omit the
+// heavy records while executable trace profiles can still emit them.
+//
+// When |code_object_hash| is provided, each export receives an AMDGPU pipeline
+// hash. Inputs are appended in this order: code_object_hash[0] and
+// code_object_hash[1] as little-endian u64 values, export ordinal as a
+// little-endian u32 value, HAL ABI constant count and binding count as
+// little-endian u16 values, static workgroup size x/y/z as little-endian u32
+// values, and export name byte length as a little-endian u64 value followed by
+// the exact export name bytes.
+//
+// Loader-derived kernel-object, kernarg, private-segment, group-segment, ISA,
+// and code-generation facts are intentionally covered by the exact
+// code-object hash rather than duplicated in the pipeline hash.
+iree_status_t iree_hal_amdgpu_profile_metadata_register_executable(
+    iree_hal_amdgpu_profile_metadata_registry_t* registry,
+    iree_host_size_t export_count,
+    const iree_hal_executable_export_info_t* export_infos,
+    const iree_host_size_t* export_parameter_offsets,
+    const uint64_t code_object_hash[2],
+    const iree_hal_amdgpu_device_kernel_args_t* host_kernel_args,
+    uint64_t* out_executable_id);
+
+// Registers code-object image and load-range artifacts for an executable
+// previously registered with
+// iree_hal_amdgpu_profile_metadata_register_executable.
+//
+// These artifacts are needed by trace/disassembly workflows but are not
+// required for normal dispatch execution or aggregate timing attribution.
+iree_status_t iree_hal_amdgpu_profile_metadata_register_executable_artifacts(
+    iree_hal_amdgpu_profile_metadata_registry_t* registry,
+    uint64_t executable_id, iree_const_byte_span_t code_object_data,
+    const uint64_t code_object_hash[2],
+    iree_host_size_t code_object_load_info_count,
+    const iree_hal_amdgpu_profile_code_object_load_info_t*
+        code_object_load_infos);
+
+// Looks up the code-object load record for |executable_id| on a physical
+// device.
+iree_status_t iree_hal_amdgpu_profile_metadata_lookup_code_object_load(
+    iree_hal_amdgpu_profile_metadata_registry_t* registry,
+    uint64_t executable_id, uint32_t physical_device_ordinal,
+    iree_hal_profile_executable_code_object_load_record_t* out_record);
+
+// Registers immutable command-buffer creation metadata and assigns
+// |out_command_buffer_id|.
+iree_status_t iree_hal_amdgpu_profile_metadata_register_command_buffer(
+    iree_hal_amdgpu_profile_metadata_registry_t* registry,
+    iree_hal_command_buffer_mode_t mode,
+    iree_hal_command_category_t command_categories,
+    iree_hal_queue_affinity_t queue_affinity,
+    iree_host_size_t physical_device_ordinal, uint64_t* out_command_buffer_id);
+
+// Registers immutable command-buffer operation records.
+iree_status_t iree_hal_amdgpu_profile_metadata_register_command_operations(
+    iree_hal_amdgpu_profile_metadata_registry_t* registry,
+    iree_host_size_t operation_count,
+    const iree_hal_profile_command_operation_record_t* operations);
+
+// Returns true if the registered executable export name matches |pattern|.
+bool iree_hal_amdgpu_profile_metadata_export_matches(
+    iree_hal_amdgpu_profile_metadata_registry_t* registry,
+    uint64_t executable_id, uint32_t export_ordinal,
+    iree_string_view_t pattern);
+
+// Writes metadata records newer than |cursor| and advances |cursor| on success.
+//
+// When |emit_executable_artifacts| is false, executable code-object image and
+// load-range records are treated as consumed for this session but are not
+// written. This keeps normal timing profiles small while later
+// executable-trace sessions can still emit self-contained code-object bundles
+// from the durable registry.
+iree_status_t iree_hal_amdgpu_profile_metadata_write(
+    iree_hal_amdgpu_profile_metadata_registry_t* registry,
+    iree_hal_profile_sink_t* sink, uint64_t session_id, iree_string_view_t name,
+    bool emit_executable_artifacts,
+    iree_hal_amdgpu_profile_metadata_cursor_t* cursor);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_PROFILE_METADATA_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/profile_metadata_test.cc b/runtime/src/iree/hal/drivers/amdgpu/profile_metadata_test.cc
new file mode 100644
index 0000000..ce26ae6
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/profile_metadata_test.cc
@@ -0,0 +1,181 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/profile_metadata.h"
+
+#include <cstdint>
+
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace iree::hal::amdgpu {
+namespace {
+
+class ProfileMetadataTest : public ::testing::Test {
+ protected:
+  void SetUp() override {
+    iree_hal_amdgpu_profile_metadata_initialize(iree_allocator_system(),
+                                                &registry_);
+  }
+
+  void TearDown() override {
+    iree_hal_amdgpu_profile_metadata_deinitialize(&registry_);
+  }
+
+  iree_hal_executable_export_info_t MakeExportInfo() {
+    iree_hal_executable_export_info_t export_info = {};
+    export_info.name = IREE_SV("test_dispatch");
+    export_info.constant_count = 3;
+    export_info.binding_count = 2;
+    export_info.workgroup_size[0] = 8;
+    export_info.workgroup_size[1] = 4;
+    export_info.workgroup_size[2] = 1;
+    return export_info;
+  }
+
+  iree_hal_amdgpu_device_kernel_args_t MakeKernelArgs() {
+    iree_hal_amdgpu_device_kernel_args_t kernel_args = {};
+    kernel_args.workgroup_size[0] = 8;
+    kernel_args.workgroup_size[1] = 4;
+    kernel_args.workgroup_size[2] = 1;
+    kernel_args.constant_count = 3;
+    kernel_args.binding_count = 2;
+    return kernel_args;
+  }
+
+  iree_hal_amdgpu_profile_metadata_registry_t registry_;
+};
+
+TEST_F(ProfileMetadataTest, HashCodeObjectGolden) {
+  const uint8_t code_object_data[] = {
+      0x7f, 'E', 'L', 'F', 0x02, 0x01, 0x01, 0x00,
+      'I',  'R', 'E', 'E', 0x00, 0x10, 0x80, 0xff,
+  };
+
+  uint64_t code_object_hash[2] = {};
+  iree_hal_amdgpu_profile_metadata_hash_code_object(
+      iree_make_const_byte_span(code_object_data, sizeof(code_object_data)),
+      code_object_hash);
+
+  EXPECT_EQ(code_object_hash[0], 0xc928aa6b62b629a9ull);
+  EXPECT_EQ(code_object_hash[1], 0xd1a52f9c3082fba5ull);
+}
+
+TEST_F(ProfileMetadataTest, RegisterExecutableRecordsOnlyIdentity) {
+  iree_hal_executable_export_info_t export_info = MakeExportInfo();
+  iree_host_size_t export_parameter_offsets[] = {0, 0};
+  iree_hal_amdgpu_device_kernel_args_t kernel_args = MakeKernelArgs();
+  uint64_t code_object_hash[2] = {0x1111111111111111ull, 0x2222222222222222ull};
+
+  uint64_t executable_id = 0;
+  IREE_ASSERT_OK(iree_hal_amdgpu_profile_metadata_register_executable(
+      &registry_, /*export_count=*/1, &export_info, export_parameter_offsets,
+      code_object_hash, &kernel_args, &executable_id));
+
+  EXPECT_EQ(executable_id, 1u);
+  ASSERT_EQ(registry_.executable_record_count, 1u);
+  EXPECT_EQ(registry_.executable_records[0].executable_id, executable_id);
+  EXPECT_EQ(registry_.executable_records[0].export_count, 1u);
+  EXPECT_NE(registry_.executable_export_record_data_length, 0u);
+  EXPECT_EQ(registry_.executable_code_object_record_data_length, 0u);
+  EXPECT_EQ(registry_.executable_code_object_load_record_count, 0u);
+}
+
+TEST_F(ProfileMetadataTest, RegisterExecutableComputesStablePipelineHash) {
+  iree_hal_executable_export_info_t export_info = MakeExportInfo();
+  iree_host_size_t export_parameter_offsets[] = {0, 3};
+  iree_hal_amdgpu_device_kernel_args_t kernel_args = MakeKernelArgs();
+  uint64_t code_object_hash[2] = {0x0706050403020100ull, 0x1716151413121110ull};
+
+  uint64_t executable_id = 0;
+  IREE_ASSERT_OK(iree_hal_amdgpu_profile_metadata_register_executable(
+      &registry_, /*export_count=*/1, &export_info, export_parameter_offsets,
+      code_object_hash, &kernel_args, &executable_id));
+
+  ASSERT_EQ(registry_.executable_export_record_data_length,
+            sizeof(iree_hal_profile_executable_export_record_t) +
+                export_info.name.size);
+  const auto* export_record =
+      reinterpret_cast<const iree_hal_profile_executable_export_record_t*>(
+          registry_.executable_export_record_data);
+  ASSERT_NE(export_record, nullptr);
+  EXPECT_EQ(export_record->flags,
+            IREE_HAL_PROFILE_EXECUTABLE_EXPORT_FLAG_PIPELINE_HASH);
+  EXPECT_EQ(export_record->executable_id, executable_id);
+  EXPECT_EQ(export_record->export_ordinal, 0u);
+  EXPECT_EQ(export_record->constant_count, 3u);
+  EXPECT_EQ(export_record->binding_count, 2u);
+  EXPECT_EQ(export_record->parameter_count, 3u);
+  EXPECT_EQ(export_record->workgroup_size[0], 8u);
+  EXPECT_EQ(export_record->workgroup_size[1], 4u);
+  EXPECT_EQ(export_record->workgroup_size[2], 1u);
+  EXPECT_EQ(export_record->pipeline_hash[0], 0x12dbd8b44277f553ull);
+  EXPECT_EQ(export_record->pipeline_hash[1], 0x873c5d1c5596dce4ull);
+}
+
+TEST_F(ProfileMetadataTest, RegisterExecutableArtifactsAttachToIdentity) {
+  iree_hal_executable_export_info_t export_info = MakeExportInfo();
+  iree_host_size_t export_parameter_offsets[] = {0, 0};
+  iree_hal_amdgpu_device_kernel_args_t kernel_args = MakeKernelArgs();
+  uint64_t code_object_hash[2] = {0x1111111111111111ull, 0x2222222222222222ull};
+
+  uint64_t executable_id = 0;
+  IREE_ASSERT_OK(iree_hal_amdgpu_profile_metadata_register_executable(
+      &registry_, /*export_count=*/1, &export_info, export_parameter_offsets,
+      code_object_hash, &kernel_args, &executable_id));
+
+  const uint8_t code_object_data[] = {0x7f, 'E', 'L', 'F', 0x01};
+  const iree_hal_amdgpu_profile_code_object_load_info_t load_infos[] = {
+      {
+          .physical_device_ordinal = 0,
+          .load_delta = 0x1000,
+          .load_size = 0x2000,
+      },
+      {
+          .physical_device_ordinal = 1,
+          .load_delta = 0x3000,
+          .load_size = 0x4000,
+      },
+  };
+  IREE_ASSERT_OK(iree_hal_amdgpu_profile_metadata_register_executable_artifacts(
+      &registry_, executable_id,
+      iree_make_const_byte_span(code_object_data, sizeof(code_object_data)),
+      code_object_hash, IREE_ARRAYSIZE(load_infos), load_infos));
+
+  EXPECT_NE(registry_.executable_code_object_record_data_length, 0u);
+  EXPECT_EQ(registry_.executable_code_object_load_record_count, 2u);
+
+  iree_hal_profile_executable_code_object_load_record_t load_record;
+  IREE_ASSERT_OK(iree_hal_amdgpu_profile_metadata_lookup_code_object_load(
+      &registry_, executable_id, /*physical_device_ordinal=*/1, &load_record));
+  EXPECT_EQ(load_record.executable_id, executable_id);
+  EXPECT_EQ(load_record.code_object_id, executable_id);
+  EXPECT_EQ(load_record.load_delta, 0x3000);
+  EXPECT_EQ(load_record.load_size, 0x4000u);
+
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_ALREADY_EXISTS,
+      iree_hal_amdgpu_profile_metadata_register_executable_artifacts(
+          &registry_, executable_id,
+          iree_make_const_byte_span(code_object_data, sizeof(code_object_data)),
+          code_object_hash, IREE_ARRAYSIZE(load_infos), load_infos));
+}
+
+TEST_F(ProfileMetadataTest, RegisterExecutableArtifactsRequiresIdentity) {
+  const uint8_t code_object_data[] = {0x7f, 'E', 'L', 'F', 0x01};
+  uint64_t code_object_hash[2] = {0x1111111111111111ull, 0x2222222222222222ull};
+
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_NOT_FOUND,
+      iree_hal_amdgpu_profile_metadata_register_executable_artifacts(
+          &registry_, /*executable_id=*/42,
+          iree_make_const_byte_span(code_object_data, sizeof(code_object_data)),
+          code_object_hash, /*code_object_load_info_count=*/0,
+          /*code_object_load_infos=*/nullptr));
+}
+
+}  // namespace
+}  // namespace iree::hal::amdgpu
diff --git a/runtime/src/iree/hal/drivers/amdgpu/profile_traces.c b/runtime/src/iree/hal/drivers/amdgpu/profile_traces.c
new file mode 100644
index 0000000..6e8e13a
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/profile_traces.c
@@ -0,0 +1,685 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/profile_traces.h"
+
+#include <string.h>
+
+#include "iree/hal/drivers/amdgpu/host_queue.h"
+#include "iree/hal/drivers/amdgpu/host_queue_profile.h"
+#include "iree/hal/drivers/amdgpu/host_queue_profile_events.h"
+#include "iree/hal/drivers/amdgpu/logical_device.h"
+#include "iree/hal/drivers/amdgpu/physical_device.h"
+#include "iree/hal/drivers/amdgpu/profile_aqlprofile.h"
+#include "iree/hal/drivers/amdgpu/system.h"
+#include "iree/hal/drivers/amdgpu/util/libaqlprofile.h"
+
+//===----------------------------------------------------------------------===//
+// Executable trace support tables
+//===----------------------------------------------------------------------===//
+
+enum { iree_hal_amdgpu_profile_trace_packets_per_event = 3u };
+enum { iree_hal_amdgpu_profile_trace_start_packets_per_event = 2u };
+
+// Requested ATT output bytes per selected in-flight dispatch. aqlprofile owns
+// the actual allocation and releases it with the per-slot packet handle.
+enum { iree_hal_amdgpu_profile_trace_default_buffer_size = 16 * 1024 * 1024 };
+enum { iree_hal_amdgpu_profile_trace_default_se_mask = 1u };
+enum { iree_hal_amdgpu_profile_trace_default_target_cu = 1u };
+enum { iree_hal_amdgpu_profile_trace_default_simd_select = 0xFu };
+
+// Per-queue/per-event-ring-slot mutable aqlprofile ATT capture state.
+struct iree_hal_amdgpu_profile_trace_slot_t {
+  // Callback context retained for the lifetime of |handle|.
+  iree_hal_amdgpu_profile_aqlprofile_memory_context_t memory_context;
+  // aqlprofile handle owning PM4 programs and trace output storage.
+  iree_hal_amdgpu_aqlprofile_handle_t handle;
+  // AQL PM4-IB packet templates referencing |handle|'s immutable PM4 programs.
+  iree_hal_amdgpu_aqlprofile_att_control_aql_packets_t packets;
+  // aqlprofile handle owning the PM4 program for |code_object_marker_packet|.
+  iree_hal_amdgpu_aqlprofile_handle_t code_object_marker_handle;
+  // AQL PM4-IB packet template that publishes the loaded code-object marker.
+  iree_hsa_amd_aql_pm4_ib_packet_t code_object_marker_packet;
+  // Code-object id currently represented by |code_object_marker_packet|.
+  uint64_t code_object_marker_id;
+  // Producer-local trace id assigned when this slot is reserved for a dispatch.
+  uint64_t trace_id;
+};
+
+// Logical-device profiling session for selected executable traces.
+struct iree_hal_amdgpu_profile_trace_session_t {
+  // Host allocator used for session and queue slot storage.
+  iree_allocator_t host_allocator;
+  // Borrowed HSA API table from the logical device.
+  const iree_hal_amdgpu_libhsa_t* libhsa;
+  // Dynamically loaded aqlprofile SDK.
+  iree_hal_amdgpu_libaqlprofile_t libaqlprofile;
+  // Requested bytes in each per-slot ATT trace output buffer.
+  uint64_t trace_buffer_size;
+  // Shader-engine mask passed to aqlprofile ATT packet generation.
+  uint32_t shader_engine_mask;
+  // Compute-unit target passed to aqlprofile ATT packet generation.
+  uint32_t target_compute_unit;
+  // SIMD lane mask passed to aqlprofile ATT packet generation.
+  uint32_t simd_select;
+  // Next nonzero producer-local executable trace id.
+  iree_atomic_int64_t next_trace_id;
+};
+
+// Context threaded through aqlprofile ATT data iteration.
+typedef struct iree_hal_amdgpu_profile_trace_collect_context_t {
+  // Host queue owning the trace slot.
+  iree_hal_amdgpu_host_queue_t* queue;
+  // Profile sink receiving trace chunks.
+  iree_hal_profile_sink_t* sink;
+  // Active profile session id.
+  uint64_t session_id;
+  // Copied dispatch event record correlated with the trace.
+  const iree_hal_profile_dispatch_event_t* event;
+  // Trace slot whose handle is being decoded.
+  const iree_hal_amdgpu_profile_trace_slot_t* slot;
+  // First callback or sink failure encountered during iteration.
+  iree_status_t status;
+} iree_hal_amdgpu_profile_trace_collect_context_t;
+
+static bool iree_hal_amdgpu_profile_trace_mode_requested(
+    const iree_hal_device_profiling_options_t* options) {
+  return iree_hal_device_profiling_options_requests_executable_traces(options);
+}
+
+static iree_status_t iree_hal_amdgpu_profile_trace_create_packets(
+    const iree_hal_amdgpu_profile_trace_session_t* session,
+    iree_hal_amdgpu_physical_device_t* physical_device,
+    const iree_hal_amdgpu_profile_aqlprofile_memory_context_t* memory_context,
+    iree_hal_amdgpu_aqlprofile_handle_t* out_handle,
+    iree_hal_amdgpu_aqlprofile_att_control_aql_packets_t* out_packets) {
+  iree_hal_amdgpu_aqlprofile_att_parameter_t parameters[6];
+  memset(parameters, 0, sizeof(parameters));
+  uint32_t parameter_count = 0;
+  parameters[parameter_count++] = (iree_hal_amdgpu_aqlprofile_att_parameter_t){
+      .parameter_name =
+          IREE_HAL_AMDGPU_AQLPROFILE_PARAMETER_NAME_COMPUTE_UNIT_TARGET,
+      .value = session->target_compute_unit,
+  };
+  parameters[parameter_count++] = (iree_hal_amdgpu_aqlprofile_att_parameter_t){
+      .parameter_name = IREE_HAL_AMDGPU_AQLPROFILE_PARAMETER_NAME_SE_MASK,
+      .value = session->shader_engine_mask,
+  };
+  parameters[parameter_count++] = (iree_hal_amdgpu_aqlprofile_att_parameter_t){
+      .parameter_name =
+          IREE_HAL_AMDGPU_AQLPROFILE_PARAMETER_NAME_SIMD_SELECTION,
+      .value = session->simd_select,
+  };
+  parameters[parameter_count++] = (iree_hal_amdgpu_aqlprofile_att_parameter_t){
+      .parameter_name =
+          IREE_HAL_AMDGPU_AQLPROFILE_PARAMETER_NAME_ATT_BUFFER_SIZE,
+      .value = (uint32_t)session->trace_buffer_size,
+  };
+  if ((session->trace_buffer_size >> 32) != 0) {
+    parameters[parameter_count++] =
+        (iree_hal_amdgpu_aqlprofile_att_parameter_t){
+            .parameter_name =
+                IREE_HAL_AMDGPU_AQLPROFILE_ATT_PARAMETER_NAME_BUFFER_SIZE_HIGH,
+            .value = (uint32_t)(session->trace_buffer_size >> 32),
+        };
+  }
+  parameters[parameter_count++] = (iree_hal_amdgpu_aqlprofile_att_parameter_t){
+      .parameter_name =
+          IREE_HAL_AMDGPU_AQLPROFILE_ATT_PARAMETER_NAME_RT_TIMESTAMP,
+      .value = IREE_HAL_AMDGPU_AQLPROFILE_ATT_PARAMETER_RT_TIMESTAMP_ENABLE,
+  };
+
+  iree_hal_amdgpu_aqlprofile_att_profile_t profile = {
+      .agent = physical_device->device_agent,
+      .parameters = parameters,
+      .parameter_count = parameter_count,
+  };
+  IREE_RETURN_IF_AQLPROFILE_ERROR(
+      &session->libaqlprofile,
+      session->libaqlprofile.aqlprofile_att_create_packets(
+          out_handle, out_packets, profile,
+          iree_hal_amdgpu_profile_aqlprofile_memory_alloc,
+          iree_hal_amdgpu_profile_aqlprofile_memory_dealloc,
+          iree_hal_amdgpu_profile_aqlprofile_memory_copy,
+          (void*)memory_context),
+      "creating AMDGPU ATT PM4 packets");
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_profile_trace_create_code_object_marker(
+    const iree_hal_amdgpu_profile_trace_session_t* session,
+    iree_hal_amdgpu_physical_device_t* physical_device,
+    const iree_hal_amdgpu_profile_aqlprofile_memory_context_t* memory_context,
+    const iree_hal_profile_executable_code_object_load_record_t* load_record,
+    iree_hal_amdgpu_aqlprofile_handle_t* out_handle,
+    iree_hsa_amd_aql_pm4_ib_packet_t* out_packet) {
+  iree_hal_amdgpu_aqlprofile_att_code_object_data_t code_object_data = {
+      .id = load_record->code_object_id,
+      .address = (uint64_t)load_record->load_delta,
+      .length = load_record->load_size,
+      .agent = physical_device->device_agent,
+      .is_unload = 0,
+      .from_start = 1,
+  };
+  IREE_RETURN_IF_AQLPROFILE_ERROR(
+      &session->libaqlprofile,
+      session->libaqlprofile.aqlprofile_att_codeobj_marker(
+          out_packet, out_handle, code_object_data,
+          iree_hal_amdgpu_profile_aqlprofile_memory_alloc,
+          iree_hal_amdgpu_profile_aqlprofile_memory_dealloc,
+          (void*)memory_context),
+      "creating AMDGPU ATT code-object marker packet");
+  return iree_ok_status();
+}
+
+static void iree_hal_amdgpu_profile_trace_destroy_packets(
+    const iree_hal_amdgpu_libaqlprofile_t* libaqlprofile,
+    iree_hal_amdgpu_aqlprofile_handle_t* handle) {
+  if (!handle->handle) return;
+  libaqlprofile->aqlprofile_att_delete_packets(*handle);
+  handle->handle = 0;
+}
+
+static void iree_hal_amdgpu_profile_trace_slot_reset(
+    const iree_hal_amdgpu_libaqlprofile_t* libaqlprofile,
+    iree_hal_amdgpu_profile_trace_slot_t* slot) {
+  iree_hal_amdgpu_profile_trace_destroy_packets(libaqlprofile, &slot->handle);
+  iree_hal_amdgpu_profile_trace_destroy_packets(
+      libaqlprofile, &slot->code_object_marker_handle);
+  memset(slot, 0, sizeof(*slot));
+}
+
+iree_status_t iree_hal_amdgpu_profile_trace_session_allocate(
+    iree_hal_amdgpu_logical_device_t* logical_device,
+    const iree_hal_device_profiling_options_t* options,
+    iree_allocator_t host_allocator,
+    iree_hal_amdgpu_profile_trace_session_t** out_session) {
+  IREE_TRACE_ZONE_BEGIN(z0);
+  *out_session = NULL;
+
+  if (!iree_hal_amdgpu_profile_trace_mode_requested(options)) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_ok_status();
+  }
+  if (IREE_UNLIKELY(!options->sink)) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "AMDGPU executable trace profiling requires a profile sink");
+  }
+  if (IREE_UNLIKELY(iree_hal_profile_capture_filter_is_default(
+          &options->capture_filter))) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "AMDGPU executable trace profiling requires a capture filter; use an "
+        "export pattern, command buffer/id, physical device, or queue filter "
+        "to avoid tracing every dispatch");
+  }
+#if defined(IREE_SANITIZER_ADDRESS)
+  IREE_TRACE_ZONE_END(z0);
+  return iree_make_status(
+      IREE_STATUS_UNAVAILABLE,
+      "AMDGPU executable trace profiling is disabled in ASAN builds because "
+      "ROCm aqlprofile ATT packet creation can abort before returning a "
+      "status; use a non-ASAN build for ATT traces");
+#endif  // IREE_SANITIZER_ADDRESS
+
+  iree_hal_amdgpu_profile_trace_session_t* session = NULL;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_allocator_malloc(host_allocator, sizeof(*session),
+                                (void**)&session));
+  memset(session, 0, sizeof(*session));
+  session->host_allocator = host_allocator;
+  session->libhsa = &logical_device->system->libhsa;
+  session->trace_buffer_size =
+      iree_hal_amdgpu_profile_trace_default_buffer_size;
+  session->shader_engine_mask = iree_hal_amdgpu_profile_trace_default_se_mask;
+  session->target_compute_unit =
+      iree_hal_amdgpu_profile_trace_default_target_cu;
+  session->simd_select = iree_hal_amdgpu_profile_trace_default_simd_select;
+  iree_atomic_store(&session->next_trace_id, 1, iree_memory_order_relaxed);
+
+  iree_status_t status = iree_hal_amdgpu_libaqlprofile_initialize(
+      session->libhsa, iree_string_view_list_empty(), host_allocator,
+      &session->libaqlprofile);
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_libaqlprofile_require_att_support(
+        &session->libaqlprofile,
+        "AMDGPU executable trace profiling requires ATT/SQTT packet generation "
+        "and trace data iteration symbols");
+  }
+
+  if (iree_status_is_ok(status)) {
+    *out_session = session;
+  } else {
+    iree_hal_amdgpu_libaqlprofile_deinitialize(&session->libaqlprofile);
+    iree_allocator_free(host_allocator, session);
+  }
+
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+void iree_hal_amdgpu_profile_trace_session_free(
+    iree_hal_amdgpu_profile_trace_session_t* session) {
+  if (!session) return;
+  IREE_TRACE_ZONE_BEGIN(z0);
+  iree_allocator_t host_allocator = session->host_allocator;
+  iree_hal_amdgpu_libaqlprofile_deinitialize(&session->libaqlprofile);
+  iree_allocator_free(host_allocator, session);
+  IREE_TRACE_ZONE_END(z0);
+}
+
+bool iree_hal_amdgpu_profile_trace_session_is_active(
+    const iree_hal_amdgpu_profile_trace_session_t* session) {
+  return session != NULL;
+}
+
+static iree_hal_amdgpu_profile_trace_slot_t*
+iree_hal_amdgpu_host_queue_profile_trace_slot(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t event_position) {
+  const uint32_t event_index =
+      iree_hal_amdgpu_host_queue_profile_dispatch_event_index(queue,
+                                                              event_position);
+  return &queue->profiling.traces.slots[event_index];
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_enable_profile_traces(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_trace_session_t* session) {
+  if (!iree_hal_amdgpu_profile_trace_session_is_active(session)) {
+    return iree_ok_status();
+  }
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  const uint32_t dispatch_event_capacity =
+      iree_hal_amdgpu_host_queue_profile_dispatch_event_capacity(queue);
+  if (IREE_UNLIKELY(!dispatch_event_capacity)) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_make_status(
+        IREE_STATUS_FAILED_PRECONDITION,
+        "AMDGPU executable trace profiling requires dispatch event storage");
+  }
+  if (IREE_UNLIKELY(!iree_any_bit_set(
+          queue->vendor_packet_capabilities,
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_AQL_PM4_IB))) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_make_status(
+        IREE_STATUS_FAILED_PRECONDITION,
+        "AMDGPU executable trace profiling requires AQL PM4-IB packet support");
+  }
+  if (IREE_UNLIKELY(queue->profiling.traces.slots)) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                            "AMDGPU executable trace profiling is already "
+                            "enabled");
+  }
+
+  iree_host_size_t slot_storage_size = 0;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, IREE_STRUCT_LAYOUT(
+              0, &slot_storage_size,
+              IREE_STRUCT_FIELD(dispatch_event_capacity,
+                                iree_hal_amdgpu_profile_trace_slot_t, NULL)));
+  iree_hal_amdgpu_profile_trace_slot_t* slots = NULL;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_allocator_malloc(queue->host_allocator, slot_storage_size,
+                                (void**)&slots));
+  memset(slots, 0, slot_storage_size);
+
+  queue->profiling.traces.session = session;
+  queue->profiling.traces.slots = slots;
+
+  IREE_TRACE_ZONE_END(z0);
+  return iree_ok_status();
+}
+
+void iree_hal_amdgpu_host_queue_disable_profile_traces(
+    iree_hal_amdgpu_host_queue_t* queue) {
+  if (!queue->profiling.traces.slots) return;
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  iree_hal_amdgpu_profile_trace_session_t* session =
+      queue->profiling.traces.session;
+  const uint32_t dispatch_event_capacity =
+      iree_hal_amdgpu_host_queue_profile_dispatch_event_capacity(queue);
+  for (uint32_t i = 0; i < dispatch_event_capacity; ++i) {
+    iree_hal_amdgpu_profile_trace_slot_reset(&session->libaqlprofile,
+                                             &queue->profiling.traces.slots[i]);
+  }
+  iree_allocator_free(queue->host_allocator, queue->profiling.traces.slots);
+  queue->profiling.traces.session = NULL;
+  queue->profiling.traces.slots = NULL;
+
+  IREE_TRACE_ZONE_END(z0);
+}
+
+uint32_t iree_hal_amdgpu_host_queue_profile_trace_packet_count(
+    const iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_dispatch_event_reservation_t reservation) {
+  if (!reservation.event_count || !queue->profiling.traces.session) return 0;
+  return reservation.event_count *
+         iree_hal_amdgpu_profile_trace_packets_per_event;
+}
+
+uint32_t iree_hal_amdgpu_host_queue_profile_trace_start_packet_count(
+    const iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_dispatch_event_reservation_t reservation) {
+  if (!reservation.event_count || !queue->profiling.traces.session) return 0;
+  return reservation.event_count *
+         iree_hal_amdgpu_profile_trace_start_packets_per_event;
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_prepare_profile_traces(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_dispatch_event_reservation_t reservation) {
+  iree_hal_amdgpu_profile_trace_session_t* session =
+      queue->profiling.traces.session;
+  if (!reservation.event_count || !session) return iree_ok_status();
+
+  iree_hal_amdgpu_logical_device_t* logical_device =
+      (iree_hal_amdgpu_logical_device_t*)queue->logical_device;
+  iree_hal_amdgpu_physical_device_t* physical_device =
+      logical_device->physical_devices[queue->device_ordinal];
+  for (uint32_t event_ordinal = 0; event_ordinal < reservation.event_count;
+       ++event_ordinal) {
+    const uint64_t event_position =
+        reservation.first_event_position + event_ordinal;
+    iree_hal_amdgpu_profile_trace_slot_t* slot =
+        iree_hal_amdgpu_host_queue_profile_trace_slot(queue, event_position);
+    if (!slot->handle.handle) {
+      slot->memory_context =
+          (iree_hal_amdgpu_profile_aqlprofile_memory_context_t){
+              .libhsa = session->libhsa,
+              .device_agent = physical_device->device_agent,
+              .host_memory_pools = &physical_device->host_memory_pools,
+              .device_coarse_pool =
+                  physical_device->coarse_block_pools.large.memory_pool,
+          };
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_profile_trace_create_packets(
+          session, physical_device, &slot->memory_context, &slot->handle,
+          &slot->packets));
+    }
+    slot->trace_id = (uint64_t)iree_atomic_fetch_add(&session->next_trace_id, 1,
+                                                     iree_memory_order_relaxed);
+  }
+  return iree_ok_status();
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_prepare_profile_trace_code_object(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t event_position,
+    uint64_t executable_id) {
+  iree_hal_amdgpu_profile_trace_session_t* session =
+      queue->profiling.traces.session;
+  if (!session) return iree_ok_status();
+
+  iree_hal_amdgpu_profile_trace_slot_t* slot =
+      iree_hal_amdgpu_host_queue_profile_trace_slot(queue, event_position);
+  if (IREE_UNLIKELY(!slot->handle.handle)) {
+    return iree_make_status(
+        IREE_STATUS_FAILED_PRECONDITION,
+        "AMDGPU executable trace slot must be prepared before its code-object "
+        "marker");
+  }
+
+  const uint32_t physical_device_ordinal =
+      iree_hal_amdgpu_host_queue_profile_device_ordinal(queue);
+  iree_hal_amdgpu_logical_device_t* logical_device =
+      (iree_hal_amdgpu_logical_device_t*)queue->logical_device;
+  iree_hal_profile_executable_code_object_load_record_t load_record;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_profile_metadata_lookup_code_object_load(
+      &logical_device->profile_metadata, executable_id, physical_device_ordinal,
+      &load_record));
+
+  if (slot->code_object_marker_handle.handle &&
+      slot->code_object_marker_id == load_record.code_object_id) {
+    return iree_ok_status();
+  }
+
+  iree_hal_amdgpu_profile_trace_destroy_packets(
+      &session->libaqlprofile, &slot->code_object_marker_handle);
+  memset(&slot->code_object_marker_packet, 0,
+         sizeof(slot->code_object_marker_packet));
+  slot->code_object_marker_id = 0;
+
+  iree_hal_amdgpu_physical_device_t* physical_device =
+      logical_device->physical_devices[queue->device_ordinal];
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_profile_trace_create_code_object_marker(
+      session, physical_device, &slot->memory_context, &load_record,
+      &slot->code_object_marker_handle, &slot->code_object_marker_packet));
+  slot->code_object_marker_id = load_record.code_object_id;
+  return iree_ok_status();
+}
+
+static void iree_hal_amdgpu_host_queue_emplace_profile_trace_packet_at(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hsa_amd_aql_pm4_ib_packet_t* source_packet,
+    uint64_t first_packet_id, uint32_t packet_index,
+    iree_hal_amdgpu_aql_packet_control_t packet_control,
+    uint16_t* packet_headers, uint16_t* packet_setups) {
+  iree_hal_amdgpu_aql_packet_t* packet = iree_hal_amdgpu_aql_ring_packet(
+      &queue->aql_ring, first_packet_id + packet_index);
+  iree_hal_amdgpu_profile_aqlprofile_emplace_pm4_ib_packet(
+      source_packet, packet, packet_control, iree_hsa_signal_null(),
+      &packet_headers[packet_index], &packet_setups[packet_index]);
+}
+
+void iree_hal_amdgpu_host_queue_emplace_profile_trace_start_packet(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t event_position,
+    uint64_t first_packet_id, uint32_t first_packet_index,
+    iree_hal_amdgpu_aql_packet_control_t packet_control,
+    uint16_t* packet_headers, uint16_t* packet_setups) {
+  iree_hal_amdgpu_profile_trace_slot_t* slot =
+      iree_hal_amdgpu_host_queue_profile_trace_slot(queue, event_position);
+  iree_hal_amdgpu_host_queue_emplace_profile_trace_packet_at(
+      queue, &slot->packets.start_packet, first_packet_id, first_packet_index,
+      packet_control, packet_headers, packet_setups);
+}
+
+void iree_hal_amdgpu_host_queue_emplace_profile_trace_code_object_packet(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t event_position,
+    uint64_t first_packet_id, uint32_t first_packet_index,
+    iree_hal_amdgpu_aql_packet_control_t packet_control,
+    uint16_t* packet_headers, uint16_t* packet_setups) {
+  iree_hal_amdgpu_profile_trace_slot_t* slot =
+      iree_hal_amdgpu_host_queue_profile_trace_slot(queue, event_position);
+  iree_hal_amdgpu_host_queue_emplace_profile_trace_packet_at(
+      queue, &slot->code_object_marker_packet, first_packet_id,
+      first_packet_index, packet_control, packet_headers, packet_setups);
+}
+
+void iree_hal_amdgpu_host_queue_emplace_profile_trace_stop_packet(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t event_position,
+    uint64_t first_packet_id, uint32_t first_packet_index,
+    iree_hal_amdgpu_aql_packet_control_t packet_control,
+    uint16_t* packet_headers, uint16_t* packet_setups) {
+  iree_hal_amdgpu_profile_trace_slot_t* slot =
+      iree_hal_amdgpu_host_queue_profile_trace_slot(queue, event_position);
+  iree_hal_amdgpu_host_queue_emplace_profile_trace_packet_at(
+      queue, &slot->packets.stop_packet, first_packet_id, first_packet_index,
+      packet_control, packet_headers, packet_setups);
+}
+
+static void iree_hal_amdgpu_host_queue_commit_profile_trace_packet(
+    iree_hal_amdgpu_host_queue_t* queue,
+    const iree_hsa_amd_aql_pm4_ib_packet_t* source_packet, uint64_t packet_id,
+    iree_hal_amdgpu_aql_packet_control_t packet_control) {
+  iree_hal_amdgpu_aql_packet_t* packet =
+      iree_hal_amdgpu_aql_ring_packet(&queue->aql_ring, packet_id);
+  uint16_t header = 0;
+  uint16_t setup = 0;
+  iree_hal_amdgpu_profile_aqlprofile_emplace_pm4_ib_packet(
+      source_packet, packet, packet_control, iree_hsa_signal_null(), &header,
+      &setup);
+  iree_hal_amdgpu_aql_ring_commit(packet, header, setup);
+}
+
+void iree_hal_amdgpu_host_queue_commit_profile_trace_start_packet(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t event_position,
+    uint64_t packet_id, iree_hal_amdgpu_aql_packet_control_t packet_control) {
+  iree_hal_amdgpu_profile_trace_slot_t* slot =
+      iree_hal_amdgpu_host_queue_profile_trace_slot(queue, event_position);
+  iree_hal_amdgpu_host_queue_commit_profile_trace_packet(
+      queue, &slot->packets.start_packet, packet_id, packet_control);
+}
+
+void iree_hal_amdgpu_host_queue_commit_profile_trace_code_object_packet(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t event_position,
+    uint64_t packet_id, iree_hal_amdgpu_aql_packet_control_t packet_control) {
+  iree_hal_amdgpu_profile_trace_slot_t* slot =
+      iree_hal_amdgpu_host_queue_profile_trace_slot(queue, event_position);
+  iree_hal_amdgpu_host_queue_commit_profile_trace_packet(
+      queue, &slot->code_object_marker_packet, packet_id, packet_control);
+}
+
+void iree_hal_amdgpu_host_queue_commit_profile_trace_stop_packet(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t event_position,
+    uint64_t packet_id, iree_hal_amdgpu_aql_packet_control_t packet_control) {
+  iree_hal_amdgpu_profile_trace_slot_t* slot =
+      iree_hal_amdgpu_host_queue_profile_trace_slot(queue, event_position);
+  iree_hal_amdgpu_host_queue_commit_profile_trace_packet(
+      queue, &slot->packets.stop_packet, packet_id, packet_control);
+}
+
+static iree_status_t iree_hal_amdgpu_profile_trace_write_chunk(
+    iree_hal_amdgpu_profile_trace_collect_context_t* context,
+    uint32_t shader_engine, const void* trace_data,
+    iree_host_size_t data_size) {
+  const iree_hal_profile_dispatch_event_t* event = context->event;
+  const iree_hal_amdgpu_host_queue_t* queue = context->queue;
+
+  iree_hal_profile_executable_trace_record_t record =
+      iree_hal_profile_executable_trace_record_default();
+  record.format = IREE_HAL_PROFILE_EXECUTABLE_TRACE_FORMAT_AMDGPU_ATT;
+  record.flags = IREE_HAL_PROFILE_EXECUTABLE_TRACE_FLAG_DISPATCH_EVENT;
+  if (iree_any_bit_set(event->flags,
+                       IREE_HAL_PROFILE_DISPATCH_EVENT_FLAG_COMMAND_BUFFER)) {
+    record.flags |= IREE_HAL_PROFILE_EXECUTABLE_TRACE_FLAG_COMMAND_OPERATION;
+  }
+  record.shader_engine = shader_engine;
+  record.trace_id = context->slot->trace_id;
+  record.dispatch_event_id = event->event_id;
+  record.submission_id = event->submission_id;
+  record.command_buffer_id = event->command_buffer_id;
+  record.executable_id = event->executable_id;
+  record.stream_id = iree_hal_amdgpu_host_queue_profile_stream_id(queue);
+  record.command_index = event->command_index;
+  record.export_ordinal = event->export_ordinal;
+  record.physical_device_ordinal =
+      iree_hal_amdgpu_host_queue_profile_device_ordinal(queue);
+  record.queue_ordinal =
+      iree_hal_amdgpu_host_queue_profile_queue_ordinal(queue);
+  record.data_length = data_size;
+
+  iree_hal_profile_chunk_metadata_t metadata =
+      iree_hal_profile_chunk_metadata_default();
+  metadata.content_type = IREE_HAL_PROFILE_CONTENT_TYPE_EXECUTABLE_TRACES;
+  metadata.name = iree_make_cstring_view("amdgpu.att");
+  metadata.session_id = context->session_id;
+  metadata.stream_id = record.stream_id;
+  metadata.event_id = event->event_id;
+  metadata.executable_id = event->executable_id;
+  metadata.command_buffer_id = event->command_buffer_id;
+  metadata.physical_device_ordinal = record.physical_device_ordinal;
+  metadata.queue_ordinal = record.queue_ordinal;
+
+  iree_const_byte_span_t iovecs[] = {
+      iree_make_const_byte_span(&record, sizeof(record)),
+      iree_make_const_byte_span(trace_data, data_size),
+  };
+  return iree_hal_profile_sink_write(context->sink, &metadata,
+                                     IREE_ARRAYSIZE(iovecs), iovecs);
+}
+
+static hsa_status_t iree_hal_amdgpu_profile_trace_collect_callback(
+    uint32_t shader_engine, void* buffer, uint64_t size, void* user_data) {
+  iree_hal_amdgpu_profile_trace_collect_context_t* context =
+      (iree_hal_amdgpu_profile_trace_collect_context_t*)user_data;
+  if (!iree_status_is_ok(context->status)) return HSA_STATUS_ERROR;
+  if (IREE_UNLIKELY(size > IREE_HOST_SIZE_MAX)) {
+    context->status =
+        iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                         "AMDGPU ATT trace byte length %" PRIu64
+                         " exceeds host addressable size %" PRIhsz,
+                         size, IREE_HOST_SIZE_MAX);
+    return HSA_STATUS_ERROR_OUT_OF_RESOURCES;
+  }
+  context->status = iree_hal_amdgpu_profile_trace_write_chunk(
+      context, shader_engine, buffer, (iree_host_size_t)size);
+  return iree_status_is_ok(context->status) ? HSA_STATUS_SUCCESS
+                                            : HSA_STATUS_ERROR;
+}
+
+static iree_status_t iree_hal_amdgpu_host_queue_write_profile_trace(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_trace_session_t* session,
+    iree_hal_profile_sink_t* sink, uint64_t session_id, uint64_t event_position,
+    const iree_hal_profile_dispatch_event_t* event) {
+  iree_hal_amdgpu_profile_trace_slot_t* slot =
+      iree_hal_amdgpu_host_queue_profile_trace_slot(queue, event_position);
+  if (IREE_UNLIKELY(!slot->handle.handle || slot->trace_id == 0)) {
+    return iree_make_status(
+        IREE_STATUS_FAILED_PRECONDITION,
+        "AMDGPU executable trace slot was not prepared before flush");
+  }
+
+  iree_hal_amdgpu_profile_trace_collect_context_t context = {
+      .queue = queue,
+      .sink = sink,
+      .session_id = session_id,
+      .event = event,
+      .slot = slot,
+      .status = iree_ok_status(),
+  };
+  hsa_status_t hsa_status = session->libaqlprofile.aqlprofile_att_iterate_data(
+      slot->handle, iree_hal_amdgpu_profile_trace_collect_callback, &context);
+  if (!iree_status_is_ok(context.status)) return context.status;
+  return iree_status_from_aqlprofile_status(
+      &session->libaqlprofile, __FILE__, __LINE__, hsa_status,
+      "aqlprofile_att_iterate_data", "iterating AMDGPU ATT trace data");
+}
+
+iree_status_t iree_hal_amdgpu_host_queue_write_profile_traces(
+    iree_hal_amdgpu_host_queue_t* queue, iree_hal_profile_sink_t* sink,
+    uint64_t session_id, uint64_t event_read_position,
+    iree_host_size_t event_count,
+    const iree_hal_profile_dispatch_event_t* events) {
+  if (!sink || !event_count || !queue->profiling.traces.session) {
+    return iree_ok_status();
+  }
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  iree_hal_amdgpu_profile_trace_session_t* session =
+      queue->profiling.traces.session;
+  iree_status_t status = iree_ok_status();
+  for (iree_host_size_t event_ordinal = 0;
+       event_ordinal < event_count && iree_status_is_ok(status);
+       ++event_ordinal) {
+    const uint64_t event_position = event_read_position + event_ordinal;
+    status = iree_hal_amdgpu_host_queue_write_profile_trace(
+        queue, session, sink, session_id, event_position,
+        &events[event_ordinal]);
+  }
+
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+void iree_hal_amdgpu_host_queue_release_profile_trace_slots(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t event_read_position,
+    iree_host_size_t event_count) {
+  if (!event_count || !queue->profiling.traces.session) return;
+  iree_hal_amdgpu_profile_trace_session_t* session =
+      queue->profiling.traces.session;
+  for (iree_host_size_t event_ordinal = 0; event_ordinal < event_count;
+       ++event_ordinal) {
+    const uint64_t event_position = event_read_position + event_ordinal;
+    iree_hal_amdgpu_profile_trace_slot_t* slot =
+        iree_hal_amdgpu_host_queue_profile_trace_slot(queue, event_position);
+    iree_hal_amdgpu_profile_trace_slot_reset(&session->libaqlprofile, slot);
+  }
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/profile_traces.h b/runtime/src/iree/hal/drivers/amdgpu/profile_traces.h
new file mode 100644
index 0000000..f5de925
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/profile_traces.h
@@ -0,0 +1,160 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_PROFILE_TRACES_H_
+#define IREE_HAL_DRIVERS_AMDGPU_PROFILE_TRACES_H_
+
+#include "iree/hal/device.h"
+#include "iree/hal/drivers/amdgpu/util/aql_emitter.h"
+#include "iree/hal/drivers/amdgpu/util/aql_ring.h"
+#include "iree/hal/profile_schema.h"
+#include "iree/hal/profile_sink.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+typedef struct iree_hal_amdgpu_host_queue_t iree_hal_amdgpu_host_queue_t;
+typedef struct iree_hal_amdgpu_logical_device_t
+    iree_hal_amdgpu_logical_device_t;
+typedef struct iree_hal_amdgpu_profile_dispatch_event_reservation_t
+    iree_hal_amdgpu_profile_dispatch_event_reservation_t;
+typedef struct iree_hal_amdgpu_profile_trace_session_t
+    iree_hal_amdgpu_profile_trace_session_t;
+typedef struct iree_hal_amdgpu_profile_trace_slot_t
+    iree_hal_amdgpu_profile_trace_slot_t;
+
+// Allocates an executable trace profiling session from |options|.
+//
+// The returned session is immutable after creation except for its monotonically
+// assigned trace identifiers. The logical-device profiling begin path owns the
+// session and publishes a borrowed pointer to queues while profiling is active.
+iree_status_t iree_hal_amdgpu_profile_trace_session_allocate(
+    iree_hal_amdgpu_logical_device_t* logical_device,
+    const iree_hal_device_profiling_options_t* options,
+    iree_allocator_t host_allocator,
+    iree_hal_amdgpu_profile_trace_session_t** out_session);
+
+// Frees |session| and releases its aqlprofile library reference.
+void iree_hal_amdgpu_profile_trace_session_free(
+    iree_hal_amdgpu_profile_trace_session_t* session);
+
+// Returns true when |session| should emit executable trace packets.
+bool iree_hal_amdgpu_profile_trace_session_is_active(
+    const iree_hal_amdgpu_profile_trace_session_t* session);
+
+// Enables queue-local executable trace storage for |queue|.
+//
+// Allocates one host-side slot per dispatch event ring entry. Slots hold only
+// small control metadata until selected dispatches prepare them; a prepared
+// slot owns one aqlprofile ATT packet handle and its trace output buffer until
+// the corresponding dispatch event is successfully flushed and released.
+iree_status_t iree_hal_amdgpu_host_queue_enable_profile_traces(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_trace_session_t* session);
+
+// Disables queue-local executable trace storage and deletes all remaining slot
+// handles.
+void iree_hal_amdgpu_host_queue_disable_profile_traces(
+    iree_hal_amdgpu_host_queue_t* queue);
+
+// Returns the number of additional AQL packets needed for |reservation|.
+uint32_t iree_hal_amdgpu_host_queue_profile_trace_packet_count(
+    const iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_dispatch_event_reservation_t reservation);
+
+// Returns the number of trace start packets emitted before profiled dispatches.
+uint32_t iree_hal_amdgpu_host_queue_profile_trace_start_packet_count(
+    const iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_dispatch_event_reservation_t reservation);
+
+// Prepares executable trace slots for |reservation|.
+//
+// Caller must hold queue->locks.submission_mutex and must call this only after
+// the dispatch profile events have been reserved. Start/stop handles are
+// created lazily per event-ring slot and then reused only after the dispatch
+// event cursor has advanced past the slot. Because ATT output buffers are
+// large, the event flush path releases the slot handle after all sink writes
+// for the corresponding dispatch event have succeeded instead of retaining it
+// for the full profiling session.
+iree_status_t iree_hal_amdgpu_host_queue_prepare_profile_traces(
+    iree_hal_amdgpu_host_queue_t* queue,
+    iree_hal_amdgpu_profile_dispatch_event_reservation_t reservation);
+
+// Prepares the ATT code-object marker packet for |event_position|.
+//
+// Caller must hold queue->locks.submission_mutex and call
+// iree_hal_amdgpu_host_queue_prepare_profile_traces first for the same event
+// slot so the aqlprofile allocation context has been initialized.
+iree_status_t iree_hal_amdgpu_host_queue_prepare_profile_trace_code_object(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t event_position,
+    uint64_t executable_id);
+
+// Emplaces one ATT start packet for |event_position| at |first_packet_index|.
+void iree_hal_amdgpu_host_queue_emplace_profile_trace_start_packet(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t event_position,
+    uint64_t first_packet_id, uint32_t first_packet_index,
+    iree_hal_amdgpu_aql_packet_control_t packet_control,
+    uint16_t* packet_headers, uint16_t* packet_setups);
+
+// Emplaces one ATT code-object marker packet for |event_position| at
+// |first_packet_index|.
+void iree_hal_amdgpu_host_queue_emplace_profile_trace_code_object_packet(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t event_position,
+    uint64_t first_packet_id, uint32_t first_packet_index,
+    iree_hal_amdgpu_aql_packet_control_t packet_control,
+    uint16_t* packet_headers, uint16_t* packet_setups);
+
+// Emplaces one ATT stop packet for |event_position| at |first_packet_index|.
+void iree_hal_amdgpu_host_queue_emplace_profile_trace_stop_packet(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t event_position,
+    uint64_t first_packet_id, uint32_t first_packet_index,
+    iree_hal_amdgpu_aql_packet_control_t packet_control,
+    uint16_t* packet_headers, uint16_t* packet_setups);
+
+// Commits one ATT start packet for |event_position| at |packet_id|.
+void iree_hal_amdgpu_host_queue_commit_profile_trace_start_packet(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t event_position,
+    uint64_t packet_id, iree_hal_amdgpu_aql_packet_control_t packet_control);
+
+// Commits one ATT code-object marker packet for |event_position| at
+// |packet_id|.
+void iree_hal_amdgpu_host_queue_commit_profile_trace_code_object_packet(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t event_position,
+    uint64_t packet_id, iree_hal_amdgpu_aql_packet_control_t packet_control);
+
+// Commits one ATT stop packet for |event_position| at |packet_id|.
+void iree_hal_amdgpu_host_queue_commit_profile_trace_stop_packet(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t event_position,
+    uint64_t packet_id, iree_hal_amdgpu_aql_packet_control_t packet_control);
+
+// Writes executable trace chunks for retired dispatch events in |events|.
+//
+// The caller must not advance the dispatch event read cursor until this returns
+// successfully. This only writes trace payloads; the caller must release the
+// flushed slots after all sink writes associated with the same event positions
+// have succeeded.
+iree_status_t iree_hal_amdgpu_host_queue_write_profile_traces(
+    iree_hal_amdgpu_host_queue_t* queue, iree_hal_profile_sink_t* sink,
+    uint64_t session_id, uint64_t event_read_position,
+    iree_host_size_t event_count,
+    const iree_hal_profile_dispatch_event_t* events);
+
+// Releases ATT handles for flushed dispatch event positions.
+//
+// The event flush path must call this only after every sink write associated
+// with the event positions has succeeded and before advancing the dispatch
+// event read cursor so those ring slots cannot be reused while they are being
+// reset.
+void iree_hal_amdgpu_host_queue_release_profile_trace_slots(
+    iree_hal_amdgpu_host_queue_t* queue, uint64_t event_read_position,
+    iree_host_size_t event_count);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_PROFILE_TRACES_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/queue_affinity.c b/runtime/src/iree/hal/drivers/amdgpu/queue_affinity.c
new file mode 100644
index 0000000..3961590
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/queue_affinity.c
@@ -0,0 +1,320 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/queue_affinity.h"
+
+static bool iree_hal_amdgpu_queue_affinity_try_normalize(
+    iree_hal_queue_affinity_t supported_affinity,
+    iree_hal_queue_affinity_t requested_affinity,
+    iree_hal_queue_affinity_t* out_normalized_affinity) {
+  iree_hal_queue_affinity_t normalized_affinity =
+      iree_hal_queue_affinity_is_any(requested_affinity) ? supported_affinity
+                                                         : requested_affinity;
+  iree_hal_queue_affinity_and_into(normalized_affinity, supported_affinity);
+  if (iree_hal_queue_affinity_is_empty(normalized_affinity)) return false;
+  *out_normalized_affinity = normalized_affinity;
+  return true;
+}
+
+static bool iree_hal_amdgpu_queue_affinity_try_resolve_ordinal(
+    iree_hal_amdgpu_queue_affinity_domain_t domain,
+    iree_host_size_t queue_ordinal,
+    iree_hal_amdgpu_queue_affinity_resolved_t* out_resolved) {
+  if (domain.queue_count_per_physical_device == 0 ||
+      queue_ordinal >= IREE_HAL_MAX_QUEUES) {
+    return false;
+  }
+
+  const iree_host_size_t physical_device_ordinal =
+      queue_ordinal / domain.queue_count_per_physical_device;
+  if (physical_device_ordinal >= domain.physical_device_count) return false;
+
+  memset(out_resolved, 0, sizeof(*out_resolved));
+  out_resolved->queue_affinity = ((iree_hal_queue_affinity_t)1)
+                                 << queue_ordinal;
+  out_resolved->queue_ordinal = queue_ordinal;
+  out_resolved->physical_device_ordinal = physical_device_ordinal;
+  out_resolved->physical_queue_ordinal =
+      queue_ordinal % domain.queue_count_per_physical_device;
+  return true;
+}
+
+static bool iree_hal_amdgpu_queue_affinity_try_for_physical_device(
+    iree_hal_amdgpu_queue_affinity_domain_t domain,
+    iree_host_size_t physical_device_ordinal,
+    iree_hal_queue_affinity_t* out_queue_affinity) {
+  if (domain.queue_count_per_physical_device == 0 ||
+      physical_device_ordinal >= domain.physical_device_count) {
+    return false;
+  }
+
+  iree_host_size_t first_queue_ordinal = 0;
+  if (!iree_host_size_checked_mul(physical_device_ordinal,
+                                  domain.queue_count_per_physical_device,
+                                  &first_queue_ordinal) ||
+      first_queue_ordinal >= IREE_HAL_MAX_QUEUES ||
+      domain.queue_count_per_physical_device >
+          IREE_HAL_MAX_QUEUES - first_queue_ordinal) {
+    return false;
+  }
+
+  iree_hal_queue_affinity_t queue_affinity = 0;
+  for (iree_host_size_t i = 0; i < domain.queue_count_per_physical_device;
+       ++i) {
+    iree_hal_queue_affinity_or_into(queue_affinity,
+                                    ((iree_hal_queue_affinity_t)1)
+                                        << (first_queue_ordinal + i));
+  }
+  *out_queue_affinity = queue_affinity;
+  return true;
+}
+
+iree_status_t iree_hal_amdgpu_queue_affinity_normalize(
+    iree_hal_queue_affinity_t supported_affinity,
+    iree_hal_queue_affinity_t requested_affinity,
+    iree_hal_queue_affinity_t* out_normalized_affinity) {
+  *out_normalized_affinity = 0;
+
+  if (!iree_hal_amdgpu_queue_affinity_try_normalize(
+          supported_affinity, requested_affinity, out_normalized_affinity)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "no valid queue affinity bits specified");
+  }
+  return iree_ok_status();
+}
+
+iree_status_t iree_hal_amdgpu_queue_affinity_resolve_ordinal(
+    iree_hal_amdgpu_queue_affinity_domain_t domain,
+    iree_host_size_t queue_ordinal,
+    iree_hal_amdgpu_queue_affinity_resolved_t* out_resolved) {
+  memset(out_resolved, 0, sizeof(*out_resolved));
+
+  if (IREE_UNLIKELY(domain.queue_count_per_physical_device == 0)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "AMDGPU queue affinity domain has no queues per physical device");
+  }
+  if (IREE_UNLIKELY(queue_ordinal >= IREE_HAL_MAX_QUEUES)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "queue ordinal %" PRIhsz " exceeds affinity bit capacity %" PRIhsz,
+        queue_ordinal, (iree_host_size_t)IREE_HAL_MAX_QUEUES);
+  }
+
+  const iree_host_size_t physical_device_ordinal =
+      queue_ordinal / domain.queue_count_per_physical_device;
+  if (IREE_UNLIKELY(physical_device_ordinal >= domain.physical_device_count)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "queue ordinal %" PRIhsz
+                            " maps to invalid physical device ordinal %" PRIhsz,
+                            queue_ordinal, physical_device_ordinal);
+  }
+
+  out_resolved->queue_affinity = ((iree_hal_queue_affinity_t)1)
+                                 << queue_ordinal;
+  out_resolved->queue_ordinal = queue_ordinal;
+  out_resolved->physical_device_ordinal = physical_device_ordinal;
+  out_resolved->physical_queue_ordinal =
+      queue_ordinal % domain.queue_count_per_physical_device;
+  return iree_ok_status();
+}
+
+iree_status_t iree_hal_amdgpu_queue_affinity_resolve(
+    iree_hal_amdgpu_queue_affinity_domain_t domain,
+    iree_hal_queue_affinity_t requested_affinity,
+    iree_hal_amdgpu_queue_affinity_resolved_t* out_resolved) {
+  iree_hal_queue_affinity_t normalized_affinity = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_queue_affinity_normalize(
+      domain.supported_affinity, requested_affinity, &normalized_affinity));
+
+  const iree_host_size_t queue_ordinal =
+      iree_hal_queue_affinity_find_first_set(normalized_affinity);
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_queue_affinity_resolve_ordinal(
+      domain, queue_ordinal, out_resolved));
+  out_resolved->queue_affinity = normalized_affinity;
+  return iree_ok_status();
+}
+
+bool iree_hal_amdgpu_queue_affinity_try_resolve(
+    iree_hal_amdgpu_queue_affinity_domain_t domain,
+    iree_hal_queue_affinity_t requested_affinity,
+    iree_hal_amdgpu_queue_affinity_resolved_t* out_resolved) {
+  memset(out_resolved, 0, sizeof(*out_resolved));
+
+  iree_hal_queue_affinity_t normalized_affinity = 0;
+  if (!iree_hal_amdgpu_queue_affinity_try_normalize(domain.supported_affinity,
+                                                    requested_affinity,
+                                                    &normalized_affinity)) {
+    return false;
+  }
+
+  const iree_host_size_t queue_ordinal =
+      iree_hal_queue_affinity_find_first_set(normalized_affinity);
+  if (!iree_hal_amdgpu_queue_affinity_try_resolve_ordinal(domain, queue_ordinal,
+                                                          out_resolved)) {
+    return false;
+  }
+  out_resolved->queue_affinity = normalized_affinity;
+  return true;
+}
+
+iree_status_t iree_hal_amdgpu_queue_affinity_for_physical_device(
+    iree_hal_amdgpu_queue_affinity_domain_t domain,
+    iree_host_size_t physical_device_ordinal,
+    iree_hal_queue_affinity_t* out_queue_affinity) {
+  *out_queue_affinity = 0;
+
+  if (IREE_UNLIKELY(domain.queue_count_per_physical_device == 0)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "AMDGPU queue affinity domain has no queues per physical device");
+  }
+  if (IREE_UNLIKELY(physical_device_ordinal >= domain.physical_device_count)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "physical device ordinal %" PRIhsz
+                            " exceeds physical device count %" PRIhsz,
+                            physical_device_ordinal,
+                            domain.physical_device_count);
+  }
+
+  iree_host_size_t first_queue_ordinal = 0;
+  if (!iree_host_size_checked_mul(physical_device_ordinal,
+                                  domain.queue_count_per_physical_device,
+                                  &first_queue_ordinal) ||
+      first_queue_ordinal >= IREE_HAL_MAX_QUEUES ||
+      domain.queue_count_per_physical_device >
+          IREE_HAL_MAX_QUEUES - first_queue_ordinal) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "physical device queue range does not fit in queue affinity "
+        "(physical_device_ordinal=%" PRIhsz
+        ", queue_count_per_physical_device=%" PRIhsz ")",
+        physical_device_ordinal, domain.queue_count_per_physical_device);
+  }
+
+  iree_hal_queue_affinity_t queue_affinity = 0;
+  for (iree_host_size_t i = 0; i < domain.queue_count_per_physical_device;
+       ++i) {
+    iree_hal_queue_affinity_or_into(queue_affinity,
+                                    ((iree_hal_queue_affinity_t)1)
+                                        << (first_queue_ordinal + i));
+  }
+  *out_queue_affinity = queue_affinity;
+  return iree_ok_status();
+}
+
+iree_status_t iree_hal_amdgpu_queue_affinity_select_physical_devices(
+    iree_hal_amdgpu_queue_affinity_domain_t domain,
+    iree_hal_queue_affinity_t requested_affinity,
+    iree_hal_amdgpu_queue_affinity_physical_device_set_t*
+        out_physical_device_set) {
+  memset(out_physical_device_set, 0, sizeof(*out_physical_device_set));
+
+  if (IREE_UNLIKELY(domain.physical_device_count > IREE_HAL_MAX_QUEUES)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "AMDGPU physical device count %" PRIhsz
+                            " exceeds physical device mask capacity %" PRIhsz,
+                            domain.physical_device_count,
+                            (iree_host_size_t)IREE_HAL_MAX_QUEUES);
+  }
+
+  iree_hal_queue_affinity_t normalized_affinity = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_queue_affinity_normalize(
+      domain.supported_affinity, requested_affinity, &normalized_affinity));
+
+  out_physical_device_set->queue_affinity = normalized_affinity;
+  for (iree_host_size_t physical_device_ordinal = 0;
+       physical_device_ordinal < domain.physical_device_count;
+       ++physical_device_ordinal) {
+    iree_hal_queue_affinity_t physical_device_affinity = 0;
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_queue_affinity_for_physical_device(
+        domain, physical_device_ordinal, &physical_device_affinity));
+    iree_hal_queue_affinity_and_into(physical_device_affinity,
+                                     domain.supported_affinity);
+
+    iree_hal_queue_affinity_t selected_affinity = normalized_affinity;
+    iree_hal_queue_affinity_and_into(selected_affinity,
+                                     physical_device_affinity);
+    if (iree_hal_queue_affinity_is_empty(selected_affinity)) continue;
+
+    if (out_physical_device_set->physical_device_count == 0) {
+      out_physical_device_set->first_physical_device_ordinal =
+          physical_device_ordinal;
+    }
+    out_physical_device_set->physical_device_mask |= ((uint64_t)1)
+                                                     << physical_device_ordinal;
+    ++out_physical_device_set->physical_device_count;
+  }
+
+  if (IREE_UNLIKELY(out_physical_device_set->physical_device_count == 0)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "queue affinity 0x%" PRIx64
+                            " selects no physical devices",
+                            requested_affinity);
+  }
+  return iree_ok_status();
+}
+
+iree_status_t iree_hal_amdgpu_queue_affinity_normalize_for_physical_device(
+    iree_hal_amdgpu_queue_affinity_domain_t domain,
+    iree_hal_queue_affinity_t requested_affinity,
+    iree_hal_queue_affinity_t* out_queue_affinity,
+    iree_host_size_t* out_physical_device_ordinal) {
+  *out_queue_affinity = 0;
+  *out_physical_device_ordinal = 0;
+
+  iree_hal_amdgpu_queue_affinity_resolved_t resolved;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_queue_affinity_resolve(
+      domain, requested_affinity, &resolved));
+
+  iree_hal_queue_affinity_t physical_device_affinity = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_queue_affinity_for_physical_device(
+      domain, resolved.physical_device_ordinal, &physical_device_affinity));
+  iree_hal_queue_affinity_and_into(physical_device_affinity,
+                                   domain.supported_affinity);
+
+  const bool is_any_affinity =
+      iree_hal_queue_affinity_is_any(requested_affinity);
+  if (!is_any_affinity &&
+      iree_any_bit_set(resolved.queue_affinity, ~physical_device_affinity)) {
+    return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                            "AMDGPU queue affinity 0x%" PRIx64
+                            " spans multiple physical devices",
+                            requested_affinity);
+  }
+
+  iree_hal_queue_affinity_t selected_affinity = physical_device_affinity;
+  if (!is_any_affinity) {
+    selected_affinity = resolved.queue_affinity;
+    iree_hal_queue_affinity_and_into(selected_affinity,
+                                     physical_device_affinity);
+  }
+
+  *out_queue_affinity = selected_affinity;
+  *out_physical_device_ordinal = resolved.physical_device_ordinal;
+  return iree_ok_status();
+}
+
+bool iree_hal_amdgpu_queue_affinity_is_physical_device_local(
+    iree_hal_amdgpu_queue_affinity_domain_t domain,
+    iree_hal_queue_affinity_t requested_affinity,
+    iree_host_size_t physical_device_ordinal) {
+  iree_hal_queue_affinity_t normalized_affinity = 0;
+  if (!iree_hal_amdgpu_queue_affinity_try_normalize(domain.supported_affinity,
+                                                    requested_affinity,
+                                                    &normalized_affinity)) {
+    return false;
+  }
+
+  iree_hal_queue_affinity_t physical_device_affinity = 0;
+  if (!iree_hal_amdgpu_queue_affinity_try_for_physical_device(
+          domain, physical_device_ordinal, &physical_device_affinity)) {
+    return false;
+  }
+  iree_hal_queue_affinity_and_into(physical_device_affinity,
+                                   domain.supported_affinity);
+  return !iree_any_bit_set(normalized_affinity, ~physical_device_affinity);
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/queue_affinity.h b/runtime/src/iree/hal/drivers/amdgpu/queue_affinity.h
new file mode 100644
index 0000000..e483d3c
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/queue_affinity.h
@@ -0,0 +1,125 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_QUEUE_AFFINITY_H_
+#define IREE_HAL_DRIVERS_AMDGPU_QUEUE_AFFINITY_H_
+
+#include "iree/base/api.h"
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Describes the AMDGPU logical-device queue affinity domain.
+typedef struct iree_hal_amdgpu_queue_affinity_domain_t {
+  // Queue bits supported by the logical device.
+  iree_hal_queue_affinity_t supported_affinity;
+
+  // Number of physical GPU devices in the logical device.
+  iree_host_size_t physical_device_count;
+
+  // Logical queues assigned to each physical GPU device.
+  iree_host_size_t queue_count_per_physical_device;
+} iree_hal_amdgpu_queue_affinity_domain_t;
+
+// Resolved queue affinity selection in the flattened HAL queue space.
+typedef struct iree_hal_amdgpu_queue_affinity_resolved_t {
+  // Queue bits remaining after applying the supported affinity mask.
+  iree_hal_queue_affinity_t queue_affinity;
+
+  // Flattened logical queue ordinal selected from |queue_affinity|.
+  iree_host_size_t queue_ordinal;
+
+  // Physical GPU device ordinal owning |queue_ordinal|.
+  iree_host_size_t physical_device_ordinal;
+
+  // Queue ordinal relative to |physical_device_ordinal|.
+  iree_host_size_t physical_queue_ordinal;
+} iree_hal_amdgpu_queue_affinity_resolved_t;
+
+// Physical GPU devices selected by a HAL queue-affinity mask.
+typedef struct iree_hal_amdgpu_queue_affinity_physical_device_set_t {
+  // Queue bits remaining after applying the supported affinity mask.
+  iree_hal_queue_affinity_t queue_affinity;
+
+  // Bitmask of selected physical GPU device ordinals.
+  uint64_t physical_device_mask;
+
+  // First selected physical GPU device ordinal.
+  iree_host_size_t first_physical_device_ordinal;
+
+  // Total selected physical GPU device count.
+  iree_host_size_t physical_device_count;
+} iree_hal_amdgpu_queue_affinity_physical_device_set_t;
+
+// Normalizes |requested_affinity| against |supported_affinity|.
+//
+// IREE_HAL_QUEUE_AFFINITY_ANY expands to |supported_affinity|. Explicit masks
+// are intersected with |supported_affinity|, matching HAL queue submission
+// behavior where a multi-bit mask means "any of these queues that exist".
+iree_status_t iree_hal_amdgpu_queue_affinity_normalize(
+    iree_hal_queue_affinity_t supported_affinity,
+    iree_hal_queue_affinity_t requested_affinity,
+    iree_hal_queue_affinity_t* out_normalized_affinity);
+
+// Resolves a flattened logical queue ordinal within |domain|.
+iree_status_t iree_hal_amdgpu_queue_affinity_resolve_ordinal(
+    iree_hal_amdgpu_queue_affinity_domain_t domain,
+    iree_host_size_t queue_ordinal,
+    iree_hal_amdgpu_queue_affinity_resolved_t* out_resolved);
+
+// Resolves |requested_affinity| to the deterministic first selected queue.
+iree_status_t iree_hal_amdgpu_queue_affinity_resolve(
+    iree_hal_amdgpu_queue_affinity_domain_t domain,
+    iree_hal_queue_affinity_t requested_affinity,
+    iree_hal_amdgpu_queue_affinity_resolved_t* out_resolved);
+
+// Attempts to resolve |requested_affinity| without constructing a status.
+//
+// This is for compatibility queries and other predicate-style cold paths where
+// invalid input is expected to return false instead of producing diagnostics.
+bool iree_hal_amdgpu_queue_affinity_try_resolve(
+    iree_hal_amdgpu_queue_affinity_domain_t domain,
+    iree_hal_queue_affinity_t requested_affinity,
+    iree_hal_amdgpu_queue_affinity_resolved_t* out_resolved);
+
+// Builds the queue affinity mask for one physical device in |domain|.
+iree_status_t iree_hal_amdgpu_queue_affinity_for_physical_device(
+    iree_hal_amdgpu_queue_affinity_domain_t domain,
+    iree_host_size_t physical_device_ordinal,
+    iree_hal_queue_affinity_t* out_queue_affinity);
+
+// Selects every physical GPU device with at least one queue in
+// |requested_affinity|.
+iree_status_t iree_hal_amdgpu_queue_affinity_select_physical_devices(
+    iree_hal_amdgpu_queue_affinity_domain_t domain,
+    iree_hal_queue_affinity_t requested_affinity,
+    iree_hal_amdgpu_queue_affinity_physical_device_set_t*
+        out_physical_device_set);
+
+// Normalizes |requested_affinity| to queues owned by a single physical device.
+//
+// IREE_HAL_QUEUE_AFFINITY_ANY selects all supported queues for the first
+// supported physical device in the domain. Explicit masks must not span
+// physical devices after intersecting with |domain.supported_affinity|.
+iree_status_t iree_hal_amdgpu_queue_affinity_normalize_for_physical_device(
+    iree_hal_amdgpu_queue_affinity_domain_t domain,
+    iree_hal_queue_affinity_t requested_affinity,
+    iree_hal_queue_affinity_t* out_queue_affinity,
+    iree_host_size_t* out_physical_device_ordinal);
+
+// Returns true if |requested_affinity| selects only queues on one device.
+bool iree_hal_amdgpu_queue_affinity_is_physical_device_local(
+    iree_hal_amdgpu_queue_affinity_domain_t domain,
+    iree_hal_queue_affinity_t requested_affinity,
+    iree_host_size_t physical_device_ordinal);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_QUEUE_AFFINITY_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/queue_affinity_test.cc b/runtime/src/iree/hal/drivers/amdgpu/queue_affinity_test.cc
new file mode 100644
index 0000000..3cc8f34
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/queue_affinity_test.cc
@@ -0,0 +1,196 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/queue_affinity.h"
+
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace iree::hal::amdgpu {
+namespace {
+
+static iree_hal_amdgpu_queue_affinity_domain_t TwoDeviceDomain() {
+  return (iree_hal_amdgpu_queue_affinity_domain_t){
+      .supported_affinity = 0xFull,
+      .physical_device_count = 2,
+      .queue_count_per_physical_device = 2,
+  };
+}
+
+TEST(QueueAffinityTest, NormalizeAnyExpandsToSupportedAffinity) {
+  iree_hal_queue_affinity_t normalized_affinity = 0;
+  IREE_ASSERT_OK(iree_hal_amdgpu_queue_affinity_normalize(
+      0xFull, IREE_HAL_QUEUE_AFFINITY_ANY, &normalized_affinity));
+  EXPECT_EQ(normalized_affinity, 0xFull);
+}
+
+TEST(QueueAffinityTest, NormalizeIntersectsExplicitAffinity) {
+  iree_hal_queue_affinity_t normalized_affinity = 0;
+  IREE_ASSERT_OK(iree_hal_amdgpu_queue_affinity_normalize(
+      0x3ull, 0x5ull, &normalized_affinity));
+  EXPECT_EQ(normalized_affinity, 0x1ull);
+}
+
+TEST(QueueAffinityTest, NormalizeRejectsEmptyIntersection) {
+  iree_hal_queue_affinity_t normalized_affinity = 0;
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_INVALID_ARGUMENT,
+                        iree_hal_amdgpu_queue_affinity_normalize(
+                            0x3ull, 0x4ull, &normalized_affinity));
+}
+
+TEST(QueueAffinityTest, ResolveSelectsFirstQueue) {
+  iree_hal_amdgpu_queue_affinity_resolved_t resolved;
+  IREE_ASSERT_OK(iree_hal_amdgpu_queue_affinity_resolve(TwoDeviceDomain(),
+                                                        0xAull, &resolved));
+  EXPECT_EQ(resolved.queue_affinity, 0xAull);
+  EXPECT_EQ(resolved.queue_ordinal, 1);
+  EXPECT_EQ(resolved.physical_device_ordinal, 0);
+  EXPECT_EQ(resolved.physical_queue_ordinal, 1);
+}
+
+TEST(QueueAffinityTest, ResolveOrdinalMapsPhysicalDeviceAndQueue) {
+  iree_hal_amdgpu_queue_affinity_resolved_t resolved;
+  IREE_ASSERT_OK(iree_hal_amdgpu_queue_affinity_resolve_ordinal(
+      TwoDeviceDomain(), 3, &resolved));
+  EXPECT_EQ(resolved.queue_affinity, 0x8ull);
+  EXPECT_EQ(resolved.queue_ordinal, 3);
+  EXPECT_EQ(resolved.physical_device_ordinal, 1);
+  EXPECT_EQ(resolved.physical_queue_ordinal, 1);
+}
+
+TEST(QueueAffinityTest, ResolveOrdinalRejectsOutOfRangeDevice) {
+  iree_hal_amdgpu_queue_affinity_resolved_t resolved;
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_OUT_OF_RANGE,
+                        iree_hal_amdgpu_queue_affinity_resolve_ordinal(
+                            TwoDeviceDomain(), 4, &resolved));
+}
+
+TEST(QueueAffinityTest, TryResolveReturnsFalseForInvalidAffinity) {
+  iree_hal_amdgpu_queue_affinity_resolved_t resolved;
+  EXPECT_FALSE(iree_hal_amdgpu_queue_affinity_try_resolve(TwoDeviceDomain(),
+                                                          0x10ull, &resolved));
+}
+
+TEST(QueueAffinityTest, PhysicalDeviceAffinityBuildsQueueRange) {
+  iree_hal_queue_affinity_t queue_affinity = 0;
+  IREE_ASSERT_OK(iree_hal_amdgpu_queue_affinity_for_physical_device(
+      TwoDeviceDomain(), 1, &queue_affinity));
+  EXPECT_EQ(queue_affinity, 0xCull);
+}
+
+TEST(QueueAffinityTest, SelectPhysicalDevicesForAnySelectsAllDevices) {
+  iree_hal_amdgpu_queue_affinity_physical_device_set_t physical_devices;
+  IREE_ASSERT_OK(iree_hal_amdgpu_queue_affinity_select_physical_devices(
+      TwoDeviceDomain(), IREE_HAL_QUEUE_AFFINITY_ANY, &physical_devices));
+  EXPECT_EQ(physical_devices.queue_affinity, 0xFull);
+  EXPECT_EQ(physical_devices.physical_device_mask, 0x3ull);
+  EXPECT_EQ(physical_devices.first_physical_device_ordinal, 0);
+  EXPECT_EQ(physical_devices.physical_device_count, 2);
+}
+
+TEST(QueueAffinityTest, SelectPhysicalDevicesForExplicitDeviceMask) {
+  iree_hal_amdgpu_queue_affinity_physical_device_set_t physical_devices;
+  IREE_ASSERT_OK(iree_hal_amdgpu_queue_affinity_select_physical_devices(
+      TwoDeviceDomain(), 0x8ull, &physical_devices));
+  EXPECT_EQ(physical_devices.queue_affinity, 0x8ull);
+  EXPECT_EQ(physical_devices.physical_device_mask, 0x2ull);
+  EXPECT_EQ(physical_devices.first_physical_device_ordinal, 1);
+  EXPECT_EQ(physical_devices.physical_device_count, 1);
+}
+
+TEST(QueueAffinityTest, SelectPhysicalDevicesForCrossDeviceMask) {
+  iree_hal_amdgpu_queue_affinity_physical_device_set_t physical_devices;
+  IREE_ASSERT_OK(iree_hal_amdgpu_queue_affinity_select_physical_devices(
+      TwoDeviceDomain(), 0x5ull, &physical_devices));
+  EXPECT_EQ(physical_devices.queue_affinity, 0x5ull);
+  EXPECT_EQ(physical_devices.physical_device_mask, 0x3ull);
+  EXPECT_EQ(physical_devices.first_physical_device_ordinal, 0);
+  EXPECT_EQ(physical_devices.physical_device_count, 2);
+}
+
+TEST(QueueAffinityTest, SelectPhysicalDevicesIntersectsUnsupportedBits) {
+  iree_hal_amdgpu_queue_affinity_domain_t domain = TwoDeviceDomain();
+  domain.supported_affinity = 0xDull;
+
+  iree_hal_amdgpu_queue_affinity_physical_device_set_t physical_devices;
+  IREE_ASSERT_OK(iree_hal_amdgpu_queue_affinity_select_physical_devices(
+      domain, 0xAull, &physical_devices));
+  EXPECT_EQ(physical_devices.queue_affinity, 0x8ull);
+  EXPECT_EQ(physical_devices.physical_device_mask, 0x2ull);
+  EXPECT_EQ(physical_devices.first_physical_device_ordinal, 1);
+  EXPECT_EQ(physical_devices.physical_device_count, 1);
+}
+
+TEST(QueueAffinityTest, SelectPhysicalDevicesRejectsEmptyIntersection) {
+  iree_hal_amdgpu_queue_affinity_physical_device_set_t physical_devices;
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_INVALID_ARGUMENT,
+                        iree_hal_amdgpu_queue_affinity_select_physical_devices(
+                            TwoDeviceDomain(), 0x10ull, &physical_devices));
+}
+
+TEST(QueueAffinityTest, SelectPhysicalDevicesRejectsUnmappedAffinity) {
+  iree_hal_amdgpu_queue_affinity_domain_t domain = TwoDeviceDomain();
+  domain.supported_affinity = 0x10ull;
+
+  iree_hal_amdgpu_queue_affinity_physical_device_set_t physical_devices;
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_INVALID_ARGUMENT,
+                        iree_hal_amdgpu_queue_affinity_select_physical_devices(
+                            domain, 0x10ull, &physical_devices));
+}
+
+TEST(QueueAffinityTest, SelectPhysicalDevicesRejectsUnrepresentableMask) {
+  iree_hal_amdgpu_queue_affinity_domain_t domain = {
+      .supported_affinity = 1ull,
+      .physical_device_count = 65,
+      .queue_count_per_physical_device = 1,
+  };
+
+  iree_hal_amdgpu_queue_affinity_physical_device_set_t physical_devices;
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_OUT_OF_RANGE,
+                        iree_hal_amdgpu_queue_affinity_select_physical_devices(
+                            domain, 1ull, &physical_devices));
+}
+
+TEST(QueueAffinityTest, NormalizeAnyForPhysicalDeviceSelectsFirstDevice) {
+  iree_hal_queue_affinity_t queue_affinity = 0;
+  iree_host_size_t physical_device_ordinal = 0;
+  IREE_ASSERT_OK(iree_hal_amdgpu_queue_affinity_normalize_for_physical_device(
+      TwoDeviceDomain(), IREE_HAL_QUEUE_AFFINITY_ANY, &queue_affinity,
+      &physical_device_ordinal));
+  EXPECT_EQ(queue_affinity, 0x3ull);
+  EXPECT_EQ(physical_device_ordinal, 0);
+}
+
+TEST(QueueAffinityTest, NormalizeExplicitForPhysicalDeviceKeepsSelectedBits) {
+  iree_hal_queue_affinity_t queue_affinity = 0;
+  iree_host_size_t physical_device_ordinal = 0;
+  IREE_ASSERT_OK(iree_hal_amdgpu_queue_affinity_normalize_for_physical_device(
+      TwoDeviceDomain(), 0xCull, &queue_affinity, &physical_device_ordinal));
+  EXPECT_EQ(queue_affinity, 0xCull);
+  EXPECT_EQ(physical_device_ordinal, 1);
+}
+
+TEST(QueueAffinityTest, NormalizeForPhysicalDeviceRejectsCrossDeviceMask) {
+  iree_hal_queue_affinity_t queue_affinity = 0;
+  iree_host_size_t physical_device_ordinal = 0;
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_UNIMPLEMENTED,
+      iree_hal_amdgpu_queue_affinity_normalize_for_physical_device(
+          TwoDeviceDomain(), 0x5ull, &queue_affinity,
+          &physical_device_ordinal));
+}
+
+TEST(QueueAffinityTest, DeviceLocalAffinity) {
+  EXPECT_TRUE(iree_hal_amdgpu_queue_affinity_is_physical_device_local(
+      TwoDeviceDomain(), 0xCull, 1));
+  EXPECT_FALSE(iree_hal_amdgpu_queue_affinity_is_physical_device_local(
+      TwoDeviceDomain(), 0x5ull, 1));
+  EXPECT_FALSE(iree_hal_amdgpu_queue_affinity_is_physical_device_local(
+      TwoDeviceDomain(), IREE_HAL_QUEUE_AFFINITY_ANY, 1));
+}
+
+}  // namespace
+}  // namespace iree::hal::amdgpu
diff --git a/runtime/src/iree/hal/drivers/amdgpu/registration/driver_module.c b/runtime/src/iree/hal/drivers/amdgpu/registration/driver_module.c
index 66ea420..3189cb5 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/registration/driver_module.c
+++ b/runtime/src/iree/hal/drivers/amdgpu/registration/driver_module.c
@@ -20,6 +20,9 @@
 IREE_FLAG(int64_t, amdgpu_host_block_pool_large_size, 0,
           "Size in bytes of a large host block in the pool. Must be a power of "
           "two or 0 for the default.");
+IREE_FLAG(int64_t, amdgpu_host_block_pool_command_buffer_size, 0,
+          "Usable size in bytes of a command-buffer recording block in the "
+          "host block pool. Must be a power of two or 0 for the default.");
 
 IREE_FLAG(int64_t, amdgpu_device_block_pool_small_size, 0,
           "Size in bytes of a small device block in the pool. Must be a power "
@@ -34,27 +37,56 @@
           "Initial large block pool block allocation count in blocks or 0 for "
           "the default.");
 
+IREE_FLAG(int64_t, amdgpu_default_pool_range_length, 0,
+          "Logical byte length of the default TLSF queue-allocation pool per "
+          "physical device or 0 for the default.");
+IREE_FLAG(int64_t, amdgpu_default_pool_alignment, 0,
+          "Minimum byte alignment for default-pool reservations. Must be a "
+          "power of two or 0 for the default.");
+IREE_FLAG(int32_t, amdgpu_default_pool_frontier_capacity, 0,
+          "Maximum death-frontier entry count stored per free default-pool "
+          "block or 0 for the default.");
+
 IREE_FLAG(string, amdgpu_queue_placement, "any",
-          "Device queue placement: 'any' (best possible based on topology), "
-          "'host', or 'device'.");
+          "Device queue placement: 'any' (currently host), 'host', or "
+          "'device' (reserved and currently unsupported).");
 
 IREE_FLAG(bool, amdgpu_preallocate_pools, true,
           "Preallocates a reasonable number of resources in pools to reduce "
           "initial execution latency.");
 
-IREE_FLAG(bool, amdgpu_trace_execution, false,
-          "Enables dispatch-level tracing (if device instrumentation is "
-          "compiled in).");
-
 IREE_FLAG(bool, amdgpu_exclusive_execution, false,
-          "Forces queues to run one entry at a time instead of overlapping or "
-          "aggressively scheduling queue entries out-of-order.");
+          "Reserved for exclusive queue scheduling; currently unsupported.");
+
+IREE_FLAG(
+    bool, amdgpu_force_wait_barrier_defer, false,
+    "Forces cross-queue wait barriers through the software deferral path "
+    "instead of using the device-side strategy selected from the GPU ISA.");
 
 IREE_FLAG(int64_t, amdgpu_wait_active_for_ns, 0,
-          "Uses HSA_WAIT_STATE_ACTIVE for up to duration before switching to "
-          "HSA_WAIT_STATE_BLOCKED. >0 will increase CPU usage in cases where "
-          "the waits are long and decrease latency in cases where "
-          "the waits are short.");
+          "Reserved for future HSA active-wait tuning. Must be 0 today.");
+
+static iree_status_t iree_hal_amdgpu_flag_int64_to_host_size(
+    const char* flag_name, int64_t flag_value, iree_host_size_t* out_value) {
+  if (flag_value < 0) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "%s must be non-negative (got %" PRIi64 ")",
+                            flag_name, flag_value);
+  }
+  *out_value = (iree_host_size_t)flag_value;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_flag_int64_to_device_size(
+    const char* flag_name, int64_t flag_value, iree_device_size_t* out_value) {
+  if (flag_value < 0) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "%s must be non-negative (got %" PRIi64 ")",
+                            flag_name, flag_value);
+  }
+  *out_value = (iree_device_size_t)flag_value;
+  return iree_ok_status();
+}
 
 static iree_status_t iree_hal_amdgpu_driver_factory_enumerate(
     void* self, iree_host_size_t* out_driver_info_count,
@@ -85,29 +117,75 @@
   options.libhsa_search_paths = FLAG_amdgpu_libhsa_search_path_list();
 
   if (FLAG_amdgpu_host_block_pool_small_size) {
-    device_options->host_block_pools.small.block_size =
-        FLAG_amdgpu_host_block_pool_small_size;
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_flag_int64_to_host_size(
+        "--amdgpu_host_block_pool_small_size",
+        FLAG_amdgpu_host_block_pool_small_size,
+        &device_options->host_block_pools.small.block_size));
   }
   if (FLAG_amdgpu_host_block_pool_large_size) {
-    device_options->host_block_pools.large.block_size =
-        FLAG_amdgpu_host_block_pool_large_size;
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_flag_int64_to_host_size(
+        "--amdgpu_host_block_pool_large_size",
+        FLAG_amdgpu_host_block_pool_large_size,
+        &device_options->host_block_pools.large.block_size));
+  }
+  if (FLAG_amdgpu_host_block_pool_command_buffer_size) {
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_flag_int64_to_host_size(
+        "--amdgpu_host_block_pool_command_buffer_size",
+        FLAG_amdgpu_host_block_pool_command_buffer_size,
+        &device_options->host_block_pools.command_buffer.usable_block_size));
   }
 
   if (FLAG_amdgpu_device_block_pool_small_size) {
-    device_options->device_block_pools.small.block_size =
-        FLAG_amdgpu_device_block_pool_small_size;
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_flag_int64_to_device_size(
+        "--amdgpu_device_block_pool_small_size",
+        FLAG_amdgpu_device_block_pool_small_size,
+        &device_options->device_block_pools.small.block_size));
   }
   if (FLAG_amdgpu_device_block_pool_small_capacity) {
-    device_options->device_block_pools.small.initial_capacity =
-        FLAG_amdgpu_device_block_pool_small_capacity;
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_flag_int64_to_host_size(
+        "--amdgpu_device_block_pool_small_capacity",
+        FLAG_amdgpu_device_block_pool_small_capacity,
+        &device_options->device_block_pools.small.initial_capacity));
   }
   if (FLAG_amdgpu_device_block_pool_large_size) {
-    device_options->device_block_pools.large.block_size =
-        FLAG_amdgpu_device_block_pool_large_size;
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_flag_int64_to_device_size(
+        "--amdgpu_device_block_pool_large_size",
+        FLAG_amdgpu_device_block_pool_large_size,
+        &device_options->device_block_pools.large.block_size));
   }
   if (FLAG_amdgpu_device_block_pool_large_capacity) {
-    device_options->device_block_pools.large.initial_capacity =
-        FLAG_amdgpu_device_block_pool_large_capacity;
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_flag_int64_to_host_size(
+        "--amdgpu_device_block_pool_large_capacity",
+        FLAG_amdgpu_device_block_pool_large_capacity,
+        &device_options->device_block_pools.large.initial_capacity));
+  }
+
+  if (FLAG_amdgpu_default_pool_range_length) {
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_flag_int64_to_device_size(
+        "--amdgpu_default_pool_range_length",
+        FLAG_amdgpu_default_pool_range_length,
+        &device_options->default_pool.range_length));
+  }
+  if (FLAG_amdgpu_default_pool_alignment) {
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_flag_int64_to_device_size(
+        "--amdgpu_default_pool_alignment", FLAG_amdgpu_default_pool_alignment,
+        &device_options->default_pool.alignment));
+  }
+  if (FLAG_amdgpu_default_pool_frontier_capacity) {
+    if (FLAG_amdgpu_default_pool_frontier_capacity < 0) {
+      return iree_make_status(
+          IREE_STATUS_OUT_OF_RANGE,
+          "default pool frontier capacity must be non-negative (got %d)",
+          FLAG_amdgpu_default_pool_frontier_capacity);
+    }
+    if (FLAG_amdgpu_default_pool_frontier_capacity > UINT8_MAX) {
+      return iree_make_status(
+          IREE_STATUS_OUT_OF_RANGE,
+          "default pool frontier capacity %d exceeds maximum %u",
+          FLAG_amdgpu_default_pool_frontier_capacity, UINT8_MAX);
+    }
+    device_options->default_pool.frontier_capacity =
+        (uint8_t)FLAG_amdgpu_default_pool_frontier_capacity;
   }
 
   if (strcmp(FLAG_amdgpu_queue_placement, "any") == 0) {
@@ -124,10 +202,17 @@
 
   device_options->preallocate_pools = FLAG_amdgpu_preallocate_pools;
 
-  device_options->trace_execution = FLAG_amdgpu_trace_execution;
-
   device_options->exclusive_execution = FLAG_amdgpu_exclusive_execution;
 
+  device_options->force_wait_barrier_defer =
+      FLAG_amdgpu_force_wait_barrier_defer;
+
+  if (FLAG_amdgpu_wait_active_for_ns < 0) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "--amdgpu_wait_active_for_ns must be non-negative (got %" PRIi64 ")",
+        FLAG_amdgpu_wait_active_for_ns);
+  }
   device_options->wait_active_for_ns = FLAG_amdgpu_wait_active_for_ns;
 
   iree_status_t status = iree_hal_amdgpu_driver_create(
diff --git a/runtime/src/iree/hal/drivers/amdgpu/semaphore.c b/runtime/src/iree/hal/drivers/amdgpu/semaphore.c
index 2dc04ba..6b1dbc0 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/semaphore.c
+++ b/runtime/src/iree/hal/drivers/amdgpu/semaphore.c
@@ -1,4 +1,4 @@
-// Copyright 2025 The IREE Authors
+// Copyright 2026 The IREE Authors
 //
 // Licensed under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,484 +6,288 @@
 
 #include "iree/hal/drivers/amdgpu/semaphore.h"
 
-#include "iree/hal/drivers/amdgpu/device/semaphore.h"
-
-// Frees an iree_status_t encoded in an HSA signal value, if any.
-// The STATUS_BIT indicates that the signal value contains an encoded pointer.
-static inline void iree_hal_amdgpu_semaphore_failure_free(uint64_t value) {
-  if (value & IREE_HAL_SEMAPHORE_FAILURE_VALUE_STATUS_BIT) {
-    iree_status_free((iree_status_t)(intptr_t)(((int64_t)value << 1) >> 1));
-  }
-}
+#include "iree/hal/drivers/amdgpu/util/notification_ring.h"
 
 //===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_internal_semaphore_t
+// iree_hal_amdgpu_semaphore_t
 //===----------------------------------------------------------------------===//
 
-static const iree_hal_semaphore_vtable_t
-    iree_hal_amdgpu_internal_semaphore_vtable;
+typedef struct iree_hal_amdgpu_semaphore_t {
+  // Embedded async semaphore at offset 0 for toll-free bridging.
+  iree_async_semaphore_t async;
 
-static iree_hal_amdgpu_internal_semaphore_t*
-iree_hal_amdgpu_internal_semaphore_cast(iree_hal_semaphore_t* base_value) {
-  IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_amdgpu_internal_semaphore_vtable);
-  return (iree_hal_amdgpu_internal_semaphore_t*)base_value;
+  // Allocator used to free this semaphore.
+  iree_allocator_t host_allocator;
+
+  // Back-pointer to the logical device that created this semaphore.
+  // Used for type discrimination (is_local check). Not retained.
+  iree_hal_amdgpu_logical_device_t* device;
+
+  // Creation flags controlling synchronization behavior.
+  iree_hal_semaphore_flags_t flags;
+
+  // Queue affinity provided at creation. When DEVICE_LOCAL is set this is the
+  // complete set of queues that may legally signal or wait on the semaphore.
+  iree_hal_queue_affinity_t queue_affinity;
+
+  // Seqlock-protected cache of the most recent signal from a queue.
+  // Updated by the submission path when queue_execute signals this semaphore.
+  // Read by the submission path for same-queue FIFO elision, cross-queue
+  // epoch lookup, and by the host-wait fast path for direct signal waits.
+  // Initialized to zero (flags=0) — no valid signal has been recorded yet.
+  iree_hal_amdgpu_last_signal_t last_signal;
+} iree_hal_amdgpu_semaphore_t;
+
+static const iree_hal_semaphore_vtable_t iree_hal_amdgpu_semaphore_vtable;
+
+static iree_hal_amdgpu_semaphore_t* iree_hal_amdgpu_semaphore_cast(
+    iree_hal_semaphore_t* base_value) {
+  IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_amdgpu_semaphore_vtable);
+  return (iree_hal_amdgpu_semaphore_t*)base_value;
 }
 
-iree_status_t iree_hal_amdgpu_internal_semaphore_initialize(
-    const iree_hal_amdgpu_libhsa_t* libhsa,
-    iree_hal_amdgpu_semaphore_options_t options,
-    iree_hal_semaphore_flags_t flags,
-    IREE_AMDGPU_DEVICE_PTR iree_hal_amdgpu_device_semaphore_t* device_semaphore,
-    iree_hal_amdgpu_internal_semaphore_release_callback_t release_callback,
-    iree_hal_amdgpu_internal_semaphore_t* out_semaphore) {
-  IREE_ASSERT_ARGUMENT(libhsa);
-  IREE_ASSERT_ARGUMENT(device_semaphore);
+iree_status_t iree_hal_amdgpu_semaphore_create(
+    iree_hal_amdgpu_logical_device_t* device, iree_async_proactor_t* proactor,
+    iree_hal_queue_affinity_t queue_affinity, uint64_t initial_value,
+    iree_hal_semaphore_flags_t flags, iree_allocator_t host_allocator,
+    iree_hal_semaphore_t** out_semaphore) {
+  IREE_ASSERT_ARGUMENT(device);
+  IREE_ASSERT_ARGUMENT(proactor);
   IREE_ASSERT_ARGUMENT(out_semaphore);
+  *out_semaphore = NULL;
   IREE_TRACE_ZONE_BEGIN(z0);
 
-  memset(out_semaphore, 0, sizeof(*out_semaphore));
-
-  // Create the HSA signal.
-  // HSA has two signal types: default and interrupt. A default signal is like a
-  // futex and is relatively light-weight but the host can only busy-wait on it.
-  // An interrupt signal involves the OS but allows for platform-level waits.
-  //
-  // TODO(benvanik): add a semaphore flag for device-only? It's hard to know
-  // that in all cases but in the compiler we could do it for our locally-scoped
-  // ones. We aggressively pool semaphores and don't track if there's host
-  // waiters so for today we just take the hit and always use interrupt signals.
-  // If we wanted device-only we'd set the HSA_AMD_SIGNAL_AMD_GPU_ONLY flag.
-  uint64_t signal_flags = 0;
-  IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0, iree_hsa_amd_signal_create(
-              IREE_LIBHSA(libhsa), /*initial_value=*/0ull,
-              /*num_consumers=*/0,
-              /*consumers=*/NULL, signal_flags, &out_semaphore->signal));
-
-  iree_hal_resource_initialize(&iree_hal_amdgpu_internal_semaphore_vtable,
-                               &out_semaphore->resource);
-  // Pooling behavior: maintain a 0 ref count until acquired.
-  iree_atomic_ref_count_dec(&out_semaphore->resource.ref_count);
-  out_semaphore->libhsa = libhsa;
-  out_semaphore->options = options;
-  out_semaphore->flags = flags;
-  out_semaphore->device_semaphore = device_semaphore;
-  out_semaphore->release_callback = release_callback;
-
-  // NOTE: today we assume the semaphore device memory is host-accessible. In
-  // the future we may make device-only semaphores and would need to do a
-  // host-to-device transfer to update the device semaphore values.
-  memset(device_semaphore, 0, sizeof(*device_semaphore));
-  device_semaphore->signal = (iree_amd_signal_t*)out_semaphore->signal.handle;
-  device_semaphore->host_semaphore = (uint64_t)out_semaphore;
-
-  IREE_TRACE_ZONE_END(z0);
-  return iree_ok_status();
-}
-
-void iree_hal_amdgpu_internal_semaphore_deinitialize(
-    iree_hal_amdgpu_internal_semaphore_t* semaphore) {
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  const iree_hal_amdgpu_libhsa_t* libhsa = semaphore->libhsa;
-
-  IREE_IGNORE_ERROR(
-      iree_hsa_signal_destroy(IREE_LIBHSA(libhsa), semaphore->signal));
-
-  IREE_TRACE_ZONE_END(z0);
-}
-
-static void iree_hal_amdgpu_internal_semaphore_destroy(
-    iree_async_semaphore_t* base_semaphore) {
-  iree_hal_amdgpu_internal_semaphore_t* semaphore =
-      iree_hal_amdgpu_internal_semaphore_cast(
-          iree_hal_semaphore_cast(base_semaphore));
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  // If the semaphore failed we need to free the status object, if any.
-  // The signal will be reset to a new initial value if it is reused.
-  const hsa_signal_value_t old_value = iree_hsa_signal_exchange_scacquire(
-      IREE_LIBHSA(semaphore->libhsa), semaphore->signal, 0);
-  iree_hal_amdgpu_semaphore_failure_free((uint64_t)old_value);
-
-  // Use the provided release callback to free or recycle the semaphore.
-  if (semaphore->release_callback.fn) {
-    semaphore->release_callback.fn(semaphore->release_callback.user_data,
-                                   semaphore);
+  iree_hal_amdgpu_semaphore_t* semaphore = NULL;
+  iree_host_size_t frontier_offset = 0, total_size = 0;
+  // Match the queue frontier/snapshot capacity so publishing a full queue
+  // frontier into a semaphore does not overflow just because the semaphore was
+  // allocated with a narrower async default.
+  iree_status_t status = iree_async_semaphore_layout(
+      sizeof(*semaphore), IREE_HAL_AMDGPU_MAX_FRONTIER_SNAPSHOT_ENTRY_COUNT,
+      &frontier_offset, &total_size);
+  if (iree_status_is_ok(status)) {
+    status =
+        iree_allocator_malloc(host_allocator, total_size, (void**)&semaphore);
   }
-
-  IREE_TRACE_ZONE_END(z0);
-}
-
-bool iree_hal_amdgpu_internal_semaphore_isa(iree_hal_semaphore_t* semaphore) {
-  return iree_hal_resource_is(semaphore,
-                              &iree_hal_amdgpu_internal_semaphore_vtable);
-}
-
-void iree_hal_amdgpu_internal_semaphore_reset(
-    iree_hal_amdgpu_internal_semaphore_t* semaphore, uint64_t initial_value) {
-  // Reset the HSA signal value to the user-provided initial value.
-  // Note that this is just a store here as we've already cleared any status
-  // that may have been embedded in the value prior to it being returned to the
-  // pool. We do a silent store here as no one should be waiting on the signal
-  // and they don't need to be notified.
-  //
-  // NOTE: ROCR implements the silent calls by just routing to the normal ones
-  // so this isn't actually silent. Darn.
-  // https://github.com/ROCm/ROCR-Runtime/issues/316
-  iree_hsa_signal_silent_store_screlease(IREE_LIBHSA(semaphore->libhsa),
-                                         semaphore->signal, initial_value);
-}
-
-static uint64_t iree_hal_amdgpu_internal_semaphore_query(
-    iree_async_semaphore_t* base_semaphore) {
-  iree_hal_amdgpu_internal_semaphore_t* semaphore =
-      iree_hal_amdgpu_internal_semaphore_cast(
-          iree_hal_semaphore_cast(base_semaphore));
-
-  // Load the HSA signal value with acquire semantics.
-  // If the semaphore has failed the value will be >=
-  // IREE_HAL_SEMAPHORE_FAILURE_VALUE and the HAL dispatch layer will convert
-  // that to the appropriate error status.
-  hsa_signal_value_t current_value = iree_hsa_signal_load_scacquire(
-      IREE_LIBHSA(semaphore->libhsa), semaphore->signal);
-  return (uint64_t)current_value;
-}
-
-static iree_status_t iree_hal_amdgpu_internal_semaphore_signal(
-    iree_async_semaphore_t* base_semaphore, uint64_t new_value,
-    const iree_async_frontier_t* frontier) {
-  (void)frontier;
-  iree_hal_amdgpu_internal_semaphore_t* semaphore =
-      iree_hal_amdgpu_internal_semaphore_cast(
-          iree_hal_semaphore_cast(base_semaphore));
-
-  // Check that we are incrementing the value. This also handles cases where the
-  // signal has failed as then the current value will always be larger than
-  // whatever value we are setting it to.
-  hsa_signal_value_t current_value = iree_hsa_signal_load_relaxed(
-      IREE_LIBHSA(semaphore->libhsa), semaphore->signal);
-  while (current_value != new_value) {
-    if (new_value < current_value) {
-      return iree_make_status(
-          IREE_STATUS_FAILED_PRECONDITION,
-          "semaphore signal requested to an older value; "
-          "semaphores must be monotonically increasing (previous=%" PRIu64
-          ", new=%" PRIu64 ")",
-          current_value, new_value);
-    }
-    // We update the signal to the new value (and notify host waiters) with a
-    // CAS. Immediately upon store some host thread or device agent may
-    // immediately wake and process whatever data is being signaled as
-    // available. If someone else came in and updated the value before us the
-    // CAS will fail and we'll try again (unless doing so would be invalid).
-    const hsa_signal_value_t observed_value = iree_hsa_signal_cas_scacq_screl(
-        IREE_LIBHSA(semaphore->libhsa), semaphore->signal, current_value,
-        (hsa_signal_value_t)new_value);
-    if (observed_value == current_value) {
-      // Swap took place.
-      break;
-    }
-    current_value = observed_value;  // try again
-  }
-
-  // TODO(benvanik): update device-side semaphore entry and wake any schedulers
-  // registered with it.
-
-  return iree_ok_status();
-}
-
-static void iree_hal_amdgpu_internal_semaphore_on_fail(
-    iree_async_semaphore_t* base_semaphore, iree_status_code_t status_code) {
-  (void)status_code;
-  iree_hal_amdgpu_internal_semaphore_t* semaphore =
-      iree_hal_amdgpu_internal_semaphore_cast(
-          iree_hal_semaphore_cast(base_semaphore));
-
-  // CAS the HSA signal to the failure sentinel so device-side waiters wake.
-  hsa_signal_value_t current_value = iree_hsa_signal_load_scacquire(
-      IREE_LIBHSA(semaphore->libhsa), semaphore->signal);
-  while (current_value < IREE_HAL_SEMAPHORE_FAILURE_VALUE) {
-    const hsa_signal_value_t observed_value = iree_hsa_signal_cas_scacq_screl(
-        IREE_LIBHSA(semaphore->libhsa), semaphore->signal, current_value,
-        (hsa_signal_value_t)IREE_HAL_SEMAPHORE_FAILURE_VALUE);
-    if (observed_value == current_value) break;
-    current_value = observed_value;
-  }
-}
-
-static iree_status_t iree_hal_amdgpu_internal_semaphore_wait(
-    iree_hal_semaphore_t* base_semaphore, uint64_t value,
-    iree_timeout_t timeout, iree_async_wait_flags_t flags) {
-  iree_hal_amdgpu_internal_semaphore_t* semaphore =
-      iree_hal_amdgpu_internal_semaphore_cast(base_semaphore);
-  iree_hal_semaphore_list_t semaphore_list = {
-      .count = 1,
-      .semaphores = &base_semaphore,
-      .payload_values = &value,
-  };
-  return iree_hal_amdgpu_wait_semaphores(semaphore->libhsa, semaphore->options,
-                                         IREE_ASYNC_WAIT_MODE_ALL,
-                                         semaphore_list, timeout, flags);
-}
-
-static iree_status_t iree_hal_amdgpu_internal_semaphore_import_timepoint(
-    iree_hal_semaphore_t* base_semaphore, uint64_t value,
-    iree_hal_queue_affinity_t queue_affinity,
-    iree_hal_external_timepoint_t external_timepoint) {
-  (void)base_semaphore;
-  (void)value;
-  (void)queue_affinity;
-  (void)external_timepoint;
-  return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
-                          "timepoint import not implemented");
-}
-
-static iree_status_t iree_hal_amdgpu_internal_semaphore_export_timepoint(
-    iree_hal_semaphore_t* base_semaphore, uint64_t value,
-    iree_hal_queue_affinity_t queue_affinity,
-    iree_hal_external_timepoint_type_t requested_type,
-    iree_hal_external_timepoint_flags_t requested_flags,
-    iree_hal_external_timepoint_t* IREE_RESTRICT out_external_timepoint) {
-  (void)base_semaphore;
-  (void)value;
-  (void)queue_affinity;
-  (void)requested_type;
-  (void)requested_flags;
-  (void)out_external_timepoint;
-  return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
-                          "timepoint export not implemented");
-}
-
-static const iree_hal_semaphore_vtable_t
-    iree_hal_amdgpu_internal_semaphore_vtable = {
-        .async =
-            {
-                .destroy = iree_hal_amdgpu_internal_semaphore_destroy,
-                .query = iree_hal_amdgpu_internal_semaphore_query,
-                .signal = iree_hal_amdgpu_internal_semaphore_signal,
-                .on_fail = iree_hal_amdgpu_internal_semaphore_on_fail,
-            },
-        .wait = iree_hal_amdgpu_internal_semaphore_wait,
-        .import_timepoint = iree_hal_amdgpu_internal_semaphore_import_timepoint,
-        .export_timepoint = iree_hal_amdgpu_internal_semaphore_export_timepoint,
-};
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_external_semaphore_t
-//===----------------------------------------------------------------------===//
-
-// TODO(benvanik): external imported semaphore wrapper.
-
-//===----------------------------------------------------------------------===//
-// Semaphore Operations
-//===----------------------------------------------------------------------===//
-
-iree_status_t iree_hal_amdgpu_semaphore_handle(
-    iree_hal_semaphore_t* base_semaphore,
-    IREE_AMDGPU_DEVICE_PTR iree_hal_amdgpu_device_semaphore_t** out_handle) {
-  if (iree_hal_amdgpu_internal_semaphore_isa(base_semaphore)) {
-    iree_hal_amdgpu_internal_semaphore_t* semaphore =
-        (iree_hal_amdgpu_internal_semaphore_t*)base_semaphore;
-    *out_handle = semaphore->device_semaphore;
-    return iree_ok_status();
-  }
-  return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                          "semaphore is not from the AMDGPU backend and has no "
-                          "corresponding device handle");
-}
-
-iree_status_t iree_hal_amdgpu_semaphore_hsa_signal(
-    iree_hal_semaphore_t* base_semaphore, hsa_signal_t* out_signal) {
-  if (iree_hal_amdgpu_internal_semaphore_isa(base_semaphore)) {
-    iree_hal_amdgpu_internal_semaphore_t* semaphore =
-        (iree_hal_amdgpu_internal_semaphore_t*)base_semaphore;
-    *out_signal = semaphore->signal;
-    return iree_ok_status();
-  }
-  return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                          "semaphore is not from the AMDGPU backend and has no "
-                          "corresponding HSA signal");
-}
-
-iree_status_t iree_hal_amdgpu_poll_semaphore(
-    iree_hal_semaphore_t* base_semaphore, uint64_t* out_current_value) {
-  if (iree_hal_amdgpu_internal_semaphore_isa(base_semaphore)) {
-    iree_hal_amdgpu_internal_semaphore_t* semaphore =
-        (iree_hal_amdgpu_internal_semaphore_t*)base_semaphore;
-    hsa_signal_value_t current_value = iree_hsa_signal_load_scacquire(
-        IREE_LIBHSA(semaphore->libhsa), semaphore->signal);
-    if (IREE_UNLIKELY(current_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE)) {
-      // If the semaphore failed then interpret the failure as an IREE status
-      // object and clone it for the caller.
-      return iree_hal_semaphore_failure_as_status(current_value);
-    }
-    *out_current_value = (uint64_t)current_value;
-    return iree_ok_status();
-  }
-  return iree_make_status(
-      IREE_STATUS_INVALID_ARGUMENT,
-      "only AMDGPU semaphores are supported; a fallback path for mixed "
-      "semaphores would be needed for polling");
-}
-
-iree_status_t iree_hal_amdgpu_poll_semaphores(
-    iree_async_wait_mode_t wait_mode,
-    const iree_hal_semaphore_list_t semaphore_list) {
-  // Poll every semaphore and check the >= condition.
-  // In wait-any mode the first satisfied condition will return OK.
-  // In wait-all mode the first unsatisfied condition will return
-  // DEADLINE_EXCEEDED.
-  for (iree_host_size_t i = 0; i < semaphore_list.count; ++i) {
-    uint64_t current_value = 0ull;
-    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_poll_semaphore(
-        semaphore_list.semaphores[i], &current_value));
-    if (current_value >= semaphore_list.payload_values[i]) {
-      // Satisfied.
-      if (wait_mode == IREE_ASYNC_WAIT_MODE_ANY) {
-        // Only one semaphore needs to be reached in wait-any mode.
-        return iree_ok_status();
-      }
-    } else {
-      // Unsatisfied.
-      if (wait_mode == IREE_ASYNC_WAIT_MODE_ALL) {
-        // All semaphores need to be reached in wait-all mode.
-        return iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED);
-      }
-    }
-  }
-  // In wait-any mode if none were satisfied then return DEADLINE_EXCEEDED.
-  // In wait-all mode if none were unsatisfied then return OK.
-  return wait_mode == IREE_ASYNC_WAIT_MODE_ANY
-             ? iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED)
-             : iree_ok_status();
-}
-
-// The ROCR implementation of wait multiple is really bad/slow. We should really
-// get that rewritten if we want to continue using it. We could easily go to
-// hsaKmtWaitOnMultipleEvents_Ext ourselves but the special wait flag handling
-// in core::Signal is something we can't directly touch. I'm not sure we
-// actually need that, though, given that we have no way of keeping that in sync
-// with device-side waits.
-iree_status_t iree_hal_amdgpu_wait_semaphores(
-    const iree_hal_amdgpu_libhsa_t* libhsa,
-    iree_hal_amdgpu_semaphore_options_t options,
-    iree_async_wait_mode_t wait_mode,
-    const iree_hal_semaphore_list_t semaphore_list, iree_timeout_t timeout,
-    iree_async_wait_flags_t flags) {
-  IREE_ASSERT_ARGUMENT(libhsa);
-  if (semaphore_list.count == 0) return iree_ok_status();  // no-op
-  IREE_TRACE_ZONE_BEGIN(z0);
-  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, iree_timeout_as_duration_ns(timeout));
-
-  // Fast-path for immediate timeouts using this API to poll.
-  if (iree_timeout_is_immediate(timeout)) {
-    iree_status_t poll_status =
-        iree_hal_amdgpu_poll_semaphores(wait_mode, semaphore_list);
-    IREE_TRACE_ZONE_END(z0);
-    return poll_status;
-  }
-
-  // TODO(benvanik): use options.wait_active_for_ns to spin locally before we
-  // call into ROCR (which is significantly more expensive).
-  const hsa_wait_state_t wait_state =
-      options.wait_active_for_ns == IREE_DURATION_INFINITE
-          ? HSA_WAIT_STATE_ACTIVE
-          : HSA_WAIT_STATE_BLOCKED;
-  const iree_duration_t timeout_duration_ns =
-      iree_timeout_is_infinite(timeout) ? UINT64_MAX
-                                        : iree_timeout_as_duration_ns(timeout);
-
-  // Fast-path for single semaphore waits.
-  // ROCR's multi-wait is inefficient and we really want to avoid it if
-  // possible.
-  if (semaphore_list.count == 1) {
-    hsa_signal_t signal;
-    IREE_RETURN_AND_END_ZONE_IF_ERROR(
-        z0,
-        iree_hal_amdgpu_semaphore_hsa_signal(semaphore_list.semaphores[0],
-                                             &signal),
-        "retrieving HSA signal from semaphore");
-    hsa_signal_value_t expected_value = semaphore_list.payload_values[0];
-    iree_status_t status = iree_ok_status();
-    hsa_signal_value_t current_value = iree_hsa_signal_wait_scacquire(
-        IREE_LIBHSA(libhsa), signal, HSA_SIGNAL_CONDITION_GTE, expected_value,
-        timeout_duration_ns, wait_state);
-    if (IREE_UNLIKELY(current_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE)) {
-      // If the semaphore failed then interpret the failure as an IREE status
-      // object and clone it for the caller.
-      status = iree_hal_semaphore_failure_as_status(current_value);
-    } else if (current_value < expected_value) {
-      // Assume timeout. It may be a spurious wake and we should try again until
-      // the timeout duration has been reached.
-      // TODO(benvanik): retry while timeout remaining.
-      status = iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED);
-    }
-    IREE_TRACE_ZONE_END(z0);
-    return status;
-  }
-
-  // Build array-of-structs for the individual wait operations.
-  hsa_signal_t* signals =
-      iree_alloca(semaphore_list.count * sizeof(hsa_signal_t));
-  hsa_signal_condition_t* conds =
-      iree_alloca(semaphore_list.count * sizeof(hsa_signal_condition_t));
-  hsa_signal_value_t* values =
-      iree_alloca(semaphore_list.count * sizeof(hsa_signal_value_t));
-  for (iree_host_size_t i = 0; i < semaphore_list.count; ++i) {
-    IREE_RETURN_AND_END_ZONE_IF_ERROR(
-        z0,
-        iree_hal_amdgpu_semaphore_hsa_signal(semaphore_list.semaphores[i],
-                                             &signals[i]),
-        "retrieving HSA signal from semaphore");
-    conds[i] = HSA_SIGNAL_CONDITION_GTE;
-    values[i] = semaphore_list.payload_values[i];
-  }
-
-  // NOTE: hsa_amd_signal_wait_all/hsa_amd_signal_wait_any has relaxed memory
-  // semantics and to have the proper acquire behavior we need to load the
-  // signal value ourselves.
-  iree_status_t status = iree_ok_status();
-  switch (wait_mode) {
-    case IREE_ASYNC_WAIT_MODE_ALL: {
-      const uint32_t wait_result = iree_hsa_amd_signal_wait_all(
-          IREE_LIBHSA(libhsa), semaphore_list.count, signals, conds, values,
-          timeout_duration_ns, wait_state, /*satisfying_values=*/NULL);
-      if (wait_result == 0) {
-        // If the wait succeeded then check for errors.
-        // This also issues an acquire fence on every semaphore.
-        status = iree_hal_amdgpu_poll_semaphores(wait_mode, semaphore_list);
-      } else {
-        status = iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED);
-      }
-    } break;
-    case IREE_ASYNC_WAIT_MODE_ANY: {
-      hsa_signal_value_t satisfying_value = 0;
-      const uint32_t satisfying_index = iree_hsa_amd_signal_wait_any(
-          IREE_LIBHSA(libhsa), semaphore_list.count, signals, conds, values,
-          timeout_duration_ns, wait_state, &satisfying_value);
-      if (satisfying_index != UINT32_MAX) {
-        // Issue an acquire fence on the satisfying semaphore. This will
-        // propagate errors if the wait succeeded because the semaphore was
-        // signaled to a failure value. We could reuse the satisfying_value
-        // above but we'd still need the acquire fence.
-        //
-        // Note that more than one semaphore make have had its condition
-        // satisfied and more than one may be in a failure state; this API
-        // doesn't exhaustively check.
-        uint64_t current_value = 0ull;
-        status = iree_hal_amdgpu_poll_semaphore(
-            semaphore_list.semaphores[satisfying_index], &current_value);
-      } else {
-        status = iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED);
-      }
-    } break;
-    default:
-      status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                                "unknown wait mode %d", (int)wait_mode);
-      break;
+  if (iree_status_is_ok(status)) {
+    iree_async_semaphore_initialize(
+        (const iree_async_semaphore_vtable_t*)&iree_hal_amdgpu_semaphore_vtable,
+        proactor, initial_value, frontier_offset,
+        IREE_HAL_AMDGPU_MAX_FRONTIER_SNAPSHOT_ENTRY_COUNT, &semaphore->async);
+    semaphore->host_allocator = host_allocator;
+    semaphore->device = device;
+    semaphore->flags = flags;
+    semaphore->queue_affinity = queue_affinity;
+    memset(&semaphore->last_signal, 0, sizeof(semaphore->last_signal));
+    *out_semaphore = iree_hal_semaphore_cast(&semaphore->async);
   }
 
   IREE_TRACE_ZONE_END(z0);
   return status;
 }
+
+static void iree_hal_amdgpu_semaphore_destroy(
+    iree_async_semaphore_t* base_semaphore) {
+  iree_hal_amdgpu_semaphore_t* semaphore =
+      iree_hal_amdgpu_semaphore_cast(iree_hal_semaphore_cast(base_semaphore));
+  iree_allocator_t host_allocator = semaphore->host_allocator;
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  iree_async_semaphore_deinitialize(&semaphore->async);
+  iree_allocator_free(host_allocator, semaphore);
+
+  IREE_TRACE_ZONE_END(z0);
+}
+
+bool iree_hal_amdgpu_semaphore_isa(iree_hal_semaphore_t* semaphore) {
+  return iree_hal_resource_is((const iree_hal_resource_t*)semaphore,
+                              &iree_hal_amdgpu_semaphore_vtable);
+}
+
+bool iree_hal_amdgpu_semaphore_is_local(
+    iree_hal_semaphore_t* semaphore,
+    const iree_hal_amdgpu_logical_device_t* device) {
+  return iree_hal_resource_is((const iree_hal_resource_t*)semaphore,
+                              &iree_hal_amdgpu_semaphore_vtable) &&
+         ((const iree_hal_amdgpu_semaphore_t*)semaphore)->device == device;
+}
+
+iree_hal_semaphore_flags_t iree_hal_amdgpu_semaphore_flags(
+    iree_hal_semaphore_t* semaphore) {
+  return ((const iree_hal_amdgpu_semaphore_t*)semaphore)->flags;
+}
+
+iree_hal_queue_affinity_t iree_hal_amdgpu_semaphore_queue_affinity(
+    iree_hal_semaphore_t* semaphore) {
+  return ((const iree_hal_amdgpu_semaphore_t*)semaphore)->queue_affinity;
+}
+
+bool iree_hal_amdgpu_semaphore_has_private_stream_semantics(
+    iree_hal_semaphore_t* semaphore,
+    const iree_hal_amdgpu_logical_device_t* device) {
+  if (!iree_hal_amdgpu_semaphore_is_local(semaphore, device)) return false;
+
+  const iree_hal_semaphore_flags_t flags =
+      iree_hal_amdgpu_semaphore_flags(semaphore);
+  const iree_hal_semaphore_flags_t required_flags =
+      IREE_HAL_SEMAPHORE_FLAG_DEVICE_LOCAL |
+      IREE_HAL_SEMAPHORE_FLAG_SINGLE_PRODUCER;
+  const iree_hal_semaphore_flags_t public_flags =
+      IREE_HAL_SEMAPHORE_FLAG_HOST_INTERRUPT |
+      IREE_HAL_SEMAPHORE_FLAG_EXPORTABLE |
+      IREE_HAL_SEMAPHORE_FLAG_EXPORTABLE_TIMEPOINTS;
+  return iree_all_bits_set(flags, required_flags) &&
+         !iree_any_bit_set(flags, public_flags);
+}
+
+iree_hal_amdgpu_last_signal_t* iree_hal_amdgpu_semaphore_last_signal(
+    iree_hal_semaphore_t* semaphore) {
+  return &((iree_hal_amdgpu_semaphore_t*)semaphore)->last_signal;
+}
+
+bool iree_hal_amdgpu_semaphore_publish_signal(
+    iree_hal_semaphore_t* base_semaphore, iree_async_axis_t producer_axis,
+    const iree_async_frontier_t* producer_frontier, uint64_t producer_epoch,
+    uint64_t producer_value) {
+  IREE_ASSERT_ARGUMENT(producer_frontier);
+  iree_hal_amdgpu_semaphore_t* semaphore =
+      iree_hal_amdgpu_semaphore_cast(base_semaphore);
+
+  iree_hal_amdgpu_last_signal_flags_t flags =
+      IREE_HAL_AMDGPU_LAST_SIGNAL_FLAG_VALID;
+  bool source_dominates_frontier = false;
+  iree_slim_mutex_lock(&semaphore->async.mutex);
+  bool merged = iree_async_frontier_merge_and_test_source_dominance(
+      semaphore->async.frontier, semaphore->async.frontier_capacity,
+      producer_frontier, &source_dominates_frontier);
+  if (merged && source_dominates_frontier) {
+    flags |= IREE_HAL_AMDGPU_LAST_SIGNAL_FLAG_PRODUCER_FRONTIER_EXACT;
+  }
+  iree_hal_amdgpu_last_signal_store(
+      &semaphore->last_signal, merged ? flags : 0,
+      merged ? producer_axis : (iree_async_axis_t)0,
+      merged ? producer_epoch : 0, merged ? producer_value : 0);
+  iree_slim_mutex_unlock(&semaphore->async.mutex);
+
+  return merged;
+}
+
+void iree_hal_amdgpu_semaphore_publish_private_stream_signal(
+    iree_hal_semaphore_t* base_semaphore, iree_async_axis_t producer_axis,
+    uint64_t producer_epoch, uint64_t producer_value) {
+  iree_hal_amdgpu_semaphore_t* semaphore =
+      iree_hal_amdgpu_semaphore_cast(base_semaphore);
+  iree_hal_amdgpu_last_signal_store(
+      &semaphore->last_signal,
+      IREE_HAL_AMDGPU_LAST_SIGNAL_FLAG_VALID |
+          IREE_HAL_AMDGPU_LAST_SIGNAL_FLAG_PRODUCER_FRONTIER_EXACT,
+      producer_axis, producer_epoch, producer_value);
+}
+
+void iree_hal_amdgpu_semaphore_clear_last_signal(
+    iree_hal_semaphore_t* base_semaphore) {
+  iree_hal_amdgpu_semaphore_t* semaphore =
+      iree_hal_amdgpu_semaphore_cast(base_semaphore);
+  iree_slim_mutex_lock(&semaphore->async.mutex);
+  iree_hal_amdgpu_last_signal_store(&semaphore->last_signal,
+                                    IREE_HAL_AMDGPU_LAST_SIGNAL_FLAG_NONE,
+                                    (iree_async_axis_t)0, 0, 0);
+  iree_slim_mutex_unlock(&semaphore->async.mutex);
+}
+
+static uint64_t iree_hal_amdgpu_semaphore_query(
+    iree_async_semaphore_t* base_semaphore) {
+  // Both fields are atomic — fully lock-free query.
+  iree_status_t failure = (iree_status_t)iree_atomic_load(
+      &base_semaphore->failure_status, iree_memory_order_acquire);
+  if (!iree_status_is_ok(failure)) {
+    return iree_hal_status_as_semaphore_failure(failure);
+  }
+  return (uint64_t)iree_atomic_load(&base_semaphore->timeline_value,
+                                    iree_memory_order_acquire);
+}
+
+static iree_status_t iree_hal_amdgpu_semaphore_signal(
+    iree_async_semaphore_t* base_semaphore, uint64_t new_value,
+    const iree_async_frontier_t* frontier) {
+  // Advance the timeline (CAS) and merge frontier.
+  iree_status_t status = iree_async_semaphore_advance_timeline(
+      base_semaphore, new_value, frontier);
+  if (!iree_status_is_ok(status)) return status;
+
+  // Dispatch satisfied timepoints.
+  iree_async_semaphore_dispatch_timepoints(base_semaphore, new_value);
+
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_semaphore_wait(
+    iree_hal_semaphore_t* base_semaphore, uint64_t value,
+    iree_timeout_t timeout, iree_async_wait_flags_t flags) {
+  iree_hal_amdgpu_semaphore_t* semaphore =
+      iree_hal_amdgpu_semaphore_cast(base_semaphore);
+
+  // Fast check: already reached or failed? Lock-free atomic load.
+  // Failure check must come first: failure values are numerically larger than
+  // any valid timeline value and would falsely satisfy the >= check.
+  uint64_t current = iree_async_semaphore_query(&semaphore->async);
+  if (current >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) {
+    return iree_hal_semaphore_failure_as_status(current);
+  }
+  if (current >= value) return iree_ok_status();
+
+  // Epoch fast path: if we have a cached epoch for a value >= target, we can
+  // wait directly on the queue's epoch signal instead of going through the
+  // timepoint/notification machinery. This is the minimum-latency path for
+  // "submit work, wait for result."
+  //
+  // The epoch signal and hsa_signal_wait_scacquire call will be wired here
+  // when the per-queue epoch signal is implemented. Until then, we fall
+  // through to the software path which is functionally identical.
+  //
+  // iree_async_axis_t producer_axis = 0;
+  // iree_hal_amdgpu_last_signal_flags_t last_signal_flags = 0;
+  // uint64_t epoch = 0, cached_value = 0;
+  // if (iree_hal_amdgpu_last_signal_load(&semaphore->last_signal,
+  //                                       &last_signal_flags, &producer_axis,
+  //                                       &epoch, &cached_value)) {
+  //   if (cached_value >= value) {
+  //     // Wait directly on producer_axis' epoch signal with LT condition.
+  //   }
+  // }
+
+  // Software fallback: timepoint-based blocking wait.
+  return iree_async_semaphore_multi_wait(
+      IREE_ASYNC_WAIT_MODE_ALL, (iree_async_semaphore_t**)&base_semaphore,
+      &value, 1, timeout, flags, iree_allocator_system());
+}
+
+static iree_status_t iree_hal_amdgpu_semaphore_import_timepoint(
+    iree_hal_semaphore_t* base_semaphore, uint64_t value,
+    iree_hal_queue_affinity_t queue_affinity,
+    iree_hal_external_timepoint_t external_timepoint) {
+  return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                          "AMDGPU timepoint import not yet implemented");
+}
+
+static iree_status_t iree_hal_amdgpu_semaphore_export_timepoint(
+    iree_hal_semaphore_t* base_semaphore, uint64_t value,
+    iree_hal_queue_affinity_t queue_affinity,
+    iree_hal_external_timepoint_type_t requested_type,
+    iree_hal_external_timepoint_flags_t requested_flags,
+    iree_hal_external_timepoint_t* IREE_RESTRICT out_external_timepoint) {
+  return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                          "AMDGPU timepoint export not yet implemented");
+}
+
+static const iree_hal_semaphore_vtable_t iree_hal_amdgpu_semaphore_vtable = {
+    .async =
+        {
+            .destroy = iree_hal_amdgpu_semaphore_destroy,
+            .query = iree_hal_amdgpu_semaphore_query,
+            .signal = iree_hal_amdgpu_semaphore_signal,
+        },
+    .wait = iree_hal_amdgpu_semaphore_wait,
+    .import_timepoint = iree_hal_amdgpu_semaphore_import_timepoint,
+    .export_timepoint = iree_hal_amdgpu_semaphore_export_timepoint,
+};
diff --git a/runtime/src/iree/hal/drivers/amdgpu/semaphore.h b/runtime/src/iree/hal/drivers/amdgpu/semaphore.h
index a40b500..d5609f4 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/semaphore.h
+++ b/runtime/src/iree/hal/drivers/amdgpu/semaphore.h
@@ -1,4 +1,4 @@
-// Copyright 2025 The IREE Authors
+// Copyright 2026 The IREE Authors
 //
 // Licensed under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -7,147 +7,206 @@
 #ifndef IREE_HAL_DRIVERS_AMDGPU_SEMAPHORE_H_
 #define IREE_HAL_DRIVERS_AMDGPU_SEMAPHORE_H_
 
+#include <string.h>
+
+#include "iree/async/semaphore.h"
 #include "iree/base/api.h"
+#include "iree/base/internal/atomics.h"
 #include "iree/hal/api.h"
-#include "iree/hal/drivers/amdgpu/util/libhsa.h"
 
 #ifdef __cplusplus
 extern "C" {
 #endif  // __cplusplus
 
-typedef struct iree_hal_amdgpu_device_semaphore_t
-    iree_hal_amdgpu_device_semaphore_t;
-
-typedef struct iree_hal_amdgpu_internal_semaphore_t
-    iree_hal_amdgpu_internal_semaphore_t;
-typedef struct iree_hal_amdgpu_system_t iree_hal_amdgpu_system_t;
+typedef struct iree_hal_amdgpu_logical_device_t
+    iree_hal_amdgpu_logical_device_t;
 
 //===----------------------------------------------------------------------===//
-// Utilities
+// iree_hal_amdgpu_last_signal_t
 //===----------------------------------------------------------------------===//
 
-// Options controlling global semaphore behavior.
-// Semaphore flags may override these options.
-typedef struct iree_hal_amdgpu_semaphore_options_t {
-  // Uses HSA_WAIT_STATE_ACTIVE for up to the given duration before switching to
-  // HSA_WAIT_STATE_BLOCKED. Above zero this will increase CPU usage in cases
-  // where the waits are long and decrease latency in cases where the waits are
-  // short. When IREE_DURATION_INFINITE waits will use HSA_WAIT_STATE_ACTIVE.
-  iree_duration_t wait_active_for_ns;
-} iree_hal_amdgpu_semaphore_options_t;
+typedef uint8_t iree_hal_amdgpu_last_signal_flags_t;
+enum iree_hal_amdgpu_last_signal_flag_bits_e {
+  IREE_HAL_AMDGPU_LAST_SIGNAL_FLAG_NONE = 0u,
+  // The cache contains a producer axis/epoch/value snapshot from at least one
+  // signal submission.
+  IREE_HAL_AMDGPU_LAST_SIGNAL_FLAG_VALID = 1u << 0,
+  // The semaphore's post-publish frontier is exactly the producer queue's
+  // frontier at |epoch|. A single barrier on |producer_axis|@|epoch| therefore
+  // implies all transitive dependencies carried by this semaphore signal, even
+  // when the producer frontier contains multiple peer axes.
+  IREE_HAL_AMDGPU_LAST_SIGNAL_FLAG_PRODUCER_FRONTIER_EXACT = 1u << 1,
+};
 
-typedef void(IREE_API_PTR* iree_hal_amdgpu_internal_semaphore_release_fn_t)(
-    void* user_data, iree_hal_amdgpu_internal_semaphore_t* semaphore);
+// Seqlock-protected cache of the most recent queue signal on a semaphore.
+// Written by the submission path when queue_execute signals the semaphore,
+// read by the submission path when processing waits (for same-queue FIFO
+// elision and direct producer-epoch cross-queue barriers) and by the
+// host-wait fast path.
+//
+// The seqlock ensures torn reads across the payload fields are detected and
+// retried. Writers increment the sequence counter to an odd value before the
+// update and to an even value after. Readers retry if the sequence is odd
+// (write in progress) or changed between the start and end of the read.
+typedef struct iree_hal_amdgpu_last_signal_t {
+  // Seqlock sequence counter; odd means a writer is updating payload fields.
+  iree_atomic_int32_t sequence;
+  // Cached signal validity and producer-frontier precision flags.
+  iree_hal_amdgpu_last_signal_flags_t flags;
+  // Reserved bytes kept zero so the payload stays naturally aligned.
+  uint8_t reserved[3];
+  // Producer queue axis that submitted the last cached signal.
+  iree_async_axis_t producer_axis;
+  // Producer queue epoch associated with the last cached signal.
+  uint64_t epoch;
+  // Semaphore payload value signaled at |producer_axis|/|epoch|.
+  uint64_t value;
+} iree_hal_amdgpu_last_signal_t;
 
-// A callback issued when a semaphore is released.
-typedef struct {
-  // Callback function pointer.
-  iree_hal_amdgpu_internal_semaphore_release_fn_t fn;
-  // User data passed to the callback function. Unowned.
-  void* user_data;
-} iree_hal_amdgpu_internal_semaphore_release_callback_t;
+// Stores a new last-signal snapshot. Thread-safe (seqlock writer).
+static inline void iree_hal_amdgpu_last_signal_store(
+    iree_hal_amdgpu_last_signal_t* cache,
+    iree_hal_amdgpu_last_signal_flags_t flags, iree_async_axis_t producer_axis,
+    uint64_t epoch, uint64_t value) {
+  // Increment to odd: signals write in progress.
+  iree_atomic_fetch_add(&cache->sequence, 1, iree_memory_order_acquire);
+  cache->flags = flags;
+  memset(cache->reserved, 0, sizeof(cache->reserved));
+  cache->producer_axis = producer_axis;
+  cache->epoch = epoch;
+  cache->value = value;
+  // Increment to even: signals write complete.
+  iree_atomic_fetch_add(&cache->sequence, 1, iree_memory_order_release);
+}
 
-// Returns a no-op release callback that implies that no cleanup is required.
-static inline iree_hal_amdgpu_internal_semaphore_release_callback_t
-iree_hal_amdgpu_internal_semaphore_release_callback_null(void) {
-  iree_hal_amdgpu_internal_semaphore_release_callback_t callback = {NULL, NULL};
-  return callback;
+// Loads the last-signal snapshot. Thread-safe (seqlock reader).
+// Returns true if the cache has been written at least once and remains valid.
+static inline bool iree_hal_amdgpu_last_signal_load(
+    const iree_hal_amdgpu_last_signal_t* cache,
+    iree_hal_amdgpu_last_signal_flags_t* out_flags,
+    iree_async_axis_t* out_producer_axis, uint64_t* out_epoch,
+    uint64_t* out_value) {
+  int32_t sequence;
+  do {
+    sequence = iree_atomic_load(&cache->sequence, iree_memory_order_acquire);
+    if (IREE_UNLIKELY(sequence & 1)) continue;  // writer in progress
+    *out_flags = cache->flags;
+    *out_producer_axis = cache->producer_axis;
+    *out_epoch = cache->epoch;
+    *out_value = cache->value;
+  } while (
+      IREE_UNLIKELY(iree_atomic_load(&cache->sequence,
+                                     iree_memory_order_acquire) != sequence));
+  return (*out_flags & IREE_HAL_AMDGPU_LAST_SIGNAL_FLAG_VALID) != 0;
 }
 
 //===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_internal_semaphore_t
+// iree_hal_amdgpu_semaphore_t
 //===----------------------------------------------------------------------===//
 
-// An internally-tracked HAL semaphore.
-// These carry additional information used by the implementation to optimize
-// wait/wake behavior and allow device-side wait/wake.
-typedef struct iree_hal_amdgpu_internal_semaphore_t {
-  iree_hal_resource_t resource;  // must be at 0
+// Creates an AMDGPU HAL semaphore backed by an embedded async semaphore.
+//
+// Signal, query, and wait all delegate to the async semaphore infrastructure.
+// The semaphore embeds iree_async_semaphore_t at offset 0 for toll-free
+// bridging between HAL and async layers.
+//
+// |device| is stored as a back-pointer for type discrimination (checking
+// whether a semaphore belongs to a specific logical device). Not retained.
+//
+// |queue_affinity| hints which queues will signal/wait on the semaphore. If
+// IREE_HAL_SEMAPHORE_FLAG_DEVICE_LOCAL is set, the semaphore is only used on
+// those queues and the implementation may optimize accordingly.
+//
+// |flags| controls semaphore behavior:
+//   DEVICE_LOCAL: only signaled/waited by queues within this device. Enables
+//     epoch-based hardware synchronization (barrier-value packets).
+//   HOST_INTERRUPT: host may call iree_hal_semaphore_wait. Enables
+//     interrupt-driven host blocking via HSA signal waits.
+//   SINGLE_PRODUCER: signals come from one producer timeline, allowing the
+//     implementation to treat the latest producer queue epoch as the complete
+//     causal frontier for the latest payload value.
+iree_status_t iree_hal_amdgpu_semaphore_create(
+    iree_hal_amdgpu_logical_device_t* device, iree_async_proactor_t* proactor,
+    iree_hal_queue_affinity_t queue_affinity, uint64_t initial_value,
+    iree_hal_semaphore_flags_t flags, iree_allocator_t host_allocator,
+    iree_hal_semaphore_t** out_semaphore);
 
-  // Unowned libhsa handle. Must be retained by the parent pool.
-  const iree_hal_amdgpu_libhsa_t* libhsa;
+// Returns true if |semaphore| is an AMDGPU semaphore.
+bool iree_hal_amdgpu_semaphore_isa(iree_hal_semaphore_t* semaphore);
 
-  // Global semaphore options, may be overridden based on flags.
-  iree_hal_amdgpu_semaphore_options_t options;
-
-  // Flags controlling semaphore behavior.
-  iree_hal_semaphore_flags_t flags;
-
-  // HSA signal handle. Contains the semaphore payload value.
-  hsa_signal_t signal;
-
-  // Device-visible semaphore in shared host/device memory.
-  // The allocation is owned by the parent semaphore pool.
-  IREE_AMDGPU_DEVICE_PTR iree_hal_amdgpu_device_semaphore_t* device_semaphore;
-
-  // Release callback that handles deallocation.
-  iree_hal_amdgpu_internal_semaphore_release_callback_t release_callback;
-} iree_hal_amdgpu_internal_semaphore_t;
-
-// Initializes an internal semaphore in-place with a 0 ref count.
-// The owning pool must increment the ref count to 1 before returning the
-// semaphore to users.
-iree_status_t iree_hal_amdgpu_internal_semaphore_initialize(
-    const iree_hal_amdgpu_libhsa_t* libhsa,
-    iree_hal_amdgpu_semaphore_options_t options,
-    iree_hal_semaphore_flags_t flags,
-    IREE_AMDGPU_DEVICE_PTR iree_hal_amdgpu_device_semaphore_t* device_semaphore,
-    iree_hal_amdgpu_internal_semaphore_release_callback_t release_callback,
-    iree_hal_amdgpu_internal_semaphore_t* out_semaphore);
-
-// Deinitializes an internal semaphore in-place assuming it has a 0 ref count.
-void iree_hal_amdgpu_internal_semaphore_deinitialize(
-    iree_hal_amdgpu_internal_semaphore_t* semaphore);
-
-// Returns true if |semaphore| is an iree_hal_amdgpu_internal_semaphore_t.
-bool iree_hal_amdgpu_internal_semaphore_isa(iree_hal_semaphore_t* semaphore);
-
-// Resets |semaphore| to |initial_value| as if it had just been allocated.
-void iree_hal_amdgpu_internal_semaphore_reset(
-    iree_hal_amdgpu_internal_semaphore_t* semaphore, uint64_t initial_value);
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_external_semaphore_t
-//===----------------------------------------------------------------------===//
-
-// TODO(benvanik): external imported semaphore wrapper.
-typedef uint64_t iree_hal_amdgpu_external_semaphore_t;
-
-//===----------------------------------------------------------------------===//
-// Semaphore Operations
-//===----------------------------------------------------------------------===//
-
-// Returns a device-side semaphore handle for the provided HAL semaphore.
-// Fails if there is no corresponding device-side handle (such as with a
-// semaphore from another HAL device). Such semaphores must be imported using
-// iree_hal_device_import_semaphore.
-iree_status_t iree_hal_amdgpu_semaphore_handle(
+// Returns true if |semaphore| is an AMDGPU semaphore belonging to |device|.
+// Used by the submission path to gate the epoch-based synchronization fast
+// path: only semaphores local to the submitting device can use barrier-value
+// packets on the device's queue epoch signals. Non-local semaphores (from
+// other HAL devices, remoting, etc.) always use the software timepoint path.
+bool iree_hal_amdgpu_semaphore_is_local(
     iree_hal_semaphore_t* semaphore,
-    IREE_AMDGPU_DEVICE_PTR iree_hal_amdgpu_device_semaphore_t** out_handle);
+    const iree_hal_amdgpu_logical_device_t* device);
 
-// Returns the HSA signal for the provided HAL semaphore.
-iree_status_t iree_hal_amdgpu_semaphore_hsa_signal(
-    iree_hal_semaphore_t* base_semaphore, hsa_signal_t* out_signal);
+// Returns the AMDGPU semaphore creation flags. Caller must verify
+// iree_hal_amdgpu_semaphore_isa() first.
+iree_hal_semaphore_flags_t iree_hal_amdgpu_semaphore_flags(
+    iree_hal_semaphore_t* semaphore);
 
-// Polls |base_semaphore| and returns its current value in |out_current_value|.
-// Returns ABORTED if the semaphore is in a failure state.
-iree_status_t iree_hal_amdgpu_poll_semaphore(
-    iree_hal_semaphore_t* base_semaphore, uint64_t* out_current_value);
+// Returns the AMDGPU semaphore creation queue affinity. Caller must verify
+// iree_hal_amdgpu_semaphore_isa() first.
+iree_hal_queue_affinity_t iree_hal_amdgpu_semaphore_queue_affinity(
+    iree_hal_semaphore_t* semaphore);
 
-// Polls |semaphore_list| and returns either OK or DEADLINE_EXCEEDED if
-// satisfied or unsatisfied at the time the method is called.
-// Returns ABORTED if any semaphore is in a failure state.
-iree_status_t iree_hal_amdgpu_poll_semaphores(
-    iree_async_wait_mode_t wait_mode,
-    const iree_hal_semaphore_list_t semaphore_list);
+// Returns true if |semaphore| has the strict private-stream contract used by
+// HIP-on-HAL stream timelines:
+//   - owned by |device|;
+//   - device-local;
+//   - single-producer; and
+//   - not host-interrupt/export/timepoint-export capable.
+//
+// Such semaphores are still normal HAL timeline semaphores, but AMDGPU may use
+// the single-producer proof to publish only the producer queue epoch on the
+// signal hot path. Completion drain still advances the timeline value, but
+// does not need to accumulate a multi-producer async frontier for the private
+// stream handoff.
+bool iree_hal_amdgpu_semaphore_has_private_stream_semantics(
+    iree_hal_semaphore_t* semaphore,
+    const iree_hal_amdgpu_logical_device_t* device);
 
-// Blocks until |semaphore_list| is satisfied per |wait_mode| or |timeout|.
-iree_status_t iree_hal_amdgpu_wait_semaphores(
-    const iree_hal_amdgpu_libhsa_t* libhsa,
-    iree_hal_amdgpu_semaphore_options_t options,
-    iree_async_wait_mode_t wait_mode,
-    const iree_hal_semaphore_list_t semaphore_list, iree_timeout_t timeout,
-    iree_async_wait_flags_t flags);
+// Returns a pointer to the last_signal cache on an AMDGPU semaphore.
+// Caller must verify iree_hal_amdgpu_semaphore_isa() first.
+iree_hal_amdgpu_last_signal_t* iree_hal_amdgpu_semaphore_last_signal(
+    iree_hal_semaphore_t* semaphore);
+
+// Publishes the submission-time frontier and last-signal cache for a signal
+// from |producer_axis| at (|producer_epoch|, |producer_value|).
+//
+// Merges |producer_frontier| into the semaphore's accumulated frontier under
+// the semaphore mutex, then updates the last-signal cache while still holding
+// that mutex so PRODUCER_FRONTIER_EXACT reflects the post-merge frontier
+// precisely. Returns false if the frontier merge overflowed capacity; in that
+// case the cache is cleared and callers must fall back to software waits for
+// not-yet-complete values.
+//
+// Caller must verify iree_hal_amdgpu_semaphore_isa() first.
+bool iree_hal_amdgpu_semaphore_publish_signal(
+    iree_hal_semaphore_t* semaphore, iree_async_axis_t producer_axis,
+    const iree_async_frontier_t* producer_frontier, uint64_t producer_epoch,
+    uint64_t producer_value);
+
+// Publishes a single-producer private-stream signal without accumulating the
+// full semaphore frontier under the async semaphore mutex.
+//
+// Caller must prove iree_hal_amdgpu_semaphore_has_private_stream_semantics()
+// and serialize all signals through |producer_axis|. The last-signal cache is
+// updated as PRODUCER_FRONTIER_EXACT because waiting on the producer queue
+// epoch is sufficient to observe the signaled payload's transitive
+// dependencies.
+void iree_hal_amdgpu_semaphore_publish_private_stream_signal(
+    iree_hal_semaphore_t* semaphore, iree_async_axis_t producer_axis,
+    uint64_t producer_epoch, uint64_t producer_value);
+
+// Clears the semaphore's last-signal cache.
+//
+// Caller must verify iree_hal_amdgpu_semaphore_isa() first.
+void iree_hal_amdgpu_semaphore_clear_last_signal(
+    iree_hal_semaphore_t* semaphore);
 
 #ifdef __cplusplus
 }  // extern "C"
diff --git a/runtime/src/iree/hal/drivers/amdgpu/semaphore_pool.c b/runtime/src/iree/hal/drivers/amdgpu/semaphore_pool.c
deleted file mode 100644
index 1d7378c..0000000
--- a/runtime/src/iree/hal/drivers/amdgpu/semaphore_pool.c
+++ /dev/null
@@ -1,434 +0,0 @@
-// Copyright 2025 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/hal/drivers/amdgpu/semaphore_pool.h"
-
-#include "iree/hal/drivers/amdgpu/device/semaphore.h"
-#include "iree/hal/drivers/amdgpu/semaphore.h"
-#include "iree/hal/drivers/amdgpu/util/topology.h"
-
-static void iree_hal_amdgpu_semaphore_pool_link_free_block(
-    iree_hal_amdgpu_semaphore_pool_t* semaphore_pool,
-    iree_hal_amdgpu_semaphore_pool_block_t* block);
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_semaphore_pool_block_t
-//===----------------------------------------------------------------------===//
-
-// A block of allocated semaphores. Manages both host heap memory and
-// device-visible memory for the device-side library resources.
-//
-// Thread-safe; each block has its own lock for free list management.
-typedef struct iree_hal_amdgpu_semaphore_pool_block_t {
-  // Pool that owns this block.
-  iree_hal_amdgpu_semaphore_pool_t* semaphore_pool;
-  // Previous block in the pool block linked list.
-  struct iree_hal_amdgpu_semaphore_pool_block_t* prev_block;
-  // Next block in the pool block linked list.
-  struct iree_hal_amdgpu_semaphore_pool_block_t* next_block;
-  // Next block in the pool block linked list with free entries.
-  struct iree_hal_amdgpu_semaphore_pool_block_t* next_free;
-  // Capacity of the block in semaphores.
-  iree_host_size_t capacity;
-  // Device memory base pointer used for `iree_hal_amdgpu_device_semaphore_t`.
-  IREE_AMDGPU_DEVICE_PTR uint8_t* device_allocation_ptr;
-  // Mutex guarding the mutable block fields.
-  iree_slim_mutex_t mutex;
-  // Count of free semaphores in the block stored in the free_list.
-  iree_host_size_t free_count IREE_GUARDED_BY(mutex);
-  // Free semaphores that are available for use.
-  iree_hal_amdgpu_internal_semaphore_t* free_list[/*capacity*/] IREE_GUARDED_BY(
-      mutex);
-  // Trailing list of iree_hal_amdgpu_internal_semaphore_t[capacity].
-} iree_hal_amdgpu_semaphore_pool_block_t;
-
-static void iree_hal_amdgpu_semaphore_pool_block_free(
-    iree_hal_amdgpu_semaphore_pool_block_t* block);
-static void iree_hal_amdgpu_semaphore_pool_block_recycle(
-    void* user_data, iree_hal_amdgpu_internal_semaphore_t* semaphore);
-
-// Allocates a block of |capacity| semaphores on host and device.
-static iree_status_t iree_hal_amdgpu_semaphore_pool_block_allocate(
-    iree_hal_amdgpu_semaphore_pool_t* semaphore_pool, iree_host_size_t capacity,
-    iree_hal_semaphore_flags_t flags,
-    iree_hal_amdgpu_semaphore_pool_block_t** out_block) {
-  IREE_ASSERT_ARGUMENT(out_block);
-  IREE_TRACE_ZONE_BEGIN(z0);
-  *out_block = NULL;
-
-  const iree_hal_amdgpu_libhsa_t* libhsa = semaphore_pool->libhsa;
-
-  // Allocate and initialize host memory.
-  iree_hal_amdgpu_semaphore_pool_block_t* block = NULL;
-  const iree_host_size_t free_list_size =
-      capacity * sizeof(block->free_list[0]);
-  const iree_host_size_t total_block_size =
-      sizeof(*block) + free_list_size +
-      capacity * sizeof(iree_hal_amdgpu_internal_semaphore_t);
-  IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0, iree_allocator_malloc(semaphore_pool->host_allocator,
-                                total_block_size, (void**)&block));
-  block->semaphore_pool = semaphore_pool;
-  block->prev_block = NULL;
-  block->next_block = NULL;
-  block->next_free = NULL;
-  block->capacity = capacity;
-  block->device_allocation_ptr = NULL;
-  iree_slim_mutex_initialize(&block->mutex);
-
-  // Allocate device memory from the HSA memory pool.
-  const iree_host_size_t total_device_size =
-      capacity * sizeof(iree_hal_amdgpu_device_semaphore_t);
-  iree_status_t status = iree_hsa_amd_memory_pool_allocate(
-      IREE_LIBHSA(libhsa), semaphore_pool->memory_pool, total_device_size,
-      HSA_AMD_MEMORY_POOL_STANDARD_FLAG, (void**)&block->device_allocation_ptr);
-
-  // Make the allocation visible to all devices.
-  if (iree_status_is_ok(status)) {
-    const iree_hal_amdgpu_topology_t* topology = semaphore_pool->topology;
-    status = iree_hsa_amd_agents_allow_access(
-        IREE_LIBHSA(libhsa), topology->all_agent_count, topology->all_agents,
-        /*flags=*/NULL, block->device_allocation_ptr);
-  }
-
-  // Initialize each host semaphore and build the free list.
-  if (iree_status_is_ok(status)) {
-    iree_hal_amdgpu_internal_semaphore_t* base_host_ptr =
-        (iree_hal_amdgpu_internal_semaphore_t*)((uint8_t*)block +
-                                                sizeof(*block) +
-                                                free_list_size);
-    iree_hal_amdgpu_device_semaphore_t* base_device_ptr =
-        (iree_hal_amdgpu_device_semaphore_t*)block->device_allocation_ptr;
-    block->free_count = capacity;
-    iree_hal_amdgpu_internal_semaphore_release_callback_t release_callback = {
-        .fn = iree_hal_amdgpu_semaphore_pool_block_recycle,
-        .user_data = block,
-    };
-    for (iree_host_size_t i = 0; i < capacity; ++i) {
-      iree_hal_amdgpu_internal_semaphore_t* semaphore = &base_host_ptr[i];
-      iree_hal_amdgpu_device_semaphore_t* device_semaphore =
-          &base_device_ptr[i];
-      status = iree_hal_amdgpu_internal_semaphore_initialize(
-          semaphore_pool->libhsa, semaphore_pool->options, flags,
-          device_semaphore, release_callback, semaphore);
-      if (!iree_status_is_ok(status)) break;
-      block->free_list[i] = semaphore;
-    }
-  }
-
-  if (iree_status_is_ok(status)) {
-    *out_block = block;
-  } else {
-    iree_hal_amdgpu_semaphore_pool_block_free(block);
-  }
-  IREE_TRACE_ZONE_END(z0);
-  return status;
-}
-
-// Frees a |block| of semaphores and its device memory.
-static void iree_hal_amdgpu_semaphore_pool_block_free(
-    iree_hal_amdgpu_semaphore_pool_block_t* block) {
-  IREE_ASSERT_ARGUMENT(block);
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  iree_slim_mutex_lock(&block->mutex);
-  IREE_ASSERT_EQ(block->free_count, block->capacity);
-  iree_slim_mutex_unlock(&block->mutex);
-
-  // Deinitialize all host semaphores. They are allocated as part of the block
-  // and only need to be cleaned up.
-  const iree_host_size_t free_list_size =
-      block->capacity * sizeof(block->free_list[0]);
-  iree_hal_amdgpu_internal_semaphore_t* base_host_ptr =
-      (iree_hal_amdgpu_internal_semaphore_t*)((uint8_t*)block + sizeof(*block) +
-                                              free_list_size);
-  for (iree_host_size_t i = 0; i < block->capacity; ++i) {
-    iree_hal_amdgpu_internal_semaphore_t* semaphore = &base_host_ptr[i];
-    iree_hal_amdgpu_internal_semaphore_deinitialize(semaphore);
-  }
-
-  // Deallocate device memory.
-  if (block->device_allocation_ptr) {
-    IREE_IGNORE_ERROR(iree_hsa_amd_memory_pool_free(
-        IREE_LIBHSA(block->semaphore_pool->libhsa),
-        block->device_allocation_ptr));
-    block->device_allocation_ptr = NULL;
-  }
-
-  // Frees the block and its embedded storage.
-  iree_slim_mutex_deinitialize(&block->mutex);
-  iree_allocator_free(block->semaphore_pool->host_allocator, block);
-
-  IREE_TRACE_ZONE_END(z0);
-}
-
-// Recycles a semaphore after it has no remaining uses.
-static void iree_hal_amdgpu_semaphore_pool_block_recycle(
-    void* user_data, iree_hal_amdgpu_internal_semaphore_t* semaphore) {
-  iree_hal_amdgpu_semaphore_pool_block_t* block =
-      (iree_hal_amdgpu_semaphore_pool_block_t*)user_data;
-
-  // Semaphore should have zero references before being recycled.
-  IREE_ASSERT_REF_COUNT_ZERO(&semaphore->resource.ref_count);
-
-  // Add to the block free list.
-  iree_slim_mutex_lock(&block->mutex);
-
-  const bool full_to_free = block->free_count == 0;
-  block->free_list[block->free_count++] = semaphore;
-
-  iree_slim_mutex_unlock(&block->mutex);
-
-  // If the block has gone from 0 to >0 free entries then link it back into the
-  // pool free list for use. Note that we can only do this on the transition
-  // from full to free as otherwise the block is already in the free list.
-  //
-  // NOTE: this happens outside of the per-block lock as the pool will hold its
-  // lock over the free list while acquiring a new entry. This may lead to
-  // (safe) races where an acquire checks the free list while we are updating
-  // the block above but before we update the free list but that's rare and
-  // bounded (there may be one extra block in the pool).
-  if (full_to_free) {
-    iree_hal_amdgpu_semaphore_pool_link_free_block(block->semaphore_pool,
-                                                   block);
-  }
-}
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_semaphore_pool_t
-//===----------------------------------------------------------------------===//
-
-iree_status_t iree_hal_amdgpu_semaphore_pool_initialize(
-    const iree_hal_amdgpu_libhsa_t* libhsa,
-    const iree_hal_amdgpu_topology_t* topology, iree_host_size_t block_capacity,
-    iree_hal_amdgpu_semaphore_options_t options,
-    iree_hal_semaphore_flags_t flags, iree_allocator_t host_allocator,
-    hsa_amd_memory_pool_t memory_pool,
-    iree_hal_amdgpu_semaphore_pool_t* out_semaphore_pool) {
-  IREE_ASSERT_ARGUMENT(libhsa);
-  IREE_ASSERT_ARGUMENT(topology);
-  IREE_ASSERT_ARGUMENT(out_semaphore_pool);
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  out_semaphore_pool->libhsa = libhsa;
-  out_semaphore_pool->topology = topology;
-  out_semaphore_pool->options = options;
-  out_semaphore_pool->host_allocator = host_allocator;
-  out_semaphore_pool->memory_pool = memory_pool;
-  out_semaphore_pool->flags = flags;
-
-  iree_slim_mutex_initialize(&out_semaphore_pool->mutex);
-  out_semaphore_pool->list_head = NULL;
-  out_semaphore_pool->free_head = NULL;
-
-  // Query the memory pool for its allocation granularity.
-  // This is not the minimum allocation size
-  // (HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_GRANULE) but the recommended size
-  // to prevent internal fragmentation. We will always make allocations of this
-  // size and adjust the block capacity to match.
-  size_t alloc_rec_granule = 0;
-  IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0,
-      iree_hsa_amd_memory_pool_get_info(
-          IREE_LIBHSA(libhsa), memory_pool,
-          HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_REC_GRANULE,
-          &alloc_rec_granule),
-      "querying HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_REC_GRANULE to "
-      "determine block capacity");
-  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, alloc_rec_granule);
-
-  // Allocate aligned to the recommended allocation granularity.
-  // We'll always be some multiple of the recommended size so we waste no device
-  // space.
-  const iree_host_size_t min_capacity_per_allocation = iree_host_size_ceil_div(
-      alloc_rec_granule, sizeof(iree_hal_amdgpu_device_semaphore_t));
-  const iree_host_size_t capacity_per_allocation =
-      iree_host_size_ceil_div(block_capacity, min_capacity_per_allocation) *
-      min_capacity_per_allocation;
-  out_semaphore_pool->block_capacity = capacity_per_allocation;
-
-  IREE_TRACE_ZONE_END(z0);
-  return iree_ok_status();
-}
-
-void iree_hal_amdgpu_semaphore_pool_deinitialize(
-    iree_hal_amdgpu_semaphore_pool_t* semaphore_pool) {
-  IREE_ASSERT_ARGUMENT(semaphore_pool);
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  iree_slim_mutex_lock(&semaphore_pool->mutex);
-  iree_hal_amdgpu_semaphore_pool_block_t* block = semaphore_pool->list_head;
-  while (block != NULL) {
-    iree_hal_amdgpu_semaphore_pool_block_t* next_block = block->next_block;
-    IREE_ASSERT_EQ(block->free_count, block->capacity);
-    iree_hal_amdgpu_semaphore_pool_block_free(block);
-    block = next_block;
-  }
-  semaphore_pool->list_head = NULL;
-  semaphore_pool->free_head = NULL;
-  iree_slim_mutex_unlock(&semaphore_pool->mutex);
-
-  iree_slim_mutex_deinitialize(&semaphore_pool->mutex);
-
-  IREE_TRACE_ZONE_END(z0);
-}
-
-// Grows the |semaphore_pool| by one block.
-// Requires the pool lock be held.
-static iree_status_t iree_hal_amdgpu_semaphore_pool_grow(
-    iree_hal_amdgpu_semaphore_pool_t* semaphore_pool) {
-  IREE_ASSERT_ARGUMENT(semaphore_pool);
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  // Allocate the new block and its resources.
-  iree_hal_amdgpu_semaphore_pool_block_t* block = NULL;
-  IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0, iree_hal_amdgpu_semaphore_pool_block_allocate(
-              semaphore_pool, semaphore_pool->block_capacity,
-              semaphore_pool->flags, &block));
-
-  // Link the block into the allocated list and the free list.
-  block->prev_block = NULL;
-  block->next_block = semaphore_pool->list_head;
-  if (block->next_block) {
-    block->next_block->prev_block = block;
-  }
-  semaphore_pool->list_head = block;
-  block->next_free = semaphore_pool->free_head;
-  semaphore_pool->free_head = block;
-
-  IREE_TRACE_ZONE_END(z0);
-  return iree_ok_status();
-}
-
-iree_status_t iree_hal_amdgpu_semaphore_pool_preallocate(
-    iree_hal_amdgpu_semaphore_pool_t* semaphore_pool, iree_host_size_t count) {
-  IREE_ASSERT_ARGUMENT(semaphore_pool);
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  iree_status_t status = iree_ok_status();
-  const iree_host_size_t block_count =
-      iree_host_size_ceil_div(count, semaphore_pool->block_capacity);
-  for (iree_host_size_t i = 0; iree_status_is_ok(status) && i < block_count;
-       ++i) {
-    status = iree_hal_amdgpu_semaphore_pool_grow(semaphore_pool);
-  }
-
-  IREE_TRACE_ZONE_END(z0);
-  return status;
-}
-
-iree_status_t iree_hal_amdgpu_semaphore_pool_acquire(
-    iree_hal_amdgpu_semaphore_pool_t* semaphore_pool, uint64_t initial_value,
-    iree_hal_semaphore_flags_t flags, iree_hal_semaphore_t** out_semaphore) {
-  IREE_ASSERT_ARGUMENT(semaphore_pool);
-  IREE_ASSERT_ARGUMENT(out_semaphore);
-  IREE_TRACE_ZONE_BEGIN(z0);
-  *out_semaphore = NULL;
-
-  iree_slim_mutex_lock(&semaphore_pool->mutex);
-
-  // If there are no blocks with free semaphores allocate a new one.
-  iree_status_t status = iree_ok_status();
-  if (semaphore_pool->free_head == NULL) {
-    // TODO(benvanik): do this outside of the lock? This allocates device
-    // resources. We could have an exclusive growth lock that does not block
-    // recycling.
-    status = iree_hal_amdgpu_semaphore_pool_grow(semaphore_pool);
-  }
-
-  // Get the next free semaphore and possibly maintain the free list.
-  iree_hal_amdgpu_internal_semaphore_t* semaphore = NULL;
-  if (iree_status_is_ok(status)) {
-    iree_hal_amdgpu_semaphore_pool_block_t* block = semaphore_pool->free_head;
-    iree_slim_mutex_lock(&block->mutex);
-
-    // Pop the last free semaphore from the block.
-    semaphore = block->free_list[block->free_count - 1];
-    IREE_ASSERT_NE(semaphore, NULL);
-    block->free_list[block->free_count - 1] = NULL;
-    --block->free_count;
-
-    // If there are no more free semaphores in the block remove it from the
-    // free list.
-    if (block->free_count == 0) {
-      semaphore_pool->free_head = block->next_free;
-      block->next_free = NULL;
-    }
-
-    iree_slim_mutex_unlock(&block->mutex);
-  }
-
-  iree_slim_mutex_unlock(&semaphore_pool->mutex);
-
-  if (iree_status_is_ok(status)) {
-    // Return with a 1 ref count as if we had allocated it.
-    iree_atomic_ref_count_inc(&semaphore->resource.ref_count);
-    iree_hal_amdgpu_internal_semaphore_reset(semaphore, initial_value);
-    *out_semaphore = (iree_hal_semaphore_t*)semaphore;
-  }
-  IREE_TRACE_ZONE_END(z0);
-  return status;
-}
-
-// Links |block| into the |semaphore_pool| free list.
-// Must not already be in the list.
-// The block is inserted at the head to try to have new acquisitions reuse it
-// before any others and keep the utilization high.
-static void iree_hal_amdgpu_semaphore_pool_link_free_block(
-    iree_hal_amdgpu_semaphore_pool_t* semaphore_pool,
-    iree_hal_amdgpu_semaphore_pool_block_t* block) {
-  iree_slim_mutex_lock(&semaphore_pool->mutex);
-  block->next_free = semaphore_pool->free_head;
-  semaphore_pool->free_head = block;
-  iree_slim_mutex_unlock(&semaphore_pool->mutex);
-}
-
-void iree_hal_amdgpu_semaphore_pool_trim(
-    iree_hal_amdgpu_semaphore_pool_t* semaphore_pool) {
-  IREE_ASSERT_ARGUMENT(semaphore_pool);
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  // Walk each block in the free list. If all semaphores are free then drop it.
-  iree_slim_mutex_lock(&semaphore_pool->mutex);
-  iree_hal_amdgpu_semaphore_pool_block_t* prev_block = NULL;
-  iree_hal_amdgpu_semaphore_pool_block_t* block = semaphore_pool->free_head;
-  while (block != NULL) {
-    iree_hal_amdgpu_semaphore_pool_block_t* next_block = block->next_free;
-    if (block->free_count != block->capacity) {
-      // One or more semaphores in use - cannot free the block.
-      prev_block = block;
-      block = next_block;
-      continue;
-    }
-
-    // Unlink the block from the free list.
-    if (prev_block != NULL) {
-      prev_block->next_free = next_block;
-    } else {
-      semaphore_pool->free_head = next_block;
-    }
-
-    // Unlink the block from the main list.
-    if (block->prev_block != NULL) {
-      block->prev_block->next_block = block->next_block;
-    } else {
-      semaphore_pool->list_head = block->next_block;
-    }
-    if (block->next_block != NULL) {
-      block->next_block->prev_block = block->prev_block;
-    }
-
-    // Free the block and its resources.
-    iree_hal_amdgpu_semaphore_pool_block_free(block);
-
-    block = next_block;
-  }
-
-  iree_slim_mutex_unlock(&semaphore_pool->mutex);
-
-  IREE_TRACE_ZONE_END(z0);
-}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/semaphore_pool.h b/runtime/src/iree/hal/drivers/amdgpu/semaphore_pool.h
deleted file mode 100644
index b6f2f1b..0000000
--- a/runtime/src/iree/hal/drivers/amdgpu/semaphore_pool.h
+++ /dev/null
@@ -1,106 +0,0 @@
-// Copyright 2025 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_HAL_DRIVERS_AMDGPU_SEMAPHORE_POOL_H_
-#define IREE_HAL_DRIVERS_AMDGPU_SEMAPHORE_POOL_H_
-
-#include "iree/base/api.h"
-#include "iree/base/threading/mutex.h"
-#include "iree/hal/api.h"
-#include "iree/hal/drivers/amdgpu/semaphore.h"
-#include "iree/hal/drivers/amdgpu/util/libhsa.h"
-
-#ifdef __cplusplus
-extern "C" {
-#endif  // __cplusplus
-
-typedef struct iree_hal_amdgpu_internal_semaphore_t
-    iree_hal_amdgpu_internal_semaphore_t;
-typedef struct iree_hal_amdgpu_semaphore_pool_block_t
-    iree_hal_amdgpu_semaphore_pool_block_t;
-typedef struct iree_hal_amdgpu_topology_t iree_hal_amdgpu_topology_t;
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_semaphore_pool_t
-//===----------------------------------------------------------------------===//
-
-// Default semaphore count per block in the pool.
-// Larger is better to reduce the number of device memory allocations but we
-// don't want to have too high of a fixed overhead. Most programs only have a
-// few dozen live semaphores at a time.
-#define IREE_HAL_AMDGPU_SEMAPHORE_POOL_DEFAULT_BLOCK_CAPACITY 512
-
-// A pool of allocated HAL semaphores and their corresponding device resources.
-// Semaphores are allocated in blocks to reduce the number of device allocations
-// we make (as some devices/drivers may have limits). Blocks are allocated
-// on-demand and contain a fixed-size set of HAL semaphores allocated inline.
-//
-// Thread-safe; multiple host threads may share the same pool.
-typedef struct iree_hal_amdgpu_semaphore_pool_t {
-  // Unowned libhsa handle. Must be retained by the owner.
-  const iree_hal_amdgpu_libhsa_t* libhsa;
-  // Topology with all CPU and GPU agents. Semaphores must be visible to all.
-  const iree_hal_amdgpu_topology_t* topology;
-  // Global semaphore options.
-  iree_hal_amdgpu_semaphore_options_t options;
-
-  // Allocator used for host allocations.
-  iree_allocator_t host_allocator;
-  // HSA memory pool for device allocations.
-  hsa_amd_memory_pool_t memory_pool;
-
-  // Common semaphore flags for all allocated in the pool.
-  // Semaphores acquired may adjust some flags if they don't change how the
-  // semaphore is allocated.
-  iree_hal_semaphore_flags_t flags;
-
-  // Capacity of each block in semaphores.
-  // Most likely IREE_HAL_AMDGPU_SEMAPHORE_POOL_DEFAULT_BLOCK_CAPACITY.
-  iree_host_size_t block_capacity;
-
-  // Guards pool resources during acquisition.
-  iree_slim_mutex_t mutex;
-  // A doubly-linked list of all allocated blocks.
-  iree_hal_amdgpu_semaphore_pool_block_t* list_head IREE_GUARDED_BY(mutex);
-  // A singly-linked list of blocks that have one or more free semaphore.
-  iree_hal_amdgpu_semaphore_pool_block_t* free_head IREE_GUARDED_BY(mutex);
-} iree_hal_amdgpu_semaphore_pool_t;
-
-// Initializes |out_semaphore_pool| for use. Performs no allocation.
-// Semaphores will be usable on all CPU and GPU devices in |topology|.
-// The device |memory_pool| will be used for device-visible allocations.
-iree_status_t iree_hal_amdgpu_semaphore_pool_initialize(
-    const iree_hal_amdgpu_libhsa_t* libhsa,
-    const iree_hal_amdgpu_topology_t* topology, iree_host_size_t block_capacity,
-    iree_hal_amdgpu_semaphore_options_t options,
-    iree_hal_semaphore_flags_t flags, iree_allocator_t host_allocator,
-    hsa_amd_memory_pool_t memory_pool,
-    iree_hal_amdgpu_semaphore_pool_t* out_semaphore_pool);
-
-// Deinitializes |semaphore_pool| and releases underlying memory.
-// All semaphores created from the pool must have been released back to it.
-void iree_hal_amdgpu_semaphore_pool_deinitialize(
-    iree_hal_amdgpu_semaphore_pool_t* semaphore_pool);
-
-// Preallocates |count| semaphores and adds them to the pool free list.
-iree_status_t iree_hal_amdgpu_semaphore_pool_preallocate(
-    iree_hal_amdgpu_semaphore_pool_t* semaphore_pool, iree_host_size_t count);
-
-// Acquires a semaphore from the pool with the given |initial_value|.
-// |flags| must be compatible with the flags used for pool initialization.
-iree_status_t iree_hal_amdgpu_semaphore_pool_acquire(
-    iree_hal_amdgpu_semaphore_pool_t* semaphore_pool, uint64_t initial_value,
-    iree_hal_semaphore_flags_t flags, iree_hal_semaphore_t** out_semaphore);
-
-// Trims all blocks that have no allocated semaphores.
-void iree_hal_amdgpu_semaphore_pool_trim(
-    iree_hal_amdgpu_semaphore_pool_t* semaphore_pool);
-
-#ifdef __cplusplus
-}  // extern "C"
-#endif  // __cplusplus
-
-#endif  // IREE_HAL_DRIVERS_AMDGPU_SEMAPHORE_POOL_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/semaphore_pool_test.cc b/runtime/src/iree/hal/drivers/amdgpu/semaphore_pool_test.cc
deleted file mode 100644
index 5408494..0000000
--- a/runtime/src/iree/hal/drivers/amdgpu/semaphore_pool_test.cc
+++ /dev/null
@@ -1,214 +0,0 @@
-// Copyright 2025 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/hal/drivers/amdgpu/semaphore_pool.h"
-
-#include <vector>
-
-#include "iree/base/api.h"
-#include "iree/hal/drivers/amdgpu/device/semaphore.h"
-#include "iree/hal/drivers/amdgpu/semaphore.h"
-#include "iree/hal/drivers/amdgpu/util/topology.h"
-#include "iree/hal/drivers/amdgpu/util/vmem.h"
-#include "iree/testing/gtest.h"
-#include "iree/testing/status_matchers.h"
-
-namespace iree::hal::amdgpu {
-namespace {
-
-using iree::testing::status::StatusIs;
-
-struct SemaphorePoolTest : public ::testing::Test {
-  static iree_allocator_t host_allocator;
-  static iree_hal_amdgpu_libhsa_t libhsa;
-  static iree_hal_amdgpu_topology_t topology;
-  static hsa_amd_memory_pool_t cpu_memory_pool;
-
-  static void SetUpTestSuite() {
-    IREE_TRACE_SCOPE();
-    host_allocator = iree_allocator_system();
-    iree_status_t status = iree_hal_amdgpu_libhsa_initialize(
-        IREE_HAL_AMDGPU_LIBHSA_FLAG_NONE, iree_string_view_list_empty(),
-        host_allocator, &libhsa);
-    if (!iree_status_is_ok(status)) {
-      iree_status_fprint(stderr, status);
-      iree_status_ignore(status);
-      GTEST_SKIP() << "HSA not available, skipping tests";
-    }
-    IREE_ASSERT_OK(
-        iree_hal_amdgpu_topology_initialize_with_defaults(&libhsa, &topology));
-    if (topology.gpu_agent_count == 0) {
-      GTEST_SKIP() << "no GPU devices available, skipping tests";
-    }
-
-    hsa_agent_t cpu_agent = topology.cpu_agents[0];
-    IREE_ASSERT_OK(iree_hal_amdgpu_find_fine_global_memory_pool(
-        &libhsa, cpu_agent, &cpu_memory_pool));
-  }
-
-  static void TearDownTestSuite() {
-    IREE_TRACE_SCOPE();
-    iree_hal_amdgpu_topology_deinitialize(&topology);
-    iree_hal_amdgpu_libhsa_deinitialize(&libhsa);
-  }
-};
-iree_allocator_t SemaphorePoolTest::host_allocator;
-iree_hal_amdgpu_libhsa_t SemaphorePoolTest::libhsa;
-iree_hal_amdgpu_topology_t SemaphorePoolTest::topology;
-hsa_amd_memory_pool_t SemaphorePoolTest::cpu_memory_pool;
-
-// Tests that a pool can be initialized/deinitialized successfully.
-// Note that pools do not allocate anything on initialization so this should
-// never allocate.
-TEST_F(SemaphorePoolTest, Lifetime) {
-  IREE_TRACE_SCOPE();
-
-  iree_hal_amdgpu_semaphore_options_t options = {0};
-  iree_hal_amdgpu_semaphore_pool_t semaphore_pool = {0};
-  IREE_ASSERT_OK(iree_hal_amdgpu_semaphore_pool_initialize(
-      &libhsa, &topology, IREE_HAL_AMDGPU_SEMAPHORE_POOL_DEFAULT_BLOCK_CAPACITY,
-      options, IREE_HAL_SEMAPHORE_FLAG_DEFAULT, host_allocator, cpu_memory_pool,
-      &semaphore_pool));
-
-  // No-op since nothing has been allocated.
-  iree_hal_amdgpu_semaphore_pool_trim(&semaphore_pool);
-
-  iree_hal_amdgpu_semaphore_pool_deinitialize(&semaphore_pool);
-}
-
-// Tests a pool that has preallocation requests.
-// We make a few requests interleaved with trims and then rely on
-// deinitialization to free the remaining resources to ensure there are no
-// leaks.
-TEST_F(SemaphorePoolTest, LifetimePreallocate) {
-  IREE_TRACE_SCOPE();
-
-  iree_hal_amdgpu_semaphore_options_t options = {0};
-  iree_hal_amdgpu_semaphore_pool_t semaphore_pool = {0};
-  IREE_ASSERT_OK(iree_hal_amdgpu_semaphore_pool_initialize(
-      &libhsa, &topology,
-      /*block_capacity=*/32, options, IREE_HAL_SEMAPHORE_FLAG_DEFAULT,
-      host_allocator, cpu_memory_pool, &semaphore_pool));
-
-  // No-op since nothing has been allocated yet.
-  iree_hal_amdgpu_semaphore_pool_trim(&semaphore_pool);
-
-  // No-op preallocation (can happen if we blindly pass options/flags of 0).
-  IREE_ASSERT_OK(
-      iree_hal_amdgpu_semaphore_pool_preallocate(&semaphore_pool, 0));
-
-  // Preallocate one block.
-  IREE_ASSERT_OK(iree_hal_amdgpu_semaphore_pool_preallocate(
-      &semaphore_pool, semaphore_pool.block_capacity));
-
-  // Trim the entire block (nothing is used).
-  iree_hal_amdgpu_semaphore_pool_trim(&semaphore_pool);
-
-  // Preallocate two blocks.
-  IREE_ASSERT_OK(iree_hal_amdgpu_semaphore_pool_preallocate(
-      &semaphore_pool, semaphore_pool.block_capacity + 1));
-
-  // Preallocate one more block (1 buffer ceildiv capacity = 1 block).
-  IREE_ASSERT_OK(
-      iree_hal_amdgpu_semaphore_pool_preallocate(&semaphore_pool, 1));
-
-  // Deinitialize with remaining preallocated blocks to test cleanup.
-  iree_hal_amdgpu_semaphore_pool_deinitialize(&semaphore_pool);
-}
-
-// Tests acquiring and releasing a buffer handle from the pool.
-TEST_F(SemaphorePoolTest, AcquireRelease) {
-  IREE_TRACE_SCOPE();
-
-  iree_hal_amdgpu_semaphore_options_t options = {0};
-  iree_hal_amdgpu_semaphore_pool_t semaphore_pool = {0};
-  IREE_ASSERT_OK(iree_hal_amdgpu_semaphore_pool_initialize(
-      &libhsa, &topology,
-      /*block_capacity=*/32, options, IREE_HAL_SEMAPHORE_FLAG_DEFAULT,
-      host_allocator, cpu_memory_pool, &semaphore_pool));
-
-  // Acquire a semaphore.
-  const uint64_t initial_value = 123ull;
-  iree_hal_semaphore_t* semaphore = NULL;
-  IREE_ASSERT_OK(iree_hal_amdgpu_semaphore_pool_acquire(
-      &semaphore_pool, initial_value, IREE_HAL_SEMAPHORE_FLAG_DEFAULT,
-      &semaphore));
-  ASSERT_NE(semaphore, nullptr);
-
-  // Ensure it reports the initial value that was specified.
-  uint64_t reported_value = 0ull;
-  IREE_ASSERT_OK(iree_hal_semaphore_query(semaphore, &reported_value));
-  EXPECT_EQ(reported_value, initial_value);
-
-  // Ensure the device-visible handle is initialized.
-  iree_hal_amdgpu_device_semaphore_t* handle = NULL;
-  IREE_ASSERT_OK(iree_hal_amdgpu_semaphore_handle(semaphore, &handle));
-  ASSERT_NE(semaphore, nullptr);
-  ASSERT_EQ(handle->host_semaphore, (uint64_t)semaphore);
-
-  // Release the semaphore back to the pool - we're the last reference and it
-  // should be recycled.
-  iree_hal_semaphore_release(semaphore);
-
-  iree_hal_amdgpu_semaphore_pool_deinitialize(&semaphore_pool);
-}
-
-// Explicitly tests pool growth by acquiring an entire block worth of
-// semaphores+1. We then release all the semaphores that should have been in the
-// first block and trim with the second block outstanding to ensure it is not
-// reclaimed with the buffer outstanding.
-TEST_F(SemaphorePoolTest, Growth) {
-  IREE_TRACE_SCOPE();
-
-  iree_hal_amdgpu_semaphore_options_t options = {0};
-  iree_hal_amdgpu_semaphore_pool_t semaphore_pool = {0};
-  IREE_ASSERT_OK(iree_hal_amdgpu_semaphore_pool_initialize(
-      &libhsa, &topology, /*block_capacity=*/32, options,
-      IREE_HAL_SEMAPHORE_FLAG_DEFAULT, host_allocator, cpu_memory_pool,
-      &semaphore_pool));
-  // NOTE: the capacity may be larger than requested due to alignment.
-  const iree_host_size_t block_capacity = semaphore_pool.block_capacity;
-
-  std::vector<iree_hal_semaphore_t*> semaphores(block_capacity);
-
-  // Preallocate the first block (just to put more load on that path).
-  IREE_ASSERT_OK(iree_hal_amdgpu_semaphore_pool_preallocate(&semaphore_pool,
-                                                            block_capacity));
-
-  // Allocate enough to consume the entire first block.
-  for (iree_host_size_t i = 0; i < block_capacity; ++i) {
-    IREE_ASSERT_OK(iree_hal_amdgpu_semaphore_pool_acquire(
-        &semaphore_pool, /*initial_value=*/0ull,
-        IREE_HAL_SEMAPHORE_FLAG_DEFAULT, &semaphores[i]));
-    ASSERT_NE(semaphores[i], nullptr);
-  }
-
-  // Allocate +1 to trigger growth and acquire the next block.
-  iree_hal_semaphore_t* growth_semaphore = NULL;
-  IREE_ASSERT_OK(iree_hal_amdgpu_semaphore_pool_acquire(
-      &semaphore_pool, /*initial_value=*/0ull, IREE_HAL_SEMAPHORE_FLAG_DEFAULT,
-      &growth_semaphore));
-  ASSERT_NE(growth_semaphore, nullptr);
-
-  // Recycle all the semaphores from the first block. After this it should have
-  // no outstanding semaphores allocated it from it and be a candidate for
-  // trimming.
-  for (iree_host_size_t i = 0; i < block_capacity; ++i) {
-    iree_hal_semaphore_release(semaphores[i]);
-  }
-
-  // Trim to drop the unused first block.
-  iree_hal_amdgpu_semaphore_pool_trim(&semaphore_pool);
-
-  // Release the last semaphore and let the deinitialize cleanup the second
-  // block.
-  iree_hal_semaphore_release(growth_semaphore);
-
-  iree_hal_amdgpu_semaphore_pool_deinitialize(&semaphore_pool);
-}
-
-}  // namespace
-}  // namespace iree::hal::amdgpu
diff --git a/runtime/src/iree/hal/drivers/amdgpu/semaphore_test.cc b/runtime/src/iree/hal/drivers/amdgpu/semaphore_test.cc
new file mode 100644
index 0000000..491db91
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/semaphore_test.cc
@@ -0,0 +1,263 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/semaphore.h"
+
+#include <string.h>
+
+#include <vector>
+
+#include "iree/async/frontier.h"
+#include "iree/async/proactor_platform.h"
+#include "iree/hal/drivers/amdgpu/host_queue_policy.h"
+#include "iree/hal/drivers/amdgpu/logical_device.h"
+#include "iree/hal/drivers/amdgpu/system.h"
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace {
+
+static iree_async_proactor_t* test_proactor() {
+  static iree_async_proactor_t* proactor = nullptr;
+  if (!proactor) {
+    IREE_CHECK_OK(iree_async_proactor_create_platform(
+        iree_async_proactor_options_default(), iree_allocator_system(),
+        &proactor));
+    atexit([] {
+      iree_async_proactor_release(proactor);
+      proactor = nullptr;
+    });
+  }
+  return proactor;
+}
+
+static iree_async_axis_t test_queue_axis(uint8_t queue_index) {
+  return iree_async_axis_make_queue(/*session_epoch=*/1, /*machine_index=*/0,
+                                    /*device_index=*/0, queue_index);
+}
+
+class FrontierBuilder {
+ public:
+  iree_async_frontier_t* Build(
+      std::initializer_list<iree_async_frontier_entry_t> entries) {
+    storage_.resize(sizeof(iree_async_frontier_t) +
+                    entries.size() * sizeof(iree_async_frontier_entry_t));
+    auto* frontier = reinterpret_cast<iree_async_frontier_t*>(storage_.data());
+    iree_async_frontier_initialize(frontier,
+                                   static_cast<uint8_t>(entries.size()));
+    iree_host_size_t i = 0;
+    for (const auto& entry : entries) {
+      frontier->entries[i++] = entry;
+    }
+    return frontier;
+  }
+
+ private:
+  std::vector<uint8_t> storage_;
+};
+
+class SemaphoreTest : public ::testing::Test {
+ protected:
+  void SetUp() override {
+    static uintptr_t fake_device_storage = 0;
+    fake_device_ = reinterpret_cast<iree_hal_amdgpu_logical_device_t*>(
+        &fake_device_storage);
+    IREE_ASSERT_OK(CreateSemaphore(IREE_HAL_SEMAPHORE_FLAG_NONE, &semaphore_));
+  }
+
+  void TearDown() override { iree_hal_semaphore_release(semaphore_); }
+
+  iree_status_t CreateSemaphore(iree_hal_semaphore_flags_t flags,
+                                iree_hal_semaphore_t** out_semaphore) {
+    return iree_hal_amdgpu_semaphore_create(
+        fake_device_, test_proactor(), IREE_HAL_QUEUE_AFFINITY_ANY,
+        /*initial_value=*/0, flags, iree_allocator_system(), out_semaphore);
+  }
+
+  iree_hal_amdgpu_logical_device_t* fake_device_ = nullptr;
+  iree_hal_semaphore_t* semaphore_ = nullptr;
+};
+
+TEST_F(SemaphoreTest, PrivateStreamSemanticsRequireStrictFlags) {
+  iree_hal_semaphore_t* private_semaphore = nullptr;
+  IREE_ASSERT_OK(CreateSemaphore(IREE_HAL_SEMAPHORE_FLAG_DEVICE_LOCAL |
+                                     IREE_HAL_SEMAPHORE_FLAG_SINGLE_PRODUCER,
+                                 &private_semaphore));
+  EXPECT_TRUE(iree_hal_amdgpu_semaphore_has_private_stream_semantics(
+      private_semaphore, fake_device_));
+  iree_hal_semaphore_release(private_semaphore);
+
+  iree_hal_semaphore_t* public_local_semaphore = nullptr;
+  IREE_ASSERT_OK(CreateSemaphore(IREE_HAL_SEMAPHORE_FLAG_DEFAULT |
+                                     IREE_HAL_SEMAPHORE_FLAG_DEVICE_LOCAL |
+                                     IREE_HAL_SEMAPHORE_FLAG_SINGLE_PRODUCER,
+                                 &public_local_semaphore));
+  EXPECT_FALSE(iree_hal_amdgpu_semaphore_has_private_stream_semantics(
+      public_local_semaphore, fake_device_));
+  iree_hal_semaphore_release(public_local_semaphore);
+
+  iree_hal_semaphore_t* multi_producer_semaphore = nullptr;
+  IREE_ASSERT_OK(CreateSemaphore(IREE_HAL_SEMAPHORE_FLAG_DEVICE_LOCAL,
+                                 &multi_producer_semaphore));
+  EXPECT_FALSE(iree_hal_amdgpu_semaphore_has_private_stream_semantics(
+      multi_producer_semaphore, fake_device_));
+  iree_hal_semaphore_release(multi_producer_semaphore);
+}
+
+TEST_F(SemaphoreTest, QueuePolicyUsesAgentScopeOnlyForSamePhysicalDevice) {
+  iree_hal_amdgpu_system_t system;
+  memset(&system, 0, sizeof(system));
+  system.topology.gpu_agent_queue_count = 2;
+
+  iree_hal_amdgpu_logical_device_t logical_device;
+  memset(&logical_device, 0, sizeof(logical_device));
+  logical_device.system = &system;
+  logical_device.physical_device_count = 2;
+  logical_device.queue_affinity_mask = 0xFull;
+
+  iree_hal_amdgpu_host_queue_t queue;
+  memset(&queue, 0, sizeof(queue));
+  queue.logical_device = (iree_hal_device_t*)&logical_device;
+  queue.device_ordinal = 0;
+
+  iree_hal_semaphore_t* same_agent_semaphore = nullptr;
+  IREE_ASSERT_OK(iree_hal_amdgpu_semaphore_create(
+      &logical_device, test_proactor(), /*queue_affinity=*/0x3ull,
+      /*initial_value=*/0, IREE_HAL_SEMAPHORE_FLAG_DEVICE_LOCAL,
+      iree_allocator_system(), &same_agent_semaphore));
+  EXPECT_EQ(iree_hal_amdgpu_host_queue_wait_acquire_scope(&queue,
+                                                          same_agent_semaphore),
+            IREE_HSA_FENCE_SCOPE_AGENT);
+  EXPECT_EQ(iree_hal_amdgpu_host_queue_signal_release_scope(
+                &queue, same_agent_semaphore),
+            IREE_HSA_FENCE_SCOPE_AGENT);
+  iree_hal_semaphore_release(same_agent_semaphore);
+
+  iree_hal_semaphore_t* cross_agent_semaphore = nullptr;
+  IREE_ASSERT_OK(iree_hal_amdgpu_semaphore_create(
+      &logical_device, test_proactor(), /*queue_affinity=*/0x4ull,
+      /*initial_value=*/0, IREE_HAL_SEMAPHORE_FLAG_DEVICE_LOCAL,
+      iree_allocator_system(), &cross_agent_semaphore));
+  EXPECT_EQ(iree_hal_amdgpu_host_queue_wait_acquire_scope(
+                &queue, cross_agent_semaphore),
+            IREE_HSA_FENCE_SCOPE_SYSTEM);
+  EXPECT_EQ(iree_hal_amdgpu_host_queue_signal_release_scope(
+                &queue, cross_agent_semaphore),
+            IREE_HSA_FENCE_SCOPE_SYSTEM);
+  iree_hal_semaphore_release(cross_agent_semaphore);
+
+  iree_hal_semaphore_t* public_semaphore = nullptr;
+  IREE_ASSERT_OK(iree_hal_amdgpu_semaphore_create(
+      &logical_device, test_proactor(), /*queue_affinity=*/0x1ull,
+      /*initial_value=*/0,
+      IREE_HAL_SEMAPHORE_FLAG_DEVICE_LOCAL |
+          IREE_HAL_SEMAPHORE_FLAG_HOST_INTERRUPT,
+      iree_allocator_system(), &public_semaphore));
+  EXPECT_EQ(
+      iree_hal_amdgpu_host_queue_wait_acquire_scope(&queue, public_semaphore),
+      IREE_HSA_FENCE_SCOPE_SYSTEM);
+  EXPECT_EQ(
+      iree_hal_amdgpu_host_queue_signal_release_scope(&queue, public_semaphore),
+      IREE_HSA_FENCE_SCOPE_SYSTEM);
+  iree_hal_semaphore_release(public_semaphore);
+}
+
+TEST_F(SemaphoreTest, PrivateStreamSignalPublishesExactProducerEpoch) {
+  iree_hal_semaphore_t* private_semaphore = nullptr;
+  IREE_ASSERT_OK(CreateSemaphore(IREE_HAL_SEMAPHORE_FLAG_DEVICE_LOCAL |
+                                     IREE_HAL_SEMAPHORE_FLAG_SINGLE_PRODUCER,
+                                 &private_semaphore));
+
+  const iree_async_axis_t producer_axis = test_queue_axis(2);
+  iree_hal_amdgpu_semaphore_publish_private_stream_signal(
+      private_semaphore, producer_axis, /*producer_epoch=*/7,
+      /*producer_value=*/3);
+
+  iree_hal_amdgpu_last_signal_flags_t flags =
+      IREE_HAL_AMDGPU_LAST_SIGNAL_FLAG_NONE;
+  iree_async_axis_t cached_axis = 0;
+  uint64_t cached_epoch = 0;
+  uint64_t cached_value = 0;
+  EXPECT_TRUE(iree_hal_amdgpu_last_signal_load(
+      iree_hal_amdgpu_semaphore_last_signal(private_semaphore), &flags,
+      &cached_axis, &cached_epoch, &cached_value));
+  EXPECT_EQ(cached_axis, producer_axis);
+  EXPECT_EQ(cached_epoch, 7u);
+  EXPECT_EQ(cached_value, 3u);
+  EXPECT_EQ(flags,
+            IREE_HAL_AMDGPU_LAST_SIGNAL_FLAG_VALID |
+                IREE_HAL_AMDGPU_LAST_SIGNAL_FLAG_PRODUCER_FRONTIER_EXACT);
+
+  iree_hal_semaphore_release(private_semaphore);
+}
+
+TEST_F(SemaphoreTest,
+       PublishSignalMarksExactWhenProducerFrontierCoversTransitiveDeps) {
+  const iree_async_axis_t producer_axis = test_queue_axis(2);
+  const iree_async_axis_t peer_axis = test_queue_axis(1);
+
+  FrontierBuilder frontier_builder;
+  iree_async_frontier_t* initial_frontier =
+      frontier_builder.Build({iree_async_frontier_entry_t{peer_axis, 4}});
+  EXPECT_TRUE(iree_hal_amdgpu_semaphore_publish_signal(
+      semaphore_, peer_axis, initial_frontier, /*producer_epoch=*/4,
+      /*producer_value=*/1));
+
+  iree_async_frontier_t* transitive_frontier =
+      frontier_builder.Build({iree_async_frontier_entry_t{peer_axis, 4},
+                              iree_async_frontier_entry_t{producer_axis, 7}});
+  EXPECT_TRUE(iree_hal_amdgpu_semaphore_publish_signal(
+      semaphore_, producer_axis, transitive_frontier, /*producer_epoch=*/7,
+      /*producer_value=*/2));
+
+  iree_hal_amdgpu_last_signal_flags_t flags =
+      IREE_HAL_AMDGPU_LAST_SIGNAL_FLAG_NONE;
+  iree_async_axis_t cached_axis = 0;
+  uint64_t cached_epoch = 0;
+  uint64_t cached_value = 0;
+  EXPECT_TRUE(iree_hal_amdgpu_last_signal_load(
+      iree_hal_amdgpu_semaphore_last_signal(semaphore_), &flags, &cached_axis,
+      &cached_epoch, &cached_value));
+  EXPECT_EQ(cached_axis, producer_axis);
+  EXPECT_EQ(cached_epoch, 7u);
+  EXPECT_EQ(cached_value, 2u);
+  EXPECT_EQ(flags,
+            IREE_HAL_AMDGPU_LAST_SIGNAL_FLAG_VALID |
+                IREE_HAL_AMDGPU_LAST_SIGNAL_FLAG_PRODUCER_FRONTIER_EXACT);
+}
+
+TEST_F(SemaphoreTest, PublishSignalClearsExactForIndependentFanIn) {
+  const iree_async_axis_t first_axis = test_queue_axis(1);
+  const iree_async_axis_t second_axis = test_queue_axis(2);
+
+  FrontierBuilder frontier_builder;
+  iree_async_frontier_t* first_frontier =
+      frontier_builder.Build({iree_async_frontier_entry_t{first_axis, 5}});
+  EXPECT_TRUE(iree_hal_amdgpu_semaphore_publish_signal(
+      semaphore_, first_axis, first_frontier, /*producer_epoch=*/5,
+      /*producer_value=*/1));
+
+  iree_async_frontier_t* second_frontier =
+      frontier_builder.Build({iree_async_frontier_entry_t{second_axis, 9}});
+  EXPECT_TRUE(iree_hal_amdgpu_semaphore_publish_signal(
+      semaphore_, second_axis, second_frontier, /*producer_epoch=*/9,
+      /*producer_value=*/2));
+
+  iree_hal_amdgpu_last_signal_flags_t flags =
+      IREE_HAL_AMDGPU_LAST_SIGNAL_FLAG_NONE;
+  iree_async_axis_t cached_axis = 0;
+  uint64_t cached_epoch = 0;
+  uint64_t cached_value = 0;
+  EXPECT_TRUE(iree_hal_amdgpu_last_signal_load(
+      iree_hal_amdgpu_semaphore_last_signal(semaphore_), &flags, &cached_axis,
+      &cached_epoch, &cached_value));
+  EXPECT_EQ(cached_axis, second_axis);
+  EXPECT_EQ(cached_epoch, 9u);
+  EXPECT_EQ(cached_value, 2u);
+  EXPECT_EQ(flags, IREE_HAL_AMDGPU_LAST_SIGNAL_FLAG_VALID);
+}
+
+}  // namespace
diff --git a/runtime/src/iree/hal/drivers/amdgpu/slab_provider.c b/runtime/src/iree/hal/drivers/amdgpu/slab_provider.c
new file mode 100644
index 0000000..67bc2fe
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/slab_provider.c
@@ -0,0 +1,566 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/slab_provider.h"
+
+#include <stddef.h>
+
+#include "iree/hal/drivers/amdgpu/access_policy.h"
+#include "iree/hal/drivers/amdgpu/buffer.h"
+#include "iree/hal/drivers/amdgpu/logical_device.h"
+#include "iree/hal/drivers/amdgpu/util/topology.h"
+#include "iree/hal/memory/tracing.h"
+
+typedef struct iree_hal_amdgpu_slab_provider_t {
+  // Base slab-provider interface header.
+  iree_hal_slab_provider_t base;
+
+  // Host allocator used to allocate and free the provider.
+  iree_allocator_t host_allocator;
+
+  // Borrowed HAL device used when wrapping slabs as AMDGPU buffers.
+  iree_hal_device_t* device;
+
+  // Borrowed HSA dispatch table used for memory-pool operations.
+  const iree_hal_amdgpu_libhsa_t* libhsa;
+
+  // Borrowed topology used to allow GPU peers to access acquired slabs.
+  const iree_hal_amdgpu_topology_t* topology;
+
+  // HSA pool this provider acquires slabs from.
+  hsa_amd_memory_pool_t memory_pool;
+
+  // Session-local physical device ordinal owning this provider.
+  uint32_t physical_device_ordinal;
+
+  // Borrowed wrapper pool used for materialized HAL buffer views.
+  iree_hal_amdgpu_buffer_pool_t* buffer_pool;
+
+  // Queue affinities in this provider's physical memory domain.
+  iree_hal_queue_affinity_t queue_affinity_mask;
+
+  // Stable named-memory stream for HSA backing allocations from this provider.
+  iree_hal_memory_trace_t trace;
+
+  // Minimum runtime allocation granule reported by the HSA pool.
+  iree_device_size_t allocation_granule;
+
+  // Base-pointer alignment guaranteed by HSA runtime allocations.
+  iree_device_size_t allocation_alignment;
+
+  // HAL memory type bits derived from the HSA pool flags.
+  iree_hal_memory_type_t memory_type;
+
+  // HAL buffer usage bits supported by slabs from the HSA pool.
+  iree_hal_buffer_usage_t supported_usage;
+
+  // Cumulative slab acquisitions reported through query_stats().
+  iree_atomic_int64_t total_acquired;
+
+  // Cumulative slab releases reported through query_stats().
+  iree_atomic_int64_t total_released;
+} iree_hal_amdgpu_slab_provider_t;
+
+typedef struct iree_hal_amdgpu_slab_handle_t {
+  // HSA allocation byte length used when releasing the slab.
+  iree_device_size_t allocation_size;
+
+  // Profiling session id owning |profile_allocation_id|.
+  uint64_t profile_session_id;
+
+  // Session-local allocation id for the slab acquire/release lifecycle.
+  uint64_t profile_allocation_id;
+} iree_hal_amdgpu_slab_handle_t;
+
+static const iree_hal_slab_provider_vtable_t
+    iree_hal_amdgpu_slab_provider_vtable;
+
+static const char* IREE_HAL_AMDGPU_SLAB_PROVIDER_TRACE_ID =
+    "iree-hal-amdgpu-slab-provider";
+
+static iree_hal_amdgpu_slab_provider_t* iree_hal_amdgpu_slab_provider_cast(
+    iree_hal_slab_provider_t* base_provider) {
+  return (iree_hal_amdgpu_slab_provider_t*)base_provider;
+}
+
+static const iree_hal_amdgpu_slab_provider_t*
+iree_hal_amdgpu_slab_provider_const_cast(
+    const iree_hal_slab_provider_t* base_provider) {
+  return (const iree_hal_amdgpu_slab_provider_t*)base_provider;
+}
+
+static bool iree_hal_amdgpu_slab_provider_record_memory_event(
+    iree_hal_amdgpu_slab_provider_t* provider,
+    iree_hal_profile_memory_event_type_t type,
+    iree_hal_amdgpu_slab_handle_t* slab_handle, const void* backing_ptr) {
+  uint64_t session_id = slab_handle->profile_session_id;
+  uint64_t allocation_id = slab_handle->profile_allocation_id;
+  if (type == IREE_HAL_PROFILE_MEMORY_EVENT_TYPE_SLAB_ACQUIRE) {
+    allocation_id =
+        iree_hal_amdgpu_logical_device_allocate_profile_memory_allocation_id(
+            provider->device, &session_id);
+    if (allocation_id == 0) return false;
+  } else if (allocation_id == 0) {
+    return false;
+  }
+
+  iree_hal_profile_memory_event_t event =
+      iree_hal_profile_memory_event_default();
+  event.type = type;
+  event.allocation_id = allocation_id;
+  event.pool_id = (uint64_t)(uintptr_t)provider;
+  event.backing_id = (uint64_t)(uintptr_t)backing_ptr;
+  event.physical_device_ordinal = provider->physical_device_ordinal;
+  event.memory_type = provider->memory_type;
+  event.buffer_usage = provider->supported_usage;
+  event.length = slab_handle->allocation_size;
+  event.alignment = provider->allocation_alignment;
+  const bool recorded =
+      iree_hal_amdgpu_logical_device_record_profile_memory_event_for_session(
+          provider->device, session_id, &event);
+  if (recorded && type == IREE_HAL_PROFILE_MEMORY_EVENT_TYPE_SLAB_ACQUIRE) {
+    slab_handle->profile_session_id = session_id;
+    slab_handle->profile_allocation_id = allocation_id;
+  }
+  return recorded;
+}
+
+static iree_status_t iree_hal_amdgpu_slab_provider_resolve_access_agents(
+    const iree_hal_amdgpu_slab_provider_t* provider,
+    iree_hal_amdgpu_access_agent_list_t* out_agent_list) {
+  const iree_hal_amdgpu_queue_affinity_domain_t domain = {
+      .supported_affinity = provider->queue_affinity_mask,
+      .physical_device_count = provider->topology->gpu_agent_count,
+      .queue_count_per_physical_device =
+          provider->topology->gpu_agent_queue_count,
+  };
+  return iree_hal_amdgpu_access_agent_list_resolve(
+      provider->topology, domain, provider->queue_affinity_mask,
+      out_agent_list);
+}
+
+IREE_API_EXPORT iree_status_t
+iree_hal_amdgpu_slab_provider_query_memory_pool_properties(
+    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_amd_memory_pool_t memory_pool,
+    iree_hal_amdgpu_slab_provider_memory_pool_properties_t* out_properties) {
+  IREE_ASSERT_ARGUMENT(libhsa);
+  IREE_ASSERT_ARGUMENT(out_properties);
+  memset(out_properties, 0, sizeof(*out_properties));
+
+  size_t allocation_granule = 0;
+  IREE_RETURN_IF_ERROR(
+      iree_hsa_amd_memory_pool_get_info(
+          IREE_LIBHSA(libhsa), memory_pool,
+          HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_GRANULE, &allocation_granule),
+      "querying HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_GRANULE for an AMDGPU "
+      "slab provider");
+  if (allocation_granule == 0 ||
+      !iree_device_size_is_power_of_two(allocation_granule)) {
+    return iree_make_status(
+        IREE_STATUS_INTERNAL,
+        "invalid HSA runtime allocation granule for an AMDGPU memory pool: "
+        "%" PRIhsz,
+        (iree_host_size_t)allocation_granule);
+  }
+
+  size_t allocation_alignment = 0;
+  IREE_RETURN_IF_ERROR(iree_hsa_amd_memory_pool_get_info(
+      IREE_LIBHSA(libhsa), memory_pool,
+      HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_ALIGNMENT, &allocation_alignment));
+  if (allocation_alignment == 0 ||
+      !iree_device_size_is_power_of_two(allocation_alignment)) {
+    return iree_make_status(
+        IREE_STATUS_INTERNAL,
+        "invalid HSA runtime allocation alignment for an AMDGPU memory pool: "
+        "%" PRIhsz,
+        (iree_host_size_t)allocation_alignment);
+  }
+
+  hsa_region_segment_t segment = 0;
+  IREE_RETURN_IF_ERROR(iree_hsa_amd_memory_pool_get_info(
+      IREE_LIBHSA(libhsa), memory_pool, HSA_AMD_MEMORY_POOL_INFO_SEGMENT,
+      &segment));
+  if (segment != HSA_REGION_SEGMENT_GLOBAL) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AMDGPU slab providers require a GLOBAL HSA pool");
+  }
+
+  bool alloc_allowed = false;
+  IREE_RETURN_IF_ERROR(iree_hsa_amd_memory_pool_get_info(
+      IREE_LIBHSA(libhsa), memory_pool,
+      HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_ALLOWED, &alloc_allowed));
+  if (!alloc_allowed) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "AMDGPU slab provider memory pool does not support runtime "
+        "allocations");
+  }
+
+  hsa_region_global_flag_t global_flags = 0;
+  IREE_RETURN_IF_ERROR(iree_hsa_amd_memory_pool_get_info(
+      IREE_LIBHSA(libhsa), memory_pool, HSA_AMD_MEMORY_POOL_INFO_GLOBAL_FLAGS,
+      &global_flags));
+
+  iree_hal_memory_type_t memory_type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL;
+  // Sharing hints do not affect HSA pool selection. Export is omitted because
+  // it requires dedicated platform export support.
+  const iree_hal_buffer_usage_t sharing_usage =
+      IREE_HAL_BUFFER_USAGE_SHARING_REPLICATE |
+      IREE_HAL_BUFFER_USAGE_SHARING_CONCURRENT |
+      IREE_HAL_BUFFER_USAGE_SHARING_IMMUTABLE;
+  iree_hal_buffer_usage_t supported_usage = IREE_HAL_BUFFER_USAGE_TRANSFER |
+                                            IREE_HAL_BUFFER_USAGE_DISPATCH |
+                                            sharing_usage;
+  if (iree_any_bit_set(
+          global_flags,
+          HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_FINE_GRAINED |
+              HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_EXTENDED_SCOPE_FINE_GRAINED)) {
+    memory_type |=
+        IREE_HAL_MEMORY_TYPE_HOST_VISIBLE | IREE_HAL_MEMORY_TYPE_HOST_COHERENT;
+    supported_usage |= IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED |
+                       IREE_HAL_BUFFER_USAGE_MAPPING_PERSISTENT |
+                       IREE_HAL_BUFFER_USAGE_MAPPING_OPTIONAL |
+                       IREE_HAL_BUFFER_USAGE_MAPPING_ACCESS_RANDOM |
+                       IREE_HAL_BUFFER_USAGE_MAPPING_ACCESS_SEQUENTIAL_WRITE;
+  }
+
+  out_properties->allocation_granule = (iree_device_size_t)allocation_granule;
+  out_properties->allocation_alignment =
+      (iree_device_size_t)allocation_alignment;
+  out_properties->memory_type = memory_type;
+  out_properties->supported_usage = supported_usage;
+  return iree_ok_status();
+}
+
+iree_status_t iree_hal_amdgpu_slab_provider_create(
+    iree_hal_device_t* device, const iree_hal_amdgpu_libhsa_t* libhsa,
+    const iree_hal_amdgpu_topology_t* topology,
+    iree_hal_amdgpu_slab_provider_options_t options,
+    iree_host_size_t physical_device_ordinal,
+    iree_hal_queue_affinity_t queue_affinity_mask,
+    iree_hal_amdgpu_buffer_pool_t* buffer_pool, iree_string_view_t trace_name,
+    iree_allocator_t host_allocator, iree_hal_slab_provider_t** out_provider) {
+  IREE_ASSERT_ARGUMENT(device);
+  IREE_ASSERT_ARGUMENT(libhsa);
+  IREE_ASSERT_ARGUMENT(topology);
+  IREE_ASSERT_ARGUMENT(buffer_pool);
+  IREE_ASSERT_ARGUMENT(out_provider);
+  IREE_TRACE_ZONE_BEGIN(z0);
+  *out_provider = NULL;
+
+  if (IREE_UNLIKELY(iree_hal_queue_affinity_is_empty(queue_affinity_mask))) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AMDGPU slab provider queue affinity mask must "
+                            "not be empty");
+  }
+  if (IREE_UNLIKELY(physical_device_ordinal > UINT32_MAX)) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "AMDGPU slab provider physical device ordinal out of range: %" PRIhsz,
+        physical_device_ordinal);
+  }
+  if (IREE_UNLIKELY(!options.memory_pool.handle)) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AMDGPU slab provider requires an HSA memory pool");
+  }
+  if (IREE_UNLIKELY(options.memory_type == IREE_HAL_MEMORY_TYPE_NONE)) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "AMDGPU slab provider requires non-empty HAL memory type bits");
+  }
+  if (IREE_UNLIKELY(options.supported_usage == IREE_HAL_BUFFER_USAGE_NONE)) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "AMDGPU slab provider requires non-empty HAL buffer usage bits");
+  }
+
+  iree_hal_amdgpu_slab_provider_t* provider = NULL;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_allocator_malloc(host_allocator, sizeof(*provider),
+                                (void**)&provider));
+  memset(provider, 0, sizeof(*provider));
+  iree_hal_slab_provider_initialize(&iree_hal_amdgpu_slab_provider_vtable,
+                                    &provider->base);
+  provider->host_allocator = host_allocator;
+  provider->device = device;
+  provider->libhsa = libhsa;
+  provider->topology = topology;
+  provider->memory_pool = options.memory_pool;
+  provider->physical_device_ordinal = (uint32_t)physical_device_ordinal;
+  provider->buffer_pool = buffer_pool;
+  provider->queue_affinity_mask = queue_affinity_mask;
+  provider->total_acquired = IREE_ATOMIC_VAR_INIT(0);
+  provider->total_released = IREE_ATOMIC_VAR_INIT(0);
+
+  iree_status_t status = iree_hal_memory_trace_initialize(
+      trace_name, IREE_HAL_AMDGPU_SLAB_PROVIDER_TRACE_ID, host_allocator,
+      &provider->trace);
+  iree_hal_amdgpu_slab_provider_memory_pool_properties_t properties;
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_slab_provider_query_memory_pool_properties(
+        libhsa, options.memory_pool, &properties);
+  }
+  if (iree_status_is_ok(status)) {
+    provider->allocation_granule = properties.allocation_granule;
+    provider->allocation_alignment = properties.allocation_alignment;
+    provider->memory_type = options.memory_type;
+    provider->supported_usage = options.supported_usage;
+    *out_provider = &provider->base;
+  } else {
+    iree_hal_slab_provider_release(&provider->base);
+  }
+
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+static void iree_hal_amdgpu_slab_provider_destroy(
+    iree_hal_slab_provider_t* base_provider) {
+  iree_hal_amdgpu_slab_provider_t* provider =
+      iree_hal_amdgpu_slab_provider_cast(base_provider);
+  iree_allocator_t host_allocator = provider->host_allocator;
+  IREE_TRACE_ZONE_BEGIN(z0);
+  iree_hal_memory_trace_deinitialize(&provider->trace);
+  iree_allocator_free(host_allocator, provider);
+  IREE_TRACE_ZONE_END(z0);
+}
+
+static iree_status_t iree_hal_amdgpu_slab_provider_acquire_slab(
+    iree_hal_slab_provider_t* base_provider, iree_device_size_t min_length,
+    iree_hal_slab_t* out_slab) {
+  IREE_ASSERT_ARGUMENT(out_slab);
+  iree_hal_amdgpu_slab_provider_t* provider =
+      iree_hal_amdgpu_slab_provider_cast(base_provider);
+  IREE_TRACE_ZONE_BEGIN(z0);
+  memset(out_slab, 0, sizeof(*out_slab));
+
+  if (IREE_UNLIKELY(min_length == 0)) {
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                             "AMDGPU slab allocations must be non-empty"));
+  }
+  iree_device_size_t allocation_size = 0;
+  if (IREE_UNLIKELY(!iree_device_size_checked_align(
+          min_length, provider->allocation_granule, &allocation_size))) {
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_make_status(
+                IREE_STATUS_OUT_OF_RANGE,
+                "AMDGPU slab allocation size overflow aligning %" PRIu64
+                " bytes to a %" PRIu64 "-byte HSA allocation granule",
+                (uint64_t)min_length, (uint64_t)provider->allocation_granule));
+  }
+
+  iree_hal_amdgpu_slab_handle_t* slab_handle = NULL;
+  iree_status_t status = iree_allocator_malloc(
+      provider->host_allocator, sizeof(*slab_handle), (void**)&slab_handle);
+  if (iree_status_is_ok(status)) {
+    memset(slab_handle, 0, sizeof(*slab_handle));
+    slab_handle->allocation_size = allocation_size;
+  }
+
+  void* base_ptr = NULL;
+  if (iree_status_is_ok(status)) {
+    status = iree_hsa_amd_memory_pool_allocate(
+        IREE_LIBHSA(provider->libhsa), provider->memory_pool,
+        (size_t)allocation_size, HSA_AMD_MEMORY_POOL_STANDARD_FLAG, &base_ptr);
+  }
+  if (iree_status_is_ok(status)) {
+    iree_hal_amdgpu_access_agent_list_t access_agents;
+    status = iree_hal_amdgpu_slab_provider_resolve_access_agents(
+        provider, &access_agents);
+    if (iree_status_is_ok(status)) {
+      status = iree_hal_amdgpu_access_allow_agent_list(
+          provider->libhsa, &access_agents, base_ptr);
+    }
+  }
+
+  if (iree_status_is_ok(status)) {
+    out_slab->base_ptr = (uint8_t*)base_ptr;
+    // Preserve the requested logical slab length. The hidden HSA allocation may
+    // be larger due to runtime granule rounding, but exposing that padding here
+    // would incorrectly inflate HAL buffer byte lengths in pass-through pools.
+    out_slab->length = min_length;
+    out_slab->provider_handle = (uint64_t)(uintptr_t)slab_handle;
+    iree_hal_memory_trace_alloc(&provider->trace, base_ptr, allocation_size);
+    iree_atomic_fetch_add(&provider->total_acquired, 1,
+                          iree_memory_order_relaxed);
+    iree_hal_amdgpu_slab_provider_record_memory_event(
+        provider, IREE_HAL_PROFILE_MEMORY_EVENT_TYPE_SLAB_ACQUIRE, slab_handle,
+        base_ptr);
+  } else if (base_ptr) {
+    status = iree_status_join(
+        status,
+        iree_hsa_amd_memory_pool_free(IREE_LIBHSA(provider->libhsa), base_ptr));
+  }
+  if (!iree_status_is_ok(status)) {
+    iree_allocator_free(provider->host_allocator, slab_handle);
+  }
+
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+static void iree_hal_amdgpu_slab_provider_release_slab(
+    iree_hal_slab_provider_t* base_provider, const iree_hal_slab_t* slab) {
+  IREE_ASSERT_ARGUMENT(slab);
+  iree_hal_amdgpu_slab_provider_t* provider =
+      iree_hal_amdgpu_slab_provider_cast(base_provider);
+  IREE_TRACE_ZONE_BEGIN(z0);
+  if (slab->base_ptr) {
+    iree_hal_amdgpu_slab_handle_t* slab_handle =
+        (iree_hal_amdgpu_slab_handle_t*)(uintptr_t)slab->provider_handle;
+    iree_hal_amdgpu_slab_provider_record_memory_event(
+        provider, IREE_HAL_PROFILE_MEMORY_EVENT_TYPE_SLAB_RELEASE, slab_handle,
+        slab->base_ptr);
+    iree_hal_memory_trace_free(&provider->trace, slab->base_ptr);
+    iree_hal_amdgpu_hsa_cleanup_assert_success(
+        iree_hsa_amd_memory_pool_free_raw(provider->libhsa, slab->base_ptr));
+    iree_allocator_free(provider->host_allocator, slab_handle);
+    iree_atomic_fetch_add(&provider->total_released, 1,
+                          iree_memory_order_relaxed);
+  }
+  IREE_TRACE_ZONE_END(z0);
+}
+
+static void iree_hal_amdgpu_slab_provider_borrowed_buffer_release(
+    void* user_data, iree_hal_buffer_t* buffer) {
+  (void)user_data;
+  (void)buffer;
+}
+
+static iree_status_t iree_hal_amdgpu_slab_provider_wrap_buffer(
+    iree_hal_slab_provider_t* base_provider, const iree_hal_slab_t* slab,
+    iree_device_size_t slab_offset, iree_device_size_t allocation_size,
+    iree_hal_buffer_params_t params,
+    iree_hal_buffer_release_callback_t release_callback,
+    iree_hal_buffer_t** out_buffer) {
+  iree_hal_amdgpu_slab_provider_t* provider =
+      iree_hal_amdgpu_slab_provider_cast(base_provider);
+
+  iree_hal_memory_type_t resolved_type = params.type;
+  if (iree_any_bit_set(resolved_type, IREE_HAL_MEMORY_TYPE_OPTIMAL)) {
+    resolved_type &= ~IREE_HAL_MEMORY_TYPE_OPTIMAL;
+    resolved_type |= provider->memory_type;
+  }
+  if (IREE_UNLIKELY(!iree_all_bits_set(provider->memory_type, resolved_type))) {
+#if IREE_STATUS_MODE
+    iree_bitfield_string_temp_t actual_temp;
+    iree_bitfield_string_temp_t requested_temp;
+    iree_string_view_t actual_string =
+        iree_hal_memory_type_format(provider->memory_type, &actual_temp);
+    iree_string_view_t requested_string =
+        iree_hal_memory_type_format(resolved_type, &requested_temp);
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "AMDGPU slab provider memory type %.*s does not satisfy requested "
+        "buffer type %.*s",
+        (int)actual_string.size, actual_string.data, (int)requested_string.size,
+        requested_string.data);
+#else
+    return iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT);
+#endif  // IREE_STATUS_MODE
+  }
+  if (IREE_UNLIKELY(
+          !iree_all_bits_set(provider->supported_usage, params.usage))) {
+#if IREE_STATUS_MODE
+    iree_bitfield_string_temp_t actual_temp;
+    iree_bitfield_string_temp_t requested_temp;
+    iree_string_view_t actual_string =
+        iree_hal_buffer_usage_format(provider->supported_usage, &actual_temp);
+    iree_string_view_t requested_string =
+        iree_hal_buffer_usage_format(params.usage, &requested_temp);
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "AMDGPU slab provider usage %.*s does not satisfy requested buffer "
+        "usage %.*s",
+        (int)actual_string.size, actual_string.data, (int)requested_string.size,
+        requested_string.data);
+#else
+    return iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT);
+#endif  // IREE_STATUS_MODE
+  }
+
+  iree_hal_queue_affinity_t queue_affinity = params.queue_affinity;
+  if (queue_affinity == IREE_HAL_QUEUE_AFFINITY_ANY) {
+    queue_affinity = provider->queue_affinity_mask;
+  } else if (IREE_UNLIKELY(!iree_all_bits_set(provider->queue_affinity_mask,
+                                              queue_affinity))) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "AMDGPU slab provider queue affinity 0x%016" PRIx64
+        " does not cover requested buffer affinity 0x%016" PRIx64,
+        provider->queue_affinity_mask, queue_affinity);
+  }
+
+  const iree_hal_buffer_placement_t placement = {
+      .device = provider->device,
+      .queue_affinity = queue_affinity,
+      .flags = IREE_HAL_BUFFER_PLACEMENT_FLAG_NONE,
+  };
+  if (!release_callback.fn) {
+    release_callback.fn = iree_hal_amdgpu_slab_provider_borrowed_buffer_release;
+  }
+  return iree_hal_amdgpu_buffer_create_pooled(
+      provider->libhsa, placement, resolved_type, params.access, params.usage,
+      allocation_size, allocation_size, slab->base_ptr + slab_offset,
+      release_callback, provider->buffer_pool, provider->host_allocator,
+      out_buffer);
+}
+
+static void iree_hal_amdgpu_slab_provider_prefault(
+    iree_hal_slab_provider_t* base_provider, iree_hal_slab_t* slab) {
+  (void)base_provider;
+  (void)slab;
+}
+
+static void iree_hal_amdgpu_slab_provider_trim(
+    iree_hal_slab_provider_t* base_provider,
+    iree_hal_slab_provider_trim_flags_t flags) {
+  (void)base_provider;
+  (void)flags;
+}
+
+static void iree_hal_amdgpu_slab_provider_query_stats(
+    const iree_hal_slab_provider_t* base_provider,
+    iree_hal_slab_provider_visited_set_t* visited,
+    iree_hal_slab_provider_stats_t* out_stats) {
+  if (iree_hal_slab_provider_visited(visited, base_provider)) {
+    return;
+  }
+  const iree_hal_amdgpu_slab_provider_t* provider =
+      iree_hal_amdgpu_slab_provider_const_cast(base_provider);
+  out_stats->total_acquired += (uint64_t)iree_atomic_load(
+      &provider->total_acquired, iree_memory_order_relaxed);
+  out_stats->total_released += (uint64_t)iree_atomic_load(
+      &provider->total_released, iree_memory_order_relaxed);
+}
+
+static void iree_hal_amdgpu_slab_provider_query_properties(
+    const iree_hal_slab_provider_t* base_provider,
+    iree_hal_memory_type_t* out_memory_type,
+    iree_hal_buffer_usage_t* out_supported_usage) {
+  const iree_hal_amdgpu_slab_provider_t* provider =
+      iree_hal_amdgpu_slab_provider_const_cast(base_provider);
+  *out_memory_type = provider->memory_type;
+  *out_supported_usage = provider->supported_usage;
+}
+
+static const iree_hal_slab_provider_vtable_t
+    iree_hal_amdgpu_slab_provider_vtable = {
+        .destroy = iree_hal_amdgpu_slab_provider_destroy,
+        .acquire_slab = iree_hal_amdgpu_slab_provider_acquire_slab,
+        .release_slab = iree_hal_amdgpu_slab_provider_release_slab,
+        .wrap_buffer = iree_hal_amdgpu_slab_provider_wrap_buffer,
+        .prefault = iree_hal_amdgpu_slab_provider_prefault,
+        .trim = iree_hal_amdgpu_slab_provider_trim,
+        .query_stats = iree_hal_amdgpu_slab_provider_query_stats,
+        .query_properties = iree_hal_amdgpu_slab_provider_query_properties,
+};
diff --git a/runtime/src/iree/hal/drivers/amdgpu/slab_provider.h b/runtime/src/iree/hal/drivers/amdgpu/slab_provider.h
new file mode 100644
index 0000000..dff2029
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/slab_provider.h
@@ -0,0 +1,84 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_SLAB_PROVIDER_H_
+#define IREE_HAL_DRIVERS_AMDGPU_SLAB_PROVIDER_H_
+
+#include "iree/base/api.h"
+#include "iree/hal/api.h"
+#include "iree/hal/drivers/amdgpu/buffer.h"
+#include "iree/hal/drivers/amdgpu/util/libhsa.h"
+#include "iree/hal/memory/slab_provider.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+typedef struct iree_hal_amdgpu_topology_t iree_hal_amdgpu_topology_t;
+
+//===----------------------------------------------------------------------===//
+// iree_hal_amdgpu_slab_provider_t
+//===----------------------------------------------------------------------===//
+
+// HSA memory-pool properties needed to configure slab-backed HAL pools.
+typedef struct iree_hal_amdgpu_slab_provider_memory_pool_properties_t {
+  // Smallest allocation-size multiple accepted by hsa_amd_memory_pool_allocate.
+  iree_device_size_t allocation_granule;
+
+  // Base-pointer alignment guaranteed by hsa_amd_memory_pool_allocate.
+  iree_device_size_t allocation_alignment;
+
+  // HAL memory type bits provided by this HSA memory pool.
+  iree_hal_memory_type_t memory_type;
+
+  // HAL buffer usage bits supported by buffers materialized from this pool.
+  iree_hal_buffer_usage_t supported_usage;
+} iree_hal_amdgpu_slab_provider_memory_pool_properties_t;
+
+// HSA memory pool and HAL capabilities exposed by a slab provider.
+typedef struct iree_hal_amdgpu_slab_provider_options_t {
+  // HSA memory pool used for slab allocations.
+  hsa_amd_memory_pool_t memory_pool;
+
+  // HAL memory type bits reported for buffers materialized from slabs.
+  iree_hal_memory_type_t memory_type;
+
+  // HAL buffer usage bits supported by buffers materialized from slabs.
+  iree_hal_buffer_usage_t supported_usage;
+} iree_hal_amdgpu_slab_provider_options_t;
+
+// Queries HSA memory-pool properties used by AMDGPU slab providers and pools.
+iree_status_t iree_hal_amdgpu_slab_provider_query_memory_pool_properties(
+    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_amd_memory_pool_t memory_pool,
+    iree_hal_amdgpu_slab_provider_memory_pool_properties_t* out_properties);
+
+// Creates a slab provider backed by an HSA memory pool on one GPU agent.
+//
+// The provider acquires whole slabs with hsa_amd_memory_pool_allocate(), grants
+// every agent in |topology| access to the slab, and wraps slab slices as
+// iree_hal_amdgpu_buffer_t views. |queue_affinity_mask| identifies the HAL
+// queues in this physical memory domain; wrap_buffer() replaces
+// IREE_HAL_QUEUE_AFFINITY_ANY with that mask and rejects explicit affinities
+// outside it so placement metadata always routes PREFER_ORIGIN dealloca back
+// into the provider's domain. Materialized buffer view wrappers are allocated
+// from |buffer_pool|, which must be in the same physical-device lifetime domain
+// as the backing HSA memory. The provider borrows |device|, |libhsa|,
+// |topology|, and |buffer_pool|; the owning physical/logical device must
+// outlive the provider and every pool/buffer created from it.
+iree_status_t iree_hal_amdgpu_slab_provider_create(
+    iree_hal_device_t* device, const iree_hal_amdgpu_libhsa_t* libhsa,
+    const iree_hal_amdgpu_topology_t* topology,
+    iree_hal_amdgpu_slab_provider_options_t options,
+    iree_host_size_t physical_device_ordinal,
+    iree_hal_queue_affinity_t queue_affinity_mask,
+    iree_hal_amdgpu_buffer_pool_t* buffer_pool, iree_string_view_t trace_name,
+    iree_allocator_t host_allocator, iree_hal_slab_provider_t** out_provider);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_SLAB_PROVIDER_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/slab_provider_test.cc b/runtime/src/iree/hal/drivers/amdgpu/slab_provider_test.cc
new file mode 100644
index 0000000..a65d445
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/slab_provider_test.cc
@@ -0,0 +1,245 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/api.h"
+#include "iree/hal/cts/util/test_base.h"
+#include "iree/hal/drivers/amdgpu/logical_device.h"
+#include "iree/hal/drivers/amdgpu/physical_device.h"
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace iree::hal::amdgpu {
+namespace {
+
+class SlabProviderTest : public ::testing::Test {
+ protected:
+  static void SetUpTestSuite() {
+    host_allocator_ = iree_allocator_system();
+    iree_status_t status = iree_hal_amdgpu_libhsa_initialize(
+        IREE_HAL_AMDGPU_LIBHSA_FLAG_NONE, iree_string_view_list_empty(),
+        host_allocator_, &libhsa_);
+    if (!iree_status_is_ok(status)) {
+      iree_status_fprint(stderr, status);
+      iree_status_free(status);
+      GTEST_SKIP() << "HSA not available, skipping tests";
+    }
+    IREE_ASSERT_OK(iree_hal_amdgpu_topology_initialize_with_defaults(
+        &libhsa_, &topology_));
+    if (topology_.gpu_agent_count == 0) {
+      GTEST_SKIP() << "no GPU devices available, skipping tests";
+    }
+  }
+
+  static void TearDownTestSuite() {
+    iree_hal_amdgpu_topology_deinitialize(&topology_);
+    iree_hal_amdgpu_libhsa_deinitialize(&libhsa_);
+  }
+
+  static iree_allocator_t host_allocator_;
+  static iree_hal_amdgpu_libhsa_t libhsa_;
+  static iree_hal_amdgpu_topology_t topology_;
+};
+
+iree_allocator_t SlabProviderTest::host_allocator_;
+iree_hal_amdgpu_libhsa_t SlabProviderTest::libhsa_;
+iree_hal_amdgpu_topology_t SlabProviderTest::topology_;
+
+class TestLogicalDevice {
+ public:
+  ~TestLogicalDevice() {
+    iree_hal_device_release(base_device);
+    iree_hal_device_group_release(device_group);
+  }
+
+  iree_status_t Initialize(
+      const iree_hal_amdgpu_logical_device_options_t* options,
+      const iree_hal_amdgpu_libhsa_t* libhsa,
+      const iree_hal_amdgpu_topology_t* topology,
+      iree_allocator_t host_allocator) {
+    IREE_RETURN_IF_ERROR(create_context.Initialize(host_allocator));
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_logical_device_create(
+        IREE_SV("amdgpu"), options, libhsa, topology, create_context.params(),
+        host_allocator, &base_device));
+    return iree_hal_device_group_create_from_device(
+        base_device, create_context.frontier_tracker(), host_allocator,
+        &device_group);
+  }
+
+  iree_hal_amdgpu_logical_device_t* device() const {
+    return (iree_hal_amdgpu_logical_device_t*)base_device;
+  }
+
+  iree_hal_device_t* hal_device() const { return base_device; }
+
+ private:
+  // Creation context supplying the proactor pool and frontier tracker.
+  iree::hal::cts::DeviceCreateContext create_context;
+
+  // Test-owned device reference released before the topology-owning group.
+  iree_hal_device_t* base_device = NULL;
+
+  // Device group that owns the topology assigned to |base_device|.
+  iree_hal_device_group_t* device_group = NULL;
+};
+
+TEST_F(SlabProviderTest,
+       DefaultPhysicalDevicePoolMaterializesDeviceLocalBuffer) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+
+  auto* device = test_device.device();
+  ASSERT_GE(device->physical_device_count, 1u);
+  iree_hal_pool_t* default_pool = device->physical_devices[0]->default_pool;
+  ASSERT_NE(default_pool, nullptr);
+
+  iree_hal_buffer_params_t params = {0};
+  params.type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE;
+  params.usage = IREE_HAL_BUFFER_USAGE_TRANSFER;
+
+  iree_hal_buffer_t* buffer = NULL;
+  IREE_ASSERT_OK(iree_hal_pool_allocate_buffer(
+      default_pool, params, /*allocation_size=*/128,
+      /*requester_frontier=*/NULL, iree_make_timeout_ms(0), &buffer));
+  ASSERT_NE(buffer, nullptr);
+  EXPECT_GE(iree_hal_buffer_allocation_size(buffer), 128u);
+  EXPECT_GE(iree_hal_buffer_byte_length(buffer), 128u);
+  EXPECT_TRUE(iree_all_bits_set(iree_hal_buffer_memory_type(buffer),
+                                IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL));
+  EXPECT_FALSE(iree_all_bits_set(iree_hal_buffer_memory_type(buffer),
+                                 IREE_HAL_MEMORY_TYPE_HOST_VISIBLE));
+
+  iree_hal_buffer_release(buffer);
+}
+
+TEST_F(SlabProviderTest, DefaultQueueAllocaRoutesOversizedRequests) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.default_pool.range_length = 4096;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+
+  iree_hal_semaphore_t* signal_semaphore = NULL;
+  IREE_ASSERT_OK(iree_hal_semaphore_create(
+      test_device.hal_device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+      /*initial_value=*/0, IREE_HAL_SEMAPHORE_FLAG_DEFAULT, &signal_semaphore));
+  iree_hal_semaphore_t* alloca_signal_semaphores[] = {signal_semaphore};
+  uint64_t alloca_signal_values[] = {1};
+  iree_hal_semaphore_list_t alloca_signal_list = {
+      IREE_ARRAYSIZE(alloca_signal_semaphores),
+      alloca_signal_semaphores,
+      alloca_signal_values,
+  };
+
+  iree_hal_buffer_params_t params = {0};
+  params.type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE;
+  params.access = IREE_HAL_MEMORY_ACCESS_ALL;
+  params.usage = IREE_HAL_BUFFER_USAGE_TRANSFER;
+
+  iree_hal_buffer_t* buffer = NULL;
+  IREE_ASSERT_OK(iree_hal_device_queue_alloca(
+      test_device.hal_device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+      iree_hal_semaphore_list_empty(), alloca_signal_list, /*pool=*/NULL,
+      params, /*allocation_size=*/8192, IREE_HAL_ALLOCA_FLAG_NONE, &buffer));
+  ASSERT_NE(buffer, nullptr);
+  IREE_ASSERT_OK(iree_hal_semaphore_list_wait(
+      alloca_signal_list, iree_infinite_timeout(), IREE_ASYNC_WAIT_FLAG_NONE));
+
+  iree_hal_semaphore_t* dealloca_wait_semaphores[] = {signal_semaphore};
+  uint64_t dealloca_wait_values[] = {1};
+  iree_hal_semaphore_list_t dealloca_wait_list = {
+      IREE_ARRAYSIZE(dealloca_wait_semaphores),
+      dealloca_wait_semaphores,
+      dealloca_wait_values,
+  };
+  iree_hal_semaphore_t* dealloca_signal_semaphores[] = {signal_semaphore};
+  uint64_t dealloca_signal_values[] = {2};
+  iree_hal_semaphore_list_t dealloca_signal_list = {
+      IREE_ARRAYSIZE(dealloca_signal_semaphores),
+      dealloca_signal_semaphores,
+      dealloca_signal_values,
+  };
+  IREE_ASSERT_OK(iree_hal_device_queue_dealloca(
+      test_device.hal_device(), IREE_HAL_QUEUE_AFFINITY_ANY, dealloca_wait_list,
+      dealloca_signal_list, buffer, IREE_HAL_DEALLOCA_FLAG_NONE));
+  IREE_ASSERT_OK(iree_hal_semaphore_list_wait(dealloca_signal_list,
+                                              iree_infinite_timeout(),
+                                              IREE_ASYNC_WAIT_FLAG_NONE));
+
+  iree_hal_buffer_release(buffer);
+  iree_hal_semaphore_release(signal_semaphore);
+}
+
+TEST_F(SlabProviderTest, DefaultPhysicalDevicePoolGrowsAdditionalSlabs) {
+  iree_hal_amdgpu_logical_device_options_t options;
+  iree_hal_amdgpu_logical_device_options_initialize(&options);
+  options.default_pool.range_length = 4096;
+
+  TestLogicalDevice test_device;
+  IREE_ASSERT_OK(
+      test_device.Initialize(&options, &libhsa_, &topology_, host_allocator_));
+
+  auto* device = test_device.device();
+  ASSERT_GE(device->physical_device_count, 1u);
+  iree_hal_pool_t* default_pool = device->physical_devices[0]->default_pool;
+  ASSERT_NE(default_pool, nullptr);
+
+  iree_hal_pool_capabilities_t capabilities;
+  iree_hal_pool_query_capabilities(default_pool, &capabilities);
+  ASSERT_NE(capabilities.max_allocation_size, 0u);
+
+  iree_hal_pool_reservation_t first_reservation = {0};
+  iree_hal_pool_acquire_info_t first_info = {0};
+  iree_hal_pool_acquire_result_t first_result = IREE_HAL_POOL_ACQUIRE_EXHAUSTED;
+  iree::Status first_status(iree_hal_pool_acquire_reservation(
+      default_pool, capabilities.max_allocation_size,
+      capabilities.min_allocation_size, /*requester_frontier=*/NULL,
+      IREE_HAL_POOL_RESERVE_FLAG_NONE, &first_reservation, &first_info,
+      &first_result));
+  const bool first_acquired =
+      first_status.ok() && (first_result == IREE_HAL_POOL_ACQUIRE_OK ||
+                            first_result == IREE_HAL_POOL_ACQUIRE_OK_FRESH);
+
+  iree_hal_pool_reservation_t second_reservation = {0};
+  iree_hal_pool_acquire_info_t second_info = {0};
+  iree_hal_pool_acquire_result_t second_result =
+      IREE_HAL_POOL_ACQUIRE_EXHAUSTED;
+  iree::Status second_status(iree_hal_pool_acquire_reservation(
+      default_pool, capabilities.max_allocation_size,
+      capabilities.min_allocation_size, /*requester_frontier=*/NULL,
+      IREE_HAL_POOL_RESERVE_FLAG_NONE, &second_reservation, &second_info,
+      &second_result));
+  const bool second_acquired =
+      second_status.ok() && (second_result == IREE_HAL_POOL_ACQUIRE_OK ||
+                             second_result == IREE_HAL_POOL_ACQUIRE_OK_FRESH);
+
+  iree_hal_pool_stats_t stats;
+  iree_hal_pool_query_stats(default_pool, &stats);
+
+  if (second_acquired) {
+    iree_hal_pool_release_reservation(default_pool, &second_reservation,
+                                      /*death_frontier=*/NULL);
+  }
+  if (first_acquired) {
+    iree_hal_pool_release_reservation(default_pool, &first_reservation,
+                                      /*death_frontier=*/NULL);
+  }
+
+  EXPECT_TRUE(first_status.ok()) << first_status.ToString();
+  EXPECT_EQ(first_result, IREE_HAL_POOL_ACQUIRE_OK_FRESH);
+  EXPECT_TRUE(second_status.ok()) << second_status.ToString();
+  EXPECT_EQ(second_result, IREE_HAL_POOL_ACQUIRE_OK_FRESH);
+  EXPECT_EQ(stats.reservation_count, 2u);
+  EXPECT_GE(stats.slab_count, 2u);
+}
+
+}  // namespace
+}  // namespace iree::hal::amdgpu
diff --git a/runtime/src/iree/hal/drivers/amdgpu/system.c b/runtime/src/iree/hal/drivers/amdgpu/system.c
index ad02695..a201cec 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/system.c
+++ b/runtime/src/iree/hal/drivers/amdgpu/system.c
@@ -7,7 +7,6 @@
 #include "iree/hal/drivers/amdgpu/system.h"
 
 #include "iree/hal/drivers/amdgpu/executable.h"
-#include "iree/hal/drivers/amdgpu/util/kfd.h"
 #include "iree/hal/drivers/amdgpu/util/topology.h"
 
 //===----------------------------------------------------------------------===//
@@ -24,14 +23,16 @@
 // emit iree_status_t errors. Instead we iterate over all and then do things in
 // local for-loops.
 typedef struct iree_hal_amdgpu_hsa_memory_pool_list_t {
+  // Number of valid entries in |values|.
   iree_host_size_t count;
+  // Fixed-capacity memory-pool list populated by HSA iteration callbacks.
   hsa_amd_memory_pool_t values[32];
 } iree_hal_amdgpu_hsa_memory_pool_list_t;
 static hsa_status_t iree_hal_amdgpu_iterate_hsa_memory_pool(
     hsa_amd_memory_pool_t memory_pool, void* user_data) {
   iree_hal_amdgpu_hsa_memory_pool_list_t* pool_list =
       (iree_hal_amdgpu_hsa_memory_pool_list_t*)user_data;
-  if (pool_list->count + 1 >= IREE_ARRAYSIZE(pool_list->values)) {
+  if (pool_list->count >= IREE_ARRAYSIZE(pool_list->values)) {
     return HSA_STATUS_ERROR_OUT_OF_RESOURCES;
   }
   pool_list->values[pool_list->count++] = memory_pool;
@@ -52,6 +53,7 @@
               IREE_LIBHSA(libhsa), host_agent,
               iree_hal_amdgpu_iterate_hsa_memory_pool, &all_memory_pools));
 
+  bool fine_pool_is_kernarg = false;
   for (iree_host_size_t i = 0; i < all_memory_pools.count; ++i) {
     hsa_amd_memory_pool_t pool = all_memory_pools.values[i];
 
@@ -82,28 +84,53 @@
             HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_ALLOWED, &alloc_allowed));
     if (!alloc_allowed) continue;
 
-    // Only want fine-grained so we can use atomics.
-    hsa_region_global_flag_t global_flag = 0;
+    // Coarse-grained pools are used for write-once/read-many data.
+    // Kernarg-init pools are used for dispatch argument storage. Fine-grained
+    // pools are used for host/device shared state that requires atomics.
+    uint32_t global_flag = 0;
     IREE_RETURN_AND_END_ZONE_IF_ERROR(
         z0, iree_hsa_amd_memory_pool_get_info(
                 IREE_LIBHSA(libhsa), pool,
                 HSA_AMD_MEMORY_POOL_INFO_GLOBAL_FLAGS, &global_flag));
-    if (global_flag & HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_FINE_GRAINED) {
-      if (!host_memory_pools->fine_pool.handle) {  // first only
-        host_memory_pools->fine_pool = pool;
-      }
+    const bool is_kernarg =
+        global_flag & HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_KERNARG_INIT;
+    const bool is_fine =
+        global_flag & HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_FINE_GRAINED;
+    if ((global_flag & HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_COARSE_GRAINED) &&
+        !host_memory_pools->coarse_pool.handle) {
+      host_memory_pools->coarse_pool = pool;
+    }
+    if (is_kernarg && !host_memory_pools->kernarg_pool.handle) {
+      host_memory_pools->kernarg_pool = pool;
+    }
+    if (is_fine && (!host_memory_pools->fine_pool.handle ||
+                    (fine_pool_is_kernarg && !is_kernarg))) {
+      host_memory_pools->fine_pool = pool;
+      fine_pool_is_kernarg = is_kernarg;
     }
   }
 
-  iree_status_t status = iree_ok_status();
+  if (!host_memory_pools->coarse_pool.handle) {
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_make_status(IREE_STATUS_NOT_FOUND,
+                             "no accessible-by-all + coarse-grained shared "
+                             "memory pool is available in the system"));
+  }
   if (!host_memory_pools->fine_pool.handle) {
-    status = iree_make_status(IREE_STATUS_NOT_FOUND,
-                              "no accessible-by-all + fine-grained shared "
-                              "memory pool is available in the system");
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_make_status(IREE_STATUS_NOT_FOUND,
+                             "no accessible-by-all + fine-grained shared "
+                             "memory pool is available in the system"));
+  }
+  if (!host_memory_pools->kernarg_pool.handle) {
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_make_status(IREE_STATUS_NOT_FOUND,
+                             "no accessible-by-all + kernarg-init shared "
+                             "memory pool is available in the system"));
   }
 
   IREE_TRACE_ZONE_END(z0);
-  return status;
+  return iree_ok_status();
 }
 
 // NOTE: we could do the filtering inline in the iteration callback but that
@@ -111,17 +138,19 @@
 // emit iree_status_t errors. Instead we iterate over all and then do things in
 // local for-loops.
 typedef struct iree_hal_amdgpu_hsa_region_list_t {
+  // Number of valid entries in |values|.
   iree_host_size_t count;
+  // Fixed-capacity region list populated by HSA iteration callbacks.
   hsa_region_t values[32];
 } iree_hal_amdgpu_hsa_region_list_t;
 static hsa_status_t iree_hal_amdgpu_iterate_hsa_region(hsa_region_t region,
                                                        void* user_data) {
-  iree_hal_amdgpu_hsa_region_list_t* pool_list =
+  iree_hal_amdgpu_hsa_region_list_t* region_list =
       (iree_hal_amdgpu_hsa_region_list_t*)user_data;
-  if (pool_list->count + 1 >= IREE_ARRAYSIZE(pool_list->values)) {
+  if (region_list->count >= IREE_ARRAYSIZE(region_list->values)) {
     return HSA_STATUS_ERROR_OUT_OF_RESOURCES;
   }
-  pool_list->values[pool_list->count++] = region;
+  region_list->values[region_list->count++] = region;
   return HSA_STATUS_SUCCESS;
 }
 
@@ -210,9 +239,10 @@
   IREE_RETURN_AND_END_ZONE_IF_ERROR(
       z0, iree_hal_amdgpu_system_info_query(libhsa, &out_system->info));
 
-  // Open /dev/kfd so that we can issue ioctls directly.
+  // Initialize the platform source used for profile clock correlation.
   IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0, iree_hal_amdgpu_kfd_open(&out_system->kfd_fd));
+      z0, iree_hal_amdgpu_device_clock_source_initialize(
+              &out_system->device_clock_source));
 
   // Copy the libhsa symbol table and retain HSA for the lifetime of the system.
   // The caller may destroy the provided libhsa after this call returns.
@@ -265,8 +295,9 @@
   // Unload the device library - no references to it should remain.
   iree_hal_amdgpu_device_library_deinitialize(&system->device_library);
 
-  // Close our handle to /dev/kfd prior to (potentially) unloading HSA.
-  iree_hal_amdgpu_kfd_close(system->kfd_fd);
+  // Release platform clock-sampling state before unloading HSA.
+  iree_hal_amdgpu_device_clock_source_deinitialize(
+      &system->device_clock_source);
 
   // This may unload HSA if we were the last retainer in the process.
   iree_hal_amdgpu_libhsa_deinitialize(&system->libhsa);
diff --git a/runtime/src/iree/hal/drivers/amdgpu/system.h b/runtime/src/iree/hal/drivers/amdgpu/system.h
index 3bf2cde..a2ec00f 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/system.h
+++ b/runtime/src/iree/hal/drivers/amdgpu/system.h
@@ -8,6 +8,7 @@
 #define IREE_HAL_DRIVERS_AMDGPU_SYSTEM_H_
 
 #include "iree/base/api.h"
+#include "iree/hal/drivers/amdgpu/util/device_clock.h"
 #include "iree/hal/drivers/amdgpu/util/device_library.h"
 #include "iree/hal/drivers/amdgpu/util/info.h"
 #include "iree/hal/drivers/amdgpu/util/libhsa.h"
@@ -23,8 +24,6 @@
 
 // Options defining total system behavior.
 typedef struct iree_hal_amdgpu_system_options_t {
-  // Enable dispatch-level tracing (if device instrumentation is compiled in).
-  uint64_t trace_execution : 1;
   // Force queues to run one entry at a time instead of overlapping or
   // aggressively scheduling queue entries out-of-order.
   uint64_t exclusive_execution : 1;
@@ -36,6 +35,13 @@
 // access the CPU memory. We try to allocate resources that may be used
 // frequently on a particular cluster of CPU/GPU agents closest to the agents.
 typedef struct iree_hal_amdgpu_host_memory_pools_t {
+  // Coarse-grained shared host memory pool used for write-once/read-many data.
+  hsa_amd_memory_pool_t coarse_pool;
+  // Shared host memory pool suitable for kernel argument storage.
+  // Allocations from this pool must have
+  // HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_KERNARG_INIT so dispatch packets can use
+  // them as kernarg segment storage.
+  hsa_amd_memory_pool_t kernarg_pool;
   // Memory pool used for various system-level resources.
   // Allocations from this pool must be accessible to all agents. Must have the
   // HSA_REGION_GLOBAL_FLAG_FINE_GRAINED region flag as memory within this
@@ -63,9 +69,8 @@
   // HSA API handle.
   iree_hal_amdgpu_libhsa_t libhsa;
 
-  // /dev/kfd handle, if needed on the platform.
-  // TODO(benvanik): drop this when HSA supports all of the ioctls we need.
-  int kfd_fd;
+  // Platform source used for device/host clock-correlation sampling.
+  iree_hal_amdgpu_device_clock_source_t device_clock_source;
 
   // System topology as visible to the HAL device. This may be a subset of
   // the devices available in the system.
diff --git a/runtime/src/iree/hal/drivers/amdgpu/system_test.cc b/runtime/src/iree/hal/drivers/amdgpu/system_test.cc
index 98184cd..8afb9cf 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/system_test.cc
+++ b/runtime/src/iree/hal/drivers/amdgpu/system_test.cc
@@ -28,7 +28,7 @@
         host_allocator, &libhsa);
     if (!iree_status_is_ok(status)) {
       iree_status_fprint(stderr, status);
-      iree_status_ignore(status);
+      iree_status_free(status);
       GTEST_SKIP() << "HSA not available, skipping tests";
     }
     IREE_ASSERT_OK(
@@ -53,7 +53,6 @@
 
   iree_hal_amdgpu_system_options_t options = {0};
   options.exclusive_execution = 0;
-  options.trace_execution = 0;
 
   iree_hal_amdgpu_system_t* system = NULL;
   IREE_ASSERT_OK(iree_hal_amdgpu_system_allocate(&libhsa, &topology, options,
diff --git a/runtime/src/iree/hal/drivers/amdgpu/transient_buffer.c b/runtime/src/iree/hal/drivers/amdgpu/transient_buffer.c
new file mode 100644
index 0000000..d6a575d
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/transient_buffer.c
@@ -0,0 +1,542 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/transient_buffer.h"
+
+struct iree_hal_amdgpu_transient_buffer_t {
+  // Base HAL buffer resource returned to callers.
+  iree_hal_buffer_t base;
+
+  // Pool this wrapper returns to when its HAL buffer refcount reaches zero.
+  iree_hal_amdgpu_transient_buffer_pool_t* pool;
+
+  // Next wrapper in either the pool return stack or acquire-side cache.
+  iree_hal_amdgpu_transient_buffer_t* pool_next;
+
+  // Provider-backed view staged for queue packet emission and future commit;
+  // retained while non-NULL.
+  iree_hal_buffer_t* staged_backing;
+
+  // Host-visible committed backing. NULL before alloca commit and after
+  // dealloca decommit.
+  //
+  // The queue's semaphore edges provide the real ordering contract; the
+  // acquire/release atomics here make host-side wrapper state transitions
+  // data-race-free and visible to TSAN.
+  iree_atomic_intptr_t committed_backing;
+
+  // Borrowed source pool for the queue-owned reservation token.
+  iree_hal_pool_t* reservation_pool;
+
+  // Queue-owned pool reservation token attached after acquire succeeds.
+  iree_hal_pool_reservation_t reservation;
+
+  // Non-zero while |reservation| is valid and must be released.
+  iree_atomic_int32_t reservation_armed;
+
+  // Set when one dealloca has been accepted for this wrapper. This is
+  // single-owner bookkeeping for reservation release/decommit, not a queue-use
+  // lifetime validator; queue operation order is expressed by semaphores.
+  iree_atomic_int32_t dealloca_queued;
+
+  // Profiling session id owning |profile_allocation_id|.
+  uint64_t profile_session_id;
+
+  // Session-local profiling allocation id for this queue_alloca lifecycle.
+  uint64_t profile_allocation_id;
+};
+
+static const iree_hal_buffer_vtable_t iree_hal_amdgpu_transient_buffer_vtable;
+
+static inline iree_hal_amdgpu_transient_buffer_t*
+iree_hal_amdgpu_transient_buffer_cast(iree_hal_buffer_t* base_buffer) {
+  IREE_HAL_ASSERT_TYPE(base_buffer, &iree_hal_amdgpu_transient_buffer_vtable);
+  return (iree_hal_amdgpu_transient_buffer_t*)base_buffer;
+}
+
+static inline const iree_hal_buffer_vtable_t*
+iree_hal_amdgpu_transient_buffer_backing_vtable(iree_hal_buffer_t* buffer) {
+  return (const iree_hal_buffer_vtable_t*)((const iree_hal_resource_t*)buffer)
+      ->vtable;
+}
+
+static inline iree_hal_buffer_t*
+iree_hal_amdgpu_transient_buffer_load_committed_backing(
+    iree_hal_amdgpu_transient_buffer_t* buffer) {
+  return (iree_hal_buffer_t*)iree_atomic_load(&buffer->committed_backing,
+                                              iree_memory_order_acquire);
+}
+
+//===----------------------------------------------------------------------===//
+// Transient buffer pool
+//===----------------------------------------------------------------------===//
+
+static iree_host_size_t iree_hal_amdgpu_transient_buffer_pool_slot_size(void) {
+  return iree_host_align(sizeof(iree_hal_amdgpu_transient_buffer_t),
+                         iree_alignof(iree_hal_amdgpu_transient_buffer_t));
+}
+
+static iree_status_t iree_hal_amdgpu_transient_buffer_pool_grow_locked(
+    iree_hal_amdgpu_transient_buffer_pool_t* pool) {
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  const iree_host_size_t slot_size =
+      iree_hal_amdgpu_transient_buffer_pool_slot_size();
+  const iree_host_size_t slot_count =
+      pool->block_pool->usable_block_size / slot_size;
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)slot_count);
+
+  iree_arena_block_t* block = NULL;
+  void* block_ptr = NULL;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_arena_block_pool_acquire(pool->block_pool, &block, &block_ptr));
+
+  if (pool->block_tail) {
+    pool->block_tail->next = block;
+  } else {
+    pool->block_head = block;
+  }
+  pool->block_tail = block;
+
+  uint8_t* slot_ptr = (uint8_t*)block_ptr;
+  for (iree_host_size_t i = 0; i < slot_count; ++i) {
+    iree_hal_amdgpu_transient_buffer_t* buffer =
+        (iree_hal_amdgpu_transient_buffer_t*)slot_ptr;
+    buffer->pool = pool;
+    buffer->pool_next = pool->acquire_head;
+    pool->acquire_head = buffer;
+    slot_ptr += slot_size;
+  }
+
+  IREE_TRACE_ZONE_END(z0);
+  return iree_ok_status();
+}
+
+iree_status_t iree_hal_amdgpu_transient_buffer_pool_initialize(
+    iree_arena_block_pool_t* block_pool,
+    iree_hal_amdgpu_transient_buffer_pool_t* out_pool) {
+  IREE_ASSERT_ARGUMENT(block_pool);
+  IREE_ASSERT_ARGUMENT(out_pool);
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  memset(out_pool, 0, sizeof(*out_pool));
+  const iree_host_size_t slot_size =
+      iree_hal_amdgpu_transient_buffer_pool_slot_size();
+  if (IREE_UNLIKELY(block_pool->usable_block_size < slot_size)) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "transient buffer pool block usable size %" PRIhsz
+                            " is smaller than wrapper slot size %" PRIhsz,
+                            block_pool->usable_block_size, slot_size);
+  }
+
+  out_pool->block_pool = block_pool;
+  iree_atomic_store(&out_pool->return_head, 0, iree_memory_order_relaxed);
+  iree_slim_mutex_initialize(&out_pool->mutex);
+#if !defined(NDEBUG)
+  iree_atomic_store(&out_pool->live_count, 0, iree_memory_order_relaxed);
+#endif  // !defined(NDEBUG)
+
+  IREE_TRACE_ZONE_END(z0);
+  return iree_ok_status();
+}
+
+void iree_hal_amdgpu_transient_buffer_pool_deinitialize(
+    iree_hal_amdgpu_transient_buffer_pool_t* pool) {
+  if (!pool || !pool->block_pool) return;
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+#if !defined(NDEBUG)
+  const int32_t live_count =
+      iree_atomic_load(&pool->live_count, iree_memory_order_acquire);
+  IREE_ASSERT(live_count == 0,
+              "deinitializing transient buffer pool with %d live wrappers",
+              live_count);
+#endif  // !defined(NDEBUG)
+
+  iree_atomic_store(&pool->return_head, 0, iree_memory_order_relaxed);
+  pool->acquire_head = NULL;
+  if (pool->block_head) {
+    iree_arena_block_pool_release(pool->block_pool, pool->block_head,
+                                  pool->block_tail);
+  }
+  pool->block_head = NULL;
+  pool->block_tail = NULL;
+  iree_slim_mutex_deinitialize(&pool->mutex);
+  pool->block_pool = NULL;
+
+  IREE_TRACE_ZONE_END(z0);
+}
+
+static iree_status_t iree_hal_amdgpu_transient_buffer_pool_acquire(
+    iree_hal_amdgpu_transient_buffer_pool_t* pool,
+    iree_hal_amdgpu_transient_buffer_t** out_buffer) {
+  *out_buffer = NULL;
+
+  iree_slim_mutex_lock(&pool->mutex);
+
+  iree_status_t status = iree_ok_status();
+  iree_hal_amdgpu_transient_buffer_t* buffer = pool->acquire_head;
+  if (buffer) {
+    pool->acquire_head = buffer->pool_next;
+  } else {
+    buffer = (iree_hal_amdgpu_transient_buffer_t*)iree_atomic_exchange(
+        &pool->return_head, 0, iree_memory_order_acquire);
+    if (buffer) {
+      pool->acquire_head = buffer->pool_next;
+    } else {
+      status = iree_hal_amdgpu_transient_buffer_pool_grow_locked(pool);
+      if (iree_status_is_ok(status)) {
+        buffer = pool->acquire_head;
+        pool->acquire_head = buffer->pool_next;
+      }
+    }
+  }
+
+  iree_slim_mutex_unlock(&pool->mutex);
+
+  if (iree_status_is_ok(status)) {
+    buffer->pool_next = NULL;
+#if !defined(NDEBUG)
+    iree_atomic_fetch_add(&pool->live_count, 1, iree_memory_order_acq_rel);
+#endif  // !defined(NDEBUG)
+    *out_buffer = buffer;
+  }
+  return status;
+}
+
+static void iree_hal_amdgpu_transient_buffer_pool_release(
+    iree_hal_amdgpu_transient_buffer_pool_t* pool,
+    iree_hal_amdgpu_transient_buffer_t* buffer) {
+#if !defined(NDEBUG)
+  const int32_t old_live_count =
+      iree_atomic_fetch_sub(&pool->live_count, 1, iree_memory_order_acq_rel);
+  IREE_ASSERT(old_live_count > 0,
+              "releasing transient buffer wrapper with no live wrapper count");
+#endif  // !defined(NDEBUG)
+
+  intptr_t expected = 0;
+  do {
+    expected = iree_atomic_load(&pool->return_head, iree_memory_order_relaxed);
+    buffer->pool_next = (iree_hal_amdgpu_transient_buffer_t*)expected;
+  } while (!iree_atomic_compare_exchange_weak(
+      &pool->return_head, &expected, (intptr_t)buffer,
+      iree_memory_order_release, iree_memory_order_relaxed));
+}
+
+//===----------------------------------------------------------------------===//
+// Transient buffer wrapper
+//===----------------------------------------------------------------------===//
+
+iree_status_t iree_hal_amdgpu_transient_buffer_create(
+    iree_hal_buffer_placement_t placement, iree_hal_buffer_params_t params,
+    iree_device_size_t allocation_size, iree_device_size_t byte_length,
+    iree_hal_amdgpu_transient_buffer_pool_t* pool,
+    iree_hal_buffer_t** out_buffer) {
+  IREE_ASSERT_ARGUMENT(pool);
+  IREE_ASSERT_ARGUMENT(out_buffer);
+  *out_buffer = NULL;
+  if (IREE_UNLIKELY(byte_length > allocation_size)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "transient buffer byte length (%" PRIu64
+                            ") exceeds allocation size (%" PRIu64 ")",
+                            (uint64_t)byte_length, (uint64_t)allocation_size);
+  }
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  iree_hal_amdgpu_transient_buffer_t* buffer = NULL;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_hal_amdgpu_transient_buffer_pool_acquire(pool, &buffer));
+
+  iree_hal_buffer_initialize(
+      placement, /*allocated_buffer=*/&buffer->base, allocation_size,
+      /*byte_offset=*/0, byte_length, params.type, params.access, params.usage,
+      &iree_hal_amdgpu_transient_buffer_vtable, &buffer->base);
+  buffer->pool = pool;
+  buffer->staged_backing = NULL;
+  iree_atomic_store(&buffer->committed_backing, 0, iree_memory_order_relaxed);
+  buffer->reservation_pool = NULL;
+  memset(&buffer->reservation, 0, sizeof(buffer->reservation));
+  iree_atomic_store(&buffer->reservation_armed, 0, iree_memory_order_relaxed);
+  iree_atomic_store(&buffer->dealloca_queued, 0, iree_memory_order_relaxed);
+  buffer->profile_session_id = 0;
+  buffer->profile_allocation_id = 0;
+
+  *out_buffer = &buffer->base;
+  IREE_TRACE_ZONE_END(z0);
+  return iree_ok_status();
+}
+
+bool iree_hal_amdgpu_transient_buffer_isa(const iree_hal_buffer_t* buffer) {
+  return iree_hal_resource_is(&buffer->resource,
+                              &iree_hal_amdgpu_transient_buffer_vtable);
+}
+
+void iree_hal_amdgpu_transient_buffer_set_profile_allocation(
+    iree_hal_buffer_t* base_buffer, uint64_t session_id,
+    uint64_t allocation_id) {
+  iree_hal_amdgpu_transient_buffer_t* buffer =
+      iree_hal_amdgpu_transient_buffer_cast(base_buffer);
+  buffer->profile_session_id = session_id;
+  buffer->profile_allocation_id = allocation_id;
+}
+
+uint64_t iree_hal_amdgpu_transient_buffer_profile_allocation_id(
+    iree_hal_buffer_t* base_buffer) {
+  iree_hal_amdgpu_transient_buffer_t* buffer =
+      iree_hal_amdgpu_transient_buffer_cast(base_buffer);
+  return buffer->profile_allocation_id;
+}
+
+uint64_t iree_hal_amdgpu_transient_buffer_profile_session_id(
+    iree_hal_buffer_t* base_buffer) {
+  iree_hal_amdgpu_transient_buffer_t* buffer =
+      iree_hal_amdgpu_transient_buffer_cast(base_buffer);
+  return buffer->profile_session_id;
+}
+
+void iree_hal_amdgpu_transient_buffer_attach_reservation(
+    iree_hal_buffer_t* base_buffer, iree_hal_pool_t* pool,
+    const iree_hal_pool_reservation_t* reservation) {
+  IREE_ASSERT_ARGUMENT(base_buffer);
+  IREE_ASSERT_ARGUMENT(pool);
+  IREE_ASSERT_ARGUMENT(reservation);
+  iree_hal_amdgpu_transient_buffer_t* buffer =
+      iree_hal_amdgpu_transient_buffer_cast(base_buffer);
+  IREE_ASSERT_TRUE(buffer->reservation_pool == NULL);
+  IREE_ASSERT_TRUE(iree_atomic_load(&buffer->reservation_armed,
+                                    iree_memory_order_acquire) == 0);
+  buffer->reservation_pool = pool;
+  buffer->reservation = *reservation;
+  iree_atomic_store(&buffer->reservation_armed, 1, iree_memory_order_release);
+}
+
+void iree_hal_amdgpu_transient_buffer_stage_backing(
+    iree_hal_buffer_t* base_buffer, iree_hal_buffer_t* backing_buffer) {
+  IREE_ASSERT_ARGUMENT(base_buffer);
+  IREE_ASSERT_ARGUMENT(backing_buffer);
+  iree_hal_amdgpu_transient_buffer_t* buffer =
+      iree_hal_amdgpu_transient_buffer_cast(base_buffer);
+  IREE_ASSERT_TRUE(buffer->staged_backing == NULL);
+  IREE_ASSERT_TRUE(
+      iree_hal_amdgpu_transient_buffer_load_committed_backing(buffer) == NULL);
+  buffer->staged_backing = backing_buffer;
+}
+
+void iree_hal_amdgpu_transient_buffer_commit(iree_hal_buffer_t* base_buffer) {
+  IREE_ASSERT_ARGUMENT(base_buffer);
+  iree_hal_amdgpu_transient_buffer_t* buffer =
+      iree_hal_amdgpu_transient_buffer_cast(base_buffer);
+  IREE_ASSERT_TRUE(buffer->staged_backing != NULL);
+  IREE_ASSERT_TRUE(
+      iree_hal_amdgpu_transient_buffer_load_committed_backing(buffer) == NULL);
+  iree_atomic_store(&buffer->committed_backing,
+                    (intptr_t)buffer->staged_backing,
+                    iree_memory_order_release);
+}
+
+void iree_hal_amdgpu_transient_buffer_decommit(iree_hal_buffer_t* base_buffer) {
+  IREE_ASSERT_ARGUMENT(base_buffer);
+  iree_hal_amdgpu_transient_buffer_t* buffer =
+      iree_hal_amdgpu_transient_buffer_cast(base_buffer);
+  iree_atomic_store(&buffer->committed_backing, 0, iree_memory_order_release);
+  if (buffer->staged_backing) {
+    iree_hal_buffer_release(buffer->staged_backing);
+    buffer->staged_backing = NULL;
+  }
+}
+
+bool iree_hal_amdgpu_transient_buffer_begin_dealloca(
+    iree_hal_buffer_t* base_buffer) {
+  IREE_ASSERT_ARGUMENT(base_buffer);
+  iree_hal_amdgpu_transient_buffer_t* buffer =
+      iree_hal_amdgpu_transient_buffer_cast(base_buffer);
+  int32_t expected = 0;
+  return iree_atomic_compare_exchange_strong(
+      &buffer->dealloca_queued, &expected, 1, iree_memory_order_acq_rel,
+      iree_memory_order_acquire);
+}
+
+void iree_hal_amdgpu_transient_buffer_abort_dealloca(
+    iree_hal_buffer_t* base_buffer) {
+  IREE_ASSERT_ARGUMENT(base_buffer);
+  iree_hal_amdgpu_transient_buffer_t* buffer =
+      iree_hal_amdgpu_transient_buffer_cast(base_buffer);
+  iree_atomic_store(&buffer->dealloca_queued, 0, iree_memory_order_release);
+}
+
+bool iree_hal_amdgpu_transient_buffer_query_reservation(
+    iree_hal_buffer_t* base_buffer, iree_hal_pool_t** out_pool,
+    iree_hal_pool_reservation_t* out_reservation) {
+  IREE_ASSERT_ARGUMENT(base_buffer);
+  IREE_ASSERT_ARGUMENT(out_pool);
+  IREE_ASSERT_ARGUMENT(out_reservation);
+  *out_pool = NULL;
+  memset(out_reservation, 0, sizeof(*out_reservation));
+
+  iree_hal_amdgpu_transient_buffer_t* buffer =
+      iree_hal_amdgpu_transient_buffer_cast(base_buffer);
+  if (!buffer->reservation_pool) return false;
+  if (iree_atomic_load(&buffer->reservation_armed, iree_memory_order_acquire) ==
+      0) {
+    return false;
+  }
+  *out_pool = buffer->reservation_pool;
+  *out_reservation = buffer->reservation;
+  return true;
+}
+
+void iree_hal_amdgpu_transient_buffer_release_reservation(
+    iree_hal_buffer_t* base_buffer,
+    const iree_async_frontier_t* death_frontier) {
+  IREE_ASSERT_ARGUMENT(base_buffer);
+  iree_hal_amdgpu_transient_buffer_t* buffer =
+      iree_hal_amdgpu_transient_buffer_cast(base_buffer);
+  iree_hal_pool_t* reservation_pool = buffer->reservation_pool;
+  if (!reservation_pool) return;
+  const int32_t was_armed = iree_atomic_exchange(&buffer->reservation_armed, 0,
+                                                 iree_memory_order_acq_rel);
+  if (was_armed) {
+    iree_hal_pool_release_reservation(reservation_pool, &buffer->reservation,
+                                      death_frontier);
+  }
+  buffer->reservation_pool = NULL;
+  memset(&buffer->reservation, 0, sizeof(buffer->reservation));
+}
+
+iree_hal_buffer_t* iree_hal_amdgpu_transient_buffer_backing_buffer(
+    iree_hal_buffer_t* base_buffer) {
+  IREE_ASSERT_ARGUMENT(base_buffer);
+  iree_hal_amdgpu_transient_buffer_t* buffer =
+      iree_hal_amdgpu_transient_buffer_cast(base_buffer);
+  return buffer->staged_backing;
+}
+
+iree_status_t iree_hal_amdgpu_transient_buffer_resolve_committed_backing(
+    iree_hal_buffer_t* base_buffer, iree_hal_buffer_t** out_backing_buffer) {
+  IREE_ASSERT_ARGUMENT(base_buffer);
+  IREE_ASSERT_ARGUMENT(out_backing_buffer);
+  *out_backing_buffer = NULL;
+  iree_hal_amdgpu_transient_buffer_t* buffer =
+      iree_hal_amdgpu_transient_buffer_cast(base_buffer);
+  if (IREE_UNLIKELY(iree_atomic_load(&buffer->dealloca_queued,
+                                     iree_memory_order_acquire) != 0)) {
+    return iree_make_status(
+        IREE_STATUS_FAILED_PRECONDITION,
+        "transient buffer has been queued for deallocation");
+  }
+  iree_hal_buffer_t* backing_buffer =
+      iree_hal_amdgpu_transient_buffer_load_committed_backing(buffer);
+  if (IREE_UNLIKELY(!backing_buffer)) {
+    return iree_make_status(
+        IREE_STATUS_FAILED_PRECONDITION,
+        "transient buffer has not been committed; wait on the alloca signal "
+        "semaphores before resolving its committed backing");
+  }
+  *out_backing_buffer = backing_buffer;
+  return iree_ok_status();
+}
+
+static void iree_hal_amdgpu_transient_buffer_destroy(
+    iree_hal_buffer_t* base_buffer) {
+  iree_hal_amdgpu_transient_buffer_t* buffer =
+      iree_hal_amdgpu_transient_buffer_cast(base_buffer);
+  iree_hal_amdgpu_transient_buffer_pool_t* pool = buffer->pool;
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  iree_hal_amdgpu_transient_buffer_decommit(base_buffer);
+  iree_hal_amdgpu_transient_buffer_release_reservation(base_buffer,
+                                                       /*death_frontier=*/NULL);
+
+  iree_atomic_store(&buffer->dealloca_queued, 0, iree_memory_order_relaxed);
+  buffer->profile_session_id = 0;
+  buffer->profile_allocation_id = 0;
+  iree_hal_amdgpu_transient_buffer_pool_release(pool, buffer);
+  IREE_TRACE_ZONE_END(z0);
+}
+
+static iree_status_t iree_hal_amdgpu_transient_buffer_load_host_backing(
+    iree_hal_amdgpu_transient_buffer_t* buffer,
+    iree_hal_buffer_t** out_backing_buffer) {
+  if (IREE_UNLIKELY(iree_atomic_load(&buffer->dealloca_queued,
+                                     iree_memory_order_acquire) != 0)) {
+    return iree_make_status(
+        IREE_STATUS_FAILED_PRECONDITION,
+        "transient buffer has been queued for deallocation");
+  }
+  iree_hal_buffer_t* backing_buffer =
+      iree_hal_amdgpu_transient_buffer_load_committed_backing(buffer);
+  if (IREE_UNLIKELY(!backing_buffer)) {
+    return iree_make_status(
+        IREE_STATUS_FAILED_PRECONDITION,
+        "transient buffer has not been committed; wait on the alloca "
+        "signal semaphores before accessing it");
+  }
+  *out_backing_buffer = backing_buffer;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_transient_buffer_map_range(
+    iree_hal_buffer_t* base_buffer, iree_hal_mapping_mode_t mapping_mode,
+    iree_hal_memory_access_t memory_access,
+    iree_device_size_t local_byte_offset, iree_device_size_t local_byte_length,
+    iree_hal_buffer_mapping_t* mapping) {
+  iree_hal_amdgpu_transient_buffer_t* buffer =
+      iree_hal_amdgpu_transient_buffer_cast(base_buffer);
+  iree_hal_buffer_t* backing_buffer = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_transient_buffer_load_host_backing(
+      buffer, &backing_buffer));
+  return iree_hal_amdgpu_transient_buffer_backing_vtable(backing_buffer)
+      ->map_range(backing_buffer, mapping_mode, memory_access,
+                  local_byte_offset, local_byte_length, mapping);
+}
+
+static iree_status_t iree_hal_amdgpu_transient_buffer_unmap_range(
+    iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset,
+    iree_device_size_t local_byte_length, iree_hal_buffer_mapping_t* mapping) {
+  iree_hal_amdgpu_transient_buffer_t* buffer =
+      iree_hal_amdgpu_transient_buffer_cast(base_buffer);
+  iree_hal_buffer_t* backing_buffer = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_transient_buffer_load_host_backing(
+      buffer, &backing_buffer));
+  return iree_hal_amdgpu_transient_buffer_backing_vtable(backing_buffer)
+      ->unmap_range(backing_buffer, local_byte_offset, local_byte_length,
+                    mapping);
+}
+
+static iree_status_t iree_hal_amdgpu_transient_buffer_invalidate_range(
+    iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset,
+    iree_device_size_t local_byte_length) {
+  iree_hal_amdgpu_transient_buffer_t* buffer =
+      iree_hal_amdgpu_transient_buffer_cast(base_buffer);
+  iree_hal_buffer_t* backing_buffer = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_transient_buffer_load_host_backing(
+      buffer, &backing_buffer));
+  return iree_hal_amdgpu_transient_buffer_backing_vtable(backing_buffer)
+      ->invalidate_range(backing_buffer, local_byte_offset, local_byte_length);
+}
+
+static iree_status_t iree_hal_amdgpu_transient_buffer_flush_range(
+    iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset,
+    iree_device_size_t local_byte_length) {
+  iree_hal_amdgpu_transient_buffer_t* buffer =
+      iree_hal_amdgpu_transient_buffer_cast(base_buffer);
+  iree_hal_buffer_t* backing_buffer = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_transient_buffer_load_host_backing(
+      buffer, &backing_buffer));
+  return iree_hal_amdgpu_transient_buffer_backing_vtable(backing_buffer)
+      ->flush_range(backing_buffer, local_byte_offset, local_byte_length);
+}
+
+static const iree_hal_buffer_vtable_t iree_hal_amdgpu_transient_buffer_vtable =
+    {
+        .recycle = iree_hal_buffer_recycle,
+        .destroy = iree_hal_amdgpu_transient_buffer_destroy,
+        .map_range = iree_hal_amdgpu_transient_buffer_map_range,
+        .unmap_range = iree_hal_amdgpu_transient_buffer_unmap_range,
+        .invalidate_range = iree_hal_amdgpu_transient_buffer_invalidate_range,
+        .flush_range = iree_hal_amdgpu_transient_buffer_flush_range,
+};
diff --git a/runtime/src/iree/hal/drivers/amdgpu/transient_buffer.h b/runtime/src/iree/hal/drivers/amdgpu/transient_buffer.h
new file mode 100644
index 0000000..c262fa2
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/transient_buffer.h
@@ -0,0 +1,174 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_TRANSIENT_BUFFER_H_
+#define IREE_HAL_DRIVERS_AMDGPU_TRANSIENT_BUFFER_H_
+
+#include "iree/async/frontier.h"
+#include "iree/base/api.h"
+#include "iree/base/internal/arena.h"
+#include "iree/base/internal/atomics.h"
+#include "iree/base/threading/mutex.h"
+#include "iree/hal/api.h"
+#include "iree/hal/pool.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+typedef struct iree_hal_amdgpu_transient_buffer_t
+    iree_hal_amdgpu_transient_buffer_t;
+
+// Per-physical-device pool of queue_alloca transient HAL buffer wrappers.
+//
+// The pool is borrowed by wrappers returned from it. The owning physical device
+// must outlive every transient buffer allocated from the pool, matching the
+// reservation lifetime rule for the allocation pools backing those wrappers.
+typedef struct iree_hal_amdgpu_transient_buffer_pool_t {
+  // Per-physical-device host block pool used for cold wrapper-block growth.
+  iree_arena_block_pool_t* block_pool;
+
+  // Head of the lock-free return stack pushed by transient-buffer destroy.
+  iree_atomic_intptr_t return_head;
+
+  // Serializes acquire-side cache pops, return-stack migration, and growth.
+  iree_slim_mutex_t mutex;
+
+  // Head of the mutex-protected acquire-side cache.
+  iree_hal_amdgpu_transient_buffer_t* acquire_head;
+
+  // First host block owned by this pool.
+  iree_arena_block_t* block_head;
+
+  // Last host block owned by this pool.
+  iree_arena_block_t* block_tail;
+
+#if !defined(NDEBUG)
+  // Number of wrappers currently retained by users or in-flight operations.
+  iree_atomic_int32_t live_count;
+#endif  // !defined(NDEBUG)
+} iree_hal_amdgpu_transient_buffer_pool_t;
+
+// Initializes a per-physical-device transient wrapper pool.
+//
+// No wrapper memory is allocated until the first acquire. Wrapper storage grows
+// in blocks borrowed from |block_pool| and returned during deinitialization.
+iree_status_t iree_hal_amdgpu_transient_buffer_pool_initialize(
+    iree_arena_block_pool_t* block_pool,
+    iree_hal_amdgpu_transient_buffer_pool_t* out_pool);
+
+// Deinitializes the pool and releases all cold-grown wrapper blocks.
+//
+// All buffers allocated from the pool must have been released before this is
+// called. Violating that lifetime contract is a device teardown/use-after-free
+// bug and is checked in debug builds.
+void iree_hal_amdgpu_transient_buffer_pool_deinitialize(
+    iree_hal_amdgpu_transient_buffer_pool_t* pool);
+
+// AMDGPU queue-ordered transient buffer wrapper.
+//
+// The wrapper is returned by queue_alloca before its backing allocation becomes
+// user-visible. The queue stages a borrowed provider buffer view immediately so
+// later queue submissions can resolve device pointers, then commits that staged
+// backing in the notification ring's pre-signal phase. queue_dealloca releases
+// the pool reservation at submit time with a death frontier and decommits the
+// wrapper in the pre-signal phase before publishing dealloca completion.
+iree_status_t iree_hal_amdgpu_transient_buffer_create(
+    iree_hal_buffer_placement_t placement, iree_hal_buffer_params_t params,
+    iree_device_size_t allocation_size, iree_device_size_t byte_length,
+    iree_hal_amdgpu_transient_buffer_pool_t* pool,
+    iree_hal_buffer_t** out_buffer);
+
+// Returns true if |buffer| is an AMDGPU transient wrapper.
+bool iree_hal_amdgpu_transient_buffer_isa(const iree_hal_buffer_t* buffer);
+
+// Tags |buffer| with the profiling identity for its queue_alloca lifecycle.
+//
+// The id is session-local and joins pool-reservation, queue alloca/dealloca,
+// queue-event, and eventual pool-release records for the same transient
+// allocation. The wrapper clears the id when it returns to its pool.
+void iree_hal_amdgpu_transient_buffer_set_profile_allocation(
+    iree_hal_buffer_t* buffer, uint64_t session_id, uint64_t allocation_id);
+
+// Returns the session-local profiling allocation id for |buffer|, or 0.
+uint64_t iree_hal_amdgpu_transient_buffer_profile_allocation_id(
+    iree_hal_buffer_t* buffer);
+
+// Returns the profiling session id owning the allocation id for |buffer|, or 0.
+uint64_t iree_hal_amdgpu_transient_buffer_profile_session_id(
+    iree_hal_buffer_t* buffer);
+
+// Attaches a queue-owned pool reservation to |buffer|.
+//
+// |pool| is borrowed and must outlive the transient buffer. |reservation|
+// ownership transfers to the wrapper until
+// iree_hal_amdgpu_transient_buffer_release_reservation() or destroy.
+void iree_hal_amdgpu_transient_buffer_attach_reservation(
+    iree_hal_buffer_t* buffer, iree_hal_pool_t* pool,
+    const iree_hal_pool_reservation_t* reservation);
+
+// Stages a backing view for future queue packet emission and commit.
+//
+// Takes ownership of one |backing_buffer| reference. The reference is released
+// by decommit or destroy.
+void iree_hal_amdgpu_transient_buffer_stage_backing(
+    iree_hal_buffer_t* buffer, iree_hal_buffer_t* backing_buffer);
+
+// Publishes the staged backing buffer to host-visible map/flush/invalidate
+// APIs.
+void iree_hal_amdgpu_transient_buffer_commit(iree_hal_buffer_t* buffer);
+
+// Decommits the wrapper and releases the staged backing view.
+void iree_hal_amdgpu_transient_buffer_decommit(iree_hal_buffer_t* buffer);
+
+// Marks the wrapper as queued for deallocation. Returns false if a dealloca has
+// already been queued for this wrapper.
+bool iree_hal_amdgpu_transient_buffer_begin_dealloca(iree_hal_buffer_t* buffer);
+
+// Clears a queued-dealloca marker after a submission/capture failure. Must only
+// be used when no dealloca completion action was published.
+void iree_hal_amdgpu_transient_buffer_abort_dealloca(iree_hal_buffer_t* buffer);
+
+// Returns the attached pool reservation without transferring ownership.
+//
+// Used by cold diagnostic/profiling paths that need to describe the reservation
+// before queue_dealloca releases it. Returns false if |buffer| has no armed
+// reservation.
+bool iree_hal_amdgpu_transient_buffer_query_reservation(
+    iree_hal_buffer_t* buffer, iree_hal_pool_t** out_pool,
+    iree_hal_pool_reservation_t* out_reservation);
+
+// Releases the attached reservation exactly once. No-op if none is attached or
+// if the reservation has already been released.
+void iree_hal_amdgpu_transient_buffer_release_reservation(
+    iree_hal_buffer_t* buffer, const iree_async_frontier_t* death_frontier);
+
+// Returns a backing buffer suitable for queue packet emission, or NULL if the
+// wrapper has no staged backing.
+//
+// This is intentionally more permissive than map/flush/invalidate: queue packet
+// builders need the staged backing before user-visible alloca commit so
+// same-queue and cross-queue submissions can be chained by semaphore waits. The
+// same rule applies after a dealloca has been queued but before its decommit
+// action runs; HAL queue operations are semaphore-ordered, not
+// submission-order-ordered.
+iree_hal_buffer_t* iree_hal_amdgpu_transient_buffer_backing_buffer(
+    iree_hal_buffer_t* buffer);
+
+// Resolves the committed backing buffer.
+//
+// Command buffer recording uses this stricter query when capturing static
+// buffer identity: queued work may use staged backing through semaphore
+// dependencies, but reusable command buffers must not capture an alloca result
+// before its allocation has completed or after deallocation has been queued.
+iree_status_t iree_hal_amdgpu_transient_buffer_resolve_committed_backing(
+    iree_hal_buffer_t* buffer, iree_hal_buffer_t** out_backing_buffer);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_TRANSIENT_BUFFER_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/BUILD.bazel b/runtime/src/iree/hal/drivers/amdgpu/util/BUILD.bazel
index 904c4d3..2db1bab 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/util/BUILD.bazel
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/BUILD.bazel
@@ -4,7 +4,12 @@
 # See https://llvm.org/LICENSE.txt for license information.
 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-load("//build_tools/bazel:build_defs.oss.bzl", "iree_runtime_cc_library", "iree_runtime_cc_test")
+load("//build_tools/bazel:build_defs.oss.bzl", "iree_cmake_extra_content", "iree_runtime_cc_fuzz", "iree_runtime_cc_library", "iree_runtime_cc_test")
+load("//build_tools/bazel:cc_binary_benchmark.bzl", "cc_binary_benchmark")
+load(
+    "//build_tools/bazel:iree_hal_cts_test_suite.bzl",
+    "iree_hal_cts_testdata",
+)
 
 package(
     default_visibility = ["//visibility:public"],
@@ -36,6 +41,25 @@
     ],
 )
 
+iree_runtime_cc_library(
+    name = "libaqlprofile",
+    srcs = [
+        "libaqlprofile.c",
+    ],
+    hdrs = [
+        "libaqlprofile.h",
+    ],
+    deps = [
+        ":libhsa",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/base/internal",
+        "//runtime/src/iree/base/internal:dynamic_library",
+        "//runtime/src/iree/base/internal:path",
+        "//runtime/src/iree/hal/drivers/amdgpu/abi",
+        "@hsa_runtime_headers",
+    ],
+)
+
 iree_runtime_cc_test(
     name = "libhsa_test",
     srcs = ["libhsa_test.cc"],
@@ -52,6 +76,22 @@
     ],
 )
 
+iree_runtime_cc_test(
+    name = "libaqlprofile_test",
+    srcs = ["libaqlprofile_test.cc"],
+    group = "iree-hal-drivers-amdgpu-tests",
+    tags = [
+        "driver=amdgpu",
+        "nodocker",
+    ],
+    deps = [
+        ":libaqlprofile",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/testing:gtest",
+        "//runtime/src/iree/testing:gtest_main",
+    ],
+)
+
 iree_runtime_cc_library(
     name = "topology",
     srcs = [
@@ -63,6 +103,7 @@
     deps = [
         ":libhsa",
         "//runtime/src/iree/base",
+        "//runtime/src/iree/hal",
     ],
 )
 
@@ -78,43 +119,277 @@
         ":libhsa",
         ":topology",
         "//runtime/src/iree/base",
+        "//runtime/src/iree/hal",
         "//runtime/src/iree/testing:gtest",
         "//runtime/src/iree/testing:gtest_main",
     ],
 )
 
+iree_runtime_cc_library(
+    name = "hsaco_metadata",
+    srcs = ["hsaco_metadata.c"],
+    hdrs = ["hsaco_metadata.h"],
+    deps = [
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/hal",
+    ],
+)
+
+iree_runtime_cc_test(
+    name = "hsaco_metadata_test",
+    srcs = ["hsaco_metadata_test.cc"],
+    group = "iree-hal-drivers-amdgpu-tests",
+    tags = [
+        "driver=amdgpu",
+        "nodocker",
+    ],
+    deps = [
+        ":hsaco_metadata",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/testing:gtest",
+        "//runtime/src/iree/testing:gtest_main",
+    ],
+)
+
+iree_runtime_cc_fuzz(
+    name = "hsaco_metadata_fuzz",
+    srcs = ["hsaco_metadata_fuzz.cc"],
+    deps = [
+        ":hsaco_metadata",
+        "//runtime/src/iree/base",
+    ],
+)
+
 ##----------------------------------------------------------------------------##
 ## Internal Utilities
 ##----------------------------------------------------------------------------##
 
-# TODO(benvanik): implement omitted files.
 iree_runtime_cc_library(
-    name = "util",
+    name = "affinity",
+    hdrs = ["affinity.h"],
+    deps = ["//runtime/src/iree/base/internal"],
+)
+
+iree_runtime_cc_library(
+    name = "packet",
+    hdrs = [
+        "aql_emitter.h",
+        "pm4_capabilities.h",
+        "pm4_emitter.h",
+    ],
+    deps = [
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/hal/drivers/amdgpu/abi",
+    ],
+)
+
+iree_runtime_cc_library(
+    name = "info",
+    srcs = ["info.c"],
+    hdrs = ["info.h"],
+    deps = [
+        ":libhsa",
+        "//runtime/src/iree/base",
+    ],
+)
+
+iree_runtime_cc_library(
+    name = "target_id",
+    srcs = ["target_id.c"],
+    hdrs = ["target_id.h"],
+    textual_hdrs = ["target_id_map.inl"],
+    deps = [
+        "//runtime/src/iree/base",
+    ],
+)
+
+iree_runtime_cc_library(
+    name = "code_object_target",
+    srcs = ["code_object_target.c"],
+    hdrs = ["code_object_target.h"],
+    deps = [
+        ":target_id",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/base:core_headers",
+        "//runtime/src/iree/hal/utils:elf_format",
+    ],
+)
+
+iree_runtime_cc_test(
+    name = "code_object_target_test",
+    srcs = ["code_object_target_test.cc"],
+    group = "iree-hal-drivers-amdgpu-tests",
+    tags = [
+        "driver=amdgpu",
+        "nodocker",
+    ],
+    deps = [
+        ":code_object_target",
+        ":target_id",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/base:core_headers",
+        "//runtime/src/iree/testing:gtest",
+        "//runtime/src/iree/testing:gtest_main",
+    ],
+)
+
+iree_runtime_cc_library(
+    name = "device_clock",
+    srcs = ["device_clock.c"],
+    hdrs = ["device_clock.h"],
+    deps = [
+        ":kfd",
+        "//runtime/src/iree/base",
+    ],
+)
+
+iree_runtime_cc_library(
+    name = "kfd",
+    srcs = ["kfd.c"],
+    hdrs = ["kfd.h"],
+    deps = ["//runtime/src/iree/base"],
+)
+
+iree_runtime_cc_library(
+    name = "pm4_program",
+    srcs = ["pm4_program.c"],
+    hdrs = [
+        "pm4_program.h",
+    ],
+    deps = [
+        ":libhsa",
+        ":packet",
+        "//runtime/src/iree/base",
+    ],
+)
+
+iree_runtime_cc_library(
+    name = "queue_primitives",
     srcs = [
-        "block_pool.c",
-        "device_library.c",
-        "info.c",
-        "kfd.c",
-        "vmem.c",
+        "kernarg_ring.c",
+        "notification_ring.c",
+        "queue_upload_ring.c",
+        "signal_pool.c",
     ],
     hdrs = [
-        "affinity.h",
-        "block_pool.h",
-        "device_library.h",
+        "aql_ring.h",
+        "epoch_signal_table.h",
         "error_callback.h",
-        "info.h",
-        "kfd.h",
-        "vmem.h",
+        "kernarg_ring.h",
+        "notification_ring.h",
+        "queue_upload_ring.h",
+        "signal_pool.h",
     ],
     deps = [
         ":libhsa",
+        "//runtime/src/iree/async",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/base/internal",
+        "//runtime/src/iree/base/internal:arena",
+        "//runtime/src/iree/base/threading",
+        "//runtime/src/iree/hal",
+        "//runtime/src/iree/hal/utils:resource_set",
+    ],
+)
+
+iree_runtime_cc_library(
+    name = "vmem",
+    srcs = ["vmem.c"],
+    hdrs = ["vmem.h"],
+    deps = [
+        ":libhsa",
         ":topology",
         "//runtime/src/iree/base",
-        "//runtime/src/iree/base/internal",
+        "//runtime/src/iree/hal",
+    ],
+)
+
+iree_runtime_cc_library(
+    name = "block_pool",
+    srcs = ["block_pool.c"],
+    hdrs = ["block_pool.h"],
+    deps = [
+        ":libhsa",
+        "//runtime/src/iree/base",
         "//runtime/src/iree/base/threading",
-        "//runtime/src/iree/base/threading:thread",
-        "//runtime/src/iree/hal/drivers/amdgpu/device:binaries",
+        "//runtime/src/iree/hal/memory:tracing",
+    ],
+)
+
+iree_runtime_cc_library(
+    name = "device_library",
+    srcs = ["device_library.c"],
+    hdrs = ["device_library.h"],
+    deps = [
+        ":device_library_target",
+        ":libhsa",
+        ":topology",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/base/internal",
         "//runtime/src/iree/hal/drivers/amdgpu/device:headers",
+        "//runtime/src/iree/hal/drivers/amdgpu/device/binaries:toc",
+    ],
+)
+
+iree_runtime_cc_library(
+    name = "device_library_target",
+    srcs = ["device_library_target.c"],
+    hdrs = ["device_library_target.h"],
+    deps = [
+        ":target_id",
+        "//runtime/src/iree/base",
+    ],
+)
+
+iree_runtime_cc_library(
+    name = "util",
+    deps = [
+        ":affinity",
+        ":block_pool",
+        ":device_library",
+        ":device_library_target",
+        ":info",
+        ":kfd",
+        ":packet",
+        ":pm4_program",
+        ":queue_primitives",
+        ":target_id",
+        ":vmem",
+    ],
+)
+
+iree_runtime_cc_library(
+    name = "benchmark_flags",
+    testonly = True,
+    srcs = ["benchmark_flags.cc"],
+    hdrs = ["benchmark_flags.h"],
+    deps = [
+        "//runtime/src/iree/async",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/base/tooling:flags",
+        "@com_google_benchmark//:benchmark",
+    ],
+)
+
+cc_binary_benchmark(
+    name = "blit_benchmark",
+    srcs = ["blit_benchmark.cc"],
+    tags = [
+        "benchmark",
+        "driver=amdgpu",
+        "nodocker",
+    ],
+    deps = [
+        ":benchmark_flags",
+        "//runtime/src:runtime_defines",
+        "//runtime/src/iree/async",
+        "//runtime/src/iree/async/util:proactor_pool",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/base/threading",
+        "//runtime/src/iree/base/tooling:flags",
+        "//runtime/src/iree/hal",
+        "//runtime/src/iree/hal/drivers/amdgpu/registration",
+        "@com_google_benchmark//:benchmark",
     ],
 )
 
@@ -127,8 +402,9 @@
         "nodocker",
     ],
     deps = [
+        ":block_pool",
         ":topology",
-        ":util",
+        ":vmem",
         "//runtime/src/iree/base",
         "//runtime/src/iree/testing:gtest",
         "//runtime/src/iree/testing:gtest_main",
@@ -144,8 +420,8 @@
         "nodocker",
     ],
     deps = [
+        ":device_library",
         ":topology",
-        ":util",
         "//runtime/src/iree/base",
         "//runtime/src/iree/hal/drivers/amdgpu/device:headers",
         "//runtime/src/iree/testing:gtest",
@@ -154,6 +430,94 @@
 )
 
 iree_runtime_cc_test(
+    name = "device_library_target_test",
+    srcs = ["device_library_target_test.cc"],
+    group = "iree-hal-drivers-amdgpu-tests",
+    deps = [
+        ":device_library_target",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/testing:gtest",
+        "//runtime/src/iree/testing:gtest_main",
+    ],
+)
+
+iree_runtime_cc_test(
+    name = "target_id_test",
+    srcs = ["target_id_test.cc"],
+    group = "iree-hal-drivers-amdgpu-tests",
+    tags = [
+        "driver=amdgpu",
+        "nodocker",
+    ],
+    deps = [
+        ":target_id",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/testing:gtest",
+        "//runtime/src/iree/testing:gtest_main",
+    ],
+)
+
+iree_runtime_cc_test(
+    name = "epoch_signal_table_test",
+    srcs = ["epoch_signal_table_test.cc"],
+    group = "iree-hal-drivers-amdgpu-tests",
+    tags = [
+        "driver=amdgpu",
+        "nodocker",
+    ],
+    deps = [
+        ":queue_primitives",
+        "//runtime/src/iree/async",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/testing:gtest",
+        "//runtime/src/iree/testing:gtest_main",
+    ],
+)
+
+iree_runtime_cc_test(
+    name = "kernarg_ring_test",
+    srcs = ["kernarg_ring_test.cc"],
+    group = "iree-hal-drivers-amdgpu-tests",
+    tags = [
+        "driver=amdgpu",
+        "nodocker",
+    ],
+    deps = [
+        ":queue_primitives",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/testing:gtest",
+        "//runtime/src/iree/testing:gtest_main",
+    ],
+)
+
+iree_runtime_cc_test(
+    name = "queue_upload_ring_test",
+    srcs = ["queue_upload_ring_test.cc"],
+    group = "iree-hal-drivers-amdgpu-tests",
+    tags = [
+        "driver=amdgpu",
+        "nodocker",
+    ],
+    deps = [
+        ":queue_primitives",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/testing:gtest",
+        "//runtime/src/iree/testing:gtest_main",
+    ],
+)
+
+iree_runtime_cc_test(
+    name = "device_clock_test",
+    srcs = ["device_clock_test.cc"],
+    deps = [
+        ":device_clock",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/testing:gtest",
+        "//runtime/src/iree/testing:gtest_main",
+    ],
+)
+
+iree_runtime_cc_test(
     name = "kfd_test",
     srcs = ["kfd_test.cc"],
     group = "iree-hal-drivers-amdgpu-tests",
@@ -162,9 +526,9 @@
         "nodocker",
     ],
     deps = [
+        ":kfd",
         ":libhsa",
         ":topology",
-        ":util",
         "//runtime/src/iree/base",
         "//runtime/src/iree/testing:gtest",
         "//runtime/src/iree/testing:gtest_main",
@@ -172,6 +536,175 @@
 )
 
 iree_runtime_cc_test(
+    name = "notification_ring_test",
+    srcs = ["notification_ring_test.cc"],
+    group = "iree-hal-drivers-amdgpu-tests",
+    tags = [
+        "driver=amdgpu",
+        "nodocker",
+    ],
+    deps = [
+        ":libhsa",
+        ":queue_primitives",
+        "//runtime/src/iree/async",
+        "//runtime/src/iree/async:platform",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/base/internal:arena",
+        "//runtime/src/iree/testing:gtest",
+        "//runtime/src/iree/testing:gtest_main",
+    ],
+)
+
+iree_runtime_cc_test(
+    name = "aql_emitter_test",
+    srcs = ["aql_emitter_test.cc"],
+    group = "iree-hal-drivers-amdgpu-tests",
+    tags = [
+        "driver=amdgpu",
+        "nodocker",
+    ],
+    deps = [
+        ":packet",
+        "//runtime/src/iree/testing:gtest",
+        "//runtime/src/iree/testing:gtest_main",
+    ],
+)
+
+iree_runtime_cc_test(
+    name = "pm4_emitter_test",
+    srcs = ["pm4_emitter_test.cc"],
+    group = "iree-hal-drivers-amdgpu-tests",
+    tags = [
+        "driver=amdgpu",
+        "nodocker",
+    ],
+    deps = [
+        ":packet",
+        "//runtime/src/iree/testing:gtest",
+        "//runtime/src/iree/testing:gtest_main",
+    ],
+)
+
+iree_runtime_cc_test(
+    name = "pm4_program_test",
+    srcs = ["pm4_program_test.cc"],
+    group = "iree-hal-drivers-amdgpu-tests",
+    tags = [
+        "driver=amdgpu",
+        "nodocker",
+    ],
+    deps = [
+        ":packet",
+        ":pm4_program",
+        ":topology",
+        ":vmem",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/testing:gtest",
+        "//runtime/src/iree/testing:gtest_main",
+    ],
+)
+
+iree_cmake_extra_content(
+    content = "if(TARGET iree::hal::drivers::amdgpu::cts::testdata_amdgpu)",
+    inline = True,
+)
+
+filegroup(
+    name = "queue_benchmark_testdata_srcs",
+    testonly = True,
+    srcs = ["queue_benchmark_testdata.mlir"],
+)
+
+iree_hal_cts_testdata(
+    backend_name = "amdgpu",
+    flag_values = {
+        "ROCM_BC_DIR": "@amdgpu_device_libs//:bitcode",
+        "ROCM_TARGET": "//build_tools/bazel:rocm_test_target",
+    },
+    flags = [
+        "--iree-rocm-target={ROCM_TARGET}",
+        "--iree-rocm-bc-dir={ROCM_BC_DIR}",
+    ],
+    format_name = "amdgpu_queue_benchmark",
+    format_string = '"amdgcn-amd-amdhsa--{ROCM_TARGET}"',
+    identifier = "iree_queue_benchmark_testdata_amdgpu",
+    target_device = "amdgpu",
+    testdata = "//runtime/src/iree/hal/drivers/amdgpu/util:queue_benchmark_testdata_srcs",
+)
+
+cc_binary_benchmark(
+    name = "queue_benchmark",
+    srcs = ["queue_benchmark.cc"],
+    tags = [
+        "benchmark",
+        "driver=amdgpu",
+        "nodocker",
+    ],
+    deps = [
+        ":benchmark_flags",
+        ":testdata_amdgpu_queue_benchmark",
+        "//runtime/src:runtime_defines",
+        "//runtime/src/iree/async",
+        "//runtime/src/iree/async/util:proactor_pool",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/base/threading",
+        "//runtime/src/iree/base/tooling:flags",
+        "//runtime/src/iree/hal",
+        "//runtime/src/iree/hal/drivers/amdgpu",
+        "//runtime/src/iree/hal/drivers/amdgpu:headers",
+        "//runtime/src/iree/hal/drivers/amdgpu:queue_affinity",
+        "//runtime/src/iree/hal/drivers/amdgpu/cts:testdata_amdgpu",
+        "//runtime/src/iree/hal/drivers/amdgpu/device:dispatch",
+        "//runtime/src/iree/hal/drivers/amdgpu/device:headers",
+        "//runtime/src/iree/hal/drivers/amdgpu/registration",
+        "//runtime/src/iree/hal/memory:tlsf_pool",
+        "//runtime/src/iree/io:file_handle",
+        "@com_google_benchmark//:benchmark",
+    ],
+)
+
+iree_cmake_extra_content(
+    content = "endif()",
+    inline = True,
+)
+
+iree_runtime_cc_test(
+    name = "signal_pool_test",
+    srcs = ["signal_pool_test.cc"],
+    group = "iree-hal-drivers-amdgpu-tests",
+    tags = [
+        "driver=amdgpu",
+        "nodocker",
+    ],
+    deps = [
+        ":libhsa",
+        ":queue_primitives",
+        ":topology",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/testing:gtest",
+        "//runtime/src/iree/testing:gtest_main",
+    ],
+)
+
+cc_binary_benchmark(
+    name = "signal_pool_benchmark",
+    srcs = ["signal_pool_benchmark.cc"],
+    tags = [
+        "benchmark",
+        "driver=amdgpu",
+        "nodocker",
+    ],
+    deps = [
+        ":libhsa",
+        ":queue_primitives",
+        ":topology",
+        "//runtime/src:runtime_defines",
+        "//runtime/src/iree/base",
+        "@com_google_benchmark//:benchmark",
+    ],
+)
+
+iree_runtime_cc_test(
     name = "vmem_test",
     srcs = ["vmem_test.cc"],
     group = "iree-hal-drivers-amdgpu-tests",
@@ -182,7 +715,7 @@
     deps = [
         ":libhsa",
         ":topology",
-        ":util",
+        ":vmem",
         "//runtime/src/iree/base",
         "//runtime/src/iree/testing:gtest",
         "//runtime/src/iree/testing:gtest_main",
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/CMakeLists.txt b/runtime/src/iree/hal/drivers/amdgpu/util/CMakeLists.txt
index cd686cc..63dd650 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/util/CMakeLists.txt
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/CMakeLists.txt
@@ -28,6 +28,24 @@
   PUBLIC
 )
 
+iree_cc_library(
+  NAME
+    libaqlprofile
+  HDRS
+    "libaqlprofile.h"
+  SRCS
+    "libaqlprofile.c"
+  DEPS
+    ::libhsa
+    hsa_runtime::headers
+    iree::base
+    iree::base::internal
+    iree::base::internal::dynamic_library
+    iree::base::internal::path
+    iree::hal::drivers::amdgpu::abi
+  PUBLIC
+)
+
 iree_cc_test(
   NAME
     libhsa_test
@@ -45,6 +63,23 @@
     "iree-hal-drivers-amdgpu-tests"
 )
 
+iree_cc_test(
+  NAME
+    libaqlprofile_test
+  SRCS
+    "libaqlprofile_test.cc"
+  DEPS
+    ::libaqlprofile
+    iree::base
+    iree::testing::gtest
+    iree::testing::gtest_main
+  LABELS
+    "driver=amdgpu"
+    "nodocker"
+  GROUP
+    "iree-hal-drivers-amdgpu-tests"
+)
+
 iree_cc_library(
   NAME
     topology
@@ -55,6 +90,7 @@
   DEPS
     ::libhsa
     iree::base
+    iree::hal
   PUBLIC
 )
 
@@ -67,6 +103,7 @@
     ::libhsa
     ::topology
     iree::base
+    iree::hal
     iree::testing::gtest
     iree::testing::gtest_main
   LABELS
@@ -78,41 +115,321 @@
 
 iree_cc_library(
   NAME
-    util
+    hsaco_metadata
+  HDRS
+    "hsaco_metadata.h"
+  SRCS
+    "hsaco_metadata.c"
+  DEPS
+    iree::base
+    iree::hal
+  PUBLIC
+)
+
+iree_cc_test(
+  NAME
+    hsaco_metadata_test
+  SRCS
+    "hsaco_metadata_test.cc"
+  DEPS
+    ::hsaco_metadata
+    iree::base
+    iree::testing::gtest
+    iree::testing::gtest_main
+  LABELS
+    "driver=amdgpu"
+    "nodocker"
+  GROUP
+    "iree-hal-drivers-amdgpu-tests"
+)
+
+iree_cc_fuzz(
+  NAME
+    hsaco_metadata_fuzz
+  SRCS
+    "hsaco_metadata_fuzz.cc"
+  DEPS
+    ::hsaco_metadata
+    iree::base
+)
+
+iree_cc_library(
+  NAME
+    affinity
   HDRS
     "affinity.h"
-    "block_pool.h"
-    "device_library.h"
-    "error_callback.h"
+  DEPS
+    iree::base::internal
+  PUBLIC
+)
+
+iree_cc_library(
+  NAME
+    packet
+  HDRS
+    "aql_emitter.h"
+    "pm4_capabilities.h"
+    "pm4_emitter.h"
+  DEPS
+    iree::base
+    iree::hal::drivers::amdgpu::abi
+  PUBLIC
+)
+
+iree_cc_library(
+  NAME
+    info
+  HDRS
     "info.h"
+  SRCS
+    "info.c"
+  DEPS
+    ::libhsa
+    iree::base
+  PUBLIC
+)
+
+iree_cc_library(
+  NAME
+    target_id
+  HDRS
+    "target_id.h"
+  TEXTUAL_HDRS
+    "target_id_map.inl"
+  SRCS
+    "target_id.c"
+  DEPS
+    iree::base
+  PUBLIC
+)
+
+iree_cc_library(
+  NAME
+    code_object_target
+  HDRS
+    "code_object_target.h"
+  SRCS
+    "code_object_target.c"
+  DEPS
+    ::target_id
+    iree::base
+    iree::base::core_headers
+    iree::hal::utils::elf_format
+  PUBLIC
+)
+
+iree_cc_test(
+  NAME
+    code_object_target_test
+  SRCS
+    "code_object_target_test.cc"
+  DEPS
+    ::code_object_target
+    ::target_id
+    iree::base
+    iree::base::core_headers
+    iree::testing::gtest
+    iree::testing::gtest_main
+  LABELS
+    "driver=amdgpu"
+    "nodocker"
+  GROUP
+    "iree-hal-drivers-amdgpu-tests"
+)
+
+iree_cc_library(
+  NAME
+    device_clock
+  HDRS
+    "device_clock.h"
+  SRCS
+    "device_clock.c"
+  DEPS
+    ::kfd
+    iree::base
+  PUBLIC
+)
+
+iree_cc_library(
+  NAME
+    kfd
+  HDRS
     "kfd.h"
+  SRCS
+    "kfd.c"
+  DEPS
+    iree::base
+  PUBLIC
+)
+
+iree_cc_library(
+  NAME
+    pm4_program
+  HDRS
+    "pm4_program.h"
+  SRCS
+    "pm4_program.c"
+  DEPS
+    ::libhsa
+    ::packet
+    iree::base
+  PUBLIC
+)
+
+iree_cc_library(
+  NAME
+    queue_primitives
+  HDRS
+    "aql_ring.h"
+    "epoch_signal_table.h"
+    "error_callback.h"
+    "kernarg_ring.h"
+    "notification_ring.h"
+    "queue_upload_ring.h"
+    "signal_pool.h"
+  SRCS
+    "kernarg_ring.c"
+    "notification_ring.c"
+    "queue_upload_ring.c"
+    "signal_pool.c"
+  DEPS
+    ::libhsa
+    iree::async
+    iree::base
+    iree::base::internal
+    iree::base::internal::arena
+    iree::base::threading
+    iree::hal
+    iree::hal::utils::resource_set
+  PUBLIC
+)
+
+iree_cc_library(
+  NAME
+    vmem
+  HDRS
     "vmem.h"
   SRCS
-    "block_pool.c"
-    "device_library.c"
-    "info.c"
-    "kfd.c"
     "vmem.c"
   DEPS
     ::libhsa
     ::topology
     iree::base
-    iree::base::internal
+    iree::hal
+  PUBLIC
+)
+
+iree_cc_library(
+  NAME
+    block_pool
+  HDRS
+    "block_pool.h"
+  SRCS
+    "block_pool.c"
+  DEPS
+    ::libhsa
+    iree::base
     iree::base::threading
-    iree::base::threading::thread
-    iree::hal::drivers::amdgpu::device::binaries
+    iree::hal::memory::tracing
+  PUBLIC
+)
+
+iree_cc_library(
+  NAME
+    device_library
+  HDRS
+    "device_library.h"
+  SRCS
+    "device_library.c"
+  DEPS
+    ::device_library_target
+    ::libhsa
+    ::topology
+    iree::base
+    iree::base::internal
+    iree::hal::drivers::amdgpu::device::binaries::toc
     iree::hal::drivers::amdgpu::device::headers
   PUBLIC
 )
 
+iree_cc_library(
+  NAME
+    device_library_target
+  HDRS
+    "device_library_target.h"
+  SRCS
+    "device_library_target.c"
+  DEPS
+    ::target_id
+    iree::base
+  PUBLIC
+)
+
+iree_cc_library(
+  NAME
+    util
+  DEPS
+    ::affinity
+    ::block_pool
+    ::device_library
+    ::device_library_target
+    ::info
+    ::kfd
+    ::packet
+    ::pm4_program
+    ::queue_primitives
+    ::target_id
+    ::vmem
+  PUBLIC
+)
+
+iree_cc_library(
+  NAME
+    benchmark_flags
+  HDRS
+    "benchmark_flags.h"
+  SRCS
+    "benchmark_flags.cc"
+  DEPS
+    benchmark
+    iree::async
+    iree::base
+    iree::base::tooling::flags
+  TESTONLY
+  PUBLIC
+)
+
+iree_cc_binary_benchmark(
+  NAME
+    blit_benchmark
+  SRCS
+    "blit_benchmark.cc"
+  DEPS
+    ::benchmark_flags
+    benchmark
+    iree::async
+    iree::async::util::proactor_pool
+    iree::base
+    iree::base::threading
+    iree::base::tooling::flags
+    iree::hal
+    iree::hal::drivers::amdgpu::registration
+  TESTONLY
+  LABELS
+    "benchmark"
+    "driver=amdgpu"
+    "nodocker"
+)
+
 iree_cc_test(
   NAME
     block_pool_test
   SRCS
     "block_pool_test.cc"
   DEPS
+    ::block_pool
     ::topology
-    ::util
+    ::vmem
     iree::base
     iree::testing::gtest
     iree::testing::gtest_main
@@ -129,8 +446,8 @@
   SRCS
     "device_library_test.cc"
   DEPS
+    ::device_library
     ::topology
-    ::util
     iree::base
     iree::hal::drivers::amdgpu::device::headers
     iree::testing::gtest
@@ -144,13 +461,25 @@
 
 iree_cc_test(
   NAME
-    kfd_test
+    device_library_target_test
   SRCS
-    "kfd_test.cc"
+    "device_library_target_test.cc"
   DEPS
-    ::libhsa
-    ::topology
-    ::util
+    ::device_library_target
+    iree::base
+    iree::testing::gtest
+    iree::testing::gtest_main
+  GROUP
+    "iree-hal-drivers-amdgpu-tests"
+)
+
+iree_cc_test(
+  NAME
+    target_id_test
+  SRCS
+    "target_id_test.cc"
+  DEPS
+    ::target_id
     iree::base
     iree::testing::gtest
     iree::testing::gtest_main
@@ -163,13 +492,264 @@
 
 iree_cc_test(
   NAME
+    epoch_signal_table_test
+  SRCS
+    "epoch_signal_table_test.cc"
+  DEPS
+    ::queue_primitives
+    iree::async
+    iree::base
+    iree::testing::gtest
+    iree::testing::gtest_main
+  LABELS
+    "driver=amdgpu"
+    "nodocker"
+  GROUP
+    "iree-hal-drivers-amdgpu-tests"
+)
+
+iree_cc_test(
+  NAME
+    kernarg_ring_test
+  SRCS
+    "kernarg_ring_test.cc"
+  DEPS
+    ::queue_primitives
+    iree::base
+    iree::testing::gtest
+    iree::testing::gtest_main
+  LABELS
+    "driver=amdgpu"
+    "nodocker"
+  GROUP
+    "iree-hal-drivers-amdgpu-tests"
+)
+
+iree_cc_test(
+  NAME
+    queue_upload_ring_test
+  SRCS
+    "queue_upload_ring_test.cc"
+  DEPS
+    ::queue_primitives
+    iree::base
+    iree::testing::gtest
+    iree::testing::gtest_main
+  LABELS
+    "driver=amdgpu"
+    "nodocker"
+  GROUP
+    "iree-hal-drivers-amdgpu-tests"
+)
+
+iree_cc_test(
+  NAME
+    device_clock_test
+  SRCS
+    "device_clock_test.cc"
+  DEPS
+    ::device_clock
+    iree::base
+    iree::testing::gtest
+    iree::testing::gtest_main
+)
+
+iree_cc_test(
+  NAME
+    kfd_test
+  SRCS
+    "kfd_test.cc"
+  DEPS
+    ::kfd
+    ::libhsa
+    ::topology
+    iree::base
+    iree::testing::gtest
+    iree::testing::gtest_main
+  LABELS
+    "driver=amdgpu"
+    "nodocker"
+  GROUP
+    "iree-hal-drivers-amdgpu-tests"
+)
+
+iree_cc_test(
+  NAME
+    notification_ring_test
+  SRCS
+    "notification_ring_test.cc"
+  DEPS
+    ::libhsa
+    ::queue_primitives
+    iree::async
+    iree::async::platform
+    iree::base
+    iree::base::internal::arena
+    iree::testing::gtest
+    iree::testing::gtest_main
+  LABELS
+    "driver=amdgpu"
+    "nodocker"
+  GROUP
+    "iree-hal-drivers-amdgpu-tests"
+)
+
+iree_cc_test(
+  NAME
+    aql_emitter_test
+  SRCS
+    "aql_emitter_test.cc"
+  DEPS
+    ::packet
+    iree::testing::gtest
+    iree::testing::gtest_main
+  LABELS
+    "driver=amdgpu"
+    "nodocker"
+  GROUP
+    "iree-hal-drivers-amdgpu-tests"
+)
+
+iree_cc_test(
+  NAME
+    pm4_emitter_test
+  SRCS
+    "pm4_emitter_test.cc"
+  DEPS
+    ::packet
+    iree::testing::gtest
+    iree::testing::gtest_main
+  LABELS
+    "driver=amdgpu"
+    "nodocker"
+  GROUP
+    "iree-hal-drivers-amdgpu-tests"
+)
+
+iree_cc_test(
+  NAME
+    pm4_program_test
+  SRCS
+    "pm4_program_test.cc"
+  DEPS
+    ::packet
+    ::pm4_program
+    ::topology
+    ::vmem
+    iree::base
+    iree::testing::gtest
+    iree::testing::gtest_main
+  LABELS
+    "driver=amdgpu"
+    "nodocker"
+  GROUP
+    "iree-hal-drivers-amdgpu-tests"
+)
+if(TARGET iree::hal::drivers::amdgpu::cts::testdata_amdgpu)
+add_custom_command(OUTPUT queue_benchmark_testdata_srcs.stamp
+    COMMAND ${CMAKE_COMMAND} -E touch queue_benchmark_testdata_srcs.stamp
+  DEPENDS
+    "queue_benchmark_testdata.mlir"
+)
+
+add_custom_target(queue_benchmark_testdata_srcs
+    DEPENDS queue_benchmark_testdata_srcs.stamp
+)
+
+iree_hal_cts_testdata(
+  FORMAT_NAME
+    amdgpu_queue_benchmark
+  TARGET_DEVICE
+    "amdgpu"
+  IDENTIFIER
+    "iree_queue_benchmark_testdata_amdgpu"
+  BACKEND_NAME
+    "amdgpu"
+  FORMAT_STRING
+    "\"amdgcn-amd-amdhsa--${IREE_ROCM_TEST_TARGET_CHIP}\""
+  TESTDATA_DIR
+    "${PROJECT_SOURCE_DIR}/runtime/src/iree/hal/drivers/amdgpu/util"
+  FLAGS
+    "--iree-rocm-target=${IREE_ROCM_TEST_TARGET_CHIP}"
+)
+
+iree_cc_binary_benchmark(
+  NAME
+    queue_benchmark
+  SRCS
+    "queue_benchmark.cc"
+  DEPS
+    ::benchmark_flags
+    ::testdata_amdgpu_queue_benchmark
+    benchmark
+    iree::async
+    iree::async::util::proactor_pool
+    iree::base
+    iree::base::threading
+    iree::base::tooling::flags
+    iree::hal
+    iree::hal::drivers::amdgpu
+    iree::hal::drivers::amdgpu::cts::testdata_amdgpu
+    iree::hal::drivers::amdgpu::device::dispatch
+    iree::hal::drivers::amdgpu::device::headers
+    iree::hal::drivers::amdgpu::headers
+    iree::hal::drivers::amdgpu::queue_affinity
+    iree::hal::drivers::amdgpu::registration
+    iree::hal::memory::tlsf_pool
+    iree::io::file_handle
+  TESTONLY
+  LABELS
+    "benchmark"
+    "driver=amdgpu"
+    "nodocker"
+)
+endif()
+iree_cc_test(
+  NAME
+    signal_pool_test
+  SRCS
+    "signal_pool_test.cc"
+  DEPS
+    ::libhsa
+    ::queue_primitives
+    ::topology
+    iree::base
+    iree::testing::gtest
+    iree::testing::gtest_main
+  LABELS
+    "driver=amdgpu"
+    "nodocker"
+  GROUP
+    "iree-hal-drivers-amdgpu-tests"
+)
+
+iree_cc_binary_benchmark(
+  NAME
+    signal_pool_benchmark
+  SRCS
+    "signal_pool_benchmark.cc"
+  DEPS
+    ::libhsa
+    ::queue_primitives
+    ::topology
+    benchmark
+    iree::base
+  TESTONLY
+  LABELS
+    "benchmark"
+    "driver=amdgpu"
+    "nodocker"
+)
+
+iree_cc_test(
+  NAME
     vmem_test
   SRCS
     "vmem_test.cc"
   DEPS
     ::libhsa
     ::topology
-    ::util
+    ::vmem
     iree::base
     iree::testing::gtest
     iree::testing::gtest_main
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/aql_emitter.h b/runtime/src/iree/hal/drivers/amdgpu/util/aql_emitter.h
new file mode 100644
index 0000000..d4e6e39
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/aql_emitter.h
@@ -0,0 +1,208 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+// AQL packet emission helpers. Pure functions that populate packet fields and
+// return the header bits. They do NOT write the header — the caller commits
+// it separately via iree_hal_amdgpu_aql_ring_commit(), which performs the
+// atomic store-release that publishes the packet to the CP.
+//
+// This separation allows the caller to:
+//   - Batch multiple packet commits before a single doorbell ring
+//   - Control completion_signal assignment (epoch signal on last packet only)
+//   - Populate kernarg memory between emission and commit
+//
+// All emitters zero reserved fields to prevent undefined behavior from stale
+// ring data.
+//
+// Direct host-queue submissions currently set the BARRIER bit on every packet
+// so one AQL queue behaves as a single in-order dependency chain.
+// Command-buffer replay is more precise: it sets the BARRIER bit only at
+// wait-prefix, execution-barrier, and final-completion boundaries. HAL queue
+// submission order alone is not user-visible ordering; semaphore signal->wait
+// edges define the visible DAG.
+//
+// Submission ordering contract (from kernarg_ring.h):
+//   1. Reserve AQL ring slots (backpressure gate)
+//   2. Allocate kernarg blocks (sizing invariant guarantees space)
+//   3. Populate kernarg + packet fields (emit helpers)
+//   4. Commit packet headers (atomic store-release)
+//   5. Ring doorbell (once per batch)
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_UTIL_AQL_EMITTER_H_
+#define IREE_HAL_DRIVERS_AMDGPU_UTIL_AQL_EMITTER_H_
+
+#include <stdbool.h>
+#include <string.h>
+
+#include "iree/base/api.h"
+#include "iree/hal/drivers/amdgpu/abi/queue.h"
+#include "iree/hal/drivers/amdgpu/abi/signal.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Packet header controls shared by all AQL emitters.
+//
+// Direct host-queue submissions still use barrier+system-fence packets so one
+// queue behaves as an in-order chain. Command-buffer replay can opt into
+// non-barrier dispatch packets and only set the barrier bit at logical ordering
+// boundaries.
+typedef struct iree_hal_amdgpu_aql_packet_control_t {
+  // True when the packet participates in AQL queue-order dependency chaining.
+  bool has_barrier;
+  // Acquire fence scope encoded in the packet header.
+  iree_hsa_fence_scope_t acquire_fence_scope;
+  // Release fence scope encoded in the packet header.
+  iree_hsa_fence_scope_t release_fence_scope;
+} iree_hal_amdgpu_aql_packet_control_t;
+
+// Returns packet control with caller-selected barrier policy and scopes.
+static inline iree_hal_amdgpu_aql_packet_control_t
+iree_hal_amdgpu_aql_packet_control(bool has_barrier,
+                                   iree_hsa_fence_scope_t acquire_fence_scope,
+                                   iree_hsa_fence_scope_t release_fence_scope) {
+  iree_hal_amdgpu_aql_packet_control_t packet_control;
+  packet_control.has_barrier = has_barrier;
+  packet_control.acquire_fence_scope = acquire_fence_scope;
+  packet_control.release_fence_scope = release_fence_scope;
+  return packet_control;
+}
+
+// Returns packet control for a barrier packet with caller-selected scopes.
+static inline iree_hal_amdgpu_aql_packet_control_t
+iree_hal_amdgpu_aql_packet_control_barrier(
+    iree_hsa_fence_scope_t acquire_fence_scope,
+    iree_hsa_fence_scope_t release_fence_scope) {
+  return iree_hal_amdgpu_aql_packet_control(
+      /*has_barrier=*/true, acquire_fence_scope, release_fence_scope);
+}
+
+// Returns the current host-queue packet policy: barrier + system-scope fences.
+static inline iree_hal_amdgpu_aql_packet_control_t
+iree_hal_amdgpu_aql_packet_control_barrier_system(void) {
+  return iree_hal_amdgpu_aql_packet_control_barrier(
+      IREE_HSA_FENCE_SCOPE_SYSTEM, IREE_HSA_FENCE_SCOPE_SYSTEM);
+}
+
+// Builds the 16-bit packet header from |packet_type| and |packet_control|.
+static inline uint16_t iree_hal_amdgpu_aql_make_header(
+    iree_hsa_packet_type_t packet_type,
+    iree_hal_amdgpu_aql_packet_control_t packet_control) {
+  return (uint16_t)iree_hsa_make_packet_header(
+      packet_type, packet_control.has_barrier,
+      packet_control.acquire_fence_scope, packet_control.release_fence_scope);
+}
+
+// Populates a kernel dispatch packet and returns the 16-bit AQL header.
+// The grid dimensions (setup field) are returned via |out_setup| for the
+// caller to pass to iree_hal_amdgpu_aql_ring_commit().
+//
+// Does NOT write the header word — the caller commits it after all packet
+// fields and kernarg memory are fully populated.
+static inline uint16_t iree_hal_amdgpu_aql_emit_dispatch(
+    iree_hsa_kernel_dispatch_packet_t* packet, uint64_t kernel_object,
+    const void* kernarg_address, const uint16_t workgroup_size[3],
+    const uint32_t grid_size[3], uint32_t private_segment_size,
+    uint32_t group_segment_size,
+    iree_hal_amdgpu_aql_packet_control_t packet_control,
+    iree_hsa_signal_t completion_signal, uint16_t* out_setup) {
+  // Setup encodes the number of grid dimensions (always 3 for IREE).
+  *out_setup = 3;
+
+  packet->workgroup_size[0] = workgroup_size[0];
+  packet->workgroup_size[1] = workgroup_size[1];
+  packet->workgroup_size[2] = workgroup_size[2];
+  packet->reserved0 = 0;
+  packet->grid_size[0] = grid_size[0];
+  packet->grid_size[1] = grid_size[1];
+  packet->grid_size[2] = grid_size[2];
+  packet->private_segment_size = private_segment_size;
+  packet->group_segment_size = group_segment_size;
+  packet->kernel_object = kernel_object;
+  packet->kernarg_address = (void*)kernarg_address;
+  packet->reserved2 = 0;
+  packet->completion_signal = completion_signal;
+
+  return iree_hal_amdgpu_aql_make_header(IREE_HSA_PACKET_TYPE_KERNEL_DISPATCH,
+                                         packet_control);
+}
+
+// Populates an AMD barrier-value packet and returns the 16-bit AQL header.
+// The vendor packet's upper 16 commit bits carry AmdFormat/reserved instead of
+// the normal dispatch setup field and are returned in |out_setup|.
+//
+// The barrier halts the CP until:
+//   (signal_load(dep_signal) & mask) CONDITION compare_value
+//
+// For cross-queue epoch waits the typical usage is:
+//   dep_signal = source_queue->epoch.signal
+//   condition  = IREE_HSA_SIGNAL_CONDITION_LT
+//   compare_value = EPOCH_INITIAL_VALUE - target_epoch + 1
+//   mask = INT64_MAX (all non-sign bits)
+static inline uint16_t iree_hal_amdgpu_aql_emit_barrier_value(
+    iree_hsa_amd_barrier_value_packet_t* packet, iree_hsa_signal_t dep_signal,
+    iree_hsa_signal_condition_t condition,
+    iree_hsa_signal_value_t compare_value, iree_hsa_signal_value_t mask,
+    iree_hal_amdgpu_aql_packet_control_t packet_control,
+    iree_hsa_signal_t completion_signal, uint16_t* out_setup) {
+  // Keep the entire first dword (primary header + AmdFormat/reserved) untouched
+  // until aql_ring_commit publishes it with release semantics.
+  packet->reserved0 = 0;
+  packet->signal = dep_signal;
+  packet->value = compare_value;
+  packet->mask = mask;
+  packet->cond = (iree_hsa_signal_condition32_t)condition;
+  packet->reserved1 = 0;
+  packet->reserved2 = 0;
+  packet->reserved3 = 0;
+  packet->completion_signal = completion_signal;
+  *out_setup = IREE_HSA_AMD_AQL_FORMAT_BARRIER_VALUE;
+
+  // The primary header uses VENDOR_SPECIFIC packet type for AMD extensions.
+  return iree_hal_amdgpu_aql_make_header(IREE_HSA_PACKET_TYPE_VENDOR_SPECIFIC,
+                                         packet_control);
+}
+
+// Populates a barrier-AND packet and returns the 16-bit AQL header.
+// The barrier halts the CP until all non-null dependency signals reach 0.
+// Up to 5 dependency signals are supported per packet.
+static inline uint16_t iree_hal_amdgpu_aql_emit_barrier_and(
+    iree_hsa_barrier_and_packet_t* packet, const iree_hsa_signal_t* dep_signals,
+    uint32_t dep_count, iree_hal_amdgpu_aql_packet_control_t packet_control,
+    iree_hsa_signal_t completion_signal) {
+  packet->reserved0 = 0;
+  packet->reserved1 = 0;
+  // Fill dependency signals, nulling any unused slots.
+  for (uint32_t i = 0; i < IREE_ARRAYSIZE(packet->dep_signal); ++i) {
+    packet->dep_signal[i] =
+        i < dep_count ? dep_signals[i] : iree_hsa_signal_null();
+  }
+  packet->reserved2 = 0;
+  packet->completion_signal = completion_signal;
+
+  return iree_hal_amdgpu_aql_make_header(IREE_HSA_PACKET_TYPE_BARRIER_AND,
+                                         packet_control);
+}
+
+// Populates a no-op packet in |packet| and returns its 16-bit AQL header.
+// Zero-dependency BARRIER_AND is the canonical "consume one slot, do no work"
+// packet. Callers choose whether that no-op packet carries a barrier edge and
+// what fence scopes it should use.
+static inline uint16_t iree_hal_amdgpu_aql_emit_nop(
+    iree_hsa_barrier_and_packet_t* packet,
+    iree_hal_amdgpu_aql_packet_control_t packet_control,
+    iree_hsa_signal_t completion_signal) {
+  return iree_hal_amdgpu_aql_emit_barrier_and(packet, /*dep_signals=*/NULL,
+                                              /*dep_count=*/0, packet_control,
+                                              completion_signal);
+}
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_UTIL_AQL_EMITTER_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/aql_emitter_test.cc b/runtime/src/iree/hal/drivers/amdgpu/util/aql_emitter_test.cc
new file mode 100644
index 0000000..a289d92
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/aql_emitter_test.cc
@@ -0,0 +1,54 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/util/aql_emitter.h"
+
+#include <cstring>
+
+#include "iree/testing/gtest.h"
+
+namespace {
+
+static iree_hsa_signal_t MakeSignal(uint64_t handle) {
+  iree_hsa_signal_t signal = {};
+  signal.handle = handle;
+  return signal;
+}
+
+TEST(AQLEmitterTest, EmitsBarrierValuePacketBody) {
+  iree_hsa_amd_barrier_value_packet_t packet;
+  std::memset(&packet, 0xCC, sizeof(packet));
+
+  const iree_hal_amdgpu_aql_packet_control_t packet_control =
+      iree_hal_amdgpu_aql_packet_control_barrier(IREE_HSA_FENCE_SCOPE_AGENT,
+                                                 IREE_HSA_FENCE_SCOPE_SYSTEM);
+  uint16_t setup = 0;
+  const uint16_t header = iree_hal_amdgpu_aql_emit_barrier_value(
+      &packet, MakeSignal(0x123456789ABCDEF0ull), IREE_HSA_SIGNAL_CONDITION_LT,
+      /*compare_value=*/0x1122334455667788ll,
+      /*mask=*/0x7FFFFFFFFFFFFFFFll, packet_control,
+      MakeSignal(0x0FEDCBA987654320ull), &setup);
+
+  EXPECT_EQ(header, iree_hal_amdgpu_aql_make_header(
+                        IREE_HSA_PACKET_TYPE_VENDOR_SPECIFIC, packet_control));
+  EXPECT_EQ(setup, IREE_HSA_AMD_AQL_FORMAT_BARRIER_VALUE);
+
+  uint32_t first_dword = 0;
+  std::memcpy(&first_dword, &packet, sizeof(first_dword));
+  EXPECT_EQ(first_dword, 0xCCCCCCCCu);
+  EXPECT_EQ(packet.reserved0, 0u);
+  EXPECT_EQ(packet.signal.handle, 0x123456789ABCDEF0ull);
+  EXPECT_EQ(packet.value, 0x1122334455667788ll);
+  EXPECT_EQ(packet.mask, 0x7FFFFFFFFFFFFFFFll);
+  EXPECT_EQ(packet.cond, static_cast<iree_hsa_signal_condition32_t>(
+                             IREE_HSA_SIGNAL_CONDITION_LT));
+  EXPECT_EQ(packet.reserved1, 0u);
+  EXPECT_EQ(packet.reserved2, 0u);
+  EXPECT_EQ(packet.reserved3, 0u);
+  EXPECT_EQ(packet.completion_signal.handle, 0x0FEDCBA987654320ull);
+}
+
+}  // namespace
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/aql_ring.h b/runtime/src/iree/hal/drivers/amdgpu/util/aql_ring.h
new file mode 100644
index 0000000..2b60a6c
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/aql_ring.h
@@ -0,0 +1,230 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+// Low-level interface to the hardware AQL ring buffer. Handles packet slot
+// reservation, header commit, and doorbell signaling.
+//
+// The ring caches hot pointers from iree_amd_queue_t at initialization:
+// ring base, mask, doorbell MMIO pointer, and the atomic write/read dispatch
+// IDs. All hot-path operations are inline with zero libhsa indirection.
+//
+// Thread safety:
+//   reserve() is multi-producer safe (atomic_fetch_add on write_dispatch_id).
+//   commit() is per-slot (no cross-slot contention).
+//   doorbell() is idempotent (multiple concurrent writes are harmless).
+//
+// Memory ordering:
+//   The CP starts processing a packet when it observes a valid header. The
+//   header commit (atomic store-release) is therefore the publication barrier:
+//   all prior writes to the packet's fields and to kernarg memory are ordered
+//   before the CP can read them. The doorbell is a wakeup hint with release
+//   semantics to ensure header writes are visible before the CP re-scans.
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_UTIL_AQL_RING_H_
+#define IREE_HAL_DRIVERS_AMDGPU_UTIL_AQL_RING_H_
+
+#include "iree/base/api.h"
+#include "iree/base/internal/atomics.h"
+#include "iree/base/threading/processor.h"
+#include "iree/hal/drivers/amdgpu/abi/queue.h"
+#include "iree/hal/drivers/amdgpu/abi/signal.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+//===----------------------------------------------------------------------===//
+// iree_hal_amdgpu_aql_packet_t
+//===----------------------------------------------------------------------===//
+
+// A single AQL packet slot. 64 bytes, cache-line aligned. Using a union gives
+// natural ring[i] indexing and allows type-punning to the specific packet type
+// without casts. The alignment attribute ensures the compiler knows every
+// packet is cache-line aligned at every access point.
+typedef union iree_alignas(64) iree_hal_amdgpu_aql_packet_t {
+  iree_hsa_kernel_dispatch_packet_t dispatch;
+  iree_hsa_barrier_and_packet_t barrier_and;
+  iree_hsa_barrier_or_packet_t barrier_or;
+  iree_hsa_amd_aql_pm4_ib_packet_t pm4_ib;
+  iree_hsa_amd_barrier_value_packet_t barrier_value;
+  uint8_t raw[64];
+} iree_hal_amdgpu_aql_packet_t;
+static_assert(sizeof(iree_hal_amdgpu_aql_packet_t) == 64,
+              "AQL packet must be exactly one cache line");
+
+//===----------------------------------------------------------------------===//
+// iree_hal_amdgpu_aql_ring_t
+//===----------------------------------------------------------------------===//
+
+// Cached hardware AQL ring buffer state. Initialized once from iree_amd_queue_t
+// and used for all subsequent packet operations. The cached pointers avoid
+// repeated indirection through the queue descriptor on the hot path.
+typedef struct iree_hal_amdgpu_aql_ring_t {
+  // Packet ring buffer base (from hsa_queue.base_address), cast for natural
+  // indexing: ring.base[id & ring.mask] gives the packet slot.
+  iree_hal_amdgpu_aql_packet_t* base;
+
+  // Power-of-two ring mask (hsa_queue.size - 1). Slot = packet_id & mask.
+  uint32_t mask;
+
+  // Cached hardware doorbell MMIO pointer. Resolved at init from the doorbell
+  // signal's iree_amd_signal_t.hardware_doorbell_ptr. Writing a packet ID here
+  // wakes the CP to process new packets. Inlined: no libhsa function pointer
+  // indirection, just an atomic store to MMIO.
+  volatile int64_t* doorbell;
+
+  // Atomic write dispatch ID. Points into the hardware queue descriptor
+  // (iree_amd_queue_t.write_dispatch_id). Multi-producer safe: each thread
+  // atomically increments to reserve a unique range of packet IDs.
+  iree_atomic_int64_t* write_dispatch_id;
+
+  // Read dispatch ID. Points into the hardware queue descriptor
+  // (iree_amd_queue_t.read_dispatch_id). Advanced by the packet processor as
+  // queue slots become reusable. This is a packet-slot lifetime signal, not a
+  // proof that unrelated sidecar storage associated with a dispatch has been
+  // consumed by later queue work.
+  const volatile int64_t* read_dispatch_id;
+} iree_hal_amdgpu_aql_ring_t;
+
+// Initializes the AQL ring from a hardware queue descriptor.
+// Resolves the doorbell pointer from the signal's iree_amd_signal_t and
+// caches all hot pointers for zero-indirection access.
+static inline void iree_hal_amdgpu_aql_ring_initialize(
+    iree_amd_queue_t* hardware_queue, iree_hal_amdgpu_aql_ring_t* out_ring) {
+  out_ring->base =
+      (iree_hal_amdgpu_aql_packet_t*)hardware_queue->hsa_queue.base_address;
+  out_ring->mask = hardware_queue->hsa_queue.size - 1;
+
+  // Resolve the doorbell MMIO pointer from the signal handle. The signal is
+  // DOORBELL kind: its hardware_doorbell_ptr points to the memory-mapped
+  // doorbell register. Writing a packet ID there wakes the CP.
+  iree_amd_signal_t* doorbell_signal =
+      (iree_amd_signal_t*)hardware_queue->hsa_queue.doorbell_signal.handle;
+  out_ring->doorbell =
+      (volatile int64_t*)doorbell_signal->hardware_doorbell_ptr;
+
+  out_ring->write_dispatch_id =
+      (iree_atomic_int64_t*)&hardware_queue->write_dispatch_id;
+  out_ring->read_dispatch_id =
+      (const volatile int64_t*)&hardware_queue->read_dispatch_id;
+}
+
+// Reserves |count| contiguous packet slots. Returns the first packet ID.
+// Multi-producer safe (atomic_fetch_add on write_dispatch_id). Spins if the
+// ring is full (GPU hasn't consumed enough packets).
+//
+// The caller populates reserved slots via iree_hal_amdgpu_aql_ring_packet()
+// and commits each via iree_hal_amdgpu_aql_ring_commit(). The doorbell
+// should be rung once after all packets in a batch are committed.
+//
+// IMPORTANT: The returned packet slots have INVALID headers from the CP's
+// perspective. The CP will not process them until a valid header is published.
+// Normal host submissions commit every reserved slot before ringing the
+// doorbell; device-side patching may intentionally leave a later slot invalid
+// only when an earlier packet is guaranteed to publish it before the CP reaches
+// that slot.
+static inline uint64_t iree_hal_amdgpu_aql_ring_reserve(
+    iree_hal_amdgpu_aql_ring_t* ring, uint32_t count) {
+  // Atomically claim |count| slots. Each concurrent thread gets a unique,
+  // non-overlapping range. Relaxed ordering: we're only claiming indices,
+  // not publishing data. Data visibility comes from the header commit.
+  const uint64_t first_id = (uint64_t)iree_atomic_fetch_add(
+      ring->write_dispatch_id, (int64_t)count, iree_memory_order_relaxed);
+
+  // Backpressure: spin until the ring has space for all reserved slots.
+  // The ring can hold (mask + 1) packets. The CP advances read_dispatch_id
+  // as packet slots become reusable.
+  const uint64_t ring_capacity = (uint64_t)(ring->mask + 1);
+  while (IREE_UNLIKELY(first_id + count -
+                           (uint64_t)iree_atomic_load(
+                               (iree_atomic_int64_t*)ring->read_dispatch_id,
+                               iree_memory_order_acquire) >
+                       ring_capacity)) {
+    iree_processor_yield();
+  }
+
+  return first_id;
+}
+
+// Attempts to reserve |count| contiguous packet slots without waiting.
+//
+// Returns false if the ring does not currently have room for the entire range.
+// Unlike reserve(), this never advances write_dispatch_id unless the range is
+// immediately available, so callers can park and retry after completion drain.
+static inline bool iree_hal_amdgpu_aql_ring_try_reserve(
+    iree_hal_amdgpu_aql_ring_t* ring, uint32_t count, uint64_t* out_first_id) {
+  const uint64_t ring_capacity = (uint64_t)(ring->mask + 1);
+  int64_t current_write =
+      iree_atomic_load(ring->write_dispatch_id, iree_memory_order_relaxed);
+  for (;;) {
+    const uint64_t first_id = (uint64_t)current_write;
+    const uint64_t current_read =
+        (uint64_t)iree_atomic_load((iree_atomic_int64_t*)ring->read_dispatch_id,
+                                   iree_memory_order_acquire);
+    if (first_id + count - current_read > ring_capacity) {
+      *out_first_id = 0;
+      return false;
+    }
+    int64_t desired_write = current_write + (int64_t)count;
+    if (iree_atomic_compare_exchange_weak(
+            ring->write_dispatch_id, &current_write, desired_write,
+            iree_memory_order_acq_rel, iree_memory_order_relaxed)) {
+      *out_first_id = first_id;
+      return true;
+    }
+  }
+}
+
+// Returns a pointer to the packet slot for |packet_id|. The caller
+// populates the packet fields (all except the header) and then commits
+// the header via iree_hal_amdgpu_aql_ring_commit().
+static inline iree_hal_amdgpu_aql_packet_t* iree_hal_amdgpu_aql_ring_packet(
+    iree_hal_amdgpu_aql_ring_t* ring, uint64_t packet_id) {
+  return &ring->base[packet_id & ring->mask];
+}
+
+// Commits a packet by writing its header + setup as a single atomic
+// store-release. This makes the packet visible to the CP: the release
+// semantics ensure all prior writes to the packet's fields (and to
+// kernarg memory) are ordered before the CP can read them.
+//
+// |header| is the 16-bit AQL packet header (type, barrier, fence scopes).
+// |setup| is the upper 16 bits: grid dimensions for dispatch packets, zero for
+// standard barriers, or vendor-specific format bits for extension packets.
+//
+// Use iree_hsa_make_packet_header() from abi/queue.h to construct |header|.
+static inline void iree_hal_amdgpu_aql_ring_commit(
+    iree_hal_amdgpu_aql_packet_t* packet, uint16_t header, uint16_t setup) {
+  const uint32_t header_setup = (uint32_t)header | ((uint32_t)setup << 16);
+  // The first 32 bits of every AQL packet are the header (16 bits) +
+  // setup/reserved (16 bits). An atomic store-release on this word is the
+  // publication barrier for the entire packet.
+  iree_atomic_store((iree_atomic_int32_t*)packet, (int32_t)header_setup,
+                    iree_memory_order_release);
+}
+
+// Rings the hardware doorbell to wake the CP. The |packet_id| should be the
+// highest committed packet ID (the CP processes from its current position up
+// to the doorbell value). Call once after committing all packets in a batch
+// to amortize the PCIe BAR write cost (~100-300ns).
+//
+// The doorbell is purely a wakeup hint — the CP will process any packet
+// whose header is valid, regardless of whether the doorbell has been rung.
+// Multiple concurrent doorbell writes are harmless (the CP re-scans).
+//
+// Release semantics ensure all committed packet headers are visible before
+// the CP wakes and starts scanning.
+static inline void iree_hal_amdgpu_aql_ring_doorbell(
+    iree_hal_amdgpu_aql_ring_t* ring, uint64_t packet_id) {
+  iree_atomic_store((iree_atomic_int64_t*)ring->doorbell, (int64_t)packet_id,
+                    iree_memory_order_release);
+}
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_UTIL_AQL_RING_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/benchmark_flags.cc b/runtime/src/iree/hal/drivers/amdgpu/util/benchmark_flags.cc
new file mode 100644
index 0000000..c5e1169
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/benchmark_flags.cc
@@ -0,0 +1,70 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/util/benchmark_flags.h"
+
+#include <cstdio>
+
+#include "iree/base/api.h"
+#include "iree/base/tooling/flags.h"
+
+static iree_status_t parse_completion_wait_flags(iree_string_view_t flag_name,
+                                                 void* storage,
+                                                 iree_string_view_t value) {
+  (void)flag_name;
+  iree_async_wait_flags_t* wait_flags = (iree_async_wait_flags_t*)storage;
+  if (iree_string_view_equal(value, IREE_SV("none"))) {
+    *wait_flags = IREE_ASYNC_WAIT_FLAG_NONE;
+    return iree_ok_status();
+  } else if (iree_string_view_equal(value, IREE_SV("yield"))) {
+    *wait_flags = IREE_ASYNC_WAIT_FLAG_YIELD;
+    return iree_ok_status();
+  } else if (iree_string_view_equal(value, IREE_SV("active"))) {
+    *wait_flags = IREE_ASYNC_WAIT_FLAG_ACTIVE;
+    return iree_ok_status();
+  }
+  return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                          "unsupported completion wait flags");
+}
+
+static void print_completion_wait_flags(iree_string_view_t flag_name,
+                                        void* storage, FILE* file) {
+  const iree_async_wait_flags_t wait_flags =
+      *(const iree_async_wait_flags_t*)storage;
+  const char* wait_flags_string = "none";
+  if (iree_any_bit_set(wait_flags, IREE_ASYNC_WAIT_FLAG_ACTIVE)) {
+    wait_flags_string = "active";
+  } else if (iree_any_bit_set(wait_flags, IREE_ASYNC_WAIT_FLAG_YIELD)) {
+    wait_flags_string = "yield";
+  }
+  fprintf(file, "--%.*s=\"%s\"\n", (int)flag_name.size, flag_name.data,
+          wait_flags_string);
+}
+
+static iree_async_wait_flags_t FLAG_completion_wait_flags =
+    IREE_ASYNC_WAIT_FLAG_NONE;
+IREE_FLAG_CALLBACK(
+    parse_completion_wait_flags, print_completion_wait_flags,
+    &FLAG_completion_wait_flags, completion_wait_flags,
+    "Wait strategy used by benchmark completion waits. One of 'none', "
+    "'yield', or 'active'. Active wait spins on the calling thread and should "
+    "only be used for latency-sensitive short waits.");
+
+iree_async_wait_flags_t iree_hal_amdgpu_benchmark_completion_wait_flags(void) {
+  return FLAG_completion_wait_flags;
+}
+
+void iree_hal_amdgpu_benchmark_set_completion_wait_counters(
+    benchmark::State& state) {
+  state.counters["completion_wait_active"] =
+      iree_any_bit_set(FLAG_completion_wait_flags, IREE_ASYNC_WAIT_FLAG_ACTIVE)
+          ? 1.0
+          : 0.0;
+  state.counters["completion_wait_yield"] =
+      iree_any_bit_set(FLAG_completion_wait_flags, IREE_ASYNC_WAIT_FLAG_YIELD)
+          ? 1.0
+          : 0.0;
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/benchmark_flags.h b/runtime/src/iree/hal/drivers/amdgpu/util/benchmark_flags.h
new file mode 100644
index 0000000..a6f02ff
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/benchmark_flags.h
@@ -0,0 +1,22 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_UTIL_BENCHMARK_FLAGS_H_
+#define IREE_HAL_DRIVERS_AMDGPU_UTIL_BENCHMARK_FLAGS_H_
+
+#include <benchmark/benchmark.h>
+
+#include "iree/async/operation.h"
+
+// Returns the async wait strategy selected by --completion_wait_flags.
+iree_async_wait_flags_t iree_hal_amdgpu_benchmark_completion_wait_flags(void);
+
+// Adds completion wait mode labels to |state| so benchmark rows remain
+// self-describing when active/yield/blocking wait results are compared.
+void iree_hal_amdgpu_benchmark_set_completion_wait_counters(
+    benchmark::State& state);
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_UTIL_BENCHMARK_FLAGS_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/blit_benchmark.cc b/runtime/src/iree/hal/drivers/amdgpu/util/blit_benchmark.cc
new file mode 100644
index 0000000..930936c
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/blit_benchmark.cc
@@ -0,0 +1,707 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+// TODO(benvanik): move this to the CTS (add a benchmark path ala iree/async/).
+
+#include <benchmark/benchmark.h>
+
+#include <cstdint>
+#include <vector>
+
+#include "iree/async/frontier_tracker.h"
+#include "iree/async/util/proactor_pool.h"
+#include "iree/base/api.h"
+#include "iree/base/threading/numa.h"
+#include "iree/base/tooling/flags.h"
+#include "iree/hal/api.h"
+#include "iree/hal/drivers/amdgpu/registration/driver_module.h"
+#include "iree/hal/drivers/amdgpu/util/benchmark_flags.h"
+
+namespace {
+
+constexpr iree_device_size_t kBenchmarkBufferAlignment = 16;
+constexpr int64_t kBatchCount = 20;
+constexpr int64_t kSubmitOnlyIterations = 200;
+constexpr uint32_t kFrontierAxisTableCapacity = 256;
+
+class BlitBenchmark : public benchmark::Fixture {
+ public:
+  static void InitializeOnce() {
+    if (initialized_) return;
+    initialized_ = true;
+    host_allocator_ = iree_allocator_system();
+
+    iree_status_t status = iree_hal_amdgpu_driver_module_register(
+        iree_hal_driver_registry_default());
+    if (iree_status_is_already_exists(status)) {
+      iree_status_free(status);
+      status = iree_ok_status();
+    }
+
+    if (iree_status_is_ok(status)) {
+      status = iree_hal_driver_registry_try_create(
+          iree_hal_driver_registry_default(), iree_make_cstring_view("amdgpu"),
+          host_allocator_, &driver_);
+    }
+
+    iree_async_proactor_pool_t* proactor_pool = nullptr;
+    if (iree_status_is_ok(status)) {
+      status = iree_async_proactor_pool_create(
+          iree_numa_node_count(), /*node_ids=*/nullptr,
+          iree_async_proactor_pool_options_default(), host_allocator_,
+          &proactor_pool);
+    }
+
+    if (iree_status_is_ok(status)) {
+      iree_hal_device_create_params_t create_params =
+          iree_hal_device_create_params_default();
+      create_params.proactor_pool = proactor_pool;
+      status = iree_hal_driver_create_default_device(driver_, &create_params,
+                                                     host_allocator_, &device_);
+    }
+    iree_async_proactor_pool_release(proactor_pool);
+
+    iree_async_frontier_tracker_t* frontier_tracker = nullptr;
+    if (iree_status_is_ok(status)) {
+      iree_async_frontier_tracker_options_t options =
+          iree_async_frontier_tracker_options_default();
+      options.axis_table_capacity = kFrontierAxisTableCapacity;
+      status = iree_async_frontier_tracker_create(options, host_allocator_,
+                                                  &frontier_tracker);
+    }
+    if (iree_status_is_ok(status)) {
+      status = iree_hal_device_group_create_from_device(
+          device_, frontier_tracker, host_allocator_, &device_group_);
+    }
+    iree_async_frontier_tracker_release(frontier_tracker);
+
+    if (iree_status_is_ok(status)) {
+      available_ = true;
+      return;
+    }
+
+    iree_status_fprint(stderr, status);
+    iree_status_free(status);
+    iree_hal_device_release(device_);
+    iree_hal_driver_release(driver_);
+    device_ = nullptr;
+    driver_ = nullptr;
+  }
+
+  static void DeinitializeOnce() {
+    if (!initialized_) return;
+    iree_hal_device_release(device_);
+    iree_hal_device_group_release(device_group_);
+    iree_hal_driver_release(driver_);
+    device_ = nullptr;
+    device_group_ = nullptr;
+    driver_ = nullptr;
+    available_ = false;
+  }
+
+  void SetUp(benchmark::State& state) override {
+    InitializeOnce();
+    if (!available_) {
+      state.SkipWithError("AMDGPU HAL device not available");
+    }
+  }
+
+  void TearDown(benchmark::State& state) override { ReleaseBuffers(); }
+
+ protected:
+  bool PrepareCopy(benchmark::State& state) {
+    if (!available_) return false;
+    length_ = static_cast<iree_device_size_t>(state.range(0));
+    source_offset_ = static_cast<iree_device_size_t>(state.range(1));
+    target_offset_ = static_cast<iree_device_size_t>(state.range(2));
+    batch_count_ = 1;
+
+    iree_device_size_t required_size = source_offset_ + length_;
+    if (target_offset_ + length_ > required_size) {
+      required_size = target_offset_ + length_;
+    }
+    allocation_size_ = iree_device_align(
+        required_size + kBenchmarkBufferAlignment, kBenchmarkBufferAlignment);
+    return AllocateBenchmarkBuffers(state, /*needs_source=*/true);
+  }
+
+  bool PrepareCopyBatch(benchmark::State& state, int64_t batch_count) {
+    if (!PrepareCopy(state)) return false;
+    batch_count_ = batch_count;
+    return true;
+  }
+
+  bool PrepareFill(benchmark::State& state) {
+    if (!available_) return false;
+    length_ = static_cast<iree_device_size_t>(state.range(0));
+    target_offset_ = static_cast<iree_device_size_t>(state.range(1));
+    pattern_length_ = static_cast<iree_host_size_t>(state.range(2));
+    batch_count_ = 1;
+    if ((target_offset_ % pattern_length_) != 0 ||
+        (length_ % pattern_length_) != 0) {
+      state.SkipWithError("fill offset/length are not pattern-aligned");
+      return false;
+    }
+
+    allocation_size_ =
+        iree_device_align(target_offset_ + length_ + kBenchmarkBufferAlignment,
+                          kBenchmarkBufferAlignment);
+    return AllocateBenchmarkBuffers(state, /*needs_source=*/false);
+  }
+
+  bool PrepareFillBatch(benchmark::State& state, int64_t batch_count) {
+    if (!PrepareFill(state)) return false;
+    batch_count_ = batch_count;
+    return true;
+  }
+
+  bool PrepareUpdate(benchmark::State& state) {
+    if (!available_) return false;
+    length_ = static_cast<iree_device_size_t>(state.range(0));
+    source_offset_ = static_cast<iree_device_size_t>(state.range(1));
+    target_offset_ = static_cast<iree_device_size_t>(state.range(2));
+    batch_count_ = 1;
+
+    allocation_size_ =
+        iree_device_align(target_offset_ + length_ + kBenchmarkBufferAlignment,
+                          kBenchmarkBufferAlignment);
+    update_source_.resize(source_offset_ + length_ + kBenchmarkBufferAlignment);
+    for (size_t i = 0; i < update_source_.size(); ++i) {
+      update_source_[i] = static_cast<uint8_t>(0xA0u + (i & 0x3Fu));
+    }
+    return AllocateBenchmarkBuffers(state, /*needs_source=*/false);
+  }
+
+  bool PrepareUpdateBatch(benchmark::State& state, int64_t batch_count) {
+    if (!PrepareUpdate(state)) return false;
+    batch_count_ = batch_count;
+    return true;
+  }
+
+  iree_status_t QueueCopyAndWait() {
+    iree_hal_semaphore_t* semaphore = completion_semaphore_;
+    uint64_t payload_value = ++completion_payload_value_;
+    iree_hal_semaphore_list_t signal_semaphore_list = {
+        /*count=*/1,
+        /*semaphores=*/&semaphore,
+        /*payload_values=*/&payload_value,
+    };
+    IREE_RETURN_IF_ERROR(iree_hal_device_queue_copy(
+        device_, IREE_HAL_QUEUE_AFFINITY_ANY, iree_hal_semaphore_list_empty(),
+        signal_semaphore_list, source_buffer_, source_offset_, target_buffer_,
+        target_offset_, length_, IREE_HAL_COPY_FLAG_NONE));
+    return WaitForCompletion(payload_value);
+  }
+
+  iree_status_t QueueCopyBatchAndWait() {
+    uint64_t payload_value = 0;
+    IREE_RETURN_IF_ERROR(QueueCopyBatchSubmit(&payload_value));
+    return WaitForCompletion(payload_value);
+  }
+
+  iree_status_t QueueCopyBatchSubmit(uint64_t* out_payload_value) {
+    iree_hal_semaphore_t* semaphore = completion_semaphore_;
+    uint64_t payload_value = completion_payload_value_;
+    for (int64_t i = 0; i < batch_count_; ++i) {
+      uint64_t wait_payload_value = payload_value;
+      uint64_t signal_payload_value = payload_value + 1;
+      iree_hal_semaphore_list_t wait_semaphore_list =
+          iree_hal_semaphore_list_empty();
+      if (i > 0) {
+        wait_semaphore_list = {
+            /*count=*/1,
+            /*semaphores=*/&semaphore,
+            /*payload_values=*/&wait_payload_value,
+        };
+      }
+      iree_hal_semaphore_list_t signal_semaphore_list = {
+          /*count=*/1,
+          /*semaphores=*/&semaphore,
+          /*payload_values=*/&signal_payload_value,
+      };
+      IREE_RETURN_IF_ERROR(iree_hal_device_queue_copy(
+          device_, IREE_HAL_QUEUE_AFFINITY_ANY, wait_semaphore_list,
+          signal_semaphore_list, source_buffer_, source_offset_, target_buffer_,
+          target_offset_, length_, IREE_HAL_COPY_FLAG_NONE));
+      payload_value = signal_payload_value;
+    }
+    completion_payload_value_ = payload_value;
+    *out_payload_value = payload_value;
+    return iree_ok_status();
+  }
+
+  iree_status_t QueueFillAndWait(iree_hal_buffer_t* target_buffer,
+                                 iree_device_size_t target_offset,
+                                 iree_device_size_t length, const void* pattern,
+                                 iree_host_size_t pattern_length) {
+    iree_hal_semaphore_t* semaphore = completion_semaphore_;
+    uint64_t payload_value = ++completion_payload_value_;
+    iree_hal_semaphore_list_t signal_semaphore_list = {
+        /*count=*/1,
+        /*semaphores=*/&semaphore,
+        /*payload_values=*/&payload_value,
+    };
+    IREE_RETURN_IF_ERROR(iree_hal_device_queue_fill(
+        device_, IREE_HAL_QUEUE_AFFINITY_ANY, iree_hal_semaphore_list_empty(),
+        signal_semaphore_list, target_buffer, target_offset, length, pattern,
+        pattern_length, IREE_HAL_FILL_FLAG_NONE));
+    return WaitForCompletion(payload_value);
+  }
+
+  iree_status_t QueueBenchmarkFillAndWait() {
+    return QueueFillAndWait(target_buffer_, target_offset_, length_,
+                            &fill_pattern_, pattern_length_);
+  }
+
+  iree_status_t QueueFillBatchAndWait() {
+    uint64_t payload_value = 0;
+    IREE_RETURN_IF_ERROR(QueueFillBatchSubmit(&payload_value));
+    return WaitForCompletion(payload_value);
+  }
+
+  iree_status_t QueueFillBatchSubmit(uint64_t* out_payload_value) {
+    iree_hal_semaphore_t* semaphore = completion_semaphore_;
+    uint64_t payload_value = completion_payload_value_;
+    for (int64_t i = 0; i < batch_count_; ++i) {
+      uint64_t wait_payload_value = payload_value;
+      uint64_t signal_payload_value = payload_value + 1;
+      iree_hal_semaphore_list_t wait_semaphore_list =
+          iree_hal_semaphore_list_empty();
+      if (i > 0) {
+        wait_semaphore_list = {
+            /*count=*/1,
+            /*semaphores=*/&semaphore,
+            /*payload_values=*/&wait_payload_value,
+        };
+      }
+      iree_hal_semaphore_list_t signal_semaphore_list = {
+          /*count=*/1,
+          /*semaphores=*/&semaphore,
+          /*payload_values=*/&signal_payload_value,
+      };
+      IREE_RETURN_IF_ERROR(iree_hal_device_queue_fill(
+          device_, IREE_HAL_QUEUE_AFFINITY_ANY, wait_semaphore_list,
+          signal_semaphore_list, target_buffer_, target_offset_, length_,
+          &fill_pattern_, pattern_length_, IREE_HAL_FILL_FLAG_NONE));
+      payload_value = signal_payload_value;
+    }
+    completion_payload_value_ = payload_value;
+    *out_payload_value = payload_value;
+    return iree_ok_status();
+  }
+
+  iree_status_t QueueUpdateAndWait() {
+    iree_hal_semaphore_t* semaphore = completion_semaphore_;
+    uint64_t payload_value = ++completion_payload_value_;
+    iree_hal_semaphore_list_t signal_semaphore_list = {
+        /*count=*/1,
+        /*semaphores=*/&semaphore,
+        /*payload_values=*/&payload_value,
+    };
+    IREE_RETURN_IF_ERROR(iree_hal_device_queue_update(
+        device_, IREE_HAL_QUEUE_AFFINITY_ANY, iree_hal_semaphore_list_empty(),
+        signal_semaphore_list, update_source_.data(),
+        (iree_host_size_t)source_offset_, target_buffer_, target_offset_,
+        length_, IREE_HAL_UPDATE_FLAG_NONE));
+    return WaitForCompletion(payload_value);
+  }
+
+  iree_status_t QueueUpdateBatchAndWait() {
+    uint64_t payload_value = 0;
+    IREE_RETURN_IF_ERROR(QueueUpdateBatchSubmit(&payload_value));
+    return WaitForCompletion(payload_value);
+  }
+
+  iree_status_t QueueUpdateBatchSubmit(uint64_t* out_payload_value) {
+    iree_hal_semaphore_t* semaphore = completion_semaphore_;
+    uint64_t payload_value = completion_payload_value_;
+    for (int64_t i = 0; i < batch_count_; ++i) {
+      uint64_t wait_payload_value = payload_value;
+      uint64_t signal_payload_value = payload_value + 1;
+      iree_hal_semaphore_list_t wait_semaphore_list =
+          iree_hal_semaphore_list_empty();
+      if (i > 0) {
+        wait_semaphore_list = {
+            /*count=*/1,
+            /*semaphores=*/&semaphore,
+            /*payload_values=*/&wait_payload_value,
+        };
+      }
+      iree_hal_semaphore_list_t signal_semaphore_list = {
+          /*count=*/1,
+          /*semaphores=*/&semaphore,
+          /*payload_values=*/&signal_payload_value,
+      };
+      IREE_RETURN_IF_ERROR(iree_hal_device_queue_update(
+          device_, IREE_HAL_QUEUE_AFFINITY_ANY, wait_semaphore_list,
+          signal_semaphore_list, update_source_.data(),
+          (iree_host_size_t)source_offset_, target_buffer_, target_offset_,
+          length_, IREE_HAL_UPDATE_FLAG_NONE));
+      payload_value = signal_payload_value;
+    }
+    completion_payload_value_ = payload_value;
+    *out_payload_value = payload_value;
+    return iree_ok_status();
+  }
+
+  iree_status_t WaitForCompletion(uint64_t payload_value) {
+    return iree_hal_semaphore_wait(
+        completion_semaphore_, payload_value, iree_infinite_timeout(),
+        iree_hal_amdgpu_benchmark_completion_wait_flags());
+  }
+
+  bool HandleStatus(benchmark::State& state, iree_status_t status,
+                    const char* message) {
+    if (iree_status_is_ok(status)) return true;
+    iree_status_fprint(stderr, status);
+    iree_status_free(status);
+    state.SkipWithError(message);
+    return false;
+  }
+
+  void SetBytesProcessed(benchmark::State& state) {
+    iree_hal_amdgpu_benchmark_set_completion_wait_counters(state);
+    state.SetBytesProcessed(static_cast<int64_t>(state.iterations()) *
+                            batch_count_ * static_cast<int64_t>(length_));
+  }
+
+ private:
+  bool AllocateBenchmarkBuffers(benchmark::State& state, bool needs_source) {
+    iree_hal_allocator_t* allocator = iree_hal_device_allocator(device_);
+    iree_hal_buffer_params_t params = {0};
+    params.usage = IREE_HAL_BUFFER_USAGE_TRANSFER;
+    params.type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE;
+    params.min_alignment = kBenchmarkBufferAlignment;
+
+    if (!HandleStatus(state,
+                      iree_hal_semaphore_create(
+                          device_, IREE_HAL_QUEUE_AFFINITY_ANY,
+                          /*initial_value=*/0, IREE_HAL_SEMAPHORE_FLAG_DEFAULT,
+                          &completion_semaphore_),
+                      "failed to create completion semaphore")) {
+      return false;
+    }
+
+    if (needs_source) {
+      if (!HandleStatus(
+              state,
+              iree_hal_allocator_allocate_buffer(
+                  allocator, params, allocation_size_, &source_buffer_),
+              "failed to allocate source buffer")) {
+        return false;
+      }
+    }
+    if (!HandleStatus(state,
+                      iree_hal_allocator_allocate_buffer(
+                          allocator, params, allocation_size_, &target_buffer_),
+                      "failed to allocate target buffer")) {
+      return false;
+    }
+
+    uint8_t source_pattern = 0x5A;
+    if (needs_source &&
+        !HandleStatus(state,
+                      QueueFillAndWait(source_buffer_, /*target_offset=*/0,
+                                       allocation_size_, &source_pattern,
+                                       sizeof(source_pattern)),
+                      "failed to pre-initialize source buffer")) {
+      return false;
+    }
+    uint8_t target_pattern = 0x00;
+    return HandleStatus(
+        state,
+        QueueFillAndWait(target_buffer_, /*target_offset=*/0, allocation_size_,
+                         &target_pattern, sizeof(target_pattern)),
+        "failed to pre-initialize target buffer");
+  }
+
+  void ReleaseBuffers() {
+    iree_hal_buffer_release(source_buffer_);
+    iree_hal_buffer_release(target_buffer_);
+    iree_hal_semaphore_release(completion_semaphore_);
+    source_buffer_ = nullptr;
+    target_buffer_ = nullptr;
+    completion_semaphore_ = nullptr;
+    completion_payload_value_ = 0;
+    update_source_.clear();
+  }
+
+  static bool initialized_;
+  static bool available_;
+  static iree_allocator_t host_allocator_;
+  static iree_hal_driver_t* driver_;
+  static iree_hal_device_group_t* device_group_;
+  static iree_hal_device_t* device_;
+
+  iree_hal_buffer_t* source_buffer_ = nullptr;
+  iree_hal_buffer_t* target_buffer_ = nullptr;
+  iree_hal_semaphore_t* completion_semaphore_ = nullptr;
+  uint64_t completion_payload_value_ = 0;
+  iree_device_size_t allocation_size_ = 0;
+  iree_device_size_t length_ = 0;
+  iree_device_size_t source_offset_ = 0;
+  iree_device_size_t target_offset_ = 0;
+  iree_host_size_t pattern_length_ = 1;
+  int64_t batch_count_ = 1;
+  uint32_t fill_pattern_ = 0xDEADBEEFu;
+  std::vector<uint8_t> update_source_;
+};
+
+bool BlitBenchmark::initialized_ = false;
+bool BlitBenchmark::available_ = false;
+iree_allocator_t BlitBenchmark::host_allocator_;
+iree_hal_driver_t* BlitBenchmark::driver_ = nullptr;
+iree_hal_device_group_t* BlitBenchmark::device_group_ = nullptr;
+iree_hal_device_t* BlitBenchmark::device_ = nullptr;
+
+BENCHMARK_DEFINE_F(BlitBenchmark, QueueCopy)(benchmark::State& state) {
+  if (!PrepareCopy(state)) return;
+  for (auto _ : state) {
+    if (!HandleStatus(state, QueueCopyAndWait(), "queue_copy failed")) break;
+  }
+  SetBytesProcessed(state);
+}
+
+BENCHMARK_DEFINE_F(BlitBenchmark, QueueFill)(benchmark::State& state) {
+  if (!PrepareFill(state)) return;
+  for (auto _ : state) {
+    if (!HandleStatus(state, QueueBenchmarkFillAndWait(),
+                      "queue_fill failed")) {
+      break;
+    }
+  }
+  SetBytesProcessed(state);
+}
+
+BENCHMARK_DEFINE_F(BlitBenchmark, QueueCopyBatch20)(benchmark::State& state) {
+  if (!PrepareCopyBatch(state, kBatchCount)) return;
+  for (auto _ : state) {
+    if (!HandleStatus(state, QueueCopyBatchAndWait(),
+                      "queue_copy batch failed")) {
+      break;
+    }
+  }
+  SetBytesProcessed(state);
+}
+
+BENCHMARK_DEFINE_F(BlitBenchmark,
+                   QueueCopyBatch20SubmitOnly)(benchmark::State& state) {
+  if (!PrepareCopyBatch(state, kBatchCount)) return;
+  for (auto _ : state) {
+    uint64_t payload_value = 0;
+    if (!HandleStatus(state, QueueCopyBatchSubmit(&payload_value),
+                      "queue_copy batch submit failed")) {
+      break;
+    }
+    state.PauseTiming();
+    iree_status_t status = WaitForCompletion(payload_value);
+    state.ResumeTiming();
+    if (!HandleStatus(state, status, "queue_copy batch wait failed")) break;
+  }
+  SetBytesProcessed(state);
+}
+
+BENCHMARK_DEFINE_F(BlitBenchmark, QueueFillBatch20)(benchmark::State& state) {
+  if (!PrepareFillBatch(state, kBatchCount)) return;
+  for (auto _ : state) {
+    if (!HandleStatus(state, QueueFillBatchAndWait(),
+                      "queue_fill batch failed")) {
+      break;
+    }
+  }
+  SetBytesProcessed(state);
+}
+
+BENCHMARK_DEFINE_F(BlitBenchmark,
+                   QueueFillBatch20SubmitOnly)(benchmark::State& state) {
+  if (!PrepareFillBatch(state, kBatchCount)) return;
+  for (auto _ : state) {
+    uint64_t payload_value = 0;
+    if (!HandleStatus(state, QueueFillBatchSubmit(&payload_value),
+                      "queue_fill batch submit failed")) {
+      break;
+    }
+    state.PauseTiming();
+    iree_status_t status = WaitForCompletion(payload_value);
+    state.ResumeTiming();
+    if (!HandleStatus(state, status, "queue_fill batch wait failed")) break;
+  }
+  SetBytesProcessed(state);
+}
+
+BENCHMARK_DEFINE_F(BlitBenchmark, QueueUpdate)(benchmark::State& state) {
+  if (!PrepareUpdate(state)) return;
+  for (auto _ : state) {
+    if (!HandleStatus(state, QueueUpdateAndWait(), "queue_update failed")) {
+      break;
+    }
+  }
+  SetBytesProcessed(state);
+}
+
+BENCHMARK_DEFINE_F(BlitBenchmark, QueueUpdateBatch20)(benchmark::State& state) {
+  if (!PrepareUpdateBatch(state, kBatchCount)) return;
+  for (auto _ : state) {
+    if (!HandleStatus(state, QueueUpdateBatchAndWait(),
+                      "queue_update batch failed")) {
+      break;
+    }
+  }
+  SetBytesProcessed(state);
+}
+
+BENCHMARK_DEFINE_F(BlitBenchmark,
+                   QueueUpdateBatch20SubmitOnly)(benchmark::State& state) {
+  if (!PrepareUpdateBatch(state, kBatchCount)) return;
+  for (auto _ : state) {
+    uint64_t payload_value = 0;
+    if (!HandleStatus(state, QueueUpdateBatchSubmit(&payload_value),
+                      "queue_update batch submit failed")) {
+      break;
+    }
+    state.PauseTiming();
+    iree_status_t status = WaitForCompletion(payload_value);
+    state.ResumeTiming();
+    if (!HandleStatus(state, status, "queue_update batch wait failed")) break;
+  }
+  SetBytesProcessed(state);
+}
+
+void ApplyCopyArguments(benchmark::Benchmark* benchmark) {
+  benchmark->ArgNames({"length", "source_offset", "target_offset"});
+  const int64_t common_sizes[] = {
+      4,   8,   16,   31,       32,        33,        64,
+      128, 256, 1024, 4 * 1024, 16 * 1024, 64 * 1024, 2 * 1024 * 1024,
+  };
+  const int64_t alignment_cases[][2] = {
+      {0, 0},
+      {8, 8},
+      {4, 4},
+      {1, 2},
+  };
+  for (int64_t length : common_sizes) {
+    for (const auto& alignment_case : alignment_cases) {
+      benchmark->Args({length, alignment_case[0], alignment_case[1]});
+    }
+  }
+  benchmark->Args({500ll * 1024 * 1024, 0, 0});
+  benchmark->Args({1024ll * 1024 * 1024, 0, 0});
+}
+
+void ApplyFillArguments(benchmark::Benchmark* benchmark) {
+  benchmark->ArgNames({"length", "target_offset", "pattern_length"});
+  const int64_t common_sizes[] = {
+      4,  8,  16,  28,  30,   31,       32,        33,        34,
+      36, 64, 128, 256, 1024, 4 * 1024, 16 * 1024, 64 * 1024, 2 * 1024 * 1024,
+  };
+  const int64_t pattern_lengths[] = {1, 2, 4};
+  for (int64_t pattern_length : pattern_lengths) {
+    for (int64_t length : common_sizes) {
+      if ((length % pattern_length) != 0) continue;
+      const int64_t offsets[] = {0, pattern_length};
+      for (int64_t offset : offsets) {
+        benchmark->Args({length, offset, pattern_length});
+      }
+    }
+  }
+  benchmark->Args({500ll * 1024 * 1024, 0, 4});
+  benchmark->Args({1024ll * 1024 * 1024, 0, 4});
+}
+
+void ApplyCopySubmitOnlyArguments(benchmark::Benchmark* benchmark) {
+  benchmark->ArgNames({"length", "source_offset", "target_offset"});
+  benchmark->Args({4, 0, 0});
+  benchmark->Args({8, 0, 0});
+  benchmark->Args({64, 0, 0});
+  benchmark->Args({4096, 0, 0});
+}
+
+void ApplyFillSubmitOnlyArguments(benchmark::Benchmark* benchmark) {
+  benchmark->ArgNames({"length", "target_offset", "pattern_length"});
+  benchmark->Args({4, 0, 4});
+  benchmark->Args({64, 0, 4});
+  benchmark->Args({4096, 0, 4});
+}
+
+void ApplyUpdateArguments(benchmark::Benchmark* benchmark) {
+  benchmark->ArgNames({"length", "source_offset", "target_offset"});
+  const int64_t common_sizes[] = {
+      4, 8, 16, 31, 32, 33, 64, 128, 256, 1024, 4 * 1024, 16 * 1024, 64 * 1024,
+  };
+  const int64_t alignment_cases[][2] = {
+      {0, 0},
+      {3, 0},
+      {5, 4},
+      {1, 1},
+  };
+  for (int64_t length : common_sizes) {
+    for (const auto& alignment_case : alignment_cases) {
+      benchmark->Args({length, alignment_case[0], alignment_case[1]});
+    }
+  }
+}
+
+void ApplyUpdateSubmitOnlyArguments(benchmark::Benchmark* benchmark) {
+  benchmark->ArgNames({"length", "source_offset", "target_offset"});
+  benchmark->Args({4, 0, 0});
+  benchmark->Args({8, 3, 0});
+  benchmark->Args({64, 0, 0});
+  benchmark->Args({4096, 0, 0});
+}
+
+BENCHMARK_REGISTER_F(BlitBenchmark, QueueCopy)
+    ->Apply(ApplyCopyArguments)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(BlitBenchmark, QueueFill)
+    ->Apply(ApplyFillArguments)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(BlitBenchmark, QueueCopyBatch20)
+    ->Apply(ApplyCopyArguments)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(BlitBenchmark, QueueCopyBatch20SubmitOnly)
+    ->Apply(ApplyCopySubmitOnlyArguments)
+    ->Iterations(kSubmitOnlyIterations)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(BlitBenchmark, QueueFillBatch20)
+    ->Apply(ApplyFillArguments)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(BlitBenchmark, QueueFillBatch20SubmitOnly)
+    ->Apply(ApplyFillSubmitOnlyArguments)
+    ->Iterations(kSubmitOnlyIterations)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(BlitBenchmark, QueueUpdate)
+    ->Apply(ApplyUpdateArguments)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(BlitBenchmark, QueueUpdateBatch20)
+    ->Apply(ApplyUpdateArguments)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(BlitBenchmark, QueueUpdateBatch20SubmitOnly)
+    ->Apply(ApplyUpdateSubmitOnlyArguments)
+    ->Iterations(kSubmitOnlyIterations)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+
+}  // namespace
+
+int main(int argc, char** argv) {
+  iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_UNDEFINED_OK |
+                               IREE_FLAGS_PARSE_MODE_CONTINUE_AFTER_HELP,
+                           &argc, &argv);
+  benchmark::Initialize(&argc, argv);
+  if (benchmark::ReportUnrecognizedArguments(argc, argv)) return 1;
+  benchmark::RunSpecifiedBenchmarks();
+  benchmark::Shutdown();
+  BlitBenchmark::DeinitializeOnce();
+  return 0;
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/block_pool.c b/runtime/src/iree/hal/drivers/amdgpu/util/block_pool.c
index 1aebc90..82932c9 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/util/block_pool.c
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/block_pool.c
@@ -13,6 +13,9 @@
 static iree_status_t iree_hal_amdgpu_block_pool_grow(
     iree_hal_amdgpu_block_pool_t* block_pool);
 
+static const char* IREE_HAL_AMDGPU_BLOCK_POOL_TRACE_ID =
+    "iree-hal-amdgpu-block-pool";
+
 iree_status_t iree_hal_amdgpu_block_pool_initialize(
     const iree_hal_amdgpu_libhsa_t* libhsa,
     iree_hal_amdgpu_block_pool_options_t options, hsa_agent_t agent,
@@ -37,6 +40,10 @@
   out_block_pool->agent = agent;
   out_block_pool->memory_pool = memory_pool;
   out_block_pool->block_size = options.block_size;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_hal_memory_trace_initialize_pool(
+              options.trace_name, IREE_HAL_AMDGPU_BLOCK_POOL_TRACE_ID,
+              host_allocator, &out_block_pool->trace));
 
   // Query the memory pool for its allocation granularity.
   // This is not the minimum allocation size
@@ -104,6 +111,7 @@
               "must have freed all blocks prior to deallocating the pool");
 
   iree_slim_mutex_deinitialize(&block_pool->mutex);
+  iree_hal_memory_trace_deinitialize(&block_pool->trace);
 
   IREE_TRACE_ZONE_END(z0);
 }
@@ -141,6 +149,10 @@
     block_allocation->used_count = 0;
     block_pool->allocations_head = block_allocation;
 
+    iree_hal_memory_trace_alloc(
+        &block_pool->trace, base_ptr,
+        block_pool->blocks_per_allocation * block_pool->block_size);
+
     // Setup all blocks to point at their relevant memory.
     // We append to the block pool free list as we go.
     for (iree_host_size_t i = 0; i < block_pool->blocks_per_allocation; ++i) {
@@ -151,8 +163,9 @@
       block_pool->free_blocks_head = block;
     }
   } else {
-    IREE_IGNORE_ERROR(iree_hsa_amd_memory_pool_free(
-        IREE_LIBHSA(block_pool->libhsa), base_ptr));
+    status = iree_status_join(
+        status, iree_hsa_amd_memory_pool_free(IREE_LIBHSA(block_pool->libhsa),
+                                              base_ptr));
   }
 
   IREE_TRACE_ZONE_END(z0);
@@ -202,8 +215,10 @@
     iree_hal_amdgpu_block_allocation_t* next_allocation = allocation->next;
     if (allocation->used_count == 0) {
       // No blocks outstanding - can free and remove from the allocation list.
-      IREE_IGNORE_ERROR(iree_hsa_amd_memory_pool_free(
-          IREE_LIBHSA(block_pool->libhsa), allocation->base_ptr));
+      iree_hal_memory_trace_free(&block_pool->trace, allocation->base_ptr);
+      iree_hal_amdgpu_hsa_cleanup_assert_success(
+          iree_hsa_amd_memory_pool_free_raw(block_pool->libhsa,
+                                            allocation->base_ptr));
       if (allocation == block_pool->allocations_head) {
         IREE_ASSERT(!prev_allocation);
         block_pool->allocations_head = next_allocation;
@@ -235,24 +250,26 @@
 
   // If there are no free blocks available grow the pool by one block allocation
   // (which may allocate multiple blocks worth of memory).
+  iree_status_t status = iree_ok_status();
   if (!block_pool->free_blocks_head) {
-    IREE_RETURN_AND_END_ZONE_IF_ERROR(
-        z0, iree_hal_amdgpu_block_pool_grow(block_pool));
+    status = iree_hal_amdgpu_block_pool_grow(block_pool);
   }
 
-  // Slice off the next free block.
-  iree_hal_amdgpu_block_t* block = block_pool->free_blocks_head;
-  block_pool->free_blocks_head = block->next;
-  block->next = NULL;  // user may use this
-  block->prev = NULL;
-  memset(block->user_data, 0, sizeof(block->user_data));
-  ++block->allocation->used_count;
+  if (iree_status_is_ok(status)) {
+    // Slice off the next free block.
+    iree_hal_amdgpu_block_t* block = block_pool->free_blocks_head;
+    block_pool->free_blocks_head = block->next;
+    block->next = NULL;  // user may use this
+    block->prev = NULL;
+    memset(block->user_data, 0, sizeof(block->user_data));
+    ++block->allocation->used_count;
+    *out_block = block;
+  }
 
   iree_slim_mutex_unlock(&block_pool->mutex);
 
-  *out_block = block;
   IREE_TRACE_ZONE_END(z0);
-  return iree_ok_status();
+  return status;
 }
 
 void iree_hal_amdgpu_block_pool_release(
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/block_pool.h b/runtime/src/iree/hal/drivers/amdgpu/util/block_pool.h
index 64946ff..dcaace0 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/util/block_pool.h
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/block_pool.h
@@ -10,6 +10,7 @@
 #include "iree/base/api.h"
 #include "iree/base/threading/mutex.h"
 #include "iree/hal/drivers/amdgpu/util/libhsa.h"
+#include "iree/hal/memory/tracing.h"
 
 #ifdef __cplusplus
 extern "C" {
@@ -47,6 +48,9 @@
   // At least this number of blocks will be allocated during pool
   // initialization, possibly split into multiple block pool allocations.
   iree_host_size_t initial_capacity;
+  // Optional named-memory trace identifier for HSA backing allocations made by
+  // this pool. Empty uses a generic process-stable identifier.
+  iree_string_view_t trace_name;
 } iree_hal_amdgpu_block_pool_options_t;
 
 // A block in the block pool.
@@ -117,6 +121,8 @@
   hsa_agent_t agent;
   // Memory pool blocks are allocated from.
   hsa_amd_memory_pool_t memory_pool;
+  // Stable named-memory stream for HSA backing allocations in this pool.
+  iree_hal_memory_trace_t trace;
   // Size in bytes of a block on device.
   iree_device_size_t block_size;
   // Number of blocks in a single device allocation.
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/block_pool_test.cc b/runtime/src/iree/hal/drivers/amdgpu/util/block_pool_test.cc
index 632af02..ce5d4bd 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/util/block_pool_test.cc
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/block_pool_test.cc
@@ -34,7 +34,7 @@
         host_allocator, &libhsa);
     if (!iree_status_is_ok(status)) {
       iree_status_fprint(stderr, status);
-      iree_status_ignore(status);
+      iree_status_free(status);
       GTEST_SKIP() << "HSA not available, skipping tests";
     }
     IREE_ASSERT_OK(
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/code_object_target.c b/runtime/src/iree/hal/drivers/amdgpu/util/code_object_target.c
new file mode 100644
index 0000000..fdb08fe
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/code_object_target.c
@@ -0,0 +1,314 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/util/code_object_target.h"
+
+#include <string.h>
+
+#include "iree/base/alignment.h"
+#include "iree/hal/utils/elf_format.h"
+
+//===----------------------------------------------------------------------===//
+// AMDGPU ELF Constants
+//===----------------------------------------------------------------------===//
+
+typedef enum iree_hal_amdgpu_elf_ident_e {
+  IREE_HAL_AMDGPU_ELF_EI_CLASS = 4,
+  IREE_HAL_AMDGPU_ELF_EI_DATA = 5,
+  IREE_HAL_AMDGPU_ELF_EI_VERSION = 6,
+  IREE_HAL_AMDGPU_ELF_EI_OSABI = 7,
+  IREE_HAL_AMDGPU_ELF_EI_ABIVERSION = 8,
+} iree_hal_amdgpu_elf_ident_t;
+
+typedef enum iree_hal_amdgpu_elf_class_e {
+  IREE_HAL_AMDGPU_ELF_CLASS_64 = 2,
+} iree_hal_amdgpu_elf_class_t;
+
+typedef enum iree_hal_amdgpu_elf_data_e {
+  IREE_HAL_AMDGPU_ELF_DATA_2LSB = 1,
+} iree_hal_amdgpu_elf_data_t;
+
+typedef enum iree_hal_amdgpu_elf_version_e {
+  IREE_HAL_AMDGPU_ELF_VERSION_CURRENT = 1,
+} iree_hal_amdgpu_elf_version_t;
+
+typedef enum iree_hal_amdgpu_elf_machine_e {
+  IREE_HAL_AMDGPU_ELF_MACHINE_AMDGPU = 224,
+} iree_hal_amdgpu_elf_machine_t;
+
+typedef enum iree_hal_amdgpu_elf_osabi_e {
+  IREE_HAL_AMDGPU_ELF_OSABI_HSA = 64,
+} iree_hal_amdgpu_elf_osabi_t;
+
+typedef enum iree_hal_amdgpu_elf_hsa_abi_version_e {
+  IREE_HAL_AMDGPU_ELF_HSA_ABI_VERSION_V3 = 1,
+  IREE_HAL_AMDGPU_ELF_HSA_ABI_VERSION_V4 = 2,
+  IREE_HAL_AMDGPU_ELF_HSA_ABI_VERSION_V5 = 3,
+  IREE_HAL_AMDGPU_ELF_HSA_ABI_VERSION_V6 = 4,
+} iree_hal_amdgpu_elf_hsa_abi_version_t;
+
+typedef enum iree_hal_amdgpu_elf_header_offset_e {
+  IREE_HAL_AMDGPU_ELF_HEADER_E_MACHINE_OFFSET = 18,
+  IREE_HAL_AMDGPU_ELF_HEADER_E_VERSION_OFFSET = 20,
+  IREE_HAL_AMDGPU_ELF64_HEADER_E_FLAGS_OFFSET = 48,
+  IREE_HAL_AMDGPU_ELF64_HEADER_SIZE = 64,
+} iree_hal_amdgpu_elf_header_offset_t;
+
+enum {
+  IREE_HAL_AMDGPU_EF_MACH = 0x0ffu,
+  IREE_HAL_AMDGPU_EF_FEATURE_XNACK_V3 = 0x100u,
+  IREE_HAL_AMDGPU_EF_FEATURE_SRAMECC_V3 = 0x200u,
+  IREE_HAL_AMDGPU_EF_FEATURE_XNACK_V4 = 0x300u,
+  IREE_HAL_AMDGPU_EF_FEATURE_XNACK_UNSUPPORTED_V4 = 0x000u,
+  IREE_HAL_AMDGPU_EF_FEATURE_XNACK_ANY_V4 = 0x100u,
+  IREE_HAL_AMDGPU_EF_FEATURE_XNACK_OFF_V4 = 0x200u,
+  IREE_HAL_AMDGPU_EF_FEATURE_XNACK_ON_V4 = 0x300u,
+  IREE_HAL_AMDGPU_EF_FEATURE_SRAMECC_V4 = 0xc00u,
+  IREE_HAL_AMDGPU_EF_FEATURE_SRAMECC_UNSUPPORTED_V4 = 0x000u,
+  IREE_HAL_AMDGPU_EF_FEATURE_SRAMECC_ANY_V4 = 0x400u,
+  IREE_HAL_AMDGPU_EF_FEATURE_SRAMECC_OFF_V4 = 0x800u,
+  IREE_HAL_AMDGPU_EF_FEATURE_SRAMECC_ON_V4 = 0xc00u,
+  IREE_HAL_AMDGPU_EF_GENERIC_VERSION = 0xff000000u,
+  IREE_HAL_AMDGPU_EF_GENERIC_VERSION_OFFSET = 24,
+};
+
+typedef struct iree_hal_amdgpu_elf_machine_target_t {
+  // AMDGPU EF_AMDGPU_MACH_* value.
+  uint32_t machine;
+  // Processor string represented by |machine|.
+  iree_string_view_t processor;
+  // True if old V3 e_flags can explicitly encode SRAM ECC off for this target.
+  bool sramecc_supported;
+  // True if old V3 e_flags can explicitly encode XNACK off for this target.
+  bool xnack_supported;
+} iree_hal_amdgpu_elf_machine_target_t;
+
+static const iree_hal_amdgpu_elf_machine_target_t
+    iree_hal_amdgpu_elf_machine_targets[] = {
+        {0x020, IREE_SVL("gfx600"), false, false},
+        {0x021, IREE_SVL("gfx601"), false, false},
+        {0x022, IREE_SVL("gfx700"), false, false},
+        {0x023, IREE_SVL("gfx701"), false, false},
+        {0x024, IREE_SVL("gfx702"), false, false},
+        {0x025, IREE_SVL("gfx703"), false, false},
+        {0x026, IREE_SVL("gfx704"), false, false},
+        {0x028, IREE_SVL("gfx801"), false, true},
+        {0x029, IREE_SVL("gfx802"), false, false},
+        {0x02a, IREE_SVL("gfx803"), false, false},
+        {0x02b, IREE_SVL("gfx810"), false, true},
+        {0x02c, IREE_SVL("gfx900"), false, true},
+        {0x02d, IREE_SVL("gfx902"), false, true},
+        {0x02e, IREE_SVL("gfx904"), false, true},
+        {0x02f, IREE_SVL("gfx906"), true, true},
+        {0x030, IREE_SVL("gfx908"), true, true},
+        {0x031, IREE_SVL("gfx909"), false, true},
+        {0x032, IREE_SVL("gfx90c"), false, true},
+        {0x033, IREE_SVL("gfx1010"), false, true},
+        {0x034, IREE_SVL("gfx1011"), false, true},
+        {0x035, IREE_SVL("gfx1012"), false, true},
+        {0x036, IREE_SVL("gfx1030"), false, false},
+        {0x037, IREE_SVL("gfx1031"), false, false},
+        {0x038, IREE_SVL("gfx1032"), false, false},
+        {0x039, IREE_SVL("gfx1033"), false, false},
+        {0x03a, IREE_SVL("gfx602"), false, false},
+        {0x03b, IREE_SVL("gfx705"), false, false},
+        {0x03c, IREE_SVL("gfx805"), false, false},
+        {0x03d, IREE_SVL("gfx1035"), false, false},
+        {0x03e, IREE_SVL("gfx1034"), false, false},
+        {0x03f, IREE_SVL("gfx90a"), true, true},
+        {0x040, IREE_SVL("gfx940"), true, true},
+        {0x041, IREE_SVL("gfx1100"), false, false},
+        {0x042, IREE_SVL("gfx1013"), false, true},
+        {0x043, IREE_SVL("gfx1150"), false, false},
+        {0x044, IREE_SVL("gfx1103"), false, false},
+        {0x045, IREE_SVL("gfx1036"), false, false},
+        {0x046, IREE_SVL("gfx1101"), false, false},
+        {0x047, IREE_SVL("gfx1102"), false, false},
+        {0x048, IREE_SVL("gfx1200"), false, false},
+        {0x049, IREE_SVL("gfx1250"), false, false},
+        {0x04a, IREE_SVL("gfx1151"), false, false},
+        {0x04b, IREE_SVL("gfx941"), true, true},
+        {0x04c, IREE_SVL("gfx942"), true, true},
+        {0x04e, IREE_SVL("gfx1201"), false, false},
+        {0x04f, IREE_SVL("gfx950"), true, true},
+        {0x050, IREE_SVL("gfx1310"), false, false},
+        {0x051, IREE_SVL("gfx9-generic"), false, true},
+        {0x052, IREE_SVL("gfx10-1-generic"), false, true},
+        {0x053, IREE_SVL("gfx10-3-generic"), false, false},
+        {0x054, IREE_SVL("gfx11-generic"), false, false},
+        {0x055, IREE_SVL("gfx1152"), false, false},
+        {0x058, IREE_SVL("gfx1153"), false, false},
+        {0x059, IREE_SVL("gfx12-generic"), false, false},
+        {0x05a, IREE_SVL("gfx1251"), false, false},
+        {0x05b, IREE_SVL("gfx12-5-generic"), false, false},
+        {0x05c, IREE_SVL("gfx1172"), false, false},
+        {0x05d, IREE_SVL("gfx1170"), false, false},
+        {0x05e, IREE_SVL("gfx1171"), false, false},
+        {0x05f, IREE_SVL("gfx9-4-generic"), true, true},
+};
+
+//===----------------------------------------------------------------------===//
+// AMDGPU Code Object Target Parsing
+//===----------------------------------------------------------------------===//
+
+static bool iree_hal_amdgpu_elf_has_available_bytes(
+    iree_const_byte_span_t elf_data, iree_host_size_t byte_count) {
+  return elf_data.data != NULL &&
+         (elf_data.data_length == 0 || byte_count <= elf_data.data_length);
+}
+
+static const iree_hal_amdgpu_elf_machine_target_t*
+iree_hal_amdgpu_lookup_elf_machine_target(uint32_t machine) {
+  for (iree_host_size_t i = 0;
+       i < IREE_ARRAYSIZE(iree_hal_amdgpu_elf_machine_targets); ++i) {
+    if (iree_hal_amdgpu_elf_machine_targets[i].machine == machine) {
+      return &iree_hal_amdgpu_elf_machine_targets[i];
+    }
+  }
+  return NULL;
+}
+
+static iree_hal_amdgpu_target_feature_state_t
+iree_hal_amdgpu_code_object_decode_v4_feature(uint32_t e_flags, uint32_t mask,
+                                              uint32_t any_value,
+                                              uint32_t off_value,
+                                              uint32_t on_value) {
+  const uint32_t value = e_flags & mask;
+  if (value == any_value) {
+    return IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_ANY;
+  } else if (value == off_value) {
+    return IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_OFF;
+  } else if (value == on_value) {
+    return IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_ON;
+  }
+  return IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_UNSUPPORTED;
+}
+
+IREE_API_EXPORT iree_status_t iree_hal_amdgpu_code_object_target_id_from_elf(
+    iree_const_byte_span_t elf_data,
+    iree_hal_amdgpu_target_id_t* out_target_id) {
+  IREE_ASSERT_ARGUMENT(out_target_id);
+  memset(out_target_id, 0, sizeof(*out_target_id));
+
+  if (!iree_hal_elf_data_starts_with_magic(elf_data)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AMDGPU code object does not begin with ELF magic");
+  }
+  if (!iree_hal_amdgpu_elf_has_available_bytes(
+          elf_data, IREE_HAL_AMDGPU_ELF64_HEADER_SIZE)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "AMDGPU code object ELF header truncated");
+  }
+
+  const uint8_t* header = elf_data.data;
+  if (header[IREE_HAL_AMDGPU_ELF_EI_CLASS] != IREE_HAL_AMDGPU_ELF_CLASS_64) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AMDGPU code object must be a 64-bit ELF");
+  }
+  if (header[IREE_HAL_AMDGPU_ELF_EI_DATA] != IREE_HAL_AMDGPU_ELF_DATA_2LSB) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AMDGPU code object must be little-endian");
+  }
+  if (header[IREE_HAL_AMDGPU_ELF_EI_VERSION] !=
+      IREE_HAL_AMDGPU_ELF_VERSION_CURRENT) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "unsupported AMDGPU code object ELF version %u",
+                            header[IREE_HAL_AMDGPU_ELF_EI_VERSION]);
+  }
+  if (header[IREE_HAL_AMDGPU_ELF_EI_OSABI] != IREE_HAL_AMDGPU_ELF_OSABI_HSA) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AMDGPU code object must use HSA OSABI");
+  }
+
+  const uint16_t e_machine = iree_unaligned_load_le_u16(
+      (const uint16_t*)(header + IREE_HAL_AMDGPU_ELF_HEADER_E_MACHINE_OFFSET));
+  if (e_machine != IREE_HAL_AMDGPU_ELF_MACHINE_AMDGPU) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "ELF machine %u is not AMDGPU", e_machine);
+  }
+  const uint32_t e_version = iree_unaligned_load_le_u32(
+      (const uint32_t*)(header + IREE_HAL_AMDGPU_ELF_HEADER_E_VERSION_OFFSET));
+  if (e_version != IREE_HAL_AMDGPU_ELF_VERSION_CURRENT) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "unsupported AMDGPU code object e_version %u",
+                            e_version);
+  }
+
+  const uint32_t e_flags = iree_unaligned_load_le_u32(
+      (const uint32_t*)(header + IREE_HAL_AMDGPU_ELF64_HEADER_E_FLAGS_OFFSET));
+  const iree_hal_amdgpu_elf_machine_target_t* machine_target =
+      iree_hal_amdgpu_lookup_elf_machine_target(e_flags &
+                                                IREE_HAL_AMDGPU_EF_MACH);
+  if (machine_target == NULL) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "unsupported AMDGPU code object processor e_flags value 0x%x",
+        e_flags & IREE_HAL_AMDGPU_EF_MACH);
+  }
+
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_target_id_parse(
+      machine_target->processor, IREE_HAL_AMDGPU_TARGET_ID_PARSE_FLAG_NONE,
+      out_target_id));
+
+  const uint8_t abi_version = header[IREE_HAL_AMDGPU_ELF_EI_ABIVERSION];
+  if (abi_version == IREE_HAL_AMDGPU_ELF_HSA_ABI_VERSION_V3) {
+    out_target_id->sramecc =
+        iree_all_bits_set(e_flags, IREE_HAL_AMDGPU_EF_FEATURE_SRAMECC_V3)
+            ? IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_ON
+        : machine_target->sramecc_supported
+            ? IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_OFF
+            : IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_UNSUPPORTED;
+    out_target_id->xnack =
+        iree_all_bits_set(e_flags, IREE_HAL_AMDGPU_EF_FEATURE_XNACK_V3)
+            ? IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_ON
+        : machine_target->xnack_supported
+            ? IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_OFF
+            : IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_UNSUPPORTED;
+  } else if (abi_version == IREE_HAL_AMDGPU_ELF_HSA_ABI_VERSION_V4 ||
+             abi_version == IREE_HAL_AMDGPU_ELF_HSA_ABI_VERSION_V5 ||
+             abi_version == IREE_HAL_AMDGPU_ELF_HSA_ABI_VERSION_V6) {
+    out_target_id->sramecc = iree_hal_amdgpu_code_object_decode_v4_feature(
+        e_flags, IREE_HAL_AMDGPU_EF_FEATURE_SRAMECC_V4,
+        IREE_HAL_AMDGPU_EF_FEATURE_SRAMECC_ANY_V4,
+        IREE_HAL_AMDGPU_EF_FEATURE_SRAMECC_OFF_V4,
+        IREE_HAL_AMDGPU_EF_FEATURE_SRAMECC_ON_V4);
+    out_target_id->xnack = iree_hal_amdgpu_code_object_decode_v4_feature(
+        e_flags, IREE_HAL_AMDGPU_EF_FEATURE_XNACK_V4,
+        IREE_HAL_AMDGPU_EF_FEATURE_XNACK_ANY_V4,
+        IREE_HAL_AMDGPU_EF_FEATURE_XNACK_OFF_V4,
+        IREE_HAL_AMDGPU_EF_FEATURE_XNACK_ON_V4);
+    if (abi_version == IREE_HAL_AMDGPU_ELF_HSA_ABI_VERSION_V6) {
+      out_target_id->generic_version =
+          (e_flags & IREE_HAL_AMDGPU_EF_GENERIC_VERSION) >>
+          IREE_HAL_AMDGPU_EF_GENERIC_VERSION_OFFSET;
+    }
+  } else {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "unsupported AMDGPU HSA code object ABI version %u",
+                            abi_version);
+  }
+
+  if (out_target_id->kind == IREE_HAL_AMDGPU_TARGET_KIND_GENERIC &&
+      abi_version != IREE_HAL_AMDGPU_ELF_HSA_ABI_VERSION_V6) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "generic AMDGPU code object target requires HSA ABI v6");
+  }
+  if (out_target_id->kind == IREE_HAL_AMDGPU_TARGET_KIND_GENERIC &&
+      out_target_id->generic_version == 0) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "generic AMDGPU code object target has no generic version");
+  }
+  if (out_target_id->kind != IREE_HAL_AMDGPU_TARGET_KIND_GENERIC &&
+      out_target_id->generic_version != 0) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "non-generic AMDGPU code object target has generic version %u",
+        out_target_id->generic_version);
+  }
+  return iree_ok_status();
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/code_object_target.h b/runtime/src/iree/hal/drivers/amdgpu/util/code_object_target.h
new file mode 100644
index 0000000..8bc8c50
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/code_object_target.h
@@ -0,0 +1,32 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_UTIL_CODE_OBJECT_TARGET_H_
+#define IREE_HAL_DRIVERS_AMDGPU_UTIL_CODE_OBJECT_TARGET_H_
+
+#include "iree/hal/drivers/amdgpu/util/target_id.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+//===----------------------------------------------------------------------===//
+// AMDGPU Code Object Targets
+//===----------------------------------------------------------------------===//
+
+// Recovers the AMDGPU target ID encoded in an HSA code-object ELF header.
+//
+// The returned processor string is borrowed from static target tables and
+// remains valid for the lifetime of the process.
+iree_status_t iree_hal_amdgpu_code_object_target_id_from_elf(
+    iree_const_byte_span_t elf_data,
+    iree_hal_amdgpu_target_id_t* out_target_id);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_UTIL_CODE_OBJECT_TARGET_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/code_object_target_test.cc b/runtime/src/iree/hal/drivers/amdgpu/util/code_object_target_test.cc
new file mode 100644
index 0000000..7088b4a
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/code_object_target_test.cc
@@ -0,0 +1,147 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/util/code_object_target.h"
+
+#include <array>
+#include <string>
+
+#include "iree/base/alignment.h"
+#include "iree/base/api.h"
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace iree::hal::amdgpu {
+namespace {
+
+static constexpr uint8_t kElfClass64 = 2;
+static constexpr uint8_t kElfData2Lsb = 1;
+static constexpr uint8_t kElfVersionCurrent = 1;
+static constexpr uint8_t kElfOsAbiAmdgpuHsa = 64;
+static constexpr uint8_t kElfAbiVersionV3 = 1;
+static constexpr uint8_t kElfAbiVersionV5 = 3;
+static constexpr uint8_t kElfAbiVersionV6 = 4;
+static constexpr uint16_t kElfMachineAmdgpu = 224;
+static constexpr uint32_t kElfMachineGfx906 = 0x02f;
+static constexpr uint32_t kElfMachineGfx1100 = 0x041;
+static constexpr uint32_t kElfMachineGfx942 = 0x04c;
+static constexpr uint32_t kElfMachineGfx11Generic = 0x054;
+static constexpr uint32_t kElfFeatureXnackUnsupportedV4 = 0x000;
+static constexpr uint32_t kElfFeatureXnackOffV4 = 0x200;
+static constexpr uint32_t kElfFeatureSrameccAnyV4 = 0x400;
+static constexpr uint32_t kElfFeatureSrameccOnV4 = 0xc00;
+static constexpr uint32_t kElfGenericVersionOffset = 24;
+
+static std::array<uint8_t, 64> MakeElf64AmdgpuHsa(uint8_t abi_version,
+                                                  uint16_t machine,
+                                                  uint32_t e_flags) {
+  std::array<uint8_t, 64> elf = {};
+  elf[0] = 0x7f;
+  elf[1] = 'E';
+  elf[2] = 'L';
+  elf[3] = 'F';
+  elf[4] = kElfClass64;
+  elf[5] = kElfData2Lsb;
+  elf[6] = kElfVersionCurrent;
+  elf[7] = kElfOsAbiAmdgpuHsa;
+  elf[8] = abi_version;
+  iree_unaligned_store_le_u16((uint16_t*)&elf[18], machine);
+  iree_unaligned_store_le_u32((uint32_t*)&elf[20], kElfVersionCurrent);
+  iree_unaligned_store_le_u32((uint32_t*)&elf[48], e_flags);
+  iree_unaligned_store_le_u16((uint16_t*)&elf[52], (uint16_t)elf.size());
+  return elf;
+}
+
+static iree_hal_amdgpu_target_id_t ParseCodeObjectTarget(
+    const std::array<uint8_t, 64>& elf) {
+  iree_hal_amdgpu_target_id_t target_id;
+  IREE_CHECK_OK(iree_hal_amdgpu_code_object_target_id_from_elf(
+      iree_make_const_byte_span(elf.data(), elf.size()), &target_id));
+  return target_id;
+}
+
+static std::string FormatTargetId(
+    const iree_hal_amdgpu_target_id_t* target_id) {
+  char buffer[64] = {0};
+  IREE_CHECK_OK(iree_hal_amdgpu_target_id_format(
+      target_id, sizeof(buffer), buffer, /*out_buffer_length=*/nullptr));
+  return std::string(buffer);
+}
+
+TEST(CodeObjectTargetTest, ParsesV5FeatureStates) {
+  const auto elf = MakeElf64AmdgpuHsa(
+      kElfAbiVersionV5, kElfMachineAmdgpu,
+      kElfMachineGfx942 | kElfFeatureSrameccOnV4 | kElfFeatureXnackOffV4);
+  auto target_id = ParseCodeObjectTarget(elf);
+  EXPECT_EQ(target_id.kind, IREE_HAL_AMDGPU_TARGET_KIND_EXACT);
+  EXPECT_EQ(target_id.generic_version, 0u);
+  EXPECT_EQ(target_id.sramecc, IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_ON);
+  EXPECT_EQ(target_id.xnack, IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_OFF);
+  EXPECT_EQ(FormatTargetId(&target_id), "gfx942:sramecc+:xnack-");
+}
+
+TEST(CodeObjectTargetTest, ParsesV5AnyAndUnsupportedFeatures) {
+  const auto elf =
+      MakeElf64AmdgpuHsa(kElfAbiVersionV5, kElfMachineAmdgpu,
+                         kElfMachineGfx1100 | kElfFeatureSrameccAnyV4 |
+                             kElfFeatureXnackUnsupportedV4);
+  auto target_id = ParseCodeObjectTarget(elf);
+  EXPECT_EQ(target_id.sramecc, IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_ANY);
+  EXPECT_EQ(target_id.xnack, IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_UNSUPPORTED);
+  EXPECT_EQ(FormatTargetId(&target_id), "gfx1100");
+}
+
+TEST(CodeObjectTargetTest, ParsesV6GenericVersion) {
+  const auto elf = MakeElf64AmdgpuHsa(
+      kElfAbiVersionV6, kElfMachineAmdgpu,
+      kElfMachineGfx11Generic | (1u << kElfGenericVersionOffset));
+  auto target_id = ParseCodeObjectTarget(elf);
+  EXPECT_EQ(target_id.kind, IREE_HAL_AMDGPU_TARGET_KIND_GENERIC);
+  EXPECT_EQ(target_id.generic_version, 1u);
+  EXPECT_EQ(FormatTargetId(&target_id), "gfx11-generic");
+}
+
+TEST(CodeObjectTargetTest, ParsesV3SupportedAbsentFeaturesAsOff) {
+  const auto elf = MakeElf64AmdgpuHsa(kElfAbiVersionV3, kElfMachineAmdgpu,
+                                      kElfMachineGfx906);
+  auto target_id = ParseCodeObjectTarget(elf);
+  EXPECT_EQ(target_id.sramecc, IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_OFF);
+  EXPECT_EQ(target_id.xnack, IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_OFF);
+  EXPECT_EQ(FormatTargetId(&target_id), "gfx906:sramecc-:xnack-");
+}
+
+TEST(CodeObjectTargetTest, RejectsV6GenericWithoutVersion) {
+  const auto elf = MakeElf64AmdgpuHsa(kElfAbiVersionV6, kElfMachineAmdgpu,
+                                      kElfMachineGfx11Generic);
+  iree_hal_amdgpu_target_id_t target_id;
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_INVALID_ARGUMENT,
+      iree_hal_amdgpu_code_object_target_id_from_elf(
+          iree_make_const_byte_span(elf.data(), elf.size()), &target_id));
+}
+
+TEST(CodeObjectTargetTest, RejectsUnsupportedMachineValue) {
+  const auto elf =
+      MakeElf64AmdgpuHsa(kElfAbiVersionV5, kElfMachineAmdgpu, 0x027);
+  iree_hal_amdgpu_target_id_t target_id;
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_INVALID_ARGUMENT,
+      iree_hal_amdgpu_code_object_target_id_from_elf(
+          iree_make_const_byte_span(elf.data(), elf.size()), &target_id));
+}
+
+TEST(CodeObjectTargetTest, RejectsNonAmdgpuElfMachine) {
+  const auto elf =
+      MakeElf64AmdgpuHsa(kElfAbiVersionV5, /*machine=*/3, kElfMachineGfx1100);
+  iree_hal_amdgpu_target_id_t target_id;
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_INVALID_ARGUMENT,
+      iree_hal_amdgpu_code_object_target_id_from_elf(
+          iree_make_const_byte_span(elf.data(), elf.size()), &target_id));
+}
+
+}  // namespace
+}  // namespace iree::hal::amdgpu
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/device_clock.c b/runtime/src/iree/hal/drivers/amdgpu/util/device_clock.c
new file mode 100644
index 0000000..68464f1
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/device_clock.c
@@ -0,0 +1,97 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/util/device_clock.h"
+
+#include <inttypes.h>
+#include <string.h>
+
+#include "iree/hal/drivers/amdgpu/util/kfd.h"
+
+iree_status_t iree_hal_amdgpu_device_clock_counters_validate(
+    uint32_t driver_uid,
+    const iree_hal_amdgpu_device_clock_counters_t* counters) {
+  IREE_ASSERT_ARGUMENT(counters);
+  if (IREE_UNLIKELY(counters->device_clock_counter == 0)) {
+    return iree_make_status(
+        IREE_STATUS_FAILED_PRECONDITION,
+        "device clock source returned an invalid zero device_clock_counter for "
+        "driver_uid=%" PRIu32,
+        driver_uid);
+  }
+  if (IREE_UNLIKELY(counters->host_cpu_timestamp_ns == 0)) {
+    return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                            "device clock source returned an invalid zero "
+                            "host_cpu_timestamp_ns for driver_uid=%" PRIu32,
+                            driver_uid);
+  }
+  if (IREE_UNLIKELY(counters->host_system_timestamp == 0)) {
+    return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                            "device clock source returned an invalid zero "
+                            "host_system_timestamp for driver_uid=%" PRIu32,
+                            driver_uid);
+  }
+  if (IREE_UNLIKELY(counters->host_system_frequency_hz == 0)) {
+    return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                            "device clock source returned an invalid zero "
+                            "host_system_frequency_hz for driver_uid=%" PRIu32,
+                            driver_uid);
+  }
+  return iree_ok_status();
+}
+
+iree_status_t iree_hal_amdgpu_device_clock_source_initialize(
+    iree_hal_amdgpu_device_clock_source_t* out_source) {
+  IREE_ASSERT_ARGUMENT(out_source);
+  memset(out_source, 0, sizeof(*out_source));
+  out_source->platform_handle = -1;
+
+#if defined(IREE_PLATFORM_LINUX)
+  int kfd = -1;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_kfd_open(&kfd));
+  out_source->platform_handle = (intptr_t)kfd;
+  out_source->type = IREE_HAL_AMDGPU_DEVICE_CLOCK_SOURCE_TYPE_LINUX_KFD;
+#endif  // IREE_PLATFORM_LINUX
+
+  return iree_ok_status();
+}
+
+void iree_hal_amdgpu_device_clock_source_deinitialize(
+    iree_hal_amdgpu_device_clock_source_t* source) {
+  if (!source) return;
+  if (source->type == IREE_HAL_AMDGPU_DEVICE_CLOCK_SOURCE_TYPE_LINUX_KFD) {
+    iree_hal_amdgpu_kfd_close((int)source->platform_handle);
+  }
+  memset(source, 0, sizeof(*source));
+  source->platform_handle = -1;
+}
+
+iree_status_t iree_hal_amdgpu_device_clock_source_sample(
+    const iree_hal_amdgpu_device_clock_source_t* source, uint32_t driver_uid,
+    iree_hal_amdgpu_device_clock_counters_t* out_counters) {
+  IREE_ASSERT_ARGUMENT(source);
+  IREE_ASSERT_ARGUMENT(out_counters);
+  memset(out_counters, 0, sizeof(*out_counters));
+
+  switch (source->type) {
+    case IREE_HAL_AMDGPU_DEVICE_CLOCK_SOURCE_TYPE_LINUX_KFD: {
+      iree_hal_amdgpu_kfd_clock_counters_t kfd_counters = {0};
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_kfd_get_clock_counters(
+          (int)source->platform_handle, driver_uid, &kfd_counters));
+      out_counters->device_clock_counter = kfd_counters.gpu_clock_counter;
+      out_counters->host_cpu_timestamp_ns = kfd_counters.cpu_clock_counter;
+      out_counters->host_system_timestamp = kfd_counters.system_clock_counter;
+      out_counters->host_system_frequency_hz = kfd_counters.system_clock_freq;
+      return iree_hal_amdgpu_device_clock_counters_validate(driver_uid,
+                                                            out_counters);
+    }
+    case IREE_HAL_AMDGPU_DEVICE_CLOCK_SOURCE_TYPE_UNAVAILABLE:
+    default:
+      return iree_make_status(
+          IREE_STATUS_UNIMPLEMENTED,
+          "AMDGPU device clock sampling is unavailable on this platform");
+  }
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/device_clock.h b/runtime/src/iree/hal/drivers/amdgpu/util/device_clock.h
new file mode 100644
index 0000000..b98a765
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/device_clock.h
@@ -0,0 +1,72 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_UTIL_DEVICE_CLOCK_H_
+#define IREE_HAL_DRIVERS_AMDGPU_UTIL_DEVICE_CLOCK_H_
+
+#include "iree/base/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Device and host clock counters sampled by one platform source.
+typedef struct iree_hal_amdgpu_device_clock_counters_t {
+  // Device clock counter sampled for the requested GPU.
+  uint64_t device_clock_counter;
+
+  // Host CPU timestamp sampled near the device clock read.
+  uint64_t host_cpu_timestamp_ns;
+
+  // Host system clock counter sampled near the device clock read.
+  uint64_t host_system_timestamp;
+
+  // Frequency in Hz for |host_system_timestamp|.
+  uint64_t host_system_frequency_hz;
+} iree_hal_amdgpu_device_clock_counters_t;
+
+// Validates that |counters| contains a usable clock-correlation sample.
+iree_status_t iree_hal_amdgpu_device_clock_counters_validate(
+    uint32_t driver_uid,
+    const iree_hal_amdgpu_device_clock_counters_t* counters);
+
+// Platform implementation used for device/host clock-correlation sampling.
+typedef enum iree_hal_amdgpu_device_clock_source_type_e {
+  IREE_HAL_AMDGPU_DEVICE_CLOCK_SOURCE_TYPE_UNAVAILABLE = 0,
+  IREE_HAL_AMDGPU_DEVICE_CLOCK_SOURCE_TYPE_LINUX_KFD = 1,
+} iree_hal_amdgpu_device_clock_source_type_t;
+
+// Platform device-clock sampling source.
+//
+// Linux currently backs this with KFD's AMDKFD_IOC_GET_CLOCK_COUNTERS ioctl.
+// Other platforms keep the source unavailable until their HSA runtime exposes
+// equivalent device/host clock correlation.
+typedef struct iree_hal_amdgpu_device_clock_source_t {
+  // Active platform sampling implementation.
+  iree_hal_amdgpu_device_clock_source_type_t type;
+
+  // Opaque platform handle for the active clock source, or -1 when unavailable.
+  intptr_t platform_handle;
+} iree_hal_amdgpu_device_clock_source_t;
+
+// Initializes a platform device-clock source.
+iree_status_t iree_hal_amdgpu_device_clock_source_initialize(
+    iree_hal_amdgpu_device_clock_source_t* out_source);
+
+// Deinitializes |source| and releases its platform handle, if any.
+void iree_hal_amdgpu_device_clock_source_deinitialize(
+    iree_hal_amdgpu_device_clock_source_t* source);
+
+// Samples clock counters for the GPU with HSA driver UID |driver_uid|.
+iree_status_t iree_hal_amdgpu_device_clock_source_sample(
+    const iree_hal_amdgpu_device_clock_source_t* source, uint32_t driver_uid,
+    iree_hal_amdgpu_device_clock_counters_t* out_counters);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_UTIL_DEVICE_CLOCK_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/device_clock_test.cc b/runtime/src/iree/hal/drivers/amdgpu/util/device_clock_test.cc
new file mode 100644
index 0000000..f96b3aa
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/device_clock_test.cc
@@ -0,0 +1,71 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/util/device_clock.h"
+
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace iree::hal::amdgpu {
+namespace {
+
+TEST(DeviceClockTest, ValidateCounters) {
+  iree_hal_amdgpu_device_clock_counters_t counters = {
+      /*.device_clock_counter=*/1,
+      /*.host_cpu_timestamp_ns=*/2,
+      /*.host_system_timestamp=*/3,
+      /*.host_system_frequency_hz=*/4,
+  };
+  IREE_EXPECT_OK(
+      iree_hal_amdgpu_device_clock_counters_validate(1234, &counters));
+
+  counters.device_clock_counter = 0;
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_FAILED_PRECONDITION,
+      iree_hal_amdgpu_device_clock_counters_validate(1234, &counters));
+  counters.device_clock_counter = 1;
+
+  counters.host_cpu_timestamp_ns = 0;
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_FAILED_PRECONDITION,
+      iree_hal_amdgpu_device_clock_counters_validate(1234, &counters));
+  counters.host_cpu_timestamp_ns = 2;
+
+  counters.host_system_timestamp = 0;
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_FAILED_PRECONDITION,
+      iree_hal_amdgpu_device_clock_counters_validate(1234, &counters));
+  counters.host_system_timestamp = 3;
+
+  counters.host_system_frequency_hz = 0;
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_FAILED_PRECONDITION,
+      iree_hal_amdgpu_device_clock_counters_validate(1234, &counters));
+}
+
+TEST(DeviceClockTest, UnavailableSourceSampleFailsExplicitly) {
+  iree_hal_amdgpu_device_clock_source_t source = {
+      /*.type=*/IREE_HAL_AMDGPU_DEVICE_CLOCK_SOURCE_TYPE_UNAVAILABLE,
+      /*.platform_handle=*/-1,
+  };
+  iree_hal_amdgpu_device_clock_counters_t counters = {
+      /*.device_clock_counter=*/1,
+      /*.host_cpu_timestamp_ns=*/2,
+      /*.host_system_timestamp=*/3,
+      /*.host_system_frequency_hz=*/4,
+  };
+
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_UNIMPLEMENTED,
+      iree_hal_amdgpu_device_clock_source_sample(&source, 1234, &counters));
+  EXPECT_EQ(counters.device_clock_counter, 0);
+  EXPECT_EQ(counters.host_cpu_timestamp_ns, 0);
+  EXPECT_EQ(counters.host_system_timestamp, 0);
+  EXPECT_EQ(counters.host_system_frequency_hz, 0);
+}
+
+}  // namespace
+}  // namespace iree::hal::amdgpu
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/device_library.c b/runtime/src/iree/hal/drivers/amdgpu/util/device_library.c
index 41b359c..ed55d2a 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/util/device_library.c
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/device_library.c
@@ -6,8 +6,10 @@
 
 #include "iree/hal/drivers/amdgpu/util/device_library.h"
 
-#include "iree/hal/drivers/amdgpu/device/binaries.h"
+#include "iree/base/internal/debugging.h"
+#include "iree/hal/drivers/amdgpu/device/binaries/toc.h"
 #include "iree/hal/drivers/amdgpu/device/kernels.h"
+#include "iree/hal/drivers/amdgpu/util/device_library_target.h"
 #include "iree/hal/drivers/amdgpu/util/topology.h"
 
 //===----------------------------------------------------------------------===//
@@ -65,6 +67,44 @@
   return iree_ok_status();
 }
 
+static const iree_file_toc_t* iree_hal_amdgpu_device_library_find_file_for_arch(
+    iree_string_view_t arch) {
+  const iree_string_view_t isa_prefix = IREE_SVL("amdgcn-amd-amdhsa--");
+  for (iree_host_size_t i = 0; i < iree_hal_amdgpu_device_binaries_size();
+       ++i) {
+    const iree_file_toc_t* file_toc =
+        &iree_hal_amdgpu_device_binaries_create()[i];
+    iree_string_view_t file_name = iree_make_cstring_view(file_toc->name);
+    if (!iree_string_view_starts_with(file_name, isa_prefix)) continue;
+    iree_string_view_t file_arch = iree_string_view_substr(
+        file_name, isa_prefix.size, IREE_STRING_VIEW_NPOS);
+    if (iree_hal_amdgpu_device_library_target_matches_file_arch(file_arch,
+                                                                arch)) {
+      return file_toc;
+    }
+  }
+  return NULL;
+}
+
+static iree_status_t iree_hal_amdgpu_device_library_find_file_for_isa(
+    iree_string_view_t isa_name, const iree_file_toc_t** out_file_toc) {
+  *out_file_toc = NULL;
+  iree_hal_amdgpu_device_library_target_candidate_list_t candidates = {0};
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_device_library_target_candidates_from_isa(isa_name,
+                                                                &candidates));
+  for (iree_host_size_t i = 0; i < candidates.count; ++i) {
+    const iree_file_toc_t* file_toc =
+        iree_hal_amdgpu_device_library_find_file_for_arch(
+            candidates.values[i].value);
+    if (file_toc) {
+      *out_file_toc = file_toc;
+      break;
+    }
+  }
+  return iree_ok_status();
+}
+
 // Selects a device library binary file that supports the ISA of the provided
 // |agent|.
 static iree_status_t iree_hal_amdgpu_device_library_select_file(
@@ -108,15 +148,12 @@
                                       HSA_ISA_INFO_NAME, isa_name_buffer));
     iree_string_view_t isa_name =
         iree_make_string_view(isa_name_buffer, isa_name_length - /*NUL*/ 1);
-    for (iree_host_size_t j = 0; j < iree_hal_amdgpu_device_binaries_size();
-         ++j) {
-      const iree_file_toc_t* file_toc =
-          &iree_hal_amdgpu_device_binaries_create()[j];
-      if (iree_string_view_starts_with(IREE_SV(file_toc->name), isa_name)) {
-        best_isa = isa;
-        best_file_toc = file_toc;
-        break;
-      }
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_hal_amdgpu_device_library_find_file_for_isa(isa_name,
+                                                             &best_file_toc));
+    if (best_file_toc) {
+      best_isa = isa;
+      break;
     }
   }
 
@@ -135,20 +172,33 @@
 #if IREE_STATUS_MODE >= 2
     iree_string_builder_t builder;
     iree_string_builder_initialize(host_allocator, &builder);
-    IREE_IGNORE_ERROR(iree_string_builder_append_string(
-        &builder, IREE_SV("available in runtime build: [")));
-    IREE_IGNORE_ERROR(iree_file_toc_append_names_to_builder(
-        iree_hal_amdgpu_device_binaries_create(),
-        iree_hal_amdgpu_device_binaries_size(), &builder));
-    IREE_IGNORE_ERROR(iree_string_builder_append_string(
-        &builder, IREE_SV("], supported by agent: [")));
-    IREE_IGNORE_ERROR(iree_hal_amdgpu_agent_available_isas_append_to_builder(
-        libhsa, &available_isas, &builder));
-    IREE_IGNORE_ERROR(
-        iree_string_builder_append_string(&builder, IREE_SV("]")));
-    status = iree_status_annotate_f(status, "%.*s",
-                                    (int)iree_string_builder_size(&builder),
-                                    iree_string_builder_buffer(&builder));
+    iree_status_t annotation_status = iree_string_builder_append_string(
+        &builder, IREE_SV("available in runtime build: ["));
+    if (iree_status_is_ok(annotation_status)) {
+      annotation_status = iree_file_toc_append_names_to_builder(
+          iree_hal_amdgpu_device_binaries_create(),
+          iree_hal_amdgpu_device_binaries_size(), &builder);
+    }
+    if (iree_status_is_ok(annotation_status)) {
+      annotation_status = iree_string_builder_append_string(
+          &builder, IREE_SV("], supported by agent: ["));
+    }
+    if (iree_status_is_ok(annotation_status)) {
+      annotation_status =
+          iree_hal_amdgpu_agent_available_isas_append_to_builder(
+              libhsa, &available_isas, &builder);
+    }
+    if (iree_status_is_ok(annotation_status)) {
+      annotation_status =
+          iree_string_builder_append_string(&builder, IREE_SV("]"));
+    }
+    if (iree_status_is_ok(annotation_status)) {
+      status = iree_status_annotate_f(status, "%.*s",
+                                      (int)iree_string_builder_size(&builder),
+                                      iree_string_builder_buffer(&builder));
+    } else {
+      status = iree_status_join(status, annotation_status);
+    }
     iree_string_builder_deinitialize(&builder);
 #endif  // IREE_STATUS_MODE >= 2
   }
@@ -188,19 +238,26 @@
   // lacking. These may have only been used for HSAIL anyway.
   const char* options = NULL;
 
+  // ROCR's executable loader retains some process-lifetime bookkeeping while
+  // building executable/code-object state. Keep LeakSanitizer focused on
+  // IREE-owned allocations by bracketing those HSA setup calls.
+
   // Bind a code object reader to the memory sourced from our rodata.
   hsa_code_object_reader_t code_object_reader;
-  IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0, iree_hsa_code_object_reader_create_from_memory(
-              IREE_LIBHSA(libhsa), file_toc->data, file_toc->size,
-              &code_object_reader));
+  IREE_LEAK_CHECK_DISABLE_PUSH();
+  iree_status_t status = iree_hsa_code_object_reader_create_from_memory(
+      IREE_LIBHSA(libhsa), file_toc->data, file_toc->size, &code_object_reader);
+  IREE_LEAK_CHECK_DISABLE_POP();
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, status);
 
   // Create the executable that will hold all of the loaded code objects.
   // TODO(benvanik): pass profile/rounding mode from queried info.
-  iree_status_t status =
+  IREE_LEAK_CHECK_DISABLE_PUSH();
+  status =
       iree_hsa_executable_create_alt(IREE_LIBHSA(libhsa), HSA_PROFILE_FULL,
                                      HSA_DEFAULT_FLOAT_ROUNDING_MODE_DEFAULT,
                                      options, &out_library->executable);
+  IREE_LEAK_CHECK_DISABLE_POP();
 
   // Load the code object for each agent.
   // Note that we could save off the loaded_code_object per-agent here but then
@@ -210,9 +267,11 @@
   // loaded_code_objects caches the results.
   if (iree_status_is_ok(status)) {
     for (iree_host_size_t i = 0; i < topology->gpu_agent_count; ++i) {
+      IREE_LEAK_CHECK_DISABLE_PUSH();
       status = iree_hsa_executable_load_agent_code_object(
           IREE_LIBHSA(libhsa), out_library->executable, topology->gpu_agents[i],
           code_object_reader, options, NULL);
+      IREE_LEAK_CHECK_DISABLE_POP();
       if (!iree_status_is_ok(status)) break;
     }
   }
@@ -220,13 +279,16 @@
   // Freeze the executable now that loading has completed. Most queries require
   // that the executable be frozen.
   if (iree_status_is_ok(status)) {
+    IREE_LEAK_CHECK_DISABLE_PUSH();
     status = iree_hsa_executable_freeze(IREE_LIBHSA(libhsa),
                                         out_library->executable, options);
+    IREE_LEAK_CHECK_DISABLE_POP();
   }
 
   // Release the reader now that the executable has been fully loaded.
-  IREE_IGNORE_ERROR(iree_hsa_code_object_reader_destroy(IREE_LIBHSA(libhsa),
-                                                        code_object_reader));
+  status =
+      iree_status_join(status, iree_hsa_code_object_reader_destroy(
+                                   IREE_LIBHSA(libhsa), code_object_reader));
 
   if (!iree_status_is_ok(status)) {
     iree_hal_amdgpu_device_library_deinitialize(out_library);
@@ -241,8 +303,8 @@
   IREE_TRACE_ZONE_BEGIN(z0);
 
   if (library->executable.handle) {
-    IREE_IGNORE_ERROR(iree_hsa_executable_destroy(IREE_LIBHSA(library->libhsa),
-                                                  library->executable));
+    iree_hal_amdgpu_hsa_cleanup_assert_success(
+        iree_hsa_executable_destroy_raw(library->libhsa, library->executable));
   }
 
   memset(library, 0, sizeof(*library));
@@ -424,14 +486,6 @@
               HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_KERNARG_SEGMENT_ALIGNMENT,
               &out_kernel_args->kernarg_alignment));
 
-#if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION_DEVICE
-  // TODO(benvanik): intern an export_loc? We don't have a Tracy API for this
-  // yet and our option is to leak the value unconditionally.
-  out_kernel_args->trace_src_loc = 0;
-#else
-  out_kernel_args->trace_src_loc = 0;
-#endif  // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION_DEVICE
-
   IREE_TRACE_ZONE_END(z0);
   return iree_ok_status();
 }
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/device_library_target.c b/runtime/src/iree/hal/drivers/amdgpu/util/device_library_target.c
new file mode 100644
index 0000000..02ea84e
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/device_library_target.c
@@ -0,0 +1,99 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/util/device_library_target.h"
+
+#include "iree/hal/drivers/amdgpu/util/target_id.h"
+
+bool iree_hal_amdgpu_device_library_target_matches_file_arch(
+    iree_string_view_t file_arch, iree_string_view_t target) {
+  if (iree_string_view_is_empty(target)) return false;
+  if (!iree_string_view_starts_with(file_arch, target)) {
+    return false;
+  }
+  iree_string_view_t suffix =
+      iree_string_view_remove_prefix(file_arch, target.size);
+  return iree_string_view_is_empty(suffix) ||
+         iree_string_view_starts_with(suffix, IREE_SV("."));
+}
+
+static iree_status_t
+iree_hal_amdgpu_device_library_target_append_unique_candidate(
+    iree_string_view_t target,
+    iree_hal_amdgpu_device_library_target_candidate_list_t* candidates) {
+  if (iree_string_view_is_empty(target)) return iree_ok_status();
+  for (iree_host_size_t i = 0; i < candidates->count; ++i) {
+    if (iree_string_view_equal(target, candidates->values[i].value)) {
+      return iree_ok_status();
+    }
+  }
+  if (candidates->count >= IREE_ARRAYSIZE(candidates->values)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "AMDGPU device library target candidate list capacity %" PRIhsz
+        " exceeded",
+        IREE_ARRAYSIZE(candidates->values));
+  }
+  iree_hal_amdgpu_device_library_target_candidate_t* candidate =
+      &candidates->values[candidates->count];
+  if (target.size >= sizeof(candidate->storage)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "AMDGPU device library target candidate length %" PRIhsz " exceeded",
+        sizeof(candidate->storage) - 1);
+  }
+  memcpy(candidate->storage, target.data, target.size);
+  candidate->storage[target.size] = 0;
+  candidate->value = iree_make_string_view(candidate->storage, target.size);
+  ++candidates->count;
+  return iree_ok_status();
+}
+
+static iree_status_t
+iree_hal_amdgpu_device_library_target_append_target_id_candidate(
+    const iree_hal_amdgpu_target_id_t* target_id,
+    iree_hal_amdgpu_device_library_target_candidate_list_t* candidates) {
+  char target_id_buffer[64] = {0};
+  iree_host_size_t target_id_length = 0;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_target_id_format(target_id, sizeof(target_id_buffer),
+                                       target_id_buffer, &target_id_length));
+  return iree_hal_amdgpu_device_library_target_append_unique_candidate(
+      iree_make_string_view(target_id_buffer, target_id_length), candidates);
+}
+
+iree_status_t iree_hal_amdgpu_device_library_target_candidates_from_isa(
+    iree_string_view_t isa_name,
+    iree_hal_amdgpu_device_library_target_candidate_list_t* out_candidates) {
+  IREE_ASSERT_ARGUMENT(out_candidates);
+  memset(out_candidates, 0, sizeof(*out_candidates));
+
+  iree_hal_amdgpu_target_id_t agent_target_id;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_target_id_parse_hsa_isa_name(isa_name, &agent_target_id));
+
+  // Try the most specific runtime binary names first. Direct arch names beat
+  // code-object target fallbacks because a concrete code object is preferable
+  // to a family-generic code object when both are bundled into the runtime.
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_device_library_target_append_target_id_candidate(
+          &agent_target_id, out_candidates));
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_device_library_target_append_unique_candidate(
+          agent_target_id.processor, out_candidates));
+  if (agent_target_id.kind == IREE_HAL_AMDGPU_TARGET_KIND_EXACT) {
+    iree_hal_amdgpu_target_id_t code_object_target_id;
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_target_id_lookup_code_object_target(
+        &agent_target_id, &code_object_target_id));
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_device_library_target_append_target_id_candidate(
+            &code_object_target_id, out_candidates));
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_device_library_target_append_unique_candidate(
+            code_object_target_id.processor, out_candidates));
+  }
+  return iree_ok_status();
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/device_library_target.h b/runtime/src/iree/hal/drivers/amdgpu/util/device_library_target.h
new file mode 100644
index 0000000..216d0e2
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/device_library_target.h
@@ -0,0 +1,50 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_UTIL_DEVICE_LIBRARY_TARGET_H_
+#define IREE_HAL_DRIVERS_AMDGPU_UTIL_DEVICE_LIBRARY_TARGET_H_
+
+#include "iree/base/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+//===----------------------------------------------------------------------===//
+// AMDGPU Device-Library Targets
+//===----------------------------------------------------------------------===//
+
+// Candidate embedded device-library target string.
+typedef struct iree_hal_amdgpu_device_library_target_candidate_t {
+  // NUL-terminated target string storage.
+  char storage[64];
+  // Candidate target view pointing into |storage|.
+  iree_string_view_t value;
+} iree_hal_amdgpu_device_library_target_candidate_t;
+
+// Device-library target candidates in descending specificity order.
+typedef struct iree_hal_amdgpu_device_library_target_candidate_list_t {
+  // Number of populated candidate entries.
+  iree_host_size_t count;
+  // Candidate entries from most-specific to least-specific.
+  iree_hal_amdgpu_device_library_target_candidate_t values[4];
+} iree_hal_amdgpu_device_library_target_candidate_list_t;
+
+// Returns true when an embedded file architecture suffix matches |target| as a
+// complete segment before any dot-separated binary suffix.
+bool iree_hal_amdgpu_device_library_target_matches_file_arch(
+    iree_string_view_t file_arch, iree_string_view_t target);
+
+// Builds ordered device-library target candidates for an HSA ISA name.
+iree_status_t iree_hal_amdgpu_device_library_target_candidates_from_isa(
+    iree_string_view_t isa_name,
+    iree_hal_amdgpu_device_library_target_candidate_list_t* out_candidates);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_UTIL_DEVICE_LIBRARY_TARGET_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/device_library_target_test.cc b/runtime/src/iree/hal/drivers/amdgpu/util/device_library_target_test.cc
new file mode 100644
index 0000000..61feadb
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/device_library_target_test.cc
@@ -0,0 +1,67 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/util/device_library_target.h"
+
+#include <string>
+#include <vector>
+
+#include "iree/base/api.h"
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace iree::hal::amdgpu {
+namespace {
+
+static std::vector<std::string> CandidateValues(const char* isa_name) {
+  iree_hal_amdgpu_device_library_target_candidate_list_t candidates = {0};
+  IREE_CHECK_OK(iree_hal_amdgpu_device_library_target_candidates_from_isa(
+      iree_make_cstring_view(isa_name), &candidates));
+  std::vector<std::string> values;
+  values.reserve(candidates.count);
+  for (iree_host_size_t i = 0; i < candidates.count; ++i) {
+    values.emplace_back(candidates.values[i].value.data,
+                        candidates.values[i].value.size);
+  }
+  return values;
+}
+
+TEST(DeviceLibraryTargetTest,
+     PreservesFeatureBearingCandidatesBeforeFallbacks) {
+  const auto values =
+      CandidateValues("amdgcn-amd-amdhsa--gfx942:sramecc+:xnack-");
+
+  ASSERT_EQ(values.size(), 4u);
+  EXPECT_EQ(values[0], "gfx942:sramecc+:xnack-");
+  EXPECT_EQ(values[1], "gfx942");
+  EXPECT_EQ(values[2], "gfx9-4-generic:sramecc+:xnack-");
+  EXPECT_EQ(values[3], "gfx9-4-generic");
+}
+
+TEST(DeviceLibraryTargetTest, MapsGfx12_5TargetsToGenericFamily) {
+  const auto values = CandidateValues("amdgcn-amd-amdhsa--gfx1250");
+
+  ASSERT_EQ(values.size(), 2u);
+  EXPECT_EQ(values[0], "gfx1250");
+  EXPECT_EQ(values[1], "gfx12-5-generic");
+}
+
+TEST(DeviceLibraryTargetTest, MatchesOnlyWholeFileArchSegments) {
+  EXPECT_TRUE(iree_hal_amdgpu_device_library_target_matches_file_arch(
+      IREE_SV("gfx9-4-generic.so"), IREE_SV("gfx9-4-generic")));
+  EXPECT_TRUE(iree_hal_amdgpu_device_library_target_matches_file_arch(
+      IREE_SV("gfx942.debug.so"), IREE_SV("gfx942")));
+
+  EXPECT_FALSE(iree_hal_amdgpu_device_library_target_matches_file_arch(
+      IREE_SV("gfx942x.so"), IREE_SV("gfx942")));
+  EXPECT_FALSE(iree_hal_amdgpu_device_library_target_matches_file_arch(
+      IREE_SV("gfx9-4-generic.so"), IREE_SV("gfx9-4-generic:sramecc+:xnack-")));
+  EXPECT_FALSE(iree_hal_amdgpu_device_library_target_matches_file_arch(
+      IREE_SV("gfx942.so"), IREE_SV("")));
+}
+
+}  // namespace
+}  // namespace iree::hal::amdgpu
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/device_library_test.cc b/runtime/src/iree/hal/drivers/amdgpu/util/device_library_test.cc
index 6bca475..68d5946 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/util/device_library_test.cc
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/device_library_test.cc
@@ -28,7 +28,7 @@
         host_allocator, &libhsa);
     if (!iree_status_is_ok(status)) {
       iree_status_fprint(stderr, status);
-      iree_status_ignore(status);
+      iree_status_free(status);
       GTEST_SKIP() << "HSA not available, skipping tests";
     }
     IREE_ASSERT_OK(
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/epoch_signal_table.h b/runtime/src/iree/hal/drivers/amdgpu/util/epoch_signal_table.h
new file mode 100644
index 0000000..d1a0a8f
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/epoch_signal_table.h
@@ -0,0 +1,161 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_UTIL_EPOCH_SIGNAL_TABLE_H_
+#define IREE_HAL_DRIVERS_AMDGPU_UTIL_EPOCH_SIGNAL_TABLE_H_
+
+#include <string.h>
+
+#include "iree/async/frontier.h"
+#include "iree/base/api.h"
+#include "iree/hal/drivers/amdgpu/util/libhsa.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+//===----------------------------------------------------------------------===//
+// iree_hal_amdgpu_epoch_signal_table_t
+//===----------------------------------------------------------------------===//
+
+// Flat lookup table mapping (device_index, queue_index) to the hsa_signal_t
+// epoch signal for that queue. Shared read-only across all queues on the same
+// machine during normal operation; mutated only during queue init/deinit.
+//
+// The epoch signal is the single hsa_signal_t that the CP decrements on each
+// AQL packet completion. It is the mechanism by which tier 2 (device-side)
+// cross-queue waits work: a queue waiting on a peer emits an AQL barrier-value
+// packet referencing the peer's epoch signal with a condition that fires when
+// the peer's epoch reaches the required value.
+//
+// For producer-frontier-exact cross-queue waits, the submission path reads the
+// semaphore's last_signal cache to identify the producer axis/epoch directly,
+// then does one lookup here to map that producer axis to an hsa_signal_t for a
+// single barrier-value packet. For multi-dependency cases, TP collective joins
+// can still require N lookups for N undominated peer axes discovered from the
+// semaphore frontier.
+//
+// The table is allocated once at device group init, sized from the topology
+// (device_count * queue_stride). Each queue registers its epoch signal during
+// init and deregisters during deinit. Lookup verifies the axis's session epoch
+// and machine index match this table's — axes from other sessions or machines
+// fail the lookup (tier 3 fallback).
+typedef struct iree_hal_amdgpu_epoch_signal_table_t {
+  // Session epoch from the axis encoding. Used to verify that a lookup axis
+  // belongs to the same session as this table. Prevents cross-session aliasing
+  // if axes from different sessions happen to share device/queue indices.
+  uint8_t session_epoch;
+  // Machine index from the axis encoding. Used to verify that a lookup axis
+  // belongs to the same machine. Cross-machine waits use tier 3 (host deferral)
+  // since there is no shared HSA signal.
+  uint8_t machine_index;
+  // Maximum queues per device (uniform across all devices in the topology).
+  // Columns in the 2D array: signals[device * queue_stride + queue].
+  uint8_t queue_stride;
+  // Number of devices in the table. Rows in the 2D array.
+  uint8_t device_count;
+  uint8_t reserved[4];
+  // Flat 2D array of epoch signals indexed by [device_index * queue_stride +
+  // queue_index]. Unregistered slots have handle == 0 (null signal). Registered
+  // slots contain the epoch signal from the queue's notification ring.
+  hsa_signal_t signals[];
+} iree_hal_amdgpu_epoch_signal_table_t;
+
+// Returns the total allocation size in bytes for an epoch signal table with
+// the given dimensions.
+static inline iree_host_size_t iree_hal_amdgpu_epoch_signal_table_size(
+    uint8_t device_count, uint8_t queue_stride) {
+  // uint8_t * uint8_t cannot overflow iree_host_size_t.
+  return sizeof(iree_hal_amdgpu_epoch_signal_table_t) +
+         (iree_host_size_t)device_count * queue_stride * sizeof(hsa_signal_t);
+}
+
+// Initializes an epoch signal table in caller-provided memory. The caller
+// must have allocated at least iree_hal_amdgpu_epoch_signal_table_size()
+// bytes. All signal slots are zeroed (unregistered).
+static inline void iree_hal_amdgpu_epoch_signal_table_initialize(
+    iree_hal_amdgpu_epoch_signal_table_t* table, uint8_t session_epoch,
+    uint8_t machine_index, uint8_t device_count, uint8_t queue_stride) {
+  table->session_epoch = session_epoch;
+  table->machine_index = machine_index;
+  table->queue_stride = queue_stride;
+  table->device_count = device_count;
+  memset(table->reserved, 0, sizeof(table->reserved));
+  memset(table->signals, 0,
+         (iree_host_size_t)device_count * queue_stride * sizeof(hsa_signal_t));
+}
+
+// Registers a queue's epoch signal in the table. Called during queue init
+// after the notification ring (which owns the epoch signal) is created.
+//
+// The slot must not already be registered (programming error if it is).
+static inline void iree_hal_amdgpu_epoch_signal_table_register(
+    iree_hal_amdgpu_epoch_signal_table_t* table, uint8_t device_index,
+    uint8_t queue_index, hsa_signal_t epoch_signal) {
+  IREE_ASSERT(device_index < table->device_count, "device_index out of range");
+  IREE_ASSERT(queue_index < table->queue_stride, "queue_index out of range");
+  iree_host_size_t slot =
+      (iree_host_size_t)device_index * table->queue_stride + queue_index;
+  IREE_ASSERT(table->signals[slot].handle == 0,
+              "epoch signal slot already registered");
+  IREE_ASSERT(epoch_signal.handle != 0, "cannot register null epoch signal");
+  table->signals[slot] = epoch_signal;
+}
+
+// Deregisters a queue's epoch signal from the table. Called during queue
+// deinit before the notification ring (which owns the epoch signal) is
+// destroyed. The slot must currently be registered.
+static inline void iree_hal_amdgpu_epoch_signal_table_deregister(
+    iree_hal_amdgpu_epoch_signal_table_t* table, uint8_t device_index,
+    uint8_t queue_index) {
+  IREE_ASSERT(device_index < table->device_count, "device_index out of range");
+  IREE_ASSERT(queue_index < table->queue_stride, "queue_index out of range");
+  iree_host_size_t slot =
+      (iree_host_size_t)device_index * table->queue_stride + queue_index;
+  IREE_ASSERT(table->signals[slot].handle != 0,
+              "epoch signal slot not registered");
+  table->signals[slot].handle = 0;
+}
+
+// Looks up the epoch signal for the queue identified by |axis|. Returns true
+// and writes the signal to |out_signal| if the axis matches this table's
+// session/machine, is a QUEUE-domain axis, is within bounds, and the slot
+// is registered. Returns false otherwise (caller should fall back to tier 3).
+//
+// This is the hot-path lookup for tier 2 barrier emission. Two byte
+// comparisons (session + machine), one domain check, two bounds checks,
+// one array index. ~15 instructions.
+static inline bool iree_hal_amdgpu_epoch_signal_table_lookup(
+    const iree_hal_amdgpu_epoch_signal_table_t* table, iree_async_axis_t axis,
+    hsa_signal_t* out_signal) {
+  // Verify this axis is from our session and machine.
+  if (iree_async_axis_session(axis) != table->session_epoch ||
+      iree_async_axis_machine(axis) != table->machine_index) {
+    return false;
+  }
+  // Must be a QUEUE-domain axis (not collective, host, etc.).
+  if (iree_async_axis_domain(axis) != IREE_ASYNC_CAUSAL_DOMAIN_QUEUE) {
+    return false;
+  }
+  uint8_t device_index = iree_async_axis_device_index(axis);
+  uint8_t queue_index = iree_async_axis_queue_index(axis);
+  if (device_index >= table->device_count ||
+      queue_index >= table->queue_stride) {
+    return false;
+  }
+  hsa_signal_t signal =
+      table->signals[(iree_host_size_t)device_index * table->queue_stride +
+                     queue_index];
+  if (signal.handle == 0) return false;  // Slot not registered.
+  *out_signal = signal;
+  return true;
+}
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_UTIL_EPOCH_SIGNAL_TABLE_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/epoch_signal_table_test.cc b/runtime/src/iree/hal/drivers/amdgpu/util/epoch_signal_table_test.cc
new file mode 100644
index 0000000..94031a3
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/epoch_signal_table_test.cc
@@ -0,0 +1,263 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/util/epoch_signal_table.h"
+
+#include <vector>
+
+#include "iree/testing/gtest.h"
+
+namespace {
+
+// Helper to allocate and initialize a table with the given dimensions.
+class EpochSignalTable {
+ public:
+  EpochSignalTable(uint8_t session_epoch, uint8_t machine_index,
+                   uint8_t device_count, uint8_t queue_stride)
+      : device_count_(device_count), queue_stride_(queue_stride) {
+    iree_host_size_t size =
+        iree_hal_amdgpu_epoch_signal_table_size(device_count, queue_stride);
+    storage_.resize(size);
+    table_ = reinterpret_cast<iree_hal_amdgpu_epoch_signal_table_t*>(
+        storage_.data());
+    iree_hal_amdgpu_epoch_signal_table_initialize(
+        table_, session_epoch, machine_index, device_count, queue_stride);
+  }
+
+  iree_hal_amdgpu_epoch_signal_table_t* get() { return table_; }
+  const iree_hal_amdgpu_epoch_signal_table_t* get() const { return table_; }
+
+ private:
+  uint8_t device_count_;
+  uint8_t queue_stride_;
+  std::vector<uint8_t> storage_;
+  iree_hal_amdgpu_epoch_signal_table_t* table_;
+};
+
+// Makes a fake hsa_signal_t with the given handle value. No HSA runtime needed.
+static hsa_signal_t make_signal(uint64_t handle) {
+  hsa_signal_t signal;
+  signal.handle = handle;
+  return signal;
+}
+
+TEST(EpochSignalTable, SizeComputation) {
+  // 1 device, 1 queue: header + 1 signal.
+  EXPECT_EQ(
+      iree_hal_amdgpu_epoch_signal_table_size(1, 1),
+      sizeof(iree_hal_amdgpu_epoch_signal_table_t) + 1 * sizeof(hsa_signal_t));
+  // 8 devices, 4 queues: header + 32 signals.
+  EXPECT_EQ(
+      iree_hal_amdgpu_epoch_signal_table_size(8, 4),
+      sizeof(iree_hal_amdgpu_epoch_signal_table_t) + 32 * sizeof(hsa_signal_t));
+  // 0 devices: header only (degenerate but valid).
+  EXPECT_EQ(iree_hal_amdgpu_epoch_signal_table_size(0, 4),
+            sizeof(iree_hal_amdgpu_epoch_signal_table_t));
+}
+
+TEST(EpochSignalTable, InitializationZerosAllSlots) {
+  EpochSignalTable table(/*session_epoch=*/1, /*machine_index=*/0,
+                         /*device_count=*/4, /*queue_stride=*/2);
+  // Every slot should be unregistered (handle == 0).
+  for (uint8_t device = 0; device < 4; ++device) {
+    for (uint8_t queue = 0; queue < 2; ++queue) {
+      iree_async_axis_t axis = iree_async_axis_make_queue(1, 0, device, queue);
+      hsa_signal_t signal;
+      EXPECT_FALSE(iree_hal_amdgpu_epoch_signal_table_lookup(table.get(), axis,
+                                                             &signal));
+    }
+  }
+}
+
+TEST(EpochSignalTable, RegisterAndLookup) {
+  EpochSignalTable table(/*session_epoch=*/3, /*machine_index=*/7,
+                         /*device_count=*/2, /*queue_stride=*/2);
+
+  // Register device 0, queue 1.
+  iree_hal_amdgpu_epoch_signal_table_register(table.get(), 0, 1,
+                                              make_signal(42));
+
+  // Lookup should succeed with the correct signal.
+  iree_async_axis_t axis = iree_async_axis_make_queue(3, 7, 0, 1);
+  hsa_signal_t signal;
+  EXPECT_TRUE(
+      iree_hal_amdgpu_epoch_signal_table_lookup(table.get(), axis, &signal));
+  EXPECT_EQ(signal.handle, 42u);
+}
+
+TEST(EpochSignalTable, SessionMismatch) {
+  EpochSignalTable table(/*session_epoch=*/3, /*machine_index=*/7,
+                         /*device_count=*/2, /*queue_stride=*/2);
+  iree_hal_amdgpu_epoch_signal_table_register(table.get(), 0, 0,
+                                              make_signal(100));
+
+  // Same machine, different session.
+  iree_async_axis_t axis = iree_async_axis_make_queue(4, 7, 0, 0);
+  hsa_signal_t signal;
+  EXPECT_FALSE(
+      iree_hal_amdgpu_epoch_signal_table_lookup(table.get(), axis, &signal));
+}
+
+TEST(EpochSignalTable, MachineMismatch) {
+  EpochSignalTable table(/*session_epoch=*/3, /*machine_index=*/7,
+                         /*device_count=*/2, /*queue_stride=*/2);
+  iree_hal_amdgpu_epoch_signal_table_register(table.get(), 0, 0,
+                                              make_signal(100));
+
+  // Same session, different machine.
+  iree_async_axis_t axis = iree_async_axis_make_queue(3, 8, 0, 0);
+  hsa_signal_t signal;
+  EXPECT_FALSE(
+      iree_hal_amdgpu_epoch_signal_table_lookup(table.get(), axis, &signal));
+}
+
+TEST(EpochSignalTable, NonQueueDomainRejected) {
+  EpochSignalTable table(/*session_epoch=*/1, /*machine_index=*/0,
+                         /*device_count=*/4, /*queue_stride=*/4);
+  iree_hal_amdgpu_epoch_signal_table_register(table.get(), 0, 0,
+                                              make_signal(100));
+
+  // COLLECTIVE domain axis with the same ordinal bits.
+  iree_async_axis_t collective_axis =
+      iree_async_axis_make(1, 0, IREE_ASYNC_CAUSAL_DOMAIN_COLLECTIVE, 0);
+  hsa_signal_t signal;
+  EXPECT_FALSE(iree_hal_amdgpu_epoch_signal_table_lookup(
+      table.get(), collective_axis, &signal));
+
+  // HOST domain.
+  iree_async_axis_t host_axis =
+      iree_async_axis_make(1, 0, IREE_ASYNC_CAUSAL_DOMAIN_HOST, 0);
+  EXPECT_FALSE(iree_hal_amdgpu_epoch_signal_table_lookup(table.get(), host_axis,
+                                                         &signal));
+}
+
+TEST(EpochSignalTable, DeviceIndexOutOfBounds) {
+  EpochSignalTable table(/*session_epoch=*/1, /*machine_index=*/0,
+                         /*device_count=*/2, /*queue_stride=*/2);
+
+  // device_index 2 is out of bounds (only 0 and 1 exist).
+  iree_async_axis_t axis = iree_async_axis_make_queue(1, 0, 2, 0);
+  hsa_signal_t signal;
+  EXPECT_FALSE(
+      iree_hal_amdgpu_epoch_signal_table_lookup(table.get(), axis, &signal));
+}
+
+TEST(EpochSignalTable, QueueIndexOutOfBounds) {
+  EpochSignalTable table(/*session_epoch=*/1, /*machine_index=*/0,
+                         /*device_count=*/2, /*queue_stride=*/2);
+
+  // queue_index 2 is out of bounds (stride is 2, so only 0 and 1).
+  iree_async_axis_t axis = iree_async_axis_make_queue(1, 0, 0, 2);
+  hsa_signal_t signal;
+  EXPECT_FALSE(
+      iree_hal_amdgpu_epoch_signal_table_lookup(table.get(), axis, &signal));
+}
+
+TEST(EpochSignalTable, UnregisteredSlotReturnsFalse) {
+  EpochSignalTable table(/*session_epoch=*/1, /*machine_index=*/0,
+                         /*device_count=*/4, /*queue_stride=*/4);
+
+  // Register only slot (1, 2).
+  iree_hal_amdgpu_epoch_signal_table_register(table.get(), 1, 2,
+                                              make_signal(99));
+
+  // Adjacent unregistered slot (1, 3) should fail.
+  iree_async_axis_t axis = iree_async_axis_make_queue(1, 0, 1, 3);
+  hsa_signal_t signal;
+  EXPECT_FALSE(
+      iree_hal_amdgpu_epoch_signal_table_lookup(table.get(), axis, &signal));
+
+  // Registered slot (1, 2) should succeed.
+  axis = iree_async_axis_make_queue(1, 0, 1, 2);
+  EXPECT_TRUE(
+      iree_hal_amdgpu_epoch_signal_table_lookup(table.get(), axis, &signal));
+  EXPECT_EQ(signal.handle, 99u);
+}
+
+TEST(EpochSignalTable, Deregister) {
+  EpochSignalTable table(/*session_epoch=*/1, /*machine_index=*/0,
+                         /*device_count=*/2, /*queue_stride=*/2);
+  iree_hal_amdgpu_epoch_signal_table_register(table.get(), 1, 0,
+                                              make_signal(77));
+
+  // Verify registered.
+  iree_async_axis_t axis = iree_async_axis_make_queue(1, 0, 1, 0);
+  hsa_signal_t signal;
+  EXPECT_TRUE(
+      iree_hal_amdgpu_epoch_signal_table_lookup(table.get(), axis, &signal));
+  EXPECT_EQ(signal.handle, 77u);
+
+  // Deregister.
+  iree_hal_amdgpu_epoch_signal_table_deregister(table.get(), 1, 0);
+
+  // Should no longer be found.
+  EXPECT_FALSE(
+      iree_hal_amdgpu_epoch_signal_table_lookup(table.get(), axis, &signal));
+}
+
+TEST(EpochSignalTable, MultiSlotIndependence) {
+  // 4 devices × 2 queues = 8 slots. Register unique signals in each.
+  EpochSignalTable table(/*session_epoch=*/5, /*machine_index=*/2,
+                         /*device_count=*/4, /*queue_stride=*/2);
+
+  for (uint8_t device = 0; device < 4; ++device) {
+    for (uint8_t queue = 0; queue < 2; ++queue) {
+      uint64_t handle = (uint64_t)device * 100 + queue + 1;
+      iree_hal_amdgpu_epoch_signal_table_register(table.get(), device, queue,
+                                                  make_signal(handle));
+    }
+  }
+
+  // Verify each slot returns its unique signal.
+  for (uint8_t device = 0; device < 4; ++device) {
+    for (uint8_t queue = 0; queue < 2; ++queue) {
+      iree_async_axis_t axis = iree_async_axis_make_queue(5, 2, device, queue);
+      hsa_signal_t signal;
+      ASSERT_TRUE(iree_hal_amdgpu_epoch_signal_table_lookup(table.get(), axis,
+                                                            &signal));
+      uint64_t expected_handle = (uint64_t)device * 100 + queue + 1;
+      EXPECT_EQ(signal.handle, expected_handle);
+    }
+  }
+}
+
+TEST(EpochSignalTable, TPCollectiveJoinPattern) {
+  // Simulates the 4-GPU TP collective join: Q0 needs to look up epoch signals
+  // for Q1, Q2, Q3 to emit barrier-value packets.
+  const uint8_t session = 1;
+  const uint8_t machine = 0;
+  const uint8_t device_count = 4;
+  const uint8_t queues_per_device = 1;
+
+  EpochSignalTable table(session, machine, device_count, queues_per_device);
+
+  // Each device has one queue with a unique epoch signal.
+  for (uint8_t device = 0; device < device_count; ++device) {
+    iree_hal_amdgpu_epoch_signal_table_register(table.get(), device, 0,
+                                                make_signal(1000 + device));
+  }
+
+  // Q0 (device 0) needs to wait on Q1, Q2, Q3. Look up their signals.
+  for (uint8_t peer = 1; peer < device_count; ++peer) {
+    iree_async_axis_t peer_axis =
+        iree_async_axis_make_queue(session, machine, peer, 0);
+    hsa_signal_t peer_signal;
+    ASSERT_TRUE(iree_hal_amdgpu_epoch_signal_table_lookup(
+        table.get(), peer_axis, &peer_signal));
+    EXPECT_EQ(peer_signal.handle, 1000u + peer);
+  }
+
+  // Q0's own signal should also be in the table (for other queues looking it
+  // up), though Q0 wouldn't barrier on itself.
+  iree_async_axis_t self_axis =
+      iree_async_axis_make_queue(session, machine, 0, 0);
+  hsa_signal_t self_signal;
+  ASSERT_TRUE(iree_hal_amdgpu_epoch_signal_table_lookup(table.get(), self_axis,
+                                                        &self_signal));
+  EXPECT_EQ(self_signal.handle, 1000u);
+}
+
+}  // namespace
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/hsaco_metadata.c b/runtime/src/iree/hal/drivers/amdgpu/util/hsaco_metadata.c
new file mode 100644
index 0000000..7e529ef
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/hsaco_metadata.c
@@ -0,0 +1,1411 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/util/hsaco_metadata.h"
+
+#include <string.h>
+
+//===----------------------------------------------------------------------===//
+// ELF note discovery
+//===----------------------------------------------------------------------===//
+
+#define IREE_HAL_AMDGPU_ELF_MAGIC0 0x7F
+#define IREE_HAL_AMDGPU_ELF_MAGIC1 'E'
+#define IREE_HAL_AMDGPU_ELF_MAGIC2 'L'
+#define IREE_HAL_AMDGPU_ELF_MAGIC3 'F'
+#define IREE_HAL_AMDGPU_ELF_CLASS_64 2
+#define IREE_HAL_AMDGPU_ELF_DATA_LITTLE 1
+#define IREE_HAL_AMDGPU_ELF_VERSION_CURRENT 1
+#define IREE_HAL_AMDGPU_ELF_MACHINE_AMDGPU 224
+#define IREE_HAL_AMDGPU_ELF_PT_NOTE 4
+#define IREE_HAL_AMDGPU_ELF_NOTE_AMDGPU_METADATA 32
+
+#define IREE_HAL_AMDGPU_ELF64_HEADER_SIZE 64
+#define IREE_HAL_AMDGPU_ELF64_PROGRAM_HEADER_SIZE 56
+
+static uint16_t iree_hal_amdgpu_hsaco_metadata_load_le_u16(const uint8_t* ptr) {
+  return iree_unaligned_load_le((const uint16_t*)ptr);
+}
+
+static uint32_t iree_hal_amdgpu_hsaco_metadata_load_le_u32(const uint8_t* ptr) {
+  return iree_unaligned_load_le((const uint32_t*)ptr);
+}
+
+static uint64_t iree_hal_amdgpu_hsaco_metadata_load_le_u64(const uint8_t* ptr) {
+  return iree_unaligned_load_le((const uint64_t*)ptr);
+}
+
+static bool iree_hal_amdgpu_hsaco_metadata_range_in_bounds(
+    iree_host_size_t offset, iree_host_size_t length, iree_host_size_t limit) {
+  iree_host_size_t end = 0;
+  return offset <= limit && iree_host_size_checked_add(offset, length, &end) &&
+         end <= limit;
+}
+
+static bool iree_hal_amdgpu_hsaco_metadata_u64_to_host_size(
+    uint64_t value, iree_host_size_t* out_value) {
+  if (value > (uint64_t)IREE_HOST_SIZE_MAX) return false;
+  *out_value = (iree_host_size_t)value;
+  return true;
+}
+
+static iree_status_t iree_hal_amdgpu_hsaco_metadata_checked_align4(
+    iree_host_size_t value, iree_host_size_t* out_aligned_value) {
+  if (!iree_host_size_checked_align(value, 4, out_aligned_value)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "AMDGPU ELF note alignment overflow");
+  }
+  return iree_ok_status();
+}
+
+static iree_string_view_t iree_hal_amdgpu_hsaco_metadata_note_name_view(
+    const uint8_t* data, iree_host_size_t length) {
+  if (length > 0 && data[length - 1] == 0) --length;
+  return iree_make_string_view((const char*)data, length);
+}
+
+static iree_status_t iree_hal_amdgpu_hsaco_metadata_scan_note_segment(
+    iree_const_byte_span_t segment_data,
+    iree_const_byte_span_t* out_message_pack_data, bool* out_found) {
+  *out_found = false;
+  iree_host_size_t offset = 0;
+  while (segment_data.data_length - offset >= 12) {
+    const uint8_t* note_header = segment_data.data + offset;
+    const uint32_t name_size =
+        iree_hal_amdgpu_hsaco_metadata_load_le_u32(note_header + 0);
+    const uint32_t desc_size =
+        iree_hal_amdgpu_hsaco_metadata_load_le_u32(note_header + 4);
+    const uint32_t note_type =
+        iree_hal_amdgpu_hsaco_metadata_load_le_u32(note_header + 8);
+    offset += 12;
+
+    const iree_host_size_t name_offset = offset;
+    if (!iree_hal_amdgpu_hsaco_metadata_range_in_bounds(
+            name_offset, name_size, segment_data.data_length)) {
+      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                              "AMDGPU ELF note name exceeds PT_NOTE bounds");
+    }
+    iree_host_size_t desc_offset = 0;
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_hsaco_metadata_checked_align4(
+        name_offset + name_size, &desc_offset));
+    if (desc_offset > segment_data.data_length) {
+      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                              "AMDGPU ELF note descriptor offset exceeds "
+                              "PT_NOTE bounds");
+    }
+    if (!iree_hal_amdgpu_hsaco_metadata_range_in_bounds(
+            desc_offset, desc_size, segment_data.data_length)) {
+      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                              "AMDGPU ELF note descriptor exceeds PT_NOTE "
+                              "bounds");
+    }
+    iree_host_size_t next_offset = 0;
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_hsaco_metadata_checked_align4(
+        desc_offset + desc_size, &next_offset));
+    if (next_offset > segment_data.data_length) {
+      // Some producers omit final padding from the segment size. The descriptor
+      // itself is still fully present, so allow the final record to end exactly
+      // at the segment end.
+      if (desc_offset + desc_size == segment_data.data_length) {
+        next_offset = segment_data.data_length;
+      } else {
+        return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                                "AMDGPU ELF note padding exceeds PT_NOTE "
+                                "bounds");
+      }
+    }
+
+    iree_string_view_t note_name =
+        iree_hal_amdgpu_hsaco_metadata_note_name_view(
+            segment_data.data + name_offset, name_size);
+    if (note_type == IREE_HAL_AMDGPU_ELF_NOTE_AMDGPU_METADATA &&
+        iree_string_view_equal(note_name, IREE_SV("AMDGPU"))) {
+      *out_message_pack_data =
+          iree_make_const_byte_span(segment_data.data + desc_offset, desc_size);
+      *out_found = true;
+      return iree_ok_status();
+    }
+
+    offset = next_offset;
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_hsaco_metadata_find_note(
+    iree_const_byte_span_t elf_data,
+    iree_const_byte_span_t* out_message_pack_data) {
+  *out_message_pack_data = iree_const_byte_span_empty();
+  if (elf_data.data_length < IREE_HAL_AMDGPU_ELF64_HEADER_SIZE) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AMDGPU ELF data too small");
+  }
+  const uint8_t* header = elf_data.data;
+  if (header[0] != IREE_HAL_AMDGPU_ELF_MAGIC0 ||
+      header[1] != IREE_HAL_AMDGPU_ELF_MAGIC1 ||
+      header[2] != IREE_HAL_AMDGPU_ELF_MAGIC2 ||
+      header[3] != IREE_HAL_AMDGPU_ELF_MAGIC3) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AMDGPU metadata input is not an ELF file");
+  }
+  if (header[4] != IREE_HAL_AMDGPU_ELF_CLASS_64) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AMDGPU metadata ELF must be 64-bit");
+  }
+  if (header[5] != IREE_HAL_AMDGPU_ELF_DATA_LITTLE) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AMDGPU metadata ELF must be little-endian");
+  }
+  if (header[6] != IREE_HAL_AMDGPU_ELF_VERSION_CURRENT) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AMDGPU metadata ELF has unsupported version");
+  }
+  const uint16_t machine =
+      iree_hal_amdgpu_hsaco_metadata_load_le_u16(header + 18);
+  if (machine != IREE_HAL_AMDGPU_ELF_MACHINE_AMDGPU) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AMDGPU metadata ELF has non-AMDGPU machine %u",
+                            machine);
+  }
+
+  iree_host_size_t program_header_offset = 0;
+  if (!iree_hal_amdgpu_hsaco_metadata_u64_to_host_size(
+          iree_hal_amdgpu_hsaco_metadata_load_le_u64(header + 32),
+          &program_header_offset)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "AMDGPU ELF program header offset overflows host "
+                            "size");
+  }
+  const uint16_t program_header_entry_size =
+      iree_hal_amdgpu_hsaco_metadata_load_le_u16(header + 54);
+  const uint16_t program_header_count =
+      iree_hal_amdgpu_hsaco_metadata_load_le_u16(header + 56);
+  if (program_header_count == 0) {
+    return iree_make_status(IREE_STATUS_NOT_FOUND,
+                            "AMDGPU ELF has no program headers");
+  }
+  if (program_header_entry_size < IREE_HAL_AMDGPU_ELF64_PROGRAM_HEADER_SIZE) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AMDGPU ELF program header entries are too small");
+  }
+  iree_host_size_t program_headers_size = 0;
+  if (!iree_host_size_checked_mul(program_header_count,
+                                  program_header_entry_size,
+                                  &program_headers_size) ||
+      !iree_hal_amdgpu_hsaco_metadata_range_in_bounds(
+          program_header_offset, program_headers_size, elf_data.data_length)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AMDGPU ELF program headers exceed file bounds");
+  }
+
+  for (uint16_t i = 0; i < program_header_count; ++i) {
+    const uint8_t* program_header =
+        elf_data.data + program_header_offset +
+        (iree_host_size_t)i * program_header_entry_size;
+    const uint32_t program_header_type =
+        iree_hal_amdgpu_hsaco_metadata_load_le_u32(program_header + 0);
+    if (program_header_type != IREE_HAL_AMDGPU_ELF_PT_NOTE) continue;
+
+    iree_host_size_t note_offset = 0;
+    iree_host_size_t note_size = 0;
+    if (!iree_hal_amdgpu_hsaco_metadata_u64_to_host_size(
+            iree_hal_amdgpu_hsaco_metadata_load_le_u64(program_header + 8),
+            &note_offset) ||
+        !iree_hal_amdgpu_hsaco_metadata_u64_to_host_size(
+            iree_hal_amdgpu_hsaco_metadata_load_le_u64(program_header + 32),
+            &note_size)) {
+      return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                              "AMDGPU ELF PT_NOTE range overflows host size");
+    }
+    if (!iree_hal_amdgpu_hsaco_metadata_range_in_bounds(note_offset, note_size,
+                                                        elf_data.data_length)) {
+      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                              "AMDGPU ELF PT_NOTE exceeds file bounds");
+    }
+    bool found = false;
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_hsaco_metadata_scan_note_segment(
+        iree_make_const_byte_span(elf_data.data + note_offset, note_size),
+        out_message_pack_data, &found));
+    if (found) return iree_ok_status();
+  }
+
+  return iree_make_status(IREE_STATUS_NOT_FOUND,
+                          "AMDGPU metadata note not found");
+}
+
+//===----------------------------------------------------------------------===//
+// MessagePack reader
+//===----------------------------------------------------------------------===//
+
+typedef struct iree_hal_amdgpu_msgpack_reader_t {
+  const uint8_t* current;
+  const uint8_t* end;
+} iree_hal_amdgpu_msgpack_reader_t;
+
+static iree_host_size_t iree_hal_amdgpu_msgpack_remaining(
+    const iree_hal_amdgpu_msgpack_reader_t* reader) {
+  return (iree_host_size_t)(reader->end - reader->current);
+}
+
+static iree_status_t iree_hal_amdgpu_msgpack_require(
+    const iree_hal_amdgpu_msgpack_reader_t* reader, iree_host_size_t length) {
+  if (iree_hal_amdgpu_msgpack_remaining(reader) < length) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "truncated AMDGPU MessagePack metadata");
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_msgpack_read_u8(
+    iree_hal_amdgpu_msgpack_reader_t* reader, uint8_t* out_value) {
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_require(reader, 1));
+  *out_value = *reader->current++;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_msgpack_read_be_u16(
+    iree_hal_amdgpu_msgpack_reader_t* reader, uint16_t* out_value) {
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_require(reader, 2));
+  *out_value = ((uint16_t)reader->current[0] << 8) | reader->current[1];
+  reader->current += 2;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_msgpack_read_be_u32(
+    iree_hal_amdgpu_msgpack_reader_t* reader, uint32_t* out_value) {
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_require(reader, 4));
+  *out_value = ((uint32_t)reader->current[0] << 24) |
+               ((uint32_t)reader->current[1] << 16) |
+               ((uint32_t)reader->current[2] << 8) | reader->current[3];
+  reader->current += 4;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_msgpack_read_be_u64(
+    iree_hal_amdgpu_msgpack_reader_t* reader, uint64_t* out_value) {
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_require(reader, 8));
+  *out_value = ((uint64_t)reader->current[0] << 56) |
+               ((uint64_t)reader->current[1] << 48) |
+               ((uint64_t)reader->current[2] << 40) |
+               ((uint64_t)reader->current[3] << 32) |
+               ((uint64_t)reader->current[4] << 24) |
+               ((uint64_t)reader->current[5] << 16) |
+               ((uint64_t)reader->current[6] << 8) | reader->current[7];
+  reader->current += 8;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_msgpack_skip_bytes(
+    iree_hal_amdgpu_msgpack_reader_t* reader, iree_host_size_t length) {
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_require(reader, length));
+  reader->current += length;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_msgpack_read_count_after_tag(
+    iree_hal_amdgpu_msgpack_reader_t* reader, uint8_t tag, uint8_t fix_base,
+    uint8_t type16, uint8_t type32, uint32_t* out_count) {
+  if ((tag & 0xF0u) == fix_base) {
+    *out_count = tag & 0x0Fu;
+    return iree_ok_status();
+  }
+  if (tag == type16) {
+    uint16_t value = 0;
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_read_be_u16(reader, &value));
+    *out_count = value;
+    return iree_ok_status();
+  }
+  if (tag == type32) {
+    return iree_hal_amdgpu_msgpack_read_be_u32(reader, out_count);
+  }
+  return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                          "unexpected AMDGPU MessagePack container tag 0x%02X",
+                          tag);
+}
+
+static iree_status_t iree_hal_amdgpu_msgpack_read_map_count(
+    iree_hal_amdgpu_msgpack_reader_t* reader, uint32_t* out_count) {
+  uint8_t tag = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_read_u8(reader, &tag));
+  return iree_hal_amdgpu_msgpack_read_count_after_tag(reader, tag, 0x80, 0xDE,
+                                                      0xDF, out_count);
+}
+
+static iree_status_t iree_hal_amdgpu_msgpack_read_array_count(
+    iree_hal_amdgpu_msgpack_reader_t* reader, uint32_t* out_count) {
+  uint8_t tag = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_read_u8(reader, &tag));
+  if ((tag & 0xF0u) == 0x90u) {
+    *out_count = tag & 0x0Fu;
+    return iree_ok_status();
+  }
+  if (tag == 0xDC) {
+    uint16_t value = 0;
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_read_be_u16(reader, &value));
+    *out_count = value;
+    return iree_ok_status();
+  }
+  if (tag == 0xDD) {
+    return iree_hal_amdgpu_msgpack_read_be_u32(reader, out_count);
+  }
+  return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                          "expected AMDGPU MessagePack array tag, got 0x%02X",
+                          tag);
+}
+
+static iree_status_t iree_hal_amdgpu_msgpack_read_string_after_tag(
+    iree_hal_amdgpu_msgpack_reader_t* reader, uint8_t tag,
+    iree_string_view_t* out_value) {
+  uint32_t length = 0;
+  if ((tag & 0xE0u) == 0xA0u) {
+    length = tag & 0x1Fu;
+  } else if (tag == 0xD9) {
+    uint8_t value = 0;
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_read_u8(reader, &value));
+    length = value;
+  } else if (tag == 0xDA) {
+    uint16_t value = 0;
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_read_be_u16(reader, &value));
+    length = value;
+  } else if (tag == 0xDB) {
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_read_be_u32(reader, &length));
+  } else {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "expected AMDGPU MessagePack string tag, got "
+                            "0x%02X",
+                            tag);
+  }
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_require(reader, length));
+  *out_value = iree_make_string_view((const char*)reader->current, length);
+  reader->current += length;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_msgpack_read_string(
+    iree_hal_amdgpu_msgpack_reader_t* reader, iree_string_view_t* out_value) {
+  uint8_t tag = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_read_u8(reader, &tag));
+  return iree_hal_amdgpu_msgpack_read_string_after_tag(reader, tag, out_value);
+}
+
+static iree_status_t iree_hal_amdgpu_msgpack_read_uint64(
+    iree_hal_amdgpu_msgpack_reader_t* reader, uint64_t* out_value) {
+  uint8_t tag = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_read_u8(reader, &tag));
+  if (tag <= 0x7F) {
+    *out_value = tag;
+    return iree_ok_status();
+  }
+  switch (tag) {
+    case 0xCC: {
+      uint8_t value = 0;
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_read_u8(reader, &value));
+      *out_value = value;
+      return iree_ok_status();
+    }
+    case 0xCD: {
+      uint16_t value = 0;
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_read_be_u16(reader, &value));
+      *out_value = value;
+      return iree_ok_status();
+    }
+    case 0xCE: {
+      uint32_t value = 0;
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_read_be_u32(reader, &value));
+      *out_value = value;
+      return iree_ok_status();
+    }
+    case 0xCF:
+      return iree_hal_amdgpu_msgpack_read_be_u64(reader, out_value);
+    case 0xD0: {
+      uint8_t value = 0;
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_read_u8(reader, &value));
+      if (value & 0x80u) break;
+      *out_value = value;
+      return iree_ok_status();
+    }
+    case 0xD1: {
+      uint16_t value = 0;
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_read_be_u16(reader, &value));
+      if (value & 0x8000u) break;
+      *out_value = value;
+      return iree_ok_status();
+    }
+    case 0xD2: {
+      uint32_t value = 0;
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_read_be_u32(reader, &value));
+      if (value & 0x80000000u) break;
+      *out_value = value;
+      return iree_ok_status();
+    }
+    case 0xD3: {
+      uint64_t value = 0;
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_read_be_u64(reader, &value));
+      if (value & 0x8000000000000000ull) break;
+      *out_value = value;
+      return iree_ok_status();
+    }
+    default:
+      break;
+  }
+  return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                          "expected non-negative AMDGPU MessagePack integer");
+}
+
+static iree_status_t iree_hal_amdgpu_msgpack_read_uint32(
+    iree_hal_amdgpu_msgpack_reader_t* reader, uint32_t* out_value) {
+  uint64_t value = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_read_uint64(reader, &value));
+  if (value > UINT32_MAX) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "AMDGPU metadata integer exceeds uint32_t");
+  }
+  *out_value = (uint32_t)value;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_msgpack_skip(
+    iree_hal_amdgpu_msgpack_reader_t* reader, uint32_t depth) {
+  if (depth > 64) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AMDGPU MessagePack metadata is nested too deeply");
+  }
+  uint8_t tag = 0;
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_read_u8(reader, &tag));
+  if (tag <= 0x7F || tag >= 0xE0 || tag == 0xC0 || tag == 0xC2 || tag == 0xC3) {
+    return iree_ok_status();
+  }
+  if ((tag & 0xE0u) == 0xA0u) {
+    return iree_hal_amdgpu_msgpack_skip_bytes(reader, tag & 0x1Fu);
+  }
+  if ((tag & 0xF0u) == 0x90u) {
+    const uint32_t count = tag & 0x0Fu;
+    for (uint32_t i = 0; i < count; ++i) {
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_skip(reader, depth + 1));
+    }
+    return iree_ok_status();
+  }
+  if ((tag & 0xF0u) == 0x80u) {
+    const uint32_t count = tag & 0x0Fu;
+    for (uint32_t i = 0; i < count; ++i) {
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_skip(reader, depth + 1));
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_skip(reader, depth + 1));
+    }
+    return iree_ok_status();
+  }
+
+  switch (tag) {
+    case 0xC4: {
+      uint8_t length = 0;
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_read_u8(reader, &length));
+      return iree_hal_amdgpu_msgpack_skip_bytes(reader, length);
+    }
+    case 0xC5: {
+      uint16_t length = 0;
+      IREE_RETURN_IF_ERROR(
+          iree_hal_amdgpu_msgpack_read_be_u16(reader, &length));
+      return iree_hal_amdgpu_msgpack_skip_bytes(reader, length);
+    }
+    case 0xC6: {
+      uint32_t length = 0;
+      IREE_RETURN_IF_ERROR(
+          iree_hal_amdgpu_msgpack_read_be_u32(reader, &length));
+      return iree_hal_amdgpu_msgpack_skip_bytes(reader, length);
+    }
+    case 0xC7: {
+      uint8_t length = 0;
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_read_u8(reader, &length));
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_skip_bytes(reader, 1));
+      return iree_hal_amdgpu_msgpack_skip_bytes(reader, length);
+    }
+    case 0xC8: {
+      uint16_t length = 0;
+      IREE_RETURN_IF_ERROR(
+          iree_hal_amdgpu_msgpack_read_be_u16(reader, &length));
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_skip_bytes(reader, 1));
+      return iree_hal_amdgpu_msgpack_skip_bytes(reader, length);
+    }
+    case 0xC9: {
+      uint32_t length = 0;
+      IREE_RETURN_IF_ERROR(
+          iree_hal_amdgpu_msgpack_read_be_u32(reader, &length));
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_skip_bytes(reader, 1));
+      return iree_hal_amdgpu_msgpack_skip_bytes(reader, length);
+    }
+    case 0xCA:
+      return iree_hal_amdgpu_msgpack_skip_bytes(reader, 4);
+    case 0xCB:
+      return iree_hal_amdgpu_msgpack_skip_bytes(reader, 8);
+    case 0xCC:
+    case 0xD0:
+      return iree_hal_amdgpu_msgpack_skip_bytes(reader, 1);
+    case 0xCD:
+    case 0xD1:
+      return iree_hal_amdgpu_msgpack_skip_bytes(reader, 2);
+    case 0xCE:
+    case 0xD2:
+      return iree_hal_amdgpu_msgpack_skip_bytes(reader, 4);
+    case 0xCF:
+    case 0xD3:
+      return iree_hal_amdgpu_msgpack_skip_bytes(reader, 8);
+    case 0xD4:
+      return iree_hal_amdgpu_msgpack_skip_bytes(reader, 2);
+    case 0xD5:
+      return iree_hal_amdgpu_msgpack_skip_bytes(reader, 3);
+    case 0xD6:
+      return iree_hal_amdgpu_msgpack_skip_bytes(reader, 5);
+    case 0xD7:
+      return iree_hal_amdgpu_msgpack_skip_bytes(reader, 9);
+    case 0xD8:
+      return iree_hal_amdgpu_msgpack_skip_bytes(reader, 17);
+    case 0xD9: {
+      uint8_t length = 0;
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_read_u8(reader, &length));
+      return iree_hal_amdgpu_msgpack_skip_bytes(reader, length);
+    }
+    case 0xDA: {
+      uint16_t length = 0;
+      IREE_RETURN_IF_ERROR(
+          iree_hal_amdgpu_msgpack_read_be_u16(reader, &length));
+      return iree_hal_amdgpu_msgpack_skip_bytes(reader, length);
+    }
+    case 0xDB: {
+      uint32_t length = 0;
+      IREE_RETURN_IF_ERROR(
+          iree_hal_amdgpu_msgpack_read_be_u32(reader, &length));
+      return iree_hal_amdgpu_msgpack_skip_bytes(reader, length);
+    }
+    case 0xDC:
+    case 0xDD: {
+      uint32_t count = 0;
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_read_count_after_tag(
+          reader, tag, 0x90, 0xDC, 0xDD, &count));
+      for (uint32_t i = 0; i < count; ++i) {
+        IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_skip(reader, depth + 1));
+      }
+      return iree_ok_status();
+    }
+    case 0xDE:
+    case 0xDF: {
+      uint32_t count = 0;
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_read_count_after_tag(
+          reader, tag, 0x80, 0xDE, 0xDF, &count));
+      for (uint32_t i = 0; i < count; ++i) {
+        IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_skip(reader, depth + 1));
+        IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_skip(reader, depth + 1));
+      }
+      return iree_ok_status();
+    }
+    default:
+      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                              "unsupported AMDGPU MessagePack tag 0x%02X", tag);
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// AMDGPU metadata decoding
+//===----------------------------------------------------------------------===//
+
+typedef struct iree_hal_amdgpu_hsaco_metadata_count_t {
+  iree_host_size_t kernel_count;
+  iree_host_size_t arg_count;
+} iree_hal_amdgpu_hsaco_metadata_count_t;
+
+typedef struct iree_hal_amdgpu_hsaco_metadata_kernel_fields_t {
+  bool has_name;
+  bool has_symbol_name;
+  bool has_kernarg_segment_size;
+  bool has_kernarg_segment_alignment;
+  bool has_group_segment_fixed_size;
+  bool has_private_segment_fixed_size;
+  bool has_required_workgroup_size;
+  bool has_args;
+} iree_hal_amdgpu_hsaco_metadata_kernel_fields_t;
+
+typedef struct iree_hal_amdgpu_hsaco_metadata_arg_fields_t {
+  bool has_name;
+  bool has_offset;
+  bool has_size;
+  bool has_value_kind;
+  bool has_address_space;
+  bool has_access;
+  bool has_actual_access;
+  bool has_alignment;
+} iree_hal_amdgpu_hsaco_metadata_arg_fields_t;
+
+static iree_status_t iree_hal_amdgpu_hsaco_metadata_duplicate_field_status(
+    iree_string_view_t key) {
+  return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                          "AMDGPU metadata repeats field `%.*s`", (int)key.size,
+                          key.data);
+}
+
+static iree_hal_amdgpu_hsaco_metadata_arg_kind_t
+iree_hal_amdgpu_hsaco_metadata_classify_arg_kind(
+    iree_string_view_t value_kind) {
+  if (iree_string_view_equal(value_kind, IREE_SV("by_value"))) {
+    return IREE_HAL_AMDGPU_HSACO_METADATA_ARG_KIND_BY_VALUE;
+  }
+  if (iree_string_view_equal(value_kind, IREE_SV("global_buffer"))) {
+    return IREE_HAL_AMDGPU_HSACO_METADATA_ARG_KIND_GLOBAL_BUFFER;
+  }
+  if (iree_string_view_equal(value_kind, IREE_SV("dynamic_shared_pointer"))) {
+    return IREE_HAL_AMDGPU_HSACO_METADATA_ARG_KIND_DYNAMIC_SHARED_POINTER;
+  }
+  if (iree_string_view_equal(value_kind, IREE_SV("image"))) {
+    return IREE_HAL_AMDGPU_HSACO_METADATA_ARG_KIND_IMAGE;
+  }
+  if (iree_string_view_equal(value_kind, IREE_SV("sampler"))) {
+    return IREE_HAL_AMDGPU_HSACO_METADATA_ARG_KIND_SAMPLER;
+  }
+  if (iree_string_view_equal(value_kind, IREE_SV("pipe"))) {
+    return IREE_HAL_AMDGPU_HSACO_METADATA_ARG_KIND_PIPE;
+  }
+  if (iree_string_view_equal(value_kind, IREE_SV("queue"))) {
+    return IREE_HAL_AMDGPU_HSACO_METADATA_ARG_KIND_QUEUE;
+  }
+  if (iree_string_view_equal(value_kind, IREE_SV("hidden_none"))) {
+    return IREE_HAL_AMDGPU_HSACO_METADATA_ARG_KIND_HIDDEN_NONE;
+  }
+  if (iree_string_view_starts_with(value_kind, IREE_SV("hidden_"))) {
+    return IREE_HAL_AMDGPU_HSACO_METADATA_ARG_KIND_HIDDEN;
+  }
+  return IREE_HAL_AMDGPU_HSACO_METADATA_ARG_KIND_UNKNOWN;
+}
+
+static iree_status_t iree_hal_amdgpu_hsaco_metadata_count_kernel_args(
+    iree_hal_amdgpu_msgpack_reader_t* reader,
+    iree_host_size_t* inout_arg_count) {
+  uint32_t field_count = 0;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_msgpack_read_map_count(reader, &field_count));
+  bool has_args = false;
+  for (uint32_t i = 0; i < field_count; ++i) {
+    iree_string_view_t key = iree_string_view_empty();
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_read_string(reader, &key));
+    if (iree_string_view_equal(key, IREE_SV(".args"))) {
+      if (has_args) {
+        return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                                "AMDGPU kernel metadata repeats `.args`");
+      }
+      has_args = true;
+      uint32_t arg_count = 0;
+      IREE_RETURN_IF_ERROR(
+          iree_hal_amdgpu_msgpack_read_array_count(reader, &arg_count));
+      if (!iree_host_size_checked_add(*inout_arg_count, arg_count,
+                                      inout_arg_count)) {
+        return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                                "AMDGPU metadata argument count overflow");
+      }
+      for (uint32_t j = 0; j < arg_count; ++j) {
+        IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_skip(reader, 0));
+      }
+    } else {
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_skip(reader, 0));
+    }
+  }
+  if (!has_args) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AMDGPU kernel metadata missing `.args`");
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_hsaco_metadata_count_message_pack(
+    iree_const_byte_span_t message_pack_data,
+    iree_hal_amdgpu_hsaco_metadata_count_t* out_count) {
+  memset(out_count, 0, sizeof(*out_count));
+  iree_hal_amdgpu_msgpack_reader_t reader = {
+      .current = message_pack_data.data,
+      .end = message_pack_data.data + message_pack_data.data_length,
+  };
+  uint32_t root_field_count = 0;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_msgpack_read_map_count(&reader, &root_field_count));
+  bool has_kernels = false;
+  for (uint32_t i = 0; i < root_field_count; ++i) {
+    iree_string_view_t key = iree_string_view_empty();
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_read_string(&reader, &key));
+    if (iree_string_view_equal(key, IREE_SV("amdhsa.kernels"))) {
+      if (has_kernels) {
+        return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                                "AMDGPU metadata repeats `amdhsa.kernels`");
+      }
+      has_kernels = true;
+      uint32_t kernel_count = 0;
+      IREE_RETURN_IF_ERROR(
+          iree_hal_amdgpu_msgpack_read_array_count(&reader, &kernel_count));
+      out_count->kernel_count = kernel_count;
+      for (uint32_t j = 0; j < kernel_count; ++j) {
+        IREE_RETURN_IF_ERROR(iree_hal_amdgpu_hsaco_metadata_count_kernel_args(
+            &reader, &out_count->arg_count));
+      }
+    } else {
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_skip(&reader, 0));
+    }
+  }
+  if (!has_kernels) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AMDGPU metadata missing `amdhsa.kernels`");
+  }
+  if (reader.current != reader.end) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AMDGPU metadata has trailing MessagePack bytes");
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_hsaco_metadata_parse_workgroup_size(
+    iree_hal_amdgpu_msgpack_reader_t* reader, uint32_t out_value[3]) {
+  uint32_t value_count = 0;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_msgpack_read_array_count(reader, &value_count));
+  if (value_count != 3) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AMDGPU workgroup size metadata must have three "
+                            "elements");
+  }
+  for (iree_host_size_t i = 0; i < 3; ++i) {
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_msgpack_read_uint32(reader, &out_value[i]));
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_hsaco_metadata_parse_arg(
+    iree_hal_amdgpu_msgpack_reader_t* reader,
+    iree_hal_amdgpu_hsaco_metadata_arg_t* out_arg) {
+  memset(out_arg, 0, sizeof(*out_arg));
+  iree_hal_amdgpu_hsaco_metadata_arg_fields_t fields = {0};
+  uint32_t field_count = 0;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_msgpack_read_map_count(reader, &field_count));
+  for (uint32_t i = 0; i < field_count; ++i) {
+    iree_string_view_t key = iree_string_view_empty();
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_read_string(reader, &key));
+    if (iree_string_view_equal(key, IREE_SV(".name"))) {
+      if (fields.has_name) {
+        return iree_hal_amdgpu_hsaco_metadata_duplicate_field_status(key);
+      }
+      IREE_RETURN_IF_ERROR(
+          iree_hal_amdgpu_msgpack_read_string(reader, &out_arg->name));
+      fields.has_name = true;
+    } else if (iree_string_view_equal(key, IREE_SV(".offset"))) {
+      if (fields.has_offset) {
+        return iree_hal_amdgpu_hsaco_metadata_duplicate_field_status(key);
+      }
+      IREE_RETURN_IF_ERROR(
+          iree_hal_amdgpu_msgpack_read_uint32(reader, &out_arg->offset));
+      fields.has_offset = true;
+    } else if (iree_string_view_equal(key, IREE_SV(".size"))) {
+      if (fields.has_size) {
+        return iree_hal_amdgpu_hsaco_metadata_duplicate_field_status(key);
+      }
+      IREE_RETURN_IF_ERROR(
+          iree_hal_amdgpu_msgpack_read_uint32(reader, &out_arg->size));
+      fields.has_size = true;
+    } else if (iree_string_view_equal(key, IREE_SV(".value_kind"))) {
+      if (fields.has_value_kind) {
+        return iree_hal_amdgpu_hsaco_metadata_duplicate_field_status(key);
+      }
+      IREE_RETURN_IF_ERROR(
+          iree_hal_amdgpu_msgpack_read_string(reader, &out_arg->value_kind));
+      out_arg->kind =
+          iree_hal_amdgpu_hsaco_metadata_classify_arg_kind(out_arg->value_kind);
+      fields.has_value_kind = true;
+    } else if (iree_string_view_equal(key, IREE_SV(".address_space"))) {
+      if (fields.has_address_space) {
+        return iree_hal_amdgpu_hsaco_metadata_duplicate_field_status(key);
+      }
+      IREE_RETURN_IF_ERROR(
+          iree_hal_amdgpu_msgpack_read_string(reader, &out_arg->address_space));
+      fields.has_address_space = true;
+    } else if (iree_string_view_equal(key, IREE_SV(".access"))) {
+      if (fields.has_access) {
+        return iree_hal_amdgpu_hsaco_metadata_duplicate_field_status(key);
+      }
+      iree_string_view_t access = iree_string_view_empty();
+      IREE_RETURN_IF_ERROR(
+          iree_hal_amdgpu_msgpack_read_string(reader, &access));
+      if (!fields.has_actual_access) out_arg->access = access;
+      fields.has_access = true;
+    } else if (iree_string_view_equal(key, IREE_SV(".actual_access"))) {
+      if (fields.has_actual_access) {
+        return iree_hal_amdgpu_hsaco_metadata_duplicate_field_status(key);
+      }
+      IREE_RETURN_IF_ERROR(
+          iree_hal_amdgpu_msgpack_read_string(reader, &out_arg->access));
+      fields.has_actual_access = true;
+    } else if (iree_string_view_equal(key, IREE_SV(".align")) ||
+               iree_string_view_equal(key, IREE_SV(".alignment"))) {
+      if (fields.has_alignment) {
+        return iree_hal_amdgpu_hsaco_metadata_duplicate_field_status(key);
+      }
+      IREE_RETURN_IF_ERROR(
+          iree_hal_amdgpu_msgpack_read_uint32(reader, &out_arg->alignment));
+      fields.has_alignment = true;
+    } else {
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_skip(reader, 0));
+    }
+  }
+  if (!fields.has_offset || !fields.has_size || !fields.has_value_kind) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AMDGPU kernel argument metadata missing required "
+                            "offset, size, or value_kind");
+  }
+  if (out_arg->alignment != 0 &&
+      !iree_host_size_is_power_of_two(out_arg->alignment)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AMDGPU kernel argument alignment must be a power "
+                            "of two");
+  }
+  if (out_arg->alignment != 0 &&
+      !iree_host_size_has_alignment(out_arg->offset, out_arg->alignment)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AMDGPU kernel argument offset is not aligned");
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_hsaco_metadata_validate_arg_ranges(
+    const iree_hal_amdgpu_hsaco_metadata_kernel_t* kernel) {
+  iree_host_size_t previous_arg_end = 0;
+  for (iree_host_size_t i = 0; i < kernel->arg_count; ++i) {
+    const iree_hal_amdgpu_hsaco_metadata_arg_t* arg = &kernel->args[i];
+    iree_host_size_t arg_end = 0;
+    if (!iree_host_size_checked_add(arg->offset, arg->size, &arg_end) ||
+        arg_end > kernel->kernarg_segment_size) {
+      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                              "AMDGPU kernel `%.*s` argument %" PRIhsz
+                              " exceeds kernarg segment size %u",
+                              (int)kernel->symbol_name.size,
+                              kernel->symbol_name.data, i,
+                              kernel->kernarg_segment_size);
+    }
+    if (i > 0 && arg->offset < previous_arg_end) {
+      return iree_make_status(
+          IREE_STATUS_INVALID_ARGUMENT,
+          "AMDGPU kernel `%.*s` argument %" PRIhsz
+          " overlaps a previous argument or is not sorted by offset",
+          (int)kernel->symbol_name.size, kernel->symbol_name.data, i);
+    }
+    previous_arg_end = arg_end;
+  }
+  return iree_ok_status();
+}
+
+static iree_string_view_t iree_hal_amdgpu_hsaco_metadata_select_reflection_name(
+    const iree_hal_amdgpu_hsaco_metadata_kernel_t* kernel) {
+  if (!iree_string_view_is_empty(kernel->name)) return kernel->name;
+  return iree_string_view_strip_suffix(kernel->symbol_name, IREE_SV(".kd"));
+}
+
+static iree_status_t iree_hal_amdgpu_hsaco_metadata_record_name_storage(
+    iree_hal_amdgpu_hsaco_metadata_t* metadata,
+    iree_hal_amdgpu_hsaco_metadata_kernel_t* kernel) {
+  kernel->reflection_name =
+      iree_hal_amdgpu_hsaco_metadata_select_reflection_name(kernel);
+  if (!iree_host_size_checked_add(metadata->reflection_name_storage_size,
+                                  kernel->reflection_name.size,
+                                  &metadata->reflection_name_storage_size)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "AMDGPU metadata reflection name storage overflow");
+  }
+  for (iree_host_size_t i = 0; i < kernel->arg_count; ++i) {
+    const iree_hal_amdgpu_hsaco_metadata_arg_t* arg = &kernel->args[i];
+    if (!iree_host_size_checked_add(kernel->arg_name_storage_size,
+                                    arg->name.size,
+                                    &kernel->arg_name_storage_size) ||
+        !iree_host_size_checked_add(metadata->arg_name_storage_size,
+                                    arg->name.size,
+                                    &metadata->arg_name_storage_size)) {
+      return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                              "AMDGPU metadata argument name storage overflow");
+    }
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_hsaco_metadata_parse_kernel(
+    iree_hal_amdgpu_msgpack_reader_t* reader,
+    iree_hal_amdgpu_hsaco_metadata_t* metadata,
+    iree_host_size_t* inout_arg_index,
+    iree_hal_amdgpu_hsaco_metadata_kernel_t* out_kernel) {
+  memset(out_kernel, 0, sizeof(*out_kernel));
+  iree_hal_amdgpu_hsaco_metadata_kernel_fields_t fields = {0};
+  uint32_t field_count = 0;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_msgpack_read_map_count(reader, &field_count));
+  for (uint32_t i = 0; i < field_count; ++i) {
+    iree_string_view_t key = iree_string_view_empty();
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_read_string(reader, &key));
+    if (iree_string_view_equal(key, IREE_SV(".name"))) {
+      if (fields.has_name) {
+        return iree_hal_amdgpu_hsaco_metadata_duplicate_field_status(key);
+      }
+      IREE_RETURN_IF_ERROR(
+          iree_hal_amdgpu_msgpack_read_string(reader, &out_kernel->name));
+      fields.has_name = true;
+    } else if (iree_string_view_equal(key, IREE_SV(".symbol"))) {
+      if (fields.has_symbol_name) {
+        return iree_hal_amdgpu_hsaco_metadata_duplicate_field_status(key);
+      }
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_read_string(
+          reader, &out_kernel->symbol_name));
+      fields.has_symbol_name = true;
+    } else if (iree_string_view_equal(key, IREE_SV(".kernarg_segment_size"))) {
+      if (fields.has_kernarg_segment_size) {
+        return iree_hal_amdgpu_hsaco_metadata_duplicate_field_status(key);
+      }
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_read_uint32(
+          reader, &out_kernel->kernarg_segment_size));
+      fields.has_kernarg_segment_size = true;
+    } else if (iree_string_view_equal(key, IREE_SV(".kernarg_segment_align"))) {
+      if (fields.has_kernarg_segment_alignment) {
+        return iree_hal_amdgpu_hsaco_metadata_duplicate_field_status(key);
+      }
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_read_uint32(
+          reader, &out_kernel->kernarg_segment_alignment));
+      fields.has_kernarg_segment_alignment = true;
+    } else if (iree_string_view_equal(key,
+                                      IREE_SV(".group_segment_fixed_size"))) {
+      if (fields.has_group_segment_fixed_size) {
+        return iree_hal_amdgpu_hsaco_metadata_duplicate_field_status(key);
+      }
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_read_uint32(
+          reader, &out_kernel->group_segment_fixed_size));
+      fields.has_group_segment_fixed_size = true;
+    } else if (iree_string_view_equal(key,
+                                      IREE_SV(".private_segment_fixed_size"))) {
+      if (fields.has_private_segment_fixed_size) {
+        return iree_hal_amdgpu_hsaco_metadata_duplicate_field_status(key);
+      }
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_read_uint32(
+          reader, &out_kernel->private_segment_fixed_size));
+      fields.has_private_segment_fixed_size = true;
+    } else if (iree_string_view_equal(key, IREE_SV(".reqd_workgroup_size"))) {
+      if (fields.has_required_workgroup_size) {
+        return iree_hal_amdgpu_hsaco_metadata_duplicate_field_status(key);
+      }
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_hsaco_metadata_parse_workgroup_size(
+          reader, out_kernel->required_workgroup_size));
+      out_kernel->has_required_workgroup_size = true;
+      fields.has_required_workgroup_size = true;
+    } else if (iree_string_view_equal(key, IREE_SV(".args"))) {
+      if (fields.has_args) {
+        return iree_hal_amdgpu_hsaco_metadata_duplicate_field_status(key);
+      }
+      fields.has_args = true;
+      uint32_t arg_count = 0;
+      IREE_RETURN_IF_ERROR(
+          iree_hal_amdgpu_msgpack_read_array_count(reader, &arg_count));
+      if (arg_count > metadata->arg_count ||
+          *inout_arg_index > metadata->arg_count - arg_count) {
+        return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                                "AMDGPU metadata argument count changed "
+                                "between parse passes");
+      }
+      out_kernel->arg_count = arg_count;
+      out_kernel->args = arg_count ? &metadata->args[*inout_arg_index] : NULL;
+      for (uint32_t j = 0; j < arg_count; ++j) {
+        IREE_RETURN_IF_ERROR(iree_hal_amdgpu_hsaco_metadata_parse_arg(
+            reader, &metadata->args[*inout_arg_index + j]));
+      }
+      *inout_arg_index += arg_count;
+    } else {
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_skip(reader, 0));
+    }
+  }
+
+  if (!fields.has_symbol_name || !fields.has_kernarg_segment_size ||
+      !fields.has_kernarg_segment_alignment ||
+      !fields.has_group_segment_fixed_size ||
+      !fields.has_private_segment_fixed_size || !fields.has_args) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AMDGPU kernel metadata missing required fields");
+  }
+  if (!iree_host_size_is_power_of_two(out_kernel->kernarg_segment_alignment)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AMDGPU kernel kernarg alignment must be a power "
+                            "of two");
+  }
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_hsaco_metadata_validate_arg_ranges(out_kernel));
+  return iree_hal_amdgpu_hsaco_metadata_record_name_storage(metadata,
+                                                            out_kernel);
+}
+
+static iree_status_t iree_hal_amdgpu_hsaco_metadata_parse_message_pack(
+    iree_const_byte_span_t message_pack_data,
+    iree_hal_amdgpu_hsaco_metadata_t* metadata) {
+  iree_hal_amdgpu_msgpack_reader_t reader = {
+      .current = message_pack_data.data,
+      .end = message_pack_data.data + message_pack_data.data_length,
+  };
+  uint32_t root_field_count = 0;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_msgpack_read_map_count(&reader, &root_field_count));
+  bool has_target = false;
+  bool has_kernels = false;
+  iree_host_size_t arg_index = 0;
+  for (uint32_t i = 0; i < root_field_count; ++i) {
+    iree_string_view_t key = iree_string_view_empty();
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_read_string(&reader, &key));
+    if (iree_string_view_equal(key, IREE_SV("amdhsa.target"))) {
+      if (has_target) {
+        return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                                "AMDGPU metadata repeats `amdhsa.target`");
+      }
+      IREE_RETURN_IF_ERROR(
+          iree_hal_amdgpu_msgpack_read_string(&reader, &metadata->target));
+      has_target = true;
+    } else if (iree_string_view_equal(key, IREE_SV("amdhsa.kernels"))) {
+      if (has_kernels) {
+        return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                                "AMDGPU metadata repeats `amdhsa.kernels`");
+      }
+      has_kernels = true;
+      uint32_t kernel_count = 0;
+      IREE_RETURN_IF_ERROR(
+          iree_hal_amdgpu_msgpack_read_array_count(&reader, &kernel_count));
+      if (kernel_count != metadata->kernel_count) {
+        return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                                "AMDGPU metadata kernel count changed between "
+                                "parse passes");
+      }
+      for (uint32_t j = 0; j < kernel_count; ++j) {
+        IREE_RETURN_IF_ERROR(iree_hal_amdgpu_hsaco_metadata_parse_kernel(
+            &reader, metadata, &arg_index, &metadata->kernels[j]));
+      }
+    } else {
+      IREE_RETURN_IF_ERROR(iree_hal_amdgpu_msgpack_skip(&reader, 0));
+    }
+  }
+  if (!has_kernels) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AMDGPU metadata missing `amdhsa.kernels`");
+  }
+  if (arg_index != metadata->arg_count) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "AMDGPU metadata argument count changed between "
+                            "parse passes");
+  }
+  if (reader.current != reader.end) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AMDGPU metadata has trailing MessagePack bytes");
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_hsaco_metadata_allocate_storage(
+    iree_hal_amdgpu_hsaco_metadata_count_t count,
+    iree_allocator_t host_allocator,
+    iree_hal_amdgpu_hsaco_metadata_t* metadata) {
+  metadata->kernel_count = count.kernel_count;
+  metadata->arg_count = count.arg_count;
+  if (count.kernel_count == 0 && count.arg_count == 0) {
+    return iree_ok_status();
+  }
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, count.kernel_count);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, count.arg_count);
+
+  iree_host_size_t kernels_size = 0;
+  iree_host_size_t args_offset = 0;
+  iree_host_size_t args_size = 0;
+  iree_host_size_t total_size = 0;
+  if (!iree_host_size_checked_mul(
+          count.kernel_count, sizeof(metadata->kernels[0]), &kernels_size) ||
+      !iree_host_size_checked_align(
+          kernels_size, iree_alignof(iree_hal_amdgpu_hsaco_metadata_arg_t),
+          &args_offset) ||
+      !iree_host_size_checked_mul(count.arg_count, sizeof(metadata->args[0]),
+                                  &args_size) ||
+      !iree_host_size_checked_add(args_offset, args_size, &total_size)) {
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                             "AMDGPU metadata storage size overflow"));
+  }
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, total_size);
+
+  uint8_t* storage = NULL;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_allocator_malloc(host_allocator, total_size, (void**)&storage));
+  memset(storage, 0, total_size);
+  metadata->kernels = (iree_hal_amdgpu_hsaco_metadata_kernel_t*)storage;
+  metadata->args =
+      count.arg_count
+          ? (iree_hal_amdgpu_hsaco_metadata_arg_t*)(storage + args_offset)
+          : NULL;
+  IREE_TRACE_ZONE_END(z0);
+  return iree_ok_status();
+}
+
+iree_status_t iree_hal_amdgpu_hsaco_metadata_initialize_from_elf(
+    iree_const_byte_span_t elf_data, iree_allocator_t host_allocator,
+    iree_hal_amdgpu_hsaco_metadata_t* out_metadata) {
+  IREE_ASSERT_ARGUMENT(out_metadata);
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, elf_data.data_length);
+  memset(out_metadata, 0, sizeof(*out_metadata));
+  out_metadata->host_allocator = host_allocator;
+  out_metadata->elf_data = elf_data;
+
+  iree_status_t status = iree_hal_amdgpu_hsaco_metadata_find_note(
+      elf_data, &out_metadata->message_pack_data);
+
+  iree_hal_amdgpu_hsaco_metadata_count_t count = {0};
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_hsaco_metadata_count_message_pack(
+        out_metadata->message_pack_data, &count);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_hsaco_metadata_allocate_storage(
+        count, host_allocator, out_metadata);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_hsaco_metadata_parse_message_pack(
+        out_metadata->message_pack_data, out_metadata);
+  }
+  if (!iree_status_is_ok(status)) {
+    iree_hal_amdgpu_hsaco_metadata_deinitialize(out_metadata);
+  }
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+void iree_hal_amdgpu_hsaco_metadata_deinitialize(
+    iree_hal_amdgpu_hsaco_metadata_t* metadata) {
+  if (!metadata) {
+    return;
+  }
+  IREE_TRACE_ZONE_BEGIN(z0);
+  if (metadata->kernels) {
+    iree_allocator_free(metadata->host_allocator, metadata->kernels);
+  }
+  memset(metadata, 0, sizeof(*metadata));
+  IREE_TRACE_ZONE_END(z0);
+}
+
+static bool iree_hal_amdgpu_hsaco_metadata_arg_kind_is_hidden(
+    iree_hal_amdgpu_hsaco_metadata_arg_kind_t kind) {
+  return kind == IREE_HAL_AMDGPU_HSACO_METADATA_ARG_KIND_HIDDEN ||
+         kind == IREE_HAL_AMDGPU_HSACO_METADATA_ARG_KIND_HIDDEN_NONE;
+}
+
+static iree_status_t iree_hal_amdgpu_hsaco_metadata_add_parameter_name_size(
+    const iree_hal_amdgpu_hsaco_metadata_arg_t* arg,
+    iree_host_size_t* inout_name_storage_size) {
+  if (!iree_host_size_checked_add(*inout_name_storage_size, arg->name.size,
+                                  inout_name_storage_size)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "AMDGPU metadata reflected parameter name storage overflow");
+  }
+  return iree_ok_status();
+}
+
+iree_status_t
+iree_hal_amdgpu_hsaco_metadata_calculate_default_export_parameter_requirements(
+    const iree_hal_amdgpu_hsaco_metadata_kernel_t* kernel,
+    iree_hal_amdgpu_hsaco_metadata_export_parameter_requirements_t*
+        out_requirements) {
+  IREE_ASSERT_ARGUMENT(kernel);
+  IREE_ASSERT_ARGUMENT(out_requirements);
+  memset(out_requirements, 0, sizeof(*out_requirements));
+
+  iree_host_size_t parameter_count = 0;
+  iree_host_size_t binding_count = 0;
+  iree_host_size_t constant_byte_count = 0;
+  iree_host_size_t name_storage_size = 0;
+  for (iree_host_size_t i = 0; i < kernel->arg_count; ++i) {
+    const iree_hal_amdgpu_hsaco_metadata_arg_t* arg = &kernel->args[i];
+    if (iree_hal_amdgpu_hsaco_metadata_arg_kind_is_hidden(arg->kind)) {
+      continue;
+    }
+    switch (arg->kind) {
+      case IREE_HAL_AMDGPU_HSACO_METADATA_ARG_KIND_GLOBAL_BUFFER:
+        if (arg->size != sizeof(uint64_t)) {
+          return iree_make_status(
+              IREE_STATUS_INVALID_ARGUMENT,
+              "AMDGPU kernel `%.*s` global_buffer argument %" PRIhsz
+              " has unsupported size %u",
+              (int)kernel->symbol_name.size, kernel->symbol_name.data, i,
+              arg->size);
+        }
+        ++parameter_count;
+        ++binding_count;
+        IREE_RETURN_IF_ERROR(
+            iree_hal_amdgpu_hsaco_metadata_add_parameter_name_size(
+                arg, &name_storage_size));
+        break;
+      case IREE_HAL_AMDGPU_HSACO_METADATA_ARG_KIND_BY_VALUE: {
+        if (arg->size > UINT8_MAX) {
+          return iree_make_status(
+              IREE_STATUS_OUT_OF_RANGE,
+              "AMDGPU kernel `%.*s` by_value argument %" PRIhsz
+              " size %u exceeds HAL parameter size range",
+              (int)kernel->symbol_name.size, kernel->symbol_name.data, i,
+              arg->size);
+        }
+        if (!iree_host_size_has_alignment(arg->size, sizeof(uint32_t))) {
+          return iree_make_status(
+              IREE_STATUS_INVALID_ARGUMENT,
+              "AMDGPU kernel `%.*s` by_value argument %" PRIhsz
+              " size %u is not a whole number of 32-bit HAL constants",
+              (int)kernel->symbol_name.size, kernel->symbol_name.data, i,
+              arg->size);
+        }
+        if (constant_byte_count > UINT16_MAX) {
+          return iree_make_status(
+              IREE_STATUS_OUT_OF_RANGE,
+              "AMDGPU kernel `%.*s` by_value argument %" PRIhsz
+              " constant offset exceeds HAL parameter offset range",
+              (int)kernel->symbol_name.size, kernel->symbol_name.data, i);
+        }
+        ++parameter_count;
+        if (!iree_host_size_checked_add(constant_byte_count, arg->size,
+                                        &constant_byte_count)) {
+          return iree_make_status(
+              IREE_STATUS_OUT_OF_RANGE,
+              "AMDGPU metadata reflected constant byte count overflow");
+        }
+        IREE_RETURN_IF_ERROR(
+            iree_hal_amdgpu_hsaco_metadata_add_parameter_name_size(
+                arg, &name_storage_size));
+        break;
+      }
+      default:
+        return iree_make_status(
+            IREE_STATUS_INVALID_ARGUMENT,
+            "AMDGPU kernel `%.*s` argument %" PRIhsz
+            " uses unsupported reflected value_kind `%.*s`",
+            (int)kernel->symbol_name.size, kernel->symbol_name.data, i,
+            (int)arg->value_kind.size, arg->value_kind.data);
+    }
+  }
+
+  iree_host_size_t aligned_constant_byte_count = 0;
+  if (!iree_host_size_checked_align(constant_byte_count, sizeof(uint32_t),
+                                    &aligned_constant_byte_count)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "AMDGPU metadata reflected constant byte count alignment overflow");
+  }
+  const iree_host_size_t constant_count =
+      aligned_constant_byte_count / sizeof(uint32_t);
+  if (parameter_count > UINT16_MAX || binding_count > UINT16_MAX ||
+      constant_count > UINT16_MAX) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "AMDGPU metadata reflected parameter counts exceed HAL ranges");
+  }
+
+  out_requirements->parameter_count = (uint16_t)parameter_count;
+  out_requirements->constant_count = (uint16_t)constant_count;
+  out_requirements->binding_count = (uint16_t)binding_count;
+  out_requirements->name_storage_size = name_storage_size;
+  return iree_ok_status();
+}
+
+iree_status_t iree_hal_amdgpu_hsaco_metadata_populate_default_export_parameters(
+    const iree_hal_amdgpu_hsaco_metadata_kernel_t* kernel,
+    iree_host_size_t parameter_capacity,
+    iree_hal_executable_export_parameter_t* out_parameters,
+    iree_host_size_t name_storage_capacity, char* name_storage) {
+  IREE_ASSERT_ARGUMENT(kernel);
+  IREE_ASSERT_ARGUMENT(out_parameters || parameter_capacity == 0);
+  IREE_ASSERT_ARGUMENT(name_storage || name_storage_capacity == 0);
+
+  iree_hal_amdgpu_hsaco_metadata_export_parameter_requirements_t requirements;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_hsaco_metadata_calculate_default_export_parameter_requirements(
+          kernel, &requirements));
+  if (parameter_capacity < requirements.parameter_count ||
+      name_storage_capacity < requirements.name_storage_size) {
+    return iree_make_status(
+        IREE_STATUS_RESOURCE_EXHAUSTED,
+        "AMDGPU metadata export parameter output capacity too small");
+  }
+  if ((requirements.parameter_count != 0 && !out_parameters) ||
+      (requirements.name_storage_size != 0 && !name_storage)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "AMDGPU metadata export parameter output storage is null");
+  }
+
+  iree_host_size_t parameter_index = 0;
+  iree_host_size_t name_storage_offset = 0;
+  uint16_t binding_ordinal = 0;
+  iree_host_size_t constant_offset = 0;
+  for (iree_host_size_t i = 0; i < kernel->arg_count; ++i) {
+    const iree_hal_amdgpu_hsaco_metadata_arg_t* arg = &kernel->args[i];
+    if (iree_hal_amdgpu_hsaco_metadata_arg_kind_is_hidden(arg->kind)) {
+      continue;
+    }
+
+    iree_hal_executable_export_parameter_t* parameter =
+        &out_parameters[parameter_index++];
+    memset(parameter, 0, sizeof(*parameter));
+    parameter->flags = IREE_HAL_EXECUTABLE_EXPORT_PARAMETER_FLAG_NONE;
+    parameter->size = (uint8_t)arg->size;
+    if (!iree_string_view_is_empty(arg->name)) {
+      memcpy(name_storage + name_storage_offset, arg->name.data,
+             arg->name.size);
+      parameter->name = iree_make_string_view(
+          name_storage + name_storage_offset, arg->name.size);
+      name_storage_offset += arg->name.size;
+    } else {
+      parameter->name = iree_string_view_empty();
+    }
+
+    switch (arg->kind) {
+      case IREE_HAL_AMDGPU_HSACO_METADATA_ARG_KIND_GLOBAL_BUFFER:
+        parameter->type = IREE_HAL_EXECUTABLE_EXPORT_PARAMETER_TYPE_BINDING;
+        parameter->offset = binding_ordinal++;
+        break;
+      case IREE_HAL_AMDGPU_HSACO_METADATA_ARG_KIND_BY_VALUE:
+        parameter->type = IREE_HAL_EXECUTABLE_EXPORT_PARAMETER_TYPE_CONSTANT;
+        parameter->offset = (uint16_t)constant_offset;
+        constant_offset += arg->size;
+        break;
+      default:
+        return iree_make_status(
+            IREE_STATUS_INVALID_ARGUMENT,
+            "AMDGPU kernel `%.*s` argument %" PRIhsz
+            " uses unsupported reflected value_kind `%.*s`",
+            (int)kernel->symbol_name.size, kernel->symbol_name.data, i,
+            (int)arg->value_kind.size, arg->value_kind.data);
+    }
+  }
+  return iree_ok_status();
+}
+
+iree_status_t iree_hal_amdgpu_hsaco_metadata_find_kernel_by_symbol(
+    const iree_hal_amdgpu_hsaco_metadata_t* metadata,
+    iree_string_view_t symbol_name,
+    const iree_hal_amdgpu_hsaco_metadata_kernel_t** out_kernel) {
+  IREE_ASSERT_ARGUMENT(metadata);
+  IREE_ASSERT_ARGUMENT(out_kernel);
+  *out_kernel = NULL;
+  for (iree_host_size_t i = 0; i < metadata->kernel_count; ++i) {
+    const iree_hal_amdgpu_hsaco_metadata_kernel_t* kernel =
+        &metadata->kernels[i];
+    if (!iree_string_view_equal(kernel->symbol_name, symbol_name)) continue;
+    if (*out_kernel) {
+      return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                              "AMDGPU metadata has duplicate kernel symbol "
+                              "`%.*s`",
+                              (int)symbol_name.size, symbol_name.data);
+    }
+    *out_kernel = kernel;
+  }
+  if (!*out_kernel) {
+    return iree_make_status(IREE_STATUS_NOT_FOUND,
+                            "AMDGPU metadata kernel symbol `%.*s` not found",
+                            (int)symbol_name.size, symbol_name.data);
+  }
+  return iree_ok_status();
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/hsaco_metadata.h b/runtime/src/iree/hal/drivers/amdgpu/util/hsaco_metadata.h
new file mode 100644
index 0000000..06f3c38
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/hsaco_metadata.h
@@ -0,0 +1,187 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_UTIL_HSACO_METADATA_H_
+#define IREE_HAL_DRIVERS_AMDGPU_UTIL_HSACO_METADATA_H_
+
+#include "iree/base/api.h"
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+//===----------------------------------------------------------------------===//
+// AMDGPU Code Object Metadata
+//===----------------------------------------------------------------------===//
+
+// Kernel argument ABI classification from AMDGPU metadata `.value_kind`.
+typedef enum iree_hal_amdgpu_hsaco_metadata_arg_kind_e {
+  // Unknown or unsupported value kind.
+  IREE_HAL_AMDGPU_HSACO_METADATA_ARG_KIND_UNKNOWN = 0,
+  // Argument bytes are copied directly into the kernarg segment.
+  IREE_HAL_AMDGPU_HSACO_METADATA_ARG_KIND_BY_VALUE,
+  // Argument is a pointer to global memory.
+  IREE_HAL_AMDGPU_HSACO_METADATA_ARG_KIND_GLOBAL_BUFFER,
+  // Argument is a pointer to dynamically allocated LDS.
+  IREE_HAL_AMDGPU_HSACO_METADATA_ARG_KIND_DYNAMIC_SHARED_POINTER,
+  // Argument is an image descriptor pointer.
+  IREE_HAL_AMDGPU_HSACO_METADATA_ARG_KIND_IMAGE,
+  // Argument is a sampler descriptor pointer.
+  IREE_HAL_AMDGPU_HSACO_METADATA_ARG_KIND_SAMPLER,
+  // Argument is an OpenCL pipe pointer.
+  IREE_HAL_AMDGPU_HSACO_METADATA_ARG_KIND_PIPE,
+  // Argument is an OpenCL device enqueue queue pointer.
+  IREE_HAL_AMDGPU_HSACO_METADATA_ARG_KIND_QUEUE,
+  // Argument is a hidden ABI/runtime value.
+  IREE_HAL_AMDGPU_HSACO_METADATA_ARG_KIND_HIDDEN,
+  // Argument reserves hidden ABI space but does not need a value.
+  IREE_HAL_AMDGPU_HSACO_METADATA_ARG_KIND_HIDDEN_NONE,
+} iree_hal_amdgpu_hsaco_metadata_arg_kind_t;
+
+// Decoded kernel argument metadata.
+typedef struct iree_hal_amdgpu_hsaco_metadata_arg_t {
+  // Source-level argument name from `.name`, if present.
+  iree_string_view_t name;
+  // Byte offset in the kernel's kernarg segment.
+  uint32_t offset;
+  // Byte length of the kernarg storage for this argument.
+  uint32_t size;
+  // Storage alignment in bytes when explicitly available; otherwise 0.
+  uint32_t alignment;
+  // Parsed classification of |value_kind|.
+  iree_hal_amdgpu_hsaco_metadata_arg_kind_t kind;
+  // Raw `.value_kind` string view borrowed from the metadata blob.
+  iree_string_view_t value_kind;
+  // Raw `.address_space` string view borrowed from the metadata blob, if any.
+  iree_string_view_t address_space;
+  // Effective access string view borrowed from the metadata blob, if any.
+  // `.actual_access` is preferred over `.access` when both are present.
+  iree_string_view_t access;
+} iree_hal_amdgpu_hsaco_metadata_arg_t;
+
+// Decoded per-kernel metadata.
+typedef struct iree_hal_amdgpu_hsaco_metadata_kernel_t {
+  // Source-level kernel name from `.name`, if present.
+  iree_string_view_t name;
+  // Kernel descriptor symbol name from `.symbol`, usually `foo.kd`.
+  iree_string_view_t symbol_name;
+  // Export reflection name from `.name` or `.symbol` with a `.kd` suffix
+  // removed.
+  iree_string_view_t reflection_name;
+  // Bytes required to clone all argument names for this kernel.
+  iree_host_size_t arg_name_storage_size;
+  // Kernel kernarg segment size from `.kernarg_segment_size`.
+  uint32_t kernarg_segment_size;
+  // Kernel kernarg segment alignment from `.kernarg_segment_align`.
+  uint32_t kernarg_segment_alignment;
+  // Fixed group segment size from `.group_segment_fixed_size`.
+  uint32_t group_segment_fixed_size;
+  // Fixed private segment size from `.private_segment_fixed_size`.
+  uint32_t private_segment_fixed_size;
+  // Required workgroup size from `.reqd_workgroup_size`, if present.
+  uint32_t required_workgroup_size[3];
+  // True when |required_workgroup_size| was present.
+  bool has_required_workgroup_size;
+  // Number of argument records in |args|.
+  iree_host_size_t arg_count;
+  // Argument records borrowed from the owning metadata object.
+  const iree_hal_amdgpu_hsaco_metadata_arg_t* args;
+} iree_hal_amdgpu_hsaco_metadata_kernel_t;
+
+// Decoded AMDGPU code object metadata.
+//
+// All string views and |message_pack_data| borrow from |elf_data|. Callers must
+// keep the ELF bytes alive for as long as this metadata object is in use.
+typedef struct iree_hal_amdgpu_hsaco_metadata_t {
+  // Allocator used for kernel and argument arrays.
+  iree_allocator_t host_allocator;
+  // Borrowed ELF bytes used as the source of all string views.
+  iree_const_byte_span_t elf_data;
+  // Borrowed AMDGPU MessagePack note descriptor payload.
+  iree_const_byte_span_t message_pack_data;
+  // Borrowed target ISA string from `amdhsa.target`, if present.
+  iree_string_view_t target;
+  // Bytes required to clone all kernel reflection names.
+  iree_host_size_t reflection_name_storage_size;
+  // Bytes required to clone all decoded argument names.
+  iree_host_size_t arg_name_storage_size;
+  // Number of decoded kernels.
+  iree_host_size_t kernel_count;
+  // Decoded kernel records.
+  iree_hal_amdgpu_hsaco_metadata_kernel_t* kernels;
+  // Total number of decoded argument records.
+  iree_host_size_t arg_count;
+  // Contiguous argument storage referenced by |kernels|.
+  iree_hal_amdgpu_hsaco_metadata_arg_t* args;
+} iree_hal_amdgpu_hsaco_metadata_t;
+
+// Requirements for materializing default HAL export parameter reflection.
+typedef struct iree_hal_amdgpu_hsaco_metadata_export_parameter_requirements_t {
+  // Number of HAL-visible parameters after hidden ABI arguments are skipped.
+  uint16_t parameter_count;
+  // Number of 32-bit constants consumed by reflected by-value parameters.
+  uint16_t constant_count;
+  // Number of HAL bindings consumed by reflected global-buffer parameters.
+  uint16_t binding_count;
+  // Bytes required to clone all reflected parameter names for this kernel.
+  iree_host_size_t name_storage_size;
+} iree_hal_amdgpu_hsaco_metadata_export_parameter_requirements_t;
+
+// Initializes |out_metadata| from a raw AMDGPU ELF code object.
+//
+// This locates the `AMDGPU`/`NT_AMDGPU_METADATA` note and decodes only the
+// fields needed for kernel argument reflection. The parser accepts a normal
+// LLVM-produced 64-bit little-endian AMDGPU ELF. It intentionally does not
+// implement HIP fat binary, clang offload bundle, or compressed code object
+// handling.
+iree_status_t iree_hal_amdgpu_hsaco_metadata_initialize_from_elf(
+    iree_const_byte_span_t elf_data, iree_allocator_t host_allocator,
+    iree_hal_amdgpu_hsaco_metadata_t* out_metadata);
+
+// Releases storage owned by |metadata|.
+void iree_hal_amdgpu_hsaco_metadata_deinitialize(
+    iree_hal_amdgpu_hsaco_metadata_t* metadata);
+
+// Calculates the storage and export-info counts required by the default HAL
+// reflection projection for |kernel|.
+//
+// This projection maps `.value_kind == "global_buffer"` arguments to bindings
+// in metadata argument order and `.value_kind == "by_value"` arguments to
+// constants in metadata argument order. Reflected by-value sizes must be
+// whole 32-bit constants. Hidden ABI arguments are skipped. Any
+// other visible argument kind fails with IREE_STATUS_INVALID_ARGUMENT so
+// callers do not accidentally publish partial reflection for unsupported ABIs.
+iree_status_t
+iree_hal_amdgpu_hsaco_metadata_calculate_default_export_parameter_requirements(
+    const iree_hal_amdgpu_hsaco_metadata_kernel_t* kernel,
+    iree_hal_amdgpu_hsaco_metadata_export_parameter_requirements_t*
+        out_requirements);
+
+// Populates |out_parameters| using the default HAL reflection projection.
+//
+// |parameter_capacity| and |name_storage_capacity| must satisfy the
+// requirements returned by
+// iree_hal_amdgpu_hsaco_metadata_calculate_default_export_parameter_requirements.
+// Reflected parameter names are cloned into |name_storage| and borrowed by the
+// returned parameter records. No NUL terminators are written or required.
+iree_status_t iree_hal_amdgpu_hsaco_metadata_populate_default_export_parameters(
+    const iree_hal_amdgpu_hsaco_metadata_kernel_t* kernel,
+    iree_host_size_t parameter_capacity,
+    iree_hal_executable_export_parameter_t* out_parameters,
+    iree_host_size_t name_storage_capacity, char* name_storage);
+
+// Finds a decoded kernel by its descriptor symbol name.
+iree_status_t iree_hal_amdgpu_hsaco_metadata_find_kernel_by_symbol(
+    const iree_hal_amdgpu_hsaco_metadata_t* metadata,
+    iree_string_view_t symbol_name,
+    const iree_hal_amdgpu_hsaco_metadata_kernel_t** out_kernel);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_UTIL_HSACO_METADATA_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/hsaco_metadata_fuzz.cc b/runtime/src/iree/hal/drivers/amdgpu/util/hsaco_metadata_fuzz.cc
new file mode 100644
index 0000000..8c955db
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/hsaco_metadata_fuzz.cc
@@ -0,0 +1,124 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <vector>
+
+#include "iree/base/api.h"
+#include "iree/hal/drivers/amdgpu/util/hsaco_metadata.h"
+
+static void iree_hal_amdgpu_hsaco_metadata_fuzz_append_u32le(
+    std::vector<uint8_t>* output, uint32_t value) {
+  output->push_back((uint8_t)value);
+  output->push_back((uint8_t)(value >> 8));
+  output->push_back((uint8_t)(value >> 16));
+  output->push_back((uint8_t)(value >> 24));
+}
+
+static void iree_hal_amdgpu_hsaco_metadata_fuzz_store_u16le(
+    std::vector<uint8_t>* output, size_t offset, uint16_t value) {
+  (*output)[offset + 0] = (uint8_t)value;
+  (*output)[offset + 1] = (uint8_t)(value >> 8);
+}
+
+static void iree_hal_amdgpu_hsaco_metadata_fuzz_store_u32le(
+    std::vector<uint8_t>* output, size_t offset, uint32_t value) {
+  (*output)[offset + 0] = (uint8_t)value;
+  (*output)[offset + 1] = (uint8_t)(value >> 8);
+  (*output)[offset + 2] = (uint8_t)(value >> 16);
+  (*output)[offset + 3] = (uint8_t)(value >> 24);
+}
+
+static void iree_hal_amdgpu_hsaco_metadata_fuzz_store_u64le(
+    std::vector<uint8_t>* output, size_t offset, uint64_t value) {
+  for (int i = 0; i < 8; ++i) {
+    (*output)[offset + i] = (uint8_t)(value >> (i * 8));
+  }
+}
+
+static void iree_hal_amdgpu_hsaco_metadata_fuzz_append_aligned4_padding(
+    std::vector<uint8_t>* output) {
+  while ((output->size() & 3) != 0) output->push_back(0);
+}
+
+static std::vector<uint8_t> iree_hal_amdgpu_hsaco_metadata_fuzz_wrap_as_elf(
+    const uint8_t* data, size_t size) {
+  constexpr size_t kElfHeaderSize = 64;
+  constexpr size_t kProgramHeaderOffset = 64;
+  constexpr size_t kProgramHeaderSize = 56;
+  constexpr size_t kNoteOffset = 128;
+  static const uint8_t kNoteName[] = {'A', 'M', 'D', 'G', 'P', 'U'};
+
+  std::vector<uint8_t> note;
+  iree_hal_amdgpu_hsaco_metadata_fuzz_append_u32le(&note,
+                                                   sizeof(kNoteName) + 1);
+  iree_hal_amdgpu_hsaco_metadata_fuzz_append_u32le(&note, (uint32_t)size);
+  iree_hal_amdgpu_hsaco_metadata_fuzz_append_u32le(&note, 32);
+  note.insert(note.end(), kNoteName, kNoteName + sizeof(kNoteName));
+  note.push_back(0);
+  iree_hal_amdgpu_hsaco_metadata_fuzz_append_aligned4_padding(&note);
+  note.insert(note.end(), data, data + size);
+  iree_hal_amdgpu_hsaco_metadata_fuzz_append_aligned4_padding(&note);
+
+  std::vector<uint8_t> elf(kNoteOffset, 0);
+  elf[0] = 0x7F;
+  elf[1] = 'E';
+  elf[2] = 'L';
+  elf[3] = 'F';
+  elf[4] = 2;
+  elf[5] = 1;
+  elf[6] = 1;
+  iree_hal_amdgpu_hsaco_metadata_fuzz_store_u16le(&elf, 16, 3);
+  iree_hal_amdgpu_hsaco_metadata_fuzz_store_u16le(&elf, 18, 224);
+  iree_hal_amdgpu_hsaco_metadata_fuzz_store_u32le(&elf, 20, 1);
+  iree_hal_amdgpu_hsaco_metadata_fuzz_store_u64le(&elf, 32,
+                                                  kProgramHeaderOffset);
+  iree_hal_amdgpu_hsaco_metadata_fuzz_store_u16le(&elf, 52, kElfHeaderSize);
+  iree_hal_amdgpu_hsaco_metadata_fuzz_store_u16le(&elf, 54, kProgramHeaderSize);
+  iree_hal_amdgpu_hsaco_metadata_fuzz_store_u16le(&elf, 56, 1);
+
+  iree_hal_amdgpu_hsaco_metadata_fuzz_store_u32le(&elf, kProgramHeaderOffset,
+                                                  4);
+  iree_hal_amdgpu_hsaco_metadata_fuzz_store_u64le(
+      &elf, kProgramHeaderOffset + 8, kNoteOffset);
+  iree_hal_amdgpu_hsaco_metadata_fuzz_store_u64le(
+      &elf, kProgramHeaderOffset + 32, note.size());
+  iree_hal_amdgpu_hsaco_metadata_fuzz_store_u64le(
+      &elf, kProgramHeaderOffset + 40, note.size());
+  iree_hal_amdgpu_hsaco_metadata_fuzz_store_u64le(&elf,
+                                                  kProgramHeaderOffset + 48, 4);
+
+  elf.insert(elf.end(), note.begin(), note.end());
+  return elf;
+}
+
+static void iree_hal_amdgpu_hsaco_metadata_fuzz_parse(
+    iree_const_byte_span_t elf_data) {
+  iree_hal_amdgpu_hsaco_metadata_t metadata;
+  iree_status_t status = iree_hal_amdgpu_hsaco_metadata_initialize_from_elf(
+      elf_data, iree_allocator_system(), &metadata);
+  if (iree_status_is_ok(status)) {
+    iree_hal_amdgpu_hsaco_metadata_deinitialize(&metadata);
+  } else {
+    iree_status_free(status);
+  }
+}
+
+extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
+  constexpr size_t kMaxInputSize = 64 * 1024;
+  if (size > kMaxInputSize) size = kMaxInputSize;
+
+  iree_hal_amdgpu_hsaco_metadata_fuzz_parse(
+      iree_make_const_byte_span(data, size));
+
+  std::vector<uint8_t> elf =
+      iree_hal_amdgpu_hsaco_metadata_fuzz_wrap_as_elf(data, size);
+  iree_hal_amdgpu_hsaco_metadata_fuzz_parse(
+      iree_make_const_byte_span(elf.data(), elf.size()));
+  return 0;
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/hsaco_metadata_test.cc b/runtime/src/iree/hal/drivers/amdgpu/util/hsaco_metadata_test.cc
new file mode 100644
index 0000000..41e0aa2
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/hsaco_metadata_test.cc
@@ -0,0 +1,604 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/util/hsaco_metadata.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <string>
+#include <vector>
+
+#include "iree/base/api.h"
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace iree::hal::amdgpu {
+namespace {
+
+static void AppendByte(std::vector<uint8_t>* output, uint8_t value) {
+  output->push_back(value);
+}
+
+static void AppendU16BE(std::vector<uint8_t>* output, uint16_t value) {
+  output->push_back((uint8_t)(value >> 8));
+  output->push_back((uint8_t)value);
+}
+
+static void AppendU32BE(std::vector<uint8_t>* output, uint32_t value) {
+  output->push_back((uint8_t)(value >> 24));
+  output->push_back((uint8_t)(value >> 16));
+  output->push_back((uint8_t)(value >> 8));
+  output->push_back((uint8_t)value);
+}
+
+static void AppendU32LE(std::vector<uint8_t>* output, uint32_t value) {
+  output->push_back((uint8_t)value);
+  output->push_back((uint8_t)(value >> 8));
+  output->push_back((uint8_t)(value >> 16));
+  output->push_back((uint8_t)(value >> 24));
+}
+
+static void StoreU16LE(std::vector<uint8_t>* output, size_t offset,
+                       uint16_t value) {
+  (*output)[offset + 0] = (uint8_t)value;
+  (*output)[offset + 1] = (uint8_t)(value >> 8);
+}
+
+static void StoreU32LE(std::vector<uint8_t>* output, size_t offset,
+                       uint32_t value) {
+  (*output)[offset + 0] = (uint8_t)value;
+  (*output)[offset + 1] = (uint8_t)(value >> 8);
+  (*output)[offset + 2] = (uint8_t)(value >> 16);
+  (*output)[offset + 3] = (uint8_t)(value >> 24);
+}
+
+static void StoreU64LE(std::vector<uint8_t>* output, size_t offset,
+                       uint64_t value) {
+  for (int i = 0; i < 8; ++i) {
+    (*output)[offset + i] = (uint8_t)(value >> (i * 8));
+  }
+}
+
+static void AppendAligned4Padding(std::vector<uint8_t>* output) {
+  while ((output->size() & 3) != 0) output->push_back(0);
+}
+
+static void AppendMsgPackMap(std::vector<uint8_t>* output, uint32_t count) {
+  if (count < 16) {
+    AppendByte(output, (uint8_t)(0x80 | count));
+  } else {
+    AppendByte(output, 0xDE);
+    AppendU16BE(output, (uint16_t)count);
+  }
+}
+
+static void AppendMsgPackArray(std::vector<uint8_t>* output, uint32_t count) {
+  if (count < 16) {
+    AppendByte(output, (uint8_t)(0x90 | count));
+  } else {
+    AppendByte(output, 0xDC);
+    AppendU16BE(output, (uint16_t)count);
+  }
+}
+
+static void AppendMsgPackString(std::vector<uint8_t>* output,
+                                iree_string_view_t value) {
+  if (value.size < 32) {
+    AppendByte(output, (uint8_t)(0xA0 | value.size));
+  } else if (value.size <= UINT8_MAX) {
+    AppendByte(output, 0xD9);
+    AppendByte(output, (uint8_t)value.size);
+  } else {
+    AppendByte(output, 0xDA);
+    AppendU16BE(output, (uint16_t)value.size);
+  }
+  output->insert(output->end(), value.data, value.data + value.size);
+}
+
+static void AppendMsgPackUint(std::vector<uint8_t>* output, uint32_t value) {
+  if (value <= 0x7F) {
+    AppendByte(output, (uint8_t)value);
+  } else if (value <= UINT8_MAX) {
+    AppendByte(output, 0xCC);
+    AppendByte(output, (uint8_t)value);
+  } else if (value <= UINT16_MAX) {
+    AppendByte(output, 0xCD);
+    AppendU16BE(output, (uint16_t)value);
+  } else {
+    AppendByte(output, 0xCE);
+    AppendU32BE(output, value);
+  }
+}
+
+static void AppendStringField(std::vector<uint8_t>* output,
+                              iree_string_view_t key,
+                              iree_string_view_t value) {
+  AppendMsgPackString(output, key);
+  AppendMsgPackString(output, value);
+}
+
+static void AppendUintField(std::vector<uint8_t>* output,
+                            iree_string_view_t key, uint32_t value) {
+  AppendMsgPackString(output, key);
+  AppendMsgPackUint(output, value);
+}
+
+static std::vector<uint8_t> BuildKernelMetadata(
+    bool out_of_range_arg = false, bool unknown_value_kind = false,
+    bool narrow_by_value_arg = false) {
+  std::vector<uint8_t> output;
+  AppendMsgPackMap(&output, 3);
+
+  AppendMsgPackString(&output, IREE_SV("amdhsa.version"));
+  AppendMsgPackArray(&output, 2);
+  AppendMsgPackUint(&output, 1);
+  AppendMsgPackUint(&output, 2);
+
+  AppendStringField(&output, IREE_SV("amdhsa.target"),
+                    IREE_SV("amdgcn-amd-amdhsa--gfx1100"));
+
+  AppendMsgPackString(&output, IREE_SV("amdhsa.kernels"));
+  AppendMsgPackArray(&output, 1);
+  AppendMsgPackMap(&output, 8);
+  AppendStringField(&output, IREE_SV(".name"), IREE_SV("vector_add"));
+  AppendStringField(&output, IREE_SV(".symbol"), IREE_SV("vector_add.kd"));
+  AppendUintField(&output, IREE_SV(".kernarg_segment_size"), 24);
+  AppendUintField(&output, IREE_SV(".kernarg_segment_align"), 8);
+  AppendUintField(&output, IREE_SV(".group_segment_fixed_size"), 1024);
+  AppendUintField(&output, IREE_SV(".private_segment_fixed_size"), 64);
+  AppendMsgPackString(&output, IREE_SV(".reqd_workgroup_size"));
+  AppendMsgPackArray(&output, 3);
+  AppendMsgPackUint(&output, 16);
+  AppendMsgPackUint(&output, 4);
+  AppendMsgPackUint(&output, 1);
+
+  AppendMsgPackString(&output, IREE_SV(".args"));
+  AppendMsgPackArray(&output, 4);
+
+  AppendMsgPackMap(&output, 8);
+  AppendStringField(&output, IREE_SV(".name"), IREE_SV("lhs"));
+  AppendUintField(&output, IREE_SV(".offset"), 0);
+  AppendUintField(&output, IREE_SV(".size"), 8);
+  AppendStringField(
+      &output, IREE_SV(".value_kind"),
+      unknown_value_kind ? IREE_SV("made_up_kind") : IREE_SV("global_buffer"));
+  AppendStringField(&output, IREE_SV(".address_space"), IREE_SV("global"));
+  AppendStringField(&output, IREE_SV(".access"), IREE_SV("read_write"));
+  AppendStringField(&output, IREE_SV(".actual_access"), IREE_SV("read_only"));
+  AppendUintField(&output, IREE_SV(".align"), 8);
+
+  AppendMsgPackMap(&output, 7);
+  AppendStringField(&output, IREE_SV(".name"), IREE_SV("rhs"));
+  AppendUintField(&output, IREE_SV(".offset"), 8);
+  AppendUintField(&output, IREE_SV(".size"), 8);
+  AppendStringField(&output, IREE_SV(".value_kind"), IREE_SV("global_buffer"));
+  AppendStringField(&output, IREE_SV(".address_space"), IREE_SV("global"));
+  AppendStringField(&output, IREE_SV(".access"), IREE_SV("write_only"));
+  AppendUintField(&output, IREE_SV(".align"), 8);
+
+  AppendMsgPackMap(&output, 5);
+  AppendStringField(&output, IREE_SV(".name"), IREE_SV("n"));
+  AppendUintField(&output, IREE_SV(".offset"), 16);
+  AppendUintField(&output, IREE_SV(".size"), narrow_by_value_arg ? 2 : 4);
+  AppendStringField(&output, IREE_SV(".value_kind"), IREE_SV("by_value"));
+  AppendUintField(&output, IREE_SV(".align"), 4);
+
+  AppendMsgPackMap(&output, 5);
+  AppendStringField(&output, IREE_SV(".name"), IREE_SV("alpha"));
+  AppendUintField(&output, IREE_SV(".offset"), out_of_range_arg ? 20 : 20);
+  AppendUintField(&output, IREE_SV(".size"), out_of_range_arg ? 8 : 4);
+  AppendStringField(&output, IREE_SV(".value_kind"), IREE_SV("by_value"));
+  AppendUintField(&output, IREE_SV(".align"), 4);
+
+  return output;
+}
+
+static std::vector<uint8_t> BuildHiddenArgumentMetadata() {
+  std::vector<uint8_t> output;
+  AppendMsgPackMap(&output, 1);
+  AppendMsgPackString(&output, IREE_SV("amdhsa.kernels"));
+  AppendMsgPackArray(&output, 1);
+  AppendMsgPackMap(&output, 6);
+  AppendStringField(&output, IREE_SV(".symbol"), IREE_SV("hidden_args.kd"));
+  AppendUintField(&output, IREE_SV(".kernarg_segment_size"), 20);
+  AppendUintField(&output, IREE_SV(".kernarg_segment_align"), 8);
+  AppendUintField(&output, IREE_SV(".group_segment_fixed_size"), 0);
+  AppendUintField(&output, IREE_SV(".private_segment_fixed_size"), 0);
+  AppendMsgPackString(&output, IREE_SV(".args"));
+  AppendMsgPackArray(&output, 3);
+
+  AppendMsgPackMap(&output, 5);
+  AppendStringField(&output, IREE_SV(".name"), IREE_SV("buffer"));
+  AppendUintField(&output, IREE_SV(".offset"), 0);
+  AppendUintField(&output, IREE_SV(".size"), 8);
+  AppendStringField(&output, IREE_SV(".value_kind"), IREE_SV("global_buffer"));
+  AppendUintField(&output, IREE_SV(".align"), 8);
+
+  AppendMsgPackMap(&output, 5);
+  AppendStringField(&output, IREE_SV(".name"), IREE_SV("grid_x"));
+  AppendUintField(&output, IREE_SV(".offset"), 8);
+  AppendUintField(&output, IREE_SV(".size"), 8);
+  AppendStringField(&output, IREE_SV(".value_kind"),
+                    IREE_SV("hidden_global_offset_x"));
+  AppendUintField(&output, IREE_SV(".align"), 8);
+
+  AppendMsgPackMap(&output, 5);
+  AppendStringField(&output, IREE_SV(".name"), IREE_SV("value"));
+  AppendUintField(&output, IREE_SV(".offset"), 16);
+  AppendUintField(&output, IREE_SV(".size"), 4);
+  AppendStringField(&output, IREE_SV(".value_kind"), IREE_SV("by_value"));
+  AppendUintField(&output, IREE_SV(".align"), 4);
+
+  return output;
+}
+
+static std::vector<uint8_t> BuildMalformedMissingKernelFieldsMetadata() {
+  std::vector<uint8_t> output;
+  AppendMsgPackMap(&output, 1);
+  AppendMsgPackString(&output, IREE_SV("amdhsa.kernels"));
+  AppendMsgPackArray(&output, 1);
+  AppendMsgPackMap(&output, 0);
+  return output;
+}
+
+static std::vector<uint8_t> BuildDuplicateArgumentFieldMetadata() {
+  std::vector<uint8_t> output;
+  AppendMsgPackMap(&output, 1);
+  AppendMsgPackString(&output, IREE_SV("amdhsa.kernels"));
+  AppendMsgPackArray(&output, 1);
+  AppendMsgPackMap(&output, 6);
+  AppendStringField(&output, IREE_SV(".symbol"), IREE_SV("duplicate.kd"));
+  AppendUintField(&output, IREE_SV(".kernarg_segment_size"), 8);
+  AppendUintField(&output, IREE_SV(".kernarg_segment_align"), 8);
+  AppendUintField(&output, IREE_SV(".group_segment_fixed_size"), 0);
+  AppendUintField(&output, IREE_SV(".private_segment_fixed_size"), 0);
+  AppendMsgPackString(&output, IREE_SV(".args"));
+  AppendMsgPackArray(&output, 1);
+  AppendMsgPackMap(&output, 4);
+  AppendUintField(&output, IREE_SV(".offset"), 0);
+  AppendUintField(&output, IREE_SV(".offset"), 4);
+  AppendUintField(&output, IREE_SV(".size"), 4);
+  AppendStringField(&output, IREE_SV(".value_kind"), IREE_SV("by_value"));
+  return output;
+}
+
+static std::vector<uint8_t> BuildElfWithNote(
+    const std::vector<uint8_t>& metadata, iree_string_view_t note_name,
+    uint32_t note_type) {
+  constexpr size_t kElfHeaderSize = 64;
+  constexpr size_t kProgramHeaderOffset = 64;
+  constexpr size_t kProgramHeaderSize = 56;
+  constexpr size_t kNoteOffset = 128;
+
+  std::vector<uint8_t> note;
+  AppendU32LE(&note, (uint32_t)note_name.size + 1);
+  AppendU32LE(&note, (uint32_t)metadata.size());
+  AppendU32LE(&note, note_type);
+  note.insert(note.end(), note_name.data, note_name.data + note_name.size);
+  note.push_back(0);
+  AppendAligned4Padding(&note);
+  note.insert(note.end(), metadata.begin(), metadata.end());
+  AppendAligned4Padding(&note);
+
+  std::vector<uint8_t> elf(kNoteOffset, 0);
+  elf[0] = 0x7F;
+  elf[1] = 'E';
+  elf[2] = 'L';
+  elf[3] = 'F';
+  elf[4] = 2;               // ELFCLASS64.
+  elf[5] = 1;               // ELFDATA2LSB.
+  elf[6] = 1;               // EV_CURRENT.
+  StoreU16LE(&elf, 16, 3);  // ET_DYN.
+  StoreU16LE(&elf, 18, 224);
+  StoreU32LE(&elf, 20, 1);
+  StoreU64LE(&elf, 32, kProgramHeaderOffset);
+  StoreU16LE(&elf, 52, kElfHeaderSize);
+  StoreU16LE(&elf, 54, kProgramHeaderSize);
+  StoreU16LE(&elf, 56, 1);
+
+  StoreU32LE(&elf, kProgramHeaderOffset + 0, 4);  // PT_NOTE.
+  StoreU64LE(&elf, kProgramHeaderOffset + 8, kNoteOffset);
+  StoreU64LE(&elf, kProgramHeaderOffset + 32, note.size());
+  StoreU64LE(&elf, kProgramHeaderOffset + 40, note.size());
+  StoreU64LE(&elf, kProgramHeaderOffset + 48, 4);
+
+  elf.insert(elf.end(), note.begin(), note.end());
+  return elf;
+}
+
+static std::vector<uint8_t> BuildElfWithMetadata(
+    const std::vector<uint8_t>& metadata) {
+  return BuildElfWithNote(metadata, IREE_SV("AMDGPU"), 32);
+}
+
+static iree_const_byte_span_t ByteSpan(const std::vector<uint8_t>& data) {
+  return iree_make_const_byte_span(data.data(), data.size());
+}
+
+static std::string ToString(iree_string_view_t value) {
+  return std::string(value.data, value.size);
+}
+
+TEST(HsacoMetadataTest, ParsesValidMetadata) {
+  std::vector<uint8_t> elf = BuildElfWithMetadata(BuildKernelMetadata());
+
+  iree_hal_amdgpu_hsaco_metadata_t metadata;
+  IREE_ASSERT_OK(iree_hal_amdgpu_hsaco_metadata_initialize_from_elf(
+      ByteSpan(elf), iree_allocator_system(), &metadata));
+
+  ASSERT_EQ(metadata.kernel_count, 1);
+  ASSERT_EQ(metadata.arg_count, 4);
+  ASSERT_GT(metadata.message_pack_data.data_length, 0);
+  EXPECT_TRUE(iree_string_view_starts_with(metadata.target,
+                                           IREE_SV("amdgcn-amd-amdhsa--gfx")));
+  ASSERT_NE(metadata.kernels, nullptr);
+  ASSERT_NE(metadata.args, nullptr);
+
+  const iree_hal_amdgpu_hsaco_metadata_kernel_t& kernel = metadata.kernels[0];
+  EXPECT_EQ(ToString(kernel.name), "vector_add");
+  EXPECT_EQ(ToString(kernel.symbol_name), "vector_add.kd");
+  EXPECT_EQ(ToString(kernel.reflection_name), "vector_add");
+  EXPECT_EQ(kernel.kernarg_segment_size, 24);
+  EXPECT_EQ(kernel.kernarg_segment_alignment, 8);
+  EXPECT_EQ(kernel.group_segment_fixed_size, 1024);
+  EXPECT_EQ(kernel.private_segment_fixed_size, 64);
+  ASSERT_TRUE(kernel.has_required_workgroup_size);
+  EXPECT_EQ(kernel.required_workgroup_size[0], 16);
+  EXPECT_EQ(kernel.required_workgroup_size[1], 4);
+  EXPECT_EQ(kernel.required_workgroup_size[2], 1);
+  ASSERT_EQ(kernel.arg_count, 4);
+  ASSERT_EQ(kernel.args, metadata.args);
+  EXPECT_EQ(kernel.arg_name_storage_size, 12);
+  EXPECT_EQ(metadata.reflection_name_storage_size, 10);
+  EXPECT_EQ(metadata.arg_name_storage_size, 12);
+
+  EXPECT_EQ(ToString(kernel.args[0].name), "lhs");
+  EXPECT_EQ(kernel.args[0].offset, 0);
+  EXPECT_EQ(kernel.args[0].size, 8);
+  EXPECT_EQ(kernel.args[0].alignment, 8);
+  EXPECT_EQ(kernel.args[0].kind,
+            IREE_HAL_AMDGPU_HSACO_METADATA_ARG_KIND_GLOBAL_BUFFER);
+  EXPECT_EQ(ToString(kernel.args[0].value_kind), "global_buffer");
+  EXPECT_EQ(ToString(kernel.args[0].address_space), "global");
+  EXPECT_EQ(ToString(kernel.args[0].access), "read_only");
+
+  EXPECT_EQ(ToString(kernel.args[1].name), "rhs");
+  EXPECT_EQ(kernel.args[1].offset, 8);
+  EXPECT_EQ(kernel.args[1].size, 8);
+  EXPECT_EQ(ToString(kernel.args[1].access), "write_only");
+
+  EXPECT_EQ(ToString(kernel.args[2].name), "n");
+  EXPECT_EQ(kernel.args[2].offset, 16);
+  EXPECT_EQ(kernel.args[2].size, 4);
+  EXPECT_EQ(kernel.args[2].kind,
+            IREE_HAL_AMDGPU_HSACO_METADATA_ARG_KIND_BY_VALUE);
+
+  EXPECT_EQ(ToString(kernel.args[3].name), "alpha");
+  EXPECT_EQ(kernel.args[3].offset, 20);
+  EXPECT_EQ(kernel.args[3].size, 4);
+  EXPECT_EQ(kernel.args[3].kind,
+            IREE_HAL_AMDGPU_HSACO_METADATA_ARG_KIND_BY_VALUE);
+
+  iree_hal_amdgpu_hsaco_metadata_deinitialize(&metadata);
+}
+
+TEST(HsacoMetadataTest, PopulatesDefaultExportParameters) {
+  std::vector<uint8_t> elf = BuildElfWithMetadata(BuildKernelMetadata());
+
+  iree_hal_amdgpu_hsaco_metadata_t metadata;
+  IREE_ASSERT_OK(iree_hal_amdgpu_hsaco_metadata_initialize_from_elf(
+      ByteSpan(elf), iree_allocator_system(), &metadata));
+
+  const iree_hal_amdgpu_hsaco_metadata_kernel_t& kernel = metadata.kernels[0];
+  iree_hal_amdgpu_hsaco_metadata_export_parameter_requirements_t requirements;
+  IREE_ASSERT_OK(
+      iree_hal_amdgpu_hsaco_metadata_calculate_default_export_parameter_requirements(
+          &kernel, &requirements));
+  EXPECT_EQ(requirements.parameter_count, 4);
+  EXPECT_EQ(requirements.binding_count, 2);
+  EXPECT_EQ(requirements.constant_count, 2);
+  EXPECT_EQ(requirements.name_storage_size, 12);
+
+  std::vector<iree_hal_executable_export_parameter_t> parameters(
+      requirements.parameter_count);
+  std::vector<char> name_storage(requirements.name_storage_size);
+  IREE_ASSERT_OK(
+      iree_hal_amdgpu_hsaco_metadata_populate_default_export_parameters(
+          &kernel, parameters.size(), parameters.data(), name_storage.size(),
+          name_storage.data()));
+
+  EXPECT_EQ(parameters[0].type,
+            IREE_HAL_EXECUTABLE_EXPORT_PARAMETER_TYPE_BINDING);
+  EXPECT_EQ(parameters[0].size, 8);
+  EXPECT_EQ(parameters[0].offset, 0);
+  EXPECT_EQ(ToString(parameters[0].name), "lhs");
+
+  EXPECT_EQ(parameters[1].type,
+            IREE_HAL_EXECUTABLE_EXPORT_PARAMETER_TYPE_BINDING);
+  EXPECT_EQ(parameters[1].size, 8);
+  EXPECT_EQ(parameters[1].offset, 1);
+  EXPECT_EQ(ToString(parameters[1].name), "rhs");
+
+  EXPECT_EQ(parameters[2].type,
+            IREE_HAL_EXECUTABLE_EXPORT_PARAMETER_TYPE_CONSTANT);
+  EXPECT_EQ(parameters[2].size, 4);
+  EXPECT_EQ(parameters[2].offset, 0);
+  EXPECT_EQ(ToString(parameters[2].name), "n");
+
+  EXPECT_EQ(parameters[3].type,
+            IREE_HAL_EXECUTABLE_EXPORT_PARAMETER_TYPE_CONSTANT);
+  EXPECT_EQ(parameters[3].size, 4);
+  EXPECT_EQ(parameters[3].offset, 4);
+  EXPECT_EQ(ToString(parameters[3].name), "alpha");
+
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_RESOURCE_EXHAUSTED,
+      iree_hal_amdgpu_hsaco_metadata_populate_default_export_parameters(
+          &kernel, parameters.size() - 1, parameters.data(),
+          name_storage.size(), name_storage.data()));
+
+  iree_hal_amdgpu_hsaco_metadata_deinitialize(&metadata);
+}
+
+TEST(HsacoMetadataTest, DefaultExportParametersSkipHiddenArguments) {
+  std::vector<uint8_t> elf =
+      BuildElfWithMetadata(BuildHiddenArgumentMetadata());
+
+  iree_hal_amdgpu_hsaco_metadata_t metadata;
+  IREE_ASSERT_OK(iree_hal_amdgpu_hsaco_metadata_initialize_from_elf(
+      ByteSpan(elf), iree_allocator_system(), &metadata));
+
+  const iree_hal_amdgpu_hsaco_metadata_kernel_t& kernel = metadata.kernels[0];
+  EXPECT_EQ(ToString(kernel.reflection_name), "hidden_args");
+  ASSERT_EQ(kernel.arg_count, 3);
+  EXPECT_EQ(kernel.args[1].kind,
+            IREE_HAL_AMDGPU_HSACO_METADATA_ARG_KIND_HIDDEN);
+
+  iree_hal_amdgpu_hsaco_metadata_export_parameter_requirements_t requirements;
+  IREE_ASSERT_OK(
+      iree_hal_amdgpu_hsaco_metadata_calculate_default_export_parameter_requirements(
+          &kernel, &requirements));
+  EXPECT_EQ(requirements.parameter_count, 2);
+  EXPECT_EQ(requirements.binding_count, 1);
+  EXPECT_EQ(requirements.constant_count, 1);
+  EXPECT_EQ(requirements.name_storage_size, 11);
+
+  std::vector<iree_hal_executable_export_parameter_t> parameters(
+      requirements.parameter_count);
+  std::vector<char> name_storage(requirements.name_storage_size);
+  IREE_ASSERT_OK(
+      iree_hal_amdgpu_hsaco_metadata_populate_default_export_parameters(
+          &kernel, parameters.size(), parameters.data(), name_storage.size(),
+          name_storage.data()));
+  EXPECT_EQ(ToString(parameters[0].name), "buffer");
+  EXPECT_EQ(parameters[0].type,
+            IREE_HAL_EXECUTABLE_EXPORT_PARAMETER_TYPE_BINDING);
+  EXPECT_EQ(ToString(parameters[1].name), "value");
+  EXPECT_EQ(parameters[1].type,
+            IREE_HAL_EXECUTABLE_EXPORT_PARAMETER_TYPE_CONSTANT);
+
+  iree_hal_amdgpu_hsaco_metadata_deinitialize(&metadata);
+}
+
+TEST(HsacoMetadataTest, RejectsNarrowByValueDefaultExportParameter) {
+  std::vector<uint8_t> elf = BuildElfWithMetadata(BuildKernelMetadata(
+      /*out_of_range_arg=*/false, /*unknown_value_kind=*/false,
+      /*narrow_by_value_arg=*/true));
+
+  iree_hal_amdgpu_hsaco_metadata_t metadata;
+  IREE_ASSERT_OK(iree_hal_amdgpu_hsaco_metadata_initialize_from_elf(
+      ByteSpan(elf), iree_allocator_system(), &metadata));
+
+  iree_hal_amdgpu_hsaco_metadata_export_parameter_requirements_t requirements;
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_INVALID_ARGUMENT,
+      iree_hal_amdgpu_hsaco_metadata_calculate_default_export_parameter_requirements(
+          &metadata.kernels[0], &requirements));
+
+  iree_hal_amdgpu_hsaco_metadata_deinitialize(&metadata);
+}
+
+TEST(HsacoMetadataTest, FindsKernelBySymbol) {
+  std::vector<uint8_t> elf = BuildElfWithMetadata(BuildKernelMetadata());
+
+  iree_hal_amdgpu_hsaco_metadata_t metadata;
+  IREE_ASSERT_OK(iree_hal_amdgpu_hsaco_metadata_initialize_from_elf(
+      ByteSpan(elf), iree_allocator_system(), &metadata));
+
+  const iree_hal_amdgpu_hsaco_metadata_kernel_t* kernel = nullptr;
+  IREE_EXPECT_OK(iree_hal_amdgpu_hsaco_metadata_find_kernel_by_symbol(
+      &metadata, IREE_SV("vector_add.kd"), &kernel));
+  ASSERT_NE(kernel, nullptr);
+  EXPECT_EQ(ToString(kernel->name), "vector_add");
+
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_NOT_FOUND,
+                        iree_hal_amdgpu_hsaco_metadata_find_kernel_by_symbol(
+                            &metadata, IREE_SV("missing.kd"), &kernel));
+
+  iree_hal_amdgpu_hsaco_metadata_deinitialize(&metadata);
+}
+
+TEST(HsacoMetadataTest, AllowsUnknownValueKindAsOpaqueMetadata) {
+  std::vector<uint8_t> elf =
+      BuildElfWithMetadata(BuildKernelMetadata(/*out_of_range_arg=*/false,
+                                               /*unknown_value_kind=*/true));
+
+  iree_hal_amdgpu_hsaco_metadata_t metadata;
+  IREE_ASSERT_OK(iree_hal_amdgpu_hsaco_metadata_initialize_from_elf(
+      ByteSpan(elf), iree_allocator_system(), &metadata));
+
+  ASSERT_EQ(metadata.kernel_count, 1);
+  ASSERT_EQ(metadata.kernels[0].arg_count, 4);
+  EXPECT_EQ(metadata.kernels[0].args[0].kind,
+            IREE_HAL_AMDGPU_HSACO_METADATA_ARG_KIND_UNKNOWN);
+  EXPECT_EQ(ToString(metadata.kernels[0].args[0].value_kind), "made_up_kind");
+  iree_hal_amdgpu_hsaco_metadata_export_parameter_requirements_t requirements;
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_INVALID_ARGUMENT,
+      iree_hal_amdgpu_hsaco_metadata_calculate_default_export_parameter_requirements(
+          &metadata.kernels[0], &requirements));
+
+  iree_hal_amdgpu_hsaco_metadata_deinitialize(&metadata);
+}
+
+TEST(HsacoMetadataTest, RejectsOutOfRangeArgument) {
+  std::vector<uint8_t> elf =
+      BuildElfWithMetadata(BuildKernelMetadata(/*out_of_range_arg=*/true));
+
+  iree_hal_amdgpu_hsaco_metadata_t metadata;
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_INVALID_ARGUMENT,
+                        iree_hal_amdgpu_hsaco_metadata_initialize_from_elf(
+                            ByteSpan(elf), iree_allocator_system(), &metadata));
+}
+
+TEST(HsacoMetadataTest, RejectsMissingMetadataNote) {
+  std::vector<uint8_t> elf =
+      BuildElfWithNote(BuildKernelMetadata(), IREE_SV("OTHER"), 32);
+
+  iree_hal_amdgpu_hsaco_metadata_t metadata;
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_NOT_FOUND,
+                        iree_hal_amdgpu_hsaco_metadata_initialize_from_elf(
+                            ByteSpan(elf), iree_allocator_system(), &metadata));
+}
+
+TEST(HsacoMetadataTest, RejectsMalformedMessagePackMetadata) {
+  std::vector<uint8_t> elf =
+      BuildElfWithMetadata(BuildMalformedMissingKernelFieldsMetadata());
+
+  iree_hal_amdgpu_hsaco_metadata_t metadata;
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_INVALID_ARGUMENT,
+                        iree_hal_amdgpu_hsaco_metadata_initialize_from_elf(
+                            ByteSpan(elf), iree_allocator_system(), &metadata));
+}
+
+TEST(HsacoMetadataTest, RejectsDuplicateArgumentField) {
+  std::vector<uint8_t> elf =
+      BuildElfWithMetadata(BuildDuplicateArgumentFieldMetadata());
+
+  iree_hal_amdgpu_hsaco_metadata_t metadata;
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_INVALID_ARGUMENT,
+                        iree_hal_amdgpu_hsaco_metadata_initialize_from_elf(
+                            ByteSpan(elf), iree_allocator_system(), &metadata));
+}
+
+TEST(HsacoMetadataTest, TruncatedElfPrefixesNeverSucceed) {
+  std::vector<uint8_t> elf = BuildElfWithMetadata(BuildKernelMetadata());
+  for (size_t length = 0; length < elf.size(); ++length) {
+    iree_hal_amdgpu_hsaco_metadata_t metadata;
+    iree_status_t status = iree_hal_amdgpu_hsaco_metadata_initialize_from_elf(
+        iree_make_const_byte_span(elf.data(), length), iree_allocator_system(),
+        &metadata);
+    if (iree_status_is_ok(status)) {
+      iree_hal_amdgpu_hsaco_metadata_deinitialize(&metadata);
+      ADD_FAILURE() << "unexpected success for truncated ELF prefix " << length;
+      return;
+    }
+    iree_status_free(status);
+  }
+}
+
+}  // namespace
+}  // namespace iree::hal::amdgpu
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/info.c b/runtime/src/iree/hal/drivers/amdgpu/util/info.c
index c57054e..fae2b3e 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/util/info.c
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/info.c
@@ -74,6 +74,7 @@
                              "only systems with SVM are supported "
                              "(HSA_AMD_SYSTEM_INFO_SVM_SUPPORTED == true)"));
   }
+  out_info->svm.supported = 1;
 
   bool svm_accessible_by_default = false;
   IREE_RETURN_AND_END_ZONE_IF_ERROR(
@@ -82,7 +83,16 @@
                                HSA_AMD_SYSTEM_INFO_SVM_ACCESSIBLE_BY_DEFAULT,
                                &svm_accessible_by_default),
       "querying HSA_AMD_SYSTEM_INFO_SVM_ACCESSIBLE_BY_DEFAULT");
-  out_info->svm_accessible_by_default = svm_accessible_by_default ? 1 : 0;
+  out_info->svm.accessible_by_default = svm_accessible_by_default ? 1 : 0;
+
+  bool xnack_enabled = false;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0,
+      iree_hsa_system_get_info(IREE_LIBHSA(libhsa),
+                               HSA_AMD_SYSTEM_INFO_XNACK_ENABLED,
+                               &xnack_enabled),
+      "querying HSA_AMD_SYSTEM_INFO_XNACK_ENABLED");
+  out_info->svm.xnack_enabled = xnack_enabled ? 1 : 0;
 
   bool dmabuf_supported = false;
   IREE_RETURN_AND_END_ZONE_IF_ERROR(
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/info.h b/runtime/src/iree/hal/drivers/amdgpu/util/info.h
index c77e3f0..d1fb5ae 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/util/info.h
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/info.h
@@ -23,10 +23,18 @@
   // Timestamp value increase rate in hz.
   // Query of HSA_SYSTEM_INFO_TIMESTAMP_FREQUENCY.
   uint64_t timestamp_frequency;
-  // Whether all agents have access to system allocated memory by default.
-  // This is true on APUs and discrete GPUs with XNACK enabled.
-  // Query of HSA_AMD_SYSTEM_INFO_SVM_ACCESSIBLE_BY_DEFAULT.
-  uint32_t svm_accessible_by_default : 1;
+  // HSA SVM/HMM process-wide capability facts.
+  struct {
+    // Whether the HSA SVM attribute and prefetch APIs are supported.
+    // Query of HSA_AMD_SYSTEM_INFO_SVM_SUPPORTED.
+    uint32_t supported : 1;
+    // Whether all agents have access to system allocated memory by default.
+    // Query of HSA_AMD_SYSTEM_INFO_SVM_ACCESSIBLE_BY_DEFAULT.
+    uint32_t accessible_by_default : 1;
+    // Whether the process is bound to XNACK-enabled execution.
+    // Query of HSA_AMD_SYSTEM_INFO_XNACK_ENABLED.
+    uint32_t xnack_enabled : 1;
+  } svm;
   // Whether the dmabuf APIs are supported by the driver.
   // Query of HSA_AMD_SYSTEM_INFO_DMABUF_SUPPORTED.
   uint32_t dmabuf_supported : 1;
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/kernarg_ring.c b/runtime/src/iree/hal/drivers/amdgpu/util/kernarg_ring.c
new file mode 100644
index 0000000..f8cfca8
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/kernarg_ring.c
@@ -0,0 +1,148 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/util/kernarg_ring.h"
+
+//===----------------------------------------------------------------------===//
+// iree_hal_amdgpu_kernarg_ring_t
+//===----------------------------------------------------------------------===//
+
+iree_status_t iree_hal_amdgpu_kernarg_ring_initialize(
+    const iree_hal_amdgpu_libhsa_t* libhsa,
+    const iree_hal_amdgpu_kernarg_ring_memory_t* memory,
+    uint32_t min_capacity_in_blocks, iree_hal_amdgpu_kernarg_ring_t* out_ring) {
+  IREE_ASSERT_ARGUMENT(libhsa);
+  IREE_ASSERT_ARGUMENT(memory);
+  IREE_ASSERT_ARGUMENT(out_ring);
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)min_capacity_in_blocks);
+  memset(out_ring, 0, sizeof(*out_ring));
+
+  IREE_ASSERT(memory->memory_pool.handle,
+              "kernarg ring memory descriptor must provide a memory pool");
+  IREE_ASSERT(memory->access_agents,
+              "kernarg ring memory descriptor must provide access agents");
+  IREE_ASSERT(memory->access_agent_count > 0 &&
+                  memory->access_agent_count <= UINT32_MAX,
+              "kernarg ring access agent count must fit in HSA's uint32_t");
+  if (!min_capacity_in_blocks ||
+      !iree_host_size_is_power_of_two(min_capacity_in_blocks)) {
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                             "kernarg ring capacity must be a non-zero "
+                             "power of two; got %u",
+                             min_capacity_in_blocks));
+  }
+  IREE_ASSERT(memory->publication.mode ==
+                  IREE_HAL_AMDGPU_KERNARG_RING_PUBLICATION_MODE_NONE ||
+              memory->publication.mode ==
+                  IREE_HAL_AMDGPU_KERNARG_RING_PUBLICATION_MODE_HDP_FLUSH);
+  IREE_ASSERT(memory->publication.mode !=
+                  IREE_HAL_AMDGPU_KERNARG_RING_PUBLICATION_MODE_HDP_FLUSH ||
+              memory->publication.hdp_mem_flush_control);
+
+  // HSA VMEM handles are not supported for CPU pools on at least current ROCm
+  // stacks, so the host kernarg ring uses a plain pool allocation and wraps by
+  // skipping tail padding in the allocator.
+  size_t alloc_granule = 0;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0,
+      iree_hsa_amd_memory_pool_get_info(
+          IREE_LIBHSA(libhsa), memory->memory_pool,
+          HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_GRANULE, &alloc_granule),
+      "querying HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_GRANULE for kernarg "
+      "ring allocation");
+  if (IREE_UNLIKELY(!alloc_granule ||
+                    !iree_host_size_is_power_of_two(alloc_granule))) {
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                             "kernarg memory pool allocation granule must be "
+                             "a non-zero power of two (got %zu)",
+                             alloc_granule));
+  }
+
+  uint32_t capacity = min_capacity_in_blocks;
+  while ((uint64_t)capacity * sizeof(iree_hal_amdgpu_kernarg_block_t) <
+             alloc_granule &&
+         capacity <= UINT32_MAX / 2) {
+    capacity <<= 1;
+  }
+  IREE_ASSERT(capacity >= min_capacity_in_blocks);
+  IREE_ASSERT(iree_host_size_is_power_of_two(capacity));
+  const size_t capacity_in_bytes =
+      (size_t)capacity * sizeof(iree_hal_amdgpu_kernarg_block_t);
+  if (IREE_UNLIKELY(
+          !iree_host_size_has_alignment(capacity_in_bytes, alloc_granule))) {
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                             "kernarg ring capacity %" PRIhsz
+                             " bytes is not aligned to pool allocation "
+                             "granule %zu",
+                             capacity_in_bytes, alloc_granule));
+  }
+
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0,
+      iree_hsa_amd_memory_pool_allocate(
+          IREE_LIBHSA(libhsa), memory->memory_pool, capacity_in_bytes,
+          HSA_AMD_MEMORY_POOL_STANDARD_FLAG, (void**)&out_ring->base),
+      "allocating kernarg ring of %" PRIhsz " bytes (%u blocks)",
+      capacity_in_bytes, capacity);
+  iree_status_t status = iree_hsa_amd_agents_allow_access(
+      IREE_LIBHSA(libhsa), (uint32_t)memory->access_agent_count,
+      memory->access_agents, /*flags=*/NULL, out_ring->base);
+  if (!iree_status_is_ok(status)) {
+    status = iree_status_join(status, iree_hsa_amd_memory_pool_free(
+                                          IREE_LIBHSA(libhsa), out_ring->base));
+    memset(out_ring, 0, sizeof(*out_ring));
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, status,
+        "making kernarg ring allocation visible to %" PRIhsz " HSA agents",
+        memory->access_agent_count);
+  }
+
+  out_ring->capacity = capacity;
+  out_ring->mask = capacity - 1;
+  out_ring->publication = memory->publication;
+  iree_atomic_store(&out_ring->write_position, 0, iree_memory_order_relaxed);
+  iree_atomic_store(&out_ring->read_position, 0, iree_memory_order_relaxed);
+
+  // Fault in the host mapping on the initialization path instead of the first
+  // submitter that writes into the ring.
+  out_ring->base[0].data[0] = 0;
+
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)capacity);
+  IREE_TRACE_ZONE_END(z0);
+  return iree_ok_status();
+}
+
+void iree_hal_amdgpu_kernarg_ring_deinitialize(
+    const iree_hal_amdgpu_libhsa_t* libhsa,
+    iree_hal_amdgpu_kernarg_ring_t* ring) {
+  IREE_ASSERT_ARGUMENT(libhsa);
+  IREE_ASSERT_ARGUMENT(ring);
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  // All in-flight work must have completed and been reclaimed. Any remaining
+  // gap between write and read indicates leaked kernarg blocks.
+  const uint64_t write = (uint64_t)iree_atomic_load(&ring->write_position,
+                                                    iree_memory_order_relaxed);
+  const uint64_t read = (uint64_t)iree_atomic_load(&ring->read_position,
+                                                   iree_memory_order_relaxed);
+  IREE_ASSERT(write == read,
+              "kernarg ring has %" PRIu64
+              " unreleased blocks at deinit (write=%" PRIu64 ", read=%" PRIu64
+              ")",
+              write - read, write, read);
+
+  if (ring->base) {
+    iree_hal_amdgpu_hsa_cleanup_assert_success(
+        iree_hsa_amd_memory_pool_free_raw(libhsa, ring->base));
+  }
+  memset(ring, 0, sizeof(*ring));
+
+  IREE_TRACE_ZONE_END(z0);
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/kernarg_ring.h b/runtime/src/iree/hal/drivers/amdgpu/util/kernarg_ring.h
new file mode 100644
index 0000000..c536530
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/kernarg_ring.h
@@ -0,0 +1,348 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_UTIL_KERNARG_RING_H_
+#define IREE_HAL_DRIVERS_AMDGPU_UTIL_KERNARG_RING_H_
+
+#include "iree/base/api.h"
+
+#if defined(IREE_ARCH_X86_64)
+#if defined(IREE_COMPILER_MSVC_COMPAT)
+#include <intrin.h>
+#else
+#include <emmintrin.h>
+#endif  // IREE_COMPILER_MSVC_COMPAT
+#endif  // IREE_ARCH_X86_64
+
+#include "iree/base/internal/atomics.h"
+#include "iree/hal/drivers/amdgpu/util/libhsa.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+//===----------------------------------------------------------------------===//
+// iree_hal_amdgpu_kernarg_block_t
+//===----------------------------------------------------------------------===//
+
+// 64-byte aligned kernarg slot. Each dispatch's kernarg region starts on a
+// cache-line boundary, preventing false sharing between the GPU command
+// processor reading one dispatch's args and the CPU writing the next. CLR uses
+// the same alignment (L3 cache line, 64 bytes on CDNA/RDNA).
+//
+// The ring is indexed in units of this block type: sizes and offsets are in
+// block counts, not bytes. A dispatch needing 128 bytes of kernarg space
+// consumes 2 blocks.
+typedef struct iree_alignas(64) iree_hal_amdgpu_kernarg_block_t {
+  uint8_t data[64];
+} iree_hal_amdgpu_kernarg_block_t;
+static_assert(sizeof(iree_hal_amdgpu_kernarg_block_t) == 64,
+              "kernarg block must be exactly one cache line");
+
+//===----------------------------------------------------------------------===//
+// iree_hal_amdgpu_kernarg_ring_t
+//===----------------------------------------------------------------------===//
+
+typedef enum iree_hal_amdgpu_kernarg_ring_publication_mode_e {
+  // Host writes to the ring need no extra publication beyond the packet-header
+  // release store.
+  IREE_HAL_AMDGPU_KERNARG_RING_PUBLICATION_MODE_NONE = 0,
+  // Host writes to the ring are published with the agent's HDP memory-flush
+  // register before packet-header commits.
+  IREE_HAL_AMDGPU_KERNARG_RING_PUBLICATION_MODE_HDP_FLUSH = 1,
+} iree_hal_amdgpu_kernarg_ring_publication_mode_t;
+
+typedef struct iree_hal_amdgpu_kernarg_ring_publication_t {
+  // Publication mode used before AQL packet headers reference the written
+  // kernargs.
+  iree_hal_amdgpu_kernarg_ring_publication_mode_t mode;
+  // HDP memory-flush register used when |mode| is
+  // IREE_HAL_AMDGPU_KERNARG_RING_PUBLICATION_MODE_HDP_FLUSH.
+  volatile uint32_t* hdp_mem_flush_control;
+} iree_hal_amdgpu_kernarg_ring_publication_t;
+
+// Memory backing policy for a queue-owned kernarg ring.
+//
+// The descriptor is consumed during ring initialization and may reference
+// caller-owned stack storage for |access_agents|.
+typedef struct iree_hal_amdgpu_kernarg_ring_memory_t {
+  // HSA memory pool used for the ring allocation.
+  hsa_amd_memory_pool_t memory_pool;
+  // Agents granted explicit access to the ring allocation.
+  const hsa_agent_t* access_agents;
+  // Number of entries in |access_agents|.
+  iree_host_size_t access_agent_count;
+  // Host-write publication mechanism for this memory pool.
+  iree_hal_amdgpu_kernarg_ring_publication_t publication;
+} iree_hal_amdgpu_kernarg_ring_memory_t;
+
+// Per-queue bump allocator for dispatch kernarg memory backed by an HSA memory
+// pool. The backing pool is selected per physical device: host kernarg-init
+// memory is universally valid, while host-visible device-local memory avoids
+// device reads over PCIe on systems where CPU agents can directly access the
+// GPU's coarse-grained pool.
+//
+// Thread safety:
+//   allocate() is multi-producer safe (CAS on write_position).
+//   reclaim() must be called from a single thread (proactor drain).
+//   Multiple threads may call allocate() concurrently.
+//
+// Backpressure contract:
+//   The caller must prove capacity for both the AQL ring and kernarg ring
+//   before publishing work. These resources are intentionally independent:
+//   crossing a 64-byte kernarg block boundary must not require extra AQL
+//   packets. Callers use can_allocate() as a non-mutating admission check
+//   before reserving AQL packets, then allocate() the same block count before
+//   committing packet headers.
+//
+//   A non-VMEM ring may need to skip one tail fragment at wrap to preserve
+//   contiguous multi-block allocations. The skipped fragment is counted as
+//   in-flight space until the allocation that caused the wrap retires. Both
+//   can_allocate() and allocate() model that skip identically so admission
+//   cannot succeed and then fail in normal execution.
+//
+// Memory ordering:
+//   write_position uses relaxed CAS. It only claims space. The GPU cannot read
+//   kernargs until it sees a valid AQL packet header. Rings backed by normal
+//   host kernarg memory rely on the AQL packet header release store to order
+//   prior kernarg writes. Rings backed by host-visible device memory also
+//   require publish_host_writes() before committing any packet header that
+//   references queue-owned kernargs so CPU write-combining/BAR effects are
+//   drained before the doorbell-visible packet stream can consume them.
+//
+//   read_position uses release (drain) / acquire (allocators). The drain
+//   publishes reclamation; allocators observe it for the defensive fullness
+//   check. On x86 (the only host architecture for this driver) release and
+//   relaxed stores are identical instructions, but we use release to be
+//   correct by construction.
+typedef struct iree_hal_amdgpu_kernarg_ring_t {
+  // Base pointer to the HSA memory-pool allocation, cast to block type for
+  // natural indexing (base[i] gives block i without byte arithmetic).
+  iree_hal_amdgpu_kernarg_block_t* base;
+
+  // Power-of-two capacity in blocks.
+  uint32_t capacity;
+  // capacity - 1, for masking logical positions to physical ring indices.
+  uint32_t mask;
+
+  // Host-write publication mechanism for this ring.
+  iree_hal_amdgpu_kernarg_ring_publication_t publication;
+
+  // Monotonically increasing write position in blocks. Each allocate()
+  // atomically advances this via a CAS loop. The position is a logical
+  // (unwrapped) index; (write_position & mask) gives the physical ring offset.
+  //
+  // Relaxed ordering: see the memory ordering discussion above.
+  iree_atomic_int64_t write_position;
+
+  // Read position in blocks. Advanced by the proactor drain when the GPU
+  // completes work that referenced the kernarg space. The number of blocks
+  // currently in use is (write_position - read_position).
+  //
+  // Single writer (proactor drain), multiple readers (allocating threads).
+  // Release store from drain, acquire load from allocators.
+  iree_atomic_int64_t read_position;
+} iree_hal_amdgpu_kernarg_ring_t;
+
+// Initializes the kernarg ring by allocating at least |min_capacity_in_blocks|
+// blocks from |memory->memory_pool|.
+//
+// |min_capacity_in_blocks| must be a power of two. The actual capacity may be
+// larger if the HSA allocation granule requires rounding; the ring preserves a
+// power-of-two block count so physical indices can be masked.
+//
+// HSA VMEM handles are not supported on at least current ROCm stacks for the
+// host pools used here, so this ring uses a plain pool allocation and skips a
+// tail fragment when needed to preserve contiguous multi-block spans.
+iree_status_t iree_hal_amdgpu_kernarg_ring_initialize(
+    const iree_hal_amdgpu_libhsa_t* libhsa,
+    const iree_hal_amdgpu_kernarg_ring_memory_t* memory,
+    uint32_t min_capacity_in_blocks, iree_hal_amdgpu_kernarg_ring_t* out_ring);
+
+// Deinitializes the kernarg ring and frees the backing HSA memory-pool
+// allocation.
+// All in-flight work must have completed and been reclaimed before calling.
+void iree_hal_amdgpu_kernarg_ring_deinitialize(
+    const iree_hal_amdgpu_libhsa_t* libhsa,
+    iree_hal_amdgpu_kernarg_ring_t* ring);
+
+// Returns true if host writes into |ring| need explicit publication before
+// packet headers referencing queue-owned kernargs become visible.
+static inline bool iree_hal_amdgpu_kernarg_ring_requires_host_write_publication(
+    const iree_hal_amdgpu_kernarg_ring_t* ring) {
+  return ring->publication.mode !=
+         IREE_HAL_AMDGPU_KERNARG_RING_PUBLICATION_MODE_NONE;
+}
+
+// Returns true if this host architecture has an explicit store fence suitable
+// for ordering BAR/write-combined kernarg writes before device-visible
+// publication.
+static inline bool iree_hal_amdgpu_kernarg_ring_supports_host_write_publication(
+    void) {
+#if defined(IREE_ARCH_X86_64) || \
+    (defined(IREE_ARCH_ARM_64) && defined(IREE_COMPILER_GCC_COMPAT))
+  return true;
+#else
+  return false;
+#endif  // IREE_ARCH_*
+}
+
+// Orders host writes to BAR/write-combined kernarg memory before the
+// device-visible publication operation that follows.
+static inline void iree_hal_amdgpu_kernarg_ring_host_write_fence(void) {
+#if defined(IREE_ARCH_X86_64)
+  _mm_sfence();
+#elif defined(IREE_ARCH_ARM_64) && defined(IREE_COMPILER_GCC_COMPAT)
+  __asm__ __volatile__("dmb oshst" ::: "memory");
+#else
+  // This fallback is only valid for normal host memory. Physical-device
+  // selection must not choose a publication-requiring kernarg pool when
+  // supports_host_write_publication() is false.
+  iree_atomic_thread_fence(iree_memory_order_seq_cst);
+#endif  // IREE_ARCH_*
+}
+
+// Publishes host writes to queue-owned kernargs before their packet headers
+// become visible to the command processor.
+//
+// Host kernarg-init memory needs no extra work: the packet header release store
+// provides the ordering edge. Host-visible device memory can be mapped through
+// write-combining/BAR paths, so the selected backing policy must drain CPU
+// writes and make them visible to the GPU before packet publication.
+static inline void iree_hal_amdgpu_kernarg_ring_publish_host_writes(
+    const iree_hal_amdgpu_kernarg_ring_t* ring) {
+  if (ring->publication.mode ==
+      IREE_HAL_AMDGPU_KERNARG_RING_PUBLICATION_MODE_NONE) {
+    return;
+  }
+  IREE_ASSERT(ring->publication.mode ==
+              IREE_HAL_AMDGPU_KERNARG_RING_PUBLICATION_MODE_HDP_FLUSH);
+  IREE_ASSERT(ring->publication.hdp_mem_flush_control);
+  iree_hal_amdgpu_kernarg_ring_host_write_fence();
+  *ring->publication.hdp_mem_flush_control = 1u;
+  (void)*ring->publication.hdp_mem_flush_control;
+}
+
+// Returns true if |block_count| contiguous blocks can currently be allocated
+// without exceeding the ring capacity. The check accounts for the same
+// tail-fragment skip used by allocate() and does not mutate the ring.
+//
+// This is a snapshot. It is a reservation proof only when the caller serializes
+// this check with a following allocate() call, such as under the host queue
+// submission mutex.
+static inline bool iree_hal_amdgpu_kernarg_ring_can_allocate(
+    iree_hal_amdgpu_kernarg_ring_t* ring, uint32_t block_count) {
+  if (IREE_UNLIKELY(block_count == 0 || block_count > ring->capacity)) {
+    return false;
+  }
+  uint64_t first_block = (uint64_t)iree_atomic_load(&ring->write_position,
+                                                    iree_memory_order_relaxed);
+  const uint64_t tail_block_count =
+      (uint64_t)ring->capacity - (first_block & ring->mask);
+  if (block_count > tail_block_count) {
+    first_block += tail_block_count;
+  }
+  const uint64_t next_write_position = first_block + block_count;
+  const uint64_t read_position = (uint64_t)iree_atomic_load(
+      &ring->read_position, iree_memory_order_acquire);
+  return next_write_position - read_position <= ring->capacity;
+}
+
+// Allocates |block_count| contiguous blocks from the ring.
+//
+// Returns a pointer to the first block, 64-byte aligned and suitable for use
+// as a dispatch packet's kernarg_address. Multi-block allocations never wrap in
+// physical memory: if the tail fragment is too small, the allocator skips to
+// the first block and leaves the tail fragment to be reclaimed with the
+// allocation's end position.
+//
+// |out_end_position| receives the logical write position after this allocation
+// (first_block + block_count). The caller records this value for epoch-driven
+// reclamation: when the GPU completes the associated submission, the drain
+// calls reclaim() with this position.
+//
+// REQUIRES: The caller must have already proved capacity with can_allocate().
+// Under the host queue submission mutex, no other allocator can consume the
+// proved space before this call. A NULL return after a successful admission
+// check indicates an internal synchronization or sizing invariant failure.
+//
+// Returns NULL if block_count is 0 or exceeds ring capacity (checked before
+// touching any atomics), or if the ring is full.
+static inline iree_hal_amdgpu_kernarg_block_t*
+iree_hal_amdgpu_kernarg_ring_allocate(iree_hal_amdgpu_kernarg_ring_t* ring,
+                                      uint32_t block_count,
+                                      uint64_t* out_end_position) {
+  IREE_ASSERT_ARGUMENT(out_end_position);
+  if (IREE_UNLIKELY(!block_count || block_count > ring->capacity)) {
+    *out_end_position = 0;
+    return NULL;
+  }
+
+  // Claim one contiguous physical span with a CAS loop. If the current tail
+  // fragment is too small, skip it and wrap to block 0. Relaxed ordering is
+  // sufficient: we are only claiming indices. The subsequent kernarg writes are
+  // ordered by the AQL header commit's release store.
+  int64_t observed_write_position =
+      iree_atomic_load(&ring->write_position, iree_memory_order_relaxed);
+  uint64_t first_block = 0;
+  uint64_t next_write_position = 0;
+  for (;;) {
+    first_block = (uint64_t)observed_write_position;
+    const uint64_t tail_block_count =
+        (uint64_t)ring->capacity - (first_block & ring->mask);
+    if (block_count > tail_block_count) {
+      first_block += tail_block_count;
+    }
+    next_write_position = first_block + block_count;
+    const uint64_t read_position = (uint64_t)iree_atomic_load(
+        &ring->read_position, iree_memory_order_acquire);
+    if (IREE_UNLIKELY(next_write_position - read_position > ring->capacity)) {
+      *out_end_position = 0;
+      return NULL;
+    }
+    if (iree_atomic_compare_exchange_weak(
+            &ring->write_position, &observed_write_position,
+            (int64_t)next_write_position, iree_memory_order_relaxed,
+            iree_memory_order_relaxed)) {
+      break;
+    }
+  }
+
+  // Record the end position for the caller's reclamation tracking.
+  *out_end_position = next_write_position;
+  return &ring->base[first_block & ring->mask];
+}
+
+// Reclaims all kernarg blocks up to |new_read_position|. Called by the
+// proactor drain after confirming the GPU has completed work that referenced
+// the kernarg space.
+//
+// |new_read_position| is the end_position returned by allocate() at
+// submission time. The drain processes completions in epoch order, so
+// read_position advances monotonically.
+//
+// Must only be called from the proactor drain thread (single writer).
+static inline void iree_hal_amdgpu_kernarg_ring_reclaim(
+    iree_hal_amdgpu_kernarg_ring_t* ring, uint64_t new_read_position) {
+  // The drain processes completions in epoch order, so read_position must
+  // advance monotonically and never exceed write_position.
+  IREE_ASSERT(new_read_position >=
+              (uint64_t)iree_atomic_load(&ring->read_position,
+                                         iree_memory_order_relaxed));
+  IREE_ASSERT(new_read_position <=
+              (uint64_t)iree_atomic_load(&ring->write_position,
+                                         iree_memory_order_relaxed));
+  // Release store publishes the reclamation. Allocating threads loading
+  // read_position with acquire will see the updated available space.
+  iree_atomic_store(&ring->read_position, (int64_t)new_read_position,
+                    iree_memory_order_release);
+}
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_UTIL_KERNARG_RING_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/kernarg_ring_test.cc b/runtime/src/iree/hal/drivers/amdgpu/util/kernarg_ring_test.cc
new file mode 100644
index 0000000..12789d3
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/kernarg_ring_test.cc
@@ -0,0 +1,95 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/util/kernarg_ring.h"
+
+#include <cstdint>
+
+#include "iree/testing/gtest.h"
+
+namespace {
+
+static void InitializeTestRing(iree_hal_amdgpu_kernarg_block_t* storage,
+                               uint32_t capacity,
+                               iree_hal_amdgpu_kernarg_ring_t* out_ring) {
+  out_ring->base = storage;
+  out_ring->capacity = capacity;
+  out_ring->mask = capacity - 1;
+  iree_atomic_store(&out_ring->write_position, 0, iree_memory_order_relaxed);
+  iree_atomic_store(&out_ring->read_position, 0, iree_memory_order_relaxed);
+}
+
+TEST(KernargRingTest, AllocatesContiguousBlocksAndReclaims) {
+  iree_hal_amdgpu_kernarg_block_t storage[8] = {};
+  iree_hal_amdgpu_kernarg_ring_t ring = {};
+  InitializeTestRing(storage, IREE_ARRAYSIZE(storage), &ring);
+
+  EXPECT_TRUE(iree_hal_amdgpu_kernarg_ring_can_allocate(&ring, 3));
+
+  uint64_t end_position = 0;
+  iree_hal_amdgpu_kernarg_block_t* first =
+      iree_hal_amdgpu_kernarg_ring_allocate(&ring, 3, &end_position);
+  EXPECT_EQ(first, &storage[0]);
+  EXPECT_EQ(end_position, 3u);
+
+  iree_hal_amdgpu_kernarg_block_t* second =
+      iree_hal_amdgpu_kernarg_ring_allocate(&ring, 4, &end_position);
+  EXPECT_EQ(second, &storage[3]);
+  EXPECT_EQ(end_position, 7u);
+
+  EXPECT_FALSE(iree_hal_amdgpu_kernarg_ring_can_allocate(&ring, 2));
+
+  iree_hal_amdgpu_kernarg_ring_reclaim(&ring, 3);
+  EXPECT_TRUE(iree_hal_amdgpu_kernarg_ring_can_allocate(&ring, 2));
+
+  iree_hal_amdgpu_kernarg_block_t* wrapped =
+      iree_hal_amdgpu_kernarg_ring_allocate(&ring, 2, &end_position);
+  EXPECT_EQ(wrapped, &storage[0]);
+  EXPECT_EQ(end_position, 10u);
+}
+
+TEST(KernargRingTest, RejectsInvalidBlockCounts) {
+  iree_hal_amdgpu_kernarg_block_t storage[4] = {};
+  iree_hal_amdgpu_kernarg_ring_t ring = {};
+  InitializeTestRing(storage, IREE_ARRAYSIZE(storage), &ring);
+
+  EXPECT_FALSE(iree_hal_amdgpu_kernarg_ring_can_allocate(&ring, 0));
+  EXPECT_FALSE(iree_hal_amdgpu_kernarg_ring_can_allocate(&ring, 5));
+
+  uint64_t end_position = UINT64_MAX;
+  EXPECT_EQ(iree_hal_amdgpu_kernarg_ring_allocate(&ring, 0, &end_position),
+            nullptr);
+  EXPECT_EQ(end_position, 0u);
+  end_position = UINT64_MAX;
+  EXPECT_EQ(iree_hal_amdgpu_kernarg_ring_allocate(&ring, 5, &end_position),
+            nullptr);
+  EXPECT_EQ(end_position, 0u);
+}
+
+TEST(KernargRingTest, PublicationModeNoneSkipsRegisterWrite) {
+  volatile uint32_t hdp_mem_flush_control = 0xCAFEu;
+  iree_hal_amdgpu_kernarg_ring_t ring = {};
+  ring.publication.mode = IREE_HAL_AMDGPU_KERNARG_RING_PUBLICATION_MODE_NONE;
+  ring.publication.hdp_mem_flush_control = &hdp_mem_flush_control;
+
+  iree_hal_amdgpu_kernarg_ring_publish_host_writes(&ring);
+
+  EXPECT_EQ(hdp_mem_flush_control, 0xCAFEu);
+}
+
+TEST(KernargRingTest, PublicationModeHdpFlushWritesRegister) {
+  volatile uint32_t hdp_mem_flush_control = 0u;
+  iree_hal_amdgpu_kernarg_ring_t ring = {};
+  ring.publication.mode =
+      IREE_HAL_AMDGPU_KERNARG_RING_PUBLICATION_MODE_HDP_FLUSH;
+  ring.publication.hdp_mem_flush_control = &hdp_mem_flush_control;
+
+  iree_hal_amdgpu_kernarg_ring_publish_host_writes(&ring);
+
+  EXPECT_EQ(hdp_mem_flush_control, 1u);
+}
+
+}  // namespace
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/kfd.c b/runtime/src/iree/hal/drivers/amdgpu/util/kfd.c
index 932a4d3..9b1af7a 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/util/kfd.c
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/kfd.c
@@ -6,12 +6,16 @@
 
 #include "iree/hal/drivers/amdgpu/util/kfd.h"
 
+#include <inttypes.h>
+#include <string.h>  // memset
+
 //===----------------------------------------------------------------------===//
 // KFD IOCTL Workaround
 //===----------------------------------------------------------------------===//
 
 #if defined(IREE_PLATFORM_LINUX)
 
+#include <errno.h>
 #include <fcntl.h>      // open
 #include <sys/ioctl.h>  // ioctl
 #include <unistd.h>     // close
@@ -25,19 +29,21 @@
 iree_status_t iree_hal_amdgpu_kfd_open(int* out_fd) {
   IREE_ASSERT_ARGUMENT(out_fd);
   IREE_TRACE_ZONE_BEGIN(z0);
-  *out_fd = 0;
+  *out_fd = -1;
 
   iree_status_t status = iree_ok_status();
   const int fd = open("/dev/kfd", O_RDWR | O_CLOEXEC);
   if (fd == -1) {
+    const int errsv = errno;
     status = iree_make_status(IREE_STATUS_INTERNAL,
                               "unable to open /dev/kfd channel; platform file "
-                              "handle limit may be reached");
+                              "handle limit may be reached; errno=%d (%s)",
+                              errsv, strerror(errsv));
   }
 
   if (iree_status_is_ok(status)) {
     *out_fd = fd;
-  } else if (fd > 0) {
+  } else if (fd >= 0) {
     close(fd);
   }
   IREE_TRACE_ZONE_END(z0);
@@ -46,7 +52,7 @@
 
 void iree_hal_amdgpu_kfd_close(int fd) {
   IREE_TRACE_ZONE_BEGIN(z0);
-  if (fd > 0) {
+  if (fd >= 0) {
     close(fd);
   }
   IREE_TRACE_ZONE_END(z0);
@@ -63,13 +69,17 @@
 #else
 
 iree_status_t iree_hal_amdgpu_kfd_open(int* out_fd) {
-  *out_fd = 0;
+  IREE_ASSERT_ARGUMENT(out_fd);
+  *out_fd = -1;
   return iree_ok_status();
 }
 
-void iree_hal_amdgpu_kfd_close(int fd) {}
+void iree_hal_amdgpu_kfd_close(int fd) { (void)fd; }
 
 int iree_hal_amdgpu_ioctl(int fd, unsigned long request, void* arg) {
+  (void)fd;
+  (void)request;
+  (void)arg;
   return -1;
 }
 
@@ -85,24 +95,48 @@
   IREE_AMDKFD_IOWR(0x05, struct iree_kfd_ioctl_get_clock_counters_args)
 
 struct iree_kfd_ioctl_get_clock_counters_args {
-  uint64_t gpu_clock_counter;     // from KFD
-  uint64_t cpu_clock_counter;     // from KFD
-  uint64_t system_clock_counter;  // from KFD
-  uint64_t system_clock_freq;     // from KFD
-  uint32_t gpu_id;                // to KFD
+  // GPU clock counter returned by KFD.
+  uint64_t gpu_clock_counter;
+
+  // Host CPU timestamp returned by KFD.
+  uint64_t cpu_clock_counter;
+
+  // Host system clock counter returned by KFD.
+  uint64_t system_clock_counter;
+
+  // Frequency in Hz for system_clock_counter returned by KFD.
+  uint64_t system_clock_freq;
+
+  // GPU identifier passed to KFD.
+  uint32_t gpu_id;
+
+  // Reserved padding matching the KFD ABI.
   uint32_t pad;
 };
 
 iree_status_t iree_hal_amdgpu_kfd_get_clock_counters(
-    int fd, uint32_t gpu_uid, iree_hal_amdgpu_clock_counters_t* out_counters) {
+    int fd, uint32_t driver_uid,
+    iree_hal_amdgpu_kfd_clock_counters_t* out_counters) {
+  IREE_ASSERT_ARGUMENT(out_counters);
+  memset(out_counters, 0, sizeof(*out_counters));
+  if (IREE_UNLIKELY(fd < 0)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "invalid /dev/kfd file descriptor for "
+                            "AMDKFD_IOC_GET_CLOCK_COUNTERS(driver_uid=%" PRIu32
+                            ")",
+                            driver_uid);
+  }
+
   struct iree_kfd_ioctl_get_clock_counters_args args = {0};
-  args.gpu_id = gpu_uid;
+  args.gpu_id = driver_uid;
   int kmt_err =
       iree_hal_amdgpu_ioctl(fd, IREE_AMDKFD_IOC_GET_CLOCK_COUNTERS, &args);
   if (IREE_UNLIKELY(kmt_err < 0)) {
+    const int errsv = errno;
     return iree_make_status(IREE_STATUS_INTERNAL,
-                            "AMDKFD_IOC_GET_CLOCK_COUNTERS failed with %d",
-                            kmt_err);
+                            "AMDKFD_IOC_GET_CLOCK_COUNTERS(driver_uid=%" PRIu32
+                            ") failed with %d; errno=%d (%s)",
+                            driver_uid, kmt_err, errsv, strerror(errsv));
   }
   out_counters->gpu_clock_counter = args.gpu_clock_counter;
   out_counters->cpu_clock_counter = args.cpu_clock_counter;
@@ -114,9 +148,14 @@
 #else
 
 iree_status_t iree_hal_amdgpu_kfd_get_clock_counters(
-    int fd, uint32_t gpu_uid, iree_hal_amdgpu_clock_counters_t* out_counters) {
+    int fd, uint32_t driver_uid,
+    iree_hal_amdgpu_kfd_clock_counters_t* out_counters) {
+  (void)fd;
+  (void)driver_uid;
   memset(out_counters, 0, sizeof(*out_counters));
-  return iree_ok_status();
+  return iree_make_status(
+      IREE_STATUS_UNIMPLEMENTED,
+      "AMDKFD_IOC_GET_CLOCK_COUNTERS requires Linux /dev/kfd support");
 }
 
 #endif  // IREE_PLATFORM_LINUX
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/kfd.h b/runtime/src/iree/hal/drivers/amdgpu/util/kfd.h
index 9e3e580..416e637 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/util/kfd.h
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/kfd.h
@@ -22,12 +22,15 @@
 // platform).
 
 // Tries to open /dev/kfd for read/write access.
+//
+// Returns a non-negative file descriptor in |out_fd| on success and -1 on
+// failure or unsupported platforms.
 // It should be exceptionally rare that this fails: if HSA has already been
 // initialized successfully the only expected failure condition would be file
 // handle exhaustion.
 iree_status_t iree_hal_amdgpu_kfd_open(int* out_fd);
 
-// Closes an open /dev/kfd handle.
+// Closes an open /dev/kfd handle. Ignores the -1 invalid sentinel.
 void iree_hal_amdgpu_kfd_close(int fd);
 
 // Interrupt-tolerant ioctl.
@@ -39,18 +42,26 @@
 // Tracking for adding AMDKFD_IOC_GET_CLOCK_COUNTERS to the API:
 // https://github.com/ROCm/ROCR-Runtime/issues/278
 
-typedef struct iree_hal_amdgpu_clock_counters_t {
+typedef struct iree_hal_amdgpu_kfd_clock_counters_t {
+  // GPU clock counter sampled by KFD for the requested GPU.
   uint64_t gpu_clock_counter;
+
+  // Host CPU timestamp sampled by KFD near the GPU clock read.
   uint64_t cpu_clock_counter;
+
+  // Host system clock counter sampled by KFD near the GPU clock read.
   uint64_t system_clock_counter;
+
+  // Frequency in Hz for system_clock_counter.
   uint64_t system_clock_freq;
-} iree_hal_amdgpu_clock_counters_t;
+} iree_hal_amdgpu_kfd_clock_counters_t;
 
 // Equivalent to `hsaKmtGetClockCounters` in the ROCR KMT.
 // |fd| must be an open /dev/kfd file handle.
-// |gpu_uid| must be the HSA_AMD_AGENT_INFO_DRIVER_UID of the node to query.
+// |driver_uid| must be the HSA_AMD_AGENT_INFO_DRIVER_UID of the node to query.
 iree_status_t iree_hal_amdgpu_kfd_get_clock_counters(
-    int fd, uint32_t gpu_uid, iree_hal_amdgpu_clock_counters_t* out_counters);
+    int fd, uint32_t driver_uid,
+    iree_hal_amdgpu_kfd_clock_counters_t* out_counters);
 
 #ifdef __cplusplus
 }  // extern "C"
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/kfd_test.cc b/runtime/src/iree/hal/drivers/amdgpu/util/kfd_test.cc
index 662a73b..7f87f1f 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/util/kfd_test.cc
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/kfd_test.cc
@@ -12,9 +12,78 @@
 #include "iree/testing/gtest.h"
 #include "iree/testing/status_matchers.h"
 
+#if defined(IREE_PLATFORM_LINUX)
+#include <errno.h>
+#include <fcntl.h>
+#include <unistd.h>
+#endif  // IREE_PLATFORM_LINUX
+
 namespace iree::hal::amdgpu {
 namespace {
 
+#if defined(IREE_PLATFORM_LINUX)
+class StdinRestorer {
+ public:
+  StdinRestorer() : saved_stdin_(dup(STDIN_FILENO)) {}
+
+  ~StdinRestorer() {
+    if (saved_stdin_ >= 0) {
+      dup2(saved_stdin_, STDIN_FILENO);
+      close(saved_stdin_);
+    }
+  }
+
+  bool ok() const { return saved_stdin_ >= 0; }
+
+ private:
+  int saved_stdin_ = -1;
+};
+#endif  // IREE_PLATFORM_LINUX
+
+TEST(KFDStandaloneTest, GetClockCountersFailsForInvalidDescriptor) {
+  iree_hal_amdgpu_kfd_clock_counters_t counters = {
+      /*.gpu_clock_counter=*/1,
+      /*.cpu_clock_counter=*/2,
+      /*.system_clock_counter=*/3,
+      /*.system_clock_freq=*/4,
+  };
+
+#if defined(IREE_PLATFORM_LINUX)
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_INVALID_ARGUMENT,
+      iree_hal_amdgpu_kfd_get_clock_counters(-1, 1234, &counters));
+#else
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_UNIMPLEMENTED,
+      iree_hal_amdgpu_kfd_get_clock_counters(-1, 1234, &counters));
+#endif  // IREE_PLATFORM_LINUX
+
+  EXPECT_EQ(counters.gpu_clock_counter, 0);
+  EXPECT_EQ(counters.cpu_clock_counter, 0);
+  EXPECT_EQ(counters.system_clock_counter, 0);
+  EXPECT_EQ(counters.system_clock_freq, 0);
+}
+
+TEST(KFDStandaloneTest, CloseTreatsFdZeroAsValid) {
+#if defined(IREE_PLATFORM_LINUX)
+  StdinRestorer stdin_restorer;
+  ASSERT_TRUE(stdin_restorer.ok());
+
+  int pipe_fds[2] = {-1, -1};
+  ASSERT_EQ(0, pipe(pipe_fds));
+  ASSERT_EQ(STDIN_FILENO, dup2(pipe_fds[0], STDIN_FILENO));
+  close(pipe_fds[0]);
+  close(pipe_fds[1]);
+
+  iree_hal_amdgpu_kfd_close(STDIN_FILENO);
+  errno = 0;
+  EXPECT_EQ(-1, fcntl(STDIN_FILENO, F_GETFD));
+  EXPECT_EQ(EBADF, errno);
+#else
+  iree_hal_amdgpu_kfd_close(0);
+#endif  // IREE_PLATFORM_LINUX
+}
+
 // NOTE: ROCR also opens the KFD - if it initializes then we're likely to
 // succeed as well. We need information we can only get from HSA to make the
 // ioctls so we have to setup a full topology here.
@@ -31,7 +100,7 @@
         host_allocator, &libhsa);
     if (!iree_status_is_ok(status)) {
       iree_status_fprint(stderr, status);
-      iree_status_ignore(status);
+      iree_status_free(status);
       GTEST_SKIP() << "HSA not available, skipping tests";
     }
     IREE_ASSERT_OK(
@@ -53,7 +122,7 @@
         IREE_LIBHSA(&libhsa), topology.gpu_agents[0],
         (hsa_agent_info_t)HSA_AMD_AGENT_INFO_DRIVER_UID, &gpu_uid);
     if (!iree_status_is_ok(status)) {
-      iree_status_ignore(status);
+      iree_status_free(status);
       return std::nullopt;
     }
     return gpu_uid;
@@ -64,17 +133,18 @@
 iree_hal_amdgpu_topology_t KFDTest::topology;
 
 // Tests opening and closing the KFD. It should not crash/leak/etc.
-// Note that we currently no-op the helpers on non-Linux platforms and this
-// will always succeed.
 TEST_F(KFDTest, Lifetime) {
   int kfd = -1;
   IREE_ASSERT_OK(iree_hal_amdgpu_kfd_open(&kfd));
+#if defined(IREE_PLATFORM_LINUX)
+  ASSERT_GE(kfd, 0);
+#else
+  ASSERT_EQ(kfd, -1);
+#endif  // IREE_PLATFORM_LINUX
   iree_hal_amdgpu_kfd_close(kfd);
 }
 
 // Tests that we get non-zero counters from the clock.
-// We always make the call but only expect non-zero if the returned kfd fd is
-// not 0 (the special value for "not a real fd" we use on non-Linux).
 TEST_F(KFDTest, GetClockCounters) {
   // Find a GPU ID we can use to make the ioctl. If we can't find one we skip
   // the test.
@@ -87,17 +157,15 @@
   int kfd = -1;
   IREE_ASSERT_OK(iree_hal_amdgpu_kfd_open(&kfd));
 
-  iree_hal_amdgpu_clock_counters_t counters = {0};
+  iree_hal_amdgpu_kfd_clock_counters_t counters = {0};
   IREE_ASSERT_OK(
       iree_hal_amdgpu_kfd_get_clock_counters(kfd, *gpu_uid, &counters));
 
-  if (kfd != 0) {
-    // Don't care about the values, just that they were populated.
-    ASSERT_NE(counters.gpu_clock_counter, 0);
-    ASSERT_NE(counters.cpu_clock_counter, 0);
-    ASSERT_NE(counters.system_clock_counter, 0);
-    ASSERT_NE(counters.system_clock_freq, 0);
-  }
+  // Don't care about the values, just that they were populated.
+  ASSERT_NE(counters.gpu_clock_counter, 0);
+  ASSERT_NE(counters.cpu_clock_counter, 0);
+  ASSERT_NE(counters.system_clock_counter, 0);
+  ASSERT_NE(counters.system_clock_freq, 0);
 
   iree_hal_amdgpu_kfd_close(kfd);
 }
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/libaqlprofile.c b/runtime/src/iree/hal/drivers/amdgpu/util/libaqlprofile.c
new file mode 100644
index 0000000..1ee198f
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/libaqlprofile.c
@@ -0,0 +1,470 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/licenses/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/util/libaqlprofile.h"
+
+#include "iree/base/internal/dynamic_library.h"
+#include "iree/base/internal/path.h"
+#include "third_party/hsa-runtime-headers/include/aqlprofile-sdk/aql_profile_v2.h"
+
+// Keep vendor aqlprofile headers private to this translation unit. The public
+// wrapper mirrors only the narrow ABI surface used by the HAL so callers do not
+// inherit a large global vendor namespace.
+IREE_AMDGPU_STATIC_ASSERT(sizeof(iree_hal_amdgpu_aqlprofile_handle_t) ==
+                              sizeof(aqlprofile_handle_t),
+                          "aqlprofile handle layout must match the v2 SDK");
+IREE_AMDGPU_STATIC_ASSERT(sizeof(iree_hal_amdgpu_aqlprofile_version_t) ==
+                              sizeof(aqlprofile_version_t),
+                          "aqlprofile version layout must match the v2 SDK");
+IREE_AMDGPU_STATIC_ASSERT(
+    sizeof(iree_hal_amdgpu_aqlprofile_buffer_desc_flags_t) ==
+        sizeof(aqlprofile_buffer_desc_flags_t),
+    "aqlprofile buffer flags layout must match the v2 SDK");
+IREE_AMDGPU_STATIC_ASSERT(
+    sizeof(iree_hal_amdgpu_aqlprofile_att_parameter_t) ==
+        sizeof(aqlprofile_att_parameter_t),
+    "aqlprofile ATT parameter layout must match the v2 SDK");
+IREE_AMDGPU_STATIC_ASSERT(
+    sizeof(iree_hal_amdgpu_aqlprofile_att_profile_t) ==
+        sizeof(aqlprofile_att_profile_t),
+    "aqlprofile ATT profile layout must match the v2 SDK");
+IREE_AMDGPU_STATIC_ASSERT(
+    sizeof(iree_hal_amdgpu_aqlprofile_pmc_event_flags_t) ==
+        sizeof(aqlprofile_pmc_event_flags_t),
+    "aqlprofile PMC event flags layout must match the v2 SDK");
+IREE_AMDGPU_STATIC_ASSERT(sizeof(iree_hal_amdgpu_aqlprofile_pmc_event_t) ==
+                              sizeof(aqlprofile_pmc_event_t),
+                          "aqlprofile PMC event layout must match the v2 SDK");
+IREE_AMDGPU_STATIC_ASSERT(sizeof(iree_hal_amdgpu_aqlprofile_agent_info_v1_t) ==
+                              sizeof(aqlprofile_agent_info_v1_t),
+                          "aqlprofile agent info layout must match the v2 SDK");
+IREE_AMDGPU_STATIC_ASSERT(
+    sizeof(iree_hal_amdgpu_aqlprofile_agent_handle_t) ==
+        sizeof(aqlprofile_agent_handle_t),
+    "aqlprofile agent handle layout must match the v2 SDK");
+IREE_AMDGPU_STATIC_ASSERT(
+    sizeof(iree_hal_amdgpu_aqlprofile_pmc_profile_t) ==
+        sizeof(aqlprofile_pmc_profile_t),
+    "aqlprofile PMC profile layout must match the v2 SDK");
+IREE_AMDGPU_STATIC_ASSERT(
+    sizeof(iree_hal_amdgpu_aqlprofile_pmc_aql_packets_t) ==
+        sizeof(aqlprofile_pmc_aql_packets_t),
+    "aqlprofile PMC AQL packet bundle layout must match the v2 SDK");
+IREE_AMDGPU_STATIC_ASSERT(
+    sizeof(iree_hal_amdgpu_aqlprofile_att_control_aql_packets_t) ==
+        sizeof(aqlprofile_att_control_aql_packets_t),
+    "aqlprofile ATT AQL packet bundle layout must match the v2 SDK");
+IREE_AMDGPU_STATIC_ASSERT(
+    sizeof(iree_hal_amdgpu_aqlprofile_att_code_object_data_t) ==
+        sizeof(aqlprofile_att_codeobj_data_t),
+    "aqlprofile ATT code object marker layout must match the v2 SDK");
+IREE_AMDGPU_STATIC_ASSERT(
+    sizeof(iree_hsa_amd_aql_pm4_ib_packet_t) ==
+        sizeof(hsa_ext_amd_aql_pm4_packet_t),
+    "PM4-IB AQL packet layout must match the HSA vendor SDK");
+IREE_AMDGPU_STATIC_ASSERT((uint32_t)IREE_HAL_AMDGPU_AQLPROFILE_BLOCK_NAME_SQ ==
+                              (uint32_t)HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_SQ,
+                          "SQ block id must match the HSA vendor SDK");
+IREE_AMDGPU_STATIC_ASSERT(
+    (uint32_t)IREE_HAL_AMDGPU_AQLPROFILE_PARAMETER_NAME_COMPUTE_UNIT_TARGET ==
+        (uint32_t)HSA_VEN_AMD_AQLPROFILE_PARAMETER_NAME_COMPUTE_UNIT_TARGET,
+    "ATT compute-unit parameter id must match the HSA vendor SDK");
+IREE_AMDGPU_STATIC_ASSERT(
+    (uint32_t)IREE_HAL_AMDGPU_AQLPROFILE_PARAMETER_NAME_SE_MASK ==
+        (uint32_t)HSA_VEN_AMD_AQLPROFILE_PARAMETER_NAME_SE_MASK,
+    "ATT shader-engine-mask parameter id must match the HSA vendor SDK");
+IREE_AMDGPU_STATIC_ASSERT(
+    (uint32_t)IREE_HAL_AMDGPU_AQLPROFILE_PARAMETER_NAME_SIMD_SELECTION ==
+        (uint32_t)HSA_VEN_AMD_AQLPROFILE_PARAMETER_NAME_SIMD_SELECTION,
+    "ATT SIMD-selection parameter id must match the HSA vendor SDK");
+IREE_AMDGPU_STATIC_ASSERT(
+    (uint32_t)IREE_HAL_AMDGPU_AQLPROFILE_PARAMETER_NAME_ATT_BUFFER_SIZE ==
+        (uint32_t)HSA_VEN_AMD_AQLPROFILE_PARAMETER_NAME_ATT_BUFFER_SIZE,
+    "ATT buffer-size parameter id must match the HSA vendor SDK");
+IREE_AMDGPU_STATIC_ASSERT(
+    (uint32_t)IREE_HAL_AMDGPU_AQLPROFILE_ATT_PARAMETER_NAME_BUFFER_SIZE_HIGH ==
+        (uint32_t)AQLPROFILE_ATT_PARAMETER_NAME_BUFFER_SIZE_HIGH,
+    "ATT buffer-size-high parameter id must match the v2 SDK");
+IREE_AMDGPU_STATIC_ASSERT(
+    (uint32_t)IREE_HAL_AMDGPU_AQLPROFILE_ATT_PARAMETER_NAME_RT_TIMESTAMP ==
+        (uint32_t)AQLPROFILE_ATT_PARAMETER_NAME_RT_TIMESTAMP,
+    "ATT runtime-timestamp parameter id must match the v2 SDK");
+
+enum {
+  IREE_HAL_AMDGPU_AQLPROFILE_SUPPORTED_MAJOR_VERSION = AQLPROFILE_VERSION_MAJOR,
+};
+
+//===----------------------------------------------------------------------===//
+// Utilities
+//===----------------------------------------------------------------------===//
+
+static const char* iree_hal_amdgpu_libaqlprofile_names[] = {
+#if defined(IREE_PLATFORM_WINDOWS)
+    "hsa-amd-aqlprofile64.dll",
+#else
+    // Versioned soname first: this is present in normal runtime packages. The
+    // unversioned .so is usually only present in development installs.
+    "libhsa-amd-aqlprofile64.so.1",
+    "libhsa-amd-aqlprofile64.so",
+#endif  // IREE_PLATFORM_WINDOWS
+};
+
+static iree_status_t iree_hal_amdgpu_libaqlprofile_load_symbols(
+    iree_dynamic_library_t* library,
+    iree_hal_amdgpu_libaqlprofile_t* out_libaqlprofile) {
+#define IREE_HAL_AMDGPU_LIBAQLPROFILE_LOOKUP(symbol)       \
+  IREE_RETURN_IF_ERROR(iree_dynamic_library_lookup_symbol( \
+      library, #symbol, (void**)&out_libaqlprofile->symbol))
+
+  IREE_HAL_AMDGPU_LIBAQLPROFILE_LOOKUP(aqlprofile_get_version);
+  IREE_HAL_AMDGPU_LIBAQLPROFILE_LOOKUP(aqlprofile_register_agent_info);
+  IREE_HAL_AMDGPU_LIBAQLPROFILE_LOOKUP(aqlprofile_validate_pmc_event);
+  IREE_HAL_AMDGPU_LIBAQLPROFILE_LOOKUP(aqlprofile_pmc_create_packets);
+  IREE_HAL_AMDGPU_LIBAQLPROFILE_LOOKUP(aqlprofile_pmc_delete_packets);
+  IREE_HAL_AMDGPU_LIBAQLPROFILE_LOOKUP(aqlprofile_pmc_iterate_data);
+
+  out_libaqlprofile->aqlprofile_att_create_packets = (hsa_status_t(HSA_API*)(
+      iree_hal_amdgpu_aqlprofile_handle_t*,
+      iree_hal_amdgpu_aqlprofile_att_control_aql_packets_t*,
+      iree_hal_amdgpu_aqlprofile_att_profile_t,
+      iree_hal_amdgpu_aqlprofile_memory_alloc_callback_t,
+      iree_hal_amdgpu_aqlprofile_memory_dealloc_callback_t,
+      iree_hal_amdgpu_aqlprofile_memory_copy_callback_t, void*))
+      iree_dynamic_library_try_lookup_symbol(library,
+                                             "aqlprofile_att_create_packets");
+  out_libaqlprofile->aqlprofile_att_delete_packets =
+      (void(HSA_API*)(iree_hal_amdgpu_aqlprofile_handle_t))
+          iree_dynamic_library_try_lookup_symbol(
+              library, "aqlprofile_att_delete_packets");
+  out_libaqlprofile->aqlprofile_att_iterate_data = (hsa_status_t(HSA_API*)(
+      iree_hal_amdgpu_aqlprofile_handle_t,
+      iree_hal_amdgpu_aqlprofile_att_data_callback_t, void*))
+      iree_dynamic_library_try_lookup_symbol(library,
+                                             "aqlprofile_att_iterate_data");
+  out_libaqlprofile->aqlprofile_att_codeobj_marker = (hsa_status_t(HSA_API*)(
+      iree_hsa_amd_aql_pm4_ib_packet_t*, iree_hal_amdgpu_aqlprofile_handle_t*,
+      iree_hal_amdgpu_aqlprofile_att_code_object_data_t,
+      iree_hal_amdgpu_aqlprofile_memory_alloc_callback_t,
+      iree_hal_amdgpu_aqlprofile_memory_dealloc_callback_t, void*))
+      iree_dynamic_library_try_lookup_symbol(library,
+                                             "aqlprofile_att_codeobj_marker");
+
+  out_libaqlprofile->hsa_ven_amd_aqlprofile_error_string =
+      (hsa_status_t(HSA_API*)(const char**))
+          iree_dynamic_library_try_lookup_symbol(
+              library, "hsa_ven_amd_aqlprofile_error_string");
+
+#undef IREE_HAL_AMDGPU_LIBAQLPROFILE_LOOKUP
+
+  return iree_ok_status();
+}
+
+static bool iree_hal_amdgpu_libaqlprofile_version_is_supported(
+    iree_hal_amdgpu_aqlprofile_version_t version) {
+  return version.major == IREE_HAL_AMDGPU_AQLPROFILE_SUPPORTED_MAJOR_VERSION;
+}
+
+static bool iree_hal_amdgpu_libaqlprofile_has_queried_version(
+    const iree_hal_amdgpu_libaqlprofile_t* libaqlprofile) {
+  return libaqlprofile &&
+         (libaqlprofile->version.major || libaqlprofile->version.minor ||
+          libaqlprofile->version.patch);
+}
+
+static iree_status_t iree_hal_amdgpu_libaqlprofile_query_version(
+    iree_hal_amdgpu_libaqlprofile_t* libaqlprofile) {
+  iree_hal_amdgpu_aqlprofile_version_t version = {0};
+  hsa_status_t hsa_status = libaqlprofile->aqlprofile_get_version(&version);
+  IREE_RETURN_IF_ERROR(iree_status_from_aqlprofile_status(
+      libaqlprofile, __FILE__, __LINE__, hsa_status, "aqlprofile_get_version",
+      "querying AMDGPU aqlprofile runtime version"));
+
+  libaqlprofile->version = version;
+  if (!iree_hal_amdgpu_libaqlprofile_version_is_supported(version)) {
+    return iree_make_status(
+        IREE_STATUS_FAILED_PRECONDITION,
+        "unsupported AMDGPU aqlprofile runtime version %u.%u.%u; expected "
+        "major version %u matching the SDK headers used to build IREE",
+        version.major, version.minor, version.patch,
+        IREE_HAL_AMDGPU_AQLPROFILE_SUPPORTED_MAJOR_VERSION);
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_libaqlprofile_try_load_library_from_file(
+    const char* file_path, iree_string_builder_t* error_builder,
+    iree_allocator_t host_allocator, iree_dynamic_library_t** out_library) {
+  IREE_ASSERT_ARGUMENT(out_library);
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_TEXT(z0, file_path);
+  *out_library = NULL;
+
+  iree_status_t status = iree_dynamic_library_load_from_file(
+      file_path, IREE_DYNAMIC_LIBRARY_FLAG_NONE, host_allocator, out_library);
+  if (!iree_status_is_ok(status)) {
+    iree_status_t load_status = status;
+    status = iree_string_builder_append_format(
+        error_builder, "\n  Tried: %s\n    ", file_path);
+    if (iree_status_is_ok(status)) {
+      status = iree_string_builder_append_status(error_builder, load_status);
+    }
+    iree_status_free(load_status);
+  }
+
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+static iree_status_t iree_hal_amdgpu_libaqlprofile_try_load_library_from_path(
+    iree_string_view_t path_fragment, iree_string_builder_t* error_builder,
+    iree_allocator_t host_allocator, iree_dynamic_library_t** out_library) {
+  IREE_ASSERT_ARGUMENT(out_library);
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_TEXT(z0, path_fragment.data, path_fragment.size);
+  *out_library = NULL;
+
+  iree_string_builder_t path_builder;
+  iree_string_builder_initialize(host_allocator, &path_builder);
+  iree_status_t status = iree_ok_status();
+
+  if (iree_file_path_is_dynamic_library(path_fragment)) {
+    status = iree_string_builder_append_string(&path_builder, path_fragment);
+    if (iree_status_is_ok(status)) {
+      status = iree_hal_amdgpu_libaqlprofile_try_load_library_from_file(
+          iree_string_builder_buffer(&path_builder), error_builder,
+          host_allocator, out_library);
+    }
+  } else {
+    for (iree_host_size_t i = 0;
+         iree_status_is_ok(status) &&
+         i < IREE_ARRAYSIZE(iree_hal_amdgpu_libaqlprofile_names) &&
+         !*out_library;
+         ++i) {
+      iree_string_builder_reset(&path_builder);
+      status = iree_string_builder_append_format(
+          &path_builder, "%.*s/%s", (int)path_fragment.size, path_fragment.data,
+          iree_hal_amdgpu_libaqlprofile_names[i]);
+      if (iree_status_is_ok(status)) {
+        path_builder.size = iree_file_path_canonicalize(
+            (char*)iree_string_builder_buffer(&path_builder),
+            iree_string_builder_size(&path_builder));
+        status = iree_hal_amdgpu_libaqlprofile_try_load_library_from_file(
+            iree_string_builder_buffer(&path_builder), error_builder,
+            host_allocator, out_library);
+      }
+    }
+  }
+
+  iree_string_builder_deinitialize(&path_builder);
+
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+static iree_status_t iree_hal_amdgpu_libaqlprofile_try_load_adjacent_to_libhsa(
+    const iree_hal_amdgpu_libhsa_t* libhsa,
+    iree_string_builder_t* error_builder, iree_allocator_t host_allocator,
+    iree_dynamic_library_t** out_library) {
+  IREE_ASSERT_ARGUMENT(out_library);
+  *out_library = NULL;
+
+  iree_string_builder_t path_builder;
+  iree_string_builder_initialize(host_allocator, &path_builder);
+  iree_status_t status =
+      iree_hal_amdgpu_libhsa_append_path_to_builder(libhsa, &path_builder);
+  if (iree_status_is_ok(status)) {
+    iree_string_view_t libhsa_path = iree_string_builder_view(&path_builder);
+    iree_string_view_t libhsa_dirname = iree_file_path_dirname(libhsa_path);
+    if (!iree_string_view_is_empty(libhsa_dirname)) {
+      status = iree_hal_amdgpu_libaqlprofile_try_load_library_from_path(
+          libhsa_dirname, error_builder, host_allocator, out_library);
+    }
+  }
+  iree_string_builder_deinitialize(&path_builder);
+  return status;
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hal_amdgpu_libaqlprofile_t
+//===----------------------------------------------------------------------===//
+
+iree_status_t iree_hal_amdgpu_libaqlprofile_initialize(
+    const iree_hal_amdgpu_libhsa_t* libhsa,
+    iree_string_view_list_t search_paths, iree_allocator_t host_allocator,
+    iree_hal_amdgpu_libaqlprofile_t* out_libaqlprofile) {
+  IREE_ASSERT_ARGUMENT(libhsa);
+  IREE_ASSERT_ARGUMENT(out_libaqlprofile);
+  IREE_TRACE_ZONE_BEGIN(z0);
+  memset(out_libaqlprofile, 0, sizeof(*out_libaqlprofile));
+
+  iree_string_builder_t error_builder;
+  iree_string_builder_initialize(host_allocator, &error_builder);
+
+  iree_dynamic_library_t* library = NULL;
+  iree_status_t status = iree_ok_status();
+  for (iree_host_size_t i = 0;
+       iree_status_is_ok(status) && i < search_paths.count && !library; ++i) {
+    status = iree_hal_amdgpu_libaqlprofile_try_load_library_from_path(
+        search_paths.values[i], &error_builder, host_allocator, &library);
+  }
+
+  iree_string_view_t env_path =
+      iree_make_cstring_view(getenv("IREE_HAL_AMDGPU_LIBAQLPROFILE_PATH"));
+  if (iree_status_is_ok(status) && !library &&
+      !iree_string_view_is_empty(env_path)) {
+    status = iree_hal_amdgpu_libaqlprofile_try_load_library_from_path(
+        env_path, &error_builder, host_allocator, &library);
+  }
+
+  if (iree_status_is_ok(status) && !library) {
+    status = iree_hal_amdgpu_libaqlprofile_try_load_adjacent_to_libhsa(
+        libhsa, &error_builder, host_allocator, &library);
+  }
+  if (iree_status_is_ok(status) && !library) {
+    for (iree_host_size_t i = 0;
+         iree_status_is_ok(status) &&
+         i < IREE_ARRAYSIZE(iree_hal_amdgpu_libaqlprofile_names) && !library;
+         ++i) {
+      status = iree_hal_amdgpu_libaqlprofile_try_load_library_from_file(
+          iree_hal_amdgpu_libaqlprofile_names[i], &error_builder,
+          host_allocator, &library);
+    }
+  }
+
+  if (iree_status_is_ok(status) && !library) {
+    status = iree_make_status(
+        IREE_STATUS_NOT_FOUND,
+        "AMDGPU aqlprofile library not found; hardware counter profiling "
+        "requires libhsa-amd-aqlprofile64 to be installed next to HSA, on a "
+        "system search path, or specified with "
+        "IREE_HAL_AMDGPU_LIBAQLPROFILE_PATH: %.*s",
+        (int)iree_string_builder_size(&error_builder),
+        iree_string_builder_buffer(&error_builder));
+  }
+  iree_string_builder_deinitialize(&error_builder);
+
+  if (iree_status_is_ok(status)) {
+    status =
+        iree_hal_amdgpu_libaqlprofile_load_symbols(library, out_libaqlprofile);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_amdgpu_libaqlprofile_query_version(out_libaqlprofile);
+  }
+  if (iree_status_is_ok(status)) {
+    out_libaqlprofile->library = library;
+  } else {
+    iree_dynamic_library_release(library);
+    memset(out_libaqlprofile, 0, sizeof(*out_libaqlprofile));
+  }
+
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+void iree_hal_amdgpu_libaqlprofile_deinitialize(
+    iree_hal_amdgpu_libaqlprofile_t* libaqlprofile) {
+  IREE_ASSERT_ARGUMENT(libaqlprofile);
+  iree_dynamic_library_release(libaqlprofile->library);
+  memset(libaqlprofile, 0, sizeof(*libaqlprofile));
+}
+
+bool iree_hal_amdgpu_libaqlprofile_has_att_support(
+    const iree_hal_amdgpu_libaqlprofile_t* libaqlprofile) {
+  return libaqlprofile && libaqlprofile->aqlprofile_att_create_packets &&
+         libaqlprofile->aqlprofile_att_delete_packets &&
+         libaqlprofile->aqlprofile_att_iterate_data &&
+         libaqlprofile->aqlprofile_att_codeobj_marker;
+}
+
+static void iree_hal_amdgpu_libaqlprofile_append_missing_symbol(
+    const char* symbol_name, bool is_missing, char* buffer,
+    iree_host_size_t buffer_capacity, iree_host_size_t* inout_length) {
+  if (!is_missing || *inout_length >= buffer_capacity) return;
+
+  const char* separator = *inout_length ? ", " : "";
+  const int written =
+      iree_snprintf(buffer + *inout_length, buffer_capacity - *inout_length,
+                    "%s%s", separator, symbol_name);
+  if (written <= 0) return;
+
+  const iree_host_size_t available = buffer_capacity - *inout_length;
+  if ((iree_host_size_t)written >= available) {
+    *inout_length = buffer_capacity - 1;
+  } else {
+    *inout_length += (iree_host_size_t)written;
+  }
+}
+
+iree_status_t iree_hal_amdgpu_libaqlprofile_require_att_support(
+    const iree_hal_amdgpu_libaqlprofile_t* libaqlprofile,
+    const char* context_message) {
+  if (iree_hal_amdgpu_libaqlprofile_has_att_support(libaqlprofile)) {
+    return iree_ok_status();
+  }
+
+  char missing_symbols[256] = {0};
+  iree_host_size_t missing_symbols_length = 0;
+  iree_hal_amdgpu_libaqlprofile_append_missing_symbol(
+      "aqlprofile_att_create_packets",
+      !libaqlprofile || !libaqlprofile->aqlprofile_att_create_packets,
+      missing_symbols, sizeof(missing_symbols), &missing_symbols_length);
+  iree_hal_amdgpu_libaqlprofile_append_missing_symbol(
+      "aqlprofile_att_delete_packets",
+      !libaqlprofile || !libaqlprofile->aqlprofile_att_delete_packets,
+      missing_symbols, sizeof(missing_symbols), &missing_symbols_length);
+  iree_hal_amdgpu_libaqlprofile_append_missing_symbol(
+      "aqlprofile_att_iterate_data",
+      !libaqlprofile || !libaqlprofile->aqlprofile_att_iterate_data,
+      missing_symbols, sizeof(missing_symbols), &missing_symbols_length);
+  iree_hal_amdgpu_libaqlprofile_append_missing_symbol(
+      "aqlprofile_att_codeobj_marker",
+      !libaqlprofile || !libaqlprofile->aqlprofile_att_codeobj_marker,
+      missing_symbols, sizeof(missing_symbols), &missing_symbols_length);
+
+  const iree_hal_amdgpu_aqlprofile_version_t version =
+      libaqlprofile ? libaqlprofile->version
+                    : (iree_hal_amdgpu_aqlprofile_version_t){0};
+  return iree_make_status(
+      IREE_STATUS_UNIMPLEMENTED,
+      "loaded AMDGPU aqlprofile library version %u.%u.%u does not export "
+      "required ATT/SQTT symbol(s): %s%s%s",
+      version.major, version.minor, version.patch, missing_symbols,
+      context_message ? ": " : "", context_message ? context_message : "");
+}
+
+iree_status_t iree_status_from_aqlprofile_status(
+    const iree_hal_amdgpu_libaqlprofile_t* libaqlprofile, const char* file,
+    uint32_t line, hsa_status_t hsa_status, const char* symbol,
+    const char* message) {
+  if (hsa_status == HSA_STATUS_SUCCESS) return iree_ok_status();
+
+  const char* error_string = NULL;
+  if (libaqlprofile && libaqlprofile->hsa_ven_amd_aqlprofile_error_string) {
+    hsa_status_t error_string_status =
+        libaqlprofile->hsa_ven_amd_aqlprofile_error_string(&error_string);
+    if (error_string_status != HSA_STATUS_SUCCESS) {
+      error_string = NULL;
+    }
+  }
+  if (!error_string) {
+    error_string = "unknown aqlprofile error";
+  }
+
+  if (iree_hal_amdgpu_libaqlprofile_has_queried_version(libaqlprofile)) {
+    return iree_make_status_with_location(
+        file, line, IREE_STATUS_INTERNAL,
+        "%s failed with hsa_status=0x%08X (%s) from AMDGPU aqlprofile runtime "
+        "version %u.%u.%u%s%s",
+        symbol, (uint32_t)hsa_status, error_string,
+        libaqlprofile->version.major, libaqlprofile->version.minor,
+        libaqlprofile->version.patch, message ? ": " : "",
+        message ? message : "");
+  }
+  return iree_make_status_with_location(
+      file, line, IREE_STATUS_INTERNAL,
+      "%s failed with hsa_status=0x%08X (%s)%s%s", symbol, (uint32_t)hsa_status,
+      error_string, message ? ": " : "", message ? message : "");
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/libaqlprofile.h b/runtime/src/iree/hal/drivers/amdgpu/util/libaqlprofile.h
new file mode 100644
index 0000000..2fd1f7e
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/libaqlprofile.h
@@ -0,0 +1,364 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/licenses/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_UTIL_LIBAQLPROFILE_H_
+#define IREE_HAL_DRIVERS_AMDGPU_UTIL_LIBAQLPROFILE_H_
+
+#include "iree/hal/drivers/amdgpu/abi/queue.h"
+#include "iree/hal/drivers/amdgpu/util/libhsa.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+typedef struct iree_dynamic_library_t iree_dynamic_library_t;
+
+//===----------------------------------------------------------------------===//
+// aqlprofile SDK ABI subset
+//===----------------------------------------------------------------------===//
+
+typedef struct iree_hal_amdgpu_aqlprofile_handle_t {
+  // Opaque handle owned by the aqlprofile runtime.
+  uint64_t handle;
+} iree_hal_amdgpu_aqlprofile_handle_t;
+
+// Mirrors aqlprofile_version_t from the ROCm aqlprofile SDK v2 ABI. The
+// runtime wrapper validates the loaded library against this ABI before exposing
+// packet generation entry points.
+typedef struct iree_hal_amdgpu_aqlprofile_version_t {
+  // Major aqlprofile runtime version.
+  uint32_t major;
+  // Minor aqlprofile runtime version.
+  uint32_t minor;
+  // Patch aqlprofile runtime version.
+  uint32_t patch;
+} iree_hal_amdgpu_aqlprofile_version_t;
+
+typedef enum iree_hal_amdgpu_aqlprofile_memory_hint_e {
+  IREE_HAL_AMDGPU_AQLPROFILE_MEMORY_HINT_NONE = 0,
+  IREE_HAL_AMDGPU_AQLPROFILE_MEMORY_HINT_HOST = 1,
+  IREE_HAL_AMDGPU_AQLPROFILE_MEMORY_HINT_DEVICE_UNCACHED = 2,
+  IREE_HAL_AMDGPU_AQLPROFILE_MEMORY_HINT_DEVICE_COHERENT = 3,
+  IREE_HAL_AMDGPU_AQLPROFILE_MEMORY_HINT_DEVICE_NONCOHERENT = 4,
+} iree_hal_amdgpu_aqlprofile_memory_hint_t;
+
+typedef enum iree_hal_amdgpu_aqlprofile_agent_version_e {
+  IREE_HAL_AMDGPU_AQLPROFILE_AGENT_VERSION_NONE = 0,
+  IREE_HAL_AMDGPU_AQLPROFILE_AGENT_VERSION_V0 = 1,
+  IREE_HAL_AMDGPU_AQLPROFILE_AGENT_VERSION_V1 = 2,
+} iree_hal_amdgpu_aqlprofile_agent_version_t;
+
+typedef uint32_t iree_hal_amdgpu_aqlprofile_parameter_name_t;
+enum iree_hal_amdgpu_aqlprofile_parameter_name_e {
+  IREE_HAL_AMDGPU_AQLPROFILE_PARAMETER_NAME_COMPUTE_UNIT_TARGET = 0,
+  IREE_HAL_AMDGPU_AQLPROFILE_PARAMETER_NAME_SE_MASK = 5,
+  IREE_HAL_AMDGPU_AQLPROFILE_PARAMETER_NAME_SIMD_SELECTION = 8,
+  IREE_HAL_AMDGPU_AQLPROFILE_PARAMETER_NAME_ATT_BUFFER_SIZE = 10,
+};
+
+typedef uint32_t iree_hal_amdgpu_aqlprofile_block_name_t;
+enum iree_hal_amdgpu_aqlprofile_block_name_e {
+  IREE_HAL_AMDGPU_AQLPROFILE_BLOCK_NAME_SQ = 6,
+};
+
+typedef union iree_hal_amdgpu_aqlprofile_buffer_desc_flags_t {
+  // Raw aqlprofile buffer descriptor flags.
+  uint32_t raw;
+  // Decoded aqlprofile buffer descriptor fields.
+  struct {
+    // True when the requested allocation must be visible to the profiled GPU.
+    uint32_t device_access : 1;
+    // True when the requested allocation must be visible to the host.
+    uint32_t host_access : 1;
+    // Requested memory placement hint.
+    uint32_t memory_hint : 6;
+    // Reserved bits reported by aqlprofile.
+    uint32_t reserved : 24;
+  };
+} iree_hal_amdgpu_aqlprofile_buffer_desc_flags_t;
+
+typedef hsa_status_t(
+    IREE_API_PTR* iree_hal_amdgpu_aqlprofile_memory_alloc_callback_t)(
+    void** ptr, uint64_t size,
+    iree_hal_amdgpu_aqlprofile_buffer_desc_flags_t flags, void* user_data);
+
+typedef void(
+    IREE_API_PTR* iree_hal_amdgpu_aqlprofile_memory_dealloc_callback_t)(
+    void* ptr, void* user_data);
+
+typedef hsa_status_t(
+    IREE_API_PTR* iree_hal_amdgpu_aqlprofile_memory_copy_callback_t)(
+    void* target, const void* source, size_t size, void* user_data);
+
+typedef enum iree_hal_amdgpu_aqlprofile_att_parameter_rt_timestamp_e {
+  IREE_HAL_AMDGPU_AQLPROFILE_ATT_PARAMETER_RT_TIMESTAMP_DEFAULT = 0,
+  IREE_HAL_AMDGPU_AQLPROFILE_ATT_PARAMETER_RT_TIMESTAMP_ENABLE = 1,
+  IREE_HAL_AMDGPU_AQLPROFILE_ATT_PARAMETER_RT_TIMESTAMP_DISABLE = 2,
+} iree_hal_amdgpu_aqlprofile_att_parameter_rt_timestamp_t;
+
+typedef enum iree_hal_amdgpu_aqlprofile_att_parameter_name_ext_e {
+  IREE_HAL_AMDGPU_AQLPROFILE_ATT_PARAMETER_NAME_BUFFER_SIZE_HIGH = 11,
+  IREE_HAL_AMDGPU_AQLPROFILE_ATT_PARAMETER_NAME_RT_TIMESTAMP = 12,
+} iree_hal_amdgpu_aqlprofile_att_parameter_name_ext_t;
+
+typedef struct iree_hal_amdgpu_aqlprofile_att_parameter_t {
+  // aqlprofile ATT parameter name or extended ATT parameter name.
+  uint32_t parameter_name;
+  union {
+    // Scalar parameter value.
+    uint32_t value;
+    struct {
+      // Performance counter event id packed into ATT perf-counter parameters.
+      uint32_t counter_id : 28;
+      // SIMD mask packed into ATT perf-counter parameters.
+      uint32_t simd_mask : 4;
+    };
+  };
+} iree_hal_amdgpu_aqlprofile_att_parameter_t;
+
+typedef struct iree_hal_amdgpu_aqlprofile_att_profile_t {
+  // HSA GPU agent being traced.
+  hsa_agent_t agent;
+  // Borrowed array of ATT trace parameters.
+  const iree_hal_amdgpu_aqlprofile_att_parameter_t* parameters;
+  // Number of entries in |parameters|.
+  uint32_t parameter_count;
+} iree_hal_amdgpu_aqlprofile_att_profile_t;
+
+typedef hsa_status_t(
+    IREE_API_PTR* iree_hal_amdgpu_aqlprofile_att_data_callback_t)(
+    uint32_t shader_engine, void* buffer, uint64_t size, void* user_data);
+
+typedef union iree_hal_amdgpu_aqlprofile_pmc_event_flags_t {
+  // Raw aqlprofile PMC event flags.
+  uint32_t raw;
+  // Decoded SQ event flag fields.
+  struct {
+    // SQ accumulation mode requested by the event.
+    uint32_t accum : 3;
+    // Reserved bits reported by aqlprofile.
+    uint32_t reserved : 25;
+    // SPM decode depth requested by the event.
+    uint32_t depth : 4;
+  } sq_flags;
+  // Decoded SPM event flag fields.
+  struct {
+    // Reserved bits reported by aqlprofile.
+    uint32_t reserved : 28;
+    // SPM decode depth requested by the event.
+    uint32_t depth : 4;
+  } spm_flags;
+} iree_hal_amdgpu_aqlprofile_pmc_event_flags_t;
+
+typedef struct iree_hal_amdgpu_aqlprofile_pmc_event_t {
+  // Hardware block instance index.
+  uint32_t block_index;
+  // Hardware event selector id within |block_name|.
+  uint32_t event_id;
+  // Event-specific flags such as accumulation mode.
+  iree_hal_amdgpu_aqlprofile_pmc_event_flags_t flags;
+  // Hardware block family containing |event_id|.
+  iree_hal_amdgpu_aqlprofile_block_name_t block_name;
+} iree_hal_amdgpu_aqlprofile_pmc_event_t;
+
+typedef struct iree_hal_amdgpu_aqlprofile_agent_info_v1_t {
+  // NUL-terminated GPU ISA name such as "gfx1100".
+  const char* agent_gfxip;
+  // Number of XCC partitions reported for the GPU.
+  uint32_t xcc_num;
+  // Number of shader engines reported for the GPU.
+  uint32_t se_num;
+  // Number of compute units reported for the GPU.
+  uint32_t cu_num;
+  // Number of shader arrays per shader engine.
+  uint32_t shader_arrays_per_se;
+  // PCI domain reported for the GPU.
+  uint32_t domain;
+  // PCI BDF location id reported for the GPU.
+  uint32_t location_id;
+} iree_hal_amdgpu_aqlprofile_agent_info_v1_t;
+
+typedef struct iree_hal_amdgpu_aqlprofile_agent_handle_t {
+  // Opaque registered-agent handle owned by the aqlprofile runtime.
+  uint64_t handle;
+} iree_hal_amdgpu_aqlprofile_agent_handle_t;
+
+typedef struct iree_hal_amdgpu_aqlprofile_pmc_profile_t {
+  // Registered aqlprofile GPU agent handle.
+  iree_hal_amdgpu_aqlprofile_agent_handle_t agent;
+  // Borrowed array of hardware PMC event requests.
+  const iree_hal_amdgpu_aqlprofile_pmc_event_t* events;
+  // Number of entries in |events|.
+  uint32_t event_count;
+} iree_hal_amdgpu_aqlprofile_pmc_profile_t;
+
+typedef hsa_status_t(
+    IREE_API_PTR* iree_hal_amdgpu_aqlprofile_pmc_data_callback_t)(
+    iree_hal_amdgpu_aqlprofile_pmc_event_t event, uint64_t counter_id,
+    uint64_t counter_value, void* user_data);
+
+typedef struct iree_hal_amdgpu_aqlprofile_pmc_aql_packets_t {
+  // AQL PM4-IB packet that resets and starts the selected counters.
+  iree_hsa_amd_aql_pm4_ib_packet_t start_packet;
+  // AQL PM4-IB packet that stops the selected counters.
+  iree_hsa_amd_aql_pm4_ib_packet_t stop_packet;
+  // AQL PM4-IB packet that reads the selected counters to output storage.
+  iree_hsa_amd_aql_pm4_ib_packet_t read_packet;
+} iree_hal_amdgpu_aqlprofile_pmc_aql_packets_t;
+
+typedef struct iree_hal_amdgpu_aqlprofile_att_control_aql_packets_t {
+  // AQL PM4-IB packet that starts thread trace.
+  iree_hsa_amd_aql_pm4_ib_packet_t start_packet;
+  // AQL PM4-IB packet that stops thread trace and flushes trace metadata.
+  iree_hsa_amd_aql_pm4_ib_packet_t stop_packet;
+} iree_hal_amdgpu_aqlprofile_att_control_aql_packets_t;
+
+typedef struct iree_hal_amdgpu_aqlprofile_att_code_object_data_t {
+  // Stable code-object marker id.
+  uint64_t id;
+  // Loader-provided code-object load delta used by the ATT decoder.
+  uint64_t address;
+  // Loaded code-object range byte length.
+  uint64_t length;
+  // HSA GPU agent owning the code object.
+  hsa_agent_t agent;
+  // True when this marker records an unload instead of a load.
+  uint32_t is_unload : 1;
+  // True when the code object was loaded before thread trace started.
+  uint32_t from_start : 1;
+} iree_hal_amdgpu_aqlprofile_att_code_object_data_t;
+
+//===----------------------------------------------------------------------===//
+// iree_hal_amdgpu_libaqlprofile_t
+//===----------------------------------------------------------------------===//
+
+// Dynamically loaded libhsa-amd-aqlprofile64.so.
+//
+// Thread-safe; immutable after initialization.
+typedef struct iree_hal_amdgpu_libaqlprofile_t {
+  // Loaded aqlprofile dynamic library.
+  iree_dynamic_library_t* library;
+
+  // Loaded aqlprofile runtime version returned by aqlprofile_get_version.
+  iree_hal_amdgpu_aqlprofile_version_t version;
+
+  // Returns the aqlprofile runtime version.
+  hsa_status_t(HSA_API* aqlprofile_get_version)(
+      iree_hal_amdgpu_aqlprofile_version_t* version);
+
+  // Registers one HSA agent with the aqlprofile runtime.
+  hsa_status_t(HSA_API* aqlprofile_register_agent_info)(
+      iree_hal_amdgpu_aqlprofile_agent_handle_t* agent_id,
+      const void* agent_info,
+      iree_hal_amdgpu_aqlprofile_agent_version_t version);
+
+  // Validates that a PMC event can be collected on an agent.
+  hsa_status_t(HSA_API* aqlprofile_validate_pmc_event)(
+      iree_hal_amdgpu_aqlprofile_agent_handle_t agent,
+      const iree_hal_amdgpu_aqlprofile_pmc_event_t* event, bool* result);
+
+  // Creates persistent PM4 programs and AQL PM4-IB packet templates for one
+  // PMC profile handle.
+  hsa_status_t(HSA_API* aqlprofile_pmc_create_packets)(
+      iree_hal_amdgpu_aqlprofile_handle_t* handle,
+      iree_hal_amdgpu_aqlprofile_pmc_aql_packets_t* packets,
+      iree_hal_amdgpu_aqlprofile_pmc_profile_t profile,
+      iree_hal_amdgpu_aqlprofile_memory_alloc_callback_t alloc_cb,
+      iree_hal_amdgpu_aqlprofile_memory_dealloc_callback_t dealloc_cb,
+      iree_hal_amdgpu_aqlprofile_memory_copy_callback_t memcpy_cb,
+      void* user_data);
+
+  // Deletes PM4 programs and output buffers associated with a PMC handle.
+  void(HSA_API* aqlprofile_pmc_delete_packets)(
+      iree_hal_amdgpu_aqlprofile_handle_t handle);
+
+  // Iterates decoded PMC values from a completed profile handle.
+  hsa_status_t(HSA_API* aqlprofile_pmc_iterate_data)(
+      iree_hal_amdgpu_aqlprofile_handle_t handle,
+      iree_hal_amdgpu_aqlprofile_pmc_data_callback_t callback, void* user_data);
+
+  // Creates persistent PM4 programs and AQL PM4-IB packet templates for one
+  // ATT/SQTT trace control profile handle. Optional; use
+  // iree_hal_amdgpu_libaqlprofile_has_att_support before calling.
+  hsa_status_t(HSA_API* aqlprofile_att_create_packets)(
+      iree_hal_amdgpu_aqlprofile_handle_t* handle,
+      iree_hal_amdgpu_aqlprofile_att_control_aql_packets_t* packets,
+      iree_hal_amdgpu_aqlprofile_att_profile_t profile,
+      iree_hal_amdgpu_aqlprofile_memory_alloc_callback_t alloc_cb,
+      iree_hal_amdgpu_aqlprofile_memory_dealloc_callback_t dealloc_cb,
+      iree_hal_amdgpu_aqlprofile_memory_copy_callback_t memcpy_cb,
+      void* user_data);
+
+  // Deletes PM4 programs and buffers associated with an ATT/SQTT handle.
+  void(HSA_API* aqlprofile_att_delete_packets)(
+      iree_hal_amdgpu_aqlprofile_handle_t handle);
+
+  // Iterates decoded ATT/SQTT trace data by shader engine.
+  hsa_status_t(HSA_API* aqlprofile_att_iterate_data)(
+      iree_hal_amdgpu_aqlprofile_handle_t handle,
+      iree_hal_amdgpu_aqlprofile_att_data_callback_t callback, void* user_data);
+
+  // Creates a persistent AQL PM4-IB packet for a code-object marker.
+  hsa_status_t(HSA_API* aqlprofile_att_codeobj_marker)(
+      iree_hsa_amd_aql_pm4_ib_packet_t* packet,
+      iree_hal_amdgpu_aqlprofile_handle_t* handle,
+      iree_hal_amdgpu_aqlprofile_att_code_object_data_t data,
+      iree_hal_amdgpu_aqlprofile_memory_alloc_callback_t alloc_cb,
+      iree_hal_amdgpu_aqlprofile_memory_dealloc_callback_t dealloc_cb,
+      void* user_data);
+
+  // Optionally returns a textual error for the most recent aqlprofile error.
+  hsa_status_t(HSA_API* hsa_ven_amd_aqlprofile_error_string)(const char** str);
+} iree_hal_amdgpu_libaqlprofile_t;
+
+// Initializes |out_libaqlprofile| by dynamically loading the aqlprofile SDK.
+//
+// |search_paths| overrides the default library search paths and looks for the
+// canonical library file under each path before falling back to the HSA library
+// directory, IREE_HAL_AMDGPU_LIBAQLPROFILE_PATH, and system search paths.
+iree_status_t iree_hal_amdgpu_libaqlprofile_initialize(
+    const iree_hal_amdgpu_libhsa_t* libhsa,
+    iree_string_view_list_t search_paths, iree_allocator_t host_allocator,
+    iree_hal_amdgpu_libaqlprofile_t* out_libaqlprofile);
+
+// Deinitializes |libaqlprofile| by unloading the backing library.
+void iree_hal_amdgpu_libaqlprofile_deinitialize(
+    iree_hal_amdgpu_libaqlprofile_t* libaqlprofile);
+
+// Returns true when the loaded aqlprofile library exports the full ATT/SQTT
+// packet generation and data iteration ABI.
+bool iree_hal_amdgpu_libaqlprofile_has_att_support(
+    const iree_hal_amdgpu_libaqlprofile_t* libaqlprofile);
+
+// Returns OK if |libaqlprofile| exports the ATT/SQTT packet generation and data
+// iteration ABI, or a detailed UNIMPLEMENTED status naming missing symbols.
+iree_status_t iree_hal_amdgpu_libaqlprofile_require_att_support(
+    const iree_hal_amdgpu_libaqlprofile_t* libaqlprofile,
+    const char* context_message);
+
+// Returns an IREE status with the aqlprofile error string when available.
+iree_status_t iree_status_from_aqlprofile_status(
+    const iree_hal_amdgpu_libaqlprofile_t* libaqlprofile, const char* file,
+    uint32_t line, hsa_status_t hsa_status, const char* symbol,
+    const char* message);
+
+// Wraps an iree_hal_amdgpu_libaqlprofile_t* for error helper calls.
+#define IREE_LIBAQLPROFILE(libaqlprofile) (libaqlprofile), __FILE__, __LINE__
+
+#define IREE_RETURN_IF_AQLPROFILE_ERROR(libaqlprofile, expr, message)        \
+  do {                                                                       \
+    hsa_status_t hsa_status_ = (expr);                                       \
+    if (IREE_UNLIKELY(hsa_status_ != HSA_STATUS_SUCCESS)) {                  \
+      return iree_status_from_aqlprofile_status(                             \
+          (libaqlprofile), __FILE__, __LINE__, hsa_status_, #expr, message); \
+    }                                                                        \
+  } while (0)
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_UTIL_LIBAQLPROFILE_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/libaqlprofile_test.cc b/runtime/src/iree/hal/drivers/amdgpu/util/libaqlprofile_test.cc
new file mode 100644
index 0000000..1cf1803
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/libaqlprofile_test.cc
@@ -0,0 +1,58 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/util/libaqlprofile.h"
+
+#include "iree/base/api.h"
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace iree::hal::amdgpu {
+namespace {
+
+// Tests that we can find, load, query, and unload aqlprofile. If HSA or
+// aqlprofile cannot be found then we skip the test so that it doesn't fail on
+// machines without ROCm profiling libraries installed.
+TEST(LibAqlprofileTest, Load) {
+  iree_hal_amdgpu_libhsa_t libhsa;
+  iree_status_t status = iree_hal_amdgpu_libhsa_initialize(
+      IREE_HAL_AMDGPU_LIBHSA_FLAG_NONE, iree_string_view_list_empty(),
+      iree_allocator_system(), &libhsa);
+  if (!iree_status_is_ok(status)) {
+    iree_status_fprint(stderr, status);
+    iree_status_free(status);
+    GTEST_SKIP() << "HSA not available, skipping test";
+  }
+
+  iree_hal_amdgpu_libaqlprofile_t libaqlprofile;
+  status = iree_hal_amdgpu_libaqlprofile_initialize(
+      &libhsa, iree_string_view_list_empty(), iree_allocator_system(),
+      &libaqlprofile);
+  if (iree_status_code(status) == IREE_STATUS_NOT_FOUND) {
+    iree_status_fprint(stderr, status);
+    iree_status_free(status);
+    iree_hal_amdgpu_libhsa_deinitialize(&libhsa);
+    GTEST_SKIP() << "aqlprofile not available, skipping test";
+  }
+  if (!iree_status_is_ok(status)) {
+    iree_hal_amdgpu_libhsa_deinitialize(&libhsa);
+  }
+  IREE_ASSERT_OK(status);
+
+  EXPECT_NE(libaqlprofile.version.major, 0u);
+  EXPECT_NE(libaqlprofile.aqlprofile_get_version, nullptr);
+  EXPECT_NE(libaqlprofile.aqlprofile_register_agent_info, nullptr);
+  EXPECT_NE(libaqlprofile.aqlprofile_validate_pmc_event, nullptr);
+  EXPECT_NE(libaqlprofile.aqlprofile_pmc_create_packets, nullptr);
+  EXPECT_NE(libaqlprofile.aqlprofile_pmc_delete_packets, nullptr);
+  EXPECT_NE(libaqlprofile.aqlprofile_pmc_iterate_data, nullptr);
+
+  iree_hal_amdgpu_libaqlprofile_deinitialize(&libaqlprofile);
+  iree_hal_amdgpu_libhsa_deinitialize(&libhsa);
+}
+
+}  // namespace
+}  // namespace iree::hal::amdgpu
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/libhsa.c b/runtime/src/iree/hal/drivers/amdgpu/util/libhsa.c
index 58b02a4..2205441 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/util/libhsa.c
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/libhsa.c
@@ -248,7 +248,7 @@
   return iree_ok_status();
 }
 
-static bool iree_hal_amdgpu_libhsa_try_load_library_from_file(
+static iree_status_t iree_hal_amdgpu_libhsa_try_load_library_from_file(
     iree_hal_amdgpu_libhsa_flags_t flags, const char* file_path,
     iree_string_builder_t* error_builder, iree_allocator_t host_allocator,
     iree_dynamic_library_t** out_library) {
@@ -263,14 +263,17 @@
 
   // Append error message to the status builder.
   if (!iree_status_is_ok(status)) {
-    IREE_IGNORE_ERROR(iree_string_builder_append_format(
-        error_builder, "\n  Tried: %s\n    ", file_path));
-    IREE_IGNORE_ERROR(iree_string_builder_append_status(error_builder, status));
+    iree_status_t load_status = status;
+    status = iree_string_builder_append_format(
+        error_builder, "\n  Tried: %s\n    ", file_path);
+    if (iree_status_is_ok(status)) {
+      status = iree_string_builder_append_status(error_builder, load_status);
+    }
+    iree_status_free(load_status);
   }
 
-  iree_status_ignore(status);
   IREE_TRACE_ZONE_END(z0);
-  return *out_library != NULL;
+  return status;
 }
 
 static const char* iree_hal_amdgpu_libhsa_names[] = {
@@ -280,11 +283,15 @@
     // users can still build the HAL driver but it won't run.
     "hsa-runtime64.dll",
 #else
+    // Versioned soname first — this is the real soname baked into the library
+    // and is always present. The unversioned .so is a development symlink that
+    // only exists in -dev packages or full ROCm installs.
+    "libhsa-runtime64.so.1",
     "libhsa-runtime64.so",
 #endif  // IREE_PLATFORM_WINDOWS
 };
 
-static bool iree_hal_amdgpu_libhsa_try_load_library_from_path(
+static iree_status_t iree_hal_amdgpu_libhsa_try_load_library_from_path(
     iree_hal_amdgpu_libhsa_flags_t flags, iree_string_view_t path_fragment,
     iree_string_builder_t* error_builder, iree_allocator_t host_allocator,
     iree_dynamic_library_t** out_library) {
@@ -297,38 +304,43 @@
   // we do that locally in a heap-allocated NUL-terminated string builder.
   iree_string_builder_t path_builder;
   iree_string_builder_initialize(host_allocator, &path_builder);
+  iree_status_t status = iree_ok_status();
 
   if (iree_file_path_is_dynamic_library(path_fragment)) {
     // User provided a filename - try to use it directly. If it's an absolute
     // file path the system will try that and otherwise it'll search all library
     // paths for the given filename.
-    iree_status_ignore(
-        iree_string_builder_append_string(&path_builder, path_fragment));
-    iree_hal_amdgpu_libhsa_try_load_library_from_file(
-        flags, iree_string_builder_buffer(&path_builder), error_builder,
-        host_allocator, out_library);
-  } else {
-    // Join the provided path with each canonical name and try that.
-    iree_string_builder_reset(&path_builder);
-    for (iree_host_size_t i = 0;
-         i < IREE_ARRAYSIZE(iree_hal_amdgpu_libhsa_names) && !*out_library;
-         ++i) {
-      iree_status_ignore(iree_string_builder_append_format(
-          &path_builder, "%.*s/%s", (int)path_fragment.size, path_fragment.data,
-          iree_hal_amdgpu_libhsa_names[i]));
-      path_builder.size = iree_file_path_canonicalize(
-          (char*)iree_string_builder_buffer(&path_builder),
-          iree_string_builder_size(&path_builder));
-      iree_hal_amdgpu_libhsa_try_load_library_from_file(
+    status = iree_string_builder_append_string(&path_builder, path_fragment);
+    if (iree_status_is_ok(status)) {
+      status = iree_hal_amdgpu_libhsa_try_load_library_from_file(
           flags, iree_string_builder_buffer(&path_builder), error_builder,
           host_allocator, out_library);
     }
+  } else {
+    // Join the provided path with each canonical name and try that.
+    for (iree_host_size_t i = 0;
+         iree_status_is_ok(status) &&
+         i < IREE_ARRAYSIZE(iree_hal_amdgpu_libhsa_names) && !*out_library;
+         ++i) {
+      iree_string_builder_reset(&path_builder);
+      status = iree_string_builder_append_format(
+          &path_builder, "%.*s/%s", (int)path_fragment.size, path_fragment.data,
+          iree_hal_amdgpu_libhsa_names[i]);
+      if (iree_status_is_ok(status)) {
+        path_builder.size = iree_file_path_canonicalize(
+            (char*)iree_string_builder_buffer(&path_builder),
+            iree_string_builder_size(&path_builder));
+        status = iree_hal_amdgpu_libhsa_try_load_library_from_file(
+            flags, iree_string_builder_buffer(&path_builder), error_builder,
+            host_allocator, out_library);
+      }
+    }
   }
 
   iree_string_builder_deinitialize(&path_builder);
 
   IREE_TRACE_ZONE_END(z0);
-  return *out_library != NULL;
+  return status;
 }
 
 static iree_status_t iree_hal_amdgpu_libhsa_load_library(
@@ -345,11 +357,13 @@
   iree_string_builder_initialize(host_allocator, &error_builder);
 
   iree_dynamic_library_t* library = NULL;
+  iree_status_t status = iree_ok_status();
 
   // If the caller provided explicit paths we always try to use those first.
   // This allows a hosting application to handle overrides as they see fit.
-  for (iree_host_size_t i = 0; i < search_paths.count && !library; ++i) {
-    iree_hal_amdgpu_libhsa_try_load_library_from_path(
+  for (iree_host_size_t i = 0;
+       iree_status_is_ok(status) && i < search_paths.count && !library; ++i) {
+    status = iree_hal_amdgpu_libhsa_try_load_library_from_path(
         flags, search_paths.values[i], &error_builder, host_allocator,
         &library);
   }
@@ -359,27 +373,27 @@
   // recompile their application/pass through flags down into IREE.
   iree_string_view_t env_path =
       iree_make_cstring_view(getenv("IREE_HAL_AMDGPU_LIBHSA_PATH"));
-  if (!library && !iree_string_view_is_empty(env_path)) {
-    iree_hal_amdgpu_libhsa_try_load_library_from_path(
+  if (iree_status_is_ok(status) && !library &&
+      !iree_string_view_is_empty(env_path)) {
+    status = iree_hal_amdgpu_libhsa_try_load_library_from_path(
         flags, env_path, &error_builder, host_allocator, &library);
   }
 
   // Fallback (that is the common case) and try loading with the canonical
   // library names from the system search paths.
-  if (!library) {
+  if (iree_status_is_ok(status) && !library) {
     for (iree_host_size_t i = 0;
-         i < IREE_ARRAYSIZE(iree_hal_amdgpu_libhsa_names) && !library; ++i) {
-      if (iree_hal_amdgpu_libhsa_try_load_library_from_file(
-              flags, iree_hal_amdgpu_libhsa_names[i], &error_builder,
-              host_allocator, &library)) {
-        break;
-      }
+         iree_status_is_ok(status) &&
+         i < IREE_ARRAYSIZE(iree_hal_amdgpu_libhsa_names) && !library;
+         ++i) {
+      status = iree_hal_amdgpu_libhsa_try_load_library_from_file(
+          flags, iree_hal_amdgpu_libhsa_names[i], &error_builder,
+          host_allocator, &library);
     }
   }
 
   // If no library was found emit the full failure status.
-  iree_status_t status = iree_ok_status();
-  if (!library) {
+  if (iree_status_is_ok(status) && !library) {
     status =
         iree_make_status(IREE_STATUS_NOT_FOUND,
                          "HSA/ROCR-Runtime library not found; ensure it is "
@@ -395,18 +409,25 @@
     if (!iree_status_is_ok(status) && out_libhsa->hsa_init) {
       iree_string_builder_t annotation_builder;
       iree_string_builder_initialize(host_allocator, &annotation_builder);
-      IREE_IGNORE_ERROR(iree_dynamic_library_append_symbol_path_to_builder(
-          out_libhsa->hsa_init, &annotation_builder));
-      status = iree_status_annotate_f(
-          status, "using %.*s",
-          (int)iree_string_builder_size(&annotation_builder),
-          iree_string_builder_buffer(&annotation_builder));
+      iree_status_t annotation_status =
+          iree_dynamic_library_append_symbol_path_to_builder(
+              out_libhsa->hsa_init, &annotation_builder);
+      if (iree_status_is_ok(annotation_status)) {
+        status = iree_status_annotate_f(
+            status, "using %.*s",
+            (int)iree_string_builder_size(&annotation_builder),
+            iree_string_builder_buffer(&annotation_builder));
+      } else {
+        status = iree_status_join(status, annotation_status);
+      }
       iree_string_builder_deinitialize(&annotation_builder);
     }
   }
 
   if (iree_status_is_ok(status)) {
     out_libhsa->library = library;
+  } else {
+    iree_dynamic_library_release(library);
   }
   IREE_TRACE_ZONE_END(z0);
   return status;
@@ -468,11 +489,12 @@
 
   // Initialize HSA. If already loaded this increments the refcount to be paired
   // with the hsa_shut_down we call in deinitialize.
+  // ROCR leaks global singleton state during initialization and extension
+  // table queries. Suppress the resulting LSAN reports for all of it.
+  IREE_LEAK_CHECK_DISABLE_PUSH();
+
   if (iree_status_is_ok(status)) {
-    // ROCR leaks a tremendous amount of global junk.
-    IREE_LEAK_CHECK_DISABLE_PUSH();
     status = iree_hsa_init(IREE_LIBHSA(out_libhsa));
-    IREE_LEAK_CHECK_DISABLE_POP();
     if (iree_status_is_ok(status)) {
       out_libhsa->initialized = true;
     }
@@ -488,6 +510,8 @@
         IREE_SV("querying HSA_EXTENSION_AMD_LOADER"));
   }
 
+  IREE_LEAK_CHECK_DISABLE_POP();
+
   if (!iree_status_is_ok(status)) {
     iree_hal_amdgpu_libhsa_deinitialize(out_libhsa);
   }
@@ -503,7 +527,7 @@
   // Decrement HSA ref count; others may still have it loaded/in-use.
   if (libhsa->initialized) {
     IREE_LEAK_CHECK_DISABLE_PUSH();
-    IREE_IGNORE_ERROR(iree_hsa_shut_down(IREE_LIBHSA(libhsa)));
+    iree_hal_amdgpu_hsa_cleanup_assert_success(iree_hsa_shut_down_raw(libhsa));
     IREE_LEAK_CHECK_DISABLE_POP();
   }
 
@@ -598,13 +622,17 @@
 // all the iree_hsa_* methods directly to the HSA functions.
 #define IREE_HAL_AMDGPU_LIBHSA_PFN_hsa_status_t(trace_category, result_type,  \
                                                 symbol, decl, args)           \
+  hsa_status_t iree_##symbol##_raw(                                           \
+      const iree_hal_amdgpu_libhsa_t* IREE_RESTRICT libhsa _COMMA_DECL(       \
+          decl)) {                                                            \
+    return IREE_HAL_AMDGPU_LIBHSA_LIBPTR(libhsa) symbol(args);                \
+  }                                                                           \
   iree_status_t iree_##symbol(                                                \
       const iree_hal_amdgpu_libhsa_t* IREE_RESTRICT libhsa, const char* file, \
       const uint32_t line _COMMA_DECL(decl)) {                                \
     IREE_HAL_AMDGPU_LIBHSA_TRACE_ZONE_BEGIN_##trace_category(z0);             \
                                                                               \
-    hsa_status_t hsa_status =                                                 \
-        IREE_HAL_AMDGPU_LIBHSA_LIBPTR(libhsa) symbol(args);                   \
+    hsa_status_t hsa_status = iree_##symbol##_raw(libhsa _COMMA_ARGS(args));  \
                                                                               \
     iree_status_t iree_status = iree_ok_status();                             \
     if (IREE_UNLIKELY(hsa_status != HSA_STATUS_SUCCESS &&                     \
@@ -660,5 +688,7 @@
 #define DECL(...) __VA_ARGS__
 #define ARGS(...) __VA_ARGS__
 #define _COMMA_DECL(...) __VA_OPT__(, ) __VA_ARGS__
+#define _COMMA_ARGS(...) __VA_OPT__(, ) __VA_ARGS__
 #include "iree/hal/drivers/amdgpu/util/libhsa_tables.h"  // IWYU pragma: export
+#undef _COMMA_ARGS
 #undef _COMMA_DECL
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/libhsa.h b/runtime/src/iree/hal/drivers/amdgpu/util/libhsa.h
index cbc722e..c79e377 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/util/libhsa.h
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/libhsa.h
@@ -186,7 +186,9 @@
                                                 symbol, decl, ...)            \
   iree_status_t iree_##symbol(                                                \
       const iree_hal_amdgpu_libhsa_t* IREE_RESTRICT libhsa, const char* file, \
-      const uint32_t line _COMMA_DECL(decl));
+      const uint32_t line _COMMA_DECL(decl));                                 \
+  hsa_status_t iree_##symbol##_raw(                                           \
+      const iree_hal_amdgpu_libhsa_t* IREE_RESTRICT libhsa _COMMA_DECL(decl));
 #define IREE_HAL_AMDGPU_LIBHSA_PFN_result(trace_category, result_type, symbol, \
                                           decl, ...)                           \
   result_type iree_##symbol(                                                   \
@@ -214,6 +216,12 @@
 #undef IREE_HAL_AMDGPU_LIBHSA_PFN_uint64_t
 #undef IREE_HAL_AMDGPU_LIBHSA_PFN_hsa_signal_value_t
 
+static inline void iree_hal_amdgpu_hsa_cleanup_assert_success(
+    hsa_status_t status) {
+  IREE_ASSERT(status == HSA_STATUS_SUCCESS, "HSA cleanup failed: %d",
+              (int)status);
+}
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif  // __cplusplus
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/libhsa_tables.h b/runtime/src/iree/hal/drivers/amdgpu/util/libhsa_tables.h
index 51f9904..b90d077 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/util/libhsa_tables.h
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/libhsa_tables.h
@@ -123,6 +123,11 @@
 IREE_HAL_AMDGPU_LIBHSA_PFN(TRACE_ALWAYS, hsa_status_t, hsa_amd_memory_pool_free,
                            DECL(void* ptr), ARGS(ptr))
 
+IREE_HAL_AMDGPU_LIBHSA_PFN(TRACE_ALWAYS, hsa_status_t, hsa_amd_memory_lock,
+                           DECL(void* host_ptr, size_t size,
+                                hsa_agent_t* agents, int num_agent,
+                                void** agent_ptr),
+                           ARGS(host_ptr, size, agents, num_agent, agent_ptr))
 IREE_HAL_AMDGPU_LIBHSA_PFN(
     TRACE_ALWAYS, hsa_status_t, hsa_amd_memory_lock_to_pool,
     DECL(void* host_ptr, size_t size, hsa_agent_t* agents, int num_agent,
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/libhsa_test.cc b/runtime/src/iree/hal/drivers/amdgpu/util/libhsa_test.cc
index bf4aad0..d252365 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/util/libhsa_test.cc
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/libhsa_test.cc
@@ -24,7 +24,7 @@
       iree_allocator_system(), &libhsa);
   if (!iree_status_is_ok(status)) {
     iree_status_fprint(stderr, status);
-    iree_status_ignore(status);
+    iree_status_free(status);
     GTEST_SKIP() << "HSA not available, skipping tests";
   }
 
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/notification_ring.c b/runtime/src/iree/hal/drivers/amdgpu/util/notification_ring.c
new file mode 100644
index 0000000..032de7b
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/notification_ring.c
@@ -0,0 +1,776 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/util/notification_ring.h"
+
+#include <string.h>
+
+#include "iree/hal/utils/resource_set.h"
+
+// All frontier snapshot sizes are multiples of 16 bytes (header=16,
+// entry=16 each), so positions within the byte ring stay aligned as long
+// as the base is aligned. Verify at compile time.
+static_assert(sizeof(iree_hal_amdgpu_frontier_snapshot_t) % 8 == 0,
+              "frontier snapshot header must be 8-byte aligned size");
+static_assert(sizeof(iree_async_frontier_entry_t) % 8 == 0,
+              "frontier entry must be 8-byte aligned size");
+
+static inline iree_host_size_t
+iree_hal_amdgpu_notification_ring_frontier_offset(
+    const iree_hal_amdgpu_notification_ring_t* ring,
+    iree_host_size_t position) {
+  return position & (ring->frontier_ring.capacity - 1);
+}
+
+static inline uint64_t iree_hal_amdgpu_notification_ring_load_position(
+    const iree_atomic_int64_t* position, iree_memory_order_t memory_order) {
+  return (uint64_t)iree_atomic_load(position, memory_order);
+}
+
+static inline void iree_hal_amdgpu_notification_ring_store_position(
+    iree_atomic_int64_t* position, uint64_t value,
+    iree_memory_order_t memory_order) {
+  iree_atomic_store(position, (int64_t)value, memory_order);
+}
+
+static inline iree_host_size_t
+iree_hal_amdgpu_notification_ring_frontier_snapshot_size(
+    const iree_hal_amdgpu_frontier_snapshot_t* snapshot) {
+  return sizeof(*snapshot) +
+         snapshot->frontier.entry_count * sizeof(iree_async_frontier_entry_t);
+}
+
+iree_status_t iree_hal_amdgpu_reclaim_entry_prepare(
+    iree_hal_amdgpu_reclaim_entry_t* entry, iree_arena_block_pool_t* block_pool,
+    uint16_t count, iree_hal_resource_t*** out_resources) {
+  IREE_ASSERT_ARGUMENT(entry);
+  IREE_ASSERT_ARGUMENT(out_resources);
+  entry->pre_signal_action.fn = NULL;
+  entry->pre_signal_action.user_data = NULL;
+  entry->profile_event_first_position = 0;
+  entry->profile_event_count = 0;
+  entry->queue_device_event_first_position = 0;
+  entry->queue_device_event_count = 0;
+  entry->resource_set = NULL;
+  entry->kernarg_write_position = 0;
+  entry->queue_upload_write_position = 0;
+  entry->count = 0;
+  if (count <= IREE_HAL_AMDGPU_RECLAIM_INLINE_CAPACITY) {
+    entry->resources = entry->inline_resources;
+  } else {
+    iree_host_size_t required_size = 0;
+    IREE_RETURN_IF_ERROR(IREE_STRUCT_LAYOUT(
+        0, &required_size,
+        IREE_STRUCT_FIELD(count, iree_hal_resource_t*, NULL)));
+    if (required_size > block_pool->usable_block_size) {
+      return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
+                              "reclaim overflow (%" PRIhsz
+                              " bytes) exceeds block pool block size (%" PRIhsz
+                              " bytes)",
+                              required_size, block_pool->usable_block_size);
+    }
+    iree_arena_block_t* block = NULL;
+    void* block_ptr = NULL;
+    IREE_TRACE_ZONE_BEGIN(z0);
+    IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, count);
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_arena_block_pool_acquire(block_pool, &block, &block_ptr));
+    entry->resources = (iree_hal_resource_t**)block_ptr;
+    IREE_TRACE_ZONE_END(z0);
+  }
+  *out_resources = entry->resources;
+  return iree_ok_status();
+}
+
+void iree_hal_amdgpu_reclaim_entry_release(
+    iree_hal_amdgpu_reclaim_entry_t* entry,
+    iree_arena_block_pool_t* block_pool) {
+  for (uint16_t i = 0; i < entry->count; ++i) {
+    iree_hal_resource_release(entry->resources[i]);
+  }
+  iree_hal_resource_set_free(entry->resource_set);
+  if (entry->resources != entry->inline_resources && entry->resources != NULL) {
+    IREE_TRACE_ZONE_BEGIN(z0);
+    iree_arena_block_t* block =
+        iree_arena_block_trailer(block_pool, entry->resources);
+    iree_arena_block_pool_release(block_pool, block, block);
+    IREE_TRACE_ZONE_END(z0);
+  }
+  entry->resources = NULL;
+  entry->pre_signal_action.fn = NULL;
+  entry->pre_signal_action.user_data = NULL;
+  entry->profile_event_first_position = 0;
+  entry->profile_event_count = 0;
+  entry->queue_device_event_first_position = 0;
+  entry->queue_device_event_count = 0;
+  entry->resource_set = NULL;
+  entry->kernarg_write_position = 0;
+  entry->queue_upload_write_position = 0;
+  entry->count = 0;
+}
+
+static inline void iree_hal_amdgpu_reclaim_entry_execute_pre_signal_action(
+    iree_hal_amdgpu_reclaim_entry_t* entry, iree_status_t status) {
+  if (!entry->pre_signal_action.fn) return;
+  iree_hal_amdgpu_reclaim_action_fn_t fn = entry->pre_signal_action.fn;
+  void* user_data = entry->pre_signal_action.user_data;
+  entry->pre_signal_action.fn = NULL;
+  entry->pre_signal_action.user_data = NULL;
+  fn(entry, user_data, status);
+}
+
+iree_status_t iree_hal_amdgpu_notification_ring_initialize(
+    const iree_hal_amdgpu_libhsa_t* libhsa, iree_arena_block_pool_t* block_pool,
+    uint32_t capacity, iree_allocator_t host_allocator,
+    iree_hal_amdgpu_notification_ring_t* out_ring) {
+  IREE_ASSERT_ARGUMENT(libhsa);
+  IREE_ASSERT_ARGUMENT(block_pool);
+  IREE_ASSERT_ARGUMENT(out_ring);
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  if (!iree_host_size_is_power_of_two(capacity)) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "notification ring capacity must be a power of two");
+  }
+
+  memset(out_ring, 0, sizeof(*out_ring));
+  out_ring->libhsa = libhsa;
+  out_ring->block_pool = block_pool;
+  out_ring->host_allocator = host_allocator;
+
+  // Allocate hot entries + frontier byte ring + reclaim entries in one block.
+  // Reserve one extra max-size snapshot so a wrap sentinel's tail padding can
+  // coexist with a full hot ring's worth of transition snapshots.
+  iree_host_size_t min_frontier_ring_capacity = 0;
+  if (!iree_host_size_checked_mul_add(
+          IREE_HAL_AMDGPU_MAX_FRONTIER_SNAPSHOT_SIZE,
+          (iree_host_size_t)capacity,
+          IREE_HAL_AMDGPU_MAX_FRONTIER_SNAPSHOT_SIZE,
+          &min_frontier_ring_capacity)) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_make_status(
+        IREE_STATUS_RESOURCE_EXHAUSTED,
+        "notification ring frontier snapshot capacity overflow");
+  }
+  iree_host_size_t frontier_ring_capacity =
+      iree_host_size_next_power_of_two(min_frontier_ring_capacity);
+  if (!iree_host_size_is_power_of_two(frontier_ring_capacity)) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_make_status(
+        IREE_STATUS_RESOURCE_EXHAUSTED,
+        "notification ring frontier snapshot capacity overflow");
+  }
+
+  iree_host_size_t entries_offset = 0;
+  iree_host_size_t frontier_ring_offset = 0;
+  iree_host_size_t reclaim_offset = 0;
+  iree_host_size_t total_size = 0;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, IREE_STRUCT_LAYOUT(
+              0, &total_size,
+              IREE_STRUCT_FIELD(capacity, iree_hal_amdgpu_notification_entry_t,
+                                &entries_offset),
+              IREE_STRUCT_ARRAY_FIELD_ALIGNED(
+                  frontier_ring_capacity, 1, uint8_t,
+                  iree_alignof(iree_hal_amdgpu_frontier_snapshot_t),
+                  &frontier_ring_offset),
+              IREE_STRUCT_FIELD(capacity, iree_hal_amdgpu_reclaim_entry_t,
+                                &reclaim_offset)));
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0,
+      iree_allocator_malloc(host_allocator, total_size, &out_ring->storage));
+  memset(out_ring->storage, 0, total_size);
+  uint8_t* base = (uint8_t*)out_ring->storage;
+  out_ring->entries =
+      (iree_hal_amdgpu_notification_entry_t*)(base + entries_offset);
+  iree_hal_amdgpu_notification_ring_store_position(&out_ring->write, 0,
+                                                   iree_memory_order_release);
+  iree_hal_amdgpu_notification_ring_store_position(&out_ring->read, 0,
+                                                   iree_memory_order_release);
+  out_ring->frontier_ring.data = base + frontier_ring_offset;
+  out_ring->frontier_ring.capacity = frontier_ring_capacity;
+  iree_hal_amdgpu_notification_ring_store_position(
+      &out_ring->frontier_ring.write, 0, iree_memory_order_release);
+  iree_hal_amdgpu_notification_ring_store_position(
+      &out_ring->frontier_ring.read, 0, iree_memory_order_release);
+  out_ring->reclaim_entries =
+      (iree_hal_amdgpu_reclaim_entry_t*)(base + reclaim_offset);
+  out_ring->capacity = capacity;
+  iree_atomic_store(&out_ring->epoch.last_drained, 0,
+                    iree_memory_order_release);
+
+  // Create the epoch signal.
+  iree_status_t status = iree_hsa_amd_signal_create(
+      IREE_LIBHSA(libhsa), IREE_HAL_AMDGPU_EPOCH_INITIAL_VALUE,
+      /*num_consumers=*/0, /*consumers=*/NULL, /*attributes=*/0,
+      &out_ring->epoch.signal);
+
+  if (!iree_status_is_ok(status)) {
+    iree_hal_amdgpu_notification_ring_deinitialize(out_ring);
+  }
+
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+void iree_hal_amdgpu_notification_ring_deinitialize(
+    iree_hal_amdgpu_notification_ring_t* ring) {
+  IREE_ASSERT_ARGUMENT(ring);
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  // Release any outstanding reclaim entries (should have been drained, but
+  // handle partial teardown gracefully).
+  if (ring->reclaim_entries && ring->block_pool) {
+    uint64_t last_drained = (uint64_t)iree_atomic_load(
+        &ring->epoch.last_drained, iree_memory_order_acquire);
+    for (uint64_t epoch = last_drained; epoch < ring->epoch.next_submission;
+         ++epoch) {
+      uint32_t index = (uint32_t)(epoch & (ring->capacity - 1));
+      iree_hal_amdgpu_reclaim_entry_release(&ring->reclaim_entries[index],
+                                            ring->block_pool);
+    }
+  }
+
+  if (ring->epoch.signal.handle) {
+    iree_hal_amdgpu_hsa_cleanup_assert_success(
+        iree_hsa_signal_destroy_raw(ring->libhsa, ring->epoch.signal));
+    ring->epoch.signal.handle = 0;
+  }
+  iree_allocator_free(ring->host_allocator, ring->storage);
+  ring->storage = NULL;
+  ring->entries = NULL;
+  ring->reclaim_entries = NULL;
+  ring->frontier_ring.data = NULL;
+
+  IREE_TRACE_ZONE_END(z0);
+}
+
+hsa_signal_t iree_hal_amdgpu_notification_ring_epoch_signal(
+    const iree_hal_amdgpu_notification_ring_t* ring) {
+  return ring->epoch.signal;
+}
+
+uint64_t iree_hal_amdgpu_notification_ring_advance_epoch(
+    iree_hal_amdgpu_notification_ring_t* ring) {
+  return ++ring->epoch.next_submission;
+}
+
+iree_status_t iree_hal_amdgpu_notification_ring_reserve(
+    const iree_hal_amdgpu_notification_ring_t* ring,
+    iree_host_size_t entry_count, iree_host_size_t frontier_snapshot_count) {
+  IREE_ASSERT_ARGUMENT(ring);
+
+  uint64_t last_drained = (uint64_t)iree_atomic_load(&ring->epoch.last_drained,
+                                                     iree_memory_order_acquire);
+  if (ring->epoch.next_submission - last_drained >= ring->capacity) {
+    return iree_make_status(
+        IREE_STATUS_RESOURCE_EXHAUSTED,
+        "notification ring reclaim capacity exhausted (pending_epochs=%" PRIu64
+        ", capacity=%u)",
+        ring->epoch.next_submission - last_drained, ring->capacity);
+  }
+
+  uint64_t write = iree_hal_amdgpu_notification_ring_load_position(
+      &ring->write, iree_memory_order_relaxed);
+  uint64_t read = iree_hal_amdgpu_notification_ring_load_position(
+      &ring->read, iree_memory_order_acquire);
+  if (entry_count > ring->capacity ||
+      write - read + entry_count > ring->capacity) {
+    return iree_make_status(
+        IREE_STATUS_RESOURCE_EXHAUSTED,
+        "notification ring capacity exhausted (available=%" PRIu64
+        ", required=%" PRIhsz ")",
+        ring->capacity - (write - read), entry_count);
+  }
+
+  if (frontier_snapshot_count == 0) {
+    return iree_ok_status();
+  }
+
+  iree_host_size_t reserved_snapshot_bytes = 0;
+  if (!iree_host_size_checked_mul_add(
+          IREE_HAL_AMDGPU_MAX_FRONTIER_SNAPSHOT_SIZE, frontier_snapshot_count,
+          IREE_HAL_AMDGPU_MAX_FRONTIER_SNAPSHOT_SIZE,
+          &reserved_snapshot_bytes)) {
+    return iree_make_status(
+        IREE_STATUS_RESOURCE_EXHAUSTED,
+        "notification ring frontier snapshot reservation overflow");
+  }
+
+  iree_host_size_t frontier_write =
+      (iree_host_size_t)iree_hal_amdgpu_notification_ring_load_position(
+          &ring->frontier_ring.write, iree_memory_order_relaxed);
+  iree_host_size_t frontier_read =
+      (iree_host_size_t)iree_hal_amdgpu_notification_ring_load_position(
+          &ring->frontier_ring.read, iree_memory_order_acquire);
+  iree_host_size_t frontier_occupied = frontier_write - frontier_read;
+  if (frontier_occupied + reserved_snapshot_bytes >
+      ring->frontier_ring.capacity) {
+    return iree_make_status(
+        IREE_STATUS_RESOURCE_EXHAUSTED,
+        "notification ring frontier snapshot capacity exhausted "
+        "(available=%" PRIhsz ", required=%" PRIhsz ")",
+        ring->frontier_ring.capacity - frontier_occupied,
+        reserved_snapshot_bytes);
+  }
+
+  return iree_ok_status();
+}
+
+bool iree_hal_amdgpu_notification_ring_can_reserve(
+    const iree_hal_amdgpu_notification_ring_t* ring,
+    iree_host_size_t entry_count, iree_host_size_t frontier_snapshot_count) {
+  IREE_ASSERT_ARGUMENT(ring);
+
+  const uint64_t last_drained = (uint64_t)iree_atomic_load(
+      &ring->epoch.last_drained, iree_memory_order_acquire);
+  if (ring->epoch.next_submission - last_drained >= ring->capacity) {
+    return false;
+  }
+
+  const uint64_t write = iree_hal_amdgpu_notification_ring_load_position(
+      &ring->write, iree_memory_order_relaxed);
+  const uint64_t read = iree_hal_amdgpu_notification_ring_load_position(
+      &ring->read, iree_memory_order_acquire);
+  if (entry_count > ring->capacity ||
+      write - read + entry_count > ring->capacity) {
+    return false;
+  }
+
+  if (frontier_snapshot_count == 0) {
+    return true;
+  }
+
+  iree_host_size_t reserved_snapshot_bytes = 0;
+  if (!iree_host_size_checked_mul_add(
+          IREE_HAL_AMDGPU_MAX_FRONTIER_SNAPSHOT_SIZE, frontier_snapshot_count,
+          IREE_HAL_AMDGPU_MAX_FRONTIER_SNAPSHOT_SIZE,
+          &reserved_snapshot_bytes)) {
+    return false;
+  }
+
+  const iree_host_size_t frontier_write =
+      (iree_host_size_t)iree_hal_amdgpu_notification_ring_load_position(
+          &ring->frontier_ring.write, iree_memory_order_relaxed);
+  const iree_host_size_t frontier_read =
+      (iree_host_size_t)iree_hal_amdgpu_notification_ring_load_position(
+          &ring->frontier_ring.read, iree_memory_order_acquire);
+  return frontier_write - frontier_read + reserved_snapshot_bytes <=
+         ring->frontier_ring.capacity;
+}
+
+void iree_hal_amdgpu_notification_ring_push(
+    iree_hal_amdgpu_notification_ring_t* ring, uint64_t submission_epoch,
+    iree_async_semaphore_t* semaphore, uint64_t timeline_value,
+    iree_hal_amdgpu_notification_entry_flags_t flags) {
+  IREE_ASSERT_ARGUMENT(ring);
+  IREE_ASSERT_ARGUMENT(semaphore);
+
+  uint64_t write = iree_hal_amdgpu_notification_ring_load_position(
+      &ring->write, iree_memory_order_relaxed);
+  uint64_t read = iree_hal_amdgpu_notification_ring_load_position(
+      &ring->read, iree_memory_order_acquire);
+  IREE_ASSERT(write - read < ring->capacity, "notification ring overflow");
+
+  uint32_t index = (uint32_t)(write & (ring->capacity - 1));
+  iree_hal_amdgpu_notification_entry_t* entry = &ring->entries[index];
+
+  entry->semaphore = semaphore;
+  entry->timeline_value = timeline_value;
+  entry->submission_epoch = submission_epoch;
+  entry->flags = flags;
+  entry->reserved0 = 0;
+
+  iree_hal_amdgpu_notification_ring_store_position(&ring->write, write + 1,
+                                                   iree_memory_order_release);
+}
+
+void iree_hal_amdgpu_notification_ring_push_frontier_snapshot(
+    iree_hal_amdgpu_notification_ring_t* ring, uint64_t epoch,
+    const iree_async_frontier_t* frontier) {
+  IREE_ASSERT_ARGUMENT(ring);
+  IREE_ASSERT_ARGUMENT(frontier);
+
+  uint8_t entry_count = frontier->entry_count;
+  IREE_ASSERT(entry_count <= IREE_HAL_AMDGPU_MAX_FRONTIER_SNAPSHOT_ENTRY_COUNT,
+              "frontier snapshot exceeds notification ring storage capacity");
+  iree_host_size_t snapshot_size =
+      sizeof(iree_hal_amdgpu_frontier_snapshot_t) +
+      entry_count * sizeof(iree_async_frontier_entry_t);
+
+  iree_host_size_t write =
+      (iree_host_size_t)iree_hal_amdgpu_notification_ring_load_position(
+          &ring->frontier_ring.write, iree_memory_order_relaxed);
+  iree_host_size_t read =
+      (iree_host_size_t)iree_hal_amdgpu_notification_ring_load_position(
+          &ring->frontier_ring.read, iree_memory_order_acquire);
+  IREE_ASSERT(write - read + snapshot_size <= ring->frontier_ring.capacity,
+              "notification ring frontier snapshot overflow");
+
+  iree_host_size_t write_offset =
+      iree_hal_amdgpu_notification_ring_frontier_offset(ring, write);
+  iree_host_size_t remaining = ring->frontier_ring.capacity - write_offset;
+
+  // If the snapshot doesn't fit in the remaining space, write a sentinel
+  // header and wrap to byte 0. All snapshot sizes are multiples of
+  // sizeof(frontier_snapshot_t), so remaining is also a multiple — there
+  // is always room for a sentinel header if remaining > 0.
+  if (remaining < snapshot_size) {
+    IREE_ASSERT(write - read + remaining + snapshot_size <=
+                    ring->frontier_ring.capacity,
+                "notification ring frontier snapshot overflow");
+    if (remaining >= sizeof(iree_hal_amdgpu_frontier_snapshot_t)) {
+      iree_hal_amdgpu_frontier_snapshot_t* sentinel =
+          (iree_hal_amdgpu_frontier_snapshot_t*)(ring->frontier_ring.data +
+                                                 write_offset);
+      sentinel->frontier.entry_count =
+          IREE_HAL_AMDGPU_FRONTIER_SNAPSHOT_SENTINEL;
+    }
+    write += remaining;
+    write_offset = 0;
+  }
+
+  // Write the snapshot at the current position.
+  iree_hal_amdgpu_frontier_snapshot_t* snapshot =
+      (iree_hal_amdgpu_frontier_snapshot_t*)(ring->frontier_ring.data +
+                                             write_offset);
+  snapshot->epoch = epoch;
+  snapshot->frontier.entry_count = entry_count;
+  memset(snapshot->frontier.reserved, 0, sizeof(snapshot->frontier.reserved));
+  if (entry_count > 0) {
+    memcpy(snapshot + 1, frontier->entries,
+           entry_count * sizeof(iree_async_frontier_entry_t));
+  }
+
+  iree_hal_amdgpu_notification_ring_store_position(&ring->frontier_ring.write,
+                                                   write + snapshot_size,
+                                                   iree_memory_order_release);
+}
+
+static const iree_hal_amdgpu_frontier_snapshot_t*
+iree_hal_amdgpu_notification_ring_frontier_snapshot_at(
+    iree_hal_amdgpu_notification_ring_t* ring, iree_host_size_t* inout_read) {
+  iree_host_size_t read_offset =
+      iree_hal_amdgpu_notification_ring_frontier_offset(ring, *inout_read);
+  const iree_hal_amdgpu_frontier_snapshot_t* snapshot =
+      (const iree_hal_amdgpu_frontier_snapshot_t*)(ring->frontier_ring.data +
+                                                   read_offset);
+  if (snapshot->frontier.entry_count ==
+      IREE_HAL_AMDGPU_FRONTIER_SNAPSHOT_SENTINEL) {
+    *inout_read += ring->frontier_ring.capacity - read_offset;
+    snapshot =
+        (const iree_hal_amdgpu_frontier_snapshot_t*)ring->frontier_ring.data;
+  }
+  return snapshot;
+}
+
+static void iree_hal_amdgpu_notification_ring_discard_stale_frontier_snapshots(
+    iree_hal_amdgpu_notification_ring_t* ring, uint64_t last_drained_epoch) {
+  iree_host_size_t read =
+      (iree_host_size_t)iree_hal_amdgpu_notification_ring_load_position(
+          &ring->frontier_ring.read, iree_memory_order_relaxed);
+  iree_host_size_t write =
+      (iree_host_size_t)iree_hal_amdgpu_notification_ring_load_position(
+          &ring->frontier_ring.write, iree_memory_order_acquire);
+  const iree_host_size_t original_read = read;
+  while (read < write) {
+    iree_host_size_t snapshot_read = read;
+    const iree_hal_amdgpu_frontier_snapshot_t* snapshot =
+        iree_hal_amdgpu_notification_ring_frontier_snapshot_at(ring,
+                                                               &snapshot_read);
+    if (snapshot->epoch > last_drained_epoch) break;
+    read = snapshot_read +
+           iree_hal_amdgpu_notification_ring_frontier_snapshot_size(snapshot);
+  }
+  if (read != original_read) {
+    iree_hal_amdgpu_notification_ring_store_position(
+        &ring->frontier_ring.read, read, iree_memory_order_release);
+  }
+}
+
+// Reads the next frontier snapshot from the frontier byte ring. Returns a
+// pointer to an iree_async_frontier_t that can be passed to semaphore_signal.
+// The returned frontier is only valid until the next read (it points into the
+// ring buffer or into a stack-local single_frontier).
+//
+// If the ring is empty, returns |fallback|.
+static const iree_async_frontier_t*
+iree_hal_amdgpu_notification_ring_read_frontier_snapshot(
+    iree_hal_amdgpu_notification_ring_t* ring,
+    const iree_async_frontier_t* fallback) {
+  iree_host_size_t read =
+      (iree_host_size_t)iree_hal_amdgpu_notification_ring_load_position(
+          &ring->frontier_ring.read, iree_memory_order_relaxed);
+  iree_host_size_t write =
+      (iree_host_size_t)iree_hal_amdgpu_notification_ring_load_position(
+          &ring->frontier_ring.write, iree_memory_order_acquire);
+  if (read == write) {
+    return fallback;
+  }
+
+  const iree_hal_amdgpu_frontier_snapshot_t* snapshot =
+      iree_hal_amdgpu_notification_ring_frontier_snapshot_at(ring, &read);
+
+  iree_host_size_t snapshot_size =
+      iree_hal_amdgpu_notification_ring_frontier_snapshot_size(snapshot);
+  iree_hal_amdgpu_notification_ring_store_position(&ring->frontier_ring.read,
+                                                   read + snapshot_size,
+                                                   iree_memory_order_release);
+
+  // The snapshot's frontier header is followed by entries[] and is
+  // layout-compatible with iree_async_frontier_t. Return a pointer to it —
+  // valid until the next read advances past it.
+  return (const iree_async_frontier_t*)&snapshot->frontier;
+}
+
+static const iree_async_frontier_t*
+iree_hal_amdgpu_notification_ring_read_span_frontier(
+    iree_hal_amdgpu_notification_ring_t* ring,
+    iree_hal_amdgpu_notification_entry_flags_t flags,
+    bool has_transition_snapshot, const iree_async_frontier_t* fallback) {
+  if (!has_transition_snapshot ||
+      iree_any_bit_set(
+          flags,
+          IREE_HAL_AMDGPU_NOTIFICATION_ENTRY_FLAG_OMIT_FRONTIER_SNAPSHOT)) {
+    return fallback;
+  }
+  return iree_hal_amdgpu_notification_ring_read_frontier_snapshot(ring,
+                                                                  fallback);
+}
+
+iree_host_size_t iree_hal_amdgpu_notification_ring_drain_reclaim_positions(
+    iree_hal_amdgpu_notification_ring_t* ring,
+    const iree_async_frontier_t* fallback_frontier,
+    iree_hal_amdgpu_reclaim_retire_fn_t retire_fn, void* retire_user_data,
+    iree_hal_amdgpu_reclaim_positions_t* out_reclaim_positions) {
+  IREE_ASSERT_ARGUMENT(ring);
+  IREE_ASSERT_ARGUMENT(out_reclaim_positions);
+
+  memset(out_reclaim_positions, 0, sizeof(*out_reclaim_positions));
+
+  // Early out if the ring was never initialized or already deinitialized.
+  if (!ring->epoch.signal.handle) return 0;
+
+  hsa_signal_value_t signal_value = iree_hsa_signal_load_scacquire(
+      IREE_LIBHSA(ring->libhsa), ring->epoch.signal);
+  uint64_t current_epoch =
+      (uint64_t)(IREE_HAL_AMDGPU_EPOCH_INITIAL_VALUE - signal_value);
+
+  uint64_t previous_drained = (uint64_t)iree_atomic_load(
+      &ring->epoch.last_drained, iree_memory_order_relaxed);
+  if (current_epoch <= previous_drained) return 0;
+  iree_hal_amdgpu_notification_ring_discard_stale_frontier_snapshots(
+      ring, previous_drained);
+
+  // Execute all pre-signal completion actions first so wrapper-visible state
+  // transitions happen-before any semaphore publication for the completed
+  // epochs. This is intentionally a separate pass from post-signal resource
+  // release to keep queue retire ordering explicit.
+  for (uint64_t epoch = previous_drained; epoch < current_epoch; ++epoch) {
+    uint32_t reclaim_index = (uint32_t)(epoch & (ring->capacity - 1));
+    iree_hal_amdgpu_reclaim_entry_execute_pre_signal_action(
+        &ring->reclaim_entries[reclaim_index], iree_ok_status());
+  }
+  if (retire_fn) {
+    for (uint64_t epoch = previous_drained; epoch < current_epoch; ++epoch) {
+      uint32_t reclaim_index = (uint32_t)(epoch & (ring->capacity - 1));
+      retire_fn(&ring->reclaim_entries[reclaim_index], epoch + 1,
+                retire_user_data);
+    }
+  }
+
+  // Single-slot coalescing: accumulate consecutive same-semaphore entries
+  // and signal once per unique semaphore span. This turns N signals into 1
+  // for the common case of N dispatches on the same stream semaphore.
+  iree_async_semaphore_t* pending_semaphore = NULL;
+  uint64_t pending_value = 0;
+  iree_hal_amdgpu_notification_entry_flags_t pending_flags =
+      IREE_HAL_AMDGPU_NOTIFICATION_ENTRY_FLAG_NONE;
+  iree_host_size_t drained_count = 0;
+  uint64_t read = iree_hal_amdgpu_notification_ring_load_position(
+      &ring->read, iree_memory_order_relaxed);
+  uint64_t write = iree_hal_amdgpu_notification_ring_load_position(
+      &ring->write, iree_memory_order_acquire);
+
+  while (read < write) {
+    uint32_t index = (uint32_t)(read & (ring->capacity - 1));
+    iree_hal_amdgpu_notification_entry_t* entry = &ring->entries[index];
+    if (entry->submission_epoch > current_epoch) break;
+
+    if (entry->semaphore != pending_semaphore) {
+      // Semaphore changed — flush the previous span.
+      if (pending_semaphore != NULL) {
+        const iree_async_frontier_t* frontier =
+            iree_hal_amdgpu_notification_ring_read_span_frontier(
+                ring, pending_flags, /*has_transition_snapshot=*/true,
+                fallback_frontier);
+        iree_status_t signal_status = iree_async_semaphore_signal_untainted(
+            pending_semaphore, pending_value, frontier);
+        if (IREE_UNLIKELY(!iree_status_is_ok(signal_status))) {
+          iree_async_semaphore_fail(pending_semaphore, signal_status);
+        }
+      }
+      pending_semaphore = entry->semaphore;
+      pending_value = entry->timeline_value;
+      pending_flags = entry->flags;
+    } else {
+      // Same semaphore, later epoch — take the later value (monotonic).
+      pending_value = entry->timeline_value;
+      pending_flags |= entry->flags;
+    }
+
+    ++read;
+    ++drained_count;
+  }
+
+  // Flush the final span. If the next unread entry is a different semaphore
+  // then a transition snapshot for this completed span has already been
+  // written, even though that next entry has not completed yet.
+  if (pending_semaphore != NULL) {
+    bool has_transition_snapshot = false;
+    if (read < write) {
+      uint32_t next_index = (uint32_t)(read & (ring->capacity - 1));
+      const iree_hal_amdgpu_notification_entry_t* next_entry =
+          &ring->entries[next_index];
+      has_transition_snapshot = next_entry->semaphore != pending_semaphore;
+    }
+    const iree_async_frontier_t* frontier =
+        iree_hal_amdgpu_notification_ring_read_span_frontier(
+            ring, pending_flags, has_transition_snapshot, fallback_frontier);
+    iree_status_t signal_status = iree_async_semaphore_signal_untainted(
+        pending_semaphore, pending_value, frontier);
+    if (IREE_UNLIKELY(!iree_status_is_ok(signal_status))) {
+      iree_async_semaphore_fail(pending_semaphore, signal_status);
+    }
+  }
+
+  // Release retained resources for all completed epochs.
+  uint64_t highest_kernarg_position = 0;
+  uint64_t highest_queue_upload_position = 0;
+  for (uint64_t epoch = previous_drained; epoch < current_epoch; ++epoch) {
+    uint32_t reclaim_index = (uint32_t)(epoch & (ring->capacity - 1));
+    uint64_t kernarg_write_position =
+        ring->reclaim_entries[reclaim_index].kernarg_write_position;
+    if (kernarg_write_position > highest_kernarg_position) {
+      highest_kernarg_position = kernarg_write_position;
+    }
+    uint64_t queue_upload_write_position =
+        ring->reclaim_entries[reclaim_index].queue_upload_write_position;
+    if (queue_upload_write_position > highest_queue_upload_position) {
+      highest_queue_upload_position = queue_upload_write_position;
+    }
+    iree_hal_amdgpu_reclaim_entry_release(&ring->reclaim_entries[reclaim_index],
+                                          ring->block_pool);
+  }
+  iree_atomic_store(&ring->epoch.last_drained, (int64_t)current_epoch,
+                    iree_memory_order_release);
+
+  iree_hal_amdgpu_notification_ring_store_position(&ring->read, read,
+                                                   iree_memory_order_release);
+
+  out_reclaim_positions->kernarg_write_position = highest_kernarg_position;
+  out_reclaim_positions->queue_upload_write_position =
+      highest_queue_upload_position;
+  return drained_count;
+}
+
+iree_host_size_t iree_hal_amdgpu_notification_ring_drain(
+    iree_hal_amdgpu_notification_ring_t* ring,
+    const iree_async_frontier_t* fallback_frontier,
+    iree_hal_amdgpu_reclaim_retire_fn_t retire_fn, void* retire_user_data,
+    uint64_t* out_kernarg_reclaim_position) {
+  IREE_ASSERT_ARGUMENT(out_kernarg_reclaim_position);
+  iree_hal_amdgpu_reclaim_positions_t reclaim_positions = {0};
+  iree_host_size_t drained_count =
+      iree_hal_amdgpu_notification_ring_drain_reclaim_positions(
+          ring, fallback_frontier, retire_fn, retire_user_data,
+          &reclaim_positions);
+  *out_kernarg_reclaim_position = reclaim_positions.kernarg_write_position;
+  return drained_count;
+}
+
+iree_host_size_t iree_hal_amdgpu_notification_ring_fail_all_reclaim_positions(
+    iree_hal_amdgpu_notification_ring_t* ring, iree_status_t error_status,
+    iree_hal_amdgpu_reclaim_positions_t* out_reclaim_positions) {
+  IREE_ASSERT_ARGUMENT(ring);
+  IREE_ASSERT_ARGUMENT(out_reclaim_positions);
+
+  memset(out_reclaim_positions, 0, sizeof(*out_reclaim_positions));
+
+  iree_host_size_t failed_count = 0;
+  uint64_t read = iree_hal_amdgpu_notification_ring_load_position(
+      &ring->read, iree_memory_order_relaxed);
+  uint64_t write = iree_hal_amdgpu_notification_ring_load_position(
+      &ring->write, iree_memory_order_acquire);
+  while (read < write) {
+    uint32_t index = (uint32_t)(read & (ring->capacity - 1));
+    iree_hal_amdgpu_notification_entry_t* entry = &ring->entries[index];
+
+    // Check-before-clone: only clone and fail if this semaphore hasn't been
+    // failed yet. Avoids cloning status objects (which contain stack traces)
+    // for every entry when many entries share a semaphore. The TOCTOU between
+    // the load and the CAS inside semaphore_fail is harmless — fail_all runs
+    // single-threaded on the proactor, so the only way failure_status is
+    // non-zero is from an earlier entry in this same loop.
+    if (iree_atomic_load(&entry->semaphore->failure_status,
+                         iree_memory_order_acquire) == 0) {
+      iree_async_semaphore_fail(entry->semaphore,
+                                iree_status_clone(error_status));
+    }
+
+    ++read;
+    ++failed_count;
+  }
+
+  // Release retained resources for all epochs.
+  uint64_t highest_kernarg_position = 0;
+  uint64_t highest_queue_upload_position = 0;
+  uint64_t last_drained = (uint64_t)iree_atomic_load(&ring->epoch.last_drained,
+                                                     iree_memory_order_relaxed);
+  for (uint64_t epoch = last_drained; epoch < ring->epoch.next_submission;
+       ++epoch) {
+    uint32_t reclaim_index = (uint32_t)(epoch & (ring->capacity - 1));
+    uint64_t kernarg_write_position =
+        ring->reclaim_entries[reclaim_index].kernarg_write_position;
+    if (kernarg_write_position > highest_kernarg_position) {
+      highest_kernarg_position = kernarg_write_position;
+    }
+    uint64_t queue_upload_write_position =
+        ring->reclaim_entries[reclaim_index].queue_upload_write_position;
+    if (queue_upload_write_position > highest_queue_upload_position) {
+      highest_queue_upload_position = queue_upload_write_position;
+    }
+    iree_hal_amdgpu_reclaim_entry_execute_pre_signal_action(
+        &ring->reclaim_entries[reclaim_index], error_status);
+    iree_hal_amdgpu_reclaim_entry_release(&ring->reclaim_entries[reclaim_index],
+                                          ring->block_pool);
+  }
+  iree_atomic_store(&ring->epoch.last_drained,
+                    (int64_t)ring->epoch.next_submission,
+                    iree_memory_order_release);
+
+  iree_hal_amdgpu_notification_ring_store_position(&ring->read, read,
+                                                   iree_memory_order_release);
+
+  out_reclaim_positions->kernarg_write_position = highest_kernarg_position;
+  out_reclaim_positions->queue_upload_write_position =
+      highest_queue_upload_position;
+  return failed_count;
+}
+
+iree_host_size_t iree_hal_amdgpu_notification_ring_fail_all(
+    iree_hal_amdgpu_notification_ring_t* ring, iree_status_t error_status,
+    uint64_t* out_kernarg_reclaim_position) {
+  IREE_ASSERT_ARGUMENT(out_kernarg_reclaim_position);
+  iree_hal_amdgpu_reclaim_positions_t reclaim_positions = {0};
+  iree_host_size_t failed_count =
+      iree_hal_amdgpu_notification_ring_fail_all_reclaim_positions(
+          ring, error_status, &reclaim_positions);
+  *out_kernarg_reclaim_position = reclaim_positions.kernarg_write_position;
+  return failed_count;
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/notification_ring.h b/runtime/src/iree/hal/drivers/amdgpu/util/notification_ring.h
new file mode 100644
index 0000000..d02119a
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/notification_ring.h
@@ -0,0 +1,471 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+// Epoch-driven notification ring for mapping GPU submission completions to
+// async semaphore signals. Each queue has one notification ring; the ring
+// maps monotonic submission epochs to pending semaphore signals that the
+// host queue drains when the GPU advances the epoch.
+//
+// The epoch signal is a single hsa_signal_t initialized to a large value
+// and decremented by 1 on each submission's last AQL packet completion.
+// The current epoch (count of completed submissions) is:
+//   INITIAL_VALUE - hsa_signal_load(epoch_signal)
+//
+// The ring uses a hot/cold split for cache-friendly drain:
+//   - Hot entries (32 bytes each): semaphore, value, epoch, and reserved
+//     padding. Stored in a power-of-two ring buffer, dense and L1-resident for
+//     the coalescing scan.
+//   - Cold frontier snapshots (variable-size): written to a byte ring only
+//     at semaphore transition points that still have an undrained span. The
+//     drain reads snapshots only when flushing a span that actually has a
+//     transition snapshot, avoiding per-entry frontier overhead.
+//
+// The drain coalesces consecutive same-semaphore entries into a single
+// signal call using a single-slot accumulator. For N dispatches on the same
+// stream semaphore, this produces 1 signal instead of N.
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_UTIL_NOTIFICATION_RING_H_
+#define IREE_HAL_DRIVERS_AMDGPU_UTIL_NOTIFICATION_RING_H_
+
+#include "iree/async/frontier.h"
+#include "iree/async/semaphore.h"
+#include "iree/base/api.h"
+#include "iree/base/internal/arena.h"
+#include "iree/base/internal/atomics.h"
+#include "iree/hal/drivers/amdgpu/util/libhsa.h"
+#include "iree/hal/resource.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+typedef struct iree_hal_resource_set_t iree_hal_resource_set_t;
+
+// Initial value for the epoch signal. The CP decrements by 1 on each
+// submission's last packet completion. The epoch (number of completed
+// submissions) is: INITIAL_VALUE - hsa_signal_load(epoch_signal).
+// INT64_MAX/2 gives ~4.6e18 decrements before overflow (~146 years at 1
+// billion submissions/second).
+#define IREE_HAL_AMDGPU_EPOCH_INITIAL_VALUE (INT64_MAX / 2)
+
+// Default notification ring capacity.
+#define IREE_HAL_AMDGPU_DEFAULT_NOTIFICATION_CAPACITY 1024
+
+// Sentinel value in frontier_snapshot_t::frontier.entry_count indicating the
+// reader should wrap to byte 0 of the frontier ring. Written when a snapshot
+// doesn't fit in the remaining buffer space.
+#define IREE_HAL_AMDGPU_FRONTIER_SNAPSHOT_SENTINEL 0xFF
+
+// Maximum number of frontier entries in one snapshot.
+//
+// This must be >= the queue frontier capacity used by host_queue.c, because
+// transition snapshots serialize the queue's accumulated frontier verbatim.
+#define IREE_HAL_AMDGPU_MAX_FRONTIER_SNAPSHOT_ENTRY_COUNT 64
+
+// Maximum size of a single frontier snapshot in bytes.
+#define IREE_HAL_AMDGPU_MAX_FRONTIER_SNAPSHOT_SIZE     \
+  (sizeof(iree_hal_amdgpu_frontier_snapshot_t) +       \
+   IREE_HAL_AMDGPU_MAX_FRONTIER_SNAPSHOT_ENTRY_COUNT * \
+       sizeof(iree_async_frontier_entry_t))
+
+//===----------------------------------------------------------------------===//
+// iree_hal_amdgpu_notification_entry_t (hot, 32 bytes)
+//===----------------------------------------------------------------------===//
+
+typedef uint32_t iree_hal_amdgpu_notification_entry_flags_t;
+enum iree_hal_amdgpu_notification_entry_flag_bits_t {
+  IREE_HAL_AMDGPU_NOTIFICATION_ENTRY_FLAG_NONE = 0u,
+  // This entry's same-semaphore span does not own a cold frontier snapshot.
+  // Drain must signal the semaphore with |fallback_frontier| instead of
+  // consuming from the cold frontier ring.
+  IREE_HAL_AMDGPU_NOTIFICATION_ENTRY_FLAG_OMIT_FRONTIER_SNAPSHOT = 1u << 0,
+};
+
+// A pending semaphore signal associated with a submission epoch. Contains
+// only the data needed for the drain coalescing scan — frontier data is
+// stored separately in the frontier snapshot ring.
+//
+// Entries are stored in a power-of-two ring buffer (32 bytes each, two per
+// 64-byte cache line).
+typedef struct iree_hal_amdgpu_notification_entry_t {
+  // Semaphore to signal when the epoch is reached. Not retained — the caller
+  // ensures the semaphore outlives the notification (queue teardown waits for
+  // all in-flight work and drains before destroying semaphores).
+  iree_async_semaphore_t* semaphore;
+  // Timeline value to signal the semaphore to.
+  uint64_t timeline_value;
+  // One-based submission epoch on this queue. When the queue's current epoch
+  // reaches this value (current_epoch >= submission_epoch), this entry is ready
+  // to drain.
+  uint64_t submission_epoch;
+  // Flags controlling how this entry is drained.
+  iree_hal_amdgpu_notification_entry_flags_t flags;
+  // Reserved padding to keep hot entries at 32 bytes (2 per 64-byte cache
+  // line).
+  uint32_t reserved0;
+} iree_hal_amdgpu_notification_entry_t;
+static_assert(sizeof(iree_hal_amdgpu_notification_entry_t) == 32,
+              "notification entries must remain 32 bytes");
+
+//===----------------------------------------------------------------------===//
+// iree_hal_amdgpu_frontier_snapshot_t (cold, variable-size)
+//===----------------------------------------------------------------------===//
+
+// Frontier snapshot written at each semaphore transition point. Records the
+// queue's accumulated frontier at the end of a same-semaphore span. The
+// drain reads one snapshot per coalesced flush.
+//
+// Variable-size: the header is followed by frontier.entry_count frontier
+// entries. Total size: sizeof(header) + frontier.entry_count *
+// sizeof(iree_async_frontier_entry_t).
+typedef struct iree_hal_amdgpu_frontier_snapshot_t {
+  // Epoch at the end of the same-semaphore span this snapshot covers.
+  uint64_t epoch;
+  // Frontier header followed by frontier.entry_count entries. An entry_count
+  // of 0xFF is a sentinel indicating the reader should wrap to byte 0 (not a
+  // real snapshot).
+  iree_async_frontier_header_t frontier;
+  // Followed by frontier.entry_count x iree_async_frontier_entry_t.
+} iree_hal_amdgpu_frontier_snapshot_t;
+
+//===----------------------------------------------------------------------===//
+// iree_hal_amdgpu_reclaim_entry_t (cold, epoch-indexed)
+//===----------------------------------------------------------------------===//
+
+// Number of resource pointers stored inline in each reclaim entry. Covers
+// the common case of 1 signal semaphore + up to 7 operation resources
+// (buffers, executables, command buffers) without any block pool allocation.
+// Dispatches with more than 7 bindings spill to a block-pool-allocated array.
+#define IREE_HAL_AMDGPU_RECLAIM_INLINE_CAPACITY 8
+
+typedef struct iree_hal_amdgpu_reclaim_entry_t iree_hal_amdgpu_reclaim_entry_t;
+
+// Infallible callback executed for one completed epoch before that epoch's
+// user-visible semaphore signals are published.
+//
+// This is the pre-signal state-transition lane for operations like transient
+// buffer commit/decommit. |status| is OK for normal GPU completion and a
+// borrowed queue/device failure status when the queue fails outstanding work.
+// Any object referenced by |user_data| must also be retained in the reclaim
+// entry's post-signal |resources| array if its lifetime must extend past
+// callback execution.
+typedef void(IREE_API_PTR* iree_hal_amdgpu_reclaim_action_fn_t)(
+    iree_hal_amdgpu_reclaim_entry_t* entry, void* user_data,
+    iree_status_t status);
+
+typedef struct iree_hal_amdgpu_reclaim_action_t {
+  // Callback invoked with |user_data| when the epoch is retired or failed.
+  iree_hal_amdgpu_reclaim_action_fn_t fn;
+  // Opaque callback state retained through another reclaim resource if needed.
+  void* user_data;
+} iree_hal_amdgpu_reclaim_action_t;
+
+// Queue-owned ring positions retired by one or more completed epochs.
+typedef struct iree_hal_amdgpu_reclaim_positions_t {
+  // Highest kernarg ring write position retired by the completed epochs.
+  uint64_t kernarg_write_position;
+  // Highest queue upload ring write position retired by the completed epochs.
+  uint64_t queue_upload_write_position;
+} iree_hal_amdgpu_reclaim_positions_t;
+
+// Optional callback invoked for one completed epoch after pre-signal actions
+// execute and before user-visible semaphore signals publish.
+typedef void(IREE_API_PTR* iree_hal_amdgpu_reclaim_retire_fn_t)(
+    iree_hal_amdgpu_reclaim_entry_t* entry, uint64_t epoch, void* user_data);
+
+// Per-epoch resource reclaim entry. Stores retained HAL resource pointers
+// that are released when the epoch completes (drain time). One entry per
+// advance_epoch call, indexed by epoch & (capacity - 1).
+//
+// Resources include signal semaphores (the notification entry stores
+// unretained semaphore pointers — the reclaim entry keeps them alive)
+// and operation-specific resources (buffers, executables, command buffers).
+struct iree_hal_amdgpu_reclaim_entry_t {
+  // Pointer to the retained-resource pointer array. Points to inline_resources
+  // when count <= INLINE_CAPACITY, otherwise to a block-pool-allocated array.
+  iree_hal_resource_t** resources;
+  // Optional resource set released with this entry after user signals publish.
+  iree_hal_resource_set_t* resource_set;
+  // One bounded pre-signal action for this epoch. Executed before any
+  // user-visible signal publication for the epoch when drain observes normal
+  // completion, and during fail_all with the failure status before resources
+  // are released.
+  iree_hal_amdgpu_reclaim_action_t pre_signal_action;
+  // First dispatch profiling event position reserved by this epoch.
+  // Valid only when |profile_event_count| is non-zero.
+  uint64_t profile_event_first_position;
+  // First queue device profiling event position reserved by this epoch.
+  // Valid only when |queue_device_event_count| is non-zero.
+  uint64_t queue_device_event_first_position;
+  // Kernarg ring write position at the time of this submission. Drain/fail_all
+  // report the highest position across retired epochs so the caller can reclaim
+  // kernarg blocks. 0 means no kernarg was allocated.
+  uint64_t kernarg_write_position;
+  // Queue upload ring write position at the time of this submission.
+  // Drain/fail_all report the highest position across retired epochs so the
+  // caller can reclaim upload bytes. 0 means no upload bytes were allocated.
+  uint64_t queue_upload_write_position;
+  // Number of dispatch profiling events reserved by this epoch.
+  uint32_t profile_event_count;
+  // Number of queue device profiling events reserved by this epoch.
+  uint32_t queue_device_event_count;
+  // Number of retained resources stored in |resources|.
+  uint16_t count;
+  // Reserved padding for stable layout.
+  uint16_t reserved[1];
+  // Inline retained-resource storage for common small submissions.
+  iree_hal_resource_t*
+      inline_resources[IREE_HAL_AMDGPU_RECLAIM_INLINE_CAPACITY];
+};
+
+// Prepares a reclaim entry for |count| resources. If count fits inline,
+// sets |*out_resources| to the entry's inline storage. Otherwise acquires
+// a block from |block_pool| and sets |*out_resources| to point into it.
+// The caller fills the array with retained resource pointers, sets any
+// queue-owned ring write positions, and sets entry->count before advancing the
+// submission epoch.
+iree_status_t iree_hal_amdgpu_reclaim_entry_prepare(
+    iree_hal_amdgpu_reclaim_entry_t* entry, iree_arena_block_pool_t* block_pool,
+    uint16_t count, iree_hal_resource_t*** out_resources);
+
+// Releases all resources in the entry and returns any overflow block to
+// the pool. Zeros entry->count.
+void iree_hal_amdgpu_reclaim_entry_release(
+    iree_hal_amdgpu_reclaim_entry_t* entry,
+    iree_arena_block_pool_t* block_pool);
+
+//===----------------------------------------------------------------------===//
+// iree_hal_amdgpu_notification_ring_t
+//===----------------------------------------------------------------------===//
+
+// Epoch-driven notification ring with hot/cold split storage.
+//
+// The hot entry ring stores 32-byte entries for the drain coalescing scan.
+// The cold frontier ring stores variable-size frontier snapshots written only
+// at semaphore transition points — one snapshot per same-semaphore span.
+// Both arrays are allocated in a single contiguous block.
+typedef struct iree_hal_amdgpu_notification_ring_t {
+  // HSA API handle for signal operations. Not retained.
+  const iree_hal_amdgpu_libhsa_t* libhsa;
+  // Host allocator used for the ring's contiguous storage allocation.
+  iree_allocator_t host_allocator;
+
+  // Monotonic completion counter.
+  struct {
+    // Per-queue hsa_signal_t created at init, destroyed at deinit. Set as
+    // completion_signal on the last AQL packet of each submission; the CP
+    // decrements it by 1 on completion.
+    hsa_signal_t signal;
+    // Next epoch to assign. Incremented by 1 per submission.
+    uint64_t next_submission;
+    // Last epoch observed by drain. The consumer stores with release after
+    // releasing reclaim entries; the submission path acquires this in reserve()
+    // to avoid reusing a still-live reclaim slot for a zero-signal epoch.
+    iree_atomic_int64_t last_drained;
+  } epoch;
+
+  // Hot entry ring (32 bytes per entry, cache-friendly for drain scan).
+  iree_hal_amdgpu_notification_entry_t* entries;
+  // Producer index (submission path advances with a release store after
+  // writing entries). The consumer acquires this before reading entries.
+  iree_atomic_int64_t write;
+  // Consumer index (drain/fail_all advances with a release store after
+  // consuming entries). The producer acquires this before capacity checks.
+  iree_atomic_int64_t read;
+  // Power-of-two ring capacity. Indices are masked by (capacity - 1).
+  uint32_t capacity;
+
+  // Cold frontier snapshot byte ring (variable-size, sparse).
+  // Written at semaphore transition points by the submission path via
+  // push_frontier_snapshot. Read sequentially by drain when a completed span
+  // reaches a different next semaphore; late snapshots for already-drained
+  // spans are discarded before processing new completions.
+  struct {
+    uint8_t* data;
+    // Power-of-two byte capacity. Monotonic byte positions are masked by
+    // (capacity - 1) to derive in-buffer offsets.
+    iree_host_size_t capacity;
+    // Monotonic byte positions, not modulo offsets. Same SPSC release/acquire
+    // contract as the hot entry ring indices.
+    iree_atomic_int64_t write;
+    iree_atomic_int64_t read;
+  } frontier_ring;
+
+  // Block pool for overflow reclaim allocations. Borrowed from the physical
+  // device (NUMA-pinned); valid for the lifetime of the ring.
+  iree_arena_block_pool_t* block_pool;
+
+  // Per-epoch resource reclaim entries. Indexed by
+  // epoch.next_submission & (capacity - 1) on the submission path, drained
+  // in lockstep with the notification entries. Same capacity as the hot
+  // entry ring (one reclaim entry per epoch, bounded by notification capacity).
+  iree_hal_amdgpu_reclaim_entry_t* reclaim_entries;
+
+  // Pointer to the base of the single allocation backing the entry array,
+  // frontier ring buffer, and reclaim entries. Freed in deinitialize.
+  void* storage;
+} iree_hal_amdgpu_notification_ring_t;
+
+// Initializes a notification ring. Creates the epoch signal and allocates
+// the hot entry array, frontier snapshot byte ring, and reclaim entries in
+// a single allocation.
+//
+// |block_pool| is used for overflow reclaim allocations (dispatches with
+// more than IREE_HAL_AMDGPU_RECLAIM_INLINE_CAPACITY resources). Must
+// outlive the ring.
+//
+// |capacity| must be a power of two.
+iree_status_t iree_hal_amdgpu_notification_ring_initialize(
+    const iree_hal_amdgpu_libhsa_t* libhsa, iree_arena_block_pool_t* block_pool,
+    uint32_t capacity, iree_allocator_t host_allocator,
+    iree_hal_amdgpu_notification_ring_t* out_ring);
+
+// Deinitializes the notification ring. Destroys the epoch signal and frees
+// the backing storage.
+//
+// All in-flight work must have completed and been drained before calling.
+void iree_hal_amdgpu_notification_ring_deinitialize(
+    iree_hal_amdgpu_notification_ring_t* ring);
+
+// Returns the epoch signal for use as completion_signal on AQL packets.
+hsa_signal_t iree_hal_amdgpu_notification_ring_epoch_signal(
+    const iree_hal_amdgpu_notification_ring_t* ring);
+
+// Advances the submission epoch counter and returns the assigned one-based
+// frontier epoch. Called by the submission path after all AQL packets for a
+// submission have been written to the hardware queue.
+//
+// Epochs are one-based because the device-side wait formula in
+// host_queue.c (compare_value = INITIAL_VALUE - target_epoch + 1) collapses
+// to "signal < INITIAL_VALUE + 1" for target_epoch == 0, which is
+// trivially true for any signal value. With one-based epochs, target == 0
+// is reserved for "no submission has happened yet" and the formula only
+// fires once at least one completion has been observed.
+uint64_t iree_hal_amdgpu_notification_ring_advance_epoch(
+    iree_hal_amdgpu_notification_ring_t* ring);
+
+// Verifies that the ring has enough space for |entry_count| notification
+// entries and up to |frontier_snapshot_count| max-size frontier snapshots.
+//
+// The snapshot reservation is conservative: it includes one extra max-size
+// snapshot worth of bytes for a wrap sentinel/tail gap. Callers should check
+// this before emitting AQL packets and calling push/push_frontier_snapshot so
+// debug-only overflow asserts remain programmer-error-only.
+iree_status_t iree_hal_amdgpu_notification_ring_reserve(
+    const iree_hal_amdgpu_notification_ring_t* ring,
+    iree_host_size_t entry_count, iree_host_size_t frontier_snapshot_count);
+
+// Returns true if reserve() would currently succeed for the requested counts.
+//
+// This is intended for continuation-style submissions that can park and retry
+// after drain instead of treating temporary ring pressure as an error. Callers
+// must still use reserve() on paths where exhaustion is a terminal status.
+bool iree_hal_amdgpu_notification_ring_can_reserve(
+    const iree_hal_amdgpu_notification_ring_t* ring,
+    iree_host_size_t entry_count, iree_host_size_t frontier_snapshot_count);
+
+// Returns the reclaim entry for the next submission. Reclaim entries are
+// indexed by the zero-based completion interval, so callers must fill this
+// before calling advance_epoch.
+static inline iree_hal_amdgpu_reclaim_entry_t*
+iree_hal_amdgpu_notification_ring_reclaim_entry(
+    iree_hal_amdgpu_notification_ring_t* ring) {
+  return &ring->reclaim_entries[ring->epoch.next_submission &
+                                (ring->capacity - 1)];
+}
+
+// Pushes a notification entry for a semaphore signal at the given epoch.
+//
+// The caller must ensure the ring has capacity by calling
+// iree_hal_amdgpu_notification_ring_reserve() before publishing AQL packets and
+// then pushing the corresponding notification entries.
+//
+// Frontier data is NOT stored per-entry. The caller must separately call
+// push_frontier_snapshot at semaphore transition points.
+void iree_hal_amdgpu_notification_ring_push(
+    iree_hal_amdgpu_notification_ring_t* ring, uint64_t submission_epoch,
+    iree_async_semaphore_t* semaphore, uint64_t timeline_value,
+    iree_hal_amdgpu_notification_entry_flags_t flags);
+
+// Pushes a frontier snapshot to the frontier byte ring. Called by the
+// submission path when the signal semaphore changes between consecutive
+// submissions (semaphore transition). |epoch| is the epoch of the last entry
+// in the ending same-semaphore span. |frontier| is the queue's accumulated
+// frontier at that point (before merging new dependencies). Must be a valid
+// frontier (entry_count may be 0).
+//
+// The drain reads snapshots when flushing a completed span that has reached a
+// different next semaphore. A final flush with no visible transition uses the
+// fallback_frontier provided to drain, and late snapshots whose covered epoch
+// has already drained are discarded on the next drain.
+void iree_hal_amdgpu_notification_ring_push_frontier_snapshot(
+    iree_hal_amdgpu_notification_ring_t* ring, uint64_t epoch,
+    const iree_async_frontier_t* frontier);
+
+// Drains all completed notification entries, coalescing consecutive
+// same-semaphore entries into a single signal call (single-slot accumulator).
+//
+// |fallback_frontier| is used for the final coalesced flush when no frontier
+// snapshot exists for the last same-semaphore span. It may be NULL if the
+// caller has already merged frontier state into the semaphore at submission
+// time and only needs completion-time timeline advancement/untainting.
+//
+// Stores the highest queue-owned ring positions across all retired epochs in
+// |out_reclaim_positions|. Positions are set to 0 if no epochs were retired.
+//
+// |retire_fn|, when provided, is called once per retired epoch before
+// user-visible semaphore publication. It must not publish user-visible
+// completion itself.
+//
+// Returns the number of entries drained.
+iree_host_size_t iree_hal_amdgpu_notification_ring_drain_reclaim_positions(
+    iree_hal_amdgpu_notification_ring_t* ring,
+    const iree_async_frontier_t* fallback_frontier,
+    iree_hal_amdgpu_reclaim_retire_fn_t retire_fn, void* retire_user_data,
+    iree_hal_amdgpu_reclaim_positions_t* out_reclaim_positions);
+
+// Stores the highest kernarg_write_position across all retired epochs in
+// |out_kernarg_reclaim_position|. Set to 0 if no epochs were retired.
+//
+// |retire_fn|, when provided, is called once per retired epoch before
+// user-visible semaphore publication. It must not publish user-visible
+// completion itself.
+//
+// Returns the number of entries drained.
+iree_host_size_t iree_hal_amdgpu_notification_ring_drain(
+    iree_hal_amdgpu_notification_ring_t* ring,
+    const iree_async_frontier_t* fallback_frontier,
+    iree_hal_amdgpu_reclaim_retire_fn_t retire_fn, void* retire_user_data,
+    uint64_t* out_kernarg_reclaim_position);
+
+// Fails all pending notification entries with |error_status|.
+// Each unique semaphore is failed exactly once; duplicate entries for the same
+// semaphore skip the clone+fail (check-before-clone: status objects contain
+// stack traces and are not free to clone).
+//
+// |error_status| is borrowed, not consumed — the caller retains ownership.
+//
+// Stores the highest queue-owned ring positions across all failed entries in
+// |out_reclaim_positions| (same semantics as drain_reclaim_positions).
+//
+// Returns the number of entries failed.
+iree_host_size_t iree_hal_amdgpu_notification_ring_fail_all_reclaim_positions(
+    iree_hal_amdgpu_notification_ring_t* ring, iree_status_t error_status,
+    iree_hal_amdgpu_reclaim_positions_t* out_reclaim_positions);
+
+// Stores the highest kernarg_write_position across all failed entries in
+// |out_kernarg_reclaim_position| (same semantics as drain).
+//
+// Returns the number of entries failed.
+iree_host_size_t iree_hal_amdgpu_notification_ring_fail_all(
+    iree_hal_amdgpu_notification_ring_t* ring, iree_status_t error_status,
+    uint64_t* out_kernarg_reclaim_position);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_UTIL_NOTIFICATION_RING_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/notification_ring_test.cc b/runtime/src/iree/hal/drivers/amdgpu/util/notification_ring_test.cc
new file mode 100644
index 0000000..6579861
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/notification_ring_test.cc
@@ -0,0 +1,820 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/util/notification_ring.h"
+
+#include <array>
+#include <memory>
+
+#include "iree/async/proactor_platform.h"
+#include "iree/async/semaphore.h"
+#include "iree/base/api.h"
+#include "iree/base/internal/arena.h"
+#include "iree/hal/drivers/amdgpu/util/libhsa.h"
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace iree::hal::amdgpu {
+namespace {
+
+struct MaxFrontierStorage {
+  uint8_t entry_count;
+  uint8_t reserved[7];
+  iree_async_frontier_entry_t
+      entries[IREE_HAL_AMDGPU_MAX_FRONTIER_SNAPSHOT_ENTRY_COUNT];
+};
+
+typedef struct PreSignalActionState {
+  // Semaphore that must not be user-visible when the callback runs.
+  iree_async_semaphore_t* semaphore;
+  // Number of times the callback has run.
+  int callback_count;
+} PreSignalActionState;
+
+typedef struct RetireCallbackState {
+  // Semaphore that must not be user-visible when the callback runs.
+  iree_async_semaphore_t* semaphore;
+  // Reclaim entry observed by the callback.
+  iree_hal_amdgpu_reclaim_entry_t* entry;
+  // Submission epoch observed by the callback.
+  uint64_t epoch;
+  // Number of times the callback has run.
+  int callback_count;
+} RetireCallbackState;
+
+static void VerifySemaphoreNotVisibleBeforePreSignalAction(
+    iree_hal_amdgpu_reclaim_entry_t* entry, void* user_data,
+    iree_status_t status) {
+  IREE_EXPECT_OK(status);
+  EXPECT_NE(entry, nullptr);
+  auto* state = static_cast<PreSignalActionState*>(user_data);
+  EXPECT_EQ(iree_async_semaphore_query(state->semaphore), 0u);
+  ++state->callback_count;
+}
+
+static void VerifySemaphoreNotVisibleBeforeRetireCallback(
+    iree_hal_amdgpu_reclaim_entry_t* entry, uint64_t epoch, void* user_data) {
+  EXPECT_NE(entry, nullptr);
+  auto* state = static_cast<RetireCallbackState*>(user_data);
+  EXPECT_EQ(iree_async_semaphore_query(state->semaphore), 0u);
+  state->entry = entry;
+  state->epoch = epoch;
+  ++state->callback_count;
+}
+
+// RAII wrapper for notification rings. Ensures deinitialize is called on
+// destruction.
+struct NotificationRingDeleter {
+  void operator()(iree_hal_amdgpu_notification_ring_t* ring) {
+    iree_hal_amdgpu_notification_ring_deinitialize(ring);
+    delete ring;
+  }
+};
+using NotificationRingPtr = std::unique_ptr<iree_hal_amdgpu_notification_ring_t,
+                                            NotificationRingDeleter>;
+
+struct NotificationRingTest : public ::testing::Test {
+  static iree_allocator_t host_allocator;
+  static iree_hal_amdgpu_libhsa_t libhsa;
+  static iree_async_proactor_t* proactor;
+  static iree_arena_block_pool_t block_pool;
+
+  static void SetUpTestSuite() {
+    IREE_TRACE_SCOPE();
+    host_allocator = iree_allocator_system();
+    iree_status_t status = iree_hal_amdgpu_libhsa_initialize(
+        IREE_HAL_AMDGPU_LIBHSA_FLAG_NONE, iree_string_view_list_empty(),
+        host_allocator, &libhsa);
+    if (!iree_status_is_ok(status)) {
+      iree_status_fprint(stderr, status);
+      iree_status_free(status);
+      GTEST_SKIP() << "HSA not available, skipping tests";
+    }
+    IREE_ASSERT_OK(iree_async_proactor_create_platform(
+        iree_async_proactor_options_default(), host_allocator, &proactor));
+    iree_arena_block_pool_initialize(4096, host_allocator, &block_pool);
+  }
+
+  static void TearDownTestSuite() {
+    IREE_TRACE_SCOPE();
+    iree_arena_block_pool_deinitialize(&block_pool);
+    iree_async_proactor_release(proactor);
+    proactor = NULL;
+    iree_hal_amdgpu_libhsa_deinitialize(&libhsa);
+  }
+
+  // Initializes a notification ring with the given capacity and returns an
+  // RAII wrapper that ensures deinitialize is called on destruction.
+  iree::StatusOr<NotificationRingPtr> InitializeRing(
+      uint32_t capacity = IREE_HAL_AMDGPU_DEFAULT_NOTIFICATION_CAPACITY) {
+    auto ring = std::make_unique<iree_hal_amdgpu_notification_ring_t>();
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_notification_ring_initialize(
+        &libhsa, &block_pool, capacity, host_allocator, ring.get()));
+    return NotificationRingPtr(ring.release());
+  }
+
+  // Creates an async semaphore with initial value 0.
+  iree_async_semaphore_t* CreateSemaphore(
+      uint8_t frontier_capacity =
+          IREE_ASYNC_SEMAPHORE_DEFAULT_FRONTIER_CAPACITY) {
+    iree_async_semaphore_t* semaphore = NULL;
+    IREE_CHECK_OK(iree_async_semaphore_create(proactor, /*initial_value=*/0,
+                                              frontier_capacity, host_allocator,
+                                              &semaphore));
+    return semaphore;
+  }
+
+  // Simulates GPU completion of N total submissions by storing the
+  // appropriate epoch signal value (INITIAL - N).
+  void SimulateCompletions(iree_hal_amdgpu_notification_ring_t* ring,
+                           uint64_t total_completed) {
+    hsa_signal_value_t target_value =
+        (hsa_signal_value_t)(IREE_HAL_AMDGPU_EPOCH_INITIAL_VALUE -
+                             total_completed);
+    iree_hsa_signal_store_screlease(IREE_LIBHSA(&libhsa), ring->epoch.signal,
+                                    target_value);
+  }
+
+  // Empty frontier for drain calls (no accumulated causal context).
+  static constexpr iree_async_frontier_t kEmptyFrontier = {0};
+
+  static const iree_async_frontier_t* InitializeMaxFrontier(
+      MaxFrontierStorage* storage) {
+    auto* frontier = reinterpret_cast<iree_async_frontier_t*>(storage);
+    iree_async_frontier_initialize(
+        frontier, IREE_HAL_AMDGPU_MAX_FRONTIER_SNAPSHOT_ENTRY_COUNT);
+    for (uint8_t i = 0; i < IREE_HAL_AMDGPU_MAX_FRONTIER_SNAPSHOT_ENTRY_COUNT;
+         ++i) {
+      frontier->entries[i].axis = iree_async_axis_make_queue(
+          /*session_epoch=*/1, /*machine_index=*/2, /*device_index=*/3, i);
+      frontier->entries[i].epoch = i + 1;
+    }
+    return frontier;
+  }
+
+  static iree_hal_amdgpu_reclaim_entry_t* ReclaimEntryForNextEpoch(
+      iree_hal_amdgpu_notification_ring_t* ring,
+      uint64_t kernarg_write_position = 0,
+      uint64_t queue_upload_write_position = 0) {
+    iree_hal_amdgpu_reclaim_entry_t* reclaim_entry =
+        iree_hal_amdgpu_notification_ring_reclaim_entry(ring);
+    reclaim_entry->kernarg_write_position = kernarg_write_position;
+    reclaim_entry->queue_upload_write_position = queue_upload_write_position;
+    return reclaim_entry;
+  }
+
+  static void PushNotification(iree_hal_amdgpu_notification_ring_t* ring,
+                               uint64_t epoch,
+                               iree_async_semaphore_t* semaphore,
+                               uint64_t value) {
+    iree_hal_amdgpu_notification_ring_push(
+        ring, epoch, semaphore, value,
+        IREE_HAL_AMDGPU_NOTIFICATION_ENTRY_FLAG_NONE);
+  }
+
+  static void PushNotificationOmittingFrontierSnapshot(
+      iree_hal_amdgpu_notification_ring_t* ring, uint64_t epoch,
+      iree_async_semaphore_t* semaphore, uint64_t value) {
+    iree_hal_amdgpu_notification_ring_push(
+        ring, epoch, semaphore, value,
+        IREE_HAL_AMDGPU_NOTIFICATION_ENTRY_FLAG_OMIT_FRONTIER_SNAPSHOT);
+  }
+};
+iree_allocator_t NotificationRingTest::host_allocator;
+iree_hal_amdgpu_libhsa_t NotificationRingTest::libhsa;
+iree_async_proactor_t* NotificationRingTest::proactor = NULL;
+iree_arena_block_pool_t NotificationRingTest::block_pool;
+
+TEST_F(NotificationRingTest, InitDeinit) {
+  IREE_ASSERT_OK_AND_ASSIGN(auto ring, InitializeRing());
+}
+
+TEST_F(NotificationRingTest, InvalidCapacity) {
+  iree_hal_amdgpu_notification_ring_t ring;
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_INVALID_ARGUMENT,
+      iree_hal_amdgpu_notification_ring_initialize(
+          &libhsa, &block_pool, /*capacity=*/100, host_allocator, &ring));
+}
+
+TEST_F(NotificationRingTest, DrainEmpty) {
+  IREE_ASSERT_OK_AND_ASSIGN(auto ring, InitializeRing());
+
+  uint64_t kernarg_position = 0;
+  EXPECT_EQ(
+      iree_hal_amdgpu_notification_ring_drain(
+          ring.get(), &kEmptyFrontier, nullptr, nullptr, &kernarg_position),
+      0u);
+  EXPECT_EQ(kernarg_position, 0u);
+}
+
+TEST_F(NotificationRingTest, SingleNotification) {
+  IREE_ASSERT_OK_AND_ASSIGN(auto ring, InitializeRing());
+  iree_async_semaphore_t* semaphore = CreateSemaphore();
+
+  // Push a notification for epoch 1.
+  ReclaimEntryForNextEpoch(ring.get());
+  uint64_t epoch = iree_hal_amdgpu_notification_ring_advance_epoch(ring.get());
+  EXPECT_EQ(epoch, 1u);
+  PushNotification(ring.get(), epoch, semaphore, 1);
+
+  // Drain before completion: nothing happens.
+  uint64_t kernarg_position = 0;
+  EXPECT_EQ(
+      iree_hal_amdgpu_notification_ring_drain(
+          ring.get(), &kEmptyFrontier, nullptr, nullptr, &kernarg_position),
+      0u);
+  EXPECT_EQ(iree_async_semaphore_query(semaphore), 0u);
+
+  // Simulate GPU completing epoch 1.
+  SimulateCompletions(ring.get(), 1);
+
+  // Drain signals the semaphore.
+  EXPECT_EQ(
+      iree_hal_amdgpu_notification_ring_drain(
+          ring.get(), &kEmptyFrontier, nullptr, nullptr, &kernarg_position),
+      1u);
+  EXPECT_EQ(iree_async_semaphore_query(semaphore), 1u);
+
+  // Draining again is a no-op.
+  EXPECT_EQ(
+      iree_hal_amdgpu_notification_ring_drain(
+          ring.get(), &kEmptyFrontier, nullptr, nullptr, &kernarg_position),
+      0u);
+
+  iree_async_semaphore_release(semaphore);
+}
+
+TEST_F(NotificationRingTest, MultiplePerEpochAndSparseEpochs) {
+  IREE_ASSERT_OK_AND_ASSIGN(auto ring, InitializeRing());
+  iree_async_semaphore_t* semaphore_a = CreateSemaphore();
+  iree_async_semaphore_t* semaphore_b = CreateSemaphore();
+
+  // Epoch 1: two semaphores signaled from one submission.
+  // This is a semaphore transition (A -> B), so push a frontier snapshot.
+  ReclaimEntryForNextEpoch(ring.get());
+  uint64_t epoch1 = iree_hal_amdgpu_notification_ring_advance_epoch(ring.get());
+  PushNotification(ring.get(), epoch1, semaphore_a, 1);
+  iree_hal_amdgpu_notification_ring_push_frontier_snapshot(ring.get(), epoch1,
+                                                           &kEmptyFrontier);
+  PushNotification(ring.get(), epoch1, semaphore_b, 5);
+
+  // Epoch 2: no notification (non-signaling submission).
+  ReclaimEntryForNextEpoch(ring.get());
+  iree_hal_amdgpu_notification_ring_advance_epoch(ring.get());
+
+  // Epoch 3: signals semaphore_a again (transition B -> A).
+  ReclaimEntryForNextEpoch(ring.get());
+  uint64_t epoch3 = iree_hal_amdgpu_notification_ring_advance_epoch(ring.get());
+  EXPECT_EQ(epoch3, 3u);
+  iree_hal_amdgpu_notification_ring_push_frontier_snapshot(ring.get(), epoch1,
+                                                           &kEmptyFrontier);
+  PushNotification(ring.get(), epoch3, semaphore_a, 10);
+
+  // Complete epoch 1 only. Drain should coalesce the A and B entries at
+  // epoch 1 into two signals (one per semaphore).
+  SimulateCompletions(ring.get(), 1);
+  uint64_t kernarg_position = 0;
+  EXPECT_EQ(
+      iree_hal_amdgpu_notification_ring_drain(
+          ring.get(), &kEmptyFrontier, nullptr, nullptr, &kernarg_position),
+      2u);
+  EXPECT_EQ(iree_async_semaphore_query(semaphore_a), 1u);
+  EXPECT_EQ(iree_async_semaphore_query(semaphore_b), 5u);
+
+  // Complete all 3 epochs.
+  SimulateCompletions(ring.get(), 3);
+  EXPECT_EQ(
+      iree_hal_amdgpu_notification_ring_drain(
+          ring.get(), &kEmptyFrontier, nullptr, nullptr, &kernarg_position),
+      1u);
+  EXPECT_EQ(iree_async_semaphore_query(semaphore_a), 10u);
+
+  iree_async_semaphore_release(semaphore_b);
+  iree_async_semaphore_release(semaphore_a);
+}
+
+TEST_F(NotificationRingTest, OmittedSnapshotDoesNotConsumeNextSnapshot) {
+  IREE_ASSERT_OK_AND_ASSIGN(auto ring, InitializeRing());
+  iree_async_semaphore_t* private_semaphore = CreateSemaphore();
+  iree_async_semaphore_t* public_semaphore = CreateSemaphore();
+  iree_async_semaphore_t* final_semaphore = CreateSemaphore();
+
+  iree_async_single_frontier_t public_frontier_storage;
+  const iree_async_axis_t public_axis =
+      iree_async_axis_make_queue(/*session_epoch=*/1, /*machine_index=*/0,
+                                 /*device_index=*/0, /*queue_index=*/3);
+  iree_async_single_frontier_initialize(&public_frontier_storage, public_axis,
+                                        /*epoch=*/42);
+
+  ReclaimEntryForNextEpoch(ring.get());
+  uint64_t private_epoch =
+      iree_hal_amdgpu_notification_ring_advance_epoch(ring.get());
+  PushNotificationOmittingFrontierSnapshot(ring.get(), private_epoch,
+                                           private_semaphore,
+                                           /*value=*/1);
+
+  ReclaimEntryForNextEpoch(ring.get());
+  uint64_t public_epoch =
+      iree_hal_amdgpu_notification_ring_advance_epoch(ring.get());
+  PushNotification(ring.get(), public_epoch, public_semaphore, /*value=*/1);
+
+  iree_hal_amdgpu_notification_ring_push_frontier_snapshot(
+      ring.get(), public_epoch,
+      iree_async_single_frontier_as_const_frontier(&public_frontier_storage));
+  ReclaimEntryForNextEpoch(ring.get());
+  uint64_t final_epoch =
+      iree_hal_amdgpu_notification_ring_advance_epoch(ring.get());
+  PushNotification(ring.get(), final_epoch, final_semaphore, /*value=*/1);
+
+  SimulateCompletions(ring.get(), 2);
+  uint64_t kernarg_position = 0;
+  EXPECT_EQ(
+      iree_hal_amdgpu_notification_ring_drain(
+          ring.get(), &kEmptyFrontier, nullptr, nullptr, &kernarg_position),
+      2u);
+
+  iree_async_single_frontier_t queried_frontier;
+  EXPECT_EQ(iree_async_semaphore_query_frontier(
+                private_semaphore,
+                iree_async_single_frontier_as_frontier(&queried_frontier),
+                /*capacity=*/1),
+            0u);
+  EXPECT_EQ(iree_async_semaphore_query_frontier(
+                public_semaphore,
+                iree_async_single_frontier_as_frontier(&queried_frontier),
+                /*capacity=*/1),
+            1u);
+  EXPECT_EQ(queried_frontier.entries[0].axis, public_axis);
+  EXPECT_EQ(queried_frontier.entries[0].epoch, 42u);
+
+  SimulateCompletions(ring.get(), 3);
+  EXPECT_EQ(
+      iree_hal_amdgpu_notification_ring_drain(
+          ring.get(), &kEmptyFrontier, nullptr, nullptr, &kernarg_position),
+      1u);
+
+  iree_async_semaphore_release(final_semaphore);
+  iree_async_semaphore_release(public_semaphore);
+  iree_async_semaphore_release(private_semaphore);
+}
+
+TEST_F(NotificationRingTest, LateSnapshotForAlreadyDrainedSpanIsDiscarded) {
+  IREE_ASSERT_OK_AND_ASSIGN(auto ring, InitializeRing());
+  iree_async_semaphore_t* semaphore_a = CreateSemaphore();
+  iree_async_semaphore_t* semaphore_b = CreateSemaphore();
+
+  ReclaimEntryForNextEpoch(ring.get());
+  uint64_t epoch1 = iree_hal_amdgpu_notification_ring_advance_epoch(ring.get());
+  PushNotification(ring.get(), epoch1, semaphore_a, /*value=*/1);
+
+  SimulateCompletions(ring.get(), 1);
+  uint64_t kernarg_position = 0;
+  EXPECT_EQ(
+      iree_hal_amdgpu_notification_ring_drain(
+          ring.get(), &kEmptyFrontier, nullptr, nullptr, &kernarg_position),
+      1u);
+
+  iree_async_single_frontier_t stale_frontier_storage;
+  const iree_async_axis_t stale_axis =
+      iree_async_axis_make_queue(/*session_epoch=*/1, /*machine_index=*/0,
+                                 /*device_index=*/0, /*queue_index=*/4);
+  iree_async_single_frontier_initialize(&stale_frontier_storage, stale_axis,
+                                        /*epoch=*/99);
+  iree_hal_amdgpu_notification_ring_push_frontier_snapshot(
+      ring.get(), epoch1,
+      iree_async_single_frontier_as_const_frontier(&stale_frontier_storage));
+
+  ReclaimEntryForNextEpoch(ring.get());
+  uint64_t epoch2 = iree_hal_amdgpu_notification_ring_advance_epoch(ring.get());
+  PushNotification(ring.get(), epoch2, semaphore_b, /*value=*/1);
+
+  SimulateCompletions(ring.get(), 2);
+  EXPECT_EQ(
+      iree_hal_amdgpu_notification_ring_drain(
+          ring.get(), &kEmptyFrontier, nullptr, nullptr, &kernarg_position),
+      1u);
+
+  iree_host_size_t frontier_write = (iree_host_size_t)iree_atomic_load(
+      &ring->frontier_ring.write, iree_memory_order_acquire);
+  iree_host_size_t frontier_read = (iree_host_size_t)iree_atomic_load(
+      &ring->frontier_ring.read, iree_memory_order_acquire);
+  EXPECT_EQ(frontier_read, frontier_write);
+
+  iree_async_semaphore_release(semaphore_b);
+  iree_async_semaphore_release(semaphore_a);
+}
+
+TEST_F(NotificationRingTest, CoalescingSameSemaphore) {
+  IREE_ASSERT_OK_AND_ASSIGN(auto ring, InitializeRing());
+  iree_async_semaphore_t* semaphore = CreateSemaphore();
+
+  // Three submissions, each signaling the semaphore to increasing values.
+  // All same semaphore — no frontier snapshots needed (no transitions).
+  for (uint64_t i = 0; i < 3; ++i) {
+    ReclaimEntryForNextEpoch(ring.get());
+    uint64_t epoch =
+        iree_hal_amdgpu_notification_ring_advance_epoch(ring.get());
+    PushNotification(ring.get(), epoch, semaphore, i + 1);
+  }
+
+  // Complete only epoch 1.
+  SimulateCompletions(ring.get(), 1);
+  uint64_t kernarg_position = 0;
+  EXPECT_EQ(
+      iree_hal_amdgpu_notification_ring_drain(
+          ring.get(), &kEmptyFrontier, nullptr, nullptr, &kernarg_position),
+      1u);
+  EXPECT_EQ(iree_async_semaphore_query(semaphore), 1u);
+
+  // Complete epochs 2 and 3. Drain should coalesce both entries (same
+  // semaphore) into a single signal to value 3.
+  SimulateCompletions(ring.get(), 3);
+  EXPECT_EQ(
+      iree_hal_amdgpu_notification_ring_drain(
+          ring.get(), &kEmptyFrontier, nullptr, nullptr, &kernarg_position),
+      2u);
+  // Coalesced: signaled directly to 3, skipping 2.
+  EXPECT_EQ(iree_async_semaphore_query(semaphore), 3u);
+
+  iree_async_semaphore_release(semaphore);
+}
+
+TEST_F(NotificationRingTest, RingWrapAround) {
+  const uint32_t capacity = 4;
+  IREE_ASSERT_OK_AND_ASSIGN(auto ring, InitializeRing(capacity));
+  iree_async_semaphore_t* semaphore = CreateSemaphore();
+
+  // Fill the ring, drain, repeat three times to exercise wrap-around.
+  for (int round = 0; round < 3; ++round) {
+    for (uint32_t i = 0; i < capacity; ++i) {
+      ReclaimEntryForNextEpoch(ring.get());
+      uint64_t epoch =
+          iree_hal_amdgpu_notification_ring_advance_epoch(ring.get());
+      PushNotification(ring.get(), epoch, semaphore,
+                       (round * capacity) + i + 1);
+    }
+    SimulateCompletions(ring.get(), ring->epoch.next_submission);
+    uint64_t kernarg_position = 0;
+    // All same semaphore — coalesced into 1 signal per round.
+    EXPECT_EQ(
+        iree_hal_amdgpu_notification_ring_drain(
+            ring.get(), &kEmptyFrontier, nullptr, nullptr, &kernarg_position),
+        capacity);
+  }
+
+  // Coalesced across each round: signaled to 4, then 8, then 12.
+  EXPECT_EQ(iree_async_semaphore_query(semaphore), 3u * capacity);
+
+  iree_async_semaphore_release(semaphore);
+}
+
+TEST_F(NotificationRingTest, ReserveReturnsResourceExhaustedWhenFull) {
+  IREE_ASSERT_OK_AND_ASSIGN(auto ring, InitializeRing(/*capacity=*/2));
+  iree_async_semaphore_t* semaphore = CreateSemaphore();
+
+  IREE_EXPECT_OK(iree_hal_amdgpu_notification_ring_reserve(
+      ring.get(), /*entry_count=*/2, /*frontier_snapshot_count=*/0));
+
+  ReclaimEntryForNextEpoch(ring.get());
+  uint64_t epoch1 = iree_hal_amdgpu_notification_ring_advance_epoch(ring.get());
+  PushNotification(ring.get(), epoch1, semaphore, /*timeline_value=*/1);
+  ReclaimEntryForNextEpoch(ring.get());
+  uint64_t epoch2 = iree_hal_amdgpu_notification_ring_advance_epoch(ring.get());
+  PushNotification(ring.get(), epoch2, semaphore, /*timeline_value=*/2);
+
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_RESOURCE_EXHAUSTED,
+      iree_hal_amdgpu_notification_ring_reserve(ring.get(), /*entry_count=*/1,
+                                                /*frontier_snapshot_count=*/0));
+
+  SimulateCompletions(ring.get(), 1);
+  uint64_t kernarg_position = 0;
+  EXPECT_EQ(
+      iree_hal_amdgpu_notification_ring_drain(
+          ring.get(), &kEmptyFrontier, nullptr, nullptr, &kernarg_position),
+      1u);
+
+  IREE_EXPECT_OK(iree_hal_amdgpu_notification_ring_reserve(
+      ring.get(), /*entry_count=*/1, /*frontier_snapshot_count=*/0));
+
+  iree_async_semaphore_release(semaphore);
+}
+
+TEST_F(NotificationRingTest, FrontierSnapshotWrapAround) {
+  IREE_ASSERT_OK_AND_ASSIGN(auto ring, InitializeRing(/*capacity=*/4));
+
+  std::array<iree_async_semaphore_t*, 10> semaphores = {
+      NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL};
+  for (iree_async_semaphore_t*& semaphore : semaphores) {
+    semaphore =
+        CreateSemaphore(IREE_HAL_AMDGPU_MAX_FRONTIER_SNAPSHOT_ENTRY_COUNT);
+  }
+
+  MaxFrontierStorage snapshot_storage;
+  const iree_async_frontier_t* snapshot_frontier =
+      InitializeMaxFrontier(&snapshot_storage);
+
+  ReclaimEntryForNextEpoch(ring.get());
+  uint64_t previous_epoch =
+      iree_hal_amdgpu_notification_ring_advance_epoch(ring.get());
+  PushNotification(ring.get(), previous_epoch, semaphores[0], 1);
+
+  // Build enough transitions to force the snapshot byte-ring write offset to
+  // wrap while one unread snapshot remains near the tail. Draining two entries
+  // after each batch preserves one transition snapshot for the next batch.
+  for (size_t i = 1; i < 4; ++i) {
+    iree_hal_amdgpu_notification_ring_push_frontier_snapshot(
+        ring.get(), previous_epoch, snapshot_frontier);
+    ReclaimEntryForNextEpoch(ring.get());
+    previous_epoch =
+        iree_hal_amdgpu_notification_ring_advance_epoch(ring.get());
+    PushNotification(ring.get(), previous_epoch, semaphores[i], 1);
+  }
+
+  uint64_t kernarg_position = 0;
+  SimulateCompletions(ring.get(), 2);
+  EXPECT_EQ(
+      iree_hal_amdgpu_notification_ring_drain(
+          ring.get(), &kEmptyFrontier, nullptr, nullptr, &kernarg_position),
+      2u);
+
+  SimulateCompletions(ring.get(), 4);
+  for (size_t i = 4; i < 6; ++i) {
+    iree_hal_amdgpu_notification_ring_push_frontier_snapshot(
+        ring.get(), previous_epoch, snapshot_frontier);
+    ReclaimEntryForNextEpoch(ring.get());
+    previous_epoch =
+        iree_hal_amdgpu_notification_ring_advance_epoch(ring.get());
+    PushNotification(ring.get(), previous_epoch, semaphores[i], 1);
+  }
+  EXPECT_EQ(
+      iree_hal_amdgpu_notification_ring_drain(
+          ring.get(), &kEmptyFrontier, nullptr, nullptr, &kernarg_position),
+      2u);
+
+  SimulateCompletions(ring.get(), 6);
+  EXPECT_EQ(
+      iree_hal_amdgpu_notification_ring_drain(
+          ring.get(), &kEmptyFrontier, nullptr, nullptr, &kernarg_position),
+      2u);
+
+  for (size_t i = 6; i < semaphores.size(); ++i) {
+    iree_hal_amdgpu_notification_ring_push_frontier_snapshot(
+        ring.get(), previous_epoch, snapshot_frontier);
+    ReclaimEntryForNextEpoch(ring.get());
+    previous_epoch =
+        iree_hal_amdgpu_notification_ring_advance_epoch(ring.get());
+    PushNotification(ring.get(), previous_epoch, semaphores[i], 1);
+  }
+  iree_host_size_t frontier_write = (iree_host_size_t)iree_atomic_load(
+      &ring->frontier_ring.write, iree_memory_order_acquire);
+  iree_host_size_t frontier_read = (iree_host_size_t)iree_atomic_load(
+      &ring->frontier_ring.read, iree_memory_order_acquire);
+  EXPECT_LT(frontier_write & (ring->frontier_ring.capacity - 1),
+            frontier_read & (ring->frontier_ring.capacity - 1));
+  SimulateCompletions(ring.get(), semaphores.size());
+  EXPECT_EQ(
+      iree_hal_amdgpu_notification_ring_drain(
+          ring.get(), &kEmptyFrontier, nullptr, nullptr, &kernarg_position),
+      4u);
+
+  MaxFrontierStorage queried_storage;
+  auto* queried_frontier =
+      reinterpret_cast<iree_async_frontier_t*>(&queried_storage);
+  // semaphores[8]'s frontier snapshot is written after the byte-ring write
+  // offset wraps back to the beginning, so this specifically verifies the
+  // wrapped snapshot was not mistaken for an empty ring.
+  EXPECT_EQ(iree_async_semaphore_query_frontier(
+                semaphores[8], queried_frontier,
+                IREE_HAL_AMDGPU_MAX_FRONTIER_SNAPSHOT_ENTRY_COUNT),
+            IREE_HAL_AMDGPU_MAX_FRONTIER_SNAPSHOT_ENTRY_COUNT);
+  EXPECT_EQ(queried_frontier->entries[0].axis,
+            snapshot_frontier->entries[0].axis);
+  EXPECT_EQ(queried_frontier->entries[0].epoch,
+            snapshot_frontier->entries[0].epoch);
+  EXPECT_EQ(queried_frontier
+                ->entries[IREE_HAL_AMDGPU_MAX_FRONTIER_SNAPSHOT_ENTRY_COUNT - 1]
+                .axis,
+            snapshot_frontier
+                ->entries[IREE_HAL_AMDGPU_MAX_FRONTIER_SNAPSHOT_ENTRY_COUNT - 1]
+                .axis);
+  EXPECT_EQ(queried_frontier
+                ->entries[IREE_HAL_AMDGPU_MAX_FRONTIER_SNAPSHOT_ENTRY_COUNT - 1]
+                .epoch,
+            snapshot_frontier
+                ->entries[IREE_HAL_AMDGPU_MAX_FRONTIER_SNAPSHOT_ENTRY_COUNT - 1]
+                .epoch);
+
+  for (iree_async_semaphore_t* semaphore : semaphores) {
+    iree_async_semaphore_release(semaphore);
+  }
+}
+
+TEST_F(NotificationRingTest, KernargPositionReporting) {
+  IREE_ASSERT_OK_AND_ASSIGN(auto ring, InitializeRing());
+  iree_async_semaphore_t* semaphore = CreateSemaphore();
+
+  // Three submissions with increasing kernarg positions.
+  for (uint64_t i = 0; i < 3; ++i) {
+    ReclaimEntryForNextEpoch(ring.get(), (i + 1) * 64);
+    uint64_t epoch =
+        iree_hal_amdgpu_notification_ring_advance_epoch(ring.get());
+    PushNotification(ring.get(), epoch, semaphore, i + 1);
+  }
+
+  // Complete epoch 1 only. Drain should report position 64.
+  SimulateCompletions(ring.get(), 1);
+  uint64_t kernarg_position = 0;
+  EXPECT_EQ(
+      iree_hal_amdgpu_notification_ring_drain(
+          ring.get(), &kEmptyFrontier, nullptr, nullptr, &kernarg_position),
+      1u);
+  EXPECT_EQ(kernarg_position, 64u);
+
+  // Complete all. Drain should report position 192 (max of 128, 192).
+  SimulateCompletions(ring.get(), 3);
+  EXPECT_EQ(
+      iree_hal_amdgpu_notification_ring_drain(
+          ring.get(), &kEmptyFrontier, nullptr, nullptr, &kernarg_position),
+      2u);
+  EXPECT_EQ(kernarg_position, 192u);
+
+  iree_async_semaphore_release(semaphore);
+}
+
+TEST_F(NotificationRingTest, QueueOwnedReclaimPositionReporting) {
+  IREE_ASSERT_OK_AND_ASSIGN(auto ring, InitializeRing());
+
+  // Epoch 1: no user-visible signals, but both queue-owned rings must retire.
+  ReclaimEntryForNextEpoch(ring.get(), /*kernarg_write_position=*/64,
+                           /*queue_upload_write_position=*/256);
+  EXPECT_EQ(iree_hal_amdgpu_notification_ring_advance_epoch(ring.get()), 1u);
+
+  // Epoch 2: later kernargs but an earlier upload watermark.
+  ReclaimEntryForNextEpoch(ring.get(), /*kernarg_write_position=*/192,
+                           /*queue_upload_write_position=*/128);
+  EXPECT_EQ(iree_hal_amdgpu_notification_ring_advance_epoch(ring.get()), 2u);
+
+  SimulateCompletions(ring.get(), 2);
+  iree_hal_amdgpu_reclaim_positions_t reclaim_positions = {0};
+  EXPECT_EQ(
+      iree_hal_amdgpu_notification_ring_drain_reclaim_positions(
+          ring.get(), &kEmptyFrontier, nullptr, nullptr, &reclaim_positions),
+      0u);
+  EXPECT_EQ(reclaim_positions.kernarg_write_position, 192u);
+  EXPECT_EQ(reclaim_positions.queue_upload_write_position, 256u);
+}
+
+TEST_F(NotificationRingTest, KernargPositionReportingForZeroSignalEpochs) {
+  IREE_ASSERT_OK_AND_ASSIGN(auto ring, InitializeRing());
+
+  // Epoch 1: no user-visible signals, but kernarg memory must still retire.
+  ReclaimEntryForNextEpoch(ring.get(), 64);
+  EXPECT_EQ(iree_hal_amdgpu_notification_ring_advance_epoch(ring.get()), 1u);
+
+  // Epoch 2: another no-signal submission with a later kernarg watermark.
+  ReclaimEntryForNextEpoch(ring.get(), 192);
+  EXPECT_EQ(iree_hal_amdgpu_notification_ring_advance_epoch(ring.get()), 2u);
+
+  // Complete epoch 1 only.
+  SimulateCompletions(ring.get(), 1);
+  uint64_t kernarg_position = 0;
+  EXPECT_EQ(
+      iree_hal_amdgpu_notification_ring_drain(
+          ring.get(), &kEmptyFrontier, nullptr, nullptr, &kernarg_position),
+      0u);
+  EXPECT_EQ(kernarg_position, 64u);
+
+  // Complete both epochs.
+  SimulateCompletions(ring.get(), 2);
+  EXPECT_EQ(
+      iree_hal_amdgpu_notification_ring_drain(
+          ring.get(), &kEmptyFrontier, nullptr, nullptr, &kernarg_position),
+      0u);
+  EXPECT_EQ(kernarg_position, 192u);
+}
+
+TEST_F(NotificationRingTest, PreSignalActionRunsBeforeSemaphorePublication) {
+  IREE_ASSERT_OK_AND_ASSIGN(auto ring, InitializeRing());
+  iree_async_semaphore_t* semaphore = CreateSemaphore();
+  PreSignalActionState action_state = {
+      .semaphore = semaphore,
+      .callback_count = 0,
+  };
+
+  iree_hal_amdgpu_reclaim_entry_t* reclaim_entry =
+      ReclaimEntryForNextEpoch(ring.get());
+  reclaim_entry->pre_signal_action = {
+      .fn = VerifySemaphoreNotVisibleBeforePreSignalAction,
+      .user_data = &action_state,
+  };
+  uint64_t epoch = iree_hal_amdgpu_notification_ring_advance_epoch(ring.get());
+  PushNotification(ring.get(), epoch, semaphore, 1);
+
+  SimulateCompletions(ring.get(), 1);
+  uint64_t kernarg_position = 0;
+  EXPECT_EQ(
+      iree_hal_amdgpu_notification_ring_drain(
+          ring.get(), &kEmptyFrontier, nullptr, nullptr, &kernarg_position),
+      1u);
+  EXPECT_EQ(action_state.callback_count, 1);
+  EXPECT_EQ(iree_async_semaphore_query(semaphore), 1u);
+
+  iree_async_semaphore_release(semaphore);
+}
+
+TEST_F(NotificationRingTest, RetireCallbackRunsBeforeSemaphorePublication) {
+  IREE_ASSERT_OK_AND_ASSIGN(auto ring, InitializeRing());
+  iree_async_semaphore_t* semaphore = CreateSemaphore();
+  RetireCallbackState callback_state = {
+      .semaphore = semaphore,
+      .entry = nullptr,
+      .epoch = 0,
+      .callback_count = 0,
+  };
+
+  iree_hal_amdgpu_reclaim_entry_t* reclaim_entry =
+      ReclaimEntryForNextEpoch(ring.get());
+  reclaim_entry->profile_event_first_position = 42;
+  reclaim_entry->profile_event_count = 3;
+  uint64_t epoch = iree_hal_amdgpu_notification_ring_advance_epoch(ring.get());
+  PushNotification(ring.get(), epoch, semaphore, 1);
+
+  SimulateCompletions(ring.get(), 1);
+  uint64_t kernarg_position = 0;
+  EXPECT_EQ(iree_hal_amdgpu_notification_ring_drain(
+                ring.get(), &kEmptyFrontier,
+                VerifySemaphoreNotVisibleBeforeRetireCallback, &callback_state,
+                &kernarg_position),
+            1u);
+  EXPECT_EQ(callback_state.callback_count, 1);
+  EXPECT_EQ(callback_state.entry, reclaim_entry);
+  EXPECT_EQ(callback_state.epoch, epoch);
+  EXPECT_EQ(iree_async_semaphore_query(semaphore), 1u);
+
+  iree_async_semaphore_release(semaphore);
+}
+
+TEST_F(NotificationRingTest, SignalFailureFailsSemaphore) {
+  IREE_ASSERT_OK_AND_ASSIGN(auto ring, InitializeRing());
+  iree_async_semaphore_t* semaphore = CreateSemaphore();
+
+  // Epoch 1: signal semaphore to value 5.
+  ReclaimEntryForNextEpoch(ring.get(), 64);
+  uint64_t epoch1 = iree_hal_amdgpu_notification_ring_advance_epoch(ring.get());
+  PushNotification(ring.get(), epoch1, semaphore, 5);
+
+  // Epoch 2: signal semaphore to value 3 (non-monotonic — will fail).
+  // Same semaphore, so drain coalesces to the LATER entry (value 3).
+  // Since semaphore is already at 5 after the first signal (coalesced value
+  // is the last one, 3, which is < 5), the coalesced signal to 3 fails.
+  // But wait — coalescing takes the last value, not the max. And 3 < 5.
+  // The semaphore_signal will fail (non-monotonic) and fall through to
+  // semaphore_fail. But because of coalescing, we only signal ONCE to
+  // value 3 (the last entry's value).
+  //
+  // Actually: with coalescing, entries (5, e1) and (3, e2) for the same
+  // semaphore produce a single signal to value 3 (the last value). But 3
+  // is not > 0 (current), so it succeeds with value 3. Then no further
+  // signal happens. The semaphore ends at 3, not 5. This is different from
+  // the non-coalesced behavior where signal(5) then signal(3) would fail
+  // on the second.
+  //
+  // This test verifies the coalesced behavior: pushing non-monotonic values
+  // for the same semaphore means the last value wins. The test pushes
+  // different semaphores to avoid coalescing and preserve the original
+  // non-monotonic test intent.
+  iree_async_semaphore_t* semaphore2 = CreateSemaphore();
+  iree_hal_amdgpu_notification_ring_push_frontier_snapshot(ring.get(), epoch1,
+                                                           &kEmptyFrontier);
+  ReclaimEntryForNextEpoch(ring.get(), 128);
+  uint64_t epoch2 = iree_hal_amdgpu_notification_ring_advance_epoch(ring.get());
+  PushNotification(ring.get(), epoch2, semaphore2, 3);
+
+  // Complete both epochs.
+  SimulateCompletions(ring.get(), 2);
+
+  // Drain processes both entries (different semaphores, no coalescing).
+  uint64_t kernarg_position = 0;
+  EXPECT_EQ(
+      iree_hal_amdgpu_notification_ring_drain(
+          ring.get(), &kEmptyFrontier, nullptr, nullptr, &kernarg_position),
+      2u);
+
+  // Semaphore got signaled to 5, semaphore2 got signaled to 3.
+  EXPECT_EQ(iree_async_semaphore_query(semaphore), 5u);
+  EXPECT_EQ(iree_async_semaphore_query(semaphore2), 3u);
+
+  // Kernarg position reporting is unaffected.
+  EXPECT_EQ(kernarg_position, 128u);
+
+  iree_async_semaphore_release(semaphore2);
+  iree_async_semaphore_release(semaphore);
+}
+
+}  // namespace
+}  // namespace iree::hal::amdgpu
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/pm4_capabilities.h b/runtime/src/iree/hal/drivers/amdgpu/util/pm4_capabilities.h
new file mode 100644
index 0000000..ff748a4
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/pm4_capabilities.h
@@ -0,0 +1,114 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_UTIL_PM4_CAPABILITIES_H_
+#define IREE_HAL_DRIVERS_AMDGPU_UTIL_PM4_CAPABILITIES_H_
+
+#include "iree/base/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Hardware mechanism used for cross-queue epoch waits after wait resolution
+// proves that a dependency has already been submitted by a local peer queue.
+typedef enum iree_hal_amdgpu_wait_barrier_strategy_e {
+  // No device-side 64-bit epoch wait is known for this agent; unresolved
+  // cross-queue waits must use software deferral.
+  IREE_HAL_AMDGPU_WAIT_BARRIER_STRATEGY_DEFER = 0,
+  // AMD vendor AQL BARRIER_VALUE packet.
+  IREE_HAL_AMDGPU_WAIT_BARRIER_STRATEGY_AQL_BARRIER_VALUE = 1,
+  // AMD vendor AQL PM4-IB packet executing a WAIT_REG_MEM64 PM4 packet.
+  IREE_HAL_AMDGPU_WAIT_BARRIER_STRATEGY_PM4_WAIT_REG_MEM64 = 2,
+} iree_hal_amdgpu_wait_barrier_strategy_t;
+
+// AMD vendor-packet and PM4 packet-family capabilities available on a physical
+// device.
+enum iree_hal_amdgpu_vendor_packet_capability_bits_t {
+  // AMD vendor AQL PM4-IB packets can jump to device-visible PM4 programs.
+  IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_AQL_PM4_IB = 1u << 0,
+  // AMD vendor AQL BARRIER_VALUE packets can wait on arbitrary signal values.
+  IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_AQL_BARRIER_VALUE = 1u << 1,
+  // PM4 WAIT_REG_MEM64 packets can perform 64-bit memory comparisons.
+  IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_WAIT_REG_MEM64 = 1u << 2,
+  // PM4 COPY_DATA can copy the immediate timestamp counter to memory.
+  IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_COPY_TIMESTAMP = 1u << 3,
+  // PM4 RELEASE_MEM can write a bottom-of-pipe timestamp to memory.
+  IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_RELEASE_MEM_TIMESTAMP = 1u << 4,
+  // PM4 EVENT_WRITE can emit compute-pipeline events.
+  IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_EVENT_WRITE = 1u << 5,
+  // PM4 SET_SH_REG can program persistent shader registers.
+  IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_SET_SH_REG = 1u << 6,
+  // PM4 SET_UCONFIG_REG can program user configuration registers.
+  IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_SET_UCONFIG_REG = 1u << 7,
+  // PM4 COPY_DATA can read register values into memory.
+  IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_REGISTER_READBACK = 1u << 8,
+  // PM4 COPY_DATA can read performance-counter values into memory.
+  IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_PERFCOUNTER_READBACK = 1u << 9,
+  // PM4 COPY_DATA can write immediate values into registers/perfcounters.
+  IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_IMMEDIATE_WRITE = 1u << 10,
+  // PM4 WRITE_DATA can write immediate values into memory through TC_L2.
+  IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_WRITE_DATA_MEMORY = 1u << 11,
+  // PM4 COPY_DATA can copy memory through TC_L2 into memory through TC_L2.
+  IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_COPY_DATA_MEMORY = 1u << 12,
+};
+typedef uint32_t iree_hal_amdgpu_vendor_packet_capability_flags_t;
+
+// Returns true if the device can emit queue-private PM4 WRITE_DATA memory
+// writes.
+static inline bool
+iree_hal_amdgpu_vendor_packet_capabilities_support_pm4_memory_write_data(
+    iree_hal_amdgpu_vendor_packet_capability_flags_t capabilities) {
+  return iree_all_bits_set(
+      capabilities,
+      IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_AQL_PM4_IB |
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_WRITE_DATA_MEMORY);
+}
+
+// Returns true if the device can emit queue-private PM4 COPY_DATA memory
+// copies.
+static inline bool
+iree_hal_amdgpu_vendor_packet_capabilities_support_pm4_memory_copy_data(
+    iree_hal_amdgpu_vendor_packet_capability_flags_t capabilities) {
+  return iree_all_bits_set(
+      capabilities,
+      IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_AQL_PM4_IB |
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_COPY_DATA_MEMORY);
+}
+
+// Returns true if the device can emit queue-private PM4 timestamp ranges using
+// a COPY_DATA start timestamp and RELEASE_MEM end timestamp.
+static inline bool
+iree_hal_amdgpu_vendor_packet_capabilities_support_timestamp_range(
+    iree_hal_amdgpu_vendor_packet_capability_flags_t capabilities) {
+  return iree_all_bits_set(
+      capabilities,
+      IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_AQL_PM4_IB |
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_COPY_TIMESTAMP |
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_RELEASE_MEM_TIMESTAMP);
+}
+
+// Returns true if the device can emit the gfx10+ packet families needed for
+// queue-local PMC start/read/stop programs.
+static inline bool
+iree_hal_amdgpu_vendor_packet_capabilities_support_gfx10_pmc_programs(
+    iree_hal_amdgpu_vendor_packet_capability_flags_t capabilities) {
+  return iree_all_bits_set(
+      capabilities,
+      IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_AQL_PM4_IB |
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_EVENT_WRITE |
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_SET_SH_REG |
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_SET_UCONFIG_REG |
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_REGISTER_READBACK |
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_PERFCOUNTER_READBACK |
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_IMMEDIATE_WRITE);
+}
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_UTIL_PM4_CAPABILITIES_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/pm4_emitter.h b/runtime/src/iree/hal/drivers/amdgpu/util/pm4_emitter.h
new file mode 100644
index 0000000..f57a0f3
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/pm4_emitter.h
@@ -0,0 +1,579 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+// PM4 packet emission helpers. These build indirect-buffer payloads and the
+// vendor AQL PM4-IB packet bodies used to jump to those payloads. They do not
+// reserve IB storage, commit AQL packet headers, or ring doorbells; queue code
+// owns those publication and lifetime rules.
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_UTIL_PM4_EMITTER_H_
+#define IREE_HAL_DRIVERS_AMDGPU_UTIL_PM4_EMITTER_H_
+
+#include <string.h>
+
+#include "iree/base/api.h"
+#include "iree/hal/drivers/amdgpu/abi/queue.h"
+#include "iree/hal/drivers/amdgpu/abi/signal.h"
+#include "iree/hal/drivers/amdgpu/util/aql_emitter.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Queue-private PM4 IB slot. The slot count always matches the AQL ring
+// capacity, so AQL packet id N uses pm4_ib_slots[N & aql_ring.mask].
+typedef struct IREE_AMDGPU_ALIGNAS(64) iree_hal_amdgpu_pm4_ib_slot_t {
+  // Encoded PM4 packet words consumed by a PM4-IB AQL packet.
+  uint32_t dwords[16];
+} iree_hal_amdgpu_pm4_ib_slot_t;
+IREE_AMDGPU_STATIC_ASSERT(sizeof(iree_hal_amdgpu_pm4_ib_slot_t) == 64,
+                          "PM4 IB slot must be exactly one cache line");
+
+enum {
+  IREE_HAL_AMDGPU_PM4_IB_SLOT_DWORD_CAPACITY = 16,
+  // PM4-IB AQL envelopes encode the indirect-buffer dword count in 20 bits.
+  IREE_HAL_AMDGPU_PM4_IB_MAX_DWORD_COUNT = 0xFFFFF,
+  IREE_HAL_AMDGPU_PM4_COPY_TIMESTAMP_DWORD_COUNT = 6,
+  IREE_HAL_AMDGPU_PM4_RELEASE_MEM_TIMESTAMP_DWORD_COUNT = 8,
+  IREE_HAL_AMDGPU_PM4_TIMESTAMP_RANGE_DWORD_COUNT =
+      IREE_HAL_AMDGPU_PM4_COPY_TIMESTAMP_DWORD_COUNT +
+      IREE_HAL_AMDGPU_PM4_RELEASE_MEM_TIMESTAMP_DWORD_COUNT,
+  IREE_HAL_AMDGPU_PM4_EVENT_WRITE_DWORD_COUNT = 2,
+  IREE_HAL_AMDGPU_PM4_SET_REGISTER_DWORD_COUNT = 3,
+  IREE_HAL_AMDGPU_PM4_COPY_DATA_DWORD_COUNT = 6,
+  IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_WRITE_DATA = 0x37,
+  IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_INDIRECT_BUFFER = 0x3F,
+  IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_COPY_DATA = 0x40,
+  IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_EVENT_WRITE = 0x46,
+  IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_RELEASE_MEM = 0x49,
+  IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_SET_SH_REG = 0x76,
+  IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_SET_UCONFIG_REG = 0x79,
+  IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_WAIT_REG_MEM64 = 0x93,
+  IREE_HAL_AMDGPU_PM4_REGISTER_OFFSET_MASK = 0x3FFFF,
+  IREE_HAL_AMDGPU_PM4_PERSISTENT_SPACE_START = 0x00002C00,
+  IREE_HAL_AMDGPU_PM4_PERSISTENT_SPACE_END = 0x00002FFF,
+  IREE_HAL_AMDGPU_PM4_UCONFIG_SPACE_START = 0x0000C000,
+  IREE_HAL_AMDGPU_PM4_UCONFIG_SPACE_END = 0x0000FFFF,
+  IREE_HAL_AMDGPU_PM4_COPY_DATA_SRC_SEL_TIMESTAMP = 9 << 0,
+  IREE_HAL_AMDGPU_PM4_COPY_DATA_SRC_SEL_TC_L2 = 2 << 0,
+  IREE_HAL_AMDGPU_PM4_COPY_DATA_SRC_SEL_MEM_MAPPED_REGISTER = 0 << 0,
+  IREE_HAL_AMDGPU_PM4_COPY_DATA_SRC_SEL_PERFCOUNTER = 4 << 0,
+  IREE_HAL_AMDGPU_PM4_COPY_DATA_SRC_SEL_IMMEDIATE_DATA = 5 << 0,
+  IREE_HAL_AMDGPU_PM4_COPY_DATA_DST_SEL_MEM_MAPPED_REGISTER = 0 << 8,
+  IREE_HAL_AMDGPU_PM4_COPY_DATA_DST_SEL_MEM = 5 << 8,
+  IREE_HAL_AMDGPU_PM4_COPY_DATA_DST_SEL_TC_L2 = 2 << 8,
+  IREE_HAL_AMDGPU_PM4_COPY_DATA_DST_SEL_PERFCOUNTER = 4 << 8,
+  IREE_HAL_AMDGPU_PM4_COPY_DATA_COUNT_SEL_64_BITS = 1 << 16,
+  IREE_HAL_AMDGPU_PM4_COPY_DATA_WR_CONFIRM_WAIT_CONFIRMATION = 1 << 20,
+  IREE_HAL_AMDGPU_PM4_EVENT_WRITE_EVENT_TYPE_CS_PARTIAL_FLUSH = 7 << 0,
+  IREE_HAL_AMDGPU_PM4_EVENT_WRITE_EVENT_INDEX_CS_PARTIAL_FLUSH = 4 << 8,
+  IREE_HAL_AMDGPU_PM4_RELEASE_MEM_EVENT_TYPE_BOTTOM_OF_PIPE_TS = 40 << 0,
+  IREE_HAL_AMDGPU_PM4_RELEASE_MEM_EVENT_INDEX_END_OF_PIPE = 5 << 8,
+  IREE_HAL_AMDGPU_PM4_RELEASE_MEM_INT_SEL_SEND_DATA_AFTER_WR_CONFIRM = 3 << 24,
+  IREE_HAL_AMDGPU_PM4_RELEASE_MEM_DATA_SEL_TIMESTAMP = 3u << 29,
+  IREE_HAL_AMDGPU_PM4_WRITE_DATA_DST_SEL_TC_L2 = 2 << 8,
+  IREE_HAL_AMDGPU_PM4_WRITE_DATA_WR_CONFIRM_WAIT_CONFIRMATION = 1 << 20,
+  IREE_HAL_AMDGPU_PM4_WAIT_REG_MEM_FUNC_LESS_THAN = 1,
+  IREE_HAL_AMDGPU_PM4_WAIT_REG_MEM_SPACE_MEMORY = 1,
+  IREE_HAL_AMDGPU_PM4_WAIT_REG_MEM_OPERATION_WAIT_REG_MEM = 0,
+};
+
+static const uint32_t
+    IREE_HAL_AMDGPU_PM4_WAIT_REG_MEM_OPTIMIZE_ACE_OFFLOAD_MODE = 0x80000000u;
+
+typedef enum iree_hal_amdgpu_pm4_register_space_e {
+  IREE_HAL_AMDGPU_PM4_REGISTER_SPACE_MEM_MAPPED_REGISTER = 0,
+  IREE_HAL_AMDGPU_PM4_REGISTER_SPACE_PERFCOUNTER = 4,
+} iree_hal_amdgpu_pm4_register_space_t;
+
+typedef enum iree_hal_amdgpu_pm4_write_confirmation_e {
+  IREE_HAL_AMDGPU_PM4_WRITE_CONFIRMATION_NONE = 0,
+  IREE_HAL_AMDGPU_PM4_WRITE_CONFIRMATION_WAIT = 1,
+} iree_hal_amdgpu_pm4_write_confirmation_t;
+
+static inline uint32_t iree_hal_amdgpu_pm4_make_header(uint32_t opcode,
+                                                       uint32_t dword_count) {
+  return (3u << 30) | (opcode << 8) | ((dword_count - 2u) << 16);
+}
+
+// Bounded builder for one queue-private PM4 IB slot.
+typedef struct iree_hal_amdgpu_pm4_ib_builder_t {
+  // Queue-owned PM4 IB slot being populated.
+  iree_hal_amdgpu_pm4_ib_slot_t* slot;
+  // Number of dwords already populated in |slot|.
+  uint32_t dword_count;
+} iree_hal_amdgpu_pm4_ib_builder_t;
+
+// Clears |slot| and initializes |out_builder| for bounded packet appends.
+static inline void iree_hal_amdgpu_pm4_ib_builder_initialize(
+    iree_hal_amdgpu_pm4_ib_slot_t* slot,
+    iree_hal_amdgpu_pm4_ib_builder_t* out_builder) {
+  memset(slot, 0, sizeof(*slot));
+  out_builder->slot = slot;
+  out_builder->dword_count = 0;
+}
+
+// Returns the number of dwords populated by |builder|.
+static inline uint32_t iree_hal_amdgpu_pm4_ib_builder_dword_count(
+    const iree_hal_amdgpu_pm4_ib_builder_t* builder) {
+  return builder->dword_count;
+}
+
+// Returns the remaining dword capacity in |builder|.
+static inline uint32_t iree_hal_amdgpu_pm4_ib_builder_remaining(
+    const iree_hal_amdgpu_pm4_ib_builder_t* builder) {
+  return IREE_HAL_AMDGPU_PM4_IB_SLOT_DWORD_CAPACITY - builder->dword_count;
+}
+
+// Appends |dword_count| uninitialized dwords and returns their start, or NULL
+// if the requested span does not fit.
+static inline uint32_t* iree_hal_amdgpu_pm4_ib_builder_append_dwords(
+    iree_hal_amdgpu_pm4_ib_builder_t* builder, uint32_t dword_count) {
+  if (dword_count > iree_hal_amdgpu_pm4_ib_builder_remaining(builder)) {
+    return NULL;
+  }
+  uint32_t* dwords = &builder->slot->dwords[builder->dword_count];
+  builder->dword_count += dword_count;
+  return dwords;
+}
+
+// Appends one PM4 packet header and reserves |dword_count| packet dwords.
+// Returns the packet start, or NULL if the packet is malformed or does not fit.
+static inline uint32_t* iree_hal_amdgpu_pm4_ib_builder_append_packet(
+    iree_hal_amdgpu_pm4_ib_builder_t* builder, uint32_t opcode,
+    uint32_t dword_count) {
+  if (dword_count < 2) return NULL;
+  uint32_t* packet =
+      iree_hal_amdgpu_pm4_ib_builder_append_dwords(builder, dword_count);
+  if (!packet) return NULL;
+  packet[0] = iree_hal_amdgpu_pm4_make_header(opcode, dword_count);
+  return packet;
+}
+
+static inline uint32_t iree_hal_amdgpu_pm4_addr_lo(uintptr_t address) {
+  return (uint32_t)(address & 0xFFFFFFFCu);
+}
+
+static inline uint32_t iree_hal_amdgpu_pm4_addr_lo_8(uintptr_t address) {
+  return (uint32_t)(address & 0xFFFFFFF8u);
+}
+
+static inline uint32_t iree_hal_amdgpu_pm4_addr_hi(uintptr_t address) {
+  return (uint32_t)(address >> 32);
+}
+
+static inline uint32_t iree_hal_amdgpu_pm4_ib_addr_hi(uintptr_t address) {
+  return (uint32_t)((address >> 32) & 0xFFFFu);
+}
+
+static inline bool iree_hal_amdgpu_pm4_register_space_is_valid(
+    iree_hal_amdgpu_pm4_register_space_t register_space) {
+  return register_space ==
+             IREE_HAL_AMDGPU_PM4_REGISTER_SPACE_MEM_MAPPED_REGISTER ||
+         register_space == IREE_HAL_AMDGPU_PM4_REGISTER_SPACE_PERFCOUNTER;
+}
+
+static inline bool iree_hal_amdgpu_pm4_write_confirmation_is_valid(
+    iree_hal_amdgpu_pm4_write_confirmation_t write_confirmation) {
+  return write_confirmation == IREE_HAL_AMDGPU_PM4_WRITE_CONFIRMATION_NONE ||
+         write_confirmation == IREE_HAL_AMDGPU_PM4_WRITE_CONFIRMATION_WAIT;
+}
+
+static inline uint32_t iree_hal_amdgpu_pm4_copy_data_source_register_space(
+    iree_hal_amdgpu_pm4_register_space_t register_space) {
+  return (uint32_t)register_space << 0;
+}
+
+static inline uint32_t iree_hal_amdgpu_pm4_copy_data_target_register_space(
+    iree_hal_amdgpu_pm4_register_space_t register_space) {
+  return (uint32_t)register_space << 8;
+}
+
+static inline uint32_t iree_hal_amdgpu_pm4_copy_data_write_confirmation(
+    iree_hal_amdgpu_pm4_write_confirmation_t write_confirmation) {
+  return write_confirmation == IREE_HAL_AMDGPU_PM4_WRITE_CONFIRMATION_WAIT
+             ? IREE_HAL_AMDGPU_PM4_COPY_DATA_WR_CONFIRM_WAIT_CONFIRMATION
+             : 0;
+}
+
+// Appends an EVENT_WRITE CS_PARTIAL_FLUSH packet. This is the queue-local
+// wait-idle building block around counter programming; stronger
+// cache-management packets should remain separate helpers.
+static inline bool
+iree_hal_amdgpu_pm4_ib_builder_emit_event_write_cs_partial_flush(
+    iree_hal_amdgpu_pm4_ib_builder_t* builder) {
+  uint32_t* dword = iree_hal_amdgpu_pm4_ib_builder_append_packet(
+      builder, IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_EVENT_WRITE,
+      IREE_HAL_AMDGPU_PM4_EVENT_WRITE_DWORD_COUNT);
+  if (!dword) return false;
+  dword[1] = IREE_HAL_AMDGPU_PM4_EVENT_WRITE_EVENT_TYPE_CS_PARTIAL_FLUSH |
+             IREE_HAL_AMDGPU_PM4_EVENT_WRITE_EVENT_INDEX_CS_PARTIAL_FLUSH;
+  return true;
+}
+
+// Appends a SET_SH_REG packet for persistent shader registers.
+static inline bool iree_hal_amdgpu_pm4_ib_builder_emit_set_sh_reg(
+    iree_hal_amdgpu_pm4_ib_builder_t* builder, uint32_t register_address,
+    uint32_t value) {
+  if (register_address < IREE_HAL_AMDGPU_PM4_PERSISTENT_SPACE_START ||
+      register_address > IREE_HAL_AMDGPU_PM4_PERSISTENT_SPACE_END) {
+    return false;
+  }
+  uint32_t* dword = iree_hal_amdgpu_pm4_ib_builder_append_packet(
+      builder, IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_SET_SH_REG,
+      IREE_HAL_AMDGPU_PM4_SET_REGISTER_DWORD_COUNT);
+  if (!dword) return false;
+  dword[1] = register_address - IREE_HAL_AMDGPU_PM4_PERSISTENT_SPACE_START;
+  dword[2] = value;
+  return true;
+}
+
+// Appends a SET_UCONFIG_REG packet for user configuration registers.
+static inline bool iree_hal_amdgpu_pm4_ib_builder_emit_set_uconfig_reg(
+    iree_hal_amdgpu_pm4_ib_builder_t* builder, uint32_t register_address,
+    uint32_t value) {
+  if (register_address < IREE_HAL_AMDGPU_PM4_UCONFIG_SPACE_START ||
+      register_address > IREE_HAL_AMDGPU_PM4_UCONFIG_SPACE_END) {
+    return false;
+  }
+  uint32_t* dword = iree_hal_amdgpu_pm4_ib_builder_append_packet(
+      builder, IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_SET_UCONFIG_REG,
+      IREE_HAL_AMDGPU_PM4_SET_REGISTER_DWORD_COUNT);
+  if (!dword) return false;
+  dword[1] = register_address - IREE_HAL_AMDGPU_PM4_UCONFIG_SPACE_START;
+  dword[2] = value;
+  return true;
+}
+
+// Appends a COPY_DATA packet that writes an immediate 32-bit value into a
+// memory-mapped register or perfcounter register address.
+static inline bool
+iree_hal_amdgpu_pm4_ib_builder_emit_copy_immediate32_to_register(
+    iree_hal_amdgpu_pm4_ib_builder_t* builder,
+    iree_hal_amdgpu_pm4_register_space_t register_space,
+    uint32_t register_address, uint32_t value,
+    iree_hal_amdgpu_pm4_write_confirmation_t write_confirmation) {
+  if (!iree_hal_amdgpu_pm4_register_space_is_valid(register_space) ||
+      !iree_hal_amdgpu_pm4_write_confirmation_is_valid(write_confirmation) ||
+      register_address > IREE_HAL_AMDGPU_PM4_REGISTER_OFFSET_MASK) {
+    return false;
+  }
+  uint32_t* dword = iree_hal_amdgpu_pm4_ib_builder_append_packet(
+      builder, IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_COPY_DATA,
+      IREE_HAL_AMDGPU_PM4_COPY_DATA_DWORD_COUNT);
+  if (!dword) return false;
+  dword[1] =
+      IREE_HAL_AMDGPU_PM4_COPY_DATA_SRC_SEL_IMMEDIATE_DATA |
+      iree_hal_amdgpu_pm4_copy_data_target_register_space(register_space) |
+      iree_hal_amdgpu_pm4_copy_data_write_confirmation(write_confirmation);
+  dword[2] = value;
+  dword[3] = 0;
+  dword[4] = register_address;
+  dword[5] = 0;
+  return true;
+}
+
+// Appends a COPY_DATA packet that copies a 32-bit register or perfcounter value
+// into memory.
+static inline bool
+iree_hal_amdgpu_pm4_ib_builder_emit_copy_register32_to_memory(
+    iree_hal_amdgpu_pm4_ib_builder_t* builder,
+    iree_hal_amdgpu_pm4_register_space_t register_space,
+    uint32_t register_address, void* target,
+    iree_hal_amdgpu_pm4_write_confirmation_t write_confirmation) {
+  if (!iree_hal_amdgpu_pm4_register_space_is_valid(register_space) ||
+      !iree_hal_amdgpu_pm4_write_confirmation_is_valid(write_confirmation) ||
+      register_address > IREE_HAL_AMDGPU_PM4_REGISTER_OFFSET_MASK ||
+      !iree_host_ptr_has_alignment(target, 4)) {
+    return false;
+  }
+  const uintptr_t address = (uintptr_t)target;
+  uint32_t* dword = iree_hal_amdgpu_pm4_ib_builder_append_packet(
+      builder, IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_COPY_DATA,
+      IREE_HAL_AMDGPU_PM4_COPY_DATA_DWORD_COUNT);
+  if (!dword) return false;
+  dword[1] =
+      iree_hal_amdgpu_pm4_copy_data_source_register_space(register_space) |
+      IREE_HAL_AMDGPU_PM4_COPY_DATA_DST_SEL_TC_L2 |
+      iree_hal_amdgpu_pm4_copy_data_write_confirmation(write_confirmation);
+  dword[2] = register_address;
+  dword[3] = 0;
+  dword[4] = iree_hal_amdgpu_pm4_addr_lo(address);
+  dword[5] = iree_hal_amdgpu_pm4_addr_hi(address);
+  return true;
+}
+
+// Appends a COPY_DATA timestamp write to |target|. This is the RADV-style
+// top-of-pipe/immediate timestamp form and is intended for profiling records,
+// not memory copies. The packet form was probed on gfx1100 AQL compute queues;
+// callers must select it only for architectures where the physical-device
+// capability table says this form is valid.
+static inline bool iree_hal_amdgpu_pm4_ib_builder_emit_copy_timestamp_to_memory(
+    iree_hal_amdgpu_pm4_ib_builder_t* builder, void* target) {
+  if (!iree_host_ptr_has_alignment(target, 8)) return false;
+  const uintptr_t address = (uintptr_t)target;
+  uint32_t* dword = iree_hal_amdgpu_pm4_ib_builder_append_packet(
+      builder, IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_COPY_DATA,
+      IREE_HAL_AMDGPU_PM4_COPY_TIMESTAMP_DWORD_COUNT);
+  if (!dword) return false;
+  dword[1] = IREE_HAL_AMDGPU_PM4_COPY_DATA_SRC_SEL_TIMESTAMP |
+             IREE_HAL_AMDGPU_PM4_COPY_DATA_DST_SEL_MEM |
+             IREE_HAL_AMDGPU_PM4_COPY_DATA_COUNT_SEL_64_BITS |
+             IREE_HAL_AMDGPU_PM4_COPY_DATA_WR_CONFIRM_WAIT_CONFIRMATION;
+  dword[2] = 0;
+  dword[3] = 0;
+  dword[4] = (uint32_t)address;
+  dword[5] = iree_hal_amdgpu_pm4_addr_hi(address);
+  return true;
+}
+
+// Appends a RELEASE_MEM bottom-of-pipe timestamp write to |target|. This uses
+// only the common timestamp event/data fields and deliberately avoids cache or
+// PWS bits whose layout differs across gfx generations. The packet form was
+// probed on gfx1100 AQL compute queues; callers must select it only for
+// architectures where the physical-device capability table says this form is
+// valid.
+static inline bool
+iree_hal_amdgpu_pm4_ib_builder_emit_release_mem_timestamp_to_memory(
+    iree_hal_amdgpu_pm4_ib_builder_t* builder, void* target) {
+  if (!iree_host_ptr_has_alignment(target, 8)) return false;
+  const uintptr_t address = (uintptr_t)target;
+  uint32_t* dword = iree_hal_amdgpu_pm4_ib_builder_append_packet(
+      builder, IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_RELEASE_MEM,
+      IREE_HAL_AMDGPU_PM4_RELEASE_MEM_TIMESTAMP_DWORD_COUNT);
+  if (!dword) return false;
+  dword[1] = IREE_HAL_AMDGPU_PM4_RELEASE_MEM_EVENT_TYPE_BOTTOM_OF_PIPE_TS |
+             IREE_HAL_AMDGPU_PM4_RELEASE_MEM_EVENT_INDEX_END_OF_PIPE;
+  dword[2] =
+      IREE_HAL_AMDGPU_PM4_RELEASE_MEM_INT_SEL_SEND_DATA_AFTER_WR_CONFIRM |
+      IREE_HAL_AMDGPU_PM4_RELEASE_MEM_DATA_SEL_TIMESTAMP;
+  dword[3] = (uint32_t)address;
+  dword[4] = iree_hal_amdgpu_pm4_addr_hi(address);
+  dword[5] = 0;
+  dword[6] = 0;
+  dword[7] = 0;
+  return true;
+}
+
+// Appends a timestamp range around subsequent queue work. The start timestamp
+// is an immediate COPY_DATA timestamp and the end timestamp is a bottom-of-pipe
+// RELEASE_MEM timestamp. The helper preflights all space and alignment so it
+// either appends the complete range marker or leaves |builder| unchanged.
+static inline bool
+iree_hal_amdgpu_pm4_ib_builder_emit_timestamp_range_to_memory(
+    iree_hal_amdgpu_pm4_ib_builder_t* builder, void* start_target,
+    void* end_target) {
+  if (!iree_host_ptr_has_alignment(start_target, 8) ||
+      !iree_host_ptr_has_alignment(end_target, 8)) {
+    return false;
+  }
+  if (iree_hal_amdgpu_pm4_ib_builder_remaining(builder) <
+      IREE_HAL_AMDGPU_PM4_TIMESTAMP_RANGE_DWORD_COUNT) {
+    return false;
+  }
+  iree_hal_amdgpu_pm4_ib_builder_emit_copy_timestamp_to_memory(builder,
+                                                               start_target);
+  iree_hal_amdgpu_pm4_ib_builder_emit_release_mem_timestamp_to_memory(
+      builder, end_target);
+  return true;
+}
+
+static inline uint32_t iree_hal_amdgpu_pm4_wait_reg_mem_dw1(
+    uint32_t function, uint32_t mem_space, uint32_t operation) {
+  return (function & 0x7u) | ((mem_space & 0x3u) << 4) |
+         ((operation & 0x3u) << 6);
+}
+
+static inline uint32_t iree_hal_amdgpu_pm4_emit_wait_reg_mem64(
+    iree_hal_amdgpu_pm4_ib_slot_t* slot, iree_hsa_signal_t epoch_signal,
+    iree_hsa_signal_value_t compare_value, iree_hsa_signal_value_t mask) {
+  memset(slot, 0, sizeof(*slot));
+  iree_amd_signal_t* signal_abi = (iree_amd_signal_t*)epoch_signal.handle;
+  volatile iree_hsa_signal_value_t* value_address = &signal_abi->value;
+  const uintptr_t address = (uintptr_t)value_address;
+  uint32_t* dword = slot->dwords;
+  dword[0] = iree_hal_amdgpu_pm4_make_header(
+      IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_WAIT_REG_MEM64, 9);
+  dword[1] = iree_hal_amdgpu_pm4_wait_reg_mem_dw1(
+      IREE_HAL_AMDGPU_PM4_WAIT_REG_MEM_FUNC_LESS_THAN,
+      IREE_HAL_AMDGPU_PM4_WAIT_REG_MEM_SPACE_MEMORY,
+      IREE_HAL_AMDGPU_PM4_WAIT_REG_MEM_OPERATION_WAIT_REG_MEM);
+  dword[2] = iree_hal_amdgpu_pm4_addr_lo_8(address);
+  dword[3] = iree_hal_amdgpu_pm4_addr_hi(address);
+  dword[4] = (uint32_t)compare_value;
+  dword[5] = (uint32_t)((uint64_t)compare_value >> 32);
+  dword[6] = (uint32_t)mask;
+  dword[7] = (uint32_t)((uint64_t)mask >> 32);
+  dword[8] = 4 | IREE_HAL_AMDGPU_PM4_WAIT_REG_MEM_OPTIMIZE_ACE_OFFLOAD_MODE;
+  return 9;
+}
+
+static inline uint32_t iree_hal_amdgpu_pm4_emit_write_data32(
+    iree_hal_amdgpu_pm4_ib_slot_t* slot, void* target, uint32_t value) {
+  memset(slot, 0, sizeof(*slot));
+  const uintptr_t address = (uintptr_t)target;
+  uint32_t* dword = slot->dwords;
+  dword[0] = iree_hal_amdgpu_pm4_make_header(
+      IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_WRITE_DATA, 5);
+  dword[1] = IREE_HAL_AMDGPU_PM4_WRITE_DATA_DST_SEL_TC_L2 |
+             IREE_HAL_AMDGPU_PM4_WRITE_DATA_WR_CONFIRM_WAIT_CONFIRMATION;
+  dword[2] = iree_hal_amdgpu_pm4_addr_lo(address);
+  dword[3] = iree_hal_amdgpu_pm4_addr_hi(address);
+  dword[4] = value;
+  return 5;
+}
+
+static inline uint32_t iree_hal_amdgpu_pm4_emit_write_data64(
+    iree_hal_amdgpu_pm4_ib_slot_t* slot, void* target, uint64_t value) {
+  memset(slot, 0, sizeof(*slot));
+  const uintptr_t address = (uintptr_t)target;
+  uint32_t* dword = slot->dwords;
+  dword[0] = iree_hal_amdgpu_pm4_make_header(
+      IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_WRITE_DATA, 6);
+  dword[1] = IREE_HAL_AMDGPU_PM4_WRITE_DATA_DST_SEL_TC_L2 |
+             IREE_HAL_AMDGPU_PM4_WRITE_DATA_WR_CONFIRM_WAIT_CONFIRMATION;
+  dword[2] = iree_hal_amdgpu_pm4_addr_lo(address);
+  dword[3] = iree_hal_amdgpu_pm4_addr_hi(address);
+  dword[4] = (uint32_t)value;
+  dword[5] = (uint32_t)(value >> 32);
+  return 6;
+}
+
+static inline uint32_t iree_hal_amdgpu_pm4_emit_copy_data32(
+    iree_hal_amdgpu_pm4_ib_slot_t* slot, const void* source, void* target) {
+  memset(slot, 0, sizeof(*slot));
+  const uintptr_t source_address = (uintptr_t)source;
+  const uintptr_t target_address = (uintptr_t)target;
+  uint32_t* dword = slot->dwords;
+  dword[0] = iree_hal_amdgpu_pm4_make_header(
+      IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_COPY_DATA, 6);
+  dword[1] = IREE_HAL_AMDGPU_PM4_COPY_DATA_SRC_SEL_TC_L2 |
+             IREE_HAL_AMDGPU_PM4_COPY_DATA_DST_SEL_TC_L2 |
+             IREE_HAL_AMDGPU_PM4_COPY_DATA_WR_CONFIRM_WAIT_CONFIRMATION;
+  dword[2] = iree_hal_amdgpu_pm4_addr_lo(source_address);
+  dword[3] = iree_hal_amdgpu_pm4_addr_hi(source_address);
+  dword[4] = iree_hal_amdgpu_pm4_addr_lo(target_address);
+  dword[5] = iree_hal_amdgpu_pm4_addr_hi(target_address);
+  return 6;
+}
+
+static inline uint32_t iree_hal_amdgpu_pm4_emit_copy_data64(
+    iree_hal_amdgpu_pm4_ib_slot_t* slot, const void* source, void* target) {
+  memset(slot, 0, sizeof(*slot));
+  const uintptr_t source_address = (uintptr_t)source;
+  const uintptr_t target_address = (uintptr_t)target;
+  uint32_t* dword = slot->dwords;
+  dword[0] = iree_hal_amdgpu_pm4_make_header(
+      IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_COPY_DATA, 6);
+  dword[1] = IREE_HAL_AMDGPU_PM4_COPY_DATA_SRC_SEL_TC_L2 |
+             IREE_HAL_AMDGPU_PM4_COPY_DATA_DST_SEL_TC_L2 |
+             IREE_HAL_AMDGPU_PM4_COPY_DATA_COUNT_SEL_64_BITS |
+             IREE_HAL_AMDGPU_PM4_COPY_DATA_WR_CONFIRM_WAIT_CONFIRMATION;
+  dword[2] = iree_hal_amdgpu_pm4_addr_lo_8(source_address);
+  dword[3] = iree_hal_amdgpu_pm4_addr_hi(source_address);
+  dword[4] = iree_hal_amdgpu_pm4_addr_lo_8(target_address);
+  dword[5] = iree_hal_amdgpu_pm4_addr_hi(target_address);
+  return 6;
+}
+
+// Emits an AQL PM4-IB envelope referencing |ib_dwords|. The referenced dword
+// storage must remain immutable and live until the AQL packet retires.
+static inline uint16_t iree_hal_amdgpu_aql_emit_pm4_ib_dwords(
+    iree_hsa_amd_aql_pm4_ib_packet_t* packet, const uint32_t* ib_dwords,
+    uint32_t ib_dword_count,
+    iree_hal_amdgpu_aql_packet_control_t packet_control,
+    iree_hsa_signal_t completion_signal, uint16_t* out_setup) {
+  IREE_ASSERT(ib_dword_count <= IREE_HAL_AMDGPU_PM4_IB_MAX_DWORD_COUNT);
+  const uintptr_t ib_address = (uintptr_t)ib_dwords;
+  packet->ib_jump_cmd[0] = iree_hal_amdgpu_pm4_make_header(
+      IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_INDIRECT_BUFFER, 4);
+  packet->ib_jump_cmd[1] = iree_hal_amdgpu_pm4_addr_lo(ib_address);
+  packet->ib_jump_cmd[2] = iree_hal_amdgpu_pm4_ib_addr_hi(ib_address);
+  packet->ib_jump_cmd[3] = ib_dword_count | (1u << 23);
+  packet->dw_cnt_remain = 0xA;
+  for (iree_host_size_t i = 0; i < IREE_ARRAYSIZE(packet->reserved); ++i) {
+    packet->reserved[i] = 0;
+  }
+  packet->completion_signal = completion_signal;
+  *out_setup = IREE_HSA_AMD_AQL_FORMAT_PM4_IB;
+  return iree_hal_amdgpu_aql_make_header(IREE_HSA_PACKET_TYPE_VENDOR_SPECIFIC,
+                                         packet_control);
+}
+
+static inline uint16_t iree_hal_amdgpu_aql_emit_pm4_ib(
+    iree_hsa_amd_aql_pm4_ib_packet_t* packet,
+    const iree_hal_amdgpu_pm4_ib_slot_t* ib_slot, uint32_t ib_dword_count,
+    iree_hal_amdgpu_aql_packet_control_t packet_control,
+    iree_hsa_signal_t completion_signal, uint16_t* out_setup) {
+  return iree_hal_amdgpu_aql_emit_pm4_ib_dwords(packet, ib_slot->dwords,
+                                                ib_dword_count, packet_control,
+                                                completion_signal, out_setup);
+}
+
+// Emits an AQL PM4-IB packet that writes a top-of-pipe timestamp to
+// |start_tick|. The caller owns packet publication.
+static inline uint16_t iree_hal_amdgpu_aql_emit_timestamp_start(
+    iree_hsa_amd_aql_pm4_ib_packet_t* packet,
+    iree_hal_amdgpu_pm4_ib_slot_t* ib_slot,
+    iree_hal_amdgpu_aql_packet_control_t packet_control, uint64_t* start_tick,
+    uint16_t* out_setup) {
+  iree_hal_amdgpu_pm4_ib_builder_t builder;
+  iree_hal_amdgpu_pm4_ib_builder_initialize(ib_slot, &builder);
+  const bool did_emit =
+      iree_hal_amdgpu_pm4_ib_builder_emit_copy_timestamp_to_memory(&builder,
+                                                                   start_tick);
+  IREE_ASSERT(did_emit, "PM4 start timestamp must fit PM4 IB slot");
+  (void)did_emit;
+  return iree_hal_amdgpu_aql_emit_pm4_ib(
+      packet, ib_slot, iree_hal_amdgpu_pm4_ib_builder_dword_count(&builder),
+      packet_control, iree_hsa_signal_null(), out_setup);
+}
+
+// Emits an AQL PM4-IB packet that writes a bottom-of-pipe timestamp to
+// |end_tick|. The caller owns packet publication.
+static inline uint16_t iree_hal_amdgpu_aql_emit_timestamp_end(
+    iree_hsa_amd_aql_pm4_ib_packet_t* packet,
+    iree_hal_amdgpu_pm4_ib_slot_t* ib_slot,
+    iree_hal_amdgpu_aql_packet_control_t packet_control,
+    iree_hsa_signal_t completion_signal, uint64_t* end_tick,
+    uint16_t* out_setup) {
+  iree_hal_amdgpu_pm4_ib_builder_t builder;
+  iree_hal_amdgpu_pm4_ib_builder_initialize(ib_slot, &builder);
+  const bool did_emit =
+      iree_hal_amdgpu_pm4_ib_builder_emit_release_mem_timestamp_to_memory(
+          &builder, end_tick);
+  IREE_ASSERT(did_emit, "PM4 end timestamp must fit PM4 IB slot");
+  (void)did_emit;
+  return iree_hal_amdgpu_aql_emit_pm4_ib(
+      packet, ib_slot, iree_hal_amdgpu_pm4_ib_builder_dword_count(&builder),
+      packet_control, completion_signal, out_setup);
+}
+
+// Emits one AQL PM4-IB packet that writes both start and end timestamp fields.
+// The caller owns packet publication.
+static inline uint16_t iree_hal_amdgpu_aql_emit_timestamp_range(
+    iree_hsa_amd_aql_pm4_ib_packet_t* packet,
+    iree_hal_amdgpu_pm4_ib_slot_t* ib_slot,
+    iree_hal_amdgpu_aql_packet_control_t packet_control,
+    iree_hsa_signal_t completion_signal, uint64_t* start_tick,
+    uint64_t* end_tick, uint16_t* out_setup) {
+  iree_hal_amdgpu_pm4_ib_builder_t builder;
+  iree_hal_amdgpu_pm4_ib_builder_initialize(ib_slot, &builder);
+  const bool did_emit =
+      iree_hal_amdgpu_pm4_ib_builder_emit_timestamp_range_to_memory(
+          &builder, start_tick, end_tick);
+  IREE_ASSERT(did_emit, "PM4 timestamp range must fit PM4 IB slot");
+  (void)did_emit;
+  return iree_hal_amdgpu_aql_emit_pm4_ib(
+      packet, ib_slot, iree_hal_amdgpu_pm4_ib_builder_dword_count(&builder),
+      packet_control, completion_signal, out_setup);
+}
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_UTIL_PM4_EMITTER_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/pm4_emitter_test.cc b/runtime/src/iree/hal/drivers/amdgpu/util/pm4_emitter_test.cc
new file mode 100644
index 0000000..d1728c8
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/pm4_emitter_test.cc
@@ -0,0 +1,541 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/util/pm4_emitter.h"
+
+#include <cstring>
+
+#include "iree/hal/drivers/amdgpu/util/pm4_capabilities.h"
+#include "iree/testing/gtest.h"
+
+namespace {
+
+TEST(PM4CapabilitiesTest, MemoryWriteDataRequiresPM4IBAndPacketFamily) {
+  EXPECT_TRUE(
+      iree_hal_amdgpu_vendor_packet_capabilities_support_pm4_memory_write_data(
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_AQL_PM4_IB |
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_WRITE_DATA_MEMORY));
+  EXPECT_FALSE(
+      iree_hal_amdgpu_vendor_packet_capabilities_support_pm4_memory_write_data(
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_AQL_PM4_IB));
+  EXPECT_FALSE(
+      iree_hal_amdgpu_vendor_packet_capabilities_support_pm4_memory_write_data(
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_WRITE_DATA_MEMORY));
+}
+
+TEST(PM4CapabilitiesTest, MemoryCopyDataRequiresPM4IBAndPacketFamily) {
+  EXPECT_TRUE(
+      iree_hal_amdgpu_vendor_packet_capabilities_support_pm4_memory_copy_data(
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_AQL_PM4_IB |
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_COPY_DATA_MEMORY));
+  EXPECT_FALSE(
+      iree_hal_amdgpu_vendor_packet_capabilities_support_pm4_memory_copy_data(
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_AQL_PM4_IB));
+  EXPECT_FALSE(
+      iree_hal_amdgpu_vendor_packet_capabilities_support_pm4_memory_copy_data(
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_COPY_DATA_MEMORY));
+}
+
+TEST(PM4CapabilitiesTest, TimestampRangeRequiresAllPacketFamilies) {
+  EXPECT_TRUE(
+      iree_hal_amdgpu_vendor_packet_capabilities_support_timestamp_range(
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_AQL_PM4_IB |
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_COPY_TIMESTAMP |
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_RELEASE_MEM_TIMESTAMP));
+  EXPECT_FALSE(
+      iree_hal_amdgpu_vendor_packet_capabilities_support_timestamp_range(
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_COPY_TIMESTAMP |
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_RELEASE_MEM_TIMESTAMP));
+  EXPECT_FALSE(
+      iree_hal_amdgpu_vendor_packet_capabilities_support_timestamp_range(
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_AQL_PM4_IB |
+          IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_COPY_TIMESTAMP));
+}
+
+TEST(PM4CapabilitiesTest, Gfx10PmcProgramsRequireAllPacketFamilies) {
+  iree_hal_amdgpu_vendor_packet_capability_flags_t capabilities =
+      IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_AQL_PM4_IB |
+      IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_EVENT_WRITE |
+      IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_SET_SH_REG |
+      IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_SET_UCONFIG_REG |
+      IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_REGISTER_READBACK |
+      IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_PERFCOUNTER_READBACK |
+      IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_IMMEDIATE_WRITE;
+  EXPECT_TRUE(
+      iree_hal_amdgpu_vendor_packet_capabilities_support_gfx10_pmc_programs(
+          capabilities));
+  capabilities &=
+      ~IREE_HAL_AMDGPU_VENDOR_PACKET_CAPABILITY_PM4_PERFCOUNTER_READBACK;
+  EXPECT_FALSE(
+      iree_hal_amdgpu_vendor_packet_capabilities_support_gfx10_pmc_programs(
+          capabilities));
+}
+
+TEST(PM4EmitterTest, BuilderInitializesSlot) {
+  iree_hal_amdgpu_pm4_ib_slot_t slot;
+  std::memset(&slot, 0xCC, sizeof(slot));
+
+  iree_hal_amdgpu_pm4_ib_builder_t builder;
+  iree_hal_amdgpu_pm4_ib_builder_initialize(&slot, &builder);
+
+  EXPECT_EQ(builder.slot, &slot);
+  EXPECT_EQ(iree_hal_amdgpu_pm4_ib_builder_dword_count(&builder), 0u);
+  EXPECT_EQ(iree_hal_amdgpu_pm4_ib_builder_remaining(&builder),
+            IREE_HAL_AMDGPU_PM4_IB_SLOT_DWORD_CAPACITY);
+  for (uint32_t i = 0; i < IREE_HAL_AMDGPU_PM4_IB_SLOT_DWORD_CAPACITY; ++i) {
+    EXPECT_EQ(slot.dwords[i], 0u);
+  }
+}
+
+TEST(PM4EmitterTest, BuilderAppendsPackets) {
+  iree_hal_amdgpu_pm4_ib_slot_t slot;
+  iree_hal_amdgpu_pm4_ib_builder_t builder;
+  iree_hal_amdgpu_pm4_ib_builder_initialize(&slot, &builder);
+
+  uint32_t* write_data_packet = iree_hal_amdgpu_pm4_ib_builder_append_packet(
+      &builder, IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_WRITE_DATA,
+      /*dword_count=*/5);
+  ASSERT_NE(write_data_packet, nullptr);
+  EXPECT_EQ(write_data_packet, &slot.dwords[0]);
+  EXPECT_EQ(write_data_packet[0],
+            iree_hal_amdgpu_pm4_make_header(
+                IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_WRITE_DATA,
+                /*dword_count=*/5));
+  EXPECT_EQ(iree_hal_amdgpu_pm4_ib_builder_dword_count(&builder), 5u);
+
+  uint32_t* copy_data_packet = iree_hal_amdgpu_pm4_ib_builder_append_packet(
+      &builder, IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_COPY_DATA,
+      /*dword_count=*/6);
+  ASSERT_NE(copy_data_packet, nullptr);
+  EXPECT_EQ(copy_data_packet, &slot.dwords[5]);
+  EXPECT_EQ(copy_data_packet[0],
+            iree_hal_amdgpu_pm4_make_header(
+                IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_COPY_DATA,
+                /*dword_count=*/6));
+  EXPECT_EQ(iree_hal_amdgpu_pm4_ib_builder_dword_count(&builder), 11u);
+  EXPECT_EQ(iree_hal_amdgpu_pm4_ib_builder_remaining(&builder), 5u);
+}
+
+TEST(PM4EmitterTest, BuilderAppendsRegisterProgrammingPackets) {
+  iree_hal_amdgpu_pm4_ib_slot_t slot;
+  iree_hal_amdgpu_pm4_ib_builder_t builder;
+  iree_hal_amdgpu_pm4_ib_builder_initialize(&slot, &builder);
+
+  EXPECT_TRUE(iree_hal_amdgpu_pm4_ib_builder_emit_event_write_cs_partial_flush(
+      &builder));
+  EXPECT_TRUE(iree_hal_amdgpu_pm4_ib_builder_emit_set_sh_reg(
+      &builder, IREE_HAL_AMDGPU_PM4_PERSISTENT_SPACE_START + 0x34,
+      0x11111111u));
+  EXPECT_TRUE(iree_hal_amdgpu_pm4_ib_builder_emit_set_uconfig_reg(
+      &builder, IREE_HAL_AMDGPU_PM4_UCONFIG_SPACE_START + 0x123, 0x22222222u));
+  EXPECT_EQ(iree_hal_amdgpu_pm4_ib_builder_dword_count(&builder), 8u);
+
+  const uint32_t* dwords = slot.dwords;
+  EXPECT_EQ(dwords[0], iree_hal_amdgpu_pm4_make_header(
+                           IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_EVENT_WRITE,
+                           IREE_HAL_AMDGPU_PM4_EVENT_WRITE_DWORD_COUNT));
+  EXPECT_EQ(dwords[1],
+            IREE_HAL_AMDGPU_PM4_EVENT_WRITE_EVENT_TYPE_CS_PARTIAL_FLUSH |
+                IREE_HAL_AMDGPU_PM4_EVENT_WRITE_EVENT_INDEX_CS_PARTIAL_FLUSH);
+  EXPECT_EQ(dwords[2], iree_hal_amdgpu_pm4_make_header(
+                           IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_SET_SH_REG,
+                           IREE_HAL_AMDGPU_PM4_SET_REGISTER_DWORD_COUNT));
+  EXPECT_EQ(dwords[3], 0x34u);
+  EXPECT_EQ(dwords[4], 0x11111111u);
+  EXPECT_EQ(dwords[5], iree_hal_amdgpu_pm4_make_header(
+                           IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_SET_UCONFIG_REG,
+                           IREE_HAL_AMDGPU_PM4_SET_REGISTER_DWORD_COUNT));
+  EXPECT_EQ(dwords[6], 0x123u);
+  EXPECT_EQ(dwords[7], 0x22222222u);
+}
+
+TEST(PM4EmitterTest, BuilderAppendsRegisterCopyPackets) {
+  iree_hal_amdgpu_pm4_ib_slot_t slot;
+  iree_hal_amdgpu_pm4_ib_builder_t builder;
+  iree_hal_amdgpu_pm4_ib_builder_initialize(&slot, &builder);
+
+  void* target =
+      reinterpret_cast<void*>(static_cast<uintptr_t>(0x123456789ABCDEF0ull));
+  EXPECT_TRUE(iree_hal_amdgpu_pm4_ib_builder_emit_copy_immediate32_to_register(
+      &builder, IREE_HAL_AMDGPU_PM4_REGISTER_SPACE_PERFCOUNTER,
+      /*register_address=*/0x2345, /*value=*/0x33333333u,
+      IREE_HAL_AMDGPU_PM4_WRITE_CONFIRMATION_NONE));
+  EXPECT_TRUE(iree_hal_amdgpu_pm4_ib_builder_emit_copy_register32_to_memory(
+      &builder, IREE_HAL_AMDGPU_PM4_REGISTER_SPACE_PERFCOUNTER,
+      /*register_address=*/0x3456, target,
+      IREE_HAL_AMDGPU_PM4_WRITE_CONFIRMATION_WAIT));
+  EXPECT_EQ(iree_hal_amdgpu_pm4_ib_builder_dword_count(&builder), 12u);
+
+  const uint32_t* dwords = slot.dwords;
+  EXPECT_EQ(dwords[0], iree_hal_amdgpu_pm4_make_header(
+                           IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_COPY_DATA,
+                           IREE_HAL_AMDGPU_PM4_COPY_DATA_DWORD_COUNT));
+  EXPECT_EQ(dwords[1], IREE_HAL_AMDGPU_PM4_COPY_DATA_SRC_SEL_IMMEDIATE_DATA |
+                           IREE_HAL_AMDGPU_PM4_COPY_DATA_DST_SEL_PERFCOUNTER);
+  EXPECT_EQ(dwords[2], 0x33333333u);
+  EXPECT_EQ(dwords[3], 0u);
+  EXPECT_EQ(dwords[4], 0x2345u);
+  EXPECT_EQ(dwords[5], 0u);
+  EXPECT_EQ(dwords[6], iree_hal_amdgpu_pm4_make_header(
+                           IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_COPY_DATA,
+                           IREE_HAL_AMDGPU_PM4_COPY_DATA_DWORD_COUNT));
+  EXPECT_EQ(dwords[7],
+            IREE_HAL_AMDGPU_PM4_COPY_DATA_SRC_SEL_PERFCOUNTER |
+                IREE_HAL_AMDGPU_PM4_COPY_DATA_DST_SEL_TC_L2 |
+                IREE_HAL_AMDGPU_PM4_COPY_DATA_WR_CONFIRM_WAIT_CONFIRMATION);
+  EXPECT_EQ(dwords[8], 0x3456u);
+  EXPECT_EQ(dwords[9], 0u);
+  EXPECT_EQ(dwords[10], 0x9ABCDEF0u);
+  EXPECT_EQ(dwords[11], 0x12345678u);
+}
+
+TEST(PM4EmitterTest, BuilderRejectsInvalidRegisterPackets) {
+  iree_hal_amdgpu_pm4_ib_slot_t slot;
+  iree_hal_amdgpu_pm4_ib_builder_t builder;
+  iree_hal_amdgpu_pm4_ib_builder_initialize(&slot, &builder);
+
+  EXPECT_FALSE(iree_hal_amdgpu_pm4_ib_builder_emit_set_sh_reg(
+      &builder, IREE_HAL_AMDGPU_PM4_PERSISTENT_SPACE_START - 1, 0x11111111u));
+  EXPECT_FALSE(iree_hal_amdgpu_pm4_ib_builder_emit_set_uconfig_reg(
+      &builder, IREE_HAL_AMDGPU_PM4_UCONFIG_SPACE_START - 1, 0x22222222u));
+  EXPECT_FALSE(iree_hal_amdgpu_pm4_ib_builder_emit_copy_immediate32_to_register(
+      &builder, (iree_hal_amdgpu_pm4_register_space_t)7,
+      /*register_address=*/0x2345, /*value=*/0x33333333u,
+      IREE_HAL_AMDGPU_PM4_WRITE_CONFIRMATION_NONE));
+  EXPECT_FALSE(iree_hal_amdgpu_pm4_ib_builder_emit_copy_immediate32_to_register(
+      &builder, IREE_HAL_AMDGPU_PM4_REGISTER_SPACE_PERFCOUNTER,
+      /*register_address=*/0x2345, /*value=*/0x33333333u,
+      (iree_hal_amdgpu_pm4_write_confirmation_t)7));
+  void* unaligned_target =
+      reinterpret_cast<void*>(static_cast<uintptr_t>(0x123456789ABCDEFull));
+  EXPECT_FALSE(iree_hal_amdgpu_pm4_ib_builder_emit_copy_register32_to_memory(
+      &builder, IREE_HAL_AMDGPU_PM4_REGISTER_SPACE_MEM_MAPPED_REGISTER,
+      /*register_address=*/0x3456, unaligned_target,
+      IREE_HAL_AMDGPU_PM4_WRITE_CONFIRMATION_NONE));
+  EXPECT_EQ(iree_hal_amdgpu_pm4_ib_builder_dword_count(&builder), 0u);
+}
+
+TEST(PM4EmitterTest, EmitsWriteDataMemoryPackets) {
+  void* target =
+      reinterpret_cast<void*>(static_cast<uintptr_t>(0x123456789ABCDEF0ull));
+  iree_hal_amdgpu_pm4_ib_slot_t slot;
+
+  uint32_t dword_count = iree_hal_amdgpu_pm4_emit_write_data32(
+      &slot, target, /*value=*/0xAABBCCDDu);
+  EXPECT_EQ(dword_count, 5u);
+  EXPECT_EQ(slot.dwords[0], iree_hal_amdgpu_pm4_make_header(
+                                IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_WRITE_DATA,
+                                /*dword_count=*/5));
+  EXPECT_EQ(slot.dwords[1],
+            IREE_HAL_AMDGPU_PM4_WRITE_DATA_DST_SEL_TC_L2 |
+                IREE_HAL_AMDGPU_PM4_WRITE_DATA_WR_CONFIRM_WAIT_CONFIRMATION);
+  EXPECT_EQ(slot.dwords[2], 0x9ABCDEF0u);
+  EXPECT_EQ(slot.dwords[3], 0x12345678u);
+  EXPECT_EQ(slot.dwords[4], 0xAABBCCDDu);
+
+  dword_count = iree_hal_amdgpu_pm4_emit_write_data64(
+      &slot, target, /*value=*/0x1122334455667788ull);
+  EXPECT_EQ(dword_count, 6u);
+  EXPECT_EQ(slot.dwords[0], iree_hal_amdgpu_pm4_make_header(
+                                IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_WRITE_DATA,
+                                /*dword_count=*/6));
+  EXPECT_EQ(slot.dwords[1],
+            IREE_HAL_AMDGPU_PM4_WRITE_DATA_DST_SEL_TC_L2 |
+                IREE_HAL_AMDGPU_PM4_WRITE_DATA_WR_CONFIRM_WAIT_CONFIRMATION);
+  EXPECT_EQ(slot.dwords[2], 0x9ABCDEF0u);
+  EXPECT_EQ(slot.dwords[3], 0x12345678u);
+  EXPECT_EQ(slot.dwords[4], 0x55667788u);
+  EXPECT_EQ(slot.dwords[5], 0x11223344u);
+}
+
+TEST(PM4EmitterTest, EmitsCopyDataMemoryPackets) {
+  const void* source = reinterpret_cast<const void*>(
+      static_cast<uintptr_t>(0x123456789ABCDEF0ull));
+  void* target =
+      reinterpret_cast<void*>(static_cast<uintptr_t>(0x0FEDCBA987654320ull));
+  iree_hal_amdgpu_pm4_ib_slot_t slot;
+
+  uint32_t dword_count =
+      iree_hal_amdgpu_pm4_emit_copy_data32(&slot, source, target);
+  EXPECT_EQ(dword_count, 6u);
+  EXPECT_EQ(slot.dwords[0], iree_hal_amdgpu_pm4_make_header(
+                                IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_COPY_DATA,
+                                /*dword_count=*/6));
+  EXPECT_EQ(slot.dwords[1],
+            IREE_HAL_AMDGPU_PM4_COPY_DATA_SRC_SEL_TC_L2 |
+                IREE_HAL_AMDGPU_PM4_COPY_DATA_DST_SEL_TC_L2 |
+                IREE_HAL_AMDGPU_PM4_COPY_DATA_WR_CONFIRM_WAIT_CONFIRMATION);
+  EXPECT_EQ(slot.dwords[2], 0x9ABCDEF0u);
+  EXPECT_EQ(slot.dwords[3], 0x12345678u);
+  EXPECT_EQ(slot.dwords[4], 0x87654320u);
+  EXPECT_EQ(slot.dwords[5], 0x0FEDCBA9u);
+
+  dword_count = iree_hal_amdgpu_pm4_emit_copy_data64(&slot, source, target);
+  EXPECT_EQ(dword_count, 6u);
+  EXPECT_EQ(slot.dwords[0], iree_hal_amdgpu_pm4_make_header(
+                                IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_COPY_DATA,
+                                /*dword_count=*/6));
+  EXPECT_EQ(slot.dwords[1],
+            IREE_HAL_AMDGPU_PM4_COPY_DATA_SRC_SEL_TC_L2 |
+                IREE_HAL_AMDGPU_PM4_COPY_DATA_DST_SEL_TC_L2 |
+                IREE_HAL_AMDGPU_PM4_COPY_DATA_COUNT_SEL_64_BITS |
+                IREE_HAL_AMDGPU_PM4_COPY_DATA_WR_CONFIRM_WAIT_CONFIRMATION);
+  EXPECT_EQ(slot.dwords[2], 0x9ABCDEF0u);
+  EXPECT_EQ(slot.dwords[3], 0x12345678u);
+  EXPECT_EQ(slot.dwords[4], 0x87654320u);
+  EXPECT_EQ(slot.dwords[5], 0x0FEDCBA9u);
+}
+
+TEST(PM4EmitterTest, EmitsWaitRegMem64Packet) {
+  iree_amd_signal_t signal_abi = {};
+  iree_hsa_signal_t epoch_signal = {};
+  epoch_signal.handle = reinterpret_cast<uint64_t>(&signal_abi);
+  iree_hal_amdgpu_pm4_ib_slot_t slot;
+
+  const uint32_t dword_count = iree_hal_amdgpu_pm4_emit_wait_reg_mem64(
+      &slot, epoch_signal, /*compare_value=*/0x1122334455667788ll,
+      /*mask=*/0x7FFFFFFFFFFFFFFFll);
+  EXPECT_EQ(dword_count, 9u);
+  EXPECT_EQ(slot.dwords[0],
+            iree_hal_amdgpu_pm4_make_header(
+                IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_WAIT_REG_MEM64,
+                /*dword_count=*/9));
+  EXPECT_EQ(slot.dwords[1],
+            iree_hal_amdgpu_pm4_wait_reg_mem_dw1(
+                IREE_HAL_AMDGPU_PM4_WAIT_REG_MEM_FUNC_LESS_THAN,
+                IREE_HAL_AMDGPU_PM4_WAIT_REG_MEM_SPACE_MEMORY,
+                IREE_HAL_AMDGPU_PM4_WAIT_REG_MEM_OPERATION_WAIT_REG_MEM));
+
+  const uintptr_t value_address =
+      reinterpret_cast<uintptr_t>(&signal_abi.value);
+  EXPECT_EQ(slot.dwords[2], iree_hal_amdgpu_pm4_addr_lo_8(value_address));
+  EXPECT_EQ(slot.dwords[3], iree_hal_amdgpu_pm4_addr_hi(value_address));
+  EXPECT_EQ(slot.dwords[4], 0x55667788u);
+  EXPECT_EQ(slot.dwords[5], 0x11223344u);
+  EXPECT_EQ(slot.dwords[6], 0xFFFFFFFFu);
+  EXPECT_EQ(slot.dwords[7], 0x7FFFFFFFu);
+  EXPECT_EQ(slot.dwords[8],
+            4u | IREE_HAL_AMDGPU_PM4_WAIT_REG_MEM_OPTIMIZE_ACE_OFFLOAD_MODE);
+}
+
+TEST(PM4EmitterTest, BuilderAppendsTimestampPackets) {
+  iree_hal_amdgpu_pm4_ib_slot_t slot;
+  iree_hal_amdgpu_pm4_ib_builder_t builder;
+  iree_hal_amdgpu_pm4_ib_builder_initialize(&slot, &builder);
+
+  void* copy_timestamp_target =
+      reinterpret_cast<void*>(static_cast<uintptr_t>(0x123456789ABCDEF0ull));
+  void* release_timestamp_target =
+      reinterpret_cast<void*>(static_cast<uintptr_t>(0x0FEDCBA987654320ull));
+
+  EXPECT_TRUE(iree_hal_amdgpu_pm4_ib_builder_emit_copy_timestamp_to_memory(
+      &builder, copy_timestamp_target));
+  EXPECT_TRUE(
+      iree_hal_amdgpu_pm4_ib_builder_emit_release_mem_timestamp_to_memory(
+          &builder, release_timestamp_target));
+  EXPECT_EQ(iree_hal_amdgpu_pm4_ib_builder_dword_count(&builder), 14u);
+  EXPECT_EQ(iree_hal_amdgpu_pm4_ib_builder_remaining(&builder), 2u);
+
+  const uint32_t* dwords = slot.dwords;
+  EXPECT_EQ(dwords[0], iree_hal_amdgpu_pm4_make_header(
+                           IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_COPY_DATA,
+                           /*dword_count=*/6));
+  EXPECT_EQ(dwords[1],
+            IREE_HAL_AMDGPU_PM4_COPY_DATA_SRC_SEL_TIMESTAMP |
+                IREE_HAL_AMDGPU_PM4_COPY_DATA_DST_SEL_MEM |
+                IREE_HAL_AMDGPU_PM4_COPY_DATA_COUNT_SEL_64_BITS |
+                IREE_HAL_AMDGPU_PM4_COPY_DATA_WR_CONFIRM_WAIT_CONFIRMATION);
+  EXPECT_EQ(dwords[2], 0u);
+  EXPECT_EQ(dwords[3], 0u);
+  EXPECT_EQ(dwords[4], 0x9ABCDEF0u);
+  EXPECT_EQ(dwords[5], 0x12345678u);
+  EXPECT_EQ(dwords[6], iree_hal_amdgpu_pm4_make_header(
+                           IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_RELEASE_MEM,
+                           /*dword_count=*/8));
+  EXPECT_EQ(dwords[7],
+            IREE_HAL_AMDGPU_PM4_RELEASE_MEM_EVENT_TYPE_BOTTOM_OF_PIPE_TS |
+                IREE_HAL_AMDGPU_PM4_RELEASE_MEM_EVENT_INDEX_END_OF_PIPE);
+  EXPECT_EQ(dwords[8],
+            IREE_HAL_AMDGPU_PM4_RELEASE_MEM_INT_SEL_SEND_DATA_AFTER_WR_CONFIRM |
+                IREE_HAL_AMDGPU_PM4_RELEASE_MEM_DATA_SEL_TIMESTAMP);
+  EXPECT_EQ(dwords[9], 0x87654320u);
+  EXPECT_EQ(dwords[10], 0x0FEDCBA9u);
+  EXPECT_EQ(dwords[11], 0u);
+  EXPECT_EQ(dwords[12], 0u);
+  EXPECT_EQ(dwords[13], 0u);
+}
+
+TEST(PM4EmitterTest, BuilderAppendsTimestampRangeAtomically) {
+  iree_hal_amdgpu_pm4_ib_slot_t slot;
+  iree_hal_amdgpu_pm4_ib_builder_t builder;
+  iree_hal_amdgpu_pm4_ib_builder_initialize(&slot, &builder);
+
+  void* start_target =
+      reinterpret_cast<void*>(static_cast<uintptr_t>(0x123456789ABCDEF0ull));
+  void* end_target =
+      reinterpret_cast<void*>(static_cast<uintptr_t>(0x0FEDCBA987654320ull));
+
+  EXPECT_TRUE(iree_hal_amdgpu_pm4_ib_builder_emit_timestamp_range_to_memory(
+      &builder, start_target, end_target));
+  EXPECT_EQ(iree_hal_amdgpu_pm4_ib_builder_dword_count(&builder),
+            IREE_HAL_AMDGPU_PM4_TIMESTAMP_RANGE_DWORD_COUNT);
+
+  const uint32_t* dwords = slot.dwords;
+  EXPECT_EQ(dwords[0], iree_hal_amdgpu_pm4_make_header(
+                           IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_COPY_DATA,
+                           IREE_HAL_AMDGPU_PM4_COPY_TIMESTAMP_DWORD_COUNT));
+  EXPECT_EQ(dwords[6],
+            iree_hal_amdgpu_pm4_make_header(
+                IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_RELEASE_MEM,
+                IREE_HAL_AMDGPU_PM4_RELEASE_MEM_TIMESTAMP_DWORD_COUNT));
+}
+
+TEST(PM4EmitterTest, EmitsTimestampAqlPackets) {
+  iree_hal_amdgpu_aql_packet_control_t packet_control =
+      iree_hal_amdgpu_aql_packet_control_barrier_system();
+  const iree_hsa_signal_t completion_signal = {0x12345678ull};
+  uint64_t start_tick = 0;
+  uint64_t end_tick = 0;
+
+  iree_hsa_amd_aql_pm4_ib_packet_t packet = {};
+  iree_hal_amdgpu_pm4_ib_slot_t slot;
+  uint16_t setup = 0;
+  uint16_t header = iree_hal_amdgpu_aql_emit_timestamp_start(
+      &packet, &slot, packet_control, &start_tick, &setup);
+  EXPECT_EQ(header, iree_hal_amdgpu_aql_make_header(
+                        IREE_HSA_PACKET_TYPE_VENDOR_SPECIFIC, packet_control));
+  EXPECT_EQ(setup, IREE_HSA_AMD_AQL_FORMAT_PM4_IB);
+  EXPECT_EQ(packet.completion_signal.handle, iree_hsa_signal_null().handle);
+  EXPECT_EQ(slot.dwords[0],
+            iree_hal_amdgpu_pm4_make_header(
+                IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_COPY_DATA,
+                IREE_HAL_AMDGPU_PM4_COPY_TIMESTAMP_DWORD_COUNT));
+
+  packet = {};
+  header = iree_hal_amdgpu_aql_emit_timestamp_end(
+      &packet, &slot, packet_control, completion_signal, &end_tick, &setup);
+  EXPECT_EQ(header, iree_hal_amdgpu_aql_make_header(
+                        IREE_HSA_PACKET_TYPE_VENDOR_SPECIFIC, packet_control));
+  EXPECT_EQ(setup, IREE_HSA_AMD_AQL_FORMAT_PM4_IB);
+  EXPECT_EQ(packet.completion_signal.handle, completion_signal.handle);
+  EXPECT_EQ(slot.dwords[0],
+            iree_hal_amdgpu_pm4_make_header(
+                IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_RELEASE_MEM,
+                IREE_HAL_AMDGPU_PM4_RELEASE_MEM_TIMESTAMP_DWORD_COUNT));
+
+  packet = {};
+  header = iree_hal_amdgpu_aql_emit_timestamp_range(
+      &packet, &slot, packet_control, completion_signal, &start_tick, &end_tick,
+      &setup);
+  EXPECT_EQ(header, iree_hal_amdgpu_aql_make_header(
+                        IREE_HSA_PACKET_TYPE_VENDOR_SPECIFIC, packet_control));
+  EXPECT_EQ(setup, IREE_HSA_AMD_AQL_FORMAT_PM4_IB);
+  EXPECT_EQ(packet.completion_signal.handle, completion_signal.handle);
+  EXPECT_EQ(slot.dwords[0],
+            iree_hal_amdgpu_pm4_make_header(
+                IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_COPY_DATA,
+                IREE_HAL_AMDGPU_PM4_COPY_TIMESTAMP_DWORD_COUNT));
+  EXPECT_EQ(slot.dwords[6],
+            iree_hal_amdgpu_pm4_make_header(
+                IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_RELEASE_MEM,
+                IREE_HAL_AMDGPU_PM4_RELEASE_MEM_TIMESTAMP_DWORD_COUNT));
+}
+
+TEST(PM4EmitterTest, BuilderRejectsTimestampRangeWithoutPartialAppend) {
+  iree_hal_amdgpu_pm4_ib_slot_t slot;
+  iree_hal_amdgpu_pm4_ib_builder_t builder;
+  iree_hal_amdgpu_pm4_ib_builder_initialize(&slot, &builder);
+
+  uint32_t* prefix = iree_hal_amdgpu_pm4_ib_builder_append_dwords(&builder, 3);
+  ASSERT_NE(prefix, nullptr);
+  void* start_target =
+      reinterpret_cast<void*>(static_cast<uintptr_t>(0x123456789ABCDEF0ull));
+  void* end_target =
+      reinterpret_cast<void*>(static_cast<uintptr_t>(0x0FEDCBA987654320ull));
+  EXPECT_FALSE(iree_hal_amdgpu_pm4_ib_builder_emit_timestamp_range_to_memory(
+      &builder, start_target, end_target));
+  EXPECT_EQ(iree_hal_amdgpu_pm4_ib_builder_dword_count(&builder), 3u);
+
+  iree_hal_amdgpu_pm4_ib_builder_initialize(&slot, &builder);
+  void* unaligned_start =
+      reinterpret_cast<void*>(static_cast<uintptr_t>(0x123456789ABCDEFull));
+  EXPECT_FALSE(iree_hal_amdgpu_pm4_ib_builder_emit_timestamp_range_to_memory(
+      &builder, unaligned_start, end_target));
+  EXPECT_EQ(iree_hal_amdgpu_pm4_ib_builder_dword_count(&builder), 0u);
+}
+
+TEST(PM4EmitterTest, BuilderRejectsMalformedAndOverflowPackets) {
+  iree_hal_amdgpu_pm4_ib_slot_t slot;
+  iree_hal_amdgpu_pm4_ib_builder_t builder;
+  iree_hal_amdgpu_pm4_ib_builder_initialize(&slot, &builder);
+
+  EXPECT_EQ(iree_hal_amdgpu_pm4_ib_builder_append_packet(
+                &builder, IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_WRITE_DATA,
+                /*dword_count=*/1),
+            nullptr);
+  EXPECT_EQ(iree_hal_amdgpu_pm4_ib_builder_dword_count(&builder), 0u);
+
+  uint32_t* full_slot = iree_hal_amdgpu_pm4_ib_builder_append_dwords(
+      &builder, IREE_HAL_AMDGPU_PM4_IB_SLOT_DWORD_CAPACITY);
+  ASSERT_NE(full_slot, nullptr);
+  EXPECT_EQ(iree_hal_amdgpu_pm4_ib_builder_dword_count(&builder),
+            IREE_HAL_AMDGPU_PM4_IB_SLOT_DWORD_CAPACITY);
+  EXPECT_EQ(iree_hal_amdgpu_pm4_ib_builder_append_dwords(&builder, 1), nullptr);
+  EXPECT_EQ(iree_hal_amdgpu_pm4_ib_builder_dword_count(&builder),
+            IREE_HAL_AMDGPU_PM4_IB_SLOT_DWORD_CAPACITY);
+}
+
+TEST(PM4EmitterTest, BuilderRejectsTimestampAlignmentAndOverflow) {
+  iree_hal_amdgpu_pm4_ib_slot_t slot;
+  iree_hal_amdgpu_pm4_ib_builder_t builder;
+  iree_hal_amdgpu_pm4_ib_builder_initialize(&slot, &builder);
+
+  void* unaligned_target =
+      reinterpret_cast<void*>(static_cast<uintptr_t>(0x123456789ABCDEFull));
+  EXPECT_FALSE(iree_hal_amdgpu_pm4_ib_builder_emit_copy_timestamp_to_memory(
+      &builder, unaligned_target));
+  EXPECT_FALSE(
+      iree_hal_amdgpu_pm4_ib_builder_emit_release_mem_timestamp_to_memory(
+          &builder, unaligned_target));
+  EXPECT_EQ(iree_hal_amdgpu_pm4_ib_builder_dword_count(&builder), 0u);
+
+  uint32_t* prefix = iree_hal_amdgpu_pm4_ib_builder_append_dwords(&builder, 10);
+  ASSERT_NE(prefix, nullptr);
+  void* aligned_target =
+      reinterpret_cast<void*>(static_cast<uintptr_t>(0x123456789ABCDEF0ull));
+  EXPECT_FALSE(
+      iree_hal_amdgpu_pm4_ib_builder_emit_release_mem_timestamp_to_memory(
+          &builder, aligned_target));
+  EXPECT_EQ(iree_hal_amdgpu_pm4_ib_builder_dword_count(&builder), 10u);
+}
+
+TEST(PM4EmitterTest, EmitsArbitraryPM4IBDwordEnvelope) {
+  uint32_t dwords[32] = {0};
+  iree_hsa_amd_aql_pm4_ib_packet_t packet = {};
+
+  iree_hal_amdgpu_aql_packet_control_t packet_control =
+      iree_hal_amdgpu_aql_packet_control_barrier_system();
+  uint16_t setup = 0;
+  uint16_t header = iree_hal_amdgpu_aql_emit_pm4_ib_dwords(
+      &packet, dwords, IREE_ARRAYSIZE(dwords), packet_control,
+      iree_hsa_signal_null(), &setup);
+
+  EXPECT_EQ(header, iree_hal_amdgpu_aql_make_header(
+                        IREE_HSA_PACKET_TYPE_VENDOR_SPECIFIC, packet_control));
+  EXPECT_EQ(setup, IREE_HSA_AMD_AQL_FORMAT_PM4_IB);
+  EXPECT_EQ(packet.ib_jump_cmd[0],
+            iree_hal_amdgpu_pm4_make_header(
+                IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_INDIRECT_BUFFER,
+                /*dword_count=*/4));
+  uintptr_t dword_address = reinterpret_cast<uintptr_t>(dwords);
+  EXPECT_EQ(packet.ib_jump_cmd[1], iree_hal_amdgpu_pm4_addr_lo(dword_address));
+  EXPECT_EQ(packet.ib_jump_cmd[2],
+            iree_hal_amdgpu_pm4_ib_addr_hi(dword_address));
+  EXPECT_EQ(packet.ib_jump_cmd[3], IREE_ARRAYSIZE(dwords) | (1u << 23));
+  EXPECT_EQ(packet.dw_cnt_remain, 0xAu);
+}
+
+}  // namespace
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/pm4_program.c b/runtime/src/iree/hal/drivers/amdgpu/util/pm4_program.c
new file mode 100644
index 0000000..27bf941
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/pm4_program.c
@@ -0,0 +1,89 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/util/pm4_program.h"
+
+#include <stddef.h>
+#include <string.h>
+
+#include "iree/hal/drivers/amdgpu/util/pm4_emitter.h"
+
+iree_status_t iree_hal_amdgpu_pm4_program_initialize(
+    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_agent_t device_agent,
+    hsa_amd_memory_pool_t memory_pool, const uint32_t* source_dwords,
+    uint32_t dword_count, iree_hal_amdgpu_pm4_program_t* out_program) {
+  IREE_ASSERT_ARGUMENT(libhsa);
+  IREE_ASSERT_ARGUMENT(source_dwords);
+  IREE_ASSERT_ARGUMENT(out_program);
+  memset(out_program, 0, sizeof(*out_program));
+  if (IREE_UNLIKELY(!memory_pool.handle)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "PM4 program memory pool is required");
+  }
+  if (IREE_UNLIKELY(dword_count == 0)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "PM4 program must contain at least one dword");
+  }
+  if (IREE_UNLIKELY(dword_count > IREE_HAL_AMDGPU_PM4_IB_MAX_DWORD_COUNT)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "PM4 program dword count %u exceeds PM4-IB maximum %u", dword_count,
+        IREE_HAL_AMDGPU_PM4_IB_MAX_DWORD_COUNT);
+  }
+
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, dword_count);
+
+  iree_host_size_t byte_length = 0;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, IREE_STRUCT_LAYOUT(0, &byte_length,
+                             IREE_STRUCT_FIELD(dword_count, uint32_t, NULL)));
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, byte_length);
+
+  IREE_AMDGPU_DEVICE_PTR uint32_t* dwords = NULL;
+  iree_status_t status = iree_hsa_amd_memory_pool_allocate(
+      IREE_LIBHSA(libhsa), memory_pool, byte_length,
+      HSA_AMD_MEMORY_POOL_EXECUTABLE_FLAG, (void**)&dwords);
+
+  if (iree_status_is_ok(status)) {
+    status =
+        iree_hsa_amd_agents_allow_access(IREE_LIBHSA(libhsa), /*num_agents=*/1,
+                                         &device_agent, /*flags=*/NULL, dwords);
+  }
+
+  if (iree_status_is_ok(status)) {
+    status = iree_hsa_memory_copy(IREE_LIBHSA(libhsa), dwords, source_dwords,
+                                  byte_length);
+  }
+
+  if (iree_status_is_ok(status)) {
+    out_program->libhsa = libhsa;
+    out_program->memory_pool = memory_pool;
+    out_program->dwords = dwords;
+    out_program->dword_count = dword_count;
+    out_program->byte_length = byte_length;
+  } else if (dwords) {
+    status = iree_status_join(
+        status, iree_hsa_amd_memory_pool_free(IREE_LIBHSA(libhsa), dwords));
+  }
+
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+iree_status_t iree_hal_amdgpu_pm4_program_release(
+    iree_hal_amdgpu_pm4_program_t* program) {
+  if (!program || !program->dwords) return iree_ok_status();
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, program->byte_length);
+  iree_status_t status = iree_hsa_amd_memory_pool_free(
+      IREE_LIBHSA(program->libhsa), program->dwords);
+  if (iree_status_is_ok(status)) {
+    memset(program, 0, sizeof(*program));
+  }
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/pm4_program.h b/runtime/src/iree/hal/drivers/amdgpu/util/pm4_program.h
new file mode 100644
index 0000000..ced42db
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/pm4_program.h
@@ -0,0 +1,51 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_UTIL_PM4_PROGRAM_H_
+#define IREE_HAL_DRIVERS_AMDGPU_UTIL_PM4_PROGRAM_H_
+
+#include "iree/hal/drivers/amdgpu/util/libhsa.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Persistent immutable PM4 program storage.
+//
+// Queue-private PM4 IB slots are intentionally one AQL slot wide. Larger or
+// longer-lived PM4 programs, such as profiling start/read/stop streams, use
+// this object to keep immutable PM4 dwords in executable memory with explicit
+// owner lifetime.
+typedef struct iree_hal_amdgpu_pm4_program_t {
+  // HSA API table used to free |dwords|. Not retained.
+  const iree_hal_amdgpu_libhsa_t* libhsa;
+  // HSA memory pool that owns |dwords|.
+  hsa_amd_memory_pool_t memory_pool;
+  // Device-visible immutable PM4 dwords referenced by PM4-IB AQL packets.
+  IREE_AMDGPU_DEVICE_PTR uint32_t* dwords;
+  // Number of valid PM4 dwords in |dwords|.
+  uint32_t dword_count;
+  // Allocated byte length of |dwords|.
+  iree_host_size_t byte_length;
+} iree_hal_amdgpu_pm4_program_t;
+
+// Copies |source_dwords| into executable memory and grants |device_agent|
+// access so the command processor can execute the program via PM4-IB AQL
+// envelopes. The copied program is immutable after initialization.
+iree_status_t iree_hal_amdgpu_pm4_program_initialize(
+    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_agent_t device_agent,
+    hsa_amd_memory_pool_t memory_pool, const uint32_t* source_dwords,
+    uint32_t dword_count, iree_hal_amdgpu_pm4_program_t* out_program);
+
+// Releases the executable storage backing |program| and clears it on success.
+iree_status_t iree_hal_amdgpu_pm4_program_release(
+    iree_hal_amdgpu_pm4_program_t* program);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_UTIL_PM4_PROGRAM_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/pm4_program_test.cc b/runtime/src/iree/hal/drivers/amdgpu/util/pm4_program_test.cc
new file mode 100644
index 0000000..2576bc8
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/pm4_program_test.cc
@@ -0,0 +1,132 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/util/pm4_program.h"
+
+#include <cstring>
+
+#include "iree/base/api.h"
+#include "iree/hal/drivers/amdgpu/util/pm4_emitter.h"
+#include "iree/hal/drivers/amdgpu/util/topology.h"
+#include "iree/hal/drivers/amdgpu/util/vmem.h"
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace iree::hal::amdgpu {
+namespace {
+
+using iree::testing::status::StatusIs;
+
+struct PM4ProgramTest : public ::testing::Test {
+  static iree_allocator_t host_allocator;
+  static iree_hal_amdgpu_libhsa_t libhsa;
+  static iree_hal_amdgpu_topology_t topology;
+
+  static void SetUpTestSuite() {
+    IREE_TRACE_SCOPE();
+    host_allocator = iree_allocator_system();
+    iree_status_t status = iree_hal_amdgpu_libhsa_initialize(
+        IREE_HAL_AMDGPU_LIBHSA_FLAG_NONE, iree_string_view_list_empty(),
+        host_allocator, &libhsa);
+    if (!iree_status_is_ok(status)) {
+      iree_status_fprint(stderr, status);
+      iree_status_free(status);
+      GTEST_SKIP() << "HSA not available, skipping tests";
+    }
+    IREE_ASSERT_OK(
+        iree_hal_amdgpu_topology_initialize_with_defaults(&libhsa, &topology));
+    if (topology.gpu_agent_count == 0 || topology.cpu_agent_count == 0) {
+      GTEST_SKIP() << "CPU and GPU agents are required, skipping tests";
+    }
+  }
+
+  static void TearDownTestSuite() {
+    IREE_TRACE_SCOPE();
+    iree_hal_amdgpu_topology_deinitialize(&topology);
+    iree_hal_amdgpu_libhsa_deinitialize(&libhsa);
+  }
+};
+iree_allocator_t PM4ProgramTest::host_allocator;
+iree_hal_amdgpu_libhsa_t PM4ProgramTest::libhsa;
+iree_hal_amdgpu_topology_t PM4ProgramTest::topology;
+
+TEST_F(PM4ProgramTest, InitializePersistentProgram) {
+  IREE_TRACE_SCOPE();
+
+  hsa_agent_t cpu_agent = topology.cpu_agents[0];
+  hsa_agent_t gpu_agent = topology.gpu_agents[0];
+  hsa_amd_memory_pool_t memory_pool;
+  IREE_ASSERT_OK(iree_hal_amdgpu_find_coarse_global_memory_pool(
+      &libhsa, cpu_agent, &memory_pool));
+
+  uint32_t source_dwords[32] = {0};
+  for (uint32_t i = 0; i < IREE_ARRAYSIZE(source_dwords); ++i) {
+    source_dwords[i] = 0xC0DEC000u + i;
+  }
+
+  iree_hal_amdgpu_pm4_program_t program = {0};
+  IREE_ASSERT_OK(iree_hal_amdgpu_pm4_program_initialize(
+      &libhsa, gpu_agent, memory_pool, source_dwords,
+      IREE_ARRAYSIZE(source_dwords), &program));
+
+  EXPECT_EQ(program.libhsa, &libhsa);
+  EXPECT_EQ(program.memory_pool.handle, memory_pool.handle);
+  ASSERT_NE(program.dwords, nullptr);
+  EXPECT_EQ(program.dword_count, IREE_ARRAYSIZE(source_dwords));
+  EXPECT_EQ(program.byte_length, sizeof(source_dwords));
+  EXPECT_EQ(std::memcmp(program.dwords, source_dwords, sizeof(source_dwords)),
+            0);
+
+  iree_hsa_amd_aql_pm4_ib_packet_t packet = {};
+  uint16_t setup = 0;
+  iree_hal_amdgpu_aql_packet_control_t packet_control =
+      iree_hal_amdgpu_aql_packet_control_barrier_system();
+  uint16_t header = iree_hal_amdgpu_aql_emit_pm4_ib_dwords(
+      &packet, program.dwords, program.dword_count, packet_control,
+      iree_hsa_signal_null(), &setup);
+
+  EXPECT_EQ(header, iree_hal_amdgpu_aql_make_header(
+                        IREE_HSA_PACKET_TYPE_VENDOR_SPECIFIC, packet_control));
+  EXPECT_EQ(setup, IREE_HSA_AMD_AQL_FORMAT_PM4_IB);
+  EXPECT_EQ(packet.ib_jump_cmd[0],
+            iree_hal_amdgpu_pm4_make_header(
+                IREE_HAL_AMDGPU_PM4_HDR_IT_OPCODE_INDIRECT_BUFFER, 4));
+  uintptr_t dword_address = (uintptr_t)program.dwords;
+  EXPECT_EQ(packet.ib_jump_cmd[1], iree_hal_amdgpu_pm4_addr_lo(dword_address));
+  EXPECT_EQ(packet.ib_jump_cmd[2],
+            iree_hal_amdgpu_pm4_ib_addr_hi(dword_address));
+  EXPECT_EQ(packet.ib_jump_cmd[3], program.dword_count | (1u << 23));
+  EXPECT_EQ(packet.dw_cnt_remain, 0xAu);
+
+  IREE_ASSERT_OK(iree_hal_amdgpu_pm4_program_release(&program));
+  EXPECT_EQ(program.dwords, nullptr);
+  EXPECT_EQ(program.dword_count, 0u);
+  EXPECT_EQ(program.byte_length, 0u);
+}
+
+TEST_F(PM4ProgramTest, RejectsInvalidProgramShape) {
+  IREE_TRACE_SCOPE();
+
+  hsa_agent_t cpu_agent = topology.cpu_agents[0];
+  hsa_agent_t gpu_agent = topology.gpu_agents[0];
+  hsa_amd_memory_pool_t memory_pool;
+  IREE_ASSERT_OK(iree_hal_amdgpu_find_coarse_global_memory_pool(
+      &libhsa, cpu_agent, &memory_pool));
+
+  const uint32_t source_dword = 0xC0DEC000u;
+  iree_hal_amdgpu_pm4_program_t program = {0};
+  EXPECT_THAT(Status(iree_hal_amdgpu_pm4_program_initialize(
+                  &libhsa, gpu_agent, memory_pool, &source_dword,
+                  /*dword_count=*/0, &program)),
+              StatusIs(StatusCode::kInvalidArgument));
+  EXPECT_THAT(Status(iree_hal_amdgpu_pm4_program_initialize(
+                  &libhsa, gpu_agent, memory_pool, &source_dword,
+                  IREE_HAL_AMDGPU_PM4_IB_MAX_DWORD_COUNT + 1, &program)),
+              StatusIs(StatusCode::kOutOfRange));
+}
+
+}  // namespace
+}  // namespace iree::hal::amdgpu
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/queue_benchmark.cc b/runtime/src/iree/hal/drivers/amdgpu/util/queue_benchmark.cc
new file mode 100644
index 0000000..60ad3b6
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/queue_benchmark.cc
@@ -0,0 +1,3985 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <benchmark/benchmark.h>
+
+#include <algorithm>
+#include <cstdint>
+#include <cstdio>
+#include <cstring>
+
+#include "iree/async/frontier_tracker.h"
+#include "iree/async/util/proactor_pool.h"
+#include "iree/base/api.h"
+#include "iree/base/threading/numa.h"
+#include "iree/base/tooling/flags.h"
+#include "iree/hal/api.h"
+#include "iree/hal/drivers/amdgpu/aql_command_buffer.h"
+#include "iree/hal/drivers/amdgpu/buffer.h"
+#include "iree/hal/drivers/amdgpu/device/dispatch.h"
+#include "iree/hal/drivers/amdgpu/executable.h"
+#include "iree/hal/drivers/amdgpu/host_queue_dispatch.h"
+#include "iree/hal/drivers/amdgpu/logical_device.h"
+#include "iree/hal/drivers/amdgpu/physical_device.h"
+#include "iree/hal/drivers/amdgpu/queue_affinity.h"
+#include "iree/hal/drivers/amdgpu/registration/driver_module.h"
+#include "iree/hal/drivers/amdgpu/semaphore.h"
+#include "iree/hal/drivers/amdgpu/util/benchmark_flags.h"
+#include "iree/hal/memory/tlsf_pool.h"
+#include "iree/io/file_contents.h"
+#include "runtime/src/iree/hal/drivers/amdgpu/cts/testdata_amdgpu.h"
+#include "runtime/src/iree/hal/drivers/amdgpu/util/testdata_amdgpu_queue_benchmark.h"
+
+IREE_FLAG(
+    string, binding_count_executable_file, "",
+    "Optional raw HSACO or AMDGPU executable file used for binding-count "
+    "dispatch benchmark rows. When empty, the embedded benchmark executable "
+    "is used.");
+IREE_FLAG(
+    int32_t, binding_count_workgroup_size_x, 0,
+    "Optional 1D workgroup-size override for binding-count dispatch benchmark "
+    "rows. Use 64 to match common HIP launch benchmark shapes. The default 0 "
+    "uses the executable export's reflected workgroup size.");
+
+namespace {
+
+constexpr int64_t kBatchCount = 20;
+constexpr uint32_t kFrontierAxisTableCapacity = 256;
+constexpr iree_device_size_t kPayloadBufferAlignment = 16;
+constexpr iree_device_size_t kPayloadLength = sizeof(uint32_t);
+constexpr iree_host_size_t kDispatchBindingBenchmarkMaxCount = 256;
+constexpr iree_host_size_t kDispatchBindingBenchmarkVariantCapacity =
+    kDispatchBindingBenchmarkMaxCount + 1;
+constexpr int64_t kProfileGuardrailBindingCount = 1;
+constexpr int64_t kProfileGuardrailIterations = 200;
+constexpr int64_t kProfileGuardrailOperationCount = 20;
+constexpr iree_hal_queue_affinity_t kQueue0 = ((iree_hal_queue_affinity_t)1ull)
+                                              << 0;
+constexpr iree_hal_queue_affinity_t kQueue1 = ((iree_hal_queue_affinity_t)1ull)
+                                              << 1;
+
+enum class PayloadKind {
+  kCopy,
+  kDispatch,
+  kFill,
+  kNoopDispatch,
+  kPreResolvedDispatch,
+};
+
+enum class ProfileGuardrailMode : int64_t {
+  kDisabled = 0,
+  kQueueDeviceEvents = 1,
+  kDispatchEvents = 2,
+  kQueueDeviceAndDispatchEvents = 3,
+};
+
+enum class QueueAllocaTlsfGrowthMode : int64_t {
+  kWarm = 0,
+  kForcedGrowth = 1,
+};
+
+constexpr iree_hal_buffer_params_t kQueueAllocaBufferParams = {
+    /*usage=*/IREE_HAL_BUFFER_USAGE_TRANSFER |
+        IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE,
+    /*access=*/IREE_HAL_MEMORY_ACCESS_ALL,
+    /*type=*/IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE,
+    /*queue_affinity=*/0,
+    /*min_alignment=*/0,
+};
+
+iree_hal_device_profiling_data_families_t ProfileGuardrailDataFamilies(
+    ProfileGuardrailMode mode) {
+  switch (mode) {
+    case ProfileGuardrailMode::kDisabled:
+      return IREE_HAL_DEVICE_PROFILING_DATA_NONE;
+    case ProfileGuardrailMode::kQueueDeviceEvents:
+      return IREE_HAL_DEVICE_PROFILING_DATA_QUEUE_EVENTS |
+             IREE_HAL_DEVICE_PROFILING_DATA_DEVICE_QUEUE_EVENTS;
+    case ProfileGuardrailMode::kDispatchEvents:
+      return IREE_HAL_DEVICE_PROFILING_DATA_DISPATCH_EVENTS;
+    case ProfileGuardrailMode::kQueueDeviceAndDispatchEvents:
+      return IREE_HAL_DEVICE_PROFILING_DATA_QUEUE_EVENTS |
+             IREE_HAL_DEVICE_PROFILING_DATA_DEVICE_QUEUE_EVENTS |
+             IREE_HAL_DEVICE_PROFILING_DATA_DISPATCH_EVENTS;
+  }
+  return IREE_HAL_DEVICE_PROFILING_DATA_NONE;
+}
+
+typedef struct QueueBenchmarkDiscardProfileSink {
+  // Resource header for iree_hal_profile_sink_t lifetime management.
+  iree_hal_resource_t resource;
+  // Host allocator used for sink lifetime.
+  iree_allocator_t host_allocator;
+  // Number of profile chunks observed by the sink.
+  uint64_t write_count;
+  // Total bytes observed across all profile chunk iovecs.
+  uint64_t payload_byte_count;
+} QueueBenchmarkDiscardProfileSink;
+
+static void QueueBenchmarkDiscardProfileSinkDestroy(
+    iree_hal_profile_sink_t* base_sink) {
+  QueueBenchmarkDiscardProfileSink* sink =
+      (QueueBenchmarkDiscardProfileSink*)base_sink;
+  iree_allocator_t host_allocator = sink->host_allocator;
+  iree_allocator_free(host_allocator, sink);
+}
+
+static iree_status_t QueueBenchmarkDiscardProfileSinkBeginSession(
+    iree_hal_profile_sink_t* base_sink,
+    const iree_hal_profile_chunk_metadata_t* metadata) {
+  (void)base_sink;
+  (void)metadata;
+  return iree_ok_status();
+}
+
+static iree_status_t QueueBenchmarkDiscardProfileSinkWrite(
+    iree_hal_profile_sink_t* base_sink,
+    const iree_hal_profile_chunk_metadata_t* metadata,
+    iree_host_size_t iovec_count, const iree_const_byte_span_t* iovecs) {
+  (void)metadata;
+  QueueBenchmarkDiscardProfileSink* sink =
+      (QueueBenchmarkDiscardProfileSink*)base_sink;
+  ++sink->write_count;
+  for (iree_host_size_t i = 0; i < iovec_count; ++i) {
+    sink->payload_byte_count += iovecs[i].data_length;
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t QueueBenchmarkDiscardProfileSinkEndSession(
+    iree_hal_profile_sink_t* base_sink,
+    const iree_hal_profile_chunk_metadata_t* metadata,
+    iree_status_code_t session_status_code) {
+  (void)base_sink;
+  (void)metadata;
+  (void)session_status_code;
+  return iree_ok_status();
+}
+
+static const iree_hal_profile_sink_vtable_t
+    kQueueBenchmarkDiscardProfileSinkVtable = {
+        .destroy = QueueBenchmarkDiscardProfileSinkDestroy,
+        .begin_session = QueueBenchmarkDiscardProfileSinkBeginSession,
+        .write = QueueBenchmarkDiscardProfileSinkWrite,
+        .end_session = QueueBenchmarkDiscardProfileSinkEndSession,
+};
+
+iree_status_t QueueBenchmarkDiscardProfileSinkCreate(
+    iree_allocator_t host_allocator, iree_hal_profile_sink_t** out_sink) {
+  *out_sink = nullptr;
+  QueueBenchmarkDiscardProfileSink* sink = nullptr;
+  IREE_RETURN_IF_ERROR(
+      iree_allocator_malloc(host_allocator, sizeof(*sink), (void**)&sink));
+  iree_hal_resource_initialize(&kQueueBenchmarkDiscardProfileSinkVtable,
+                               &sink->resource);
+  sink->host_allocator = host_allocator;
+  sink->write_count = 0;
+  sink->payload_byte_count = 0;
+  *out_sink = (iree_hal_profile_sink_t*)sink;
+  return iree_ok_status();
+}
+
+class QueueBenchmark : public benchmark::Fixture {
+ public:
+  static void InitializeOnce() {
+    if (initialized_) return;
+    initialized_ = true;
+    host_allocator_ = iree_allocator_system();
+
+    iree_status_t status = iree_hal_amdgpu_driver_module_register(
+        iree_hal_driver_registry_default());
+    if (iree_status_is_already_exists(status)) {
+      iree_status_free(status);
+      status = iree_ok_status();
+    }
+
+    if (iree_status_is_ok(status)) {
+      status = iree_hal_driver_registry_try_create(
+          iree_hal_driver_registry_default(), iree_make_cstring_view("amdgpu"),
+          host_allocator_, &driver_);
+    }
+
+    iree_async_proactor_pool_t* proactor_pool = nullptr;
+    if (iree_status_is_ok(status)) {
+      status = iree_async_proactor_pool_create(
+          iree_numa_node_count(), /*node_ids=*/nullptr,
+          iree_async_proactor_pool_options_default(), host_allocator_,
+          &proactor_pool);
+    }
+
+    if (iree_status_is_ok(status)) {
+      iree_hal_device_create_params_t create_params =
+          iree_hal_device_create_params_default();
+      create_params.proactor_pool = proactor_pool;
+      status = iree_hal_driver_create_default_device(driver_, &create_params,
+                                                     host_allocator_, &device_);
+    }
+    iree_async_proactor_pool_release(proactor_pool);
+
+    iree_async_frontier_tracker_t* frontier_tracker = nullptr;
+    if (iree_status_is_ok(status)) {
+      iree_async_frontier_tracker_options_t options =
+          iree_async_frontier_tracker_options_default();
+      options.axis_table_capacity = kFrontierAxisTableCapacity;
+      status = iree_async_frontier_tracker_create(options, host_allocator_,
+                                                  &frontier_tracker);
+    }
+    if (iree_status_is_ok(status)) {
+      status = iree_hal_device_group_create_from_device(
+          device_, frontier_tracker, host_allocator_, &device_group_);
+    }
+    iree_async_frontier_tracker_release(frontier_tracker);
+
+    if (iree_status_is_ok(status)) {
+      available_ = true;
+      return;
+    }
+
+    iree_status_fprint(stderr, status);
+    iree_status_free(status);
+    iree_hal_device_release(device_);
+    iree_hal_driver_release(driver_);
+    device_ = nullptr;
+    driver_ = nullptr;
+  }
+
+  static void DeinitializeOnce() {
+    if (!initialized_) return;
+    iree_hal_executable_release(binding_count_executable_);
+    iree_hal_executable_cache_release(binding_count_executable_cache_);
+    iree_io_file_contents_free(binding_count_executable_file_contents_);
+    iree_hal_executable_release(dispatch_executable_);
+    iree_hal_executable_cache_release(dispatch_executable_cache_);
+    iree_hal_device_release(device_);
+    iree_hal_device_group_release(device_group_);
+    iree_hal_driver_release(driver_);
+    binding_count_executable_ = nullptr;
+    binding_count_executable_cache_ = nullptr;
+    binding_count_executable_file_contents_ = nullptr;
+    dispatch_executable_ = nullptr;
+    dispatch_executable_cache_ = nullptr;
+    device_ = nullptr;
+    device_group_ = nullptr;
+    driver_ = nullptr;
+    available_ = false;
+  }
+
+  void SetUp(benchmark::State& state) override {
+    InitializeOnce();
+    if (!available_) {
+      state.SkipWithError("AMDGPU HAL device not available");
+      return;
+    }
+
+    if (!CreatePublicSemaphore(state, &completion_semaphore_) ||
+        !CreatePrivateStreamSemaphore(state, &stream_semaphore_) ||
+        !CreatePrivateStreamSemaphore(state, &producer_semaphore_)) {
+      return;
+    }
+  }
+
+  void TearDown(benchmark::State& state) override {
+    if (profile_session_active_) {
+      EndProfileSession(state,
+                        "profiling end failed during benchmark teardown");
+    }
+    iree_hal_profile_sink_release(profile_sink_);
+    profile_sink_ = nullptr;
+    ReleasePreResolvedDispatch();
+    for (iree_host_size_t i = 0; i < kDispatchBindingBenchmarkVariantCapacity;
+         ++i) {
+      iree_hal_command_buffer_release(binding_count_command_buffers_[i]);
+      binding_count_command_buffers_[i] = nullptr;
+    }
+    for (iree_host_size_t i = 0; i < kDispatchBindingBenchmarkMaxCount; ++i) {
+      iree_hal_buffer_release(binding_count_buffers_[i]);
+      binding_count_buffers_[i] = nullptr;
+    }
+    iree_hal_buffer_release(source_buffer_);
+    iree_hal_buffer_release(target_buffer_);
+    iree_hal_semaphore_release(completion_semaphore_);
+    iree_hal_semaphore_release(stream_semaphore_);
+    iree_hal_semaphore_release(producer_semaphore_);
+    source_buffer_ = nullptr;
+    target_buffer_ = nullptr;
+    completion_semaphore_ = nullptr;
+    stream_semaphore_ = nullptr;
+    producer_semaphore_ = nullptr;
+    completion_payload_value_ = 0;
+    stream_payload_value_ = 0;
+    producer_payload_value_ = 0;
+  }
+
+ protected:
+  struct SubmittedCompletion {
+    iree_hal_semaphore_t* semaphore;
+    uint64_t payload_value;
+  };
+
+  static iree_hal_queue_affinity_t CrossQueuePingPongFinalQueue(
+      int64_t handoff_count) {
+    return (handoff_count & 1) ? kQueue1 : kQueue0;
+  }
+
+  bool EnsureQueueAvailable(benchmark::State& state,
+                            iree_hal_queue_affinity_t queue_affinity) {
+    if (queue_affinity == kQueue1) {
+      // Cross-queue rows are same-physical-agent measurements; skip instead
+      // of silently turning them into cross-device/system-scope rows.
+      return EnsurePrivateStreamQueuePair(state);
+    }
+    return HandleStatus(state,
+                        iree_hal_device_queue_flush(device_, queue_affinity),
+                        "queue affinity not available");
+  }
+
+  bool BeginProfileSession(benchmark::State& state, ProfileGuardrailMode mode) {
+    const iree_hal_device_profiling_data_families_t data_families =
+        ProfileGuardrailDataFamilies(mode);
+    if (data_families == IREE_HAL_DEVICE_PROFILING_DATA_NONE) return true;
+
+    if (!HandleStatus(state,
+                      QueueBenchmarkDiscardProfileSinkCreate(host_allocator_,
+                                                             &profile_sink_),
+                      "failed to create profile benchmark sink")) {
+      return false;
+    }
+    iree_hal_device_profiling_options_t options = {0};
+    options.data_families = data_families;
+    options.sink = profile_sink_;
+    if (!HandleStatus(state, iree_hal_device_profiling_begin(device_, &options),
+                      "failed to begin profile benchmark session")) {
+      iree_hal_profile_sink_release(profile_sink_);
+      profile_sink_ = nullptr;
+      return false;
+    }
+    profile_session_active_ = true;
+    return true;
+  }
+
+  void EndProfileSession(benchmark::State& state, const char* message) {
+    if (!profile_session_active_) return;
+    profile_session_active_ = false;
+    HandleStatus(state, iree_hal_device_profiling_end(device_), message);
+  }
+
+  bool FlushProfileSession(benchmark::State& state, const char* message) {
+    if (!profile_session_active_) return true;
+    return HandleStatus(state, iree_hal_device_profiling_flush(device_),
+                        message);
+  }
+
+  bool FlushProfileSessionWithTimingPaused(benchmark::State& state,
+                                           const char* message) {
+    if (!profile_session_active_) return true;
+    state.PauseTiming();
+    const bool result = FlushProfileSession(state, message);
+    state.ResumeTiming();
+    return result;
+  }
+
+  bool WaitAndFlushProfileSessionWithTimingPaused(
+      benchmark::State& state, const SubmittedCompletion& completion,
+      const char* wait_message, const char* flush_message) {
+    state.PauseTiming();
+    iree_status_t status = Wait(completion.semaphore, completion.payload_value);
+    bool result = HandleStatus(state, status, wait_message);
+    if (result) {
+      result = FlushProfileSession(state, flush_message);
+    }
+    state.ResumeTiming();
+    return result;
+  }
+
+  template <typename RunWaitFn>
+  void RunProfileGuardrailFinalWaitBenchmark(
+      benchmark::State& state, int64_t queue_submissions_per_sync,
+      int64_t profiled_operations_per_sync, const char* flush_message,
+      const char* end_message, RunWaitFn run_wait) {
+    const ProfileGuardrailMode profile_mode =
+        static_cast<ProfileGuardrailMode>(state.range(0));
+    if (!BeginProfileSession(state, profile_mode)) return;
+    for (auto _ : state) {
+      if (!run_wait()) break;
+      if (!FlushProfileSessionWithTimingPaused(state, flush_message)) break;
+    }
+    SetProfileGuardrailCounters(state, profile_mode, queue_submissions_per_sync,
+                                profiled_operations_per_sync);
+    EndProfileSession(state, end_message);
+  }
+
+  template <typename SubmitFn>
+  void RunProfileGuardrailSubmitOnlyBenchmark(
+      benchmark::State& state, int64_t queue_submissions_per_sync,
+      int64_t profiled_operations_per_sync, const char* wait_message,
+      const char* flush_message, const char* end_message, SubmitFn submit) {
+    const ProfileGuardrailMode profile_mode =
+        static_cast<ProfileGuardrailMode>(state.range(0));
+    if (!BeginProfileSession(state, profile_mode)) return;
+    for (auto _ : state) {
+      SubmittedCompletion completion;
+      if (!submit(&completion)) break;
+      if (!WaitAndFlushProfileSessionWithTimingPaused(
+              state, completion, wait_message, flush_message)) {
+        break;
+      }
+    }
+    SetProfileGuardrailCounters(state, profile_mode, queue_submissions_per_sync,
+                                profiled_operations_per_sync);
+    EndProfileSession(state, end_message);
+  }
+
+  iree_status_t LookupHostQueue(iree_hal_queue_affinity_t queue_affinity,
+                                iree_hal_amdgpu_host_queue_t** out_host_queue) {
+    *out_host_queue = nullptr;
+    auto* logical_device =
+        reinterpret_cast<iree_hal_amdgpu_logical_device_t*>(device_);
+    const iree_hal_amdgpu_queue_affinity_domain_t domain = {
+        .supported_affinity = logical_device->queue_affinity_mask,
+        .physical_device_count = logical_device->physical_device_count,
+        .queue_count_per_physical_device =
+            logical_device->system->topology.gpu_agent_queue_count,
+    };
+    iree_hal_amdgpu_queue_affinity_resolved_t resolved;
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_queue_affinity_resolve(
+        domain, queue_affinity, &resolved));
+    iree_hal_amdgpu_physical_device_t* physical_device =
+        logical_device->physical_devices[resolved.physical_device_ordinal];
+    if (IREE_UNLIKELY(resolved.physical_queue_ordinal >=
+                      physical_device->host_queue_count)) {
+      return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                              "queue ordinal has no initialized host queue");
+    }
+
+    *out_host_queue =
+        &physical_device->host_queues[resolved.physical_queue_ordinal];
+    return iree_ok_status();
+  }
+
+  iree_status_t LookupHostQueueByAxis(
+      iree_async_axis_t axis, iree_hal_amdgpu_host_queue_t** out_host_queue) {
+    *out_host_queue = nullptr;
+    if (IREE_UNLIKELY(iree_async_axis_domain(axis) !=
+                      IREE_ASYNC_CAUSAL_DOMAIN_QUEUE)) {
+      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                              "producer axis is not a queue axis");
+    }
+
+    auto* logical_device =
+        reinterpret_cast<iree_hal_amdgpu_logical_device_t*>(device_);
+    const uint8_t device_index = iree_async_axis_device_index(axis);
+    if (IREE_UNLIKELY(device_index >= logical_device->physical_device_count)) {
+      return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                              "producer axis has no physical device");
+    }
+    iree_hal_amdgpu_physical_device_t* physical_device =
+        logical_device->physical_devices[device_index];
+    const uint8_t queue_index = iree_async_axis_queue_index(axis);
+    if (IREE_UNLIKELY(queue_index >= physical_device->host_queue_count)) {
+      return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                              "producer axis has no initialized host queue");
+    }
+
+    *out_host_queue = &physical_device->host_queues[queue_index];
+    return iree_ok_status();
+  }
+
+  bool EnsurePrivateStreamQueuePair(benchmark::State& state) {
+    iree_hal_amdgpu_host_queue_t* queue0 = nullptr;
+    iree_hal_amdgpu_host_queue_t* queue1 = nullptr;
+    if (!HandleStatus(state, LookupHostQueue(kQueue0, &queue0),
+                      "queue 0 is not available")) {
+      return false;
+    }
+    if (!HandleStatus(state, LookupHostQueue(kQueue1, &queue1),
+                      "queue 1 is not available")) {
+      return false;
+    }
+    if (queue0->device_ordinal != queue1->device_ordinal) {
+      state.SkipWithError(
+          "queue 0 and queue 1 are not on the same physical AMDGPU agent");
+      return false;
+    }
+    return true;
+  }
+
+  iree_status_t WaitForSubmittedProducerEpoch(
+      const SubmittedCompletion& completion) {
+    if (IREE_UNLIKELY(!iree_hal_amdgpu_semaphore_isa(completion.semaphore))) {
+      return iree_make_status(
+          IREE_STATUS_INVALID_ARGUMENT,
+          "epoch completion floor requires an AMDGPU semaphore");
+    }
+
+    iree_hal_amdgpu_last_signal_flags_t signal_flags =
+        IREE_HAL_AMDGPU_LAST_SIGNAL_FLAG_NONE;
+    iree_async_axis_t producer_axis = 0;
+    uint64_t producer_epoch = 0;
+    uint64_t producer_value = 0;
+    if (IREE_UNLIKELY(!iree_hal_amdgpu_last_signal_load(
+            iree_hal_amdgpu_semaphore_last_signal(completion.semaphore),
+            &signal_flags, &producer_axis, &producer_epoch, &producer_value))) {
+      return iree_make_status(
+          IREE_STATUS_FAILED_PRECONDITION,
+          "completion semaphore has no submitted producer epoch");
+    }
+    if (IREE_UNLIKELY(producer_value < completion.payload_value)) {
+      return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                              "completion semaphore producer value %" PRIu64
+                              " is below target %" PRIu64,
+                              producer_value, completion.payload_value);
+    }
+
+    iree_hal_amdgpu_host_queue_t* host_queue = nullptr;
+    IREE_RETURN_IF_ERROR(LookupHostQueueByAxis(producer_axis, &host_queue));
+    hsa_signal_t epoch_signal = iree_hal_amdgpu_notification_ring_epoch_signal(
+        &host_queue->notification_ring);
+    const hsa_signal_value_t compare_value =
+        (hsa_signal_value_t)(IREE_HAL_AMDGPU_EPOCH_INITIAL_VALUE -
+                             producer_epoch + 1);
+
+    auto* logical_device =
+        reinterpret_cast<iree_hal_amdgpu_logical_device_t*>(device_);
+    uint64_t wait_timeout_hint =
+        logical_device->system->info.timestamp_frequency / 1000;
+    if (wait_timeout_hint == 0) wait_timeout_hint = 1;
+
+    for (;;) {
+      hsa_signal_value_t signal_value = iree_hsa_signal_wait_scacquire(
+          IREE_LIBHSA(host_queue->libhsa), epoch_signal,
+          HSA_SIGNAL_CONDITION_LT, compare_value, wait_timeout_hint,
+          HSA_WAIT_STATE_BLOCKED);
+      if (signal_value < compare_value) return iree_ok_status();
+
+      iree_status_t queue_error = (iree_status_t)iree_atomic_load(
+          &host_queue->error_status, iree_memory_order_acquire);
+      if (IREE_UNLIKELY(!iree_status_is_ok(queue_error))) {
+        return iree_status_clone(queue_error);
+      }
+    }
+  }
+
+  static iree_status_t ResolveBufferDevicePointer(iree_hal_buffer_t* buffer,
+                                                  uint64_t* out_device_ptr) {
+    *out_device_ptr = 0;
+    iree_hal_buffer_t* allocated_buffer =
+        iree_hal_buffer_allocated_buffer(buffer);
+    void* device_ptr = iree_hal_amdgpu_buffer_device_pointer(allocated_buffer);
+    if (IREE_UNLIKELY(!device_ptr)) {
+      return iree_make_status(
+          IREE_STATUS_INVALID_ARGUMENT,
+          "dispatch benchmark buffer must be backed by an AMDGPU allocation");
+    }
+
+    const iree_device_size_t device_offset =
+        iree_hal_buffer_byte_offset(buffer);
+    if (IREE_UNLIKELY(device_offset > UINTPTR_MAX)) {
+      return iree_make_status(
+          IREE_STATUS_OUT_OF_RANGE,
+          "dispatch benchmark buffer offset exceeds host pointer size");
+    }
+    *out_device_ptr = (uint64_t)((uintptr_t)device_ptr + device_offset);
+    return iree_ok_status();
+  }
+
+  iree_status_t SubmitBarrierWithLists(
+      iree_hal_queue_affinity_t queue_affinity,
+      iree_hal_semaphore_list_t wait_semaphore_list,
+      iree_hal_semaphore_list_t signal_semaphore_list) {
+    return iree_hal_device_queue_barrier(
+        device_, queue_affinity, wait_semaphore_list, signal_semaphore_list,
+        IREE_HAL_EXECUTE_FLAG_NONE);
+  }
+
+  iree_status_t SubmitPayloadWithLists(
+      PayloadKind payload_kind, iree_hal_queue_affinity_t queue_affinity,
+      iree_hal_semaphore_list_t wait_semaphore_list,
+      iree_hal_semaphore_list_t signal_semaphore_list) {
+    if (payload_kind == PayloadKind::kCopy) {
+      return iree_hal_device_queue_copy(
+          device_, queue_affinity, wait_semaphore_list, signal_semaphore_list,
+          source_buffer_, /*source_offset=*/0, target_buffer_,
+          /*target_offset=*/0, kPayloadLength, IREE_HAL_COPY_FLAG_NONE);
+    }
+    if (payload_kind == PayloadKind::kDispatch) {
+      const uint32_t constant_data[] = {3, 10};
+      iree_const_byte_span_t constants =
+          iree_make_const_byte_span(constant_data, sizeof(constant_data));
+      iree_hal_buffer_ref_t binding_refs[2] = {
+          iree_hal_make_buffer_ref(source_buffer_, /*offset=*/0,
+                                   iree_hal_buffer_byte_length(source_buffer_)),
+          iree_hal_make_buffer_ref(target_buffer_, /*offset=*/0,
+                                   iree_hal_buffer_byte_length(target_buffer_)),
+      };
+      iree_hal_buffer_ref_list_t bindings = {
+          /*count=*/IREE_ARRAYSIZE(binding_refs),
+          /*values=*/binding_refs,
+      };
+      return iree_hal_device_queue_dispatch(
+          device_, queue_affinity, wait_semaphore_list, signal_semaphore_list,
+          dispatch_executable_, /*export_ordinal=*/0,
+          iree_hal_make_static_dispatch_config(1, 1, 1), constants, bindings,
+          IREE_HAL_DISPATCH_FLAG_NONE);
+    }
+    if (payload_kind == PayloadKind::kNoopDispatch) {
+      return iree_hal_device_queue_dispatch(
+          device_, queue_affinity, wait_semaphore_list, signal_semaphore_list,
+          dispatch_executable_, /*export_ordinal=*/0,
+          iree_hal_make_static_dispatch_config(0, 0, 0),
+          iree_const_byte_span_empty(), iree_hal_buffer_ref_list_empty(),
+          IREE_HAL_DISPATCH_FLAG_NONE);
+    }
+    if (payload_kind == PayloadKind::kPreResolvedDispatch) {
+      return SubmitPreResolvedDispatchWithLists(
+          queue_affinity, wait_semaphore_list, signal_semaphore_list);
+    }
+    return iree_hal_device_queue_fill(
+        device_, queue_affinity, wait_semaphore_list, signal_semaphore_list,
+        target_buffer_, /*target_offset=*/0, kPayloadLength, &fill_pattern_,
+        sizeof(fill_pattern_), IREE_HAL_FILL_FLAG_NONE);
+  }
+
+  iree_status_t SubmitBarrier(iree_hal_queue_affinity_t queue_affinity,
+                              iree_hal_semaphore_t* wait_semaphore,
+                              uint64_t wait_payload_value,
+                              iree_hal_semaphore_t* signal_semaphore,
+                              uint64_t signal_payload_value) {
+    iree_hal_semaphore_t* wait_semaphore_storage = wait_semaphore;
+    iree_hal_semaphore_t* signal_semaphore_storage = signal_semaphore;
+    iree_hal_semaphore_list_t wait_semaphore_list =
+        iree_hal_semaphore_list_empty();
+    iree_hal_semaphore_list_t signal_semaphore_list =
+        iree_hal_semaphore_list_empty();
+    if (wait_semaphore) {
+      wait_semaphore_list = {
+          /*count=*/1,
+          /*semaphores=*/&wait_semaphore_storage,
+          /*payload_values=*/&wait_payload_value,
+      };
+    }
+    if (signal_semaphore) {
+      signal_semaphore_list = {
+          /*count=*/1,
+          /*semaphores=*/&signal_semaphore_storage,
+          /*payload_values=*/&signal_payload_value,
+      };
+    }
+    return SubmitBarrierWithLists(queue_affinity, wait_semaphore_list,
+                                  signal_semaphore_list);
+  }
+
+  iree_status_t SubmitBarrierWithWaitList(
+      iree_hal_queue_affinity_t queue_affinity,
+      iree_hal_semaphore_list_t wait_semaphore_list,
+      iree_hal_semaphore_t* signal_semaphore, uint64_t signal_payload_value) {
+    iree_hal_semaphore_t* signal_semaphore_storage = signal_semaphore;
+    iree_hal_semaphore_list_t signal_semaphore_list =
+        iree_hal_semaphore_list_empty();
+    if (signal_semaphore) {
+      signal_semaphore_list = {
+          /*count=*/1,
+          /*semaphores=*/&signal_semaphore_storage,
+          /*payload_values=*/&signal_payload_value,
+      };
+    }
+    return SubmitBarrierWithLists(queue_affinity, wait_semaphore_list,
+                                  signal_semaphore_list);
+  }
+
+  iree_status_t SubmitBarrierWithSingleWaitAndSignalList(
+      iree_hal_queue_affinity_t queue_affinity,
+      iree_hal_semaphore_t* wait_semaphore, uint64_t wait_payload_value,
+      iree_hal_semaphore_list_t signal_semaphore_list) {
+    iree_hal_semaphore_t* wait_semaphore_storage = wait_semaphore;
+    iree_hal_semaphore_list_t wait_semaphore_list = {
+        /*count=*/1,
+        /*semaphores=*/&wait_semaphore_storage,
+        /*payload_values=*/&wait_payload_value,
+    };
+    return SubmitBarrierWithLists(queue_affinity, wait_semaphore_list,
+                                  signal_semaphore_list);
+  }
+
+  iree_status_t Wait(iree_hal_semaphore_t* semaphore, uint64_t payload_value) {
+    return iree_hal_semaphore_wait(
+        semaphore, payload_value, iree_infinite_timeout(),
+        iree_hal_amdgpu_benchmark_completion_wait_flags());
+  }
+
+  iree_status_t CreateQueueAllocaTlsfPool(iree_device_size_t allocation_size,
+                                          iree_hal_pool_t** out_pool) {
+    *out_pool = nullptr;
+    iree_hal_queue_pool_backend_t backend = {0};
+    IREE_RETURN_IF_ERROR(
+        iree_hal_device_query_queue_pool_backend(device_, kQueue0, &backend));
+    if (!backend.slab_provider || !backend.notification) {
+      return iree_make_status(
+          IREE_STATUS_FAILED_PRECONDITION,
+          "queue pool backend query returned an incomplete backend bundle");
+    }
+
+    iree_hal_tlsf_pool_options_t options = {};
+    options.tlsf_options.range_length = allocation_size;
+    options.tlsf_options.alignment = kPayloadBufferAlignment;
+    options.tlsf_options.initial_block_capacity = 16;
+    options.tlsf_options.frontier_capacity = 2;
+    return iree_hal_tlsf_pool_create(
+        options, backend.slab_provider, backend.notification,
+        iree_hal_pool_epoch_query_null(), host_allocator_, out_pool);
+  }
+
+  iree_status_t QueueAllocaSubmit(iree_hal_pool_t* pool,
+                                  iree_device_size_t allocation_size,
+                                  iree_hal_buffer_t** out_buffer,
+                                  SubmittedCompletion* out_completion) {
+    *out_buffer = nullptr;
+    uint64_t signal_payload_value = ++completion_payload_value_;
+    iree_hal_semaphore_t* signal_semaphore = completion_semaphore_;
+    iree_hal_semaphore_list_t signal_semaphore_list = {
+        /*count=*/1,
+        /*semaphores=*/&signal_semaphore,
+        /*payload_values=*/&signal_payload_value,
+    };
+    IREE_RETURN_IF_ERROR(iree_hal_device_queue_alloca(
+        device_, kQueue0, iree_hal_semaphore_list_empty(),
+        signal_semaphore_list, pool, kQueueAllocaBufferParams, allocation_size,
+        IREE_HAL_ALLOCA_FLAG_NONE, out_buffer));
+    *out_completion = {completion_semaphore_, signal_payload_value};
+    return iree_ok_status();
+  }
+
+  iree_status_t QueueAllocaCleanup(iree_hal_buffer_t* buffer,
+                                   SubmittedCompletion alloca_completion) {
+    IREE_RETURN_IF_ERROR(
+        Wait(alloca_completion.semaphore, alloca_completion.payload_value));
+
+    uint64_t signal_payload_value = ++completion_payload_value_;
+    iree_hal_semaphore_t* wait_semaphore = alloca_completion.semaphore;
+    iree_hal_semaphore_t* signal_semaphore = completion_semaphore_;
+    iree_hal_semaphore_list_t wait_semaphore_list = {
+        /*count=*/1,
+        /*semaphores=*/&wait_semaphore,
+        /*payload_values=*/&alloca_completion.payload_value,
+    };
+    iree_hal_semaphore_list_t signal_semaphore_list = {
+        /*count=*/1,
+        /*semaphores=*/&signal_semaphore,
+        /*payload_values=*/&signal_payload_value,
+    };
+    IREE_RETURN_IF_ERROR(iree_hal_device_queue_dealloca(
+        device_, kQueue0, wait_semaphore_list, signal_semaphore_list, buffer,
+        IREE_HAL_DEALLOCA_FLAG_NONE));
+    IREE_RETURN_IF_ERROR(Wait(completion_semaphore_, signal_payload_value));
+    iree_hal_buffer_release(buffer);
+    return iree_ok_status();
+  }
+
+  iree_status_t QueueAllocaSubmitAndCleanup(
+      iree_hal_pool_t* pool, iree_device_size_t allocation_size) {
+    iree_hal_buffer_t* buffer = nullptr;
+    SubmittedCompletion completion;
+    IREE_RETURN_IF_ERROR(
+        QueueAllocaSubmit(pool, allocation_size, &buffer, &completion));
+    return QueueAllocaCleanup(buffer, completion);
+  }
+
+  iree_status_t FillBufferAndWait(iree_hal_buffer_t* target_buffer,
+                                  const void* pattern,
+                                  iree_host_size_t pattern_length) {
+    uint64_t payload_value = ++completion_payload_value_;
+    iree_hal_semaphore_t* signal_semaphore = completion_semaphore_;
+    iree_hal_semaphore_list_t signal_semaphore_list = {
+        /*count=*/1,
+        /*semaphores=*/&signal_semaphore,
+        /*payload_values=*/&payload_value,
+    };
+    IREE_RETURN_IF_ERROR(iree_hal_device_queue_fill(
+        device_, kQueue0, iree_hal_semaphore_list_empty(),
+        signal_semaphore_list, target_buffer, /*target_offset=*/0,
+        kPayloadBufferAlignment, pattern, pattern_length,
+        IREE_HAL_FILL_FLAG_NONE));
+    return Wait(completion_semaphore_, payload_value);
+  }
+
+  iree_status_t SameQueueBarrierAndWait() {
+    uint64_t payload_value = ++completion_payload_value_;
+    IREE_RETURN_IF_ERROR(SubmitBarrier(kQueue0, /*wait_semaphore=*/nullptr,
+                                       /*wait_payload_value=*/0,
+                                       completion_semaphore_, payload_value));
+    return Wait(completion_semaphore_, payload_value);
+  }
+
+  static iree_status_t NoopHostCall(void* user_data, const uint64_t args[4],
+                                    iree_hal_host_call_context_t* context) {
+    (void)user_data;
+    (void)args;
+    (void)context;
+    return iree_ok_status();
+  }
+
+  iree_status_t HostCallAndWait(iree_hal_host_call_flags_t flags) {
+    uint64_t payload_value = ++completion_payload_value_;
+    iree_hal_semaphore_t* signal_semaphore = completion_semaphore_;
+    iree_hal_semaphore_list_t signal_semaphore_list = {
+        /*count=*/1,
+        /*semaphores=*/&signal_semaphore,
+        /*payload_values=*/&payload_value,
+    };
+    iree_hal_host_call_t call =
+        iree_hal_make_host_call(NoopHostCall, /*user_data=*/nullptr);
+    uint64_t args[4] = {0, 0, 0, 0};
+    IREE_RETURN_IF_ERROR(iree_hal_device_queue_host_call(
+        device_, kQueue0, iree_hal_semaphore_list_empty(),
+        signal_semaphore_list, call, args, flags));
+    return Wait(completion_semaphore_, payload_value);
+  }
+
+  iree_status_t HostCallBatchSubmit(iree_hal_host_call_flags_t flags,
+                                    int64_t batch_count,
+                                    SubmittedCompletion* out_completion) {
+    for (int64_t i = 0; i < batch_count; ++i) {
+      const bool is_final_operation = i + 1 == batch_count;
+      uint64_t wait_payload_value = stream_payload_value_;
+      uint64_t signal_payload_value = stream_payload_value_ + 1;
+      iree_hal_semaphore_t* wait_semaphore = stream_semaphore_;
+      iree_hal_semaphore_list_t wait_semaphore_list =
+          iree_hal_semaphore_list_empty();
+      if (i > 0) {
+        wait_semaphore_list = iree_hal_semaphore_list_t{
+            /*count=*/1,
+            /*semaphores=*/&wait_semaphore,
+            /*payload_values=*/&wait_payload_value,
+        };
+      }
+
+      iree_hal_semaphore_t* signal_semaphores[2] = {
+          stream_semaphore_,
+          completion_semaphore_,
+      };
+      uint64_t signal_payload_values[2] = {
+          signal_payload_value,
+          completion_payload_value_ + 1,
+      };
+      iree_hal_semaphore_list_t signal_semaphore_list = {
+          /*count=*/is_final_operation ? 2u : 1u,
+          /*semaphores=*/signal_semaphores,
+          /*payload_values=*/signal_payload_values,
+      };
+      iree_hal_host_call_t call =
+          iree_hal_make_host_call(NoopHostCall, /*user_data=*/nullptr);
+      uint64_t args[4] = {0, 0, 0, 0};
+      IREE_RETURN_IF_ERROR(iree_hal_device_queue_host_call(
+          device_, kQueue0, wait_semaphore_list, signal_semaphore_list, call,
+          args, flags));
+      stream_payload_value_ = signal_payload_value;
+      if (is_final_operation) {
+        ++completion_payload_value_;
+      }
+    }
+
+    *out_completion = {completion_semaphore_, completion_payload_value_};
+    return iree_ok_status();
+  }
+
+  iree_status_t HostCallBatchAndWait(iree_hal_host_call_flags_t flags,
+                                     int64_t batch_count) {
+    SubmittedCompletion completion;
+    IREE_RETURN_IF_ERROR(HostCallBatchSubmit(flags, batch_count, &completion));
+    return Wait(completion.semaphore, completion.payload_value);
+  }
+
+  iree_status_t SameQueueBarrierBatchSubmit(
+      int64_t batch_count, SubmittedCompletion* out_completion) {
+    uint64_t payload_value = completion_payload_value_;
+    for (int64_t i = 0; i < batch_count; ++i) {
+      const uint64_t wait_payload_value = payload_value;
+      const uint64_t signal_payload_value = payload_value + 1;
+      IREE_RETURN_IF_ERROR(SubmitBarrier(
+          kQueue0, i == 0 ? nullptr : completion_semaphore_, wait_payload_value,
+          completion_semaphore_, signal_payload_value));
+      payload_value = signal_payload_value;
+    }
+    completion_payload_value_ = payload_value;
+    *out_completion = {completion_semaphore_, payload_value};
+    return iree_ok_status();
+  }
+
+  iree_status_t SameQueueBarrierBatchAndWait(int64_t batch_count) {
+    SubmittedCompletion completion;
+    IREE_RETURN_IF_ERROR(SameQueueBarrierBatchSubmit(batch_count, &completion));
+    return Wait(completion.semaphore, completion.payload_value);
+  }
+
+  iree_status_t SameQueueEpochChainSubmit(int64_t batch_count,
+                                          SubmittedCompletion* out_completion) {
+    for (int64_t i = 0; i < batch_count; ++i) {
+      const uint64_t wait_payload_value = stream_payload_value_;
+      const uint64_t signal_payload_value = stream_payload_value_ + 1;
+      IREE_RETURN_IF_ERROR(SubmitBarrier(
+          kQueue0, i == 0 ? nullptr : stream_semaphore_, wait_payload_value,
+          stream_semaphore_, signal_payload_value));
+      stream_payload_value_ = signal_payload_value;
+    }
+    *out_completion = {stream_semaphore_, stream_payload_value_};
+    return iree_ok_status();
+  }
+
+  iree_status_t SameQueueEpochChainAndWait(int64_t batch_count) {
+    SubmittedCompletion completion;
+    IREE_RETURN_IF_ERROR(SameQueueEpochChainSubmit(batch_count, &completion));
+    return Wait(completion.semaphore, completion.payload_value);
+  }
+
+  iree_status_t CrossQueueAlreadyCompletedWaitAndSignal() {
+    const uint64_t completion_payload_value = ++completion_payload_value_;
+    IREE_RETURN_IF_ERROR(
+        SubmitBarrier(kQueue1, producer_semaphore_, producer_payload_value_,
+                      completion_semaphore_, completion_payload_value));
+    return Wait(completion_semaphore_, completion_payload_value);
+  }
+
+  iree_status_t PrimeProducerSemaphore() {
+    producer_payload_value_ = 1;
+    IREE_RETURN_IF_ERROR(SubmitBarrier(
+        kQueue0, /*wait_semaphore=*/nullptr, /*wait_payload_value=*/0,
+        producer_semaphore_, producer_payload_value_));
+    return Wait(producer_semaphore_, producer_payload_value_);
+  }
+
+  iree_status_t CrossQueueBarrierValueAndWait() {
+    const uint64_t producer_payload_value = ++producer_payload_value_;
+    const uint64_t completion_payload_value = ++completion_payload_value_;
+    IREE_RETURN_IF_ERROR(SubmitBarrier(
+        kQueue0, /*wait_semaphore=*/nullptr, /*wait_payload_value=*/0,
+        producer_semaphore_, producer_payload_value));
+    IREE_RETURN_IF_ERROR(
+        SubmitBarrier(kQueue1, producer_semaphore_, producer_payload_value,
+                      completion_semaphore_, completion_payload_value));
+    return Wait(completion_semaphore_, completion_payload_value);
+  }
+
+  iree_status_t CrossQueueBarrierValueBatchSubmit(
+      int64_t batch_count, SubmittedCompletion* out_completion) {
+    uint64_t completion_payload_value = completion_payload_value_;
+    for (int64_t i = 0; i < batch_count; ++i) {
+      const uint64_t producer_payload_value = ++producer_payload_value_;
+      IREE_RETURN_IF_ERROR(SubmitBarrier(
+          kQueue0, /*wait_semaphore=*/nullptr, /*wait_payload_value=*/0,
+          producer_semaphore_, producer_payload_value));
+
+      iree_hal_semaphore_t* wait_semaphores[2] = {producer_semaphore_, nullptr};
+      uint64_t wait_payload_values[2] = {producer_payload_value, 0};
+      iree_host_size_t wait_semaphore_count = 1;
+      if (i > 0) {
+        wait_semaphores[wait_semaphore_count] = completion_semaphore_;
+        wait_payload_values[wait_semaphore_count] = completion_payload_value;
+        ++wait_semaphore_count;
+      }
+      iree_hal_semaphore_list_t wait_semaphore_list = {
+          /*count=*/wait_semaphore_count,
+          /*semaphores=*/wait_semaphores,
+          /*payload_values=*/wait_payload_values,
+      };
+
+      const uint64_t signal_completion_payload_value =
+          completion_payload_value + 1;
+      IREE_RETURN_IF_ERROR(SubmitBarrierWithWaitList(
+          kQueue1, wait_semaphore_list, completion_semaphore_,
+          signal_completion_payload_value));
+      completion_payload_value = signal_completion_payload_value;
+    }
+    completion_payload_value_ = completion_payload_value;
+    *out_completion = {completion_semaphore_, completion_payload_value};
+    return iree_ok_status();
+  }
+
+  iree_status_t CrossQueueBarrierValueBatchAndWait(int64_t batch_count) {
+    SubmittedCompletion completion;
+    IREE_RETURN_IF_ERROR(
+        CrossQueueBarrierValueBatchSubmit(batch_count, &completion));
+    return Wait(completion.semaphore, completion.payload_value);
+  }
+
+  iree_status_t CrossQueuePingPongChainSubmit(
+      int64_t handoff_count, SubmittedCompletion* out_completion) {
+    uint64_t producer_payload_value = ++producer_payload_value_;
+    IREE_RETURN_IF_ERROR(SubmitBarrier(
+        kQueue0, /*wait_semaphore=*/nullptr, /*wait_payload_value=*/0,
+        producer_semaphore_, producer_payload_value));
+
+    iree_hal_semaphore_t* final_semaphore = producer_semaphore_;
+    uint64_t final_payload_value = producer_payload_value;
+    for (int64_t i = 0; i < handoff_count; ++i) {
+      if ((i & 1) == 0) {
+        const uint64_t stream_payload_value = ++stream_payload_value_;
+        IREE_RETURN_IF_ERROR(
+            SubmitBarrier(kQueue1, producer_semaphore_, producer_payload_value,
+                          stream_semaphore_, stream_payload_value));
+        final_semaphore = stream_semaphore_;
+        final_payload_value = stream_payload_value;
+      } else {
+        producer_payload_value = ++producer_payload_value_;
+        IREE_RETURN_IF_ERROR(
+            SubmitBarrier(kQueue0, stream_semaphore_, stream_payload_value_,
+                          producer_semaphore_, producer_payload_value));
+        final_semaphore = producer_semaphore_;
+        final_payload_value = producer_payload_value;
+      }
+    }
+    *out_completion = {final_semaphore, final_payload_value};
+    return iree_ok_status();
+  }
+
+  iree_status_t CrossQueuePingPongChainSubmitPublicFinalInline(
+      int64_t handoff_count, SubmittedCompletion* out_completion) {
+    if (handoff_count == 0) {
+      SubmittedCompletion private_completion;
+      IREE_RETURN_IF_ERROR(
+          CrossQueuePingPongChainSubmit(handoff_count, &private_completion));
+      const uint64_t completion_payload_value = ++completion_payload_value_;
+      IREE_RETURN_IF_ERROR(SubmitBarrier(
+          CrossQueuePingPongFinalQueue(handoff_count),
+          private_completion.semaphore, private_completion.payload_value,
+          completion_semaphore_, completion_payload_value));
+      *out_completion = {completion_semaphore_, completion_payload_value};
+      return iree_ok_status();
+    }
+
+    uint64_t producer_payload_value = ++producer_payload_value_;
+    IREE_RETURN_IF_ERROR(SubmitBarrier(
+        kQueue0, /*wait_semaphore=*/nullptr, /*wait_payload_value=*/0,
+        producer_semaphore_, producer_payload_value));
+
+    for (int64_t i = 0; i < handoff_count; ++i) {
+      const bool is_final_handoff = i + 1 == handoff_count;
+      if ((i & 1) == 0) {
+        const uint64_t stream_payload_value = ++stream_payload_value_;
+        iree_hal_semaphore_t* signal_semaphores[2] = {
+            stream_semaphore_,
+            completion_semaphore_,
+        };
+        uint64_t signal_payload_values[2] = {
+            stream_payload_value,
+            completion_payload_value_ + 1,
+        };
+        iree_hal_semaphore_list_t signal_semaphore_list = {
+            /*count=*/is_final_handoff ? 2u : 1u,
+            /*semaphores=*/signal_semaphores,
+            /*payload_values=*/signal_payload_values,
+        };
+        IREE_RETURN_IF_ERROR(SubmitBarrierWithSingleWaitAndSignalList(
+            kQueue1, producer_semaphore_, producer_payload_value,
+            signal_semaphore_list));
+        if (is_final_handoff) {
+          ++completion_payload_value_;
+        }
+      } else {
+        producer_payload_value = ++producer_payload_value_;
+        iree_hal_semaphore_t* signal_semaphores[2] = {
+            producer_semaphore_,
+            completion_semaphore_,
+        };
+        uint64_t signal_payload_values[2] = {
+            producer_payload_value,
+            completion_payload_value_ + 1,
+        };
+        iree_hal_semaphore_list_t signal_semaphore_list = {
+            /*count=*/is_final_handoff ? 2u : 1u,
+            /*semaphores=*/signal_semaphores,
+            /*payload_values=*/signal_payload_values,
+        };
+        IREE_RETURN_IF_ERROR(SubmitBarrierWithSingleWaitAndSignalList(
+            kQueue0, stream_semaphore_, stream_payload_value_,
+            signal_semaphore_list));
+        if (is_final_handoff) {
+          ++completion_payload_value_;
+        }
+      }
+    }
+    *out_completion = {completion_semaphore_, completion_payload_value_};
+    return iree_ok_status();
+  }
+
+  iree_status_t CrossQueuePingPongChainSubmitPublicFinalSeparate(
+      int64_t handoff_count, SubmittedCompletion* out_completion) {
+    SubmittedCompletion private_completion;
+    IREE_RETURN_IF_ERROR(
+        CrossQueuePingPongChainSubmit(handoff_count, &private_completion));
+    const uint64_t completion_payload_value = ++completion_payload_value_;
+    IREE_RETURN_IF_ERROR(SubmitBarrier(
+        CrossQueuePingPongFinalQueue(handoff_count),
+        private_completion.semaphore, private_completion.payload_value,
+        completion_semaphore_, completion_payload_value));
+    *out_completion = {completion_semaphore_, completion_payload_value};
+    return iree_ok_status();
+  }
+
+  iree_status_t CrossQueuePingPongPayloadSubmitPublicFinalInline(
+      PayloadKind payload_kind, int64_t handoff_count,
+      SubmittedCompletion* out_completion) {
+    uint64_t producer_payload_value = ++producer_payload_value_;
+    IREE_RETURN_IF_ERROR(SubmitBarrier(
+        kQueue0, /*wait_semaphore=*/nullptr, /*wait_payload_value=*/0,
+        producer_semaphore_, producer_payload_value));
+
+    for (int64_t i = 0; i < handoff_count; ++i) {
+      const bool is_final_handoff = i + 1 == handoff_count;
+      iree_hal_semaphore_t* signal_semaphores[2] = {
+          nullptr,
+          completion_semaphore_,
+      };
+      uint64_t signal_payload_values[2] = {
+          0,
+          completion_payload_value_ + 1,
+      };
+      iree_hal_semaphore_list_t signal_semaphore_list = {
+          /*count=*/is_final_handoff ? 2u : 1u,
+          /*semaphores=*/signal_semaphores,
+          /*payload_values=*/signal_payload_values,
+      };
+
+      if ((i & 1) == 0) {
+        const uint64_t stream_payload_value = ++stream_payload_value_;
+        iree_hal_semaphore_t* wait_semaphore = producer_semaphore_;
+        iree_hal_semaphore_list_t wait_semaphore_list = {
+            /*count=*/1,
+            /*semaphores=*/&wait_semaphore,
+            /*payload_values=*/&producer_payload_value,
+        };
+        signal_semaphores[0] = stream_semaphore_;
+        signal_payload_values[0] = stream_payload_value;
+        IREE_RETURN_IF_ERROR(SubmitPayloadWithLists(
+            payload_kind, kQueue1, wait_semaphore_list, signal_semaphore_list));
+      } else {
+        producer_payload_value = ++producer_payload_value_;
+        iree_hal_semaphore_t* wait_semaphore = stream_semaphore_;
+        iree_hal_semaphore_list_t wait_semaphore_list = {
+            /*count=*/1,
+            /*semaphores=*/&wait_semaphore,
+            /*payload_values=*/&stream_payload_value_,
+        };
+        signal_semaphores[0] = producer_semaphore_;
+        signal_payload_values[0] = producer_payload_value;
+        IREE_RETURN_IF_ERROR(SubmitPayloadWithLists(
+            payload_kind, kQueue0, wait_semaphore_list, signal_semaphore_list));
+      }
+      if (is_final_handoff) {
+        ++completion_payload_value_;
+      }
+    }
+
+    *out_completion = {completion_semaphore_, completion_payload_value_};
+    return iree_ok_status();
+  }
+
+  iree_status_t CrossQueuePingPongPayloadPublicFinalInlineAndWait(
+      PayloadKind payload_kind, int64_t handoff_count) {
+    SubmittedCompletion completion;
+    IREE_RETURN_IF_ERROR(CrossQueuePingPongPayloadSubmitPublicFinalInline(
+        payload_kind, handoff_count, &completion));
+    return Wait(completion.semaphore, completion.payload_value);
+  }
+
+  iree_status_t SameQueuePrivateStreamPayloadSubmitPublicFinalInline(
+      PayloadKind payload_kind, int64_t operation_count,
+      SubmittedCompletion* out_completion) {
+    for (int64_t i = 0; i < operation_count; ++i) {
+      const bool is_final_operation = i + 1 == operation_count;
+      uint64_t wait_payload_value = stream_payload_value_;
+      uint64_t signal_payload_value = stream_payload_value_ + 1;
+      iree_hal_semaphore_t* wait_semaphore = stream_semaphore_;
+      iree_hal_semaphore_list_t wait_semaphore_list =
+          iree_hal_semaphore_list_empty();
+      if (i > 0) {
+        wait_semaphore_list = iree_hal_semaphore_list_t{
+            /*count=*/1,
+            /*semaphores=*/&wait_semaphore,
+            /*payload_values=*/&wait_payload_value,
+        };
+      }
+      iree_hal_semaphore_t* signal_semaphores[2] = {
+          stream_semaphore_,
+          completion_semaphore_,
+      };
+      uint64_t signal_payload_values[2] = {
+          signal_payload_value,
+          completion_payload_value_ + 1,
+      };
+      iree_hal_semaphore_list_t signal_semaphore_list = {
+          /*count=*/is_final_operation ? 2u : 1u,
+          /*semaphores=*/signal_semaphores,
+          /*payload_values=*/signal_payload_values,
+      };
+      IREE_RETURN_IF_ERROR(SubmitPayloadWithLists(
+          payload_kind, kQueue0, wait_semaphore_list, signal_semaphore_list));
+      stream_payload_value_ = signal_payload_value;
+      if (is_final_operation) {
+        ++completion_payload_value_;
+      }
+    }
+
+    *out_completion = {completion_semaphore_, completion_payload_value_};
+    return iree_ok_status();
+  }
+
+  iree_status_t SameQueuePrivateStreamPayloadPublicFinalInlineAndWait(
+      PayloadKind payload_kind, int64_t operation_count) {
+    SubmittedCompletion completion;
+    IREE_RETURN_IF_ERROR(SameQueuePrivateStreamPayloadSubmitPublicFinalInline(
+        payload_kind, operation_count, &completion));
+    return Wait(completion.semaphore, completion.payload_value);
+  }
+
+  iree_status_t CrossQueuePingPongChainAndWait(int64_t handoff_count) {
+    SubmittedCompletion completion;
+    IREE_RETURN_IF_ERROR(
+        CrossQueuePingPongChainSubmit(handoff_count, &completion));
+    return Wait(completion.semaphore, completion.payload_value);
+  }
+
+  iree_status_t CrossQueuePingPongChainPublicFinalInlineAndWait(
+      int64_t handoff_count) {
+    SubmittedCompletion completion;
+    IREE_RETURN_IF_ERROR(CrossQueuePingPongChainSubmitPublicFinalInline(
+        handoff_count, &completion));
+    return Wait(completion.semaphore, completion.payload_value);
+  }
+
+  iree_status_t CrossQueuePingPongChainPublicFinalSeparateAndWait(
+      int64_t handoff_count) {
+    SubmittedCompletion completion;
+    IREE_RETURN_IF_ERROR(CrossQueuePingPongChainSubmitPublicFinalSeparate(
+        handoff_count, &completion));
+    return Wait(completion.semaphore, completion.payload_value);
+  }
+
+  iree_status_t WaitBeforeSignalChainAndWait() {
+    const uint64_t producer_payload_value = ++producer_payload_value_;
+    const uint64_t completion_payload_value = ++completion_payload_value_;
+    IREE_RETURN_IF_ERROR(
+        SubmitBarrier(kQueue1, producer_semaphore_, producer_payload_value,
+                      completion_semaphore_, completion_payload_value));
+    IREE_RETURN_IF_ERROR(SubmitBarrier(
+        kQueue0, /*wait_semaphore=*/nullptr, /*wait_payload_value=*/0,
+        producer_semaphore_, producer_payload_value));
+    return Wait(completion_semaphore_, completion_payload_value);
+  }
+
+  bool HandleStatus(benchmark::State& state, iree_status_t status,
+                    const char* message) {
+    if (iree_status_is_ok(status)) return true;
+    iree_status_fprint(stderr, status);
+    iree_status_free(status);
+    state.SkipWithError(message);
+    return false;
+  }
+
+  void SetQueueSubmissionsProcessed(benchmark::State& state,
+                                    int64_t queue_submissions_per_sync) {
+    state.counters["queue_submissions_per_sync"] =
+        static_cast<double>(queue_submissions_per_sync);
+    iree_hal_amdgpu_benchmark_set_completion_wait_counters(state);
+    state.SetItemsProcessed(state.iterations() * queue_submissions_per_sync);
+  }
+
+  void SetCrossQueuePingPongCounters(
+      benchmark::State& state, int64_t handoff_count,
+      int64_t queue_submissions_per_sync,
+      int64_t public_completion_signals_per_sync) {
+    state.counters["cross_queue_handoffs_per_sync"] =
+        static_cast<double>(handoff_count);
+    state.counters["hip_equivalent_round_trips_per_sync"] =
+        static_cast<double>(handoff_count) / 2.0;
+    state.counters["public_completion_signals_per_sync"] =
+        static_cast<double>(public_completion_signals_per_sync);
+    SetQueueSubmissionsProcessed(state, queue_submissions_per_sync);
+  }
+
+  void SetPayloadPingPongCounters(benchmark::State& state,
+                                  int64_t handoff_count,
+                                  int64_t queue_submissions_per_sync,
+                                  int64_t public_completion_signals_per_sync) {
+    SetCrossQueuePingPongCounters(state, handoff_count,
+                                  queue_submissions_per_sync,
+                                  public_completion_signals_per_sync);
+    state.counters["payload_operations_per_sync"] =
+        static_cast<double>(handoff_count);
+  }
+
+  void SetSingleStreamPayloadCounters(benchmark::State& state,
+                                      int64_t operation_count) {
+    state.counters["operations_per_sync"] =
+        static_cast<double>(operation_count);
+    state.counters["public_completion_signals_per_sync"] = 1.0;
+    SetQueueSubmissionsProcessed(state, operation_count);
+  }
+
+  void SetQueueAllocaCounters(benchmark::State& state,
+                              iree_device_size_t allocation_size,
+                              QueueAllocaTlsfGrowthMode growth_mode) {
+    state.counters["queue_allocas_per_sync"] = 1.0;
+    state.counters["allocation_bytes_per_sync"] =
+        static_cast<double>(allocation_size);
+    state.counters["forced_tlsf_growth"] =
+        growth_mode == QueueAllocaTlsfGrowthMode::kForcedGrowth ? 1.0 : 0.0;
+    SetQueueSubmissionsProcessed(state, /*queue_submissions_per_sync=*/1);
+  }
+
+  void SetProfileGuardrailCounters(benchmark::State& state,
+                                   ProfileGuardrailMode mode,
+                                   int64_t queue_submissions_per_sync,
+                                   int64_t profiled_operations_per_sync) {
+    const iree_hal_device_profiling_data_families_t data_families =
+        ProfileGuardrailDataFamilies(mode);
+    state.counters["profile_data_families"] =
+        static_cast<double>(data_families);
+    state.counters["profile_queue_events"] =
+        iree_any_bit_set(data_families,
+                         IREE_HAL_DEVICE_PROFILING_DATA_QUEUE_EVENTS)
+            ? 1.0
+            : 0.0;
+    state.counters["profile_device_queue_events"] =
+        iree_any_bit_set(data_families,
+                         IREE_HAL_DEVICE_PROFILING_DATA_DEVICE_QUEUE_EVENTS)
+            ? 1.0
+            : 0.0;
+    state.counters["profile_dispatch_events"] =
+        iree_any_bit_set(data_families,
+                         IREE_HAL_DEVICE_PROFILING_DATA_DISPATCH_EVENTS)
+            ? 1.0
+            : 0.0;
+    state.counters["profiled_operations_per_sync"] =
+        static_cast<double>(profiled_operations_per_sync);
+    SetQueueSubmissionsProcessed(state, queue_submissions_per_sync);
+  }
+
+  bool WaitWithTimingPaused(benchmark::State& state,
+                            const SubmittedCompletion& completion,
+                            const char* message) {
+    state.PauseTiming();
+    iree_status_t status = Wait(completion.semaphore, completion.payload_value);
+    state.ResumeTiming();
+    return HandleStatus(state, status, message);
+  }
+
+  bool EnsurePayloadBuffers(benchmark::State& state) {
+    if (source_buffer_ && target_buffer_) return true;
+    return AllocatePayloadBuffers(state);
+  }
+
+  iree_status_t LoadExecutableFromData(
+      iree_const_byte_span_t executable_data,
+      iree_hal_executable_cache_t** out_executable_cache,
+      iree_hal_executable_t** out_executable) {
+    *out_executable_cache = nullptr;
+    *out_executable = nullptr;
+
+    iree_hal_executable_cache_t* executable_cache = nullptr;
+    iree_hal_executable_t* executable = nullptr;
+    iree_status_t status = iree_hal_executable_cache_create(
+        device_, iree_make_cstring_view("default"), &executable_cache);
+
+    char executable_format[128] = {0};
+    iree_host_size_t inferred_size = 0;
+    if (iree_status_is_ok(status)) {
+      status = iree_hal_executable_cache_infer_format(
+          executable_cache,
+          IREE_HAL_EXECUTABLE_CACHING_MODE_ALIAS_PROVIDED_DATA, executable_data,
+          IREE_ARRAYSIZE(executable_format), executable_format, &inferred_size);
+    }
+
+    if (iree_status_is_ok(status)) {
+      iree_hal_executable_params_t executable_params;
+      iree_hal_executable_params_initialize(&executable_params);
+      executable_params.caching_mode =
+          IREE_HAL_EXECUTABLE_CACHING_MODE_ALIAS_PROVIDED_DATA;
+      executable_params.executable_format =
+          iree_make_cstring_view(executable_format);
+      executable_params.executable_data = executable_data;
+      status = iree_hal_executable_cache_prepare_executable(
+          executable_cache, &executable_params, &executable);
+    }
+
+    if (iree_status_is_ok(status)) {
+      *out_executable_cache = executable_cache;
+      *out_executable = executable;
+    } else {
+      iree_hal_executable_release(executable);
+      iree_hal_executable_cache_release(executable_cache);
+    }
+    return status;
+  }
+
+  bool EnsureDispatchExecutable(benchmark::State& state) {
+    if (dispatch_executable_) return true;
+
+    iree_const_byte_span_t executable_data = iree_const_byte_span_empty();
+    iree_status_t status = iree_ok_status();
+    executable_data = FindCtsExecutableData(iree_make_cstring_view(
+        "command_buffer_dispatch_constants_bindings_test.bin"));
+    if (executable_data.data_length == 0) {
+      status = iree_make_status(IREE_STATUS_NOT_FOUND,
+                                "AMDGPU CTS dispatch executable not found");
+    }
+
+    if (iree_status_is_ok(status)) {
+      status = LoadExecutableFromData(
+          executable_data, &dispatch_executable_cache_, &dispatch_executable_);
+    }
+    return HandleStatus(state, status, "failed to load dispatch executable");
+  }
+
+  bool EnsureBindingCountExecutable(benchmark::State& state) {
+    if (binding_count_executable_) return true;
+
+    const iree_string_view_t executable_file =
+        iree_make_cstring_view(FLAG_binding_count_executable_file);
+    iree_io_file_contents_t* executable_file_contents = nullptr;
+    iree_const_byte_span_t executable_data = iree_const_byte_span_empty();
+    iree_status_t status = iree_ok_status();
+    if (!iree_string_view_is_empty(executable_file)) {
+      status = iree_io_file_contents_read(executable_file, host_allocator_,
+                                          &executable_file_contents);
+      if (iree_status_is_ok(status)) {
+        executable_data = executable_file_contents->const_buffer;
+      }
+    } else {
+      executable_data = FindQueueBenchmarkExecutableData(
+          iree_make_cstring_view("queue_benchmark_testdata.bin"));
+      if (executable_data.data_length == 0) {
+        status = iree_make_status(
+            IREE_STATUS_NOT_FOUND,
+            "AMDGPU queue benchmark dispatch executable not found");
+      }
+    }
+    if (iree_status_is_ok(status)) {
+      status = LoadExecutableFromData(executable_data,
+                                      &binding_count_executable_cache_,
+                                      &binding_count_executable_);
+    }
+    if (iree_status_is_ok(status)) {
+      binding_count_executable_file_contents_ = executable_file_contents;
+    } else {
+      iree_io_file_contents_free(executable_file_contents);
+    }
+    return HandleStatus(state, status,
+                        "failed to load binding-count dispatch executable");
+  }
+
+  bool EnsureBindingCountBuffers(benchmark::State& state,
+                                 int64_t binding_count) {
+    if (binding_count < 0 ||
+        binding_count > (int64_t)kDispatchBindingBenchmarkMaxCount) {
+      return HandleStatus(
+          state,
+          iree_make_status(
+              IREE_STATUS_OUT_OF_RANGE,
+              "unsupported dispatch benchmark binding count %" PRId64,
+              binding_count),
+          "invalid dispatch benchmark binding count");
+    }
+
+    iree_hal_allocator_t* allocator = iree_hal_device_allocator(device_);
+    iree_hal_buffer_params_t params = {0};
+    params.usage = IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE;
+    params.type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE;
+    params.min_alignment = kPayloadBufferAlignment;
+    for (iree_host_size_t i = 0; i < (iree_host_size_t)binding_count; ++i) {
+      if (binding_count_buffers_[i]) continue;
+      iree_status_t status = iree_hal_allocator_allocate_buffer(
+          allocator, params, kPayloadBufferAlignment,
+          &binding_count_buffers_[i]);
+      if (!HandleStatus(state, status,
+                        "failed to allocate dispatch benchmark binding")) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  bool EnsurePreResolvedDispatch(benchmark::State& state) {
+    if (pre_resolved_dispatch_kernargs_) return true;
+    if (!EnsurePayloadBuffers(state) || !EnsureDispatchExecutable(state)) {
+      return false;
+    }
+    return HandleStatus(state, PreparePreResolvedDispatch(),
+                        "failed to prepare pre-resolved dispatch");
+  }
+
+  iree_status_t PreparePreResolvedDispatch() {
+    iree_hal_amdgpu_host_queue_t* host_queue = nullptr;
+    IREE_RETURN_IF_ERROR(LookupHostQueue(kQueue0, &host_queue));
+
+    const iree_hal_amdgpu_executable_dispatch_descriptor_t* descriptor =
+        nullptr;
+    IREE_RETURN_IF_ERROR(
+        iree_hal_amdgpu_executable_lookup_dispatch_descriptor_for_device(
+            dispatch_executable_, /*export_ordinal=*/0,
+            host_queue->device_ordinal, &descriptor));
+    if (IREE_UNLIKELY(!descriptor)) {
+      return iree_make_status(
+          IREE_STATUS_FAILED_PRECONDITION,
+          "dispatch executable has no descriptor for device ordinal "
+          "%" PRIhsz,
+          host_queue->device_ordinal);
+    }
+    const iree_hal_amdgpu_device_kernel_args_t* kernel_args =
+        &descriptor->kernel_args;
+
+    uint64_t binding_ptrs[2] = {0, 0};
+    IREE_RETURN_IF_ERROR(
+        ResolveBufferDevicePointer(source_buffer_, &binding_ptrs[0]));
+    IREE_RETURN_IF_ERROR(
+        ResolveBufferDevicePointer(target_buffer_, &binding_ptrs[1]));
+
+    const uint32_t workgroup_count[3] = {1, 1, 1};
+    const uint32_t dynamic_workgroup_local_memory = 0;
+    const uint32_t kernarg_block_count = descriptor->hal_kernarg_block_count;
+
+    iree_host_size_t kernarg_length = 0;
+    if (IREE_UNLIKELY(!iree_host_size_checked_mul(
+            kernarg_block_count, sizeof(iree_hal_amdgpu_kernarg_block_t),
+            &kernarg_length))) {
+      return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                              "pre-resolved dispatch kernarg storage "
+                              "overflows host size");
+    }
+
+    uint8_t* kernargs = nullptr;
+    IREE_RETURN_IF_ERROR(iree_allocator_malloc(host_allocator_, kernarg_length,
+                                               (void**)&kernargs));
+    std::memset(kernargs, 0, kernarg_length);
+
+    const uint32_t constant_data[] = {3, 10};
+    iree_hal_amdgpu_device_dispatch_emplace_hal_kernargs(
+        kernel_args, workgroup_count, dynamic_workgroup_local_memory,
+        &descriptor->hal_kernarg_layout, binding_ptrs, constant_data, kernargs);
+    std::memset(&pre_resolved_dispatch_packet_template_, 0,
+                sizeof(pre_resolved_dispatch_packet_template_));
+    iree_hal_amdgpu_device_dispatch_emplace_packet(
+        kernel_args, workgroup_count, dynamic_workgroup_local_memory,
+        &pre_resolved_dispatch_packet_template_, /*kernarg_ptr=*/nullptr);
+
+    pre_resolved_dispatch_kernargs_ = kernargs;
+    pre_resolved_dispatch_kernarg_length_ = kernarg_length;
+    pre_resolved_dispatch_kernarg_block_count_ = (uint32_t)kernarg_block_count;
+    return iree_ok_status();
+  }
+
+  void ReleasePreResolvedDispatch() {
+    iree_allocator_free(host_allocator_, pre_resolved_dispatch_kernargs_);
+    pre_resolved_dispatch_kernargs_ = nullptr;
+    pre_resolved_dispatch_kernarg_length_ = 0;
+    pre_resolved_dispatch_kernarg_block_count_ = 0;
+    std::memset(&pre_resolved_dispatch_packet_template_, 0,
+                sizeof(pre_resolved_dispatch_packet_template_));
+  }
+
+  iree_status_t SubmitPreResolvedDispatchWithLists(
+      iree_hal_queue_affinity_t queue_affinity,
+      iree_hal_semaphore_list_t wait_semaphore_list,
+      iree_hal_semaphore_list_t signal_semaphore_list) {
+    iree_hal_amdgpu_host_queue_t* host_queue = nullptr;
+    IREE_RETURN_IF_ERROR(LookupHostQueue(queue_affinity, &host_queue));
+
+    iree_slim_mutex_lock(&host_queue->locks.submission_mutex);
+    iree_hal_amdgpu_wait_resolution_t resolution;
+    iree_hal_amdgpu_host_queue_resolve_waits(host_queue, wait_semaphore_list,
+                                             &resolution);
+    iree_status_t status = iree_ok_status();
+    if (IREE_UNLIKELY(resolution.needs_deferral)) {
+      status = iree_make_status(
+          IREE_STATUS_FAILED_PRECONDITION,
+          "pre-resolved dispatch benchmark path cannot defer waits");
+    }
+
+    if (iree_status_is_ok(status)) {
+      iree_hal_resource_t* operation_resources[3] = {
+          (iree_hal_resource_t*)dispatch_executable_,
+          (iree_hal_resource_t*)source_buffer_,
+          (iree_hal_resource_t*)target_buffer_,
+      };
+      iree_hal_amdgpu_host_queue_dispatch_submission_t submission;
+      bool ready = false;
+      status = iree_hal_amdgpu_host_queue_try_begin_dispatch_submission(
+          host_queue, &resolution, signal_semaphore_list,
+          IREE_ARRAYSIZE(operation_resources),
+          pre_resolved_dispatch_kernarg_block_count_,
+          iree_hal_amdgpu_profile_dispatch_event_reservation_t{0},
+          /*profile_queue_event_info=*/nullptr, &ready, &submission);
+      if (iree_status_is_ok(status) && !ready) {
+        status = iree_make_status(
+            IREE_STATUS_RESOURCE_EXHAUSTED,
+            "pre-resolved dispatch benchmark path hit temporary queue "
+            "capacity");
+      }
+      if (iree_status_is_ok(status)) {
+        std::memcpy(submission.kernel.kernargs.blocks->data,
+                    pre_resolved_dispatch_kernargs_,
+                    pre_resolved_dispatch_kernarg_length_);
+        submission.dispatch_setup =
+            iree_hal_amdgpu_host_queue_write_dispatch_packet_body(
+                &submission.dispatch_slot->dispatch,
+                &pre_resolved_dispatch_packet_template_,
+                submission.kernel.kernargs.blocks->data,
+                submission.dispatch_completion_signal);
+        iree_hal_amdgpu_host_queue_finish_dispatch_submission(
+            host_queue, &resolution, signal_semaphore_list, operation_resources,
+            IREE_ARRAYSIZE(operation_resources),
+            /*profile_queue_event_info=*/nullptr,
+            IREE_HAL_AMDGPU_HOST_QUEUE_SUBMISSION_FLAG_RETAIN_RESOURCES,
+            &submission);
+      }
+    }
+    iree_slim_mutex_unlock(&host_queue->locks.submission_mutex);
+    return status;
+  }
+
+  iree_status_t ValidateDispatchOnce(
+      iree_hal_amdgpu_host_queue_t* host_queue,
+      iree_host_size_t* out_operation_resource_count) {
+    const uint32_t constant_data[] = {3, 10};
+    iree_const_byte_span_t constants =
+        iree_make_const_byte_span(constant_data, sizeof(constant_data));
+    iree_hal_buffer_ref_t binding_refs[2] = {
+        iree_hal_make_buffer_ref(source_buffer_, /*offset=*/0,
+                                 iree_hal_buffer_byte_length(source_buffer_)),
+        iree_hal_make_buffer_ref(target_buffer_, /*offset=*/0,
+                                 iree_hal_buffer_byte_length(target_buffer_)),
+    };
+    iree_hal_buffer_ref_list_t bindings = {
+        /*count=*/IREE_ARRAYSIZE(binding_refs),
+        /*values=*/binding_refs,
+    };
+    return iree_hal_amdgpu_host_queue_validate_dispatch(
+        host_queue, dispatch_executable_, /*export_ordinal=*/0,
+        iree_hal_make_static_dispatch_config(1, 1, 1), constants, bindings,
+        IREE_HAL_DISPATCH_FLAG_NONE, out_operation_resource_count);
+  }
+
+  static iree_status_t BindingCountDispatchConfig(
+      iree_hal_dispatch_config_t* out_config) {
+    *out_config = iree_hal_make_static_dispatch_config(1, 1, 1);
+    if (FLAG_binding_count_workgroup_size_x == 0) {
+      return iree_ok_status();
+    }
+    if (FLAG_binding_count_workgroup_size_x < 0 ||
+        FLAG_binding_count_workgroup_size_x > UINT16_MAX) {
+      return iree_make_status(
+          IREE_STATUS_OUT_OF_RANGE,
+          "binding_count_workgroup_size_x must be in [0, %" PRIu16 "]",
+          UINT16_MAX);
+    }
+    out_config->workgroup_size[0] =
+        (uint32_t)FLAG_binding_count_workgroup_size_x;
+    out_config->workgroup_size[1] = 1;
+    out_config->workgroup_size[2] = 1;
+    return iree_ok_status();
+  }
+
+  static iree_status_t BindingCountExportName(int64_t binding_count,
+                                              char* buffer,
+                                              iree_host_size_t capacity,
+                                              iree_string_view_t* out_name) {
+    *out_name = iree_string_view_empty();
+    if (binding_count < 0 ||
+        binding_count > (int64_t)kDispatchBindingBenchmarkMaxCount) {
+      return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                              "unsupported dispatch benchmark binding count "
+                              "%" PRId64,
+                              binding_count);
+    }
+    int result =
+        snprintf(buffer, capacity, "binding_count_%" PRId64, binding_count);
+    if (result < 0 || (iree_host_size_t)result >= capacity) {
+      return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                              "binding-count export name is too long");
+    }
+    *out_name = iree_make_string_view(buffer, (iree_host_size_t)result);
+    return iree_ok_status();
+  }
+
+  iree_status_t BindingCountExportOrdinal(
+      int64_t binding_count,
+      iree_hal_executable_export_ordinal_t* out_export_ordinal) {
+    *out_export_ordinal = 0;
+    char export_name_buffer[32] = {0};
+    iree_string_view_t export_name = iree_string_view_empty();
+    IREE_RETURN_IF_ERROR(BindingCountExportName(
+        binding_count, export_name_buffer, IREE_ARRAYSIZE(export_name_buffer),
+        &export_name));
+    return iree_hal_executable_lookup_export_by_name(
+        binding_count_executable_, export_name, out_export_ordinal);
+  }
+
+  iree_hal_buffer_ref_list_t BindingCountDispatchBindings(
+      int64_t binding_count) {
+    for (iree_host_size_t i = 0; i < (iree_host_size_t)binding_count; ++i) {
+      binding_count_binding_ref_scratch_[i] = iree_hal_make_buffer_ref(
+          binding_count_buffers_[i], /*offset=*/0,
+          iree_hal_buffer_byte_length(binding_count_buffers_[i]));
+    }
+    return (iree_hal_buffer_ref_list_t){
+        /*count=*/(iree_host_size_t)binding_count,
+        /*values=*/binding_count_binding_ref_scratch_,
+    };
+  }
+
+  iree_status_t ValidateBindingCountDispatchOnce(
+      iree_hal_amdgpu_host_queue_t* host_queue, int64_t binding_count,
+      iree_host_size_t* out_operation_resource_count) {
+    iree_hal_executable_export_ordinal_t export_ordinal = 0;
+    IREE_RETURN_IF_ERROR(
+        BindingCountExportOrdinal(binding_count, &export_ordinal));
+    iree_hal_dispatch_config_t dispatch_config;
+    IREE_RETURN_IF_ERROR(BindingCountDispatchConfig(&dispatch_config));
+    return iree_hal_amdgpu_host_queue_validate_dispatch(
+        host_queue, binding_count_executable_, export_ordinal, dispatch_config,
+        iree_const_byte_span_empty(),
+        BindingCountDispatchBindings(binding_count),
+        IREE_HAL_DISPATCH_FLAG_NONE, out_operation_resource_count);
+  }
+
+  iree_status_t SubmitBindingCountDispatchWithLists(
+      iree_hal_queue_affinity_t queue_affinity,
+      iree_hal_semaphore_list_t wait_semaphore_list,
+      iree_hal_semaphore_list_t signal_semaphore_list, int64_t binding_count) {
+    iree_hal_executable_export_ordinal_t export_ordinal = 0;
+    IREE_RETURN_IF_ERROR(
+        BindingCountExportOrdinal(binding_count, &export_ordinal));
+    iree_hal_dispatch_config_t dispatch_config;
+    IREE_RETURN_IF_ERROR(BindingCountDispatchConfig(&dispatch_config));
+    return iree_hal_device_queue_dispatch(
+        device_, queue_affinity, wait_semaphore_list, signal_semaphore_list,
+        binding_count_executable_, export_ordinal, dispatch_config,
+        iree_const_byte_span_empty(),
+        BindingCountDispatchBindings(binding_count),
+        IREE_HAL_DISPATCH_FLAG_NONE);
+  }
+
+  iree_status_t RecordBindingCountCommandBuffer(
+      int64_t binding_count, iree_hal_command_buffer_t** out_command_buffer) {
+    *out_command_buffer = nullptr;
+    iree_hal_executable_export_ordinal_t export_ordinal = 0;
+    IREE_RETURN_IF_ERROR(
+        BindingCountExportOrdinal(binding_count, &export_ordinal));
+    iree_hal_dispatch_config_t dispatch_config;
+    IREE_RETURN_IF_ERROR(BindingCountDispatchConfig(&dispatch_config));
+
+    iree_hal_command_buffer_t* command_buffer = nullptr;
+    iree_status_t status = iree_hal_command_buffer_create(
+        device_, IREE_HAL_COMMAND_BUFFER_MODE_DEFAULT,
+        IREE_HAL_COMMAND_CATEGORY_DISPATCH, IREE_HAL_QUEUE_AFFINITY_ANY,
+        /*binding_capacity=*/0, &command_buffer);
+    if (iree_status_is_ok(status)) {
+      status = iree_hal_command_buffer_begin(command_buffer);
+    }
+    if (iree_status_is_ok(status)) {
+      status = iree_hal_command_buffer_dispatch(
+          command_buffer, binding_count_executable_, export_ordinal,
+          dispatch_config, iree_const_byte_span_empty(),
+          BindingCountDispatchBindings(binding_count),
+          IREE_HAL_DISPATCH_FLAG_NONE);
+    }
+    if (iree_status_is_ok(status)) {
+      status = iree_hal_command_buffer_end(command_buffer);
+    }
+    if (iree_status_is_ok(status)) {
+      *out_command_buffer = command_buffer;
+    } else {
+      iree_hal_command_buffer_release(command_buffer);
+    }
+    return status;
+  }
+
+  bool EnsureBindingCountCommandBuffer(benchmark::State& state,
+                                       int64_t binding_count) {
+    if (!EnsureBindingCountExecutable(state) ||
+        !EnsureBindingCountBuffers(state, binding_count)) {
+      return false;
+    }
+    iree_hal_command_buffer_t** command_buffer_slot =
+        &binding_count_command_buffers_[binding_count];
+    if (*command_buffer_slot) return true;
+    iree_status_t status =
+        RecordBindingCountCommandBuffer(binding_count, command_buffer_slot);
+    return HandleStatus(state, status,
+                        "failed to record binding-count command buffer");
+  }
+
+  iree_status_t RecordBindingCountDispatchChainCommandBuffer(
+      int64_t operation_count, int64_t binding_count,
+      iree_hal_command_buffer_t** out_command_buffer) {
+    *out_command_buffer = nullptr;
+    if (operation_count <= 0) {
+      return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                              "operation count must be positive");
+    }
+
+    iree_hal_executable_export_ordinal_t export_ordinal = 0;
+    IREE_RETURN_IF_ERROR(
+        BindingCountExportOrdinal(binding_count, &export_ordinal));
+    iree_hal_dispatch_config_t dispatch_config;
+    IREE_RETURN_IF_ERROR(BindingCountDispatchConfig(&dispatch_config));
+
+    iree_hal_command_buffer_t* command_buffer = nullptr;
+    iree_status_t status = iree_hal_command_buffer_create(
+        device_, IREE_HAL_COMMAND_BUFFER_MODE_DEFAULT,
+        IREE_HAL_COMMAND_CATEGORY_DISPATCH, IREE_HAL_QUEUE_AFFINITY_ANY,
+        /*binding_capacity=*/0, &command_buffer);
+    if (iree_status_is_ok(status)) {
+      status = iree_hal_command_buffer_begin(command_buffer);
+    }
+    for (int64_t i = 0; i < operation_count && iree_status_is_ok(status); ++i) {
+      status = iree_hal_command_buffer_dispatch(
+          command_buffer, binding_count_executable_, export_ordinal,
+          dispatch_config, iree_const_byte_span_empty(),
+          BindingCountDispatchBindings(binding_count),
+          IREE_HAL_DISPATCH_FLAG_NONE);
+    }
+    if (iree_status_is_ok(status)) {
+      status = iree_hal_command_buffer_end(command_buffer);
+    }
+    if (iree_status_is_ok(status)) {
+      *out_command_buffer = command_buffer;
+    } else {
+      iree_hal_command_buffer_release(command_buffer);
+    }
+    return status;
+  }
+
+  void SetBindingCountCounters(benchmark::State& state, int64_t binding_count) {
+    state.counters["binding_count"] = static_cast<double>(binding_count);
+    state.counters["binding_count_external_executable"] =
+        iree_string_view_is_empty(
+            iree_make_cstring_view(FLAG_binding_count_executable_file))
+            ? 0.0
+            : 1.0;
+    state.counters["binding_count_workgroup_size_x"] =
+        static_cast<double>(FLAG_binding_count_workgroup_size_x);
+  }
+
+  void SetCommandBufferProgramCounters(
+      benchmark::State& state, iree_hal_command_buffer_t* command_buffer,
+      int64_t operation_count, int64_t binding_count) {
+    const iree_hal_amdgpu_aql_program_t* program =
+        iree_hal_amdgpu_aql_command_buffer_program(command_buffer);
+    uint64_t payload_block_count = 0;
+    uint64_t total_aql_packet_count = 0;
+    uint64_t total_block_bytes = 0;
+    uint64_t total_used_bytes = 0;
+    struct {
+      uint64_t dispatch_count = 0;
+      uint64_t payload_bytes = 0;
+      uint64_t storage_span_bytes = 0;
+    } prepublished_kernarg;
+    struct {
+      uint64_t dispatch_count = 0;
+      uint64_t payload_bytes = 0;
+      uint64_t reserved_bytes = 0;
+    } queue_kernarg;
+    for (const iree_hal_amdgpu_command_buffer_block_header_t* block =
+             program->first_block;
+         block; block = iree_hal_amdgpu_aql_program_block_next(
+                    program->block_pool, block)) {
+      const uint64_t binding_source_length =
+          (uint64_t)block->binding_source_count *
+          sizeof(iree_hal_amdgpu_command_buffer_binding_source_t);
+      const uint64_t used_bytes = block->header_length + block->command_length +
+                                  binding_source_length + block->rodata_length;
+      if (block->aql_packet_count > 0) ++payload_block_count;
+      total_aql_packet_count += block->aql_packet_count;
+      total_block_bytes += block->block_length;
+      total_used_bytes += used_bytes;
+
+      const uint8_t* command_end =
+          (const uint8_t*)block + block->command_offset + block->command_length;
+      for (const iree_hal_amdgpu_command_buffer_command_header_t* command =
+               iree_hal_amdgpu_command_buffer_block_commands_const(block);
+           (const uint8_t*)command < command_end;
+           command =
+               iree_hal_amdgpu_command_buffer_command_next_const(command)) {
+        if (command->opcode != IREE_HAL_AMDGPU_COMMAND_BUFFER_OPCODE_DISPATCH) {
+          continue;
+        }
+        const iree_hal_amdgpu_command_buffer_dispatch_command_t*
+            dispatch_command =
+                (const iree_hal_amdgpu_command_buffer_dispatch_command_t*)
+                    command;
+        const uint64_t kernarg_bytes =
+            (uint64_t)dispatch_command->kernarg_length_qwords * 8u;
+        if (dispatch_command->kernarg_strategy ==
+            IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_PREPUBLISHED) {
+          ++prepublished_kernarg.dispatch_count;
+          prepublished_kernarg.payload_bytes += kernarg_bytes;
+          const uint64_t storage_end =
+              (uint64_t)dispatch_command->payload_reference + kernarg_bytes;
+          prepublished_kernarg.storage_span_bytes =
+              std::max(prepublished_kernarg.storage_span_bytes, storage_end);
+        } else {
+          ++queue_kernarg.dispatch_count;
+          queue_kernarg.payload_bytes += kernarg_bytes;
+          uint64_t kernarg_block_count = std::max<uint64_t>(
+              1, (kernarg_bytes + sizeof(iree_hal_amdgpu_kernarg_block_t) - 1) /
+                     sizeof(iree_hal_amdgpu_kernarg_block_t));
+          if (iree_any_bit_set(
+                  dispatch_command->dispatch_flags,
+                  IREE_HAL_AMDGPU_COMMAND_BUFFER_DISPATCH_FLAG_INDIRECT_PARAMETERS)) {
+            ++kernarg_block_count;
+          }
+          queue_kernarg.reserved_bytes +=
+              kernarg_block_count * sizeof(iree_hal_amdgpu_kernarg_block_t);
+        }
+      }
+    }
+
+    state.counters["operation_count"] = static_cast<double>(operation_count);
+    SetBindingCountCounters(state, binding_count);
+    state.counters["command_buffer_blocks"] =
+        static_cast<double>(program->block_count);
+    state.counters["command_buffer_payload_blocks"] =
+        static_cast<double>(payload_block_count);
+    state.counters["command_buffer_block_bytes"] =
+        static_cast<double>(program->first_block->block_length);
+    state.counters["command_buffer_occupancy_pct"] =
+        total_block_bytes
+            ? 100.0 * (double)total_used_bytes / (double)total_block_bytes
+            : 0.0;
+    state.counters["aql_packets_per_sync"] =
+        static_cast<double>(total_aql_packet_count);
+    state.counters["max_block_aql_packets"] =
+        static_cast<double>(program->max_block_aql_packet_count);
+    state.counters["max_block_kernarg_bytes"] =
+        static_cast<double>(program->max_block_kernarg_length);
+    state.counters["prepublished_dispatches_per_sync"] =
+        static_cast<double>(prepublished_kernarg.dispatch_count);
+    state.counters["prepublished_kernarg_payload_bytes_per_sync"] =
+        static_cast<double>(prepublished_kernarg.payload_bytes);
+    state.counters["prepublished_storage_span_bytes"] =
+        static_cast<double>(prepublished_kernarg.storage_span_bytes);
+    state.counters["queue_kernarg_dispatches_per_sync"] =
+        static_cast<double>(queue_kernarg.dispatch_count);
+    state.counters["queue_kernarg_payload_bytes_per_sync"] =
+        static_cast<double>(queue_kernarg.payload_bytes);
+    state.counters["queue_kernarg_reserved_bytes_per_sync"] =
+        static_cast<double>(queue_kernarg.reserved_bytes);
+    state.counters["queue_submissions_per_sync"] = 1.0;
+    iree_hal_amdgpu_benchmark_set_completion_wait_counters(state);
+    state.SetItemsProcessed(state.iterations() * operation_count);
+  }
+
+  iree_status_t BindingCountDispatchSubmitPublicFinalInline(
+      int64_t binding_count, SubmittedCompletion* out_completion) {
+    uint64_t completion_payload_value = ++completion_payload_value_;
+    iree_hal_semaphore_t* signal_semaphore = completion_semaphore_;
+    iree_hal_semaphore_list_t signal_semaphore_list = {
+        /*count=*/1,
+        /*semaphores=*/&signal_semaphore,
+        /*payload_values=*/&completion_payload_value,
+    };
+    IREE_RETURN_IF_ERROR(SubmitBindingCountDispatchWithLists(
+        kQueue0, iree_hal_semaphore_list_empty(), signal_semaphore_list,
+        binding_count));
+    *out_completion = {completion_semaphore_, completion_payload_value};
+    return iree_ok_status();
+  }
+
+  iree_status_t BindingCountCommandBufferSubmitPublicFinalInline(
+      int64_t binding_count, SubmittedCompletion* out_completion) {
+    uint64_t completion_payload_value = ++completion_payload_value_;
+    iree_hal_semaphore_t* signal_semaphore = completion_semaphore_;
+    iree_hal_semaphore_list_t signal_semaphore_list = {
+        /*count=*/1,
+        /*semaphores=*/&signal_semaphore,
+        /*payload_values=*/&completion_payload_value,
+    };
+    IREE_RETURN_IF_ERROR(iree_hal_device_queue_execute(
+        device_, kQueue0, iree_hal_semaphore_list_empty(),
+        signal_semaphore_list, binding_count_command_buffers_[binding_count],
+        iree_hal_buffer_binding_table_empty(), IREE_HAL_EXECUTE_FLAG_NONE));
+    *out_completion = {completion_semaphore_, completion_payload_value};
+    return iree_ok_status();
+  }
+
+  iree_status_t BindingCountDispatchChainCommandBufferSubmitPublicFinalInline(
+      iree_hal_command_buffer_t* command_buffer,
+      SubmittedCompletion* out_completion) {
+    uint64_t completion_payload_value = ++completion_payload_value_;
+    iree_hal_semaphore_t* signal_semaphore = completion_semaphore_;
+    iree_hal_semaphore_list_t signal_semaphore_list = {
+        /*count=*/1,
+        /*semaphores=*/&signal_semaphore,
+        /*payload_values=*/&completion_payload_value,
+    };
+    IREE_RETURN_IF_ERROR(iree_hal_device_queue_execute(
+        device_, kQueue0, iree_hal_semaphore_list_empty(),
+        signal_semaphore_list, command_buffer,
+        iree_hal_buffer_binding_table_empty(), IREE_HAL_EXECUTE_FLAG_NONE));
+    *out_completion = {completion_semaphore_, completion_payload_value};
+    return iree_ok_status();
+  }
+
+ private:
+  static iree_const_byte_span_t FindExecutableData(
+      const iree_file_toc_t* toc, iree_string_view_t file_name) {
+    for (iree_host_size_t i = 0; toc[i].name != nullptr; ++i) {
+      if (iree_string_view_equal(file_name,
+                                 iree_make_cstring_view(toc[i].name))) {
+        return iree_make_const_byte_span(
+            reinterpret_cast<const uint8_t*>(toc[i].data), toc[i].size);
+      }
+    }
+    return iree_const_byte_span_empty();
+  }
+
+  static iree_const_byte_span_t FindCtsExecutableData(
+      iree_string_view_t file_name) {
+    return FindExecutableData(iree_cts_testdata_amdgpu_create(), file_name);
+  }
+
+  static iree_const_byte_span_t FindQueueBenchmarkExecutableData(
+      iree_string_view_t file_name) {
+    return FindExecutableData(iree_queue_benchmark_testdata_amdgpu_create(),
+                              file_name);
+  }
+
+  bool AllocatePayloadBuffers(benchmark::State& state) {
+    iree_hal_allocator_t* allocator = iree_hal_device_allocator(device_);
+    iree_hal_buffer_params_t params = {0};
+    params.usage =
+        IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE | IREE_HAL_BUFFER_USAGE_TRANSFER;
+    params.type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE;
+    params.min_alignment = kPayloadBufferAlignment;
+
+    if (!HandleStatus(
+            state,
+            iree_hal_allocator_allocate_buffer(
+                allocator, params, kPayloadBufferAlignment, &source_buffer_),
+            "failed to allocate source payload buffer")) {
+      return false;
+    }
+    if (!HandleStatus(
+            state,
+            iree_hal_allocator_allocate_buffer(
+                allocator, params, kPayloadBufferAlignment, &target_buffer_),
+            "failed to allocate target payload buffer")) {
+      return false;
+    }
+
+    uint8_t source_pattern = 0x5A;
+    if (!HandleStatus(state,
+                      FillBufferAndWait(source_buffer_, &source_pattern,
+                                        sizeof(source_pattern)),
+                      "failed to initialize source payload buffer")) {
+      return false;
+    }
+    uint8_t target_pattern = 0x00;
+    return HandleStatus(state,
+                        FillBufferAndWait(target_buffer_, &target_pattern,
+                                          sizeof(target_pattern)),
+                        "failed to initialize target payload buffer");
+  }
+
+  bool CreatePublicSemaphore(benchmark::State& state,
+                             iree_hal_semaphore_t** out_semaphore) {
+    return HandleStatus(state,
+                        iree_hal_semaphore_create(
+                            device_, IREE_HAL_QUEUE_AFFINITY_ANY,
+                            /*initial_value=*/0,
+                            IREE_HAL_SEMAPHORE_FLAG_DEFAULT, out_semaphore),
+                        "failed to create semaphore");
+  }
+
+  bool CreatePrivateStreamSemaphore(benchmark::State& state,
+                                    iree_hal_semaphore_t** out_semaphore) {
+    return HandleStatus(
+        state,
+        iree_hal_semaphore_create(device_, kQueue0 | kQueue1,
+                                  /*initial_value=*/0,
+                                  IREE_HAL_SEMAPHORE_FLAG_DEVICE_LOCAL |
+                                      IREE_HAL_SEMAPHORE_FLAG_SINGLE_PRODUCER,
+                                  out_semaphore),
+        "failed to create private stream semaphore");
+  }
+
+  static bool initialized_;
+  static bool available_;
+  static iree_allocator_t host_allocator_;
+  static iree_hal_driver_t* driver_;
+  static iree_hal_device_group_t* device_group_;
+  static iree_hal_device_t* device_;
+  // Executable cache used for the CTS-derived tiny dispatch benchmark payload.
+  static iree_hal_executable_cache_t* dispatch_executable_cache_;
+  // CTS-derived tiny dispatch executable shared by dispatch benchmark rows.
+  static iree_hal_executable_t* dispatch_executable_;
+  // Executable cache used for binding-count dispatch benchmark exports.
+  static iree_hal_executable_cache_t* binding_count_executable_cache_;
+  // Empty-kernel executable with exports that vary only in ABI binding count.
+  static iree_hal_executable_t* binding_count_executable_;
+  // Optional file contents backing an externally loaded binding-count HSACO.
+  static iree_io_file_contents_t* binding_count_executable_file_contents_;
+
+  // Precomputed dispatch packet body used by direct-substrate attribution rows.
+  iree_hsa_kernel_dispatch_packet_t pre_resolved_dispatch_packet_template_ = {};
+  // Precomputed kernarg bytes used by direct-substrate attribution rows.
+  uint8_t* pre_resolved_dispatch_kernargs_ = nullptr;
+  // Byte length of |pre_resolved_dispatch_kernargs_|.
+  iree_host_size_t pre_resolved_dispatch_kernarg_length_ = 0;
+  // Queue kernarg-ring block count required by the precomputed kernarg bytes.
+  uint32_t pre_resolved_dispatch_kernarg_block_count_ = 0;
+  // Small source buffer used by queue copy payload benchmark rows.
+  iree_hal_buffer_t* source_buffer_ = nullptr;
+  // Small target buffer used by queue copy and fill payload benchmark rows.
+  iree_hal_buffer_t* target_buffer_ = nullptr;
+  // Unique device buffers passed to binding-count dispatch benchmark rows.
+  iree_hal_buffer_t* binding_count_buffers_[kDispatchBindingBenchmarkMaxCount] =
+      {};
+  // Queue_dispatch binding refs assembled from |binding_count_buffers_|.
+  iree_hal_buffer_ref_t
+      binding_count_binding_ref_scratch_[kDispatchBindingBenchmarkMaxCount] =
+          {};
+  // Static direct command buffers for binding-count dispatch benchmark rows.
+  iree_hal_command_buffer_t*
+      binding_count_command_buffers_[kDispatchBindingBenchmarkVariantCapacity] =
+          {};
+  // Public semaphore used for final host-observable completion.
+  iree_hal_semaphore_t* completion_semaphore_ = nullptr;
+  // Private single-producer stream semaphore used by queue 1.
+  iree_hal_semaphore_t* stream_semaphore_ = nullptr;
+  // Private single-producer stream semaphore used by queue 0.
+  iree_hal_semaphore_t* producer_semaphore_ = nullptr;
+  // Discard sink used by profiling overhead guardrail rows.
+  iree_hal_profile_sink_t* profile_sink_ = nullptr;
+  // True while the fixture owns an active HAL profiling session.
+  bool profile_session_active_ = false;
+  // Next public completion payload value.
+  uint64_t completion_payload_value_ = 0;
+  // Next private queue 1 stream payload value.
+  uint64_t stream_payload_value_ = 0;
+  // Next private queue 0 stream payload value.
+  uint64_t producer_payload_value_ = 0;
+  // Dword fill pattern used by fill payload benchmark rows.
+  uint32_t fill_pattern_ = 0xDEADBEEFu;
+};
+
+bool QueueBenchmark::initialized_ = false;
+bool QueueBenchmark::available_ = false;
+iree_allocator_t QueueBenchmark::host_allocator_;
+iree_hal_driver_t* QueueBenchmark::driver_ = nullptr;
+iree_hal_device_group_t* QueueBenchmark::device_group_ = nullptr;
+iree_hal_device_t* QueueBenchmark::device_ = nullptr;
+iree_hal_executable_cache_t* QueueBenchmark::dispatch_executable_cache_ =
+    nullptr;
+iree_hal_executable_t* QueueBenchmark::dispatch_executable_ = nullptr;
+iree_hal_executable_cache_t* QueueBenchmark::binding_count_executable_cache_ =
+    nullptr;
+iree_hal_executable_t* QueueBenchmark::binding_count_executable_ = nullptr;
+iree_io_file_contents_t*
+    QueueBenchmark::binding_count_executable_file_contents_ = nullptr;
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   SameQueueBarrierWait)(benchmark::State& state) {
+  for (auto _ : state) {
+    if (!HandleStatus(state, SameQueueBarrierAndWait(),
+                      "same-queue barrier failed")) {
+      break;
+    }
+  }
+  SetQueueSubmissionsProcessed(state, /*queue_submissions_per_sync=*/1);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   QueueAllocaWarmTlsfSubmitOnly)(benchmark::State& state) {
+  const iree_device_size_t allocation_size =
+      static_cast<iree_device_size_t>(state.range(0));
+  iree_hal_pool_t* pool = nullptr;
+  if (!HandleStatus(state, CreateQueueAllocaTlsfPool(allocation_size, &pool),
+                    "failed to create queue_alloca TLSF pool")) {
+    return;
+  }
+  if (!HandleStatus(state, QueueAllocaSubmitAndCleanup(pool, allocation_size),
+                    "failed to prewarm queue_alloca TLSF pool")) {
+    iree_hal_pool_release(pool);
+    return;
+  }
+
+  for (auto _ : state) {
+    iree_hal_buffer_t* buffer = nullptr;
+    SubmittedCompletion completion;
+    if (!HandleStatus(
+            state,
+            QueueAllocaSubmit(pool, allocation_size, &buffer, &completion),
+            "warm queue_alloca submit failed")) {
+      break;
+    }
+    state.PauseTiming();
+    const bool cleanup_ok =
+        HandleStatus(state, QueueAllocaCleanup(buffer, completion),
+                     "warm queue_alloca cleanup failed");
+    state.ResumeTiming();
+    if (!cleanup_ok) break;
+  }
+
+  iree_hal_pool_release(pool);
+  SetQueueAllocaCounters(state, allocation_size,
+                         QueueAllocaTlsfGrowthMode::kWarm);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark, QueueAllocaForcedTlsfGrowthSubmitOnly)(
+    benchmark::State& state) {
+  const iree_device_size_t allocation_size =
+      static_cast<iree_device_size_t>(state.range(0));
+
+  for (auto _ : state) {
+    state.PauseTiming();
+    iree_hal_pool_t* pool = nullptr;
+    iree_hal_buffer_t* held_buffer = nullptr;
+    SubmittedCompletion held_completion = {};
+    iree_status_t status = CreateQueueAllocaTlsfPool(allocation_size, &pool);
+    if (iree_status_is_ok(status)) {
+      status = QueueAllocaSubmit(pool, allocation_size, &held_buffer,
+                                 &held_completion);
+    }
+    if (iree_status_is_ok(status)) {
+      status = Wait(held_completion.semaphore, held_completion.payload_value);
+    }
+    if (!HandleStatus(state, status,
+                      "failed to seed forced-growth queue_alloca pool")) {
+      iree_hal_buffer_release(held_buffer);
+      iree_hal_pool_release(pool);
+      state.ResumeTiming();
+      break;
+    }
+
+    state.ResumeTiming();
+    iree_hal_buffer_t* grown_buffer = nullptr;
+    SubmittedCompletion grown_completion;
+    status = QueueAllocaSubmit(pool, allocation_size, &grown_buffer,
+                               &grown_completion);
+    state.PauseTiming();
+
+    bool ok =
+        HandleStatus(state, status, "forced-growth queue_alloca submit failed");
+    if (ok) {
+      ok = HandleStatus(state,
+                        QueueAllocaCleanup(grown_buffer, grown_completion),
+                        "forced-growth queue_alloca cleanup failed");
+    } else {
+      iree_hal_buffer_release(grown_buffer);
+    }
+    ok = HandleStatus(state, QueueAllocaCleanup(held_buffer, held_completion),
+                      "forced-growth held alloca cleanup failed") &&
+         ok;
+    iree_hal_pool_release(pool);
+    state.ResumeTiming();
+    if (!ok) break;
+  }
+
+  SetQueueAllocaCounters(state, allocation_size,
+                         QueueAllocaTlsfGrowthMode::kForcedGrowth);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   HostCallBlockingWait)(benchmark::State& state) {
+  for (auto _ : state) {
+    if (!HandleStatus(state, HostCallAndWait(IREE_HAL_HOST_CALL_FLAG_NONE),
+                      "blocking host call failed")) {
+      break;
+    }
+  }
+  SetQueueSubmissionsProcessed(state, /*queue_submissions_per_sync=*/1);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   HostCallBlockingRelaxedWait)(benchmark::State& state) {
+  for (auto _ : state) {
+    if (!HandleStatus(state, HostCallAndWait(IREE_HAL_HOST_CALL_FLAG_RELAXED),
+                      "relaxed blocking host call failed")) {
+      break;
+    }
+  }
+  SetQueueSubmissionsProcessed(state, /*queue_submissions_per_sync=*/1);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   HostCallNonBlockingWait)(benchmark::State& state) {
+  for (auto _ : state) {
+    if (!HandleStatus(state,
+                      HostCallAndWait(IREE_HAL_HOST_CALL_FLAG_NON_BLOCKING |
+                                      IREE_HAL_HOST_CALL_FLAG_RELAXED),
+                      "relaxed nonblocking host call failed")) {
+      break;
+    }
+  }
+  SetQueueSubmissionsProcessed(state, /*queue_submissions_per_sync=*/1);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   HostCallBlockingBatch20FinalWait)(benchmark::State& state) {
+  for (auto _ : state) {
+    if (!HandleStatus(
+            state,
+            HostCallBatchAndWait(IREE_HAL_HOST_CALL_FLAG_NONE, kBatchCount),
+            "blocking host call batch failed")) {
+      break;
+    }
+  }
+  state.counters["host_functions_per_sync"] = static_cast<double>(kBatchCount);
+  SetQueueSubmissionsProcessed(state, kBatchCount);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark, HostCallBlockingRelaxedBatch20FinalWait)(
+    benchmark::State& state) {
+  for (auto _ : state) {
+    if (!HandleStatus(
+            state,
+            HostCallBatchAndWait(IREE_HAL_HOST_CALL_FLAG_RELAXED, kBatchCount),
+            "relaxed blocking host call batch failed")) {
+      break;
+    }
+  }
+  state.counters["host_functions_per_sync"] = static_cast<double>(kBatchCount);
+  SetQueueSubmissionsProcessed(state, kBatchCount);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark, HostCallNonBlockingBatch20FinalWait)(
+    benchmark::State& state) {
+  for (auto _ : state) {
+    if (!HandleStatus(
+            state,
+            HostCallBatchAndWait(IREE_HAL_HOST_CALL_FLAG_NON_BLOCKING |
+                                     IREE_HAL_HOST_CALL_FLAG_RELAXED,
+                                 kBatchCount),
+            "relaxed nonblocking host call batch failed")) {
+      break;
+    }
+  }
+  state.counters["host_functions_per_sync"] = static_cast<double>(kBatchCount);
+  SetQueueSubmissionsProcessed(state, kBatchCount);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   SameQueueBarrierBatch20FinalWait)(benchmark::State& state) {
+  for (auto _ : state) {
+    if (!HandleStatus(state, SameQueueBarrierBatchAndWait(kBatchCount),
+                      "same-queue barrier batch failed")) {
+      break;
+    }
+  }
+  SetQueueSubmissionsProcessed(state, kBatchCount);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   SameQueueBarrierBatchFinalWait)(benchmark::State& state) {
+  const int64_t batch_count = state.range(0);
+  for (auto _ : state) {
+    if (!HandleStatus(state, SameQueueBarrierBatchAndWait(batch_count),
+                      "same-queue barrier batch failed")) {
+      break;
+    }
+  }
+  SetQueueSubmissionsProcessed(state, batch_count);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   SameQueueBarrierBatchSubmitOnly)(benchmark::State& state) {
+  const int64_t batch_count = state.range(0);
+  for (auto _ : state) {
+    SubmittedCompletion completion;
+    if (!HandleStatus(state,
+                      SameQueueBarrierBatchSubmit(batch_count, &completion),
+                      "same-queue barrier batch submit failed")) {
+      break;
+    }
+    if (!WaitWithTimingPaused(state, completion,
+                              "same-queue barrier batch wait failed")) {
+      break;
+    }
+  }
+  SetQueueSubmissionsProcessed(state, batch_count);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   SameQueueEpochChain20)(benchmark::State& state) {
+  for (auto _ : state) {
+    if (!HandleStatus(state, SameQueueEpochChainAndWait(kBatchCount),
+                      "same-queue epoch chain failed")) {
+      break;
+    }
+  }
+  SetQueueSubmissionsProcessed(state, kBatchCount);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   SameQueueEpochChain)(benchmark::State& state) {
+  const int64_t batch_count = state.range(0);
+  for (auto _ : state) {
+    if (!HandleStatus(state, SameQueueEpochChainAndWait(batch_count),
+                      "same-queue epoch chain failed")) {
+      break;
+    }
+  }
+  SetQueueSubmissionsProcessed(state, batch_count);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   SameQueueEpochChainSubmitOnly)(benchmark::State& state) {
+  const int64_t batch_count = state.range(0);
+  for (auto _ : state) {
+    SubmittedCompletion completion;
+    if (!HandleStatus(state,
+                      SameQueueEpochChainSubmit(batch_count, &completion),
+                      "same-queue epoch chain submit failed")) {
+      break;
+    }
+    if (!WaitWithTimingPaused(state, completion,
+                              "same-queue epoch chain wait failed")) {
+      break;
+    }
+  }
+  SetQueueSubmissionsProcessed(state, batch_count);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   CrossQueueAlreadyCompletedWait)(benchmark::State& state) {
+  if (!EnsureQueueAvailable(state, kQueue1)) return;
+  if (!HandleStatus(state, PrimeProducerSemaphore(),
+                    "failed to prime producer semaphore")) {
+    return;
+  }
+
+  for (auto _ : state) {
+    if (!HandleStatus(state, CrossQueueAlreadyCompletedWaitAndSignal(),
+                      "cross-queue completed wait failed")) {
+      break;
+    }
+  }
+  SetQueueSubmissionsProcessed(state, /*queue_submissions_per_sync=*/1);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   CrossQueueBarrierValue)(benchmark::State& state) {
+  if (!EnsureQueueAvailable(state, kQueue1)) return;
+  for (auto _ : state) {
+    if (!HandleStatus(state, CrossQueueBarrierValueAndWait(),
+                      "cross-queue barrier-value wait failed")) {
+      break;
+    }
+  }
+  SetQueueSubmissionsProcessed(state, /*queue_submissions_per_sync=*/2);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark, CrossQueueBarrierValueBatch20FinalWait)(
+    benchmark::State& state) {
+  if (!EnsureQueueAvailable(state, kQueue1)) return;
+  for (auto _ : state) {
+    if (!HandleStatus(state, CrossQueueBarrierValueBatchAndWait(kBatchCount),
+                      "cross-queue barrier-value batch failed")) {
+      break;
+    }
+  }
+  state.counters["cross_queue_handoffs_per_sync"] =
+      static_cast<double>(kBatchCount);
+  SetQueueSubmissionsProcessed(state,
+                               /*queue_submissions_per_sync=*/2 * kBatchCount);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark, CrossQueueBarrierValueBatchFinalWait)(
+    benchmark::State& state) {
+  if (!EnsureQueueAvailable(state, kQueue1)) return;
+  const int64_t batch_count = state.range(0);
+  for (auto _ : state) {
+    if (!HandleStatus(state, CrossQueueBarrierValueBatchAndWait(batch_count),
+                      "cross-queue barrier-value batch failed")) {
+      break;
+    }
+  }
+  state.counters["cross_queue_handoffs_per_sync"] =
+      static_cast<double>(batch_count);
+  SetQueueSubmissionsProcessed(state,
+                               /*queue_submissions_per_sync=*/2 * batch_count);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark, CrossQueueBarrierValueBatchSubmitOnly)(
+    benchmark::State& state) {
+  if (!EnsureQueueAvailable(state, kQueue1)) return;
+  const int64_t batch_count = state.range(0);
+  for (auto _ : state) {
+    SubmittedCompletion completion;
+    if (!HandleStatus(
+            state, CrossQueueBarrierValueBatchSubmit(batch_count, &completion),
+            "cross-queue barrier-value batch submit failed")) {
+      break;
+    }
+    if (!WaitWithTimingPaused(state, completion,
+                              "cross-queue barrier-value batch wait failed")) {
+      break;
+    }
+  }
+  state.counters["cross_queue_handoffs_per_sync"] =
+      static_cast<double>(batch_count);
+  SetQueueSubmissionsProcessed(state,
+                               /*queue_submissions_per_sync=*/2 * batch_count);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   CrossQueuePingPongChain20)(benchmark::State& state) {
+  if (!EnsureQueueAvailable(state, kQueue1)) return;
+  for (auto _ : state) {
+    if (!HandleStatus(state, CrossQueuePingPongChainAndWait(kBatchCount),
+                      "cross-queue ping-pong chain failed")) {
+      break;
+    }
+  }
+  SetCrossQueuePingPongCounters(state, kBatchCount,
+                                /*queue_submissions_per_sync=*/1 + kBatchCount,
+                                /*public_completion_signals_per_sync=*/0);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   CrossQueuePingPongChain)(benchmark::State& state) {
+  if (!EnsureQueueAvailable(state, kQueue1)) return;
+  const int64_t handoff_count = state.range(0);
+  for (auto _ : state) {
+    if (!HandleStatus(state, CrossQueuePingPongChainAndWait(handoff_count),
+                      "cross-queue ping-pong chain failed")) {
+      break;
+    }
+  }
+  SetCrossQueuePingPongCounters(
+      state, handoff_count, /*queue_submissions_per_sync=*/1 + handoff_count,
+      /*public_completion_signals_per_sync=*/0);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   CrossQueuePingPongChainSubmitOnly)(benchmark::State& state) {
+  if (!EnsureQueueAvailable(state, kQueue1)) return;
+  const int64_t handoff_count = state.range(0);
+  for (auto _ : state) {
+    SubmittedCompletion completion;
+    if (!HandleStatus(state,
+                      CrossQueuePingPongChainSubmit(handoff_count, &completion),
+                      "cross-queue ping-pong chain submit failed")) {
+      break;
+    }
+    if (!WaitWithTimingPaused(state, completion,
+                              "cross-queue ping-pong chain wait failed")) {
+      break;
+    }
+  }
+  SetCrossQueuePingPongCounters(
+      state, handoff_count, /*queue_submissions_per_sync=*/1 + handoff_count,
+      /*public_completion_signals_per_sync=*/0);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark, CrossQueuePingPongPublicFinalInline)(
+    benchmark::State& state) {
+  if (!EnsureQueueAvailable(state, kQueue1)) return;
+  const int64_t handoff_count = state.range(0);
+  for (auto _ : state) {
+    if (!HandleStatus(
+            state,
+            CrossQueuePingPongChainPublicFinalInlineAndWait(handoff_count),
+            "cross-queue ping-pong public-final inline chain failed")) {
+      break;
+    }
+  }
+  SetCrossQueuePingPongCounters(
+      state, handoff_count, /*queue_submissions_per_sync=*/1 + handoff_count,
+      /*public_completion_signals_per_sync=*/1);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   CrossQueuePingPongPublicFinalInlineSubmitOnly)(
+    benchmark::State& state) {
+  if (!EnsureQueueAvailable(state, kQueue1)) return;
+  const int64_t handoff_count = state.range(0);
+  for (auto _ : state) {
+    SubmittedCompletion completion;
+    if (!HandleStatus(
+            state,
+            CrossQueuePingPongChainSubmitPublicFinalInline(handoff_count,
+                                                           &completion),
+            "cross-queue ping-pong public-final inline submit failed")) {
+      break;
+    }
+    if (!WaitWithTimingPaused(
+            state, completion,
+            "cross-queue ping-pong public-final inline wait failed")) {
+      break;
+    }
+  }
+  SetCrossQueuePingPongCounters(
+      state, handoff_count, /*queue_submissions_per_sync=*/1 + handoff_count,
+      /*public_completion_signals_per_sync=*/1);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark, CrossQueuePingPongPublicFinalSeparate)(
+    benchmark::State& state) {
+  if (!EnsureQueueAvailable(state, kQueue1)) return;
+  const int64_t handoff_count = state.range(0);
+  for (auto _ : state) {
+    if (!HandleStatus(
+            state,
+            CrossQueuePingPongChainPublicFinalSeparateAndWait(handoff_count),
+            "cross-queue ping-pong public-final separate chain failed")) {
+      break;
+    }
+  }
+  SetCrossQueuePingPongCounters(
+      state, handoff_count, /*queue_submissions_per_sync=*/2 + handoff_count,
+      /*public_completion_signals_per_sync=*/1);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   CrossQueuePingPongPublicFinalSeparateSubmitOnly)(
+    benchmark::State& state) {
+  if (!EnsureQueueAvailable(state, kQueue1)) return;
+  const int64_t handoff_count = state.range(0);
+  for (auto _ : state) {
+    SubmittedCompletion completion;
+    if (!HandleStatus(
+            state,
+            CrossQueuePingPongChainSubmitPublicFinalSeparate(handoff_count,
+                                                             &completion),
+            "cross-queue ping-pong public-final separate submit failed")) {
+      break;
+    }
+    if (!WaitWithTimingPaused(
+            state, completion,
+            "cross-queue ping-pong public-final separate wait failed")) {
+      break;
+    }
+  }
+  SetCrossQueuePingPongCounters(
+      state, handoff_count, /*queue_submissions_per_sync=*/2 + handoff_count,
+      /*public_completion_signals_per_sync=*/1);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark, CrossQueuePingPongCopyPublicFinalInline)(
+    benchmark::State& state) {
+  if (!EnsureQueueAvailable(state, kQueue1)) return;
+  if (!EnsurePayloadBuffers(state)) return;
+  const int64_t handoff_count = state.range(0);
+  for (auto _ : state) {
+    if (!HandleStatus(
+            state,
+            CrossQueuePingPongPayloadPublicFinalInlineAndWait(
+                PayloadKind::kCopy, handoff_count),
+            "cross-queue ping-pong copy public-final inline chain failed")) {
+      break;
+    }
+  }
+  SetPayloadPingPongCounters(state, handoff_count,
+                             /*queue_submissions_per_sync=*/1 + handoff_count,
+                             /*public_completion_signals_per_sync=*/1);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   CrossQueuePingPongCopyPublicFinalInlineSubmitOnly)(
+    benchmark::State& state) {
+  if (!EnsureQueueAvailable(state, kQueue1)) return;
+  if (!EnsurePayloadBuffers(state)) return;
+  const int64_t handoff_count = state.range(0);
+  for (auto _ : state) {
+    SubmittedCompletion completion;
+    if (!HandleStatus(
+            state,
+            CrossQueuePingPongPayloadSubmitPublicFinalInline(
+                PayloadKind::kCopy, handoff_count, &completion),
+            "cross-queue ping-pong copy public-final inline submit failed")) {
+      break;
+    }
+    if (!WaitWithTimingPaused(
+            state, completion,
+            "cross-queue ping-pong copy public-final inline wait failed")) {
+      break;
+    }
+  }
+  SetPayloadPingPongCounters(state, handoff_count,
+                             /*queue_submissions_per_sync=*/1 + handoff_count,
+                             /*public_completion_signals_per_sync=*/1);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark, CrossQueuePingPongFillPublicFinalInline)(
+    benchmark::State& state) {
+  if (!EnsureQueueAvailable(state, kQueue1)) return;
+  if (!EnsurePayloadBuffers(state)) return;
+  const int64_t handoff_count = state.range(0);
+  for (auto _ : state) {
+    if (!HandleStatus(
+            state,
+            CrossQueuePingPongPayloadPublicFinalInlineAndWait(
+                PayloadKind::kFill, handoff_count),
+            "cross-queue ping-pong fill public-final inline chain failed")) {
+      break;
+    }
+  }
+  SetPayloadPingPongCounters(state, handoff_count,
+                             /*queue_submissions_per_sync=*/1 + handoff_count,
+                             /*public_completion_signals_per_sync=*/1);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   CrossQueuePingPongFillPublicFinalInlineSubmitOnly)(
+    benchmark::State& state) {
+  if (!EnsureQueueAvailable(state, kQueue1)) return;
+  if (!EnsurePayloadBuffers(state)) return;
+  const int64_t handoff_count = state.range(0);
+  for (auto _ : state) {
+    SubmittedCompletion completion;
+    if (!HandleStatus(
+            state,
+            CrossQueuePingPongPayloadSubmitPublicFinalInline(
+                PayloadKind::kFill, handoff_count, &completion),
+            "cross-queue ping-pong fill public-final inline submit failed")) {
+      break;
+    }
+    if (!WaitWithTimingPaused(
+            state, completion,
+            "cross-queue ping-pong fill public-final inline wait failed")) {
+      break;
+    }
+  }
+  SetPayloadPingPongCounters(state, handoff_count,
+                             /*queue_submissions_per_sync=*/1 + handoff_count,
+                             /*public_completion_signals_per_sync=*/1);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark, CrossQueuePingPongDispatchPublicFinalInline)(
+    benchmark::State& state) {
+  if (!EnsureQueueAvailable(state, kQueue1)) return;
+  if (!EnsurePayloadBuffers(state)) return;
+  if (!EnsureDispatchExecutable(state)) return;
+  const int64_t handoff_count = state.range(0);
+  for (auto _ : state) {
+    if (!HandleStatus(state,
+                      CrossQueuePingPongPayloadPublicFinalInlineAndWait(
+                          PayloadKind::kDispatch, handoff_count),
+                      "cross-queue ping-pong dispatch public-final inline "
+                      "chain failed")) {
+      break;
+    }
+  }
+  SetPayloadPingPongCounters(state, handoff_count,
+                             /*queue_submissions_per_sync=*/1 + handoff_count,
+                             /*public_completion_signals_per_sync=*/1);
+}
+
+BENCHMARK_DEFINE_F(
+    QueueBenchmark,
+    CrossQueuePingPongDispatchPublicFinalInlineEpochCompletionFloor)(
+    benchmark::State& state) {
+  if (!EnsureQueueAvailable(state, kQueue1)) return;
+  if (!EnsurePayloadBuffers(state)) return;
+  if (!EnsureDispatchExecutable(state)) return;
+  const int64_t handoff_count = state.range(0);
+  for (auto _ : state) {
+    SubmittedCompletion completion;
+    if (!HandleStatus(state,
+                      CrossQueuePingPongPayloadSubmitPublicFinalInline(
+                          PayloadKind::kDispatch, handoff_count, &completion),
+                      "cross-queue ping-pong dispatch public-final inline "
+                      "submit failed")) {
+      break;
+    }
+    if (!HandleStatus(state, WaitForSubmittedProducerEpoch(completion),
+                      "cross-queue ping-pong dispatch producer epoch wait "
+                      "failed")) {
+      break;
+    }
+    if (!WaitWithTimingPaused(
+            state, completion,
+            "cross-queue ping-pong dispatch public-final inline wait failed")) {
+      break;
+    }
+  }
+  SetPayloadPingPongCounters(state, handoff_count,
+                             /*queue_submissions_per_sync=*/1 + handoff_count,
+                             /*public_completion_signals_per_sync=*/1);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   CrossQueuePingPongDispatchPublicFinalInlineSubmitOnly)(
+    benchmark::State& state) {
+  if (!EnsureQueueAvailable(state, kQueue1)) return;
+  if (!EnsurePayloadBuffers(state)) return;
+  if (!EnsureDispatchExecutable(state)) return;
+  const int64_t handoff_count = state.range(0);
+  for (auto _ : state) {
+    SubmittedCompletion completion;
+    if (!HandleStatus(state,
+                      CrossQueuePingPongPayloadSubmitPublicFinalInline(
+                          PayloadKind::kDispatch, handoff_count, &completion),
+                      "cross-queue ping-pong dispatch public-final inline "
+                      "submit failed")) {
+      break;
+    }
+    if (!WaitWithTimingPaused(
+            state, completion,
+            "cross-queue ping-pong dispatch public-final inline wait failed")) {
+      break;
+    }
+  }
+  SetPayloadPingPongCounters(state, handoff_count,
+                             /*queue_submissions_per_sync=*/1 + handoff_count,
+                             /*public_completion_signals_per_sync=*/1);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   CrossQueuePingPongNoopDispatchPublicFinalInline)(
+    benchmark::State& state) {
+  if (!EnsureQueueAvailable(state, kQueue1)) return;
+  if (!EnsureDispatchExecutable(state)) return;
+  const int64_t handoff_count = state.range(0);
+  for (auto _ : state) {
+    if (!HandleStatus(state,
+                      CrossQueuePingPongPayloadPublicFinalInlineAndWait(
+                          PayloadKind::kNoopDispatch, handoff_count),
+                      "cross-queue ping-pong noop dispatch public-final inline "
+                      "chain failed")) {
+      break;
+    }
+  }
+  SetPayloadPingPongCounters(state, handoff_count,
+                             /*queue_submissions_per_sync=*/1 + handoff_count,
+                             /*public_completion_signals_per_sync=*/1);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   CrossQueuePingPongNoopDispatchPublicFinalInlineSubmitOnly)(
+    benchmark::State& state) {
+  if (!EnsureQueueAvailable(state, kQueue1)) return;
+  if (!EnsureDispatchExecutable(state)) return;
+  const int64_t handoff_count = state.range(0);
+  for (auto _ : state) {
+    SubmittedCompletion completion;
+    if (!HandleStatus(
+            state,
+            CrossQueuePingPongPayloadSubmitPublicFinalInline(
+                PayloadKind::kNoopDispatch, handoff_count, &completion),
+            "cross-queue ping-pong noop dispatch public-final inline "
+            "submit failed")) {
+      break;
+    }
+    if (!WaitWithTimingPaused(
+            state, completion,
+            "cross-queue ping-pong noop dispatch public-final inline wait "
+            "failed")) {
+      break;
+    }
+  }
+  SetPayloadPingPongCounters(state, handoff_count,
+                             /*queue_submissions_per_sync=*/1 + handoff_count,
+                             /*public_completion_signals_per_sync=*/1);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   CrossQueuePingPongPreResolvedDispatchPublicFinalInline)(
+    benchmark::State& state) {
+  if (!EnsureQueueAvailable(state, kQueue1)) return;
+  if (!EnsurePreResolvedDispatch(state)) return;
+  const int64_t handoff_count = state.range(0);
+  for (auto _ : state) {
+    if (!HandleStatus(state,
+                      CrossQueuePingPongPayloadPublicFinalInlineAndWait(
+                          PayloadKind::kPreResolvedDispatch, handoff_count),
+                      "cross-queue ping-pong pre-resolved dispatch "
+                      "public-final inline chain failed")) {
+      break;
+    }
+  }
+  SetPayloadPingPongCounters(state, handoff_count,
+                             /*queue_submissions_per_sync=*/1 + handoff_count,
+                             /*public_completion_signals_per_sync=*/1);
+}
+
+BENCHMARK_DEFINE_F(
+    QueueBenchmark,
+    CrossQueuePingPongPreResolvedDispatchPublicFinalInlineSubmitOnly)(
+    benchmark::State& state) {
+  if (!EnsureQueueAvailable(state, kQueue1)) return;
+  if (!EnsurePreResolvedDispatch(state)) return;
+  const int64_t handoff_count = state.range(0);
+  for (auto _ : state) {
+    SubmittedCompletion completion;
+    if (!HandleStatus(
+            state,
+            CrossQueuePingPongPayloadSubmitPublicFinalInline(
+                PayloadKind::kPreResolvedDispatch, handoff_count, &completion),
+            "cross-queue ping-pong pre-resolved dispatch "
+            "public-final inline submit failed")) {
+      break;
+    }
+    if (!WaitWithTimingPaused(
+            state, completion,
+            "cross-queue ping-pong pre-resolved dispatch public-final inline "
+            "wait failed")) {
+      break;
+    }
+  }
+  SetPayloadPingPongCounters(state, handoff_count,
+                             /*queue_submissions_per_sync=*/1 + handoff_count,
+                             /*public_completion_signals_per_sync=*/1);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   SameQueuePrivateStreamCopyChainPublicFinalInline)(
+    benchmark::State& state) {
+  if (!EnsurePayloadBuffers(state)) return;
+  const int64_t operation_count = state.range(0);
+  for (auto _ : state) {
+    if (!HandleStatus(state,
+                      SameQueuePrivateStreamPayloadPublicFinalInlineAndWait(
+                          PayloadKind::kCopy, operation_count),
+                      "same-queue private-stream copy chain failed")) {
+      break;
+    }
+  }
+  SetSingleStreamPayloadCounters(state, operation_count);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   SameQueuePrivateStreamCopyChainPublicFinalInlineSubmitOnly)(
+    benchmark::State& state) {
+  if (!EnsurePayloadBuffers(state)) return;
+  const int64_t operation_count = state.range(0);
+  for (auto _ : state) {
+    SubmittedCompletion completion;
+    if (!HandleStatus(state,
+                      SameQueuePrivateStreamPayloadSubmitPublicFinalInline(
+                          PayloadKind::kCopy, operation_count, &completion),
+                      "same-queue private-stream copy submit failed")) {
+      break;
+    }
+    if (!WaitWithTimingPaused(state, completion,
+                              "same-queue private-stream copy wait failed")) {
+      break;
+    }
+  }
+  SetSingleStreamPayloadCounters(state, operation_count);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   SameQueuePrivateStreamDispatchChainPublicFinalInline)(
+    benchmark::State& state) {
+  if (!EnsurePayloadBuffers(state)) return;
+  if (!EnsureDispatchExecutable(state)) return;
+  const int64_t operation_count = state.range(0);
+  for (auto _ : state) {
+    if (!HandleStatus(state,
+                      SameQueuePrivateStreamPayloadPublicFinalInlineAndWait(
+                          PayloadKind::kDispatch, operation_count),
+                      "same-queue private-stream dispatch chain failed")) {
+      break;
+    }
+  }
+  SetSingleStreamPayloadCounters(state, operation_count);
+}
+
+BENCHMARK_DEFINE_F(
+    QueueBenchmark,
+    SameQueuePrivateStreamDispatchChainPublicFinalInlineEpochCompletionFloor)(
+    benchmark::State& state) {
+  if (!EnsurePayloadBuffers(state)) return;
+  if (!EnsureDispatchExecutable(state)) return;
+  const int64_t operation_count = state.range(0);
+  for (auto _ : state) {
+    SubmittedCompletion completion;
+    if (!HandleStatus(state,
+                      SameQueuePrivateStreamPayloadSubmitPublicFinalInline(
+                          PayloadKind::kDispatch, operation_count, &completion),
+                      "same-queue private-stream dispatch submit failed")) {
+      break;
+    }
+    if (!HandleStatus(state, WaitForSubmittedProducerEpoch(completion),
+                      "same-queue private-stream dispatch producer epoch wait "
+                      "failed")) {
+      break;
+    }
+    if (!WaitWithTimingPaused(
+            state, completion,
+            "same-queue private-stream dispatch wait failed")) {
+      break;
+    }
+  }
+  SetSingleStreamPayloadCounters(state, operation_count);
+}
+
+BENCHMARK_DEFINE_F(
+    QueueBenchmark,
+    SameQueuePrivateStreamDispatchChainPublicFinalInlineSubmitOnly)(
+    benchmark::State& state) {
+  if (!EnsurePayloadBuffers(state)) return;
+  if (!EnsureDispatchExecutable(state)) return;
+  const int64_t operation_count = state.range(0);
+  for (auto _ : state) {
+    SubmittedCompletion completion;
+    if (!HandleStatus(state,
+                      SameQueuePrivateStreamPayloadSubmitPublicFinalInline(
+                          PayloadKind::kDispatch, operation_count, &completion),
+                      "same-queue private-stream dispatch submit failed")) {
+      break;
+    }
+    if (!WaitWithTimingPaused(
+            state, completion,
+            "same-queue private-stream dispatch wait failed")) {
+      break;
+    }
+  }
+  SetSingleStreamPayloadCounters(state, operation_count);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   SameQueuePrivateStreamNoopDispatchChainPublicFinalInline)(
+    benchmark::State& state) {
+  if (!EnsureDispatchExecutable(state)) return;
+  const int64_t operation_count = state.range(0);
+  for (auto _ : state) {
+    if (!HandleStatus(state,
+                      SameQueuePrivateStreamPayloadPublicFinalInlineAndWait(
+                          PayloadKind::kNoopDispatch, operation_count),
+                      "same-queue private-stream noop dispatch chain failed")) {
+      break;
+    }
+  }
+  SetSingleStreamPayloadCounters(state, operation_count);
+}
+
+BENCHMARK_DEFINE_F(
+    QueueBenchmark,
+    SameQueuePrivateStreamNoopDispatchChainPublicFinalInlineSubmitOnly)(
+    benchmark::State& state) {
+  if (!EnsureDispatchExecutable(state)) return;
+  const int64_t operation_count = state.range(0);
+  for (auto _ : state) {
+    SubmittedCompletion completion;
+    if (!HandleStatus(
+            state,
+            SameQueuePrivateStreamPayloadSubmitPublicFinalInline(
+                PayloadKind::kNoopDispatch, operation_count, &completion),
+            "same-queue private-stream noop dispatch submit "
+            "failed")) {
+      break;
+    }
+    if (!WaitWithTimingPaused(
+            state, completion,
+            "same-queue private-stream noop dispatch wait failed")) {
+      break;
+    }
+  }
+  SetSingleStreamPayloadCounters(state, operation_count);
+}
+
+BENCHMARK_DEFINE_F(
+    QueueBenchmark,
+    SameQueuePrivateStreamPreResolvedDispatchChainPublicFinalInline)(
+    benchmark::State& state) {
+  if (!EnsurePreResolvedDispatch(state)) return;
+  const int64_t operation_count = state.range(0);
+  for (auto _ : state) {
+    if (!HandleStatus(state,
+                      SameQueuePrivateStreamPayloadPublicFinalInlineAndWait(
+                          PayloadKind::kPreResolvedDispatch, operation_count),
+                      "same-queue private-stream pre-resolved dispatch chain "
+                      "failed")) {
+      break;
+    }
+  }
+  SetSingleStreamPayloadCounters(state, operation_count);
+}
+
+BENCHMARK_DEFINE_F(
+    QueueBenchmark,
+    SameQueuePrivateStreamPreResolvedDispatchChainPublicFinalInlineSubmitOnly)(
+    benchmark::State& state) {
+  if (!EnsurePreResolvedDispatch(state)) return;
+  const int64_t operation_count = state.range(0);
+  for (auto _ : state) {
+    SubmittedCompletion completion;
+    if (!HandleStatus(state,
+                      SameQueuePrivateStreamPayloadSubmitPublicFinalInline(
+                          PayloadKind::kPreResolvedDispatch, operation_count,
+                          &completion),
+                      "same-queue private-stream pre-resolved dispatch submit "
+                      "failed")) {
+      break;
+    }
+    if (!WaitWithTimingPaused(
+            state, completion,
+            "same-queue private-stream pre-resolved dispatch wait failed")) {
+      break;
+    }
+  }
+  SetSingleStreamPayloadCounters(state, operation_count);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   SameQueuePrivateStreamFillChainPublicFinalInline)(
+    benchmark::State& state) {
+  if (!EnsurePayloadBuffers(state)) return;
+  const int64_t operation_count = state.range(0);
+  for (auto _ : state) {
+    if (!HandleStatus(state,
+                      SameQueuePrivateStreamPayloadPublicFinalInlineAndWait(
+                          PayloadKind::kFill, operation_count),
+                      "same-queue private-stream fill chain failed")) {
+      break;
+    }
+  }
+  SetSingleStreamPayloadCounters(state, operation_count);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   SameQueuePrivateStreamFillChainPublicFinalInlineSubmitOnly)(
+    benchmark::State& state) {
+  if (!EnsurePayloadBuffers(state)) return;
+  const int64_t operation_count = state.range(0);
+  for (auto _ : state) {
+    SubmittedCompletion completion;
+    if (!HandleStatus(state,
+                      SameQueuePrivateStreamPayloadSubmitPublicFinalInline(
+                          PayloadKind::kFill, operation_count, &completion),
+                      "same-queue private-stream fill submit failed")) {
+      break;
+    }
+    if (!WaitWithTimingPaused(state, completion,
+                              "same-queue private-stream fill wait failed")) {
+      break;
+    }
+  }
+  SetSingleStreamPayloadCounters(state, operation_count);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   DispatchValidateOnly)(benchmark::State& state) {
+  if (!EnsurePayloadBuffers(state)) return;
+  if (!EnsureDispatchExecutable(state)) return;
+  iree_hal_amdgpu_host_queue_t* host_queue = nullptr;
+  if (!HandleStatus(state, LookupHostQueue(kQueue0, &host_queue),
+                    "failed to find queue 0")) {
+    return;
+  }
+
+  for (auto _ : state) {
+    iree_host_size_t operation_resource_count = 0;
+    iree_status_t status =
+        ValidateDispatchOnce(host_queue, &operation_resource_count);
+    if (!HandleStatus(state, status, "dispatch validation failed")) {
+      break;
+    }
+    benchmark::DoNotOptimize(operation_resource_count);
+  }
+  state.SetItemsProcessed(state.iterations());
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   DispatchBindingCountValidateOnly)(benchmark::State& state) {
+  const int64_t binding_count = state.range(0);
+  if (!EnsureBindingCountExecutable(state)) return;
+  if (!EnsureBindingCountBuffers(state, binding_count)) return;
+  iree_hal_amdgpu_host_queue_t* host_queue = nullptr;
+  if (!HandleStatus(state, LookupHostQueue(kQueue0, &host_queue),
+                    "failed to find queue 0")) {
+    return;
+  }
+
+  for (auto _ : state) {
+    iree_host_size_t operation_resource_count = 0;
+    iree_status_t status = ValidateBindingCountDispatchOnce(
+        host_queue, binding_count, &operation_resource_count);
+    if (!HandleStatus(state, status,
+                      "binding-count dispatch validation failed")) {
+      break;
+    }
+    benchmark::DoNotOptimize(operation_resource_count);
+  }
+  SetBindingCountCounters(state, binding_count);
+  state.SetItemsProcessed(state.iterations());
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   SameQueueDispatchBindingCountPublicFinalInlineSubmitOnly)(
+    benchmark::State& state) {
+  const int64_t binding_count = state.range(0);
+  if (!EnsureBindingCountExecutable(state)) return;
+  if (!EnsureBindingCountBuffers(state, binding_count)) return;
+  for (auto _ : state) {
+    SubmittedCompletion completion;
+    if (!HandleStatus(state,
+                      BindingCountDispatchSubmitPublicFinalInline(binding_count,
+                                                                  &completion),
+                      "binding-count dispatch submit failed")) {
+      break;
+    }
+    if (!WaitWithTimingPaused(state, completion,
+                              "binding-count dispatch wait failed")) {
+      break;
+    }
+  }
+  SetBindingCountCounters(state, binding_count);
+  SetQueueSubmissionsProcessed(state, /*queue_submissions_per_sync=*/1);
+}
+
+BENCHMARK_DEFINE_F(
+    QueueBenchmark,
+    SameQueueCommandBufferBindingCountStaticPublicFinalInlineSubmitOnly)(
+    benchmark::State& state) {
+  const int64_t binding_count = state.range(0);
+  if (!EnsureBindingCountCommandBuffer(state, binding_count)) return;
+  for (auto _ : state) {
+    SubmittedCompletion completion;
+    if (!HandleStatus(state,
+                      BindingCountCommandBufferSubmitPublicFinalInline(
+                          binding_count, &completion),
+                      "binding-count command-buffer submit failed")) {
+      break;
+    }
+    if (!WaitWithTimingPaused(state, completion,
+                              "binding-count command-buffer wait failed")) {
+      break;
+    }
+  }
+  SetBindingCountCounters(state, binding_count);
+  SetQueueSubmissionsProcessed(state, /*queue_submissions_per_sync=*/1);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   SameQueueCommandBufferDispatchChainStaticPublicFinalInline)(
+    benchmark::State& state) {
+  const int64_t operation_count = state.range(0);
+  const int64_t binding_count = state.range(1);
+  if (!EnsureBindingCountExecutable(state)) return;
+  if (!EnsureBindingCountBuffers(state, binding_count)) return;
+
+  iree_hal_command_buffer_t* command_buffer = nullptr;
+  if (!HandleStatus(state,
+                    RecordBindingCountDispatchChainCommandBuffer(
+                        operation_count, binding_count, &command_buffer),
+                    "failed to record command-buffer dispatch chain")) {
+    return;
+  }
+  for (auto _ : state) {
+    SubmittedCompletion completion;
+    if (!HandleStatus(
+            state,
+            BindingCountDispatchChainCommandBufferSubmitPublicFinalInline(
+                command_buffer, &completion),
+            "command-buffer dispatch chain submit failed")) {
+      break;
+    }
+    if (!HandleStatus(state,
+                      Wait(completion.semaphore, completion.payload_value),
+                      "command-buffer dispatch chain wait failed")) {
+      break;
+    }
+  }
+  SetCommandBufferProgramCounters(state, command_buffer, operation_count,
+                                  binding_count);
+  iree_hal_command_buffer_release(command_buffer);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark, ProfileGuardrailBarrierBatch20FinalWait)(
+    benchmark::State& state) {
+  RunProfileGuardrailFinalWaitBenchmark(
+      state, kProfileGuardrailOperationCount, kProfileGuardrailOperationCount,
+      "profile guardrail barrier batch flush failed",
+      "profile guardrail barrier batch end failed", [&]() {
+        return HandleStatus(
+            state,
+            SameQueueBarrierBatchAndWait(kProfileGuardrailOperationCount),
+            "profile guardrail barrier batch failed");
+      });
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark, ProfileGuardrailBarrierBatch20SubmitOnly)(
+    benchmark::State& state) {
+  RunProfileGuardrailSubmitOnlyBenchmark(
+      state, kProfileGuardrailOperationCount, kProfileGuardrailOperationCount,
+      "profile guardrail barrier batch wait failed",
+      "profile guardrail barrier batch flush failed",
+      "profile guardrail barrier batch end failed",
+      [&](SubmittedCompletion* completion) {
+        return HandleStatus(state,
+                            SameQueueBarrierBatchSubmit(
+                                kProfileGuardrailOperationCount, completion),
+                            "profile guardrail barrier batch submit failed");
+      });
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark, ProfileGuardrailCopyBatch20FinalWait)(
+    benchmark::State& state) {
+  if (!EnsurePayloadBuffers(state)) return;
+  RunProfileGuardrailFinalWaitBenchmark(
+      state, kProfileGuardrailOperationCount, kProfileGuardrailOperationCount,
+      "profile guardrail copy batch flush failed",
+      "profile guardrail copy batch end failed", [&]() {
+        return HandleStatus(
+            state,
+            SameQueuePrivateStreamPayloadPublicFinalInlineAndWait(
+                PayloadKind::kCopy, kProfileGuardrailOperationCount),
+            "profile guardrail copy batch failed");
+      });
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark, ProfileGuardrailCopyBatch20SubmitOnly)(
+    benchmark::State& state) {
+  if (!EnsurePayloadBuffers(state)) return;
+  RunProfileGuardrailSubmitOnlyBenchmark(
+      state, kProfileGuardrailOperationCount, kProfileGuardrailOperationCount,
+      "profile guardrail copy batch wait failed",
+      "profile guardrail copy batch flush failed",
+      "profile guardrail copy batch end failed",
+      [&](SubmittedCompletion* completion) {
+        return HandleStatus(
+            state,
+            SameQueuePrivateStreamPayloadSubmitPublicFinalInline(
+                PayloadKind::kCopy, kProfileGuardrailOperationCount,
+                completion),
+            "profile guardrail copy batch submit failed");
+      });
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark, ProfileGuardrailFillBatch20FinalWait)(
+    benchmark::State& state) {
+  if (!EnsurePayloadBuffers(state)) return;
+  RunProfileGuardrailFinalWaitBenchmark(
+      state, kProfileGuardrailOperationCount, kProfileGuardrailOperationCount,
+      "profile guardrail fill batch flush failed",
+      "profile guardrail fill batch end failed", [&]() {
+        return HandleStatus(
+            state,
+            SameQueuePrivateStreamPayloadPublicFinalInlineAndWait(
+                PayloadKind::kFill, kProfileGuardrailOperationCount),
+            "profile guardrail fill batch failed");
+      });
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark, ProfileGuardrailFillBatch20SubmitOnly)(
+    benchmark::State& state) {
+  if (!EnsurePayloadBuffers(state)) return;
+  RunProfileGuardrailSubmitOnlyBenchmark(
+      state, kProfileGuardrailOperationCount, kProfileGuardrailOperationCount,
+      "profile guardrail fill batch wait failed",
+      "profile guardrail fill batch flush failed",
+      "profile guardrail fill batch end failed",
+      [&](SubmittedCompletion* completion) {
+        return HandleStatus(
+            state,
+            SameQueuePrivateStreamPayloadSubmitPublicFinalInline(
+                PayloadKind::kFill, kProfileGuardrailOperationCount,
+                completion),
+            "profile guardrail fill batch submit failed");
+      });
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark, ProfileGuardrailDispatchBatch20FinalWait)(
+    benchmark::State& state) {
+  if (!EnsurePayloadBuffers(state)) return;
+  if (!EnsureDispatchExecutable(state)) return;
+  RunProfileGuardrailFinalWaitBenchmark(
+      state, kProfileGuardrailOperationCount, kProfileGuardrailOperationCount,
+      "profile guardrail dispatch batch flush failed",
+      "profile guardrail dispatch batch end failed", [&]() {
+        return HandleStatus(
+            state,
+            SameQueuePrivateStreamPayloadPublicFinalInlineAndWait(
+                PayloadKind::kDispatch, kProfileGuardrailOperationCount),
+            "profile guardrail dispatch batch failed");
+      });
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark, ProfileGuardrailDispatchBatch20SubmitOnly)(
+    benchmark::State& state) {
+  if (!EnsurePayloadBuffers(state)) return;
+  if (!EnsureDispatchExecutable(state)) return;
+  RunProfileGuardrailSubmitOnlyBenchmark(
+      state, kProfileGuardrailOperationCount, kProfileGuardrailOperationCount,
+      "profile guardrail dispatch batch wait failed",
+      "profile guardrail dispatch batch flush failed",
+      "profile guardrail dispatch batch end failed",
+      [&](SubmittedCompletion* completion) {
+        return HandleStatus(
+            state,
+            SameQueuePrivateStreamPayloadSubmitPublicFinalInline(
+                PayloadKind::kDispatch, kProfileGuardrailOperationCount,
+                completion),
+            "profile guardrail dispatch batch submit failed");
+      });
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   ProfileGuardrailCommandBufferDispatchChain20FinalWait)(
+    benchmark::State& state) {
+  if (!EnsureBindingCountExecutable(state)) return;
+  if (!EnsureBindingCountBuffers(state, kProfileGuardrailBindingCount)) return;
+
+  iree_hal_command_buffer_t* command_buffer = nullptr;
+  if (!HandleStatus(state,
+                    RecordBindingCountDispatchChainCommandBuffer(
+                        kProfileGuardrailOperationCount,
+                        kProfileGuardrailBindingCount, &command_buffer),
+                    "failed to record profile guardrail command buffer")) {
+    return;
+  }
+  RunProfileGuardrailFinalWaitBenchmark(
+      state, /*queue_submissions_per_sync=*/1, kProfileGuardrailOperationCount,
+      "profile guardrail command-buffer flush failed",
+      "profile guardrail command-buffer end failed", [&]() {
+        SubmittedCompletion completion;
+        if (!HandleStatus(
+                state,
+                BindingCountDispatchChainCommandBufferSubmitPublicFinalInline(
+                    command_buffer, &completion),
+                "profile guardrail command-buffer submit failed")) {
+          return false;
+        }
+        return HandleStatus(
+            state, Wait(completion.semaphore, completion.payload_value),
+            "profile guardrail command-buffer wait failed");
+      });
+  iree_hal_command_buffer_release(command_buffer);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   ProfileGuardrailCommandBufferDispatchChain20SubmitOnly)(
+    benchmark::State& state) {
+  if (!EnsureBindingCountExecutable(state)) return;
+  if (!EnsureBindingCountBuffers(state, kProfileGuardrailBindingCount)) return;
+
+  iree_hal_command_buffer_t* command_buffer = nullptr;
+  if (!HandleStatus(state,
+                    RecordBindingCountDispatchChainCommandBuffer(
+                        kProfileGuardrailOperationCount,
+                        kProfileGuardrailBindingCount, &command_buffer),
+                    "failed to record profile guardrail command buffer")) {
+    return;
+  }
+  RunProfileGuardrailSubmitOnlyBenchmark(
+      state, /*queue_submissions_per_sync=*/1, kProfileGuardrailOperationCount,
+      "profile guardrail command-buffer wait failed",
+      "profile guardrail command-buffer flush failed",
+      "profile guardrail command-buffer end failed",
+      [&](SubmittedCompletion* completion) {
+        return HandleStatus(
+            state,
+            BindingCountDispatchChainCommandBufferSubmitPublicFinalInline(
+                command_buffer, completion),
+            "profile guardrail command-buffer submit failed");
+      });
+  iree_hal_command_buffer_release(command_buffer);
+}
+
+BENCHMARK_DEFINE_F(
+    QueueBenchmark,
+    SameQueueCommandBufferDispatchChainStaticPublicFinalInlineSubmitOnly)(
+    benchmark::State& state) {
+  const int64_t operation_count = state.range(0);
+  const int64_t binding_count = state.range(1);
+  if (!EnsureBindingCountExecutable(state)) return;
+  if (!EnsureBindingCountBuffers(state, binding_count)) return;
+
+  iree_hal_command_buffer_t* command_buffer = nullptr;
+  if (!HandleStatus(state,
+                    RecordBindingCountDispatchChainCommandBuffer(
+                        operation_count, binding_count, &command_buffer),
+                    "failed to record command-buffer dispatch chain")) {
+    return;
+  }
+  for (auto _ : state) {
+    SubmittedCompletion completion;
+    if (!HandleStatus(
+            state,
+            BindingCountDispatchChainCommandBufferSubmitPublicFinalInline(
+                command_buffer, &completion),
+            "command-buffer dispatch chain submit failed")) {
+      break;
+    }
+    if (!WaitWithTimingPaused(state, completion,
+                              "command-buffer dispatch chain wait failed")) {
+      break;
+    }
+  }
+  SetCommandBufferProgramCounters(state, command_buffer, operation_count,
+                                  binding_count);
+  iree_hal_command_buffer_release(command_buffer);
+}
+
+BENCHMARK_DEFINE_F(QueueBenchmark,
+                   WaitBeforeSignalChain)(benchmark::State& state) {
+  if (!EnsureQueueAvailable(state, kQueue1)) return;
+  for (auto _ : state) {
+    if (!HandleStatus(state, WaitBeforeSignalChainAndWait(),
+                      "wait-before-signal chain failed")) {
+      break;
+    }
+  }
+  SetQueueSubmissionsProcessed(state, /*queue_submissions_per_sync=*/2);
+}
+
+void ApplyProfileGuardrailModes(benchmark::Benchmark* benchmark) {
+  benchmark->ArgName("profile_mode");
+  benchmark->Arg((int64_t)ProfileGuardrailMode::kDisabled);
+  benchmark->Arg((int64_t)ProfileGuardrailMode::kQueueDeviceEvents);
+  benchmark->Arg((int64_t)ProfileGuardrailMode::kDispatchEvents);
+  benchmark->Arg((int64_t)ProfileGuardrailMode::kQueueDeviceAndDispatchEvents);
+  benchmark->Iterations(kProfileGuardrailIterations);
+}
+
+BENCHMARK_REGISTER_F(QueueBenchmark, SameQueueBarrierWait)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, QueueAllocaWarmTlsfSubmitOnly)
+    ->Arg(4096)
+    ->Arg(65536)
+    ->ArgName("allocation_size")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, QueueAllocaForcedTlsfGrowthSubmitOnly)
+    ->Arg(4096)
+    ->Arg(65536)
+    ->ArgName("allocation_size")
+    ->Iterations(100)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, HostCallBlockingWait)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, HostCallBlockingRelaxedWait)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, HostCallNonBlockingWait)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, HostCallBlockingBatch20FinalWait)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, HostCallBlockingRelaxedBatch20FinalWait)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, HostCallNonBlockingBatch20FinalWait)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, SameQueueBarrierBatch20FinalWait)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, SameQueueBarrierBatchFinalWait)
+    ->Arg(20)
+    ->Arg(100)
+    ->Arg(1000)
+    ->ArgName("batch_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, SameQueueBarrierBatchSubmitOnly)
+    ->Arg(20)
+    ->Arg(100)
+    ->Arg(1000)
+    ->ArgName("batch_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, SameQueueEpochChain20)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, SameQueueEpochChain)
+    ->Arg(20)
+    ->Arg(100)
+    ->Arg(1000)
+    ->ArgName("batch_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, SameQueueEpochChainSubmitOnly)
+    ->Arg(20)
+    ->Arg(100)
+    ->Arg(1000)
+    ->ArgName("batch_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, CrossQueueAlreadyCompletedWait)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, CrossQueueBarrierValue)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, CrossQueueBarrierValueBatch20FinalWait)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, CrossQueueBarrierValueBatchFinalWait)
+    ->Arg(20)
+    ->Arg(100)
+    ->Arg(1000)
+    ->ArgName("handoff_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, CrossQueueBarrierValueBatchSubmitOnly)
+    ->Arg(20)
+    ->Arg(100)
+    ->Arg(1000)
+    ->ArgName("handoff_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, CrossQueuePingPongChain20)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, CrossQueuePingPongChain)
+    ->Arg(20)
+    ->Arg(100)
+    ->Arg(1000)
+    ->ArgName("handoff_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, CrossQueuePingPongChainSubmitOnly)
+    ->Arg(20)
+    ->Arg(100)
+    ->Arg(1000)
+    ->ArgName("handoff_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, CrossQueuePingPongPublicFinalInline)
+    ->Arg(2)
+    ->Arg(32)
+    ->Arg(512)
+    ->Arg(2048)
+    ->ArgName("handoff_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark,
+                     CrossQueuePingPongPublicFinalInlineSubmitOnly)
+    ->Arg(2)
+    ->Arg(32)
+    ->Arg(512)
+    ->Arg(2048)
+    ->ArgName("handoff_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, CrossQueuePingPongPublicFinalSeparate)
+    ->Arg(2)
+    ->Arg(32)
+    ->Arg(512)
+    ->Arg(2048)
+    ->ArgName("handoff_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark,
+                     CrossQueuePingPongPublicFinalSeparateSubmitOnly)
+    ->Arg(2)
+    ->Arg(32)
+    ->Arg(512)
+    ->Arg(2048)
+    ->ArgName("handoff_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, CrossQueuePingPongCopyPublicFinalInline)
+    ->Arg(2)
+    ->Arg(32)
+    ->Arg(512)
+    ->Arg(2048)
+    ->ArgName("handoff_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark,
+                     CrossQueuePingPongCopyPublicFinalInlineSubmitOnly)
+    ->Arg(2)
+    ->Arg(32)
+    ->Arg(512)
+    ->Arg(2048)
+    ->ArgName("handoff_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, CrossQueuePingPongFillPublicFinalInline)
+    ->Arg(2)
+    ->Arg(32)
+    ->Arg(512)
+    ->Arg(2048)
+    ->ArgName("handoff_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark,
+                     CrossQueuePingPongFillPublicFinalInlineSubmitOnly)
+    ->Arg(2)
+    ->Arg(32)
+    ->Arg(512)
+    ->Arg(2048)
+    ->ArgName("handoff_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark,
+                     CrossQueuePingPongDispatchPublicFinalInline)
+    ->Arg(2)
+    ->Arg(32)
+    ->Arg(512)
+    ->Arg(2048)
+    ->ArgName("handoff_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(
+    QueueBenchmark,
+    CrossQueuePingPongDispatchPublicFinalInlineEpochCompletionFloor)
+    ->Arg(2)
+    ->Arg(32)
+    ->Arg(512)
+    ->Arg(2048)
+    ->ArgName("handoff_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark,
+                     CrossQueuePingPongDispatchPublicFinalInlineSubmitOnly)
+    ->Arg(2)
+    ->Arg(32)
+    ->Arg(512)
+    ->Arg(2048)
+    ->ArgName("handoff_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark,
+                     CrossQueuePingPongNoopDispatchPublicFinalInline)
+    ->Arg(2)
+    ->Arg(32)
+    ->Arg(512)
+    ->Arg(2048)
+    ->ArgName("handoff_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark,
+                     CrossQueuePingPongNoopDispatchPublicFinalInlineSubmitOnly)
+    ->Arg(2)
+    ->Arg(32)
+    ->Arg(512)
+    ->Arg(2048)
+    ->ArgName("handoff_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark,
+                     CrossQueuePingPongPreResolvedDispatchPublicFinalInline)
+    ->Arg(2)
+    ->Arg(32)
+    ->Arg(512)
+    ->Arg(2048)
+    ->ArgName("handoff_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(
+    QueueBenchmark,
+    CrossQueuePingPongPreResolvedDispatchPublicFinalInlineSubmitOnly)
+    ->Arg(2)
+    ->Arg(32)
+    ->Arg(512)
+    ->Arg(2048)
+    ->ArgName("handoff_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark,
+                     SameQueuePrivateStreamCopyChainPublicFinalInline)
+    ->Arg(1)
+    ->Arg(20)
+    ->Arg(1000)
+    ->ArgName("operation_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark,
+                     SameQueuePrivateStreamCopyChainPublicFinalInlineSubmitOnly)
+    ->Arg(1)
+    ->Arg(20)
+    ->Arg(1000)
+    ->ArgName("operation_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark,
+                     SameQueuePrivateStreamDispatchChainPublicFinalInline)
+    ->Arg(1)
+    ->Arg(20)
+    ->Arg(1000)
+    ->ArgName("operation_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(
+    QueueBenchmark,
+    SameQueuePrivateStreamDispatchChainPublicFinalInlineEpochCompletionFloor)
+    ->Arg(1)
+    ->Arg(20)
+    ->Arg(1000)
+    ->ArgName("operation_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(
+    QueueBenchmark,
+    SameQueuePrivateStreamDispatchChainPublicFinalInlineSubmitOnly)
+    ->Arg(1)
+    ->Arg(20)
+    ->Arg(1000)
+    ->ArgName("operation_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark,
+                     SameQueuePrivateStreamNoopDispatchChainPublicFinalInline)
+    ->Arg(1)
+    ->Arg(20)
+    ->Arg(1000)
+    ->ArgName("operation_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(
+    QueueBenchmark,
+    SameQueuePrivateStreamNoopDispatchChainPublicFinalInlineSubmitOnly)
+    ->Arg(1)
+    ->Arg(20)
+    ->Arg(1000)
+    ->ArgName("operation_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(
+    QueueBenchmark,
+    SameQueuePrivateStreamPreResolvedDispatchChainPublicFinalInline)
+    ->Arg(1)
+    ->Arg(20)
+    ->Arg(1000)
+    ->ArgName("operation_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(
+    QueueBenchmark,
+    SameQueuePrivateStreamPreResolvedDispatchChainPublicFinalInlineSubmitOnly)
+    ->Arg(1)
+    ->Arg(20)
+    ->Arg(1000)
+    ->ArgName("operation_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark,
+                     SameQueuePrivateStreamFillChainPublicFinalInline)
+    ->Arg(1)
+    ->Arg(20)
+    ->Arg(1000)
+    ->ArgName("operation_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark,
+                     SameQueuePrivateStreamFillChainPublicFinalInlineSubmitOnly)
+    ->Arg(1)
+    ->Arg(20)
+    ->Arg(1000)
+    ->ArgName("operation_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, DispatchValidateOnly)
+    ->UseRealTime()
+    ->Unit(benchmark::kNanosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, DispatchBindingCountValidateOnly)
+    ->Arg(0)
+    ->Arg(1)
+    ->Arg(4)
+    ->Arg(8)
+    ->Arg(9)
+    ->Arg(16)
+    ->Arg(17)
+    ->Arg(24)
+    ->Arg(25)
+    ->Arg(256)
+    ->ArgName("binding_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kNanosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark,
+                     SameQueueDispatchBindingCountPublicFinalInlineSubmitOnly)
+    ->Arg(0)
+    ->Arg(1)
+    ->Arg(4)
+    ->Arg(8)
+    ->Arg(9)
+    ->Arg(16)
+    ->Arg(17)
+    ->Arg(24)
+    ->Arg(25)
+    ->Arg(256)
+    ->ArgName("binding_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(
+    QueueBenchmark,
+    SameQueueCommandBufferBindingCountStaticPublicFinalInlineSubmitOnly)
+    ->Arg(0)
+    ->Arg(1)
+    ->Arg(4)
+    ->Arg(8)
+    ->Arg(9)
+    ->Arg(16)
+    ->Arg(17)
+    ->Arg(24)
+    ->Arg(25)
+    ->Arg(256)
+    ->ArgName("binding_count")
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark,
+                     SameQueueCommandBufferDispatchChainStaticPublicFinalInline)
+    ->ArgsProduct({{1, 10, 100, 1000, 5000}, {0, 1, 4, 8, 16}})
+    ->ArgNames({"operation_count", "binding_count"})
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, ProfileGuardrailBarrierBatch20FinalWait)
+    ->Apply(ApplyProfileGuardrailModes)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, ProfileGuardrailBarrierBatch20SubmitOnly)
+    ->Apply(ApplyProfileGuardrailModes)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, ProfileGuardrailCopyBatch20FinalWait)
+    ->Apply(ApplyProfileGuardrailModes)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, ProfileGuardrailCopyBatch20SubmitOnly)
+    ->Apply(ApplyProfileGuardrailModes)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, ProfileGuardrailFillBatch20FinalWait)
+    ->Apply(ApplyProfileGuardrailModes)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, ProfileGuardrailFillBatch20SubmitOnly)
+    ->Apply(ApplyProfileGuardrailModes)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, ProfileGuardrailDispatchBatch20FinalWait)
+    ->Apply(ApplyProfileGuardrailModes)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, ProfileGuardrailDispatchBatch20SubmitOnly)
+    ->Apply(ApplyProfileGuardrailModes)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark,
+                     ProfileGuardrailCommandBufferDispatchChain20FinalWait)
+    ->Apply(ApplyProfileGuardrailModes)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark,
+                     ProfileGuardrailCommandBufferDispatchChain20SubmitOnly)
+    ->Apply(ApplyProfileGuardrailModes)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(
+    QueueBenchmark,
+    SameQueueCommandBufferDispatchChainStaticPublicFinalInlineSubmitOnly)
+    ->ArgsProduct({{1, 10, 100, 1000, 5000}, {0, 1, 4, 8, 16}})
+    ->ArgNames({"operation_count", "binding_count"})
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+BENCHMARK_REGISTER_F(QueueBenchmark, WaitBeforeSignalChain)
+    ->UseRealTime()
+    ->Unit(benchmark::kMicrosecond);
+
+}  // namespace
+
+int main(int argc, char** argv) {
+  iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_UNDEFINED_OK |
+                               IREE_FLAGS_PARSE_MODE_CONTINUE_AFTER_HELP,
+                           &argc, &argv);
+  benchmark::Initialize(&argc, argv);
+  if (benchmark::ReportUnrecognizedArguments(argc, argv)) return 1;
+  benchmark::RunSpecifiedBenchmarks();
+  benchmark::Shutdown();
+  QueueBenchmark::DeinitializeOnce();
+  return 0;
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/queue_benchmark_testdata.mlir b/runtime/src/iree/hal/drivers/amdgpu/util/queue_benchmark_testdata.mlir
new file mode 100644
index 0000000..afbb8d2
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/queue_benchmark_testdata.mlir
@@ -0,0 +1,470 @@
+// Benchmark executable with exports that vary only in HAL ABI binding count.
+//
+// The kernel bodies intentionally do no work. These exports are used to
+// measure queue_dispatch host-side validation, device pointer packing, and
+// reclaim bookkeeping as binding table width changes without mixing in device
+// memory traffic.
+
+#layout_0 = #hal.pipeline.layout<constants = 0, bindings = []>
+
+#layout_1 = #hal.pipeline.layout<constants = 0, bindings = [
+  #hal.pipeline.binding<storage_buffer>
+]>
+
+#layout_4 = #hal.pipeline.layout<constants = 0, bindings = [
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>
+]>
+
+#layout_8 = #hal.pipeline.layout<constants = 0, bindings = [
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>
+]>
+
+#layout_9 = #hal.pipeline.layout<constants = 0, bindings = [
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>
+]>
+
+#layout_16 = #hal.pipeline.layout<constants = 0, bindings = [
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>
+]>
+
+#layout_17 = #hal.pipeline.layout<constants = 0, bindings = [
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>
+]>
+
+#layout_24 = #hal.pipeline.layout<constants = 0, bindings = [
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>
+]>
+
+#layout_25 = #hal.pipeline.layout<constants = 0, bindings = [
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>
+]>
+
+#layout_256 = #hal.pipeline.layout<constants = 0, bindings = [
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>
+]>
+
+hal.executable.source public @queue_benchmark {
+  hal.executable.export public @binding_count_0 ordinal(0) layout(#layout_0) count(%arg0: !hal.device) -> (index, index, index) {
+    %c1 = arith.constant 1 : index
+    hal.return %c1, %c1, %c1 : index, index, index
+  } attributes {workgroup_size = [1 : index, 1 : index, 1 : index]}
+  hal.executable.export public @binding_count_1 ordinal(1) layout(#layout_1) count(%arg0: !hal.device) -> (index, index, index) {
+    %c1 = arith.constant 1 : index
+    hal.return %c1, %c1, %c1 : index, index, index
+  } attributes {workgroup_size = [1 : index, 1 : index, 1 : index]}
+  hal.executable.export public @binding_count_4 ordinal(2) layout(#layout_4) count(%arg0: !hal.device) -> (index, index, index) {
+    %c1 = arith.constant 1 : index
+    hal.return %c1, %c1, %c1 : index, index, index
+  } attributes {workgroup_size = [1 : index, 1 : index, 1 : index]}
+  hal.executable.export public @binding_count_8 ordinal(3) layout(#layout_8) count(%arg0: !hal.device) -> (index, index, index) {
+    %c1 = arith.constant 1 : index
+    hal.return %c1, %c1, %c1 : index, index, index
+  } attributes {workgroup_size = [1 : index, 1 : index, 1 : index]}
+  hal.executable.export public @binding_count_9 ordinal(4) layout(#layout_9) count(%arg0: !hal.device) -> (index, index, index) {
+    %c1 = arith.constant 1 : index
+    hal.return %c1, %c1, %c1 : index, index, index
+  } attributes {workgroup_size = [1 : index, 1 : index, 1 : index]}
+  hal.executable.export public @binding_count_16 ordinal(5) layout(#layout_16) count(%arg0: !hal.device) -> (index, index, index) {
+    %c1 = arith.constant 1 : index
+    hal.return %c1, %c1, %c1 : index, index, index
+  } attributes {workgroup_size = [1 : index, 1 : index, 1 : index]}
+  hal.executable.export public @binding_count_17 ordinal(6) layout(#layout_17) count(%arg0: !hal.device) -> (index, index, index) {
+    %c1 = arith.constant 1 : index
+    hal.return %c1, %c1, %c1 : index, index, index
+  } attributes {workgroup_size = [1 : index, 1 : index, 1 : index]}
+  hal.executable.export public @binding_count_24 ordinal(7) layout(#layout_24) count(%arg0: !hal.device) -> (index, index, index) {
+    %c1 = arith.constant 1 : index
+    hal.return %c1, %c1, %c1 : index, index, index
+  } attributes {workgroup_size = [1 : index, 1 : index, 1 : index]}
+  hal.executable.export public @binding_count_25 ordinal(8) layout(#layout_25) count(%arg0: !hal.device) -> (index, index, index) {
+    %c1 = arith.constant 1 : index
+    hal.return %c1, %c1, %c1 : index, index, index
+  } attributes {workgroup_size = [1 : index, 1 : index, 1 : index]}
+  hal.executable.export public @binding_count_256 ordinal(9) layout(#layout_256) count(%arg0: !hal.device) -> (index, index, index) {
+    %c1 = arith.constant 1 : index
+    hal.return %c1, %c1, %c1 : index, index, index
+  } attributes {workgroup_size = [1 : index, 1 : index, 1 : index]}
+  builtin.module {
+    func.func @binding_count_0() {
+      return
+    }
+    func.func @binding_count_1() {
+      return
+    }
+    func.func @binding_count_4() {
+      return
+    }
+    func.func @binding_count_8() {
+      return
+    }
+    func.func @binding_count_9() {
+      return
+    }
+    func.func @binding_count_16() {
+      return
+    }
+    func.func @binding_count_17() {
+      return
+    }
+    func.func @binding_count_24() {
+      return
+    }
+    func.func @binding_count_25() {
+      return
+    }
+    func.func @binding_count_256() {
+      return
+    }
+  }
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/queue_upload_ring.c b/runtime/src/iree/hal/drivers/amdgpu/util/queue_upload_ring.c
new file mode 100644
index 0000000..1925ff7
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/queue_upload_ring.c
@@ -0,0 +1,137 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/util/queue_upload_ring.h"
+
+//===----------------------------------------------------------------------===//
+// iree_hal_amdgpu_queue_upload_ring_t
+//===----------------------------------------------------------------------===//
+
+iree_status_t iree_hal_amdgpu_queue_upload_ring_initialize(
+    const iree_hal_amdgpu_libhsa_t* libhsa,
+    const iree_hal_amdgpu_queue_upload_ring_memory_t* memory,
+    uint32_t min_capacity, iree_hal_amdgpu_queue_upload_ring_t* out_ring) {
+  IREE_ASSERT_ARGUMENT(libhsa);
+  IREE_ASSERT_ARGUMENT(memory);
+  IREE_ASSERT_ARGUMENT(out_ring);
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, min_capacity);
+  memset(out_ring, 0, sizeof(*out_ring));
+
+  IREE_ASSERT(memory->memory_pool.handle,
+              "queue upload ring memory descriptor must provide a memory pool");
+  IREE_ASSERT(memory->access_agents,
+              "queue upload ring memory descriptor must provide access agents");
+  IREE_ASSERT(memory->access_agent_count > 0 &&
+                  memory->access_agent_count <= UINT32_MAX,
+              "queue upload ring access agent count must fit in HSA's "
+              "uint32_t");
+  if (!min_capacity || !iree_host_size_is_power_of_two(min_capacity)) {
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                             "queue upload ring capacity must be a non-zero "
+                             "power of two; got %u",
+                             min_capacity));
+  }
+  IREE_ASSERT(memory->publication.mode ==
+                  IREE_HAL_AMDGPU_KERNARG_RING_PUBLICATION_MODE_NONE ||
+              memory->publication.mode ==
+                  IREE_HAL_AMDGPU_KERNARG_RING_PUBLICATION_MODE_HDP_FLUSH);
+  IREE_ASSERT(memory->publication.mode !=
+                  IREE_HAL_AMDGPU_KERNARG_RING_PUBLICATION_MODE_HDP_FLUSH ||
+              memory->publication.hdp_mem_flush_control);
+
+  size_t alloc_granule = 0;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0,
+      iree_hsa_amd_memory_pool_get_info(
+          IREE_LIBHSA(libhsa), memory->memory_pool,
+          HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_GRANULE, &alloc_granule),
+      "querying HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_GRANULE for queue "
+      "upload ring allocation");
+  if (IREE_UNLIKELY(!alloc_granule ||
+                    !iree_host_size_is_power_of_two(alloc_granule))) {
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                             "queue upload memory pool allocation granule must "
+                             "be a non-zero power of two (got %zu)",
+                             alloc_granule));
+  }
+
+  uint32_t capacity = min_capacity;
+  while ((uint64_t)capacity < alloc_granule && capacity <= UINT32_MAX / 2) {
+    capacity <<= 1;
+  }
+  IREE_ASSERT(capacity >= min_capacity);
+  IREE_ASSERT(iree_host_size_is_power_of_two(capacity));
+  if (IREE_UNLIKELY(!iree_host_size_has_alignment(capacity, alloc_granule))) {
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                             "queue upload ring capacity %u bytes is not "
+                             "aligned to pool allocation granule %zu",
+                             capacity, alloc_granule));
+  }
+
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0,
+      iree_hsa_amd_memory_pool_allocate(
+          IREE_LIBHSA(libhsa), memory->memory_pool, capacity,
+          HSA_AMD_MEMORY_POOL_STANDARD_FLAG, (void**)&out_ring->base),
+      "allocating queue upload ring of %u bytes", capacity);
+  iree_status_t status = iree_hsa_amd_agents_allow_access(
+      IREE_LIBHSA(libhsa), (uint32_t)memory->access_agent_count,
+      memory->access_agents, /*flags=*/NULL, out_ring->base);
+  if (!iree_status_is_ok(status)) {
+    status = iree_status_join(status, iree_hsa_amd_memory_pool_free(
+                                          IREE_LIBHSA(libhsa), out_ring->base));
+    memset(out_ring, 0, sizeof(*out_ring));
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, status,
+        "making queue upload ring allocation visible to %" PRIhsz " HSA agents",
+        memory->access_agent_count);
+  }
+
+  out_ring->device_base = (uint64_t)(uintptr_t)out_ring->base;
+  out_ring->capacity = capacity;
+  out_ring->mask = capacity - 1;
+  out_ring->publication = memory->publication;
+  iree_atomic_store(&out_ring->write_position, 0, iree_memory_order_relaxed);
+  iree_atomic_store(&out_ring->read_position, 0, iree_memory_order_relaxed);
+
+  // Fault in the host mapping on the initialization path instead of the first
+  // submitter that writes into the ring.
+  out_ring->base[0] = 0;
+
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, capacity);
+  IREE_TRACE_ZONE_END(z0);
+  return iree_ok_status();
+}
+
+void iree_hal_amdgpu_queue_upload_ring_deinitialize(
+    const iree_hal_amdgpu_libhsa_t* libhsa,
+    iree_hal_amdgpu_queue_upload_ring_t* ring) {
+  IREE_ASSERT_ARGUMENT(libhsa);
+  IREE_ASSERT_ARGUMENT(ring);
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  const uint64_t write = (uint64_t)iree_atomic_load(&ring->write_position,
+                                                    iree_memory_order_relaxed);
+  const uint64_t read = (uint64_t)iree_atomic_load(&ring->read_position,
+                                                   iree_memory_order_relaxed);
+  IREE_ASSERT(write == read,
+              "queue upload ring has %" PRIu64
+              " unreleased bytes at deinit (write=%" PRIu64 ", read=%" PRIu64
+              ")",
+              write - read, write, read);
+
+  if (ring->base) {
+    iree_hal_amdgpu_hsa_cleanup_assert_success(
+        iree_hsa_amd_memory_pool_free_raw(libhsa, ring->base));
+  }
+  memset(ring, 0, sizeof(*ring));
+
+  IREE_TRACE_ZONE_END(z0);
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/queue_upload_ring.h b/runtime/src/iree/hal/drivers/amdgpu/util/queue_upload_ring.h
new file mode 100644
index 0000000..767d28a
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/queue_upload_ring.h
@@ -0,0 +1,222 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_UTIL_QUEUE_UPLOAD_RING_H_
+#define IREE_HAL_DRIVERS_AMDGPU_UTIL_QUEUE_UPLOAD_RING_H_
+
+#include "iree/base/api.h"
+#include "iree/base/internal/atomics.h"
+#include "iree/hal/drivers/amdgpu/util/kernarg_ring.h"
+#include "iree/hal/drivers/amdgpu/util/libhsa.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Memory backing policy for a queue-owned upload ring.
+//
+// The descriptor is consumed during ring initialization and may reference
+// caller-owned stack storage for |access_agents|. Upload rings use the same
+// host-write publication policy as queue-owned kernargs because both carry
+// host-produced records consumed by device packets after AQL publication.
+typedef struct iree_hal_amdgpu_queue_upload_ring_memory_t {
+  // HSA memory pool used for the ring allocation.
+  hsa_amd_memory_pool_t memory_pool;
+  // Agents granted explicit access to the ring allocation.
+  const hsa_agent_t* access_agents;
+  // Number of entries in |access_agents|.
+  iree_host_size_t access_agent_count;
+  // Host-write publication mechanism for this memory pool.
+  iree_hal_amdgpu_kernarg_ring_publication_t publication;
+} iree_hal_amdgpu_queue_upload_ring_memory_t;
+
+// Byte span reserved from a queue upload ring.
+typedef struct iree_hal_amdgpu_queue_upload_span_t {
+  // Host pointer used by the CPU submission path to populate the record.
+  uint8_t* host_ptr;
+  // Device-visible pointer value corresponding to |host_ptr|.
+  uint64_t device_ptr;
+  // Number of bytes reserved in this span.
+  iree_host_size_t length;
+  // Logical write position after this allocation, used for epoch reclaim.
+  uint64_t end_position;
+} iree_hal_amdgpu_queue_upload_span_t;
+
+// Per-queue byte-granular upload allocator for small device-visible control
+// records such as binding pointer tables and device-side fixup arguments.
+//
+// Thread safety:
+//   allocate() is multi-producer safe (CAS on write_position).
+//   reclaim() must be called from a single thread (notification drain).
+//   Multiple threads may call allocate() concurrently.
+//
+// Backpressure contract:
+//   Callers use can_allocate() as a non-mutating admission check under the
+//   queue submission mutex, then allocate() the same request before publishing
+//   packets. Alignment padding and wrap skips are in-flight bytes until the
+//   submission that caused them retires.
+//
+// Memory ordering:
+//   write_position only claims space and uses relaxed ordering. Device
+//   consumers cannot observe records until the caller publishes host writes and
+//   commits the dependent AQL packet headers. read_position uses release
+//   (drain) / acquire (allocators) so submitters observe reclaimed capacity.
+typedef struct iree_hal_amdgpu_queue_upload_ring_t {
+  // Base pointer to the HSA memory-pool allocation.
+  uint8_t* base;
+
+  // Device-visible pointer value corresponding to |base|.
+  uint64_t device_base;
+
+  // Power-of-two capacity in bytes.
+  uint32_t capacity;
+  // capacity - 1, for masking logical positions to physical ring offsets.
+  uint32_t mask;
+
+  // Host-write publication mechanism for this ring.
+  iree_hal_amdgpu_kernarg_ring_publication_t publication;
+
+  // Monotonically increasing write position in bytes.
+  iree_atomic_int64_t write_position;
+
+  // Read position in bytes. Advanced by notification drain when the GPU
+  // completes work that referenced upload records.
+  iree_atomic_int64_t read_position;
+} iree_hal_amdgpu_queue_upload_ring_t;
+
+static inline uint64_t iree_hal_amdgpu_queue_upload_ring_align_position(
+    uint64_t position, iree_host_size_t alignment) {
+  return (position + alignment - 1) & ~((uint64_t)alignment - 1);
+}
+
+// Initializes the upload ring by allocating at least |min_capacity| bytes from
+// |memory->memory_pool|. |min_capacity| must be a power of two.
+iree_status_t iree_hal_amdgpu_queue_upload_ring_initialize(
+    const iree_hal_amdgpu_libhsa_t* libhsa,
+    const iree_hal_amdgpu_queue_upload_ring_memory_t* memory,
+    uint32_t min_capacity, iree_hal_amdgpu_queue_upload_ring_t* out_ring);
+
+// Deinitializes the upload ring and frees the backing HSA allocation. All
+// in-flight work must have completed and been reclaimed before calling.
+void iree_hal_amdgpu_queue_upload_ring_deinitialize(
+    const iree_hal_amdgpu_libhsa_t* libhsa,
+    iree_hal_amdgpu_queue_upload_ring_t* ring);
+
+// Publishes host writes to upload records before packet headers referencing
+// them become visible to the command processor.
+static inline void iree_hal_amdgpu_queue_upload_ring_publish_host_writes(
+    const iree_hal_amdgpu_queue_upload_ring_t* ring) {
+  if (ring->publication.mode ==
+      IREE_HAL_AMDGPU_KERNARG_RING_PUBLICATION_MODE_NONE) {
+    return;
+  }
+  IREE_ASSERT(ring->publication.mode ==
+              IREE_HAL_AMDGPU_KERNARG_RING_PUBLICATION_MODE_HDP_FLUSH);
+  IREE_ASSERT(ring->publication.hdp_mem_flush_control);
+  iree_hal_amdgpu_kernarg_ring_host_write_fence();
+  *ring->publication.hdp_mem_flush_control = 1u;
+  (void)*ring->publication.hdp_mem_flush_control;
+}
+
+// Returns true if |length| bytes with |alignment| can currently be allocated.
+// The check accounts for the same alignment padding and wrap skip used by
+// allocate() and does not mutate the ring.
+static inline bool iree_hal_amdgpu_queue_upload_ring_can_allocate(
+    iree_hal_amdgpu_queue_upload_ring_t* ring, iree_host_size_t length,
+    iree_host_size_t alignment) {
+  if (IREE_UNLIKELY(length == 0 || length > ring->capacity || alignment == 0 ||
+                    alignment > ring->capacity ||
+                    !iree_host_size_is_power_of_two(alignment))) {
+    return false;
+  }
+  uint64_t first_byte = (uint64_t)iree_atomic_load(&ring->write_position,
+                                                   iree_memory_order_relaxed);
+  first_byte =
+      iree_hal_amdgpu_queue_upload_ring_align_position(first_byte, alignment);
+  const uint64_t tail_length =
+      (uint64_t)ring->capacity - (first_byte & ring->mask);
+  if (length > tail_length) {
+    first_byte += tail_length;
+    first_byte =
+        iree_hal_amdgpu_queue_upload_ring_align_position(first_byte, alignment);
+  }
+  const uint64_t next_write_position = first_byte + length;
+  const uint64_t read_position = (uint64_t)iree_atomic_load(
+      &ring->read_position, iree_memory_order_acquire);
+  return next_write_position - read_position <= ring->capacity;
+}
+
+// Allocates |length| contiguous bytes aligned to |alignment| from the ring.
+//
+// REQUIRES: The caller must have already proved capacity with can_allocate()
+// under the same queue submission admission critical section.
+static inline iree_hal_amdgpu_queue_upload_span_t
+iree_hal_amdgpu_queue_upload_ring_allocate(
+    iree_hal_amdgpu_queue_upload_ring_t* ring, iree_host_size_t length,
+    iree_host_size_t alignment) {
+  iree_hal_amdgpu_queue_upload_span_t span = {0};
+  if (IREE_UNLIKELY(length == 0 || length > ring->capacity || alignment == 0 ||
+                    alignment > ring->capacity ||
+                    !iree_host_size_is_power_of_two(alignment))) {
+    return span;
+  }
+
+  int64_t observed_write_position =
+      iree_atomic_load(&ring->write_position, iree_memory_order_relaxed);
+  uint64_t first_byte = 0;
+  uint64_t next_write_position = 0;
+  for (;;) {
+    first_byte = (uint64_t)observed_write_position;
+    first_byte =
+        iree_hal_amdgpu_queue_upload_ring_align_position(first_byte, alignment);
+    const uint64_t tail_length =
+        (uint64_t)ring->capacity - (first_byte & ring->mask);
+    if (length > tail_length) {
+      first_byte += tail_length;
+      first_byte = iree_hal_amdgpu_queue_upload_ring_align_position(first_byte,
+                                                                    alignment);
+    }
+    next_write_position = first_byte + length;
+    const uint64_t read_position = (uint64_t)iree_atomic_load(
+        &ring->read_position, iree_memory_order_acquire);
+    if (IREE_UNLIKELY(next_write_position - read_position > ring->capacity)) {
+      return span;
+    }
+    if (iree_atomic_compare_exchange_weak(
+            &ring->write_position, &observed_write_position,
+            (int64_t)next_write_position, iree_memory_order_relaxed,
+            iree_memory_order_relaxed)) {
+      break;
+    }
+  }
+
+  const uint32_t physical_offset = (uint32_t)(first_byte & ring->mask);
+  span.host_ptr = ring->base + physical_offset;
+  span.device_ptr = ring->device_base + physical_offset;
+  span.length = length;
+  span.end_position = next_write_position;
+  return span;
+}
+
+// Reclaims all upload bytes up to |new_read_position|. Called by notification
+// drain after confirming the GPU has completed work that referenced the span.
+static inline void iree_hal_amdgpu_queue_upload_ring_reclaim(
+    iree_hal_amdgpu_queue_upload_ring_t* ring, uint64_t new_read_position) {
+  IREE_ASSERT(new_read_position >=
+              (uint64_t)iree_atomic_load(&ring->read_position,
+                                         iree_memory_order_relaxed));
+  IREE_ASSERT(new_read_position <=
+              (uint64_t)iree_atomic_load(&ring->write_position,
+                                         iree_memory_order_relaxed));
+  iree_atomic_store(&ring->read_position, (int64_t)new_read_position,
+                    iree_memory_order_release);
+}
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_UTIL_QUEUE_UPLOAD_RING_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/queue_upload_ring_test.cc b/runtime/src/iree/hal/drivers/amdgpu/util/queue_upload_ring_test.cc
new file mode 100644
index 0000000..e11a9d5
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/queue_upload_ring_test.cc
@@ -0,0 +1,130 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/util/queue_upload_ring.h"
+
+#include <cstdint>
+
+#include "iree/testing/gtest.h"
+
+namespace {
+
+static void InitializeTestRing(uint8_t* storage, uint32_t capacity,
+                               iree_hal_amdgpu_queue_upload_ring_t* out_ring) {
+  out_ring->base = storage;
+  out_ring->device_base = 0x10000000ull;
+  out_ring->capacity = capacity;
+  out_ring->mask = capacity - 1;
+  iree_atomic_store(&out_ring->write_position, 0, iree_memory_order_relaxed);
+  iree_atomic_store(&out_ring->read_position, 0, iree_memory_order_relaxed);
+}
+
+TEST(QueueUploadRingTest, AllocatesAlignedSpansAndReclaims) {
+  alignas(64) uint8_t storage[64] = {};
+  iree_hal_amdgpu_queue_upload_ring_t ring = {};
+  InitializeTestRing(storage, IREE_ARRAYSIZE(storage), &ring);
+
+  EXPECT_TRUE(iree_hal_amdgpu_queue_upload_ring_can_allocate(
+      &ring, /*length=*/7, /*alignment=*/1));
+  iree_hal_amdgpu_queue_upload_span_t first =
+      iree_hal_amdgpu_queue_upload_ring_allocate(&ring, /*length=*/7,
+                                                 /*alignment=*/1);
+  EXPECT_EQ(first.host_ptr, &storage[0]);
+  EXPECT_EQ(first.device_ptr, 0x10000000ull);
+  EXPECT_EQ(first.length, 7u);
+  EXPECT_EQ(first.end_position, 7u);
+
+  iree_hal_amdgpu_queue_upload_span_t second =
+      iree_hal_amdgpu_queue_upload_ring_allocate(&ring, /*length=*/8,
+                                                 /*alignment=*/8);
+  EXPECT_EQ(second.host_ptr, &storage[8]);
+  EXPECT_EQ(second.device_ptr, 0x10000008ull);
+  EXPECT_EQ(second.length, 8u);
+  EXPECT_EQ(second.end_position, 16u);
+
+  iree_hal_amdgpu_queue_upload_ring_reclaim(&ring, second.end_position);
+  EXPECT_TRUE(iree_hal_amdgpu_queue_upload_ring_can_allocate(
+      &ring, /*length=*/16, /*alignment=*/32));
+
+  iree_hal_amdgpu_queue_upload_span_t wrapped =
+      iree_hal_amdgpu_queue_upload_ring_allocate(&ring, /*length=*/16,
+                                                 /*alignment=*/32);
+  EXPECT_EQ(wrapped.host_ptr, &storage[32]);
+  EXPECT_EQ(wrapped.device_ptr, 0x10000020ull);
+  EXPECT_EQ(wrapped.length, 16u);
+  EXPECT_EQ(wrapped.end_position, 48u);
+}
+
+TEST(QueueUploadRingTest, SkipsTailWhenAlignedSpanWouldWrap) {
+  alignas(64) uint8_t storage[64] = {};
+  iree_hal_amdgpu_queue_upload_ring_t ring = {};
+  InitializeTestRing(storage, IREE_ARRAYSIZE(storage), &ring);
+
+  iree_hal_amdgpu_queue_upload_span_t first =
+      iree_hal_amdgpu_queue_upload_ring_allocate(&ring, /*length=*/48,
+                                                 /*alignment=*/1);
+  ASSERT_NE(first.host_ptr, nullptr);
+  iree_hal_amdgpu_queue_upload_ring_reclaim(&ring, first.end_position);
+
+  iree_hal_amdgpu_queue_upload_span_t wrapped =
+      iree_hal_amdgpu_queue_upload_ring_allocate(&ring, /*length=*/24,
+                                                 /*alignment=*/32);
+  EXPECT_EQ(wrapped.host_ptr, &storage[0]);
+  EXPECT_EQ(wrapped.device_ptr, 0x10000000ull);
+  EXPECT_EQ(wrapped.length, 24u);
+  EXPECT_EQ(wrapped.end_position, 88u);
+}
+
+TEST(QueueUploadRingTest, RejectsInvalidRequests) {
+  alignas(64) uint8_t storage[32] = {};
+  iree_hal_amdgpu_queue_upload_ring_t ring = {};
+  InitializeTestRing(storage, IREE_ARRAYSIZE(storage), &ring);
+
+  EXPECT_FALSE(iree_hal_amdgpu_queue_upload_ring_can_allocate(
+      &ring, /*length=*/0, /*alignment=*/1));
+  EXPECT_FALSE(iree_hal_amdgpu_queue_upload_ring_can_allocate(
+      &ring, /*length=*/33, /*alignment=*/1));
+  EXPECT_FALSE(iree_hal_amdgpu_queue_upload_ring_can_allocate(
+      &ring, /*length=*/8, /*alignment=*/0));
+  EXPECT_FALSE(iree_hal_amdgpu_queue_upload_ring_can_allocate(
+      &ring, /*length=*/8, /*alignment=*/3));
+  EXPECT_FALSE(iree_hal_amdgpu_queue_upload_ring_can_allocate(
+      &ring, /*length=*/8, /*alignment=*/64));
+
+  EXPECT_EQ(iree_hal_amdgpu_queue_upload_ring_allocate(&ring, /*length=*/0,
+                                                       /*alignment=*/1)
+                .host_ptr,
+            nullptr);
+  EXPECT_EQ(iree_hal_amdgpu_queue_upload_ring_allocate(&ring, /*length=*/8,
+                                                       /*alignment=*/3)
+                .host_ptr,
+            nullptr);
+}
+
+TEST(QueueUploadRingTest, PublicationModeNoneSkipsRegisterWrite) {
+  volatile uint32_t hdp_mem_flush_control = 0xCAFEu;
+  iree_hal_amdgpu_queue_upload_ring_t ring = {};
+  ring.publication.mode = IREE_HAL_AMDGPU_KERNARG_RING_PUBLICATION_MODE_NONE;
+  ring.publication.hdp_mem_flush_control = &hdp_mem_flush_control;
+
+  iree_hal_amdgpu_queue_upload_ring_publish_host_writes(&ring);
+
+  EXPECT_EQ(hdp_mem_flush_control, 0xCAFEu);
+}
+
+TEST(QueueUploadRingTest, PublicationModeHdpFlushWritesRegister) {
+  volatile uint32_t hdp_mem_flush_control = 0u;
+  iree_hal_amdgpu_queue_upload_ring_t ring = {};
+  ring.publication.mode =
+      IREE_HAL_AMDGPU_KERNARG_RING_PUBLICATION_MODE_HDP_FLUSH;
+  ring.publication.hdp_mem_flush_control = &hdp_mem_flush_control;
+
+  iree_hal_amdgpu_queue_upload_ring_publish_host_writes(&ring);
+
+  EXPECT_EQ(hdp_mem_flush_control, 1u);
+}
+
+}  // namespace
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/signal_pool.c b/runtime/src/iree/hal/drivers/amdgpu/util/signal_pool.c
new file mode 100644
index 0000000..869e9af
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/signal_pool.c
@@ -0,0 +1,163 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/util/signal_pool.h"
+
+//===----------------------------------------------------------------------===//
+// iree_hal_amdgpu_host_signal_pool_t
+//===----------------------------------------------------------------------===//
+
+// Creates |count| signals and pushes them onto the free list.
+// Grows the free list array if needed. Caller must hold the mutex.
+static iree_status_t iree_hal_amdgpu_host_signal_pool_grow(
+    iree_hal_amdgpu_host_signal_pool_t* pool, iree_host_size_t count) {
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)count);
+
+  // Grow the array to hold all existing + new signals. The array capacity
+  // must always be >= allocated_count so that every signal can be returned
+  // to the free list simultaneously.
+  iree_host_size_t required_capacity = pool->allocated_count + count;
+  iree_status_t status = iree_ok_status();
+  if (required_capacity > pool->free_capacity) {
+    status = iree_allocator_grow_array(
+        pool->host_allocator, required_capacity, sizeof(hsa_signal_t),
+        &pool->free_capacity, (void**)&pool->free_signals);
+  }
+
+  // Create signals and push onto the free list.
+  iree_host_size_t created_count = 0;
+  for (iree_host_size_t i = 0; i < count && iree_status_is_ok(status); ++i) {
+    hsa_signal_t signal = {0};
+    status = iree_hsa_amd_signal_create(IREE_LIBHSA(pool->libhsa),
+                                        /*initial_value=*/0,
+                                        /*num_consumers=*/0, /*consumers=*/NULL,
+                                        /*attributes=*/0, &signal);
+    if (iree_status_is_ok(status)) {
+      pool->free_signals[pool->free_count + i] = signal;
+      ++created_count;
+    }
+  }
+
+  if (iree_status_is_ok(status)) {
+    pool->free_count += created_count;
+    pool->allocated_count += created_count;
+  } else {
+    // Destroy any signals we created before the failure.
+    for (iree_host_size_t i = 0; i < created_count; ++i) {
+      status = iree_status_join(
+          status,
+          iree_hsa_signal_destroy(IREE_LIBHSA(pool->libhsa),
+                                  pool->free_signals[pool->free_count + i]));
+    }
+  }
+
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+iree_status_t iree_hal_amdgpu_host_signal_pool_initialize(
+    const iree_hal_amdgpu_libhsa_t* libhsa, iree_host_size_t initial_capacity,
+    iree_host_size_t batch_size, iree_allocator_t host_allocator,
+    iree_hal_amdgpu_host_signal_pool_t* out_pool) {
+  IREE_ASSERT_ARGUMENT(libhsa);
+  IREE_ASSERT_ARGUMENT(out_pool);
+  IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)initial_capacity);
+
+  memset(out_pool, 0, sizeof(*out_pool));
+  out_pool->libhsa = libhsa;
+  out_pool->host_allocator = host_allocator;
+  out_pool->batch_size =
+      batch_size ? batch_size
+                 : IREE_HAL_AMDGPU_HOST_SIGNAL_POOL_BATCH_SIZE_DEFAULT;
+
+  iree_slim_mutex_initialize(&out_pool->mutex);
+
+  iree_status_t status = iree_ok_status();
+  if (initial_capacity > 0) {
+    iree_slim_mutex_lock(&out_pool->mutex);
+    status = iree_hal_amdgpu_host_signal_pool_grow(out_pool, initial_capacity);
+    iree_slim_mutex_unlock(&out_pool->mutex);
+  }
+
+  if (!iree_status_is_ok(status)) {
+    iree_hal_amdgpu_host_signal_pool_deinitialize(out_pool);
+  }
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+void iree_hal_amdgpu_host_signal_pool_deinitialize(
+    iree_hal_amdgpu_host_signal_pool_t* pool) {
+  IREE_ASSERT_ARGUMENT(pool);
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  // All signals must have been released back to the pool.
+  IREE_ASSERT(pool->free_count == pool->allocated_count,
+              "signal pool has outstanding unreleased signals");
+
+  // Destroy all signals via the free list.
+  for (iree_host_size_t i = 0; i < pool->free_count; ++i) {
+    iree_hal_amdgpu_hsa_cleanup_assert_success(
+        iree_hsa_signal_destroy_raw(pool->libhsa, pool->free_signals[i]));
+  }
+
+  iree_allocator_free(pool->host_allocator, pool->free_signals);
+  iree_slim_mutex_deinitialize(&pool->mutex);
+  memset(pool, 0, sizeof(*pool));
+
+  IREE_TRACE_ZONE_END(z0);
+}
+
+iree_status_t iree_hal_amdgpu_host_signal_pool_acquire(
+    iree_hal_amdgpu_host_signal_pool_t* pool, hsa_signal_value_t initial_value,
+    hsa_signal_t* out_signal) {
+  IREE_ASSERT_ARGUMENT(pool);
+  IREE_ASSERT_ARGUMENT(out_signal);
+  *out_signal = (hsa_signal_t){0};
+
+  iree_slim_mutex_lock(&pool->mutex);
+
+  // Grow the pool if empty.
+  iree_status_t status = iree_ok_status();
+  if (pool->free_count == 0) {
+    status = iree_hal_amdgpu_host_signal_pool_grow(pool, pool->batch_size);
+  }
+
+  // Pop from the free list (LIFO).
+  hsa_signal_t signal = {0};
+  if (iree_status_is_ok(status)) {
+    signal = pool->free_signals[--pool->free_count];
+  }
+
+  iree_slim_mutex_unlock(&pool->mutex);
+
+  // Reset the signal value outside the lock — the HSA signal lives in CPU
+  // kernarg memory so this is a local RAM write with no PCIe traffic.
+  if (iree_status_is_ok(status)) {
+    iree_hsa_signal_store_relaxed(IREE_LIBHSA(pool->libhsa), signal,
+                                  initial_value);
+    *out_signal = signal;
+  }
+  return status;
+}
+
+void iree_hal_amdgpu_host_signal_pool_release(
+    iree_hal_amdgpu_host_signal_pool_t* pool, hsa_signal_t signal) {
+  IREE_ASSERT_ARGUMENT(pool);
+
+  iree_slim_mutex_lock(&pool->mutex);
+
+  // The free list capacity is always >= allocated_count, so this can never
+  // overflow unless the signal was double-freed or came from another pool.
+  IREE_ASSERT(pool->free_count < pool->allocated_count,
+              "signal pool release overflow: double-free or wrong pool");
+
+  pool->free_signals[pool->free_count++] = signal;
+
+  iree_slim_mutex_unlock(&pool->mutex);
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/signal_pool.h b/runtime/src/iree/hal/drivers/amdgpu/util/signal_pool.h
new file mode 100644
index 0000000..a6bcec7
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/signal_pool.h
@@ -0,0 +1,86 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_UTIL_SIGNAL_POOL_H_
+#define IREE_HAL_DRIVERS_AMDGPU_UTIL_SIGNAL_POOL_H_
+
+#include "iree/base/api.h"
+#include "iree/base/threading/mutex.h"
+#include "iree/hal/drivers/amdgpu/util/libhsa.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+//===----------------------------------------------------------------------===//
+// iree_hal_amdgpu_host_signal_pool_t
+//===----------------------------------------------------------------------===//
+
+// Default number of signals to pre-create in a batch when the pool is empty.
+// Each signal costs ~10-50us to create via hsa_amd_signal_create, so batching
+// amortizes the overhead across multiple acquisitions.
+#define IREE_HAL_AMDGPU_HOST_SIGNAL_POOL_BATCH_SIZE_DEFAULT 32
+
+// Pool of HSA signals created via hsa_amd_signal_create.
+// These are full-featured signals with interrupt capability (mailbox event,
+// eventfd bridge) suitable for host waits, cross-device synchronization, and
+// proactor integration.
+//
+// Signals are expensive to create (~10-50us each) so the pool pre-creates them
+// in batches and maintains a free list for O(1) acquire/release. Released
+// signals are returned to the free list for reuse; they are only destroyed when
+// the pool is deinitialized.
+//
+// Thread-safe.
+typedef struct iree_hal_amdgpu_host_signal_pool_t {
+  // HSA API handle. Unowned.
+  const iree_hal_amdgpu_libhsa_t* libhsa;
+  // Host allocator for the free list array.
+  iree_allocator_t host_allocator;
+  // Number of signals to create per batch when the pool is empty.
+  iree_host_size_t batch_size;
+
+  // Guards access to the free list and growth.
+  iree_slim_mutex_t mutex;
+  // LIFO stack of available signals. Capacity grows to steady state and stays.
+  hsa_signal_t* free_signals IREE_GUARDED_BY(mutex);
+  iree_host_size_t free_count IREE_GUARDED_BY(mutex);
+  iree_host_size_t free_capacity IREE_GUARDED_BY(mutex);
+  // Total signals created by this pool. Used to assert all signals are returned
+  // before deinitialization.
+  iree_host_size_t allocated_count IREE_GUARDED_BY(mutex);
+} iree_hal_amdgpu_host_signal_pool_t;
+
+// Initializes the host signal pool.
+// |initial_capacity| signals will be pre-created. If 0, signals are created
+// lazily on first acquire.
+iree_status_t iree_hal_amdgpu_host_signal_pool_initialize(
+    const iree_hal_amdgpu_libhsa_t* libhsa, iree_host_size_t initial_capacity,
+    iree_host_size_t batch_size, iree_allocator_t host_allocator,
+    iree_hal_amdgpu_host_signal_pool_t* out_pool);
+
+// Deinitializes the pool and destroys all signals.
+// All acquired signals must have been released before calling.
+void iree_hal_amdgpu_host_signal_pool_deinitialize(
+    iree_hal_amdgpu_host_signal_pool_t* pool);
+
+// Acquires a signal from the pool, resetting its value to |initial_value|.
+// The signal will have interrupt capability for host waits.
+// Must be released back to the pool when no longer needed.
+iree_status_t iree_hal_amdgpu_host_signal_pool_acquire(
+    iree_hal_amdgpu_host_signal_pool_t* pool, hsa_signal_value_t initial_value,
+    hsa_signal_t* out_signal);
+
+// Releases a signal back to the pool for reuse.
+// The signal must not be in use by any pending operations.
+void iree_hal_amdgpu_host_signal_pool_release(
+    iree_hal_amdgpu_host_signal_pool_t* pool, hsa_signal_t signal);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_UTIL_SIGNAL_POOL_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/signal_pool_benchmark.cc b/runtime/src/iree/hal/drivers/amdgpu/util/signal_pool_benchmark.cc
new file mode 100644
index 0000000..c1c3dde
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/signal_pool_benchmark.cc
@@ -0,0 +1,127 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <benchmark/benchmark.h>
+
+#include "iree/base/api.h"
+#include "iree/hal/drivers/amdgpu/util/signal_pool.h"
+#include "iree/hal/drivers/amdgpu/util/topology.h"
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Shared HSA context (one per process, like gtest's SetUpTestSuite)
+//===----------------------------------------------------------------------===//
+
+class SignalPoolBenchmark : public benchmark::Fixture {
+ public:
+  static void InitializeOnce() {
+    if (initialized_) return;
+    initialized_ = true;
+    host_allocator_ = iree_allocator_system();
+    iree_status_t status = iree_hal_amdgpu_libhsa_initialize(
+        IREE_HAL_AMDGPU_LIBHSA_FLAG_NONE, iree_string_view_list_empty(),
+        host_allocator_, &libhsa_);
+    if (!iree_status_is_ok(status)) {
+      iree_status_free(status);
+      return;
+    }
+    status =
+        iree_hal_amdgpu_topology_initialize_with_defaults(&libhsa_, &topology_);
+    if (!iree_status_is_ok(status)) {
+      iree_status_free(status);
+      iree_hal_amdgpu_libhsa_deinitialize(&libhsa_);
+      return;
+    }
+    if (topology_.gpu_agent_count == 0) {
+      iree_hal_amdgpu_topology_deinitialize(&topology_);
+      iree_hal_amdgpu_libhsa_deinitialize(&libhsa_);
+      return;
+    }
+    available_ = true;
+  }
+
+  static void DeinitializeOnce() {
+    if (!available_) return;
+    iree_hal_amdgpu_topology_deinitialize(&topology_);
+    iree_hal_amdgpu_libhsa_deinitialize(&libhsa_);
+    available_ = false;
+  }
+
+  void SetUp(benchmark::State& state) override {
+    InitializeOnce();
+    if (!available_) {
+      state.SkipWithError("HSA not available or no GPU devices");
+    }
+  }
+
+ protected:
+  static bool initialized_;
+  static bool available_;
+  static iree_allocator_t host_allocator_;
+  static iree_hal_amdgpu_libhsa_t libhsa_;
+  static iree_hal_amdgpu_topology_t topology_;
+};
+
+bool SignalPoolBenchmark::initialized_ = false;
+bool SignalPoolBenchmark::available_ = false;
+iree_allocator_t SignalPoolBenchmark::host_allocator_;
+iree_hal_amdgpu_libhsa_t SignalPoolBenchmark::libhsa_;
+iree_hal_amdgpu_topology_t SignalPoolBenchmark::topology_;
+
+//===----------------------------------------------------------------------===//
+// Baseline: raw hsa_amd_signal_create + hsa_signal_destroy
+//===----------------------------------------------------------------------===//
+// Measures ROCR's SharedSignalPool pop+push plus signal construction.
+// This includes ROCR's process-global HybridMutex.
+
+BENCHMARK_DEFINE_F(SignalPoolBenchmark, RawHsaSignalCreateDestroy)
+(benchmark::State& state) {
+  for (auto _ : state) {
+    hsa_signal_t signal = {0};
+    IREE_CHECK_OK(iree_hsa_amd_signal_create(
+        IREE_LIBHSA(&libhsa_), /*initial_value=*/1, /*num_consumers=*/0,
+        /*consumers=*/NULL, /*attributes=*/0, &signal));
+    benchmark::DoNotOptimize(signal);
+    IREE_CHECK_OK(iree_hsa_signal_destroy(IREE_LIBHSA(&libhsa_), signal));
+  }
+}
+BENCHMARK_REGISTER_F(SignalPoolBenchmark, RawHsaSignalCreateDestroy);
+
+//===----------------------------------------------------------------------===//
+// Host signal pool: acquire + release (steady state)
+//===----------------------------------------------------------------------===//
+// Measures our pool's LIFO pop+push plus hsa_signal_store_relaxed for value
+// reset. The HSA signal memory lives in CPU kernarg memory (placed by ROCR),
+// so the relaxed store is a local system RAM write with no PCIe traffic.
+
+BENCHMARK_DEFINE_F(SignalPoolBenchmark, HostPoolAcquireRelease)
+(benchmark::State& state) {
+  iree_hal_amdgpu_host_signal_pool_t pool;
+  IREE_CHECK_OK(iree_hal_amdgpu_host_signal_pool_initialize(
+      &libhsa_, /*initial_capacity=*/64, /*batch_size=*/32, host_allocator_,
+      &pool));
+
+  for (auto _ : state) {
+    hsa_signal_t signal = {0};
+    IREE_CHECK_OK(iree_hal_amdgpu_host_signal_pool_acquire(&pool, 1, &signal));
+    benchmark::DoNotOptimize(signal);
+    iree_hal_amdgpu_host_signal_pool_release(&pool, signal);
+  }
+
+  iree_hal_amdgpu_host_signal_pool_deinitialize(&pool);
+}
+BENCHMARK_REGISTER_F(SignalPoolBenchmark, HostPoolAcquireRelease);
+
+}  // namespace
+
+int main(int argc, char** argv) {
+  benchmark::Initialize(&argc, argv);
+  benchmark::RunSpecifiedBenchmarks();
+  benchmark::Shutdown();
+  SignalPoolBenchmark::DeinitializeOnce();
+  return 0;
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/signal_pool_test.cc b/runtime/src/iree/hal/drivers/amdgpu/util/signal_pool_test.cc
new file mode 100644
index 0000000..7dc0840
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/signal_pool_test.cc
@@ -0,0 +1,165 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/util/signal_pool.h"
+
+#include "iree/base/api.h"
+#include "iree/hal/drivers/amdgpu/util/topology.h"
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace iree::hal::amdgpu {
+namespace {
+
+using iree::testing::status::StatusIs;
+
+//===----------------------------------------------------------------------===//
+// Test fixture with HSA initialization
+//===----------------------------------------------------------------------===//
+
+struct SignalPoolTest : public ::testing::Test {
+  static iree_allocator_t host_allocator;
+  static iree_hal_amdgpu_libhsa_t libhsa;
+  static iree_hal_amdgpu_topology_t topology;
+
+  static void SetUpTestSuite() {
+    IREE_TRACE_SCOPE();
+    host_allocator = iree_allocator_system();
+    iree_status_t status = iree_hal_amdgpu_libhsa_initialize(
+        IREE_HAL_AMDGPU_LIBHSA_FLAG_NONE, iree_string_view_list_empty(),
+        host_allocator, &libhsa);
+    if (!iree_status_is_ok(status)) {
+      iree_status_fprint(stderr, status);
+      iree_status_free(status);
+      GTEST_SKIP() << "HSA not available, skipping tests";
+    }
+    IREE_ASSERT_OK(
+        iree_hal_amdgpu_topology_initialize_with_defaults(&libhsa, &topology));
+    if (topology.gpu_agent_count == 0) {
+      GTEST_SKIP() << "no GPU devices available, skipping tests";
+    }
+  }
+
+  static void TearDownTestSuite() {
+    IREE_TRACE_SCOPE();
+    iree_hal_amdgpu_topology_deinitialize(&topology);
+    iree_hal_amdgpu_libhsa_deinitialize(&libhsa);
+  }
+};
+iree_allocator_t SignalPoolTest::host_allocator;
+iree_hal_amdgpu_libhsa_t SignalPoolTest::libhsa;
+iree_hal_amdgpu_topology_t SignalPoolTest::topology;
+
+//===----------------------------------------------------------------------===//
+// iree_hal_amdgpu_host_signal_pool_t
+//===----------------------------------------------------------------------===//
+
+TEST_F(SignalPoolTest, HostPoolLifetimeEmpty) {
+  IREE_TRACE_SCOPE();
+  iree_hal_amdgpu_host_signal_pool_t pool;
+  IREE_ASSERT_OK(iree_hal_amdgpu_host_signal_pool_initialize(
+      &libhsa, /*initial_capacity=*/0, /*batch_size=*/4, host_allocator,
+      &pool));
+  iree_hal_amdgpu_host_signal_pool_deinitialize(&pool);
+}
+
+TEST_F(SignalPoolTest, HostPoolLifetimePreallocated) {
+  IREE_TRACE_SCOPE();
+  iree_hal_amdgpu_host_signal_pool_t pool;
+  IREE_ASSERT_OK(iree_hal_amdgpu_host_signal_pool_initialize(
+      &libhsa, /*initial_capacity=*/16, /*batch_size=*/8, host_allocator,
+      &pool));
+  iree_hal_amdgpu_host_signal_pool_deinitialize(&pool);
+}
+
+TEST_F(SignalPoolTest, HostPoolAcquireRelease) {
+  IREE_TRACE_SCOPE();
+  iree_hal_amdgpu_host_signal_pool_t pool;
+  IREE_ASSERT_OK(iree_hal_amdgpu_host_signal_pool_initialize(
+      &libhsa, /*initial_capacity=*/4, /*batch_size=*/4, host_allocator,
+      &pool));
+
+  // Acquire a signal and verify it has the requested initial value.
+  hsa_signal_t signal = {0};
+  IREE_ASSERT_OK(iree_hal_amdgpu_host_signal_pool_acquire(
+      &pool, /*initial_value=*/42, &signal));
+  EXPECT_NE(signal.handle, 0u);
+  hsa_signal_value_t value =
+      iree_hsa_signal_load_relaxed(IREE_LIBHSA(&libhsa), signal);
+  EXPECT_EQ(value, 42);
+
+  // Release and re-acquire — should get a recycled signal.
+  iree_hal_amdgpu_host_signal_pool_release(&pool, signal);
+  hsa_signal_t signal2 = {0};
+  IREE_ASSERT_OK(iree_hal_amdgpu_host_signal_pool_acquire(
+      &pool, /*initial_value=*/99, &signal2));
+  EXPECT_NE(signal2.handle, 0u);
+  // LIFO: should get the same signal back.
+  EXPECT_EQ(signal2.handle, signal.handle);
+  value = iree_hsa_signal_load_relaxed(IREE_LIBHSA(&libhsa), signal2);
+  EXPECT_EQ(value, 99);
+
+  iree_hal_amdgpu_host_signal_pool_release(&pool, signal2);
+  iree_hal_amdgpu_host_signal_pool_deinitialize(&pool);
+}
+
+TEST_F(SignalPoolTest, HostPoolBatchGrowth) {
+  IREE_TRACE_SCOPE();
+  iree_hal_amdgpu_host_signal_pool_t pool;
+  IREE_ASSERT_OK(iree_hal_amdgpu_host_signal_pool_initialize(
+      &libhsa, /*initial_capacity=*/0, /*batch_size=*/4, host_allocator,
+      &pool));
+
+  // Acquire more signals than one batch — forces growth.
+  hsa_signal_t signals[10];
+  for (int i = 0; i < 10; ++i) {
+    IREE_ASSERT_OK(iree_hal_amdgpu_host_signal_pool_acquire(
+        &pool, /*initial_value=*/i, &signals[i]));
+    EXPECT_NE(signals[i].handle, 0u);
+  }
+
+  // All should be unique handles.
+  for (int i = 0; i < 10; ++i) {
+    for (int j = i + 1; j < 10; ++j) {
+      EXPECT_NE(signals[i].handle, signals[j].handle)
+          << "signals[" << i << "] == signals[" << j << "]";
+    }
+  }
+
+  // Release all — capacity is always >= allocated_count.
+  for (int i = 0; i < 10; ++i) {
+    iree_hal_amdgpu_host_signal_pool_release(&pool, signals[i]);
+  }
+
+  iree_hal_amdgpu_host_signal_pool_deinitialize(&pool);
+}
+
+TEST_F(SignalPoolTest, HostPoolAllOutstandingThenRelease) {
+  IREE_TRACE_SCOPE();
+  iree_hal_amdgpu_host_signal_pool_t pool;
+  IREE_ASSERT_OK(iree_hal_amdgpu_host_signal_pool_initialize(
+      &libhsa, /*initial_capacity=*/8, /*batch_size=*/8, host_allocator,
+      &pool));
+
+  // Acquire all 8 pre-created signals — pool is now empty.
+  hsa_signal_t signals[8];
+  for (int i = 0; i < 8; ++i) {
+    IREE_ASSERT_OK(iree_hal_amdgpu_host_signal_pool_acquire(
+        &pool, /*initial_value=*/0, &signals[i]));
+  }
+
+  // Release all 8 back. This exercises the case where free_count grows back
+  // to allocated_count — the free list must have been sized to accommodate
+  // this.
+  for (int i = 0; i < 8; ++i) {
+    iree_hal_amdgpu_host_signal_pool_release(&pool, signals[i]);
+  }
+
+  iree_hal_amdgpu_host_signal_pool_deinitialize(&pool);
+}
+
+}  // namespace
+}  // namespace iree::hal::amdgpu
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/target_id.c b/runtime/src/iree/hal/drivers/amdgpu/util/target_id.c
new file mode 100644
index 0000000..0e3bad1
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/target_id.c
@@ -0,0 +1,575 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/util/target_id.h"
+
+typedef enum iree_hal_amdgpu_target_feature_support_bits_e {
+  IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_NONE = 0u,
+  IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_SRAMECC = 1u << 0,
+  IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_XNACK = 1u << 1,
+} iree_hal_amdgpu_target_feature_support_bits_t;
+typedef uint32_t iree_hal_amdgpu_target_feature_support_flags_t;
+
+typedef struct iree_hal_amdgpu_target_id_mapping_t {
+  // Exact HSA ISA processor name.
+  iree_string_view_t exact_processor;
+  // Code-object processor selected for the exact processor.
+  iree_string_view_t code_object_processor;
+  // Feature support flags from
+  // iree_hal_amdgpu_target_feature_support_bits_t.
+  iree_hal_amdgpu_target_feature_support_flags_t feature_support;
+} iree_hal_amdgpu_target_id_mapping_t;
+
+static const iree_hal_amdgpu_target_id_mapping_t
+    iree_hal_amdgpu_target_id_mappings[] = {
+#include "iree/hal/drivers/amdgpu/util/target_id_map.inl"
+};
+
+static bool iree_hal_amdgpu_parse_decimal_digit(char c, uint32_t* out_value) {
+  if (c < '0' || c > '9') return false;
+  *out_value = (uint32_t)(c - '0');
+  return true;
+}
+
+static bool iree_hal_amdgpu_parse_hex_digit(char c, uint32_t* out_value) {
+  if (c >= '0' && c <= '9') {
+    *out_value = (uint32_t)(c - '0');
+    return true;
+  } else if (c >= 'a' && c <= 'f') {
+    *out_value = (uint32_t)(c - 'a' + 10);
+    return true;
+  } else if (c >= 'A' && c <= 'F') {
+    *out_value = (uint32_t)(c - 'A' + 10);
+    return true;
+  }
+  return false;
+}
+
+static bool iree_hal_amdgpu_parse_decimal_number(iree_string_view_t value,
+                                                 uint32_t* out_number) {
+  if (iree_string_view_is_empty(value)) return false;
+  uint64_t number = 0;
+  for (iree_host_size_t i = 0; i < value.size; ++i) {
+    uint32_t digit = 0;
+    if (!iree_hal_amdgpu_parse_decimal_digit(value.data[i], &digit)) {
+      return false;
+    }
+    number = number * 10 + digit;
+    if (number > UINT32_MAX) return false;
+  }
+  *out_number = (uint32_t)number;
+  return true;
+}
+
+static bool iree_hal_amdgpu_gfxip_version_equal(
+    iree_hal_amdgpu_gfxip_version_t lhs, iree_hal_amdgpu_gfxip_version_t rhs) {
+  return lhs.major == rhs.major && lhs.minor == rhs.minor &&
+         lhs.stepping == rhs.stepping;
+}
+
+static bool iree_hal_amdgpu_parse_exact_processor(
+    iree_string_view_t processor,
+    iree_hal_amdgpu_gfxip_version_t* out_version) {
+  memset(out_version, 0, sizeof(*out_version));
+  if (!iree_string_view_consume_prefix(&processor, IREE_SV("gfx"))) {
+    return false;
+  }
+
+  uint32_t major0 = 0;
+  uint32_t major1 = 0;
+  uint32_t minor = 0;
+  uint32_t stepping = 0;
+  if (processor.size == 4 &&
+      iree_hal_amdgpu_parse_decimal_digit(processor.data[0], &major0) &&
+      major0 == 1 &&
+      iree_hal_amdgpu_parse_decimal_digit(processor.data[1], &major1) &&
+      iree_hal_amdgpu_parse_decimal_digit(processor.data[2], &minor) &&
+      iree_hal_amdgpu_parse_hex_digit(processor.data[3], &stepping)) {
+    out_version->major = 10 + major1;
+    out_version->minor = minor;
+    out_version->stepping = stepping;
+    return true;
+  }
+  if (processor.size == 3 &&
+      iree_hal_amdgpu_parse_decimal_digit(processor.data[0], &major0) &&
+      iree_hal_amdgpu_parse_decimal_digit(processor.data[1], &minor) &&
+      iree_hal_amdgpu_parse_hex_digit(processor.data[2], &stepping)) {
+    out_version->major = major0;
+    out_version->minor = minor;
+    out_version->stepping = stepping;
+    return true;
+  }
+  return false;
+}
+
+static const iree_hal_amdgpu_target_id_mapping_t*
+iree_hal_amdgpu_target_id_lookup_mapping(iree_string_view_t exact_processor) {
+  for (iree_host_size_t i = 0;
+       i < IREE_ARRAYSIZE(iree_hal_amdgpu_target_id_mappings); ++i) {
+    if (iree_string_view_equal(
+            exact_processor,
+            iree_hal_amdgpu_target_id_mappings[i].exact_processor)) {
+      return &iree_hal_amdgpu_target_id_mappings[i];
+    }
+  }
+  return NULL;
+}
+
+static void iree_hal_amdgpu_target_id_apply_known_feature_support(
+    iree_hal_amdgpu_target_id_t* target_id) {
+  if (target_id->kind != IREE_HAL_AMDGPU_TARGET_KIND_EXACT) return;
+  const iree_hal_amdgpu_target_id_mapping_t* mapping =
+      iree_hal_amdgpu_target_id_lookup_mapping(target_id->processor);
+  if (mapping == NULL) return;
+
+  if (target_id->sramecc == IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_ANY &&
+      !iree_any_bit_set(mapping->feature_support,
+                        IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_SRAMECC)) {
+    target_id->sramecc = IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_UNSUPPORTED;
+  }
+  if (target_id->xnack == IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_ANY &&
+      !iree_any_bit_set(mapping->feature_support,
+                        IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_XNACK)) {
+    target_id->xnack = IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_UNSUPPORTED;
+  }
+}
+
+static bool iree_hal_amdgpu_parse_generic_processor(
+    iree_string_view_t processor,
+    iree_hal_amdgpu_gfxip_version_t* out_version) {
+  memset(out_version, 0, sizeof(*out_version));
+  if (!iree_string_view_consume_prefix(&processor, IREE_SV("gfx")) ||
+      !iree_string_view_consume_suffix(&processor, IREE_SV("-generic"))) {
+    return false;
+  }
+
+  iree_string_view_t major = iree_string_view_empty();
+  iree_string_view_t minor = iree_string_view_empty();
+  if (iree_string_view_split(processor, '-', &major, &minor) == -1) {
+    major = processor;
+  } else if (iree_string_view_find_char(minor, '-', 0) !=
+             IREE_STRING_VIEW_NPOS) {
+    return false;
+  }
+
+  uint32_t major_value = 0;
+  uint32_t minor_value = 0;
+  if (!iree_hal_amdgpu_parse_decimal_number(major, &major_value)) {
+    return false;
+  }
+  if (!iree_string_view_is_empty(minor) &&
+      !iree_hal_amdgpu_parse_decimal_number(minor, &minor_value)) {
+    return false;
+  }
+  out_version->major = major_value;
+  out_version->minor = minor_value;
+  out_version->stepping = 0;
+  return true;
+}
+
+static iree_status_t iree_hal_amdgpu_target_id_parse_processor(
+    iree_string_view_t processor, iree_hal_amdgpu_target_id_t* out_target_id) {
+  if (iree_string_view_is_empty(processor)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AMDGPU target ID has an empty processor name");
+  }
+
+  out_target_id->processor = processor;
+  if (iree_hal_amdgpu_parse_generic_processor(processor,
+                                              &out_target_id->version)) {
+    out_target_id->kind = IREE_HAL_AMDGPU_TARGET_KIND_GENERIC;
+    return iree_ok_status();
+  }
+  if (iree_hal_amdgpu_parse_exact_processor(processor,
+                                            &out_target_id->version)) {
+    out_target_id->kind = IREE_HAL_AMDGPU_TARGET_KIND_EXACT;
+    return iree_ok_status();
+  }
+  return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                          "unsupported AMDGPU target processor syntax: %.*s",
+                          (int)processor.size, processor.data);
+}
+
+static iree_status_t iree_hal_amdgpu_target_id_parse_feature(
+    iree_string_view_t feature,
+    iree_hal_amdgpu_target_feature_state_t* inout_sramecc,
+    iree_hal_amdgpu_target_feature_state_t* inout_xnack) {
+  if (feature.size < 2) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AMDGPU target feature suffix is empty");
+  }
+
+  const char selector = feature.data[feature.size - 1];
+  iree_hal_amdgpu_target_feature_state_t state =
+      IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_ANY;
+  if (selector == '+') {
+    state = IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_ON;
+  } else if (selector == '-') {
+    state = IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_OFF;
+  } else {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AMDGPU target feature suffix missing +/-: %.*s",
+                            (int)feature.size, feature.data);
+  }
+
+  iree_string_view_t name = iree_string_view_remove_suffix(feature, /*n=*/1);
+  iree_hal_amdgpu_target_feature_state_t* feature_state = NULL;
+  if (iree_string_view_equal(name, IREE_SV("sramecc"))) {
+    feature_state = inout_sramecc;
+  } else if (iree_string_view_equal(name, IREE_SV("xnack"))) {
+    feature_state = inout_xnack;
+  } else {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "unsupported AMDGPU target feature suffix: %.*s",
+                            (int)feature.size, feature.data);
+  }
+  if (*feature_state != IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_ANY) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "duplicate AMDGPU target feature suffix: %.*s",
+                            (int)name.size, name.data);
+  }
+  *feature_state = state;
+  return iree_ok_status();
+}
+
+IREE_API_EXPORT iree_status_t iree_hal_amdgpu_target_id_parse(
+    iree_string_view_t value, iree_hal_amdgpu_target_id_parse_flags_t flags,
+    iree_hal_amdgpu_target_id_t* out_target_id) {
+  IREE_ASSERT_ARGUMENT(out_target_id);
+  memset(out_target_id, 0, sizeof(*out_target_id));
+  out_target_id->sramecc = IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_ANY;
+  out_target_id->xnack = IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_ANY;
+
+  const iree_hal_amdgpu_target_id_parse_flags_t known_flags =
+      IREE_HAL_AMDGPU_TARGET_ID_PARSE_FLAG_ALLOW_HSA_PREFIX |
+      IREE_HAL_AMDGPU_TARGET_ID_PARSE_FLAG_ALLOW_ARCH_ONLY |
+      IREE_HAL_AMDGPU_TARGET_ID_PARSE_FLAG_ALLOW_FEATURE_SUFFIXES;
+  if (IREE_UNLIKELY(iree_any_bit_set(flags, ~known_flags))) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "unknown AMDGPU target ID parse flags: 0x%x",
+                            flags);
+  }
+
+  const bool allow_hsa_prefix = iree_any_bit_set(
+      flags, IREE_HAL_AMDGPU_TARGET_ID_PARSE_FLAG_ALLOW_HSA_PREFIX);
+  const bool allow_arch_only =
+      flags == IREE_HAL_AMDGPU_TARGET_ID_PARSE_FLAG_NONE ||
+      iree_any_bit_set(flags,
+                       IREE_HAL_AMDGPU_TARGET_ID_PARSE_FLAG_ALLOW_ARCH_ONLY);
+  const bool allow_feature_suffixes = iree_any_bit_set(
+      flags, IREE_HAL_AMDGPU_TARGET_ID_PARSE_FLAG_ALLOW_FEATURE_SUFFIXES);
+
+  if (iree_string_view_starts_with(value, IREE_SV("amdgcn-amd-amdhsa--"))) {
+    if (!allow_hsa_prefix) {
+      return iree_make_status(
+          IREE_STATUS_INVALID_ARGUMENT,
+          "HSA ISA prefix not accepted in AMDGPU target ID: %.*s",
+          (int)value.size, value.data);
+    }
+    value = iree_string_view_substr(value, IREE_SV("amdgcn-amd-amdhsa--").size,
+                                    IREE_STRING_VIEW_NPOS);
+  } else if (!allow_arch_only) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "bare processor name not accepted in AMDGPU target ID: %.*s",
+        (int)value.size, value.data);
+  }
+
+  iree_string_view_t processor = value;
+  iree_string_view_t feature_list = iree_string_view_empty();
+  if (iree_string_view_split(value, ':', &processor, &feature_list) != -1) {
+    if (!allow_feature_suffixes) {
+      return iree_make_status(
+          IREE_STATUS_INVALID_ARGUMENT,
+          "AMDGPU target feature suffixes not accepted in target ID: %.*s",
+          (int)value.size, value.data);
+    }
+    if (iree_string_view_is_empty(feature_list) ||
+        value.data[value.size - 1] == ':') {
+      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                              "AMDGPU target feature suffix is empty");
+    }
+  }
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_target_id_parse_processor(processor, out_target_id));
+
+  while (!iree_string_view_is_empty(feature_list)) {
+    iree_string_view_t feature = iree_string_view_empty();
+    iree_string_view_t remaining_features = iree_string_view_empty();
+    if (iree_string_view_split(feature_list, ':', &feature,
+                               &remaining_features) == -1) {
+      feature = feature_list;
+    }
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_target_id_parse_feature(
+        feature, &out_target_id->sramecc, &out_target_id->xnack));
+    feature_list = remaining_features;
+  }
+  iree_hal_amdgpu_target_id_apply_known_feature_support(out_target_id);
+  return iree_ok_status();
+}
+
+IREE_API_EXPORT iree_status_t iree_hal_amdgpu_target_id_parse_hsa_isa_name(
+    iree_string_view_t value, iree_hal_amdgpu_target_id_t* out_target_id) {
+  return iree_hal_amdgpu_target_id_parse(
+      value,
+      IREE_HAL_AMDGPU_TARGET_ID_PARSE_FLAG_ALLOW_HSA_PREFIX |
+          IREE_HAL_AMDGPU_TARGET_ID_PARSE_FLAG_ALLOW_FEATURE_SUFFIXES,
+      out_target_id);
+}
+
+typedef struct iree_hal_amdgpu_target_id_formatter_t {
+  // Caller-provided output buffer; NULL when only querying required length.
+  char* buffer;
+  // Caller-provided output buffer capacity in bytes.
+  iree_host_size_t capacity;
+  // Required output length excluding the NUL terminator.
+  iree_host_size_t length;
+} iree_hal_amdgpu_target_id_formatter_t;
+
+static void iree_hal_amdgpu_target_id_formatter_append(
+    iree_hal_amdgpu_target_id_formatter_t* formatter,
+    iree_string_view_t value) {
+  if (formatter->buffer != NULL && formatter->capacity > 0 &&
+      formatter->length < formatter->capacity - 1) {
+    const iree_host_size_t available =
+        formatter->capacity - 1 - formatter->length;
+    const iree_host_size_t copy_length = iree_min(value.size, available);
+    memcpy(formatter->buffer + formatter->length, value.data, copy_length);
+    formatter->buffer[formatter->length + copy_length] = 0;
+  }
+  formatter->length += value.size;
+}
+
+static void iree_hal_amdgpu_target_id_formatter_append_feature(
+    iree_hal_amdgpu_target_id_formatter_t* formatter, iree_string_view_t name,
+    iree_hal_amdgpu_target_feature_state_t state) {
+  if (state == IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_OFF) {
+    iree_hal_amdgpu_target_id_formatter_append(formatter, IREE_SV(":"));
+    iree_hal_amdgpu_target_id_formatter_append(formatter, name);
+    iree_hal_amdgpu_target_id_formatter_append(formatter, IREE_SV("-"));
+  } else if (state == IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_ON) {
+    iree_hal_amdgpu_target_id_formatter_append(formatter, IREE_SV(":"));
+    iree_hal_amdgpu_target_id_formatter_append(formatter, name);
+    iree_hal_amdgpu_target_id_formatter_append(formatter, IREE_SV("+"));
+  }
+}
+
+IREE_API_EXPORT iree_status_t
+iree_hal_amdgpu_target_id_format(const iree_hal_amdgpu_target_id_t* target_id,
+                                 iree_host_size_t buffer_capacity, char* buffer,
+                                 iree_host_size_t* out_buffer_length) {
+  IREE_ASSERT_ARGUMENT(target_id);
+  if (IREE_UNLIKELY(iree_string_view_is_empty(target_id->processor))) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AMDGPU target ID has no processor name");
+  }
+
+  iree_hal_amdgpu_target_id_formatter_t formatter = {
+      .buffer = buffer,
+      .capacity = buffer_capacity,
+      .length = 0,
+  };
+  if (buffer != NULL && buffer_capacity > 0) buffer[0] = 0;
+  iree_hal_amdgpu_target_id_formatter_append(&formatter, target_id->processor);
+  iree_hal_amdgpu_target_id_formatter_append_feature(
+      &formatter, IREE_SV("sramecc"), target_id->sramecc);
+  iree_hal_amdgpu_target_id_formatter_append_feature(
+      &formatter, IREE_SV("xnack"), target_id->xnack);
+  if (out_buffer_length != NULL) {
+    *out_buffer_length = formatter.length;
+  }
+  if (buffer != NULL && buffer_capacity <= formatter.length) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "AMDGPU target ID buffer capacity exceeded");
+  }
+  return iree_ok_status();
+}
+
+static bool iree_hal_amdgpu_target_id_lookup_code_object_processor(
+    iree_string_view_t exact_processor,
+    iree_string_view_t* out_code_object_processor) {
+  const iree_hal_amdgpu_target_id_mapping_t* mapping =
+      iree_hal_amdgpu_target_id_lookup_mapping(exact_processor);
+  if (mapping == NULL) return false;
+  *out_code_object_processor = mapping->code_object_processor;
+  return true;
+}
+
+IREE_API_EXPORT iree_status_t
+iree_hal_amdgpu_target_id_lookup_code_object_target(
+    const iree_hal_amdgpu_target_id_t* exact_target_id,
+    iree_hal_amdgpu_target_id_t* out_code_object_target_id) {
+  IREE_ASSERT_ARGUMENT(exact_target_id);
+  IREE_ASSERT_ARGUMENT(out_code_object_target_id);
+  memset(out_code_object_target_id, 0, sizeof(*out_code_object_target_id));
+
+  if (exact_target_id->kind != IREE_HAL_AMDGPU_TARGET_KIND_EXACT) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "AMDGPU code-object target lookup requires an "
+                            "exact processor target ID");
+  }
+  iree_string_view_t code_object_processor = exact_target_id->processor;
+  iree_hal_amdgpu_target_id_lookup_code_object_processor(
+      exact_target_id->processor, &code_object_processor);
+  if (iree_string_view_equal(code_object_processor,
+                             exact_target_id->processor)) {
+    *out_code_object_target_id = *exact_target_id;
+  } else {
+    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_target_id_parse(
+        code_object_processor, IREE_HAL_AMDGPU_TARGET_ID_PARSE_FLAG_NONE,
+        out_code_object_target_id));
+  }
+  out_code_object_target_id->sramecc = exact_target_id->sramecc;
+  out_code_object_target_id->xnack = exact_target_id->xnack;
+  return iree_ok_status();
+}
+
+static bool iree_hal_amdgpu_generic_version_compatible(
+    iree_hal_amdgpu_gfxip_version_t code_object_version,
+    iree_hal_amdgpu_gfxip_version_t agent_version) {
+  if (code_object_version.major != agent_version.major) return false;
+  if (code_object_version.minor > agent_version.minor) return false;
+  if (code_object_version.minor == agent_version.minor &&
+      code_object_version.stepping > agent_version.stepping) {
+    return false;
+  }
+  return true;
+}
+
+static bool iree_hal_amdgpu_target_feature_compatible(
+    iree_hal_amdgpu_target_feature_state_t code_object_feature,
+    iree_hal_amdgpu_target_feature_state_t agent_feature) {
+  if (code_object_feature == IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_ON ||
+      code_object_feature == IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_OFF) {
+    return code_object_feature == agent_feature;
+  }
+  return true;
+}
+
+static uint32_t iree_hal_amdgpu_generic_code_object_minimum_version(
+    const iree_hal_amdgpu_target_id_t* generic_target_id) {
+  return generic_target_id->kind == IREE_HAL_AMDGPU_TARGET_KIND_GENERIC ? 1 : 0;
+}
+
+IREE_API_EXPORT iree_hal_amdgpu_target_compatibility_t
+iree_hal_amdgpu_target_id_check_compatible(
+    const iree_hal_amdgpu_target_id_t* code_object_target_id,
+    const iree_hal_amdgpu_target_id_t* agent_target_id) {
+  IREE_ASSERT_ARGUMENT(code_object_target_id);
+  IREE_ASSERT_ARGUMENT(agent_target_id);
+
+  iree_hal_amdgpu_target_compatibility_t compatibility =
+      IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_COMPATIBLE;
+  if (code_object_target_id->kind == IREE_HAL_AMDGPU_TARGET_KIND_EXACT) {
+    if (agent_target_id->kind != IREE_HAL_AMDGPU_TARGET_KIND_EXACT ||
+        !iree_hal_amdgpu_gfxip_version_equal(code_object_target_id->version,
+                                             agent_target_id->version)) {
+      compatibility |= IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_MISMATCH_PROCESSOR;
+    }
+  } else {
+    iree_string_view_t agent_code_object_processor = iree_string_view_empty();
+    if (agent_target_id->kind == IREE_HAL_AMDGPU_TARGET_KIND_EXACT) {
+      if (!iree_hal_amdgpu_target_id_lookup_code_object_processor(
+              agent_target_id->processor, &agent_code_object_processor)) {
+        if (!iree_hal_amdgpu_generic_version_compatible(
+                code_object_target_id->version, agent_target_id->version)) {
+          compatibility |=
+              IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_MISMATCH_GENERIC_FAMILY;
+        }
+      } else if (!iree_string_view_equal(code_object_target_id->processor,
+                                         agent_code_object_processor)) {
+        compatibility |=
+            IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_MISMATCH_GENERIC_FAMILY;
+      }
+    } else if (!iree_string_view_equal(code_object_target_id->processor,
+                                       agent_target_id->processor)) {
+      compatibility |=
+          IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_MISMATCH_GENERIC_FAMILY;
+    }
+    const uint32_t minimum_generic_version =
+        iree_hal_amdgpu_generic_code_object_minimum_version(
+            code_object_target_id);
+    if (code_object_target_id->generic_version != 0 &&
+        code_object_target_id->generic_version < minimum_generic_version) {
+      compatibility |=
+          IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_MISMATCH_GENERIC_VERSION;
+    }
+  }
+  if (!iree_hal_amdgpu_target_feature_compatible(code_object_target_id->sramecc,
+                                                 agent_target_id->sramecc)) {
+    compatibility |= IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_MISMATCH_SRAMECC;
+  }
+  if (!iree_hal_amdgpu_target_feature_compatible(code_object_target_id->xnack,
+                                                 agent_target_id->xnack)) {
+    compatibility |= IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_MISMATCH_XNACK;
+  }
+  return compatibility;
+}
+
+static void iree_hal_amdgpu_target_compatibility_formatter_append_reason(
+    iree_hal_amdgpu_target_id_formatter_t* formatter,
+    iree_host_size_t* inout_reason_count, iree_string_view_t reason) {
+  if (*inout_reason_count != 0) {
+    iree_hal_amdgpu_target_id_formatter_append(formatter, IREE_SV(", "));
+  }
+  iree_hal_amdgpu_target_id_formatter_append(formatter, reason);
+  ++*inout_reason_count;
+}
+
+IREE_API_EXPORT iree_status_t iree_hal_amdgpu_target_compatibility_format(
+    iree_hal_amdgpu_target_compatibility_t compatibility,
+    iree_host_size_t buffer_capacity, char* buffer,
+    iree_host_size_t* out_buffer_length) {
+  iree_hal_amdgpu_target_id_formatter_t formatter = {
+      .buffer = buffer,
+      .capacity = buffer_capacity,
+      .length = 0,
+  };
+  if (buffer != NULL && buffer_capacity > 0) buffer[0] = 0;
+
+  iree_host_size_t reason_count = 0;
+  if (compatibility == IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_COMPATIBLE) {
+    iree_hal_amdgpu_target_compatibility_formatter_append_reason(
+        &formatter, &reason_count, IREE_SV("compatible"));
+  }
+  if (iree_any_bit_set(
+          compatibility,
+          IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_MISMATCH_PROCESSOR)) {
+    iree_hal_amdgpu_target_compatibility_formatter_append_reason(
+        &formatter, &reason_count, IREE_SV("processor"));
+  }
+  if (iree_any_bit_set(
+          compatibility,
+          IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_MISMATCH_GENERIC_FAMILY)) {
+    iree_hal_amdgpu_target_compatibility_formatter_append_reason(
+        &formatter, &reason_count, IREE_SV("generic family"));
+  }
+  if (iree_any_bit_set(
+          compatibility,
+          IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_MISMATCH_GENERIC_VERSION)) {
+    iree_hal_amdgpu_target_compatibility_formatter_append_reason(
+        &formatter, &reason_count, IREE_SV("generic version"));
+  }
+  if (iree_any_bit_set(compatibility,
+                       IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_MISMATCH_SRAMECC)) {
+    iree_hal_amdgpu_target_compatibility_formatter_append_reason(
+        &formatter, &reason_count, IREE_SV("sramecc"));
+  }
+  if (iree_any_bit_set(compatibility,
+                       IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_MISMATCH_XNACK)) {
+    iree_hal_amdgpu_target_compatibility_formatter_append_reason(
+        &formatter, &reason_count, IREE_SV("xnack"));
+  }
+  if (out_buffer_length != NULL) {
+    *out_buffer_length = formatter.length;
+  }
+  if (buffer != NULL && buffer_capacity <= formatter.length) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "AMDGPU target compatibility buffer capacity exceeded");
+  }
+  return iree_ok_status();
+}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/target_id.h b/runtime/src/iree/hal/drivers/amdgpu/util/target_id.h
new file mode 100644
index 0000000..8d29829
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/target_id.h
@@ -0,0 +1,143 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_AMDGPU_UTIL_TARGET_ID_H_
+#define IREE_HAL_DRIVERS_AMDGPU_UTIL_TARGET_ID_H_
+
+#include "iree/base/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+//===----------------------------------------------------------------------===//
+// AMDGPU Target IDs
+//===----------------------------------------------------------------------===//
+
+// Parsed gfx IP version.
+typedef struct iree_hal_amdgpu_gfxip_version_t {
+  // Major gfx ISA version, such as 9, 10, 11, or 12.
+  uint32_t major;
+  // Minor gfx ISA version within |major|.
+  uint32_t minor;
+  // Stepping digit within |major|.|minor|.
+  uint32_t stepping;
+} iree_hal_amdgpu_gfxip_version_t;
+
+// Target feature selector state from AMDGPU target IDs.
+typedef enum iree_hal_amdgpu_target_feature_state_e {
+  // Feature is not represented by the parsed target ID.
+  IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_ANY = 0,
+  // Feature is known not to be supported by the target.
+  IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_UNSUPPORTED,
+  // Feature is explicitly disabled, such as `:xnack-`.
+  IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_OFF,
+  // Feature is explicitly enabled, such as `:sramecc+`.
+  IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_ON,
+} iree_hal_amdgpu_target_feature_state_t;
+
+// Target processor name class.
+typedef enum iree_hal_amdgpu_target_kind_e {
+  // Exact target processor such as `gfx942` or `gfx1100`.
+  IREE_HAL_AMDGPU_TARGET_KIND_EXACT = 0,
+  // Generic target processor such as `gfx9-4-generic` or `gfx11-generic`.
+  IREE_HAL_AMDGPU_TARGET_KIND_GENERIC,
+} iree_hal_amdgpu_target_kind_t;
+
+// Parsed AMDGPU target ID.
+typedef struct iree_hal_amdgpu_target_id_t {
+  // Processor name class used to interpret |version|.
+  iree_hal_amdgpu_target_kind_t kind;
+  // Parsed gfx IP version or generic family version.
+  iree_hal_amdgpu_gfxip_version_t version;
+  // Generic code-object format version from ELF e_flags, or 0 if unspecified.
+  uint32_t generic_version;
+  // SRAM ECC selector state.
+  iree_hal_amdgpu_target_feature_state_t sramecc;
+  // XNACK selector state.
+  iree_hal_amdgpu_target_feature_state_t xnack;
+  // Borrowed processor name without feature suffixes or HSA triple prefix.
+  iree_string_view_t processor;
+} iree_hal_amdgpu_target_id_t;
+
+// Parser modes for AMDGPU target IDs.
+typedef enum iree_hal_amdgpu_target_id_parse_flag_bits_e {
+  // Requires a bare processor name with no feature suffixes.
+  IREE_HAL_AMDGPU_TARGET_ID_PARSE_FLAG_NONE = 0u,
+  // Accepts `amdgcn-amd-amdhsa--`-prefixed HSA ISA names.
+  IREE_HAL_AMDGPU_TARGET_ID_PARSE_FLAG_ALLOW_HSA_PREFIX = 1u << 0,
+  // Accepts bare processor names such as `gfx942`.
+  IREE_HAL_AMDGPU_TARGET_ID_PARSE_FLAG_ALLOW_ARCH_ONLY = 1u << 1,
+  // Accepts AMDGPU feature suffixes such as `:sramecc+:xnack-`.
+  IREE_HAL_AMDGPU_TARGET_ID_PARSE_FLAG_ALLOW_FEATURE_SUFFIXES = 1u << 2,
+} iree_hal_amdgpu_target_id_parse_flag_bits_t;
+typedef uint32_t iree_hal_amdgpu_target_id_parse_flags_t;
+
+// Compatibility reasons reported by iree_hal_amdgpu_target_id_check_compatible.
+typedef enum iree_hal_amdgpu_target_compatibility_bits_e {
+  // Code object target identity is compatible with the agent target identity.
+  IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_COMPATIBLE = 0u,
+  // Exact processor versions do not match.
+  IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_MISMATCH_PROCESSOR = 1u << 0,
+  // Generic processor family does not match the agent's mapped family.
+  IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_MISMATCH_GENERIC_FAMILY = 1u << 1,
+  // Generic code-object version is older than the agent's supported floor.
+  IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_MISMATCH_GENERIC_VERSION = 1u << 2,
+  // Explicit SRAM ECC mode does not match the agent.
+  IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_MISMATCH_SRAMECC = 1u << 3,
+  // Explicit XNACK mode does not match the agent.
+  IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_MISMATCH_XNACK = 1u << 4,
+} iree_hal_amdgpu_target_compatibility_bits_t;
+typedef uint32_t iree_hal_amdgpu_target_compatibility_t;
+
+// Parses an AMDGPU target ID into |out_target_id|.
+//
+// String views in |out_target_id| borrow from |value| and are only valid while
+// |value| storage remains live.
+iree_status_t iree_hal_amdgpu_target_id_parse(
+    iree_string_view_t value, iree_hal_amdgpu_target_id_parse_flags_t flags,
+    iree_hal_amdgpu_target_id_t* out_target_id);
+
+// Parses an HSA ISA name reported by HSA_ISA_INFO_NAME.
+iree_status_t iree_hal_amdgpu_target_id_parse_hsa_isa_name(
+    iree_string_view_t value, iree_hal_amdgpu_target_id_t* out_target_id);
+
+// Formats |target_id| into canonical AMDGPU target-ID syntax.
+//
+// If |buffer_capacity| is insufficient, |out_buffer_length| still receives the
+// required character length excluding the NUL terminator.
+iree_status_t iree_hal_amdgpu_target_id_format(
+    const iree_hal_amdgpu_target_id_t* target_id,
+    iree_host_size_t buffer_capacity, char* buffer,
+    iree_host_size_t* out_buffer_length);
+
+// Maps an exact target ID to the processor used for code objects and device
+// libraries. If no generic-compatible mapping is known, the exact target is
+// returned unchanged.
+iree_status_t iree_hal_amdgpu_target_id_lookup_code_object_target(
+    const iree_hal_amdgpu_target_id_t* exact_target_id,
+    iree_hal_amdgpu_target_id_t* out_code_object_target_id);
+
+// Checks whether |code_object_target_id| can execute on |agent_target_id|.
+iree_hal_amdgpu_target_compatibility_t
+iree_hal_amdgpu_target_id_check_compatible(
+    const iree_hal_amdgpu_target_id_t* code_object_target_id,
+    const iree_hal_amdgpu_target_id_t* agent_target_id);
+
+// Formats compatibility mismatch bits into a comma-separated diagnostic string.
+//
+// If |buffer_capacity| is insufficient, |out_buffer_length| still receives the
+// required character length excluding the NUL terminator.
+iree_status_t iree_hal_amdgpu_target_compatibility_format(
+    iree_hal_amdgpu_target_compatibility_t compatibility,
+    iree_host_size_t buffer_capacity, char* buffer,
+    iree_host_size_t* out_buffer_length);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_DRIVERS_AMDGPU_UTIL_TARGET_ID_H_
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/target_id_map.inl b/runtime/src/iree/hal/drivers/amdgpu/util/target_id_map.inl
new file mode 100644
index 0000000..834f24b
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/target_id_map.inl
@@ -0,0 +1,45 @@
+// Generated by build_tools/scripts/amdgpu_target_map.py.
+// Do not edit directly; edit the map in that script and regenerate.
+// Output: runtime/src/iree/hal/drivers/amdgpu/util/target_id_map.inl
+//
+// Included inside iree_hal_amdgpu_target_id_mappings.
+
+// clang-format off
+{IREE_SVL("gfx900"), IREE_SVL("gfx9-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_XNACK},
+{IREE_SVL("gfx902"), IREE_SVL("gfx9-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_XNACK},
+{IREE_SVL("gfx904"), IREE_SVL("gfx9-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_XNACK},
+{IREE_SVL("gfx90c"), IREE_SVL("gfx9-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_XNACK},
+{IREE_SVL("gfx906"), IREE_SVL("gfx9-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_SRAMECC | IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_XNACK},
+{IREE_SVL("gfx908"), IREE_SVL("gfx908"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_SRAMECC | IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_XNACK},
+{IREE_SVL("gfx909"), IREE_SVL("gfx9-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_XNACK},
+{IREE_SVL("gfx90a"), IREE_SVL("gfx90a"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_SRAMECC | IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_XNACK},
+{IREE_SVL("gfx940"), IREE_SVL("gfx9-4-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_SRAMECC | IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_XNACK},
+{IREE_SVL("gfx941"), IREE_SVL("gfx9-4-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_SRAMECC | IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_XNACK},
+{IREE_SVL("gfx942"), IREE_SVL("gfx9-4-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_SRAMECC | IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_XNACK},
+{IREE_SVL("gfx950"), IREE_SVL("gfx9-4-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_SRAMECC | IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_XNACK},
+{IREE_SVL("gfx1010"), IREE_SVL("gfx10-1-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_XNACK},
+{IREE_SVL("gfx1011"), IREE_SVL("gfx10-1-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_XNACK},
+{IREE_SVL("gfx1012"), IREE_SVL("gfx10-1-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_XNACK},
+{IREE_SVL("gfx1013"), IREE_SVL("gfx10-1-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_XNACK},
+{IREE_SVL("gfx1030"), IREE_SVL("gfx10-3-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_NONE},
+{IREE_SVL("gfx1031"), IREE_SVL("gfx10-3-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_NONE},
+{IREE_SVL("gfx1032"), IREE_SVL("gfx10-3-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_NONE},
+{IREE_SVL("gfx1033"), IREE_SVL("gfx10-3-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_NONE},
+{IREE_SVL("gfx1034"), IREE_SVL("gfx10-3-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_NONE},
+{IREE_SVL("gfx1035"), IREE_SVL("gfx10-3-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_NONE},
+{IREE_SVL("gfx1036"), IREE_SVL("gfx10-3-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_NONE},
+{IREE_SVL("gfx1100"), IREE_SVL("gfx11-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_NONE},
+{IREE_SVL("gfx1101"), IREE_SVL("gfx11-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_NONE},
+{IREE_SVL("gfx1102"), IREE_SVL("gfx11-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_NONE},
+{IREE_SVL("gfx1103"), IREE_SVL("gfx11-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_NONE},
+{IREE_SVL("gfx1150"), IREE_SVL("gfx11-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_NONE},
+{IREE_SVL("gfx1151"), IREE_SVL("gfx11-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_NONE},
+{IREE_SVL("gfx1152"), IREE_SVL("gfx11-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_NONE},
+{IREE_SVL("gfx1153"), IREE_SVL("gfx11-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_NONE},
+{IREE_SVL("gfx1170"), IREE_SVL("gfx11-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_NONE},
+{IREE_SVL("gfx1171"), IREE_SVL("gfx11-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_NONE},
+{IREE_SVL("gfx1172"), IREE_SVL("gfx11-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_NONE},
+{IREE_SVL("gfx1200"), IREE_SVL("gfx12-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_NONE},
+{IREE_SVL("gfx1201"), IREE_SVL("gfx12-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_NONE},
+{IREE_SVL("gfx1250"), IREE_SVL("gfx12-5-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_NONE},
+{IREE_SVL("gfx1251"), IREE_SVL("gfx12-5-generic"), IREE_HAL_AMDGPU_TARGET_FEATURE_SUPPORT_NONE},
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/target_id_test.cc b/runtime/src/iree/hal/drivers/amdgpu/util/target_id_test.cc
new file mode 100644
index 0000000..852e5be
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/target_id_test.cc
@@ -0,0 +1,270 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/amdgpu/util/target_id.h"
+
+#include <cstring>
+#include <string>
+
+#include "iree/base/api.h"
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace iree::hal::amdgpu {
+namespace {
+
+static constexpr iree_hal_amdgpu_target_id_parse_flags_t
+    kArchFeatureParseFlags =
+        IREE_HAL_AMDGPU_TARGET_ID_PARSE_FLAG_ALLOW_ARCH_ONLY |
+        IREE_HAL_AMDGPU_TARGET_ID_PARSE_FLAG_ALLOW_FEATURE_SUFFIXES;
+
+static iree_hal_amdgpu_target_id_t ParseTargetId(
+    const char* value, iree_hal_amdgpu_target_id_parse_flags_t flags =
+                           IREE_HAL_AMDGPU_TARGET_ID_PARSE_FLAG_NONE) {
+  iree_hal_amdgpu_target_id_t target_id;
+  IREE_CHECK_OK(iree_hal_amdgpu_target_id_parse(iree_make_cstring_view(value),
+                                                flags, &target_id));
+  return target_id;
+}
+
+static std::string FormatTargetId(
+    const iree_hal_amdgpu_target_id_t* target_id) {
+  char buffer[64] = {0};
+  IREE_CHECK_OK(iree_hal_amdgpu_target_id_format(
+      target_id, sizeof(buffer), buffer, /*out_buffer_length=*/nullptr));
+  return std::string(buffer);
+}
+
+TEST(TargetIdTest, ParsesExactProcessor) {
+  auto target_id = ParseTargetId("gfx1100");
+  EXPECT_EQ(target_id.kind, IREE_HAL_AMDGPU_TARGET_KIND_EXACT);
+  EXPECT_EQ(target_id.version.major, 11u);
+  EXPECT_EQ(target_id.version.minor, 0u);
+  EXPECT_EQ(target_id.version.stepping, 0u);
+  EXPECT_TRUE(iree_string_view_equal(target_id.processor, IREE_SV("gfx1100")));
+  EXPECT_EQ(target_id.sramecc,
+            IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_UNSUPPORTED);
+  EXPECT_EQ(target_id.xnack, IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_UNSUPPORTED);
+}
+
+TEST(TargetIdTest, ParsesExactProcessorWithHexStepping) {
+  auto target_id = ParseTargetId("gfx90a");
+  EXPECT_EQ(target_id.kind, IREE_HAL_AMDGPU_TARGET_KIND_EXACT);
+  EXPECT_EQ(target_id.version.major, 9u);
+  EXPECT_EQ(target_id.version.minor, 0u);
+  EXPECT_EQ(target_id.version.stepping, 10u);
+}
+
+TEST(TargetIdTest, ParsesGenericProcessor) {
+  auto target_id = ParseTargetId("gfx9-4-generic");
+  EXPECT_EQ(target_id.kind, IREE_HAL_AMDGPU_TARGET_KIND_GENERIC);
+  EXPECT_EQ(target_id.version.major, 9u);
+  EXPECT_EQ(target_id.version.minor, 4u);
+  EXPECT_EQ(target_id.version.stepping, 0u);
+  EXPECT_TRUE(
+      iree_string_view_equal(target_id.processor, IREE_SV("gfx9-4-generic")));
+
+  target_id = ParseTargetId("gfx11-generic");
+  EXPECT_EQ(target_id.kind, IREE_HAL_AMDGPU_TARGET_KIND_GENERIC);
+  EXPECT_EQ(target_id.version.major, 11u);
+  EXPECT_EQ(target_id.version.minor, 0u);
+  EXPECT_EQ(target_id.version.stepping, 0u);
+}
+
+TEST(TargetIdTest, ParsesKnownFeatureSupport) {
+  auto target_id = ParseTargetId("gfx942");
+  EXPECT_EQ(target_id.sramecc, IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_ANY);
+  EXPECT_EQ(target_id.xnack, IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_ANY);
+
+  target_id = ParseTargetId("gfx1030");
+  EXPECT_EQ(target_id.sramecc,
+            IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_UNSUPPORTED);
+  EXPECT_EQ(target_id.xnack, IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_UNSUPPORTED);
+
+  target_id = ParseTargetId("gfx1013");
+  EXPECT_EQ(target_id.sramecc,
+            IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_UNSUPPORTED);
+  EXPECT_EQ(target_id.xnack, IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_ANY);
+}
+
+TEST(TargetIdTest, ParsesHsaIsaNameWithFeatureSuffixes) {
+  iree_hal_amdgpu_target_id_t target_id;
+  IREE_ASSERT_OK(iree_hal_amdgpu_target_id_parse_hsa_isa_name(
+      IREE_SV("amdgcn-amd-amdhsa--gfx942:xnack-:sramecc+"), &target_id));
+  EXPECT_EQ(target_id.kind, IREE_HAL_AMDGPU_TARGET_KIND_EXACT);
+  EXPECT_EQ(target_id.version.major, 9u);
+  EXPECT_EQ(target_id.version.minor, 4u);
+  EXPECT_EQ(target_id.version.stepping, 2u);
+  EXPECT_EQ(target_id.sramecc, IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_ON);
+  EXPECT_EQ(target_id.xnack, IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_OFF);
+  EXPECT_EQ(FormatTargetId(&target_id), "gfx942:sramecc+:xnack-");
+}
+
+TEST(TargetIdTest, RejectsUnsupportedSyntax) {
+  iree_hal_amdgpu_target_id_t target_id;
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_INVALID_ARGUMENT,
+      iree_hal_amdgpu_target_id_parse(IREE_SV("amdgcn-amd-amdhsa--gfx942"),
+                                      IREE_HAL_AMDGPU_TARGET_ID_PARSE_FLAG_NONE,
+                                      &target_id));
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_INVALID_ARGUMENT,
+      iree_hal_amdgpu_target_id_parse(IREE_SV("gfx942:xnack+"),
+                                      IREE_HAL_AMDGPU_TARGET_ID_PARSE_FLAG_NONE,
+                                      &target_id));
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_INVALID_ARGUMENT,
+      iree_hal_amdgpu_target_id_parse(IREE_SV("gfx942foo"),
+                                      IREE_HAL_AMDGPU_TARGET_ID_PARSE_FLAG_NONE,
+                                      &target_id));
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_INVALID_ARGUMENT,
+      iree_hal_amdgpu_target_id_parse(IREE_SV("gfx942:xnack+:xnack-"),
+                                      kArchFeatureParseFlags, &target_id));
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_INVALID_ARGUMENT,
+      iree_hal_amdgpu_target_id_parse(IREE_SV("gfx942:wavefrontsize64+"),
+                                      kArchFeatureParseFlags, &target_id));
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_INVALID_ARGUMENT,
+      iree_hal_amdgpu_target_id_parse(IREE_SV("gfx942:"),
+                                      kArchFeatureParseFlags, &target_id));
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_INVALID_ARGUMENT,
+      iree_hal_amdgpu_target_id_parse(IREE_SV("gfx942:xnack+:"),
+                                      kArchFeatureParseFlags, &target_id));
+}
+
+TEST(TargetIdTest, FormatsIntoQueriedBufferLength) {
+  auto target_id =
+      ParseTargetId("gfx942:sramecc+:xnack-", kArchFeatureParseFlags);
+  iree_host_size_t required_length = 0;
+  IREE_EXPECT_OK(iree_hal_amdgpu_target_id_format(
+      &target_id, /*buffer_capacity=*/0, /*buffer=*/nullptr, &required_length));
+  EXPECT_EQ(required_length, strlen("gfx942:sramecc+:xnack-"));
+
+  char buffer[8] = {0};
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_OUT_OF_RANGE,
+      iree_hal_amdgpu_target_id_format(&target_id, sizeof(buffer), buffer,
+                                       &required_length));
+  EXPECT_EQ(required_length, strlen("gfx942:sramecc+:xnack-"));
+}
+
+TEST(TargetIdTest, LooksUpCodeObjectTarget) {
+  auto target_id =
+      ParseTargetId("gfx942:sramecc+:xnack-", kArchFeatureParseFlags);
+  iree_hal_amdgpu_target_id_t code_object_target_id;
+  IREE_ASSERT_OK(iree_hal_amdgpu_target_id_lookup_code_object_target(
+      &target_id, &code_object_target_id));
+  EXPECT_EQ(code_object_target_id.kind, IREE_HAL_AMDGPU_TARGET_KIND_GENERIC);
+  EXPECT_TRUE(iree_string_view_equal(code_object_target_id.processor,
+                                     IREE_SV("gfx9-4-generic")));
+  EXPECT_EQ(code_object_target_id.sramecc,
+            IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_ON);
+  EXPECT_EQ(code_object_target_id.xnack,
+            IREE_HAL_AMDGPU_TARGET_FEATURE_STATE_OFF);
+
+  target_id = ParseTargetId("gfx908");
+  IREE_ASSERT_OK(iree_hal_amdgpu_target_id_lookup_code_object_target(
+      &target_id, &code_object_target_id));
+  EXPECT_EQ(code_object_target_id.kind, IREE_HAL_AMDGPU_TARGET_KIND_EXACT);
+  EXPECT_TRUE(iree_string_view_equal(code_object_target_id.processor,
+                                     IREE_SV("gfx908")));
+
+  target_id = ParseTargetId("gfx1300");
+  IREE_ASSERT_OK(iree_hal_amdgpu_target_id_lookup_code_object_target(
+      &target_id, &code_object_target_id));
+  EXPECT_EQ(code_object_target_id.kind, IREE_HAL_AMDGPU_TARGET_KIND_EXACT);
+  EXPECT_TRUE(iree_string_view_equal(code_object_target_id.processor,
+                                     IREE_SV("gfx1300")));
+}
+
+TEST(TargetIdTest, ChecksExactCompatibility) {
+  auto code_object_target_id = ParseTargetId("gfx1100");
+  auto agent_target_id = ParseTargetId("gfx1100");
+  EXPECT_EQ(iree_hal_amdgpu_target_id_check_compatible(&code_object_target_id,
+                                                       &agent_target_id),
+            IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_COMPATIBLE);
+
+  agent_target_id = ParseTargetId("gfx1101");
+  EXPECT_TRUE(iree_any_bit_set(
+      iree_hal_amdgpu_target_id_check_compatible(&code_object_target_id,
+                                                 &agent_target_id),
+      IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_MISMATCH_PROCESSOR));
+}
+
+TEST(TargetIdTest, ChecksGenericCompatibilityWithMappedFamily) {
+  auto code_object_target_id = ParseTargetId("gfx11-generic");
+  auto agent_target_id = ParseTargetId("gfx1100");
+  EXPECT_EQ(iree_hal_amdgpu_target_id_check_compatible(&code_object_target_id,
+                                                       &agent_target_id),
+            IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_COMPATIBLE);
+
+  code_object_target_id = ParseTargetId("gfx9-4-generic");
+  agent_target_id = ParseTargetId("gfx942");
+  EXPECT_EQ(iree_hal_amdgpu_target_id_check_compatible(&code_object_target_id,
+                                                       &agent_target_id),
+            IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_COMPATIBLE);
+
+  code_object_target_id = ParseTargetId("gfx9-generic");
+  EXPECT_TRUE(iree_any_bit_set(
+      iree_hal_amdgpu_target_id_check_compatible(&code_object_target_id,
+                                                 &agent_target_id),
+      IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_MISMATCH_GENERIC_FAMILY));
+
+  code_object_target_id = ParseTargetId("gfx9-4-generic");
+  agent_target_id = ParseTargetId("gfx940");
+  EXPECT_EQ(iree_hal_amdgpu_target_id_check_compatible(&code_object_target_id,
+                                                       &agent_target_id),
+            IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_COMPATIBLE);
+
+  code_object_target_id = ParseTargetId("gfx9-generic");
+  EXPECT_TRUE(iree_any_bit_set(
+      iree_hal_amdgpu_target_id_check_compatible(&code_object_target_id,
+                                                 &agent_target_id),
+      IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_MISMATCH_GENERIC_FAMILY));
+
+  code_object_target_id = ParseTargetId("gfx12-5-generic");
+  agent_target_id = ParseTargetId("gfx1250");
+  EXPECT_EQ(iree_hal_amdgpu_target_id_check_compatible(&code_object_target_id,
+                                                       &agent_target_id),
+            IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_COMPATIBLE);
+
+  code_object_target_id = ParseTargetId("gfx12-generic");
+  EXPECT_TRUE(iree_any_bit_set(
+      iree_hal_amdgpu_target_id_check_compatible(&code_object_target_id,
+                                                 &agent_target_id),
+      IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_MISMATCH_GENERIC_FAMILY));
+}
+
+TEST(TargetIdTest, ChecksFeatureCompatibility) {
+  auto code_object_target_id =
+      ParseTargetId("gfx942:xnack+", kArchFeatureParseFlags);
+  auto agent_target_id = ParseTargetId("gfx942:xnack-", kArchFeatureParseFlags);
+  EXPECT_TRUE(
+      iree_any_bit_set(iree_hal_amdgpu_target_id_check_compatible(
+                           &code_object_target_id, &agent_target_id),
+                       IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_MISMATCH_XNACK));
+
+  code_object_target_id = ParseTargetId("gfx942");
+  EXPECT_EQ(iree_hal_amdgpu_target_id_check_compatible(&code_object_target_id,
+                                                       &agent_target_id),
+            IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_COMPATIBLE);
+}
+
+TEST(TargetIdTest, FormatsCompatibilityReasons) {
+  char buffer[64] = {0};
+  IREE_ASSERT_OK(iree_hal_amdgpu_target_compatibility_format(
+      IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_MISMATCH_GENERIC_FAMILY |
+          IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_MISMATCH_SRAMECC |
+          IREE_HAL_AMDGPU_TARGET_COMPATIBILITY_MISMATCH_XNACK,
+      sizeof(buffer), buffer, /*out_buffer_length=*/nullptr));
+  EXPECT_STREQ(buffer, "generic family, sramecc, xnack");
+}
+
+}  // namespace
+}  // namespace iree::hal::amdgpu
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/topology.c b/runtime/src/iree/hal/drivers/amdgpu/util/topology.c
index 89b4dcb..4d37029 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/util/topology.c
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/topology.c
@@ -6,10 +6,119 @@
 
 #include "iree/hal/drivers/amdgpu/util/topology.h"
 
+#include "iree/hal/api.h"
+
 //===----------------------------------------------------------------------===//
 // iree_hal_amdgpu_topology_t
 //===----------------------------------------------------------------------===//
 
+static bool iree_hal_amdgpu_agent_list_contains(const hsa_agent_t* agents,
+                                                iree_host_size_t agent_count,
+                                                hsa_agent_t agent) {
+  for (iree_host_size_t i = 0; i < agent_count; ++i) {
+    if (agents[i].handle == agent.handle) return true;
+  }
+  return false;
+}
+
+static iree_status_t iree_hal_amdgpu_topology_verify_storage_counts(
+    const iree_hal_amdgpu_topology_t* topology) {
+  if (topology->all_agent_count > IREE_ARRAYSIZE(topology->all_agents)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "all_agent_count=%" PRIhsz
+                            " exceeds topology storage capacity %" PRIhsz,
+                            topology->all_agent_count,
+                            IREE_ARRAYSIZE(topology->all_agents));
+  }
+  if (topology->cpu_agent_count > IREE_ARRAYSIZE(topology->cpu_agents)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "cpu_agent_count=%" PRIhsz
+                            " exceeds topology storage capacity %" PRIhsz,
+                            topology->cpu_agent_count,
+                            IREE_ARRAYSIZE(topology->cpu_agents));
+  }
+  if (topology->gpu_agent_count > IREE_ARRAYSIZE(topology->gpu_agents)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "gpu_agent_count=%" PRIhsz
+                            " exceeds topology storage capacity %" PRIhsz,
+                            topology->gpu_agent_count,
+                            IREE_ARRAYSIZE(topology->gpu_agents));
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_topology_verify_unique_agent_list(
+    const char* list_name, const hsa_agent_t* agents,
+    iree_host_size_t agent_count) {
+  for (iree_host_size_t i = 0; i < agent_count; ++i) {
+    for (iree_host_size_t j = i + 1; j < agent_count; ++j) {
+      if (agents[i].handle == agents[j].handle) {
+        return iree_make_status(
+            IREE_STATUS_INVALID_ARGUMENT,
+            "topology %s contains duplicate agent handles at ordinals %" PRIhsz
+            " and %" PRIhsz,
+            list_name, i, j);
+      }
+    }
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_amdgpu_topology_verify_agent_sets(
+    const iree_hal_amdgpu_topology_t* topology) {
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_topology_verify_unique_agent_list(
+      "all_agents", topology->all_agents, topology->all_agent_count));
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_topology_verify_unique_agent_list(
+      "cpu_agents", topology->cpu_agents, topology->cpu_agent_count));
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_topology_verify_unique_agent_list(
+      "gpu_agents", topology->gpu_agents, topology->gpu_agent_count));
+
+  for (iree_host_size_t i = 0; i < topology->cpu_agent_count; ++i) {
+    hsa_agent_t cpu_agent = topology->cpu_agents[i];
+    if (iree_hal_amdgpu_agent_list_contains(
+            topology->gpu_agents, topology->gpu_agent_count, cpu_agent)) {
+      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                              "topology CPU agent ordinal %" PRIhsz
+                              " is also present in gpu_agents",
+                              i);
+    }
+    if (!iree_hal_amdgpu_agent_list_contains(
+            topology->all_agents, topology->all_agent_count, cpu_agent)) {
+      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                              "topology CPU agent ordinal %" PRIhsz
+                              " is missing from all_agents",
+                              i);
+    }
+  }
+
+  for (iree_host_size_t i = 0; i < topology->gpu_agent_count; ++i) {
+    hsa_agent_t gpu_agent = topology->gpu_agents[i];
+    if (!iree_hal_amdgpu_agent_list_contains(
+            topology->all_agents, topology->all_agent_count, gpu_agent)) {
+      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                              "topology GPU agent ordinal %" PRIhsz
+                              " is missing from all_agents",
+                              i);
+    }
+  }
+
+  for (iree_host_size_t i = 0; i < topology->all_agent_count; ++i) {
+    hsa_agent_t agent = topology->all_agents[i];
+    const bool is_cpu = iree_hal_amdgpu_agent_list_contains(
+        topology->cpu_agents, topology->cpu_agent_count, agent);
+    const bool is_gpu = iree_hal_amdgpu_agent_list_contains(
+        topology->gpu_agents, topology->gpu_agent_count, agent);
+    if (!is_cpu && !is_gpu) {
+      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                              "topology all_agents ordinal %" PRIhsz
+                              " is not present in cpu_agents or gpu_agents",
+                              i);
+    }
+  }
+
+  return iree_ok_status();
+}
+
 IREE_API_EXPORT void iree_hal_amdgpu_topology_initialize(
     iree_hal_amdgpu_topology_t* out_topology) {
   IREE_ASSERT_ARGUMENT(out_topology);
@@ -17,7 +126,8 @@
 
   memset(out_topology, 0, sizeof(*out_topology));
 
-  out_topology->gpu_agent_queue_count = 1;
+  out_topology->gpu_agent_queue_count =
+      IREE_HAL_AMDGPU_DEFAULT_GPU_AGENT_QUEUE_COUNT;
 
   IREE_TRACE_ZONE_END(z0);
 }
@@ -40,6 +150,9 @@
   IREE_ASSERT_ARGUMENT(libhsa);
   if (out_index) *out_index = 0;
 
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_topology_verify_storage_counts(topology));
+
   // Scan for the agent in the current topology.
   for (iree_host_size_t i = 0; i < topology->cpu_agent_count; ++i) {
     if (topology->cpu_agents[i].handle == cpu_agent.handle) {
@@ -49,11 +162,17 @@
   }
 
   // Check capacity before mutating the topology.
-  if (topology->cpu_agent_count + 1 >= IREE_ARRAYSIZE(topology->cpu_agents)) {
+  if (topology->cpu_agent_count >= IREE_ARRAYSIZE(topology->cpu_agents)) {
     return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
                             "max CPU agent count reached (limit %" PRIhsz ")",
                             IREE_ARRAYSIZE(topology->cpu_agents));
   }
+  if (topology->all_agent_count >= IREE_ARRAYSIZE(topology->all_agents)) {
+    return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
+                            "max topology agent count reached (limit %" PRIhsz
+                            ")",
+                            IREE_ARRAYSIZE(topology->all_agents));
+  }
 
   // Verify the agent is a supported CPU agent.
   hsa_device_type_t device_type = 0;
@@ -82,19 +201,32 @@
   IREE_ASSERT_ARGUMENT(topology);
   IREE_ASSERT_ARGUMENT(libhsa);
 
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_topology_verify_storage_counts(topology));
+
   // Ignore if the GPU agent has already been added.
   for (iree_host_size_t i = 0; i < topology->gpu_agent_count; ++i) {
     if (topology->gpu_agents[i].handle == gpu_agent.handle) {
-      return iree_ok_status();  // already present
+      return iree_ok_status();
     }
   }
 
   // Check capacity before mutating the topology.
-  if (topology->gpu_agent_count + 1 >= IREE_ARRAYSIZE(topology->gpu_agents)) {
+  if (topology->gpu_agent_count >= IREE_ARRAYSIZE(topology->gpu_agents)) {
     return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
                             "max GPU agent count reached (limit %" PRIhsz ")",
                             IREE_ARRAYSIZE(topology->gpu_agents));
   }
+  const bool cpu_agent_present = iree_hal_amdgpu_agent_list_contains(
+      topology->cpu_agents, topology->cpu_agent_count, cpu_agent);
+  const iree_host_size_t required_all_agent_slots = cpu_agent_present ? 1 : 2;
+  if (topology->all_agent_count >
+      IREE_ARRAYSIZE(topology->all_agents) - required_all_agent_slots) {
+    return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
+                            "max topology agent count reached (limit %" PRIhsz
+                            ")",
+                            IREE_ARRAYSIZE(topology->all_agents));
+  }
 
   // Verify the agent is a supported GPU agent.
   hsa_device_type_t device_type = 0;
@@ -258,6 +390,9 @@
   IREE_ASSERT_ARGUMENT(topology);
   IREE_ASSERT_ARGUMENT(libhsa);
 
+  IREE_RETURN_IF_ERROR(
+      iree_hal_amdgpu_topology_verify_storage_counts(topology));
+
   // Must have at least one of each agent type in the topology.
   // This is just a guard for creating systems that don't have any GPUs so that
   // code in the implementation can assume that there's always _something_ to
@@ -273,6 +408,49 @@
         topology->gpu_agent_queue_count);
   }
 
+  const iree_host_size_t expected_all_agent_count =
+      topology->cpu_agent_count + topology->gpu_agent_count;
+  if (topology->all_agent_count != expected_all_agent_count) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "topology all_agent_count=%" PRIhsz
+                            " does not match cpu_agent_count + "
+                            "gpu_agent_count (%" PRIhsz ")",
+                            topology->all_agent_count,
+                            expected_all_agent_count);
+  }
+
+  if (topology->gpu_agent_queue_count > UINT8_MAX) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "gpu_agent_queue_count=%" PRIhsz
+                            " exceeds the queue-axis encoding limit (%u)",
+                            topology->gpu_agent_queue_count, UINT8_MAX);
+  }
+  iree_host_size_t total_queue_count = 0;
+  if (!iree_host_size_checked_mul(topology->gpu_agent_count,
+                                  topology->gpu_agent_queue_count,
+                                  &total_queue_count) ||
+      total_queue_count > IREE_HAL_MAX_QUEUES) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "topology queue space does not fit in iree_hal_queue_affinity_t "
+        "(gpu_agent_count=%" PRIhsz ", gpu_agent_queue_count=%" PRIhsz
+        ", max_total_queues=%" PRIhsz ")",
+        topology->gpu_agent_count, topology->gpu_agent_queue_count,
+        (iree_host_size_t)IREE_HAL_MAX_QUEUES);
+  }
+
+  for (iree_host_size_t i = 0; i < topology->gpu_agent_count; ++i) {
+    if (topology->gpu_cpu_map[i] >= topology->cpu_agent_count) {
+      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                              "topology gpu_cpu_map[%" PRIhsz
+                              "]=%u exceeds cpu_agent_count=%" PRIhsz,
+                              i, topology->gpu_cpu_map[i],
+                              topology->cpu_agent_count);
+    }
+  }
+
+  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_topology_verify_agent_sets(topology));
+
   // Ensure all CPU agents are compatible with each other.
   for (iree_host_size_t i = 1; i < topology->cpu_agent_count; ++i) {
     bool are_compatible = false;
@@ -555,13 +733,25 @@
   IREE_RETURN_AND_END_ZONE_IF_ERROR(
       z0, iree_hal_amdgpu_query_available_agents(libhsa, &agents));
 
+  const iree_hal_amdgpu_gpu_agent_mask_t valid_gpu_agent_mask =
+      agents.gpu_agent_count >= 64 ? UINT64_MAX
+                                   : ((1ull << agents.gpu_agent_count) - 1ull);
+  if ((gpu_agent_mask & ~valid_gpu_agent_mask) != 0) {
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                             "GPU agent mask 0x%016" PRIx64
+                             " selects unavailable GPU ordinals; only %" PRIhsz
+                             " visible GPU agent(s) are available",
+                             gpu_agent_mask, agents.gpu_agent_count));
+  }
+
   // Initialize an empty topology.
   iree_hal_amdgpu_topology_initialize(out_topology);
 
   // Add each device to the topology.
   iree_status_t status = iree_ok_status();
-  for (iree_host_size_t gpu_ordinal = 0;
-       gpu_ordinal < IREE_ARRAYSIZE(agents.gpu_agents); ++gpu_ordinal) {
+  for (iree_host_size_t gpu_ordinal = 0; gpu_ordinal < agents.gpu_agent_count;
+       ++gpu_ordinal) {
     if ((gpu_agent_mask & (1ull << gpu_ordinal)) == 0) continue;
     status = iree_hal_amdgpu_topology_insert_gpu_agent_with_nearest_cpu_agent(
         out_topology, libhsa, agents.gpu_agents[gpu_ordinal]);
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/topology.h b/runtime/src/iree/hal/drivers/amdgpu/util/topology.h
index 5a07f23..870ebef 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/util/topology.h
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/topology.h
@@ -21,6 +21,9 @@
 #define IREE_HAL_AMDGPU_MAX_CPU_AGENT 64
 #define IREE_HAL_AMDGPU_MAX_GPU_AGENT 64
 
+// Default number of logical HAL queues exposed per GPU agent.
+#define IREE_HAL_AMDGPU_DEFAULT_GPU_AGENT_QUEUE_COUNT (2)
+
 // Defines a system topology specifying which agents are to be used by the HAL.
 //
 // Today many internal structures assume at most 64 CPU agents and 64 GPU
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/topology_test.cc b/runtime/src/iree/hal/drivers/amdgpu/util/topology_test.cc
index b6dcd8b..857de54 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/util/topology_test.cc
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/topology_test.cc
@@ -6,7 +6,10 @@
 
 #include "iree/hal/drivers/amdgpu/util/topology.h"
 
+#include <string>
+
 #include "iree/base/api.h"
+#include "iree/hal/api.h"
 #include "iree/hal/drivers/amdgpu/util/libhsa.h"
 #include "iree/testing/gtest.h"
 #include "iree/testing/status_matchers.h"
@@ -14,7 +17,107 @@
 namespace iree::hal::amdgpu {
 namespace {
 
-using iree::testing::status::StatusIs;
+static hsa_agent_t MakeFakeAgent(uint64_t handle) {
+  hsa_agent_t agent;
+  agent.handle = handle;
+  return agent;
+}
+
+static const iree_hal_amdgpu_libhsa_t* FakeLibHsa() {
+  static const iree_hal_amdgpu_libhsa_t libhsa = {};
+  return &libhsa;
+}
+
+static iree_hal_amdgpu_topology_t MakeStructurallyValidTopology() {
+  iree_hal_amdgpu_topology_t topology;
+  iree_hal_amdgpu_topology_initialize(&topology);
+  topology.gpu_agent_queue_count = 1;
+  topology.cpu_agent_count = 1;
+  topology.cpu_agents[0] = MakeFakeAgent(1);
+  topology.gpu_agent_count = 1;
+  topology.gpu_agents[0] = MakeFakeAgent(2);
+  topology.all_agent_count = 2;
+  topology.all_agents[0] = topology.cpu_agents[0];
+  topology.all_agents[1] = topology.gpu_agents[0];
+  topology.gpu_cpu_map[0] = 0;
+  return topology;
+}
+
+static void ExpectTopologyHasTwoGpus(
+    const iree_hal_amdgpu_topology_t& topology) {
+  EXPECT_GE(topology.all_agent_count, 3);
+  EXPECT_GE(topology.cpu_agent_count, 1);
+  ASSERT_EQ(topology.gpu_agent_count, 2);
+  EXPECT_GE(topology.gpu_agent_queue_count, 1);
+  for (iree_host_size_t i = 0; i < topology.gpu_agent_count; ++i) {
+    EXPECT_LT(topology.gpu_cpu_map[i], topology.cpu_agent_count);
+  }
+}
+
+static iree_status_t AppendAgentUuidPathFragment(
+    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_agent_t agent,
+    std::string* path) {
+  char agent_uuid[32] = {0};
+  IREE_RETURN_IF_ERROR(iree_hsa_agent_get_info(
+      IREE_LIBHSA(libhsa), agent, (hsa_agent_info_t)HSA_AMD_AGENT_INFO_UUID,
+      agent_uuid));
+  if (!path->empty()) path->append(",");
+  path->append(agent_uuid);
+  return iree_ok_status();
+}
+
+TEST(TopologyStructureTest, VerifyAcceptsStructurallyValidTopology) {
+  iree_hal_amdgpu_topology_t topology = MakeStructurallyValidTopology();
+  IREE_EXPECT_OK(iree_hal_amdgpu_topology_verify(&topology, FakeLibHsa()));
+}
+
+TEST(TopologyStructureTest, VerifyRejectsStorageCountBeyondCapacity) {
+  iree_hal_amdgpu_topology_t topology = MakeStructurallyValidTopology();
+  topology.cpu_agent_count = IREE_HAL_AMDGPU_MAX_CPU_AGENT + 1;
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_INVALID_ARGUMENT,
+      iree_hal_amdgpu_topology_verify(&topology, FakeLibHsa()));
+}
+
+TEST(TopologyStructureTest, VerifyRejectsAllAgentCountMismatch) {
+  iree_hal_amdgpu_topology_t topology = MakeStructurallyValidTopology();
+  topology.all_agent_count = 1;
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_INVALID_ARGUMENT,
+      iree_hal_amdgpu_topology_verify(&topology, FakeLibHsa()));
+}
+
+TEST(TopologyStructureTest, VerifyRejectsQueueSpaceOverflow) {
+  iree_hal_amdgpu_topology_t topology = MakeStructurallyValidTopology();
+  topology.gpu_agent_queue_count = IREE_HAL_MAX_QUEUES + 1;
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_OUT_OF_RANGE,
+      iree_hal_amdgpu_topology_verify(&topology, FakeLibHsa()));
+}
+
+TEST(TopologyStructureTest, VerifyRejectsGpuCpuMapOutOfRange) {
+  iree_hal_amdgpu_topology_t topology = MakeStructurallyValidTopology();
+  topology.gpu_cpu_map[0] = 1;
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_INVALID_ARGUMENT,
+      iree_hal_amdgpu_topology_verify(&topology, FakeLibHsa()));
+}
+
+TEST(TopologyStructureTest, VerifyRejectsDuplicateAllAgents) {
+  iree_hal_amdgpu_topology_t topology = MakeStructurallyValidTopology();
+  topology.all_agents[1] = topology.all_agents[0];
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_INVALID_ARGUMENT,
+      iree_hal_amdgpu_topology_verify(&topology, FakeLibHsa()));
+}
+
+TEST(TopologyStructureTest, VerifyRejectsGpuMissingFromAllAgents) {
+  iree_hal_amdgpu_topology_t topology = MakeStructurallyValidTopology();
+  topology.all_agents[1] = MakeFakeAgent(3);
+  IREE_EXPECT_STATUS_IS(
+      IREE_STATUS_INVALID_ARGUMENT,
+      iree_hal_amdgpu_topology_verify(&topology, FakeLibHsa()));
+}
 
 struct TopologyTest : public ::testing::Test {
   static iree_allocator_t host_allocator;
@@ -28,7 +131,7 @@
         host_allocator, &libhsa);
     if (!iree_status_is_ok(status)) {
       iree_status_fprint(stderr, status);
-      iree_status_ignore(status);
+      iree_status_free(status);
       GTEST_SKIP() << "HSA not available, skipping tests";
     }
   }
@@ -45,8 +148,77 @@
   iree_hal_amdgpu_topology_t topology;
   iree_hal_amdgpu_topology_initialize(&topology);
   // Need at least 1 CPU and GPU agent.
-  EXPECT_THAT(Status(iree_hal_amdgpu_topology_verify(&topology, &libhsa)),
-              StatusIs(StatusCode::kInvalidArgument));
+  IREE_EXPECT_STATUS_IS(IREE_STATUS_INVALID_ARGUMENT,
+                        iree_hal_amdgpu_topology_verify(&topology, &libhsa));
+  iree_hal_amdgpu_topology_deinitialize(&topology);
+}
+
+TEST_F(TopologyTest, InsertCpuAgentAllowsLastSlot) {
+  iree_hal_amdgpu_topology_t defaults;
+  IREE_ASSERT_OK(
+      iree_hal_amdgpu_topology_initialize_with_defaults(&libhsa, &defaults));
+  if (defaults.cpu_agent_count == 0) {
+    iree_hal_amdgpu_topology_deinitialize(&defaults);
+    GTEST_SKIP() << "no CPU agents found";
+    return;
+  }
+  hsa_agent_t cpu_agent = defaults.cpu_agents[0];
+  iree_hal_amdgpu_topology_deinitialize(&defaults);
+
+  iree_hal_amdgpu_topology_t topology;
+  iree_hal_amdgpu_topology_initialize(&topology);
+  topology.cpu_agent_count = IREE_HAL_AMDGPU_MAX_CPU_AGENT - 1;
+  topology.all_agent_count = topology.cpu_agent_count;
+  for (iree_host_size_t i = 0; i < topology.cpu_agent_count; ++i) {
+    hsa_agent_t fake_agent = MakeFakeAgent(0xCAFE000000000000ull + i);
+    topology.cpu_agents[i] = fake_agent;
+    topology.all_agents[i] = fake_agent;
+  }
+
+  iree_host_size_t index = 0;
+  IREE_ASSERT_OK(iree_hal_amdgpu_topology_insert_cpu_agent(&topology, &libhsa,
+                                                           cpu_agent, &index));
+  EXPECT_EQ(index, IREE_HAL_AMDGPU_MAX_CPU_AGENT - 1);
+  EXPECT_EQ(topology.cpu_agent_count, IREE_HAL_AMDGPU_MAX_CPU_AGENT);
+  EXPECT_EQ(topology.all_agent_count, IREE_HAL_AMDGPU_MAX_CPU_AGENT);
+  EXPECT_EQ(topology.cpu_agents[index].handle, cpu_agent.handle);
+  iree_hal_amdgpu_topology_deinitialize(&topology);
+}
+
+TEST_F(TopologyTest, InsertGpuAgentAllowsLastSlot) {
+  iree_hal_amdgpu_topology_t defaults;
+  IREE_ASSERT_OK(
+      iree_hal_amdgpu_topology_initialize_with_defaults(&libhsa, &defaults));
+  if (defaults.gpu_agent_count == 0 || defaults.cpu_agent_count == 0) {
+    iree_hal_amdgpu_topology_deinitialize(&defaults);
+    GTEST_SKIP() << "no GPU agents found";
+    return;
+  }
+  hsa_agent_t cpu_agent = defaults.cpu_agents[0];
+  hsa_agent_t gpu_agent = defaults.gpu_agents[0];
+  iree_hal_amdgpu_topology_deinitialize(&defaults);
+
+  iree_hal_amdgpu_topology_t topology;
+  iree_hal_amdgpu_topology_initialize(&topology);
+  topology.cpu_agent_count = 1;
+  topology.cpu_agents[0] = cpu_agent;
+  topology.all_agent_count = 1;
+  topology.all_agents[0] = cpu_agent;
+  topology.gpu_agent_count = IREE_HAL_AMDGPU_MAX_GPU_AGENT - 1;
+  for (iree_host_size_t i = 0; i < topology.gpu_agent_count; ++i) {
+    hsa_agent_t fake_agent = MakeFakeAgent(0xC0DE000000000000ull + i);
+    topology.gpu_agents[i] = fake_agent;
+    topology.gpu_cpu_map[i] = 0;
+    topology.all_agents[topology.all_agent_count++] = fake_agent;
+  }
+
+  IREE_ASSERT_OK(iree_hal_amdgpu_topology_insert_gpu_agent(
+      &topology, &libhsa, gpu_agent, cpu_agent));
+  EXPECT_EQ(topology.gpu_agent_count, IREE_HAL_AMDGPU_MAX_GPU_AGENT);
+  EXPECT_EQ(topology.all_agent_count, IREE_HAL_AMDGPU_MAX_GPU_AGENT + 1);
+  EXPECT_EQ(topology.gpu_agents[IREE_HAL_AMDGPU_MAX_GPU_AGENT - 1].handle,
+            gpu_agent.handle);
+  EXPECT_EQ(topology.gpu_cpu_map[IREE_HAL_AMDGPU_MAX_GPU_AGENT - 1], 0);
   iree_hal_amdgpu_topology_deinitialize(&topology);
 }
 
@@ -93,19 +265,60 @@
   IREE_ASSERT_OK(iree_hal_amdgpu_topology_initialize_from_path(
       &libhsa, IREE_SV("0"), &topology));
   if (topology.gpu_agent_count == 0) {
-    // This could be ignoring an error, but it usually just indicates no agents
-    // on the machine.
     GTEST_SKIP() << "no GPU agents found";
     return;
   }
   EXPECT_EQ(topology.all_agent_count, 2);
   EXPECT_EQ(topology.cpu_agent_count, 1);
   EXPECT_EQ(topology.gpu_agent_count, 1);
-  EXPECT_EQ(topology.gpu_agent_queue_count, 1);
+  EXPECT_EQ(topology.gpu_agent_queue_count,
+            IREE_HAL_AMDGPU_DEFAULT_GPU_AGENT_QUEUE_COUNT);
   EXPECT_EQ(topology.gpu_cpu_map[0], 0);
   iree_hal_amdgpu_topology_deinitialize(&topology);
 }
 
+TEST_F(TopologyTest, InitializeFromPathTwoOrdinals) {
+  iree_hal_amdgpu_topology_t topology;
+  iree_status_t status = iree_hal_amdgpu_topology_initialize_from_path(
+      &libhsa, IREE_SV("0,1"), &topology);
+  if (!iree_status_is_ok(status)) {
+    iree_status_code_t status_code = iree_status_code(status);
+    if (status_code == IREE_STATUS_INVALID_ARGUMENT) {
+      iree_status_free(status);
+      GTEST_SKIP() << "fewer than two visible GPU agents";
+      return;
+    }
+    IREE_ASSERT_OK(status);
+  }
+  ExpectTopologyHasTwoGpus(topology);
+  iree_hal_amdgpu_topology_deinitialize(&topology);
+}
+
+TEST_F(TopologyTest, InitializeFromPathTwoDefaultGpuUuidsVerifies) {
+  iree_hal_amdgpu_topology_t defaults;
+  IREE_ASSERT_OK(
+      iree_hal_amdgpu_topology_initialize_with_defaults(&libhsa, &defaults));
+  if (defaults.gpu_agent_count < 2) {
+    iree_hal_amdgpu_topology_deinitialize(&defaults);
+    GTEST_SKIP() << "fewer than two compatible GPU agents";
+    return;
+  }
+
+  std::string path;
+  IREE_ASSERT_OK(
+      AppendAgentUuidPathFragment(&libhsa, defaults.gpu_agents[0], &path));
+  IREE_ASSERT_OK(
+      AppendAgentUuidPathFragment(&libhsa, defaults.gpu_agents[1], &path));
+
+  iree_hal_amdgpu_topology_t topology;
+  IREE_ASSERT_OK(iree_hal_amdgpu_topology_initialize_from_path(
+      &libhsa, iree_make_cstring_view(path.c_str()), &topology));
+  ExpectTopologyHasTwoGpus(topology);
+  IREE_EXPECT_OK(iree_hal_amdgpu_topology_verify(&topology, &libhsa));
+  iree_hal_amdgpu_topology_deinitialize(&topology);
+  iree_hal_amdgpu_topology_deinitialize(&defaults);
+}
+
 // Tests that initialize_from_gpu_agent_mask with a 0 mask is the same as
 // initializing from defaults.
 TEST_F(TopologyTest, InitializeFromGPUAgentMask0) {
@@ -113,8 +326,6 @@
   IREE_ASSERT_OK(iree_hal_amdgpu_topology_initialize_from_gpu_agent_mask(
       &libhsa, 0ull, &topology));
   if (topology.gpu_agent_count == 0) {
-    // This could be ignoring an error, but it usually just indicates no agents
-    // on the machine.
     GTEST_SKIP() << "no GPU agents found";
     return;
   }
@@ -131,18 +342,35 @@
   IREE_ASSERT_OK(iree_hal_amdgpu_topology_initialize_from_gpu_agent_mask(
       &libhsa, 1ull << 0, &topology));
   if (topology.gpu_agent_count == 0) {
-    // This could be ignoring an error, but it usually just indicates no agents
-    // on the machine.
     GTEST_SKIP() << "no GPU agents found";
     return;
   }
   EXPECT_EQ(topology.all_agent_count, 2);
   EXPECT_EQ(topology.cpu_agent_count, 1);
   EXPECT_EQ(topology.gpu_agent_count, 1);
-  EXPECT_EQ(topology.gpu_agent_queue_count, 1);
+  EXPECT_EQ(topology.gpu_agent_queue_count,
+            IREE_HAL_AMDGPU_DEFAULT_GPU_AGENT_QUEUE_COUNT);
   EXPECT_EQ(topology.gpu_cpu_map[0], 0);
   iree_hal_amdgpu_topology_deinitialize(&topology);
 }
 
+TEST_F(TopologyTest, InitializeFromGPUAgentMaskTwoDevices) {
+  iree_hal_amdgpu_topology_t topology;
+  iree_status_t status =
+      iree_hal_amdgpu_topology_initialize_from_gpu_agent_mask(&libhsa, 0x3ull,
+                                                              &topology);
+  if (!iree_status_is_ok(status)) {
+    iree_status_code_t status_code = iree_status_code(status);
+    if (status_code == IREE_STATUS_OUT_OF_RANGE) {
+      iree_status_free(status);
+      GTEST_SKIP() << "fewer than two visible GPU agents";
+      return;
+    }
+    IREE_ASSERT_OK(status);
+  }
+  ExpectTopologyHasTwoGpus(topology);
+  iree_hal_amdgpu_topology_deinitialize(&topology);
+}
+
 }  // namespace
 }  // namespace iree::hal::amdgpu
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/vmem.c b/runtime/src/iree/hal/drivers/amdgpu/util/vmem.c
index 57d2e38..256df72 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/util/vmem.c
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/vmem.c
@@ -24,24 +24,26 @@
 
   // Filter to the global segment only.
   hsa_region_segment_t segment = 0;
-  IREE_IGNORE_ERROR(iree_hsa_amd_memory_pool_get_info(
-      IREE_LIBHSA(state->libhsa), memory_pool, HSA_AMD_MEMORY_POOL_INFO_SEGMENT,
-      &segment));
+  hsa_status_t hsa_status = iree_hsa_amd_memory_pool_get_info_raw(
+      state->libhsa, memory_pool, HSA_AMD_MEMORY_POOL_INFO_SEGMENT, &segment);
+  if (hsa_status != HSA_STATUS_SUCCESS) return hsa_status;
   if (segment != HSA_REGION_SEGMENT_GLOBAL) return HSA_STATUS_SUCCESS;
 
   // Must be able to allocate. This should be true for any pool we query that
   // matches the other flags. Workgroup-private pools won't have this set.
   bool alloc_allowed = false;
-  IREE_IGNORE_ERROR(iree_hsa_amd_memory_pool_get_info(
-      IREE_LIBHSA(state->libhsa), memory_pool,
-      HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_ALLOWED, &alloc_allowed));
+  hsa_status = iree_hsa_amd_memory_pool_get_info_raw(
+      state->libhsa, memory_pool,
+      HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_ALLOWED, &alloc_allowed);
+  if (hsa_status != HSA_STATUS_SUCCESS) return hsa_status;
   if (!alloc_allowed) return HSA_STATUS_SUCCESS;
 
   // Match if flags are present.
   hsa_region_global_flag_t global_flag = 0;
-  IREE_IGNORE_ERROR(iree_hsa_amd_memory_pool_get_info(
-      IREE_LIBHSA(state->libhsa), memory_pool,
-      HSA_AMD_MEMORY_POOL_INFO_GLOBAL_FLAGS, &global_flag));
+  hsa_status = iree_hsa_amd_memory_pool_get_info_raw(
+      state->libhsa, memory_pool, HSA_AMD_MEMORY_POOL_INFO_GLOBAL_FLAGS,
+      &global_flag);
+  if (hsa_status != HSA_STATUS_SUCCESS) return hsa_status;
   if (global_flag & state->match_flags) {
     state->best_pool = memory_pool;
     return HSA_STATUS_INFO_BREAK;
@@ -97,14 +99,67 @@
       out_pool);
 }
 
+bool iree_hal_amdgpu_try_find_coarse_global_memory_pool(
+    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_agent_t agent,
+    hsa_amd_memory_pool_t* out_pool) {
+  memset(out_pool, 0, sizeof(*out_pool));
+  iree_hal_amdgpu_find_global_memory_pool_state_t find_state = {
+      .libhsa = libhsa,
+      .match_flags = HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_COARSE_GRAINED,
+      .best_pool = {0},
+  };
+  (void)iree_hsa_amd_agent_iterate_memory_pools_raw(
+      libhsa, agent, iree_hal_amdgpu_find_global_memory_pool_iterator,
+      &find_state);
+  *out_pool = find_state.best_pool;
+  return find_state.best_pool.handle != 0;
+}
+
+bool iree_hal_amdgpu_try_find_fine_global_memory_pool(
+    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_agent_t agent,
+    hsa_amd_memory_pool_t* out_pool) {
+  memset(out_pool, 0, sizeof(*out_pool));
+  iree_hal_amdgpu_find_global_memory_pool_state_t find_state = {
+      .libhsa = libhsa,
+      .match_flags =
+          HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_FINE_GRAINED |
+          HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_EXTENDED_SCOPE_FINE_GRAINED,
+      .best_pool = {0},
+  };
+  (void)iree_hsa_amd_agent_iterate_memory_pools_raw(
+      libhsa, agent, iree_hal_amdgpu_find_global_memory_pool_iterator,
+      &find_state);
+  *out_pool = find_state.best_pool;
+  return find_state.best_pool.handle != 0;
+}
+
 //===----------------------------------------------------------------------===//
 // iree_hal_amdgpu_vmem_ringbuffer_t
 //===----------------------------------------------------------------------===//
 
+static iree_status_t iree_hal_amdgpu_vmem_translate_memory_type(
+    iree_hal_amdgpu_vmem_memory_type_t memory_type,
+    hsa_amd_memory_type_t* out_hsa_memory_type) {
+  IREE_ASSERT_ARGUMENT(out_hsa_memory_type);
+  switch (memory_type) {
+    case IREE_HAL_AMDGPU_VMEM_MEMORY_TYPE_DEFAULT:
+      *out_hsa_memory_type = MEMORY_TYPE_NONE;
+      return iree_ok_status();
+    case IREE_HAL_AMDGPU_VMEM_MEMORY_TYPE_PINNED_HOST:
+      *out_hsa_memory_type = MEMORY_TYPE_PINNED;
+      return iree_ok_status();
+    default:
+      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                              "unsupported vmem memory type: %d",
+                              (int)memory_type);
+  }
+}
+
 iree_status_t iree_hal_amdgpu_vmem_ringbuffer_initialize(
     const iree_hal_amdgpu_libhsa_t* libhsa, hsa_agent_t local_agent,
-    hsa_amd_memory_pool_t memory_pool, iree_device_size_t min_capacity,
-    iree_host_size_t access_desc_count,
+    hsa_amd_memory_pool_t memory_pool,
+    iree_hal_amdgpu_vmem_memory_type_t memory_type,
+    iree_device_size_t min_capacity, iree_host_size_t access_desc_count,
     const hsa_amd_memory_access_desc_t* access_descs,
     iree_hal_amdgpu_vmem_ringbuffer_t* out_ringbuffer) {
   IREE_ASSERT_ARGUMENT(libhsa);
@@ -113,6 +168,11 @@
   IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, min_capacity);
   memset(out_ringbuffer, 0, sizeof(*out_ringbuffer));
 
+  hsa_amd_memory_type_t hsa_memory_type = MEMORY_TYPE_NONE;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_hal_amdgpu_vmem_translate_memory_type(memory_type,
+                                                     &hsa_memory_type));
+
   // hsa_amd_vmem_handle_create wants values aligned to this value.
   size_t alloc_rec_granule = 0;
   IREE_RETURN_AND_END_ZONE_IF_ERROR(
@@ -140,8 +200,8 @@
 
   // Allocate the physical memory for backing the ringbuffer.
   iree_status_t status = iree_hsa_amd_vmem_handle_create(
-      IREE_LIBHSA(libhsa), memory_pool, capacity, MEMORY_TYPE_NONE,
-      /*flags=*/0, &out_ringbuffer->alloc_handle);
+      IREE_LIBHSA(libhsa), memory_pool, capacity, hsa_memory_type, /*flags=*/0,
+      &out_ringbuffer->alloc_handle);
 
   void* va_offsets[3] = {
       (uint8_t*)out_ringbuffer->va_base_ptr + 0 * capacity,
@@ -175,8 +235,9 @@
 
 iree_status_t iree_hal_amdgpu_vmem_ringbuffer_initialize_with_topology(
     const iree_hal_amdgpu_libhsa_t* libhsa, hsa_agent_t local_agent,
-    hsa_amd_memory_pool_t memory_pool, iree_device_size_t min_capacity,
-    const iree_hal_amdgpu_topology_t* topology,
+    hsa_amd_memory_pool_t memory_pool,
+    iree_hal_amdgpu_vmem_memory_type_t memory_type,
+    iree_device_size_t min_capacity, const iree_hal_amdgpu_topology_t* topology,
     iree_hal_amdgpu_vmem_access_mode_t access_mode,
     iree_hal_amdgpu_vmem_ringbuffer_t* out_ringbuffer) {
   IREE_ASSERT_ARGUMENT(libhsa);
@@ -252,8 +313,8 @@
 
   // Route to the explicit initializer.
   iree_status_t status = iree_hal_amdgpu_vmem_ringbuffer_initialize(
-      libhsa, local_agent, memory_pool, min_capacity, access_desc_count,
-      access_descs, out_ringbuffer);
+      libhsa, local_agent, memory_pool, memory_type, min_capacity,
+      access_desc_count, access_descs, out_ringbuffer);
 
   IREE_TRACE_ZONE_END(z0);
   return status;
@@ -274,17 +335,17 @@
         (uint8_t*)ringbuffer->va_base_ptr + 2 * ringbuffer->capacity,
     };
     for (iree_host_size_t i = 0; i < IREE_ARRAYSIZE(va_offsets); ++i) {
-      IREE_IGNORE_ERROR(iree_hsa_amd_vmem_unmap(
-          IREE_LIBHSA(libhsa), va_offsets[i], ringbuffer->capacity));
+      iree_hal_amdgpu_hsa_cleanup_assert_success(iree_hsa_amd_vmem_unmap_raw(
+          libhsa, va_offsets[i], ringbuffer->capacity));
     }
-    IREE_IGNORE_ERROR(iree_hsa_amd_vmem_handle_release(
-        IREE_LIBHSA(libhsa), ringbuffer->alloc_handle));
+    iree_hal_amdgpu_hsa_cleanup_assert_success(
+        iree_hsa_amd_vmem_handle_release_raw(libhsa, ringbuffer->alloc_handle));
   }
 
   if (ringbuffer->va_base_ptr) {
-    IREE_IGNORE_ERROR(iree_hsa_amd_vmem_address_free(IREE_LIBHSA(libhsa),
-                                                     ringbuffer->va_base_ptr,
-                                                     ringbuffer->capacity * 3));
+    iree_hal_amdgpu_hsa_cleanup_assert_success(
+        iree_hsa_amd_vmem_address_free_raw(libhsa, ringbuffer->va_base_ptr,
+                                           ringbuffer->capacity * 3));
   }
 
   memset(ringbuffer, 0, sizeof(*ringbuffer));
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/vmem.h b/runtime/src/iree/hal/drivers/amdgpu/util/vmem.h
index b75a6c6..66782fb 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/util/vmem.h
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/vmem.h
@@ -37,6 +37,18 @@
   IREE_HAL_AMDGPU_ACCESS_MODE_EXCLUSIVE_PRODUCER,
 } iree_hal_amdgpu_vmem_access_mode_t;
 
+// Selects the HSA vmem allocation type for a ringbuffer's backing memory.
+//
+// AMD's HSA extension exposes this as hsa_amd_memory_type_t with bare
+// MEMORY_TYPE_* enumerants; this local enum keeps that upstream namespace leak
+// out of the rest of the driver.
+typedef enum iree_hal_amdgpu_vmem_memory_type_e {
+  // Default vmem allocation mode for device-local pools.
+  IREE_HAL_AMDGPU_VMEM_MEMORY_TYPE_DEFAULT = 0,
+  // Pinned host allocation mode for CPU memory pools.
+  IREE_HAL_AMDGPU_VMEM_MEMORY_TYPE_PINNED_HOST = 1,
+} iree_hal_amdgpu_vmem_memory_type_t;
+
 // Finds a global memory pool on the |agent| matching any of the specified
 // global flags.
 iree_status_t iree_hal_amdgpu_find_global_memory_pool(
@@ -59,6 +71,18 @@
     const iree_hal_amdgpu_libhsa_t* libhsa, hsa_agent_t agent,
     hsa_amd_memory_pool_t* out_pool);
 
+// Tries to find a coarse-grained memory pool on the |agent|.
+// Returns true and populates |out_pool| if found, false otherwise.
+bool iree_hal_amdgpu_try_find_coarse_global_memory_pool(
+    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_agent_t agent,
+    hsa_amd_memory_pool_t* out_pool);
+
+// Tries to find a fine-grained memory pool on the |agent|.
+// Returns true and populates |out_pool| if found, false otherwise.
+bool iree_hal_amdgpu_try_find_fine_global_memory_pool(
+    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_agent_t agent,
+    hsa_amd_memory_pool_t* out_pool);
+
 //===----------------------------------------------------------------------===//
 // iree_hal_amdgpu_vmem_ringbuffer_t
 //===----------------------------------------------------------------------===//
@@ -102,11 +126,16 @@
 
 // Initializes a ringbuffer by allocating the physical and virtual memory of at
 // least the requested |min_capacity| with at least 64 byte alignment.
+// |memory_type| selects the HSA allocation mode for the selected pool; callers
+// allocating from host CPU pools should use
+// IREE_HAL_AMDGPU_VMEM_MEMORY_TYPE_PINNED_HOST while device-local pools
+// generally use IREE_HAL_AMDGPU_VMEM_MEMORY_TYPE_DEFAULT.
 // |access_descs| will be used to setup accessibility.
 iree_status_t iree_hal_amdgpu_vmem_ringbuffer_initialize(
     const iree_hal_amdgpu_libhsa_t* libhsa, hsa_agent_t local_agent,
-    hsa_amd_memory_pool_t memory_pool, iree_device_size_t min_capacity,
-    iree_host_size_t access_desc_count,
+    hsa_amd_memory_pool_t memory_pool,
+    iree_hal_amdgpu_vmem_memory_type_t memory_type,
+    iree_device_size_t min_capacity, iree_host_size_t access_desc_count,
     const hsa_amd_memory_access_desc_t* access_descs,
     iree_hal_amdgpu_vmem_ringbuffer_t* out_ringbuffer);
 
@@ -116,8 +145,9 @@
 // accessibility.
 iree_status_t iree_hal_amdgpu_vmem_ringbuffer_initialize_with_topology(
     const iree_hal_amdgpu_libhsa_t* libhsa, hsa_agent_t local_agent,
-    hsa_amd_memory_pool_t memory_pool, iree_device_size_t min_capacity,
-    const iree_hal_amdgpu_topology_t* topology,
+    hsa_amd_memory_pool_t memory_pool,
+    iree_hal_amdgpu_vmem_memory_type_t memory_type,
+    iree_device_size_t min_capacity, const iree_hal_amdgpu_topology_t* topology,
     iree_hal_amdgpu_vmem_access_mode_t access_mode,
     iree_hal_amdgpu_vmem_ringbuffer_t* out_ringbuffer);
 
diff --git a/runtime/src/iree/hal/drivers/amdgpu/util/vmem_test.cc b/runtime/src/iree/hal/drivers/amdgpu/util/vmem_test.cc
index 9d8d739..6cf2a4a 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/util/vmem_test.cc
+++ b/runtime/src/iree/hal/drivers/amdgpu/util/vmem_test.cc
@@ -27,7 +27,7 @@
         host_allocator, &libhsa);
     if (!iree_status_is_ok(status)) {
       iree_status_fprint(stderr, status);
-      iree_status_ignore(status);
+      iree_status_free(status);
       GTEST_SKIP() << "HSA not available, skipping tests";
     }
     IREE_ASSERT_OK(
@@ -104,8 +104,9 @@
   const iree_device_size_t min_capacity = 1 * 1024 * 1024;
   iree_hal_amdgpu_vmem_ringbuffer_t ringbuffer = {0};
   IREE_ASSERT_OK(iree_hal_amdgpu_vmem_ringbuffer_initialize_with_topology(
-      &libhsa, gpu_agent, memory_pool, min_capacity, &topology,
-      IREE_HAL_AMDGPU_ACCESS_MODE_SHARED, &ringbuffer));
+      &libhsa, gpu_agent, memory_pool, IREE_HAL_AMDGPU_VMEM_MEMORY_TYPE_DEFAULT,
+      min_capacity, &topology, IREE_HAL_AMDGPU_ACCESS_MODE_SHARED,
+      &ringbuffer));
 
   EXPECT_GE(ringbuffer.capacity, min_capacity);
   EXPECT_EQ(ringbuffer.ring_base_ptr,
@@ -128,8 +129,9 @@
   const iree_device_size_t min_capacity = 1 * 1024 * 1024;
   iree_hal_amdgpu_vmem_ringbuffer_t ringbuffer = {0};
   IREE_ASSERT_OK(iree_hal_amdgpu_vmem_ringbuffer_initialize_with_topology(
-      &libhsa, gpu_agent, memory_pool, min_capacity, &topology,
-      IREE_HAL_AMDGPU_ACCESS_MODE_SHARED, &ringbuffer));
+      &libhsa, gpu_agent, memory_pool, IREE_HAL_AMDGPU_VMEM_MEMORY_TYPE_DEFAULT,
+      min_capacity, &topology, IREE_HAL_AMDGPU_ACCESS_MODE_SHARED,
+      &ringbuffer));
 
   // Fill entire range [0,capacity).
   iree_device_size_t capacity_u32 = ringbuffer.capacity / sizeof(uint32_t);
diff --git a/runtime/src/iree/hal/drivers/amdgpu/virtual_queue.c b/runtime/src/iree/hal/drivers/amdgpu/virtual_queue.c
deleted file mode 100644
index efc3893..0000000
--- a/runtime/src/iree/hal/drivers/amdgpu/virtual_queue.c
+++ /dev/null
@@ -1,130 +0,0 @@
-// Copyright 2025 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/hal/drivers/amdgpu/virtual_queue.h"
-
-//===----------------------------------------------------------------------===//
-// iree_hal_amdgpu_queue_options_t
-//===----------------------------------------------------------------------===//
-
-iree_status_t iree_hal_amdgpu_queue_infer_placement(
-    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_agent_t cpu_agent,
-    hsa_agent_t gpu_agent, iree_hal_amdgpu_queue_placement_t* out_placement) {
-  // TODO(benvanik): implement conditions:
-  // * PCIe Atomics
-  // * !PCIe Atomics && APU
-  // * !PCIe Atomics && gfx90a && xGMI
-  *out_placement = IREE_HAL_AMDGPU_QUEUE_PLACEMENT_HOST;
-  return iree_ok_status();
-}
-
-void iree_hal_amdgpu_queue_options_initialize(
-    iree_hal_amdgpu_queue_options_t* out_options) {
-  IREE_ASSERT_ARGUMENT(out_options);
-  memset(out_options, 0, sizeof(*out_options));
-  out_options->placement = IREE_HAL_AMDGPU_QUEUE_PLACEMENT_HOST;
-  out_options->flags = IREE_HAL_AMDGPU_QUEUE_FLAG_NONE;
-  out_options->mode = IREE_HAL_AMDGPU_QUEUE_SCHEDULING_MODE_DEFAULT;
-  out_options->control_queue_capacity =
-      IREE_HAL_AMDGPU_DEFAULT_CONTROL_QUEUE_CAPACITY;
-  out_options->execution_queue_count =
-      IREE_HAL_AMDGPU_DEFAULT_EXECUTION_QUEUE_COUNT;
-  out_options->execution_queue_capacity =
-      IREE_HAL_AMDGPU_DEFAULT_EXECUTION_QUEUE_CAPACITY;
-  out_options->kernarg_ringbuffer_capacity =
-      IREE_HAL_AMDGPU_DEFAULT_KERNARG_RINGBUFFER_CAPACITY;
-  out_options->trace_buffer_capacity =
-      IREE_HAL_AMDGPU_DEFAULT_TRACE_BUFFER_CAPACITY;
-}
-
-// Verifies that the given |queue_capacity| is between the agent min/max
-// requirements and a power-of-two.
-static iree_status_t iree_hal_amdgpu_verify_hsa_queue_size(
-    iree_string_view_t queue_name, iree_host_size_t queue_size,
-    uint32_t queue_min_size, uint32_t queue_max_size) {
-  // Queues must meet the min/max size requirements.
-  if (queue_size < queue_min_size || queue_size > queue_max_size) {
-    return iree_make_status(
-        IREE_STATUS_INVALID_ARGUMENT,
-        "%.*s queue capacity on this agent must be between "
-        "HSA_AGENT_INFO_QUEUE_MIN_SIZE=%u and HSA_AGENT_INFO_QUEUE_MAX_SIZE=%u "
-        "(provided %" PRIhsz ")",
-        (int)queue_name.size, queue_name.data, queue_min_size, queue_max_size,
-        queue_size);
-  }
-
-  // All queues must be a power-of-two due to ringbuffer masking.
-  if (!iree_host_size_is_power_of_two(queue_size)) {
-    return iree_make_status(
-        IREE_STATUS_INVALID_ARGUMENT,
-        "%.*s queue capacity must be a power of two (provided %" PRIhsz ")",
-        (int)queue_name.size, queue_name.data, queue_size);
-  }
-
-  return iree_ok_status();
-}
-
-iree_status_t iree_hal_amdgpu_queue_options_verify(
-    const iree_hal_amdgpu_queue_options_t* options,
-    const iree_hal_amdgpu_libhsa_t* libhsa, hsa_agent_t cpu_agent,
-    hsa_agent_t gpu_agent) {
-  IREE_ASSERT_ARGUMENT(options);
-  IREE_ASSERT_ARGUMENT(libhsa);
-
-  // If the queue is placed on the device it must support PCIe atomics or be
-  // connected via xGMI.
-  if (options->placement == IREE_HAL_AMDGPU_QUEUE_PLACEMENT_DEVICE) {
-    iree_hal_amdgpu_queue_placement_t possible_placement =
-        IREE_HAL_AMDGPU_QUEUE_PLACEMENT_ANY;
-    IREE_RETURN_IF_ERROR(iree_hal_amdgpu_queue_infer_placement(
-        libhsa, cpu_agent, gpu_agent, &possible_placement));
-    if (possible_placement != options->placement) {
-      return iree_make_status(
-          IREE_STATUS_INCOMPATIBLE,
-          "device-side queue placement requested but the device does not meet "
-          "the minimum requirements (PCIe atomics, xGMI, or APU)");
-    }
-  }
-
-  // Query agent min/max queue size.
-  uint32_t queue_min_size = 0;
-  IREE_RETURN_IF_ERROR(iree_hsa_agent_get_info(IREE_LIBHSA(libhsa), gpu_agent,
-                                               HSA_AGENT_INFO_QUEUE_MIN_SIZE,
-                                               &queue_min_size));
-  uint32_t queue_max_size = 0;
-  IREE_RETURN_IF_ERROR(iree_hsa_agent_get_info(IREE_LIBHSA(libhsa), gpu_agent,
-                                               HSA_AGENT_INFO_QUEUE_MAX_SIZE,
-                                               &queue_max_size));
-
-  // Verify HSA queues.
-  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_verify_hsa_queue_size(
-      IREE_SV("control"), options->control_queue_capacity, queue_min_size,
-      queue_max_size));
-  IREE_RETURN_IF_ERROR(iree_hal_amdgpu_verify_hsa_queue_size(
-      IREE_SV("execution"), options->execution_queue_capacity, queue_min_size,
-      queue_max_size));
-
-  // Verify kernarg ringbuffer capacity (our ringbuffer so no HSA min/max
-  // required).
-  if (!iree_device_size_is_power_of_two(options->kernarg_ringbuffer_capacity)) {
-    return iree_make_status(
-        IREE_STATUS_INVALID_ARGUMENT,
-        "kernarg ringbuffer capacity must be a power of two (provided %" PRIdsz
-        ")",
-        options->kernarg_ringbuffer_capacity);
-  }
-
-  // Verify trace buffer capacity (our ringbuffer so no HSA min/max required).
-  if (options->trace_buffer_capacity &&
-      !iree_device_size_is_power_of_two(options->trace_buffer_capacity)) {
-    return iree_make_status(
-        IREE_STATUS_INVALID_ARGUMENT,
-        "trace buffer capacity must be a power of two (provided %" PRIdsz ")",
-        options->trace_buffer_capacity);
-  }
-
-  return iree_ok_status();
-}
diff --git a/runtime/src/iree/hal/drivers/amdgpu/virtual_queue.h b/runtime/src/iree/hal/drivers/amdgpu/virtual_queue.h
index 72faa89..9e07f70 100644
--- a/runtime/src/iree/hal/drivers/amdgpu/virtual_queue.h
+++ b/runtime/src/iree/hal/drivers/amdgpu/virtual_queue.h
@@ -174,9 +174,6 @@
       iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
       iree_device_size_t length, iree_hal_copy_flags_t flags);
 
-  // NULL if not implemented and emulation should be used.
-  // TODO(benvanik): when all queue implementations support native I/O we should
-  // drop the emulation (it's bad).
   iree_status_t(IREE_API_PTR* read)(
       iree_hal_amdgpu_virtual_queue_t* queue,
       const iree_hal_semaphore_list_t wait_semaphore_list,
@@ -185,9 +182,6 @@
       iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
       iree_device_size_t length, iree_hal_read_flags_t flags);
 
-  // NULL if not implemented and emulation should be used.
-  // TODO(benvanik): when all queue implementations support native I/O we should
-  // drop the emulation (it's bad).
   iree_status_t(IREE_API_PTR* write)(
       iree_hal_amdgpu_virtual_queue_t* queue,
       const iree_hal_semaphore_list_t wait_semaphore_list,
@@ -196,6 +190,23 @@
       iree_hal_file_t* target_file, uint64_t target_offset,
       iree_device_size_t length, iree_hal_write_flags_t flags);
 
+  iree_status_t(IREE_API_PTR* host_call)(
+      iree_hal_amdgpu_virtual_queue_t* queue,
+      const iree_hal_semaphore_list_t wait_semaphore_list,
+      const iree_hal_semaphore_list_t signal_semaphore_list,
+      iree_hal_host_call_t call, const uint64_t args[4],
+      iree_hal_host_call_flags_t flags);
+
+  iree_status_t(IREE_API_PTR* dispatch)(
+      iree_hal_amdgpu_virtual_queue_t* queue,
+      const iree_hal_semaphore_list_t wait_semaphore_list,
+      const iree_hal_semaphore_list_t signal_semaphore_list,
+      iree_hal_executable_t* executable,
+      iree_hal_executable_export_ordinal_t export_ordinal,
+      const iree_hal_dispatch_config_t config, iree_const_byte_span_t constants,
+      const iree_hal_buffer_ref_list_t bindings,
+      iree_hal_dispatch_flags_t flags);
+
   iree_status_t(IREE_API_PTR* execute)(
       iree_hal_amdgpu_virtual_queue_t* queue,
       const iree_hal_semaphore_list_t wait_semaphore_list,
diff --git a/runtime/src/iree/hal/local/BUILD.bazel b/runtime/src/iree/hal/local/BUILD.bazel
index 4d3270f..3f5e1be 100644
--- a/runtime/src/iree/hal/local/BUILD.bazel
+++ b/runtime/src/iree/hal/local/BUILD.bazel
@@ -147,6 +147,7 @@
         "//runtime/src/iree/base",
         "//runtime/src/iree/base/threading",
         "//runtime/src/iree/hal",
+        "//runtime/src/iree/hal/utils:profile_event_ring",
     ],
 )
 
diff --git a/runtime/src/iree/hal/local/CMakeLists.txt b/runtime/src/iree/hal/local/CMakeLists.txt
index 413bf53..bf4ad5c 100644
--- a/runtime/src/iree/hal/local/CMakeLists.txt
+++ b/runtime/src/iree/hal/local/CMakeLists.txt
@@ -169,6 +169,7 @@
     iree::base
     iree::base::threading
     iree::hal
+    iree::hal::utils::profile_event_ring
   PUBLIC
 )
 
diff --git a/runtime/src/iree/hal/local/profile.c b/runtime/src/iree/hal/local/profile.c
index e13337d..ca70d61 100644
--- a/runtime/src/iree/hal/local/profile.c
+++ b/runtime/src/iree/hal/local/profile.c
@@ -11,6 +11,7 @@
 
 #include "iree/base/threading/mutex.h"
 #include "iree/hal/local/local_executable.h"
+#include "iree/hal/utils/profile_event_ring.h"
 
 //===----------------------------------------------------------------------===//
 // iree_hal_local_profile_recorder_t
@@ -19,55 +20,6 @@
 // Default number of records retained per enabled event stream between flushes.
 #define IREE_HAL_LOCAL_PROFILE_DEFAULT_EVENT_CAPACITY 4096
 
-typedef struct iree_hal_local_profile_event_ring_t {
-  // Storage for fixed-size event records, or NULL when the stream is disabled.
-  void* records;
-
-  // Size in bytes of each event record in |records|.
-  iree_host_size_t record_size;
-
-  // Power-of-two number of event records in |records|.
-  iree_host_size_t capacity;
-
-  // Bit mask used to wrap absolute positions into |records|.
-  iree_host_size_t mask;
-
-  // Absolute position of the first unflushed event record.
-  uint64_t read_position;
-
-  // Absolute position one past the last appended event record.
-  uint64_t write_position;
-
-  // Next nonzero event id assigned to a captured record.
-  uint64_t next_event_id;
-
-  // Records dropped since the last successful truncated flush.
-  uint64_t dropped_record_count;
-} iree_hal_local_profile_event_ring_t;
-
-typedef struct iree_hal_local_profile_event_ring_snapshot_t {
-  // Absolute read position captured for this flush attempt.
-  uint64_t read_position;
-
-  // Number of event records captured for this flush attempt.
-  iree_host_size_t record_count;
-
-  // Dropped records captured for this flush attempt.
-  uint64_t dropped_record_count;
-
-  // First contiguous record span in the ring.
-  const void* first_records;
-
-  // Number of records in |first_records|.
-  iree_host_size_t first_record_count;
-
-  // Second contiguous record span after ring wraparound, or NULL.
-  const void* second_records;
-
-  // Number of records in |second_records|.
-  iree_host_size_t second_record_count;
-} iree_hal_local_profile_event_ring_snapshot_t;
-
 typedef struct iree_hal_local_profile_id_set_t {
   // Open-addressed nonzero ids whose metadata was emitted.
   uint64_t* ids;
@@ -111,16 +63,16 @@
   void* event_storage;
 
   // Ring of host queue event records.
-  iree_hal_local_profile_event_ring_t queue_event_ring;
+  iree_hal_profile_event_ring_t queue_event_ring;
 
   // Ring of host execution span records.
-  iree_hal_local_profile_event_ring_t host_execution_event_ring;
+  iree_hal_profile_event_ring_t host_execution_event_ring;
 
   // Ring of memory lifecycle event records.
-  iree_hal_local_profile_event_ring_t memory_event_ring;
+  iree_hal_profile_event_ring_t memory_event_ring;
 
   // Ring of command-buffer region event records.
-  iree_hal_local_profile_event_ring_t command_region_event_ring;
+  iree_hal_profile_event_ring_t command_region_event_ring;
 
   // Metadata ids already emitted to the session sink.
   struct {
@@ -281,17 +233,6 @@
   return iree_ok_status();
 }
 
-static void iree_hal_local_profile_event_ring_initialize(
-    void* records, iree_host_size_t record_size, iree_host_size_t capacity,
-    iree_hal_local_profile_event_ring_t* out_ring) {
-  memset(out_ring, 0, sizeof(*out_ring));
-  out_ring->records = records;
-  out_ring->record_size = record_size;
-  out_ring->capacity = capacity;
-  out_ring->mask = capacity - 1;
-  out_ring->next_event_id = 1;
-}
-
 static iree_status_t iree_hal_local_profile_recorder_allocate_events(
     iree_hal_local_profile_recorder_t* recorder,
     const iree_hal_local_profile_recorder_options_t* recorder_options) {
@@ -354,25 +295,25 @@
   recorder->event_storage = event_storage;
 
   if (queue_event_capacity != 0) {
-    iree_hal_local_profile_event_ring_initialize(
+    iree_hal_profile_event_ring_initialize(
         (uint8_t*)event_storage + queue_events_offset,
         sizeof(iree_hal_profile_queue_event_t), queue_event_capacity,
         &recorder->queue_event_ring);
   }
   if (host_execution_event_capacity != 0) {
-    iree_hal_local_profile_event_ring_initialize(
+    iree_hal_profile_event_ring_initialize(
         (uint8_t*)event_storage + host_execution_events_offset,
         sizeof(iree_hal_profile_host_execution_event_t),
         host_execution_event_capacity, &recorder->host_execution_event_ring);
   }
   if (memory_event_capacity != 0) {
-    iree_hal_local_profile_event_ring_initialize(
+    iree_hal_profile_event_ring_initialize(
         (uint8_t*)event_storage + memory_events_offset,
         sizeof(iree_hal_profile_memory_event_t), memory_event_capacity,
         &recorder->memory_event_ring);
   }
   if (command_region_event_capacity != 0) {
-    iree_hal_local_profile_event_ring_initialize(
+    iree_hal_profile_event_ring_initialize(
         (uint8_t*)event_storage + command_region_events_offset,
         sizeof(iree_hal_profile_command_region_event_t),
         command_region_event_capacity, &recorder->command_region_event_ring);
@@ -856,33 +797,29 @@
 }
 
 static iree_hal_profile_queue_event_t* iree_hal_local_profile_queue_event_at(
-    const iree_hal_local_profile_event_ring_t* ring, uint64_t position) {
-  iree_hal_profile_queue_event_t* records =
-      (iree_hal_profile_queue_event_t*)ring->records;
-  return &records[position & ring->mask];
+    const iree_hal_profile_event_ring_t* ring, uint64_t position) {
+  return (iree_hal_profile_queue_event_t*)iree_hal_profile_event_ring_record_at(
+      ring, position);
 }
 
 static iree_hal_profile_host_execution_event_t*
 iree_hal_local_profile_host_execution_event_at(
-    const iree_hal_local_profile_event_ring_t* ring, uint64_t position) {
-  iree_hal_profile_host_execution_event_t* records =
-      (iree_hal_profile_host_execution_event_t*)ring->records;
-  return &records[position & ring->mask];
+    const iree_hal_profile_event_ring_t* ring, uint64_t position) {
+  return (iree_hal_profile_host_execution_event_t*)
+      iree_hal_profile_event_ring_record_at(ring, position);
 }
 
 static iree_hal_profile_memory_event_t* iree_hal_local_profile_memory_event_at(
-    const iree_hal_local_profile_event_ring_t* ring, uint64_t position) {
-  iree_hal_profile_memory_event_t* records =
-      (iree_hal_profile_memory_event_t*)ring->records;
-  return &records[position & ring->mask];
+    const iree_hal_profile_event_ring_t* ring, uint64_t position) {
+  return (iree_hal_profile_memory_event_t*)
+      iree_hal_profile_event_ring_record_at(ring, position);
 }
 
 static iree_hal_profile_command_region_event_t*
 iree_hal_local_profile_command_region_event_at(
-    const iree_hal_local_profile_event_ring_t* ring, uint64_t position) {
-  iree_hal_profile_command_region_event_t* records =
-      (iree_hal_profile_command_region_event_t*)ring->records;
-  return &records[position & ring->mask];
+    const iree_hal_profile_event_ring_t* ring, uint64_t position) {
+  return (iree_hal_profile_command_region_event_t*)
+      iree_hal_profile_event_ring_record_at(ring, position);
 }
 
 void iree_hal_local_profile_recorder_append_queue_event(
@@ -901,17 +838,18 @@
   IREE_ASSERT(is_valid);
   if (IREE_UNLIKELY(!is_valid)) return;
 
-  iree_hal_local_profile_event_ring_t* ring = &recorder->queue_event_ring;
+  iree_hal_profile_event_ring_t* ring = &recorder->queue_event_ring;
   iree_slim_mutex_lock(&recorder->mutex);
-  if (ring->write_position - ring->read_position >= ring->capacity) {
-    ++ring->dropped_record_count;
+  uint64_t event_position = 0;
+  uint64_t event_id = 0;
+  if (!iree_hal_profile_event_ring_try_append(ring, &event_position,
+                                              &event_id)) {
     iree_slim_mutex_unlock(&recorder->mutex);
     return;
   }
 
-  const uint64_t event_id = ring->next_event_id++;
   iree_hal_profile_queue_event_t* event =
-      iree_hal_local_profile_queue_event_at(ring, ring->write_position);
+      iree_hal_local_profile_queue_event_at(ring, event_position);
   *event = iree_hal_profile_queue_event_default();
   event->type = event_info->type;
   event->flags = event_info->flags;
@@ -931,7 +869,6 @@
   event->barrier_count = event_info->barrier_count;
   event->operation_count = event_info->operation_count;
   event->payload_length = event_info->payload_length;
-  ++ring->write_position;
   if (out_event_id) *out_event_id = event_id;
   iree_slim_mutex_unlock(&recorder->mutex);
 }
@@ -960,19 +897,18 @@
   IREE_ASSERT(has_valid_range);
   if (IREE_UNLIKELY(!has_valid_range)) end_time_ns = start_time_ns;
 
-  iree_hal_local_profile_event_ring_t* ring =
-      &recorder->host_execution_event_ring;
+  iree_hal_profile_event_ring_t* ring = &recorder->host_execution_event_ring;
   iree_slim_mutex_lock(&recorder->mutex);
-  if (ring->write_position - ring->read_position >= ring->capacity) {
-    ++ring->dropped_record_count;
+  uint64_t event_position = 0;
+  uint64_t event_id = 0;
+  if (!iree_hal_profile_event_ring_try_append(ring, &event_position,
+                                              &event_id)) {
     iree_slim_mutex_unlock(&recorder->mutex);
     return;
   }
 
-  const uint64_t event_id = ring->next_event_id++;
   iree_hal_profile_host_execution_event_t* event =
-      iree_hal_local_profile_host_execution_event_at(ring,
-                                                     ring->write_position);
+      iree_hal_local_profile_host_execution_event_at(ring, event_position);
   *event = iree_hal_profile_host_execution_event_default();
   event->type = event_info->type;
   event->flags = event_info->flags;
@@ -997,7 +933,6 @@
   event->tile_count = event_info->tile_count;
   event->tile_duration_sum_ns = event_info->tile_duration_sum_ns;
   event->operation_count = event_info->operation_count;
-  ++ring->write_position;
   if (out_event_id) *out_event_id = event_id;
   iree_slim_mutex_unlock(&recorder->mutex);
 }
@@ -1027,19 +962,18 @@
   IREE_ASSERT(has_valid_range);
   if (IREE_UNLIKELY(!has_valid_range)) end_time_ns = start_time_ns;
 
-  iree_hal_local_profile_event_ring_t* ring =
-      &recorder->command_region_event_ring;
+  iree_hal_profile_event_ring_t* ring = &recorder->command_region_event_ring;
   iree_slim_mutex_lock(&recorder->mutex);
-  if (ring->write_position - ring->read_position >= ring->capacity) {
-    ++ring->dropped_record_count;
+  uint64_t event_position = 0;
+  uint64_t event_id = 0;
+  if (!iree_hal_profile_event_ring_try_append(ring, &event_position,
+                                              &event_id)) {
     iree_slim_mutex_unlock(&recorder->mutex);
     return;
   }
 
-  const uint64_t event_id = ring->next_event_id++;
   iree_hal_profile_command_region_event_t* event =
-      iree_hal_local_profile_command_region_event_at(ring,
-                                                     ring->write_position);
+      iree_hal_local_profile_command_region_event_at(ring, event_position);
   *event = iree_hal_profile_command_region_event_default();
   event->flags = event_info->flags;
   event->event_id = event_id;
@@ -1102,7 +1036,6 @@
   event->retention.publish_keep_active_count =
       event_info->retention.publish_keep_active_count;
   event->retention.keep_warm_count = event_info->retention.keep_warm_count;
-  ++ring->write_position;
   if (out_event_id) *out_event_id = event_id;
   iree_slim_mutex_unlock(&recorder->mutex);
 }
@@ -1132,93 +1065,41 @@
   IREE_ASSERT(is_valid);
   if (IREE_UNLIKELY(!is_valid)) return;
 
-  iree_hal_local_profile_event_ring_t* ring = &recorder->memory_event_ring;
+  iree_hal_profile_event_ring_t* ring = &recorder->memory_event_ring;
   iree_slim_mutex_lock(&recorder->mutex);
-  if (ring->write_position - ring->read_position >= ring->capacity) {
-    ++ring->dropped_record_count;
+  uint64_t event_position = 0;
+  uint64_t event_id = 0;
+  if (!iree_hal_profile_event_ring_try_append(ring, &event_position,
+                                              &event_id)) {
     iree_slim_mutex_unlock(&recorder->mutex);
     return;
   }
 
-  const uint64_t event_id = ring->next_event_id++;
   iree_hal_profile_memory_event_t* record =
-      iree_hal_local_profile_memory_event_at(ring, ring->write_position);
+      iree_hal_local_profile_memory_event_at(ring, event_position);
   *record = *event;
   record->record_length = sizeof(*record);
   record->event_id = event_id;
   if (record->host_time_ns == 0) {
     record->host_time_ns = iree_time_now();
   }
-  ++ring->write_position;
   if (out_event_id) *out_event_id = event_id;
   iree_slim_mutex_unlock(&recorder->mutex);
 }
 
-static void iree_hal_local_profile_event_ring_snapshot(
-    const iree_hal_local_profile_event_ring_t* ring,
-    iree_hal_local_profile_event_ring_snapshot_t* out_snapshot) {
-  memset(out_snapshot, 0, sizeof(*out_snapshot));
-  if (!ring->records) return;
-
-  out_snapshot->read_position = ring->read_position;
-  out_snapshot->record_count =
-      (iree_host_size_t)(ring->write_position - ring->read_position);
-  out_snapshot->dropped_record_count = ring->dropped_record_count;
-  IREE_ASSERT_LE(out_snapshot->record_count, ring->capacity);
-  if (out_snapshot->record_count == 0) return;
-
-  const iree_host_size_t first_record_index =
-      (iree_host_size_t)(ring->read_position & ring->mask);
-  out_snapshot->first_record_count =
-      iree_min(out_snapshot->record_count, ring->capacity - first_record_index);
-  out_snapshot->first_records =
-      (const uint8_t*)ring->records + first_record_index * ring->record_size;
-  out_snapshot->second_record_count =
-      out_snapshot->record_count - out_snapshot->first_record_count;
-  if (out_snapshot->second_record_count != 0) {
-    out_snapshot->second_records = ring->records;
-  }
-}
-
-static iree_status_t iree_hal_local_profile_append_iovec(
-    const void* records, iree_host_size_t record_count,
-    iree_host_size_t record_size, iree_host_size_t* iovec_count,
-    iree_const_byte_span_t iovecs[2]) {
-  if (record_count == 0) return iree_ok_status();
-  iree_host_size_t byte_length = 0;
-  if (IREE_UNLIKELY(!iree_host_size_checked_mul(record_count, record_size,
-                                                &byte_length))) {
-    return iree_make_status(
-        IREE_STATUS_OUT_OF_RANGE,
-        "local profiling event chunk size overflow for %" PRIhsz " records",
-        record_count);
-  }
-  iovecs[(*iovec_count)++] = iree_make_const_byte_span(records, byte_length);
-  return iree_ok_status();
-}
-
 static iree_status_t iree_hal_local_profile_recorder_write_event_ring(
     iree_hal_local_profile_recorder_t* recorder,
-    iree_string_view_t content_type,
-    iree_hal_local_profile_event_ring_t* ring) {
-  iree_hal_local_profile_event_ring_snapshot_t snapshot;
+    iree_string_view_t content_type, iree_hal_profile_event_ring_t* ring) {
+  iree_hal_profile_event_ring_snapshot_t snapshot;
   iree_slim_mutex_lock(&recorder->mutex);
-  iree_hal_local_profile_event_ring_snapshot(ring, &snapshot);
+  iree_status_t status = iree_hal_profile_event_ring_snapshot(ring, &snapshot);
   iree_slim_mutex_unlock(&recorder->mutex);
+  IREE_RETURN_IF_ERROR(status);
 
   if (snapshot.record_count == 0 && snapshot.dropped_record_count == 0) {
     return iree_ok_status();
   }
 
-  iree_const_byte_span_t iovecs[2];
-  iree_host_size_t iovec_count = 0;
-  IREE_RETURN_IF_ERROR(iree_hal_local_profile_append_iovec(
-      snapshot.first_records, snapshot.first_record_count, ring->record_size,
-      &iovec_count, iovecs));
-  IREE_RETURN_IF_ERROR(iree_hal_local_profile_append_iovec(
-      snapshot.second_records, snapshot.second_record_count, ring->record_size,
-      &iovec_count, iovecs));
-
   iree_hal_profile_chunk_metadata_t metadata =
       iree_hal_local_profile_recorder_metadata(recorder, content_type);
   if (snapshot.dropped_record_count != 0) {
@@ -1226,15 +1107,11 @@
     metadata.dropped_record_count = snapshot.dropped_record_count;
   }
   IREE_RETURN_IF_ERROR(iree_hal_profile_sink_write(
-      recorder->options.sink, &metadata, iovec_count, iovecs));
+      recorder->options.sink, &metadata, snapshot.record_span_count,
+      snapshot.record_span_count ? snapshot.record_spans : NULL));
 
   iree_slim_mutex_lock(&recorder->mutex);
-  ring->read_position = snapshot.read_position + snapshot.record_count;
-  if (ring->dropped_record_count >= snapshot.dropped_record_count) {
-    ring->dropped_record_count -= snapshot.dropped_record_count;
-  } else {
-    ring->dropped_record_count = 0;
-  }
+  iree_hal_profile_event_ring_commit_snapshot(ring, &snapshot);
   iree_slim_mutex_unlock(&recorder->mutex);
   return iree_ok_status();
 }
diff --git a/runtime/src/iree/hal/memory/tlsf_pool.c b/runtime/src/iree/hal/memory/tlsf_pool.c
index c36af22..fbebfbf 100644
--- a/runtime/src/iree/hal/memory/tlsf_pool.c
+++ b/runtime/src/iree/hal/memory/tlsf_pool.c
@@ -703,7 +703,6 @@
     iree_hal_pool_acquire_info_t* out_info,
     iree_hal_pool_acquire_result_t* out_result) {
   iree_hal_tlsf_pool_t* pool = (iree_hal_tlsf_pool_t*)base_pool;
-  (void)flags;
 
   if (size == 0) {
     return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
@@ -743,6 +742,7 @@
   iree_hal_pool_acquire_result_t selected_result =
       IREE_HAL_POOL_ACQUIRE_EXHAUSTED;
   bool has_selected_allocation = false;
+  bool growth_required = false;
 
   iree_slim_mutex_lock(&pool->mutex);
   iree_hal_tlsf_pool_drain_pending_releases(pool);
@@ -772,15 +772,19 @@
   }
 
   if (iree_status_is_ok(status) && !has_selected_allocation) {
-    uint16_t slab_index = 0;
-    status = iree_hal_tlsf_pool_append_slab(pool, &slab_index);
-    if (iree_status_is_ok(status)) {
-      status = iree_hal_tlsf_pool_try_acquire_from_slab(
-          pool, slab_index, size, requester_frontier,
-          /*record_reuse_miss=*/true, &selected_allocation, &selected_result);
-      has_selected_allocation =
-          selected_result == IREE_HAL_POOL_ACQUIRE_OK ||
-          selected_result == IREE_HAL_POOL_ACQUIRE_OK_FRESH;
+    if (iree_all_bits_set(flags, IREE_HAL_POOL_RESERVE_FLAG_DISALLOW_GROWTH)) {
+      growth_required = true;
+    } else {
+      uint16_t slab_index = 0;
+      status = iree_hal_tlsf_pool_append_slab(pool, &slab_index);
+      if (iree_status_is_ok(status)) {
+        status = iree_hal_tlsf_pool_try_acquire_from_slab(
+            pool, slab_index, size, requester_frontier,
+            /*record_reuse_miss=*/true, &selected_allocation, &selected_result);
+        has_selected_allocation =
+            selected_result == IREE_HAL_POOL_ACQUIRE_OK ||
+            selected_result == IREE_HAL_POOL_ACQUIRE_OK_FRESH;
+      }
     }
   }
 
@@ -815,6 +819,9 @@
                             iree_memory_order_relaxed);
       memset(out_reservation, 0, sizeof(*out_reservation));
       memset(out_info, 0, sizeof(*out_info));
+      if (growth_required) {
+        out_info->flags |= IREE_HAL_POOL_ACQUIRE_FLAG_GROWTH_REQUIRED;
+      }
       *out_result = IREE_HAL_POOL_ACQUIRE_EXHAUSTED;
       iree_hal_tlsf_pool_uncharge_reservation(pool, charged_length);
       charged_length = 0;
diff --git a/runtime/src/iree/hal/memory/tlsf_pool_test.cc b/runtime/src/iree/hal/memory/tlsf_pool_test.cc
index 3f1889b..7053e98 100644
--- a/runtime/src/iree/hal/memory/tlsf_pool_test.cc
+++ b/runtime/src/iree/hal/memory/tlsf_pool_test.cc
@@ -520,6 +520,65 @@
   iree_hal_slab_provider_release(slab_provider);
 }
 
+TEST(TLSFPool, ReserveCanReportGrowthRequiredWithoutGrowing) {
+  iree_allocator_t allocator = iree_allocator_system();
+  iree_hal_slab_provider_t* slab_provider = NULL;
+  IREE_ASSERT_OK(iree_hal_cpu_slab_provider_create(allocator, &slab_provider));
+  iree_async_notification_t* notification = NULL;
+  IREE_ASSERT_OK(iree_async_notification_create(
+      test_proactor(), IREE_ASYNC_NOTIFICATION_FLAG_NONE, &notification));
+
+  iree_hal_tlsf_pool_options_t options = DefaultOptions();
+  options.tlsf_options.range_length = 256;
+
+  iree_hal_pool_t* pool = NULL;
+  IREE_ASSERT_OK(iree_hal_tlsf_pool_create(options, slab_provider, notification,
+                                           iree_hal_pool_epoch_query_null(),
+                                           allocator, &pool));
+
+  iree_hal_pool_reservation_t reservation;
+  iree_hal_pool_acquire_info_t reserve_info;
+  iree_hal_pool_acquire_result_t result;
+  IREE_ASSERT_OK(iree_hal_pool_acquire_reservation(
+      pool, 256, 16, /*requester_frontier=*/NULL,
+      IREE_HAL_POOL_RESERVE_FLAG_NONE, &reservation, &reserve_info, &result));
+  EXPECT_EQ(result, IREE_HAL_POOL_ACQUIRE_OK_FRESH);
+
+  MAKE_FRONTIER(death, 1, E(TestQueueAxis(0), 20));
+  iree_hal_pool_release_reservation(pool, &reservation, death);
+
+  MAKE_FRONTIER(requester, 1, E(TestQueueAxis(0), 10));
+  IREE_ASSERT_OK(iree_hal_pool_acquire_reservation(
+      pool, 256, 16, requester,
+      IREE_HAL_POOL_RESERVE_FLAG_ALLOW_WAIT_FRONTIER |
+          IREE_HAL_POOL_RESERVE_FLAG_DISALLOW_GROWTH,
+      &reservation, &reserve_info, &result));
+  EXPECT_EQ(result, IREE_HAL_POOL_ACQUIRE_EXHAUSTED);
+  EXPECT_TRUE(iree_all_bits_set(reserve_info.flags,
+                                IREE_HAL_POOL_ACQUIRE_FLAG_GROWTH_REQUIRED));
+
+  iree_hal_pool_stats_t stats;
+  iree_hal_pool_query_stats(pool, &stats);
+  EXPECT_EQ(stats.slab_count, 1u);
+  EXPECT_EQ(stats.reuse_miss_count, 1u);
+  EXPECT_EQ(stats.exhausted_count, 1u);
+  EXPECT_EQ(stats.wait_count, 0u);
+
+  iree_hal_pool_reservation_t grown_reservation;
+  IREE_ASSERT_OK(iree_hal_pool_acquire_reservation(
+      pool, 256, 16, requester, IREE_HAL_POOL_RESERVE_FLAG_ALLOW_WAIT_FRONTIER,
+      &grown_reservation, &reserve_info, &result));
+  EXPECT_EQ(result, IREE_HAL_POOL_ACQUIRE_OK_FRESH);
+  EXPECT_FALSE(iree_all_bits_set(reserve_info.flags,
+                                 IREE_HAL_POOL_ACQUIRE_FLAG_GROWTH_REQUIRED));
+  EXPECT_NE(grown_reservation.slab_index, 0u);
+
+  iree_hal_pool_release_reservation(pool, &grown_reservation, NULL);
+  iree_hal_pool_release(pool);
+  iree_async_notification_release(notification);
+  iree_hal_slab_provider_release(slab_provider);
+}
+
 TEST(TLSFPool, ReserveRejectedTaintRemainsRejected) {
   iree_allocator_t allocator = iree_allocator_system();
   iree_hal_slab_provider_t* slab_provider = NULL;
diff --git a/runtime/src/iree/hal/pool.h b/runtime/src/iree/hal/pool.h
index b388819..930ae62 100644
--- a/runtime/src/iree/hal/pool.h
+++ b/runtime/src/iree/hal/pool.h
@@ -128,6 +128,16 @@
   // this flag. Such calls should receive only immediately-usable reservations
   // or transient EXHAUSTED/OVER_BUDGET results from well-behaved pools.
   IREE_HAL_POOL_RESERVE_FLAG_ALLOW_WAIT_FRONTIER = 1u << 0,
+
+  // Prevents growable pools from acquiring additional backing storage during
+  // this reservation attempt. Pools that could satisfy the request by growing
+  // should return IREE_HAL_POOL_ACQUIRE_EXHAUSTED with
+  // IREE_HAL_POOL_ACQUIRE_FLAG_GROWTH_REQUIRED instead of calling into their
+  // slab provider.
+  //
+  // Queue implementations use this inside critical sections so unbounded
+  // platform memory allocation is routed through an explicit cold path.
+  IREE_HAL_POOL_RESERVE_FLAG_DISALLOW_GROWTH = 1u << 1,
 };
 
 // Generic metadata flags returned by a pool reservation acquisition.
@@ -140,6 +150,10 @@
   // intentionally disabled for safety. Queue implementations should treat this
   // as a conservative dependency edge, not proof of precise happens-before.
   IREE_HAL_POOL_ACQUIRE_FLAG_WAIT_FRONTIER_TAINTED = 1u << 0,
+
+  // The pool did not make a reservation because the caller prohibited growth
+  // and the request could only be satisfied by acquiring more backing storage.
+  IREE_HAL_POOL_ACQUIRE_FLAG_GROWTH_REQUIRED = 1u << 1,
 };
 
 // Generic metadata returned by a pool reservation acquisition.
diff --git a/runtime/src/iree/hal/profile_options.h b/runtime/src/iree/hal/profile_options.h
index d4b025a..640b490 100644
--- a/runtime/src/iree/hal/profile_options.h
+++ b/runtime/src/iree/hal/profile_options.h
@@ -49,8 +49,8 @@
   // when they cannot retain complete selected events.
   IREE_HAL_DEVICE_PROFILING_DATA_DISPATCH_EVENTS = 1ull << 3,
 
-  // Explicitly selected hardware/software counter samples. Requested counters
-  // are described by |counter_sets|.
+  // Explicitly selected hardware/software counter samples attributed to
+  // individual operations. Requested counters are described by |counter_sets|.
   IREE_HAL_DEVICE_PROFILING_DATA_COUNTER_SAMPLES = 1ull << 4,
 
   // Executable/code-object/export metadata needed for offline analysis. Some
@@ -81,6 +81,11 @@
   // around groups of command operations; dispatch/kernel execution details stay
   // in dispatch or host-execution event families.
   IREE_HAL_DEVICE_PROFILING_DATA_COMMAND_REGION_EVENTS = 1ull << 9,
+
+  // Explicitly selected hardware/software counter ranges. Requested counters
+  // are described by |counter_sets| and are sampled over producer-defined time
+  // ranges without requiring operation attribution.
+  IREE_HAL_DEVICE_PROFILING_DATA_COUNTER_RANGES = 1ull << 10,
 };
 
 // Bitfield selecting producer-side profiling behavior that is not itself a
@@ -246,9 +251,8 @@
   // string views it contains before returning from profiling_begin.
   iree_hal_profile_capture_filter_t capture_filter;
 
-  // Number of explicitly requested hardware/software counter sets.
-  // Must be nonzero when requesting
-  // IREE_HAL_DEVICE_PROFILING_DATA_COUNTER_SAMPLES.
+  // Number of explicitly requested hardware/software counter sets. Must be
+  // nonzero when requesting counter samples or counter ranges.
   iree_host_size_t counter_set_count;
 
   // Borrowed begin-call-only array of explicitly requested counter sets.
@@ -296,13 +300,28 @@
       options->flags, IREE_HAL_DEVICE_PROFILING_FLAG_LIGHTWEIGHT_STATISTICS);
 }
 
-// Returns true when |options| requests explicit hardware counter capture.
+// Returns true when |options| requests operation-attributed counter samples.
 static inline bool iree_hal_device_profiling_options_requests_counter_samples(
     const iree_hal_device_profiling_options_t* options) {
   return iree_hal_device_profiling_options_requests_data(
       options, IREE_HAL_DEVICE_PROFILING_DATA_COUNTER_SAMPLES);
 }
 
+// Returns true when |options| requests range-scoped counter samples.
+static inline bool iree_hal_device_profiling_options_requests_counter_ranges(
+    const iree_hal_device_profiling_options_t* options) {
+  return iree_hal_device_profiling_options_requests_data(
+      options, IREE_HAL_DEVICE_PROFILING_DATA_COUNTER_RANGES);
+}
+
+// Returns true when |options| requests any explicit counter capture.
+static inline bool iree_hal_device_profiling_options_requests_counters(
+    const iree_hal_device_profiling_options_t* options) {
+  return iree_hal_device_profiling_options_requests_data(
+      options, IREE_HAL_DEVICE_PROFILING_DATA_COUNTER_SAMPLES |
+                   IREE_HAL_DEVICE_PROFILING_DATA_COUNTER_RANGES);
+}
+
 // Returns true when |options| requests executable trace artifacts.
 static inline bool iree_hal_device_profiling_options_requests_executable_traces(
     const iree_hal_device_profiling_options_t* options) {
diff --git a/runtime/src/iree/hal/topology.h b/runtime/src/iree/hal/topology.h
index 166a232..6d7c337 100644
--- a/runtime/src/iree/hal/topology.h
+++ b/runtime/src/iree/hal/topology.h
@@ -276,6 +276,10 @@
   // Shared virtual addressing (SVA/SVM) across this link.
   // Both devices can use the same virtual addresses for shared memory.
   IREE_HAL_TOPOLOGY_CAPABILITY_SHARED_VIRTUAL_ADDRESS = 1u << 11,
+  // One or more direct peer access paths represented by this edge require
+  // per-allocation access grants. Until the allocation/access policy proves a
+  // grant was applied, those buffer modes must not report NATIVE access.
+  IREE_HAL_TOPOLOGY_CAPABILITY_PEER_ACCESS_REQUIRES_GRANT = 1u << 12,
 };
 typedef uint16_t iree_hal_topology_capability_t;
 
@@ -510,6 +514,7 @@
 // - TIMELINE_SEMAPHORE: Supports timeline semaphores for fine-grained sync
 // - REMOTE_DMA: RDMA transfers supported across this link
 // - SHARED_VIRTUAL_ADDRESS: SVA/SVM across this link
+// - PEER_ACCESS_REQUIRES_GRANT: direct peer access needs allocation grants
 //
 // Implementations should be conservative - only set flags that hardware truly
 // guarantees. ATOMIC_SYSTEM requires platform support (PCIe atomics, vendor
diff --git a/runtime/src/iree/hal/topology_builder.c b/runtime/src/iree/hal/topology_builder.c
index 91d57c7..bb02787 100644
--- a/runtime/src/iree/hal/topology_builder.c
+++ b/runtime/src/iree/hal/topology_builder.c
@@ -188,7 +188,8 @@
       IREE_HAL_TOPOLOGY_CAPABILITY_CONCURRENT_SAFE |
       IREE_HAL_TOPOLOGY_CAPABILITY_ATOMIC_DEVICE |
       IREE_HAL_TOPOLOGY_CAPABILITY_ATOMIC_SYSTEM |
-      IREE_HAL_TOPOLOGY_CAPABILITY_TIMELINE_SEMAPHORE;
+      IREE_HAL_TOPOLOGY_CAPABILITY_TIMELINE_SEMAPHORE |
+      IREE_HAL_TOPOLOGY_CAPABILITY_SHARED_VIRTUAL_ADDRESS;
   lo = iree_hal_topology_edge_set_capability_flags(lo, caps);
 
   // Zero cost for all operations on self.
@@ -338,8 +339,9 @@
   //     Determined by PEER_ADDRESSABLE (large BAR mapping) and P2P_COPY.
   //
   //   Coherent: memory with hardware-maintained coherency (fine-grained, SVM).
-  //     Often more accessible — SVM provides direct load/store across devices.
-  //     Determined by UNIFIED_MEMORY (SVM) or PEER_COHERENT + PEER_ADDRESSABLE.
+  //     UNIFIED_MEMORY means device-visible coherent memory is accessible by
+  //     default. SHARED_VIRTUAL_ADDRESS only says matching virtual addresses
+  //     can be made meaningful; it may still require per-range access grants.
   //
   // NATIVE: load/store addressable — scheduler references the buffer directly.
   // IMPORT: buffer handle import — one-time setup, then directly usable.
@@ -355,6 +357,9 @@
   bool unified_memory =
       (src_caps->flags & IREE_HAL_DEVICE_CAPABILITY_UNIFIED_MEMORY) &&
       (dst_caps->flags & IREE_HAL_DEVICE_CAPABILITY_UNIFIED_MEMORY);
+  bool shared_virtual_address =
+      (src_caps->flags & IREE_HAL_DEVICE_CAPABILITY_SHARED_VIRTUAL_ADDRESS) &&
+      (dst_caps->flags & IREE_HAL_DEVICE_CAPABILITY_SHARED_VIRTUAL_ADDRESS);
 
   // Non-coherent buffer modes (device-local, coarse-grained).
   iree_hal_topology_interop_mode_t nc_buffer_read_mode, nc_buffer_write_mode;
@@ -408,6 +413,10 @@
     caps |= IREE_HAL_TOPOLOGY_CAPABILITY_UNIFIED_MEMORY;
   }
 
+  if (shared_virtual_address) {
+    caps |= IREE_HAL_TOPOLOGY_CAPABILITY_SHARED_VIRTUAL_ADDRESS;
+  }
+
   if ((src_caps->flags & IREE_HAL_DEVICE_CAPABILITY_PEER_COHERENT) &&
       (dst_caps->flags & IREE_HAL_DEVICE_CAPABILITY_PEER_COHERENT) &&
       same_driver) {
diff --git a/runtime/src/iree/hal/topology_test.cc b/runtime/src/iree/hal/topology_test.cc
index d971ba3..3e99793 100644
--- a/runtime/src/iree/hal/topology_test.cc
+++ b/runtime/src/iree/hal/topology_test.cc
@@ -171,7 +171,8 @@
       IREE_HAL_TOPOLOGY_CAPABILITY_CONCURRENT_SAFE |
       IREE_HAL_TOPOLOGY_CAPABILITY_ATOMIC_DEVICE |
       IREE_HAL_TOPOLOGY_CAPABILITY_ATOMIC_SYSTEM |
-      IREE_HAL_TOPOLOGY_CAPABILITY_TIMELINE_SEMAPHORE;
+      IREE_HAL_TOPOLOGY_CAPABILITY_TIMELINE_SEMAPHORE |
+      IREE_HAL_TOPOLOGY_CAPABILITY_SHARED_VIRTUAL_ADDRESS;
   EXPECT_EQ(iree_hal_topology_edge_capability_flags(edge.lo), expected_caps);
 
   // Self-edges use SAME_DIE link class.
@@ -282,6 +283,25 @@
   EXPECT_NE(iree_hal_topology_edge_copy_cost(edge.lo), 0);
 }
 
+TEST(TopologyEdge, SharedVirtualAddressDoesNotImplyUnifiedMemory) {
+  iree_hal_device_capabilities_t caps = {0};
+  caps.flags = IREE_HAL_DEVICE_CAPABILITY_SHARED_VIRTUAL_ADDRESS;
+
+  iree_hal_topology_edge_t edge = iree_hal_topology_edge_from_capabilities(
+      &caps, &caps, IREE_SV("amdgpu"), IREE_SV("amdgpu"));
+  iree_hal_topology_capability_t topology_caps =
+      iree_hal_topology_edge_capability_flags(edge.lo);
+
+  EXPECT_TRUE(topology_caps &
+              IREE_HAL_TOPOLOGY_CAPABILITY_SHARED_VIRTUAL_ADDRESS);
+  EXPECT_FALSE(topology_caps & IREE_HAL_TOPOLOGY_CAPABILITY_UNIFIED_MEMORY);
+  EXPECT_FALSE(topology_caps & IREE_HAL_TOPOLOGY_CAPABILITY_PEER_COHERENT);
+  EXPECT_EQ(iree_hal_topology_edge_buffer_read_mode_coherent(edge.lo),
+            IREE_HAL_TOPOLOGY_INTEROP_MODE_COPY);
+  EXPECT_EQ(iree_hal_topology_edge_buffer_write_mode_coherent(edge.lo),
+            IREE_HAL_TOPOLOGY_INTEROP_MODE_COPY);
+}
+
 //===----------------------------------------------------------------------===//
 // Resource origin tests
 //===----------------------------------------------------------------------===//
diff --git a/runtime/src/iree/hal/utils/BUILD.bazel b/runtime/src/iree/hal/utils/BUILD.bazel
index 58f9985..ee0f715 100644
--- a/runtime/src/iree/hal/utils/BUILD.bazel
+++ b/runtime/src/iree/hal/utils/BUILD.bazel
@@ -220,6 +220,15 @@
 )
 
 iree_runtime_cc_library(
+    name = "profile_event_ring",
+    srcs = ["profile_event_ring.c"],
+    hdrs = ["profile_event_ring.h"],
+    deps = [
+        "//runtime/src/iree/base",
+    ],
+)
+
+iree_runtime_cc_library(
     name = "statistics_sink",
     srcs = ["statistics_sink.c"],
     hdrs = ["statistics_sink.h"],
diff --git a/runtime/src/iree/hal/utils/CMakeLists.txt b/runtime/src/iree/hal/utils/CMakeLists.txt
index b71fb67..42c1de8 100644
--- a/runtime/src/iree/hal/utils/CMakeLists.txt
+++ b/runtime/src/iree/hal/utils/CMakeLists.txt
@@ -257,6 +257,18 @@
 
 iree_cc_library(
   NAME
+    profile_event_ring
+  HDRS
+    "profile_event_ring.h"
+  SRCS
+    "profile_event_ring.c"
+  DEPS
+    iree::base
+  PUBLIC
+)
+
+iree_cc_library(
+  NAME
     statistics_sink
   HDRS
     "statistics_sink.h"
diff --git a/runtime/src/iree/hal/utils/profile_event_ring.c b/runtime/src/iree/hal/utils/profile_event_ring.c
new file mode 100644
index 0000000..a6459c5
--- /dev/null
+++ b/runtime/src/iree/hal/utils/profile_event_ring.c
@@ -0,0 +1,124 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/utils/profile_event_ring.h"
+
+#include <string.h>
+
+void iree_hal_profile_event_ring_initialize(
+    void* records, iree_host_size_t record_size, iree_host_size_t capacity,
+    iree_hal_profile_event_ring_t* out_ring) {
+  IREE_ASSERT_ARGUMENT(out_ring);
+  IREE_ASSERT(capacity == 0 || iree_host_size_is_power_of_two(capacity));
+  IREE_ASSERT(capacity == 0 || records != NULL);
+  memset(out_ring, 0, sizeof(*out_ring));
+  out_ring->records = records;
+  out_ring->record_size = record_size;
+  out_ring->capacity = capacity;
+  out_ring->mask = capacity ? capacity - 1 : 0;
+  out_ring->next_event_id = 1;
+}
+
+void iree_hal_profile_event_ring_clear(iree_hal_profile_event_ring_t* ring) {
+  IREE_ASSERT_ARGUMENT(ring);
+  ring->read_position = 0;
+  ring->write_position = 0;
+  ring->dropped_record_count = 0;
+  ring->next_event_id = 1;
+  if (ring->records && ring->capacity != 0) {
+    memset(ring->records, 0, ring->capacity * ring->record_size);
+  }
+}
+
+void* iree_hal_profile_event_ring_record_at(
+    const iree_hal_profile_event_ring_t* ring, uint64_t position) {
+  IREE_ASSERT_ARGUMENT(ring);
+  IREE_ASSERT(ring->records != NULL);
+  IREE_ASSERT(ring->capacity != 0);
+  return (uint8_t*)ring->records + (position & ring->mask) * ring->record_size;
+}
+
+bool iree_hal_profile_event_ring_try_append(iree_hal_profile_event_ring_t* ring,
+                                            uint64_t* out_position,
+                                            uint64_t* out_event_id) {
+  IREE_ASSERT_ARGUMENT(ring);
+  IREE_ASSERT_ARGUMENT(out_position);
+  IREE_ASSERT_ARGUMENT(out_event_id);
+  *out_position = 0;
+  *out_event_id = 0;
+  if (!ring->records || ring->capacity == 0) return false;
+
+  const uint64_t read_position = ring->read_position;
+  const uint64_t write_position = ring->write_position;
+  const uint64_t occupied_count = write_position - read_position;
+  if (occupied_count >= ring->capacity) {
+    ++ring->dropped_record_count;
+    return false;
+  }
+
+  *out_position = write_position;
+  *out_event_id = ring->next_event_id++;
+  ring->write_position = write_position + 1;
+  return true;
+}
+
+iree_status_t iree_hal_profile_event_ring_snapshot(
+    const iree_hal_profile_event_ring_t* ring,
+    iree_hal_profile_event_ring_snapshot_t* out_snapshot) {
+  IREE_ASSERT_ARGUMENT(ring);
+  IREE_ASSERT_ARGUMENT(out_snapshot);
+  memset(out_snapshot, 0, sizeof(*out_snapshot));
+  if (!ring->records || ring->capacity == 0) return iree_ok_status();
+
+  out_snapshot->read_position = ring->read_position;
+  out_snapshot->record_count =
+      (iree_host_size_t)(ring->write_position - ring->read_position);
+  out_snapshot->dropped_record_count = ring->dropped_record_count;
+  IREE_ASSERT_LE(out_snapshot->record_count, ring->capacity);
+  if (out_snapshot->record_count == 0) return iree_ok_status();
+
+  const iree_host_size_t first_record_index =
+      (iree_host_size_t)(ring->read_position & ring->mask);
+  const iree_host_size_t first_record_count =
+      iree_min(out_snapshot->record_count, ring->capacity - first_record_index);
+  iree_host_size_t first_byte_length = 0;
+  if (IREE_UNLIKELY(!iree_host_size_checked_mul(
+          first_record_count, ring->record_size, &first_byte_length))) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "profile event ring snapshot size overflows");
+  }
+  out_snapshot->record_spans[out_snapshot->record_span_count++] =
+      iree_make_const_byte_span((const uint8_t*)ring->records +
+                                    first_record_index * ring->record_size,
+                                first_byte_length);
+
+  const iree_host_size_t second_record_count =
+      out_snapshot->record_count - first_record_count;
+  if (second_record_count != 0) {
+    iree_host_size_t second_byte_length = 0;
+    if (IREE_UNLIKELY(!iree_host_size_checked_mul(
+            second_record_count, ring->record_size, &second_byte_length))) {
+      return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                              "profile event ring snapshot size overflows");
+    }
+    out_snapshot->record_spans[out_snapshot->record_span_count++] =
+        iree_make_const_byte_span(ring->records, second_byte_length);
+  }
+  return iree_ok_status();
+}
+
+void iree_hal_profile_event_ring_commit_snapshot(
+    iree_hal_profile_event_ring_t* ring,
+    const iree_hal_profile_event_ring_snapshot_t* snapshot) {
+  IREE_ASSERT_ARGUMENT(ring);
+  IREE_ASSERT_ARGUMENT(snapshot);
+  ring->read_position = snapshot->read_position + snapshot->record_count;
+  if (ring->dropped_record_count >= snapshot->dropped_record_count) {
+    ring->dropped_record_count -= snapshot->dropped_record_count;
+  } else {
+    ring->dropped_record_count = 0;
+  }
+}
diff --git a/runtime/src/iree/hal/utils/profile_event_ring.h b/runtime/src/iree/hal/utils/profile_event_ring.h
new file mode 100644
index 0000000..7501c32
--- /dev/null
+++ b/runtime/src/iree/hal/utils/profile_event_ring.h
@@ -0,0 +1,107 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_UTILS_PROFILE_EVENT_RING_H_
+#define IREE_HAL_UTILS_PROFILE_EVENT_RING_H_
+
+#include "iree/base/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+//===----------------------------------------------------------------------===//
+// iree_hal_profile_event_ring_t
+//===----------------------------------------------------------------------===//
+
+// Lossy fixed-capacity profiling event ring.
+//
+// The ring owns no synchronization and no storage lifetime. Callers decide
+// which mutex or higher-level exclusion contract protects positions, event ids,
+// drop counts, and record contents.
+typedef struct iree_hal_profile_event_ring_t {
+  // Storage for fixed-size event records, or NULL when disabled.
+  void* records;
+
+  // Byte size of each event record in |records|.
+  iree_host_size_t record_size;
+
+  // Power-of-two number of event records in |records|.
+  iree_host_size_t capacity;
+
+  // Capacity minus one, for mapping logical positions to physical slots.
+  iree_host_size_t mask;
+
+  // Absolute position of the first retained event record.
+  uint64_t read_position;
+
+  // Absolute position one past the last retained event record.
+  uint64_t write_position;
+
+  // Records dropped since the last successful flush accounted them.
+  uint64_t dropped_record_count;
+
+  // Next nonzero event id assigned by this ring.
+  uint64_t next_event_id;
+} iree_hal_profile_event_ring_t;
+
+// Immutable view of a ring flush attempt.
+typedef struct iree_hal_profile_event_ring_snapshot_t {
+  // Absolute read position captured for this flush attempt.
+  uint64_t read_position;
+
+  // Number of event records captured for this flush attempt.
+  iree_host_size_t record_count;
+
+  // Dropped records captured for this flush attempt.
+  uint64_t dropped_record_count;
+
+  // Contiguous byte spans covering captured records in ring order.
+  iree_const_byte_span_t record_spans[2];
+
+  // Number of initialized entries in |record_spans|.
+  iree_host_size_t record_span_count;
+} iree_hal_profile_event_ring_snapshot_t;
+
+// Initializes |out_ring| over caller-owned record storage.
+//
+// |capacity| must be zero or a nonzero power of two. |records| may be NULL
+// only when |capacity| is zero, which creates a disabled ring.
+void iree_hal_profile_event_ring_initialize(
+    void* records, iree_host_size_t record_size, iree_host_size_t capacity,
+    iree_hal_profile_event_ring_t* out_ring);
+
+// Clears positions, drop counts, event ids, and retained record bytes.
+void iree_hal_profile_event_ring_clear(iree_hal_profile_event_ring_t* ring);
+
+// Returns the fixed-size record slot for |position|.
+void* iree_hal_profile_event_ring_record_at(
+    const iree_hal_profile_event_ring_t* ring, uint64_t position);
+
+// Attempts to reserve one event record.
+//
+// Returns false when the ring is disabled. Returns false and accounts one
+// dropped record when the ring is full. On success, returns the reserved
+// logical position and assigned event id.
+bool iree_hal_profile_event_ring_try_append(iree_hal_profile_event_ring_t* ring,
+                                            uint64_t* out_position,
+                                            uint64_t* out_event_id);
+
+// Captures retained records and dropped-record count for a flush attempt.
+iree_status_t iree_hal_profile_event_ring_snapshot(
+    const iree_hal_profile_event_ring_t* ring,
+    iree_hal_profile_event_ring_snapshot_t* out_snapshot);
+
+// Advances |ring| after a successful sink write of |snapshot|.
+void iree_hal_profile_event_ring_commit_snapshot(
+    iree_hal_profile_event_ring_t* ring,
+    const iree_hal_profile_event_ring_snapshot_t* snapshot);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_UTILS_PROFILE_EVENT_RING_H_
diff --git a/runtime/src/iree/schemas/amdgpu_executable_def.fbs b/runtime/src/iree/schemas/amdgpu_executable_def.fbs
index e16bc56..d2fea40 100644
--- a/runtime/src/iree/schemas/amdgpu_executable_def.fbs
+++ b/runtime/src/iree/schemas/amdgpu_executable_def.fbs
@@ -51,8 +51,11 @@
 }
 
 table ExecutableDef {
-  // Canonical ISA name this executable has been compiled for.
-  // E.g. `amdgcn-amd-amdhsa--gfx1100`
+  // AMDGPU target ID this executable has been compiled for.
+  //
+  // E.g. `gfx1100` or `gfx942:sramecc+:xnack-`. Runtime compatibility checks
+  // recover the load target from embedded ELF module flags; this field is kept
+  // as producer metadata and for older tools that inspect the flatbuffer.
   isa:string;
 
   // Exported functions in canonical executable entry point order.
diff --git a/runtime/src/iree/tooling/device_util.c b/runtime/src/iree/tooling/device_util.c
index 364b0e2..d35afd5 100644
--- a/runtime/src/iree/tooling/device_util.c
+++ b/runtime/src/iree/tooling/device_util.c
@@ -450,7 +450,8 @@
     "HAL device profiling data families as a comma-separated list drawn from\n"
     "['queue-events', 'host-execution', 'device-queue-events',\n"
     "'dispatch-events', 'memory-events', 'device-metrics',\n"
-    "'command-region-events', 'counters', 'executable-metadata',\n"
+    "'command-region-events', 'counters', 'counter-ranges',\n"
+    "'executable-metadata',\n"
     "'executable-traces'] or empty to disable profiling. HAL implementations\n"
     "may require additional flags in order to configure profiling support on\n"
     "their devices. Tooling may force VM-created command buffers to retain\n"
@@ -488,9 +489,13 @@
     "Optional implementation-specific hardware counter name to capture. May "
     "be\n"
     "specified multiple times; the selected HAL driver decides which counter\n"
-    "names and combinations are supported. Some backends collect counters by\n"
-    "injecting profiling packets around selected dispatches, which perturbs\n"
-    "queue timing even though it enables per-dispatch attribution.");
+    "names and combinations are supported. Use "
+    "--device_profiling_mode=counters\n"
+    "for dispatch-scoped attribution, which may inject packets around "
+    "selected\n"
+    "dispatches and perturb queue timing. Use\n"
+    "--device_profiling_mode=counter-ranges for low-disturbance range "
+    "samples.");
 IREE_FLAG(
     int64_t, device_profiling_flush_interval_ms, 0,
     "Optional interval in milliseconds for a tooling-owned background thread\n"
@@ -662,6 +667,9 @@
           IREE_HAL_DEVICE_PROFILING_DATA_COMMAND_REGION_EVENTS;
     } else if (iree_string_view_equal(family_part, IREE_SV("counters"))) {
       *out_data_families |= IREE_HAL_DEVICE_PROFILING_DATA_COUNTER_SAMPLES;
+    } else if (iree_string_view_equal(family_part, IREE_SV("counter-ranges")) ||
+               iree_string_view_equal(family_part, IREE_SV("pmc-ranges"))) {
+      *out_data_families |= IREE_HAL_DEVICE_PROFILING_DATA_COUNTER_RANGES;
     } else if (iree_string_view_equal(family_part,
                                       IREE_SV("executable-metadata"))) {
       *out_data_families |= IREE_HAL_DEVICE_PROFILING_DATA_EXECUTABLE_METADATA;
@@ -894,7 +902,8 @@
     if (counter_names.count != 0) {
       return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
                               "--device_profiling_counter requires "
-                              "--device_profiling_mode=counters");
+                              "--device_profiling_mode=counters or "
+                              "--device_profiling_mode=counter-ranges");
     }
     if (strlen(FLAG_device_profiling_output) != 0) {
       return iree_make_status(
@@ -922,10 +931,12 @@
   }
   if (counter_names.count != 0 &&
       !iree_any_bit_set(options.data_families,
-                        IREE_HAL_DEVICE_PROFILING_DATA_COUNTER_SAMPLES)) {
+                        IREE_HAL_DEVICE_PROFILING_DATA_COUNTER_SAMPLES |
+                            IREE_HAL_DEVICE_PROFILING_DATA_COUNTER_RANGES)) {
     return iree_make_status(
         IREE_STATUS_INVALID_ARGUMENT,
-        "--device_profiling_counter requires --device_profiling_mode=counters");
+        "--device_profiling_counter requires --device_profiling_mode=counters "
+        "or --device_profiling_mode=counter-ranges");
   }
   if (options.data_families != IREE_HAL_DEVICE_PROFILING_DATA_NONE &&
       strlen(FLAG_device_profiling_output) == 0 && !statistics_requested) {
diff --git a/runtime/src/iree/tooling/profile/help.c b/runtime/src/iree/tooling/profile/help.c
index a269ac8..6d259ea 100644
--- a/runtime/src/iree/tooling/profile/help.c
+++ b/runtime/src/iree/tooling/profile/help.c
@@ -600,6 +600,16 @@
       "         map({key,counter,samples,avg,sum})'\n"
       "```\n"
       "\n"
+      "Capture low-disturbance PMC ranges without dispatch attribution:\n"
+      "\n"
+      "```bash\n"
+      "iree-benchmark-module --device=amdgpu --module=model.vmfb \\\n"
+      "  --function=main --benchmark_min_time=20x \\\n"
+      "  --device_profiling_mode=counter-ranges \\\n"
+      "  --device_profiling_counter=SQ_WAVES \\\n"
+      "  --device_profiling_output=/tmp/model.ireeprof\n"
+      "```\n"
+      "\n"
       "Show raw counter samples for one dispatch event:\n"
       "\n"
       "```bash\n"
diff --git a/runtime/src/iree/tooling/profile/render/perfetto.py b/runtime/src/iree/tooling/profile/render/perfetto.py
index 43e586f..acbff79 100644
--- a/runtime/src/iree/tooling/profile/render/perfetto.py
+++ b/runtime/src/iree/tooling/profile/render/perfetto.py
@@ -91,8 +91,11 @@
     diagnostic_instants: int = 0
     counter_samples: int = 0
     dispatch_scoped_counter_values: int = 0
+    range_counter_values: int = 0
     skipped_dispatch_scoped_counter_samples: int = 0
     skipped_dispatch_scoped_counter_values: int = 0
+    skipped_range_counter_samples: int = 0
+    skipped_range_counter_values: int = 0
     device_metric_counter_values: int = 0
     device_metric_partial_samples: int = 0
     skipped_device_metric_samples: int = 0
@@ -728,7 +731,10 @@
         elif record_type == "clock_correlation":
             self.collect_clock_correlation(record)
         elif record_type == "counter_sample":
-            self.collect_dispatch_scoped_counter_sample(record)
+            if record.get("scope") == "device_time_range":
+                self.collect_range_counter_sample(record)
+            else:
+                self.collect_dispatch_scoped_counter_sample(record)
         elif record_type == "device_metric_sample":
             self.collect_device_metric_sample(record)
         elif record_type == "diagnostic":
@@ -1146,7 +1152,7 @@
 
         emitted_value_count = 0
         for counter in counters:
-            counter_value = self.dispatch_scoped_counter_value(values, counter)
+            counter_value = self.counter_sample_value(values, counter)
             if counter_value is None:
                 self.stats.skipped_dispatch_scoped_counter_values += 1
                 continue
@@ -1188,7 +1194,7 @@
         if emitted_value_count == 0:
             self.stats.skipped_dispatch_scoped_counter_samples += 1
 
-    def dispatch_scoped_counter_value(
+    def counter_sample_value(
         self, values: list[Any], counter: dict[str, Any]
     ) -> int | None:
         offset = parse_integer(counter.get("sample_value_offset", 0))
@@ -1200,6 +1206,116 @@
             value_sum += parse_integer(value)
         return value_sum
 
+    def collect_range_counter_sample(self, record: dict[str, Any]) -> None:
+        values = record.get("values", [])
+        counter_set_id = parse_integer(record.get("counter_set_id", 0))
+        counter_set = self.counter_sets_by_id.get(counter_set_id)
+        counters = self.counters_by_counter_set_id.get(counter_set_id, [])
+        if (
+            record.get("device_tick_range_valid") is False
+            or not isinstance(values, list)
+            or counter_set is None
+            or not counters
+        ):
+            self.stats.skipped_range_counter_samples += 1
+            return
+        expected_value_count = parse_integer(
+            counter_set.get("sample_value_count", len(values))
+        )
+        if expected_value_count != len(values):
+            self.stats.skipped_range_counter_samples += 1
+            return
+        host_range = device_event_host_time_range(record, self.clock_mappers)
+        if host_range is None:
+            self.stats.skipped_range_counter_samples += 1
+            return
+        start_time_ns, end_time_ns, time_domain = host_range
+        normalized_range = normalized_time_range(start_time_ns, end_time_ns)
+        if normalized_range is None:
+            self.stats.skipped_range_counter_samples += 1
+            return
+        start_time_ns, end_time_ns = normalized_range
+        event_time_ns = (start_time_ns + end_time_ns) // 2
+        physical_device_ordinal, _ = queue_key(record)
+
+        emitted_value_count = 0
+        for counter in counters:
+            counter_value = self.counter_sample_value(values, counter)
+            if counter_value is None:
+                self.stats.skipped_range_counter_values += 1
+                continue
+            counter_ordinal = parse_ordinal(counter.get("counter_ordinal"))
+            track_uuid = self.define_range_counter_track(
+                physical_device_ordinal, counter_set_id, counter_set, counter
+            )
+            annotations = self.range_counter_annotations(
+                record, counter_set, counter, counter_value, time_domain
+            )
+            self.timeline_events.append(
+                TimelineEvent(
+                    event_time_ns,
+                    (
+                        2,
+                        "range-counter",
+                        physical_device_ordinal,
+                        counter_set_id,
+                        counter_ordinal,
+                        event_time_ns,
+                        parse_integer(record.get("sample_id", 0)),
+                    ),
+                    lambda timestamp_ns=event_time_ns, track_uuid=track_uuid, value=counter_value, annotations=annotations: (
+                        add_counter(
+                            self.builder,
+                            self.perfetto.track_event,
+                            timestamp_ns,
+                            track_uuid,
+                            value,
+                            annotations,
+                        )
+                    ),
+                )
+            )
+            self.all_timestamp_ns.append(event_time_ns)
+            self.stats.counter_samples += 1
+            self.stats.range_counter_values += 1
+            emitted_value_count += 1
+        if emitted_value_count == 0:
+            self.stats.skipped_range_counter_samples += 1
+
+    def range_counter_annotations(
+        self,
+        sample_record: dict[str, Any],
+        counter_set: dict[str, Any],
+        counter: dict[str, Any],
+        counter_value: int,
+        time_domain: str,
+    ) -> dict[str, Any]:
+        return {
+            "iree_counter_sample_id": sample_record.get("sample_id"),
+            "iree_counter_set_id": sample_record.get("counter_set_id"),
+            "iree_counter_set_name": counter_set.get("name"),
+            "iree_counter_ordinal": counter.get("counter_ordinal"),
+            "iree_counter_name": counter.get("name"),
+            "iree_counter_block": counter.get("block"),
+            "iree_counter_unit": counter.get("unit"),
+            "iree_counter_value_aggregation": "sum",
+            "iree_counter_value": counter_value,
+            "iree_counter_raw_value_offset": counter.get("sample_value_offset"),
+            "iree_counter_raw_value_count": counter.get("sample_value_count"),
+            "iree_counter_sample_scope": sample_record.get("scope"),
+            "iree_counter_sample_scope_value": sample_record.get("scope_value"),
+            "iree_counter_sample_flags": sample_record.get("flags"),
+            "iree_physical_device_ordinal": sample_record.get(
+                "physical_device_ordinal"
+            ),
+            "iree_queue_ordinal": sample_record.get("queue_ordinal"),
+            "iree_stream_id": sample_record.get("stream_id"),
+            "iree_duration_ticks": sample_record.get("duration_ticks"),
+            "iree_duration_ns": sample_record.get("duration_ns"),
+            "iree_perfetto_time_domain": time_domain,
+            "iree_perfetto_timing_source": "range_counter_sample_midpoint",
+        }
+
     def dispatch_scoped_counter_annotations(
         self,
         sample_record: dict[str, Any],
@@ -1318,6 +1434,71 @@
         )
         return track_uuid
 
+    def ensure_range_counter_group_track(self, physical_device_ordinal: int) -> int:
+        device_uuid = self.ensure_device_track(physical_device_ordinal)
+        track_uuid = self.tracks.uuid("iree", "range-counters", physical_device_ordinal)
+        self.tracks.define(
+            track_uuid,
+            "range counters",
+            parent_uuid=device_uuid,
+            sibling_order_rank=8100,
+            explicit_child_order=True,
+        )
+        return track_uuid
+
+    def define_range_counter_set_track(
+        self,
+        physical_device_ordinal: int,
+        counter_set_id: int,
+        counter_set: dict[str, Any],
+    ) -> int:
+        group_uuid = self.ensure_range_counter_group_track(physical_device_ordinal)
+        track_uuid = self.tracks.uuid(
+            "iree", "range-counter-set", physical_device_ordinal, counter_set_id
+        )
+        counter_set_name = str(
+            counter_set.get("name") or f"counter set {counter_set_id}"
+        )
+        self.tracks.define(
+            track_uuid,
+            counter_set_name,
+            parent_uuid=group_uuid,
+            sibling_order_rank=min(counter_set_id, 999999),
+            explicit_child_order=True,
+        )
+        return track_uuid
+
+    def define_range_counter_track(
+        self,
+        physical_device_ordinal: int,
+        counter_set_id: int,
+        counter_set: dict[str, Any],
+        counter: dict[str, Any],
+    ) -> int:
+        counter_set_uuid = self.define_range_counter_set_track(
+            physical_device_ordinal, counter_set_id, counter_set
+        )
+        counter_ordinal = parse_ordinal(counter.get("counter_ordinal"))
+        track_uuid = self.tracks.uuid(
+            "iree",
+            "range-counter",
+            physical_device_ordinal,
+            counter_set_id,
+            counter_ordinal,
+        )
+        counter_name = str(counter.get("name") or f"counter[{counter_ordinal}]")
+        block_name = str(counter.get("block") or "")
+        if block_name and not counter_name.startswith(f"{block_name}_"):
+            counter_name = f"{block_name}.{counter_name}"
+        self.tracks.define(
+            track_uuid,
+            counter_name,
+            parent_uuid=counter_set_uuid,
+            sibling_order_rank=min(counter_ordinal, 999999),
+            is_counter=True,
+        )
+        return track_uuid
+
     def collect_device_metric_sample(self, record: dict[str, Any]) -> None:
         event_time_ns = timestamp_midpoint(record)
         values = record.get("values", [])
@@ -1728,6 +1909,7 @@
         ("clock_instants", stats.clock_instants),
         ("counter_samples", stats.counter_samples),
         ("dispatch_scoped_counter_values", stats.dispatch_scoped_counter_values),
+        ("range_counter_values", stats.range_counter_values),
         (
             "skipped_dispatch_scoped_counter_samples",
             stats.skipped_dispatch_scoped_counter_samples,
@@ -1736,6 +1918,8 @@
             "skipped_dispatch_scoped_counter_values",
             stats.skipped_dispatch_scoped_counter_values,
         ),
+        ("skipped_range_counter_samples", stats.skipped_range_counter_samples),
+        ("skipped_range_counter_values", stats.skipped_range_counter_values),
         ("device_metric_counter_values", stats.device_metric_counter_values),
         ("device_metric_partial_samples", stats.device_metric_partial_samples),
         ("skipped_device_metric_samples", stats.skipped_device_metric_samples),
diff --git a/runtime/src/iree/tooling/profile/render_test.py b/runtime/src/iree/tooling/profile/render_test.py
index b4fe1c6..731ef32 100644
--- a/runtime/src/iree/tooling/profile/render_test.py
+++ b/runtime/src/iree/tooling/profile/render_test.py
@@ -465,6 +465,103 @@
         self.assertIn("iree_counter_value_aggregation", annotation_names)
         self.assertIn("iree_perfetto_timing_source", annotation_names)
 
+    def test_build_trace_projects_range_counters_on_separate_tracks(self):
+        builder = _FakeTraceProtoBuilder()
+        trace_bytes, stats = perfetto.build_trace(
+            [
+                {"record_type": "device", "physical_device_ordinal": 0},
+                {
+                    "record_type": "queue",
+                    "physical_device_ordinal": 0,
+                    "queue_ordinal": 0,
+                    "stream_id": 0,
+                },
+                {
+                    "record_type": "clock_correlation",
+                    "physical_device_ordinal": 0,
+                    "device_tick": 100,
+                    "host_time_begin_ns": 1000,
+                    "host_time_end_ns": 1000,
+                },
+                {
+                    "record_type": "clock_correlation",
+                    "physical_device_ordinal": 0,
+                    "device_tick": 200,
+                    "host_time_begin_ns": 2000,
+                    "host_time_end_ns": 2000,
+                },
+                {
+                    "record_type": "counter_set",
+                    "counter_set_id": 1,
+                    "physical_device_ordinal": 0,
+                    "counter_count": 1,
+                    "sample_value_count": 2,
+                    "name": "amdgpu.pmc",
+                },
+                {
+                    "record_type": "counter",
+                    "counter_set_id": 1,
+                    "counter_ordinal": 0,
+                    "physical_device_ordinal": 0,
+                    "unit": 1,
+                    "sample_value_offset": 0,
+                    "sample_value_count": 2,
+                    "block": "SQ",
+                    "name": "SQ_WAVES",
+                },
+                {
+                    "record_type": "counter_sample",
+                    "sample_id": 7,
+                    "counter_set_id": 1,
+                    "scope": "device_time_range",
+                    "physical_device_ordinal": 0,
+                    "queue_ordinal": 0,
+                    "stream_id": 0,
+                    "flags": 2,
+                    "device_tick_range_present": True,
+                    "start_tick": 110,
+                    "end_tick": 120,
+                    "duration_ticks": 10,
+                    "device_tick_range_valid": True,
+                    "duration_ns": 100,
+                    "values": [1, 2],
+                },
+            ],
+            perfetto.PerfettoImports(
+                trace_proto_builder=lambda: builder,
+                track_descriptor=_FakeTrackDescriptor,
+                track_event=_FakeTrackEvent,
+            ),
+        )
+
+        counter_packets = [
+            packet
+            for packet in builder.packets
+            if getattr(packet.track_event, "type", None) == _FakeTrackEvent.TYPE_COUNTER
+        ]
+        track_names = [
+            getattr(packet.track_descriptor, "name", "")
+            for packet in builder.packets
+            if hasattr(packet.track_descriptor, "name")
+        ]
+        annotation_values = [
+            annotation.string_value
+            for packet in counter_packets
+            for annotation in packet.track_event.debug_annotations
+            if getattr(annotation, "name", "") == "iree_perfetto_timing_source"
+        ]
+
+        self.assertTrue(trace_bytes.startswith(b"packets="))
+        self.assertEqual(1, stats.counter_samples)
+        self.assertEqual(0, stats.dispatch_scoped_counter_values)
+        self.assertEqual(1, stats.range_counter_values)
+        self.assertEqual(0, stats.skipped_range_counter_samples)
+        self.assertEqual(1, len(counter_packets))
+        self.assertEqual(3, counter_packets[0].track_event.counter_value)
+        self.assertIn("range counters", track_names)
+        self.assertNotIn("dispatch-scoped counters", track_names)
+        self.assertIn("range_counter_sample_midpoint", annotation_values)
+
     def test_build_trace_skips_unmapped_device_ticks_but_keeps_host_spans(self):
         trace_bytes, stats = perfetto.build_trace(
             [
diff --git a/third_party/hsa-runtime-headers b/third_party/hsa-runtime-headers
index 98297e8..cc2b5f4 160000
--- a/third_party/hsa-runtime-headers
+++ b/third_party/hsa-runtime-headers
@@ -1 +1 @@
-Subproject commit 98297e81ce23bbf302c296aa318fa66fbfd6200f
+Subproject commit cc2b5f429de4d1cb2be96ed10e6f45246e408d0e