Moving MaterializeInterfaces' spooky action at a distance around a little. (#16521)

The magic `hal.interface.bindings` attribute that allows for
`stream.cmd.dispatch` alignment with HAL interfaces during conversion is
now moved to the export ops instead of being added on the dispatches.
This allows each variant to have its own binding mapping and thus its
own pipeline layout in subsequent steps. We still generate the same
mapping for everything today but externally-provided executables are now
allowed to differ and in the future target backends can specify their
own. This will also allow us to potentially perform pruning/linking
prior to conversion for cases where some variants are only dispatched on
certain devices and we want to optimize layouts to reduce the number of
layout changes during execution.

Future changes rework this code even more to reduce the number of
full-module walks we perform and allow for scoped device targets.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchExterns.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchExterns.cpp
index 32487bf..19ce74e 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchExterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchExterns.cpp
@@ -43,9 +43,6 @@
       dispatchExternOp.getArguments(), dispatchExternOp.getArgumentDims(),
       dispatchExternOp.getResultDims(), dispatchExternOp.getTiedOperandsAttr());
   dispatchOp->setDialectAttrs(dispatchExternOp->getDialectAttrs());
-  if (auto bindingsAttr = dispatchExternOp.getBindingsAttr()) {
-    dispatchOp->setAttr("hal.interface.bindings", bindingsAttr);
-  }
 
   // Replace uses of the existing results with the new results.
   for (int i = 0; i < dispatchExternOp.getNumResults(); ++i) {
@@ -110,6 +107,9 @@
         dispatchExternOp.getSubgroupSizeAttr(),
         dispatchExternOp.getWorkgroupLocalMemoryAttr());
     exportOp->setDialectAttrs(dispatchExternOp->getDialectAttrs());
