Adding condition support to `hal.dispatch.extern`. (#15797)

Required switching from `hal.dispatch.extern` outlining into a
`flow.executable` to instead outlining into a `hal.executable` with
variants. This is nice as it prevents us from needing to carry through
all the information a `hal.executable` could have through flow/stream
but does mean we need to support `flow.dispatch`/`stream.async.dispatch`
on expanded executables with variants that may have multiple export
symbols. In the common case (codegen) there's no change.

It's not great this is happening in flow - future work may try to move
outlining of the extern ops to the end of the global optimization
pipeline (assuming that happens _after_ plugins that may insert the
extern ops). There's a few refactorings in other passes we'd need to
make that happen.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
index 36eb30a..f0a1565 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
@@ -787,7 +787,6 @@
                                $tied_operands)
   }];
 
-  let skipDefaultBuilders = 1;
   let builders = [
     OpBuilder<(ins
       "ExecutableExportOp":$exportOp, "ValueRange":$workload,
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
index 5a2d2c9..f626437 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
@@ -52,6 +52,7 @@
         "InsertDispatchDebugTargets.cpp",
         "InterchangeGenericOps.cpp",
         "InterchangeTransposeGenericOps.cpp",
+        "OutlineDispatchExterns.cpp",
         "OutlineDispatchRegions.cpp",
         "PassDetail.h",
         "Passes.cpp",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index 67eea6f..b0ba163 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -51,6 +51,7 @@
     "InsertDispatchDebugTargets.cpp"
     "InterchangeGenericOps.cpp"
     "InterchangeTransposeGenericOps.cpp"
+    "OutlineDispatchExterns.cpp"
     "OutlineDispatchRegions.cpp"
     "PassDetail.h"
     "Passes.cpp"
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchExterns.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchExterns.cpp
new file mode 100644
index 0000000..4389863
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchExterns.cpp
@@ -0,0 +1,177 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <utility>
+
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
+#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
+#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Flow {
+namespace {
+
+//===----------------------------------------------------------------------===//
+// hal.dispatch.extern
+//===----------------------------------------------------------------------===//
+
+// Converts a dispatch region op into a dispatch op to the outlined region.
+static LogicalResult
+convertDispatchExternToDispatchOp(IREE::HAL::DispatchExternOp dispatchExternOp,
+                                  ArrayRef<Attribute> exportRefs) {
+  // Insert at the same place as the original region.
+  OpBuilder builder(dispatchExternOp);
+
+  // Create the dispatch op to the executable function.
+  // Note that we copy the tied operand indices from the workgroups op - it
+  // lines up 1:1 with the dispatch once we've outlined things.
+  auto dispatchOp = builder.create<IREE::Flow::DispatchOp>(
+      dispatchExternOp.getLoc(), dispatchExternOp.getResultTypes(),
+      dispatchExternOp.getWorkload(), builder.getArrayAttr(exportRefs),
+      dispatchExternOp.getArguments(), dispatchExternOp.getArgumentDims(),
+      dispatchExternOp.getResultDims(), dispatchExternOp.getTiedOperandsAttr());
+  dispatchOp->setDialectAttrs(dispatchExternOp->getDialectAttrs());
+  if (auto bindingsAttr = dispatchExternOp.getBindingsAttr()) {
+    dispatchOp->setAttr("hal.interface.bindings", bindingsAttr);
+  }
+
+  // Replace uses of the existing results with the new results.
+  for (int i = 0; i < dispatchExternOp.getNumResults(); ++i) {
+    dispatchExternOp.getResult(i).replaceAllUsesWith(dispatchOp.getResult(i));
+  }
+
+  return success();
+}
+
+// Outlines a dispatch region into a flow.executable and replaces the region op
+// with a dispatch to that outlined executable.
+static LogicalResult
+outlineDispatchExternOp(std::string name,
+                        IREE::HAL::DispatchExternOp dispatchExternOp) {
+  // Create the executable that will contain the outlined region.
+  // NOTE: this will get uniquified if we have multiple in the same block.
+  auto parentFuncOp = dispatchExternOp->getParentOfType<FunctionOpInterface>();
+  auto parentModuleOp = parentFuncOp->getParentOfType<mlir::ModuleOp>();
+  OpBuilder parentModuleBuilder(&parentModuleOp.getBody()->back());
+  auto executableOp = parentModuleBuilder.create<IREE::HAL::ExecutableOp>(
+      dispatchExternOp.getLoc(), name);
+  executableOp.getOperation()->moveBefore(parentFuncOp);
+  executableOp.setPrivate();
+
+  // Add one variant per object target.
+  SymbolTable executableSymbolTable(executableOp);
+  OpBuilder executableBuilder(executableOp.getBody());
+  SmallVector<Attribute> exportRefs;
+  for (auto [targetAttr, targetOrdinalAttr, targetObjectsAttr,
+             targetConditionRegion] :
+       llvm::zip_equal(
+           dispatchExternOp.getTargetsAttr()
+               .getAsRange<IREE::HAL::ExecutableTargetAttr>(),
+           dispatchExternOp.getTargetOrdinalsAttr().getAsRange<IntegerAttr>(),
+           dispatchExternOp.getTargetObjectsAttr().getAsRange<ArrayAttr>(),
+           dispatchExternOp.getTargetRegions())) {
+    // Create the variant for the given target. Note that we may have multiple
+    // variants that use the same base targetAttr but have unique condition
+    // regions so we rely on the symbol table for uniquing names.
+    auto variantOp = executableBuilder.create<IREE::HAL::ExecutableVariantOp>(
+        dispatchExternOp.getLoc(), targetAttr.getSymbolNameFragment(),
+        targetAttr);
+    variantOp.setObjectsAttr(targetObjectsAttr);
+    executableSymbolTable.insert(variantOp);
+
+    // Move over optional target condition region to a condition op.
+    OpBuilder variantBuilder(variantOp.getBody());
+    if (!targetConditionRegion.empty()) {
+      auto conditionOp =
+          variantBuilder.create<IREE::HAL::ExecutableConditionOp>(
+              dispatchExternOp.getLoc());
+      IRMapping mapper;
+      targetConditionRegion.cloneInto(&conditionOp.getBody(), mapper);
+    }
+
+    // Add an export pointing at the entry point function.
+    auto exportOp = variantBuilder.create<IREE::HAL::ExecutableExportOp>(
+        dispatchExternOp.getLoc(), dispatchExternOp.getExportAttr(),
+        targetOrdinalAttr, dispatchExternOp.getLayoutAttr(),
+        dispatchExternOp.getWorkgroupSizeAttr(),
+        dispatchExternOp.getSubgroupSizeAttr(),
+        dispatchExternOp.getWorkgroupLocalMemoryAttr());
+    exportOp->setDialectAttrs(dispatchExternOp->getDialectAttrs());
+    if (!dispatchExternOp.getWorkgroupCount().empty()) {
+      IRMapping mapper;
+      dispatchExternOp.getWorkgroupCount().cloneInto(
+          &exportOp.getWorkgroupCount(), mapper);
+    }
+
+    exportRefs.push_back(
+        SymbolRefAttr::get(executableOp.getNameAttr(),
+                           {FlatSymbolRefAttr::get(variantOp.getNameAttr()),
+                            FlatSymbolRefAttr::get(exportOp.getNameAttr())}));
+  }
+
+  // Finally convert the dispatch region into a dispatch to the external
+  // exports.
+  return convertDispatchExternToDispatchOp(dispatchExternOp, exportRefs);
+}
+
+} // namespace
+
+class OutlineDispatchExternsPass
+    : public OutlineDispatchExternsBase<OutlineDispatchExternsPass> {
+public:
+  OutlineDispatchExternsPass() = default;
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<IREE::Flow::FlowDialect>();
+    registry.insert<IREE::HAL::HALDialect>();
+  }
+
+  void runOnOperation() override {
+    for (auto funcOp : getOperation().getOps<FunctionOpInterface>()) {
+      // Outline all of the dispatch externs ops in this function.
+      SmallVector<Operation *> deadOps;
+      auto outlineOps = [&](Operation *op) {
+        return TypeSwitch<Operation *, WalkResult>(op)
+            .Case<IREE::HAL::DispatchExternOp>([&](auto dispatchExternOp) {
+              if (failed(outlineDispatchExternOp(
+                      ("extern_dispatch_" + llvm::Twine(deadOps.size())).str(),
+                      dispatchExternOp))) {
+                return WalkResult::interrupt();
+              }
+              deadOps.push_back(op);
+              return WalkResult::advance();
+            })
+            .Default(WalkResult::advance());
+      };
+      if (funcOp.walk(outlineOps).wasInterrupted())
+        return signalPassFailure();
+      for (auto *deadOp : deadOps)
+        deadOp->erase();
+    }
+  }
+};
+
+std::unique_ptr<OperationPass<mlir::ModuleOp>>
+createOutlineDispatchExternsPass() {
+  return std::make_unique<OutlineDispatchExternsPass>();
+}
+
+} // namespace Flow
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp
index c8d097a..54de80e 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp
@@ -9,7 +9,6 @@
 #include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
 #include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
 #include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
