Adding selection condition to hal.executable.variant. (#15284)
This allows for variants to declare host logic that determines whether
the variant should be selected for loading. When multiple variants are
available their declared conditions will be evaluated in op order along
with the existing executable format match.
Unfortunately the MLIR SymbolTable trait disallows multiple regions on
any op holding it so a new `hal.executable.condition` region op was
added that may be optionally present on any `hal.executable.variant`.
Ideally we clean this up and make it an optional region but that'll need
relaxing of upstream assertions like
https://sourcegraph.com/github.com/llvm/llvm-project/-/blob/mlir/lib/IR/SymbolTable.cpp?L122-123
(ideally either treating region 0 as the symbol table on ops or having
an interface override for selecting the region ala `getCallableRegion`
such as `getSymbolTableRegion`).
This removes the `hal.device.switch` op in favor of `scf.index_switch`.
When we start running const expr hoisting during the HAL pipeline this
should allow variant selection to be completely hoisted to
initialization time (or at least memoized per device). There's decent
low-hanging future work on optimizing the ranking/selection and
improving `scf.index_switch` hoisting/canonicalization to make things
better.
diff --git a/compiler/plugins/target/CUDA/CUDATarget.cpp b/compiler/plugins/target/CUDA/CUDATarget.cpp
index 8d853fc..6054688 100644
--- a/compiler/plugins/target/CUDA/CUDATarget.cpp
+++ b/compiler/plugins/target/CUDA/CUDATarget.cpp
@@ -431,7 +431,7 @@
// Collect all the entry point parameters.
SmallVector<std::array<int32_t, 3>> workgroupSizes;
SmallVector<uint32_t> workgroupLocalMemories;
- for (auto exportOp : variantOp.getOps<IREE::HAL::ExecutableExportOp>()) {
+ for (auto exportOp : variantOp.getExportOps()) {
std::array<int32_t, 3> workgroupSize;
if (std::optional<ArrayAttr> workgroupSizeAttr =
exportOp.getWorkgroupSize()) {
@@ -472,7 +472,7 @@
// these to match the names in their kernels. We don't support any kind of
// mangling and if the user was silly enough to rely on nvcc C++ mangling
// they'll have to figure that out.
- for (auto exportOp : variantOp.getOps<IREE::HAL::ExecutableExportOp>()) {
+ for (auto exportOp : variantOp.getExportOps()) {
entryPointNames.emplace_back(exportOp.getSymName());
}
@@ -503,8 +503,7 @@
}
for (auto [exportOp, workgroupSize] :
- llvm::zip_equal(variantOp.getOps<IREE::HAL::ExecutableExportOp>(),
- workgroupSizes)) {
+ llvm::zip_equal(variantOp.getExportOps(), workgroupSizes)) {
auto *llvmFunc = llvmModule->getFunction(exportOp.getName());
if (llvmFunc->isDeclaration())
continue;
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel
index 900d220..7b1ab5f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel
@@ -96,7 +96,6 @@
"//compiler/src/iree/compiler/Dialect/Flow/IR",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Dialect/HAL/IR:HALDialect",
- "//compiler/src/iree/compiler/Dialect/HAL/Utils",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"//compiler/src/iree/compiler/Dialect/Util/Transforms",
"//compiler/src/iree/compiler/Utils",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
index 8d5feaf..d798720 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
@@ -146,7 +146,6 @@
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::HAL::IR::HALDialect
- iree::compiler::Dialect::HAL::Utils
iree::compiler::Dialect::Util::IR
iree::compiler::Dialect::Util::Transforms
iree::compiler::Utils
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp
index 0748dc1..589d9dc 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp
@@ -85,7 +85,7 @@
// Add the amount of shared memory required as an attribute.
auto variantOp = moduleOp->getParentOfType<IREE::HAL::ExecutableVariantOp>();
if (variantOp != nullptr) {
- for (auto exportOp : variantOp.getOps<IREE::HAL::ExecutableExportOp>()) {
+ for (auto exportOp : variantOp.getExportOps()) {
exportOp->setAttr(exportOp.getWorkgroupLocalMemoryAttrName(),
builder.getIndexAttr(numberOfBytes));
}
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
index 83de13f..3ba39a1 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
@@ -179,8 +179,7 @@
// TODO(#1519): this conversion should look up the entry point information
// to get the total push constant count.
auto variantOp = loadOp->getParentOfType<IREE::HAL::ExecutableVariantOp>();
- auto exportOps =
- llvm::to_vector<1>(variantOp.getOps<IREE::HAL::ExecutableExportOp>());
+ auto exportOps = llvm::to_vector<1>(variantOp.getExportOps());
assert(exportOps.size() == 1);
auto layoutAttr = exportOps.front().getLayout();
diff --git a/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp
index 824d8c2..97d6124 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp
@@ -236,14 +236,14 @@
// Move any constant blocks that need to be preserved for future host
// translation. There may be duplicates provided but they'll be cleaned
// up in future passes.
- for (auto constantBlockOp : llvm::make_early_inc_range(
- variantOp.getOps<IREE::HAL::ExecutableConstantBlockOp>())) {
+ for (auto constantBlockOp :
+ llvm::make_early_inc_range(variantOp.getConstantBlockOps())) {
constantBlockOp->moveBefore(&*linkedTargetBuilder.getInsertionPoint());
}
// Clone export ops and queue remapping ordinals and updating
// symbol refs.
- for (auto exportOp : variantOp.getOps<IREE::HAL::ExecutableExportOp>()) {
+ for (auto exportOp : variantOp.getExportOps()) {
auto newExportOp =
linkedTargetBuilder.create<IREE::HAL::ExecutableExportOp>(
exportOp.getLoc(), exportOp.getSymNameAttr(),
@@ -290,7 +290,7 @@
replaceEntryPointUses(moduleOp, symbolReplacements);
// Remove if we didn't add anything.
- if (linkedTargetOp.getOps<IREE::HAL::ExecutableExportOp>().empty()) {
+ if (linkedTargetOp.getExportOps().empty()) {
linkedTargetOp.erase();
linkedExecutableOp.erase();
}
diff --git a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
index c443639..bbfc47c 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
@@ -39,7 +39,7 @@
if (!variantOp)
return failure();
- for (auto op : variantOp.getOps<IREE::HAL::ExecutableExportOp>()) {
+ for (auto op : variantOp.getExportOps()) {
if (op.getSymName() == funcOp.getName()) {
return op;
}
@@ -66,7 +66,7 @@
getAllEntryPoints(ModuleOp module) {
auto variantOp = module->getParentOfType<IREE::HAL::ExecutableVariantOp>();
llvm::StringMap<IREE::HAL::ExecutableExportOp> exportOps;
- for (auto op : variantOp.getOps<IREE::HAL::ExecutableExportOp>()) {
+ for (auto op : variantOp.getExportOps()) {
exportOps[op.getSymName()] = op;
}
return exportOps;
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir
index 56f05a5..03462c5 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir
@@ -202,6 +202,15 @@
// CHECK-SAME: hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>]
// CHECK-SAME: } : (tensor<4xi32>, tensor<8xi32>, i32) -> %arg1
%result = hal.dispatch.extern "main"[%x, %y](%arg0, %arg1, %arg2) : (tensor<4xi32>, tensor<8xi32>, i32) -> %arg1
+ // Translates the workload (%x and %y captured above) into an XYZ workgroup
+ // count, optionally using device information.
+ count(%device: !hal.device, %x_capture: index, %y_capture: index) -> (index, index, index) {
+ // Shows how device queries can be used when computing the workgroup count.
+ // The device is the one used at runtime.
+ %ok, %z_i32 = hal.device.query<%device : !hal.device> key("some" :: "value") : i1, i32
+ %z = arith.index_cast %z_i32 : i32 to index
+ hal.return %x_capture, %y_capture, %z : index, index, index
+ }
// Must match the external definition.
layout(#hal.pipeline.layout<push_constants = 1, sets = [
<0, bindings = [
@@ -219,15 +228,6 @@
#hal.executable.target<"llvm-cpu", "a"> = [#hal.executable.object<{path = "a.o"}>],
#hal.executable.target<"llvm-cpu", "b"> = [#hal.executable.object<{path = "b.o"}>]
}>)
- // Translates the workload (%x and %y captured above) into an XYZ workgroup
- // count, optionally using device information.
- count(%device: !hal.device, %x_capture: index, %y_capture: index) -> (index, index, index) {
- // Shows how device queries can be used when computing the workgroup count.
- // The device is the one used at runtime.
- %ok, %z_i32 = hal.device.query<%device : !hal.device> key("some" :: "value") : i1, i32
- %z = arith.index_cast %z_i32 : i32 to index
- hal.return %x_capture, %y_capture, %z : index, index, index
- }
// CHECK: return %[[RESULT]]
return %result : tensor<8xi32>
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Conversion/BUILD.bazel
index 292353f..d87a54f 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/BUILD.bazel
@@ -25,7 +25,6 @@
],
deps = [
"//compiler/src/iree/compiler/Dialect/HAL/IR",
- "//compiler/src/iree/compiler/Dialect/HAL/Utils",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Conversion/CMakeLists.txt
index c452ba5..04da911 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/CMakeLists.txt
@@ -26,7 +26,6 @@
MLIRMemRefDialect
MLIRTransforms
iree::compiler::Dialect::HAL::IR
- iree::compiler::Dialect::HAL::Utils
iree::compiler::Dialect::Util::IR
PUBLIC
)
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/BUILD.bazel
index 9f7dba7..c630dc4 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/BUILD.bazel
@@ -32,7 +32,6 @@
deps = [
"//compiler/src/iree/compiler/Dialect/HAL:hal_imports",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
- "//compiler/src/iree/compiler/Dialect/HAL/Utils",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"//compiler/src/iree/compiler/Dialect/VM/Conversion",
"//compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/CMakeLists.txt
index 421ef74..5a5cb8a 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/CMakeLists.txt
@@ -34,7 +34,6 @@
MLIRPass
MLIRTransforms
iree::compiler::Dialect::HAL::IR
- iree::compiler::Dialect::HAL::Utils
iree::compiler::Dialect::HAL::hal_imports
iree::compiler::Dialect::Util::IR
iree::compiler::Dialect::VM::Conversion
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/BUILD.bazel
index c9803ce..52653fb 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/BUILD.bazel
@@ -25,7 +25,6 @@
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Dialect/HAL/IR:HALDialect",
"//compiler/src/iree/compiler/Dialect/HAL/Target",
- "//compiler/src/iree/compiler/Dialect/HAL/Utils",
"//compiler/src/iree/compiler/Dialect/Stream/IR",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"@llvm-project//llvm:Support",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/CMakeLists.txt
index 1fe1cbe..92dd0c6 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/CMakeLists.txt
@@ -29,7 +29,6 @@
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::HAL::IR::HALDialect
iree::compiler::Dialect::HAL::Target
- iree::compiler::Dialect::HAL::Utils
iree::compiler::Dialect::Stream::IR
iree::compiler::Dialect::Util::IR
PUBLIC
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
index d0a809e..46e21e6 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
@@ -9,7 +9,6 @@
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
-#include "iree/compiler/Dialect/HAL/Utils/DeviceSwitchBuilder.h"
#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
@@ -920,16 +919,6 @@
}
};
-// Returns a hal.device.switch match expression that selects the given export.
-static Attribute
-getExportConditionAttr(IREE::HAL::ExecutableExportOp exportOp) {
- // TODO(benvanik): customizable selection logic. Today this just checks
- // whether the variant target is supported but we can also allow
- // specialization of entry points based on dispatch site parameters.
- auto variantOp = exportOp->getParentOfType<IREE::HAL::ExecutableVariantOp>();
- return variantOp.getTarget().getMatchExpression();
-}
-
struct CmdDispatchOpPattern
: public StreamConversionPattern<IREE::Stream::CmdDispatchOp> {
using StreamConversionPattern::StreamConversionPattern;
@@ -942,42 +931,65 @@
// Get the device handle we're executing against in this execution region.
// Note that this is a dynamic value: we have to treat the device as unknown
// here.
- auto device = rewriter.create<IREE::HAL::CommandBufferDeviceOp>(
+ auto deviceValue = rewriter.create<IREE::HAL::CommandBufferDeviceOp>(
loc, rewriter.getType<IREE::HAL::DeviceType>(), commandBuffer);
- // Ask each target backend to record their dispatch logic.
- IREE::HAL::DeviceSwitchRewriter switchRewriter(loc,
- /*resultTypes=*/TypeRange{},
- device, rewriter);
+ // Prepare for variant switch table by gathering the conditions selecting
+ // each variant.
+ SmallVector<int64_t> caseIndices;
+ SmallVector<std::pair<SymbolRefAttr, IREE::HAL::ExecutableExportOp>>
+ caseExportOps;
dispatchOp.forEachEntryPointAttr([&](SymbolRefAttr entryPointAttr) {
// NOTE: slow lookup!
auto exportOp =
SymbolTable::lookupNearestSymbolFrom<IREE::HAL::ExecutableExportOp>(
dispatchOp, entryPointAttr);
assert(exportOp && "dispatch target export not found");
+ caseIndices.push_back(caseIndices.size());
+ caseExportOps.push_back(std::make_pair(entryPointAttr, exportOp));
+ });
- // Setup the case condition for the entry point.
- auto *caseRegion =
- switchRewriter.addConditionRegion(getExportConditionAttr(exportOp));
- auto &entryBlock = caseRegion->front();
- auto caseBuilder = OpBuilder::atBlockBegin(&entryBlock);
+ // Select the variant index.
+ Value selectedIndex = buildIfElseTree(
+ loc, caseExportOps.size(),
+ [&](Location loc, size_t i, OpBuilder &builder) {
+ auto exportOp = caseExportOps[i].second;
+ auto variantOp =
+ exportOp->getParentOfType<IREE::HAL::ExecutableVariantOp>();
+ return variantOp.buildCondition(deviceValue, rewriter);
+ },
+ rewriter);
+
+ // Allow each variant to define how it is dispatched.
+ auto switchOp = rewriter.replaceOpWithNewOp<scf::IndexSwitchOp>(
+ dispatchOp, TypeRange{}, selectedIndex, caseIndices,
+ caseIndices.size());
+ for (size_t i = 0; i < caseExportOps.size(); ++i) {
+ auto entryPointAttr = caseExportOps[i].first;
+ auto exportOp = caseExportOps[i].second;
+ auto &caseBlock = switchOp.getCaseRegions()[i].emplaceBlock();
+ auto caseBuilder = OpBuilder::atBlockBegin(&caseBlock);
// Record push constants and buffer bindings.
- recordParameters(loc, device, commandBuffer, dispatchOp, adaptor,
+ recordParameters(loc, deviceValue, commandBuffer, dispatchOp, adaptor,
exportOp.getLayout(), caseBuilder);
// Dispatch with a target-specific workgroup count.
auto caseWorkgroupCount = exportOp.calculateWorkgroupCount(
- loc, device, adaptor.getWorkload(), caseBuilder);
+ loc, deviceValue, adaptor.getWorkload(), caseBuilder);
caseBuilder.create<IREE::HAL::CommandBufferDispatchSymbolOp>(
loc, commandBuffer, entryPointAttr, caseWorkgroupCount[0],
caseWorkgroupCount[1], caseWorkgroupCount[2]);
- caseBuilder.create<IREE::HAL::ReturnOp>(loc);
- });
- switchRewriter.build();
+ caseBuilder.create<scf::YieldOp>(loc);
+ }
- rewriter.eraseOp(dispatchOp);
+ // Fallback for no available variant. Today we just no-op as executable
+ // loading should have already failed.
+ auto &defaultBlock = switchOp.getDefaultRegion().emplaceBlock();
+ auto defaultBuilder = OpBuilder::atBlockBegin(&defaultBlock);
+ defaultBuilder.create<scf::YieldOp>(loc);
+
return success();
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir
index 23eac6d..d9e6873 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir
@@ -164,9 +164,10 @@
// -----
-#executable_target_embedded_elf_x86_64 = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64">
+#executable_target_aarch64 = #hal.executable.target<"llvm-cpu", "embedded-elf-aarch64">
+#executable_target_x86_64 = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64">
#device_target_cpu = #hal.device.target<"llvm-cpu", {
- executable_targets = [#executable_target_embedded_elf_x86_64]
+ executable_targets = [#executable_target_aarch64, #executable_target_x86_64]
}>
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
@@ -177,7 +178,24 @@
]>
]>
hal.executable private @ex {
- hal.executable.variant public @embedded_elf_x86_64 target(#executable_target_embedded_elf_x86_64) {
+ hal.executable.variant public @aarch64 target(#executable_target_aarch64) {
+ hal.executable.condition(%device: !hal.device) -> i1 {
+ %ok, %selected = hal.device.query<%device : !hal.device> key("some" :: "feature") : i1, i1
+ hal.return %selected : i1
+ }
+ hal.executable.export public @dispatch ordinal(0) layout(#pipeline_layout) attributes {
+ translation_info = #iree_codegen.translation_info<CPUDefault>
+ } {
+ ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index): // no predecessors
+ %c1 = arith.constant 1 : index
+ %0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%arg0]
+ hal.return %0, %c1, %c1 : index, index, index
+ }
+ builtin.module {
+ // Opaque at this point (in some target-specific dialects).
+ }
+ }
+ hal.executable.variant public @x86_64 target(#executable_target_x86_64) {
hal.executable.export public @dispatch ordinal(0) layout(#pipeline_layout) attributes {
translation_info = #iree_codegen.translation_info<CPUDefault>
} {
@@ -203,9 +221,19 @@
%c128 = arith.constant 128 : index
// CHECK: %[[CMD:.+]] = hal.command_buffer.create
%0 = stream.cmd.execute with(%arg0 as %arg4: !stream.resource<transient>{%arg1}, %arg2 as %arg5: !stream.resource<external>{%arg3}) {
- // Switch for each executable variant:
- // CHECK: hal.device.switch
- // CHECK-NEXT: #hal.device.match.executable.format<"embedded-elf-x86_64">
+ // Switch for each executable variant by checking conditions and ranking:
+ // CHECK: %[[DEVICE:.+]] = hal.command_buffer.device<%[[CMD]] : !hal.command_buffer>
+ // CHECK-DAG: %{{.+}}, %[[AARCH64_FORMAT:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.executable.format" :: "embedded-elf-aarch64")
+ // CHECK-DAG: %[[AARCH64_FEATURE:.+]] = scf.execute_region -> i1 {
+ // CHECK-NEXT: %{{.+}}, %[[FEATURE:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("some" :: "feature")
+ // CHECK-NEXT: scf.yield %[[FEATURE]]
+ // CHECK-NEXT: }
+ // CHECK-DAG: %[[AARCH64_SELECTED:.+]] = arith.andi %[[AARCH64_FORMAT]], %[[AARCH64_FEATURE]]
+ // CHECK-DAG: %{{.+}}, %[[X86_64_SELECTED:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.executable.format" :: "embedded-elf-x86_64")
+ // CHECK: %[[VARIANT1:.+]] = arith.select %[[X86_64_SELECTED]], %c1
+ // CHECK: %[[VARIANT0:.+]] = arith.select %[[AARCH64_SELECTED]], %c0{{.+}}, %[[VARIANT1]]
+ // CHECK: scf.index_switch %[[VARIANT0]]
+ // CHECK-NEXT: case 0 {
// Cache queries:
// CHECK-DAG: %[[LAYOUT:.+]] = hal.pipeline_layout.lookup {{.+}} layout(#pipeline_layout)
@@ -230,9 +258,14 @@
// Dispatch:
// CHECK: hal.command_buffer.dispatch.symbol<%[[CMD]]
- // CHECK-SAME: target(@ex::@embedded_elf_x86_64::@dispatch)
+ // CHECK-SAME: target(@ex::@aarch64::@dispatch)
// CHECK-SAME: workgroups([%[[X]], %[[YZ]], %[[YZ]]])
- stream.cmd.dispatch @ex::@embedded_elf_x86_64::@dispatch[%c1, %c2, %c3](%c4_i32, %c5_i32 : i32, i32) {
+
+ // Other variant, when selected:
+ // CHECK: case 1 {
+ // CHECK: hal.command_buffer.dispatch.symbol<%[[CMD]]
+ // CHECK-SAME: target(@ex::@x86_64::@dispatch)
+ stream.cmd.dispatch {@ex::@aarch64::@dispatch, @ex::@x86_64::@dispatch}[%c1, %c2, %c3](%c4_i32, %c5_i32 : i32, i32) {
ro %arg4[%c0 for %c128] : !stream.resource<transient>{%arg1},
wo %arg5[%c0 for %c128] : !stream.resource<external>{%arg3}
} attributes {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel
index e5a47ac..4c543db 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel
@@ -81,6 +81,7 @@
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Parser",
+ "@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:SideEffectInterfaces",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
@@ -105,6 +106,7 @@
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Parser",
+ "@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:TransformUtils",
],
)
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt
index 8d6c892..64d0fb1 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt
@@ -47,6 +47,7 @@
MLIRIR
MLIRMemRefDialect
MLIRParser
+ MLIRSCFDialect
MLIRSideEffectInterfaces
MLIRSupport
MLIRTransformUtils
@@ -73,6 +74,7 @@
MLIRIR
MLIRMemRefDialect
MLIRParser
+ MLIRSCFDialect
MLIRTransformUtils
iree::compiler::Dialect::HAL::Conversion::HALToVM
iree::compiler::Dialect::HAL::hal_imports
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALDialect.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALDialect.cpp
index 7d76552..decdf93 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALDialect.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALDialect.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
@@ -127,6 +128,7 @@
HALDialect::HALDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context, TypeID::get<HALDialect>()) {
context->loadDialect<mlir::cf::ControlFlowDialect>();
+ context->loadDialect<mlir::scf::SCFDialect>();
context->loadDialect<IREE::Util::UtilDialect>();
registerAttributes();
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALInterfaces.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALInterfaces.td
index cdf89ae..b1df6ff 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALInterfaces.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALInterfaces.td
@@ -25,7 +25,7 @@
attribute is true for the given value.
}],
"Value", "buildConditionExpression",
- (ins "Location":$loc, "Value":$value, "OpBuilder":$builder)
+ (ins "Location":$loc, "Value":$device, "OpBuilder":$builder)
>,
];
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
index 6d52a20..dfc893d 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
@@ -351,16 +351,6 @@
}
//===----------------------------------------------------------------------===//
-// hal.device.switch
-//===----------------------------------------------------------------------===//
-
-// TODO(benvanik): fold conditions with the same IR tree.
-// TODO(benvanik): remove duplicate conditions.
-// TODO(benvanik): fold condition expressions (any(always, ...) -> always, etc).
-// TODO(benvanik): completely replace switches with just one always block.
-// TODO(benvanik): remove conditions with no side-effects.
-
-//===----------------------------------------------------------------------===//
// hal.device.match.id
//===----------------------------------------------------------------------===//
@@ -682,8 +672,7 @@
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ExecutableVariantOp variantOp,
PatternRewriter &rewriter) const override {
- auto blockOps =
- llvm::to_vector(variantOp.getOps<ExecutableConstantBlockOp>());
+ auto blockOps = llvm::to_vector(variantOp.getConstantBlockOps());
if (blockOps.size() <= 1) {
return rewriter.notifyMatchFailure(variantOp,
"not enough blocks to merge");
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index 4157fe1..018cc77 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -10,6 +10,7 @@
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -114,6 +115,146 @@
}
//===----------------------------------------------------------------------===//
+// custom<TargetConditionRegion>($body)
+//===----------------------------------------------------------------------===//
+
+static FunctionType getTargetConditionRegionType(MLIRContext *context) {
+ return FunctionType::get(context,
+ {
+ IREE::HAL::DeviceType::get(context),
+ },
+ {
+ IntegerType::get(context, 1),
+ });
+}
+
+static LogicalResult verifyTargetConditionRegion(Operation *op,
+ Region ®ion) {
+ // Ignore if empty.
+ if (region.empty())
+ return success();
+
+ // Verify region takes a !hal.device.
+ if (region.getNumArguments() != 1 ||
+ !isa<IREE::HAL::DeviceType>(region.getArgumentTypes().front())) {
+ return op->emitOpError()
+ << "target condition region must take a !hal.device";
+ }
+
+ // Verify i1 return.
+ for (auto returnOp : region.getOps<IREE::HAL::ReturnOp>()) {
+ if (returnOp.getNumOperands() != 1) {
+ return returnOp.emitOpError()
+ << "target condition region must return a single i1 result";
+ }
+ for (auto returnType : returnOp.getOperandTypes()) {
+ if (!returnType.isInteger(1)) {
+ return returnOp.emitOpError()
+ << "target condition region must return a single i1 result";
+ }
+ }
+ }
+
+ return success();
+}
+
+static ParseResult parseTargetConditionRegion(OpAsmParser &parser,
+ Region &body) {
+ SmallVector<OpAsmParser::Argument> args;
+ if (failed(parser.parseArgumentList(args, AsmParser::Delimiter::Paren,
+ /*allowType=*/true,
+ /*allowAttrs=*/true))) {
+ return failure();
+ }
+
+ SmallVector<Type> returnTypes;
+ if (failed(parser.parseArrowTypeList(returnTypes))) {
+ return failure();
+ }
+ if (returnTypes.size() != 1 ||
+ !llvm::all_of(returnTypes, [](Type type) { return type.isInteger(1); })) {
+ return parser.emitError(parser.getCurrentLocation())
+ << "target condition region must return one i1";
+ }
+
+ return parser.parseRegion(body, args, /*enableNameShadowing=*/false);
+}
+
+static void printTargetConditionRegion(OpAsmPrinter &p, Operation *op,
+ Region &body) {
+ if (body.empty())
+ return;
+ p << "(";
+ llvm::interleaveComma(body.getArguments(), p,
+ [&](BlockArgument arg) { p.printRegionArgument(arg); });
+ p << ")";
+ p.printArrowTypeList(TypeRange{IntegerType::get(body.getContext(), 1)});
+ p << " ";
+ p.printRegion(body, /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/true);
+}
+
+//===----------------------------------------------------------------------===//
+// custom<ConditionalTargetRegions>($targets, $objects, $target_regions)
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseConditionalTargetRegions(
+ OpAsmParser &parser, ArrayAttr &targetsAttr, ArrayAttr &objectsAttr,
+ SmallVectorImpl<std::unique_ptr<Region>> &targetRegions) {
+ auto builder = parser.getBuilder();
+ SmallVector<Attribute> targetAttrs;
+ SmallVector<Attribute> objectsAttrs;
+ do {
+ IREE::HAL::ExecutableTargetAttr targetAttr;
+ if (failed(parser.parseAttribute(targetAttr)))
+ return failure();
+ targetAttrs.push_back(targetAttr);
+ std::unique_ptr<Region> targetRegion = std::make_unique<Region>();
+ if (succeeded(parser.parseOptionalKeyword("if"))) {
+ if (failed(parseTargetConditionRegion(parser, *targetRegion)))
+ return failure();
+ }
+ targetRegions.emplace_back(std::move(targetRegion));
+ if (failed(parser.parseEqual()))
+ return failure();
+ ArrayAttr targetObjectsAttr;
+ if (failed(parser.parseAttribute(targetObjectsAttr)))
+ return failure();
+ objectsAttrs.push_back(targetObjectsAttr);
+ } while (succeeded(parser.parseOptionalComma()));
+ targetsAttr = builder.getArrayAttr(targetAttrs);
+ objectsAttr = builder.getArrayAttr(objectsAttrs);
+ return success();
+}
+
+static void
+printConditionalTargetRegions(OpAsmPrinter &p, Operation *op,
+ ArrayAttr targetsAttr, ArrayAttr objectsAttr,
+ MutableArrayRef<Region> targetRegions) {
+ p.increaseIndent();
+ p.printNewline();
+ llvm::interleave(
+ llvm::zip_equal(targetsAttr.getAsRange<IREE::HAL::ExecutableTargetAttr>(),
+ objectsAttr.getAsRange<ArrayAttr>(), targetRegions),
+ [&](auto it) {
+ auto [targetAttr, targetObjectsAttr, targetRegion] = it;
+ p.printAttribute(targetAttr);
+ if (!targetRegion.empty()) {
+ p << " if";
+ printTargetConditionRegion(p, op, targetRegion);
+ }
+ p << " = ";
+ p.printAttribute(targetObjectsAttr);
+ },
+ [&]() {
+ p << ",";
+ p.printNewline();
+ });
+ p.decreaseIndent();
+ p.printNewline();
+}
+
+//===----------------------------------------------------------------------===//
// custom<WorkgroupCountRegion>($body)
//===----------------------------------------------------------------------===//
@@ -161,12 +302,8 @@
if (body.empty())
return;
p << "(";
- auto args = body.getArguments();
- for (unsigned i = 0; i < args.size(); ++i) {
- if (i > 0)
- p << ", ";
- p.printRegionArgument(args[i]);
- }
+ llvm::interleaveComma(body.getArguments(), p,
+ [&](BlockArgument arg) { p.printRegionArgument(arg); });
p << ")";
Type indexType = IndexType::get(body.getContext());
p.printArrowTypeList(TypeRange{indexType, indexType, indexType});
@@ -705,113 +842,6 @@
}
//===----------------------------------------------------------------------===//
-// hal.device.switch
-//===----------------------------------------------------------------------===//
-
-void DeviceSwitchOp::build(OpBuilder &builder, OperationState &state,
- TypeRange resultTypes, Value device,
- ArrayRef<Attribute> conditions,
- ArrayRef<NamedAttribute> attributes) {
- state.addOperands({device});
- state.addAttribute("conditions", builder.getArrayAttr(conditions));
- for (size_t i = 0; i < conditions.size(); ++i) {
- state.addRegion();
- }
- state.addTypes(resultTypes);
- state.addAttributes(attributes);
-}
-
-ParseResult DeviceSwitchOp::parse(OpAsmParser &parser, OperationState &result) {
- OpAsmParser::UnresolvedOperand device;
- Type deviceType;
- if (failed(parser.parseLess()) || failed(parser.parseOperand(device)) ||
- failed(parser.parseColonType(deviceType)) ||
- failed(parser.resolveOperand(device, deviceType, result.operands)) ||
- failed(parser.parseGreater()) ||
- failed(parser.parseOptionalArrowTypeList(result.types))) {
- return failure();
- }
-
- // Parses each switch condition attribute and region, like:
- // #hal.device.match.id<"vulkan-v1.?-*"> {
- // hal.return %c1 : i32
- // }, ...
- SmallVector<Attribute> conditionAttrs;
- do {
- Attribute conditionAttr;
- NamedAttrList dummyAttrs;
- if (failed(parser.parseAttribute(conditionAttr, "condition", dummyAttrs))) {
- return failure();
- }
- conditionAttrs.push_back(conditionAttr);
- SmallVector<OpAsmParser::Argument> regionArgs;
- auto *regionBody = result.addRegion();
- if (failed(parser.parseRegion(*regionBody, regionArgs))) {
- return failure();
- }
- } while (succeeded(parser.parseOptionalComma()));
- result.addAttribute("conditions",
- ArrayAttr::get(result.getContext(), conditionAttrs));
-
- if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) {
- return failure();
- }
- return success();
-}
-
-void DeviceSwitchOp::print(OpAsmPrinter &p) {
- Operation *op = getOperation();
- p << "<";
- p.printOperand(getDevice());
- p << " : ";
- p.printType(getDevice().getType());
- p << ">";
- p.printOptionalArrowTypeList(getResultTypes());
- p << "\n";
- p.getStream().indent(4);
- interleave(
- llvm::zip_equal(getConditions(), getConditionRegions()),
- [&](std::tuple<Attribute, Region &> it) {
- auto &conditionAttr = std::get<0>(it);
- auto &conditionRegion = std::get<1>(it);
- p.printAttribute(conditionAttr);
- p << " ";
- p.printRegion(conditionRegion,
- /*printEntryBlockArgs=*/false,
- /*printBlockTerminators=*/true);
- },
- [&]() {
- p << ",\n";
- p.getStream().indent(4);
- });
- p.printOptionalAttrDictWithKeyword(op->getAttrs(),
- /*elidedAttrs=*/{"conditions"});
-}
-
-LogicalResult DeviceSwitchOp::verify() {
- DeviceSwitchOp op = *this;
- if (op.getConditions().size() != op.getConditionRegions().size()) {
- return op.emitOpError() << "requires conditions and regions be matched 1:1";
- } else if (op.getConditionRegions().empty()) {
- return op.emitOpError() << "requires at least one condition";
- }
- for (auto ®ion : op.getConditionRegions()) {
- for (auto &block : region) {
- if (auto returnOp =
- dyn_cast_or_null<IREE::HAL::ReturnOp>(block.getTerminator())) {
- if (!std::equal(returnOp.getOperandTypes().begin(),
- returnOp.getOperandTypes().end(),
- op.getResultTypes().begin())) {
- return op.emitOpError()
- << "requires all regions return the same types";
- }
- }
- }
- }
- return success();
-}
-
-//===----------------------------------------------------------------------===//
// hal.device.queue.*
//===----------------------------------------------------------------------===//
@@ -858,6 +888,21 @@
}
//===----------------------------------------------------------------------===//
+// hal.executable.source
+//===----------------------------------------------------------------------===//
+
+LogicalResult ExecutableSourceOp::verify() {
+ ExecutableSourceOp op = *this;
+
+ auto conditionOps = getOps<IREE::HAL::ExecutableConditionOp>();
+ if (llvm::range_size(conditionOps) > 1)
+ return op.emitOpError()
+ << "only one condition op is allowed in an executable";
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// hal.executable
//===----------------------------------------------------------------------===//
@@ -1102,9 +1147,19 @@
state.addAttribute("target", target);
}
+LogicalResult ExecutableVariantOp::verify() {
+ ExecutableVariantOp op = *this;
+
+ auto conditionOps = getOps<IREE::HAL::ExecutableConditionOp>();
+ if (llvm::range_size(conditionOps) > 1)
+ return op.emitOpError() << "only one condition op is allowed in a variant";
+
+ return success();
+}
+
DenseMap<Attribute, int> ExecutableVariantOp::gatherConstantOrdinals() {
DenseMap<Attribute, int> map;
- for (auto blockOp : getOps<IREE::HAL::ExecutableConstantBlockOp>()) {
+ for (auto blockOp : getConstantBlockOps()) {
int baseCount = map.size();
for (auto [i, keyAttr] : llvm::enumerate(blockOp.getKeys())) {
map.try_emplace(keyAttr, baseCount + i);
@@ -1113,6 +1168,89 @@
return map;
}
+Value ExecutableVariantOp::buildCondition(Value device, OpBuilder &builder) {
+ // Base case dependent on target information.
+ auto matchAttr =
+ cast<IREE::HAL::MatchAttrInterface>(getTarget().getMatchExpression());
+ auto selected = matchAttr.buildConditionExpression(getLoc(), device, builder);
+
+ // Factor in variant condition region, if any.
+ auto conditionOp = getConditionOp();
+ if (conditionOp) {
+ auto regionOp = builder.create<scf::ExecuteRegionOp>(conditionOp.getLoc(),
+ builder.getI1Type());
+
+ IRMapping mapper;
+ mapper.map(conditionOp.getRegion().getArgument(0), device);
+ conditionOp.getRegion().cloneInto(®ionOp.getRegion(), mapper);
+
+ for (auto returnOp :
+ llvm::make_early_inc_range(regionOp.getOps<IREE::HAL::ReturnOp>())) {
+ OpBuilder(returnOp).create<scf::YieldOp>(returnOp.getLoc(),
+ returnOp.getOperands());
+ returnOp.erase();
+ }
+
+ selected = builder.create<arith::AndIOp>(getLoc(), selected,
+ regionOp.getResult(0));
+ }
+
+ return selected;
+}
+
+//===----------------------------------------------------------------------===//
+// hal.executable.condition
+//===----------------------------------------------------------------------===//
+
+LogicalResult ExecutableConditionOp::verify() {
+ ExecutableConditionOp op = *this;
+ return verifyTargetConditionRegion(op, op.getBody());
+}
+
+void ExecutableConditionOp::build(OpBuilder &builder, OperationState &result,
+ ArrayRef<NamedAttribute> attrs) {
+ result.addAttribute(
+ "function_type",
+ TypeAttr::get(getTargetConditionRegionType(builder.getContext())));
+ result.addRegion();
+ result.attributes.append(attrs.begin(), attrs.end());
+}
+
+ParseResult ExecutableConditionOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ if (parseTargetConditionRegion(parser, *result.addRegion()))
+ return failure();
+ result.addAttribute(
+ "function_type",
+ TypeAttr::get(getTargetConditionRegionType(parser.getContext())));
+ if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
+ return failure();
+ return success();
+}
+
+void ExecutableConditionOp::print(OpAsmPrinter &p) {
+ Operation *op = getOperation();
+ printTargetConditionRegion(p, op, getBody());
+ p.printOptionalAttrDictWithKeyword(op->getAttrs(),
+ /*elidedAttrs=*/{"function_type"});
+}
+
+Block *ExecutableConditionOp::addEntryBlock() {
+ assert(empty() && "function already has an entry block");
+ auto *entry = new Block();
+ auto argTypes = getArgumentTypes();
+ SmallVector<Location> argLocs(argTypes.size(), getLoc());
+ entry->addArguments(argTypes, argLocs);
+ push_back(entry);
+ return entry;
+}
+
+Block *ExecutableConditionOp::addBlock() {
+ assert(!empty() && "function should at least have an entry block");
+ push_back(new Block());
+ return &back();
+}
+
//===----------------------------------------------------------------------===//
// hal.executable.constant.block
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
index fc27fe4..29f9457 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -394,10 +394,10 @@
type($arguments), $argument_dims,
type($results), $result_dims,
$tied_operands)
+ `count` `` custom<WorkgroupCountRegion>($workgroup_count)
`layout` `(` $layout `)`
(`bindings` `(` $bindings^ `)`)?
`objects` `(` $objects `)`
- `count` `` custom<WorkgroupCountRegion>($workgroup_count)
attr-dict-with-keyword
}];
@@ -1534,88 +1534,8 @@
];
}
-def HAL_DeviceSwitchOp : HAL_Op<"device.switch", [
- NoRegionArguments,
- RecursiveMemoryEffects,
- ]> {
- let summary = [{runtime device switch pseudo op}];
- let description = [{
- Switches between multiple regions based on the runtime device type.
- The provided regions are matched against the runtime backend of the given
- device and executed only when the device matches the conditions.
-
- Conditions can match on wildcards and be folded to enable conditions that
- have similar bodies to be folded. The patterns themselves are only matched
- once at startup and then the results are cached; the runtime overhead is
- equivalent to a normal switch statement. In cases where the compiler can
- statically identify the device type entire cases can be folded away.
-
- Supported conditions:
- * `#hal.match...`: execute the region if the expression matches.
-
- Supported match expressions:
- * `#hal.match.always`: always matches; useful for defaults.
- * `#hal.match.any<[...]>`: matches if any of the nested expressions match.
- * `#hal.match.all<[...]>`: matches only if all of the nested expressions
- match.
- * `#hal.device.match.id<"pattern*-?-*">`: matches against the device
- identifier. The pattern is evaluated with standard file path wildcards
- (`*` for zero or more characters and `?` for one character).
-
- If more than one condition is satisfied the first listed will be chosen.
- More specific conditions should be earlier in the set. If no condition is
- matched but there are return values the switch will abort at runtime. It's
- strongly recommend that all switches that return values end with a trailing
- `#hal.match.always` condition to handle the fallthrough case.
-
- Upon creation each condition region will have an empty entry block with the
- specified operands available as arguments. Each region must be setup to
- return the same types.
-
- ```mlir
- %c0 = arith.constant 0 : i32
- %c1 = arith.constant 1 : i32
- %c2 = arith.constant 2 : i32
- %device = ... : !hal.device
- %0 = hal.device.switch<%device : !hal.device> -> i32
- #hal.device.match.id<"vulkan-v1.?-*"> {
- hal.return %c1 : i32
- },
- #hal.match.any<[#hal.device.match.id<"vmvx">, #hal.device.match.id<"vulkan-*">]> {
- hal.return %c2 : i32
- },
- #hal.match.always {
- hal.return %c0 : i32
- }
- ```
- }];
-
- let arguments = (ins
- HAL_Device:$device,
- ArrayAttr:$conditions
- );
- let results = (outs
- Variadic<AnyType>:$results
- );
-
- let regions = (region VariadicRegion<AnyRegion>:$condition_regions);
-
- let skipDefaultBuilders = 1;
- let builders = [
- OpBuilder<(ins
- "TypeRange":$resultTypes,
- "Value":$device,
- "ArrayRef<Attribute>":$conditions,
- CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes
- )>,
- ];
-
-
- let hasVerifier = 1;
-}
-
def HAL_ReturnOp : HAL_Op<"return", [Terminator]> {
- let summary = [{return from a hal.device.switch region}];
+ let summary = [{return from a hal.* region}];
let description = [{
Returns the given values from the region and back to the host code.
}];
@@ -1955,19 +1875,34 @@
OptionalAttr<HAL_ExecutableObjectsAttr>:$objects
);
- let regions = (region SizedRegion<1>:$body);
+ let regions = (region
+ SizedRegion<1>:$body
+ );
let assemblyFormat = [{
custom<SymbolVisibility>($sym_visibility)
$sym_name
attr-dict-with-keyword
``
- regions
+ $body
}];
let extraClassDeclaration = [{
Block& getBlock() { return getBody().front(); }
+ IREE::HAL::ExecutableConditionOp getConditionOp() {
+ auto conditionOps = getBody().getOps<IREE::HAL::ExecutableConditionOp>();
+ return !conditionOps.empty() ? *conditionOps.begin() : IREE::HAL::ExecutableConditionOp{};
+ }
+ iterator_range<Region::op_iterator<IREE::HAL::ExecutableConstantBlockOp>>
+ getConstantBlockOps() {
+ return getBody().getOps<IREE::HAL::ExecutableConstantBlockOp>();
+ }
+ iterator_range<Region::op_iterator<IREE::HAL::ExecutableExportOp>>
+ getExportOps() {
+ return getBody().getOps<IREE::HAL::ExecutableExportOp>();
+ }
+
bool isExternal() {
return getBlock().getOps<::mlir::ModuleOp>().empty();
}
@@ -1978,6 +1913,8 @@
return *it.begin();
}
}];
+
+ let hasVerifier = 1;
}
def HAL_ExecutableSourceEndOp : HAL_Op<"executable.source_end", [
@@ -2114,6 +2051,12 @@
let description = [{
The target IR for the executable. This can be preserved for debugging but
is usually removed during transformation.
+
+ Variants are selected based on their target and an optional condition
+ op that returns true if the variant is valid for use on the provided
+ runtime `!hal.device`. If no variants within an executable are valid then
+ loading will fail at runtime. If multiple variants are valid the first valid
+ one found will be loaded and used for execution.
}];
let arguments = (ins
@@ -2123,7 +2066,9 @@
OptionalAttr<HAL_ExecutableObjectArrayAttr>:$objects
);
- let regions = (region SizedRegion<1>:$body);
+ let regions = (region
+ SizedRegion<1>:$body
+ );
let assemblyFormat = [{
custom<SymbolVisibility>($sym_visibility)
@@ -2131,7 +2076,7 @@
`target` `(` $target `)`
(`objects` `(` $objects^ `)` )?
attr-dict-with-keyword
- regions
+ $body
}];
let skipDefaultBuilders = 1;
@@ -2142,6 +2087,19 @@
let extraClassDeclaration = [{
Block& getBlock() { return getBody().front(); }
+ IREE::HAL::ExecutableConditionOp getConditionOp() {
+ auto conditionOps = getBody().getOps<IREE::HAL::ExecutableConditionOp>();
+ return !conditionOps.empty() ? *conditionOps.begin() : IREE::HAL::ExecutableConditionOp{};
+ }
+ iterator_range<Region::op_iterator<IREE::HAL::ExecutableConstantBlockOp>>
+ getConstantBlockOps() {
+ return getBody().getOps<IREE::HAL::ExecutableConstantBlockOp>();
+ }
+ iterator_range<Region::op_iterator<IREE::HAL::ExecutableExportOp>>
+ getExportOps() {
+ return getBody().getOps<IREE::HAL::ExecutableExportOp>();
+ }
+
bool isExternal() {
return getBlock().getOps<::mlir::ModuleOp>().empty();
}
@@ -2155,9 +2113,13 @@
// Returns a map of constant key attributes to ordinals across all constant
// blocks inside the variant.
DenseMap<Attribute, int> gatherConstantOrdinals();
+
+ // Returns an i1 indicating whether this variant should be selected.
+ Value buildCondition(Value device, OpBuilder &builder);
}];
let hasCanonicalizer = 1;
+ let hasVerifier = 1;
}
def HAL_ExecutableVariantEndOp : HAL_Op<"executable.variant_end", [
@@ -2168,6 +2130,59 @@
let assemblyFormat = "attr-dict";
}
+def HAL_ExecutableConditionOp : HAL_Op<"executable.condition", [
+ IsolatedFromAbove,
+ FunctionOpInterface,
+ CallableOpInterface,
+ ]> {
+ let summary = [{host code to determine if the executable is enabled}];
+ let description = [{
+ Variants are selected based on their target and this optional condition
+ op that returns true if the variant is valid for use on the provided
+ runtime `!hal.device`. If no variants within an executable are valid then
+ loading will fail at runtime. If multiple variants are valid the first valid
+ one found will be loaded and used for execution.
+ }];
+
+ let arguments = (ins
+ TypeAttrOf<FunctionType>:$function_type,
+ OptionalAttr<DictArrayAttr>:$arg_attrs,
+ OptionalAttr<DictArrayAttr>:$res_attrs
+ );
+
+ let regions = (region AnyRegion:$body);
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<(ins
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs
+ )>,
+ ];
+
+ let extraClassDeclaration = [{
+ /// Add an entry block to an empty function and set up the block arguments
+ /// to match the signature of the function.
+ Block *addEntryBlock();
+ Block *addBlock();
+
+ ArrayRef<Type> getArgumentTypes() { return getFunctionType().getInputs(); }
+ ArrayRef<Type> getResultTypes() { return getFunctionType().getResults(); }
+
+ LogicalResult verifyType() { return success(); }
+
+ Region *getCallableRegion() { return &getBody(); }
+ ArrayRef<Type> getCallableResults() { return getResultTypes(); }
+
+ ::mlir::ArrayAttr getCallableArgAttrs() { return nullptr; }
+ ::mlir::ArrayAttr getCallableResAttrs() { return nullptr; }
+
+ /// Make symbol optional as this op has no symbol.
+ bool isOptionalSymbol() { return true; }
+ }];
+
+ let hasVerifier = 1;
+}
+
def HAL_ExecutableConstantBlockOp :
HAL_Op<"executable.constant.block", [
ParentOneOf<[
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h
index 665625e..2d5a196 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h
@@ -41,7 +41,7 @@
#include "iree/compiler/Dialect/HAL/IR/HALTypeInterfaces.h.inc" // IWYU pragma: export
//===----------------------------------------------------------------------===//
-// Enum utilities
+// Utilities
//===----------------------------------------------------------------------===//
// Returns a stable identifier for the MLIR element type or nullopt if the
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir
index c163e14..a04e5db 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir
@@ -10,40 +10,6 @@
// -----
-// CHECK-LABEL: @device_switch
-// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device)
-func.func @device_switch(%device: !hal.device) -> i32 {
- // CHECK-DAG: %[[C0:.+]] = arith.constant 0
- %c0 = arith.constant 0 : i32
- // CHECK-DAG: %[[C1:.+]] = arith.constant 1
- %c1 = arith.constant 1 : i32
- // CHECK-DAG: %[[C2:.+]] = arith.constant 2
- %c2 = arith.constant 2 : i32
- // CHECK: = hal.device.switch<%[[DEVICE]] : !hal.device> -> i32
- %0 = hal.device.switch<%device : !hal.device> -> i32
- // CHECK-NEXT: #hal.device.match.id<"vulkan-v1.?-*"> {
- #hal.device.match.id<"vulkan-v1.?-*"> {
- // CHECK-NEXT: hal.return %[[C1]] : i32
- hal.return %c1 : i32
- // CHECK-NEXT: },
- },
- // CHECK-NEXT: #hal.match.any<[#hal.device.match.id<"vmvx">, #hal.device.match.id<"vulkan-*">]> {
- #hal.match.any<[#hal.device.match.id<"vmvx">, #hal.device.match.id<"vulkan-*">]> {
- // CHECK-NEXT: hal.return %[[C2]] : i32
- hal.return %c2 : i32
- // CHECK-NEXT: },
- },
- // CHECK-NEXT: #hal.match.always {
- #hal.match.always {
- // CHECK-NEXT: hal.return %[[C0]] : i32
- hal.return %c0 : i32
- // CHECK-NEXT: }
- }
- return %0 : i32
-}
-
-// -----
-
// CHECK-LABEL: @device_query
// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device)
func.func @device_query(%device : !hal.device) -> (i1, i32) {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir
index a2590db..acd872d 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir
@@ -1,6 +1,7 @@
// RUN: iree-opt --split-input-file %s | FileCheck %s
#executable_target_format = #hal.executable.target<"backend", "format">
+
// CHECK-LABEL: @ex
hal.executable @ex {
// CHECK: hal.executable.variant public @backend
@@ -67,6 +68,47 @@
#executable_target_format = #hal.executable.target<"backend", "format">
+// CHECK-LABEL: @ex_with_condition
+hal.executable @ex_with_condition {
+ // CHECK: hal.executable.variant public @backend target(#executable_target_format
+ hal.executable.variant @backend target(#executable_target_format) {
+ // CHECK: hal.executable.condition(%[[DEVICE:.+]]: !hal.device) -> i1 {
+ hal.executable.condition(%device: !hal.device) -> i1 {
+ // CHECK-NEXT: %[[OK:.+]], %[[VALUE:.+]] = hal.device.query<%[[DEVICE]]
+ %ok, %value = hal.device.query<%device : !hal.device> key("some" :: "value") : i1, i32
+ // CHECK-NEXT: return %[[OK]]
+ hal.return %ok : i1
+ }
+
+ // CHECK-DAG: hal.executable.export public @entry0 ordinal(0) layout(#pipeline_layout) attributes {
+ // CHECK-SAME: subgroup_size = 64 : index
+ // CHECK-SAME: workgroup_size = [4 : index, 1 : index, 1 : index]
+ hal.executable.export @entry0 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>
+ ]>
+ ]>) attributes {
+ subgroup_size = 64 : index,
+ workgroup_size = [4 : index, 1 : index, 1 : index]
+ } {
+ ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index):
+ hal.return %arg0, %arg1, %arg2 : index, index, index
+ }
+ }
+ // CHECK: hal.executable.binary
+ hal.executable.binary @backend_binary attributes {
+ // CHECK-SAME: data = dense<1> : vector<128xi8>,
+ data = dense<1> : vector<128xi8>,
+ // CHECK-SAME: format = "some_format"
+ format = "some_format"
+ }
+}
+
+// -----
+
+#executable_target_format = #hal.executable.target<"backend", "format">
+
// CHECK-LABEL: @ex_with_constants
hal.executable @ex_with_constants {
// CHECK: hal.executable.variant public @backend
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/tensor_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/tensor_ops.mlir
index 86b81c9..42f1d9e 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/tensor_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/tensor_ops.mlir
@@ -64,6 +64,15 @@
// Dispatch workgroups to the externally defined function "main" in the
// referenced object files.
%0 = hal.dispatch.extern "main"[%x, %y](%arg0, %arg1, %arg2) : (tensor<4xi32>, tensor<8xi32>, i32) -> %arg1
+ // Translates the workload (%x and %y captured above) into an XYZ workgroup
+ // count, optionally using device information.
+ count(%device: !hal.device, %x_capture: index, %y_capture: index) -> (index, index, index) {
+ // Shows how device queries can be used when computing the workgroup count.
+ // The device is the one used at runtime.
+ %ok, %z_i32 = hal.device.query<%device : !hal.device> key("some" :: "value") : i1, i32
+ %z = arith.index_cast %z_i32 : i32 to index
+ hal.return %x_capture, %y_capture, %z : index, index, index
+ }
// Must match the external definition.
layout(#hal.pipeline.layout<push_constants = 1, sets = [
<0, bindings = [
@@ -81,14 +90,5 @@
#hal.executable.target<"llvm-cpu", "a"> = [#hal.executable.object<{path = "a.o"}>],
#hal.executable.target<"llvm-cpu", "b"> = [#hal.executable.object<{path = "b.o"}>]
}>)
- // Translates the workload (%x and %y captured above) into an XYZ workgroup
- // count, optionally using device information.
- count(%device: !hal.device, %x_capture: index, %y_capture: index) -> (index, index, index) {
- // Shows how device queries can be used when computing the workgroup count.
- // The device is the one used at runtime.
- %ok, %z_i32 = hal.device.query<%device : !hal.device> key("some" :: "value") : i1, i32
- %z = arith.index_cast %z_i32 : i32 to index
- hal.return %x_capture, %y_capture, %z : index, index, index
- }
return %0 : tensor<8xi32>
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Target/BUILD.bazel
index 0ba1199..655eb9a 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/BUILD.bazel
@@ -25,7 +25,6 @@
deps = [
"//compiler/src/iree/compiler/Dialect/Flow/IR",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
- "//compiler/src/iree/compiler/Dialect/HAL/Utils",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"//compiler/src/iree/compiler/Utils",
"@llvm-project//llvm:Support",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Target/CMakeLists.txt
index 689783e..12aca6d 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/CMakeLists.txt
@@ -27,7 +27,6 @@
MLIRTransforms
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::IR
- iree::compiler::Dialect::HAL::Utils
iree::compiler::Dialect::Util::IR
iree::compiler::Utils
PUBLIC
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTarget.cpp
index 88ee535..b72f31d 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTarget.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTarget.cpp
@@ -147,7 +147,7 @@
// Collect all the entry point names.
llvm::StringMap<IREE::HAL::ExecutableExportOp> exportOps;
- for (auto op : variantOp.getOps<IREE::HAL::ExecutableExportOp>()) {
+ for (auto op : variantOp.getExportOps()) {
exportOps[op.getSymName()] = op;
}
std::vector<std::array<int32_t, 3>> workgroupSizes;
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetBackend.h b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetBackend.h
index 4c61dcb..33fc136 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetBackend.h
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetBackend.h
@@ -14,7 +14,6 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "iree/compiler/Dialect/HAL/Utils/DeviceSwitchBuilder.h"
#include "iree/compiler/Utils/OptionUtils.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
index 829dbef..51df7ca 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
@@ -250,7 +250,7 @@
// Take exported names verbatim for passing into VkShaderModuleCreateInfo.
SmallVector<StringRef, 8> entryPointNames;
- for (auto exportOp : variantOp.getOps<IREE::HAL::ExecutableExportOp>()) {
+ for (auto exportOp : variantOp.getExportOps()) {
entryPointNames.emplace_back(exportOp.getSymName());
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp
index 0905020..8018ca2 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp
@@ -148,8 +148,7 @@
// For each executable entry point op, rename the entry point symbol using
// that convention and keep track of the mapping between entry point
// ordinals to which shader module they reference.
- auto exportOps =
- llvm::to_vector(variantOp.getOps<IREE::HAL::ExecutableExportOp>());
+ auto exportOps = llvm::to_vector(variantOp.getExportOps());
llvm::SmallVector<uint32_t> entryPointOrdinals(exportOps.size());
SymbolTableCollection symbolTable;
SymbolUserMap symbolUsers(symbolTable, variantOp);
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel
index f511949..6b69403 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel
@@ -22,7 +22,6 @@
"DumpExecutableSources.cpp",
"ElideRedundantCommands.cpp",
"FixupLegacySync.cpp",
- "InlineDeviceSwitches.cpp",
"LinkExecutables.cpp",
"MaterializeDispatchInstrumentation.cpp",
"MaterializeInterfaces.cpp",
@@ -52,7 +51,6 @@
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Dialect/HAL/IR:HALDialect",
"//compiler/src/iree/compiler/Dialect/HAL/Target",
- "//compiler/src/iree/compiler/Dialect/HAL/Utils",
"//compiler/src/iree/compiler/Dialect/Stream/IR",
"//compiler/src/iree/compiler/Dialect/Stream/Transforms",
"//compiler/src/iree/compiler/Dialect/Util/Conversion",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
index 703768b..81bee1b 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
@@ -23,7 +23,6 @@
"DumpExecutableSources.cpp"
"ElideRedundantCommands.cpp"
"FixupLegacySync.cpp"
- "InlineDeviceSwitches.cpp"
"LinkExecutables.cpp"
"MaterializeDispatchInstrumentation.cpp"
"MaterializeInterfaces.cpp"
@@ -63,7 +62,6 @@
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::HAL::IR::HALDialect
iree::compiler::Dialect::HAL::Target
- iree::compiler::Dialect::HAL::Utils
iree::compiler::Dialect::Stream::IR
iree::compiler::Dialect::Stream::Transforms
iree::compiler::Dialect::Util::Conversion
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
index 4bad200..1496e28 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
@@ -408,7 +408,7 @@
// Add functions to test each entry point with its various dispatch
// parameters.
bool hasAnyBenchmarks = false;
- for (auto exportOp : variantOp.getOps<IREE::HAL::ExecutableExportOp>()) {
+ for (auto exportOp : variantOp.getExportOps()) {
auto symbolRefAttr =
SymbolRefAttr::get(executableOp.getNameAttr(),
{
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/InlineDeviceSwitches.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/InlineDeviceSwitches.cpp
deleted file mode 100644
index a40e7b8..0000000
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/InlineDeviceSwitches.cpp
+++ /dev/null
@@ -1,175 +0,0 @@
-// Copyright 2020 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include <utility>
-
-#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
-#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
-#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
-#include "llvm/ADT/StringSet.h"
-#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/IRMapping.h"
-#include "mlir/Pass/Pass.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-namespace HAL {
-
-// Inlines a condition region from a switch op into the function at the given
-// point. This assumes that the insertion point will only be reached if the
-// condition the region is predicated on is true.
-static void inlineConditionRegion(Region &conditionRegion, Block *exitBlock,
- OpBuilder funcBuilder) {
- assert(!conditionRegion.empty() && "source regions must not be empty");
- assert(conditionRegion.front().getNumArguments() == 0 &&
- "switch does not capture");
-
- // Splice in the region blocks.
- auto *insertBlock = funcBuilder.getBlock();
- auto postInsertBlockIt = std::next(insertBlock->getIterator())->getIterator();
- auto *insertRegion = insertBlock->getParent();
- insertRegion->getBlocks().splice(postInsertBlockIt,
- conditionRegion.getBlocks());
- auto newBlocks = llvm::make_range(std::next(insertBlock->getIterator()),
- postInsertBlockIt);
- auto *firstNewBlock = &*newBlocks.begin();
-
- // Handle the hal.return ops which will transfer control to the exitBlock.
- for (auto &newBlock : newBlocks) {
- if (auto returnOp =
- dyn_cast<IREE::HAL::ReturnOp>(newBlock.getTerminator())) {
- OpBuilder branchBuilder(returnOp);
- branchBuilder.create<cf::BranchOp>(returnOp.getLoc(), exitBlock,
- returnOp.getOperands());
- returnOp.erase();
- }
- }
-
- // Splice the instructions of the inlined entry block into the insert block.
- insertBlock->getOperations().splice(insertBlock->end(),
- firstNewBlock->getOperations());
- firstNewBlock->erase();
-}
-
-// Inlines each switch condition region into the parent function predicated on
-// the switch condition expression.
-//
-// Since switch conditions are evaluated in the order they are defined we can
-// trivially turn the switch into a chain of if-else blocks.
-// if condition_0_match:
-// <inlined condition_0>
-// else
-// if condition_1_match:
-// <inlined condition_1>
-// else ...
-static void buildConditionDispatchTable(IREE::HAL::DeviceSwitchOp switchOp,
- OpBuilder funcBuilder) {
- // Split the block containing the switch op such that all ops before the
- // switch are before and the switch and the following ops are after.
- // We'll have all of our inlined regions bounce over to the afterBlock with
- // the results of the call and use that to replace the switch op.
- auto *beforeBlock = funcBuilder.getBlock();
- auto *afterBlock = beforeBlock->splitBlock(switchOp);
- SmallVector<Location> locs(switchOp.getNumResults(), switchOp.getLoc());
- auto finalValues = llvm::to_vector(
- afterBlock->addArguments(switchOp.getResultTypes(), locs));
-
- // Create the blocks we'll use for all our conditions so that we can
- // reference them when inserting the branch ops.
- SmallVector<Block *> conditionMatchBlocks(
- switchOp.getConditionRegions().size());
- SmallVector<Block *> conditionFallthroughBlocks(
- switchOp.getConditionRegions().size());
- for (int i = 0; i < conditionMatchBlocks.size(); ++i) {
- conditionMatchBlocks[i] = funcBuilder.createBlock(afterBlock);
- conditionFallthroughBlocks[i] = funcBuilder.createBlock(afterBlock);
- }
-
- funcBuilder.setInsertionPoint(beforeBlock, beforeBlock->end());
- for (auto condition :
- llvm::enumerate(llvm::zip_equal(switchOp.getConditions().getValue(),
- switchOp.getConditionRegions()))) {
- auto conditionAttr = llvm::cast<IREE::HAL::MatchAttrInterface>(
- std::get<0>(condition.value()));
- auto &conditionRegion = std::get<1>(condition.value());
-
- // Insert the branch based on the match. We either match and jump to a
- // block that will contain the inlined region or don't match and need to
- // fall through.
- auto isMatch = conditionAttr.buildConditionExpression(
- switchOp.getLoc(), switchOp.getDevice(), funcBuilder);
- auto *matchBlock = conditionMatchBlocks[condition.index()];
- auto *fallthroughBlock = conditionFallthroughBlocks[condition.index()];
- funcBuilder.create<cf::CondBranchOp>(switchOp.getLoc(), isMatch, matchBlock,
- fallthroughBlock);
-
- // Block that contains the inlined region and then jumps out of the chain.
- funcBuilder.setInsertionPointToStart(matchBlock);
- inlineConditionRegion(conditionRegion, afterBlock, funcBuilder);
-
- // Block that we enter to check the next condition.
- funcBuilder.setInsertionPointToStart(fallthroughBlock);
- if (condition.index() + 1 < conditionFallthroughBlocks.size()) {
- // Just continue on - the next loop iteration for the following
- // condition will add its IR to the block.
- } else {
- // Fallthrough of all expressions; die if we expected return values.
- funcBuilder.create<IREE::Util::UnreachableOp>(
- switchOp.getLoc(),
- "device not supported in the compiled configuration");
- }
- }
-
- // Remove the switch op and replace its results with the final joined
- // results.
- switchOp.replaceAllUsesWith(finalValues);
-}
-
-class InlineDeviceSwitchesPass
- : public PassWrapper<InlineDeviceSwitchesPass, OperationPass<void>> {
-public:
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<IREE::Util::UtilDialect>();
- }
-
- StringRef getArgument() const override {
- return "iree-hal-inline-device-switches";
- }
-
- StringRef getDescription() const override {
- return "Inlines hal.device.switch condition regions";
- }
-
- void runOnOperation() override {
- auto funcOp = getOperation();
- SmallVector<IREE::HAL::DeviceSwitchOp> switchOps;
- funcOp->walk([&](IREE::HAL::DeviceSwitchOp switchOp) {
- switchOps.push_back(switchOp);
- });
- for (auto switchOp : switchOps) {
- OpBuilder funcBuilder(switchOp);
- buildConditionDispatchTable(switchOp, funcBuilder);
- switchOp.erase();
- }
- }
-};
-
-std::unique_ptr<OperationPass<void>> createInlineDeviceSwitchesPass() {
- return std::make_unique<InlineDeviceSwitchesPass>();
-}
-
-static PassRegistration<InlineDeviceSwitchesPass> pass;
-
-} // namespace HAL
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
index 63cf4b7..6979939 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
@@ -35,7 +35,7 @@
namespace {
// Map of original SymbolRefAttr to a list of SymbolRefAttrs in variants.
-using EntryPointExpansions = DenseMap<Attribute, SmallVector<Attribute>>;
+using ExportExpansions = DenseMap<Attribute, SmallVector<Attribute>>;
//===----------------------------------------------------------------------===//
// Linkage utilities
@@ -70,7 +70,7 @@
static LogicalResult materializeExecutableFromSourceOp(
IREE::HAL::ExecutableSourceOp sourceOp,
ArrayRef<IREE::HAL::ExecutableTargetAttr> targetAttrs,
- EntryPointExpansions &entryPointExpansions) {
+ ExportExpansions &exportExpansions) {
OpBuilder moduleBuilder(sourceOp);
// Create the op that will contain the translated executable.
@@ -80,7 +80,7 @@
// With this hand-authored path all variants have the same layout and entry
// points and we can just clone them.
- auto sourceEntryPointOps = sourceOp.getOps<IREE::HAL::ExecutableExportOp>();
+ auto sourceExportOps = sourceOp.getExportOps();
// Materialize all of the hal.executable.variant ops for all backends we are
// targeting.
@@ -92,16 +92,15 @@
sourceOp->getLoc(), targetAttr.getSymbolNameFragment(), targetAttr);
targetSymbolTable.insert(targetVariantOp);
OpBuilder variantBuilder(&targetVariantOp.getBlock().back());
- for (auto sourceEntryPointOp : sourceEntryPointOps) {
- variantBuilder.clone(*sourceEntryPointOp);
+ for (auto sourceExportOp : sourceExportOps) {
+ variantBuilder.clone(*sourceExportOp);
// Map the original export names to the new variant exports.
- entryPointExpansions[SymbolRefAttr::get(
- executableOp.getNameAttr(),
- {FlatSymbolRefAttr::get(
- sourceEntryPointOp.getNameAttr())})]
+ exportExpansions[SymbolRefAttr::get(executableOp.getNameAttr(),
+ {FlatSymbolRefAttr::get(
+ sourceExportOp.getNameAttr())})]
.push_back(makeExportSymbolRefAttr(executableOp, targetVariantOp,
- sourceEntryPointOp));
+ sourceExportOp));
}
// Clone any target-specific object files specified.
@@ -124,8 +123,9 @@
return success();
}
-static LogicalResult materializeExecutablesFromSourceOps(
- mlir::ModuleOp moduleOp, EntryPointExpansions &entryPointExpansions) {
+static LogicalResult
+materializeExecutablesFromSourceOps(mlir::ModuleOp moduleOp,
+ ExportExpansions &exportExpansions) {
auto sourceOps =
llvm::to_vector<32>(moduleOp.getOps<IREE::HAL::ExecutableSourceOp>());
for (auto sourceOp : sourceOps) {
@@ -139,7 +139,7 @@
}
if (failed(materializeExecutableFromSourceOp(sourceOp, targetAttrs,
- entryPointExpansions))) {
+ exportExpansions))) {
return failure();
}
}
@@ -271,14 +271,13 @@
}
// Updates the target entry point symbols of |dispatchOp| to the expanded set of
-// variant exports in |entryPointExpansions|.
-static void
-updateDispatchTargets(IREE::Stream::CmdDispatchOp dispatchOp,
- const EntryPointExpansions &entryPointExpansions) {
+// variant exports in |exportExpansions|.
+static void updateDispatchTargets(IREE::Stream::CmdDispatchOp dispatchOp,
+ const ExportExpansions &exportExpansions) {
SmallVector<Attribute> newAttrs;
for (auto oldAttr : dispatchOp.getEntryPointRefs()) {
- auto it = entryPointExpansions.find(oldAttr);
- if (it == entryPointExpansions.end()) {
+ auto it = exportExpansions.find(oldAttr);
+ if (it == exportExpansions.end()) {
newAttrs.push_back(oldAttr); // preserve existing
continue;
}
@@ -313,7 +312,7 @@
declareEntryPointOps(IREE::Stream::ExecutableOp sourceExecutableOp,
IREE::HAL::ExecutableOp targetExecutableOp,
const BindingLayoutAnalysis &layoutAnalysis,
- EntryPointExpansions &entryPointExpansions) {
+ ExportExpansions &exportExpansions) {
auto variantOps =
targetExecutableOp.getBlock().getOps<IREE::HAL::ExecutableVariantOp>();
OpBuilder executableBuilder(&targetExecutableOp.getBlock().front());
@@ -387,9 +386,9 @@
/*workgroup_local_memory=*/IntegerAttr{});
// Map the original export name to the new variant export.
- entryPointExpansions[SymbolRefAttr::get(sourceExecutableOp.getNameAttr(),
- {FlatSymbolRefAttr::get(
- exportOp.getNameAttr())})]
+ exportExpansions[SymbolRefAttr::get(
+ sourceExecutableOp.getNameAttr(),
+ {FlatSymbolRefAttr::get(exportOp.getNameAttr())})]
.push_back(makeExportSymbolRefAttr(targetExecutableOp, variantOp,
newExportOp));
@@ -537,12 +536,12 @@
void runOnOperation() override {
SymbolTable symbolTable(getOperation());
- EntryPointExpansions entryPointExpansions;
+ ExportExpansions exportExpansions;
// Handle any hand-authored executables; these only need variant expansion
// and no layout analysis as the user specified the layout themselves.
if (failed(materializeExecutablesFromSourceOps(getOperation(),
- entryPointExpansions))) {
+ exportExpansions))) {
return signalPassFailure();
}
@@ -595,7 +594,7 @@
// Define interfaces for each exported function based on analysis.
if (failed(declareEntryPointOps(sourceOp, executableOp, layoutAnalysis,
- entryPointExpansions))) {
+ exportExpansions))) {
return signalPassFailure();
}
@@ -615,7 +614,7 @@
// pipeline layout, though, and any that fall through are errors.
auto updateDispatchSites = [&](IREE::Stream::CmdDispatchOp dispatchOp) {
// Update the export targets to point at the new variants.
- updateDispatchTargets(dispatchOp, entryPointExpansions);
+ updateDispatchTargets(dispatchOp, exportExpansions);
// Annotate the dispatch site with binding information if required.
// TODO(benvanik): remove this path; shouldn't be needed in real usage.
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
index e962542..249ff78 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
@@ -10,12 +10,10 @@
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
-#include "iree/compiler/Dialect/HAL/Utils/DeviceSwitchBuilder.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
-#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -44,7 +42,7 @@
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<mlir::arith::ArithDialect>();
- registry.insert<mlir::cf::ControlFlowDialect>();
+ registry.insert<mlir::scf::SCFDialect>();
registry.insert<IREE::HAL::HALDialect>();
}
@@ -91,8 +89,7 @@
for (auto executableOp : executableOps) {
for (auto variantOp :
executableOp.getOps<IREE::HAL::ExecutableVariantOp>()) {
- for (auto exportOp :
- variantOp.getOps<IREE::HAL::ExecutableExportOp>()) {
+ for (auto exportOp : variantOp.getExportOps()) {
definePipelineLayoutOp(exportOp.getLoc(), exportOp.getLayout());
}
}
@@ -226,21 +223,33 @@
// Each case should then cache only executables which contain a matching
// ExecutableVariantOp.
// Afterwards, canonicalization will take care of de-duping/etc.
- DeviceSwitchBuilder switchBuilder(loc,
- /*resultTypes=*/TypeRange{executableType},
- deviceValue, blockBuilder);
- for (auto executableVariantOp :
+ SmallVector<int64_t> caseIndices;
+ SmallVector<IREE::HAL::ExecutableVariantOp> caseVariantOps;
+ for (auto variantOp :
executableOp.getOps<IREE::HAL::ExecutableVariantOp>()) {
- auto *region = switchBuilder.addConditionRegion(
- executableVariantOp.getTarget().getMatchExpression());
- auto &entryBlock = region->front();
- auto caseBuilder = OpBuilder::atBlockBegin(&entryBlock);
+ caseIndices.push_back(caseIndices.size());
+ caseVariantOps.push_back(variantOp);
+ }
+
+ // Select the variant index.
+ Value selectedIndex = buildIfElseTree(
+ loc, caseVariantOps.size(),
+ [&](Location loc, size_t i, OpBuilder &builder) {
+ return caseVariantOps[i].buildCondition(deviceValue, builder);
+ },
+ blockBuilder);
+
+ // Allow each variant to define how it is loaded and what pipeline it has.
+ auto switchOp = blockBuilder.create<scf::IndexSwitchOp>(
+ loc, executableType, selectedIndex, caseIndices, caseIndices.size());
+ for (auto [i, variantOp] : llvm::enumerate(caseVariantOps)) {
+ auto &caseBlock = switchOp.getCaseRegions()[i].emplaceBlock();
+ auto caseBuilder = OpBuilder::atBlockBegin(&caseBlock);
// Gather each of the pipeline layouts needed for each entry point in
// the executable.
SmallVector<Value, 8> pipelineLayoutValues;
- for (auto exportOp :
- executableVariantOp.getOps<IREE::HAL::ExecutableExportOp>()) {
+ for (auto exportOp : variantOp.getExportOps()) {
auto pipelineLayoutGlobalOp =
definePipelineLayoutOp(executableOp.getLoc(), exportOp.getLayout());
pipelineLayoutValues.push_back(
@@ -253,32 +262,29 @@
// We want these to all happen inside of this device switch case; they'll
// get deduplicated/hoisted if possible in future canonicalization passes.
SmallVector<Value> constantValues;
- for (auto blockOp : llvm::make_early_inc_range(
- executableVariantOp
- .getOps<IREE::HAL::ExecutableConstantBlockOp>())) {
+ for (auto blockOp :
+ llvm::make_early_inc_range(variantOp.getConstantBlockOps())) {
constantValues.append(inlineConstantBlockOp(blockOp, moduleBuilder,
caseBuilder, deviceValue));
blockOp.erase();
}
auto executableValue = caseBuilder.createOrFold<ExecutableCreateOp>(
- loc, ExecutableType::get(loc.getContext()), deviceValue,
- SymbolRefAttr::get(
- executableOp.getSymNameAttr(),
- {SymbolRefAttr::get(executableVariantOp.getSymNameAttr())}),
+ loc, executableType, deviceValue,
+ SymbolRefAttr::get(executableOp.getSymNameAttr(),
+ {SymbolRefAttr::get(variantOp.getSymNameAttr())}),
pipelineLayoutValues, constantValues);
- caseBuilder.create<IREE::HAL::ReturnOp>(loc, executableValue);
+ caseBuilder.create<scf::YieldOp>(loc, executableValue);
}
- auto *defaultRegion = switchBuilder.addConditionRegion(
- IREE::HAL::MatchAlwaysAttr::get(loc.getContext()));
- auto defaultBuilder = OpBuilder::atBlockBegin(&defaultRegion->front());
+ // Fallback for no available variant.
+ auto &defaultBlock = switchOp.getDefaultRegion().emplaceBlock();
+ auto defaultBuilder = OpBuilder::atBlockBegin(&defaultBlock);
auto nullValue =
defaultBuilder.createOrFold<IREE::Util::NullOp>(loc, executableType);
- defaultBuilder.create<IREE::HAL::ReturnOp>(loc, nullValue);
+ defaultBuilder.create<scf::YieldOp>(loc, nullValue);
- auto switchOp = switchBuilder.build();
auto executableValue = switchOp.getResult(0);
blockBuilder.create<IREE::Util::GlobalStoreOp>(loc, executableValue,
globalOp.getName());
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
index abc6758..944ae53 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
@@ -318,10 +318,6 @@
// Device management and specialization
//----------------------------------------------------------------------------
- // Inline hal.device.switch ops and memoize their queries such that we can
- // better CSE/fold dispatch logic.
- FunctionLikeNest(passManager).addPass(createInlineDeviceSwitchesPass);
-
// Memoize device queries such that we don't need to repeatedly ask the same
// information at runtime.
passManager.addPass(createMemoizeDeviceQueriesPass());
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h
index 6b9146f..0c0cd24 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h
@@ -90,9 +90,6 @@
// removed.
std::unique_ptr<OperationPass<mlir::ModuleOp>> createFixupLegacySyncPass();
-// Outlines hal.device.switch conditions into functions and inlines conditions.
-std::unique_ptr<OperationPass<void>> createInlineDeviceSwitchesPass();
-
// Finds hal.device.query ops and creates variables initialized on startup.
std::unique_ptr<OperationPass<mlir::ModuleOp>> createMemoizeDeviceQueriesPass();
@@ -208,7 +205,6 @@
createConvertToHALPass();
createDumpExecutableSourcesPass("");
createElideRedundantCommandsPass();
- createInlineDeviceSwitchesPass();
createFixupLegacySyncPass();
createLinkExecutablesPass(TargetBackendRegistry::getGlobal());
createLinkTargetExecutablesPass(TargetBackendRegistry::getGlobal(), "");
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/SubstituteExecutables.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/SubstituteExecutables.cpp
index d7ce7f5..492799a 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/SubstituteExecutables.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/SubstituteExecutables.cpp
@@ -180,8 +180,8 @@
variantOp.setObjectsAttr(builder.getArrayAttr({dataObjectAttr}));
// Drop the inner module if present (may already be external).
- for (auto moduleOp :
- llvm::make_early_inc_range(variantOp.getOps<mlir::ModuleOp>())) {
+ for (auto moduleOp : llvm::make_early_inc_range(
+ variantOp.getBody().getOps<mlir::ModuleOp>())) {
moduleOp.erase();
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel
index 623f210..d9b760a 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel
@@ -23,7 +23,6 @@
"dump_executable_sources.mlir",
"elide_redundant_commands.mlir",
"fixup_legacy_sync.mlir",
- "inline_device_switches.mlir",
"materialize_dispatch_instrumentation.mlir",
"materialize_interfaces.mlir",
"materialize_resource_caches.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt
index 6c2f342..7505915 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt
@@ -21,7 +21,6 @@
"dump_executable_sources.mlir"
"elide_redundant_commands.mlir"
"fixup_legacy_sync.mlir"
- "inline_device_switches.mlir"
"materialize_dispatch_instrumentation.mlir"
"materialize_interfaces.mlir"
"materialize_resource_caches.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir
index 84e7e27..1b9e70a 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir
@@ -94,8 +94,12 @@
%arg1_resource as %arg1_capture: !stream.resource<external>{%c16},
%result_resource as %result_capture: !stream.resource<external>{%c16}) {
- // CHECK: hal.device.switch<%[[DEVICE]] : !hal.device>
- // CHECK: #hal.device.match.executable.format<"embedded-elf-x86_64"> {
+ // CHECK-DAG: %{{.+}}, %[[FORMAT_AARCH64:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.executable.format" :: "embedded-elf-aarch64")
+ // CHECK-DAG: %{{.+}}, %[[FORMAT_X86_64:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.executable.format" :: "embedded-elf-x86_64")
+ // CHECK-DAG: %[[SWITCH1:.+]] = arith.select %[[FORMAT_X86_64]], %c1, %c-1
+ // CHECK-DAG: %[[SWITCH0:.+]] = arith.select %[[FORMAT_AARCH64]], %c0, %[[SWITCH1]]
+ // CHECK: scf.index_switch %[[SWITCH0]]
+ // CHECK: case 0 {
// CHECK: %[[PIPELINE_LAYOUT:.+]] = hal.pipeline_layout.lookup
// CHECK-SAME: device(%[[DEVICE]] : !hal.device)
// CHECK-SAME: layout(#pipeline_layout) : !hal.pipeline_layout
@@ -107,9 +111,14 @@
// CHECK: %c2 = (%[[RESULT_BUFFER]] : !hal.buffer)[%c0, %c16]
// CHECK: ])
// CHECK: hal.command_buffer.dispatch.symbol<%[[CMD]] : !hal.command_buffer>
- // CHECK-SAME: target(@ex::@embedded_elf_x86_64::@dispatch)
+ // CHECK-SAME: target(@ex::@embedded_elf_aarch64::@dispatch)
// CHECK-SAME: workgroups([%c1, %c1, %c1])
- // CHECK: hal.return
+ // CHECK: scf.yield
+ // CHECK: }
+ // CHECK: case 1 {
+ // CHECK: hal.command_buffer.dispatch.symbol<%[[CMD]] : !hal.command_buffer>
+ // CHECK-SAME: target(@ex::@embedded_elf_x86_64::@dispatch)
+ // CHECK: scf.yield
// CHECK: }
stream.cmd.dispatch {@ex::@embedded_elf_aarch64::@dispatch, @ex::@embedded_elf_x86_64::@dispatch}[%c4, %c1, %c1] {
ro %arg0_capture[%c0 for %c16] : !stream.resource<external>{%c16},
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/inline_device_switches.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/inline_device_switches.mlir
deleted file mode 100644
index 7a06975..0000000
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/inline_device_switches.mlir
+++ /dev/null
@@ -1,84 +0,0 @@
-// RUN: iree-opt --allow-unregistered-dialect --split-input-file --iree-hal-inline-device-switches --canonicalize %s | FileCheck %s
-
-// CHECK-LABEL: @simple_constants
-// CHECK-SAME: %[[DEVICE:.+]]: !hal.device
-// CHECK-SAME: %[[ARG:.+]]: i32
-func.func @simple_constants(%device : !hal.device, %arg : i32) -> i32 {
- // CHECK-DAG: %[[C0:.+]] = arith.constant 0
- %c0 = arith.constant 0 : i32
- // CHECK-DAG: %[[C1:.+]] = arith.constant 1
- %c1 = arith.constant 1 : i32
- // CHECK-DAG: %[[C2:.+]] = arith.constant 2
- %c2 = arith.constant 2 : i32
- // CHECK-DAG: %[[C3:.+]] = arith.constant 3
- // CHECK-DAG: %[[C4:.+]] = arith.constant 4
- %0 = hal.device.switch<%device : !hal.device> -> i32
- // CHECK-NEXT: %{{.+}}, %[[IS0:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.device.id" :: "vulkan-v1.?-*") : i1, i1 = false
- // CHECK-NEXT: cf.cond_br %[[IS0]], ^bb3(%[[C1]] : i32), ^bb1
- #hal.device.match.id<"vulkan-v1.?-*"> {
- hal.return %c1 : i32
- },
- // CHECK-NEXT: ^bb1:
- // CHECK-NEXT: %{{.+}}, %[[IS1L:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.device.id" :: "vmvx") : i1, i1 = false
- // CHECK-NEXT: %{{.+}}, %[[IS1R:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.device.id" :: "vulkan-*") : i1, i1 = false
- // CHECK-NEXT: %[[IS1:.+]] = arith.ori %[[IS1L]], %[[IS1R]] : i1
- // CHECK-NEXT: cf.cond_br %[[IS1]], ^bb2, ^bb3(%[[C0]] : i32)
- // CHECK-NEXT: ^bb2:
- // CHECK-NEXT: %[[EQZ:.+]] = arith.cmpi eq, %[[ARG]], %[[C2]] : i32
- // CHECK-NEXT: cf.cond_br %[[EQZ]], ^bb3(%[[C3]] : i32), ^bb3(%[[C4]] : i32)
- #hal.match.any<[#hal.device.match.id<"vmvx">, #hal.device.match.id<"vulkan-*">]> {
- %eqz = arith.cmpi eq, %arg, %c2 : i32
- cf.cond_br %eqz, ^bb_true, ^bb_false
- ^bb_true:
- %c3 = arith.constant 3 : i32
- hal.return %c3 : i32
- ^bb_false:
- %c4 = arith.constant 4 : i32
- hal.return %c4 : i32
- },
- #hal.match.always {
- hal.return %c0 : i32
- }
- // CHECK-NEXT: ^bb3(%[[RES:.+]]: i32):
- // CHECK-NEXT: return %[[RES]] : i32
- return %0 : i32
-}
-
-// -----
-
-// CHECK-LABEL: @no_results
-// CHECK-SAME: %[[DEVICE:.+]]: !hal.device
-func.func @no_results(%device : !hal.device) {
- hal.device.switch<%device : !hal.device>
- // CHECK-NEXT: %{{.+}}, %[[IS0:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.device.id" :: "vulkan-v1.?-*") : i1, i1 = false
- // CHECK-NEXT: cf.cond_br %[[IS0]], ^bb1, ^bb2
- // CHECK-NEXT: ^bb1:
- // CHECK-NEXT: "some.op_a"()
- // CHECK-NEXT: cf.br ^bb5
- #hal.device.match.id<"vulkan-v1.?-*"> {
- "some.op_a"() : () -> ()
- hal.return
- },
- // CHECK-NEXT: ^bb2:
- // CHECK-NEXT: %{{.+}}, %[[IS1L:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.device.id" :: "vmvx") : i1, i1 = false
- // CHECK-NEXT: %{{.+}}, %[[IS1R:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.device.id" :: "vulkan-*") : i1, i1 = false
- // CHECK-NEXT: %[[IS1:.+]] = arith.ori %[[IS1L]], %[[IS1R]] : i1
- // CHECK-NEXT: cf.cond_br %[[IS1]], ^bb3, ^bb4
- // CHECK-NEXT: ^bb3:
- // CHECK-NEXT: "some.op_b"()
- // CHECK-NEXT: cf.br ^bb5
- #hal.match.any<[#hal.device.match.id<"vmvx">, #hal.device.match.id<"vulkan-*">]> {
- "some.op_b"() : () -> ()
- hal.return
- },
- // CHECK-NEXT: ^bb4:
- // CHECK-NEXT: "some.op_c"()
- // CHECK-NEXT: cf.br ^bb5
- #hal.match.always {
- "some.op_c"() : () -> ()
- hal.return
- }
- // CHECK-NEXT: ^bb5:
- // CHECK-NEXT: return
- return
-}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir
index b129edf..e2aa4f8 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir
@@ -122,6 +122,10 @@
// - If there is no matching hal.executable.variant then the executable will not be cached
hal.executable @exe {
hal.executable.variant @vmvx target(<"vmvx", "vmvx-bytecode-fb">) {
+ hal.executable.condition(%device: !hal.device) -> i1 {
+ %ok, %selected = hal.device.query<%device : !hal.device> key("some" :: "feature") : i1, i1
+ hal.return %selected : i1
+ }
hal.executable.export @entry0 ordinal(0) layout(#pipeline_layout_0) attributes {
workgroup_size = [32 : index, 1 : index, 1 : index]
}
@@ -155,9 +159,18 @@
// CHECK: util.global private @_executable_exe : !hal.executable
// CHECK-NEXT: util.initializer {
+
+// Switch on the supported formats:
// CHECK: %[[DEVICE:.+]] = hal.ex.shared_device : !hal.device
-// CHECK: %[[RET:.+]] = hal.device.switch<%[[DEVICE]] : !hal.device> -> !hal.executable
-// CHECK: #hal.device.match.executable.format<"vmvx-bytecode-fb"> {
+// CHECK: %{{.+}}, %[[FORMAT_VMVX:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.executable.format" :: "vmvx-bytecode-fb")
+// CHECK: %[[VMVX_CONDITION:.+]] = scf.execute_region -> i1 {
+// CHECK: %{{.+}}, %[[FEATURE:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("some" :: "feature")
+// CHECK: scf.yield %[[FEATURE]]
+// CHECK: }
+// CHECK: %[[VMVX_VARIANT_SELECTED:.+]] = arith.andi %[[FORMAT_VMVX]], %[[VMVX_CONDITION]]
+// CHECK: %[[VARIANT_INDEX:.+]] = arith.select %[[VMVX_VARIANT_SELECTED]], %c0, %c-1
+// CHECK: %[[RET:.+]] = scf.index_switch %[[VARIANT_INDEX]] -> !hal.executable
+// CHECK: case 0 {
// Dependent layouts:
// CHECK: %[[LAYOUT0:.+]] = util.global.load @_pipeline_layout_0 : !hal.pipeline_layout
@@ -176,11 +189,11 @@
// CHECK-SAME: constants([%[[CONST_01]]#0, %[[CONST_01]]#1, %[[CONST_2]]])
// CHECK-SAME: : !hal.executable
-// CHECK: hal.return %[[EXE]] : !hal.executable
-// CHECK: },
-// CHECK: #hal.match.always {
+// CHECK: scf.yield %[[EXE]] : !hal.executable
+// CHECK: }
+// CHECK: default {
// CHECK: %[[NULL:.+]] = util.null : !hal.executable
-// CHECK: hal.return %[[NULL]] : !hal.executable
+// CHECK: scf.yield %[[NULL]] : !hal.executable
// CHECK: }
// CHECK: util.global.store %[[RET]], @_executable_exe : !hal.executable
@@ -247,17 +260,21 @@
util.global private @_executable_exe : !hal.executable
util.initializer {
%device = hal.ex.shared_device : !hal.device
- %0 = hal.device.switch<%device : !hal.device> -> !hal.executable
- #hal.device.match.executable.format<"vmvx-bytecode-fb"> {
+ %format_ok, %format_supported = hal.device.query<%device : !hal.device> key("hal.executable.format" :: "some-format") : i1, i1
+ %c0 = arith.constant 0 : index
+ %c-1 = arith.constant -1 : index
+ %variant = arith.select %format_supported, %c0, %c-1 : index
+ %selected = scf.index_switch %variant -> !hal.executable
+ case 0 {
%_pipeline_layout_0 = util.global.load @_pipeline_layout_0 : !hal.pipeline_layout
%exe = hal.executable.create device(%device : !hal.device) target(@exe0::@vmvx) layouts([%_pipeline_layout_0]) : !hal.executable
- hal.return %exe : !hal.executable
- },
- #hal.match.always {
- %1 = util.null : !hal.executable
- hal.return %1 : !hal.executable
+ scf.yield %exe : !hal.executable
}
- util.global.store %0, @_executable_exe : !hal.executable
+ default {
+ %null = util.null : !hal.executable
+ scf.yield %null : !hal.executable
+ }
+ util.global.store %selected, @_executable_exe : !hal.executable
util.initializer.return
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Utils/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Utils/BUILD.bazel
deleted file mode 100644
index e7696df..0000000
--- a/compiler/src/iree/compiler/Dialect/HAL/Utils/BUILD.bazel
+++ /dev/null
@@ -1,29 +0,0 @@
-# Copyright 2019 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library")
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-iree_compiler_cc_library(
- name = "Utils",
- hdrs = [
- "DeviceSwitchBuilder.h",
- ],
- deps = [
- "//compiler/src/iree/compiler/Dialect/HAL/IR",
- "//compiler/src/iree/compiler/Utils",
- "@llvm-project//llvm:Support",
- "@llvm-project//mlir:FuncDialect",
- "@llvm-project//mlir:IR",
- "@llvm-project//mlir:Support",
- "@llvm-project//mlir:Transforms",
- ],
-)
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Utils/CMakeLists.txt
deleted file mode 100644
index 5509ab2..0000000
--- a/compiler/src/iree/compiler/Dialect/HAL/Utils/CMakeLists.txt
+++ /dev/null
@@ -1,29 +0,0 @@
-################################################################################
-# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
-# compiler/src/iree/compiler/Dialect/HAL/Utils/BUILD.bazel #
-# #
-# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
-# CMake-only content. #
-# #
-# To disable autogeneration for this file entirely, delete this header. #
-################################################################################
-
-iree_add_all_subdirs()
-
-iree_cc_library(
- NAME
- Utils
- HDRS
- "DeviceSwitchBuilder.h"
- DEPS
- LLVMSupport
- MLIRFuncDialect
- MLIRIR
- MLIRSupport
- MLIRTransforms
- iree::compiler::Dialect::HAL::IR
- iree::compiler::Utils
- PUBLIC
-)
-
-### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Utils/DeviceSwitchBuilder.h b/compiler/src/iree/compiler/Dialect/HAL/Utils/DeviceSwitchBuilder.h
deleted file mode 100644
index abb93b5..0000000
--- a/compiler/src/iree/compiler/Dialect/HAL/Utils/DeviceSwitchBuilder.h
+++ /dev/null
@@ -1,207 +0,0 @@
-// Copyright 2020 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_COMPILER_DIALECT_HAL_UTILS_DEVICE_SWITCH_BUILDER_H_
-#define IREE_COMPILER_DIALECT_HAL_UTILS_DEVICE_SWITCH_BUILDER_H_
-
-#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/IRMapping.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/RegionUtils.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-namespace HAL {
-
-// See DeviceSwitchBuilder for details.
-class DeviceSwitchCaseBuilder {
-public:
- DeviceSwitchCaseBuilder(Location loc, TypeRange resultTypes, Value device,
- Attribute initialCondition,
- SmallVectorImpl<IREE::HAL::DeviceSwitchOp> &caseOps,
- OpBuilder &builder)
- : loc_(loc), resultTypes_(resultTypes), device_(device),
- initialCondition_(initialCondition), caseOps_(caseOps),
- builder_(builder) {}
-
- // Result types that each region must return.
- TypeRange resultTypes() { return resultTypes_; }
-
- // Runtime device the switch will match against.
- Value device() { return device_; }
-
- // Pushes a new condition onto the stack and returns a builder that must have
- // all previously nested conditions met in order to execute any conditions.
- DeviceSwitchCaseBuilder nest(Attribute conditionAttr) {
- auto matchAttr =
- initialCondition_
- ? IREE::HAL::MatchAllAttr::get(
- conditionAttr.getContext(),
- ArrayRef<Attribute>{initialCondition_, conditionAttr})
- : conditionAttr;
- return DeviceSwitchCaseBuilder(loc_, resultTypes_, device_, matchAttr,
- caseOps_, builder_);
- }
-
- // Adds a new condition region that must satisfy all parent conditions.
- // The region will have a single empty entry block.
- Region *addRegion() {
- auto switchOp = builder_.create<IREE::HAL::DeviceSwitchOp>(
- loc_, resultTypes_, device_, ArrayRef<Attribute>{initialCondition_});
- auto *region = &switchOp.getRegion(0);
- OpBuilder(region).createBlock(region);
- caseOps_.emplace_back(switchOp);
- return region;
- }
-
- // Adds a new condition region that must satisfy |conditionAttr| and all
- // parent conditions. The region will have a single empty entry block.
- Region *addConditionRegion(Attribute conditionAttr) {
- return nest(conditionAttr).addRegion();
- }
-
-private:
- Location loc_;
- SmallVector<Type> resultTypes_;
- Value device_;
- Attribute initialCondition_;
- SmallVectorImpl<IREE::HAL::DeviceSwitchOp> &caseOps_;
- OpBuilder &builder_;
-};
-
-// Builder for hal.device.switch ops that allows for nesting of conditions.
-//
-// Example:
-// DeviceSwitchBuilder builder();
-// auto b0 = builder.nest(Z);
-// b0.addRegion(); // condition: Z
-// b0.addConditionRegion(A); // condition: Z && A
-// auto b1 = b0.nest(B);
-// b1.addConditionRegion(C); // condition: Z && B && C
-// b1.addConditionRegion(D); // condition: Z && B && D
-// auto b2 = b1.nest(E);
-// b2.addRegion(); // condition: Z && B && E
-// b2.addConditionRegion(F); // condition: Z && B && E && F
-//
-// Note that the arguments passed into addRegion/addConditionRegion are captured
-// from outside of the switch and accessible as entry block arguments on the
-// region that captured them. You must query the returned Region entry block
-// arguments to use them within the region.
-class DeviceSwitchBuilder {
-public:
- DeviceSwitchBuilder(Location loc, TypeRange resultTypes, Value device,
- OpBuilder builder)
- : loc_(loc), resultTypes_(resultTypes), device_(device),
- builder_(builder) {}
-
- // Pushes a new condition onto the stack and returns a builder that must have
- // all previously nested conditions met in order to execute any conditions.
- DeviceSwitchCaseBuilder nest(Attribute conditionAttr) {
- return DeviceSwitchCaseBuilder(loc_, resultTypes_, device_, conditionAttr,
- caseOps_, builder_);
- }
-
- // Adds a new condition region that must satisfy |conditionAttr| and all
- // parent conditions. The region will have a single entry block with the
- // given |args|.
- Region *addConditionRegion(Attribute conditionAttr) {
- return nest(conditionAttr).addRegion();
- }
-
- // Constructs a single hal.device.switch from all added regions.
- IREE::HAL::DeviceSwitchOp build() {
- SmallVector<Attribute> conditionAttrs;
- llvm::SetVector<Value> capturedFromAbove;
- for (auto caseOp : caseOps_) {
- conditionAttrs.push_back(caseOp.getConditions().getValue()[0]);
- }
- auto switchOp = builder_.create<IREE::HAL::DeviceSwitchOp>(
- loc_, resultTypes_, device_, conditionAttrs);
- for (int i = 0; i < caseOps_.size(); ++i) {
- switchOp.getRegion(i).takeBody(caseOps_[i].getRegion(0));
- caseOps_[i].erase();
- }
- return switchOp;
- }
-
-private:
- Location loc_;
- SmallVector<Type> resultTypes_;
- Value device_;
- SmallVector<IREE::HAL::DeviceSwitchOp> caseOps_;
- OpBuilder builder_;
-};
-
-// Rewriter-compatible version of DeviceSwitchBuilder.
-class DeviceSwitchRewriter {
-public:
- DeviceSwitchRewriter(Location loc, TypeRange resultTypes, Value device,
- ConversionPatternRewriter &rewriter)
- : loc_(loc), resultTypes_(resultTypes), device_(device),
- rewriter_(rewriter) {}
-
- // Pushes a new condition onto the stack and returns a builder that must have
- // all previously nested conditions met in order to execute any conditions.
- DeviceSwitchCaseBuilder nest(Attribute conditionAttr) {
- return DeviceSwitchCaseBuilder(loc_, resultTypes_, device_, conditionAttr,
- caseOps_, rewriter_);
- }
-
- // Adds a new condition region that must satisfy |conditionAttr| and all
- // parent conditions. The region will have a single empty entry block.
- Region *addConditionRegion(Attribute conditionAttr) {
- return nest(conditionAttr).addRegion();
- }
-
- // Constructs a single hal.device.switch from all added regions.
- IREE::HAL::DeviceSwitchOp build() {
- SmallVector<Attribute> conditionAttrs;
- llvm::SetVector<Value> capturedFromAbove;
- for (auto caseOp : caseOps_) {
- conditionAttrs.push_back(caseOp.getConditions().getValue()[0]);
- }
- auto switchOp = rewriter_.create<IREE::HAL::DeviceSwitchOp>(
- loc_, resultTypes_, device_, conditionAttrs);
- for (int i = 0; i < caseOps_.size(); ++i) {
- Region &targetRegion = switchOp.getRegion(i);
-
- SmallVector<Type> entryTypes;
- Block *entryBlock =
- rewriter_.createBlock(&targetRegion, targetRegion.end(), entryTypes);
- rewriter_.setInsertionPointAfter(switchOp);
-
- IRMapping mapper;
-
- Region &sourceRegion = caseOps_[i].getRegion(0);
- // When cloning `sourceRegion` into `targetRegion` remap the captured
- // values to use arguments of the `targetRegion`.
- rewriter_.cloneRegionBefore(sourceRegion, targetRegion,
- ++(Region::iterator(entryBlock)), mapper);
- Block *secondBlock = entryBlock->getNextNode();
- rewriter_.mergeBlocks(secondBlock, entryBlock, {});
- rewriter_.eraseOp(caseOps_[i]);
- }
- return switchOp;
- }
-
- ConversionPatternRewriter &getRewriter() const { return rewriter_; }
-
-private:
- Location loc_;
- SmallVector<Type> resultTypes_;
- Value device_;
- SmallVector<IREE::HAL::DeviceSwitchOp> caseOps_;
- ConversionPatternRewriter &rewriter_;
-};
-
-} // namespace HAL
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_DIALECT_HAL_UTILS_DEVICE_SWITCH_BUILDER_H_
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
index 11865dc..5e4f0ea 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
@@ -26,6 +26,37 @@
namespace iree_compiler {
//===----------------------------------------------------------------------===//
+// Experimental
+//===----------------------------------------------------------------------===//
+
+// For now we emit all cases and then select the first found (by selecting
+// in reverse). So if selecting between case0, case1, and case2 we'd end up with
+// %case0 = ...
+// %case1 = ...
+// %case2 = ...
+// %0 = arith.select %case2, %c2, %c-1
+// %1 = arith.select %case1, %c1, %0
+// %2 = arith.select %case0, %c0, %1
+// // %2 is now -1 if nothing matched or the index of the match
+Value buildIfElseTree(
+ Location loc, size_t count,
+ std::function<Value(Location, size_t, OpBuilder &)> caseBuilder,
+ OpBuilder &builder) {
+ SmallVector<Value> caseValues;
+ caseValues.reserve(count);
+ for (size_t i = 0; i < count; ++i) {
+ caseValues.push_back(caseBuilder(loc, i, builder));
+ }
+ Value result = builder.create<arith::ConstantIndexOp>(loc, -1);
+ for (int i = count - 1; i >= 0; --i) {
+ result = builder.create<arith::SelectOp>(
+ loc, caseValues[i], builder.create<arith::ConstantIndexOp>(loc, i),
+ result);
+ }
+ return result;
+}
+
+//===----------------------------------------------------------------------===//
// Utils
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.h b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.h
index a815359..b89675f 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.h
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.h
@@ -31,6 +31,20 @@
namespace iree_compiler {
//===----------------------------------------------------------------------===//
+// Experimental
+//===----------------------------------------------------------------------===//
+
+// NOTE: this is a placeholder for a util.tree_switch (or something) op that
+// looks like scf.index_switch but with a region per case. For now we emit a
+// sequence of arith.select ops and return the index of the first condition that
+// is true. Would be nicer with some range template magic instead of an index.
+// Returns an index of -1 if no case matches.
+Value buildIfElseTree(
+ Location loc, size_t count,
+ std::function<Value(Location, size_t, OpBuilder &)> caseBuilder,
+ OpBuilder &builder);
+
+//===----------------------------------------------------------------------===//
// Utils
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/BUILD.bazel b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/BUILD.bazel
index d1998b0..8b34227 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/BUILD.bazel
+++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/BUILD.bazel
@@ -25,7 +25,6 @@
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Dialect/HAL/IR:HALDialect",
"//compiler/src/iree/compiler/Dialect/HAL/Target",
- "//compiler/src/iree/compiler/Dialect/HAL/Utils",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"//compiler/src/iree/compiler/Modules/HAL/Inline/IR",
"//compiler/src/iree/compiler/Modules/HAL/Inline/IR:HALInlineDialect",
diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/CMakeLists.txt b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/CMakeLists.txt
index c289898..a895f81 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/CMakeLists.txt
@@ -28,7 +28,6 @@
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::HAL::IR::HALDialect
iree::compiler::Dialect::HAL::Target
- iree::compiler::Dialect::HAL::Utils
iree::compiler::Dialect::Util::IR
iree::compiler::Modules::HAL::Inline::IR
iree::compiler::Modules::HAL::Inline::IR::HALInlineDialect
diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/BUILD.bazel b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/BUILD.bazel
index 2e7554d..e21b910 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/BUILD.bazel
+++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/BUILD.bazel
@@ -25,7 +25,6 @@
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Dialect/HAL/IR:HALDialect",
"//compiler/src/iree/compiler/Dialect/HAL/Target",
- "//compiler/src/iree/compiler/Dialect/HAL/Utils",
"//compiler/src/iree/compiler/Dialect/Stream/IR",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"//compiler/src/iree/compiler/Modules/HAL/Inline/IR",
diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/CMakeLists.txt b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/CMakeLists.txt
index 645b1e0..a20f7a6 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/CMakeLists.txt
@@ -28,7 +28,6 @@
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::HAL::IR::HALDialect
iree::compiler::Dialect::HAL::Target
- iree::compiler::Dialect::HAL::Utils
iree::compiler::Dialect::Stream::IR
iree::compiler::Dialect::Util::IR
iree::compiler::Modules::HAL::Inline::IR
diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/InlineExecutables.cpp b/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/InlineExecutables.cpp
index 27880e0..abb6179 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/InlineExecutables.cpp
+++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/InlineExecutables.cpp
@@ -100,7 +100,7 @@
auto indexType = innerModuleBuilder.getIndexType();
auto i32Type = innerModuleBuilder.getI32Type();
auto bufferType = innerModuleBuilder.getType<IREE::Util::BufferType>();
- for (auto exportOp : variantOp.getOps<IREE::HAL::ExecutableExportOp>()) {
+ for (auto exportOp : variantOp.getExportOps()) {
// Build dispatch function signature that the stream.cmd.dispatch ops will
// map to.
auto layoutAttr = exportOp.getLayout();
diff --git a/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/BUILD.bazel b/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/BUILD.bazel
index 49aa543..12996ee 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/BUILD.bazel
+++ b/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/BUILD.bazel
@@ -25,7 +25,6 @@
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Dialect/HAL/IR:HALDialect",
"//compiler/src/iree/compiler/Dialect/HAL/Target",
- "//compiler/src/iree/compiler/Dialect/HAL/Utils",
"//compiler/src/iree/compiler/Dialect/Stream/IR",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"//compiler/src/iree/compiler/Modules/HAL/Inline/IR",
diff --git a/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/CMakeLists.txt b/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/CMakeLists.txt
index a6ffbb4..44ae182 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/CMakeLists.txt
@@ -28,7 +28,6 @@
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::HAL::IR::HALDialect
iree::compiler::Dialect::HAL::Target
- iree::compiler::Dialect::HAL::Utils
iree::compiler::Dialect::Stream::IR
iree::compiler::Dialect::Util::IR
iree::compiler::Modules::HAL::Inline::IR
diff --git a/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/Patterns.cpp b/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/Patterns.cpp
index 005d442..7861ee2 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/Patterns.cpp
+++ b/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/Patterns.cpp
@@ -47,8 +47,8 @@
auto loc = dispatchOp.getLoc();
// TODO(benvanik): support a lightweight switch builder for picking variants
- // that doesn't pull in the full HAL dialect - today the
- // DeviceSwitchRewriter needs a !hal.device and its query methods.
+ // that doesn't pull in the full HAL dialect. We could make the match
+ // expressions take a callback that performs the query, for example.
// For now we bail if there's multiple.
auto entryPointAttrs = dispatchOp.getEntryPoints().getValue();
if (entryPointAttrs.size() != 1) {
@@ -76,10 +76,9 @@
loc, rewriter.getType<IREE::HAL::ExecutableType>(),
executableOp.getName());
- // TODO(benvanik): a real switch op. For now we inline what the
- // hal.device.switch op does.
+ // TODO(benvanik): use scf.index_switch as with the full HAL.
for (auto variantOp : variantOps) {
- auto exportOps = variantOp.getOps<IREE::HAL::ExecutableExportOp>();
+ auto exportOps = variantOp.getExportOps();
auto exportIt =
llvm::find_if(exportOps, [&](IREE::HAL::ExecutableExportOp op) {
return op.getNameAttr() == entryPointAttr.getLeafReference();
diff --git a/samples/custom_dispatch/vulkan/shaders/example_inline.mlir b/samples/custom_dispatch/vulkan/shaders/example_inline.mlir
index b213bd3..980df87 100644
--- a/samples/custom_dispatch/vulkan/shaders/example_inline.mlir
+++ b/samples/custom_dispatch/vulkan/shaders/example_inline.mlir
@@ -53,6 +53,16 @@
// Dispatch a basic `ret = lhs * rhs` shader.
%0 = hal.dispatch.extern "main"[%dim](%dim_i32, %arg0, %arg1) : (i32, tensor<?xf32>{%dim}, tensor<?xf32>{%dim}) -> tensor<?xf32>{%dim}
+ count(%device: !hal.device, %workload: index) -> (index, index, index) {
+ // This host function is used to compute the XYZ workgroup count
+ // dispatched at runtime. It can query the %device for capabilities
+ // and limits (shared memory size, etc). The other arguments are the
+ // values passed in the dispatch operation (usually things like root
+ // output op tensor dimensions and other abstract values).
+ %x = affine.apply affine_map<()[s0] -> (s0 ceildiv 64)>()[%workload]
+ %c1 = arith.constant 1 : index
+ hal.return %x, %c1, %c1 : index, index, index
+ }
// The layout defines the required bindings and push constants and can be
// thought of as the function signature.
layout(#hal.pipeline.layout<push_constants = 1, sets = [
@@ -94,16 +104,6 @@
}>
]
}>)
- count(%device: !hal.device, %workload: index) -> (index, index, index) {
- // This host function is used to compute the XYZ workgroup count
- // dispatched at runtime. It can query the %device for capabilities
- // and limits (shared memory size, etc). The other arguments are the
- // values passed in the dispatch operation (usually things like root
- // output op tensor dimensions and other abstract values).
- %x = affine.apply affine_map<()[s0] -> (s0 ceildiv 64)>()[%workload]
- %c1 = arith.constant 1 : index
- hal.return %x, %c1, %c1 : index, index, index
- }
// Code gen some other ops - these will interleave with the hand-authored
// ones but naturally won't be able to fuse with them.