+    if (auto bindingsAttr = dispatchExternOp.getBindingsAttr()) {
+      exportOp->setAttr("hal.interface.bindings", bindingsAttr);
+    }
     if (!dispatchExternOp.getWorkgroupCount().empty()) {
       IRMapping mapper;
       dispatchExternOp.getWorkgroupCount().cloneInto(
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_externs.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_externs.mlir
index 70bb373..4e73fe4 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_externs.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_externs.mlir
@@ -5,6 +5,7 @@
 // CHECK-SAME:       objects([#hal.executable.object<{path = "a.o"}>])
 // CHECK-NEXT:     hal.executable.export public @main ordinal(100)
 // CHECK-SAME:         layout(#hal.pipeline.layout<push_constants = 1, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>)
+// CHECK-SAME:         hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>]
 // CHECK-NEXT:     ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
 // CHECK-NEXT:       %ok, %value = hal.device.query<%arg0 : !hal.device> key("some" :: "value") : i1, i32
 // CHECK-NEXT:       %0 = arith.index_cast %value : i32 to index
@@ -16,6 +17,7 @@
 // CHECK-NEXT:       hal.return %ok : i1
 //      CHECK:     hal.executable.export public @main ordinal(200)
 // CHECK-SAME:         layout(#hal.pipeline.layout<push_constants = 1, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>)
+// CHECK-SAME:         hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>]
 // CHECK-NEXT:     ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
 
 // Demonstrates the full functionality of an extern dispatch op.
@@ -27,9 +29,7 @@
   %y = arith.constant 50 : index
   // Dispatch workgroups to the externally defined function "main" in the
   // referenced object files.
-  // CHECK: %[[RESULT:.+]] = flow.dispatch {@extern_dispatch_0::@a::@main, @extern_dispatch_0::@b::@main}[%c100, %c50](%arg0, %arg1, %arg2) {
-  // CHECK-SAME: hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>]
-  // CHECK-SAME: } : (tensor<4xi32>, tensor<8xi32>, i32) -> %arg1
+  // CHECK: %[[RESULT:.+]] = flow.dispatch {@extern_dispatch_0::@a::@main, @extern_dispatch_0::@b::@main}
   %result = hal.dispatch.extern "main"[%x, %y](%arg0, %arg1, %arg2) : (tensor<4xi32>, tensor<8xi32>, i32) -> %arg1
     // Translates the workload (%x and %y captured above) into an XYZ workgroup
     // count, optionally using device information.
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp
index 486edc1..9676ef0 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp
@@ -128,7 +128,7 @@
   unsigned operandCount = 0;
   unsigned bindingCount = 0;
   for (auto arg : funcOp.getArgumentTypes()) {
-    if (llvm::isa<IREE::Stream::BindingType>(arg)) {
+    if (isa<IREE::Stream::BindingType>(arg)) {
       ++bindingCount;
     } else {
       ++operandCount;
@@ -140,9 +140,8 @@
   for (auto dispatchOp : dispatchOps) {
     auto resourceAccessesAttrs = dispatchOp.getResourceAccesses().getValue();
     for (unsigned i = 0; i < bindingCount; ++i) {
-      auto resourceAccessAttr =
-          llvm::cast<IREE::Stream::ResourceAccessBitfieldAttr>(
-              resourceAccessesAttrs[i]);
+      auto resourceAccessAttr = cast<IREE::Stream::ResourceAccessBitfieldAttr>(
+          resourceAccessesAttrs[i]);
       auto resourceAccess = static_cast<IREE::Stream::ResourceAccessBitfield>(
           resourceAccessAttr.getInt());
       if (!bitEnumContainsAll(resourceAccess,
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
index 6311a2b..b905fee 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
@@ -720,8 +720,8 @@
       auto caseBuilder = OpBuilder::atBlockBegin(&caseBlock);
 
       // Record push constants and buffer bindings.
-      recordParameters(loc, affinityAttr, device, commandBuffer, dispatchOp,
-                       adaptor, exportOp.getLayout(), caseBuilder);
+      recordParameters(loc, affinityAttr, device, commandBuffer, exportOp,
+                       dispatchOp, adaptor, caseBuilder);
 
       // Dispatch with a target-specific workgroup count.
       auto caseWorkgroupCount = exportOp.calculateWorkgroupCount(
@@ -749,10 +749,10 @@
 
   void recordParameters(Location loc, IREE::Stream::AffinityAttr affinityAttr,
                         Value device, Value commandBuffer,
+                        IREE::HAL::ExecutableExportOp exportOp,
                         IREE::Stream::CmdDispatchOp dispatchOp,
-                        OpAdaptor adaptor,
-                        IREE::HAL::PipelineLayoutAttr layoutAttr,
-                        OpBuilder &builder) const {
+                        OpAdaptor adaptor, OpBuilder &builder) const {
+    auto layoutAttr = exportOp.getLayout();
     auto pipelineLayout =
         builder
             .create<IREE::HAL::PipelineLayoutLookupOp>(
@@ -777,12 +777,6 @@
           builder.getIndexAttr(pushConstantBase), pushConstants);
     }
 
-    // TODO(benvanik): typed accessors for bindings.
-    auto bindingAttrs = llvm::dyn_cast_if_present<ArrayAttr>(
-        dispatchOp->getAttr("hal.interface.bindings"));
-    assert(bindingAttrs &&
-           "interface materialization must annotate dispatch sites");
-
     // Push descriptor bindings.
     int64_t currentSet = -1;
     SmallVector<IREE::HAL::DescriptorSetBindingValue> bindings;
@@ -791,9 +785,9 @@
           loc, commandBuffer, pipelineLayout, currentSet, bindings);
       bindings.clear();
     };
-    for (unsigned i = 0; i < adaptor.getResources().size(); ++i) {
-      auto bindingAttr =
-          llvm::cast<IREE::HAL::InterfaceBindingAttr>(bindingAttrs[i]);
+    auto bindingAttrs = IREE::HAL::getInterfaceBindingAttrs(
+        exportOp, dispatchOp.getResources().size());
+    for (auto [i, bindingAttr] : llvm::enumerate(bindingAttrs)) {
       int64_t set = bindingAttr.getSet();
       if (currentSet != -1 && currentSet != set)
         flushSet();
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir
index bcc2ff2..9093f31 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir
@@ -184,6 +184,10 @@
       hal.return %selected : i1
     }
     hal.executable.export public @dispatch ordinal(0) layout(#pipeline_layout) attributes {
+      hal.interface.bindings = [
+        #hal.interface.binding<0, 4>,
+        #hal.interface.binding<1, 5>
+      ],
       translation_info = #iree_codegen.translation_info<CPUDefault>
     } {
     ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index):  // no predecessors
@@ -197,6 +201,10 @@
   }
   hal.executable.variant public @x86_64 target(#executable_target_x86_64) {
     hal.executable.export public @dispatch ordinal(0) layout(#pipeline_layout) attributes {
+      hal.interface.bindings = [
+        #hal.interface.binding<0, 4>,
+        #hal.interface.binding<1, 5>
+      ],
       translation_info = #iree_codegen.translation_info<CPUDefault>
     } {
     ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index):  // no predecessors
@@ -276,11 +284,6 @@
     stream.cmd.dispatch {@ex::@aarch64::@dispatch, @ex::@x86_64::@dispatch}[%c1, %c2, %c3](%c4_i32, %c5_i32 : i32, i32) {
       ro %arg4[%c0 for %c128] : !stream.resource<transient>{%arg1},
       wo %arg5[%c0 for %c128] : !stream.resource<external>{%arg3}
-    } attributes {
-      hal.interface.bindings = [
-        #hal.interface.binding<0, 4>,
-        #hal.interface.binding<1, 5>
-      ]
     }
     // CHECK: hal.command_buffer.execution_barrier<%[[CMD]]
   } => !stream.timepoint
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
index 292c05a..afc2932 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
@@ -87,6 +87,30 @@
 }
 
 //===----------------------------------------------------------------------===//
+// Utilities
+//===----------------------------------------------------------------------===//
+
+SmallVector<IREE::HAL::InterfaceBindingAttr>
+getInterfaceBindingAttrs(Operation *op, size_t resourceCount) {
+  // It'd be nice if we had something typed here but this is just used for
+  // spooky action at a distance or user overrides. If the attribute is not
+  // found (not set by MaterializeInterfaces or the user) we construct one by
+  // convention (dense set 0 bindings for each resource).
+  auto bindingAttrs = op->getAttrOfType<ArrayAttr>("hal.interface.bindings");
+  if (bindingAttrs) {
+    return llvm::to_vector(
+        bindingAttrs.getAsRange<IREE::HAL::InterfaceBindingAttr>());
+  }
+  SmallVector<IREE::HAL::InterfaceBindingAttr> bindings;
+  for (size_t i = 0; i < resourceCount; ++i) {
+    bindings.push_back(IREE::HAL::InterfaceBindingAttr::get(op->getContext(),
+                                                            /*set=*/0,
+                                                            /*binding=*/i));
+  }
+  return bindings;
+}
+
+//===----------------------------------------------------------------------===//
 // Dialect registration
 //===----------------------------------------------------------------------===//
 
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h
index 8bf6036..e50a9d2 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h
@@ -278,4 +278,17 @@
 #include "iree/compiler/Dialect/HAL/IR/HALAttrs.h.inc" // IWYU pragma: keep
 // clang-format on
 
+//===----------------------------------------------------------------------===//
+// Utilities
+//===----------------------------------------------------------------------===//
+
+namespace mlir::iree_compiler::IREE::HAL {
+
+// Returns the assigned bindings via the `hal.interface.bindings` attribute
+// on the operation or default bindings in set 0 with bindings [0, count).
+SmallVector<IREE::HAL::InterfaceBindingAttr>
+getInterfaceBindingAttrs(Operation *op, size_t resourceCount);
+
+} // namespace mlir::iree_compiler::IREE::HAL
+
 #endif // IREE_COMPILER_DIALECT_HAL_IR_HALTYPES_H_
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp
index 4ccb387..06d2366 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp
@@ -88,6 +88,17 @@
                                       std::move(patterns)))) {
       return signalPassFailure();
     }
+
+    // Cleanup conversion attributes used for spooky action at a distance.
+    for (auto executableOp : getOperation().getOps<IREE::HAL::ExecutableOp>()) {
+      for (auto variantOp :
+           executableOp.getOps<IREE::HAL::ExecutableVariantOp>()) {
+        for (auto exportOp :
+             variantOp.getOps<IREE::HAL::ExecutableExportOp>()) {
+          exportOp->removeAttr("hal.interface.bindings");
+        }
+      }
+    }
   }
 };
 
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
index b1f902f..f090386 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
@@ -68,19 +68,14 @@
 // one entry for all dispatches with a given workgroup count.
 // Dispatches will be ignored if they have a dynamic workload or any dynamically
 // sized resources.
-static DispatchParamsMap gatherDispatchParams(mlir::ModuleOp moduleOp) {
+static DispatchParamsMap gatherDispatchParams(mlir::ModuleOp moduleOp,
+                                              SymbolTable &symbolTable) {
   DispatchParamsMap map;
 
   for (auto funcOp : moduleOp.getOps<mlir::FunctionOpInterface>()) {
     funcOp.walk([&](IREE::Stream::CmdDispatchOp dispatchOp) {
       auto affinityAttr = IREE::Stream::AffinityAttr::lookup(dispatchOp);
 
-      // TODO(benvanik): typed accessors for bindings.
-      auto bindingAttrs = llvm::dyn_cast_if_present<ArrayAttr>(
-          dispatchOp->getAttr("hal.interface.bindings"));
-      assert(bindingAttrs &&
-             "interface materialization must annotate dispatch sites");
-
       auto workloadValues = dispatchOp.getWorkload();
       SmallVector<unsigned> workload;
       workload.reserve(workloadValues.size());
@@ -97,25 +92,6 @@
         workload.push_back(workloadConstValue.getSExtValue());
       }
 
-      SmallVector<Binding> bindings;
-      for (auto [bindingAttr, resourceLength] : llvm::zip_equal(
-               bindingAttrs.getAsRange<IREE::HAL::InterfaceBindingAttr>(),
-               dispatchOp.getResourceLengths())) {
-        APInt resourceLengthInt;
-        if (!matchPattern(resourceLength, m_ConstantInt(&resourceLengthInt))) {
-          LLVM_DEBUG({
-            auto firstEntryPoint = *dispatchOp.getEntryPointRefs().begin();
-            llvm::dbgs() << "Skipping dispatch of entry point `"
-                         << firstEntryPoint
-                         << "` (non-constant resource length)\n";
-          });
-          return;
-        }
-        bindings.push_back({(unsigned)bindingAttr.getSet(),
-                            (unsigned)bindingAttr.getBinding(),
-                            resourceLengthInt.getSExtValue()});
-      }
-
       SmallVector<TypedAttr> uniformOperands;
       for (auto operand : dispatchOp.getUniformOperands()) {
         TypedAttr uniformOperand;
@@ -135,6 +111,28 @@
 
       // Work around needing a mutable key for the set; C++ was a mistake.
       dispatchOp.forEachEntryPointAttr([&](SymbolRefAttr entryPointAttr) {
+        auto exportOp =
+            symbolTable.lookupNearestSymbolFrom<IREE::HAL::ExecutableExportOp>(
+                dispatchOp, entryPointAttr);
+        auto bindingAttrs = IREE::HAL::getInterfaceBindingAttrs(
+            exportOp, dispatchOp.getResources().size());
+
+        SmallVector<Binding> bindings;
+        for (auto [bindingAttr, resourceLength] :
+             llvm::zip_equal(bindingAttrs, dispatchOp.getResourceLengths())) {
+          APInt resourceLengthInt;
+          if (!matchPattern(resourceLength,
+                            m_ConstantInt(&resourceLengthInt))) {
+            LLVM_DEBUG(llvm::dbgs() << "Skipping dispatch of entry point `"
+                                    << entryPointAttr
+                                    << "` (non-constant resource length)\n";);
+            return;
+          }
+          bindings.push_back({(unsigned)bindingAttr.getSet(),
+                              (unsigned)bindingAttr.getBinding(),
+                              resourceLengthInt.getSExtValue()});
+        }
+
         auto &dispatchParamsSet = map[entryPointAttr];
         DispatchParams *dispatchParams = nullptr;
         for (auto &it : dispatchParamsSet) {
@@ -486,12 +484,13 @@
   void runOnOperation() override {
     auto moduleOp = getOperation();
     auto moduleName = moduleOp.getName().value_or("module");
+    SymbolTable symbolTable(moduleOp);
 
     // Analyze the module to find dispatch parameters.
     // This is a full walk of all stream.cmd.dispatch ops and will handle
     // filtering out dispatches that have dynamic parameters we don't
     // currently support.
-    auto dispatchParamsMap = gatherDispatchParams(moduleOp);
+    auto dispatchParamsMap = gatherDispatchParams(moduleOp, symbolTable);
     if (dispatchParamsMap.empty()) {
       mlir::emitRemark(moduleOp.getLoc())
           << "Executable benchmarks were requested but none were generated. "
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
index 8586f75..022898f 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
@@ -40,7 +40,7 @@
 using ExportExpansions = DenseMap<Attribute, SmallVector<Attribute>>;
 
 //===----------------------------------------------------------------------===//
-// Linkage utilities
+// Utilities
 //===----------------------------------------------------------------------===//
 
 static void setApplicableObjects(Operation *sourceOp,
@@ -55,6 +55,25 @@
   targetOp.setObjectsAttr(*objects);
 }
 
+// Updates the target entry point symbols of |dispatchOp| to the expanded set of
+// variant exports in |exportExpansions|.
+static void updateDispatchTargets(IREE::Stream::CmdDispatchOp dispatchOp,
+                                  const ExportExpansions &exportExpansions) {
+  SmallVector<Attribute> newAttrs;
+  for (auto oldAttr : dispatchOp.getEntryPointRefs()) {
+    auto it = exportExpansions.find(oldAttr);
+    if (it == exportExpansions.end()) {
+      newAttrs.push_back(oldAttr); // preserve existing
+      continue;
+    }
+    for (auto newAttr : it->second) {
+      newAttrs.push_back(newAttr);
+    }
+  }
+  dispatchOp.setEntryPointsAttr(
+      ArrayAttr::get(dispatchOp.getContext(), newAttrs));
+}
+
 //===----------------------------------------------------------------------===//
 // hal.executable.source materialization
 //===----------------------------------------------------------------------===//
@@ -126,8 +145,9 @@
 }
 
 static LogicalResult
-materializeExecutablesFromSourceOps(mlir::ModuleOp moduleOp,
-                                    ExportExpansions &exportExpansions) {
+materializeExecutablesFromSourceOps(mlir::ModuleOp moduleOp) {
+  ExportExpansions exportExpansions;
+
   auto sourceOps =
       llvm::to_vector<32>(moduleOp.getOps<IREE::HAL::ExecutableSourceOp>());
   for (auto sourceOp : sourceOps) {
@@ -145,6 +165,15 @@
       return failure();
     }
   }
+  if (exportExpansions.empty())
+    return success();
+
+  for (auto funcOp : moduleOp.getOps<FunctionOpInterface>()) {
+    funcOp.walk([&](IREE::Stream::CmdDispatchOp dispatchOp) {
+      updateDispatchTargets(dispatchOp, exportExpansions);
+    });
+  }
+
   return success();
 }
 
@@ -272,49 +301,13 @@
   return clonedFuncOp;
 }
 
-// Updates the target entry point symbols of |dispatchOp| to the expanded set of
-// variant exports in |exportExpansions|.
-static void updateDispatchTargets(IREE::Stream::CmdDispatchOp dispatchOp,
-                                  const ExportExpansions &exportExpansions) {
-  SmallVector<Attribute> newAttrs;
-  for (auto oldAttr : dispatchOp.getEntryPointRefs()) {
-    auto it = exportExpansions.find(oldAttr);
-    if (it == exportExpansions.end()) {
-      newAttrs.push_back(oldAttr); // preserve existing
-      continue;
-    }
-    for (auto newAttr : it->second) {
-      newAttrs.push_back(newAttr);
-    }
-  }
-  dispatchOp.setEntryPointsAttr(
-      ArrayAttr::get(dispatchOp.getContext(), newAttrs));
-}
-
-// Annotates |dispatchOp| with resource binding to interface binding mappings.
-// TODO(benvanik): have a HAL op with structured information instead.
-static void annotateDispatchSite(IREE::Stream::CmdDispatchOp dispatchOp,
-                                 const PipelineResourceMap &resourceMap) {
-  // Ignore if bindings already defined.
-  if (dispatchOp->hasAttr("hal.interface.bindings"))
-    return;
-  SmallVector<Attribute> bindingAttrs;
-  for (auto setBinding : resourceMap) {
-    bindingAttrs.push_back(IREE::HAL::InterfaceBindingAttr::get(
-        dispatchOp.getContext(), setBinding.first, setBinding.second));
-  }
-  dispatchOp->setAttr("hal.interface.bindings",
-                      ArrayAttr::get(dispatchOp.getContext(), bindingAttrs));
-}
-
 // Adds the entry point ops with assigned ordinals for each entry function.
 // The entry points will all use the provided |interfaceOp| and be exported with
 // hal.executable.export ops.
 static LogicalResult
 declareEntryPointOps(IREE::Stream::ExecutableOp sourceExecutableOp,
                      IREE::HAL::ExecutableOp targetExecutableOp,
-                     const BindingLayoutAnalysis &layoutAnalysis,
-                     ExportExpansions &exportExpansions) {
+                     const BindingLayoutAnalysis &layoutAnalysis) {
   auto variantOps =
       targetExecutableOp.getBlock().getOps<IREE::HAL::ExecutableVariantOp>();
   OpBuilder executableBuilder(&targetExecutableOp.getBlock().front());
@@ -342,15 +335,8 @@
     const auto &pipelineLayout = layoutAnalysis.getPipelineLayout(exportOp);
     const PipelineResourceMap &resourceMap = pipelineLayout.resourceMap;
 
-    // Update all dispatch sites with the binding information required for
-    // conversion into the HAL dialect. By doing this here we ensure that the
-    // dialect conversion needs only local information on the ops and that it's
-    // not possible for the dispatches and their targets to get out of sync.
-    for (auto dispatchOp : layoutAnalysis.getExportDispatches(exportOp)) {
-      annotateDispatchSite(dispatchOp, resourceMap);
-    }
-
     // Clone the updated function declaration into each variant.
+    ExportExpansions exportExpansions;
     int ordinal = nextOrdinal++;
     for (auto variantOp : variantOps) {
       auto targetBuilder = OpBuilder::atBlockBegin(&variantOp.getBlock());
@@ -394,6 +380,17 @@
           .push_back(makeExportSymbolRefAttr(targetExecutableOp, variantOp,
                                              newExportOp));
 
+      // Annotate the export with the a mapping of the resources to the
+      // interface bindings. This is used during conversion.
+      SmallVector<Attribute> bindingAttrs;
+      for (auto setBinding : resourceMap) {
+        bindingAttrs.push_back(IREE::HAL::InterfaceBindingAttr::get(
+            newExportOp.getContext(), setBinding.first, setBinding.second));
+      }
+      newExportOp->setAttr(
+          "hal.interface.bindings",
+          ArrayAttr::get(newExportOp.getContext(), bindingAttrs));
+
       // Clone the workgroup count calculation function.
       if (!exportOp.getWorkgroupCount().empty()) {
         mlir::IRMapping mapper;
@@ -415,6 +412,11 @@
         targetFuncOps[sourceFuncOp][variantOp] = variantFuncOp;
       }
     }
+
+    // Update all dispatch sites to reference the new expanded variants.
+    for (auto dispatchOp : layoutAnalysis.getExportDispatches(exportOp)) {
+      updateDispatchTargets(dispatchOp, exportExpansions);
+    }
   }
 
   // Clone all of the ops in the source module to each variant.
@@ -524,12 +526,9 @@
   void runOnOperation() override {
     SymbolTable symbolTable(getOperation());
 
-    ExportExpansions exportExpansions;
-
     // Handle any hand-authored executables; these only need variant expansion
     // and no layout analysis as the user specified the layout themselves.
-    if (failed(materializeExecutablesFromSourceOps(getOperation(),
-                                                   exportExpansions))) {
+    if (failed(materializeExecutablesFromSourceOps(getOperation()))) {
       return signalPassFailure();
     }
 
@@ -581,8 +580,8 @@
       }
 
       // Define interfaces for each exported function based on analysis.
-      if (failed(declareEntryPointOps(sourceOp, executableOp, layoutAnalysis,
-                                      exportExpansions))) {
+      if (failed(
+              declareEntryPointOps(sourceOp, executableOp, layoutAnalysis))) {
         return signalPassFailure();
       }
 
@@ -594,52 +593,6 @@
 
       sourceOp.erase();
     }
-
-    // Do a cleanup pass for any dispatches that don't yet have interfaces
-    // assigned. If we had dispatches to externally-defined HAL executables we
-    // won't have materialized them from the stream ops above. We do expect to
-    // be able to find the dispatch targets such that we can pull out the
-    // pipeline layout, though, and any that fall through are errors.
-    auto updateDispatchSites = [&](IREE::Stream::CmdDispatchOp dispatchOp) {
-      // Update the export targets to point at the new variants.
-      updateDispatchTargets(dispatchOp, exportExpansions);
-
-      // Annotate the dispatch site with binding information if required.
-      // TODO(benvanik): remove this path; shouldn't be needed in real usage.
-      // Because this is a hack we just look for the first target entry point.
-      PipelineResourceMap resourceMap;
-      auto anyEntryPointAttr = *dispatchOp.getEntryPointRefs().begin();
-      auto anyExportOp =
-          symbolTable.lookupNearestSymbolFrom<IREE::HAL::ExecutableExportOp>(
-              dispatchOp, anyEntryPointAttr);
-      if (anyExportOp) {
-        // Export found - we can use the pipeline layout defined there to infer
-        // the bindings. This allows for bindings to be sparse or have
-        // additional information declared.
-        for (auto setLayout : anyExportOp.getLayoutAttr().getSetLayouts()) {
-          for (auto binding : setLayout.getBindings()) {
-            resourceMap.emplace_back(setLayout.getOrdinal(),
-                                     binding.getOrdinal());
-          }
-        }
-      } else {
-        // No export found - this is likely an external executable and we can
-        // infer a dense pipeline layout. This is kind of shady as we may want
-        // to error in these cases where users have something special explicitly
-        // defined but then typo things but the ergonomic improvements in the
-        // normal case are worth that risk.
-        size_t resourceCount = dispatchOp.getResources().size();
-        for (int i = 0; i < resourceCount; ++i) {
-          // set=0, binding=resource ordinal
-          resourceMap.emplace_back(0, i);
-        }
-      }
-      annotateDispatchSite(dispatchOp, resourceMap);
-      return WalkResult::advance();
-    };
-    if (getOperation()->walk(updateDispatchSites).wasInterrupted()) {
-      return signalPassFailure();
-    }
   }
 };
 
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir
index ce6168b..f77cd08 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir
@@ -4,18 +4,36 @@
 
 #executable_target_embedded_elf_aarch64 = #hal.executable.target<"llvm-cpu", "embedded-elf-aarch64">
 #executable_target_embedded_elf_x86_64 = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64">
-#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+
+// CHECK: #[[PIPELINE_LAYOUT_ATTR_0:.+]] = #hal.pipeline.layout
+#pipeline_layout_0 = #hal.pipeline.layout<push_constants = 0, sets = [
   #hal.descriptor_set.layout<0, bindings = [
+    // CHECK-SAME: <0, storage_buffer>
     #hal.descriptor_set.binding<0, storage_buffer>,
+    // CHECK-SAME: <1, storage_buffer>
     #hal.descriptor_set.binding<1, storage_buffer>,
+    // CHECK-SAME: <2, storage_buffer>
     #hal.descriptor_set.binding<2, storage_buffer>
   ]>
 ]>
+// CHECK: #[[PIPELINE_LAYOUT_ATTR_1:.+]] = #hal.pipeline.layout
+#pipeline_layout_1 = #hal.pipeline.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    // CHECK-SAME: <4, storage_buffer>
+    #hal.descriptor_set.binding<4, storage_buffer>
+  ]>,
+  #hal.descriptor_set.layout<1, bindings = [
+    // CHECK-SAME: <5, storage_buffer>
+    #hal.descriptor_set.binding<5, storage_buffer>,
+    // CHECK-SAME: <6, storage_buffer>
+    #hal.descriptor_set.binding<6, storage_buffer>
+  ]>
+]>
 
 // CHECK: hal.executable private @ex
 hal.executable private @ex {
   hal.executable.variant public @embedded_elf_aarch64 target(#executable_target_embedded_elf_aarch64) {
-    hal.executable.export public @dispatch ordinal(0) layout(#pipeline_layout) attributes {
+    hal.executable.export public @dispatch ordinal(0) layout(#pipeline_layout_0) attributes {
       translation_info = #iree_codegen.translation_info<CPUDefault>
     } {
     ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index):  // no predecessors
@@ -28,7 +46,14 @@
     }
   }
   hal.executable.variant public @embedded_elf_x86_64 target(#executable_target_embedded_elf_x86_64) {
-    hal.executable.export public @dispatch ordinal(0) layout(#pipeline_layout) attributes {
+    hal.executable.export public @dispatch ordinal(0) layout(#pipeline_layout_1) attributes {
+      // Override the bindings. The other variant uses the default ones.
+      // CHECK-NOT: hal.interface.bindings
+      hal.interface.bindings = [
+        #hal.interface.binding<0, 4>,
+        #hal.interface.binding<1, 5>,
+        #hal.interface.binding<1, 6>
+      ],
       translation_info = #iree_codegen.translation_info<CPUDefault>
     } {
     ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index):  // no predecessors
@@ -42,8 +67,8 @@
   }
 }
 
-// CHECK-LABEL: util.func public @simpleDispatch
-//  CHECK-SAME: (%[[ARG0:.+]]: !hal.buffer_view, %[[ARG1:.+]]: !hal.buffer_view) -> !hal.buffer_view
+// CHECK: util.func public @simpleDispatch
+// CHECK-SAME: (%[[ARG0:.+]]: !hal.buffer_view, %[[ARG1:.+]]: !hal.buffer_view) -> !hal.buffer_view
 util.func public @simpleDispatch(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
   %c1 = arith.constant 1 : index
   %c4 = arith.constant 4 : index
@@ -94,11 +119,11 @@
     // CHECK-DAG: %[[SWITCH0:.+]] = arith.select %[[FORMAT_AARCH64]], %c0, %[[SWITCH1]]
     // CHECK: scf.index_switch %[[SWITCH0]]
     // CHECK: case 0 {
-    // CHECK:   %[[PIPELINE_LAYOUT:.+]] = hal.pipeline_layout.lookup
+    // CHECK:   %[[PIPELINE_LAYOUT_0:.+]] = hal.pipeline_layout.lookup
     // CHECK-SAME: device(%[[DEVICE]] : !hal.device)
-    // CHECK-SAME: layout(#pipeline_layout) : !hal.pipeline_layout
+    // CHECK-SAME: layout(#[[PIPELINE_LAYOUT_ATTR_0]]) : !hal.pipeline_layout
     // CHECK:   hal.command_buffer.push_descriptor_set<%[[CMD]] : !hal.command_buffer>
-    // CHECK-SAME: layout(%[[PIPELINE_LAYOUT]] : !hal.pipeline_layout)[%c0]
+    // CHECK-SAME: layout(%[[PIPELINE_LAYOUT_0]] : !hal.pipeline_layout)[%c0]
     // CHECK-SAME: bindings([
     // CHECK:     %c0 = (%[[ARG0_BUFFER]] : !hal.buffer)[%c0, %c16],
     // CHECK:     %c1 = (%[[ARG1_BUFFER]] : !hal.buffer)[%c0, %c16],
@@ -112,6 +137,20 @@
     // CHECK:   scf.yield
     // CHECK: }
     // CHECK: case 1 {
+    // CHECK:   %[[PIPELINE_LAYOUT_1:.+]] = hal.pipeline_layout.lookup
+    // CHECK-SAME: device(%[[DEVICE]] : !hal.device)
+    // CHECK-SAME: layout(#[[PIPELINE_LAYOUT_ATTR_1]]) : !hal.pipeline_layout
+    // CHECK:   hal.command_buffer.push_descriptor_set<%[[CMD]] : !hal.command_buffer>
+    // CHECK-SAME: layout(%[[PIPELINE_LAYOUT_1]] : !hal.pipeline_layout)[%c0]
+    // CHECK-SAME: bindings([
+    // CHECK:     %c4 = (%[[ARG0_BUFFER]] : !hal.buffer)[%c0, %c16]
+    // CHECK:   ])
+    // CHECK:   hal.command_buffer.push_descriptor_set<%[[CMD]] : !hal.command_buffer>
+    // CHECK-SAME: layout(%[[PIPELINE_LAYOUT_1]] : !hal.pipeline_layout)[%c1]
+    // CHECK-SAME: bindings([
+    // CHECK:     %c5 = (%[[ARG1_BUFFER]] : !hal.buffer)[%c0, %c16],
+    // CHECK:     %c6 = (%[[RESULT_BUFFER]] : !hal.buffer)[%c0, %c16]
+    // CHECK:   ])
     // CHECK-DAG: %[[EXECUTABLE_1:.+]] = hal.executable.lookup device(%[[DEVICE]] : !hal.device) executable(@ex) : !hal.executable
     // CHECK-DAG: %[[ORDINAL_1:.+]] = hal.executable.export.ordinal target(@ex::@embedded_elf_x86_64::@dispatch) : index
     // CHECK:   hal.command_buffer.dispatch<%[[CMD]] : !hal.command_buffer>
@@ -125,12 +164,6 @@
       ro %arg0_capture[%c0 for %c16] : !stream.resource<external>{%c16},
       ro %arg1_capture[%c0 for %c16] : !stream.resource<external>{%c16},
       wo %result_capture[%c0 for %c16] : !stream.resource<external>{%c16}
-    } attributes {
-      hal.interface.bindings = [
-        #hal.interface.binding<0, 0>,
-        #hal.interface.binding<0, 1>,
-        #hal.interface.binding<0, 2>
-      ]
     }
 
   // CHECK: hal.command_buffer.execution_barrier<%[[CMD]] : !hal.command_buffer>
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_benchmarks.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_benchmarks.mlir
index a7f276c..98bcc79 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_benchmarks.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_benchmarks.mlir
@@ -134,46 +134,29 @@
         ro %result_capture[%c0 for %c32] : !stream.resource<transient>{%c128},
         rw %result_capture[%c32 for %c32] : !stream.resource<transient>{%c128},
         rw %result_capture[%c64 for %c32] : !stream.resource<transient>{%c128}
-      } attributes {hal.interface.bindings = [
-        #hal.interface.binding<0, 0>,
-        #hal.interface.binding<0, 1>,
-        #hal.interface.binding<0, 2>
-      ]}
+      }
       // NOTE: today the dynamic args will prevent us from generating
       // benchmarks. We could handle this better by tracking alignment and such.
       stream.cmd.dispatch @ex0::@embedded_elf_x86_64::@dispatch0[%c512](%c300_i32, %dynamic_arg : i32, i32) {
         ro %result_capture[%c0 for %c32] : !stream.resource<transient>{%c128},
         rw %result_capture[%c32 for %c32] : !stream.resource<transient>{%c128},
         rw %result_capture[%c64 for %c32] : !stream.resource<transient>{%c128}
-      } attributes {hal.interface.bindings = [
-        #hal.interface.binding<0, 0>,
-        #hal.interface.binding<0, 1>,
-        #hal.interface.binding<0, 2>
-      ]}
+      }
 
       // Multiple dispatches to a single entry point.
       // Dispatches are deduplicated and the two 128x32x1 should combine.
       stream.cmd.dispatch @ex0::@embedded_elf_x86_64::@dispatch1[%c512, %c1] {
         ro %result_capture[%c0 for %c64] : !stream.resource<transient>{%c128},
         rw %result_capture[%c64 for %c32] : !stream.resource<transient>{%c128}
-      } attributes {hal.interface.bindings = [
-        #hal.interface.binding<0, 0>,
-        #hal.interface.binding<0, 1>
-      ]}
+      }
       stream.cmd.dispatch @ex0::@embedded_elf_x86_64::@dispatch1[%c128, %c32] {
         ro %result_capture[%c0 for %c64] : !stream.resource<transient>{%c128},
         rw %result_capture[%c64 for %c32] : !stream.resource<transient>{%c128}
-      } attributes {hal.interface.bindings = [
-        #hal.interface.binding<0, 0>,
-        #hal.interface.binding<0, 1>
-      ]}
+      }
       stream.cmd.dispatch @ex0::@embedded_elf_x86_64::@dispatch1[%c128, %c32] {
         ro %result_capture[%c0 for %c64] : !stream.resource<transient>{%c128},
         rw %result_capture[%c64 for %c32] : !stream.resource<transient>{%c128}
-      } attributes {hal.interface.bindings = [
-        #hal.interface.binding<0, 0>,
-        #hal.interface.binding<0, 1>
-      ]}
+      }
     } => !stream.timepoint
     %39 = stream.resource.dealloca await(%6) => %result : !stream.resource<transient>{%c128} => !stream.timepoint
     util.return %39 : !stream.timepoint
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir
index 7697d0d..db45b5e 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir
@@ -21,7 +21,8 @@
 
   // CHECK: hal.executable private @ex_workgroups
   // CHECK:   hal.executable.variant public @embedded_elf_arm_64 target(#executable_target_embedded_elf_arm_64