-#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
 #include "iree/compiler/Dialect/Util/IR/UtilOps.h"
 #include "llvm/Support/Debug.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -155,75 +154,6 @@
                                                executableOp, exportOp);
 }
 
-//===----------------------------------------------------------------------===//
-// hal.dispatch.extern
-//===----------------------------------------------------------------------===//
-
-// Converts a dispatch region op into a dispatch op to the outlined region.
-static LogicalResult
-convertDispatchExternToDispatchOp(IREE::HAL::DispatchExternOp dispatchExternOp,
-                                  IREE::Flow::ExecutableOp executableOp,
-                                  IREE::Flow::ExecutableExportOp exportOp) {
-  // Insert at the same place as the original region.
-  OpBuilder builder(dispatchExternOp);
-
-  // Create the dispatch op to the executable function.
-  // Note that we copy the tied operand indices from the workgroups op - it
-  // lines up 1:1 with the dispatch once we've outlined things.
-  auto dispatchOp = builder.create<IREE::Flow::DispatchOp>(
-      dispatchExternOp.getLoc(), exportOp, dispatchExternOp.getWorkload(),
-      dispatchExternOp.getResultTypes(), dispatchExternOp.getResultDims(),
-      dispatchExternOp.getArguments(), dispatchExternOp.getArgumentDims(),
-      dispatchExternOp.getTiedOperandsAttr());
-  dispatchOp->setDialectAttrs(dispatchExternOp->getDialectAttrs());
-  if (auto bindingsAttr = dispatchExternOp.getBindingsAttr()) {
-    dispatchOp->setAttr("hal.interface.bindings", bindingsAttr);
-  }
-
-  // Replace uses of the existing results with the new results.
-  for (int i = 0; i < dispatchExternOp.getNumResults(); ++i) {
-    dispatchExternOp.getResult(i).replaceAllUsesWith(dispatchOp.getResult(i));
-  }
-
-  return success();
-}
-
-// Outlines a dispatch region into a flow.executable and replaces the region op
-// with a dispatch to that outlined executable.
-static LogicalResult
-outlineDispatchExternOp(std::string name,
-                        IREE::HAL::DispatchExternOp dispatchExternOp) {
-  // Create the executable that will contain the outlined region.
-  // NOTE: this will get uniquified if we have multiple in the same block.
-  auto parentFuncOp = dispatchExternOp->getParentOfType<FunctionOpInterface>();
-  auto parentModuleOp = parentFuncOp->getParentOfType<mlir::ModuleOp>();
-  OpBuilder parentModuleBuilder(&parentModuleOp.getBody()->back());
-  auto executableOp = parentModuleBuilder.create<IREE::Flow::ExecutableOp>(
-      dispatchExternOp.getLoc(), name);
-  executableOp.getOperation()->moveBefore(parentFuncOp);
-  executableOp.setPrivate();
-  executableOp->setAttr("hal.executable.objects",
-                        dispatchExternOp.getObjectsAttr());
-
-  // Add an export pointing at the entry point function.
-  OpBuilder builder(executableOp.getBody());
-  auto exportOp = builder.create<IREE::Flow::ExecutableExportOp>(
-      dispatchExternOp.getLoc(), dispatchExternOp.getExport(),
-      FlatSymbolRefAttr::get(builder.getContext(),
-                             dispatchExternOp.getExport()));
-  exportOp->setDialectAttrs(dispatchExternOp->getDialectAttrs());
-  exportOp->setAttr("hal.interface.layout", dispatchExternOp.getLayoutAttr());
-
-  // Move over the workgroup count region, if present.
-  if (!dispatchExternOp.getWorkgroupCount().empty()) {
-    exportOp.getWorkgroupCount().takeBody(dispatchExternOp.getWorkgroupCount());
-  }
-
-  // Finally convert the dispatch region into a dispatch to the outlined func.
-  return convertDispatchExternToDispatchOp(dispatchExternOp, executableOp,
-                                           exportOp);
-}
-
 } // namespace
 
 class OutlineDispatchRegionsPass
