[Stream] Selecting unified encodings from encoding resolvers. (#22898)

Add a new `getUnifiedEncoding` method to the `LayoutResolverAttr`
interface that returns a unified encoding given multiple candidate
encodings. This is used by the UnifyEncodingForGlobals pass to select an
appropriate encoding when the same source data has multiple encoded
versions.

Also refactor GlobalEncodingAnalyzer to:
- Set up the layout resolver from dialect interfaces internally
- Compute unified encodings as part of the analysis phase
- Provide a `getUnifiedEncoding(name)` getter for querying results

This simplifies the pass by moving analysis-related logic into the
analyzer class, making the pass focused on applying transformations.

Update lit tests to include device definitions and affinity attributes
on stream.tensor.* ops, which are required for the pass to resolve
layout attributes properly. It also switches identity_resolver to
specialization_resolver, which improves the test quality. Identity
encoding is used in fallback solution.

It is a step towards https://github.com/iree-org/iree/issues/22485

---------

Signed-off-by: hanhanW <hanhan0912@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp
index 945daa5..bbd9c60 100644
--- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp
+++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp
@@ -600,6 +600,12 @@
   return Encoding::IdentityAttr::get(getContext());
 }
 
+Attribute
+IdentityResolverAttr::getUnifiedEncoding(ArrayRef<Attribute> encodings) const {
+  MLIRContext *ctx = getContext();
+  return Encoding::IdentityAttr::get(ctx);
+}
+
 Type IdentityResolverAttr::convertType(Type type) const {
   using IREE::TensorExt::DispatchTensorType;
   return TypeSwitch<Type, Type>(type)
@@ -673,6 +679,12 @@
                               TypeAttr::get(type.dropEncoding()));
 }
 
+Attribute SpecializationResolverAttr::getUnifiedEncoding(
+    ArrayRef<Attribute> encodings) const {
+  MLIRContext *ctx = getContext();
+  return SpecializedAttr::get(ctx, getSeed(), /*type=*/nullptr);
+}
+
 } // namespace mlir::iree_compiler::IREE::Encoding
 
 using namespace mlir::iree_compiler::IREE::Encoding;
diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td
index ad45bce..df4c82d 100644
--- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td
+++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td
@@ -278,6 +278,7 @@
       DeclareAttrInterfaceMethods<IREEEncoding_LayoutResolverAttr, [
         "cloneWithSimplifiedConfig",
         "getLayout",
+        "getUnifiedEncoding",
       ]>,
       DeclareAttrInterfaceMethods<IREEEncoding_LayoutMaterializerAttr, [
         "convertType",
@@ -361,6 +362,7 @@
       DeclareAttrInterfaceMethods<IREEEncoding_LayoutResolverAttr, [
         "cloneWithSimplifiedConfig",
         "getLayout",
+        "getUnifiedEncoding",
       ]>
     ]> {
   let mnemonic = "specialization_resolver";
diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingInterfaces.td b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingInterfaces.td
index 71f78b1..2514add 100644
--- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingInterfaces.td
+++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingInterfaces.td
@@ -69,6 +69,33 @@
         assert(false && "unimplemented interface method");
         return {};
       }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns a unified encoding attribute given a list of candidate encoding
+        attributes. This is used when multiple encodings are used for the same
+        source data and need to be unified to a single encoding.
+
+        Returns nullptr if it fails to get a unified encoding.
+
+        The resolver implementation can create any encoding based on
+        backend-specific heuristics. The returned encoding does not need to be
+        one of the input encodings - for example, a resolver might return an
+        identity encoding for simplicity, or synthesize a new encoding that
+        introduces less overheads in relayout.
+      }],
+      /*retTy=*/"::mlir::Attribute",
+      /*methodName=*/"getUnifiedEncoding",
+      /*args=*/(ins
+        "llvm::ArrayRef<::mlir::Attribute>":$encodings
+      ),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        // TODO(hanchung): Remove the assertion once we know what to do for the
+        // default implementation.
+        assert(false && "unimplemented interface method");
+        return {};
+      }]
     >
   ];
 }
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/UnifyEncodingForGlobals.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/UnifyEncodingForGlobals.cpp
index f735158..1af1ea6 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/UnifyEncodingForGlobals.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/UnifyEncodingForGlobals.cpp
@@ -6,6 +6,7 @@
 
 #include "iree/compiler/Dialect/Encoding/IR/EncodingDialect.h"
 #include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamInterfaces.h"
 #include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
 #include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
 #include "iree/compiler/Dialect/Stream/Transforms/Passes.h"
@@ -28,6 +29,29 @@
 
 namespace {
 
+/// Returns a stably sorted list of dialect interfaces of T for all dialects
+/// used within the given module.
+template <typename T>
+SmallVector<const T *> gatherUsedDialectInterfaces(mlir::ModuleOp moduleOp) {
+  SmallPtrSet<const T *, 4> resultSet;
+  for (Dialect *dialect : moduleOp.getContext()->getLoadedDialects()) {
+    const T *dialectInterface = dialect->getRegisteredInterface<T>();
+    if (!dialectInterface)
+      continue;
+    resultSet.insert(dialectInterface);
+  }
+
+  // NOTE: to ensure deterministic output we sort the result so that imports are
+  // always added in a consistent order.
+  auto results = llvm::to_vector_of<const T *>(resultSet);
+  llvm::sort(
+      results, +[](const T *a, const T *b) {
+        return a->getDialect()->getNamespace().compare(
+                   b->getDialect()->getNamespace()) < 0;
+      });
+  return results;
+}
+
 //===----------------------------------------------------------------------===//
 // Analysis.
 //===----------------------------------------------------------------------===//
@@ -64,13 +88,17 @@
 }
 
 // Analyzes a module to find immutable globals that have multiple encoded
-// versions. Use run() to perform analysis, then query results with
-// getSourcesWithMultipleEncodings() or getSourceGlobals().
+// versions, and computes unified encodings for them using the layout resolver
+// from the dialect interface. Use run() to perform analysis, then query results
+// with getSourcesWithMultipleEncodings(), getSourceGlobals(), or
+// getUnifiedEncoding().
 class GlobalEncodingAnalyzer {
 public:
   explicit GlobalEncodingAnalyzer(ModuleOp moduleOp)
       : moduleOp(moduleOp), symbolTable(moduleOp), globalTable(moduleOp) {}
 
+  // Runs the full analysis: sets up resolver, collects encodings, and computes
+  // unified encodings.
   LogicalResult run();
 
   // Returns all source globals that have multiple distinct encodings.
@@ -88,16 +116,22 @@
     return result;
   }
 
-  // Returns the SourceGlobalInfo for the given source global name, or
+  // Returns the SourceGlobalInfo for the given source global name. There is a
+  // copy in the call, so it is not a cheap call.
+  SourceGlobalInfo getSourceGlobals(StringRef name) {
+    return sourceGlobals.at(name);
+  }
+
+  // Returns the unified encoding for the given source global name, or
   // std::nullopt if not found.
-  std::optional<SourceGlobalInfo> getSourceGlobals(StringRef name) const {
-    if (sourceGlobals.contains(name)) {
-      return sourceGlobals.find(name)->second;
-    }
-    return std::nullopt;
+  Attribute getUnifiedEncoding(StringRef name) const {
+    return unifiedEncodings.at(name);
   }
 
 private:
+  // Sets up the layout resolver from dialect interfaces.
+  LogicalResult setupLayoutResolver();
+
   // Walks all initializers to find encoding patterns and populates
   // sourceGlobals map. Looks for patterns like:
   //   %source = util.global.load @source_global
@@ -106,6 +140,9 @@
   // Only considers immutable source and encoded globals.
   LogicalResult collectGlobalEncodings();
 
+  // Computes unified encodings for all source globals with multiple encodings.
+  LogicalResult computeUnifiedEncodings();
+
   // Traces from encode op's source operand back to a source global.
   // Returns nullptr if tracing fails or source is mutable.
   IREE::Util::GlobalOpInterface traceToSourceGlobal(Value value);
@@ -117,14 +154,23 @@
   SymbolTable symbolTable;
   IREE::Util::GlobalTable globalTable;
 
+  // Layout resolver function from dialect interface.
+  IREE::Stream::ResolveLayoutAttrFn resolveLayoutAttr;
+
   // Maps source global name to its info. Populated by run(). The global name
   // must match the name of `sourceGlobal` inside SourceGlobalInfo. StringRef is
   // used for easier lookup, which works better with SymbolTable, etc.
   llvm::MapVector<StringRef, SourceGlobalInfo> sourceGlobals;
+
+  // Maps source global name to its unified encoding. Populated by run().
+  llvm::StringMap<Attribute> unifiedEncodings;
 };
 
 LogicalResult GlobalEncodingAnalyzer::run() {
   LDBG() << "=== GlobalEncodingAnalyzer::run() ===";
+  if (failed(setupLayoutResolver())) {
+    return failure();
+  }
   globalTable.rebuild();
   if (failed(collectGlobalEncodings())) {
     return failure();
@@ -137,6 +183,93 @@
          << " encoded versions\n";
     }
   });