-  // CHECK:     hal.executable.export public @entry ordinal(0) layout(#pipeline_layout) {
+  // CHECK:     hal.executable.export public @entry ordinal(0) layout(#pipeline_layout)
+  // CHECK-SAME:   hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>]
   // CHECK-NEXT: ^bb0(%[[DEVICE:.+]]: !hal.device, %[[ARG0:.+]]: index, %[[ARG1:.+]]: index):
   // CHECK-NEXT:   hal.return %[[ARG0]], %[[ARG1]], %[[ARG0]] : index, index, index
   // CHECK-NEXT: }
@@ -29,7 +30,8 @@
   // CHECK-NEXT:  func.func private @extern_func()
   // CHECK-NEXT:  func.func @entry
   // CHECK:   hal.executable.variant public @embedded_elf_x86_64 target(#executable_target_embedded_elf_x86_64
-  // CHECK:     hal.executable.export public @entry ordinal(0) layout(#pipeline_layout) {
+  // CHECK:     hal.executable.export public @entry ordinal(0) layout(#pipeline_layout)
+  // CHECK-SAME:   hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>]
   // CHECK-NEXT: ^bb0(%[[DEVICE:.+]]: !hal.device, %[[ARG0:.+]]: index, %[[ARG1:.+]]: index):
   // CHECK-NEXT:   hal.return %[[ARG0]], %[[ARG1]], %[[ARG0]] : index, index, index
   // CHECK-NEXT: }