@@ -269,16 +199,6 @@
                   deadOps.push_back(op);
                   return WalkResult::advance();
                 })
-            .Case<IREE::HAL::DispatchExternOp>([&](auto dispatchExternOp) {
-              if (failed(outlineDispatchExternOp(
-                      (namePrefix + "_dispatch_" + llvm::Twine(deadOps.size()))
-                          .str(),
-                      dispatchExternOp))) {
-                return WalkResult::interrupt();
-              }
-              deadOps.push_back(op);
-              return WalkResult::advance();
-            })
             .Default(WalkResult::advance());
       };
       if (funcOp.walk(outlineOps).wasInterrupted())
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index b5d3814..1ff1d52 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -199,6 +199,7 @@
 
   // Module pass to outline dispatch regions (and similar ops) into their own
   // functions wrapped in executables.
+  passManager.addPass(IREE::Flow::createOutlineDispatchExternsPass());
   passManager.addPass(IREE::Flow::createOutlineDispatchRegionsPass());
 
   // Annotate executables based on their contents.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
index 3c09524..ca7f569 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -153,6 +153,10 @@
 // Captures dynamic shape dimensions required by dispatch operands.
 std::unique_ptr<Pass> createCaptureDispatchDynamicDimsPass();
 
+// Outlines external dispatches into executables.
+std::unique_ptr<OperationPass<mlir::ModuleOp>>
+createOutlineDispatchExternsPass();
+
 // Outlines dispatch regions into executables.
 std::unique_ptr<OperationPass<mlir::ModuleOp>>
 createOutlineDispatchRegionsPass();
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
index c6973de..8ac6e5a 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
@@ -229,6 +229,12 @@
   let constructor = "mlir::iree_compiler::IREE::Flow::createInterchangeTransposeGenericOpsPass()";
 }
 
