[Flow] Add pass to bubble and hoist encoding ops out of dispatch regions (#18063)

This PR adds a new pass in the Flow data tiling pipeline to hoist
encoding ops out of their dispatch regions. After SetEncoding, the
encoding ops are inserted directly inside of the dispatch regions that
contain the data tiled ops. The set_encoding ops then need to be hoisted
out of the dispatch region in order to fuse into the producer dispatch.

Sometimes there may be producer operations fused into the same dispatch
as the data tiled op, in which case the set_encoding ops will have
producers inside of the dispatch. In order to hoist the set_encoding op,
it needs to be bubbled up through these producer operations until it has
no producers inside of its dispatch. This pass supports bubbling of
set_encoding ops through bit extending ops and broadcasting ops.

After this pass, all set_encoding ops should be outside of dispatch
regions, and they need to be fused with their producers. Another pass
will be added in the next PR to fuse set_encoding ops into their
producer dispatch regions or wrap them in a new dispatch region.

---------

Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingBase.td b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingBase.td
index dc89604..befda32 100644
--- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingBase.td
+++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingBase.td
@@ -113,6 +113,9 @@
 
     /// Returns an integer array with values in `round_dims_to`.
     ArrayRef<int64_t> getRoundDimsToArray();
+
+    /// Clones an encoding with a new bcast_map
+    EncodingAttr clone(AffineMap bcastMap);
   }];
 
   let genVerifyDecl = 0;
diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.cpp b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.cpp
index 01bcd7e..0c3ef6d 100644
--- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.cpp
@@ -153,6 +153,13 @@
   return llvm::cast<DenseI64ArrayAttr>(roundDimsTo).asArrayRef();
 }
 
+EncodingAttr EncodingAttr::clone(AffineMap bcastMap) {
+  return get(bcastMap.getContext(), getOperandIndex(), getOpType(),
+             getElementTypes(), getOriginalType(), getMatmulNarrow_M(),
+             getMatmulNarrow_N(), getUserIndexingMaps(),
+             AffineMapAttr::get(bcastMap), getRoundDimsTo());
+}
+
 //===---------------------------------------------------------------------===//
 // Encoding Dialect Helpers
 //===---------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
index fe5eaab..b790a99 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
@@ -55,6 +55,7 @@
         "FuseMultiUseElementwiseProducer.cpp",
         "FusionPreprocessing.cpp",
         "FusionUtils.cpp",
+        "HoistEncodingOps.cpp",
         "InitializeEmptyTensors.cpp",
         "InjectDispatchTracing.cpp",
         "InjectTensorTracing.cpp",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index 701eefe..7bbb5d5 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -55,6 +55,7 @@
     "FuseMultiUseElementwiseProducer.cpp"
     "FusionPreprocessing.cpp"
     "FusionUtils.cpp"
