Adding -iree-hal-dump-executable-benchmarks-to= flag.
This produces a directory of mlir files that can be translated and
passed to iree-benchmark-module.
diff --git a/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/ConvertStreamToHAL.cpp b/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/ConvertStreamToHAL.cpp
index 3a9903b..a77a52f 100644
--- a/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/ConvertStreamToHAL.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/ConvertStreamToHAL.cpp
@@ -698,9 +698,8 @@
SymbolRefAttr::get(caseBuilder.getContext(), executableOp.getName(),
{SymbolRefAttr::get(entryPointOp->getParentOp()),
SymbolRefAttr::get(entryPointOp)});
- auto caseWorkgroupCount = calculateDispatchWorkgroupCount(
- loc, executableOp, entryPointOp, adaptor.workgroup_count(),
- caseBuilder);
+ auto caseWorkgroupCount = entryPointOp.calculateWorkgroupCount(
+ loc, adaptor.workgroup_count(), caseBuilder);
caseBuilder.create<IREE::HAL::CommandBufferDispatchSymbolOp>(
loc, commandBuffer, entryPointSymRef, caseWorkgroupCount[0],
caseWorkgroupCount[1], caseWorkgroupCount[2]);
@@ -772,110 +771,6 @@
}
if (currentSet != -1) flushSet();
}
-
- // Calculates the workgroup count (x, y, z) for dispatching to the given
- // |entryPointOp|. The provided N-dimensional |workload| is the total number
- // of invocations required as calculated by the generic workload logic
- // (basically, number of output elements in tensors).
- static std::array<Value, 3> calculateDispatchWorkgroupCount(
- Location loc, IREE::HAL::ExecutableOp executableOp,
- IREE::HAL::ExecutableEntryPointOp entryPointOp, ValueRange workload,
- OpBuilder &builder) {
- Block *body = entryPointOp.getWorkgroupCountBody();
- if (body) {
- return calculateDispatchWorkgroupCountFromRegion(loc, entryPointOp,
- workload, builder);
- }
- auto workgroupSize = calculateDispatchWorkgroupSize(
- loc, executableOp, entryPointOp, workload, builder);
- return calculateWorkloadWorkgroupCount(loc, workload, workgroupSize,
- builder);
- }
-
- // Calculates the workgroup size (x, y, z). These are the dimension numbers
- // for a single workgroup.
- static std::array<Value, 3> calculateDispatchWorkgroupSize(
- Location loc, IREE::HAL::ExecutableOp executableOp,
- IREE::HAL::ExecutableEntryPointOp entryPointOp, ValueRange workload,
- OpBuilder &builder) {
- // When no workgroup size is specified we just assume [1,1,1].
- // This yields a workgroup count that models the extents of the workload.
- return {
- builder.createOrFold<arith::ConstantIndexOp>(loc, 1),
- builder.createOrFold<arith::ConstantIndexOp>(loc, 1),
- builder.createOrFold<arith::ConstantIndexOp>(loc, 1),
- };
- }
-
- static std::array<Value, 3> calculateDispatchWorkgroupCountFromRegion(
- Location loc, IREE::HAL::ExecutableEntryPointOp entryPointOp,
- ValueRange workload, OpBuilder &builder) {
- // TODO(benvanik): replace with region inlining util.
- Block *body = entryPointOp.getWorkgroupCountBody();
- BlockAndValueMapping bvm;
- for (auto args : llvm::enumerate(workload)) {
- bvm.map(body->getArgument(args.index()), args.value());
- }
- for (Operation &op : body->without_terminator()) {
- builder.clone(op, bvm);
- }
- auto returnOp = cast<IREE::HAL::ReturnOp>(body->getTerminator());
- return {
- bvm.lookup(returnOp.operands()[0]),
- bvm.lookup(returnOp.operands()[1]),
- bvm.lookup(returnOp.operands()[2]),
- };
- }
-
- // Calculates the workgroup count (x, y, z) given the total N-dimensional
- // |workload| and specific |workgroupSize|.
- static std::array<Value, 3> calculateWorkloadWorkgroupCount(
- Location loc, ValueRange workload,
- const std::array<Value, 3> &workgroupSize, OpBuilder &builder) {
- std::array<Value, 3> result;
-
- auto constantOne = builder.createOrFold<arith::ConstantIndexOp>(loc, 1);
- if (workload.size() <= 3) {
- // 1-D to 3-D are easy (pad 2 to 0 dimensions) and divide by workgroup
- // size.
- for (int i = 0; i < 3; ++i) {
- // Round up: (workload[i] + workgroup_size - 1) / workgroup_size;
- Value workloadI = i < workload.size() ? workload[i] : constantOne;
- workloadI = builder.createOrFold<arith::SubIOp>(
- loc,
- builder.createOrFold<arith::AddIOp>(loc, workloadI,
- workgroupSize[i]),
- constantOne);
- result[i] = builder.createOrFold<arith::DivUIOp>(loc, workloadI,
- workgroupSize[i]);
- }
- } else {
- // TODO(#4140): remapping of N-D to 3-D: this is not how you do this!
- Value flatWorkload = constantOne;
- for (auto workloadI : workload) {
- flatWorkload =
- builder.createOrFold<arith::MulIOp>(loc, flatWorkload, workloadI);
- }
- for (int i = 0; i < 3; ++i) {
- // Round up: (workload[i] + workgroup_size - 1) / workgroup_size;
- auto rounded = builder.createOrFold<arith::SubIOp>(
- loc,
- builder.createOrFold<arith::AddIOp>(loc, flatWorkload,
- workgroupSize[i]),
- constantOne);
- auto workgroupCountI = builder.createOrFold<arith::DivUIOp>(
- loc, rounded, workgroupSize[i]);
- result[i] = workgroupCountI;
-
- // Multiply back out and subtract from invocations.
- flatWorkload = builder.createOrFold<arith::SubIOp>(
- loc, flatWorkload,
- builder.createOrFold<arith::MulIOp>(loc, workgroupCountI, rounded));
- }
- }
-
- return result;
- }
};
static void insertSerializationBarriers(Location loc, Block &block,
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index 714271c..e6f1ec6 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -13,6 +13,7 @@
#include "llvm/Support/SMLoc.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
@@ -754,6 +755,100 @@
return success();
}
+// Calculates the workgroup count (x, y, z) given the total N-dimensional
+// |workload| and specific |workgroupSize|.
+static std::array<Value, 3> calculateWorkloadWorkgroupCount(
+ Location loc, ValueRange workload,
+ const std::array<Value, 3> &workgroupSize, OpBuilder &builder) {
+ std::array<Value, 3> result;
+
+ auto constantOne = builder.createOrFold<arith::ConstantIndexOp>(loc, 1);
+ if (workload.size() <= 3) {
+ // 1-D to 3-D are easy (pad 2 to 0 dimensions) and divide by workgroup
+ // size.
+ for (int i = 0; i < 3; ++i) {
+ // Round up: (workload[i] + workgroup_size - 1) / workgroup_size;
+ Value workloadI = i < workload.size() ? workload[i] : constantOne;
+ workloadI = builder.createOrFold<arith::SubIOp>(
+ loc,
+ builder.createOrFold<arith::AddIOp>(loc, workloadI, workgroupSize[i]),
+ constantOne);
+ result[i] = builder.createOrFold<arith::DivUIOp>(loc, workloadI,
+ workgroupSize[i]);
+ }
+ } else {
+ // TODO(#4140): remapping of N-D to 3-D: this is not how you do this!
+ Value flatWorkload = constantOne;
+ for (auto workloadI : workload) {
+ flatWorkload =
+ builder.createOrFold<arith::MulIOp>(loc, flatWorkload, workloadI);
+ }
+ for (int i = 0; i < 3; ++i) {
+ // Round up: (workload[i] + workgroup_size - 1) / workgroup_size;
+ auto rounded = builder.createOrFold<arith::SubIOp>(
+ loc,
+ builder.createOrFold<arith::AddIOp>(loc, flatWorkload,
+ workgroupSize[i]),
+ constantOne);
+ auto workgroupCountI =
+ builder.createOrFold<arith::DivUIOp>(loc, rounded, workgroupSize[i]);
+ result[i] = workgroupCountI;
+
+ // Multiply back out and subtract from invocations.
+ flatWorkload = builder.createOrFold<arith::SubIOp>(
+ loc, flatWorkload,
+ builder.createOrFold<arith::MulIOp>(loc, workgroupCountI, rounded));
+ }
+ }
+
+ return result;
+}
+
+static std::array<Value, 3> calculateWorkgroupCountFromRegion(
+ Location loc, Block *body, ValueRange workload, OpBuilder &builder) {
+ // TODO(benvanik): replace with region inlining util.
+ BlockAndValueMapping bvm;
+ for (auto args : llvm::enumerate(workload)) {
+ bvm.map(body->getArgument(args.index()), args.value());
+ }
+ for (Operation &op : body->without_terminator()) {
+ builder.clone(op, bvm);
+ }
+ auto returnOp = cast<IREE::HAL::ReturnOp>(body->getTerminator());
+ return {
+ bvm.lookup(returnOp.operands()[0]),
+ bvm.lookup(returnOp.operands()[1]),
+ bvm.lookup(returnOp.operands()[2]),
+ };
+}
+
+// Calculates the workgroup count (x, y, z) for dispatching to the entry point.
+// The provided N-dimensional |workload| is the total number of invocations
+// required as calculated by the generic workload logic (basically, number of
+// output elements in tensors).
+std::array<Value, 3> ExecutableEntryPointOp::calculateWorkgroupCount(
+ Location loc, ValueRange workload, OpBuilder &builder) {
+ Block *body = getWorkgroupCountBody();
+ if (body) {
+ return calculateWorkgroupCountFromRegion(loc, body, workload, builder);
+ }
+ auto workgroupSize = calculateWorkgroupSize(loc, workload, builder);
+ return calculateWorkloadWorkgroupCount(loc, workload, workgroupSize, builder);
+}
+
+// Calculates the workgroup size (x, y, z). These are the dimension numbers
+// for a single workgroup.
+std::array<Value, 3> ExecutableEntryPointOp::calculateWorkgroupSize(
+ Location loc, ValueRange workload, OpBuilder &builder) {
+ // When no workgroup size is specified we just assume [1,1,1].
+ // This yields a workgroup count that models the extents of the workload.
+ return {
+ builder.createOrFold<arith::ConstantIndexOp>(loc, 1),
+ builder.createOrFold<arith::ConstantIndexOp>(loc, 1),
+ builder.createOrFold<arith::ConstantIndexOp>(loc, 1),
+ };
+}
+
//===----------------------------------------------------------------------===//
// hal.executable.variant
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.td b/iree/compiler/Dialect/HAL/IR/HALOps.td
index bf76989..e1da397 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -1569,6 +1569,14 @@
if (workgroup_count_region().empty()) return nullptr;
return &workgroup_count_region().front();
}
+
+ // Calculates an XYZ workgroup count based on the given |workload|.
+ std::array<Value, 3> calculateWorkgroupCount(
+ Location loc, ValueRange workload, OpBuilder &builder);
+
+ // Calculates an XYZ workgroup size based on the given |workload|.
+ std::array<Value, 3> calculateWorkgroupSize(
+ Location loc, ValueRange workload, OpBuilder &builder);
}];
}
diff --git a/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp b/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp
index 161b688..0b149dc 100644
--- a/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp
+++ b/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp
@@ -35,6 +35,12 @@
llvm::cl::desc("Path to write individual hal.executable input "
"source listings into (- for stdout)."),
llvm::cl::cat(halTargetOptionsCategory));
+
+ binder.opt<std::string>(
+ "iree-hal-dump-executable-benchmarks-to", executableBenchmarksPath,
+ llvm::cl::desc("Path to write standalone hal.executable benchmarks into "
+ "(- for stdout)."),
+ llvm::cl::cat(halTargetOptionsCategory));
}
// Renames |op| within |moduleOp| with a new name that is unique within both
diff --git a/iree/compiler/Dialect/HAL/Target/TargetBackend.h b/iree/compiler/Dialect/HAL/Target/TargetBackend.h
index 3fe4564..542a32e 100644
--- a/iree/compiler/Dialect/HAL/Target/TargetBackend.h
+++ b/iree/compiler/Dialect/HAL/Target/TargetBackend.h
@@ -34,6 +34,9 @@
// A path to write individual executable source listings into.
std::string sourceListingPath;
+ // A path to write standalone executable benchkmarks into.
+ std::string executableBenchmarksPath;
+
// TODO(benvanik): flags for debug/optimization/etc.
// The intent is that we can have a global debug/-ON flag that then each
// target backend can have tickle it's own flags in the right way. Right now
diff --git a/iree/compiler/Dialect/HAL/Transforms/BUILD b/iree/compiler/Dialect/HAL/Transforms/BUILD
index 270a799..dc57eed 100644
--- a/iree/compiler/Dialect/HAL/Transforms/BUILD
+++ b/iree/compiler/Dialect/HAL/Transforms/BUILD
@@ -16,6 +16,7 @@
"AssignTargetDevices.cpp",
"BenchmarkBatchDispatches.cpp",
"ConvertToHAL.cpp",
+ "DumpExecutableBenchmarks.cpp",
"DumpExecutableSources.cpp",
"ElideRedundantCommands.cpp",
"InlineDeviceSwitches.cpp",
@@ -49,6 +50,7 @@
"//iree/compiler/Dialect/Util/Conversion",
"//iree/compiler/Dialect/Util/IR",
"//iree/compiler/Dialect/Util/Transforms",
+ "//iree/compiler/Utils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineToStandard",
"@llvm-project//mlir:ArithmeticDialect",
@@ -57,6 +59,7 @@
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
diff --git a/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt b/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
index 9fa8a9a..35ce91e 100644
--- a/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
@@ -19,6 +19,7 @@
"AssignTargetDevices.cpp"
"BenchmarkBatchDispatches.cpp"
"ConvertToHAL.cpp"
+ "DumpExecutableBenchmarks.cpp"
"DumpExecutableSources.cpp"
"ElideRedundantCommands.cpp"
"InlineDeviceSwitches.cpp"
@@ -41,6 +42,7 @@
MLIRFunc
MLIRIR
MLIRPass
+ MLIRSCF
MLIRSupport
MLIRTransforms
iree::compiler::Dialect::Flow::IR
@@ -58,6 +60,7 @@
iree::compiler::Dialect::Util::Conversion
iree::compiler::Dialect::Util::IR
iree::compiler::Dialect::Util::Transforms
+ iree::compiler::Utils
PUBLIC
)
diff --git a/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp b/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
new file mode 100644
index 0000000..616026a
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
@@ -0,0 +1,465 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <memory>
+#include <utility>
+
+#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+#include "iree/compiler/Utils/IndexSet.h"
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/Path.h"
+#include "llvm/Support/ToolOutputFile.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Transforms/Passes.h"
+
+// NOTE: redundant bindings will result in unique buffer locations during the
+// benchmark and will impact caching behavior.
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+
+// We could use the resource constraints in the module when we have them.
+static const int64_t kBufferAlignment = 256;
+
+using Vec3 = std::tuple<unsigned, unsigned, unsigned>;
+
+struct Binding {
+ unsigned set = 0;
+ unsigned binding = 0;
+ int64_t size = 0;
+};
+
+// Combined data for all dispatches of a particular static workload size.
+struct DispatchParams {
+ // All locations that dispatch with these parameters.
+ SmallVector<Location> locs;
+ // Workload used as input to the workgroup count calculation function.
+ Vec3 workload;
+ // Analyzed minimum binding sizes.
+ SmallVector<Binding> bindings;
+};
+
+using DispatchParamsMap =
+ llvm::DenseMap<SymbolRefAttr, llvm::MapVector<Vec3, DispatchParams>>;
+
+// Walk |moduleOp| and gather all of the dispatches to each executable.
+// Dispatch parameters are deduplicated by workload so that there's only ever
+// one entry for all dispatches with a given workgroup count.
+// Dispatches will be ignored if they have a dynamic workload or any dynamically
+// sized resources.
+static DispatchParamsMap gatherDispatchParams(mlir::ModuleOp moduleOp) {
+ DispatchParamsMap map;
+
+ for (auto funcOp : moduleOp.getOps<FunctionOpInterface>()) {
+ funcOp.walk([&](IREE::Stream::CmdDispatchOp dispatchOp) {
+ // NOTE: we could record operands here if we wanted real push constants.
+
+ // TODO(benvanik): typed accessors for bindings.
+ auto bindingAttrs = dispatchOp->getAttr("hal.interface.bindings")
+ .dyn_cast_or_null<ArrayAttr>();
+ assert(bindingAttrs &&
+ "interface materialization must annotate dispatch sites");
+
+ auto workloadValues = dispatchOp.workgroup_count();
+ APInt workloadX, workloadY, workloadZ;
+ if (!matchPattern(workloadValues[0], m_ConstantInt(&workloadX)) ||
+ !matchPattern(workloadValues[1], m_ConstantInt(&workloadY)) ||
+ !matchPattern(workloadValues[2], m_ConstantInt(&workloadZ))) {
+ // Non-constant workload; skip this dispatch.
+ return;
+ }
+ Vec3 workload =
+ std::make_tuple(workloadX.getSExtValue(), workloadY.getSExtValue(),
+ workloadZ.getSExtValue());
+
+ SmallVector<Binding> bindings;
+ for (auto it : llvm::zip(bindingAttrs, dispatchOp.resource_lengths())) {
+ auto bindingAttr =
+ std::get<0>(it).cast<IREE::HAL::InterfaceBindingAttr>();
+ APInt resourceLength;
+ if (!matchPattern(std::get<1>(it), m_ConstantInt(&resourceLength))) {
+ // Non-constant resource length; skip this dispatch.
+ return;
+ }
+ bindings.push_back({(unsigned)bindingAttr.getSet(),
+ (unsigned)bindingAttr.getBinding(),
+ resourceLength.getSExtValue()});
+ }
+
+ auto &dispatchParamsSet = map[dispatchOp.entry_point()];
+ auto &dispatchParams = dispatchParamsSet[workload];
+ dispatchParams.locs.push_back(dispatchOp.getLoc());
+ dispatchParams.workload = workload;
+ dispatchParams.bindings = bindings;
+ });
+ }
+
+ return map;
+}
+
+// Appends a global hal.buffer initialized to the size required for all
+// of the bindings in |dispatchParams| (plus alignment).
+static IREE::Util::GlobalOp appendGlobalBuffer(
+ Location loc, StringRef baseName, const DispatchParams &dispatchParams,
+ OpBuilder &moduleBuilder) {
+ // Create a global to hold the HAL buffer.
+ auto globalOp = moduleBuilder.create<IREE::Util::GlobalOp>(
+ loc, (baseName + "_buffer").str(),
+ /*isMutable=*/true,
+ IREE::HAL::BufferType::get(moduleBuilder.getContext()));
+ globalOp.setPrivate();
+
+ // Compute the total size of the buffer based on all binding sizes when
+ // aligned.
+ int64_t totalLength = 0;
+ for (auto binding : dispatchParams.bindings) {
+ totalLength =
+ IREE::Util::align(totalLength + binding.size, kBufferAlignment);
+ }
+
+ // Build an initializer to allocate the buffer.
+ auto initOp = moduleBuilder.create<IREE::Util::InitializerOp>(loc);
+ auto initBuilder = OpBuilder::atBlockBegin(initOp.addEntryBlock());
+ IndexSet indexSet(loc, initBuilder);
+
+ // TODO(benvanik): real device lookup.
+ auto device = initBuilder.create<IREE::HAL::ExSharedDeviceOp>(loc);
+ auto allocator =
+ initBuilder.create<IREE::HAL::DeviceAllocatorOp>(loc, device).result();
+
+ auto memoryTypes = IREE::HAL::MemoryTypeBitfield::DeviceLocal;
+ auto bufferUsage = IREE::HAL::BufferUsageBitfield::Dispatch;
+ auto allocateOp = initBuilder.create<IREE::HAL::AllocatorAllocateOp>(
+ loc, globalOp.type(), allocator, memoryTypes, bufferUsage,
+ indexSet.get(totalLength));
+
+ initBuilder.create<IREE::Util::GlobalStoreOp>(loc, allocateOp.result(),
+ globalOp.getNameAttr());
+ initBuilder.create<IREE::Util::InitializerReturnOp>(loc);
+
+ return globalOp;
+}
+
+// Appends a function calling the given |entryPointOp| with |dispatchParams|.
+// This will add a global value for the resources required.
+//
+// Expects the runner to pass an i32 value indicating the number of dispatches
+// to be made in one submission.
+static void appendDispatchBenchmark(
+ IREE::HAL::ExecutableOp executableOp,
+ IREE::HAL::ExecutableVariantOp variantOp,
+ IREE::HAL::ExecutableEntryPointOp entryPointOp,
+ DispatchParams dispatchParams, OpBuilder &moduleBuilder) {
+ auto loc = FusedLoc::get(executableOp.getContext(), dispatchParams.locs);
+
+ std::string baseName =
+ (executableOp.getName() + "_" + variantOp.getName() + "_" +
+ entryPointOp.getName() + "_" +
+ std::to_string(std::get<0>(dispatchParams.workload)) + "x" +
+ std::to_string(std::get<1>(dispatchParams.workload)) + "x" +
+ std::to_string(std::get<2>(dispatchParams.workload)))
+ .str();
+
+ // Add a global variable holding an initialized buffer for the dispatch IO.
+ auto bufferGlobalOp =
+ appendGlobalBuffer(loc, baseName, dispatchParams, moduleBuilder);
+
+ // Create an exported benchmark function that runs the dispatches.
+ auto funcType =
+ moduleBuilder.getFunctionType({moduleBuilder.getI32Type()}, {});
+ auto funcOp = moduleBuilder.create<mlir::FuncOp>(loc, baseName, funcType);
+ funcOp.setVisibility(SymbolTable::Visibility::Public);
+
+ // Mark the function as being a dispatch benchmark.
+ // This tells iree-benchmark-module to pass in the arguments we need.
+ funcOp->setAttr("iree.abi.stub", moduleBuilder.getUnitAttr());
+ funcOp->setAttr(
+ "iree.reflection",
+ moduleBuilder.getDictionaryAttr({
+ moduleBuilder.getNamedAttr("iree.benchmark",
+ moduleBuilder.getStringAttr("dispatch")),
+ }));
+
+ // Build the function that runs the dispatches.
+ auto *entryBlock = funcOp.addEntryBlock();
+ OpBuilder funcBuilder = OpBuilder::atBlockBegin(entryBlock);
+ IndexSet indexSet(loc, funcBuilder);
+ auto batchSizeArg = funcBuilder.create<arith::IndexCastOp>(
+ loc, funcBuilder.getIndexType(), entryBlock->getArgument(0));
+
+ // TODO(benvanik): real device lookup.
+ auto device = funcBuilder.create<IREE::HAL::ExSharedDeviceOp>(loc);
+
+ // Create and begin command buffer.
+ // TODO(benvanik): reuse the command buffer (initialize once and store).
+ auto commandBufferModes =
+ IREE::HAL::CommandBufferModeBitfield::OneShot |
+ IREE::HAL::CommandBufferModeBitfield::AllowInlineExecution;
+ auto commandBuffer =
+ funcBuilder
+ .create<IREE::HAL::CommandBufferCreateOp>(
+ loc, funcBuilder.getType<IREE::HAL::CommandBufferType>(), device,
+ commandBufferModes, IREE::HAL::CommandCategoryBitfield::Dispatch)
+ .result();
+ funcBuilder.create<IREE::HAL::CommandBufferBeginOp>(loc, commandBuffer);
+
+ // Get the layout required to set up the dispatches.
+ auto layoutAttr = entryPointOp.layoutAttr();
+ auto executableLayout =
+ funcBuilder
+ .create<IREE::HAL::ExecutableLayoutLookupOp>(
+ loc, IREE::HAL::ExecutableLayoutType::get(loc.getContext()),
+ device, layoutAttr)
+ .result();
+
+ // Push constant values.
+ // TODO(benvanik): use push constants the program used? can help with
+ // specialization that may have been applied in the streams dialect.
+ if (int64_t pushConstantCount = layoutAttr.getPushConstants()) {
+ int pushConstantBase = 0; // always 0 today
+ SmallVector<Value> pushConstants(pushConstantCount,
+ funcBuilder.create<arith::ConstantIntOp>(
+ loc, 0, funcBuilder.getI32Type()));
+ funcBuilder.create<IREE::HAL::CommandBufferPushConstantsOp>(
+ loc, commandBuffer, executableLayout,
+ funcBuilder.getIndexAttr(pushConstantBase), pushConstants);
+ }
+
+ // Push descriptor sets.
+ auto buffer =
+ funcBuilder.create<IREE::Util::GlobalLoadOp>(loc, bufferGlobalOp)
+ .result();
+ int64_t currentSet = -1;
+ SmallVector<IREE::HAL::DescriptorSetBindingValue> bindingValues;
+ auto flushSet = [&]() {
+ funcBuilder.create<IREE::HAL::CommandBufferPushDescriptorSetOp>(
+ loc, commandBuffer, executableLayout, currentSet, bindingValues);
+ bindingValues.clear();
+ };
+ int64_t bufferOffset = 0;
+ for (auto binding : dispatchParams.bindings) {
+ if (currentSet != -1 && currentSet != binding.set) flushSet();
+ currentSet = binding.set;
+ IREE::HAL::DescriptorSetBindingValue bindingValue;
+ bindingValue.ordinal =
+ funcBuilder.create<arith::ConstantIndexOp>(loc, binding.binding);
+ bindingValue.buffer = buffer;
+ bindingValue.byteOffset = indexSet.get(bufferOffset);
+ bindingValue.byteLength = indexSet.get(binding.size);
+ bindingValues.push_back(bindingValue);
+ bufferOffset =
+ IREE::Util::align(bufferOffset + binding.size, kBufferAlignment);
+ }
+ if (currentSet != -1) flushSet();
+
+ // Compute the workgroup parameters.
+ auto workgroupCount = entryPointOp.calculateWorkgroupCount(
+ loc,
+ {
+ indexSet.get(std::get<0>(dispatchParams.workload)),
+ indexSet.get(std::get<1>(dispatchParams.workload)),
+ indexSet.get(std::get<2>(dispatchParams.workload)),
+ },
+ funcBuilder);
+
+ // Loop around dispatches based on batch size.
+ // Note that we insert a barrier between each dispatch - we could make this
+ // optional so that concurrent utilization is measured.
+ funcBuilder.create<scf::ForOp>(
+ loc, indexSet.get(0), batchSizeArg, indexSet.get(1), ValueRange{},
+ [&](OpBuilder &forBuilder, Location loc, Value iv, ValueRange iters) {
+ // Dispatch.
+ auto symbolRefAttr = SymbolRefAttr::get(
+ executableOp.getNameAttr(),
+ {
+ SymbolRefAttr::get(variantOp.getNameAttr()),
+ SymbolRefAttr::get(entryPointOp.getNameAttr()),
+ });
+ forBuilder.create<IREE::HAL::CommandBufferDispatchSymbolOp>(
+ loc, commandBuffer, symbolRefAttr, workgroupCount[0],
+ workgroupCount[1], workgroupCount[2]);
+
+ // Barrier following the dispatch to block the next dispatch.
+ auto sourceStage = IREE::HAL::ExecutionStageBitfield::CommandRetire |
+ IREE::HAL::ExecutionStageBitfield::Dispatch;
+ auto targetStage = IREE::HAL::ExecutionStageBitfield::CommandIssue |
+ IREE::HAL::ExecutionStageBitfield::Dispatch;
+ auto barrierFlags = IREE::HAL::ExecutionBarrierFlagBitfield::None;
+ forBuilder.create<IREE::HAL::CommandBufferExecutionBarrierOp>(
+ loc, commandBuffer, sourceStage, targetStage, barrierFlags);
+
+ forBuilder.create<scf::YieldOp>(loc);
+ });
+
+ // Submit command buffer.
+ funcBuilder.create<IREE::HAL::CommandBufferEndOp>(loc, commandBuffer);
+ funcBuilder.create<IREE::HAL::ExSubmitAndWaitOp>(loc, device, commandBuffer);
+
+ funcBuilder.create<mlir::func::ReturnOp>(loc);
+}
+
+// Builds a module exporting one function for each dispatch configuration
+// targeting |sourceExecutableOp|.
+static mlir::OwningOpRef<mlir::ModuleOp> buildBenchmarkModule(
+ IREE::HAL::ExecutableOp sourceExecutableOp,
+ IREE::HAL::ExecutableVariantOp sourceVariantOp,
+ const DispatchParamsMap &dispatchParamsMap) {
+ // Empty module with default name.
+ // We could use the original module name here to make tracking nicer.
+ mlir::OwningOpRef<mlir::ModuleOp> moduleOp =
+ mlir::ModuleOp::create(sourceExecutableOp.getLoc());
+ auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp->getBody());
+
+ // Copy over the device targets from the original module.
+ // TODO(benvanik): filter this by the target of the variant.
+ moduleOp->getOperation()->setAttr(
+ "hal.device.targets",
+ sourceExecutableOp->getParentOfType<mlir::ModuleOp>()->getAttr(
+ "hal.device.targets"));
+
+ // Clone the executable variant into the new module.
+ auto executableOp = moduleBuilder.create<IREE::HAL::ExecutableOp>(
+ sourceExecutableOp.getLoc(), sourceExecutableOp.getName());
+ executableOp.setVisibility(sourceExecutableOp.getVisibility());
+ auto variantOp = cast<IREE::HAL::ExecutableVariantOp>(
+ OpBuilder::atBlockBegin(executableOp.getBody())
+ .clone(*sourceVariantOp.getOperation()));
+
+ // Add functions to test each entry point with its various dispatch
+ // parameters.
+ bool hasAnyBenchmarks = false;
+ for (auto entryPointOp :
+ variantOp.getOps<IREE::HAL::ExecutableEntryPointOp>()) {
+ auto symbolRefAttr = SymbolRefAttr::get(
+ executableOp.getNameAttr(),
+ {FlatSymbolRefAttr::get(entryPointOp.getNameAttr())});
+ auto dispatchParamsSet = dispatchParamsMap.find(symbolRefAttr);
+ if (dispatchParamsSet != dispatchParamsMap.end()) {
+ for (auto dispatchParams : dispatchParamsSet->second) {
+ appendDispatchBenchmark(executableOp, variantOp, entryPointOp,
+ dispatchParams.second, moduleBuilder);
+ hasAnyBenchmarks = true;
+ }
+ }
+ }
+
+ // Skip the file when we could not generate any benchmarks.
+ if (!hasAnyBenchmarks) return {};
+
+ // Run CSE and the canonicalizer to pretty up the output.
+ PassManager passManager(moduleOp->getContext());
+ passManager.addPass(mlir::createCanonicalizerPass());
+ passManager.addPass(mlir::createCSEPass());
+ if (failed(passManager.run(*moduleOp))) {
+ moduleOp->emitError("failed to run canonicalizer; malformed output");
+ return {};
+ }
+
+ return moduleOp;
+}
+
+static void dumpModuleToStream(mlir::ModuleOp moduleOp, StringRef fileName,
+ llvm::raw_ostream &os) {
+ OpPrintingFlags flags;
+ flags.useLocalScope(); // could use global scope, but IR gets messy fast
+ moduleOp.print(os, flags);
+ os << "\n"; // newline at end of file
+}
+
+class DumpExecutableBenchmarksPass
+ : public PassWrapper<DumpExecutableBenchmarksPass,
+ OperationPass<ModuleOp>> {
+ public:
+ DumpExecutableBenchmarksPass() = default;
+ DumpExecutableBenchmarksPass(const DumpExecutableBenchmarksPass &pass) {}
+ DumpExecutableBenchmarksPass(StringRef path) { this->path = path.str(); }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<IREE::HAL::HALDialect>();
+ registry.insert<arith::ArithmeticDialect>();
+ registry.insert<scf::SCFDialect>();
+ }
+
+ StringRef getArgument() const override {
+ return "iree-hal-dump-executable-benchmarks";
+ }
+
+ StringRef getDescription() const override {
+ return "Dumps standalone hal.executable benchmarks to a path.";
+ }
+
+ void runOnOperation() override {
+ auto moduleOp = getOperation();
+ auto moduleName = moduleOp.getName().getValueOr("module");
+
+ // Analyze the module to find dispatch parameters.
+ // This is a full walk of all stream.cmd.dispatch ops and will handle
+ // filtering out dispatches that have dynamic parameters we don't
+ // currently support.
+ auto dispatchParamsMap = gatherDispatchParams(moduleOp);
+ if (dispatchParamsMap.empty()) return;
+
+ // Help people out and mkdir if needed.
+ if (!path.empty() && path != "-") {
+ llvm::sys::fs::create_directories(path);
+ }
+
+ // Produce one file per executable containing all entry points.
+ for (auto executableOp : moduleOp.getOps<IREE::HAL::ExecutableOp>()) {
+ for (auto variantOp :
+ executableOp.getOps<IREE::HAL::ExecutableVariantOp>()) {
+ auto benchmarkModuleOp =
+ buildBenchmarkModule(executableOp, variantOp, dispatchParamsMap);
+ if (!benchmarkModuleOp) continue;
+ auto fileName = (moduleName + "_" + executableOp.getName() + "_" +
+ variantOp.getName() + ".mlir")
+ .str();
+ if (path.empty() || path == "-") {
+ dumpModuleToStream(*benchmarkModuleOp, fileName, llvm::outs());
+ } else {
+ auto filePath =
+ (path + llvm::sys::path::get_separator() + fileName).str();
+ std::string error;
+ auto file = mlir::openOutputFile(filePath, &error);
+ if (!file) {
+ executableOp.emitError()
+ << "while dumping to " << path << ": " << error;
+ return signalPassFailure();
+ }
+ dumpModuleToStream(*benchmarkModuleOp, fileName, file->os());
+ file->keep();
+ }
+ }
+ }
+ }
+
+ private:
+ Option<std::string> path{
+ *this, "path",
+ llvm::cl::desc("Path to write hal.executable benchmarks into.")};
+};
+
+std::unique_ptr<OperationPass<ModuleOp>> createDumpExecutableBenchmarksPass(
+ StringRef path) {
+ return std::make_unique<DumpExecutableBenchmarksPass>(path);
+}
+
+static PassRegistration<DumpExecutableBenchmarksPass> pass([] {
+ return std::make_unique<DumpExecutableBenchmarksPass>();
+});
+
+} // namespace HAL
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
index 6baf677..d1904d7 100644
--- a/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
@@ -125,6 +125,14 @@
createDumpExecutableSourcesPass(targetOptions.sourceListingPath));
}
+ // Dump standalone hal.executable benchmark modules.
+ // Today this only works for executables that have static dispatch parameters
+ // and is only useful for basic microbenchmarking.
+ if (!targetOptions.executableBenchmarksPath.empty()) {
+ passManager.addPass(createDumpExecutableBenchmarksPass(
+ targetOptions.executableBenchmarksPath));
+ }
+
// TODO(benvanik): move translation after conversion; today translation
// inserts the workgroup count logic we need to convert but we could instead
// insert placeholder ops that are expanded after translation.
diff --git a/iree/compiler/Dialect/HAL/Transforms/Passes.h b/iree/compiler/Dialect/HAL/Transforms/Passes.h
index 0067fd9..aac5702 100644
--- a/iree/compiler/Dialect/HAL/Transforms/Passes.h
+++ b/iree/compiler/Dialect/HAL/Transforms/Passes.h
@@ -84,6 +84,10 @@
std::unique_ptr<OperationPass<mlir::ModuleOp>> createDumpExecutableSourcesPass(
StringRef path);
+// Dumps standalone hal.executable benchmarks to |path|.
+std::unique_ptr<OperationPass<mlir::ModuleOp>>
+createDumpExecutableBenchmarksPass(StringRef path);
+
// Translates hal.executable.variant ops via a nested translation pipeline.
std::unique_ptr<OperationPass<IREE::HAL::ExecutableOp>>
createTranslateExecutablesPass();
diff --git a/iree/compiler/Dialect/HAL/Transforms/test/BUILD b/iree/compiler/Dialect/HAL/Transforms/test/BUILD
index c9b0985..475332d 100644
--- a/iree/compiler/Dialect/HAL/Transforms/test/BUILD
+++ b/iree/compiler/Dialect/HAL/Transforms/test/BUILD
@@ -20,6 +20,7 @@
"assign_target_devices.mlir",
"benchmark_batch_dispatches.mlir",
"convert_to_hal.mlir",
+ "dump_executable_benchmarks.mlir",
"dump_executable_sources.mlir",
"elide_redundant_commands.mlir",
"inline_device_switches.mlir",
diff --git a/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt b/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt
index a641b86..f87d205 100644
--- a/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt
@@ -17,6 +17,7 @@
"assign_target_devices.mlir"
"benchmark_batch_dispatches.mlir"
"convert_to_hal.mlir"
+ "dump_executable_benchmarks.mlir"
"dump_executable_sources.mlir"
"elide_redundant_commands.mlir"
"inline_device_switches.mlir"
diff --git a/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_benchmarks.mlir b/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_benchmarks.mlir
new file mode 100644
index 0000000..d5aab40
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_benchmarks.mlir
@@ -0,0 +1,157 @@
+// RUN: iree-opt -split-input-file -iree-hal-dump-executable-benchmarks %s | FileCheck %s
+
+// Tests dumping executable benchmarks to stdout - it's more common to use files
+// but this is much easier to test with lit.
+
+#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm", "embedded-elf-x86_64">
+#device_target_cpu = #hal.device.target<"cpu", {
+ executable_targets = [#executable_target_embedded_elf_x86_64_]
+}>
+#executable_layout_0 = #hal.executable.layout<push_constants = 2, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>,
+ #hal.descriptor_set.binding<2, storage_buffer>
+ ]>
+]>
+#executable_layout_1 = #hal.executable.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>
+ ]>
+]>
+
+module attributes {hal.device.targets = [#device_target_cpu]} {
+
+ // Executable should be dumped:
+ // CHECK: hal.executable private @ex0
+ hal.executable private @ex0 {
+ hal.executable.variant public @embedded_elf_x86_64, target = #executable_target_embedded_elf_x86_64_ {
+ hal.executable.entry_point public @dispatch0 ordinal(0) layout(#executable_layout_0) {
+ translation_info = #iree_codegen.translation_info<CPUDefault, workload_per_wg = [4]>
+ } {
+ ^bb0(%arg0: index, %arg1: index, %arg2: index):
+ %c1 = arith.constant 1 : index
+ %0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%arg0]
+ hal.return %0, %c1, %c1 : index, index, index
+ }
+ builtin.module {
+ func.func @dispatch0() {
+ func.return
+ }
+ }
+
+ hal.executable.entry_point public @dispatch1 ordinal(1) layout(#executable_layout_1) {
+ translation_info = #iree_codegen.translation_info<CPUDefault, workload_per_wg = [4]>
+ } {
+ ^bb0(%arg0: index, %arg1: index, %arg2: index):
+ %c1 = arith.constant 1 : index
+ %0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%arg0]
+ hal.return %0, %c1, %c1 : index, index, index
+ }
+ builtin.module {
+ func.func @dispatch1() {
+ func.return
+ }
+ }
+ }
+ }
+
+ // ===========================================================================
+ // @dispatch0 benchmark logic:
+ // ===========================================================================
+
+ // CHECK: util.global private mutable @ex0_embedded_elf_x86_64_dispatch0_512x1x1_buffer : !hal.buffer
+ // CHECK-NEXT: util.initializer {
+ // CHECK: %[[BUFFER:.+]] = hal.allocator.allocate<%{{.+}} : !hal.allocator> type("DeviceVisible|DeviceLocal") usage(Dispatch) : !hal.buffer{%c768}
+ // CHECK-NEXT: util.global.store %[[BUFFER]], @ex0_embedded_elf_x86_64_dispatch0_512x1x1_buffer : !hal.buffer
+
+ // CHECK: func @ex0_embedded_elf_x86_64_dispatch0_512x1x1(%arg0: i32)
+ // CHECK-SAME: attributes {iree.abi.stub, iree.reflection = {iree.benchmark = "dispatch"}} {
+ // CHECK: %[[BATCH_SIZE:.+]] = arith.index_cast %arg0 : i32 to index
+
+ // Create command buffer:
+ // CHECK: %[[CMD:.+]] = hal.command_buffer.create
+ // CHECK: hal.command_buffer.begin<%[[CMD]] : !hal.command_buffer>
+
+ // Setup dispatch constants and bindings:
+ // CHECK: hal.command_buffer.push_constants<%[[CMD]] : !hal.command_buffer> layout(%{{.+}} : !hal.executable_layout) offset(0) values([%c0_i32, %c0_i32]) : i32, i32
+ // CHECK: %[[BUFFER:.+]] = util.global.load @ex0_embedded_elf_x86_64_dispatch0_512x1x1_buffer
+ // CHECK: hal.command_buffer.push_descriptor_set<%[[CMD]] : !hal.command_buffer> layout(%{{.+}} : !hal.executable_layout)[%c0] bindings([
+ // CHECK-NEXT: %c0 = (%[[BUFFER]] : !hal.buffer)[%c0, %c32],
+ // CHECK-NEXT: %c1 = (%[[BUFFER]] : !hal.buffer)[%c256, %c32],
+ // CHECK-NEXT: %c2 = (%[[BUFFER]] : !hal.buffer)[%c512, %c32]
+ // CHECK-NEXT: ])
+
+ // Dispatch up to batch size dispatches:
+ // CHECK: scf.for %{{.+}} = %c0 to %[[BATCH_SIZE]] step %c1 {
+ // CHECK-NEXT: hal.command_buffer.dispatch.symbol<%[[CMD]] : !hal.command_buffer> target(@ex0::@embedded_elf_x86_64::@dispatch0) workgroups([%c128, %c1, %c1])
+ // CHECK-NEXT: hal.command_buffer.execution_barrier
+ // CHECK-NEXT: }
+
+ // Submit and wait for dispatches to complete:
+ // CHECK: hal.command_buffer.end<%[[CMD]] : !hal.command_buffer>
+ // CHECK: hal.ex.submit_and_wait %{{.+}}, %[[CMD]]
+
+ // ===========================================================================
+ // @dispatch1 benchmark logic (note two deduplicated dispatches):
+ // ===========================================================================
+
+ // CHECK: util.global private mutable @ex0_embedded_elf_x86_64_dispatch1_512x1x1_buffer : !hal.buffer
+ // CHECK: func @ex0_embedded_elf_x86_64_dispatch1_512x1x1(%arg0: i32)
+ // CHECK: hal.command_buffer.dispatch.symbol<%{{.+}} : !hal.command_buffer> target(@ex0::@embedded_elf_x86_64::@dispatch1) workgroups([%c128, %c1, %c1])
+
+ // CHECK: util.global private mutable @ex0_embedded_elf_x86_64_dispatch1_128x32x1_buffer : !hal.buffer
+ // CHECK: func @ex0_embedded_elf_x86_64_dispatch1_128x32x1(%arg0: i32)
+ // CHECK: hal.command_buffer.dispatch.symbol<%{{.+}} : !hal.command_buffer> target(@ex0::@embedded_elf_x86_64::@dispatch1) workgroups([%c32, %c1, %c1])
+
+ func private @main() -> !stream.timepoint {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c32 = arith.constant 32 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c512 = arith.constant 512 : index
+ %c0_i32 = arith.constant 0 : i32
+ %c1_i32 = arith.constant 1 : i32
+ %result, %result_timepoint = stream.resource.alloca uninitialized : !stream.resource<transient>{%c128} => !stream.timepoint
+ %6 = stream.cmd.execute await(%result_timepoint) => with(%result as %arg0: !stream.resource<transient>{%c128}) {
+ // Dispatch with dynamic args.
+ stream.cmd.dispatch @ex0::@dispatch0[%c512, %c1, %c1](%c0_i32, %c1_i32 : i32, i32) {
+ ro %arg0[%c0 for %c32] : !stream.resource<transient>{%c128},
+ rw %arg0[%c32 for %c32] : !stream.resource<transient>{%c128},
+ rw %arg0[%c64 for %c32] : !stream.resource<transient>{%c128}
+ } attributes {hal.interface.bindings = [
+ #hal.interface.binding<0, 0>,
+ #hal.interface.binding<0, 1>,
+ #hal.interface.binding<0, 2>
+ ]}
+
+ // Multiple dispatches to a single entry point.
+ // Dispatches are deduplicated and the two 128x32x1 should combine.
+ stream.cmd.dispatch @ex0::@dispatch1[%c512, %c1, %c1] {
+ ro %arg0[%c0 for %c64] : !stream.resource<transient>{%c128},
+ rw %arg0[%c64 for %c32] : !stream.resource<transient>{%c128}
+ } attributes {hal.interface.bindings = [
+ #hal.interface.binding<0, 0>,
+ #hal.interface.binding<0, 1>
+ ]}
+ stream.cmd.dispatch @ex0::@dispatch1[%c128, %c32, %c1] {
+ ro %arg0[%c0 for %c64] : !stream.resource<transient>{%c128},
+ rw %arg0[%c64 for %c32] : !stream.resource<transient>{%c128}
+ } attributes {hal.interface.bindings = [
+ #hal.interface.binding<0, 0>,
+ #hal.interface.binding<0, 1>
+ ]}
+ stream.cmd.dispatch @ex0::@dispatch1[%c128, %c32, %c1] {
+ ro %arg0[%c0 for %c64] : !stream.resource<transient>{%c128},
+ rw %arg0[%c64 for %c32] : !stream.resource<transient>{%c128}
+ } attributes {hal.interface.bindings = [
+ #hal.interface.binding<0, 0>,
+ #hal.interface.binding<0, 1>
+ ]}
+ } => !stream.timepoint
+ %39 = stream.resource.dealloca await(%6) => %result : !stream.resource<transient>{%c128} => !stream.timepoint
+ return %39 : !stream.timepoint
+ }
+}
diff --git a/iree/compiler/Dialect/Stream/Conversion/HALToStream/ConvertHALToStream.cpp b/iree/compiler/Dialect/Stream/Conversion/HALToStream/ConvertHALToStream.cpp
index c92f063..87bebc1 100644
--- a/iree/compiler/Dialect/Stream/Conversion/HALToStream/ConvertHALToStream.cpp
+++ b/iree/compiler/Dialect/Stream/Conversion/HALToStream/ConvertHALToStream.cpp
@@ -208,6 +208,10 @@
ConversionTarget &conversionTarget,
TypeConverter &typeConverter,
RewritePatternSet &patterns) {
+ // Allow executables through without modification.
+ conversionTarget.addLegalOp<IREE::HAL::ExecutableOp>();
+ conversionTarget.markOpRecursivelyLegal<IREE::HAL::ExecutableOp>();
+
conversionTarget.addDynamicallyLegalOp<IREE::HAL::TensorImportOp>(
[&](IREE::HAL::TensorImportOp op) {
return typeConverter.isLegal(op.source().getType()) &&
diff --git a/iree/compiler/Dialect/Stream/Transforms/BUILD b/iree/compiler/Dialect/Stream/Transforms/BUILD
index 531da5a..1c01300 100644
--- a/iree/compiler/Dialect/Stream/Transforms/BUILD
+++ b/iree/compiler/Dialect/Stream/Transforms/BUILD
@@ -46,6 +46,7 @@
deps = [
":PassesIncGen",
"//iree/compiler/Dialect/Flow/IR",
+ "//iree/compiler/Dialect/HAL/IR",
"//iree/compiler/Dialect/Stream/Analysis",
"//iree/compiler/Dialect/Stream/Conversion",
"//iree/compiler/Dialect/Stream/Conversion/FlowToStream",
diff --git a/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt
index c537436..2dfc4c7 100644
--- a/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt
@@ -59,6 +59,7 @@
MLIRTransforms
MLIRVector
iree::compiler::Dialect::Flow::IR
+ iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::Stream::Analysis
iree::compiler::Dialect::Stream::Conversion
iree::compiler::Dialect::Stream::Conversion::FlowToStream
diff --git a/iree/compiler/Dialect/Stream/Transforms/VerifyLowerings.cpp b/iree/compiler/Dialect/Stream/Transforms/VerifyLowerings.cpp
index 3ceeaab..062b0e9 100644
--- a/iree/compiler/Dialect/Stream/Transforms/VerifyLowerings.cpp
+++ b/iree/compiler/Dialect/Stream/Transforms/VerifyLowerings.cpp
@@ -6,6 +6,7 @@
#include <utility>
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "iree/compiler/Dialect/Stream/IR/StreamTraits.h"
@@ -207,6 +208,10 @@
DenseMap<TypeID, TypeVerifierFn> typeVerifiers;
};
+static void setupDefaultOpLegality(Verifier &verifier) {
+ verifier.addRecursivelyLegalOp<IREE::HAL::ExecutableOp>();
+}
+
static void markStreamTensorOpsIllegal(Verifier &verifier) {
verifier.addOpVerifier([](Operation *op) -> Optional<Verifier::Legality> {
if (op->hasTrait<OpTrait::IREE::Stream::TensorPhaseOp>()) {
@@ -246,6 +251,7 @@
void runOnOperation() override {
Verifier verifier;
+ setupDefaultOpLegality(verifier);
// TODO(#7432): add indirect global expansion support to streams.
verifier.addIllegalOp<IREE::Util::GlobalAddressOp>();
@@ -297,6 +303,7 @@
// We cannot have stream.cmd.* ops mixed with stream.tensor/async.* ops
// as they use different memory models.
Verifier verifier;
+ setupDefaultOpLegality(verifier);
markTensorInputsIllegal(verifier);
markStreamCmdOpsIllegal(verifier);
if (failed(verifier.run(getOperation()))) {
@@ -332,6 +339,7 @@
// We cannot have stream.cmd.* ops mixed with stream.tensor/async.* ops
// as they use different memory models.
Verifier verifier;
+ setupDefaultOpLegality(verifier);
markTensorInputsIllegal(verifier);
markStreamTensorOpsIllegal(verifier);
markStreamCmdOpsIllegal(verifier);
@@ -390,6 +398,7 @@
void runOnOperation() override {
Verifier verifier;
+ setupDefaultOpLegality(verifier);
markTensorInputsIllegal(verifier);
markStreamTensorOpsIllegal(verifier);
markStreamAsyncOpsIllegal(verifier);