+def OutlineDispatchExterns :
+    Pass<"iree-flow-outline-dispatch-externs", "mlir::ModuleOp"> {
+  let summary = "Outlines external dispatches into executables";
+  let constructor = "mlir::iree_compiler::IREE::Flow::createOutlineDispatchExternsPass()";
+}
+
 def OutlineDispatchRegions :
     Pass<"iree-flow-outline-dispatch-regions", "mlir::ModuleOp"> {
   let summary = "Outlines dispatch regions into executables";
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
index f009f30..58a515a 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
@@ -39,6 +39,7 @@
             "insert_dispatch_debug_markers.mlir",
             "interchange_generic_ops.mlir",
             "interchange_transpose_generic_ops.mlir",
+            "outline_dispatch_externs.mlir",
             "outline_dispatch_regions.mlir",
             "pad_fusion_with_consumer.mlir",
             "pad_fusion_with_producer.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
index aa57301..ea8e021 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
@@ -37,6 +37,7 @@
     "insert_dispatch_debug_markers.mlir"
     "interchange_generic_ops.mlir"
     "interchange_transpose_generic_ops.mlir"
+    "outline_dispatch_externs.mlir"
     "outline_dispatch_regions.mlir"
     "pad_fusion_with_consumer.mlir"
     "pad_fusion_with_producer.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_externs.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_externs.mlir
new file mode 100644
index 0000000..7571eb3
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_externs.mlir
@@ -0,0 +1,65 @@
+// RUN: iree-opt --allow-unregistered-dialect --split-input-file --iree-flow-outline-dispatch-externs --mlir-print-local-scope %s | FileCheck %s
+
+//      CHECK: hal.executable private @extern_dispatch_0
+// CHECK-NEXT:   hal.executable.variant public @a target(<"llvm-cpu", "a">)
+// CHECK-SAME:       objects([#hal.executable.object<{path = "a.o"}>])
+// CHECK-NEXT:     hal.executable.export public @main ordinal(100)
+// CHECK-SAME:         layout(#hal.pipeline.layout<push_constants = 1, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>)
+// CHECK-NEXT:     ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
+// CHECK-NEXT:       %ok, %value = hal.device.query<%arg0 : !hal.device> key("some" :: "value") : i1, i32
+// CHECK-NEXT:       %0 = arith.index_cast %value : i32 to index
+// CHECK-NEXT:       hal.return %arg1, %arg2, %0 : index, index, index
+//      CHECK:   hal.executable.variant public @b target(<"llvm-cpu", "b">)
+// CHECK-SAME:       objects([#hal.executable.object<{path = "b.o"}>])
+// CHECK-NEXT:     hal.executable.condition(%arg0: !hal.device) -> i1 {
+// CHECK-NEXT:       %ok, %value = hal.device.query<%arg0 : !hal.device> key("some" :: "feature") : i1, i32
+// CHECK-NEXT:       hal.return %ok : i1
+//      CHECK:     hal.executable.export public @main ordinal(200)
+// CHECK-SAME:         layout(#hal.pipeline.layout<push_constants = 1, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>)
+// CHECK-NEXT:     ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
+
+// Demonstrates the full functionality of an extern dispatch op.
+// Note that some fields are optional.
+
+// CHECK-LABEL: func.func @dispatchExtern
+func.func @dispatchExtern(%arg0: tensor<4xi32>, %arg1: tensor<8xi32>, %arg2: i32) -> tensor<8xi32> {
+  %x = arith.constant 100 : index
+  %y = arith.constant 50 : index
+  // Dispatch workgroups to the externally defined function "main" in the
+  // referenced object files.
+  // CHECK: %[[RESULT:.+]] = flow.dispatch {@extern_dispatch_0::@a::@main, @extern_dispatch_0::@b::@main}[%c100, %c50](%arg0, %arg1, %arg2) {
+  // CHECK-SAME: hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>]
+  // CHECK-SAME: } : (tensor<4xi32>, tensor<8xi32>, i32) -> %arg1
+  %result = hal.dispatch.extern "main"[%x, %y](%arg0, %arg1, %arg2) : (tensor<4xi32>, tensor<8xi32>, i32) -> %arg1
+    // Translates the workload (%x and %y captured above) into an XYZ workgroup
+    // count, optionally using device information.
+    count(%device: !hal.device, %x_capture: index, %y_capture: index) -> (index, index, index) {
+      // Shows how device queries can be used when computing the workgroup count.
+      // The device is the one used at runtime.
+      %ok, %z_i32 = hal.device.query<%device : !hal.device> key("some" :: "value") : i1, i32
+      %z = arith.index_cast %z_i32 : i32 to index
+      hal.return %x_capture, %y_capture, %z : index, index, index
+    }
+    // Must match the external definition.
+    layout(#hal.pipeline.layout<push_constants = 1, sets = [
+      <0, bindings = [
+          <0, storage_buffer, ReadOnly>,
+          <1, storage_buffer>
+      ]>
+    ]>)
+    // Optional, automatically inferred if omitted.
+    bindings([
+      #hal.interface.binding<0, 0>,
+      #hal.interface.binding<0, 1>
+    ])
+    // Can have object references for multiple targets or configurations.
+    objects({
+      #hal.executable.target<"llvm-cpu", "a"> ordinal(100) = [#hal.executable.object<{path = "a.o"}>],
+      #hal.executable.target<"llvm-cpu", "b"> if(%device: !hal.device) -> i1 {
+        %ok, %z_i32 = hal.device.query<%device : !hal.device> key("some" :: "feature") : i1, i32
+        hal.return %ok : i1
+      } ordinal(200) = [#hal.executable.object<{path = "b.o"}>]
+    })
+  // CHECK: return %[[RESULT]]
+  return %result : tensor<8xi32>
+}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir
index 03462c5..67e90e1 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir
@@ -176,58 +176,3 @@
   }
   return %0 : tensor<4xi32>
 }