@@ -55,11 +57,6 @@
     %0 = stream.resource.alloc uninitialized : !stream.resource<transient>{%arg2}
     %1 = stream.cmd.execute with(%arg0 as %arg4: !stream.resource<constant>{%arg2}, %arg1 as %arg5: !stream.resource<transient>{%arg2}, %0 as %arg6: !stream.resource<transient>{%arg2}) {
       // CHECK: stream.cmd.dispatch {@ex_workgroups::@embedded_elf_arm_64::@entry, @ex_workgroups::@embedded_elf_x86_64::@entry}
-      // CHECK: attributes {
-      // CHECK-SAME: hal.interface.bindings = [
-      // CHECK-SAME:   #hal.interface.binding<0, 0>,
-      // CHECK-SAME:   #hal.interface.binding<0, 1>,
-      // CHECK-SAME:   #hal.interface.binding<0, 2>
       stream.cmd.dispatch @ex_workgroups::@entry[%c1, %c2](%arg3 : i32) {
         ro %arg4[%c0 for %arg2] : !stream.resource<constant>{%arg2},
         ro %arg5[%c0 for %arg2] : !stream.resource<transient>{%arg2},
diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/test/inline_executables.mlir b/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/test/inline_executables.mlir
index f88d821..1ee7e25 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/test/inline_executables.mlir
+++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/test/inline_executables.mlir
@@ -178,12 +178,6 @@
       ro %resource0_inner[%binding0_offset for %binding0_length] : !stream.resource<constant>{%resource_size},
       ro %resource1_inner[%binding1_offset for %binding1_length] : !stream.resource<transient>{%resource_size},
       wo %resource2_inner[%binding2_offset for %binding2_length] : !stream.resource<external>{%resource_size}
-    } attributes {
-      hal.interface.bindings = [
-        #hal.interface.binding<0, 0>,
-        #hal.interface.binding<0, 1>,
-        #hal.interface.binding<0, 2>
-      ]
     }
   } => !stream.timepoint
   util.return