+    "HoistEncodingOps.cpp"
     "InitializeEmptyTensors.cpp"
     "InjectDispatchTracing.cpp"
     "InjectTensorTracing.cpp"
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/HoistEncodingOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/HoistEncodingOps.cpp
new file mode 100644
index 0000000..0068744
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/HoistEncodingOps.cpp
@@ -0,0 +1,217 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Encoding/IR/EncodingDialect.h"
+#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
+#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
+#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-flow-hoist-encoding-ops"
+
+namespace mlir::iree_compiler::IREE::Flow {
+#define GEN_PASS_DEF_HOISTENCODINGOPSPASS
+#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc"
+
+static AffineMap getBcastMapOrIdentity(RewriterBase &rewriter,
+                                       RankedTensorType encodedType) {
+  auto encoding = cast<IREE::Encoding::EncodingAttr>(encodedType.getEncoding());
+  AffineMapAttr bcastMapAttr = encoding.getBcastMap();
+  return bcastMapAttr ? bcastMapAttr.getAffineMap()
+                      : rewriter.getMultiDimIdentityMap(encodedType.getRank());
+}
+
+/// Bubbles a SetEncodingOp up through a linalg::GenericOp. The `genericOp`
+/// must:
+///  1. Have a single result.
+///  2. Have single use.
+///  3. Have all parallel iterators.
+///  4. Have an identity output indexing map.
+///  5. Have a tensor.empty init operand.
+///  6. Have as many indexing map dims as there are results in the encoding's
+///     bcast_map.
+///
+/// This function creates SetEncoding ops on all of the inputs to the
+/// `genericOp`, and replaces the op with an encoded version. If any of
+/// the above conditions are false, then it returns failure.
+///
+/// Note: The bcast_map on the set_encoding op must be identity or absent.
+///       The implementation should work for cases where it is not, but it is
+///       unexpected in IREE compilation to find such cases, and it will not
+///       be well tested.
+static LogicalResult
+bubbleUpSetEncodingThroughGenericOp(RewriterBase &rewriter,
+                                    Encoding::SetEncodingOp encodingOp,
+                                    linalg::GenericOp genericOp) {
+  if (!genericOp->hasOneUse()) {
+    return rewriter.notifyMatchFailure(genericOp,
+                                       "genericOp must have one use");
+  }
+  if (genericOp.getNumDpsInits() != 1) {
+    return rewriter.notifyMatchFailure(genericOp,
+                                       "genericOp must have a single init");
+  }
+  if (genericOp.getNumReductionLoops() != 0) {
+    return rewriter.notifyMatchFailure(
+        genericOp, "genericOp must have all parallel loops");
+  }
+  if (!genericOp.getDpsInitOperand(0)->get().getDefiningOp<tensor::EmptyOp>()) {
+    return rewriter.notifyMatchFailure(genericOp,
+                                       "init operand must be tensor.empty");
+  }
+  AffineMap outputMap =
+      genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
+  if (!outputMap.isIdentity()) {
+    return rewriter.notifyMatchFailure(genericOp, "output map not identity");
+  }
+
+  RankedTensorType encodedType = encodingOp.getResultType();
+  AffineMap bcastMap = getBcastMapOrIdentity(rewriter, encodedType);
+  if (!bcastMap.isIdentity()) {
+    return rewriter.notifyMatchFailure(genericOp, "bcast_map map not identity");
+  }
+  if (outputMap.getNumDims() != bcastMap.getNumResults()) {
+    return rewriter.notifyMatchFailure(
+        genericOp, "output map numDims do not match bcast_map numResults");
+  }
+
+  // Set encodings on each input
+  Location loc = genericOp->getLoc();
+  SmallVector<Value> encodedOperands;
+  auto encoding = cast<IREE::Encoding::EncodingAttr>(encodedType.getEncoding());
+  for (OpOperand *operand : genericOp.getDpsInputOperands()) {
+    // Compute the new bcastMap from the operand's indexing map.
+    AffineMap operandMap = genericOp.getMatchingIndexingMap(operand);
+    AffineMap newBcastMap = operandMap.compose(bcastMap);
+
+    // Create new encoding and set encoding on the operand.
+    auto newEncoding = encoding.clone(newBcastMap);
+    auto operandType = cast<RankedTensorType>(operand->get().getType());
+    auto resType = RankedTensorType::get(
+        operandType.getShape(), operandType.getElementType(), newEncoding);
+    Value encodedInput =
+        rewriter.create<Encoding::SetEncodingOp>(loc, resType, operand->get());
+    encodedOperands.push_back(encodedInput);
+  }
+
+  // Create encoded generic op.
+  SmallVector<OpFoldResult> mixedSizes =
+      tensor::getMixedSizes(rewriter, loc, encodingOp.getSource());
+  Value encodedInit = rewriter.create<tensor::EmptyOp>(
+      loc, mixedSizes, encodedType.getElementType(), encoding);
+  encodedOperands.push_back(encodedInit);
+  auto encodedGenericOp =
+      clone(rewriter, genericOp, encodingOp.getResultType(), encodedOperands);
+
+  rewriter.replaceOp(encodingOp, encodedGenericOp);
+  return success();
+}
+
+static LogicalResult bubbleUpSetEncoding(RewriterBase &rewriter,
+                                         OpOperand &operand) {
+  auto setEncoding = cast<Encoding::SetEncodingOp>(operand.getOwner());
+  auto producer = operand.get().getDefiningOp<linalg::GenericOp>();
+  if (!producer) {
+    return failure();
+  }
+  // Only bubble through dequantization ops and broadcasting ops for now.
+  if (!LinalgExt::isBitExtendOp(producer) &&
+      !LinalgExt::isBroadcastingOp(producer)) {
+    return failure();
+  }
+  return bubbleUpSetEncodingThroughGenericOp(rewriter, setEncoding, producer);
+}
+
+namespace {
+/// Pass declaration.
+struct HoistEncodingOpsPass
+    : public IREE::Flow::impl::HoistEncodingOpsPassBase<HoistEncodingOpsPass> {
+  using IREE::Flow::impl::HoistEncodingOpsPassBase<
+      HoistEncodingOpsPass>::HoistEncodingOpsPassBase;
+  void runOnOperation() override;
+};
+
+/// Pattern to bubble SetEncoding ops upwards through producers. This pattern
+/// runs until bubbling is not possible, or until the SetEncoding op is outside
+/// of a dispatch.
+struct BubbleUpSetEncodingOp
+    : public OpRewritePattern<Encoding::SetEncodingOp> {
+  using OpRewritePattern<Encoding::SetEncodingOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(Encoding::SetEncodingOp encodingOp,
+                                PatternRewriter &rewriter) const override {
+    if (isNonNullAndOutsideDispatch(encodingOp)) {
+      return failure();
+    }
+    // Fail if the encodingOp is not in the same dispatch as its producer.
+    Operation *producer = encodingOp.getSource().getDefiningOp();
+    if (!producer) {
+      return failure();
+    }
+    auto dispatch = producer->getParentOfType<DispatchRegionOp>();
+    if (!dispatch ||
+        dispatch != encodingOp->getParentOfType<DispatchRegionOp>()) {
+      return failure();
+    }
+
+    return bubbleUpSetEncoding(rewriter, encodingOp->getOpOperand(0));
+  }
+};
+
+} // namespace
+
+/// Create dispatch.region Ops based on a fusion heuristic.
+void HoistEncodingOpsPass::runOnOperation() {
+  MLIRContext *ctx = &getContext();
+  auto funcOp = getOperation();
+
+  RewritePatternSet bubblingPatterns(ctx);
+  bubblingPatterns.insert<BubbleUpSetEncodingOp>(ctx);
+  if (failed(
+          applyPatternsAndFoldGreedily(funcOp, std::move(bubblingPatterns)))) {
+    return signalPassFailure();
+  }
+
+  SmallVector<Encoding::SetEncodingOp> candidates;
+  funcOp->walk([&](Encoding::SetEncodingOp setEncodingOp) {
+    if (setEncodingOp->getParentOfType<DispatchRegionOp>()) {
+      candidates.push_back(setEncodingOp);
+    }
+  });
+  IRRewriter rewriter(ctx);
+  for (auto setEncodingOp : candidates) {
+    if (failed(hoistOutOfDispatch(rewriter, setEncodingOp))) {
+      return signalPassFailure();
+    }
+  }
+
+  RewritePatternSet cleanPatterns(ctx);
+  memref::populateResolveRankedShapedTypeResultDimsPatterns(cleanPatterns);
+  DispatchRegionOp::getCanonicalizationPatterns(cleanPatterns, ctx);
+  if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(cleanPatterns)))) {
+    return signalPassFailure();
+  }
+}
+} // namespace mlir::iree_compiler::IREE::Flow
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index 4c522a4..5ce1211 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -279,12 +279,24 @@
       // afterwards that would need the full dispatch content but don't want to
       // handle explicit captures as materialized as dispatch workgroup operands
       // and block arguments.