-
-// -----
-
-//      CHECK: flow.executable private @dispatchExtern_dispatch_0
-// CHECK-NEXT:   flow.executable.export public @main
-// CHECK-SAME:     workgroups(%arg0: !hal.device, %arg1: index, %arg2: index) -> (index, index, index) {
-// CHECK-NEXT:       %ok, %value = hal.device.query<%arg0 : !hal.device> key("some" :: "value") : i1, i32
-// CHECK-NEXT:       %0 = arith.index_cast %value : i32 to index
-// CHECK-NEXT:       hal.return %arg1, %arg2, %0 : index, index, index
-// CHECK-NEXT:     } attributes {
-// CHECK-SAME:       hal.interface.layout = #hal.pipeline.layout<push_constants = 1, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>
-// CHECK-SAME:     }
-
-// Demonstrates the full functionality of an extern dispatch op.
-// Note that some fields are optional.
-
-// CHECK-LABEL: func.func @dispatchExtern
-func.func @dispatchExtern(%arg0: tensor<4xi32>, %arg1: tensor<8xi32>, %arg2: i32) -> tensor<8xi32> {
-  %x = arith.constant 100 : index
-  %y = arith.constant 50 : index
-  // Dispatch workgroups to the externally defined function "main" in the
-  // referenced object files.
-  // CHECK: %[[RESULT:.+]] = flow.dispatch @dispatchExtern_dispatch_0::@main[%c100, %c50](%arg0, %arg1, %arg2) {
-  // CHECK-SAME: hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>]
-  // CHECK-SAME: } : (tensor<4xi32>, tensor<8xi32>, i32) -> %arg1
-  %result = hal.dispatch.extern "main"[%x, %y](%arg0, %arg1, %arg2) : (tensor<4xi32>, tensor<8xi32>, i32) -> %arg1
-    // Translates the workload (%x and %y captured above) into an XYZ workgroup
-    // count, optionally using device information.
-    count(%device: !hal.device, %x_capture: index, %y_capture: index) -> (index, index, index) {
-      // Shows how device queries can be used when computing the workgroup count.
-      // The device is the one used at runtime.
-      %ok, %z_i32 = hal.device.query<%device : !hal.device> key("some" :: "value") : i1, i32
-      %z = arith.index_cast %z_i32 : i32 to index
-      hal.return %x_capture, %y_capture, %z : index, index, index
-    }
-    // Must match the external definition.
-    layout(#hal.pipeline.layout<push_constants = 1, sets = [
-      <0, bindings = [
-          <0, storage_buffer, ReadOnly>,
-          <1, storage_buffer>
-      ]>
-    ]>)
-    // Optional, automatically inferred if omitted.
-    bindings([
-      #hal.interface.binding<0, 0>,
-      #hal.interface.binding<0, 1>
-    ])
-    // Can have object references for multiple targets or configurations.
-    objects(#hal.executable.objects<{
-      #hal.executable.target<"llvm-cpu", "a"> = [#hal.executable.object<{path = "a.o"}>],
-      #hal.executable.target<"llvm-cpu", "b"> = [#hal.executable.object<{path = "b.o"}>]
-    }>)
-  // CHECK: return %[[RESULT]]
-  return %result : tensor<8xi32>
-}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td
index 3ecd4bc..1d42aa0 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td
@@ -493,6 +493,7 @@
 ]>;
 
 def HAL_OrdinalAttr : Util_IndexAttrBase<"size_t">;
+def HAL_OrdinalArrayAttr : TypedArrayAttrBase<HAL_OrdinalAttr, "Array of index ordinal attributes">;
 
 def HAL_ExecutableDataAttr : SignlessIntElementsAttr<8>;
 
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index 77c2125..b9655c3 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -194,6 +194,85 @@
                 /*printBlockTerminators=*/true);
 }
 