diff --git a/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/test/cmd_ops.mlir b/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/test/cmd_ops.mlir
index 75625e8..a767492 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/test/cmd_ops.mlir
+++ b/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/test/cmd_ops.mlir
@@ -80,11 +80,6 @@
     stream.cmd.dispatch @ex::@dispatch[%workload_x, %workload_y](%constant0, %constant1 : i32, i32) {
       ro %buffer0_inner[%buffer0_offset for %buffer0_length] : !stream.resource<transient>{%buffer0_size},
       wo %buffer1_inner[%buffer1_offset for %buffer1_length] : !stream.resource<external>{%buffer1_size}
-    } attributes {
-      hal.interface.bindings = [
-        #hal.interface.binding<0, 4>,
-        #hal.interface.binding<1, 5>
-      ]
     }
   } => !stream.timepoint
   // CHECK: return %c0
diff --git a/samples/custom_dispatch/cpu/embedded/example_hal.mlir b/samples/custom_dispatch/cpu/embedded/example_hal.mlir
index 7bddf07..d94d34c 100644
--- a/samples/custom_dispatch/cpu/embedded/example_hal.mlir
+++ b/samples/custom_dispatch/cpu/embedded/example_hal.mlir
@@ -90,7 +90,19 @@
                 <1, storage_buffer, ReadOnly>,
                 <2, storage_buffer>
             ]>