-      .addPass(IREE::Flow::createCloneProducersIntoDispatchRegionsPass)
-      .addPredicatedPass(clEnableDataTiling,
-                         [&]() {
-                           return createSetEncodingPass(
-                               SetEncodingPassOptions{clPadFactor});
-                         })
+      .addPass(IREE::Flow::createCloneProducersIntoDispatchRegionsPass);
+  // Experimental data tiling path. The intent of this path is to set encodings
+  // after fusion decisions have already been made, so encodings can be
+  // separated from compiler fusion decisions.
+  if (clEnableDataTiling) {
+    SetEncodingPassOptions options{clPadFactor};
+    FunctionLikeNest(passManager)
+        // Set encodings on all eligible ops. All ops should be in compiler
+        // formed dispatch regions, so encodings will be placed inside of the
+        // dispatch regions with the data-tiled op.
+        .addPass([&]() { return createSetEncodingPass(options); })
+        // SetEncodingOps should not be in the same dispatch as the data-tiled
+        // op, so hoist them out of their current dispatch regions. Also, bubble
+        // SetEncodingOps through special operations like bit-extending ops and
+        // broadcasting ops.
+        .addPass(IREE::Flow::createHoistEncodingOpsPass);
+  }
+  FunctionLikeNest(passManager)
       // Collapse dimensions of linalg Ops.
       .addPass(IREE::Flow::createCollapseDimensionsPass)
       // Convert dispatch regions into dispatch workgroups by capturing values
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
index 8fe891b..fabde78 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
@@ -437,6 +437,16 @@
 }
 
 
+def HoistEncodingOpsPass :
+    InterfacePass<"iree-flow-hoist-encoding-ops", "mlir::FunctionOpInterface"> {
+  let summary = "Hoists tensor encoding ops out of flow dispatch regions.";
+  let dependentDialects = [
+    "mlir::linalg::LinalgDialect",
+    "IREE::Flow::FlowDialect",
+    "IREE::Encoding::IREEEncodingDialect",
+  ];
+}
+
 def InitializeEmptyTensorsPass :
     Pass<"iree-flow-initialize-empty-tensors", ""> {
   let summary = "Initialize empty tensors.";
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
index 4b1d017..4b0e5fe 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
@@ -11,6 +11,7 @@
 #include "iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.h"
 #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
 #include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/Support/CommandLine.h"
 #include "mlir/Analysis/SliceAnalysis.h"
@@ -528,6 +529,182 @@
   return newRegionOp;
 }
 