+static ParseResult parseTargetConditionObjects(
+    OpAsmParser &parser, ArrayAttr &targetsAttr, ArrayAttr &targetOrdinalsAttr,
+    ArrayAttr &targetObjectsAttr,
+    SmallVector<std::unique_ptr<Region>, 2> &targetRegions) {
+  SmallVector<Attribute> targetsAttrs;
+  SmallVector<Attribute> targetOrdinalsAttrs;
+  SmallVector<Attribute> targetObjectsAttrs;
+  do {
+    // #hal.executable.target<...>
+    Attribute targetAttr;
+    if (failed(parser.parseAttribute(targetAttr)))
+      return failure();
+    targetsAttrs.push_back(targetAttr);
+
+    // if(...) -> i1 { ... }
+    auto region = std::make_unique<Region>();
+    if (succeeded(parser.parseOptionalKeyword("if"))) {
+      if (failed(parseTargetConditionRegion(parser, *region)))
+        return failure();
+    }
+    targetRegions.push_back(std::move(region));
+
+    // ordinal(#)
+    Attribute targetOrdinalAttr;
+    if (failed(parser.parseKeyword("ordinal")) ||
+        failed(parser.parseLParen()) ||
+        failed(parser.parseAttribute(targetOrdinalAttr,
+                                     IndexType::get(parser.getContext()))) ||
+        failed(parser.parseRParen()))
+      return failure();
+    targetOrdinalsAttrs.push_back(targetOrdinalAttr);
+
+    // = [#hal.executable.object<...>, ...]
+    ArrayAttr targetObjectsAttr;
+    if (failed(parser.parseEqual()) ||
+        failed(parser.parseAttribute(targetObjectsAttr)))
+      return failure();
+    targetObjectsAttrs.push_back(targetObjectsAttr);
+  } while (succeeded(parser.parseOptionalComma()));
+  targetsAttr = ArrayAttr::get(parser.getContext(), targetsAttrs);
+  targetOrdinalsAttr = ArrayAttr::get(parser.getContext(), targetOrdinalsAttrs);
+  targetObjectsAttr = ArrayAttr::get(parser.getContext(), targetObjectsAttrs);
+  return success();
+}
+
+static void printTargetConditionObjects(OpAsmPrinter &p, Operation *op,
+                                        ArrayAttr targetsAttr,
+                                        ArrayAttr targetOrdinalsAttr,
+                                        ArrayAttr targetObjectsAttr,
+                                        MutableArrayRef<Region> targetRegions) {
+  p.increaseIndent();
+  p.printNewline();
+
+  llvm::interleave(
+      llvm::zip_equal(targetsAttr, targetOrdinalsAttr, targetObjectsAttr,
+                      targetRegions),
+      [&](auto it) {
+        auto &[targetAttr, targetOrdinalAttr, targetObjectsAttr, targetRegion] =
+            it;
+        p.printAttribute(targetAttr);
+        if (!targetRegion.empty()) {
+          p << " if";
+          printTargetConditionRegion(p, op, targetRegion);
+        }
+        p << " ordinal(";
+        p.printAttributeWithoutType(targetOrdinalAttr);
+        p << ")";
+        p << " = ";
+        p.printAttribute(targetObjectsAttr);
+      },
+      [&]() {
+        p << ",";
+        p.printNewline();
+      });
+
+  p.decreaseIndent();
+  p.printNewline();
+}
+
 //===----------------------------------------------------------------------===//
 // custom<WorkgroupCountRegion>($body)
 //===----------------------------------------------------------------------===//
@@ -467,6 +546,7 @@
                              ValueRange resultDims, ValueRange arguments,
                              ValueRange argumentDims,
                              ArrayRef<int64_t> tiedOperands,
+                             IREE::HAL::ExecutableObjectsAttr targetObjects,
                              ArrayRef<NamedAttribute> attributes) {
   state.addTypes(resultTypes);
   state.addOperands(workload);
@@ -477,6 +557,8 @@
   state.attributes.erase(IREE::Util::TiedOpInterface::getStorageAttrName());
   state.addAttribute(IREE::Util::TiedOpInterface::getStorageAttrName(),
                      builder.getIndexArrayAttr(tiedOperands));
+  state.addAttribute("targets", targetObjects.getTargets());
+  state.addAttribute("target_objects", targetObjects.getTargetObjects());
   state.attributes.erase(getOperandSegmentSizeAttr());
   state.addAttribute(getOperandSegmentSizeAttr(),
                      builder.getDenseI32ArrayAttr({
@@ -488,6 +570,10 @@
 
   // NOTE: workgroup count region is empty; callers are expected to populate it.
   state.addRegion();
+
+  // Add one empty region per target.
+  for (size_t i = 0; i < targetObjects.getTargets().size(); ++i)
+    state.addRegion();
 }
 
 // Verifies that |dynamicDims| contains the appropriate number of dims for all
@@ -569,6 +655,19 @@
     return failure();
   }
 
+  if (getTargets().size() != getTargetObjects().size()) {
+    return op->emitOpError() << "target and objects arrays must match";
+  }
+  if (getTargets().size() != getTargetRegions().size()) {
+    return op->emitOpError()
+           << "target and condition regions must match (but they may be empty)";
+  }
+  for (auto &targetRegion : getTargetRegions()) {
+    if (failed(verifyTargetConditionRegion(op, targetRegion))) {
+      return failure();
+    }
+  }
+
   return success();
 }
 
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
index 1d65042..50414a6 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -362,6 +362,13 @@
     E.g., in the above example, `-> %0` would tie the first argument to the
     result. In that case, there would be no separate block argument for the
     result.
+
+    Objects for multiple targets can be specified and the ones used are selected
+    based on their target and an optional condition region that returns true if
+    the variant is valid for use on the provided runtime `!hal.device`. If no
+    variants within an executable are valid then loading will fail at runtime.
+    If multiple variants are valid the first valid one found will be loaded and
+    used for execution.
   }];
 
   let arguments = (ins
@@ -371,7 +378,9 @@
     HAL_ShapeDynamicDims:$argument_dims,
     HAL_ShapeDynamicDims:$result_dims,
     HAL_PipelineLayoutAttr:$layout,
-    HAL_ExecutableObjectsAttr:$objects,
+    ArrayAttr:$targets,
+    HAL_OrdinalArrayAttr:$target_ordinals,
+    ArrayAttr:$target_objects,
     OptionalAttr<HAL_WorkgroupSizeAttr>:$workgroup_size,
     OptionalAttr<HAL_SubgroupSizeAttr>:$subgroup_size,
     OptionalAttr<IndexAttr>:$workgroup_local_memory,
@@ -383,7 +392,8 @@
   );
 
   let regions = (region
-    AnyRegion:$workgroup_count
+    AnyRegion:$workgroup_count,
+    VariadicRegion<AnyRegion>:$target_regions
   );
 
   let assemblyFormat = [{
@@ -397,7 +407,10 @@
     `count` `` custom<WorkgroupCountRegion>($workgroup_count)
     `layout` `(` $layout `)`
     (`bindings` `(` $bindings^ `)`)?
-    `objects` `(` $objects `)`
+    `objects` `(` `{` custom<TargetConditionObjects>($targets,
+                                                     $target_ordinals,
+                                                     $target_objects,
+                                                     $target_regions) `}` `)`
     attr-dict-with-keyword
   }];
 
