Making MaterializeInterfaces anchor on dispatch site device targets. (#16536)

Now which executable targets are selected for materialization is derived
from the dispatch sites for exports in the source executables. This
allows us to join all required targets for a particular export for
compilation while keeping each dispatch site referencing only the
targets it may dispatch.

To support easier testing and direct HAL executable compilation any
`hal.executable.source` or `stream.executable` that is public will take
all targets specified on the module in addition to any from dispatch
sites. In all normal programs the executables should be private and only
use the dispatch sites to determine their targets.
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp
index 9676ef0..f762a1f 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp
@@ -36,25 +36,9 @@
   }
 }
 
-// Finds all dispatches within |rootOp| and groups them by executable export.
-static BindingLayoutAnalysis::ExportDispatchMap
-findAllDispatchSites(Operation *rootOp) {
-  SymbolTable symbolTable(rootOp);
-  BindingLayoutAnalysis::ExportDispatchMap dispatchMap;
-  rootOp->walk([&](IREE::Stream::CmdDispatchOp dispatchOp) {
-    dispatchOp.forEachEntryPointAttr([&](SymbolRefAttr entryPointAttr) {
-      auto exportOp =
-          symbolTable.lookupNearestSymbolFrom(dispatchOp, entryPointAttr);
-      dispatchMap[exportOp].push_back(dispatchOp);
-    });
-  });
-  return dispatchMap;
-}
-
 // Assumes an explicit layout as specified on an export.
 static PipelineLayout
-assumeExportLayout(IREE::Stream::ExecutableExportOp exportOp,
-                   IREE::HAL::PipelineLayoutAttr layoutAttr) {
+assumeExportLayout(IREE::HAL::PipelineLayoutAttr layoutAttr) {
   PipelineLayout pipelineLayout;
   pipelineLayout.pushConstantCount = layoutAttr.getPushConstants();
 
@@ -86,23 +70,24 @@
     pipelineLayout.setLayouts[setLayout.ordinal] = setLayout;
   }
 
-  LLVM_DEBUG({
-    auto executableOp = exportOp->getParentOfType<IREE::Stream::ExecutableOp>();
-    llvm::dbgs() << "assumeExportLayout(@" << executableOp.getSymName() << "::@"
-                 << exportOp.getSymName() << "):\n";
-    pipelineLayout.print(llvm::dbgs());
-  });
-
   return pipelineLayout;
 }
 
 // Derives an pipeline layout from all of the dispatches to |exportOp|.
 static PipelineLayout
