Replacing hal.ex.shared_device with hal.devices.* ops. (#15916)

This bumps the HAL version as the existing hal.ex.shared_device op is
removed. The runtime HAL module still only reports a single device and
the compiler code is still selecting the "default" device (0). Future
changes will support multiple `--device=` flags to register multiple
available devices in tools and add attributes/lookup for mapping
stream/device affinity to device ordinals.

The plan is to have a `hal.devices.lookup` op that is hoisted into an
initializer to enumerate devices and store the selected ones in globals.
The enumeration code will emit a loop over `hal.devices.count` and use
`hal.devices.get` and the various device query methods to match the
requested devices.
diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp
index 144a501..2d0e90e 100644
--- a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp
+++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp
@@ -130,8 +130,12 @@
     // HACK: this is relying on the fact that there's only one HAL device.
     // We should instead have a way of creating fences on the device that
     // is used to produce the tensors we're wrapping.
-    auto device =
-        entryBuilder.create<IREE::HAL::ExSharedDeviceOp>(importOp.getLoc());
+    //
+    // TODO(multi-device): emit get with derived ordinal or lookup with attr. We
+    // could always say device 0 for now but could instead look for an
+    // iree.abi.affinity/iree.abi.device/etc.
+    Value device =
+        IREE::HAL::DeviceType::resolveAny(importOp.getLoc(), entryBuilder);
 
     // When exporting a fence we need to put a barrier between the rest of the
     // program and the tensors consumed by the import.
diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points_coarse_fences.mlir b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points_coarse_fences.mlir
index c89451a..f6ff6d1 100644
--- a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points_coarse_fences.mlir
+++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points_coarse_fences.mlir
@@ -117,7 +117,7 @@
 // CHECK: func.func private @_import(%[[ARG0_TENSOR:.+]]: tensor<?x2xi32>, %[[ARG1_TENSOR:.+]]: tensor<?x3xi32>) -> (tensor<2x?xi32>, tensor<3x?xi32>) {
 
 // Prepare fences and put a barrier on input arguments:
-// CHECK:   %[[DEVICE:.+]] = hal.ex.shared_device
+// CHECK:   %[[DEVICE:.+]] = hal.devices.get %{{.+}}
 // CHECK:   %[[WAIT_FENCE:.+]] = hal.fence.create device(%[[DEVICE]]
 // CHECK:   %[[ARG_BARRIER:.+]]:2 = hal.tensor.barrier join(%[[ARG0_TENSOR]], %[[ARG1_TENSOR]] : tensor<?x2xi32>, tensor<?x3xi32>) => %[[WAIT_FENCE]] : !hal.fence
 // CHECK:   %[[SIGNAL_FENCE:.+]] = hal.fence.create device(%[[DEVICE]]
@@ -186,7 +186,7 @@
 // CHECK: func.func private @_importI32Effects(%[[ARG0_TENSOR:.+]]: tensor<4xf32>) -> i32 {
 
 // Wait for the inputs to be ready and create the signal fence to wait on.
-// CHECK:   %[[DEVICE:.+]] = hal.ex.shared_device
+// CHECK:   %[[DEVICE:.+]] = hal.devices.get %{{.+}}
 // CHECK:   %[[WAIT_FENCE:.+]] = hal.fence.create device(%[[DEVICE]]
 // CHECK:   %[[ARG0_BARRIER:.+]] = hal.tensor.barrier join(%[[ARG0_TENSOR]] : tensor<4xf32>) => %[[WAIT_FENCE]] : !hal.fence
 // CHECK:   %[[SIGNAL_FENCE:.+]] = hal.fence.create device(%[[DEVICE]]
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir
index 9e08de1..7b407a6 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir
@@ -73,7 +73,8 @@
   testing.func.b = @dispatch_0::@spirv,
   testing.func.c = @dispatch_0::@spirv::@dispatch_0
 } {
-  %device = hal.ex.shared_device : !hal.device
+  %c0 = arith.constant 0 : index
+  %device = hal.devices.get %c0 : !hal.device
   %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") : !hal.command_buffer attributes {
     testing.op.a = @dispatch_0,
     testing.op.b = @dispatch_0::@spirv,
@@ -86,7 +87,8 @@
   return
 }
 util.initializer {
-  %device = hal.ex.shared_device : !hal.device
+  %c0 = arith.constant 0 : index
+  %device = hal.devices.get %c0 : !hal.device
   %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") : !hal.command_buffer
   %c1 = arith.constant 1 : index
   hal.command_buffer.dispatch.symbol<%cmd : !hal.command_buffer> target(@dispatch_0::@spirv::@dispatch_0) workgroups([%c1, %c1, %c1])
@@ -237,7 +239,8 @@
   }
 }
 func.func @two_target_environments_1() -> () {
-  %device = hal.ex.shared_device : !hal.device
+  %c0 = arith.constant 0 : index
+  %device = hal.devices.get %c0 : !hal.device
   %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") : !hal.command_buffer
   %c1 = arith.constant 1 : index
   hal.command_buffer.dispatch.symbol<%cmd : !hal.command_buffer> target(@dispatch_0::@spirv::@dispatch_0) workgroups([%c1, %c1, %c1])
@@ -245,7 +248,8 @@
   return
 }
 func.func @two_target_environments_2() -> () {
-  %device = hal.ex.shared_device : !hal.device
+  %c0 = arith.constant 0 : index
+  %device = hal.devices.get %c0 : !hal.device
   %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") : !hal.command_buffer
   %c1 = arith.constant 1 : index
   hal.command_buffer.dispatch.symbol<%cmd : !hal.command_buffer> target(@dispatch_1::@spirv::@dispatch_1) workgroups([%c1, %c1, %c1])
diff --git a/compiler/src/iree/compiler/Codegen/VMVX/test/link_executables.mlir b/compiler/src/iree/compiler/Codegen/VMVX/test/link_executables.mlir
index 98dcd6c..93425d8 100644
--- a/compiler/src/iree/compiler/Codegen/VMVX/test/link_executables.mlir
+++ b/compiler/src/iree/compiler/Codegen/VMVX/test/link_executables.mlir
@@ -72,7 +72,8 @@
   testing.func.b = @dispatch_0::@vmvx,
   testing.func.c = @dispatch_0::@vmvx::@dispatch_0
 } {
-  %device = hal.ex.shared_device : !hal.device
+  %c0 = arith.constant 0 : index
+  %device = hal.devices.get %c0 : !hal.device
   %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") : !hal.command_buffer attributes {
     testing.op.a = @dispatch_0,
     testing.op.b = @dispatch_0::@vmvx,
@@ -85,7 +86,8 @@
   return
 }
 util.initializer {
-  %device = hal.ex.shared_device : !hal.device
+  %c0 = arith.constant 0 : index
+  %device = hal.devices.get %c0 : !hal.device
   %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") : !hal.command_buffer
   %c1 = arith.constant 1 : index
   hal.command_buffer.dispatch.symbol<%cmd : !hal.command_buffer> target(@dispatch_0::@vmvx::@dispatch_0) workgroups([%c1, %c1, %c1])
diff --git a/compiler/src/iree/compiler/ConstEval/Runtime.cpp b/compiler/src/iree/compiler/ConstEval/Runtime.cpp
index 25d3c43..c5a93cf 100644
--- a/compiler/src/iree/compiler/ConstEval/Runtime.cpp
+++ b/compiler/src/iree/compiler/ConstEval/Runtime.cpp
@@ -440,9 +440,10 @@
   iree_hal_driver_release(driver);
 
   // Create hal module.
-  IREE_CHECK_OK(iree_hal_module_create(runtime.instance.get(), device.get(),
-                                       IREE_HAL_MODULE_FLAG_NONE,
-                                       iree_allocator_system(), &hal_module));
+  iree_hal_device_t *device_ptr = device.get();
+  IREE_CHECK_OK(iree_hal_module_create(
+      runtime.instance.get(), /*device_count=*/1, &device_ptr,
+      IREE_HAL_MODULE_FLAG_NONE, iree_allocator_system(), &hal_module));
 
   // Bytecode module.
   IREE_CHECK_OK(iree_vm_bytecode_module_create(
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/BUILD.bazel
index c630dc4..fcbb192 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/BUILD.bazel
@@ -21,6 +21,7 @@
         "ConvertChannelOps.cpp",
         "ConvertCommandBufferOps.cpp",
         "ConvertDeviceOps.cpp",
+        "ConvertDevicesOps.cpp",
         "ConvertExecutableOps.cpp",
         "ConvertExperimentalOps.cpp",
         "ConvertFenceOps.cpp",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/CMakeLists.txt
index 5a5cb8a..713982b 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/CMakeLists.txt
@@ -22,6 +22,7 @@
     "ConvertChannelOps.cpp"
     "ConvertCommandBufferOps.cpp"
     "ConvertDeviceOps.cpp"
+    "ConvertDevicesOps.cpp"
     "ConvertExecutableOps.cpp"
     "ConvertExperimentalOps.cpp"
     "ConvertFenceOps.cpp"
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDevicesOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDevicesOps.cpp
new file mode 100644
index 0000000..266233d
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDevicesOps.cpp
@@ -0,0 +1,23 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir::iree_compiler {
+
+void populateHALDevicesToVMPatterns(MLIRContext *context,
+                                    SymbolTable &importSymbols,
+                                    TypeConverter &typeConverter,
+                                    RewritePatternSet &patterns) {
+  patterns.insert<VMImportOpConversion<IREE::HAL::DevicesCountOp>>(
+      context, importSymbols, typeConverter, "hal.devices.count");
+  patterns.insert<VMImportOpConversion<IREE::HAL::DevicesGetOp>>(
+      context, importSymbols, typeConverter, "hal.devices.get");
+}
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExperimentalOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExperimentalOps.cpp
index e724412..b8bc13d 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExperimentalOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExperimentalOps.cpp
@@ -14,8 +14,6 @@
                                          SymbolTable &importSymbols,
                                          TypeConverter &typeConverter,
                                          RewritePatternSet &patterns) {
-  patterns.insert<VMImportOpConversion<IREE::HAL::ExSharedDeviceOp>>(
-      context, importSymbols, typeConverter, "hal.ex.shared_device");
   patterns.insert<VMImportOpConversion<IREE::HAL::ExFileFromMemoryOp>>(
       context, importSymbols, typeConverter, "hal.ex.file.from_memory");
 }
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/Patterns.cpp
index 983e2cf..592e78f 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/Patterns.cpp
@@ -51,6 +51,10 @@
                                           SymbolTable &importSymbols,
                                           TypeConverter &typeConverter,
                                           RewritePatternSet &patterns);
+extern void populateHALDevicesToVMPatterns(MLIRContext *context,
+                                           SymbolTable &importSymbols,
+                                           TypeConverter &typeConverter,
+                                           RewritePatternSet &patterns);
 extern void populateHALExecutableToVMPatterns(MLIRContext *context,
                                               SymbolTable &importSymbols,
                                               TypeConverter &typeConverter,
@@ -79,6 +83,8 @@
                                        patterns);
   populateHALDeviceToVMPatterns(context, importSymbols, typeConverter,
                                 patterns);
+  populateHALDevicesToVMPatterns(context, importSymbols, typeConverter,
+                                 patterns);
   populateHALExecutableToVMPatterns(context, importSymbols, typeConverter,
                                     patterns);
   populateHALExperimentalToVMPatterns(context, importSymbols, typeConverter,
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/BUILD.bazel
index 0f80024..a486719 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/BUILD.bazel
@@ -22,6 +22,7 @@
             "channel_ops.mlir",
             "command_buffer_ops.mlir",
             "device_ops.mlir",
+            "devices_ops.mlir",
             "executable_ops.mlir",
             "fence_ops.mlir",
         ],
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/CMakeLists.txt
index 9ec645e..b65d9cf 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/CMakeLists.txt
@@ -20,6 +20,7 @@
     "channel_ops.mlir"
     "command_buffer_ops.mlir"
     "device_ops.mlir"
+    "devices_ops.mlir"
     "executable_ops.mlir"
     "fence_ops.mlir"
   TOOLS
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/devices_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/devices_ops.mlir
new file mode 100644
index 0000000..b8423ad
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/devices_ops.mlir
@@ -0,0 +1,18 @@
+// RUN: iree-opt --split-input-file --iree-convert-hal-to-vm --canonicalize --iree-vm-target-index-bits=32 %s | FileCheck %s
+
+// CHECK-LABEL: @devices_count
+func.func @devices_count() -> index {
+  // CHECK: = vm.call @hal.devices.count() {nosideeffects} : () -> i32
+  %device_count = hal.devices.count : index
+  return %device_count : index
+}
+
+// -----
+
+// CHECK-LABEL: @devices_get
+// CHECK-SAME: (%[[INDEX:.+]]: i32)
+func.func @devices_get(%index: index) -> !hal.device {
+  // CHECK: = vm.call @hal.devices.get(%[[INDEX]]) {nosideeffects} : (i32) -> !vm.ref<!hal.device>
+  %device = hal.devices.get %index : !hal.device
+  return %device : !hal.device
+}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
index 55994dd..8b34b92 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
@@ -51,9 +51,9 @@
     auto resultTypes = llvm::to_vector(resolveOp.getResultTypes());
     assert(!resultTypes.empty() && "must have at least one result");
 
-    // TODO(benvanik): make this do multi-device lookup and other fancy things.
+    // TODO(multi-device): emit get with derived ordinal or lookup with attr.
     Value device =
-        rewriter.create<IREE::HAL::ExSharedDeviceOp>(resolveOp.getLoc());
+        IREE::HAL::DeviceType::resolveAny(resolveOp.getLoc(), rewriter);
 
     SmallVector<Value> results;
     if (resultTypes[0].isa<IREE::HAL::DeviceType>()) {
@@ -678,7 +678,7 @@
     // Get the device handle we're executing against in this execution region.
     // Note that this is a dynamic value: we have to treat the device as unknown
     // here.
-    auto deviceValue = rewriter.create<IREE::HAL::CommandBufferDeviceOp>(
+    Value device = rewriter.create<IREE::HAL::CommandBufferDeviceOp>(
         loc, rewriter.getType<IREE::HAL::DeviceType>(), commandBuffer);
 
     // Prepare for variant switch table by gathering the conditions selecting
@@ -703,7 +703,7 @@
           auto exportOp = caseExportOps[i].second;
           auto variantOp =
               exportOp->getParentOfType<IREE::HAL::ExecutableVariantOp>();
-          return variantOp.buildCondition(deviceValue, rewriter);
+          return variantOp.buildCondition(device, rewriter);
         },
         rewriter);
 
@@ -718,12 +718,12 @@
       auto caseBuilder = OpBuilder::atBlockBegin(&caseBlock);
 
       // Record push constants and buffer bindings.
-      recordParameters(loc, deviceValue, commandBuffer, dispatchOp, adaptor,
+      recordParameters(loc, device, commandBuffer, dispatchOp, adaptor,
                        exportOp.getLayout(), caseBuilder);
 
       // Dispatch with a target-specific workgroup count.
       auto caseWorkgroupCount = exportOp.calculateWorkgroupCount(
-          loc, deviceValue, adaptor.getWorkload(), caseBuilder);
+          loc, device, adaptor.getWorkload(), caseBuilder);
       caseBuilder.create<IREE::HAL::CommandBufferDispatchSymbolOp>(
           loc, commandBuffer, entryPointAttr, caseWorkgroupCount[0],
           caseWorkgroupCount[1], caseWorkgroupCount[2]);
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/channel_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/channel_ops.mlir
index 9161467..d11abe8 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/channel_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/channel_ops.mlir
@@ -3,7 +3,7 @@
 // CHECK-LABEL: @channel_create
 //  CHECK-SAME: () -> !hal.channel
 func.func @channel_create() -> !stream.channel {
-  // CHECK-DAG: %[[DEVICE:.+]] = hal.ex.shared_device : !hal.device
+  // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}} : !hal.device
   // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant 3
   // CHECK-DAG: %[[ID:.+]] = util.null : !util.buffer
   // CHECK-DAG: %[[GROUP:.+]] = util.buffer.constant : !util.buffer = "group"
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir
index 659c11b..dd43362 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir
@@ -2,7 +2,7 @@
 
 // CHECK-LABEL: @contextResolveAllocator
 func.func @contextResolveAllocator() -> !hal.allocator {
-  // CHECK: %[[DEVICE:.+]] = hal.ex.shared_device
+  // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
   // CHECK: %[[ALLOCATOR:.+]] = hal.device.allocator<%[[DEVICE]] : !hal.device> : !hal.allocator
   %allocator = stream.context.resolve : !hal.allocator
   // CHECK: return %[[ALLOCATOR]]
@@ -13,7 +13,7 @@
 
 // CHECK-LABEL: @contextResolveDevice
 func.func @contextResolveDevice() -> !hal.device {
-  // CHECK: %[[DEVICE:.+]] = hal.ex.shared_device
+  // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
   %device = stream.context.resolve : !hal.device
   // CHECK: return %[[DEVICE]]
   return %device : !hal.device
@@ -23,7 +23,7 @@
 
 // CHECK-LABEL: @contextResolveDeviceQueueAffinityAny
 func.func @contextResolveDeviceQueueAffinityAny() -> (!hal.device, i64) {
-  // CHECK-DAG: %[[DEVICE:.+]] = hal.ex.shared_device
+  // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
   // CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant -1 : i64
   %device, %queue_affinity_any = stream.context.resolve on(#hal.affinity.queue<*>) : !hal.device, i64
   // CHECK: return %[[DEVICE]], %[[QUEUE_AFFINITY]]
@@ -34,7 +34,7 @@
 
 // CHECK-LABEL: @contextResolveDeviceQueueAffinity45
 func.func @contextResolveDeviceQueueAffinity45() -> (!hal.device, i64) {
-  // CHECK: %[[DEVICE:.+]] = hal.ex.shared_device
+  // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
   // CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant 48 : i64
   %device, %queue_affinity_45 = stream.context.resolve on(#hal.affinity.queue<[4, 5]>) : !hal.device, i64
   // CHECK: return %[[DEVICE]], %[[QUEUE_AFFINITY]]
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/file_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/file_ops.mlir
index 2ac1385..9473df7 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/file_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/file_ops.mlir
@@ -5,7 +5,8 @@
 func.func @file_constant(%buffer: !util.buffer) {
   %c0 = arith.constant 0 : index
   %c1088 = arith.constant 1088 : index
-  // CHECK: = hal.ex.file.from_memory device(%device : !hal.device) affinity(%c-1_i64) access(Read) buffer(%[[BUFFER]] : !util.buffer)[%c0 for %c1088] flags(%c0_i32) : !hal.file
+  // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
+  // CHECK: = hal.ex.file.from_memory device(%[[DEVICE]] : !hal.device) affinity(%c-1_i64) access(Read) buffer(%[[BUFFER]] : !util.buffer)[%c0 for %c1088] flags(%c0_i32) : !hal.file
   %file = stream.file.constant %buffer[%c0 for %c1088] : !util.buffer{%c1088} -> !stream.file
   return
 }
@@ -18,8 +19,9 @@
   %c0 = arith.constant 0 : index
   %c0_i64 = arith.constant 0 : i64
   %c1088 = arith.constant 1088 : index
+  // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
   // CHECK: %[[SIGNAL:.+]] = hal.fence.create
-  // CHECK: hal.device.queue.read<%device : !hal.device> affinity(%c-1_i64) wait(%[[WAIT]]) signal(%[[SIGNAL]]) source(%[[FILE]] : !hal.file)[%c0_i64] target(%[[RESOURCE]] : !hal.buffer)[%c0] length(%c1088) flags(0)
+  // CHECK: hal.device.queue.read<%[[DEVICE]] : !hal.device> affinity(%c-1_i64) wait(%[[WAIT]]) signal(%[[SIGNAL]]) source(%[[FILE]] : !hal.file)[%c0_i64] target(%[[RESOURCE]] : !hal.buffer)[%c0] length(%c1088) flags(0)
   %signal = stream.file.read await(%wait) => %file[%c0_i64], %resource[%c0], %c1088 : !stream.file -> !stream.resource<variable>{%c1088} => !stream.timepoint
   // CHECK: return %[[SIGNAL]]
   return %signal : !stream.timepoint
@@ -33,8 +35,9 @@
   %c0 = arith.constant 0 : index
   %c0_i64 = arith.constant 0 : i64
   %c1088 = arith.constant 1088 : index
+  // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
   // CHECK: %[[SIGNAL:.+]] = hal.fence.create
-  // CHECK: hal.device.queue.write<%device : !hal.device> affinity(%c-1_i64) wait(%[[WAIT]]) signal(%[[SIGNAL]]) source(%[[RESOURCE]] : !hal.buffer)[%c0] target(%[[FILE]] : !hal.file)[%c0_i64] length(%c1088) flags(0)
+  // CHECK: hal.device.queue.write<%[[DEVICE]] : !hal.device> affinity(%c-1_i64) wait(%[[WAIT]]) signal(%[[SIGNAL]]) source(%[[RESOURCE]] : !hal.buffer)[%c0] target(%[[FILE]] : !hal.file)[%c0_i64] length(%c1088) flags(0)
   %signal = stream.file.write await(%wait) => %resource[%c0], %file[%c0_i64], %c1088 : !stream.resource<variable>{%c1088} -> !stream.file => !stream.timepoint
   // CHECK: return %[[SIGNAL]]
   return %signal : !stream.timepoint
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/timepoint_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/timepoint_ops.mlir
index 7d71e1d..cca49b1 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/timepoint_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/timepoint_ops.mlir
@@ -45,7 +45,7 @@
 // CHECK-LABEL: @timepointChainExternal
 //  CHECK-SAME: (%[[TIMEPOINT:.+]]: !hal.fence, %[[SIGNAL:.+]]: !hal.fence)
 func.func @timepointChainExternal(%timepoint: !stream.timepoint, %signal: !hal.fence) {
-  // CHECK: %[[DEVICE:.+]] = hal.ex.shared_device
+  // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
   // CHECK: hal.device.queue.execute<%[[DEVICE]] : !hal.device> affinity(%c-1_i64) wait(%[[TIMEPOINT]]) signal(%[[SIGNAL]])
   stream.timepoint.chain_external %timepoint => (%signal : !hal.fence)
   return
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index 3711768..2abfd23 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -332,11 +332,6 @@
 // hal.ex.*
 //===----------------------------------------------------------------------===//
 
-void ExSharedDeviceOp::getAsmResultNames(
-    function_ref<void(Value, StringRef)> setNameFn) {
-  setNameFn(getResult(), "device");
-}
-
 void ExFileFromMemoryOp::getAsmResultNames(
     function_ref<void(Value, StringRef)> setNameFn) {
   setNameFn(getResult(), "memory_file");
@@ -924,6 +919,27 @@
 }
 
 //===----------------------------------------------------------------------===//
+// hal.devices.*
+//===----------------------------------------------------------------------===//
+
+void DevicesCountOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "device_count");
+}
+
+void DevicesGetOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  APInt index;
+  if (matchPattern(getIndex(), m_ConstantInt(&index))) {
+    llvm::SmallString<16> str("device_");
+    index.toStringUnsigned(str);
+    setNameFn(getResult(), str);
+  } else {
+    setNameFn(getResult(), "device_n");
+  }
+}
+
+//===----------------------------------------------------------------------===//
 // hal.executable.source
 //===----------------------------------------------------------------------===//
 
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
index 50414a6..059e451 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -37,24 +37,6 @@
 
 let opDocGroup = OpGroupExperimentalOps in {
 
-def HAL_ExSharedDeviceOp : HAL_PureOp<"ex.shared_device", [
-    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-  ]> {
-  let results = (outs
-    HAL_Device:$result
-  );
-
-  let assemblyFormat = "attr-dict `:` type($result)";
-
-  let skipDefaultBuilders = 1;
-  let builders = [
-    OpBuilder<(ins),
-    [{
-      $_state.addTypes({DeviceType::get($_builder.getContext())});
-    }]>,
-  ];
-}
-
 def HAL_ExFileFromMemoryOp : HAL_Op<"ex.file.from_memory", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
   ]> {
@@ -1864,6 +1846,57 @@
 } // OpGroupDeviceOps
 
 //===----------------------------------------------------------------------===//
+// !hal.device management
+//===----------------------------------------------------------------------===//
+
+def OpGroupDeviceManagementOps : OpDocGroup {
+  let summary = "Device management ops";
+  let description = "Device availability and selection support.";
+}
+
+let opDocGroup = OpGroupDeviceManagementOps in {
+
+def HAL_DevicesCountOp : HAL_PureOp<"devices.count", [
+  DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+]> {
+  let summary = [{returns the number of available devices}];
+  let description = [{
+    Returns the total number of available devices registered at runtime.
+  }];
+
+  let results = (outs
+    Index:$result
+  );
+
+  let assemblyFormat = [{
+    attr-dict `:` type($result)
+  }];
+}
+
+def HAL_DevicesGetOp : HAL_PureOp<"devices.get", [
+  DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+]> {
+  let summary = [{returns the device with the given index}];
+  let description = [{
+    Returns the device with the given index in the [0, hal.devices.count) range.
+    Devices may be lazily initialized upon first use.
+  }];
+
+  let arguments = (ins
+    Index:$index
+  );
+  let results = (outs
+    HAL_Device:$result
+  );
+
+  let assemblyFormat = [{
+    $index attr-dict `:` type($result)
+  }];
+}
+
+}  // OpGroupDeviceManagementOps
+
+//===----------------------------------------------------------------------===//
 // !hal.executable / iree_hal_executable_t
 //===----------------------------------------------------------------------===//
 
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
index bcf3b84..9359e0f 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
@@ -218,6 +218,13 @@
           loc, builder.getType<IREE::HAL::BufferType>(), value));
 }
 
+// static
+Value DeviceType::resolveAny(Location loc, OpBuilder &builder) {
+  Value deviceIndex = builder.create<arith::ConstantIndexOp>(loc, 0);
+  return builder.create<IREE::HAL::DevicesGetOp>(
+      loc, builder.getType<IREE::HAL::DeviceType>(), deviceIndex);
+}
+
 //===----------------------------------------------------------------------===//
 // #hal.device.target
 //===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h
index acdca66..963f404 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h
@@ -125,6 +125,12 @@
   using Base::Base;
 
   static constexpr StringLiteral name = "hal.device";
+
+  // Resolves to any device at runtime.
+  // This is unlikely to be what any particular caller wants and should be
+  // avoided outside of testing/debugging passes that don't care about
+  // multi-targeting.
+  static Value resolveAny(Location loc, OpBuilder &builder);
 };
 
 struct EventType : public Type::TypeBase<EventType, Type, TypeStorage> {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/IR/test/BUILD.bazel
index 013db09..c8e6437 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/BUILD.bazel
@@ -27,6 +27,7 @@
             "descriptor_set_ops.mlir",
             "device_folding.mlir",
             "device_ops.mlir",
+            "devices_ops.mlir",
             "executable_folding.mlir",
             "executable_ops.mlir",
             "executable_targets.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/IR/test/CMakeLists.txt
index 55bd72b..f1fb349 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/CMakeLists.txt
@@ -25,6 +25,7 @@
     "descriptor_set_ops.mlir"
     "device_folding.mlir"
     "device_ops.mlir"
+    "devices_ops.mlir"
     "executable_folding.mlir"
     "executable_ops.mlir"
     "executable_targets.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir
index 392dc87..3e1b94d 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir
@@ -1,18 +1,17 @@
 // RUN: iree-opt --split-input-file --canonicalize %s | iree-opt --split-input-file | FileCheck %s
 
 // CHECK-LABEL: @skip_command_buffer_device
-func.func @skip_command_buffer_device() -> !hal.executable {
-  // CHECK: %[[DEVICE:.+]] = hal.ex.shared_device
-  %dev = hal.ex.shared_device : !hal.device
-  %cmd = hal.command_buffer.create device(%dev : !hal.device)
+// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device)
+func.func @skip_command_buffer_device(%device: !hal.device) -> !hal.executable {
+  %cmd = hal.command_buffer.create device(%device : !hal.device)
                                      mode(OneShot)
-                                     categories("Transfer|Dispatch") : !hal.command_buffer
+                               categories("Transfer|Dispatch") : !hal.command_buffer
 
   // CHECK-NOT: hal.command_buffer.device
   //      CHECK: = hal.executable.lookup device(%[[DEVICE]] : !hal.device)
   // CHECK-SAME:     executable(@executable_name) : !hal.executable
-  %0 = hal.command_buffer.device<%cmd : !hal.command_buffer> : !hal.device
-  %exe = hal.executable.lookup device(%dev : !hal.device)
+  %device2 = hal.command_buffer.device<%cmd : !hal.command_buffer> : !hal.device
+  %exe = hal.executable.lookup device(%device2 : !hal.device)
                            executable(@executable_name) : !hal.executable
 
   return %exe : !hal.executable
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/devices_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/devices_ops.mlir
new file mode 100644
index 0000000..633318b
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/devices_ops.mlir
@@ -0,0 +1,18 @@
+// RUN: iree-opt --split-input-file %s | iree-opt --split-input-file | FileCheck %s
+
+// CHECK-LABEL: @devices_count
+func.func @devices_count() -> index {
+  // CHECK: = hal.devices.count : index
+  %device_count = hal.devices.count : index
+  return %device_count : index
+}
+
+// -----
+
+// CHECK-LABEL: @devices_get
+// CHECK-SAME: (%[[INDEX:.+]]: index)
+func.func @devices_get(%index: index) -> !hal.device {
+  // CHECK: = hal.devices.get %[[INDEX]] : !hal.device
+  %device = hal.devices.get %index : !hal.device
+  return %device : !hal.device
+}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/experimental_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/experimental_ops.mlir
index 3fb6585..3c8fd84 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/experimental_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/experimental_ops.mlir
@@ -1,14 +1,5 @@
 // RUN: iree-opt --split-input-file %s | iree-opt --split-input-file | FileCheck %s
 
-// CHECK-LABEL: @shared_device
-func.func @shared_device() -> !hal.device {
-  // CHECK: %device = hal.ex.shared_device : !hal.device
-  %device = hal.ex.shared_device : !hal.device
-  return %device : !hal.device
-}
-
-// -----
-
 // CHECK-LABEL: @file_from_memory
 // CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[BUFFER:.+]]: !util.buffer)
 func.func @file_from_memory(%device: !hal.device, %buffer: !util.buffer) -> !hal.file {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
index 4a6f515..5fc02a5 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
@@ -173,8 +173,8 @@
   auto initBuilder = OpBuilder::atBlockBegin(initOp.addEntryBlock());
   IndexSet indexSet(loc, initBuilder);
 
-  // TODO(benvanik): real device lookup.
-  auto device = initBuilder.create<IREE::HAL::ExSharedDeviceOp>(loc);
+  // TODO(multi-device): support multiple devices in benchmark generation.
+  Value device = IREE::HAL::DeviceType::resolveAny(loc, initBuilder);
   auto allocator =
       initBuilder.create<IREE::HAL::DeviceAllocatorOp>(loc, device).getResult();
 
@@ -242,8 +242,8 @@
   auto batchSizeArg = funcBuilder.create<arith::IndexCastOp>(
       loc, funcBuilder.getIndexType(), entryBlock->getArgument(0));
 
-  // TODO(benvanik): real device lookup.
-  auto device = funcBuilder.create<IREE::HAL::ExSharedDeviceOp>(loc);
+  // TODO(multi-device): support multiple devices in benchmark generation.
+  Value device = IREE::HAL::DeviceType::resolveAny(loc, funcBuilder);
 
   // Create and begin command buffer.
   // TODO(benvanik): reuse the command buffer (initialize once and store).
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
index 29f8580..3b0bb79 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
@@ -136,10 +136,11 @@
     auto initializerOp = moduleBuilder.create<IREE::Util::InitializerOp>(loc);
     OpBuilder blockBuilder =
         OpBuilder::atBlockEnd(initializerOp.addEntryBlock());
-    auto deviceValue = blockBuilder.createOrFold<ExSharedDeviceOp>(loc);
-    auto layoutValue = blockBuilder.createOrFold<DescriptorSetLayoutCreateOp>(
-        loc, layoutType, deviceValue, flags, bindingAttrs);
-    blockBuilder.create<IREE::Util::GlobalStoreOp>(loc, layoutValue,
+    // TODO(multi-device): pass in resolve info to the call and reuse.
+    Value device = IREE::HAL::DeviceType::resolveAny(loc, blockBuilder);
+    Value layout = blockBuilder.createOrFold<DescriptorSetLayoutCreateOp>(
+        loc, layoutType, device, flags, bindingAttrs);
+    blockBuilder.create<IREE::Util::GlobalStoreOp>(loc, layout,
                                                    globalOp.getName());
     blockBuilder.create<IREE::Util::InitializerReturnOp>(loc);
 
@@ -188,12 +189,13 @@
           setLayoutGlobalOp.getSymName());
       setLayoutValues.push_back(setLayoutValue);
     }
-    auto deviceValue = blockBuilder.createOrFold<ExSharedDeviceOp>(loc);
-    auto layoutValue = blockBuilder.createOrFold<PipelineLayoutCreateOp>(
-        loc, layoutType, deviceValue,
+    // TODO(multi-device): pass in resolve info to the call and reuse.
+    Value device = IREE::HAL::DeviceType::resolveAny(loc, blockBuilder);
+    Value layout = blockBuilder.createOrFold<PipelineLayoutCreateOp>(
+        loc, layoutType, device,
         blockBuilder.getIndexAttr(layoutAttr.getPushConstants()),
         setLayoutValues);
-    blockBuilder.create<IREE::Util::GlobalStoreOp>(loc, layoutValue,
+    blockBuilder.create<IREE::Util::GlobalStoreOp>(loc, layout,
                                                    globalOp.getName());
     blockBuilder.create<IREE::Util::InitializerReturnOp>(loc);
 
@@ -214,7 +216,8 @@
     auto initializerOp = moduleBuilder.create<IREE::Util::InitializerOp>(loc);
     OpBuilder blockBuilder =
         OpBuilder::atBlockEnd(initializerOp.addEntryBlock());
-    auto deviceValue = blockBuilder.createOrFold<ExSharedDeviceOp>(loc);
+    // TODO(multi-device): pass in resolve info to the call and reuse.
+    Value device = IREE::HAL::DeviceType::resolveAny(loc, blockBuilder);
 
     // Create a switch statement with a case for each variant.
     // Each case should then cache only executables which contain a matching
@@ -232,7 +235,7 @@
     Value selectedIndex = buildIfElseTree(
         loc, caseVariantOps.size(),
         [&](Location loc, size_t i, OpBuilder &builder) {
-          return caseVariantOps[i].buildCondition(deviceValue, builder);
+          return caseVariantOps[i].buildCondition(device, builder);
         },
         blockBuilder);
 
@@ -261,18 +264,18 @@
       SmallVector<Value> constantValues;
       for (auto blockOp :
            llvm::make_early_inc_range(variantOp.getConstantBlockOps())) {
-        constantValues.append(inlineConstantBlockOp(blockOp, moduleBuilder,
-                                                    caseBuilder, deviceValue));
+        constantValues.append(
+            inlineConstantBlockOp(blockOp, moduleBuilder, caseBuilder, device));
         blockOp.erase();
       }
 
-      auto executableValue = caseBuilder.createOrFold<ExecutableCreateOp>(
-          loc, executableType, deviceValue,
+      Value executable = caseBuilder.createOrFold<ExecutableCreateOp>(
+          loc, executableType, device,
           SymbolRefAttr::get(executableOp.getSymNameAttr(),
                              {SymbolRefAttr::get(variantOp.getSymNameAttr())}),
           pipelineLayoutValues, constantValues);
 
-      caseBuilder.create<scf::YieldOp>(loc, executableValue);
+      caseBuilder.create<scf::YieldOp>(loc, executable);
     }
 
     // Fallback for no available variant.
@@ -299,7 +302,7 @@
   SmallVector<Value> inlineConstantBlockOp(ExecutableConstantBlockOp blockOp,
                                            OpBuilder &moduleBuilder,
                                            OpBuilder &callerBuilder,
-                                           Value deviceValue) {
+                                           Value device) {
     // Create the function with the region contents of the constant block.
     auto funcName = (StringRef("__constant_block_") +
                      std::to_string(nextUniqueConstantBlockId++))
@@ -320,7 +323,7 @@
     // Create the call passing in the device if needed.
     SmallVector<Value> callOperands;
     if (funcOp.getNumArguments() > 0) {
-      callOperands.push_back(deviceValue);
+      callOperands.push_back(device);
     }
     auto callOp = callerBuilder.create<func::CallOp>(blockOp.getLoc(), funcOp,
                                                      callOperands);
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp
index 4e47212..80845d3 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp
@@ -44,6 +44,7 @@
         auto fullKey = ArrayAttr::get(
             moduleOp.getContext(),
             {
+                // TODO(multi-device): add attr key on device resolve source.
                 StringAttr::get(moduleOp.getContext(),
                                 queryOp.getCategory() + queryOp.getKey()),
                 queryOp.getDefaultValue().has_value()
@@ -90,8 +91,8 @@
       auto initializerOp =
           moduleBuilder.create<IREE::Util::InitializerOp>(fusedLoc);
       auto funcBuilder = OpBuilder::atBlockBegin(initializerOp.addEntryBlock());
-      auto device =
-          funcBuilder.createOrFold<IREE::HAL::ExSharedDeviceOp>(fusedLoc);
+      // TODO(multi-device): pass in resolve info to the call and reuse.
+      Value device = IREE::HAL::DeviceType::resolveAny(fusedLoc, funcBuilder);
       auto queryOp = funcBuilder.create<IREE::HAL::DeviceQueryOp>(
           fusedLoc, funcBuilder.getI1Type(), queryType, device,
           anyQueryOp.getCategoryAttr(), anyQueryOp.getKeyAttr(),
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir
index 1b9e70a..d1540ba 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir
@@ -59,7 +59,7 @@
     // CHECK: %[[ARG0_BUFFER:.+]] = hal.buffer_view.buffer<%[[ARG0]] : !hal.buffer_view> : !hal.buffer
 
     // (annoyingly out of order)
-    // CHECK-DAG: %[[DEVICE:.+]] = hal.ex.shared_device : !hal.device
+    // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
     // CHECK-DAG: %[[ALLOCATOR:.+]] = hal.device.allocator<%[[DEVICE]] : !hal.device> : !hal.allocator
 
     // CHECK: hal.buffer.assert<%[[ARG0_BUFFER]] : !hal.buffer>
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir
index 8d468b0..ce67b16 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir
@@ -2,15 +2,15 @@
 
 //      CHECK: util.global private @_descriptor_set_layout_0 : !hal.descriptor_set_layout
 // CHECK-NEXT: util.initializer {
-// CHECK-NEXT:   %device = hal.ex.shared_device : !hal.device
-// CHECK-NEXT:   %descriptor_set_layout = hal.descriptor_set_layout.create
-// CHECK-SAME:     device(%device : !hal.device)
+//      CHECK:   %[[DEVICE:.+]] = hal.devices.get %{{.+}}
+// CHECK-NEXT:   %[[LAYOUT:.+]] = hal.descriptor_set_layout.create
+// CHECK-SAME:     device(%[[DEVICE]] : !hal.device)
 // CHECK-SAME:     flags("None")
 // CHECK-SAME:     bindings([
 // CHECK-SAME:       #hal.descriptor_set.binding<0, storage_buffer>,
 // CHECK-SAME:       #hal.descriptor_set.binding<1, storage_buffer>
 // CHECK-SAME:     ]) : !hal.descriptor_set_layout
-// CHECK-NEXT:   util.global.store %descriptor_set_layout, @_descriptor_set_layout_0 : !hal.descriptor_set_layout
+// CHECK-NEXT:   util.global.store %[[LAYOUT]], @_descriptor_set_layout_0 : !hal.descriptor_set_layout
 
 // CHECK-LABEL: @descriptorSetLayoutLookup
 func.func @descriptorSetLayoutLookup(%device : !hal.device) -> !hal.descriptor_set_layout {
@@ -31,13 +31,13 @@
 
 //      CHECK: util.global private @_pipeline_layout_0 : !hal.pipeline_layout
 // CHECK-NEXT: util.initializer {
-// CHECK-NEXT:   %[[SET0:.+]] = util.global.load @_descriptor_set_layout_0 : !hal.descriptor_set_layout
-// CHECK-NEXT:   %device = hal.ex.shared_device : !hal.device
-// CHECK-NEXT:   %pipeline_layout = hal.pipeline_layout.create
-// CHECK-SAME:     device(%device : !hal.device)
+//  CHECK-DAG:   %[[SET0:.+]] = util.global.load @_descriptor_set_layout_0 : !hal.descriptor_set_layout
+//  CHECK-DAG:   %[[DEVICE:.+]] = hal.devices.get %{{.+}}
+// CHECK-NEXT:   %[[LAYOUT:.+]] = hal.pipeline_layout.create
+// CHECK-SAME:     device(%[[DEVICE]] : !hal.device)
 // CHECK-SAME:     push_constants(1)
 // CHECK-SAME:     layouts([%[[SET0]]]) : !hal.pipeline_layout
-// CHECK-NEXT:   util.global.store %pipeline_layout, @_pipeline_layout_0 : !hal.pipeline_layout
+// CHECK-NEXT:   util.global.store %[[LAYOUT]], @_pipeline_layout_0 : !hal.pipeline_layout
 
 // CHECK-LABEL: @exeLayoutLookup
 func.func @exeLayoutLookup(%device : !hal.device) -> !hal.pipeline_layout {
@@ -60,14 +60,14 @@
 
 //      CHECK: util.global private @_pipeline_layout_0 : !hal.pipeline_layout
 // CHECK-NEXT: util.initializer {
-// CHECK-NEXT:   %[[SET0:.+]] = util.global.load @_descriptor_set_layout_0 : !hal.descriptor_set_layout
-// CHECK-NEXT:   %[[SET1:.+]] = util.global.load @_descriptor_set_layout_1 : !hal.descriptor_set_layout
-// CHECK-NEXT:   %device = hal.ex.shared_device : !hal.device
-// CHECK-NEXT:   %pipeline_layout = hal.pipeline_layout.create
-// CHECK-SAME:     device(%device : !hal.device)
+//  CHECK-DAG:   %[[SET0:.+]] = util.global.load @_descriptor_set_layout_0 : !hal.descriptor_set_layout
+//  CHECK-DAG:   %[[SET1:.+]] = util.global.load @_descriptor_set_layout_1 : !hal.descriptor_set_layout
+//  CHECK-DAG:   %[[DEVICE:.+]] = hal.devices.get %{{.+}}
+// CHECK-NEXT:   %[[LAYOUT:.+]] = hal.pipeline_layout.create
+// CHECK-SAME:     device(%[[DEVICE]] : !hal.device)
 // CHECK-SAME:     push_constants(1)
 // CHECK-SAME:     layouts([%[[SET0]], %[[SET1]]]) : !hal.pipeline_layout
-// CHECK-NEXT:   util.global.store %pipeline_layout, @_pipeline_layout_0 : !hal.pipeline_layout
+// CHECK-NEXT:   util.global.store %[[LAYOUT]], @_pipeline_layout_0 : !hal.pipeline_layout
 
 // CHECK-LABEL: @sharedLayoutLookup
 func.func @sharedLayoutLookup(%device : !hal.device) -> !hal.pipeline_layout {
@@ -161,14 +161,16 @@
 // CHECK-NEXT: util.initializer {
 
 // Switch on the supported formats:
-// CHECK:   %[[DEVICE:.+]] = hal.ex.shared_device : !hal.device
+// CHECK:   %[[DEVICE:.+]] = hal.devices.get %{{.+}}
 // CHECK:   %{{.+}}, %[[FORMAT_VMVX:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.executable.format" :: "vmvx-bytecode-fb")
 // CHECK:   %[[VMVX_CONDITION:.+]] = scf.execute_region -> i1 {
 // CHECK:     %{{.+}}, %[[FEATURE:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("some" :: "feature")
 // CHECK:     scf.yield %[[FEATURE]]
 // CHECK:   }
 // CHECK:   %[[VMVX_VARIANT_SELECTED:.+]] = arith.andi %[[FORMAT_VMVX]], %[[VMVX_CONDITION]]
-// CHECK:   %[[VARIANT_INDEX:.+]] = arith.select %[[VMVX_VARIANT_SELECTED]], %c0, %c-1
+// CHECK-DAG: %[[VARIANT_VMVX:.+]] = arith.constant 0
+// CHECK-DAG: %[[VARIANT_DEFAULT:.+]] = arith.constant -1
+// CHECK:   %[[VARIANT_INDEX:.+]] = arith.select %[[VMVX_VARIANT_SELECTED]], %[[VARIANT_VMVX]], %[[VARIANT_DEFAULT]]
 // CHECK:   %[[RET:.+]] = scf.index_switch %[[VARIANT_INDEX]] -> !hal.executable
 // CHECK:   case 0 {
 
@@ -244,7 +246,8 @@
 
 util.global private @_descriptor_set_layout_0 : !hal.descriptor_set_layout
 util.initializer {
-  %device = hal.ex.shared_device : !hal.device
+  %c0 = arith.constant 0 : index
+  %device = hal.devices.get %c0 : !hal.device
   %descriptor_set_layout = hal.descriptor_set_layout.create device(%device : !hal.device) flags("None") bindings([#hal.descriptor_set.binding<0, storage_buffer>, #hal.descriptor_set.binding<1, storage_buffer>]) : !hal.descriptor_set_layout
   util.global.store %descriptor_set_layout, @_descriptor_set_layout_0 : !hal.descriptor_set_layout
   util.initializer.return
@@ -253,7 +256,8 @@
 util.global private @_pipeline_layout_0 : !hal.pipeline_layout
 util.initializer {
   %_descriptor_set_layout_0 = util.global.load @_descriptor_set_layout_0 : !hal.descriptor_set_layout
-  %device = hal.ex.shared_device : !hal.device
+  %c0 = arith.constant 0 : index
+  %device = hal.devices.get %c0 : !hal.device
   %pipeline_layout = hal.pipeline_layout.create device(%device : !hal.device) push_constants(0) layouts([%_descriptor_set_layout_0]) : !hal.pipeline_layout
   util.global.store %pipeline_layout, @_pipeline_layout_0 : !hal.pipeline_layout
   util.initializer.return
@@ -261,9 +265,9 @@
 
 util.global private @_executable_exe : !hal.executable
 util.initializer {
-  %device = hal.ex.shared_device : !hal.device
-  %format_ok, %format_supported = hal.device.query<%device : !hal.device> key("hal.executable.format" :: "some-format") : i1, i1
   %c0 = arith.constant 0 : index
+  %device = hal.devices.get %c0 : !hal.device
+  %format_ok, %format_supported = hal.device.query<%device : !hal.device> key("hal.executable.format" :: "some-format") : i1, i1
   %c-1 = arith.constant -1 : index
   %variant = arith.select %format_supported, %c0, %c-1 : index
   %selected = scf.index_switch %variant -> !hal.executable
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/memoize_device_queries.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/memoize_device_queries.mlir
index 6884811..1fd8492 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/memoize_device_queries.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/memoize_device_queries.mlir
@@ -3,7 +3,7 @@
 //      CHECK: util.global private @_device_query_0 : i1
 // CHECK-NEXT: util.global private @_device_query_0_ok : i1
 // CHECK-NEXT: util.initializer {
-// CHECK-NEXT:   %[[DEVICE:.+]] = hal.ex.shared_device : !hal.device
+//  CHECK-DAG:   %[[DEVICE:.+]] = hal.devices.get %{{.+}}
 // CHECK-NEXT:   %[[OK0:.+]], %[[VALUE0:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.device.id" :: "id0*") : i1, i1 = false
 // CHECK-NEXT:   util.global.store %[[OK0]], @_device_query_0_ok : i1
 // CHECK-NEXT:   util.global.store %[[VALUE0]], @_device_query_0 : i1
@@ -11,7 +11,7 @@
 //      CHECK: util.global private @_device_query_1 : i1
 // CHECK-NEXT: util.global private @_device_query_1_ok : i1
 // CHECK-NEXT: util.initializer {
-// CHECK-NEXT:   %[[DEVICE:.+]] = hal.ex.shared_device : !hal.device
+//  CHECK-DAG:   %[[DEVICE:.+]] = hal.devices.get %{{.+}}
 // CHECK-NEXT:   %[[OK1:.+]], %[[VALUE1:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.device.id" :: "id1") : i1, i1 = false
 // CHECK-NEXT:   util.global.store %[[OK1]], @_device_query_1_ok : i1
 // CHECK-NEXT:   util.global.store %[[VALUE1]], @_device_query_1 : i1
diff --git a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir
index 5166f5a..445c508 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir
@@ -8,9 +8,6 @@
 // Experimental/temporary ops
 //===----------------------------------------------------------------------===//
 
-vm.import private @ex.shared_device() -> !vm.ref<!hal.device>
-attributes {nosideeffects}
-
 // Creates a file mapped into a byte range of a host buffer.
 // EXPERIMENTAL: may be removed in future versions.
 vm.import private @ex.file.from_memory(
@@ -419,6 +416,22 @@
 )
 
 //===----------------------------------------------------------------------===//
+// iree_hal_device_t management
+//===----------------------------------------------------------------------===//
+
+vm.import private @devices.count() -> i32
+attributes {
+  minimum_version = 2 : i32,
+  nosideeffects
+}
+
+vm.import private @devices.get(%index : i32) -> !vm.ref<!hal.device>
+attributes {
+  minimum_version = 2 : i32,
+  nosideeffects
+}
+
+//===----------------------------------------------------------------------===//
 // iree_hal_executable_t
 //===----------------------------------------------------------------------===//
 
diff --git a/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp b/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp
index c08f5b6..9cf13da 100644
--- a/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp
+++ b/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp
@@ -68,7 +68,8 @@
   state.addAttributes(srcOp->getAttrs());
 
   // Add device argument.
-  Value device = rewriter.create<IREE::HAL::ExSharedDeviceOp>(srcOp->getLoc());
+  // TODO(multi-device): support multiple devices in check tests .
+  Value device = IREE::HAL::DeviceType::resolveAny(srcOp->getLoc(), rewriter);
   state.addOperands({device});
 
   for (auto [srcOperand, dstOperand] :
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/parameter_ops.mlir b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/parameter_ops.mlir
index eb75719..92035c3 100644
--- a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/parameter_ops.mlir
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/parameter_ops.mlir
@@ -7,7 +7,7 @@
   %c51_i64 = arith.constant 51 : i64
   %c100 = arith.constant 100 : index
   %c101 = arith.constant 101 : index
-  // CHECK-DAG: %[[DEVICE:.+]] = hal.ex.shared_device
+  // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
   // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
   // CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
   // CHECK: %[[BUFFERS:.+]]:2 = io_parameters.load<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
@@ -30,7 +30,7 @@
 func.func @parameterLoadNoScope(%wait: !stream.timepoint) -> (!stream.resource<constant>, !stream.timepoint) {
   %c50_i64 = arith.constant 50 : i64
   %c100 = arith.constant 100 : index
-  // CHECK-DAG: %[[DEVICE:.+]] = hal.ex.shared_device
+  // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
   // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
   // CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
   // CHECK: %[[BUFFER:.+]] = io_parameters.load<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
@@ -53,7 +53,7 @@
   %c100 = arith.constant 100 : index
   %c200 = arith.constant 200 : index
   %c300 = arith.constant 300 : index
-  // CHECK-DAG: %[[DEVICE:.+]] = hal.ex.shared_device
+  // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
   // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
   // CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
   // CHECK: io_parameters.gather<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
@@ -73,7 +73,7 @@
   %c100 = arith.constant 100 : index
   %c200 = arith.constant 200 : index
   %c300 = arith.constant 300 : index
-  // CHECK-DAG: %[[DEVICE:.+]] = hal.ex.shared_device
+  // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
   // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
   // CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
   // CHECK: io_parameters.scatter<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
@@ -99,7 +99,7 @@
   %c201 = arith.constant 201 : index
   %c202 = arith.constant 202 : index
   %c300 = arith.constant 300 : index
-  // CHECK-DAG: %[[DEVICE:.+]] = hal.ex.shared_device
+  // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
   // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
   // CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
   // CHECK: io_parameters.gather<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
@@ -128,7 +128,7 @@
   %c200 = arith.constant 200 : index
   %c201 = arith.constant 201 : index
   %c300 = arith.constant 300 : index
-  // CHECK-DAG: %[[DEVICE:.+]] = hal.ex.shared_device
+  // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
   // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
   // CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
   // CHECK: io_parameters.gather<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
@@ -158,7 +158,7 @@
   %c201 = arith.constant 201 : index
   %c202 = arith.constant 202 : index
   %c300 = arith.constant 300 : index
-  // CHECK-DAG: %[[DEVICE:.+]] = hal.ex.shared_device
+  // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
   // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
   // CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
   // CHECK: io_parameters.scatter<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
diff --git a/docs/website/docs/developers/general/developer-tips.md b/docs/website/docs/developers/general/developer-tips.md
index c830cfd..f4e2090 100644
--- a/docs/website/docs/developers/general/developer-tips.md
+++ b/docs/website/docs/developers/general/developer-tips.md
@@ -123,8 +123,8 @@
   hal, version >= 0, required
 
 Imported Functions:
-  [  0] hal.ex.shared_device() -> (!vm.ref<?>)
-  [  1] hal.allocator.allocate(!vm.ref<?>, i32, i32, i64) -> (!vm.ref<?>)
+  [  0] hal.allocator.allocate(!vm.ref<?>, i32, i32, i64) -> (!vm.ref<?>)
+  [  1] hal.devices.get(i32) -> (!vm.ref<?>)
   ...
 
 Exported Functions:
diff --git a/integrations/pjrt/src/iree_pjrt/common/api_impl.cc b/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
index d6f1dc4..a49faa1 100644
--- a/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
+++ b/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
@@ -1597,8 +1597,8 @@
   // HAL module.
   modules.push_back({});
   IREE_RETURN_IF_ERROR(iree_hal_module_create(
-      vm_instance(), hal_device, IREE_HAL_MODULE_FLAG_NONE, host_allocator(),
-      &modules.back()));
+      vm_instance(), /*device_count=*/1, &hal_device, IREE_HAL_MODULE_FLAG_NONE,
+      host_allocator(), &modules.back()));
 
   // Main module.
   modules.push_back(main_module);
diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc
index 9843c68..2460b36 100644
--- a/runtime/bindings/python/hal.cc
+++ b/runtime/bindings/python/hal.cc
@@ -768,10 +768,12 @@
 // HAL module
 //------------------------------------------------------------------------------
 
+// TODO(multi-device): allow for multiple devices to be passed in.
 VmModule CreateHalModule(VmInstance* instance, HalDevice* device) {
+  iree_hal_device_t* device_ptr = device->raw_ptr();
   iree_vm_module_t* module = NULL;
-  CheckApiStatus(iree_hal_module_create(instance->raw_ptr(), device->raw_ptr(),
-                                        IREE_HAL_MODULE_FLAG_NONE,
+  CheckApiStatus(iree_hal_module_create(instance->raw_ptr(), /*device_count=*/1,
+                                        &device_ptr, IREE_HAL_MODULE_FLAG_NONE,
                                         iree_allocator_system(), &module),
                  "Error creating hal module");
   return VmModule::StealFromRawPtr(module);
diff --git a/runtime/bindings/tflite/interpreter.c b/runtime/bindings/tflite/interpreter.c
index 3dae25d..6f0120b 100644
--- a/runtime/bindings/tflite/interpreter.c
+++ b/runtime/bindings/tflite/interpreter.c
@@ -60,9 +60,10 @@
       "failed creating the default device for driver '%.*s'",
       (int)driver_name.size, driver_name.data);
 
-  IREE_RETURN_IF_ERROR(iree_hal_module_create(
-      interpreter->instance, interpreter->device, IREE_HAL_MODULE_FLAG_NONE,
-      interpreter->allocator, &interpreter->hal_module));
+  IREE_RETURN_IF_ERROR(
+      iree_hal_module_create(interpreter->instance, /*device_count=*/1,
+                             &interpreter->device, IREE_HAL_MODULE_FLAG_NONE,
+                             interpreter->allocator, &interpreter->hal_module));
 
   return iree_ok_status();
 }
diff --git a/runtime/src/iree/modules/check/check_test.cc b/runtime/src/iree/modules/check/check_test.cc
index 0966580..e6afaee 100644
--- a/runtime/src/iree/modules/check/check_test.cc
+++ b/runtime/src/iree/modules/check/check_test.cc
@@ -44,9 +44,9 @@
     }
     IREE_ASSERT_OK(iree_hal_driver_create_default_device(
         hal_driver, iree_allocator_system(), &device_));
-    IREE_ASSERT_OK(
-        iree_hal_module_create(instance_, device_, IREE_HAL_MODULE_FLAG_NONE,
-                               iree_allocator_system(), &hal_module_));
+    IREE_ASSERT_OK(iree_hal_module_create(
+        instance_, /*device_count=*/1, &device_, IREE_HAL_MODULE_FLAG_NONE,
+        iree_allocator_system(), &hal_module_));
     iree_hal_driver_release(hal_driver);
 
     IREE_ASSERT_OK(iree_check_module_create(instance_, iree_allocator_system(),
diff --git a/runtime/src/iree/modules/check/test/success.mlir b/runtime/src/iree/modules/check/test/success.mlir
index 45b7939..7c5012e 100644
--- a/runtime/src/iree/modules/check/test/success.mlir
+++ b/runtime/src/iree/modules/check/test/success.mlir
@@ -13,7 +13,8 @@
 }
 
 func.func @expect_all_true() {
-  %device = hal.ex.shared_device : !hal.device
+  %c0 = arith.constant 0 : index
+  %device = hal.devices.get %c0 : !hal.device
   %all_true = util.unfoldable_constant dense<1> : tensor<2x2xi32>
   %all_true_view = hal.tensor.export %all_true : tensor<2x2xi32> -> !hal.buffer_view
   check.expect_all_true<%device>(%all_true_view) : !hal.buffer_view
diff --git a/runtime/src/iree/modules/hal/exports.inl b/runtime/src/iree/modules/hal/exports.inl
index c9eae4b..e61638a 100644
--- a/runtime/src/iree/modules/hal/exports.inl
+++ b/runtime/src/iree/modules/hal/exports.inl
@@ -71,8 +71,10 @@
 EXPORT_FN("device.queue.read", iree_hal_module_device_queue_read, rIrrrIrIIi, v)
 EXPORT_FN("device.queue.write", iree_hal_module_device_queue_write, rIrrrIrIIi, v)
 
+EXPORT_FN("devices.count", iree_hal_module_devices_count, v, i)
+EXPORT_FN("devices.get", iree_hal_module_devices_get, i, r)
+
 EXPORT_FN("ex.file.from_memory", iree_hal_module_ex_file_from_memory, rIirIIi, r)
-EXPORT_FN("ex.shared_device", iree_hal_module_ex_shared_device, v, r)
 
 EXPORT_FN("executable.create", iree_hal_module_executable_create, rrrrCrD, r)
 
diff --git a/runtime/src/iree/modules/hal/module.c b/runtime/src/iree/modules/hal/module.c
index 57a4cc6..d00c39a 100644
--- a/runtime/src/iree/modules/hal/module.c
+++ b/runtime/src/iree/modules/hal/module.c
@@ -32,18 +32,26 @@
 // Module type definitions
 //===----------------------------------------------------------------------===//
 
-#define IREE_HAL_MODULE_VERSION_0_1 0x00000001u
-#define IREE_HAL_MODULE_VERSION_LATEST IREE_HAL_MODULE_VERSION_0_1
+#define IREE_HAL_MODULE_VERSION_0_2 0x00000002u
+#define IREE_HAL_MODULE_VERSION_LATEST IREE_HAL_MODULE_VERSION_0_2
 
 typedef struct iree_hal_module_t {
   iree_allocator_t host_allocator;
   iree_hal_module_flags_t flags;
-  iree_hal_device_t* shared_device;
+  iree_host_size_t device_count;
+  iree_hal_device_t* devices[];
 } iree_hal_module_t;
 
 #define IREE_HAL_MODULE_CAST(module) \
   (iree_hal_module_t*)((uint8_t*)(module) + iree_vm_native_module_size());
 
+static void IREE_API_PTR iree_hal_module_destroy(void* base_module) {
+  iree_hal_module_t* module = IREE_HAL_MODULE_CAST(base_module);
+  for (iree_host_size_t i = 0; i < module->device_count; ++i) {
+    iree_hal_device_release(module->devices[i]);
+  }
+}
+
 typedef struct iree_hal_module_state_t {
   iree_allocator_t host_allocator;
 
@@ -51,29 +59,26 @@
   // application. All instantiations of a module share the same flags.
   iree_hal_module_flags_t flags;
 
-  // HACK: today we only support a single device per context - in the future
-  // this should be a set of available devices that the module is able to pick
-  // from - the module will then hang on to them and use them as native globals
-  // instead of storing anything in module state here.
-  iree_hal_device_t* shared_device;
+  // Total number of devices available to the module.
+  iree_host_size_t device_count;
+  // Devices referencing the storage in the parent module.
+  // Unretained as the parent module must remain live longer than any module
+  // state allocated from it and we can rely on it to keep the devices retained.
+  iree_hal_device_t** devices;
 
   // TODO(benvanik): add iree_loop_t to module constructor.
   // Status of the nested loop we run for executable creation today. We should
   // instead be taking a loop upon creation and scheduling work against that.
   iree_status_t loop_status;
 
-  // Shared executable cache for all executables created in the context.
-  // We could have multiple to allow for modules to create distinct sets of
-  // executables like ones for training vs inference in the same model, or just
-  // always use this.
-  iree_hal_executable_cache_t* executable_cache;
+  // Shared executable cache for each device used to cache all executables
+  // created in the context. We could have multiple to allow for modules to
+  // create distinct sets of executables like ones for training vs inference in
+  // the same model or allow these to be injected so that multiple loaded
+  // contexts share the caches.
+  iree_hal_executable_cache_t* executable_caches[];
 } iree_hal_module_state_t;
 
-static void IREE_API_PTR iree_hal_module_destroy(void* base_module) {
-  iree_hal_module_t* module = IREE_HAL_MODULE_CAST(base_module);
-  iree_hal_device_release(module->shared_device);
-}
-
 static iree_status_t IREE_API_PTR
 iree_hal_module_alloc_state(void* self, iree_allocator_t host_allocator,
                             iree_vm_module_state_t** out_module_state) {
@@ -81,24 +86,36 @@
 
   iree_hal_module_t* module = IREE_HAL_MODULE_CAST(self);
   iree_hal_module_state_t* state = NULL;
+  iree_host_size_t total_size =
+      sizeof(*state) +
+      module->device_count * sizeof(state->executable_caches[0]);
   IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0,
-      iree_allocator_malloc(host_allocator, sizeof(*state), (void**)&state));
-  memset(state, 0, sizeof(*state));
+      z0, iree_allocator_malloc(host_allocator, total_size, (void**)&state));
+  memset(state, 0, total_size);
   state->host_allocator = host_allocator;
   state->flags = module->flags;
-  state->shared_device = module->shared_device;
-  iree_hal_device_retain(state->shared_device);
-
+  state->device_count = module->device_count;
+  state->devices = module->devices;
   state->loop_status = iree_ok_status();
-  IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0, iree_hal_executable_cache_create(
-              state->shared_device, iree_string_view_empty(),
-              iree_loop_inline(&state->loop_status), &state->executable_cache));
 
-  *out_module_state = (iree_vm_module_state_t*)state;
+  iree_status_t status = iree_ok_status();
+  for (iree_host_size_t i = 0; i < state->device_count; ++i) {
+    status = iree_hal_executable_cache_create(
+        state->devices[i], iree_string_view_empty(),
+        iree_loop_inline(&state->loop_status), &state->executable_caches[i]);
+    if (!iree_status_is_ok(status)) break;
+  }
+
+  if (iree_status_is_ok(status)) {
+    *out_module_state = (iree_vm_module_state_t*)state;
+  } else {
+    for (iree_host_size_t i = 0; i < state->device_count; ++i) {
+      iree_hal_executable_cache_release(state->executable_caches[i]);
+    }
+    iree_allocator_free(host_allocator, state);
+  }
   IREE_TRACE_ZONE_END(z0);
-  return iree_ok_status();
+  return status;
 }
 
 static void IREE_API_PTR
@@ -106,23 +123,52 @@
   IREE_TRACE_ZONE_BEGIN(z0);
 
   iree_hal_module_state_t* state = (iree_hal_module_state_t*)module_state;
-  iree_hal_executable_cache_release(state->executable_cache);
+  for (iree_host_size_t i = 0; i < state->device_count; ++i) {
+    iree_hal_executable_cache_release(state->executable_caches[i]);
+  }
   iree_status_ignore(state->loop_status);
-  iree_hal_device_release(state->shared_device);
   iree_allocator_free(state->host_allocator, state);
 
   IREE_TRACE_ZONE_END(z0);
 }
 
+// Returns an unretained reference to the executable cache for the given device.
+// If the same device is registered multiple times the first cache is returned.
+static iree_status_t iree_hal_module_state_lookup_executable_cache(
+    iree_hal_module_state_t* state, iree_hal_device_t* device,
+    iree_hal_executable_cache_t** out_executable_cache) {
+  IREE_ASSERT_ARGUMENT(state);
+  IREE_ASSERT_ARGUMENT(device);
+  IREE_ASSERT_ARGUMENT(out_executable_cache);
+  *out_executable_cache = NULL;
+  for (iree_host_size_t i = 0; i < state->device_count; ++i) {
+    if (state->devices[i] == device) {
+      *out_executable_cache = state->executable_caches[i];
+      return iree_ok_status();
+    }
+  }
+  return iree_make_status(
+      IREE_STATUS_NOT_FOUND,
+      "no executable cache for the given device found; possibly a device not "
+      "registered with the HAL module");
+}
+
 static iree_status_t IREE_API_PTR iree_hal_module_notify(
     void* self, iree_vm_module_state_t* module_state, iree_vm_signal_t signal) {
   iree_hal_module_state_t* state = (iree_hal_module_state_t*)module_state;
   switch (signal) {
     case IREE_VM_SIGNAL_SUSPEND:
-    case IREE_VM_SIGNAL_LOW_MEMORY:
-      return iree_hal_device_trim(state->shared_device);
-    default:
+    case IREE_VM_SIGNAL_LOW_MEMORY: {
+      for (iree_host_size_t i = 0; i < state->device_count; ++i) {
+        IREE_RETURN_IF_ERROR(iree_hal_device_trim(state->devices[i]));
+      }
       return iree_ok_status();
+    }
+    default: {
+      // Ignored today but if we started managing device power down we could
+      // use this to wake them back up again.
+      return iree_ok_status();
+    }
   }
 }
 
@@ -150,13 +196,6 @@
 // NOTE: Ex* APIs are experimental and likely to be removed soon. Modules
 // using these APIs are not forward compatible.
 
-IREE_VM_ABI_EXPORT(iree_hal_module_ex_shared_device,  //
-                   iree_hal_module_state_t,           //
-                   v, r) {
-  rets->r0 = iree_hal_device_retain_ref(state->shared_device);
-  return iree_ok_status();
-}
-
 static void iree_hal_module_file_buffer_release(
     void* user_data, iree_io_file_handle_primitive_t handle_primitive) {
   iree_vm_buffer_t* backing_buffer = (iree_vm_buffer_t*)user_data;
@@ -411,9 +450,8 @@
                             "load length byte count %d exceeds max", length);
   }
 
-  IREE_RETURN_IF_ERROR(iree_hal_device_transfer_d2h(
-      state->shared_device, source_buffer, source_offset, &target_buffer,
-      length, IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout()));
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_map_read(source_buffer, source_offset,
+                                                &target_buffer, length));
 
   rets->i0 = target_buffer;
   return iree_ok_status();
@@ -440,9 +478,8 @@
                             iree_hal_buffer_byte_length(target_buffer));
   }
 
-  return iree_hal_device_transfer_h2d(
-      state->shared_device, &value, target_buffer, target_offset, length,
-      IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout());
+  return iree_hal_buffer_map_write(target_buffer, target_offset, &value,
+                                   length);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1104,6 +1141,30 @@
   return iree_hal_device_queue_flush(device, queue_affinity);
 }
 
+//===----------------------------------------------------------------------===//
+// iree_hal_device_t management
+//===----------------------------------------------------------------------===//
+
+IREE_VM_ABI_EXPORT(iree_hal_module_devices_count,  //
+                   iree_hal_module_state_t,        //
+                   v, i) {
+  rets->i0 = (int32_t)state->device_count;
+  return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_devices_get,  //
+                   iree_hal_module_state_t,      //
+                   i, r) {
+  if (args->i0 >= state->device_count) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "device index %d out of bounds (%" PRIhsz
+                            " devices available)",
+                            args->i0, state->device_count);
+  }
+  rets->r0 = iree_hal_device_retain_ref(state->devices[args->i0]);
+  return iree_ok_status();
+}
+
 //===--------------------------------------------------------------------===//
 // iree_hal_executable_t
 //===--------------------------------------------------------------------===//
@@ -1135,6 +1196,11 @@
     constant_count = constant_buffer->data.data_length / sizeof(uint32_t);
     constants = (const uint32_t*)constant_buffer->data.data;
   }
+
+  iree_hal_executable_cache_t* executable_cache = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_module_state_lookup_executable_cache(
+      state, device, &executable_cache));
+
   iree_host_size_t pipeline_layout_count = args->a4_count;
   iree_hal_pipeline_layout_t** pipeline_layouts = NULL;
   IREE_RETURN_IF_ERROR(
@@ -1164,7 +1230,7 @@
     executable_params.constant_count = constant_count;
     executable_params.constants = constants;
     status = iree_hal_executable_cache_prepare_executable(
-        state->executable_cache, &executable_params, &executable);
+        executable_cache, &executable_params, &executable);
   }
 
   iree_allocator_free(state->host_allocator, pipeline_layouts);
@@ -1573,13 +1639,15 @@
 };
 
 IREE_API_EXPORT iree_status_t iree_hal_module_create(
-    iree_vm_instance_t* instance, iree_hal_device_t* device,
-    iree_hal_module_flags_t flags, iree_allocator_t host_allocator,
-    iree_vm_module_t** out_module) {
+    iree_vm_instance_t* instance, iree_host_size_t device_count,
+    iree_hal_device_t** devices, iree_hal_module_flags_t flags,
+    iree_allocator_t host_allocator, iree_vm_module_t** out_module) {
   IREE_ASSERT_ARGUMENT(instance);
-  IREE_ASSERT_ARGUMENT(device);
+  IREE_ASSERT_ARGUMENT(device_count);
+  IREE_ASSERT_ARGUMENT(devices);
   IREE_ASSERT_ARGUMENT(out_module);
   *out_module = NULL;
+  IREE_TRACE_ZONE_BEGIN(z0);
 
   // Setup the interface with the functions we implement ourselves. Any function
   // we omit will be handled by the base native module.
@@ -1591,8 +1659,9 @@
   };
 
   // Allocate shared module state.
-  iree_host_size_t total_size =
-      iree_vm_native_module_size() + sizeof(iree_hal_module_t);
+  iree_host_size_t total_size = iree_vm_native_module_size() +
+                                sizeof(iree_hal_module_t) +
+                                device_count * sizeof(iree_hal_device_t*);
   iree_vm_module_t* base_module = NULL;
   IREE_RETURN_IF_ERROR(
       iree_allocator_malloc(host_allocator, total_size, (void**)&base_module));
@@ -1602,6 +1671,7 @@
                                        instance, host_allocator, base_module);
   if (!iree_status_is_ok(status)) {
     iree_allocator_free(host_allocator, base_module);
+    IREE_TRACE_ZONE_END(z0);
     return status;
   }
 
@@ -1609,15 +1679,25 @@
   module->host_allocator = host_allocator;
   // TODO(benvanik): fix vm yield with result storage.
   module->flags = flags | IREE_HAL_MODULE_FLAG_SYNCHRONOUS;
-  module->shared_device = device;
-  iree_hal_device_retain(module->shared_device);
+  module->device_count = device_count;
+  for (iree_host_size_t i = 0; i < device_count; ++i) {
+    module->devices[i] = devices[i];
+    iree_hal_device_retain(module->devices[i]);
+  }
 
   *out_module = base_module;
+  IREE_TRACE_ZONE_END(z0);
   return iree_ok_status();
 }
 
-IREE_API_EXPORT iree_hal_device_t* iree_hal_module_state_device(
-    iree_vm_module_state_t* module_state) {
+IREE_API_EXPORT iree_host_size_t
+iree_hal_module_state_device_count(iree_vm_module_state_t* module_state) {
   iree_hal_module_state_t* state = (iree_hal_module_state_t*)module_state;
-  return state->shared_device;
+  return state->device_count;
+}
+
+IREE_API_EXPORT iree_hal_device_t* iree_hal_module_state_device_get(
+    iree_vm_module_state_t* module_state, iree_host_size_t index) {
+  iree_hal_module_state_t* state = (iree_hal_module_state_t*)module_state;
+  return index < state->device_count ? state->devices[index] : NULL;
 }
diff --git a/runtime/src/iree/modules/hal/module.h b/runtime/src/iree/modules/hal/module.h
index 5ba71af..ee45570 100644
--- a/runtime/src/iree/modules/hal/module.h
+++ b/runtime/src/iree/modules/hal/module.h
@@ -26,18 +26,23 @@
 };
 typedef uint32_t iree_hal_module_flags_t;
 
-// Creates the HAL module initialized to use a specific |device|.
-// Each context using this module will share the device and have compatible
+// Creates the HAL module initialized to use one or more |devices|.
+// Each context using this module will share the devices and have compatible
 // allocations.
 IREE_API_EXPORT iree_status_t iree_hal_module_create(
-    iree_vm_instance_t* instance, iree_hal_device_t* device,
-    iree_hal_module_flags_t flags, iree_allocator_t host_allocator,
-    iree_vm_module_t** out_module);
+    iree_vm_instance_t* instance, iree_host_size_t device_count,
+    iree_hal_device_t** devices, iree_hal_module_flags_t flags,
+    iree_allocator_t host_allocator, iree_vm_module_t** out_module);
 
-// Returns the device currently in use by the HAL module.
-// Returns NULL if no device has been initialized yet.
-IREE_API_EXPORT iree_hal_device_t* iree_hal_module_state_device(
-    iree_vm_module_state_t* module_state);
+// Returns the total number of available devices registered with the HAL module.
+IREE_API_EXPORT iree_host_size_t
+iree_hal_module_state_device_count(iree_vm_module_state_t* module_state);
+
+// Returns the device at |index| currently in use by the HAL module.
+// Returns NULL if no device has been initialized yet or the index is out of
+// bounds.
+IREE_API_EXPORT iree_hal_device_t* iree_hal_module_state_device_get(
+    iree_vm_module_state_t* module_state, iree_host_size_t index);
 
 #ifdef __cplusplus
 }  // extern "C"
diff --git a/runtime/src/iree/runtime/session.c b/runtime/src/iree/runtime/session.c
index e166569..4286178 100644
--- a/runtime/src/iree/runtime/session.c
+++ b/runtime/src/iree/runtime/session.c
@@ -94,8 +94,9 @@
   iree_vm_module_t* hal_module = NULL;
   if (iree_status_is_ok(status)) {
     status = iree_hal_module_create(iree_runtime_instance_vm_instance(instance),
-                                    device, IREE_HAL_MODULE_FLAG_NONE,
-                                    host_allocator, &hal_module);
+                                    /*device_count=*/1, &device,
+                                    IREE_HAL_MODULE_FLAG_NONE, host_allocator,
+                                    &hal_module);
   }
   if (iree_status_is_ok(status)) {
     status = iree_vm_context_register_modules(
@@ -163,7 +164,8 @@
 IREE_API_EXPORT iree_hal_device_t* iree_runtime_session_device(
     const iree_runtime_session_t* session) {
   IREE_ASSERT_ARGUMENT(session);
-  return iree_hal_module_state_device(session->hal_module_state);
+  // NOTE: only one device is supported via this API today.
+  return iree_hal_module_state_device_get(session->hal_module_state, 0);
 }
 
 IREE_API_EXPORT iree_hal_allocator_t* iree_runtime_session_device_allocator(
diff --git a/runtime/src/iree/runtime/session.h b/runtime/src/iree/runtime/session.h
index c78d75c..f867da1 100644
--- a/runtime/src/iree/runtime/session.h
+++ b/runtime/src/iree/runtime/session.h
@@ -116,6 +116,8 @@
 //
 // NOTE: this device will not be available until initialized by a user module
 // and will return NULL if queried prior.
+//
+// NOTE: this API does not support multiple devices.
 IREE_API_EXPORT iree_hal_device_t* iree_runtime_session_device(
     const iree_runtime_session_t* session);
 
@@ -125,6 +127,8 @@
 //
 // NOTE: this device allocator will not be available until initialized by a
 // user module and will return NULL if queried prior.
+//
+// NOTE: this API does not support multiple devices.
 IREE_API_EXPORT iree_hal_allocator_t* iree_runtime_session_device_allocator(
     const iree_runtime_session_t* session);
 
diff --git a/runtime/src/iree/tooling/context_util.c b/runtime/src/iree/tooling/context_util.c
index 135a111..ccb0d9c 100644
--- a/runtime/src/iree/tooling/context_util.c
+++ b/runtime/src/iree/tooling/context_util.c
@@ -196,6 +196,8 @@
   IREE_RETURN_AND_END_ZONE_IF_ERROR(
       z0, iree_hal_module_register_all_types(instance));
 
+  // TODO(multi-device): create multiple devices (maybe with an
+  // iree_hal_device_list_t helper for retaining/managing the dynamic list).
   // Create the device to use.
   // In the future this will change to a set of available devices instead.
   if (iree_string_view_is_empty(default_device_uri)) {
@@ -214,8 +216,8 @@
   // Create HAL module wrapping the device created above.
   iree_hal_module_flags_t flags = IREE_HAL_MODULE_FLAG_NONE;
   iree_vm_module_t* module = NULL;
-  iree_status_t status =
-      iree_hal_module_create(instance, device, flags, host_allocator, &module);
+  iree_status_t status = iree_hal_module_create(
+      instance, /*device_count=*/1, &device, flags, host_allocator, &module);
 
   if (iree_status_is_ok(status)) {
     *out_module = module;
diff --git a/runtime/src/iree/tooling/trace_replay.c b/runtime/src/iree/tooling/trace_replay.c
index 06ed094..5f2350a 100644
--- a/runtime/src/iree/tooling/trace_replay.c
+++ b/runtime/src/iree/tooling/trace_replay.c
@@ -176,8 +176,8 @@
     IREE_RETURN_IF_ERROR(iree_trace_replay_create_device(
         replay, device_node, replay->host_allocator, &replay->device));
     IREE_RETURN_IF_ERROR(iree_hal_module_create(
-        replay->instance, replay->device, IREE_HAL_MODULE_FLAG_NONE,
-        replay->host_allocator, &module));
+        replay->instance, /*device_count=*/1, &replay->device,
+        IREE_HAL_MODULE_FLAG_NONE, replay->host_allocator, &module));
   }
   if (!module) {
     return iree_make_status(
diff --git a/runtime/src/iree/vm/shims.c b/runtime/src/iree/vm/shims.c
index 83245cb..80f11f5 100644
--- a/runtime/src/iree/vm/shims.c
+++ b/runtime/src/iree/vm/shims.c
@@ -7,6 +7,7 @@
 #include "iree/vm/shims.h"
 
 IREE_VM_ABI_DEFINE_SHIM(irIi, v);
+IREE_VM_ABI_DEFINE_SHIM(i, r);
 IREE_VM_ABI_DEFINE_SHIM(r, i);
 IREE_VM_ABI_DEFINE_SHIM(r, I);
 IREE_VM_ABI_DEFINE_SHIM(r, ii);
diff --git a/runtime/src/iree/vm/shims.h b/runtime/src/iree/vm/shims.h
index af75456..860d96f 100644
--- a/runtime/src/iree/vm/shims.h
+++ b/runtime/src/iree/vm/shims.h
@@ -614,6 +614,7 @@
 //===----------------------------------------------------------------------===//
 
 IREE_VM_ABI_DECLARE_SHIM(irIi, v);
+IREE_VM_ABI_DECLARE_SHIM(i, r);
 IREE_VM_ABI_DECLARE_SHIM(r, i);
 IREE_VM_ABI_DECLARE_SHIM(r, I);
 IREE_VM_ABI_DECLARE_SHIM(r, ii);
diff --git a/samples/simple_embedding/simple_embedding.c b/samples/simple_embedding/simple_embedding.c
index 57f6761..b5df80b 100644
--- a/samples/simple_embedding/simple_embedding.c
+++ b/samples/simple_embedding/simple_embedding.c
@@ -40,9 +40,9 @@
   IREE_RETURN_IF_ERROR(create_sample_device(iree_allocator_system(), &device),
                        "create device");
   iree_vm_module_t* hal_module = NULL;
-  IREE_RETURN_IF_ERROR(
-      iree_hal_module_create(instance, device, IREE_HAL_MODULE_FLAG_SYNCHRONOUS,
-                             iree_allocator_system(), &hal_module));
+  IREE_RETURN_IF_ERROR(iree_hal_module_create(
+      instance, /*device_count=*/1, &device, IREE_HAL_MODULE_FLAG_SYNCHRONOUS,
+      iree_allocator_system(), &hal_module));
 
   // Load bytecode module from the embedded data.
   const iree_const_byte_span_t module_data = load_bytecode_module_data();