Add a preprocessing pass to move entire function into a single dispatch. (#14578)

For cases where the model is very small and does not have much concurrency, it is better to move the entire function body into a single dispatch. Eventually the default heuristics can probably figure out when a model is "too small", but for now this PR adds a pass to move the entire function body into a single dispatch to use as a way to find codegen issues such an approach throws up, and also to experiment with different heuristics needed to find such dispatches automatically.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
index 78d4e3b..c1a295f 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
@@ -280,33 +280,80 @@
 // flow.dispatch.region
 //===----------------------------------------------------------------------===//
 
-LogicalResult DispatchRegionOp::verify() {
-  // No block arguments.
-  if (!getBody().getArguments().empty())
-    return emitOpError() << "expected no block arguments";
+// Verifies the workgroup count
 
-  // Only one block.
-  if (!getBody().hasOneBlock())
-    return emitOpError() << "expected exactly 1 block";
+static LogicalResult
+verifyWorkgroupCountRegion(Operation *op, ValueRange workload, Region &region) {
+  // Verify the workload operands match the expected capture args.
+  if (workload.size() != region.getNumArguments()) {
+    return op->emitOpError()
+           << "workload operands and workgroup count args mismatch ("
+           << workload.size() << " vs " << region.getNumArguments() << ")";
+  }
+  for (auto [index, values] :
+       llvm::enumerate(llvm::zip_equal(workload, region.getArguments()))) {
+    auto [workloadValue, capturedArg] = values;
+    if (workloadValue.getType() != capturedArg.getType()) {
+      return op->emitOpError()
+             << "workload value " << index << " type mismatch; operand is "
+             << workloadValue.getType() << " but region captures "
+             << capturedArg.getType();
+    }
+  }
 
-  // Verify terminator.
-  auto returnOp = dyn_cast<Flow::ReturnOp>(getBody().front().getTerminator());
-  if (!returnOp)
-    return emitOpError() << "expected 'flow.return' terminator";
-  for (const auto [resultType, returnType] :
-       llvm::zip_equal(getResultTypes(), returnOp->getOperandTypes()))
-    if (resultType != returnType)
-      return returnOp->emitOpError()
-             << "operand types do not match with parent results";
-
-  // Make sure that all returned values are ranked tensors.
-  for (Type t : getResultTypes())
-    if (!llvm::isa<RankedTensorType>(t))
-      return emitOpError() << "only ranked tensor results are allowed";
+  // Verify the return ops all provide XYZ values.
+  for (auto returnOp : region.getOps<IREE::Flow::ReturnOp>()) {
+    if (returnOp.getNumOperands() != 3 ||
+        !llvm::all_of(returnOp.getOperandTypes(),
+                      [](Type type) { return type.isIndex(); })) {
+      return returnOp.emitOpError() << "workgroup count region must return "
+                                       "the XYZ dimension counts";
+    }
+  }
 
   return success();
 }
 
+LogicalResult DispatchRegionOp::verify() {
+  // No block arguments.
+  if (!getBody().getArguments().empty()) {
+    return emitOpError() << "expected no block arguments";
+  }
+
+  // Verify terminator.
+  SmallVector<Flow::ReturnOp> returnOps;
+  for (Block &block : getBody()) {
+    if (auto returnOp =
+            dyn_cast_or_null<Flow::ReturnOp>(block.getTerminator())) {
+      returnOps.push_back(returnOp);
+    }
+  }
+  for (auto returnOp : returnOps) {
+    for (const auto [resultType, returnType] :
+         llvm::zip_equal(getResultTypes(), returnOp->getOperandTypes()))
+      if (resultType != returnType) {
+        return returnOp->emitOpError()
+               << "operand types do not match with parent results";
+      }
+  }
+
+  // Make sure that all returned values are ranked tensors.
+  for (Type t : getResultTypes()) {
+    if (!llvm::isa<RankedTensorType>(t)) {
+      return emitOpError() << "only ranked tensor results are allowed";
+    }
+  }
+
+  Region &workgroupCount = getWorkgroupCount();
+  if (workgroupCount.empty()) {
+    return success();
+  }
+
+  // If workgroup count region exists, check it has a single block.
+  return verifyWorkgroupCountRegion(getOperation(), getWorkload(),
+                                    getWorkgroupCount());
+}
+
 ParseResult DispatchRegionOp::parse(OpAsmParser &parser,
                                     OperationState &result) {
   SmallVector<Type> resultTypes;
@@ -348,7 +395,6 @@
     return failure();
   if (parser.parseRegion(*bodyRegion))
     return failure();
-  ensureTerminator(*bodyRegion, parser.getBuilder(), result.location);
 
   if (parseDispatchWorkgroupsCountRegion(parser, *workloadCountRegion)) {
     return failure();
@@ -868,38 +914,6 @@
                 /*printBlockTerminators=*/true);
 }
 
-LogicalResult verifyWorkgroupCountRegion(Operation *op, ValueRange workload,
-                                         Region &region) {
-  // Verify the workload operands match the expected capture args.
-  if (workload.size() != region.getNumArguments()) {
-    return op->emitOpError()
-           << "workload operands and workgroup count args mismatch ("
-           << workload.size() << " vs " << region.getNumArguments() << ")";
-  }
-  for (auto [index, values] :
-       llvm::enumerate(llvm::zip_equal(workload, region.getArguments()))) {
-    auto [workloadValue, capturedArg] = values;
-    if (workloadValue.getType() != capturedArg.getType()) {
-      return op->emitOpError()
-             << "workload value " << index << " type mismatch; operand is "
-             << workloadValue.getType() << " but region captures "
-             << capturedArg.getType();
-    }
-  }
-
-  // Verify the return ops all provide XYZ values.
-  for (auto returnOp : region.getOps<IREE::Flow::ReturnOp>()) {
-    if (returnOp.getNumOperands() != 3 ||
-        !llvm::all_of(returnOp.getOperandTypes(),
-                      [](Type type) { return type.isIndex(); })) {
-      return returnOp.emitOpError() << "workgroup count region must return "
-                                       "the XYZ dimension counts";
-    }
-  }
-
-  return success();
-}
-
 LogicalResult DispatchWorkgroupsOp::verify() {
   Operation *op = getOperation();
 
@@ -1043,7 +1057,7 @@
 
 IREE::Util::ValueAccess
 DispatchWorkgroupsOp::getResultAccess(unsigned resultIndex) {
-  unsigned startIndex = getBody()->getNumArguments() - getNumResults();
+  unsigned startIndex = getWorkgroupBody().getNumArguments() - getNumResults();
   BlockArgument arg =
       getWorkgroupBody().front().getArgument(startIndex + resultIndex);
   if (auto tensorType = llvm::dyn_cast<DispatchTensorType>(arg.getType())) {
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
index 8affac2..e2a2335 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
@@ -35,8 +35,7 @@
 
 def FLOW_DispatchRegionOp : FLOW_PureOp<"dispatch.region", [
     Util_ShapeAwareOp,
-    AttrSizedOperandSegments,
-    SingleBlockImplicitTerminator<"IREE::Flow::ReturnOp">]> {
+    AttrSizedOperandSegments]> {
   let summary = [{a group of ops}];
   let description = [{
     This op is a container/grouping of ops. It represents a fusion group before
@@ -76,7 +75,6 @@
 def FLOW_DispatchWorkgroupsOp : FLOW_PureOp<"dispatch.workgroups", [
   IsolatedFromAbove,
   AttrSizedOperandSegments,
-  SingleBlockImplicitTerminator<"IREE::Flow::ReturnOp">,
   DeclareOpInterfaceMethods<Util_ClosureOpInterface>,
   DeclareOpInterfaceMethods<Util_TiedOpInterface, [
     "getTiedOperandsIndexAndLength",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CaptureDispatchDynamicDims.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CaptureDispatchDynamicDims.cpp
index 0c99de0..614a391 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CaptureDispatchDynamicDims.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CaptureDispatchDynamicDims.cpp
@@ -30,7 +30,11 @@
 // leave the cleanup of redundant work to further optimization passes to keep
 // this simple.
 static void captureDims(IREE::Flow::DispatchWorkgroupsOp dispatchOp) {
-  auto *entryBlock = dispatchOp.getBody();
+  Region &body = dispatchOp.getWorkgroupBody();
+  if (body.empty()) {
+    return;
+  }
+  auto *entryBlock = &body.front();
 
   // Map of SSA values on the outside of the op to arguments on the inside.
   // This lets us avoid capturing duplicate values - they'd be cleaned up
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.cpp
index 601f640..802b6ef 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.cpp
@@ -63,13 +63,15 @@
 
   while (!isOutside(value)) {
     auto tiedOpInterface = value.getDefiningOp<IREE::Util::TiedOpInterface>();
-    if (!tiedOpInterface)
+    if (!tiedOpInterface) {
       // Reached an op that does not implement the interface.
       return std::nullopt;
+    }
     value = tiedOpInterface.getTiedResultOperand(value);
-    if (!value)
+    if (!value) {
       // Nothing is tied here.
       return std::nullopt;
+    }
   }
 
   return value;
@@ -84,13 +86,13 @@
 FailureOr<Flow::DispatchWorkgroupsOp>
 rewriteFlowDispatchRegionToFlowDispatchWorkgroups(
     Flow::DispatchRegionOp regionOp, RewriterBase &rewriter) {
-  // Only ops with a single block are supported.
   Region &region = regionOp.getBody();
-  if (!region.hasOneBlock())
-    return failure();
-  Block &body = region.front();
-  auto terminator = cast<Flow::ReturnOp>(body.getTerminator());
-  unsigned numResults = terminator->getNumOperands();
+  // Currently this does not handle empty `flow.dispatch.region` ops.
+  if (region.empty()) {
+    return rewriter.notifyMatchFailure(regionOp,
+                                       "unhandled op with empty region");
+  }
+  unsigned numResults = regionOp->getNumResults();
 
   // Prepare rewriter.
   OpBuilder::InsertionGuard guard(rewriter);
@@ -118,7 +120,14 @@
   DenseSet<Value> tiedArgumentsSet;
   SmallVector<int64_t> tiedArguments(numResults,
                                      IREE::Util::TiedOpInterface::kUntiedIndex);
-  for (const auto &it : llvm::enumerate(terminator->getOperands())) {
+  SmallVector<Flow::ReturnOp> origTerminators;
+  region.walk(
+      [&](Flow::ReturnOp returnOp) { origTerminators.push_back(returnOp); });
+  assert(!origTerminators.empty() && "expected at least one terminator");
+  // Use one of the terminators to get the the `tiedArguments` set.
+  // TODO: Check that using all terminators gives you the same result.
+  for (const auto &it :
+       llvm::enumerate(origTerminators.front()->getOperands())) {
     auto tiedArgument =
         findFirstTiedValueOutsideOfRegionOp(regionOp, it.value());
     if (!tiedArgument.has_value())
@@ -166,15 +175,20 @@
   bvm.map(arguments, workgroupsOp.getInputBlockArguments());
 
   // Create DispatchTensorLoadOp for all tensor arguments.
-  assert(workgroupsOp.getWorkgroupBody().hasOneBlock() &&
-         "expected one block after constructor");
-  Block &newBody = workgroupsOp.getWorkgroupBody().getBlocks().front();
-  assert(newBody.empty() && "expected empty block after constructor");
-  rewriter.setInsertionPointToStart(&newBody);
+  Region &newBody = workgroupsOp.getWorkgroupBody();
+  assert(llvm::hasSingleElement(newBody) &&
+         "expected `flow.dispatch.workgroup` op to be created with a single "
+         "block");
+
+  Block *newBodyEntry = &newBody.front();
+  rewriter.setInsertionPointToStart(newBodyEntry);
+  SmallVector<Value> argValues;
   for (const auto &it : llvm::enumerate(arguments)) {
     auto tensorType = llvm::dyn_cast<RankedTensorType>(it.value().getType());
-    if (!tensorType)
+    if (!tensorType) {
+      argValues.push_back(it.value());
       continue;
+    }
     auto inputBbArg = workgroupsOp.getInputBlockArgument(it.index());
     auto dims =
         Util::findVariadicDynamicDims(it.index(), arguments, argumentDims);
@@ -185,10 +199,16 @@
     Value loadedTensor = rewriter.create<IREE::Flow::DispatchTensorLoadOp>(
         loc, tensorType, inputBbArg, bbArgDims);
     bvm.map(it.value(), loadedTensor);
+    argValues.push_back(loadedTensor);
   }
 
   // Move regionOp body into the workgroupsOp.
-  newBody.getOperations().splice(newBody.end(), body.getOperations());
+  rewriter.inlineRegionBefore(region, newBody, newBody.end());
+  // Merge the enrty block of `newBody` with the original entry block from the
+  // region.
+  Block *origEntry = &(*(std::next(newBody.begin())));
+  rewriter.mergeBlocks(origEntry, newBodyEntry);
+
   for (Value argument : arguments) {
     argument.replaceUsesWithIf(bvm.lookup(argument), [&](OpOperand &operand) {
       return workgroupsOp->isProperAncestor(operand.getOwner());
@@ -196,33 +216,38 @@
   }
 
   // Update terminator.
-  rewriter.setInsertionPoint(terminator);
-  for (const auto &it : llvm::enumerate(terminator->getOperands())) {
-    auto outputBbArg = workgroupsOp.getOutputBlockArgument(it.index());
-    ValueRange dims;
-    if (tiedArguments[it.index()] ==
-        IREE::Util::TiedOpInterface::kUntiedIndex) {
-      dims = regionOp.getResultDynamicDims(it.index());
-    } else {
-      // This assumes that the number of dynamic dims does not change when
-      // following an SSA use-def chain of tied values.
-      dims = Util::findVariadicDynamicDims(tiedArguments[it.index()], arguments,
-                                           argumentDims);
-    }
+  SmallVector<Flow::ReturnOp> terminators;
+  newBody.walk(
+      [&](Flow::ReturnOp returnOp) { terminators.push_back(returnOp); });
+  for (auto terminator : terminators) {
+    rewriter.setInsertionPoint(terminator);
+    for (const auto &it : llvm::enumerate(terminator->getOperands())) {
+      auto outputBbArg = workgroupsOp.getOutputBlockArgument(it.index());
+      ValueRange dims;
+      if (tiedArguments[it.index()] ==
+          IREE::Util::TiedOpInterface::kUntiedIndex) {
+        dims = regionOp.getResultDynamicDims(it.index());
+      } else {
+        // This assumes that the number of dynamic dims does not change when
+        // following an SSA use-def chain of tied values.
+        dims = Util::findVariadicDynamicDims(tiedArguments[it.index()],
+                                             arguments, argumentDims);
+      }
 #ifndef NDEBUG
-    auto tensorType = it.value().getType().cast<RankedTensorType>();
-    assert(dims.size() == tensorType.getNumDynamicDims() &&
-           "mismatching number of dynamic dims");
+      auto tensorType = it.value().getType().cast<RankedTensorType>();
+      assert(dims.size() == tensorType.getNumDynamicDims() &&
+             "mismatching number of dynamic dims");
 #endif // NDEBUG
-    SmallVector<Value> bbArgDims =
-        llvm::map_to_vector(dims, [&](Value v) { return bvm.lookup(v); });
-    rewriter.create<IREE::Flow::DispatchTensorStoreOp>(loc, it.value(),
-                                                       outputBbArg, bbArgDims);
-  }
+      SmallVector<Value> bbArgDims =
+          llvm::map_to_vector(dims, [&](Value v) { return bvm.lookup(v); });
+      rewriter.create<IREE::Flow::DispatchTensorStoreOp>(
+          loc, it.value(), outputBbArg, bbArgDims);
+    }
 
-  // Delete the old terminator and create a new one.
-  rewriter.create<IREE::Flow::ReturnOp>(loc);
-  rewriter.eraseOp(terminator);
+    // Delete the old terminator and create a new one.
+    rewriter.create<IREE::Flow::ReturnOp>(loc);
+    rewriter.eraseOp(terminator);
+  }
 
   rewriter.replaceOp(regionOp, workgroupsOp.getResults());
   return workgroupsOp;
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchWorkgroups.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchWorkgroups.cpp
index 5a6396f..d5cd0fa 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchWorkgroups.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchWorkgroups.cpp
@@ -236,7 +236,11 @@
 
     // Annotate the values captures as workload with their position in the
     // workload list.
-    rewriter.setInsertionPointToStart(workgroupsOp.getBody());
+    Region &body = workgroupsOp.getWorkgroupBody();
+    if (body.empty()) {
+      return;
+    }
+    rewriter.setInsertionPointToStart(&body.front());
     int ordinalNumber = 0;
     for (auto [index, operand] : llvm::enumerate(workgroupsOp.getArguments())) {
       if (!llvm::isa<IndexType>(operand.getType()))
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp
index 6088473..f15114e 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp
@@ -201,7 +201,7 @@
   Operation *bestOp = NULL;
   const int64_t kMinEstimatedCost = -1;
   int64_t bestEstimatedCost = kMinEstimatedCost;
-  regionOp.getBodyRegion().walk([&](Operation *op) {
+  regionOp.getWorkgroupBody().walk([&](Operation *op) {
     TypeSwitch<Operation *>(op)
         .Case<linalg::LinalgOp>([&](auto op) {
           int64_t estimatedCost = estimateLinalgOpCost(op);
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_dispatch_workgroups.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_dispatch_workgroups.mlir
index 01fe2cc..85f263b 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_dispatch_workgroups.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_dispatch_workgroups.mlir
@@ -14,3 +14,32 @@
 //       CHECK:   count(%[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: index)
 //       CHECK:     %[[C1:.+]] = arith.constant 1 : index
 //       CHECK:     flow.return %[[ARG2]], %[[ARG3]], %[[C1]]
+
+// -----
+
+func.func @simple_test_with_cfg(%arg0: i1) -> (tensor<10x20xf32>) {
+  %cst = arith.constant dense<1.000000e+00> : tensor<10x20xf32>
+  %0 = flow.dispatch.region -> (tensor<10x20xf32>) {
+    %cst_0 = arith.constant dense<1.000000e+00> : tensor<10x20xf32>
+    cf.cond_br %arg0, ^bb1, ^bb2
+  ^bb1:  // pred: ^bb0                                                                                                                                                                                                                                                                                             
+    %2 = tensor.empty() : tensor<10x20xf32>
+    flow.return %2 : tensor<10x20xf32>
+  ^bb2:  // pred: ^bb0                                                                                                                                                                                                                                                                                             
+    flow.return %cst_0 : tensor<10x20xf32>
+  }
+  return %0 : tensor<10x20xf32>
+}
+// CHECK-LABEL: func @simple_test_with_cfg
+//  CHECK-SAME:     %[[ARG0:.+]]: i1
+//       CHECK:   %[[RESULT:.+]] = flow.dispatch.workgroups(%[[ARG0]])
+//  CHECK-NEXT:       %[[ARG1:.+]]: i1, %[[ARG2:.+]]: !flow.dispatch.tensor
+//       CHECK:     %[[CST:.+]] = arith.constant
+//       CHECK:     ^[[BB1:.+]]:
+//       CHECK:       %[[EMPTY:.+]] = tensor.empty()
+//       CHECK:       flow.dispatch.tensor.store %[[EMPTY]], %[[ARG2]]
+//       CHECK:       flow.return
+//       CHECK:     ^[[BB2:.+]]:
+//       CHECK:       flow.dispatch.tensor.store %[[CST]], %[[ARG2]]
+//       CHECK:       flow.return
+//       CHECK:   return %[[RESULT]]
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
index c7a6702..dfd3178 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
@@ -31,6 +31,7 @@
     name = "Transforms",
     srcs = [
         "ConvertConv2DToImg2Col.cpp",
+        "MakeSingleDispatchForFunction.cpp",
         "PadLinalgOps.cpp",
         "PassDetail.h",
         "Passes.cpp",
@@ -42,6 +43,7 @@
     ],
     deps = [
         ":PassesIncGen",
+        "//compiler/src/iree/compiler/Dialect/Flow/IR",
         "//compiler/src/iree/compiler/Dialect/Flow/Transforms",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:ArithDialect",
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
index b0ec7b5..7f59a88 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
@@ -27,6 +27,7 @@
     "Passes.h.inc"
   SRCS
     "ConvertConv2DToImg2Col.cpp"
+    "MakeSingleDispatchForFunction.cpp"
     "PadLinalgOps.cpp"
     "PassDetail.h"
     "Passes.cpp"
@@ -43,6 +44,7 @@
     MLIRTensorDialect
     MLIRTensorUtils
     MLIRTransforms
+    iree::compiler::Dialect::Flow::IR
     iree::compiler::Dialect::Flow::Transforms
   PUBLIC
 )
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/MakeSingleDispatchForFunction.cpp b/compiler/src/iree/compiler/Preprocessing/Common/MakeSingleDispatchForFunction.cpp
new file mode 100644
index 0000000..b452302
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/Common/MakeSingleDispatchForFunction.cpp
@@ -0,0 +1,97 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Preprocessing/Common/PassDetail.h"
+#include "iree/compiler/Preprocessing/Common/Passes.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+
+namespace {
+
+struct MakeSingleDispatchForFunctionPass
+    : public MakeSingleDispatchForFunctionBase<
+          MakeSingleDispatchForFunctionPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<IREE::Flow::FlowDialect>();
+  }
+
+  void runOnOperation() override;
+};
+} // namespace
+
+void MakeSingleDispatchForFunctionPass::runOnOperation() {
+  auto funcOp = getOperation();
+
+  // Abort if there are any operations that prevent moving all operations
+  // into a single dispatch.
+  auto walkResult = funcOp.walk([](Operation *op) -> WalkResult {
+    return success(!isa<func::CallOp>(op));
+  });
+  if (walkResult.wasInterrupted()) {
+    funcOp->emitOpError("unhandled operation in function body prevents moving "
+                        "body into a single dispatch");
+  }
+
+  // Currently this can only be done for static shapes cause
+  // there is no way of getting the tied dynamic shapes for
+  // a function.
+  auto resultTypes = funcOp.getFunctionType().getResults();
+  if (llvm::any_of(resultTypes, [&](Type t) {
+        auto shapedType = t.dyn_cast<ShapedType>();
+        return shapedType && !shapedType.hasStaticShape();
+      })) {
+    return;
+  }
+
+  IRRewriter rewriter(&getContext());
+  Location loc = funcOp.getLoc();
+  Region &funcBody = funcOp.getBody();
+
+  // Split the function entry block to create a new entry block into which the
+  // new operations will be added.
+  Block &entryBlock = funcBody.front();
+  Block *funcBodyStart = rewriter.splitBlock(&entryBlock, entryBlock.begin());
+
+  // Create an empty `flow.dispatch.region` operation with same result type as
+  // the function.
+  rewriter.setInsertionPointToEnd(&entryBlock);
+  auto dispatchRegionOp = rewriter.create<IREE::Flow::DispatchRegionOp>(
+      loc, resultTypes, /*result_dims=*/ValueRange{},
+      /*workload=*/ValueRange{});
+
+  // Move the body of the function into the region.
+  Region &region = dispatchRegionOp.getBody();
+  region.getBlocks().splice(region.begin(), funcBody.getBlocks(),
+                            Region::iterator(funcBodyStart), funcBody.end());
+
+  // Replace all `func.return` with `flow.return`.
+  SmallVector<func::ReturnOp> returnOps =
+      llvm::to_vector(region.getOps<func::ReturnOp>());
+  for (auto returnOp : returnOps) {
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(returnOp);
+    rewriter.replaceOpWithNewOp<IREE::Flow::ReturnOp>(returnOp,
+                                                      returnOp.getOperands());
+  }
+
+  // Return the results of the `flow.dispatch.region`.
+  rewriter.setInsertionPointAfter(dispatchRegionOp);
+  rewriter.create<func::ReturnOp>(loc, dispatchRegionOp.getResults());
+}
+
+std::unique_ptr<OperationPass<func::FuncOp>>
+createMakeSingleDispatchForFunctionPass() {
+  return std::make_unique<MakeSingleDispatchForFunctionPass>();
+}
+
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h
index bd5d07b..c9835a9 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h
+++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h
@@ -17,11 +17,15 @@
 namespace iree_compiler {
 namespace IREE {
 
-// Creates a pass to convert linalg convolution ops into linalg.matmul ops
-// using im2col tranformation.
+/// Creates a pass to convert linalg convolution ops into linalg.matmul ops
+/// using im2col tranformation.
 std::unique_ptr<Pass> createConvertConv2DToImg2ColPass();
 
-// A pass to pad linalg ops to the next integer multiple of `paddingSize`.
+/// Moves the body of the entire function into a single dispatch.
+std::unique_ptr<OperationPass<func::FuncOp>>
+createMakeSingleDispatchForFunctionPass();
+
+/// A pass to pad linalg ops to the next integer multiple of `paddingSize`.
 std::unique_ptr<Pass> createPadLinalgOpsToIntegerMultiplePass();
 
 /// Pass to merge parallel linalg operations.
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
index e063ef7..9c88781 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
+++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
@@ -15,6 +15,12 @@
   let constructor = "mlir::iree_compiler::IREE::createConvertConv2DToImg2ColPass()";
 }
 
+def MakeSingleDispatchForFunction :
+    Pass<"iree-preprocessing-make-single-dispatch-for-function", "func::FuncOp"> {
+  let summary = "Convert entire function into a single dispatch";
+  let constructor = "mlir::iree_compiler::IREE::createMakeSingleDispatchForFunctionPass()";
+}
+
 def PadLinalgOps :
     Pass<"iree-preprocessing-pad-linalg-ops", ""> {
   let summary = "Pad linalg ops to the next integer multiple of paddingSize.";
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel
index 9dd501c..37d77d7 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel
@@ -17,6 +17,7 @@
     srcs = enforce_glob(
         [
             "conv2d_to_img2col.mlir",
+            "make_single_dispatch_for_function.mlir",
             "pad_linalg_ops.mlir",
             "rematerialize_parallel_ops.mlir",
         ],
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt
index a754aea..0cca180 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt
@@ -15,6 +15,7 @@
     lit
   SRCS
     "conv2d_to_img2col.mlir"
+    "make_single_dispatch_for_function.mlir"
     "pad_linalg_ops.mlir"
     "rematerialize_parallel_ops.mlir"
   TOOLS
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/make_single_dispatch_for_function.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/make_single_dispatch_for_function.mlir
new file mode 100644
index 0000000..8b8a155
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/make_single_dispatch_for_function.mlir
@@ -0,0 +1,35 @@
+// RUN: iree-opt --iree-preprocessing-make-single-dispatch-for-function --split-input-file %s | FileCheck %s
+
+func.func @simple_test() -> tensor<10x20xf32> {
+  %0 = tensor.empty() : tensor<10x20xf32>
+  return %0 : tensor<10x20xf32>
+}
+// CHECK-LABEL: func @simple_test() -> tensor<10x20xf32>
+//  CHECK-NEXT:   %[[DISPATCH:.+]] = flow.dispatch.region
+//       CHECK:     %[[EMPTY:.+]] = tensor.empty
+//       CHECK:     flow.return %[[EMPTY]]
+//       CHECK:   return %[[DISPATCH]]
+
+// -----
+
+func.func @simple_test_with_cfg(%arg0 : i1) -> tensor<10x20xf32> {
+    cf.cond_br %arg0, ^bb1, ^bb2
+  ^bb1:
+    %0 = tensor.empty() : tensor<10x20xf32>
+    return %0 : tensor<10x20xf32>
+  ^bb2:
+    %1 = arith.constant dense<1.0> : tensor<10x20xf32>
+    return %1 : tensor<10x20xf32>
+}
+// CHECK-LABEL: func @simple_test_with_cfg
+//  CHECK-SAME:     %[[ARG0:.+]]: i1
+//  CHECK-NEXT:   %[[DISPATCH:.+]] = flow.dispatch.region -> (tensor<10x20xf32>) {
+//       CHECK:       cf.cond_br %[[ARG0]], ^[[BB1:[a-zA-Z0-9]+]], ^[[BB2:[a-zA-Z0-9]+]]
+//       CHECK:     ^[[BB1]]:
+//       CHECK:       %[[EMPTY:.+]] = tensor.empty
+//       CHECK:       flow.return %[[EMPTY]]
+//       CHECK:     ^[[BB2]]:
+//       CHECK:       %[[CST:.+]] = arith.constant
+//       CHECK:       flow.return %[[CST]]
+//       CHECK:   }
+//       CHECK:   return %[[DISPATCH]]
diff --git a/tests/e2e/regression/BUILD.bazel b/tests/e2e/regression/BUILD.bazel
index bd18c9c..11fd236 100644
--- a/tests/e2e/regression/BUILD.bazel
+++ b/tests/e2e/regression/BUILD.bazel
@@ -38,6 +38,7 @@
     name = "lit",
     srcs = [
         "fill_i64.mlir",
+        "force_single_dispatch.mlir",
         "globals.mlir",
         "libm_linking.mlir",
         "scalar.mlir",
diff --git a/tests/e2e/regression/CMakeLists.txt b/tests/e2e/regression/CMakeLists.txt
index 7f506a2..cb6b2d1 100644
--- a/tests/e2e/regression/CMakeLists.txt
+++ b/tests/e2e/regression/CMakeLists.txt
@@ -15,6 +15,7 @@
     lit
   SRCS
     "fill_i64.mlir"
+    "force_single_dispatch.mlir"
     "globals.mlir"
     "libm_linking.mlir"
     "scalar.mlir"
diff --git a/tests/e2e/regression/force_single_dispatch.mlir b/tests/e2e/regression/force_single_dispatch.mlir
new file mode 100644
index 0000000..8876d95
--- /dev/null
+++ b/tests/e2e/regression/force_single_dispatch.mlir
@@ -0,0 +1,14 @@
+// RUN: iree-opt --iree-preprocessing-make-single-dispatch-for-function %s | iree-run-mlir --Xcompiler,iree-hal-target-backends=llvm-cpu --input="1" -
+func.func @simple_test_with_cfg(%arg0 : i8) -> tensor<2x4xf32> {
+    %c0_i8 = arith.constant 0 : i8
+    %cond = arith.cmpi eq, %arg0, %c0_i8 : i8
+    cf.cond_br %cond, ^bb1, ^bb2
+  ^bb1:
+    %0 = tensor.empty() : tensor<2x4xf32>
+    return %0 : tensor<2x4xf32>
+  ^bb2:
+    %1 = arith.constant dense<1.0> : tensor<2x4xf32>
+    return %1 : tensor<2x4xf32>
+}
+// CHECK-LABEL: EXEC @simple_test_with_cfg
+//       CHECK: 2x4xf32=[1 1 1 1][1 1 1 1]