-          ]>) {
+          ]>) attributes {
+            // Bindings are automatically inferred when possible as part of the
+            // ABI but can be overridden if the user wants to use features such
+            // as sparse bindings or multiple descriptor sets. To do so the
+            // `hal.interface.bindings` attribute can be added to a dispatch op
+            // as follows mapping tensor operands/results to the pipeline layout
+            // sets/bindings:
+            hal.interface.bindings = [
+              #hal.interface.binding<0, 0>,
+              #hal.interface.binding<0, 1>,
+              #hal.interface.binding<0, 2>
+            ]
+          } {
       ^bb0(%device: !hal.device, %workload: index):
         // This host function is used to compute the XYZ workgroup count
         // dispatched at runtime. It can query the %device for capabilities
@@ -245,17 +257,6 @@
 
     // Dispatch a basic `ret = lhs * rhs` using an external function.
     %0 = flow.dispatch @executable::@x86_64::@simple_mul[%dim](%dim_i32, %arg0, %arg1) {
-      // Bindings are automatically inferred when possible as part of the ABI
-      // but can be overridden if the user wants to use features such as sparse
-      // bindings or multiple descriptor sets. To do so the
-      // `hal.interface.bindings` attribute can be added to a dispatch op as
-      // follows mapping tensor operands/results to the pipeline layout
-      // sets/bindings:
-      hal.interface.bindings = [
-        #hal.interface.binding<0, 0>,
-        #hal.interface.binding<0, 1>,
-        #hal.interface.binding<0, 2>
-      ],
       // HACK: keep the executable live through DCE. Only required when
       // using the automatic variant selection.
       // TODO(benvanik): automatically add this when required.
diff --git a/samples/custom_dispatch/cpu/embedded/example_transform_spec.mlir b/samples/custom_dispatch/cpu/embedded/example_transform_spec.mlir
index c709e20..b19e6d3 100644
--- a/samples/custom_dispatch/cpu/embedded/example_transform_spec.mlir
+++ b/samples/custom_dispatch/cpu/embedded/example_transform_spec.mlir
@@ -74,13 +74,7 @@
     %dim_i32 = arith.index_cast %dim : index to i32
 
     // Dispatch a basic `ret = -|lhs * rhs|` using an external function.
-    %0 = flow.dispatch @executable::@x86_64::@simple_mul_abs_negate[%dim](%dim_i32, %arg0, %arg1) {
-      hal.interface.bindings = [
-        #hal.interface.binding<0, 0>,
-        #hal.interface.binding<0, 1>,
-        #hal.interface.binding<0, 2>
-      ]
-    } : (i32, tensor<?xf32>{%dim}, tensor<?xf32>{%dim}) -> tensor<?xf32>{%dim}
+    %0 = flow.dispatch @executable::@x86_64::@simple_mul_abs_negate[%dim](%dim_i32, %arg0, %arg1) : (i32, tensor<?xf32>{%dim}, tensor<?xf32>{%dim}) -> tensor<?xf32>{%dim}
 
     util.return %0 : tensor<?xf32>
   }
diff --git a/samples/custom_dispatch/cpu/mlp_plugin/mlp_spec.mlir b/samples/custom_dispatch/cpu/mlp_plugin/mlp_spec.mlir
index eec83f7..58f3f15 100644
--- a/samples/custom_dispatch/cpu/mlp_plugin/mlp_spec.mlir
+++ b/samples/custom_dispatch/cpu/mlp_plugin/mlp_spec.mlir
@@ -61,13 +61,7 @@
     %n_i32 = arith.index_cast %n : index to i32
     %k_i32 = arith.index_cast %k : index to i32
 
-    %mlp_result = flow.dispatch @executable::@x86_64::@mlp(%lhs, %rhs, %m_i32, %n_i32, %k_i32) {
-      hal.interface.bindings = [
-        #hal.interface.binding<0, 0>,
-        #hal.interface.binding<0, 1>,
-        #hal.interface.binding<0, 2>
-      ]
-    } : (tensor<?x?xf32>{%m, %k}, tensor<?x?xf32>{%k, %n}, i32, i32, i32) -> tensor<?x?xf32>{%m, %n}
+    %mlp_result = flow.dispatch @executable::@x86_64::@mlp(%lhs, %rhs, %m_i32, %n_i32, %k_i32) : (tensor<?x?xf32>{%m, %k}, tensor<?x?xf32>{%k, %n}, i32, i32, i32) -> tensor<?x?xf32>{%m, %n}
 
     util.return %mlp_result : tensor<?x?xf32>
   }