@@ -408,6 +421,7 @@
       "TypeRange":$resultTypes, "ValueRange":$resultDims,
       "ValueRange":$arguments, "ValueRange":$argumentDims,
       "ArrayRef<int64_t>":$tiedOperands,
+      "IREE::HAL::ExecutableObjectsAttr":$targetObjects,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
   ];
 
@@ -429,6 +443,10 @@
     ValueRange getResultDynamicDims(unsigned idx) {
       return IREE::Util::findVariadicDynamicDims(idx, getResults(), getResultDims());
     }
+
+    IREE::HAL::ExecutableObjectsAttr getObjectsAttr() {
+      return IREE::HAL::ExecutableObjectsAttr::get(getContext(), getTargets(), getTargetObjects());
+    }
   }];
 
   let hasVerifier = 1;
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/tensor_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/tensor_ops.mlir
index 42f1d9e..45e1c6f 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/tensor_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/tensor_ops.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file %s | iree-opt --split-input-file | FileCheck %s
+// RUN: iree-opt --split-input-file --mlir-print-local-scope %s | iree-opt --split-input-file --mlir-print-local-scope | FileCheck %s
 
 // CHECK-LABEL: @tensorImportStatic
 func.func @tensorImportStatic(%arg0: !hal.buffer_view) -> tensor<5xi32> {
@@ -59,21 +59,28 @@
 
 // CHECK-LABEL: func.func @dispatchExtern
 func.func @dispatchExtern(%arg0: tensor<4xi32>, %arg1: tensor<8xi32>, %arg2: i32) -> tensor<8xi32> {
-  %x = arith.constant 100 : index
-  %y = arith.constant 50 : index
+  // CHECK-DAG: %[[WORKLOAD_X:.+]] = arith.constant 100
+  %workload_x = arith.constant 100 : index
+  // CHECK-DAG: %[[WORKLOAD_Y:.+]] = arith.constant 50
+  %workload_y = arith.constant 50 : index
+
   // Dispatch workgroups to the externally defined function "main" in the
-  // referenced object files.
-  %0 = hal.dispatch.extern "main"[%x, %y](%arg0, %arg1, %arg2) : (tensor<4xi32>, tensor<8xi32>, i32) -> %arg1
+  // referenced object files with the ordinal specified per object group.
+  // CHECK: %[[RESULT:.+]] = hal.dispatch.extern "main"[%[[WORKLOAD_X]], %[[WORKLOAD_Y]]](%arg0, %arg1, %arg2) : (tensor<4xi32>, tensor<8xi32>, i32) -> %arg1
+  %0 = hal.dispatch.extern "main"[%workload_x, %workload_y](%arg0, %arg1, %arg2) : (tensor<4xi32>, tensor<8xi32>, i32) -> %arg1
     // Translates the workload (%x and %y captured above) into an XYZ workgroup
     // count, optionally using device information.
+    // CHECK: count(%[[DEVICE:.+]]: !hal.device, %[[X_CAPTURE:.+]]: index, %[[Y_CAPTURE:.+]]: index) -> (index, index, index) {
     count(%device: !hal.device, %x_capture: index, %y_capture: index) -> (index, index, index) {
       // Shows how device queries can be used when computing the workgroup count.
       // The device is the one used at runtime.
+      // CHECK: = hal.device.query<%[[DEVICE]] : !hal.device>
       %ok, %z_i32 = hal.device.query<%device : !hal.device> key("some" :: "value") : i1, i32
       %z = arith.index_cast %z_i32 : i32 to index
       hal.return %x_capture, %y_capture, %z : index, index, index
     }
     // Must match the external definition.
+    // CHECK: layout(<push_constants = 1, sets =
     layout(#hal.pipeline.layout<push_constants = 1, sets = [
       <0, bindings = [
           <0, storage_buffer, ReadOnly>,
@@ -81,14 +88,31 @@
       ]>
     ]>)
     // Optional, automatically inferred if omitted.
+    // CHECK: bindings([#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>])
     bindings([
       #hal.interface.binding<0, 0>,
       #hal.interface.binding<0, 1>
     ])
     // Can have object references for multiple targets or configurations.
-    objects(#hal.executable.objects<{
-      #hal.executable.target<"llvm-cpu", "a"> = [#hal.executable.object<{path = "a.o"}>],
-      #hal.executable.target<"llvm-cpu", "b"> = [#hal.executable.object<{path = "b.o"}>]
-    }>)
+    // CHECK: objects({
+    objects({
+      // CHECK: #hal.executable.target<"llvm-cpu", "a"> ordinal(100) = [#hal.executable.object<{path = "a.o"}>]
+      #hal.executable.target<"llvm-cpu", "a"> ordinal(100) = [#hal.executable.object<{path = "a.o"}>],
+      // CHECK: #hal.executable.target<"llvm-cpu", "b"> if(%[[B_DEVICE:.+]]: !hal.device) -> i1 {
+      #hal.executable.target<"llvm-cpu", "b"> if(%device: !hal.device) -> i1 {
+        // CHECK: = hal.device.query<%[[B_DEVICE]] : !hal.device>
+        %ok, %z_i32 = hal.device.query<%device : !hal.device> key("some" :: "feature_b") : i1, i32
+        hal.return %ok : i1
+      // CHECK: } ordinal(200) = [#hal.executable.object<{path = "b.o"}>]
+      } ordinal(200) = [#hal.executable.object<{path = "b.o"}>],
+      // CHECK: #hal.executable.target<"llvm-cpu", "c"> if(%[[C_DEVICE:.+]]: !hal.device) -> i1 {
+      #hal.executable.target<"llvm-cpu", "c"> if(%device: !hal.device) -> i1 {
+        // CHECK: = hal.device.query<%[[C_DEVICE]] : !hal.device>
+        %ok, %z_i32 = hal.device.query<%device : !hal.device> key("some" :: "feature_c") : i1, i32
+        hal.return %ok : i1
+      // CHECK: } ordinal(300) = [#hal.executable.object<{path = "c.o"}>]
+      } ordinal(300) = [#hal.executable.object<{path = "c.o"}>]
+    })
+  // CHECK: return %[[RESULT]]
   return %0 : tensor<8xi32>
 }