-deriveExportLayout(IREE::Stream::ExecutableExportOp exportOp,
-                   SmallVector<IREE::Stream::CmdDispatchOp> &dispatchOps) {
+deriveStreamExportLayout(IREE::Stream::ExecutableExportOp exportOp,
+                         ArrayRef<IREE::Stream::CmdDispatchOp> dispatchOps) {
   if (auto layoutAttr = exportOp->getAttrOfType<IREE::HAL::PipelineLayoutAttr>(
           "hal.interface.layout")) {
-    return assumeExportLayout(exportOp, layoutAttr);
+    auto assumedLayout = assumeExportLayout(layoutAttr);
+    LLVM_DEBUG({
+      auto executableOp =
+          exportOp->getParentOfType<IREE::Stream::ExecutableOp>();
+      llvm::dbgs() << "assumeExportLayout(@" << executableOp.getSymName()
+                   << "::@" << exportOp.getSymName() << "):\n";
+      assumedLayout.print(llvm::dbgs());
+    });
+    return assumedLayout;
   }
 
   auto funcOp = exportOp.lookupFunctionRef();
@@ -184,36 +169,61 @@
   return pipelineLayout;
 }
 
-static BindingLayoutAnalysis::ExportLayoutMap
-deriveExportLayouts(Operation *rootOp,
-                    BindingLayoutAnalysis::ExportDispatchMap dispatchMap) {
-  BindingLayoutAnalysis::ExportLayoutMap layoutMap;
-  rootOp->walk([&](IREE::Stream::ExecutableExportOp exportOp) {
-    auto &dispatchOps = dispatchMap[exportOp];
-    layoutMap[exportOp] = deriveExportLayout(exportOp, dispatchOps);
+BindingLayoutAnalysis::BindingLayoutAnalysis(Operation *rootOp,
+                                             SymbolTable &symbolTable) {
+  // Finds all exports and dispatches within rootOp and groups them by
+  // executable export. We need to complete gathering all of the information
+  // before we derive the layouts.
+  auto getExportInfo = [&](Operation *exportOp) -> ExportInfo & {
+    auto &exportInfo = exportInfos[exportOp];
+    if (!exportInfo)
+      exportInfo = std::make_unique<ExportInfo>();
+    return *exportInfo;
+  };
+  rootOp->walk([&](Operation *op) {
+    TypeSwitch<Operation *>(op)
+        .Case<IREE::Stream::ExecutableExportOp>(
+            [&](auto exportOp) { (void)getExportInfo(exportOp); })
+        .Case<IREE::HAL::ExecutableExportOp>([&](auto exportOp) {
+          auto &exportInfo = getExportInfo(exportOp);
+          exportInfo.pipelineLayout =
+              assumeExportLayout(exportOp.getLayoutAttr());
+        })
+        .Case<IREE::Stream::CmdDispatchOp>([&](auto dispatchOp) {
+          dispatchOp.forEachEntryPointAttr([&](SymbolRefAttr entryPointAttr) {
+            auto exportOp =
+                symbolTable.lookupNearestSymbolFrom(dispatchOp, entryPointAttr);
+            auto &exportInfo = getExportInfo(exportOp);
+            exportInfo.dispatchOps.push_back(dispatchOp);
+          });
+        })
+        .Default([](auto op) {});
   });
-  return layoutMap;
+
+  // Derive the layouts for each export op.
+  for (auto &it : exportInfos) {
+    TypeSwitch<Operation *>(it.first)
+        .Case<IREE::Stream::ExecutableExportOp>([&](auto exportOp) {
+          it.second->pipelineLayout =
+              deriveStreamExportLayout(exportOp, it.second->dispatchOps);
+        })
+        .Default([&](auto op) {});
+  }
 }
 
-BindingLayoutAnalysis::BindingLayoutAnalysis(Operation *rootOp) {
-  exportDispatches = findAllDispatchSites(rootOp);
-  exportLayouts = deriveExportLayouts(rootOp, exportDispatches);
+ArrayRef<IREE::Stream::CmdDispatchOp>
+BindingLayoutAnalysis::getExportDispatches(Operation *exportOp) const {
+  auto it = exportInfos.find(exportOp);
+  if (it == exportInfos.end())
+    return {}; // not analyzed
+  return it->second.get()->dispatchOps;
 }
 
-SmallVector<IREE::Stream::CmdDispatchOp>
-BindingLayoutAnalysis::getExportDispatches(
-    IREE::Stream::ExecutableExportOp exportOp) const {
-  auto it = exportDispatches.find(exportOp);
-  if (it == exportDispatches.end())
-    return {}; // no dispatches
-  return it->second;
-}
-
-const PipelineLayout &BindingLayoutAnalysis::getPipelineLayout(
-    IREE::Stream::ExecutableExportOp exportOp) const {
-  auto it = exportLayouts.find(exportOp);
-  assert(it != exportLayouts.end() && "unanalyzed export");
-  return it->second;
+const PipelineLayout &
+BindingLayoutAnalysis::getPipelineLayout(Operation *exportOp) const {
+  auto it = exportInfos.find(exportOp);
+  assert(it != exportInfos.end() && "unanalyzed export");
+  return it->second.get()->pipelineLayout;
 }
 
 } // namespace mlir::iree_compiler::IREE::HAL
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.h b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.h
index c309f74..1e8704d 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.h
+++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.h
@@ -7,6 +7,7 @@
 #ifndef IREE_COMPILER_DIALECT_HAL_ANALYSIS_BINDINGLAYOUT_
 #define IREE_COMPILER_DIALECT_HAL_ANALYSIS_BINDINGLAYOUT_
 
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
 #include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
 #include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
 #include "llvm/Support/raw_ostream.h"
@@ -56,26 +57,39 @@
 // NOTE: erasing dispatch ops will invalidate this analysis.
 class BindingLayoutAnalysis {
 public:
-  using ExportDispatchMap =
-      DenseMap<Operation *, SmallVector<IREE::Stream::CmdDispatchOp>>;
-  using ExportLayoutMap = DenseMap<Operation *, PipelineLayout>;
-
-  explicit BindingLayoutAnalysis(Operation *rootOp);
+  explicit BindingLayoutAnalysis(Operation *rootOp, SymbolTable &symbolTable);
 
   // Returns all of the dispatches to the given executable export.
-  SmallVector<IREE::Stream::CmdDispatchOp>
-  getExportDispatches(IREE::Stream::ExecutableExportOp exportOp) const;
+  ArrayRef<IREE::Stream::CmdDispatchOp>
+  getExportDispatches(IREE::Stream::ExecutableExportOp exportOp) const {
+    return getExportDispatches(exportOp.getOperation());
+  }
+  ArrayRef<IREE::Stream::CmdDispatchOp>
+  getExportDispatches(IREE::HAL::ExecutableExportOp exportOp) const {
+    return getExportDispatches(exportOp.getOperation());
+  }
 
   // Returns a layout used for the given executable export op.
   const PipelineLayout &
-  getPipelineLayout(IREE::Stream::ExecutableExportOp exportOp) const;
+  getPipelineLayout(IREE::Stream::ExecutableExportOp exportOp) const {
+    return getPipelineLayout(exportOp.getOperation());
+  }
+  const PipelineLayout &
+  getPipelineLayout(IREE::HAL::ExecutableExportOp exportOp) const {
+    return getPipelineLayout(exportOp.getOperation());
+  }
 
 private:
-  // All dispatches to a particular executable IREE::Stream::ExecutableExportOp.
-  ExportDispatchMap exportDispatches;
-  // Pipeline layout for each IREE::Stream::ExecutableExportOp.
-  // Many of these may be duplicates.
-  ExportLayoutMap exportLayouts;
+  ArrayRef<IREE::Stream::CmdDispatchOp>
+  getExportDispatches(Operation *exportOp) const;
+  const PipelineLayout &getPipelineLayout(Operation *exportOp) const;
+
+  struct ExportInfo {
+    SmallVector<IREE::Stream::CmdDispatchOp> dispatchOps;
+    PipelineLayout pipelineLayout;
+  };
+  using ExportInfoMap = DenseMap<Operation *, std::unique_ptr<ExportInfo>>;
+  ExportInfoMap exportInfos;
 };
 
 } // namespace mlir::iree_compiler::IREE::HAL
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
index 022898f..a6518ec 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
@@ -36,8 +36,11 @@
 
 namespace {
 
-// Map of original SymbolRefAttr to a list of SymbolRefAttrs in variants.
-using ExportExpansions = DenseMap<Attribute, SmallVector<Attribute>>;
+// Map of original SymbolRefAttr to a list of SymbolRefAttrs in variants marked
+// with the executable target the export is assigned.
+using ExportExpansions = DenseMap<
+    Attribute,
+    SmallVector<std::pair<Attribute, IREE::HAL::ExecutableTargetAttr>>>;
 
 //===----------------------------------------------------------------------===//
 // Utilities
@@ -55,10 +58,48 @@
   targetOp.setObjectsAttr(*objects);
 }
 
+// Returns a set of executable targets required by any dispatch to the given
+// executable. Not all exports may be dispatched on the targets.
+// If the |executableOp| is public then targets specified on the module will be
+// used in addition to any from the dispatches.
+template <typename OpT>
+static SmallVector<IREE::HAL::ExecutableTargetAttr>
+gatherExecutableTargetAttrs(SymbolOpInterface executableOp,
+                            llvm::iterator_range<OpT> exportOps,
+                            const BindingLayoutAnalysis &layoutAnalysis) {
+  llvm::SetVector<IREE::HAL::ExecutableTargetAttr,
+                  SmallVector<IREE::HAL::ExecutableTargetAttr>>
+      targetAttrsSet;
+  if (executableOp.isPublic()) {
+    for (auto targetAttr :
+         IREE::HAL::DeviceTargetAttr::lookupExecutableTargets(executableOp)) {
+      targetAttrsSet.insert(targetAttr);
+    }
+  }
+  for (auto exportOp : exportOps) {
+    for (auto dispatchOp : layoutAnalysis.getExportDispatches(exportOp)) {
+      for (auto targetAttr :
+           IREE::HAL::DeviceTargetAttr::lookupExecutableTargets(dispatchOp)) {
+        targetAttrsSet.insert(targetAttr);
+      }
+    }
+  }
+  auto targetAttrs = targetAttrsSet.takeVector();
+  llvm::stable_sort(targetAttrs, [](auto lhs, auto rhs) {
+    return lhs.getSymbolNameFragment() < rhs.getSymbolNameFragment();
+  });
+  return targetAttrs;
+}
+
 // Updates the target entry point symbols of |dispatchOp| to the expanded set of
 // variant exports in |exportExpansions|.
 static void updateDispatchTargets(IREE::Stream::CmdDispatchOp dispatchOp,
                                   const ExportExpansions &exportExpansions) {
+  DenseSet<IREE::HAL::ExecutableTargetAttr> requiredTargetAttrs;
+  for (auto targetAttr :
+       IREE::HAL::DeviceTargetAttr::lookupExecutableTargets(dispatchOp)) {
+    requiredTargetAttrs.insert(targetAttr);
+  }
   SmallVector<Attribute> newAttrs;
   for (auto oldAttr : dispatchOp.getEntryPointRefs()) {
     auto it = exportExpansions.find(oldAttr);
@@ -66,8 +107,10 @@
       newAttrs.push_back(oldAttr); // preserve existing
       continue;
     }
-    for (auto newAttr : it->second) {
-      newAttrs.push_back(newAttr);
+    for (auto [newAttr, targetAttr] : it->second) {
+      // Filter the new expansions to only those used by the dispatch.
+      if (requiredTargetAttrs.contains(targetAttr))
+        newAttrs.push_back(newAttr);
     }
   }
   dispatchOp.setEntryPointsAttr(
@@ -88,13 +131,19 @@
                             });
 }
 
-static LogicalResult materializeExecutableFromSourceOp(
-    IREE::HAL::ExecutableSourceOp sourceOp,
-    ArrayRef<IREE::HAL::ExecutableTargetAttr> targetAttrs,
-    ExportExpansions &exportExpansions) {
-  OpBuilder moduleBuilder(sourceOp);
+static void
+materializeExecutableFromSourceOp(IREE::HAL::ExecutableSourceOp sourceOp,
+                                  BindingLayoutAnalysis &layoutAnalysis) {
+  // Gather the required executable targets based on the dispatches to exports
+  // in the source op.
+  auto targetAttrs = gatherExecutableTargetAttrs(
+      sourceOp, sourceOp.getOps<IREE::HAL::ExecutableExportOp>(),
+      layoutAnalysis);
+  if (targetAttrs.empty())
+    return;
 
   // Create the op that will contain the translated executable.
+  OpBuilder moduleBuilder(sourceOp);
   auto executableOp = moduleBuilder.create<IREE::HAL::ExecutableOp>(
       sourceOp.getLoc(), sourceOp.getName());
   executableOp.setVisibility(sourceOp.getVisibility());
@@ -105,6 +154,7 @@
 
   // Materialize all of the hal.executable.variant ops for all backends we are
   // targeting.
+  ExportExpansions exportExpansions;
   SymbolTable targetSymbolTable(executableOp);
   OpBuilder targetBuilder(&executableOp.getBlock().back());
   for (auto targetAttr : targetAttrs) {
@@ -117,11 +167,13 @@
       variantBuilder.clone(*sourceExportOp);
 
       // Map the original export names to the new variant exports.
-      exportExpansions[SymbolRefAttr::get(executableOp.getNameAttr(),
-                                          {FlatSymbolRefAttr::get(
-                                              sourceExportOp.getNameAttr())})]
-          .push_back(makeExportSymbolRefAttr(executableOp, targetVariantOp,
-                                             sourceExportOp));
+      auto oldRefAttr = SymbolRefAttr::get(
+          executableOp.getNameAttr(),
+          {FlatSymbolRefAttr::get(sourceExportOp.getNameAttr())});
+      auto newRefAttr = makeExportSymbolRefAttr(executableOp, targetVariantOp,
+                                                sourceExportOp);
+      exportExpansions[oldRefAttr].push_back(
+          std::make_pair(newRefAttr, targetAttr));
     }
 
     // Clone any target-specific object files specified.
@@ -138,43 +190,15 @@
     }
   }
 
+  // Update all dispatch sites to reference the new expanded variants.
+  for (auto exportOp : sourceExportOps) {
+    for (auto dispatchOp : layoutAnalysis.getExportDispatches(exportOp)) {
+      updateDispatchTargets(dispatchOp, exportExpansions);
+    }
+  }
+
   // Remove the original.
   sourceOp.erase();
-
-  return success();
-}
-
-static LogicalResult
-materializeExecutablesFromSourceOps(mlir::ModuleOp moduleOp) {
-  ExportExpansions exportExpansions;
-
-  auto sourceOps =
-      llvm::to_vector<32>(moduleOp.getOps<IREE::HAL::ExecutableSourceOp>());
-  for (auto sourceOp : sourceOps) {
-    // Gather a list of all #hal.executable.targets that we should produce
-    // variants for.
-    auto targetAttrs =
-        IREE::HAL::DeviceTargetAttr::lookupExecutableTargets(sourceOp);
-    if (targetAttrs.empty()) {
-      return sourceOp.emitError()
-             << "no executable targets specified for translation";
-    }
-
-    if (failed(materializeExecutableFromSourceOp(sourceOp, targetAttrs,
-                                                 exportExpansions))) {
-      return failure();
-    }
-  }
-  if (exportExpansions.empty())
-    return success();
-
-  for (auto funcOp : moduleOp.getOps<FunctionOpInterface>()) {
-    funcOp.walk([&](IREE::Stream::CmdDispatchOp dispatchOp) {
-      updateDispatchTargets(dispatchOp, exportExpansions);
-    });
-  }
-
-  return success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -374,11 +398,13 @@
           /*workgroup_local_memory=*/IntegerAttr{});
 
       // Map the original export name to the new variant export.
-      exportExpansions[SymbolRefAttr::get(
-                           sourceExecutableOp.getNameAttr(),
-                           {FlatSymbolRefAttr::get(exportOp.getNameAttr())})]
-          .push_back(makeExportSymbolRefAttr(targetExecutableOp, variantOp,
-                                             newExportOp));
+      auto oldRefAttr =
+          SymbolRefAttr::get(sourceExecutableOp.getNameAttr(),
+                             {FlatSymbolRefAttr::get(exportOp.getNameAttr())});
+      auto newRefAttr =
+          makeExportSymbolRefAttr(targetExecutableOp, variantOp, newExportOp);
+      exportExpansions[oldRefAttr].push_back(
+          std::make_pair(newRefAttr, variantOp.getTargetAttr()));
 
       // Annotate the export with the a mapping of the resources to the
       // interface bindings. This is used during conversion.
@@ -524,39 +550,38 @@
     : public IREE::HAL::impl::MaterializeInterfacesPassBase<
           MaterializeInterfacesPass> {
   void runOnOperation() override {
-    SymbolTable symbolTable(getOperation());
+    auto moduleOp = getOperation();
+    SymbolTable symbolTable(moduleOp);
+    BindingLayoutAnalysis layoutAnalysis(moduleOp, symbolTable);
 
     // Handle any hand-authored executables; these only need variant expansion
     // and no layout analysis as the user specified the layout themselves.
-    if (failed(materializeExecutablesFromSourceOps(getOperation()))) {
-      return signalPassFailure();
+    for (auto sourceOp : llvm::make_early_inc_range(
+             moduleOp.getOps<IREE::HAL::ExecutableSourceOp>())) {
+      materializeExecutableFromSourceOp(sourceOp, layoutAnalysis);
     }
 
-    const auto &layoutAnalysis = getAnalysis<BindingLayoutAnalysis>();
-
     // Processes all executables within the input module and produce the
     // output HAL ops. We should ensure all deduping is performed prior to
     // this when it's easier to diff IR and where we still have the flow
     // context.
-    auto sourceOps = llvm::to_vector<32>(
-        getOperation().getOps<IREE::Stream::ExecutableOp>());
-    for (auto sourceOp : sourceOps) {
+    for (auto sourceOp : llvm::make_early_inc_range(
+             moduleOp.getOps<IREE::Stream::ExecutableOp>())) {
       auto exportOps = sourceOp.getOps<IREE::Stream::ExecutableExportOp>();
       if (exportOps.empty())
         continue;
 
       // Gather a list of all #hal.executable.targets that we should produce
-      // variants for.
+      // variants for based on the dispatches performed. Not all exports may be
+      // used on any particular target but we let future DCE/pruning passes
+      // remove them instead of modifying the inner modules here.
       auto targetAttrs =
-          IREE::HAL::DeviceTargetAttr::lookupExecutableTargets(sourceOp);
-      if (targetAttrs.empty()) {
-        sourceOp.emitError()
-            << "no executable targets specified for translation";
-        return signalPassFailure();
-      }
+          gatherExecutableTargetAttrs(sourceOp, exportOps, layoutAnalysis);
+      if (targetAttrs.empty())
+        continue;
 
       // Create the op that will contain the translated executable.
-      OpBuilder builder = OpBuilder::atBlockEnd(getOperation().getBody());
+      OpBuilder builder = OpBuilder::atBlockEnd(moduleOp.getBody());
       builder.setInsertionPointAfter(sourceOp);
       auto executableOp = builder.create<IREE::HAL::ExecutableOp>(
           sourceOp.getLoc(), sourceOp.getName());
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir
index db45b5e..87eb7a1 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir
@@ -5,12 +5,11 @@
 module attributes {hal.device.targets = [
   #hal.device.target<"llvm-cpu", {
     executable_targets = [
-      #hal.executable.target<"llvm-cpu", "embedded-elf-arm_64">,
-      #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64">
+      #hal.executable.target<"llvm-cpu", "arm_64">,
+      #hal.executable.target<"llvm-cpu", "x86_64">
     ]
   }>
 ]} {
-
   // CHECK: #pipeline_layout = #hal.pipeline.layout<
   // CHECK-SAME: push_constants = 1
   // CHECK-SAME: sets = [
@@ -19,8 +18,8 @@
   // CHECK-SAME:     <1, storage_buffer, ReadOnly>
   // CHECK-SAME:     <2, storage_buffer>
 
-  // CHECK: hal.executable private @ex_workgroups
-  // CHECK:   hal.executable.variant public @embedded_elf_arm_64 target(#executable_target_embedded_elf_arm_64
+  // CHECK: hal.executable private @ex
+  // CHECK:   hal.executable.variant public @arm_64 target(#executable_target_arm_64
   // CHECK:     hal.executable.export public @entry ordinal(0) layout(#pipeline_layout)
   // CHECK-SAME:   hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>]
   // CHECK-NEXT: ^bb0(%[[DEVICE:.+]]: !hal.device, %[[ARG0:.+]]: index, %[[ARG1:.+]]: index):
@@ -29,7 +28,7 @@
   // CHECK:     builtin.module
   // CHECK-NEXT:  func.func private @extern_func()
   // CHECK-NEXT:  func.func @entry
-  // CHECK:   hal.executable.variant public @embedded_elf_x86_64 target(#executable_target_embedded_elf_x86_64
+  // CHECK:   hal.executable.variant public @x86_64 target(#executable_target_x86_64
   // CHECK:     hal.executable.export public @entry ordinal(0) layout(#pipeline_layout)
   // CHECK-SAME:   hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>]
   // CHECK-NEXT: ^bb0(%[[DEVICE:.+]]: !hal.device, %[[ARG0:.+]]: index, %[[ARG1:.+]]: index):
@@ -37,9 +36,9 @@
   // CHECK-NEXT: }
   // CHECK:     builtin.module
   // CHECK-NEXT:  func.func private @extern_func()
-  // CHECK-NEXT:  func.func @entry
 
-  stream.executable private @ex_workgroups {
+  // CHECK-NEXT:  func.func @entry
+  stream.executable private @ex {
     stream.executable.export public @entry workgroups(%arg0: index, %arg1: index) -> (index, index, index) {
       stream.return %arg0, %arg1, %arg0 : index, index, index
     }
@@ -56,8 +55,10 @@
     %c2 = arith.constant 2 : index
     %0 = stream.resource.alloc uninitialized : !stream.resource<transient>{%arg2}
     %1 = stream.cmd.execute with(%arg0 as %arg4: !stream.resource<constant>{%arg2}, %arg1 as %arg5: !stream.resource<transient>{%arg2}, %0 as %arg6: !stream.resource<transient>{%arg2}) {
-      // CHECK: stream.cmd.dispatch {@ex_workgroups::@embedded_elf_arm_64::@entry, @ex_workgroups::@embedded_elf_x86_64::@entry}
-      stream.cmd.dispatch @ex_workgroups::@entry[%c1, %c2](%arg3 : i32) {
+      // CHECK: stream.cmd.dispatch
+      // CHECK-SAME: @ex::@arm_64::@entry
+      // CHECK-SAME: @ex::@x86_64::@entry
+      stream.cmd.dispatch @ex::@entry[%c1, %c2](%arg3 : i32) {
         ro %arg4[%c0 for %arg2] : !stream.resource<constant>{%arg2},
         ro %arg5[%c0 for %arg2] : !stream.resource<transient>{%arg2},
         wo %arg6[%c0 for %arg2] : !stream.resource<transient>{%arg2}
@@ -70,56 +71,204 @@
 
 // -----
 
+// Tests that executable variants are expanded based on what devices they are
+// dispatched on.
+
+module attributes {
+  // The default device when none is specified.
+  // Functions and scopes can override the target device.
+  hal.device.targets = [
+    #hal.device.target<"cpu", {
+      executable_targets = [
+        #hal.executable.target<"llvm-cpu", "arm_64">,
+        #hal.executable.target<"llvm-cpu", "x86_64">
+      ]
+    }>
+  ]
+} {
+  // CHECK: hal.executable private @ex
+  // CHECK:   hal.executable.variant public @arm_64
+  // CHECK:   hal.executable.variant public @riscv_32
+  // CHECK:   hal.executable.variant public @x86_64
+  stream.executable private @ex {
+    stream.executable.export public @entry workgroups() -> (index, index, index) {
+      %c1 = arith.constant 1 : index
+      stream.return %c1, %c1, %c1 : index, index, index
+    }
+    builtin.module {
+      func.func @entry(%arg0: !stream.binding {stream.alignment = 64 : index}) {
+        return
+      }
+    }
+  }
+  // This function uses the default HAL device targeting arm_64 and x86_64.
+  // CHECK-LABEL: @using_default
+  util.func public @using_default(%arg0: !stream.resource<transient>, %arg1: index) -> !stream.timepoint {
+    %c0 = arith.constant 0 : index
+    %0 = stream.cmd.execute with(%arg0 as %arg2: !stream.resource<transient>{%arg1}) {
+      // CHECK: stream.cmd.dispatch
+      // CHECK-SAME: @ex::@arm_64::@entry
+      // CHECK-NOT: @ex::@riscv_32::@entry
+      // CHECK-SAME: @ex::@x86_64::@entry
+      stream.cmd.dispatch @ex::@entry {
+        rw %arg2[%c0 for %arg1] : !stream.resource<transient>{%arg1}
+      }
+    } => !stream.timepoint
+    util.return %0 : !stream.timepoint
+  }
+  // This function is specialized to only run on only riscv_32 and should
+  // not get assigned the arm_64/x86_64 variant entry points.
+  // CHECK-LABEL: @using_specialized
+  util.func public @using_specialized(%arg0: !stream.resource<transient>, %arg1: index) -> !stream.timepoint attributes {
+    hal.device.targets = [
+      #hal.device.target<"cpu", {
+        executable_targets = [
+          #hal.executable.target<"llvm-cpu", "riscv_32">
+        ]
+      }>
+    ]
+  } {
+    %c0 = arith.constant 0 : index
+    %0 = stream.cmd.execute with(%arg0 as %arg2: !stream.resource<transient>{%arg1}) {
+      // CHECK: stream.cmd.dispatch
+      // CHECK-NOT: @ex::@arm_64::@entry
+      // CHECK-SAME: @ex::@riscv_32::@entry
+      // CHECK-NOT: @ex::@x86_64::@entry
+      stream.cmd.dispatch @ex::@entry {
+        rw %arg2[%c0 for %arg1] : !stream.resource<transient>{%arg1}
+      }
+    } => !stream.timepoint
+    util.return %0 : !stream.timepoint
+  }
+}
+
+// -----
+
 // Tests an already-specified executable source op is expanded into the variants
 // specified by the target configuration. These source executables may come from
 // hand-authored code or other dialects that perform interface assignment
 // themselves.
 
-module attributes {hal.device.targets = [
-  #hal.device.target<"llvm-cpu", {
-    executable_targets = [
-      #hal.executable.target<"llvm-cpu", "embedded-elf-arm_64">,
-      #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64">
-    ]
-  }>
-]} {
-
-hal.executable.source public @ex {
-  hal.executable.export public @entry layout(#hal.pipeline.layout<push_constants = 1, sets = [
-    #hal.descriptor_set.layout<0, bindings = [
-      #hal.descriptor_set.binding<0, storage_buffer>
-    ]>,
-    #hal.descriptor_set.layout<1, bindings = [
-      #hal.descriptor_set.binding<0, storage_buffer>,
-      #hal.descriptor_set.binding<1, storage_buffer>
-    ]>
-  ]>)
-  builtin.module {
-    func.func @entry() {
-      %const0 = hal.interface.constant.load[0] : index
-      %const1 = hal.interface.constant.load[1] : index
-      %s0b0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(32) offset(%const0) : !flow.dispatch.tensor<readonly:tensor<4xf32>>
-      %s1b0 = hal.interface.binding.subspan set(1) binding(0) type(storage_buffer) alignment(32) offset(%const1) : !flow.dispatch.tensor<readonly:tensor<4xf32>>
-      %s1b1 = hal.interface.binding.subspan set(1) binding(1) type(storage_buffer) alignment(16) : !flow.dispatch.tensor<writeonly:tensor<4xf32>>
-      %workgroup_size_x = hal.interface.workgroup.size[0] : index
-      %workgroup_id_x = hal.interface.workgroup.id[0] : index
-      %workgroup_count_x = hal.interface.workgroup.count[0] : index
-      return
+module attributes {
+  // The default device when none is specified.
+  // Functions and scopes can override the target device.
+  hal.device.targets = [
+    #hal.device.target<"cpu", {
+      executable_targets = [
+        #hal.executable.target<"llvm-cpu", "arm_64">,
+        #hal.executable.target<"llvm-cpu", "x86_64">
+      ]
+    }>
+  ]
+} {
+  // CHECK: hal.executable private @ex
+  // CHECK:   hal.executable.variant public @arm_64
+  // CHECK:   hal.executable.variant public @riscv_32
+  // CHECK:   hal.executable.variant public @x86_64
+  hal.executable.source private @ex {
+    hal.executable.export public @entry layout(#hal.pipeline.layout<push_constants = 0, sets = [
+      #hal.descriptor_set.layout<0, bindings = [
+        #hal.descriptor_set.binding<0, storage_buffer>
+      ]>
+    ]>)
+    builtin.module {
+      func.func @entry() {
+        return
+      }
     }
   }
+  // This function uses the default HAL device targeting arm_64 and x86_64.
+  // CHECK-LABEL: @using_default
+  util.func public @using_default(%arg0: !stream.resource<transient>, %arg1: index) -> !stream.timepoint {
+    %c0 = arith.constant 0 : index
+    %0 = stream.cmd.execute with(%arg0 as %arg2: !stream.resource<transient>{%arg1}) {
+      // CHECK: stream.cmd.dispatch
+      // CHECK-SAME: @ex::@arm_64::@entry
+      // CHECK-NOT: @ex::@riscv_32::@entry
+      // CHECK-SAME: @ex::@x86_64::@entry
+      stream.cmd.dispatch @ex::@entry {
+        rw %arg2[%c0 for %arg1] : !stream.resource<transient>{%arg1}
+      }
+    } => !stream.timepoint
+    util.return %0 : !stream.timepoint
+  }
+  // This function is specialized to only run on only riscv_32 and should
+  // not get assigned the arm_64/x86_64 variant entry points.
+  // CHECK-LABEL: @using_specialized
+  util.func public @using_specialized(%arg0: !stream.resource<transient>, %arg1: index) -> !stream.timepoint attributes {
+    hal.device.targets = [
+      #hal.device.target<"cpu", {
+        executable_targets = [
+          #hal.executable.target<"llvm-cpu", "riscv_32">
+        ]
+      }>
+    ]
+  } {
+    %c0 = arith.constant 0 : index
+    %0 = stream.cmd.execute with(%arg0 as %arg2: !stream.resource<transient>{%arg1}) {
+      // CHECK: stream.cmd.dispatch
+      // CHECK-NOT: @ex::@arm_64::@entry
+      // CHECK-SAME: @ex::@riscv_32::@entry
+      // CHECK-NOT: @ex::@x86_64::@entry
+      stream.cmd.dispatch @ex::@entry {
+        rw %arg2[%c0 for %arg1] : !stream.resource<transient>{%arg1}
+      }
+    } => !stream.timepoint
+    util.return %0 : !stream.timepoint
+  }
 }
 
-// CHECK: hal.executable public @ex
-// CHECK:   hal.executable.variant public @embedded_elf_arm_64 target(#executable_target_embedded_elf_arm_64
-// CHECK:     hal.executable.export public @entry layout(#pipeline_layout)
-// CHECK:     builtin.module
-// CHECK-NEXT:  func.func @entry()
-// CHECK:   hal.executable.variant public @embedded_elf_x86_64 target(#executable_target_embedded_elf_x86_64
-// CHECK:     hal.executable.export public @entry layout(#pipeline_layout)
-// CHECK:     builtin.module
-// CHECK-NEXT:  func.func @entry()
+// -----
 
-// TODO(benvanik): test fixup of stream ops when attrs to specify the
-// layout bindings are implemented.
+// Tests that a hal.executable.source op gets expanded to all default targets
+// when it's public in addition to any ones from dispatch sites.
 
+module attributes {
+  hal.device.targets = [
+    #hal.device.target<"cpu", {
+      executable_targets = [
+        #hal.executable.target<"llvm-cpu", "arm_64">,
+        #hal.executable.target<"llvm-cpu", "x86_64">
+      ]
+    }>
+  ]
+} {
+  // CHECK: hal.executable public @ex
+  // CHECK:   hal.executable.variant public @arm_64
+  // CHECK:   hal.executable.variant public @riscv_32
+  // CHECK:   hal.executable.variant public @x86_64
+  hal.executable.source public @ex {
+    hal.executable.export public @entry layout(#hal.pipeline.layout<push_constants = 0, sets = [
+      #hal.descriptor_set.layout<0, bindings = [
+        #hal.descriptor_set.binding<0, storage_buffer>
+      ]>
+    ]>)
+    builtin.module {
+      func.func @entry() {
+        return
+      }
+    }
+  }
+  // CHECK-LABEL: @using_specialized
+  util.func public @using_specialized(%arg0: !stream.resource<transient>, %arg1: index) -> !stream.timepoint attributes {
+    hal.device.targets = [
+      #hal.device.target<"cpu", {
+        executable_targets = [
+          #hal.executable.target<"llvm-cpu", "riscv_32">
+        ]
+      }>
+    ]
+  } {
+    %c0 = arith.constant 0 : index
+    %0 = stream.cmd.execute with(%arg0 as %arg2: !stream.resource<transient>{%arg1}) {
+      // CHECK: stream.cmd.dispatch
+      // CHECK-NOT: @ex::@arm_64::@entry
+      // CHECK-SAME: @ex::@riscv_32::@entry
+      // CHECK-NOT: @ex::@x86_64::@entry
+      stream.cmd.dispatch @ex::@entry {
+        rw %arg2[%c0 for %arg1] : !stream.resource<transient>{%arg1}
+      }
+    } => !stream.timepoint
+    util.return %0 : !stream.timepoint
+  }
 }