diff --git a/samples/custom_dispatch/cuda/kernels/example.mlir b/samples/custom_dispatch/cuda/kernels/example.mlir
index 1438b20..1c64964 100644
--- a/samples/custom_dispatch/cuda/kernels/example.mlir
+++ b/samples/custom_dispatch/cuda/kernels/example.mlir
@@ -86,7 +86,18 @@
         ]>) attributes {
       // Certain backends (like CUDA) require a workgroup size (aka block
       // size) to be defined ahead of time.
-      workgroup_size = [64 : index, 1 : index, 1 : index]
+      workgroup_size = [64 : index, 1 : index, 1 : index],
+      // Bindings are automatically inferred when possible as part of the ABI
+      // but can be overridden if the user wants to use features such as sparse
+      // bindings or multiple descriptor sets. To do so the
+      // `hal.interface.bindings` attribute can be added to a dispatch op as
+      // follows mapping tensor operands/results to the pipeline layout
+      // sets/bindings:
+      hal.interface.bindings = [
+        #hal.interface.binding<0, 0>,
+        #hal.interface.binding<0, 1>,
+        #hal.interface.binding<0, 2>
+      ]
     } {
     ^bb0(%device: !hal.device, %workload: index):
       // This host function is used to compute the XYZ workgroup count
@@ -137,19 +148,7 @@
     %dim_i32 = arith.index_cast %dim : index to i32
 
     // Dispatch a basic `ret = lhs * rhs` kernel.
-    %0 = flow.dispatch @executable::@simple_mul[%dim](%dim_i32, %arg0, %arg1) {
-      // Bindings are automatically inferred when possible as part of the ABI
-      // but can be overridden if the user wants to use features such as sparse
-      // bindings or multiple descriptor sets. To do so the
-      // `hal.interface.bindings` attribute can be added to a dispatch op as
-      // follows mapping tensor operands/results to the pipeline layout
-      // sets/bindings:
-      hal.interface.bindings = [
-        #hal.interface.binding<0, 0>,
-        #hal.interface.binding<0, 1>,
-        #hal.interface.binding<0, 2>
-      ]
-    } : (i32, tensor<?xf32>{%dim}, tensor<?xf32>{%dim}) -> tensor<?xf32>{%dim}
+    %0 = flow.dispatch @executable::@simple_mul[%dim](%dim_i32, %arg0, %arg1) : (i32, tensor<?xf32>{%dim}, tensor<?xf32>{%dim}) -> tensor<?xf32>{%dim}
 
     // Code gen some other ops - these will interleave with the hand-authored
     // ones but naturally won't be able to fuse with them.
diff --git a/samples/custom_dispatch/vulkan/shaders/example.mlir b/samples/custom_dispatch/vulkan/shaders/example.mlir
index 0aa2b31..4dd479d 100644
--- a/samples/custom_dispatch/vulkan/shaders/example.mlir
+++ b/samples/custom_dispatch/vulkan/shaders/example.mlir
@@ -78,7 +78,19 @@
               <1, storage_buffer, ReadOnly>,
               <2, storage_buffer>
           ]>