diff --git a/samples/custom_dispatch/vulkan/shaders/example_inline.mlir b/samples/custom_dispatch/vulkan/shaders/example_inline.mlir
index 980df87..5979d76 100644
--- a/samples/custom_dispatch/vulkan/shaders/example_inline.mlir
+++ b/samples/custom_dispatch/vulkan/shaders/example_inline.mlir
@@ -52,13 +52,15 @@
     %dim_i32 = arith.index_cast %dim : index to i32
 
     // Dispatch a basic `ret = lhs * rhs` shader.
+    // Note that not all backends use names or the names are derived from
+    // ordinals so we include that (`:ordinal`).
     %0 = hal.dispatch.extern "main"[%dim](%dim_i32, %arg0, %arg1) : (i32, tensor<?xf32>{%dim}, tensor<?xf32>{%dim}) -> tensor<?xf32>{%dim}
+      // This host function is used to compute the XYZ workgroup count
+      // dispatched at runtime. It can query the %device for capabilities
+      // and limits (shared memory size, etc). The other arguments are the
+      // values passed in the dispatch operation (usually things like root
+      // output op tensor dimensions and other abstract values).
       count(%device: !hal.device, %workload: index) -> (index, index, index) {
-        // This host function is used to compute the XYZ workgroup count
-        // dispatched at runtime. It can query the %device for capabilities
-        // and limits (shared memory size, etc). The other arguments are the
-        // values passed in the dispatch operation (usually things like root
-        // output op tensor dimensions and other abstract values).
         %x = affine.apply affine_map<()[s0] -> (s0 ceildiv 64)>()[%workload]
         %c1 = arith.constant 1 : index
         hal.return %x, %c1, %c1 : index, index, index
@@ -90,8 +92,8 @@
       // keys can be generic. This allows for an object file to be linked in based
       // only on the target triple while allowing for more specialized ones
       // requiring certain CPU features to be only included when building those.
-      objects(#hal.executable.objects<{
-        #spirv_target = [
+      objects({
+        #spirv_target ordinal(0) = [
           #hal.executable.object<{
             // Referencing a file path on disk but could also have the data
             // embedded in order to make the MLIR file hermetic/portable across
@@ -103,7 +105,7 @@
             path = "samples/custom_dispatch/vulkan/shaders/simple_mul.spv"
           }>
         ]
-      }>)
+      })
 
     // Code gen some other ops - these will interleave with the hand-authored
     // ones but naturally won't be able to fuse with them.