+FailureOr<Operation *> hoistOutOfDispatch(RewriterBase &rewriter,
+                                          Operation *op) {
+  assert(op && !isNonNullAndOutsideDispatch(op) &&
+         "op expected to be in a dispatch");
+
+  // Step 1: Clone the op outside of the dispatch region.
+
+  OpBuilder::InsertionGuard g(rewriter);
+  auto dispatchRegionOp = op->getParentOfType<DispatchRegionOp>();
+
+  // If all operands of the `op` come from outside the dispatch, then the op can
+  // be hoisted out before the dispatch region. Otherwise, the op can be hoisted
+  // out below the dispatch if the only users of the op are the dispatch return.
+  if (llvm::none_of(op->getOperands(), [&](Value operand) {
+        Operation *producer = operand.getDefiningOp();
+        return producer && producer->getParentOfType<DispatchRegionOp>();
+      })) {
+    rewriter.setInsertionPoint(dispatchRegionOp);
+  } else if (llvm::all_of(op->getUsers(), [&](Operation *user) {
+               return isa<IREE::Flow::ReturnOp>(user);
+             })) {
+    rewriter.setInsertionPointAfter(dispatchRegionOp);
+  } else {
+    return rewriter.notifyMatchFailure(
+        op, "op has both operands and users insided of its dispatch");
+  }
+  Operation *hoistedOp = rewriter.clone(*op);
+
+  // Step 2: Replace op uses inside and outside of the dispatch region with the
+  //         hoisted results.
+
+  auto getMatchingDispatchResult =
+      [&](Value result) -> std::optional<OpResult> {
+    for (OpOperand &use : result.getUses()) {
+      if (isa<IREE::Flow::ReturnOp>(use.getOwner())) {
+        return dispatchRegionOp.getResults()[use.getOperandNumber()];
+      }
+    }
+    return std::nullopt;
+  };
+  bool yieldsResults = false;
+  for (OpResult result : op->getResults()) {
+    Value hoistedResult = hoistedOp->getResult(result.getResultNumber());
+    // Replace all results yielded by the dispatch region with the hoisted
+    // op results.
+    std::optional<OpResult> dispResult = getMatchingDispatchResult(result);
+    if (dispResult.has_value()) {
+      yieldsResults = true;
+      rewriter.replaceAllUsesWith(dispResult.value(), hoistedResult);
+    }
+    // Replace uses inside the dispatch region.
+    rewriter.replaceAllUsesWith(result, hoistedResult);
+  }
+  // If no results were yielded from `op`, then nothing more to do.
+  if (!yieldsResults) {
+    return hoistedOp;
+  }
+
+  // Step 3: Collect the new set of dispatch results and dynamic dims, and
+  //         create a new dispatch region to replace the old one. The new
+  //         dispatch may have duplicated results,
+
+  // Get the new dispatch region return values and dynamic dims, excluding the
+  // ones coming from the `hoistedOp`.
+  auto dispatchReturnOp = cast<IREE::Flow::ReturnOp>(
+      dispatchRegionOp.getBody().front().getTerminator());
+  SmallVector<Value, 2> newDispatchReturnOperands;
+  SmallVector<Value, 4> newDispatchResultDynamicDims;
+  // Keep track of which results in the original dispatch region correspond to
+  // which results in the new dispatch region with `oldDispatchResultInds`.
+  SmallVector<int64_t, 2> oldDispatchResultInds;
+  for (OpOperand &operand : dispatchReturnOp->getOpOperands()) {
+    if (operand.get().getDefiningOp() == hoistedOp) {
+      continue;
+    }
+    oldDispatchResultInds.push_back(operand.getOperandNumber());
+    newDispatchReturnOperands.push_back(operand.get());
+    auto dims =
+        dispatchRegionOp.getResultDynamicDims(operand.getOperandNumber());
+    newDispatchResultDynamicDims.append(dims.begin(), dims.end());
+  }
+
+  // Add the operands of the `op` to the new return values of the dispatch, and
+  // add their result dynamic dims to the new result dynamic dims.
+  // Save the result index in the new dispatch corresponding to each hoisted op
+  // operand in `resultIndsForHoistedOperands`, so uses can be replaced later.
+  SmallVector<int64_t> resultIndsForHoistedOperands;
+  for (OpOperand &operand : op->getOpOperands()) {
+    // Only need to yield operands defined in the dispatch region.
+    if (operand.get().getParentRegion() != &dispatchRegionOp.getBody()) {
+      continue;
+    }
+
+    // If the operand is already yielded by the dispatch, don't yield it again,
+    // and save the result index.
+    bool resultAlreadyYielded = false;
+    for (auto [idx, returnOperand] :
+         llvm::enumerate(newDispatchReturnOperands)) {
+      if (returnOperand == operand.get()) {
+        resultAlreadyYielded = true;
+        resultIndsForHoistedOperands.push_back(idx);
+        break;
+      }
+    }
+    if (resultAlreadyYielded) {
+      break;
+    }
+    resultIndsForHoistedOperands.push_back(newDispatchReturnOperands.size());
+    newDispatchReturnOperands.push_back(operand.get());
+
+    // Save operand and dynamic dims to add to the dispatch region.
+    SmallVector<Value> dims;
+    if (failed(reifyDynamicResultDims(rewriter, operand.get(), dims))) {
+      return op->emitOpError(
+          "failed to reify dynamic dims of result to be yielded from "
+          "dispatch region");
+    }
+    newDispatchResultDynamicDims.append(dims.begin(), dims.end());
+  }
+
+  // Create the new dispatch region op. `newDispatchReturnOperands` now has all
+  // the original return operands, excluding the hoisted op's results, and
+  // including any new results coming from the hoisted op's old operands. The
+  // `newDispatchResultDynamicDims` contains the corresponding result dynamic
+  // dims for `newDispatchReturnOperands`.
+  SmallVector<Type> newResultTypes =
+      llvm::map_to_vector(newDispatchReturnOperands,
+                          [](Value operand) { return operand.getType(); });
+  rewriter.setInsertionPoint(dispatchRegionOp);
+  auto newDispatchRegionOp = rewriter.create<IREE::Flow::DispatchRegionOp>(
+      dispatchRegionOp->getLoc(), newResultTypes, newDispatchResultDynamicDims,
+      dispatchRegionOp.getWorkload());
+  rewriter.inlineRegionBefore(dispatchRegionOp.getBody(),
+                              newDispatchRegionOp.getBody(),
+                              newDispatchRegionOp.getBody().begin());
+  // Move the workgroup count region over.
+  if (!dispatchRegionOp.getWorkgroupCount().empty()) {
+    Region &newWorkgroupCountRegion = newDispatchRegionOp.getWorkgroupCount();
+    rewriter.inlineRegionBefore(dispatchRegionOp.getWorkgroupCount(),
+                                newWorkgroupCountRegion,
+                                newWorkgroupCountRegion.begin());
+  }
+  // Need to make a new flow.return op, since the body was copied from the
+  // old dispatch region.
+  auto newDispatchReturnOp = cast<IREE::Flow::ReturnOp>(
+      newDispatchRegionOp.getBody().front().getTerminator());
+  rewriter.setInsertionPoint(newDispatchReturnOp);
+  rewriter.replaceOpWithNewOp<IREE::Flow::ReturnOp>(newDispatchReturnOp,
+                                                    newDispatchReturnOperands);
+
+  // Replace operands of the `hoistedOp` with dispatch region results. They are
+  // currently using values from inside the dispatch region.
+  for (auto [idx, operand] : llvm::enumerate(hoistedOp->getOperands())) {
+    auto newResultIdx = resultIndsForHoistedOperands[idx];
+    Value newDispatchResult = newDispatchRegionOp->getResults()[newResultIdx];
+    rewriter.replaceUsesWithIf(operand, newDispatchResult,
+                               [&](OpOperand &opOperand) {
+                                 return opOperand.getOwner() == hoistedOp;
+                               });
+  }
+
+  // Step 4: Fixup all uses. Still need to replace the operands of the hoisted
+  //         op, and replace the remaining uses of the old dispatch region with
+  //         the new dispatch region results.
+
+  // Replace the uses of the original dispatch region results with the final
+  // dispatch region results.
+  for (auto [oldIdx, newIdx] : llvm::enumerate(oldDispatchResultInds)) {
+    Value newDispatchResult = newDispatchRegionOp->getResults()[newIdx];
+    Value dispatchResult = dispatchRegionOp->getResults()[oldIdx];
+    rewriter.replaceAllUsesWith(dispatchResult, newDispatchResult);
+  }
+
+  return hoistedOp;
+}
+
 //===---------------------------------------------------------------------===//
 // Utilities to make a dispatch region isolated from above
 //===---------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h