+
+  if (failed(computeUnifiedEncodings())) {
+    return failure();
+  }
+
+  return success();
+}
+
+LogicalResult GlobalEncodingAnalyzer::setupLayoutResolver() {
+  auto usedDialects = gatherUsedDialectInterfaces<
+      IREE::Stream::AffinityAnalysisDialectInterface>(moduleOp);
+  if (usedDialects.size() != 1) {
+    LDBG() << "Expected only one dialect implementing "
+              "AffinityAnalysisDialectInterface";
+    return failure();
+  }
+  resolveLayoutAttr = usedDialects[0]->makeLayoutAttrResolver(moduleOp);
+  return success();
+}
+
+LogicalResult GlobalEncodingAnalyzer::computeUnifiedEncodings() {
+  SmallVector<StringRef> candidates = getSourcesWithMultipleEncodings();
+  if (candidates.empty()) {
+    LDBG() << "No source globals with multiple encodings found.";
+    return success();
+  }
+
+  // Build queries for layout resolution.
+  SmallVector<IREE::Stream::AffinityAndOpPair> queries;
+  for (StringRef sourceName : candidates) {
+    SourceGlobalInfo sourceInfo = getSourceGlobals(sourceName);
+    for (EncodedGlobalInfo &encodedInfo : sourceInfo.encodedVersions) {
+      queries.push_back(
+          {encodedInfo.encodeOp.getAffinityAttr(), encodedInfo.encodedGlobal});
+    }
+  }
+
+  // Resolve layout attributes for all queries.
+  llvm::DenseMap<IREE::Stream::AffinityAndOpPair, SetVector<Attribute>>
+      cachedLayoutAttrs;
+  if (failed(resolveLayoutAttr(queries, cachedLayoutAttrs))) {
+    LDBG() << "Failed to resolve layouts for a query";
+    return failure();
+  }
+
+  // Compute unified encoding for each source global.
+  MLIRContext *ctx = moduleOp.getContext();
+  for (StringRef sourceName : candidates) {
+    SetVector<Attribute> layoutResolvers;
+    SmallVector<Attribute> encodingAttrVersions;
+    SourceGlobalInfo sourceInfo = getSourceGlobals(sourceName);
+    for (EncodedGlobalInfo &encodedInfo : sourceInfo.encodedVersions) {
+      const SetVector<Attribute> &resolvers =
+          cachedLayoutAttrs[IREE::Stream::AffinityAndOpPair(
+              encodedInfo.encodeOp.getAffinityAttr(),
+              encodedInfo.encodedGlobal)];
+      layoutResolvers.insert(resolvers.begin(), resolvers.end());
+      encodingAttrVersions.push_back(encodedInfo.encodingAttr);
+    }
+
+    // TODO: It is not clear which encoding to pick when there are multiple
+    // layout resolvers. For now, just fallback to identity encoding for safety.
+    // A minor improvement can be checking if all the resolvers return the
+    // identical unified encoding and use that.
+    if (layoutResolvers.size() != 1) {
+      unifiedEncodings[sourceName] = IREE::Encoding::IdentityAttr::get(ctx);
+      continue;
+    }
+
+    // Invalid layout resolver, use identity encoding.
+    IREE::Encoding::LayoutResolverAttr layoutResolver =
+        dyn_cast<IREE::Encoding::LayoutResolverAttr>(layoutResolvers[0]);
+    if (!layoutResolver) {
+      unifiedEncodings[sourceName] = IREE::Encoding::IdentityAttr::get(ctx);
+      continue;
+    }
+
+    LDBG() << "Use encoding resolver " << layoutResolver
+           << " to unify encodings for source global: " << sourceName;
+    unifiedEncodings[sourceName] =
+        layoutResolver.getUnifiedEncoding(encodingAttrVersions);
+    // Fallback to identity encoding on failure.
+    if (!unifiedEncodings[sourceName]) {
+      unifiedEncodings[sourceName] = IREE::Encoding::IdentityAttr::get(ctx);
+    }
+  }
+
   return success();
 }
 
@@ -463,16 +596,12 @@
       LDBG() << "Analysis failed, skipping.";
       return;
     }
-    auto candidates = analyzer.getSourcesWithMultipleEncodings();
+    SmallVector<StringRef> candidates =
+        analyzer.getSourcesWithMultipleEncodings();
     if (candidates.empty()) {
       LDBG() << "No source globals with multiple encodings found.";
       return;
     }
-    LDBG() << "Found " << candidates.size()
-           << " source globals with multiple encodings:";
-    for (auto name : candidates) {
-      LDBG() << "  - " << name;
-    }
 
     // Unify encodings for each source global with multiple encodings, and cache
     // the updates.
@@ -480,21 +609,13 @@
     explorer.setOpAction<IREE::Stream::ExecutableOp>(TraversalAction::IGNORE);
     explorer.initialize();
     TensorEncodingUpdates tensorEncodingUpdates;
