Adding condition support to `hal.dispatch.extern`. (#15797)
Required switching from `hal.dispatch.extern` outlining into a
`flow.executable` to instead outlining into a `hal.executable` with
variants. This is nice as it prevents us from needing to carry through
all the information a `hal.executable` could have through flow/stream
but does mean we need to support `flow.dispatch`/`stream.async.dispatch`
on expanded executables with variants that may have multiple export
symbols. In the common case (codegen) there's no change.
It's not great this is happening in flow - future work may try to move
outlining of the extern ops to the end of the global optimization
pipeline (assuming that happens _after_ plugins that may insert the
extern ops). There's a few refactorings in other passes we'd need to
make that happen.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
index 36eb30a..f0a1565 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
@@ -787,7 +787,6 @@
$tied_operands)
}];
- let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins
"ExecutableExportOp":$exportOp, "ValueRange":$workload,
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
index 5a2d2c9..f626437 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
@@ -52,6 +52,7 @@
"InsertDispatchDebugTargets.cpp",
"InterchangeGenericOps.cpp",
"InterchangeTransposeGenericOps.cpp",
+ "OutlineDispatchExterns.cpp",
"OutlineDispatchRegions.cpp",
"PassDetail.h",
"Passes.cpp",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index 67eea6f..b0ba163 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -51,6 +51,7 @@
"InsertDispatchDebugTargets.cpp"
"InterchangeGenericOps.cpp"
"InterchangeTransposeGenericOps.cpp"
+ "OutlineDispatchExterns.cpp"
"OutlineDispatchRegions.cpp"
"PassDetail.h"
"Passes.cpp"
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchExterns.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchExterns.cpp
new file mode 100644
index 0000000..4389863
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchExterns.cpp
@@ -0,0 +1,177 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <utility>
+
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
+#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
+#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Flow {
+namespace {
+
+//===----------------------------------------------------------------------===//
+// hal.dispatch.extern
+//===----------------------------------------------------------------------===//
+
+// Converts a dispatch region op into a dispatch op to the outlined region.
+static LogicalResult
+convertDispatchExternToDispatchOp(IREE::HAL::DispatchExternOp dispatchExternOp,
+ ArrayRef<Attribute> exportRefs) {
+ // Insert at the same place as the original region.
+ OpBuilder builder(dispatchExternOp);
+
+ // Create the dispatch op to the executable function.
+ // Note that we copy the tied operand indices from the workgroups op - it
+ // lines up 1:1 with the dispatch once we've outlined things.
+ auto dispatchOp = builder.create<IREE::Flow::DispatchOp>(
+ dispatchExternOp.getLoc(), dispatchExternOp.getResultTypes(),
+ dispatchExternOp.getWorkload(), builder.getArrayAttr(exportRefs),
+ dispatchExternOp.getArguments(), dispatchExternOp.getArgumentDims(),
+ dispatchExternOp.getResultDims(), dispatchExternOp.getTiedOperandsAttr());
+ dispatchOp->setDialectAttrs(dispatchExternOp->getDialectAttrs());
+ if (auto bindingsAttr = dispatchExternOp.getBindingsAttr()) {
+ dispatchOp->setAttr("hal.interface.bindings", bindingsAttr);
+ }
+
+ // Replace uses of the existing results with the new results.
+ for (int i = 0; i < dispatchExternOp.getNumResults(); ++i) {
+ dispatchExternOp.getResult(i).replaceAllUsesWith(dispatchOp.getResult(i));
+ }
+
+ return success();
+}
+
+// Outlines a dispatch region into a flow.executable and replaces the region op
+// with a dispatch to that outlined executable.
+static LogicalResult
+outlineDispatchExternOp(std::string name,
+ IREE::HAL::DispatchExternOp dispatchExternOp) {
+ // Create the executable that will contain the outlined region.
+ // NOTE: this will get uniquified if we have multiple in the same block.
+ auto parentFuncOp = dispatchExternOp->getParentOfType<FunctionOpInterface>();
+ auto parentModuleOp = parentFuncOp->getParentOfType<mlir::ModuleOp>();
+ OpBuilder parentModuleBuilder(&parentModuleOp.getBody()->back());
+ auto executableOp = parentModuleBuilder.create<IREE::HAL::ExecutableOp>(
+ dispatchExternOp.getLoc(), name);
+ executableOp.getOperation()->moveBefore(parentFuncOp);
+ executableOp.setPrivate();
+
+ // Add one variant per object target.
+ SymbolTable executableSymbolTable(executableOp);
+ OpBuilder executableBuilder(executableOp.getBody());
+ SmallVector<Attribute> exportRefs;
+ for (auto [targetAttr, targetOrdinalAttr, targetObjectsAttr,
+ targetConditionRegion] :
+ llvm::zip_equal(
+ dispatchExternOp.getTargetsAttr()
+ .getAsRange<IREE::HAL::ExecutableTargetAttr>(),
+ dispatchExternOp.getTargetOrdinalsAttr().getAsRange<IntegerAttr>(),
+ dispatchExternOp.getTargetObjectsAttr().getAsRange<ArrayAttr>(),
+ dispatchExternOp.getTargetRegions())) {
+ // Create the variant for the given target. Note that we may have multiple
+ // variants that use the same base targetAttr but have unique condition
+ // regions so we rely on the symbol table for uniquing names.
+ auto variantOp = executableBuilder.create<IREE::HAL::ExecutableVariantOp>(
+ dispatchExternOp.getLoc(), targetAttr.getSymbolNameFragment(),
+ targetAttr);
+ variantOp.setObjectsAttr(targetObjectsAttr);
+ executableSymbolTable.insert(variantOp);
+
+ // Move over optional target condition region to a condition op.
+ OpBuilder variantBuilder(variantOp.getBody());
+ if (!targetConditionRegion.empty()) {
+ auto conditionOp =
+ variantBuilder.create<IREE::HAL::ExecutableConditionOp>(
+ dispatchExternOp.getLoc());
+ IRMapping mapper;
+ targetConditionRegion.cloneInto(&conditionOp.getBody(), mapper);
+ }
+
+ // Add an export pointing at the entry point function.
+ auto exportOp = variantBuilder.create<IREE::HAL::ExecutableExportOp>(
+ dispatchExternOp.getLoc(), dispatchExternOp.getExportAttr(),
+ targetOrdinalAttr, dispatchExternOp.getLayoutAttr(),
+ dispatchExternOp.getWorkgroupSizeAttr(),
+ dispatchExternOp.getSubgroupSizeAttr(),
+ dispatchExternOp.getWorkgroupLocalMemoryAttr());
+ exportOp->setDialectAttrs(dispatchExternOp->getDialectAttrs());
+ if (!dispatchExternOp.getWorkgroupCount().empty()) {
+ IRMapping mapper;
+ dispatchExternOp.getWorkgroupCount().cloneInto(
+ &exportOp.getWorkgroupCount(), mapper);
+ }
+
+ exportRefs.push_back(
+ SymbolRefAttr::get(executableOp.getNameAttr(),
+ {FlatSymbolRefAttr::get(variantOp.getNameAttr()),
+ FlatSymbolRefAttr::get(exportOp.getNameAttr())}));
+ }
+
+ // Finally convert the dispatch region into a dispatch to the external
+ // exports.
+ return convertDispatchExternToDispatchOp(dispatchExternOp, exportRefs);
+}
+
+} // namespace
+
+class OutlineDispatchExternsPass
+ : public OutlineDispatchExternsBase<OutlineDispatchExternsPass> {
+public:
+ OutlineDispatchExternsPass() = default;
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<IREE::Flow::FlowDialect>();
+ registry.insert<IREE::HAL::HALDialect>();
+ }
+
+ void runOnOperation() override {
+ for (auto funcOp : getOperation().getOps<FunctionOpInterface>()) {
+ // Outline all of the dispatch externs ops in this function.
+ SmallVector<Operation *> deadOps;
+ auto outlineOps = [&](Operation *op) {
+ return TypeSwitch<Operation *, WalkResult>(op)
+ .Case<IREE::HAL::DispatchExternOp>([&](auto dispatchExternOp) {
+ if (failed(outlineDispatchExternOp(
+ ("extern_dispatch_" + llvm::Twine(deadOps.size())).str(),
+ dispatchExternOp))) {
+ return WalkResult::interrupt();
+ }
+ deadOps.push_back(op);
+ return WalkResult::advance();
+ })
+ .Default(WalkResult::advance());
+ };
+ if (funcOp.walk(outlineOps).wasInterrupted())
+ return signalPassFailure();
+ for (auto *deadOp : deadOps)
+ deadOp->erase();
+ }
+ }
+};
+
+std::unique_ptr<OperationPass<mlir::ModuleOp>>
+createOutlineDispatchExternsPass() {
+ return std::make_unique<OutlineDispatchExternsPass>();
+}
+
+} // namespace Flow
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp
index c8d097a..54de80e 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp
@@ -9,7 +9,6 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
-#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -155,75 +154,6 @@
executableOp, exportOp);
}
-//===----------------------------------------------------------------------===//
-// hal.dispatch.extern
-//===----------------------------------------------------------------------===//
-
-// Converts a dispatch region op into a dispatch op to the outlined region.
-static LogicalResult
-convertDispatchExternToDispatchOp(IREE::HAL::DispatchExternOp dispatchExternOp,
- IREE::Flow::ExecutableOp executableOp,
- IREE::Flow::ExecutableExportOp exportOp) {
- // Insert at the same place as the original region.
- OpBuilder builder(dispatchExternOp);
-
- // Create the dispatch op to the executable function.
- // Note that we copy the tied operand indices from the workgroups op - it
- // lines up 1:1 with the dispatch once we've outlined things.
- auto dispatchOp = builder.create<IREE::Flow::DispatchOp>(
- dispatchExternOp.getLoc(), exportOp, dispatchExternOp.getWorkload(),
- dispatchExternOp.getResultTypes(), dispatchExternOp.getResultDims(),
- dispatchExternOp.getArguments(), dispatchExternOp.getArgumentDims(),
- dispatchExternOp.getTiedOperandsAttr());
- dispatchOp->setDialectAttrs(dispatchExternOp->getDialectAttrs());
- if (auto bindingsAttr = dispatchExternOp.getBindingsAttr()) {
- dispatchOp->setAttr("hal.interface.bindings", bindingsAttr);
- }
-
- // Replace uses of the existing results with the new results.
- for (int i = 0; i < dispatchExternOp.getNumResults(); ++i) {
- dispatchExternOp.getResult(i).replaceAllUsesWith(dispatchOp.getResult(i));
- }
-
- return success();
-}
-
-// Outlines a dispatch region into a flow.executable and replaces the region op
-// with a dispatch to that outlined executable.
-static LogicalResult
-outlineDispatchExternOp(std::string name,
- IREE::HAL::DispatchExternOp dispatchExternOp) {
- // Create the executable that will contain the outlined region.
- // NOTE: this will get uniquified if we have multiple in the same block.
- auto parentFuncOp = dispatchExternOp->getParentOfType<FunctionOpInterface>();
- auto parentModuleOp = parentFuncOp->getParentOfType<mlir::ModuleOp>();
- OpBuilder parentModuleBuilder(&parentModuleOp.getBody()->back());
- auto executableOp = parentModuleBuilder.create<IREE::Flow::ExecutableOp>(
- dispatchExternOp.getLoc(), name);
- executableOp.getOperation()->moveBefore(parentFuncOp);
- executableOp.setPrivate();
- executableOp->setAttr("hal.executable.objects",
- dispatchExternOp.getObjectsAttr());
-
- // Add an export pointing at the entry point function.
- OpBuilder builder(executableOp.getBody());
- auto exportOp = builder.create<IREE::Flow::ExecutableExportOp>(
- dispatchExternOp.getLoc(), dispatchExternOp.getExport(),
- FlatSymbolRefAttr::get(builder.getContext(),
- dispatchExternOp.getExport()));
- exportOp->setDialectAttrs(dispatchExternOp->getDialectAttrs());
- exportOp->setAttr("hal.interface.layout", dispatchExternOp.getLayoutAttr());
-
- // Move over the workgroup count region, if present.
- if (!dispatchExternOp.getWorkgroupCount().empty()) {
- exportOp.getWorkgroupCount().takeBody(dispatchExternOp.getWorkgroupCount());
- }
-
- // Finally convert the dispatch region into a dispatch to the outlined func.
- return convertDispatchExternToDispatchOp(dispatchExternOp, executableOp,
- exportOp);
-}
-
} // namespace
class OutlineDispatchRegionsPass
@@ -269,16 +199,6 @@
deadOps.push_back(op);
return WalkResult::advance();
})
- .Case<IREE::HAL::DispatchExternOp>([&](auto dispatchExternOp) {
- if (failed(outlineDispatchExternOp(
- (namePrefix + "_dispatch_" + llvm::Twine(deadOps.size()))
- .str(),
- dispatchExternOp))) {
- return WalkResult::interrupt();
- }
- deadOps.push_back(op);
- return WalkResult::advance();
- })
.Default(WalkResult::advance());
};
if (funcOp.walk(outlineOps).wasInterrupted())
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index b5d3814..1ff1d52 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -199,6 +199,7 @@
// Module pass to outline dispatch regions (and similar ops) into their own
// functions wrapped in executables.
+ passManager.addPass(IREE::Flow::createOutlineDispatchExternsPass());
passManager.addPass(IREE::Flow::createOutlineDispatchRegionsPass());
// Annotate executables based on their contents.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
index 3c09524..ca7f569 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -153,6 +153,10 @@
// Captures dynamic shape dimensions required by dispatch operands.
std::unique_ptr<Pass> createCaptureDispatchDynamicDimsPass();
+// Outlines external dispatches into executables.
+std::unique_ptr<OperationPass<mlir::ModuleOp>>
+createOutlineDispatchExternsPass();
+
// Outlines dispatch regions into executables.
std::unique_ptr<OperationPass<mlir::ModuleOp>>
createOutlineDispatchRegionsPass();
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
index c6973de..8ac6e5a 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
@@ -229,6 +229,12 @@
let constructor = "mlir::iree_compiler::IREE::Flow::createInterchangeTransposeGenericOpsPass()";
}
+def OutlineDispatchExterns :
+ Pass<"iree-flow-outline-dispatch-externs", "mlir::ModuleOp"> {
+ let summary = "Outlines external dispatches into executables";
+ let constructor = "mlir::iree_compiler::IREE::Flow::createOutlineDispatchExternsPass()";
+}
+
def OutlineDispatchRegions :
Pass<"iree-flow-outline-dispatch-regions", "mlir::ModuleOp"> {
let summary = "Outlines dispatch regions into executables";
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
index f009f30..58a515a 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
@@ -39,6 +39,7 @@
"insert_dispatch_debug_markers.mlir",
"interchange_generic_ops.mlir",
"interchange_transpose_generic_ops.mlir",
+ "outline_dispatch_externs.mlir",
"outline_dispatch_regions.mlir",
"pad_fusion_with_consumer.mlir",
"pad_fusion_with_producer.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
index aa57301..ea8e021 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
@@ -37,6 +37,7 @@
"insert_dispatch_debug_markers.mlir"
"interchange_generic_ops.mlir"
"interchange_transpose_generic_ops.mlir"
+ "outline_dispatch_externs.mlir"
"outline_dispatch_regions.mlir"
"pad_fusion_with_consumer.mlir"
"pad_fusion_with_producer.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_externs.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_externs.mlir
new file mode 100644
index 0000000..7571eb3
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_externs.mlir
@@ -0,0 +1,65 @@
+// RUN: iree-opt --allow-unregistered-dialect --split-input-file --iree-flow-outline-dispatch-externs --mlir-print-local-scope %s | FileCheck %s
+
+// CHECK: hal.executable private @extern_dispatch_0
+// CHECK-NEXT: hal.executable.variant public @a target(<"llvm-cpu", "a">)
+// CHECK-SAME: objects([#hal.executable.object<{path = "a.o"}>])
+// CHECK-NEXT: hal.executable.export public @main ordinal(100)
+// CHECK-SAME: layout(#hal.pipeline.layout<push_constants = 1, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>)
+// CHECK-NEXT: ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
+// CHECK-NEXT: %ok, %value = hal.device.query<%arg0 : !hal.device> key("some" :: "value") : i1, i32
+// CHECK-NEXT: %0 = arith.index_cast %value : i32 to index
+// CHECK-NEXT: hal.return %arg1, %arg2, %0 : index, index, index
+// CHECK: hal.executable.variant public @b target(<"llvm-cpu", "b">)
+// CHECK-SAME: objects([#hal.executable.object<{path = "b.o"}>])
+// CHECK-NEXT: hal.executable.condition(%arg0: !hal.device) -> i1 {
+// CHECK-NEXT: %ok, %value = hal.device.query<%arg0 : !hal.device> key("some" :: "feature") : i1, i32
+// CHECK-NEXT: hal.return %ok : i1
+// CHECK: hal.executable.export public @main ordinal(200)
+// CHECK-SAME: layout(#hal.pipeline.layout<push_constants = 1, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>)
+// CHECK-NEXT: ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
+
+// Demonstrates the full functionality of an extern dispatch op.
+// Note that some fields are optional.
+
+// CHECK-LABEL: func.func @dispatchExtern
+func.func @dispatchExtern(%arg0: tensor<4xi32>, %arg1: tensor<8xi32>, %arg2: i32) -> tensor<8xi32> {
+ %x = arith.constant 100 : index
+ %y = arith.constant 50 : index
+ // Dispatch workgroups to the externally defined function "main" in the
+ // referenced object files.
+ // CHECK: %[[RESULT:.+]] = flow.dispatch {@extern_dispatch_0::@a::@main, @extern_dispatch_0::@b::@main}[%c100, %c50](%arg0, %arg1, %arg2) {
+ // CHECK-SAME: hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>]
+ // CHECK-SAME: } : (tensor<4xi32>, tensor<8xi32>, i32) -> %arg1
+ %result = hal.dispatch.extern "main"[%x, %y](%arg0, %arg1, %arg2) : (tensor<4xi32>, tensor<8xi32>, i32) -> %arg1
+ // Translates the workload (%x and %y captured above) into an XYZ workgroup
+ // count, optionally using device information.
+ count(%device: !hal.device, %x_capture: index, %y_capture: index) -> (index, index, index) {
+ // Shows how device queries can be used when computing the workgroup count.
+ // The device is the one used at runtime.
+ %ok, %z_i32 = hal.device.query<%device : !hal.device> key("some" :: "value") : i1, i32
+ %z = arith.index_cast %z_i32 : i32 to index
+ hal.return %x_capture, %y_capture, %z : index, index, index
+ }
+ // Must match the external definition.
+ layout(#hal.pipeline.layout<push_constants = 1, sets = [
+ <0, bindings = [
+ <0, storage_buffer, ReadOnly>,
+ <1, storage_buffer>
+ ]>
+ ]>)
+ // Optional, automatically inferred if omitted.
+ bindings([
+ #hal.interface.binding<0, 0>,
+ #hal.interface.binding<0, 1>
+ ])
+ // Can have object references for multiple targets or configurations.
+ objects({
+ #hal.executable.target<"llvm-cpu", "a"> ordinal(100) = [#hal.executable.object<{path = "a.o"}>],
+ #hal.executable.target<"llvm-cpu", "b"> if(%device: !hal.device) -> i1 {
+ %ok, %z_i32 = hal.device.query<%device : !hal.device> key("some" :: "feature") : i1, i32
+ hal.return %ok : i1
+ } ordinal(200) = [#hal.executable.object<{path = "b.o"}>]
+ })
+ // CHECK: return %[[RESULT]]
+ return %result : tensor<8xi32>
+}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir
index 03462c5..67e90e1 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir
@@ -176,58 +176,3 @@
}
return %0 : tensor<4xi32>
}
-
-// -----
-
-// CHECK: flow.executable private @dispatchExtern_dispatch_0
-// CHECK-NEXT: flow.executable.export public @main
-// CHECK-SAME: workgroups(%arg0: !hal.device, %arg1: index, %arg2: index) -> (index, index, index) {
-// CHECK-NEXT: %ok, %value = hal.device.query<%arg0 : !hal.device> key("some" :: "value") : i1, i32
-// CHECK-NEXT: %0 = arith.index_cast %value : i32 to index
-// CHECK-NEXT: hal.return %arg1, %arg2, %0 : index, index, index
-// CHECK-NEXT: } attributes {
-// CHECK-SAME: hal.interface.layout = #hal.pipeline.layout<push_constants = 1, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>
-// CHECK-SAME: }
-
-// Demonstrates the full functionality of an extern dispatch op.
-// Note that some fields are optional.
-
-// CHECK-LABEL: func.func @dispatchExtern
-func.func @dispatchExtern(%arg0: tensor<4xi32>, %arg1: tensor<8xi32>, %arg2: i32) -> tensor<8xi32> {
- %x = arith.constant 100 : index
- %y = arith.constant 50 : index
- // Dispatch workgroups to the externally defined function "main" in the
- // referenced object files.
- // CHECK: %[[RESULT:.+]] = flow.dispatch @dispatchExtern_dispatch_0::@main[%c100, %c50](%arg0, %arg1, %arg2) {
- // CHECK-SAME: hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>]
- // CHECK-SAME: } : (tensor<4xi32>, tensor<8xi32>, i32) -> %arg1
- %result = hal.dispatch.extern "main"[%x, %y](%arg0, %arg1, %arg2) : (tensor<4xi32>, tensor<8xi32>, i32) -> %arg1
- // Translates the workload (%x and %y captured above) into an XYZ workgroup
- // count, optionally using device information.
- count(%device: !hal.device, %x_capture: index, %y_capture: index) -> (index, index, index) {
- // Shows how device queries can be used when computing the workgroup count.
- // The device is the one used at runtime.
- %ok, %z_i32 = hal.device.query<%device : !hal.device> key("some" :: "value") : i1, i32
- %z = arith.index_cast %z_i32 : i32 to index
- hal.return %x_capture, %y_capture, %z : index, index, index
- }
- // Must match the external definition.
- layout(#hal.pipeline.layout<push_constants = 1, sets = [
- <0, bindings = [
- <0, storage_buffer, ReadOnly>,
- <1, storage_buffer>
- ]>
- ]>)
- // Optional, automatically inferred if omitted.
- bindings([
- #hal.interface.binding<0, 0>,
- #hal.interface.binding<0, 1>
- ])
- // Can have object references for multiple targets or configurations.
- objects(#hal.executable.objects<{
- #hal.executable.target<"llvm-cpu", "a"> = [#hal.executable.object<{path = "a.o"}>],
- #hal.executable.target<"llvm-cpu", "b"> = [#hal.executable.object<{path = "b.o"}>]
- }>)
- // CHECK: return %[[RESULT]]
- return %result : tensor<8xi32>
-}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td
index 3ecd4bc..1d42aa0 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td
@@ -493,6 +493,7 @@
]>;
def HAL_OrdinalAttr : Util_IndexAttrBase<"size_t">;
+def HAL_OrdinalArrayAttr : TypedArrayAttrBase<HAL_OrdinalAttr, "Array of index ordinal attributes">;
def HAL_ExecutableDataAttr : SignlessIntElementsAttr<8>;
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index 77c2125..b9655c3 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -194,6 +194,85 @@
/*printBlockTerminators=*/true);
}
+static ParseResult parseTargetConditionObjects(
+ OpAsmParser &parser, ArrayAttr &targetsAttr, ArrayAttr &targetOrdinalsAttr,
+ ArrayAttr &targetObjectsAttr,
+ SmallVector<std::unique_ptr<Region>, 2> &targetRegions) {
+ SmallVector<Attribute> targetsAttrs;
+ SmallVector<Attribute> targetOrdinalsAttrs;
+ SmallVector<Attribute> targetObjectsAttrs;
+ do {
+ // #hal.executable.target<...>
+ Attribute targetAttr;
+ if (failed(parser.parseAttribute(targetAttr)))
+ return failure();
+ targetsAttrs.push_back(targetAttr);
+
+ // if(...) -> i1 { ... }
+ auto region = std::make_unique<Region>();
+ if (succeeded(parser.parseOptionalKeyword("if"))) {
+ if (failed(parseTargetConditionRegion(parser, *region)))
+ return failure();
+ }
+ targetRegions.push_back(std::move(region));
+
+ // ordinal(#)
+ Attribute targetOrdinalAttr;
+ if (failed(parser.parseKeyword("ordinal")) ||
+ failed(parser.parseLParen()) ||
+ failed(parser.parseAttribute(targetOrdinalAttr,
+ IndexType::get(parser.getContext()))) ||
+ failed(parser.parseRParen()))
+ return failure();
+ targetOrdinalsAttrs.push_back(targetOrdinalAttr);
+
+ // = [#hal.executable.object<...>, ...]
+ ArrayAttr targetObjectsAttr;
+ if (failed(parser.parseEqual()) ||
+ failed(parser.parseAttribute(targetObjectsAttr)))
+ return failure();
+ targetObjectsAttrs.push_back(targetObjectsAttr);
+ } while (succeeded(parser.parseOptionalComma()));
+ targetsAttr = ArrayAttr::get(parser.getContext(), targetsAttrs);
+ targetOrdinalsAttr = ArrayAttr::get(parser.getContext(), targetOrdinalsAttrs);
+ targetObjectsAttr = ArrayAttr::get(parser.getContext(), targetObjectsAttrs);
+ return success();
+}
+
+static void printTargetConditionObjects(OpAsmPrinter &p, Operation *op,
+ ArrayAttr targetsAttr,
+ ArrayAttr targetOrdinalsAttr,
+ ArrayAttr targetObjectsAttr,
+ MutableArrayRef<Region> targetRegions) {
+ p.increaseIndent();
+ p.printNewline();
+
+ llvm::interleave(
+ llvm::zip_equal(targetsAttr, targetOrdinalsAttr, targetObjectsAttr,
+ targetRegions),
+ [&](auto it) {
+ auto &[targetAttr, targetOrdinalAttr, targetObjectsAttr, targetRegion] =
+ it;
+ p.printAttribute(targetAttr);
+ if (!targetRegion.empty()) {
+ p << " if";
+ printTargetConditionRegion(p, op, targetRegion);
+ }
+ p << " ordinal(";
+ p.printAttributeWithoutType(targetOrdinalAttr);
+ p << ")";
+ p << " = ";
+ p.printAttribute(targetObjectsAttr);
+ },
+ [&]() {
+ p << ",";
+ p.printNewline();
+ });
+
+ p.decreaseIndent();
+ p.printNewline();
+}
+
//===----------------------------------------------------------------------===//
// custom<WorkgroupCountRegion>($body)
//===----------------------------------------------------------------------===//
@@ -467,6 +546,7 @@
ValueRange resultDims, ValueRange arguments,
ValueRange argumentDims,
ArrayRef<int64_t> tiedOperands,
+ IREE::HAL::ExecutableObjectsAttr targetObjects,
ArrayRef<NamedAttribute> attributes) {
state.addTypes(resultTypes);
state.addOperands(workload);
@@ -477,6 +557,8 @@
state.attributes.erase(IREE::Util::TiedOpInterface::getStorageAttrName());
state.addAttribute(IREE::Util::TiedOpInterface::getStorageAttrName(),
builder.getIndexArrayAttr(tiedOperands));
+ state.addAttribute("targets", targetObjects.getTargets());
+ state.addAttribute("target_objects", targetObjects.getTargetObjects());
state.attributes.erase(getOperandSegmentSizeAttr());
state.addAttribute(getOperandSegmentSizeAttr(),
builder.getDenseI32ArrayAttr({
@@ -488,6 +570,10 @@
// NOTE: workgroup count region is empty; callers are expected to populate it.
state.addRegion();
+
+ // Add one empty region per target.
+ for (size_t i = 0; i < targetObjects.getTargets().size(); ++i)
+ state.addRegion();
}
// Verifies that |dynamicDims| contains the appropriate number of dims for all
@@ -569,6 +655,19 @@
return failure();
}
+ if (getTargets().size() != getTargetObjects().size()) {
+ return op->emitOpError() << "target and objects arrays must match";
+ }
+ if (getTargets().size() != getTargetRegions().size()) {
+ return op->emitOpError()
+ << "target and condition regions must match (but they may be empty)";
+ }
+ for (auto &targetRegion : getTargetRegions()) {
+ if (failed(verifyTargetConditionRegion(op, targetRegion))) {
+ return failure();
+ }
+ }
+
return success();
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
index 1d65042..50414a6 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -362,6 +362,13 @@
E.g., in the above example, `-> %0` would tie the first argument to the
result. In that case, there would be no separate block argument for the
result.
+
+ Objects for multiple targets can be specified and the ones used are selected
+ based on their target and an optional condition region that returns true if
+ the variant is valid for use on the provided runtime `!hal.device`. If no
+ variants within an executable are valid then loading will fail at runtime.
+ If multiple variants are valid the first valid one found will be loaded and
+ used for execution.
}];
let arguments = (ins
@@ -371,7 +378,9 @@
HAL_ShapeDynamicDims:$argument_dims,
HAL_ShapeDynamicDims:$result_dims,
HAL_PipelineLayoutAttr:$layout,
- HAL_ExecutableObjectsAttr:$objects,
+ ArrayAttr:$targets,
+ HAL_OrdinalArrayAttr:$target_ordinals,
+ ArrayAttr:$target_objects,
OptionalAttr<HAL_WorkgroupSizeAttr>:$workgroup_size,
OptionalAttr<HAL_SubgroupSizeAttr>:$subgroup_size,
OptionalAttr<IndexAttr>:$workgroup_local_memory,
@@ -383,7 +392,8 @@
);
let regions = (region
- AnyRegion:$workgroup_count
+ AnyRegion:$workgroup_count,
+ VariadicRegion<AnyRegion>:$target_regions
);
let assemblyFormat = [{
@@ -397,7 +407,10 @@
`count` `` custom<WorkgroupCountRegion>($workgroup_count)
`layout` `(` $layout `)`
(`bindings` `(` $bindings^ `)`)?
- `objects` `(` $objects `)`
+ `objects` `(` `{` custom<TargetConditionObjects>($targets,
+ $target_ordinals,
+ $target_objects,
+ $target_regions) `}` `)`
attr-dict-with-keyword
}];
@@ -408,6 +421,7 @@
"TypeRange":$resultTypes, "ValueRange":$resultDims,
"ValueRange":$arguments, "ValueRange":$argumentDims,
"ArrayRef<int64_t>":$tiedOperands,
+ "IREE::HAL::ExecutableObjectsAttr":$targetObjects,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
];
@@ -429,6 +443,10 @@
ValueRange getResultDynamicDims(unsigned idx) {
return IREE::Util::findVariadicDynamicDims(idx, getResults(), getResultDims());
}
+
+ IREE::HAL::ExecutableObjectsAttr getObjectsAttr() {
+ return IREE::HAL::ExecutableObjectsAttr::get(getContext(), getTargets(), getTargetObjects());
+ }
}];
let hasVerifier = 1;
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/tensor_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/tensor_ops.mlir
index 42f1d9e..45e1c6f 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/tensor_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/tensor_ops.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file %s | iree-opt --split-input-file | FileCheck %s
+// RUN: iree-opt --split-input-file --mlir-print-local-scope %s | iree-opt --split-input-file --mlir-print-local-scope | FileCheck %s
// CHECK-LABEL: @tensorImportStatic
func.func @tensorImportStatic(%arg0: !hal.buffer_view) -> tensor<5xi32> {
@@ -59,21 +59,28 @@
// CHECK-LABEL: func.func @dispatchExtern
func.func @dispatchExtern(%arg0: tensor<4xi32>, %arg1: tensor<8xi32>, %arg2: i32) -> tensor<8xi32> {
- %x = arith.constant 100 : index
- %y = arith.constant 50 : index
+ // CHECK-DAG: %[[WORKLOAD_X:.+]] = arith.constant 100
+ %workload_x = arith.constant 100 : index
+ // CHECK-DAG: %[[WORKLOAD_Y:.+]] = arith.constant 50
+ %workload_y = arith.constant 50 : index
+
// Dispatch workgroups to the externally defined function "main" in the
- // referenced object files.
- %0 = hal.dispatch.extern "main"[%x, %y](%arg0, %arg1, %arg2) : (tensor<4xi32>, tensor<8xi32>, i32) -> %arg1
+ // referenced object files with the ordinal specified per object group.
+ // CHECK: %[[RESULT:.+]] = hal.dispatch.extern "main"[%[[WORKLOAD_X]], %[[WORKLOAD_Y]]](%arg0, %arg1, %arg2) : (tensor<4xi32>, tensor<8xi32>, i32) -> %arg1
+ %0 = hal.dispatch.extern "main"[%workload_x, %workload_y](%arg0, %arg1, %arg2) : (tensor<4xi32>, tensor<8xi32>, i32) -> %arg1
// Translates the workload (%x and %y captured above) into an XYZ workgroup
// count, optionally using device information.
+ // CHECK: count(%[[DEVICE:.+]]: !hal.device, %[[X_CAPTURE:.+]]: index, %[[Y_CAPTURE:.+]]: index) -> (index, index, index) {
count(%device: !hal.device, %x_capture: index, %y_capture: index) -> (index, index, index) {
// Shows how device queries can be used when computing the workgroup count.
// The device is the one used at runtime.
+ // CHECK: = hal.device.query<%[[DEVICE]] : !hal.device>
%ok, %z_i32 = hal.device.query<%device : !hal.device> key("some" :: "value") : i1, i32
%z = arith.index_cast %z_i32 : i32 to index
hal.return %x_capture, %y_capture, %z : index, index, index
}
// Must match the external definition.
+ // CHECK: layout(<push_constants = 1, sets =
layout(#hal.pipeline.layout<push_constants = 1, sets = [
<0, bindings = [
<0, storage_buffer, ReadOnly>,
@@ -81,14 +88,31 @@
]>
]>)
// Optional, automatically inferred if omitted.
+ // CHECK: bindings([#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>])
bindings([
#hal.interface.binding<0, 0>,
#hal.interface.binding<0, 1>
])
// Can have object references for multiple targets or configurations.
- objects(#hal.executable.objects<{
- #hal.executable.target<"llvm-cpu", "a"> = [#hal.executable.object<{path = "a.o"}>],
- #hal.executable.target<"llvm-cpu", "b"> = [#hal.executable.object<{path = "b.o"}>]
- }>)
+ // CHECK: objects({
+ objects({
+ // CHECK: #hal.executable.target<"llvm-cpu", "a"> ordinal(100) = [#hal.executable.object<{path = "a.o"}>]
+ #hal.executable.target<"llvm-cpu", "a"> ordinal(100) = [#hal.executable.object<{path = "a.o"}>],
+ // CHECK: #hal.executable.target<"llvm-cpu", "b"> if(%[[B_DEVICE:.+]]: !hal.device) -> i1 {
+ #hal.executable.target<"llvm-cpu", "b"> if(%device: !hal.device) -> i1 {
+ // CHECK: = hal.device.query<%[[B_DEVICE]] : !hal.device>
+ %ok, %z_i32 = hal.device.query<%device : !hal.device> key("some" :: "feature_b") : i1, i32
+ hal.return %ok : i1
+ // CHECK: } ordinal(200) = [#hal.executable.object<{path = "b.o"}>]
+ } ordinal(200) = [#hal.executable.object<{path = "b.o"}>],
+ // CHECK: #hal.executable.target<"llvm-cpu", "c"> if(%[[C_DEVICE:.+]]: !hal.device) -> i1 {
+ #hal.executable.target<"llvm-cpu", "c"> if(%device: !hal.device) -> i1 {
+ // CHECK: = hal.device.query<%[[C_DEVICE]] : !hal.device>
+ %ok, %z_i32 = hal.device.query<%device : !hal.device> key("some" :: "feature_c") : i1, i32
+ hal.return %ok : i1
+ // CHECK: } ordinal(300) = [#hal.executable.object<{path = "c.o"}>]
+ } ordinal(300) = [#hal.executable.object<{path = "c.o"}>]
+ })
+ // CHECK: return %[[RESULT]]
return %0 : tensor<8xi32>
}
diff --git a/samples/custom_dispatch/vulkan/shaders/example_inline.mlir b/samples/custom_dispatch/vulkan/shaders/example_inline.mlir
index 980df87..5979d76 100644
--- a/samples/custom_dispatch/vulkan/shaders/example_inline.mlir
+++ b/samples/custom_dispatch/vulkan/shaders/example_inline.mlir
@@ -52,13 +52,15 @@
%dim_i32 = arith.index_cast %dim : index to i32
// Dispatch a basic `ret = lhs * rhs` shader.
+ // Note that not all backends use names or the names are derived from
+ // ordinals so we include that (`:ordinal`).
%0 = hal.dispatch.extern "main"[%dim](%dim_i32, %arg0, %arg1) : (i32, tensor<?xf32>{%dim}, tensor<?xf32>{%dim}) -> tensor<?xf32>{%dim}
+ // This host function is used to compute the XYZ workgroup count
+ // dispatched at runtime. It can query the %device for capabilities
+ // and limits (shared memory size, etc). The other arguments are the
+ // values passed in the dispatch operation (usually things like root
+ // output op tensor dimensions and other abstract values).
count(%device: !hal.device, %workload: index) -> (index, index, index) {
- // This host function is used to compute the XYZ workgroup count
- // dispatched at runtime. It can query the %device for capabilities
- // and limits (shared memory size, etc). The other arguments are the
- // values passed in the dispatch operation (usually things like root
- // output op tensor dimensions and other abstract values).
%x = affine.apply affine_map<()[s0] -> (s0 ceildiv 64)>()[%workload]
%c1 = arith.constant 1 : index
hal.return %x, %c1, %c1 : index, index, index
@@ -90,8 +92,8 @@
// keys can be generic. This allows for an object file to be linked in based
// only on the target triple while allowing for more specialized ones
// requiring certain CPU features to be only included when building those.
- objects(#hal.executable.objects<{
- #spirv_target = [
+ objects({
+ #spirv_target ordinal(0) = [
#hal.executable.object<{
// Referencing a file path on disk but could also have the data
// embedded in order to make the MLIR file hermetic/portable across
@@ -103,7 +105,7 @@
path = "samples/custom_dispatch/vulkan/shaders/simple_mul.spv"
}>
]
- }>)
+ })
// Code gen some other ops - these will interleave with the hand-authored
// ones but naturally won't be able to fuse with them.