index 45a375f..1451cc3 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h
@@ -104,6 +104,18 @@
 /// into a dispatch region.
 bool isClonableIntoDispatchOp(Operation *op);
 
+/// Hoists an operation out of a dispatch region, as long as it does not have
+/// producers inside of the dispatch region, or all of its uses are part of
+/// the dispatch region op return. If these criteria are not met, then return
+/// failure.
+///
+/// If all producers are defined outside of the dispatch region, then the op
+/// will be hoisted above the dispatch region op. Otherwise, the op will be
+/// hoisted below the dispatch region op, and the operands of the hoisted op
+/// will be added to the yielded values of the dispatch region op.
+FailureOr<Operation *> hoistOutOfDispatch(RewriterBase &rewriter,
+                                          Operation *op);
+
 /// Collect all ops that should be cloned into the given dispatch region op.
 SmallVector<Operation *> getCloneableOps(Flow::DispatchRegionOp regionOp);
 
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
index ee4f936..1c16925 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
@@ -42,6 +42,7 @@
             "fuse_horizontal_contractions.mlir",
             "fuse_multiuse_elementwise_producer.mlir",
             "fusion_preprocessing.mlir",
+            "hoist_encoding_ops.mlir",
             "initialize_empty_tensors.mlir",
             "inject_dispatch_tracing.mlir",
             "inject_tensor_tracing.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
index a88bb4f..203a4ad 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
@@ -40,6 +40,7 @@
     "fuse_horizontal_contractions.mlir"
     "fuse_multiuse_elementwise_producer.mlir"
     "fusion_preprocessing.mlir"
+    "hoist_encoding_ops.mlir"
     "initialize_empty_tensors.mlir"
     "inject_dispatch_tracing.mlir"
     "inject_tensor_tracing.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/hoist_encoding_ops.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/hoist_encoding_ops.mlir