-    for (auto sourceName : candidates) {
-      std::optional<SourceGlobalInfo> sourceInfo =
-          analyzer.getSourceGlobals(sourceName);
-      if (!sourceInfo) {
-        LDBG() << "  ERROR: source global info not found for " << sourceName;
-        continue;
-      }
-
-      // TODO(#22485): Select unified encoding via resolver. For now, use
-      // identity encoding.
-      auto unifiedEncoding =
-          IREE::Encoding::IdentityAttr::get(moduleOp.getContext());
-
+    for (StringRef sourceName : candidates) {
+      SourceGlobalInfo sourceInfo = analyzer.getSourceGlobals(sourceName);
       // Update each encode op to use the unified encoding.
-      for (EncodedGlobalInfo &encodedInfo : sourceInfo->encodedVersions) {
+      Attribute unifiedEncoding = analyzer.getUnifiedEncoding(sourceName);
+      LDBG() << "Unifying encodings for source global: " << sourceName << " to "
+             << unifiedEncoding;
+      for (EncodedGlobalInfo &encodedInfo : sourceInfo.encodedVersions) {
         auto encodeOp = encodedInfo.encodeOp;
         auto oldResultType =
             cast<RankedTensorType>(encodeOp.getResultEncoding());
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/unify_encoding_for_globals.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/unify_encoding_for_globals.mlir
index d520d11..4839c3f 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/unify_encoding_for_globals.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/unify_encoding_for_globals.mlir
@@ -3,11 +3,15 @@
 // Test: immutable source global (with initial value) with two encodings -
 // should unify to identity encoding.
 
+#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {iree.encoding.resolver = #iree_encoding.identity_resolver}>
+#device_target_local = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
 #encoding1 = #iree_encoding.testing<layouts = [#iree_encoding.specialized<123>]>
 #encoding2 = #iree_encoding.testing<layouts = [#iree_encoding.specialized<456>]>
 
 // CHECK-LABEL: module @immutable_source_with_initial_value
+//       CHECK:   util.global private @[[$DEVICE_A:.+]] =
 module @immutable_source_with_initial_value {
+  util.global private @device_a = #device_target_local
   util.global private @source = #stream.parameter.named<"model"::"weight"> : !stream.resource<constant>
   util.global private @encoded_v1 : !stream.resource<constant>
   util.global private @encoded_v2 : !stream.resource<constant>
@@ -18,18 +22,18 @@
     %source = util.global.load @source : !stream.resource<constant>
     %source_size = stream.resource.size %source : !stream.resource<constant>
 
-    // CHECK: stream.tensor.sizeof tensor<4096x4096xf32, #iree_encoding.identity>
-    // CHECK: stream.tensor.encode %[[SOURCE]] : {{.*}} -> tensor<4096x4096xf32, #iree_encoding.identity>
-    %size1 = stream.tensor.sizeof tensor<4096x4096xf32, #encoding1> : index
-    %enc1 = stream.tensor.encode %source : tensor<4096x4096xf32> in !stream.resource<constant>{%source_size} -> tensor<4096x4096xf32, #encoding1> in !stream.resource<*>{%size1}
-    %const1 = stream.async.clone %enc1 : !stream.resource<*>{%size1} -> !stream.resource<constant>{%size1}
+    // CHECK: stream.tensor.sizeof on(#hal.device.affinity<@[[$DEVICE_A]]>) tensor<4096x4096xf32, #iree_encoding.identity>
+    // CHECK: stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_A]]>) %[[SOURCE]] : {{.*}} -> tensor<4096x4096xf32, #iree_encoding.identity>
+    %size1 = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<4096x4096xf32, #encoding1> : index
+    %enc1 = stream.tensor.encode on(#hal.device.affinity<@device_a>) %source : tensor<4096x4096xf32> in !stream.resource<constant>{%source_size} -> tensor<4096x4096xf32, #encoding1> in !stream.resource<*>{%size1}
+    %const1 = stream.async.clone on(#hal.device.affinity<@device_a>) %enc1 : !stream.resource<*>{%size1} -> !stream.resource<constant>{%size1}
     util.global.store %const1, @encoded_v1 : !stream.resource<constant>
 
-    // CHECK: stream.tensor.sizeof tensor<4096x4096xf32, #iree_encoding.identity>
-    // CHECK: stream.tensor.encode %[[SOURCE]] : {{.*}} -> tensor<4096x4096xf32, #iree_encoding.identity>
-    %size2 = stream.tensor.sizeof tensor<4096x4096xf32, #encoding2> : index
-    %enc2 = stream.tensor.encode %source : tensor<4096x4096xf32> in !stream.resource<constant>{%source_size} -> tensor<4096x4096xf32, #encoding2> in !stream.resource<*>{%size2}
-    %const2 = stream.async.clone %enc2 : !stream.resource<*>{%size2} -> !stream.resource<constant>{%size2}
+    // CHECK: stream.tensor.sizeof on(#hal.device.affinity<@[[$DEVICE_A]]>) tensor<4096x4096xf32, #iree_encoding.identity>
+    // CHECK: stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_A]]>) %[[SOURCE]] : {{.*}} -> tensor<4096x4096xf32, #iree_encoding.identity>
+    %size2 = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<4096x4096xf32, #encoding2> : index
+    %enc2 = stream.tensor.encode on(#hal.device.affinity<@device_a>) %source : tensor<4096x4096xf32> in !stream.resource<constant>{%source_size} -> tensor<4096x4096xf32, #encoding2> in !stream.resource<*>{%size2}
+    %const2 = stream.async.clone on(#hal.device.affinity<@device_a>) %enc2 : !stream.resource<*>{%size2} -> !stream.resource<constant>{%size2}
     util.global.store %const2, @encoded_v2 : !stream.resource<constant>
 
     util.return
@@ -38,16 +42,206 @@
 
 // -----
 
+// Checks that the identity encoding is generated if no resolver is specified.
+
+#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {}>
+#device_target_local = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
+#encoding1 = #iree_encoding.testing<layouts = [#iree_encoding.specialized<456>]>
+#encoding2 = #iree_encoding.testing<layouts = [#iree_encoding.specialized<789>]>
+
+// CHECK: util.global private @[[$DEVICE_A:.+]] =
+util.global private @device_a = #device_target_local
+util.global private @weight : !stream.resource<constant>
+util.global private @weight_size : index
+util.global private @encoded_v1 : !stream.resource<constant>
+util.global private @encoded_v1_size : index
+util.global private @encoded_v2 : !stream.resource<constant>
+util.global private @encoded_v2_size : index
+
+// CHECK: util.initializer
+util.initializer {
+  %cst = stream.tensor.constant on(#hal.device.affinity<@device_a>) : tensor<4096x4096xf32> in !stream.resource<constant> = #stream.parameter.named<"model"::"weight"> : tensor<4096x4096xf32>
+  %0 = stream.resource.size %cst : !stream.resource<constant>
+  util.global.store %cst, @weight : !stream.resource<constant>
+  util.global.store %0, @weight_size : index
+  // CHECK: %[[SOURCE:.+]] = util.global.load @weight
+  %source = util.global.load @weight : !stream.resource<constant>
+  %source_size = util.global.load @weight_size : index
+
+  // CHECK: stream.tensor.sizeof on(#hal.device.affinity<@[[$DEVICE_A]]>) tensor<4096x4096xf32, #iree_encoding.identity>
+  // CHECK: stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_A]]>) %[[SOURCE]] : {{.*}} -> tensor<4096x4096xf32, #iree_encoding.identity>
+  %size1 = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<4096x4096xf32, #encoding1> : index
+  %enc1 = stream.tensor.encode on(#hal.device.affinity<@device_a>) %source : tensor<4096x4096xf32> in !stream.resource<constant>{%source_size} -> tensor<4096x4096xf32, #encoding1> in !stream.resource<constant>{%size1}
+  util.global.store %enc1, @encoded_v1 : !stream.resource<constant>
+  util.global.store %size1, @encoded_v1_size : index
+
+  // CHECK: stream.tensor.sizeof on(#hal.device.affinity<@[[$DEVICE_A]]>) tensor<4096x4096xf32, #iree_encoding.identity>
+  // CHECK: stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_A]]>) %[[SOURCE]] : {{.*}} -> tensor<4096x4096xf32, #iree_encoding.identity>
+  %size2 = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<4096x4096xf32, #encoding2> : index
+  %enc2 = stream.tensor.encode on(#hal.device.affinity<@device_a>) %source : tensor<4096x4096xf32> in !stream.resource<constant>{%source_size} -> tensor<4096x4096xf32, #encoding2> in !stream.resource<constant>{%size2}
+  util.global.store %enc2, @encoded_v2 : !stream.resource<constant>
+  util.global.store %size2, @encoded_v2_size : index
+
+  util.return
+}
+
+// -----
+
+// Test: multiple devices with different resolvers encoding the same source global.
+// Since different resolvers produce different encodings and they share the same source,
+// there's no common encoding - should fall back to identity encoding.
+
+#executable_target_0 = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {iree.encoding.resolver = #iree_encoding.specialization_resolver<123>}>
+#executable_target_1 = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {iree.encoding.resolver = #iree_encoding.specialization_resolver<456>}>
+#device_target_local_0 = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_0]> : !hal.device
+#device_target_local_1 = #hal.device.target<"local", {ordinal = 1 : index}, [#executable_target_1]> : !hal.device
+#encoding1 = #iree_encoding.testing<layouts = [#iree_encoding.specialized<111>]>
+#encoding2 = #iree_encoding.testing<layouts = [#iree_encoding.specialized<222>]>
+
+// CHECK: util.global private @[[$DEVICE_A:.+]] =
+// CHECK: util.global private @[[$DEVICE_B:.+]] =
+util.global private @device_a = #device_target_local_0
+util.global private @device_b = #device_target_local_1
+// Single source global shared by both devices.
+util.global private @weight : !stream.resource<constant>
+util.global private @weight_size : index
+util.global private @encoded_a_v1 : !stream.resource<constant>
+util.global private @encoded_a_v2 : !stream.resource<constant>
+util.global private @encoded_b_v1 : !stream.resource<constant>
+util.global private @encoded_b_v2 : !stream.resource<constant>
+
+// CHECK: util.initializer
+util.initializer {
+  %cst = stream.tensor.constant on(#hal.device.affinity<@device_a>) : tensor<4096x4096xf32> in !stream.resource<constant> = #stream.parameter.named<"model"::"weight"> : tensor<4096x4096xf32>
+  %0 = stream.resource.size %cst : !stream.resource<constant>
+  util.global.store %cst, @weight : !stream.resource<constant>
+  util.global.store %0, @weight_size : index
+
+  // CHECK: %[[SOURCE:.+]] = util.global.load @weight
+  %source = util.global.load @weight : !stream.resource<constant>
+  %source_size = util.global.load @weight_size : index
+
+  // Device A encodes the shared source - should get identity encoding.
+  // CHECK: stream.tensor.sizeof on(#hal.device.affinity<@[[$DEVICE_A]]>) tensor<4096x4096xf32, #iree_encoding.identity>
+  // CHECK: stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_A]]>) %[[SOURCE]] : {{.*}} -> tensor<4096x4096xf32, #iree_encoding.identity>
+  %size1 = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<4096x4096xf32, #encoding1> : index
+  %enc1 = stream.tensor.encode on(#hal.device.affinity<@device_a>) %source : tensor<4096x4096xf32> in !stream.resource<constant>{%source_size} -> tensor<4096x4096xf32, #encoding1> in !stream.resource<constant>{%size1}
+  util.global.store %enc1, @encoded_a_v1 : !stream.resource<constant>
+
+  // CHECK: stream.tensor.sizeof on(#hal.device.affinity<@[[$DEVICE_A]]>) tensor<4096x4096xf32, #iree_encoding.identity>
+  // CHECK: stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_A]]>) %[[SOURCE]] : {{.*}} -> tensor<4096x4096xf32, #iree_encoding.identity>
+  %size2 = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<4096x4096xf32, #encoding2> : index
+  %enc2 = stream.tensor.encode on(#hal.device.affinity<@device_a>) %source : tensor<4096x4096xf32> in !stream.resource<constant>{%source_size} -> tensor<4096x4096xf32, #encoding2> in !stream.resource<constant>{%size2}
+  util.global.store %enc2, @encoded_a_v2 : !stream.resource<constant>
+
+  // Device B encodes the same shared source - should also get identity encoding.
+  // CHECK: stream.tensor.sizeof on(#hal.device.affinity<@[[$DEVICE_B]]>) tensor<4096x4096xf32, #iree_encoding.identity>
+  // CHECK: stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_B]]>) %[[SOURCE]] : {{.*}} -> tensor<4096x4096xf32, #iree_encoding.identity>
+  %size3 = stream.tensor.sizeof on(#hal.device.affinity<@device_b>) tensor<4096x4096xf32, #encoding1> : index
+  %enc3 = stream.tensor.encode on(#hal.device.affinity<@device_b>) %source : tensor<4096x4096xf32> in !stream.resource<constant>{%source_size} -> tensor<4096x4096xf32, #encoding1> in !stream.resource<constant>{%size3}
+  util.global.store %enc3, @encoded_b_v1 : !stream.resource<constant>
+
+  // CHECK: stream.tensor.sizeof on(#hal.device.affinity<@[[$DEVICE_B]]>) tensor<4096x4096xf32, #iree_encoding.identity>
+  // CHECK: stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_B]]>) %[[SOURCE]] : {{.*}} -> tensor<4096x4096xf32, #iree_encoding.identity>
+  %size4 = stream.tensor.sizeof on(#hal.device.affinity<@device_b>) tensor<4096x4096xf32, #encoding2> : index
+  %enc4 = stream.tensor.encode on(#hal.device.affinity<@device_b>) %source : tensor<4096x4096xf32> in !stream.resource<constant>{%source_size} -> tensor<4096x4096xf32, #encoding2> in !stream.resource<constant>{%size4}
+  util.global.store %enc4, @encoded_b_v2 : !stream.resource<constant>
+
+  util.return
+}
+
+// -----
+
+// Test: multiple devices with different resolvers - each device should use its own resolver.
+
+#executable_target_0 = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {iree.encoding.resolver = #iree_encoding.specialization_resolver<123>}>
+#executable_target_1 = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {iree.encoding.resolver = #iree_encoding.specialization_resolver<456>}>
+#device_target_local_0 = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_0]> : !hal.device
+#device_target_local_1 = #hal.device.target<"local", {ordinal = 1 : index}, [#executable_target_1]> : !hal.device
+#encoding1 = #iree_encoding.testing<layouts = [#iree_encoding.specialized<111>]>
+#encoding2 = #iree_encoding.testing<layouts = [#iree_encoding.specialized<222>]>
+
+// CHECK: util.global private @[[$DEVICE_A:.+]] =
+// CHECK: util.global private @[[$DEVICE_B:.+]] =
+util.global private @device_a = #device_target_local_0
+util.global private @device_b = #device_target_local_1
+// Two separate source globals - one for each device.
+util.global private @weight_a : !stream.resource<constant>
+util.global private @weight_a_size : index
+util.global private @weight_b : !stream.resource<constant>
+util.global private @weight_b_size : index
+util.global private @encoded_a_v1 : !stream.resource<constant>
+util.global private @encoded_a_v2 : !stream.resource<constant>
+util.global private @encoded_b_v1 : !stream.resource<constant>
+util.global private @encoded_b_v2 : !stream.resource<constant>
+
+// CHECK: util.initializer
+util.initializer {
+  // Initialize weight_a for device_a.
+  %cst_a = stream.tensor.constant on(#hal.device.affinity<@device_a>) : tensor<4096x4096xf32> in !stream.resource<constant> = #stream.parameter.named<"model"::"weight_a"> : tensor<4096x4096xf32>
+  %size_a = stream.resource.size %cst_a : !stream.resource<constant>
+  util.global.store %cst_a, @weight_a : !stream.resource<constant>
+  util.global.store %size_a, @weight_a_size : index
+
+  // Initialize weight_b for device_b.
+  %cst_b = stream.tensor.constant on(#hal.device.affinity<@device_b>) : tensor<4096x4096xf32> in !stream.resource<constant> = #stream.parameter.named<"model"::"weight_b"> : tensor<4096x4096xf32>
+  %size_b = stream.resource.size %cst_b : !stream.resource<constant>
+  util.global.store %cst_b, @weight_b : !stream.resource<constant>
+  util.global.store %size_b, @weight_b_size : index
+
+  // CHECK: %[[SOURCE_A:.+]] = util.global.load @weight_a
+  %source_a = util.global.load @weight_a : !stream.resource<constant>
+  %source_a_size = util.global.load @weight_a_size : index
+
+  // CHECK: %[[SOURCE_B:.+]] = util.global.load @weight_b
+  %source_b = util.global.load @weight_b : !stream.resource<constant>
+  %source_b_size = util.global.load @weight_b_size : index
+
+  // Device A encodes weight_a with specialization_resolver<123>.
+  // CHECK: stream.tensor.sizeof on(#hal.device.affinity<@[[$DEVICE_A]]>) tensor<4096x4096xf32, #iree_encoding.specialized<123>>
+  // CHECK: stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_A]]>) %[[SOURCE_A]] : {{.*}} -> tensor<4096x4096xf32, #iree_encoding.specialized<123>>
+  %size1 = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<4096x4096xf32, #encoding1> : index
+  %enc1 = stream.tensor.encode on(#hal.device.affinity<@device_a>) %source_a : tensor<4096x4096xf32> in !stream.resource<constant>{%source_a_size} -> tensor<4096x4096xf32, #encoding1> in !stream.resource<constant>{%size1}
+  util.global.store %enc1, @encoded_a_v1 : !stream.resource<constant>
+
+  // CHECK: stream.tensor.sizeof on(#hal.device.affinity<@[[$DEVICE_A]]>) tensor<4096x4096xf32, #iree_encoding.specialized<123>>
+  // CHECK: stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_A]]>) %[[SOURCE_A]] : {{.*}} -> tensor<4096x4096xf32, #iree_encoding.specialized<123>>
+  %size2 = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<4096x4096xf32, #encoding2> : index
+  %enc2 = stream.tensor.encode on(#hal.device.affinity<@device_a>) %source_a : tensor<4096x4096xf32> in !stream.resource<constant>{%source_a_size} -> tensor<4096x4096xf32, #encoding2> in !stream.resource<constant>{%size2}
+  util.global.store %enc2, @encoded_a_v2 : !stream.resource<constant>
+
+  // Device B encodes weight_b with specialization_resolver<456>.
+  // CHECK: stream.tensor.sizeof on(#hal.device.affinity<@[[$DEVICE_B]]>) tensor<4096x4096xf32, #iree_encoding.specialized<456>>
+  // CHECK: stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_B]]>) %[[SOURCE_B]] : {{.*}} -> tensor<4096x4096xf32, #iree_encoding.specialized<456>>
+  %size3 = stream.tensor.sizeof on(#hal.device.affinity<@device_b>) tensor<4096x4096xf32, #encoding1> : index
+  %enc3 = stream.tensor.encode on(#hal.device.affinity<@device_b>) %source_b : tensor<4096x4096xf32> in !stream.resource<constant>{%source_b_size} -> tensor<4096x4096xf32, #encoding1> in !stream.resource<constant>{%size3}
+  util.global.store %enc3, @encoded_b_v1 : !stream.resource<constant>
+
+  // CHECK: stream.tensor.sizeof on(#hal.device.affinity<@[[$DEVICE_B]]>) tensor<4096x4096xf32, #iree_encoding.specialized<456>>
+  // CHECK: stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_B]]>) %[[SOURCE_B]] : {{.*}} -> tensor<4096x4096xf32, #iree_encoding.specialized<456>>
+  %size4 = stream.tensor.sizeof on(#hal.device.affinity<@device_b>) tensor<4096x4096xf32, #encoding2> : index
+  %enc4 = stream.tensor.encode on(#hal.device.affinity<@device_b>) %source_b : tensor<4096x4096xf32> in !stream.resource<constant>{%source_b_size} -> tensor<4096x4096xf32, #encoding2> in !stream.resource<constant>{%size4}
+  util.global.store %enc4, @encoded_b_v2 : !stream.resource<constant>
+
+  util.return
+}
+
+// -----
+
 // Test: immutable source global (initialized from parameter in initializer) with
 // two encodings - should unify to identity encoding.
 // This test also verifies that stream.async.clone between load and encode is
 // properly traced through (matching real-world patterns from input pipelines).
 
-#encoding1 = #iree_encoding.testing<layouts = [#iree_encoding.specialized<123>]>
-#encoding2 = #iree_encoding.testing<layouts = [#iree_encoding.specialized<456>]>
+#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {iree.encoding.resolver = #iree_encoding.specialization_resolver<123>}>
+#device_target_local = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
+#encoding1 = #iree_encoding.testing<layouts = [#iree_encoding.specialized<456>]>
+#encoding2 = #iree_encoding.testing<layouts = [#iree_encoding.specialized<789>]>
 
 // CHECK-LABEL: module @immutable_source_initialized_from_parameter
+//       CHECK:   util.global private @[[$DEVICE_A:.+]] =
 module @immutable_source_initialized_from_parameter {
+  util.global private @device_a = #device_target_local
   util.global private @weight : !stream.resource<constant>
   util.global private @weight_size : index
   util.global private @encoded_v1 : !stream.resource<constant>
@@ -57,7 +251,7 @@
 
   // CHECK: util.initializer
   util.initializer {
-    %cst = stream.tensor.constant : tensor<4096x4096xf8E4M3FNUZ> in !stream.resource<constant> = #stream.parameter.named<"model"::"weight"> : tensor<4096x4096xf32>
+    %cst = stream.tensor.constant on(#hal.device.affinity<@device_a>) : tensor<4096x4096xf32> in !stream.resource<constant> = #stream.parameter.named<"model"::"weight"> : tensor<4096x4096xf32>
     %0 = stream.resource.size %cst : !stream.resource<constant>
     util.global.store %cst, @weight : !stream.resource<constant>
     util.global.store %0, @weight_size : index
@@ -66,25 +260,25 @@
     %source_size = util.global.load @weight_size : index
 
     // Clone before encode (common pattern in real pipelines).
-    // CHECK: %[[CLONE1:.+]] = stream.async.clone %[[SOURCE]]
-    %cloned1 = stream.async.clone %source : !stream.resource<constant>{%source_size} -> !stream.resource<*>{%source_size}
+    // CHECK: %[[CLONE1:.+]] = stream.async.clone on(#hal.device.affinity<@[[$DEVICE_A]]>) %[[SOURCE]]
+    %cloned1 = stream.async.clone on(#hal.device.affinity<@device_a>) %source : !stream.resource<constant>{%source_size} -> !stream.resource<*>{%source_size}
 
-    // CHECK: stream.tensor.sizeof tensor<4096x4096xf32, #iree_encoding.identity>
-    // CHECK: stream.tensor.encode %[[CLONE1]] : {{.*}} -> tensor<4096x4096xf32, #iree_encoding.identity>
-    %size1 = stream.tensor.sizeof tensor<4096x4096xf32, #encoding1> : index
-    %enc1 = stream.tensor.encode %cloned1 : tensor<4096x4096xf32> in !stream.resource<*>{%source_size} -> tensor<4096x4096xf32, #encoding1> in !stream.resource<*>{%size1}
-    %const1 = stream.async.clone %enc1 : !stream.resource<*>{%size1} -> !stream.resource<constant>{%size1}
+    // CHECK: stream.tensor.sizeof on(#hal.device.affinity<@[[$DEVICE_A]]>) tensor<4096x4096xf32, #iree_encoding.specialized<123>>
+    // CHECK: stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_A]]>) %[[CLONE1]] : {{.*}} -> tensor<4096x4096xf32, #iree_encoding.specialized<123>>
+    %size1 = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<4096x4096xf32, #encoding1> : index
+    %enc1 = stream.tensor.encode on(#hal.device.affinity<@device_a>) %cloned1 : tensor<4096x4096xf32> in !stream.resource<*>{%source_size} -> tensor<4096x4096xf32, #encoding1> in !stream.resource<*>{%size1}
+    %const1 = stream.async.clone on(#hal.device.affinity<@device_a>) %enc1 : !stream.resource<*>{%size1} -> !stream.resource<constant>{%size1}
     util.global.store %const1, @encoded_v1 : !stream.resource<constant>
     util.global.store %size1, @encoded_v1_size : index
 
-    // CHECK: %[[CLONE2:.+]] = stream.async.clone %[[SOURCE]]
-    %cloned2 = stream.async.clone %source : !stream.resource<constant>{%source_size} -> !stream.resource<*>{%source_size}
+    // CHECK: %[[CLONE2:.+]] = stream.async.clone on(#hal.device.affinity<@[[$DEVICE_A]]>) %[[SOURCE]]
+    %cloned2 = stream.async.clone on(#hal.device.affinity<@device_a>) %source : !stream.resource<constant>{%source_size} -> !stream.resource<*>{%source_size}
 
-    // CHECK: stream.tensor.sizeof tensor<4096x4096xf32, #iree_encoding.identity>
-    // CHECK: stream.tensor.encode %[[CLONE2]] : {{.*}} -> tensor<4096x4096xf32, #iree_encoding.identity>
-    %size2 = stream.tensor.sizeof tensor<4096x4096xf32, #encoding2> : index
-    %enc2 = stream.tensor.encode %cloned2 : tensor<4096x4096xf32> in !stream.resource<*>{%source_size} -> tensor<4096x4096xf32, #encoding2> in !stream.resource<*>{%size2}
-    %const2 = stream.async.clone %enc2 : !stream.resource<*>{%size2} -> !stream.resource<constant>{%size2}
+    // CHECK: stream.tensor.sizeof on(#hal.device.affinity<@[[$DEVICE_A]]>) tensor<4096x4096xf32, #iree_encoding.specialized<123>>
+    // CHECK: stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_A]]>) %[[CLONE2]] : {{.*}} -> tensor<4096x4096xf32, #iree_encoding.specialized<123>>
+    %size2 = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<4096x4096xf32, #encoding2> : index
+    %enc2 = stream.tensor.encode on(#hal.device.affinity<@device_a>) %cloned2 : tensor<4096x4096xf32> in !stream.resource<*>{%source_size} -> tensor<4096x4096xf32, #encoding2> in !stream.resource<*>{%size2}
+    %const2 = stream.async.clone on(#hal.device.affinity<@device_a>) %enc2 : !stream.resource<*>{%size2} -> !stream.resource<constant>{%size2}
     util.global.store %const2, @encoded_v2 : !stream.resource<constant>
     util.global.store %size2, @encoded_v2_size : index
 
@@ -136,7 +330,7 @@
     %encoded = util.global.load @encoded_v1 : !stream.resource<constant>
     %encoded_size = util.global.load @encoded_v1_size : index
     // CHECK:      stream.tensor.dispatch
-    // CHECK-SAME:   tensor<4096x4096xf32, #iree_encoding.identity>
+    // CHECK-SAME:   tensor<4096x4096xf32, #iree_encoding.specialized<123>>
     %result = stream.tensor.dispatch @executable_v1::@dispatch(%encoded)
       : (tensor<4096x4096xf32, #encoding1> in !stream.resource<constant>{%encoded_size})
       -> tensor<16xf32> in !stream.resource<*>{%arg0}
@@ -148,7 +342,7 @@
     %encoded = util.global.load @encoded_v2 : !stream.resource<constant>
     %encoded_size = util.global.load @encoded_v2_size : index
     // CHECK:      stream.tensor.dispatch
-    // CHECK-SAME:   tensor<4096x4096xf32, #iree_encoding.identity>
+    // CHECK-SAME:   tensor<4096x4096xf32, #iree_encoding.specialized<123>>
     %result = stream.tensor.dispatch @executable_v2::@dispatch(%encoded)
       : (tensor<4096x4096xf32, #encoding2> in !stream.resource<constant>{%encoded_size})
       -> tensor<16xf32> in !stream.resource<*>{%arg0}
@@ -173,8 +367,8 @@
     %cloned_v2 = stream.async.clone %encoded_v2 : !stream.resource<constant>{%encoded_v2_size} -> !stream.resource<*>{%encoded_v2_size}
 
     // CHECK:      stream.tensor.dispatch @executable_both::@dispatch
-    // CHECK-SAME:   tensor<4096x4096xf32, #iree_encoding.identity>
-    // CHECK-SAME:   tensor<4096x4096xf32, #iree_encoding.identity>
+    // CHECK-SAME:   tensor<4096x4096xf32, #iree_encoding.specialized<123>>
+    // CHECK-SAME:   tensor<4096x4096xf32, #iree_encoding.specialized<123>>
     %result = stream.tensor.dispatch @executable_both::@dispatch[%c16, %c32](%cloned_v1, %c0, %cloned_v2, %c1) : (tensor<4096x4096xf32, #encoding1> in !stream.resource<*>{%encoded_v1_size}, index, tensor<4096x4096xf32, #encoding2> in !stream.resource<*>{%encoded_v2_size}, index) -> tensor<16xf32> in !stream.resource<*>{%arg0}
 
     util.return %result : !stream.resource<*>
@@ -186,11 +380,15 @@
 // Test: cross-function tracking - load encoded global, pass to callee via
 // util.call, and verify dispatch site encoding is updated in callee.
 
-#encoding1 = #iree_encoding.testing<layouts = [#iree_encoding.specialized<123>]>
-#encoding2 = #iree_encoding.testing<layouts = [#iree_encoding.specialized<456>]>
+#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {iree.encoding.resolver = #iree_encoding.specialization_resolver<123>}>
+#device_target_local = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
+#encoding1 = #iree_encoding.testing<layouts = [#iree_encoding.specialized<456>]>
+#encoding2 = #iree_encoding.testing<layouts = [#iree_encoding.specialized<789>]>
 
 // CHECK-LABEL: module @cross_function_tracking
+//       CHECK:   util.global private @[[$DEVICE_A:.+]] =
 module @cross_function_tracking {
+  util.global private @device_a = #device_target_local
   util.global private @source = #stream.parameter.named<"model"::"weight"> : !stream.resource<constant>
   util.global private @encoded_v1 : !stream.resource<constant>
   util.global private @encoded_v1_size : index
@@ -201,19 +399,19 @@
     %source = util.global.load @source : !stream.resource<constant>
     %source_size = stream.resource.size %source : !stream.resource<constant>
 
-    // CHECK: stream.tensor.sizeof tensor<4096x4096xf32, #iree_encoding.identity>
-    // CHECK: stream.tensor.encode {{.*}} -> tensor<4096x4096xf32, #iree_encoding.identity>
-    %size1 = stream.tensor.sizeof tensor<4096x4096xf32, #encoding1> : index
-    %enc1 = stream.tensor.encode %source : tensor<4096x4096xf32> in !stream.resource<constant>{%source_size} -> tensor<4096x4096xf32, #encoding1> in !stream.resource<*>{%size1}
-    %const1 = stream.async.clone %enc1 : !stream.resource<*>{%size1} -> !stream.resource<constant>{%size1}
+    // CHECK: stream.tensor.sizeof on(#hal.device.affinity<@[[$DEVICE_A]]>) tensor<4096x4096xf32, #iree_encoding.specialized<123>>
+    // CHECK: stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_A]]>) {{.*}} -> tensor<4096x4096xf32, #iree_encoding.specialized<123>>
+    %size1 = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<4096x4096xf32, #encoding1> : index
+    %enc1 = stream.tensor.encode on(#hal.device.affinity<@device_a>) %source : tensor<4096x4096xf32> in !stream.resource<constant>{%source_size} -> tensor<4096x4096xf32, #encoding1> in !stream.resource<*>{%size1}
+    %const1 = stream.async.clone on(#hal.device.affinity<@device_a>) %enc1 : !stream.resource<*>{%size1} -> !stream.resource<constant>{%size1}
     util.global.store %const1, @encoded_v1 : !stream.resource<constant>
     util.global.store %size1, @encoded_v1_size : index
 
-    // CHECK: stream.tensor.sizeof tensor<4096x4096xf32, #iree_encoding.identity>
-    // CHECK: stream.tensor.encode {{.*}} -> tensor<4096x4096xf32, #iree_encoding.identity>
-    %size2 = stream.tensor.sizeof tensor<4096x4096xf32, #encoding2> : index
-    %enc2 = stream.tensor.encode %source : tensor<4096x4096xf32> in !stream.resource<constant>{%source_size} -> tensor<4096x4096xf32, #encoding2> in !stream.resource<*>{%size2}
-    %const2 = stream.async.clone %enc2 : !stream.resource<*>{%size2} -> !stream.resource<constant>{%size2}
+    // CHECK: stream.tensor.sizeof on(#hal.device.affinity<@[[$DEVICE_A]]>) tensor<4096x4096xf32, #iree_encoding.specialized<123>>
+    // CHECK: stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_A]]>) {{.*}} -> tensor<4096x4096xf32, #iree_encoding.specialized<123>>
+    %size2 = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<4096x4096xf32, #encoding2> : index
+    %enc2 = stream.tensor.encode on(#hal.device.affinity<@device_a>) %source : tensor<4096x4096xf32> in !stream.resource<constant>{%source_size} -> tensor<4096x4096xf32, #encoding2> in !stream.resource<*>{%size2}
+    %const2 = stream.async.clone on(#hal.device.affinity<@device_a>) %enc2 : !stream.resource<*>{%size2} -> !stream.resource<constant>{%size2}
     util.global.store %const2, @encoded_v2 : !stream.resource<constant>
     util.global.store %size2, @encoded_v2_size : index
 
@@ -237,8 +435,8 @@
   // CHECK-LABEL: util.func private @dispatch_helper
   util.func private @dispatch_helper(%arg0: !stream.resource<*>, %arg1: index, %arg2: !stream.resource<*>, %arg3: index, %arg4: index) -> !stream.resource<*> {
     // CHECK:      stream.tensor.dispatch @executable_both::@dispatch
-    // CHECK-SAME:   tensor<4096x4096xf32, #iree_encoding.identity>
-    // CHECK-SAME:   tensor<4096x4096xf32, #iree_encoding.identity>
+    // CHECK-SAME:   tensor<4096x4096xf32, #iree_encoding.specialized<123>>
+    // CHECK-SAME:   tensor<4096x4096xf32, #iree_encoding.specialized<123>>
     %result = stream.tensor.dispatch @executable_both::@dispatch(%arg0, %arg2) : (tensor<4096x4096xf32, #encoding1> in !stream.resource<*>{%arg1}, tensor<4096x4096xf32, #encoding2> in !stream.resource<*>{%arg3}) -> tensor<16xf32> in !stream.resource<*>{%arg4}
     util.return %result : !stream.resource<*>
   }
@@ -257,10 +455,18 @@
 
 // -----
 
-// CHECK: #[[$ENC:.+]] = #iree_encoding.testing<layouts = [#iree_encoding.specialized<123>]>
-#encoding1 = #iree_encoding.testing<layouts = [#iree_encoding.specialized<123>]>
-#encoding2 = #iree_encoding.testing<layouts = [#iree_encoding.specialized<456>]>
+// Test tied operand: dispatch result tied to input operand.
+// When encoding changes, both operand and result encoding must change.
+// A re-encode op should be inserted after dispatch.
 
+#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {iree.encoding.resolver = #iree_encoding.specialization_resolver<123>}>
+#device_target_local = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
+// CHECK: #[[$ENC:.+]] = #iree_encoding.testing<layouts = [#iree_encoding.specialized<456>]>
+#encoding1 = #iree_encoding.testing<layouts = [#iree_encoding.specialized<456>]>
+#encoding2 = #iree_encoding.testing<layouts = [#iree_encoding.specialized<789>]>
+
+// CHECK: util.global private @[[$DEVICE_A:.+]] =
+util.global private @device_a = #device_target_local
 util.global private @weight : !stream.resource<constant>
 util.global private @weight_size : index
 util.global private @encoded_v1 : !stream.resource<constant>
@@ -270,7 +476,7 @@
 
 // CHECK: util.initializer
 util.initializer {
-  %cst = stream.tensor.constant : tensor<4096x4096xf32> in !stream.resource<constant> = #stream.parameter.named<"model"::"weight"> : tensor<4096x4096xf32>
+  %cst = stream.tensor.constant on(#hal.device.affinity<@device_a>) : tensor<4096x4096xf32> in !stream.resource<constant> = #stream.parameter.named<"model"::"weight"> : tensor<4096x4096xf32>
   %0 = stream.resource.size %cst : !stream.resource<constant>
   util.global.store %cst, @weight : !stream.resource<constant>
   util.global.store %0, @weight_size : index
@@ -278,17 +484,17 @@
   %source = util.global.load @weight : !stream.resource<constant>
   %source_size = util.global.load @weight_size : index
 
-  // CHECK: stream.tensor.sizeof tensor<4096x4096xf32, #iree_encoding.identity>
-  // CHECK: stream.tensor.encode %[[SOURCE]] : {{.*}} -> tensor<4096x4096xf32, #iree_encoding.identity>
-  %size1 = stream.tensor.sizeof tensor<4096x4096xf32, #encoding1> : index
-  %enc1 = stream.tensor.encode %source : tensor<4096x4096xf32> in !stream.resource<constant>{%source_size} -> tensor<4096x4096xf32, #encoding1> in !stream.resource<constant>{%size1}
+  // CHECK: stream.tensor.sizeof on(#hal.device.affinity<@[[$DEVICE_A]]>) tensor<4096x4096xf32, #iree_encoding.specialized<123>>
+  // CHECK: stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_A]]>) %[[SOURCE]] : {{.*}} -> tensor<4096x4096xf32, #iree_encoding.specialized<123>>
+  %size1 = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<4096x4096xf32, #encoding1> : index
+  %enc1 = stream.tensor.encode on(#hal.device.affinity<@device_a>) %source : tensor<4096x4096xf32> in !stream.resource<constant>{%source_size} -> tensor<4096x4096xf32, #encoding1> in !stream.resource<constant>{%size1}
   util.global.store %enc1, @encoded_v1 : !stream.resource<constant>
   util.global.store %size1, @encoded_v1_size : index
 
-  // CHECK: stream.tensor.sizeof tensor<4096x4096xf32, #iree_encoding.identity>
-  // CHECK: stream.tensor.encode %[[SOURCE]] : {{.*}} -> tensor<4096x4096xf32, #iree_encoding.identity>
-  %size2 = stream.tensor.sizeof tensor<4096x4096xf32, #encoding2> : index
-  %enc2 = stream.tensor.encode %source : tensor<4096x4096xf32> in !stream.resource<constant>{%source_size} -> tensor<4096x4096xf32, #encoding2> in !stream.resource<constant>{%size2}
+  // CHECK: stream.tensor.sizeof on(#hal.device.affinity<@[[$DEVICE_A]]>) tensor<4096x4096xf32, #iree_encoding.specialized<123>>
+  // CHECK: stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_A]]>) %[[SOURCE]] : {{.*}} -> tensor<4096x4096xf32, #iree_encoding.specialized<123>>
+  %size2 = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<4096x4096xf32, #encoding2> : index
+  %enc2 = stream.tensor.encode on(#hal.device.affinity<@device_a>) %source : tensor<4096x4096xf32> in !stream.resource<constant>{%source_size} -> tensor<4096x4096xf32, #encoding2> in !stream.resource<constant>{%size2}
   util.global.store %enc2, @encoded_v2 : !stream.resource<constant>
   util.global.store %size2, @encoded_v2_size : index
 
@@ -307,9 +513,6 @@
   }
 }
 
-// Test tied operand: dispatch result tied to input operand.
-// When encoding changes, both operand and result encoding must change.
-// A re-encode op should be inserted after dispatch.
 // CHECK-LABEL: util.func public @dispatch_with_tied_operand
 // CHECK-SAME:    %[[N:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:    %[[M:[a-zA-Z0-9_]+]]: index
@@ -319,15 +522,15 @@
   %cloned = stream.async.clone %encoded : !stream.resource<constant>{%encoded_size} -> !stream.resource<*>{%encoded_size}
 
   // The dispatch has a tied result (result -> %cloned).
-  // CHECK:      %[[DISPATCH:.+]] = stream.tensor.dispatch @executable_tied::@dispatch_inplace
-  // CHECK-SAME:   tensor<?x?xf32, #iree_encoding.identity>{%[[N]], %[[M]]}
-  // CHECK-SAME:   -> tensor<?x?xf32, #iree_encoding.identity>{%[[N]], %[[M]]}
+  // CHECK:      %[[DISPATCH:.+]] = stream.tensor.dispatch on(#hal.device.affinity<@[[$DEVICE_A]]>) @executable_tied::@dispatch_inplace
+  // CHECK-SAME:   tensor<?x?xf32, #iree_encoding.specialized<123>>{%[[N]], %[[M]]}
+  // CHECK-SAME:   -> tensor<?x?xf32, #iree_encoding.specialized<123>>{%[[N]], %[[M]]}
   // Re-encode sizeof and encode ops are inserted after dispatch.
-  // CHECK:      stream.tensor.sizeof tensor<?x?xf32, #[[$ENC]]>{%[[N]], %[[M]]}
-  // CHECK:      stream.tensor.encode %[[DISPATCH]] :
-  // CHECK-SAME:   tensor<?x?xf32, #iree_encoding.identity>
+  // CHECK:      stream.tensor.sizeof on(#hal.device.affinity<@[[$DEVICE_A]]>) tensor<?x?xf32, #[[$ENC]]>{%[[N]], %[[M]]}
+  // CHECK:      stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_A]]>) %[[DISPATCH]] :
+  // CHECK-SAME:   tensor<?x?xf32, #iree_encoding.specialized<123>>
   // CHECK-SAME:   -> tensor<?x?xf32, #[[$ENC]]>{%[[N]], %[[M]]}
-  %result = stream.tensor.dispatch @executable_tied::@dispatch_inplace(%cloned)
+  %result = stream.tensor.dispatch on(#hal.device.affinity<@device_a>) @executable_tied::@dispatch_inplace(%cloned)
     : (tensor<?x?xf32, #encoding1>{%N, %M} in !stream.resource<*>{%encoded_size})
     -> tensor<?x?xf32, #encoding1>{%N, %M} in %cloned{%encoded_size}
 
@@ -338,27 +541,31 @@
 
 // Test: mutable source global - should be skipped, encoding unchanged.
 
+#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {iree.encoding.resolver = #iree_encoding.identity_resolver}>
+#device_target_local = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
 #encoding1 = #iree_encoding.testing<layouts = [#iree_encoding.specialized<123>]>
 
 // CHECK: #[[$ENC:.+]] = #iree_encoding.testing<layouts = [#iree_encoding.specialized<123>]>
 // CHECK-LABEL: module @mutable_source_skipped
+//       CHECK:   util.global private @[[$DEVICE_A:.+]] =
 module @mutable_source_skipped {
+  util.global private @device_a = #device_target_local
   util.global private mutable @mutable_source : !stream.resource<constant>
   util.global private @encoded : !stream.resource<constant>
 
   util.initializer {
-    %cst = stream.tensor.constant : tensor<4096x4096xf32> in !stream.resource<constant> = dense<0.0> : tensor<4096x4096xf32>
+    %cst = stream.tensor.constant on(#hal.device.affinity<@device_a>) : tensor<4096x4096xf32> in !stream.resource<constant> = dense<0.0> : tensor<4096x4096xf32>
     %cst_size = stream.resource.size %cst : !stream.resource<constant>
     util.global.store %cst, @mutable_source : !stream.resource<constant>
 
     %source = util.global.load @mutable_source : !stream.resource<constant>
     %source_size = stream.resource.size %source : !stream.resource<constant>
 
-    // CHECK: stream.tensor.sizeof tensor<4096x4096xf32, #[[$ENC]]>
-    // CHECK: stream.tensor.encode {{.*}} -> tensor<4096x4096xf32, #[[$ENC]]>
-    %size1 = stream.tensor.sizeof tensor<4096x4096xf32, #encoding1> : index
-    %enc1 = stream.tensor.encode %source : tensor<4096x4096xf32> in !stream.resource<constant>{%source_size} -> tensor<4096x4096xf32, #encoding1> in !stream.resource<*>{%size1}
-    %const1 = stream.async.clone %enc1 : !stream.resource<*>{%size1} -> !stream.resource<constant>{%size1}
+    // CHECK: stream.tensor.sizeof on(#hal.device.affinity<@[[$DEVICE_A]]>) tensor<4096x4096xf32, #[[$ENC]]>
+    // CHECK: stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_A]]>) {{.*}} -> tensor<4096x4096xf32, #[[$ENC]]>
+    %size1 = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<4096x4096xf32, #encoding1> : index
+    %enc1 = stream.tensor.encode on(#hal.device.affinity<@device_a>) %source : tensor<4096x4096xf32> in !stream.resource<constant>{%source_size} -> tensor<4096x4096xf32, #encoding1> in !stream.resource<*>{%size1}
+    %const1 = stream.async.clone on(#hal.device.affinity<@device_a>) %enc1 : !stream.resource<*>{%size1} -> !stream.resource<constant>{%size1}
     util.global.store %const1, @encoded : !stream.resource<constant>
 
     util.return
@@ -369,13 +576,17 @@
 
 // Test: mutable encoded global - should be skipped, encoding unchanged.
 
+#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {iree.encoding.resolver = #iree_encoding.identity_resolver}>
+#device_target_local = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
 #encoding1 = #iree_encoding.testing<layouts = [#iree_encoding.specialized<123>]>
 #encoding2 = #iree_encoding.testing<layouts = [#iree_encoding.specialized<456>]>
 
 // CHECK: #[[$ENC1:.+]] = #iree_encoding.testing<layouts = [#iree_encoding.specialized<123>]>
 // CHECK: #[[$ENC2:.+]] = #iree_encoding.testing<layouts = [#iree_encoding.specialized<456>]>
 // CHECK-LABEL: module @mutable_encoded_global_skipped
+//       CHECK:   util.global private @[[$DEVICE_A:.+]] =
 module @mutable_encoded_global_skipped {
+  util.global private @device_a = #device_target_local
   util.global private @source = #stream.parameter.named<"model"::"weight"> : !stream.resource<constant>
   util.global private mutable @encoded_mutable_v1 : !stream.resource<constant>
   util.global private mutable @encoded_mutable_v2 : !stream.resource<constant>
@@ -384,18 +595,18 @@
     %source = util.global.load @source : !stream.resource<constant>
     %source_size = stream.resource.size %source : !stream.resource<constant>
 
-    // CHECK: stream.tensor.sizeof tensor<4096x4096xf32, #[[$ENC1]]>
-    // CHECK: stream.tensor.encode {{.*}} -> tensor<4096x4096xf32, #[[$ENC1]]>
-    %size1 = stream.tensor.sizeof tensor<4096x4096xf32, #encoding1> : index
-    %enc1 = stream.tensor.encode %source : tensor<4096x4096xf32> in !stream.resource<constant>{%source_size} -> tensor<4096x4096xf32, #encoding1> in !stream.resource<*>{%size1}
-    %const1 = stream.async.clone %enc1 : !stream.resource<*>{%size1} -> !stream.resource<constant>{%size1}
+    // CHECK: stream.tensor.sizeof on(#hal.device.affinity<@[[$DEVICE_A]]>) tensor<4096x4096xf32, #[[$ENC1]]>
+    // CHECK: stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_A]]>) {{.*}} -> tensor<4096x4096xf32, #[[$ENC1]]>
+    %size1 = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<4096x4096xf32, #encoding1> : index
+    %enc1 = stream.tensor.encode on(#hal.device.affinity<@device_a>) %source : tensor<4096x4096xf32> in !stream.resource<constant>{%source_size} -> tensor<4096x4096xf32, #encoding1> in !stream.resource<*>{%size1}
+    %const1 = stream.async.clone on(#hal.device.affinity<@device_a>) %enc1 : !stream.resource<*>{%size1} -> !stream.resource<constant>{%size1}
     util.global.store %const1, @encoded_mutable_v1 : !stream.resource<constant>
 
-    // CHECK: stream.tensor.sizeof tensor<4096x4096xf32, #[[$ENC2]]>
-    // CHECK: stream.tensor.encode {{.*}} -> tensor<4096x4096xf32, #[[$ENC2]]>
-    %size2 = stream.tensor.sizeof tensor<4096x4096xf32, #encoding2> : index
-    %enc2 = stream.tensor.encode %source : tensor<4096x4096xf32> in !stream.resource<constant>{%source_size} -> tensor<4096x4096xf32, #encoding2> in !stream.resource<*>{%size2}
-    %const2 = stream.async.clone %enc2 : !stream.resource<*>{%size2} -> !stream.resource<constant>{%size2}
+    // CHECK: stream.tensor.sizeof on(#hal.device.affinity<@[[$DEVICE_A]]>) tensor<4096x4096xf32, #[[$ENC2]]>
+    // CHECK: stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_A]]>) {{.*}} -> tensor<4096x4096xf32, #[[$ENC2]]>
+    %size2 = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<4096x4096xf32, #encoding2> : index
+    %enc2 = stream.tensor.encode on(#hal.device.affinity<@device_a>) %source : tensor<4096x4096xf32> in !stream.resource<constant>{%source_size} -> tensor<4096x4096xf32, #encoding2> in !stream.resource<*>{%size2}
+    %const2 = stream.async.clone on(#hal.device.affinity<@device_a>) %enc2 : !stream.resource<*>{%size2} -> !stream.resource<constant>{%size2}
     util.global.store %const2, @encoded_mutable_v2 : !stream.resource<constant>
     util.return
   }
@@ -405,11 +616,15 @@
 
 // Test: single encoding - not a candidate for unification, encoding unchanged.
 
+#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {iree.encoding.resolver = #iree_encoding.identity_resolver}>
+#device_target_local = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
 #encoding1 = #iree_encoding.testing<layouts = [#iree_encoding.specialized<123>]>
 
 // CHECK: #[[$ENC:.+]] = #iree_encoding.testing<layouts = [#iree_encoding.specialized<123>]>
 // CHECK-LABEL: module @single_encoding_no_unification
+//       CHECK:   util.global private @[[$DEVICE_A:.+]] =
 module @single_encoding_no_unification {
+  util.global private @device_a = #device_target_local
   util.global private @source = #stream.parameter.named<"model"::"weight"> : !stream.resource<constant>
   util.global private @encoded : !stream.resource<constant>
 
@@ -417,11 +632,11 @@
     %source = util.global.load @source : !stream.resource<constant>
     %source_size = stream.resource.size %source : !stream.resource<constant>
 
-    // CHECK: stream.tensor.sizeof tensor<4096x4096xf32, #[[$ENC]]>
-    // CHECK: stream.tensor.encode {{.*}} -> tensor<4096x4096xf32, #[[$ENC]]>
-    %size1 = stream.tensor.sizeof tensor<4096x4096xf32, #encoding1> : index
-    %enc1 = stream.tensor.encode %source : tensor<4096x4096xf32> in !stream.resource<constant>{%source_size} -> tensor<4096x4096xf32, #encoding1> in !stream.resource<*>{%size1}
-    %const1 = stream.async.clone %enc1 : !stream.resource<*>{%size1} -> !stream.resource<constant>{%size1}
+    // CHECK: stream.tensor.sizeof on(#hal.device.affinity<@[[$DEVICE_A]]>) tensor<4096x4096xf32, #[[$ENC]]>
+    // CHECK: stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_A]]>) {{.*}} -> tensor<4096x4096xf32, #[[$ENC]]>
+    %size1 = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<4096x4096xf32, #encoding1> : index
+    %enc1 = stream.tensor.encode on(#hal.device.affinity<@device_a>) %source : tensor<4096x4096xf32> in !stream.resource<constant>{%source_size} -> tensor<4096x4096xf32, #encoding1> in !stream.resource<*>{%size1}
+    %const1 = stream.async.clone on(#hal.device.affinity<@device_a>) %enc1 : !stream.resource<*>{%size1} -> !stream.resource<constant>{%size1}
     util.global.store %const1, @encoded : !stream.resource<constant>
 
     util.return
@@ -432,11 +647,15 @@
 
 // Test: same encoding used twice - not a candidate (only one unique encoding).
 
+#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {iree.encoding.resolver = #iree_encoding.identity_resolver}>
+#device_target_local = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
 #encoding1 = #iree_encoding.testing<layouts = [#iree_encoding.specialized<123>]>
 
 // CHECK: #[[$ENC:.+]] = #iree_encoding.testing<layouts = [#iree_encoding.specialized<123>]>
 // CHECK-LABEL: module @same_encoding_twice_no_unification
+//       CHECK:   util.global private @[[$DEVICE_A:.+]] =
 module @same_encoding_twice_no_unification {
+  util.global private @device_a = #device_target_local
   util.global private @source = #stream.parameter.named<"model"::"weight"> : !stream.resource<constant>
   util.global private @encoded_v1 : !stream.resource<constant>
   util.global private @encoded_v2 : !stream.resource<constant>
@@ -445,18 +664,18 @@
     %source = util.global.load @source : !stream.resource<constant>
     %source_size = stream.resource.size %source : !stream.resource<constant>
 
-    // CHECK: stream.tensor.sizeof tensor<4096x4096xf32, #[[$ENC]]>
-    // CHECK: stream.tensor.encode {{.*}} -> tensor<4096x4096xf32, #[[$ENC]]>
-    %size1 = stream.tensor.sizeof tensor<4096x4096xf32, #encoding1> : index
-    %enc1 = stream.tensor.encode %source : tensor<4096x4096xf32> in !stream.resource<constant>{%source_size} -> tensor<4096x4096xf32, #encoding1> in !stream.resource<*>{%size1}
-    %const1 = stream.async.clone %enc1 : !stream.resource<*>{%size1} -> !stream.resource<constant>{%size1}
+    // CHECK: stream.tensor.sizeof on(#hal.device.affinity<@[[$DEVICE_A]]>) tensor<4096x4096xf32, #[[$ENC]]>
+    // CHECK: stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_A]]>) {{.*}} -> tensor<4096x4096xf32, #[[$ENC]]>
+    %size1 = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<4096x4096xf32, #encoding1> : index
+    %enc1 = stream.tensor.encode on(#hal.device.affinity<@device_a>) %source : tensor<4096x4096xf32> in !stream.resource<constant>{%source_size} -> tensor<4096x4096xf32, #encoding1> in !stream.resource<*>{%size1}
+    %const1 = stream.async.clone on(#hal.device.affinity<@device_a>) %enc1 : !stream.resource<*>{%size1} -> !stream.resource<constant>{%size1}
     util.global.store %const1, @encoded_v1 : !stream.resource<constant>
 
-    // CHECK: stream.tensor.sizeof tensor<4096x4096xf32, #[[$ENC]]>
-    // CHECK: stream.tensor.encode {{.*}} -> tensor<4096x4096xf32, #[[$ENC]]>
-    %size2 = stream.tensor.sizeof tensor<4096x4096xf32, #encoding1> : index
-    %enc2 = stream.tensor.encode %source : tensor<4096x4096xf32> in !stream.resource<constant>{%source_size} -> tensor<4096x4096xf32, #encoding1> in !stream.resource<*>{%size2}
-    %const2 = stream.async.clone %enc2 : !stream.resource<*>{%size2} -> !stream.resource<constant>{%size2}
+    // CHECK: stream.tensor.sizeof on(#hal.device.affinity<@[[$DEVICE_A]]>) tensor<4096x4096xf32, #[[$ENC]]>
+    // CHECK: stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_A]]>) {{.*}} -> tensor<4096x4096xf32, #[[$ENC]]>
+    %size2 = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<4096x4096xf32, #encoding1> : index
+    %enc2 = stream.tensor.encode on(#hal.device.affinity<@device_a>) %source : tensor<4096x4096xf32> in !stream.resource<constant>{%source_size} -> tensor<4096x4096xf32, #encoding1> in !stream.resource<*>{%size2}
+    %const2 = stream.async.clone on(#hal.device.affinity<@device_a>) %enc2 : !stream.resource<*>{%size2} -> !stream.resource<constant>{%size2}
     util.global.store %const2, @encoded_v2 : !stream.resource<constant>
 
     util.return