Add a preprocessing pass to move entire function into a single dispatch. (#14578)
For cases where the model is very small and does not have much concurrency, it is better to move the entire function body into a single dispatch. Eventually the default heuristics can probably figure out when a model is "too small", but for now this PR adds a pass to move the entire function body into a single dispatch to use as a way to find codegen issues such an approach throws up, and also to experiment with different heuristics needed to find such dispatches automatically.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
index 78d4e3b..c1a295f 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
@@ -280,33 +280,80 @@
// flow.dispatch.region
//===----------------------------------------------------------------------===//
-LogicalResult DispatchRegionOp::verify() {
- // No block arguments.
- if (!getBody().getArguments().empty())
- return emitOpError() << "expected no block arguments";
+// Verifies the workgroup count
- // Only one block.
- if (!getBody().hasOneBlock())
- return emitOpError() << "expected exactly 1 block";
+static LogicalResult
+verifyWorkgroupCountRegion(Operation *op, ValueRange workload, Region ®ion) {
+ // Verify the workload operands match the expected capture args.
+ if (workload.size() != region.getNumArguments()) {
+ return op->emitOpError()
+ << "workload operands and workgroup count args mismatch ("
+ << workload.size() << " vs " << region.getNumArguments() << ")";
+ }
+ for (auto [index, values] :
+ llvm::enumerate(llvm::zip_equal(workload, region.getArguments()))) {
+ auto [workloadValue, capturedArg] = values;
+ if (workloadValue.getType() != capturedArg.getType()) {
+ return op->emitOpError()
+ << "workload value " << index << " type mismatch; operand is "
+ << workloadValue.getType() << " but region captures "
+ << capturedArg.getType();
+ }
+ }
- // Verify terminator.
- auto returnOp = dyn_cast<Flow::ReturnOp>(getBody().front().getTerminator());
- if (!returnOp)
- return emitOpError() << "expected 'flow.return' terminator";
- for (const auto [resultType, returnType] :
- llvm::zip_equal(getResultTypes(), returnOp->getOperandTypes()))
- if (resultType != returnType)
- return returnOp->emitOpError()
- << "operand types do not match with parent results";
-
- // Make sure that all returned values are ranked tensors.
- for (Type t : getResultTypes())
- if (!llvm::isa<RankedTensorType>(t))
- return emitOpError() << "only ranked tensor results are allowed";
+ // Verify the return ops all provide XYZ values.
+ for (auto returnOp : region.getOps<IREE::Flow::ReturnOp>()) {
+ if (returnOp.getNumOperands() != 3 ||
+ !llvm::all_of(returnOp.getOperandTypes(),
+ [](Type type) { return type.isIndex(); })) {
+ return returnOp.emitOpError() << "workgroup count region must return "
+ "the XYZ dimension counts";
+ }
+ }
return success();
}
+LogicalResult DispatchRegionOp::verify() {
+ // No block arguments.
+ if (!getBody().getArguments().empty()) {
+ return emitOpError() << "expected no block arguments";
+ }
+
+ // Verify terminator.
+ SmallVector<Flow::ReturnOp> returnOps;
+ for (Block &block : getBody()) {
+ if (auto returnOp =
+ dyn_cast_or_null<Flow::ReturnOp>(block.getTerminator())) {
+ returnOps.push_back(returnOp);
+ }
+ }
+ for (auto returnOp : returnOps) {
+ for (const auto [resultType, returnType] :
+ llvm::zip_equal(getResultTypes(), returnOp->getOperandTypes()))
+ if (resultType != returnType) {
+ return returnOp->emitOpError()
+ << "operand types do not match with parent results";
+ }
+ }
+
+ // Make sure that all returned values are ranked tensors.
+ for (Type t : getResultTypes()) {
+ if (!llvm::isa<RankedTensorType>(t)) {
+ return emitOpError() << "only ranked tensor results are allowed";
+ }
+ }
+
+ Region &workgroupCount = getWorkgroupCount();
+ if (workgroupCount.empty()) {
+ return success();
+ }
+
+ // If workgroup count region exists, check it has a single block.
+ return verifyWorkgroupCountRegion(getOperation(), getWorkload(),
+ getWorkgroupCount());
+}
+
ParseResult DispatchRegionOp::parse(OpAsmParser &parser,
OperationState &result) {
SmallVector<Type> resultTypes;
@@ -348,7 +395,6 @@
return failure();
if (parser.parseRegion(*bodyRegion))
return failure();
- ensureTerminator(*bodyRegion, parser.getBuilder(), result.location);
if (parseDispatchWorkgroupsCountRegion(parser, *workloadCountRegion)) {
return failure();
@@ -868,38 +914,6 @@
/*printBlockTerminators=*/true);
}
-LogicalResult verifyWorkgroupCountRegion(Operation *op, ValueRange workload,
- Region ®ion) {
- // Verify the workload operands match the expected capture args.
- if (workload.size() != region.getNumArguments()) {
- return op->emitOpError()
- << "workload operands and workgroup count args mismatch ("
- << workload.size() << " vs " << region.getNumArguments() << ")";
- }
- for (auto [index, values] :
- llvm::enumerate(llvm::zip_equal(workload, region.getArguments()))) {
- auto [workloadValue, capturedArg] = values;
- if (workloadValue.getType() != capturedArg.getType()) {
- return op->emitOpError()
- << "workload value " << index << " type mismatch; operand is "
- << workloadValue.getType() << " but region captures "
- << capturedArg.getType();
- }
- }
-
- // Verify the return ops all provide XYZ values.
- for (auto returnOp : region.getOps<IREE::Flow::ReturnOp>()) {
- if (returnOp.getNumOperands() != 3 ||
- !llvm::all_of(returnOp.getOperandTypes(),
- [](Type type) { return type.isIndex(); })) {
- return returnOp.emitOpError() << "workgroup count region must return "
- "the XYZ dimension counts";
- }
- }
-
- return success();
-}
-
LogicalResult DispatchWorkgroupsOp::verify() {
Operation *op = getOperation();
@@ -1043,7 +1057,7 @@
IREE::Util::ValueAccess
DispatchWorkgroupsOp::getResultAccess(unsigned resultIndex) {
- unsigned startIndex = getBody()->getNumArguments() - getNumResults();
+ unsigned startIndex = getWorkgroupBody().getNumArguments() - getNumResults();
BlockArgument arg =
getWorkgroupBody().front().getArgument(startIndex + resultIndex);
if (auto tensorType = llvm::dyn_cast<DispatchTensorType>(arg.getType())) {
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
index 8affac2..e2a2335 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
@@ -35,8 +35,7 @@
def FLOW_DispatchRegionOp : FLOW_PureOp<"dispatch.region", [
Util_ShapeAwareOp,
- AttrSizedOperandSegments,
- SingleBlockImplicitTerminator<"IREE::Flow::ReturnOp">]> {
+ AttrSizedOperandSegments]> {
let summary = [{a group of ops}];
let description = [{
This op is a container/grouping of ops. It represents a fusion group before
@@ -76,7 +75,6 @@
def FLOW_DispatchWorkgroupsOp : FLOW_PureOp<"dispatch.workgroups", [
IsolatedFromAbove,
AttrSizedOperandSegments,
- SingleBlockImplicitTerminator<"IREE::Flow::ReturnOp">,
DeclareOpInterfaceMethods<Util_ClosureOpInterface>,
DeclareOpInterfaceMethods<Util_TiedOpInterface, [
"getTiedOperandsIndexAndLength",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CaptureDispatchDynamicDims.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CaptureDispatchDynamicDims.cpp
index 0c99de0..614a391 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CaptureDispatchDynamicDims.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CaptureDispatchDynamicDims.cpp
@@ -30,7 +30,11 @@
// leave the cleanup of redundant work to further optimization passes to keep
// this simple.
static void captureDims(IREE::Flow::DispatchWorkgroupsOp dispatchOp) {
- auto *entryBlock = dispatchOp.getBody();
+ Region &body = dispatchOp.getWorkgroupBody();
+ if (body.empty()) {
+ return;
+ }
+ auto *entryBlock = &body.front();
// Map of SSA values on the outside of the op to arguments on the inside.
// This lets us avoid capturing duplicate values - they'd be cleaned up
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.cpp
index 601f640..802b6ef 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.cpp
@@ -63,13 +63,15 @@
while (!isOutside(value)) {
auto tiedOpInterface = value.getDefiningOp<IREE::Util::TiedOpInterface>();
- if (!tiedOpInterface)
+ if (!tiedOpInterface) {
// Reached an op that does not implement the interface.
return std::nullopt;
+ }
value = tiedOpInterface.getTiedResultOperand(value);
- if (!value)
+ if (!value) {
// Nothing is tied here.
return std::nullopt;
+ }
}
return value;
@@ -84,13 +86,13 @@
FailureOr<Flow::DispatchWorkgroupsOp>
rewriteFlowDispatchRegionToFlowDispatchWorkgroups(
Flow::DispatchRegionOp regionOp, RewriterBase &rewriter) {
- // Only ops with a single block are supported.
Region ®ion = regionOp.getBody();
- if (!region.hasOneBlock())
- return failure();
- Block &body = region.front();
- auto terminator = cast<Flow::ReturnOp>(body.getTerminator());
- unsigned numResults = terminator->getNumOperands();
+ // Currently this does not handle empty `flow.dispatch.region` ops.
+ if (region.empty()) {
+ return rewriter.notifyMatchFailure(regionOp,
+ "unhandled op with empty region");
+ }
+ unsigned numResults = regionOp->getNumResults();
// Prepare rewriter.
OpBuilder::InsertionGuard guard(rewriter);
@@ -118,7 +120,14 @@
DenseSet<Value> tiedArgumentsSet;
SmallVector<int64_t> tiedArguments(numResults,
IREE::Util::TiedOpInterface::kUntiedIndex);
- for (const auto &it : llvm::enumerate(terminator->getOperands())) {
+ SmallVector<Flow::ReturnOp> origTerminators;
+ region.walk(
+ [&](Flow::ReturnOp returnOp) { origTerminators.push_back(returnOp); });
+ assert(!origTerminators.empty() && "expected at least one terminator");
+ // Use one of the terminators to get the the `tiedArguments` set.
+ // TODO: Check that using all terminators gives you the same result.
+ for (const auto &it :
+ llvm::enumerate(origTerminators.front()->getOperands())) {
auto tiedArgument =
findFirstTiedValueOutsideOfRegionOp(regionOp, it.value());
if (!tiedArgument.has_value())
@@ -166,15 +175,20 @@
bvm.map(arguments, workgroupsOp.getInputBlockArguments());
// Create DispatchTensorLoadOp for all tensor arguments.
- assert(workgroupsOp.getWorkgroupBody().hasOneBlock() &&
- "expected one block after constructor");
- Block &newBody = workgroupsOp.getWorkgroupBody().getBlocks().front();
- assert(newBody.empty() && "expected empty block after constructor");
- rewriter.setInsertionPointToStart(&newBody);
+ Region &newBody = workgroupsOp.getWorkgroupBody();
+ assert(llvm::hasSingleElement(newBody) &&
+ "expected `flow.dispatch.workgroup` op to be created with a single "
+ "block");
+
+ Block *newBodyEntry = &newBody.front();
+ rewriter.setInsertionPointToStart(newBodyEntry);
+ SmallVector<Value> argValues;
for (const auto &it : llvm::enumerate(arguments)) {
auto tensorType = llvm::dyn_cast<RankedTensorType>(it.value().getType());
- if (!tensorType)
+ if (!tensorType) {
+ argValues.push_back(it.value());
continue;
+ }
auto inputBbArg = workgroupsOp.getInputBlockArgument(it.index());
auto dims =
Util::findVariadicDynamicDims(it.index(), arguments, argumentDims);
@@ -185,10 +199,16 @@
Value loadedTensor = rewriter.create<IREE::Flow::DispatchTensorLoadOp>(
loc, tensorType, inputBbArg, bbArgDims);
bvm.map(it.value(), loadedTensor);
+ argValues.push_back(loadedTensor);
}
// Move regionOp body into the workgroupsOp.
- newBody.getOperations().splice(newBody.end(), body.getOperations());
+ rewriter.inlineRegionBefore(region, newBody, newBody.end());
+ // Merge the enrty block of `newBody` with the original entry block from the
+ // region.
+ Block *origEntry = &(*(std::next(newBody.begin())));
+ rewriter.mergeBlocks(origEntry, newBodyEntry);
+
for (Value argument : arguments) {
argument.replaceUsesWithIf(bvm.lookup(argument), [&](OpOperand &operand) {
return workgroupsOp->isProperAncestor(operand.getOwner());
@@ -196,33 +216,38 @@
}
// Update terminator.
- rewriter.setInsertionPoint(terminator);
- for (const auto &it : llvm::enumerate(terminator->getOperands())) {
- auto outputBbArg = workgroupsOp.getOutputBlockArgument(it.index());
- ValueRange dims;
- if (tiedArguments[it.index()] ==
- IREE::Util::TiedOpInterface::kUntiedIndex) {
- dims = regionOp.getResultDynamicDims(it.index());
- } else {
- // This assumes that the number of dynamic dims does not change when
- // following an SSA use-def chain of tied values.
- dims = Util::findVariadicDynamicDims(tiedArguments[it.index()], arguments,
- argumentDims);
- }
+ SmallVector<Flow::ReturnOp> terminators;
+ newBody.walk(
+ [&](Flow::ReturnOp returnOp) { terminators.push_back(returnOp); });
+ for (auto terminator : terminators) {
+ rewriter.setInsertionPoint(terminator);
+ for (const auto &it : llvm::enumerate(terminator->getOperands())) {
+ auto outputBbArg = workgroupsOp.getOutputBlockArgument(it.index());
+ ValueRange dims;
+ if (tiedArguments[it.index()] ==
+ IREE::Util::TiedOpInterface::kUntiedIndex) {
+ dims = regionOp.getResultDynamicDims(it.index());
+ } else {
+ // This assumes that the number of dynamic dims does not change when
+ // following an SSA use-def chain of tied values.
+ dims = Util::findVariadicDynamicDims(tiedArguments[it.index()],
+ arguments, argumentDims);
+ }
#ifndef NDEBUG
- auto tensorType = it.value().getType().cast<RankedTensorType>();
- assert(dims.size() == tensorType.getNumDynamicDims() &&
- "mismatching number of dynamic dims");
+ auto tensorType = it.value().getType().cast<RankedTensorType>();
+ assert(dims.size() == tensorType.getNumDynamicDims() &&
+ "mismatching number of dynamic dims");
#endif // NDEBUG
- SmallVector<Value> bbArgDims =
- llvm::map_to_vector(dims, [&](Value v) { return bvm.lookup(v); });
- rewriter.create<IREE::Flow::DispatchTensorStoreOp>(loc, it.value(),
- outputBbArg, bbArgDims);
- }
+ SmallVector<Value> bbArgDims =
+ llvm::map_to_vector(dims, [&](Value v) { return bvm.lookup(v); });
+ rewriter.create<IREE::Flow::DispatchTensorStoreOp>(
+ loc, it.value(), outputBbArg, bbArgDims);
+ }
- // Delete the old terminator and create a new one.
- rewriter.create<IREE::Flow::ReturnOp>(loc);
- rewriter.eraseOp(terminator);
+ // Delete the old terminator and create a new one.
+ rewriter.create<IREE::Flow::ReturnOp>(loc);
+ rewriter.eraseOp(terminator);
+ }
rewriter.replaceOp(regionOp, workgroupsOp.getResults());
return workgroupsOp;
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchWorkgroups.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchWorkgroups.cpp
index 5a6396f..d5cd0fa 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchWorkgroups.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchWorkgroups.cpp
@@ -236,7 +236,11 @@
// Annotate the values captures as workload with their position in the
// workload list.
- rewriter.setInsertionPointToStart(workgroupsOp.getBody());
+ Region &body = workgroupsOp.getWorkgroupBody();
+ if (body.empty()) {
+ return;
+ }
+ rewriter.setInsertionPointToStart(&body.front());
int ordinalNumber = 0;
for (auto [index, operand] : llvm::enumerate(workgroupsOp.getArguments())) {
if (!llvm::isa<IndexType>(operand.getType()))
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp
index 6088473..f15114e 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp
@@ -201,7 +201,7 @@
Operation *bestOp = NULL;
const int64_t kMinEstimatedCost = -1;
int64_t bestEstimatedCost = kMinEstimatedCost;
- regionOp.getBodyRegion().walk([&](Operation *op) {
+ regionOp.getWorkgroupBody().walk([&](Operation *op) {
TypeSwitch<Operation *>(op)
.Case<linalg::LinalgOp>([&](auto op) {
int64_t estimatedCost = estimateLinalgOpCost(op);
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_dispatch_workgroups.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_dispatch_workgroups.mlir
index 01fe2cc..85f263b 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_dispatch_workgroups.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_dispatch_workgroups.mlir
@@ -14,3 +14,32 @@
// CHECK: count(%[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: index)
// CHECK: %[[C1:.+]] = arith.constant 1 : index
// CHECK: flow.return %[[ARG2]], %[[ARG3]], %[[C1]]
+
+// -----
+
+func.func @simple_test_with_cfg(%arg0: i1) -> (tensor<10x20xf32>) {
+ %cst = arith.constant dense<1.000000e+00> : tensor<10x20xf32>
+ %0 = flow.dispatch.region -> (tensor<10x20xf32>) {
+ %cst_0 = arith.constant dense<1.000000e+00> : tensor<10x20xf32>
+ cf.cond_br %arg0, ^bb1, ^bb2
+ ^bb1: // pred: ^bb0
+ %2 = tensor.empty() : tensor<10x20xf32>
+ flow.return %2 : tensor<10x20xf32>
+ ^bb2: // pred: ^bb0
+ flow.return %cst_0 : tensor<10x20xf32>
+ }
+ return %0 : tensor<10x20xf32>
+}
+// CHECK-LABEL: func @simple_test_with_cfg
+// CHECK-SAME: %[[ARG0:.+]]: i1
+// CHECK: %[[RESULT:.+]] = flow.dispatch.workgroups(%[[ARG0]])
+// CHECK-NEXT: %[[ARG1:.+]]: i1, %[[ARG2:.+]]: !flow.dispatch.tensor
+// CHECK: %[[CST:.+]] = arith.constant
+// CHECK: ^[[BB1:.+]]:
+// CHECK: %[[EMPTY:.+]] = tensor.empty()
+// CHECK: flow.dispatch.tensor.store %[[EMPTY]], %[[ARG2]]
+// CHECK: flow.return
+// CHECK: ^[[BB2:.+]]:
+// CHECK: flow.dispatch.tensor.store %[[CST]], %[[ARG2]]
+// CHECK: flow.return
+// CHECK: return %[[RESULT]]
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
index c7a6702..dfd3178 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
@@ -31,6 +31,7 @@
name = "Transforms",
srcs = [
"ConvertConv2DToImg2Col.cpp",
+ "MakeSingleDispatchForFunction.cpp",
"PadLinalgOps.cpp",
"PassDetail.h",
"Passes.cpp",
@@ -42,6 +43,7 @@
],
deps = [
":PassesIncGen",
+ "//compiler/src/iree/compiler/Dialect/Flow/IR",
"//compiler/src/iree/compiler/Dialect/Flow/Transforms",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
index b0ec7b5..7f59a88 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
@@ -27,6 +27,7 @@
"Passes.h.inc"
SRCS
"ConvertConv2DToImg2Col.cpp"
+ "MakeSingleDispatchForFunction.cpp"
"PadLinalgOps.cpp"
"PassDetail.h"
"Passes.cpp"
@@ -43,6 +44,7 @@
MLIRTensorDialect
MLIRTensorUtils
MLIRTransforms
+ iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::Flow::Transforms
PUBLIC
)
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/MakeSingleDispatchForFunction.cpp b/compiler/src/iree/compiler/Preprocessing/Common/MakeSingleDispatchForFunction.cpp
new file mode 100644
index 0000000..b452302
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/Common/MakeSingleDispatchForFunction.cpp
@@ -0,0 +1,97 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Preprocessing/Common/PassDetail.h"
+#include "iree/compiler/Preprocessing/Common/Passes.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+
+namespace {
+
+struct MakeSingleDispatchForFunctionPass
+ : public MakeSingleDispatchForFunctionBase<
+ MakeSingleDispatchForFunctionPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<IREE::Flow::FlowDialect>();
+ }
+
+ void runOnOperation() override;
+};
+} // namespace
+
+void MakeSingleDispatchForFunctionPass::runOnOperation() {
+ auto funcOp = getOperation();
+
+ // Abort if there are any operations that prevent moving all operations
+ // into a single dispatch.
+ auto walkResult = funcOp.walk([](Operation *op) -> WalkResult {
+ return success(!isa<func::CallOp>(op));
+ });
+ if (walkResult.wasInterrupted()) {
+ funcOp->emitOpError("unhandled operation in function body prevents moving "
+ "body into a single dispatch");
+ }
+
+ // Currently this can only be done for static shapes cause
+ // there is no way of getting the tied dynamic shapes for
+ // a function.
+ auto resultTypes = funcOp.getFunctionType().getResults();
+ if (llvm::any_of(resultTypes, [&](Type t) {
+ auto shapedType = t.dyn_cast<ShapedType>();
+ return shapedType && !shapedType.hasStaticShape();
+ })) {
+ return;
+ }
+
+ IRRewriter rewriter(&getContext());
+ Location loc = funcOp.getLoc();
+ Region &funcBody = funcOp.getBody();
+
+ // Split the function entry block to create a new entry block into which the
+ // new operations will be added.
+ Block &entryBlock = funcBody.front();
+ Block *funcBodyStart = rewriter.splitBlock(&entryBlock, entryBlock.begin());
+
+ // Create an empty `flow.dispatch.region` operation with same result type as
+ // the function.
+ rewriter.setInsertionPointToEnd(&entryBlock);
+ auto dispatchRegionOp = rewriter.create<IREE::Flow::DispatchRegionOp>(
+ loc, resultTypes, /*result_dims=*/ValueRange{},
+ /*workload=*/ValueRange{});
+
+ // Move the body of the function into the region.
+ Region ®ion = dispatchRegionOp.getBody();
+ region.getBlocks().splice(region.begin(), funcBody.getBlocks(),
+ Region::iterator(funcBodyStart), funcBody.end());
+
+ // Replace all `func.return` with `flow.return`.
+ SmallVector<func::ReturnOp> returnOps =
+ llvm::to_vector(region.getOps<func::ReturnOp>());
+ for (auto returnOp : returnOps) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(returnOp);
+ rewriter.replaceOpWithNewOp<IREE::Flow::ReturnOp>(returnOp,
+ returnOp.getOperands());
+ }
+
+ // Return the results of the `flow.dispatch.region`.
+ rewriter.setInsertionPointAfter(dispatchRegionOp);
+ rewriter.create<func::ReturnOp>(loc, dispatchRegionOp.getResults());
+}
+
+std::unique_ptr<OperationPass<func::FuncOp>>
+createMakeSingleDispatchForFunctionPass() {
+ return std::make_unique<MakeSingleDispatchForFunctionPass>();
+}
+
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h
index bd5d07b..c9835a9 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h
+++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h
@@ -17,11 +17,15 @@
namespace iree_compiler {
namespace IREE {
-// Creates a pass to convert linalg convolution ops into linalg.matmul ops
-// using im2col tranformation.
+/// Creates a pass to convert linalg convolution ops into linalg.matmul ops
+/// using im2col tranformation.
std::unique_ptr<Pass> createConvertConv2DToImg2ColPass();
-// A pass to pad linalg ops to the next integer multiple of `paddingSize`.
+/// Moves the body of the entire function into a single dispatch.
+std::unique_ptr<OperationPass<func::FuncOp>>
+createMakeSingleDispatchForFunctionPass();
+
+/// A pass to pad linalg ops to the next integer multiple of `paddingSize`.
std::unique_ptr<Pass> createPadLinalgOpsToIntegerMultiplePass();
/// Pass to merge parallel linalg operations.
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
index e063ef7..9c88781 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
+++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
@@ -15,6 +15,12 @@
let constructor = "mlir::iree_compiler::IREE::createConvertConv2DToImg2ColPass()";
}
+def MakeSingleDispatchForFunction :
+ Pass<"iree-preprocessing-make-single-dispatch-for-function", "func::FuncOp"> {
+ let summary = "Convert entire function into a single dispatch";
+ let constructor = "mlir::iree_compiler::IREE::createMakeSingleDispatchForFunctionPass()";
+}
+
def PadLinalgOps :
Pass<"iree-preprocessing-pad-linalg-ops", ""> {
let summary = "Pad linalg ops to the next integer multiple of paddingSize.";
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel
index 9dd501c..37d77d7 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel
@@ -17,6 +17,7 @@
srcs = enforce_glob(
[
"conv2d_to_img2col.mlir",
+ "make_single_dispatch_for_function.mlir",
"pad_linalg_ops.mlir",
"rematerialize_parallel_ops.mlir",
],
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt
index a754aea..0cca180 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt
@@ -15,6 +15,7 @@
lit
SRCS
"conv2d_to_img2col.mlir"
+ "make_single_dispatch_for_function.mlir"
"pad_linalg_ops.mlir"
"rematerialize_parallel_ops.mlir"
TOOLS
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/make_single_dispatch_for_function.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/make_single_dispatch_for_function.mlir
new file mode 100644
index 0000000..8b8a155
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/make_single_dispatch_for_function.mlir
@@ -0,0 +1,35 @@
+// RUN: iree-opt --iree-preprocessing-make-single-dispatch-for-function --split-input-file %s | FileCheck %s
+
+func.func @simple_test() -> tensor<10x20xf32> {
+ %0 = tensor.empty() : tensor<10x20xf32>
+ return %0 : tensor<10x20xf32>
+}
+// CHECK-LABEL: func @simple_test() -> tensor<10x20xf32>
+// CHECK-NEXT: %[[DISPATCH:.+]] = flow.dispatch.region
+// CHECK: %[[EMPTY:.+]] = tensor.empty
+// CHECK: flow.return %[[EMPTY]]
+// CHECK: return %[[DISPATCH]]
+
+// -----
+
+func.func @simple_test_with_cfg(%arg0 : i1) -> tensor<10x20xf32> {
+ cf.cond_br %arg0, ^bb1, ^bb2
+ ^bb1:
+ %0 = tensor.empty() : tensor<10x20xf32>
+ return %0 : tensor<10x20xf32>
+ ^bb2:
+ %1 = arith.constant dense<1.0> : tensor<10x20xf32>
+ return %1 : tensor<10x20xf32>
+}
+// CHECK-LABEL: func @simple_test_with_cfg
+// CHECK-SAME: %[[ARG0:.+]]: i1
+// CHECK-NEXT: %[[DISPATCH:.+]] = flow.dispatch.region -> (tensor<10x20xf32>) {
+// CHECK: cf.cond_br %[[ARG0]], ^[[BB1:[a-zA-Z0-9]+]], ^[[BB2:[a-zA-Z0-9]+]]
+// CHECK: ^[[BB1]]:
+// CHECK: %[[EMPTY:.+]] = tensor.empty
+// CHECK: flow.return %[[EMPTY]]
+// CHECK: ^[[BB2]]:
+// CHECK: %[[CST:.+]] = arith.constant
+// CHECK: flow.return %[[CST]]
+// CHECK: }
+// CHECK: return %[[DISPATCH]]
diff --git a/tests/e2e/regression/BUILD.bazel b/tests/e2e/regression/BUILD.bazel
index bd18c9c..11fd236 100644
--- a/tests/e2e/regression/BUILD.bazel
+++ b/tests/e2e/regression/BUILD.bazel
@@ -38,6 +38,7 @@
name = "lit",
srcs = [
"fill_i64.mlir",
+ "force_single_dispatch.mlir",
"globals.mlir",
"libm_linking.mlir",
"scalar.mlir",
diff --git a/tests/e2e/regression/CMakeLists.txt b/tests/e2e/regression/CMakeLists.txt
index 7f506a2..cb6b2d1 100644
--- a/tests/e2e/regression/CMakeLists.txt
+++ b/tests/e2e/regression/CMakeLists.txt
@@ -15,6 +15,7 @@
lit
SRCS
"fill_i64.mlir"
+ "force_single_dispatch.mlir"
"globals.mlir"
"libm_linking.mlir"
"scalar.mlir"
diff --git a/tests/e2e/regression/force_single_dispatch.mlir b/tests/e2e/regression/force_single_dispatch.mlir
new file mode 100644
index 0000000..8876d95
--- /dev/null
+++ b/tests/e2e/regression/force_single_dispatch.mlir
@@ -0,0 +1,14 @@
+// RUN: iree-opt --iree-preprocessing-make-single-dispatch-for-function %s | iree-run-mlir --Xcompiler,iree-hal-target-backends=llvm-cpu --input="1" -
+func.func @simple_test_with_cfg(%arg0 : i8) -> tensor<2x4xf32> {
+ %c0_i8 = arith.constant 0 : i8
+ %cond = arith.cmpi eq, %arg0, %c0_i8 : i8
+ cf.cond_br %cond, ^bb1, ^bb2
+ ^bb1:
+ %0 = tensor.empty() : tensor<2x4xf32>
+ return %0 : tensor<2x4xf32>
+ ^bb2:
+ %1 = arith.constant dense<1.0> : tensor<2x4xf32>
+ return %1 : tensor<2x4xf32>
+}
+// CHECK-LABEL: EXEC @simple_test_with_cfg
+// CHECK: 2x4xf32=[1 1 1 1][1 1 1 1]