-        ]>) {
+        ]>) attributes {
+          // Bindings are automatically inferred when possible as part of the
+          // ABI but can be overridden if the user wants to use features such as
+          // sparse bindings or multiple descriptor sets. To do so the
+          // `hal.interface.bindings` attribute can be added to an export op as
+          // follows mapping tensor operands/results to the pipeline layout
+          // sets/bindings:
+          hal.interface.bindings = [
+            #hal.interface.binding<0, 0>,
+            #hal.interface.binding<0, 1>,
+            #hal.interface.binding<0, 2>
+          ]
+        } {
     ^bb0(%device: !hal.device, %workload: index):
       // This host function is used to compute the XYZ workgroup count
       // dispatched at runtime. It can query the %device for capabilities
@@ -136,29 +148,13 @@
     %dim_i32 = arith.index_cast %dim : index to i32
 
     // Dispatch a basic `ret = lhs * rhs` shader.
-    %0 = flow.dispatch @simple_mul::@main[%dim](%dim_i32, %arg0, %arg1) {
-      // Bindings are automatically inferred when possible as part of the ABI
-      // but can be overridden if the user wants to use features such as sparse
-      // bindings or multiple descriptor sets. To do so the
-      // `hal.interface.bindings` attribute can be added to a dispatch op as
-      // follows mapping tensor operands/results to the pipeline layout
-      // sets/bindings:
-      hal.interface.bindings = [
-        #hal.interface.binding<0, 0>,
-        #hal.interface.binding<0, 1>,
-        #hal.interface.binding<0, 2>
-      ]
-    } : (i32, tensor<?xf32>{%dim}, tensor<?xf32>{%dim}) -> tensor<?xf32>{%dim}
+    %0 = flow.dispatch @simple_mul::@main[%dim](%dim_i32, %arg0, %arg1) : (i32, tensor<?xf32>{%dim}, tensor<?xf32>{%dim}) -> tensor<?xf32>{%dim}
 
     // Code gen some other ops - these will interleave with the hand-authored
     // ones but naturally won't be able to fuse with them.
     %1 = arith.addf %0, %arg1 : tensor<?xf32>
 
     // Dispatch an in-place `rhs *= lhs` shader.
-    //
-    // Note that we don't declare the hal.interface.bindings and let them be
-    // inferred - this only works when either specifying the variant that has
-    // a pipeline layout defined or all variants have the same pipeline layouts.
     %2 = flow.dispatch @simple_mul_inplace::@main[%dim](%dim_i32, %0, %1) : (i32, tensor<?xf32>{%dim}, tensor<?xf32>{%dim}) -> %1{%dim}
 
     // CHECK: 8xf32=96 96 96 96 96 96 96 96