new file mode 100644
index 0000000..ca97bb0
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/hoist_encoding_ops.mlir
@@ -0,0 +1,192 @@
+// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-flow-hoist-encoding-ops))" --split-input-file %s | FileCheck %s
+
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#lhs_encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type =  matmul, element_types = [f32, f32, f32], original_type = tensor<2x128x64xf32>, user_indexing_maps = [#map1, #map2, #map3], round_dims_to = array<i64: 32, 32, 32>>
+#rhs_encoding = #iree_encoding.encoding<operand_index = 1 : index, op_type =  matmul, element_types = [f32, f32, f32], original_type = tensor<2x11008x128xf32>, user_indexing_maps = [#map1, #map2, #map3], round_dims_to = array<i64: 32, 32, 32>>
+#result_encoding = #iree_encoding.encoding<operand_index = 2 : index, op_type =  matmul, element_types = [f32, f32, f32], original_type = tensor<2x11008x64xf32>, user_indexing_maps = [#map1, #map2, #map3], round_dims_to = array<i64: 32, 32, 32>>
+module {
+  util.func public @hoist_matmul_encodings(%arg0: tensor<2x128x64xf32>, %arg1: tensor<2x11008x128xf32>) -> tensor<2x11008x64xf32> {
+    %cst = arith.constant 0.000000e+00 : f32
+    %2 = flow.dispatch.region -> (tensor<2x11008x64xf32>) {
+      %3 = iree_encoding.set_encoding %arg0 : tensor<2x128x64xf32> -> tensor<2x128x64xf32, #lhs_encoding>
+      %4 = iree_encoding.set_encoding %arg1 : tensor<2x11008x128xf32> -> tensor<2x11008x128xf32, #rhs_encoding>
+      %5 = tensor.empty() : tensor<2x11008x64xf32, #result_encoding>
+      %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<2x11008x64xf32, #result_encoding>) -> tensor<2x11008x64xf32, #result_encoding>
+      %7 = linalg.generic {
+          indexing_maps = [#map1, #map2, #map3],
+          iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+          ins(%3, %4 : tensor<2x128x64xf32, #lhs_encoding>, tensor<2x11008x128xf32, #rhs_encoding>)
+          outs(%6 : tensor<2x11008x64xf32, #result_encoding>) {
+      ^bb0(%in: f32, %in_0: f32, %out: f32):
+        %9 = arith.mulf %in, %in_0 : f32
+        %10 = arith.addf %9, %out : f32
+        linalg.yield %10 : f32
+      } -> tensor<2x11008x64xf32, #result_encoding>
+      %8 = iree_encoding.unset_encoding %7 : tensor<2x11008x64xf32, #result_encoding> -> tensor<2x11008x64xf32>
+      flow.return %8 : tensor<2x11008x64xf32>
+    }
+    util.return %2 : tensor<2x11008x64xf32>
+  }
+}
+
+// CHECK-LABEL: @hoist_matmul_encodings
+// CHECK-SAME:    (%[[ARG0:.+]]: tensor<2x128x64xf32>, %[[ARG1:.+]]: tensor<2x11008x128xf32>)
+// CHECK-DAG:   %[[SET_ENCODING0:.+]] = iree_encoding.set_encoding %[[ARG0]] : tensor<2x128x64xf32> -> tensor<2x128x64xf32, #iree_encoding.encoding
+// CHECK-DAG:   %[[SET_ENCODING1:.+]] = iree_encoding.set_encoding %[[ARG1]] : tensor<2x11008x128xf32> -> tensor<2x11008x128xf32, #iree_encoding.encoding
+// CHECK:       %[[DISPATCH:.+]] = flow.dispatch.region -> (tensor<2x11008x64xf32>) {
+// CHECK:         %[[MATMUL:.+]] = linalg.generic {{.*}} ins(%[[SET_ENCODING0]], %[[SET_ENCODING1]]
+// CHECK:         %[[UNSET_ENCODING1:.+]] = iree_encoding.unset_encoding %[[MATMUL]] : tensor<2x11008x64xf32, #iree_encoding.encoding
+// CHECK:         flow.return %[[UNSET_ENCODING1]] : tensor<2x11008x64xf32>
+// CHECK:       }
+// CHECK:       util.return %[[DISPATCH]] : tensor<2x11008x64xf32>
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#encoding = #iree_encoding.encoding<operand_index = 1 : index, op_type = matmul, element_types = [f32, f32, f32], original_type = tensor<2x11008x128xf32>, user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], round_dims_to = array<i64: 32, 32, 32>>
+util.func public @bubble_through_dequant(
+    %arg0: tensor<2x11008x128xi8>, %arg1: tensor<2x11008xf32>, %arg2: tensor<2x11008xf32>) -> tensor<2x11008x128xf32, #encoding> {
+  %6 = flow.dispatch.region -> (tensor<2x11008x128xf32, #encoding>) {
+    %8 = tensor.empty() : tensor<2x11008x128xf32>
+    %11 = linalg.generic
+        {indexing_maps = [#map, #map1, #map1, #map],
+        iterator_types = ["parallel", "parallel", "parallel"]}
+        ins(%arg0, %arg1, %arg2 : tensor<2x11008x128xi8>, tensor<2x11008xf32>, tensor<2x11008xf32>)
+        outs(%8 : tensor<2x11008x128xf32>) {
+    ^bb0(%in: i8, %in_0: f32, %in_1: f32, %out: f32):
+      %18 = arith.extui %in : i8 to i32
+      %19 = arith.uitofp %18 : i32 to f32
+      %20 = arith.subf %19, %in_1 : f32
+      %21 = arith.mulf %20, %in_0 : f32
+      linalg.yield %21 : f32
+    } -> tensor<2x11008x128xf32>
+    %13 = iree_encoding.set_encoding %11 : tensor<2x11008x128xf32> -> tensor<2x11008x128xf32, #encoding>
+    flow.return %13 : tensor<2x11008x128xf32, #encoding>
+  }
+  util.return %6 : tensor<2x11008x128xf32, #encoding>
+}
+
+// CHECK-DAG:   #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK-DAG:   #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG:   #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-DAG:   #[[$MAP3:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-DAG:   #[[$MAP4:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-LABEL: @bubble_through_dequant
+// CHECK-SAME:    %[[ARG0:.+]]: tensor<2x11008x128xi8>,
+// CHECK-SAME:    %[[ARG1:.+]]: tensor<2x11008xf32>, %[[ARG2:.+]]: tensor<2x11008xf32>
+// CHECK-DAG:   %[[SET_ENCODING0:.+]] = iree_encoding.set_encoding %[[ARG0]] : {{.*}} bcast_map = #[[$MAP4]]
+// CHECK-DAG:   %[[SET_ENCODING1:.+]] = iree_encoding.set_encoding %[[ARG1]] : {{.*}} bcast_map = #[[$MAP3]]
+// CHECK-DAG:   %[[SET_ENCODING2:.+]] = iree_encoding.set_encoding %[[ARG2]] : {{.*}} bcast_map = #[[$MAP3]]
+// CHECK:       %[[DISPATCH:.+]] = flow.dispatch.region
+// CHECK:         %[[INIT:.+]] = tensor.empty() : tensor<2x11008x128xf32, #iree_encoding.encoding
+// CHECK:         %[[DEQUANT:.+]] = linalg.generic {{.*}} ins(%[[SET_ENCODING0]], %[[SET_ENCODING1]], %[[SET_ENCODING2]] : {{.*}} outs(%[[INIT]] :
+// CHECK:         flow.return %[[DEQUANT]]
+// CHECK:       }
+// CHECK:       util.return %[[DISPATCH]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#encoding = #iree_encoding.encoding<operand_index = 1 : index, op_type = matmul, element_types = [f32, f32, f32], original_type = tensor<2x11008x128xf32>, user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], round_dims_to = array<i64: 32, 32, 32>>
+util.func public @bubble_through_broadcast(
+    %arg0: tensor<11008x128xf32>) -> tensor<2x11008x128xf32, #encoding> {
+  %6 = flow.dispatch.region -> (tensor<2x11008x128xf32, #encoding>) {
+    %8 = tensor.empty() : tensor<2x11008x128xf32>
+    %11 = linalg.generic
+        {indexing_maps = [#map1, #map],
+        iterator_types = ["parallel", "parallel", "parallel"]}
+        ins(%arg0 : tensor<11008x128xf32>)
+        outs(%8 : tensor<2x11008x128xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      linalg.yield %in : f32
+    } -> tensor<2x11008x128xf32>
+    %13 = iree_encoding.set_encoding %11 : tensor<2x11008x128xf32> -> tensor<2x11008x128xf32, #encoding>
+    flow.return %13 : tensor<2x11008x128xf32, #encoding>
+  }
+  util.return %6 : tensor<2x11008x128xf32, #encoding>
+}
+
+// CHECK-DAG:   #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK-DAG:   #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG:   #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-DAG:   #[[$MAP3:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG:   #[[$MAP4:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-LABEL: @bubble_through_broadcast
+// CHECK-SAME:    %[[ARG0:.+]]: tensor<11008x128xf32>
+// CHECK-DAG:   %[[SET_ENCODING:.+]] = iree_encoding.set_encoding %[[ARG0]] : {{.*}} bcast_map = #[[$MAP3]]
+// CHECK:       %[[DISPATCH:.+]] = flow.dispatch.region
+// CHECK:         %[[INIT:.+]] = tensor.empty() : tensor<2x11008x128xf32, #iree_encoding.encoding
+// CHECK:         %[[BROADCAST:.+]] = linalg.generic {{.*}} ins(%[[SET_ENCODING]] : {{.*}} outs(%[[INIT]] :
+// CHECK:         flow.return %[[BROADCAST]]
+// CHECK:       }
+// CHECK:       util.return %[[DISPATCH]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#encoding = #iree_encoding.encoding<operand_index = 1 : index, op_type =  matmul, element_types = [f32, f32, f32], original_type = tensor<2x11008x128xf32>, user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 32, 32, 32>>
+module {
+  util.func public @hoist_below(%arg0: tensor<2x11008x128xf32>) -> tensor<2x11008x128xf32, #encoding> {
+    %0 = flow.dispatch.region -> (tensor<2x11008x128xf32, #encoding>) {
+      %1 = tensor.empty() : tensor<2x11008x128xf32>
+      %2 = linalg.generic {indexing_maps = [#map3, #map3, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %arg0 : tensor<2x11008x128xf32>, tensor<2x11008x128xf32>) outs(%1 : tensor<2x11008x128xf32>) {
+      ^bb0(%in: f32, %in_0: f32, %out: f32):
+        %4 = arith.addf %in, %in_0 : f32
+        linalg.yield %4 : f32
+      } -> tensor<2x11008x128xf32>
+      %3 = iree_encoding.set_encoding %2 : tensor<2x11008x128xf32> -> tensor<2x11008x128xf32, #encoding>
+      flow.return %3 : tensor<2x11008x128xf32, #encoding>
+    }
+    util.return %0 : tensor<2x11008x128xf32, #encoding>
+  }
+}
+
+// CHECK-LABEL: @hoist_below
+// CHECK-SAME:    %[[ARG0:.+]]: tensor<2x11008x128xf32>
+// CHECK:       %[[DISPATCH:.+]] = flow.dispatch.region
+// CHECK:         %[[INIT:.+]] = tensor.empty() : tensor<2x11008x128xf32>
+// CHECK:         %[[ADD:.+]] = linalg.generic {{.*}} ins(%[[ARG0]], %[[ARG0]] : {{.*}} outs(%[[INIT]] :
+// CHECK:         flow.return %[[ADD]]
+// CHECK:       }
+// CHECK:       %[[SET_ENCODING:.+]] = iree_encoding.set_encoding %[[DISPATCH]]
+// CHECK:       util.return %[[SET_ENCODING]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#encoding = #iree_encoding.encoding<operand_index = 1 : index, op_type =  matmul, element_types = [f32, f32, f32], original_type = tensor<?x?x?xf32>, user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 32, 32, 32>>
+module {
+  util.func public @hoist_dynamic(%arg0: tensor<?x?x?xf32>, %d0: index, %d1: index, %d2: index) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32, #encoding>) {
+    %0:2 = flow.dispatch.region -> (tensor<?x?x?xf32>{%d0, %d1, %d2}, tensor<?x?x?xf32, #encoding>{%d0, %d1, %d2}) {
+      %1 = tensor.empty(%d0, %d1, %d2) : tensor<?x?x?xf32>
+      %2 = linalg.generic {indexing_maps = [#map3, #map3, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %arg0 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%1 : tensor<?x?x?xf32>) {
+      ^bb0(%in: f32, %in_0: f32, %out: f32):
+        %4 = arith.addf %in, %in_0 : f32
+        linalg.yield %4 : f32
+      } -> tensor<?x?x?xf32>
+      %3 = iree_encoding.set_encoding %2 : tensor<?x?x?xf32> -> tensor<?x?x?xf32, #encoding>
+      flow.return %2, %3 : tensor<?x?x?xf32>, tensor<?x?x?xf32, #encoding>
+    }
+    util.return %0#0, %0#1 : tensor<?x?x?xf32>, tensor<?x?x?xf32, #encoding>
+  }
+}
+
+// CHECK-LABEL: @hoist_dynamic
+// CHECK-SAME:    %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[D0:.+]]: index, %[[D1:.+]]: index, %[[D2:.+]]: index)
+// CHECK:       %[[DISPATCH:.+]] = flow.dispatch.region -> (tensor<?x?x?xf32>{%[[D0]], %[[D1]], %[[D2]]})
+// CHECK:         %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]], %[[D2]]) : tensor<?x?x?xf32>
+// CHECK:         %[[ADD:.+]] = linalg.generic {{.*}} ins(%[[ARG0]], %[[ARG0]] : {{.*}} outs(%[[INIT]] :
+// CHECK:         flow.return %[[ADD]]
+// CHECK:       }
+// CHECK:       %[[SET_ENCODING:.+]] = iree_encoding.set_encoding %[[DISPATCH]]
+// CHECK:       util.return %[[DISPATCH]], %[[SET_ENCODING]]
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
index 30dccb0..d46c58b 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
@@ -244,4 +244,44 @@
   return isBitExtendOrTruncateOp(op) == BitWidthChangeInfo::kTruncate;
 }
 
+//===---------------------------------------------------------------------===//
+// Classification of other ops
+//===---------------------------------------------------------------------===//
+
+bool isBroadcastingOp(linalg::LinalgOp op) {
+  if (isa<linalg::BroadcastOp>(op)) {
+    return true;
+  }
+  auto genericOp = dyn_cast<linalg::GenericOp>(op.getOperation());
+  if (!genericOp) {
+    return false;
+  }
+
+  // Only allow a single input and init.
+  if (genericOp.getNumDpsInits() != 1 || genericOp.getNumDpsInputs() != 1) {
+    return false;
+  }
+
+  // Check that the all loops are parallel.
+  unsigned numLoops = genericOp.getNumLoops();
+  unsigned numParallelLoops = genericOp.getNumParallelLoops();
+  if (numLoops != numParallelLoops) {
+    return false;
+  }
+
+  // Check that indexing maps are broadcasting.
+  SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
+  auto inMap =
+      genericOp.getMatchingIndexingMap(genericOp.getDpsInputOperand(0));
+  auto outMap =
+      genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
+  if (inMap.getNumResults() >= outMap.getNumResults()) {
+    return false;
+  }
+  if (!inMap.isProjectedPermutation() || !outMap.isIdentity()) {
+    return false;
+  }
+  return llvm::hasSingleElement(op.getBlock()->getOperations());
+}
+
 } // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h
index 3c4e139..d6794af 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h
@@ -7,6 +7,7 @@
 #ifndef IREE_COMPILER_DIALECT_LINALGEXT_UTILS_UTILS_H_
 #define IREE_COMPILER_DIALECT_LINALGEXT_UTILS_UTILS_H_
 
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
@@ -127,5 +128,16 @@
 ///    the output element type has a lower bitwidth.
 bool isBitTruncateOp(Operation *op);
 
+/// Returns true if the operation is a BroadcastOp or a GenericOp performing
+/// a broadcast.
+/// This function checks that the genericOp:
+///     1. Has a single input and output.
+///     2. Has all parallel loops.
+///     3. Has an identity output map.
+///     4. Has a projected permutation input map.
+///     5. The input map has fewer results than the output map.
+///     6. Has a body with only a linalg.yield op.
+bool isBroadcastingOp(linalg::LinalgOp op);
+
 } // namespace mlir::iree_compiler::IREE::LinalgExt
 #endif // IREE_COMPILER_DIALECT_LINALGEXT_UTILS_UTILS_H_