[Flow] Move first part of Flow transforms to new pipeline (#18290)
Refactor the logic in `Flow` that forms dispatches into a new `DispatchCreation` pipeline. This moves the passes from the start of `Flow` up until `createMaterializeDefaultWorkgroupCountRegionPass`. This commit will most likely be followed by additional commit(s) that do further (but less invasive) cleanup.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.cpp b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.cpp
index 163badc..a7d746a 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.cpp
@@ -8,7 +8,6 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.h"
-#include "iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
index b790a99..36e3248 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
@@ -31,53 +31,31 @@
name = "Transforms",
srcs = [
"AnnotateDispatches.cpp",
- "BubbleUpExpandShapes.cpp",
"Canonicalizer.cpp",
"CaptureDynamicDims.cpp",
"CleanupTensorShapes.cpp",
- "CloneProducersIntoDispatchRegions.cpp",
- "CollapseDimensions.cpp",
- "CollapseReductionDimensions.cpp",
- "ConvertDispatchRegionsToWorkgroups.cpp",
"ConvertMeshToFlow.cpp",
"ConvertRegionToWorkgroups.cpp",
- "ConvertTensorToFlow.cpp",
"ConvertToFlow.cpp",
"DeduplicateExecutables.cpp",
- "DispatchWithTransformDialect.cpp",
"DumpDispatchGraph.cpp",
- "ElementwiseOpFusion.cpp",
"ExportBenchmarkFuncs.cpp",
- "FoldUnitExtentDims.cpp",
"FormDispatchRegions.cpp",
- "FormScalarDispatches.cpp",
- "FuseHorizontalContractions.cpp",
- "FuseMultiUseElementwiseProducer.cpp",
- "FusionPreprocessing.cpp",
- "FusionUtils.cpp",
- "HoistEncodingOps.cpp",
"InitializeEmptyTensors.cpp",
"InjectDispatchTracing.cpp",
"InjectTensorTracing.cpp",
"InsertDispatchDebugTargets.cpp",
- "MaterializeDefaultWorkgroupCountRegion.cpp",
"OutlineConstants.cpp",
"OutlineDispatchExterns.cpp",
"OutlineDispatchRegions.cpp",
"Passes.cpp",
"RegionOpUtils.cpp",
- "SetEncoding.cpp",
- "SinkReshapes.cpp",
- "SplitReduction.cpp",
- "TensorPadToTensorInsertSlice.cpp",
"TopLevelSCFToCFG.cpp",
- "TransposeGenericOps.cpp",
"VerifyInputLegality.cpp",
],
hdrs = [
"ConvertRegionToWorkgroups.h",
"FormDispatchRegions.h",
- "FusionUtils.h",
"Passes.h",
"Passes.h.inc",
"RegionOpUtils.h",
@@ -120,8 +98,6 @@
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:MemRefTransforms",
"@llvm-project//mlir:MeshDialect",
- "@llvm-project//mlir:PDLDialect",
- "@llvm-project//mlir:PDLInterpDialect",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
@@ -132,8 +108,6 @@
"@llvm-project//mlir:TensorUtils",
"@llvm-project//mlir:TilingInterface",
"@llvm-project//mlir:TosaDialect",
- "@llvm-project//mlir:TransformDialect",
- "@llvm-project//mlir:TransformDialectTransforms",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
],
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index 7bbb5d5..2fa2e77 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -25,53 +25,31 @@
HDRS
"ConvertRegionToWorkgroups.h"
"FormDispatchRegions.h"
- "FusionUtils.h"
"Passes.h"
"Passes.h.inc"
"RegionOpUtils.h"
SRCS
"AnnotateDispatches.cpp"
- "BubbleUpExpandShapes.cpp"
"Canonicalizer.cpp"
"CaptureDynamicDims.cpp"
"CleanupTensorShapes.cpp"
- "CloneProducersIntoDispatchRegions.cpp"
- "CollapseDimensions.cpp"
- "CollapseReductionDimensions.cpp"
- "ConvertDispatchRegionsToWorkgroups.cpp"
"ConvertMeshToFlow.cpp"
"ConvertRegionToWorkgroups.cpp"
- "ConvertTensorToFlow.cpp"
"ConvertToFlow.cpp"
"DeduplicateExecutables.cpp"
- "DispatchWithTransformDialect.cpp"
"DumpDispatchGraph.cpp"
- "ElementwiseOpFusion.cpp"
"ExportBenchmarkFuncs.cpp"
- "FoldUnitExtentDims.cpp"
"FormDispatchRegions.cpp"
- "FormScalarDispatches.cpp"
- "FuseHorizontalContractions.cpp"
- "FuseMultiUseElementwiseProducer.cpp"
- "FusionPreprocessing.cpp"
- "FusionUtils.cpp"
- "HoistEncodingOps.cpp"
"InitializeEmptyTensors.cpp"
"InjectDispatchTracing.cpp"
"InjectTensorTracing.cpp"
"InsertDispatchDebugTargets.cpp"
- "MaterializeDefaultWorkgroupCountRegion.cpp"
"OutlineConstants.cpp"
"OutlineDispatchExterns.cpp"
"OutlineDispatchRegions.cpp"
"Passes.cpp"
"RegionOpUtils.cpp"
- "SetEncoding.cpp"
- "SinkReshapes.cpp"
- "SplitReduction.cpp"
- "TensorPadToTensorInsertSlice.cpp"
"TopLevelSCFToCFG.cpp"
- "TransposeGenericOps.cpp"
"VerifyInputLegality.cpp"
DEPS
::PassesIncGen
@@ -95,8 +73,6 @@
MLIRMemRefDialect
MLIRMemRefTransforms
MLIRMeshDialect
- MLIRPDLDialect
- MLIRPDLInterpDialect
MLIRParser
MLIRPass
MLIRSCFDialect
@@ -107,8 +83,6 @@
MLIRTensorUtils
MLIRTilingInterface
MLIRTosaDialect
- MLIRTransformDialect
- MLIRTransformDialectTransforms
MLIRTransformUtils
MLIRTransforms
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.cpp
index fcd5bf3..eb83ea8 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.cpp
@@ -8,7 +8,6 @@
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
index 553a097..3dbc951 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
@@ -6,46 +6,19 @@
#include "iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.h"
-#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
-#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
-#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.h"
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
-#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
-#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
-#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
-#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
-#include "llvm/Support/Debug.h"
-#include "mlir/Analysis/TopologicalSortUtils.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Linalg/Utils/Utils.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
-#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dominance.h"
-#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
-#include "mlir/Interfaces/DestinationStyleOpInterface.h"
-#include "mlir/Interfaces/FunctionInterfaces.h"
-#include "mlir/Interfaces/TilingInterface.h"
-#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
-#define DEBUG_TYPE "iree-flow-form-dispatch-regions"
-
-static const char kRootOpAttr[] = "__root_op__";
-static const char kFusionGroupsAttr[] = "__fused_op__";
-
//===----------------------------------------------------------------------===//
// Definition of TensorDimTrackingRewriter
//===----------------------------------------------------------------------===//
@@ -80,9 +53,6 @@
namespace mlir::iree_compiler::IREE::Flow {
-#define GEN_PASS_DEF_FORMDISPATCHREGIONSPASS
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc"
-
LogicalResult simplifyDimOps(RewriterBase &rewriter,
const SmallVector<tensor::DimOp> &dimOps) {
for (tensor::DimOp dimOp : dimOps) {
@@ -120,902 +90,4 @@
return success();
}
-//===----------------------------------------------------------------------===//
-// Root and fusion group attribute handling
-//===----------------------------------------------------------------------===//
-
-/// Returns true if an op has a root operation.
-static bool hasRootOpAttribute(Operation *op) {
- return static_cast<bool>(op->getAttrOfType<IntegerAttr>(kRootOpAttr));
-}
-/// Removes root attribute. Asserts if root attribute is not present.
-static void removeRootOpAttribute(Operation *op) {
- op->removeAttr(kRootOpAttr);
-}
-/// Sets the root attribute for an operation. The root attribute needs a number
-/// to identify the root. Asserts if root attribute is already set on an
-/// operation.
-static void setRootAttribute(MLIRContext *context, Operation *op,
- int64_t rootNumber) {
- assert(!op->hasAttr(kRootOpAttr) &&
- "invalid to update root attribute on an op");
- op->setAttr(kRootOpAttr,
- IntegerAttr::get(IntegerType::get(context, 64), rootNumber));
-}
-/// Returns the number of the root. Asserts if the operation is not already set
-/// as a root.
-static int64_t getRootNumber(Operation *op) {
- return op->getAttrOfType<IntegerAttr>(kRootOpAttr).getInt();
-}
-/// Returns true if an op is part of a fusion group.
-static bool hasFusionGroupsAttribute(Operation *op) {
- return static_cast<bool>(op->getAttrOfType<ArrayAttr>(kFusionGroupsAttr));
-}
-/// Returns the fusion groups for the given `op`.
-static SmallVector<int64_t, 1> getFusionGroups(Operation *op) {
- SmallVector<int64_t, 1> fusionGroups = {};
- if (auto fusionGroupsAttr = op->getAttrOfType<ArrayAttr>(kFusionGroupsAttr)) {
- fusionGroups = llvm::map_to_vector(fusionGroupsAttr, [](Attribute attr) {
- return llvm::cast<IntegerAttr>(attr).getInt();
- });
- }
- return fusionGroups;
-}
-/// Appends the given `op` to the `newGroups` fusion groups.
-static void appendToFusionGroup(Operation *op, ArrayRef<int64_t> newGroups) {
- SmallVector<int64_t> fusionGroups = getFusionGroups(op);
- fusionGroups.append(newGroups.begin(), newGroups.end());
- op->setAttr(kFusionGroupsAttr, Builder(op).getI64ArrayAttr(fusionGroups));
-}
-/// Removes the fusion groups attribute.
-static void removeFusionGroupsAttribute(Operation *op) {
- op->removeAttr(kFusionGroupsAttr);
-}
-
-//===----------------------------------------------------------------------===//
-// Op property charecterizations
-//===----------------------------------------------------------------------===//
-
-/// Returns true if the reduced dimensions in the linalgOp of the unpack result
-/// are not unpacked by the producer tensor::UnPackOp. This means the reduced
-/// dimensions of the unpack result are not part of the inner_dims_pos.
-static bool hasNoPackedReductionDimensions(linalg::LinalgOp linalgOp,
- Operation *producer) {
- auto unpack = dyn_cast<tensor::UnPackOp>(producer);
- if (!unpack) {
- return false;
- }
- AffineMap map;
- for (auto &use : producer->getResult(0).getUses()) {
- if (use.getOwner() == linalgOp) {
- map = linalgOp.getMatchingIndexingMap(&use);
- break;
- }
- }
- if (!map) {
- return false;
- }
- auto iterators = linalgOp.getIteratorTypesArray();
- auto reduction = utils::IteratorType::reduction;
- for (auto expr : llvm::enumerate(map.getResults())) {
- auto dim = dyn_cast<AffineDimExpr>(expr.value());
- if (!dim) {
- return false;
- }
- unsigned pos = dim.getPosition();
- if (iterators[pos] == reduction &&
- llvm::any_of(unpack.getInnerDimsPos(),
- [expr](int64_t idp) { return expr.index() == idp; })) {
- return false;
- }
- }
- return true;
-}
-
-/// Returns true if the linalgOp is fusable with an unpack producer
-static bool hasFusableUnpackProducer(linalg::LinalgOp linalgOp) {
- return llvm::any_of(linalgOp->getOperands(), [&](Value operand) {
- auto producer = operand.getDefiningOp<tensor::UnPackOp>();
- return producer && hasNoPackedReductionDimensions(linalgOp, producer);
- });
-}
-
-/// Operations that are treated as root operations for dispatch region
-/// formation.
-static bool isRootOp(Operation *op) {
- if (op->getParentOfType<IREE::Flow::DispatchWorkgroupsOp>()) {
- return false;
- }
- // Dequantization-like ops get cloned into dispatches later.
- if (LinalgExt::isBitExtendOp(op)) {
- return false;
- }
- // Any Linalg named op or generic op with reduction iterator types is a root
- // op.
- if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
- if (isa<linalg::GenericOp>(op)) {
- return linalgOp.getNumReductionLoops() != 0 &&
- !hasFusableUnpackProducer(linalgOp);
- }
- return !isa<linalg::FillOp>(op);
- }
- if (isa<TilingInterface>(op)) {
- return !isa<tensor::PadOp, tensor::PackOp>(op);
- }
- return isa<Encoding::UnsetEncodingOp, tensor::UnPackOp>(op);
-}
-
-/// Returns true if the operation is a `pack` op or a `set_encoding` op that
-/// has pack semantics.
-// TODO(ravishankarm): This seems like a use case for an interface.
-static bool isPackLikeOp(Operation *op) {
- return isa<IREE::Encoding::SetEncodingOp, tensor::PackOp>(op);
-}
-
-/// Returns true if the operation is an `unpack` op or an `unset_encoding` op,
-/// or an `extract_slice` op whose source operand matches those criteria,
-/// recursively.
-/// The idea is that we want to ensure that `extract_slice` ops can't prevent
-/// fusion between a `unset_encoding` producer and some linalg consumer. In
-/// %0 = unset_encoding ...
-/// %1 = extract_slice %0 ...
-/// %2 = linalg.generic ins(%1) ...
-/// we are not content to be fusing %1 into %0, we also want to be fusing %2,
-/// so we want to prevent %1 from acting as a consumer fusion barrier.
-static bool isUnpackLikeOpViaExtractSliceOps(Operation *op) {
- if (isa<IREE::Encoding::UnsetEncodingOp, tensor::UnPackOp>(op)) {
- return true;
- }
- if (isa<tensor::ExtractSliceOp>(op)) {
- Value source = op->getOperand(0);
- Operation *producer = source.getDefiningOp();
- if (isUnpackLikeOpViaExtractSliceOps(producer)) {
- return true;
- }
- }
- return false;
-}
-
-/// Since `iree_encoding.set_encoding` doesnt have padding semantics a
-/// `tensor.pad` is introduced to get the shapes of the input and output to
-/// match. The `tensor.pad` -> `set_encoding` can be folded later on into a
-/// single `tensor.pack` operation. But it means the fusion has to try to keep
-/// these in the same dispatch.
-// TODO(ravishankarm): Maybe make `set_encoding` have pad semantics that can be
-// explicitly broken down if needed.
-static bool isPadUsedInSetEncoding(tensor::PadOp padOp) {
- return llvm::any_of(padOp->getUsers(),
- llvm::IsaPred<IREE::Encoding::SetEncodingOp>);
-}
-
-//===----------------------------------------------------------------------===//
-// Heuristics for fusing dispatchble ops with root ops using tile + fuse.
-//===----------------------------------------------------------------------===//
-
-/// Returns a bit vector of size number of loops of the `interfaceOp` with
-/// the bits corresponding to outer parallel loops set to `true`.
-static llvm::SmallBitVector getOuterParallelLoops(Operation *op) {
- if (auto setEncodingOp = dyn_cast<IREE::Encoding::SetEncodingOp>(op)) {
- return llvm::SmallBitVector(setEncodingOp.getResultType().getRank(), true);
- }
- if (auto unsetEncodingOp = dyn_cast<IREE::Encoding::UnsetEncodingOp>(op)) {
- return llvm::SmallBitVector(unsetEncodingOp.getResultType().getRank(),
- true);
- }
-
- auto interfaceOp = dyn_cast<TilingInterface>(op);
- if (!interfaceOp) {
- // For ops that dont implement the `TilingInterface` just return empty.
- return llvm::SmallBitVector{};
- }
- SmallVector<utils::IteratorType> loopIteratorTypes =
- interfaceOp.getLoopIteratorTypes();
- llvm::SmallBitVector parallelLoops(loopIteratorTypes.size());
- for (auto iteratorType : llvm::enumerate(loopIteratorTypes)) {
- if (iteratorType.value() != utils::IteratorType::parallel)
- break;
- parallelLoops.set(iteratorType.index());
- }
- return parallelLoops;
-}
-
-/// Returns true if `map` is an identity map with zeros, i.e. if you
-/// drop the result exprs that are constant zeros, the `map` will become an
-/// identity.
-static bool isIdentityMapWithZeros(AffineMap map) {
- if (map.getNumSymbols() != 0)
- return false;
- if (map.isEmpty())
- return false;
- unsigned dimsSeen = 0;
- for (AffineExpr result : map.getResults()) {
- if (auto dimExpr = dyn_cast<AffineDimExpr>(result)) {
- if (dimExpr.getPosition() != dimsSeen) {
- return false;
- }
- dimsSeen++;
- } else if (auto constExpr = dyn_cast<AffineConstantExpr>(result)) {
- if (constExpr.getValue() != 0) {
- return false;
- }
- } else {
- return false;
- }
- }
- return dimsSeen == map.getNumDims();
-}
-
-static bool
-matchIteratorTypes(const llvm::SmallBitVector &rootOuterParallelLoop,
- const llvm::SmallBitVector &candidateOuterParallelLoop) {
- // If the candidate is not all parallel, then its loop configuration should be
- // the same as the root.
- if (candidateOuterParallelLoop.size() != candidateOuterParallelLoop.count()) {
- return rootOuterParallelLoop == candidateOuterParallelLoop;
- }
-
- // If the candidate is all parallel, then it should be at least as parallel as
- // the root.
- for (int pos : llvm::seq<int>(0, rootOuterParallelLoop.size())) {
- // If we reach the end of the outer loops of the root, break out of the
- // loop.
- if (!rootOuterParallelLoop.test(pos))
- break;
- // If the root loop is parallel, the candidate loop should also be parallel.
- if (pos >= candidateOuterParallelLoop.size() ||
- !candidateOuterParallelLoop.test(pos))
- return false;
- }
- return true;
-}
-
-// Method to check if the op with have compatible indexing map on outer-parallel
-// loops. Currently it means the map needs to be identity on the those
-// dimensions, ignoring its reduction dimensions.
-static bool hasCompatibleOuterParallelLoops(
- TilingInterface tileOp, AffineMap indexingMap,
- const llvm::SmallBitVector &rootOuterParallelLoops) {
- if (!indexingMap.isProjectedPermutation()) {
- return false;
- }
-
- llvm::SmallBitVector parallelLoops = getOuterParallelLoops(tileOp);
- if (!matchIteratorTypes(rootOuterParallelLoops, parallelLoops)) {
- return false;
- }
-
- /// Project out the non-parallel dimensions.
- llvm::SmallBitVector projectedDims(rootOuterParallelLoops);
- projectedDims.flip();
- projectedDims.resize(tileOp.getLoopIteratorTypes().size(), true);
- auto projectedMap = getProjectedMap(indexingMap, projectedDims);
- return isIdentityMapWithZeros(projectedMap);
-}
-
-// Method to check if two `linalg.generic` op with producer-consumer
-// relationship through `operand` have compatible outer-parallel loops.
-static bool hasCompatibleOuterParallelLoops(
- OpOperand &operand, const llvm::SmallBitVector &rootOuterParallelLoops) {
- auto producer =
- operand.get().getDefiningOp<LinalgExt::LinalgFusionOpInterface>();
- auto consumer =
- dyn_cast<LinalgExt::LinalgFusionOpInterface>(operand.getOwner());
- if (!producer || !consumer)
- return false;
-
- auto producerIndexingMap = producer.getIndexingMapMatchingResult(
- llvm::cast<OpResult>(operand.get()));
- auto consumerIndexingMap = consumer.getMatchingIndexingMap(&operand);
-
- if (!producerIndexingMap || !consumerIndexingMap) {
- return false;
- }
-
- return hasCompatibleOuterParallelLoops(
- cast<TilingInterface>(producer.getOperation()),
- producerIndexingMap, rootOuterParallelLoops) &&
- hasCompatibleOuterParallelLoops(
- cast<TilingInterface>(consumer.getOperation()),
- consumerIndexingMap, rootOuterParallelLoops);
-}
-
-/// For all uses of an operation, finds the use that dominates all other uses.
-static std::optional<OpOperand *>
-getFusableUse(Operation *op, DominanceInfo const &dominanceInfo,
- bool aggressiveFusion) {
- if (!aggressiveFusion && llvm::count_if(op->getUses(), [](OpOperand &use) {
- return !isa<tensor::DimOp>(use.getOwner());
- }) != 1) {
- return std::nullopt;
- }
-
- // Collect non-dim users.
- SmallVector<Operation *> nonDimUsers;
- for (Operation *user : op->getUsers()) {
- if (isa<tensor::DimOp>(user))
- continue;
- nonDimUsers.push_back(user);
- }
-
- // Find the use in a non-dim user that dominates all other non-dim users.
- for (auto &use : op->getUses()) {
- Operation *user = use.getOwner();
- if (isa<tensor::DimOp>(user))
- continue;
- if (llvm::all_of(nonDimUsers, [&](Operation *c) {
- return dominanceInfo.dominates(user, c);
- })) {
- return &use;
- }
- }
- return std::nullopt;
-}
-
-/// Returns true if the operands are fusable.
-static bool areOpsFusable(Operation *producer, Operation *consumer,
- const llvm::SmallBitVector &rootOuterParallelLoops) {
- // Collect all the uses from producer to consumer.
- SmallVector<OpOperand *> allUses;
- for (OpOperand &producerUse : producer->getUses()) {
- if (producerUse.getOwner() != consumer)
- continue;
- allUses.push_back(&producerUse);
- }
-
- // Check that the consumer and producer have compatible outer parallel loops.
- if (!llvm::all_of(allUses, [&](OpOperand *operand) {
- return hasCompatibleOuterParallelLoops(*operand,
- rootOuterParallelLoops);
- })) {
- return false;
- }
- return true;
-}
-
-/// For the fusion of root op -> elementwise operation to be bufferized
-/// in-place without use of extra memory, the result of the root operation
-/// must be able to reuse the buffer for the result of the elementwise
-/// operation. Check if that is possible for the input/init operand pair.
-static bool canUseInOperandAsInitOperand(OpOperand *inOperand,
- OpOperand *initOperand) {
- assert(inOperand->getOwner() == initOperand->getOwner() &&
- "expected in-operand and init-operand to be owned by same operation");
-
- // Check that the owner is a `generic` op.
- auto genericOp = dyn_cast<linalg::GenericOp>(inOperand->getOwner());
- if (!genericOp)
- return false;
-
- // All loops to be parallel.
- if (genericOp.getNumLoops() != genericOp.getNumParallelLoops()) {
- return false;
- }
-
- /// The input operand cannot be an init operand already.
- if (genericOp.isDpsInit(inOperand))
- return false;
-
- // If the init operand value is used it cannot be reused for the input
- // operand.
- if (genericOp.payloadUsesValueFromOperand(initOperand))
- return false;
-
- // Indexing map used to access the input and init have to match.
- if (genericOp.getMatchingIndexingMap(inOperand) !=
- genericOp.getMatchingIndexingMap(initOperand)) {
- return false;
- }
-
- // Types have to match for the input operand to reuse the buffer from the init
- // operand
- if (inOperand->get().getType() != initOperand->get().getType())
- return false;
-
- return true;
-}
-
-/// All operations in a dispatch should be vectorized, which isnt the case today
-/// This is an explicit list of operations that arent vectorized for now
-/// requiring special handling for now in dispatch region formation to avoid
-/// large stack allocations.
-static bool isVectorizedAlways(Operation *producer) {
- // TODO(#17155) : This is a black list of operations that are not vectorized
- // today (under the aggressive fusion flag). Remove this blacklist to return
- // true always.
- if (auto convOp = dyn_cast<linalg::Conv2DNhwcHwcfOp>(producer)) {
- auto strides = convOp.getStrides();
- return strides.isSplat() && strides.getSplatValue<int64_t>() == 1;
- }
- return true;
-}
-
-/// Returns true if this is a fusable use, while fusing a root with its
-/// consumer.
-static bool
-isFusableWithConsumer(OpOperand &fusedOperand,
- const llvm::SmallBitVector &rootOuterParallelLoops,
- FormDispatchRegionsPassOptions const &options) {
- Operation *producer = fusedOperand.get().getDefiningOp();
- Operation *consumer = fusedOperand.getOwner();
-
- // If consumer is a dequant operation, dont fuse it. These get cloned
- // into their consumers.
- if (LinalgExt::isBitExtendOp(consumer)) {
- return false;
- }
-
- // Fuse unset_encoding operations with `tensor.extract_slice` and elementwise
- // generic ops.
- if (isUnpackLikeOpViaExtractSliceOps(producer)) {
- // Fuse `unset_encoding` -> `extract_slice` op since they get folded into
- // `unpack` on materialization.
- if (isa<tensor::ExtractSliceOp>(consumer)) {
- auto sliceOp = cast<tensor::ExtractSliceOp>(consumer);
- return llvm::all_of(
- sliceOp.getMixedOffsets(),
- [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) &&
- llvm::all_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
- return isConstantIntValue(ofr, 1);
- });
- }
- // Fuse `unset_encoding/unpack` -> elementwise operations. Fuse unpack with
- // non-overlapping reductions (i.e., the reduction dimension is not packed).
- if (auto consumerLinalgOp = dyn_cast<linalg::LinalgOp>(consumer)) {
- if (hasNoPackedReductionDimensions(consumerLinalgOp, producer)) {
- return true;
- }
- return linalg::isElementwise(consumerLinalgOp) &&
- consumerLinalgOp.getNumLoops() ==
- llvm::cast<RankedTensorType>(producer->getResult(0).getType())
- .getRank();
- }
- return false;
- }
-
- if (isPackLikeOp(consumer)) {
- return TypeSwitch<Operation *, bool>(producer)
- .Case<tensor::PadOp>([&](auto padOp) { return true; })
- .Case<linalg::LinalgOp>([&](auto linalgOp) {
- auto producerIndexingMap = linalgOp.getIndexingMapMatchingResult(
- llvm::cast<OpResult>(fusedOperand.get()));
- // Make sure the producer op has an identitiy result indexing map. As
- // CPU backend currently can't handle tranpose between fused ops.
- return hasCompatibleOuterParallelLoops(
- cast<TilingInterface>(linalgOp.getOperation()),
- producerIndexingMap, rootOuterParallelLoops);
- })
- .Default([](Operation *) { return false; });
- }
-
- // By default, padding should be fused with producers. It is hard to square
- // this with fusion of pad with consumer. So for now split the difference.
- // Either fuse pad with producer or with consumer.
- if (auto padOp = dyn_cast<tensor::PadOp>(consumer)) {
- if (options.fusePadWithProducers || isPadUsedInSetEncoding(padOp)) {
- return isa<linalg::LinalgOp>(producer);
- }
- return false;
- }
-
- // Insert slice ops should always be fused with their producers.
- if (auto insertSliceOp = dyn_cast<tensor::InsertSliceOp>(consumer)) {
- // TODO: Enable multi-use slice source fusion.
- Value source = insertSliceOp.getSource();
- if (!source.hasOneUse() || source.getDefiningOp() != producer) {
- return false;
- }
- // Fuse in `insert_slice` consumer operations if destination is a fill.
- // TODO: This can be generalized, but destination cannot be a
- // `arith.constant` or other constant-like objects. `linalg.fill` captures a
- // common case of pad generalization.
- return insertSliceOp.getDest().getDefiningOp<linalg::FillOp>();
- }
-
- // TODO(#16025): Enable mmt4d fusion. It is disabled because the backends
- // can not set multi lowering_config properly. See the issue for more details.
- if (isa<linalg::Mmt4DOp>(producer)) {
- return false;
- }
-
- auto producerFusionOp =
- dyn_cast<LinalgExt::LinalgFusionOpInterface>(producer);
- auto consumerFusionOp =
- dyn_cast<LinalgExt::LinalgFusionOpInterface>(consumer);
- if (!producerFusionOp || !consumerFusionOp)
- return false;
-
- // Check that the consumer is all parallel.
- if (consumerFusionOp.getNumLoops() !=
- consumerFusionOp.getNumParallelLoops()) {
- return false;
- }
-
- if (!areOpsFusable(producer, consumer, rootOuterParallelLoops)) {
- return false;
- }
-
- // Check if the iteration spaces of the producer and consumer are same.
- // TODO(#12664): This is unnecessary requirement, but we need a better config
- // to tile the consumer with a larger iteration space.
- if (!options.aggressiveFusion) {
- auto producerIterationSpace = producerFusionOp.getStaticLoopRanges();
- auto consumerIterationSpace = consumerFusionOp.getStaticLoopRanges();
- if (producerIterationSpace.size() < consumerIterationSpace.size()) {
- return false;
- }
- }
-
- // Under aggressive fusion assume that the dispatches are vectorized. In which
- // case we dont need to account for the subsequent stack allocation condition.
- if (options.aggressiveFusion) {
- if (isVectorizedAlways(producer)) {
- return true;
- }
- }
-
- // While fusing with consumer, the result of the root might not be the final
- // result of the dispatch. To avoid a stack allocation we have to ensure that
- // all operations can bufferize without needing additional memory.
- auto consumerDstOp =
- dyn_cast<DestinationStyleOpInterface>(consumerFusionOp.getOperation());
- if (!consumerDstOp) {
- return true;
- }
-
- for (OpOperand *inputOperand : consumerDstOp.getDpsInputOperands()) {
- if (inputOperand->get().getDefiningOp() != producer)
- continue;
- if (isa<linalg::ConvolutionOpInterface>(producer) &&
- !llvm::any_of(
- consumerDstOp.getDpsInitsMutable(), [&](OpOperand &initOperand) {
- return canUseInOperandAsInitOperand(inputOperand, &initOperand);
- })) {
- return false;
- }
- }
-
- return true;
-}
-
-/// Fuses roots with its consumers. If a root is fused with its consumer, it is
-/// no more tagged as a root to aid with the dispatch region formation.
-static void
-fuseRootsWithConsumers(MLIRContext *context, ArrayRef<Operation *> roots,
- DominanceInfo const &dominanceInfo,
- FormDispatchRegionsPassOptions const &options) {
- // Fuse with consumers where possible.
- for (Operation *root : roots) {
- SmallVector<Operation *> workList;
- llvm::SmallBitVector rootOuterParallelLoops = getOuterParallelLoops(root);
- workList.push_back(root);
- while (!workList.empty()) {
- Operation *currRoot = workList.pop_back_val();
- assert(hasRootOpAttribute(currRoot) &&
- "unexpected non-root op in worklist");
-
- // Helper function to make the consumer the root instead of the producer
- // when they are to be fused.
- auto updateRootTo = [&context, &currRoot](Operation *newRoot) {
- int64_t rootNumber = getRootNumber(currRoot);
- setRootAttribute(context, newRoot, rootNumber);
- removeRootOpAttribute(currRoot);
- appendToFusionGroup(currRoot, rootNumber);
- };
-
- std::optional<OpOperand *> fusableUse =
- getFusableUse(currRoot, dominanceInfo,
- /*aggressiveFusion=*/options.aggressiveFusion);
- if (!fusableUse)
- continue;
-
- // Analyse the use to see if it is fusable.
- Operation *consumerOp = fusableUse.value()->getOwner();
- if (hasRootOpAttribute(consumerOp) ||
- hasFusionGroupsAttribute(consumerOp)) {
- continue;
- }
-
- if (isFusableWithConsumer(*(fusableUse.value()), rootOuterParallelLoops,
- options)) {
- updateRootTo(consumerOp);
- workList.push_back(consumerOp);
- }
- }
- }
-}
-
-/// Method to check if the consumer of a use can be fused with its producer.
-static bool
-isFusableWithProducer(OpOperand &operand,
- const llvm::SmallBitVector &rootOuterParallelLoops,
- FormDispatchRegionsPassOptions const &options) {
- Operation *producer = operand.get().getDefiningOp();
- Operation *consumer = operand.getOwner();
-
- if (auto padOp = dyn_cast<tensor::PadOp>(consumer)) {
- if (options.fusePadWithProducers || isPadUsedInSetEncoding(padOp)) {
- return isa<linalg::LinalgOp>(producer);
- }
- return false;
- }
-
- if (options.fusePadWithConsumers && isa<tensor::PadOp>(producer) &&
- isa<linalg::ConvolutionOpInterface>(consumer)) {
- return true;
- }
-
- // Don't fuse attention with it's producer
- if (isa<LinalgExt::AttentionOp>(consumer)) {
- return false;
- }
-
- if (isPackLikeOp(consumer)) {
- return TypeSwitch<Operation *, bool>(producer)
- .Case<tensor::PadOp>([&](auto padOp) { return true; })
- .Case<linalg::LinalgOp>([&](auto linalgOp) {
- if (auto packOp = dyn_cast<tensor::PackOp>(consumer)) {
- // TODO(#12746): fusion of pack with dynamic inner tile size
- // causes an error in backend. Disable for now.
- if (!packOp.getInnerTiles().empty()) {
- return false;
- }
- }
- auto producerIndexingMap = linalgOp.getIndexingMapMatchingResult(
- llvm::cast<OpResult>(operand.get()));
- // Make sure the producer op has an identitiy result indexing map. As
- // CPU backend currently can't handle tranpose between fused ops.
- return hasCompatibleOuterParallelLoops(
- cast<TilingInterface>(linalgOp.getOperation()),
- producerIndexingMap, rootOuterParallelLoops);
- })
- .Default([](Operation *) { return false; });
- }
-
- if (!isa<LinalgExt::LinalgFusionOpInterface>(consumer) ||
- !isa<LinalgExt::LinalgFusionOpInterface>(producer)) {
- return false;
- }
-
- if (!options.aggressiveFusion) {
- auto consumerFusionOp = dyn_cast<DestinationStyleOpInterface>(consumer);
- if (consumerFusionOp && !consumerFusionOp.isDpsInit(&operand)) {
- return false;
- }
- }
-
- return areOpsFusable(producer, consumer, rootOuterParallelLoops);
-}
-
-/// Starting from the `root` op, traverse the operand use-def chain
-/// in reverse to fuse with producers.
-static void
-fuseRootsWithProducers(MLIRContext *context, Operation *root, unsigned groupNum,
- DominanceInfo const &dominanceInfo,
- FormDispatchRegionsPassOptions const &options) {
- SmallVector<Operation *> worklist;
- worklist.push_back(root);
- llvm::SmallBitVector rootOuterParallelLoops = getOuterParallelLoops(root);
- while (!worklist.empty()) {
- Operation *candidate = worklist.pop_back_val();
- for (OpOperand &operand : candidate->getOpOperands()) {
- Operation *producer = operand.get().getDefiningOp();
- if (!producer)
- continue;
- if (isClonableIntoDispatchOp(producer) ||
- hasFusionGroupsAttribute(producer) || hasRootOpAttribute(producer)) {
- continue;
- }
-
- std::optional<OpOperand *> fusableUse =
- getFusableUse(producer, dominanceInfo,
- /*aggressiveFusion=*/options.aggressiveFusion);
- if (!fusableUse || fusableUse.value()->getOwner() != candidate)
- continue;
-
- if (!isFusableWithProducer(operand, rootOuterParallelLoops, options)) {
- continue;
- }
-
- appendToFusionGroup(producer, groupNum);
- worklist.push_back(producer);
- }
- }
-}
-
-/// Some heuristic is needed to fuse a dispatchable op with root operations
-/// using tile + fuse. Using some heuristic, each root operation is tagged with
-/// an ID (using an IntegerAttr with name `kRootOpAttr`) and all dispatchable
-/// ops to be fused with it is tagged with the same ID (using a list of
-/// IntegerAttr with name `kFusionGroupsAttr`). Each dispatchable operation can
-/// be marked to fuse with multiple root operations (i.e. replicated). For now a
-/// very simple heuristic is used below, but the mechanism should be general
-/// enough to capture any heuristic.
-static unsigned
-decideFusableLinalgOps(Region ®ion, DominanceInfo const &dominanceInfo,
- FormDispatchRegionsPassOptions const &options,
- unsigned numRootOps = 0) {
- MLIRContext *context = region.getContext();
- OpBuilder builder(context);
- for (Block &block : region) {
- // Dispatch region formation works by first cloning the root into
- // the dispatch region and then pulling operations in.
- // So procedure here is to
- // - First find the roots
- // - To fuse with consumers make the consumer the root.
- SmallVector<Operation *> roots;
- for (Operation &op : llvm::reverse(block)) {
- if (isa<scf::SCFDialect>(op.getDialect())) {
- for (auto ®ion : op.getRegions()) {
- numRootOps = decideFusableLinalgOps(region, dominanceInfo, options,
- numRootOps);
- }
- continue;
- }
-
- // Start with a root operation and fuse its producers.
- if (hasFusionGroupsAttribute(&op) || !isRootOp(&op))
- continue;
- unsigned newGroup = numRootOps++;
- setRootAttribute(context, &op, newGroup);
-
- fuseRootsWithProducers(context, &op, newGroup, dominanceInfo, options);
- roots.push_back(&op);
- }
- roots = llvm::to_vector(llvm::reverse(roots));
- fuseRootsWithConsumers(context, roots, dominanceInfo, options);
- }
-
- // Once all root linalg ops have been tagged, put all remaining generic ops
- // into their own dispatches.
- for (Block &block : region) {
- SmallVector<Operation *> roots;
- for (Operation &op : llvm::reverse(block)) {
- // If it is part of a fusion group or root op, ignore it.
- if (hasFusionGroupsAttribute(&op) || hasRootOpAttribute(&op))
- continue;
- // Only look for Linalg ops here. Avoid moving `linalg.fill` that aren't
- // fused with anything else into their own dispatches since it is better
- // to convert them to splats. Also avoid moving dequantization-like ops
- // into their own dispatch since it is better to clone these ops and avoid
- // materializing large tensors between dispatches.
- if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp,
- IREE::Encoding::SetEncodingOp>(op) ||
- isa<linalg::FillOp>(op) || LinalgExt::isBitExtendOp(&op)) {
- continue;
- }
-
- unsigned newGroup = numRootOps++;
- setRootAttribute(context, &op, newGroup);
-
- fuseRootsWithProducers(context, &op, newGroup, dominanceInfo, options);
- roots.push_back(&op);
- }
- roots = llvm::to_vector(llvm::reverse(roots));
- fuseRootsWithConsumers(context, roots, dominanceInfo, options);
- }
-
- return numRootOps;
-}
-
-//===----------------------------------------------------------------------===//
-// Dispatch region formation
-//===----------------------------------------------------------------------===//
-
-/// Create IREE::Flow::DispatchGroupsOps based on a fusion heuristic.
-static LogicalResult
-createFusionGroups(TensorDimTrackingRewriter &rewriter,
- mlir::FunctionOpInterface funcOp,
- DominanceInfo const &dominanceInfo,
- FormDispatchRegionsPassOptions const &options) {
- // Step 1: Decide fusion groups (heuristic). This marks rootOps with an
- // attribute
- unsigned numRoots =
- decideFusableLinalgOps(funcOp.getFunctionBody(), dominanceInfo, options);
- SmallVector<Operation *> roots(numRoots, nullptr);
- DenseMap<unsigned, SmallVector<Operation *>> producers;
-
- LLVM_DEBUG({
- llvm::dbgs() << "\n--- After deciding fusion groups ---\n";
- funcOp->print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
- llvm::dbgs() << "\n\n";
- });
-
- // TODO: Incrementally add ops to an empty DispatchGroupOp instead of
- // annotating fusion group IDs via attributes.
- funcOp.walk([&](Operation *op) {
- if (hasRootOpAttribute(op)) {
- roots[getRootNumber(op)] = op;
- removeRootOpAttribute(op);
- }
- if (hasFusionGroupsAttribute(op)) {
- assert(getFusionGroups(op).size() == 1 && "expected exactly one group");
- producers[getFusionGroups(op).front()].push_back(op);
- removeFusionGroupsAttribute(op);
- }
- });
-
- // Step 2. Create a DispatchRegionOp for every fusion group.
- OpBuilder::InsertionGuard g(rewriter);
- SmallVector<IREE::Flow::DispatchRegionOp> regionOps;
- for (const auto &it : llvm::enumerate(roots)) {
- // Simplify tensor::DimOps.
- {
- SmallVector<tensor::DimOp> dimOps = rewriter.getTensorDimOps();
- if (failed(IREE::Flow::simplifyDimOps(rewriter, dimOps))) {
- return failure();
- }
- }
-
- // Create fusion group.
- IREE::Flow::DispatchRegionOp regionOp;
- auto maybeRegionOp =
- IREE::Flow::wrapOpInDispatchRegion(rewriter, it.value());
- if (failed(maybeRegionOp))
- return failure();
- regionOp = *maybeRegionOp;
-
- // Sort producers topologically. All producers must be in the same block
- // as the root.
- bool sortResult = mlir::computeTopologicalSorting(producers[it.index()]);
- (void)sortResult;
- assert(sortResult && "could not compute topological sorting");
-
- // Move ops into the region.
- for (Operation *producer : llvm::reverse(producers[it.index()])) {
- // Simplify tensor::DimOps.
- {
- SmallVector<tensor::DimOp> dimOps = rewriter.getTensorDimOps();
- if (failed(IREE::Flow::simplifyDimOps(rewriter, dimOps))) {
- return failure();
- }
- }
-
- auto newRegionOp =
- movePrecedingOpsIntoDispatchRegion(rewriter, producer, regionOp);
- if (failed(newRegionOp))
- return failure();
- regionOp = *newRegionOp;
- }
- // Simplify tensor::DimOps.
- {
- SmallVector<tensor::DimOp> dimOps = rewriter.getTensorDimOps();
- if (failed(IREE::Flow::simplifyDimOps(rewriter, dimOps))) {
- return failure();
- }
- }
- regionOps.push_back(regionOp);
- }
-
- LLVM_DEBUG({
- llvm::dbgs() << "\n--- After creating flow.dispatch.region ---\n";
- funcOp->print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
- llvm::dbgs() << "\n\n";
- });
-
- return success();
-}
-
-namespace {
-/// Pass declaration.
-struct FormDispatchRegionsPass
- : public IREE::Flow::impl::FormDispatchRegionsPassBase<
- FormDispatchRegionsPass> {
- using IREE::Flow::impl::FormDispatchRegionsPassBase<
- FormDispatchRegionsPass>::FormDispatchRegionsPassBase;
- void runOnOperation() override;
-};
-} // namespace
-
-/// Create dispatch.region Ops based on a fusion heuristic.
-void FormDispatchRegionsPass::runOnOperation() {
- mlir::FunctionOpInterface funcOp = getOperation();
- DominanceInfo const &dominanceInfo = getAnalysis<DominanceInfo>();
- TensorDimTrackingRewriter rewriter(funcOp);
- FormDispatchRegionsPassOptions options{aggressiveFusion, fusePadWithConsumers,
- fusePadWithProducers};
- if (failed(createFusionGroups(rewriter, funcOp, dominanceInfo, options))) {
- funcOp->emitOpError("failed to create fusion groups");
- return signalPassFailure();
- }
-}
} // namespace mlir::iree_compiler::IREE::Flow
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.h
index ab31090..c849123 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.h
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.h
@@ -37,8 +37,8 @@
/// Computes the workload and provides a workload region builder for the given
/// root op.
-FailureOr<Flow::WorkloadBuilder> getWorkloadBuilder(OpBuilder &builder,
- Operation *rootOp);
+FailureOr<IREE::Flow::WorkloadBuilder> getWorkloadBuilder(OpBuilder &builder,
+ Operation *rootOp);
/// Simplfy the given tensor::DimOps as much as possible.
/// * Static dimensions are replaced by constant.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index 5ce1211..c646852 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -53,54 +53,6 @@
"occurrences of the dispatch symbol."),
llvm::cl::init(""));
-static llvm::cl::opt<bool> clDetensoring(
- "iree-flow-enable-detensoring",
- llvm::cl::desc(
- "Enable changing of tensor operations into scalar operations."),
- llvm::cl::init(false));
-
-static llvm::cl::opt<bool> clEnablePadHandling(
- "iree-flow-enable-pad-handling",
- llvm::cl::desc("Enable native handling of tensor.pad operations."),
- llvm::cl::init(false));
-
-static llvm::cl::opt<bool> clEnableFusePaddingIntoLinalgConsumerOps(
- "iree-flow-enable-fuse-padding-into-linalg-consumer-ops",
- llvm::cl::desc("Enable fusing tensor.pad ops into Linalg consumer ops."),
- llvm::cl::init(false));
-
-static llvm::cl::opt<bool> clEnableFusePaddingIntoLinalgProducerOps(
- "iree-flow-enable-fuse-padding-into-linalg-producer-ops",
- llvm::cl::desc("Enable fusing tensor.pad ops into Linalg consumer ops."),
- llvm::cl::init(false));
-
-static llvm::cl::opt<bool> clEnableFuseHorizontalContractions(
- "iree-flow-enable-fuse-horizontal-contractions",
- llvm::cl::desc(
- "Enables horizontal fusion of contractions with one common operand"),
- llvm::cl::init(false));
-
-static llvm::cl::opt<bool> clCollapseReductionDims(
- "iree-flow-collapse-reduction-dims",
- llvm::cl::desc("Enable collapsing of reduction dims"),
- llvm::cl::init(false));
-
-static llvm::cl::opt<bool>
- clEnableFuseMultiUse("iree-flow-fuse-multi-use",
- llvm::cl::desc("Fuse multi-use ops."),
- llvm::cl::init(false));
-
-static llvm::cl::opt<bool> clEnableElementWiseFuseMultiReduction(
- "iree-flow-element-wise-fuse-multi-reduction",
- llvm::cl::desc("Enable element-wise fusion of multi-reduction loop ops."),
- llvm::cl::init(true));
-
-static llvm::cl::opt<bool> clEnableAggressiveFusion(
- "iree-flow-enable-aggressive-fusion",
- llvm::cl::desc("Aggressive fusion opportunities that are behind a flag "
- "since all backends dont support it yet"),
- llvm::cl::init(false));
-
static llvm::cl::opt<bool>
clDumpDispatchGraph("iree-flow-dump-dispatch-graph",
llvm::cl::desc("Dump a dot graph for dispatches."),
@@ -111,35 +63,12 @@
llvm::cl::desc("Output file name for a dispatch graph dump."),
llvm::cl::init("dispatch.dot"));
-static llvm::cl::opt<std::string> clDispatchTransformFileName(
- "iree-flow-dispatch-use-transform-dialect",
- llvm::cl::desc("MLIR file containing a top-level module that specifies "
- "the transformations to apply to form dispatch regions."),
- llvm::cl::init(""));
-
static llvm::cl::opt<bool> clZeroFillEmptyTensors(
"iree-flow-zero-fill-empty-tensors",
llvm::cl::desc(
"Zero fill empty tensors instead of leaving them uninitialized."),
llvm::cl::init(false));
-static llvm::cl::opt<bool> clEnableDataTiling(
- "iree-flow-experimental-data-tiling",
- llvm::cl::desc("Enable data-tiling at flow level, i.e., it sets encodings "
- "in dispatch regions, hoist them out of region, and enables "
- "fusion for the set_encodings. This is still an "
- "experimental path. The current main data tiling path is "
- "iree-opt-data-tiling, which is on by default. To use this "
- "path, --iree-opt-data-tiling=false must be set as wells"),
- llvm::cl::init(false));
-
-static llvm::cl::opt<int> clPadFactor(
- "iree-flow-pad-factor",
- llvm::cl::desc("Provides padding size hints that will be attached to "
- "encodings. This only affects the experimental data tiling "
- "path in Flow with iree-flow-experimental-data-tiling."),
- llvm::cl::init(32));
-
namespace mlir::iree_compiler::IREE::Flow {
using FunctionLikeNest =
@@ -175,203 +104,11 @@
// Pipelines
//===----------------------------------------------------------------------===//
-void addDispatchRegionCreationPreprocessingPasses(OpPassManager &passManager) {
- // 1. Do some simple elementwise op fusion. This could be skipped,
- // but could reduce the surface area of ops to handle later.
- FunctionLikeNest(passManager)
- .addPass([]() {
- return IREE::Flow::createElementwiseOpFusionPass(
- ElementwiseOpFusionPassOptions{
- clEnableElementWiseFuseMultiReduction});
- })
- .addPass(IREE::Flow::createCanonicalizerPass)
- .addPass(mlir::createCSEPass)
-
- // 2. Bubble up expand_shape ops (or sink collapse_shape ops) to get
- // elementwise operation into higher dimensions for more fusion
- // opportunities.
- .addPass(IREE::Flow::createBubbleUpExpandShapesPass)
- .addPass(IREE::Flow::createCanonicalizerPass)
- .addPass(mlir::createCSEPass)
-
- // 3. Perform elementwise operation fusion again (now with higher
- // dimensionality).
- .addPass([]() {
- return IREE::Flow::createElementwiseOpFusionPass(
- ElementwiseOpFusionPassOptions{
- clEnableElementWiseFuseMultiReduction});
- })
- .addPass(IREE::Flow::createCanonicalizerPass)
- .addPass(mlir::createCSEPass)
-
- // 4. After elementwise operation fusion sink reshapes that block
- // producer-consumer fusion.
- .addPass(IREE::Flow::createSinkReshapesPass)
- .addPass(IREE::Flow::createCanonicalizerPass)
- .addPass(mlir::createCSEPass);
-
- if (clEnableFuseHorizontalContractions) {
- FunctionLikeNest(passManager)
- .addPass(createFuseHorizontalContractionsPass)
- .addPass(mlir::createCanonicalizerPass)
- .addPass(mlir::createCSEPass);
- }
-
- FunctionLikeNest(passManager)
- // 5. After all the reshape propagations, fuse elementwise operations
- // even if the producer has multiple uses.
- .addPass(IREE::Flow::createFuseMultiUseElementwiseProducerPass)
-
- // 6. Some more "post elementwise fusion passes".
- // a. Detensorize.
- // TODO: This is probably not in the right place.
- .addPredicatedPass(clDetensoring,
- [&]() { return mlir::createLinalgDetensorizePass(); })
- .addPass(IREE::Flow::createCanonicalizerPass)
- .addPass(mlir::createCSEPass)
-
- // b. For ops with multiple reduction dimensions, collapse the
- // reduction dimension.
- // TODO: This pass is only needed till all backends can handle
- // multiple reduction dimensions.
- .addPredicatedPass(clCollapseReductionDims,
- IREE::Flow::createCollapseReductionDimensionsPass)
-
- // c. Split reduction operations into parallel and reduction, i.e
- // .
- .addPass(IREE::Flow::createSplitReductionPass)
-
- // d. Transpose generic ops to
- // - help with dispatch region formation.
- // - move reduction iterators to be innermost.
- .addPass(IREE::Flow::createTransposeGenericOpsPass);
-}
-
-// Pipeline to first create `flow.dispatch.region` ops and then lower to
-// `flow.dispatch.workgroup` ops.
-static void addDispatchRegionCreationPasses(OpPassManager &passManager) {
- FunctionLikeNest(passManager)
- // Only want use the transform dialect for some dispatch regions and let
- // the FormDispatchRegions handle the rest. This only moves the root
- // compute op into the dispatch region, so that we can run additional
- // transformations afterwards with a simple region and without bothering
- // producers.
- .addPredicatedPass(
- !clDispatchTransformFileName.empty(),
- [&]() {
- DispatchWithTransformDialectPassOptions options;
- options.transformSpecPath = clDispatchTransformFileName;
- return createDispatchWithTransformDialectPass(options);
- })
- // Create dispatches for scalar operations as roots
- .addPass(IREE::Flow::createFormScalarDispatchesPass)
- // Create `flow.dispatch.region` centered around a root and fuse with
- // producers and consumers.
- .addPass([&]() {
- return IREE::Flow::createFormDispatchRegionsPass(
- FormDispatchRegionsPassOptions{
- clEnableAggressiveFusion,
- clEnableFusePaddingIntoLinalgConsumerOps,
- clEnableFusePaddingIntoLinalgProducerOps});
- })
- // Clone all producers into the dispatch region to perpare for being
- // isolated from above. This enables running additional transformations
- // afterwards that would need the full dispatch content but don't want to
- // handle explicit captures as materialized as dispatch workgroup operands
- // and block arguments.
- .addPass(IREE::Flow::createCloneProducersIntoDispatchRegionsPass);
- // Experimental data tiling path. The intent of this path is to set encodings
- // after fusion decisions have already been made, so encodings can be
- // separated from compiler fusion decisions.
- if (clEnableDataTiling) {
- SetEncodingPassOptions options{clPadFactor};
- FunctionLikeNest(passManager)
- // Set encodings on all eligible ops. All ops should be in compiler
- // formed dispatch regions, so encodings will be placed inside of the
- // dispatch regions with the data-tiled op.
- .addPass([&]() { return createSetEncodingPass(options); })
- // SetEncodingOps should not be in the same dispatch as the data-tiled
- // op, so hoist them out of their current dispatch regions. Also, bubble
- // SetEncodingOps through special operations like bit-extending ops and
- // broadcasting ops.
- .addPass(IREE::Flow::createHoistEncodingOpsPass);
- }
- FunctionLikeNest(passManager)
- // Collapse dimensions of linalg Ops.
- .addPass(IREE::Flow::createCollapseDimensionsPass)
- // Convert dispatch regions into dispatch workgroups by capturing values
- // and making the new workgroups isolated from above.
- .addPass(IREE::Flow::createConvertDispatchRegionsToWorkgroupsPass)
- // Convert tensor operations to flow.tensor ops.
- // - Convert extract/insert slice to flow update ops when the tensor op
- // acts as a contiguous view of the tensor
- // - Apply tensor -> flow patterns
- .addPass(IREE::Flow::createConvertTensorToFlowPass)
- .addPass(IREE::Flow::createCanonicalizerPass)
- /// Creates the workgroup count region where the materialized computation
- /// is derived as a program slice of the body of the dispatch. This method
- /// - Computes the `workload` to use for the `workgroupsOp`, which are
- /// derived from the values captured by the `workgroupsOp`.
- /// - Populates the workgroup count region for this with the placeholder
- /// op `flow.dispatch.workgroups_count_from_body_slice`. This op is
- /// resolved in the backends into the actual workgroup count
- /// computation.
- /// - To correlate back to the captured workload,
- /// `flow.dispatch.workload.ordinal`
- /// to map the captured operand to the position in the workload list.
- .addPass(IREE::Flow::createMaterializeDefaultWorkgroupCountRegionPass);
-}
-
-// Apply preprocessing and form dispatch regions
-void addDispatchRegionCreationPasses(OpPassManager &passManager,
- const TransformOptions &transformOptions) {
- FunctionLikeNest(passManager)
- // Preprocess the input to a form more amenable for fusion.
- .addPass(IREE::Flow::createFusionPreprocessingPass)
- .addPass(IREE::Flow::createCanonicalizerPass)
- .addPass(mlir::createCSEPass);
-
- addDispatchRegionCreationPreprocessingPasses(passManager);
- addDispatchRegionCreationPasses(passManager);
-}
-
void buildFlowTransformPassPipeline(OpPassManager &passManager,
const TransformOptions &transformOptions) {
// Start of Flow pipeline, verify input legality.
passManager.addPass(IREE::Flow::createVerifyInputLegalityPass());
- // Inject tensor tracing early as we need to have the tracers in the IR
- // prior to dispatch region formation where we may lose access to them.
- FunctionLikeNest(passManager)
- .addPass(IREE::Flow::createInjectTensorTracingPass);
-
- // Transform pad operations into linalg.fill + tensor.insert_slice.
- // This is a WAR for not having native pad handling.
- if (!clEnablePadHandling && !clEnableFusePaddingIntoLinalgProducerOps) {
- passManager.addPass(IREE::Flow::createTensorPadToTensorInsertSlicePass(
- TensorPadToTensorInsertSlicePassOptions{
- /*skipSingleLinalgOpUses=*/
- clEnableFusePaddingIntoLinalgConsumerOps}));
- }
-
- {
- // We run these under a fixed-point iteration such that we can perform
- // inter-procedural, intra-procedural, and canonicalization as separably
- // verifiable/reusable passes. IPO will fold duplicate arguments/results
- // and inline constants to allow the local optimizations to work more
- // effectively.
- OpPassManager ipoPipeline(mlir::ModuleOp::getOperationName());
-
- // IPO and other cleanups.
- addCleanupPatterns(ipoPipeline);
-
- // Run fixed-point iteration on the IPO pipeline.
- passManager.addPass(
- IREE::Util::createFixedPointIteratorPass(std::move(ipoPipeline)));
- }
-
- addDispatchRegionCreationPasses(passManager, transformOptions);
-
FunctionLikeNest(passManager)
.addPass(IREE::Flow::createCaptureDynamicDimsPass)
.addPass(IREE::Flow::createCanonicalizerPass)
@@ -504,21 +241,6 @@
[](OpPassManager &passManager, const TransformOptions &transformOptions) {
buildFlowTransformPassPipeline(passManager, transformOptions);
});
-
- PassPipelineRegistration<> flowDispatchRegionFormationPreprocessingPipeline(
- "iree-flow-dispatch-region-formation-preprocessing-pipeline",
- "Flag used to run preprocessing passes that run passes before dispatch "
- "region formation. Used only for testing",
- [](OpPassManager &passManager) {
- addDispatchRegionCreationPreprocessingPasses(passManager);
- });
-
- PassPipelineRegistration<> flowDispatchRegionCreationPipeline(
- "iree-flow-dispatch-region-creation-pipeline",
- "Flag used to run passes that form dispatch regions",
- [](OpPassManager &passManager) {
- addDispatchRegionCreationPasses(passManager);
- });
}
namespace {
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
index fabde78..1328cce 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
@@ -9,297 +9,6 @@
include "mlir/Pass/PassBase.td"
-// File organization:
-// Groups passes that are related under one banner //===....===//. For example
-// the dispatch region creation preprocessing passes and dispatch region
-// formation passes are a couple of such groups. For any new pass add it to the
-// relevant group and keep them alphabetical within a group.
-
-//===---------------------------------------------------------------------===//
-// Dispatch region creation preprocessing passes :
-// Passes that transform the program before forming dispatches, like
-// - Elementwise operation fusion
-// - Reshape propagation passes
-//===---------------------------------------------------------------------===//
-
-def BubbleUpExpandShapesPass :
- Pass<"iree-flow-bubble-up-expand-shapes"> {
- let summary = "Propagate expand_shapes up the program (and collapse_shapes down).";
- let dependentDialects = [
- "mlir::affine::AffineDialect",
- ];
-}
-
-def CollapseReductionDimensionsPass :
- Pass<"iree-flow-collapse-reduction-dimensions", ""> {
- let summary = "Collapse reduction dimensions when possible.";
- let dependentDialects = [
- "mlir::linalg::LinalgDialect",
- ];
-}
-
-def ElementwiseOpFusionPass :
- Pass<"iree-flow-elementwise-op-fusion", ""> {
- let summary = "Fuse elementwise operations.";
- let options = [
- Option<"fuseMultiReduction", "fuse-multi-reduction", "bool",
- /*default=*/"true", "Fuse ops that have multiple reduction iterators">
- ];
- let dependentDialects = [
- "mlir::affine::AffineDialect",
- ];
-}
-
-def FoldUnitExtentDimsPass :
- Pass<"iree-flow-fold-unit-extent-dims", "mlir::ModuleOp"> {
- let summary = "Fold unit extent dimension of operations.";
- let description = [{
- Imports upstream patterns to fold unit extent dims but with IREE control.
- }];
- let dependentDialects = [
- "mlir::affine::AffineDialect",
- "mlir::arith::ArithDialect",
- "mlir::linalg::LinalgDialect",
- "mlir::tensor::TensorDialect",
- ];
-}
-
-def FuseHorizontalContractionsPass:
- InterfacePass<"iree-flow-fuse-horizontal-contractions", "mlir::FunctionOpInterface"> {
- let summary = "Fuses horizontal contraction ops without fusions";
- let dependentDialects = [
- "mlir::arith::ArithDialect",
- "mlir::tensor::TensorDialect",
- ];
- let options = [
- Option<"fusionLimit", "fusion-limit", "int",
- /*default=*/"3", "Maximum number of contractions fused into one">
- ];
- let statistics = [
- Statistic<"numFusionGroups", "num-fusion-groups", "Number of fusion groups found">,
- Statistic<"numSize2FusionGroups", "num-size-2-groups", "Number of fusion groups of size 2">,
- Statistic<"numSize3FusionGroups", "num-size-3-groups", "Number of fusion groups of size 3">
- ];
-}
-
-def FuseMultiUseElementwiseProducerPass :
- InterfacePass<"iree-flow-fuse-multi-use-elementwise-producer",
- "mlir::FunctionOpInterface"> {
- let summary = "Fuse elementwise linalg operations on tensors when producers have multiple uses.";
- let options = [
- Option<"numIterations", "num-iterations", "unsigned",
- /*default=*/"2", "Number of iterations to fuse multiuse ops">
- ];
- let dependentDialects = [
- "mlir::affine::AffineDialect",
- "mlir::arith::ArithDialect",
- "mlir::linalg::LinalgDialect",
- "mlir::math::MathDialect",
- ];
-}
-
-def FusionPreprocessingPass :
- Pass<"iree-flow-fusion-preprocessing", ""> {
- let summary = "Run useful preprocessing patterns that help with fusion.";
- let dependentDialects = [
- "mlir::affine::AffineDialect",
- ];
-}
-
-def SinkReshapesPass :
- Pass<"iree-flow-sink-reshapes", ""> {
- let summary = "Sink reshapes to allow for compute op -> consumer fusion.";
- let dependentDialects = [
- "mlir::affine::AffineDialect",
- "mlir::arith::ArithDialect",
- ];
-}
-
-def SplitReductionPass :
- Pass<"iree-flow-split-reduction-ops", ""> {
- let summary = "Split reduction dimension to increase parallelism.";
- let dependentDialects = [
- "mlir::linalg::LinalgDialect",
- ];
-}
-
-def TensorPadToTensorInsertSlicePass :
- Pass<"iree-flow-tensor-pad-to-tensor-insert-slice", ""> {
- let summary = "Convert tensor.pad into linalg.fill + tensor.insert_slice.";
- let options = [
- Option<"skipSingleLinalgOpUses", "skip-one-linalg-use-case", "bool",
- /*default=*/"false",
- "Skip the op that has only one use which is used"
- "by a Linalg op">,
- ];
- let dependentDialects = [
- "mlir::arith::ArithDialect",
- "mlir::linalg::LinalgDialect",
- "mlir::math::MathDialect",
- "mlir::memref::MemRefDialect",
- ];
-}
-
-def TransposeGenericOpsPass :
- Pass<"iree-flow-transpose-generic-ops", ""> {
- let summary = "Transpose generic op loops.";
- let dependentDialects = [
- "mlir::linalg::LinalgDialect",
- ];
-}
-
-//===---------------------------------------------------------------------===//
-// Dispatch region creation passes.
-//===---------------------------------------------------------------------===//
-
-def CloneProducersIntoDispatchRegionsPass :
- InterfacePass<"iree-flow-clone-producers-into-dispatch-regions", "mlir::FunctionOpInterface"> {
- let summary = "Clone producers into dispatch regions to be isolated above.";
- let description = [{
- Pass to clone into dispatch regions producers of values used in the dispatch
- regions but defined in the above. This prepares the dispatch regions for
- converting to dispatch workgroups with explicit captures.
- }];
-}
-
-def CollapseDimensionsPass :
- InterfacePass<"iree-flow-collapse-dimensions", "mlir::FunctionOpInterface"> {
- let summary = "Collapse dimensions of Linalg Ops on tensor ops.";
- let options = [
- Option<"maxIterations", "max-iterations", "int",
- /*default=*/"10",
- "Maximum number of iterations to wait for collapse dimensions to converge"
->,
- ];
- let description = [{
- Collapse dimensions of Linalg Ops on tensor ops inside dispatch.region ops
- and hoist the reshaping operations out of the dispatch.
- }];
-}
-
-def ConvertDispatchRegionsToWorkgroupsPass :
- InterfacePass<"iree-flow-convert-dispatch-regions-to-workgroups", "mlir::FunctionOpInterface"> {
- let summary = "Convert dispatch regions to dispatch workgroups.";
- let description = [{
- Pass to convert dispatch regions to dispatch workgroups. This pass is
- intended to be used after dispatch regions have been formed.
- }];
- let dependentDialects = [
- "mlir::affine::AffineDialect",
- "mlir::linalg::LinalgDialect",
- "mlir::scf::SCFDialect",
- "mlir::tensor::TensorDialect",
- "IREE::Flow::FlowDialect",
- ];
- let statistics = [
- Statistic<"numDispatches", "num-dispatches", "Number of dispatches created">
- ];
-}
-
-def ConvertTensorToFlowPass :
- InterfacePass<"iree-flow-convert-tensor-to-flow", "mlir::FunctionOpInterface"> {
- let summary = "Convert tensor operations to flow";
- let description = [{
- Pass to convert tensor operations to flow.tensor.* operations.
- }];
- let dependentDialects = [
- "mlir::affine::AffineDialect",
- "mlir::arith::ArithDialect",
- "mlir::linalg::LinalgDialect",
- "mlir::tensor::TensorDialect",
- "IREE::Flow::FlowDialect",
- ];
- let statistics = [
- Statistic<"numSlowCopyDispatches", "num-slow-copy-dispatches",
- "Number of slow copy dispatches (for handling slices) created">
- ];
-}
-
-def DispatchWithTransformDialectPass : Pass<"iree-flow-dispatch-with-transform-dialect"> {
- let summary = "Dispatch Linalg operations on tensors by using the transform dialect interpreter.";
- let description = [{
- Pass to perform dispatch of Linalg on tensor ops by using the transform
- dialect. Dispatch regions are created as specified by the transform module
- that is parsed from `transformSpecPath`.
-
- TODO: Drop this pass in favor of the one upstream. The one upstream requires
- separate loading of the module and thus isn't suited for single-use
- transform scripts.
- }];
- let dependentDialects = [
- "mlir::affine::AffineDialect",
- "mlir::arith::ArithDialect",
- "mlir::linalg::LinalgDialect",
- "mlir::pdl::PDLDialect",
- "mlir::pdl_interp::PDLInterpDialect",
- "mlir::scf::SCFDialect",
- "mlir::tensor::TensorDialect",
- "mlir::transform::TransformDialect",
- "IREE::Flow::FlowDialect",
- "IREE::LinalgExt::IREELinalgExtDialect",
- ];
- let options = [
- Option<"disableExpensiveChecks", "disable-expensive-checks", "bool",
- "false",
- "Disable expensive checks in the interpreter for a faster run.">,
- Option<"transformSpecPath", "transform-spec-path", "std::string",
- /*default=*/"", "File path to the transform spec to use.">,
- ];
-}
-
-def FormDispatchRegionsPass :
- InterfacePass<"iree-flow-form-dispatch-regions", "mlir::FunctionOpInterface"> {
- let summary = "Form Dispatch Region Ops from Linalg operations on tensors to form dispatch.regions.";
- let options = [
- Option<"aggressiveFusion", "aggressive-fusion", "bool",
- /*default=*/"false", "Aggressive mode enabling fusions not ready for all backends">,
- Option<"fusePadWithConsumers", "fuse-pad-with-consumers", "bool",
- /*default=*/"false", "Enable fusing pad with consumer">,
- Option<"fusePadWithProducers", "fuse-pad-with-producers", "bool",
- /*default=*/"false", "Enable fusion of pad with producers">
- ];
- let description = [{
- Pass to form dispatch.region ops from Linalg on tensor ops. A dispatch region
- is created for each tiled loop nest. This pass only moves the root compute op
- into the dispatch region, allowing producers to be outside.
- }];
- let dependentDialects = [
- "mlir::affine::AffineDialect",
- "mlir::linalg::LinalgDialect",
- "mlir::scf::SCFDialect",
- "mlir::tensor::TensorDialect",
- "IREE::Flow::FlowDialect",
- "IREE::LinalgExt::IREELinalgExtDialect",
- ];
-}
-
-def FormScalarDispatchesPass :
- InterfacePass<"iree-flow-form-scalar-dispatches", "mlir::FunctionOpInterface"> {
- let summary = "Form Dispatch Regions for scalar computations.";
- let dependentDialects = [
- "mlir::affine::AffineDialect",
- "mlir::linalg::LinalgDialect",
- "mlir::tensor::TensorDialect",
- "IREE::Flow::FlowDialect",
- ];
-}
-
-def MaterializeDefaultWorkgroupCountRegionPass:
- InterfacePass<"iree-flow-materialize-default-workgroup-count-region",
- "mlir::FunctionOpInterface"> {
- let summary = "Canonicalize dispatch workgroups ops.";
- let description = [{
- Apply dispatch workgroups canonicalization patterns.
- }];
- let dependentDialects = [
- "mlir::affine::AffineDialect",
- "mlir::arith::ArithDialect",
- "mlir::linalg::LinalgDialect",
- "mlir::scf::SCFDialect",
- "IREE::Flow::FlowDialect",
- ];
-}
-
//===---------------------------------------------------------------------===//
// General Flow passes
//===---------------------------------------------------------------------===//
@@ -436,17 +145,6 @@
];
}
-
-def HoistEncodingOpsPass :
- InterfacePass<"iree-flow-hoist-encoding-ops", "mlir::FunctionOpInterface"> {
- let summary = "Hoists tensor encoding ops out of flow dispatch regions.";
- let dependentDialects = [
- "mlir::linalg::LinalgDialect",
- "IREE::Flow::FlowDialect",
- "IREE::Encoding::IREEEncodingDialect",
- ];
-}
-
def InitializeEmptyTensorsPass :
Pass<"iree-flow-initialize-empty-tensors", ""> {
let summary = "Initialize empty tensors.";
@@ -547,20 +245,6 @@
];
}
-def SetEncodingPass :
- InterfacePass<"iree-flow-set-encoding", "mlir::FunctionOpInterface"> {
- let summary = "Introduces tensor encoding for flow dispatch regions.";
- let dependentDialects = [
- "mlir::linalg::LinalgDialect",
- "IREE::Flow::FlowDialect",
- "IREE::Encoding::IREEEncodingDialect",
- ];
- let options = [
- Option<"padFactor", "pad-factor", "int64_t", /*default=*/"32",
- "provides padding size hints that will be attached to encodings.">,
- ];
-}
-
def TopLevelSCFToCFGPass :
InterfacePass<"iree-top-level-scf-to-cfg", "mlir::FunctionOpInterface"> {
let summary = "Converts non-nested SCF constructs to CFG (not traversing into opaque operations).";
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
index 4b0e5fe..175f822 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
@@ -8,7 +8,6 @@
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
-#include "iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "llvm/ADT/STLExtras.h"
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
index 1c16925..3ce7528 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
@@ -17,32 +17,12 @@
srcs = enforce_glob(
[
"annotate_dispatches.mlir",
- "attention_fuse_by_expansion.mlir",
"capture_dispatch_dynamic_dims.mlir",
"capture_scf_for_dynamic_dims.mlir",
"cleanup_tensor_shapes.mlir",
- "clone_producers_into_dispatch_regions.mlir",
- "collapse_dimensions.mlir",
- "collapse_linalg_generic_on_tensors.mlir",
- "collapse_reduction.mlir",
- "convert_region_to_workgroups.mlir",
"deduplicate_executables.mlir",
- "dispatch_linalg_on_tensors.mlir",
- "dispatch_linalg_on_tensors_default.mlir",
- "dispatch_linalg_on_tensors_fusion_with_transpose.mlir",
- "dispatch_linalg_transform_dialect.mlir",
- "dispatch_region_formation_preprocessing.mlir",
"export_benchmark_funcs.mlir",
"flow_canonicalize.mlir",
- "fold_unit_dims.mlir",
- "form_dispatch_regions.mlir",
- "form_dispatch_workgroups.mlir",
- "form_scalar_dispatches.mlir",
- "dispatch_linalg_ext_fusion.mlir",
- "fuse_horizontal_contractions.mlir",
- "fuse_multiuse_elementwise_producer.mlir",
- "fusion_preprocessing.mlir",
- "hoist_encoding_ops.mlir",
"initialize_empty_tensors.mlir",
"inject_dispatch_tracing.mlir",
"inject_tensor_tracing.mlir",
@@ -50,27 +30,13 @@
"outline_constants.mlir",
"outline_dispatch_externs.mlir",
"outline_dispatch_regions.mlir",
- "pad_fusion_with_consumer.mlir",
- "pad_fusion_with_producer.mlir",
"pipeline_tests.mlir",
- "set_encoding.mlir",
- "sink_reshapes.mlir",
- "split_reduction.mlir",
- "tensor_pad_to_tensor_insert_slice.mlir",
"top_level_scf_to_cfg.mlir",
- "transform_dispatch_region_formation.mlir",
- "transpose_generic_ops.mlir",
"verify_input_ir.mlir",
],
include = ["*.mlir"],
- # transform_dialect_dispatch_spec is a an MLIR file that specifies a
- # transformation, it needs to be included as data.
- exclude = [
- "transform_dialect_dispatch_spec.mlir",
- ],
),
cfg = "//compiler:lit.cfg.py",
- data = ["transform_dialect_dispatch_spec.mlir"],
tools = [
"//tools:iree-opt",
"@llvm-project//llvm:FileCheck",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
index 203a4ad..53280e7 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
@@ -15,32 +15,12 @@
lit
SRCS
"annotate_dispatches.mlir"
- "attention_fuse_by_expansion.mlir"
"capture_dispatch_dynamic_dims.mlir"
"capture_scf_for_dynamic_dims.mlir"
"cleanup_tensor_shapes.mlir"
- "clone_producers_into_dispatch_regions.mlir"
- "collapse_dimensions.mlir"
- "collapse_linalg_generic_on_tensors.mlir"
- "collapse_reduction.mlir"
- "convert_region_to_workgroups.mlir"
"deduplicate_executables.mlir"
- "dispatch_linalg_ext_fusion.mlir"
- "dispatch_linalg_on_tensors.mlir"
- "dispatch_linalg_on_tensors_default.mlir"
- "dispatch_linalg_on_tensors_fusion_with_transpose.mlir"
- "dispatch_linalg_transform_dialect.mlir"
- "dispatch_region_formation_preprocessing.mlir"
"export_benchmark_funcs.mlir"
"flow_canonicalize.mlir"
- "fold_unit_dims.mlir"
- "form_dispatch_regions.mlir"
- "form_dispatch_workgroups.mlir"
- "form_scalar_dispatches.mlir"
- "fuse_horizontal_contractions.mlir"
- "fuse_multiuse_elementwise_producer.mlir"
- "fusion_preprocessing.mlir"
- "hoist_encoding_ops.mlir"
"initialize_empty_tensors.mlir"
"inject_dispatch_tracing.mlir"
"inject_tensor_tracing.mlir"
@@ -48,22 +28,12 @@
"outline_constants.mlir"
"outline_dispatch_externs.mlir"
"outline_dispatch_regions.mlir"
- "pad_fusion_with_consumer.mlir"
- "pad_fusion_with_producer.mlir"
"pipeline_tests.mlir"
- "set_encoding.mlir"
- "sink_reshapes.mlir"
- "split_reduction.mlir"
- "tensor_pad_to_tensor_insert_slice.mlir"
"top_level_scf_to_cfg.mlir"
- "transform_dispatch_region_formation.mlir"
- "transpose_generic_ops.mlir"
"verify_input_ir.mlir"
TOOLS
FileCheck
iree-opt
- DATA
- transform_dialect_dispatch_spec.mlir
)
### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pipeline_tests.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pipeline_tests.mlir
index 89e6d77..8973ba5 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pipeline_tests.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pipeline_tests.mlir
@@ -1,5 +1,5 @@
// TODO(hanchung): Split the transformation pipeline tests into two mlir files.
-// RUN: iree-opt --iree-global-optimization-transformation-pipeline --iree-flow-transformation-pipeline --split-input-file %s | FileCheck %s
+// RUN: iree-opt --iree-global-optimization-transformation-pipeline --iree-dispatch-creation-pipeline --iree-flow-transformation-pipeline --split-input-file %s | FileCheck %s
#map = affine_map<(d0, d1) -> (d0)>
#map1 = affine_map<(d0, d1) -> (d1)>
diff --git a/compiler/src/iree/compiler/DispatchCreation/BUILD.bazel b/compiler/src/iree/compiler/DispatchCreation/BUILD.bazel
new file mode 100644
index 0000000..3be4817
--- /dev/null
+++ b/compiler/src/iree/compiler/DispatchCreation/BUILD.bazel
@@ -0,0 +1,129 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library", "iree_gentbl_cc_library")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_compiler_cc_library(
+ name = "DispatchCreation",
+ srcs = [
+ "BubbleUpExpandShapes.cpp",
+ "CloneProducersIntoDispatchRegions.cpp",
+ "CollapseDimensions.cpp",
+ "CollapseReductionDimensions.cpp",
+ "ConvertDispatchRegionsToWorkgroups.cpp",
+ "ConvertTensorToFlow.cpp",
+ "DispatchWithTransformDialect.cpp",
+ "ElementwiseOpFusion.cpp",
+ "FoldUnitExtentDims.cpp",
+ "FormDispatchRegions.cpp",
+ "FormScalarDispatches.cpp",
+ "FuseHorizontalContractions.cpp",
+ "FuseMultiUseElementwiseProducer.cpp",
+ "FusionPreprocessing.cpp",
+ "FusionUtils.cpp",
+ "HoistEncodingOps.cpp",
+ "MaterializeDefaultWorkgroupCountRegion.cpp",
+ "Passes.cpp",
+ "SetEncoding.cpp",
+ "SinkReshapes.cpp",
+ "SplitReduction.cpp",
+ "TensorPadToTensorInsertSlice.cpp",
+ "TransposeGenericOps.cpp",
+ ],
+ hdrs = [
+ "FusionUtils.h",
+ "Passes.h",
+ ],
+ deps = [
+ ":PassHeaders",
+ ":PassesIncGen",
+ "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
+ "//compiler/src/iree/compiler/Dialect/Encoding/IR",
+ "//compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow",
+ "//compiler/src/iree/compiler/Dialect/Flow/IR",
+ "//compiler/src/iree/compiler/Dialect/Flow/Transforms",
+ "//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
+ "//compiler/src/iree/compiler/Dialect/LinalgExt/Transforms",
+ "//compiler/src/iree/compiler/Dialect/LinalgExt/Utils",
+ "//compiler/src/iree/compiler/Dialect/Stream/IR",
+ "//compiler/src/iree/compiler/Dialect/Util/Analysis",
+ "//compiler/src/iree/compiler/Dialect/Util/Analysis/Attributes",
+ "//compiler/src/iree/compiler/Dialect/Util/Analysis/DFX",
+ "//compiler/src/iree/compiler/Dialect/Util/IR",
+ "//compiler/src/iree/compiler/Dialect/Util/Transforms",
+ "//compiler/src/iree/compiler/Pipelines:Options",
+ "//compiler/src/iree/compiler/Utils",
+ "//llvm-external-projects/iree-dialects:IREEDialectsTransforms",
+ "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:AffineDialect",
+ "@llvm-project//mlir:AffineUtils",
+ "@llvm-project//mlir:Analysis",
+ "@llvm-project//mlir:ArithDialect",
+ "@llvm-project//mlir:ArithUtils",
+ "@llvm-project//mlir:ControlFlowDialect",
+ "@llvm-project//mlir:DestinationStyleOpInterface",
+ "@llvm-project//mlir:DialectUtils",
+ "@llvm-project//mlir:FuncDialect",
+ "@llvm-project//mlir:FunctionInterfaces",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:LinalgDialect",
+ "@llvm-project//mlir:LinalgTransforms",
+ "@llvm-project//mlir:LinalgUtils",
+ "@llvm-project//mlir:MathDialect",
+ "@llvm-project//mlir:MemRefDialect",
+ "@llvm-project//mlir:MemRefTransforms",
+ "@llvm-project//mlir:PDLDialect",
+ "@llvm-project//mlir:PDLInterpDialect",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:SCFDialect",
+ "@llvm-project//mlir:SCFTransforms",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TensorDialect",
+ "@llvm-project//mlir:TensorTransforms",
+ "@llvm-project//mlir:TensorUtils",
+ "@llvm-project//mlir:TilingInterface",
+ "@llvm-project//mlir:TransformDialect",
+ "@llvm-project//mlir:TransformDialectTransforms",
+ "@llvm-project//mlir:TransformUtils",
+ "@llvm-project//mlir:Transforms",
+ ],
+)
+
+iree_gentbl_cc_library(
+ name = "PassesIncGen",
+ tbl_outs = [
+ (
+ ["--gen-pass-decls"],
+ "Passes.h.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "Passes.td",
+ deps = [
+ "@llvm-project//mlir:PassBaseTdFiles",
+ ],
+)
+
+iree_compiler_cc_library(
+ name = "PassHeaders",
+ hdrs = [
+ "Passes.h",
+ "Passes.h.inc",
+ ],
+ deps = [
+ ":PassesIncGen",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:TransformUtils",
+ "@llvm-project//mlir:Transforms",
+ ],
+)
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BubbleUpExpandShapes.cpp b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp
similarity index 85%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/BubbleUpExpandShapes.cpp
rename to compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp
index 422ace8..79ae8d3 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BubbleUpExpandShapes.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp
@@ -12,11 +12,11 @@
//
//===----------------------------------------------------------------------===//
-#include "iree/compiler/Dialect/Flow/Transforms/FusionUtils.h"
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
+#include "iree/compiler/DispatchCreation/FusionUtils.h"
+#include "iree/compiler/DispatchCreation/Passes.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
@@ -24,20 +24,17 @@
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#define DEBUG_TYPE "iree-flow-bubble-up-expand-shapes"
+#define DEBUG_TYPE "iree-dispatch-creation-bubble-up-expand-shapes"
-namespace mlir::iree_compiler::IREE::Flow {
+namespace mlir::iree_compiler::DispatchCreation {
#define GEN_PASS_DEF_BUBBLEUPEXPANDSHAPESPASS
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc"
+#include "iree/compiler/DispatchCreation/Passes.h.inc"
namespace {
-class BubbleUpExpandShapesPass
+struct BubbleUpExpandShapesPass final
: public impl::BubbleUpExpandShapesPassBase<BubbleUpExpandShapesPass> {
-public:
- using Base::Base;
-
void runOnOperation() override;
};
@@ -51,12 +48,12 @@
[](OpOperand *fusedOperand) {
Operation *producer = fusedOperand->get().getDefiningOp();
Operation *consumer = fusedOperand->getOwner();
- if (!isNonNullAndOutsideDispatch({producer, consumer})) {
+ if (!IREE::Flow::isNonNullAndOutsideDispatch({producer, consumer})) {
return false;
}
// Do not fuse by expand if consumer is dequant.
- if (LinalgExt::isBitExtendOp(consumer)) {
+ if (IREE::LinalgExt::isBitExtendOp(consumer)) {
return false;
}
@@ -84,7 +81,7 @@
};
linalg::populateFoldReshapeOpsByExpansionPatterns(bubbleExpandShapePatterns,
bubbleUpExpansionControlFn);
- LinalgExt::populateFoldReshapeOpsByExpansionPatterns(
+ IREE::LinalgExt::populateFoldReshapeOpsByExpansionPatterns(
bubbleExpandShapePatterns, bubbleUpExpansionControlFn);
// Add patterns to do some additional cleanup (on top of canonicalizations
@@ -101,4 +98,4 @@
}
}
-} // namespace mlir::iree_compiler::IREE::Flow
+} // namespace mlir::iree_compiler::DispatchCreation
diff --git a/compiler/src/iree/compiler/DispatchCreation/CMakeLists.txt b/compiler/src/iree/compiler/DispatchCreation/CMakeLists.txt
new file mode 100644
index 0000000..d8cbd3b
--- /dev/null
+++ b/compiler/src/iree/compiler/DispatchCreation/CMakeLists.txt
@@ -0,0 +1,121 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# compiler/src/iree/compiler/DispatchCreation/BUILD.bazel #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ DispatchCreation
+ HDRS
+ "FusionUtils.h"
+ "Passes.h"
+ SRCS
+ "BubbleUpExpandShapes.cpp"
+ "CloneProducersIntoDispatchRegions.cpp"
+ "CollapseDimensions.cpp"
+ "CollapseReductionDimensions.cpp"
+ "ConvertDispatchRegionsToWorkgroups.cpp"
+ "ConvertTensorToFlow.cpp"
+ "DispatchWithTransformDialect.cpp"
+ "ElementwiseOpFusion.cpp"
+ "FoldUnitExtentDims.cpp"
+ "FormDispatchRegions.cpp"
+ "FormScalarDispatches.cpp"
+ "FuseHorizontalContractions.cpp"
+ "FuseMultiUseElementwiseProducer.cpp"
+ "FusionPreprocessing.cpp"
+ "FusionUtils.cpp"
+ "HoistEncodingOps.cpp"
+ "MaterializeDefaultWorkgroupCountRegion.cpp"
+ "Passes.cpp"
+ "SetEncoding.cpp"
+ "SinkReshapes.cpp"
+ "SplitReduction.cpp"
+ "TensorPadToTensorInsertSlice.cpp"
+ "TransposeGenericOps.cpp"
+ DEPS
+ ::PassHeaders
+ ::PassesIncGen
+ IREEDialectsTransforms
+ IREELinalgTransformDialect
+ LLVMSupport
+ MLIRAffineDialect
+ MLIRAffineUtils
+ MLIRAnalysis
+ MLIRArithDialect
+ MLIRArithUtils
+ MLIRControlFlowDialect
+ MLIRDestinationStyleOpInterface
+ MLIRFuncDialect
+ MLIRFunctionInterfaces
+ MLIRIR
+ MLIRLinalgDialect
+ MLIRLinalgTransforms
+ MLIRLinalgUtils
+ MLIRMathDialect
+ MLIRMemRefDialect
+ MLIRMemRefTransforms
+ MLIRPDLDialect
+ MLIRPDLInterpDialect
+ MLIRPass
+ MLIRSCFDialect
+ MLIRSCFTransforms
+ MLIRSupport
+ MLIRTensorDialect
+ MLIRTensorTransforms
+ MLIRTensorUtils
+ MLIRTilingInterface
+ MLIRTransformDialect
+ MLIRTransformDialectTransforms
+ MLIRTransformUtils
+ MLIRTransforms
+ iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
+ iree::compiler::Dialect::Encoding::IR
+ iree::compiler::Dialect::Flow::Conversion::TensorToFlow
+ iree::compiler::Dialect::Flow::IR
+ iree::compiler::Dialect::Flow::Transforms
+ iree::compiler::Dialect::LinalgExt::IR
+ iree::compiler::Dialect::LinalgExt::Transforms
+ iree::compiler::Dialect::LinalgExt::Utils
+ iree::compiler::Dialect::Stream::IR
+ iree::compiler::Dialect::Util::Analysis
+ iree::compiler::Dialect::Util::Analysis::Attributes
+ iree::compiler::Dialect::Util::Analysis::DFX
+ iree::compiler::Dialect::Util::IR
+ iree::compiler::Dialect::Util::Transforms
+ iree::compiler::Pipelines::Options
+ iree::compiler::Utils
+ PUBLIC
+)
+
+iree_tablegen_library(
+ NAME
+ PassesIncGen
+ TD_FILE
+ "Passes.td"
+ OUTS
+ --gen-pass-decls Passes.h.inc
+)
+
+iree_cc_library(
+ NAME
+ PassHeaders
+ HDRS
+ "Passes.h"
+ "Passes.h.inc"
+ DEPS
+ ::PassesIncGen
+ MLIRPass
+ MLIRTransformUtils
+ MLIRTransforms
+ PUBLIC
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CloneProducersIntoDispatchRegions.cpp b/compiler/src/iree/compiler/DispatchCreation/CloneProducersIntoDispatchRegions.cpp
similarity index 60%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/CloneProducersIntoDispatchRegions.cpp
rename to compiler/src/iree/compiler/DispatchCreation/CloneProducersIntoDispatchRegions.cpp
index 0d35f44..5d9bdd9 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CloneProducersIntoDispatchRegions.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/CloneProducersIntoDispatchRegions.cpp
@@ -5,45 +5,47 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
+#include "iree/compiler/DispatchCreation/Passes.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
-#define DEBUG_TYPE "iree-flow-clone-producers-into-dispatch-regions"
+#define DEBUG_TYPE \
+ "iree-dispatch-creation-clone-producers-into-dispatch-regions"
-namespace mlir::iree_compiler::IREE::Flow {
+namespace mlir::iree_compiler::DispatchCreation {
#define GEN_PASS_DEF_CLONEPRODUCERSINTODISPATCHREGIONSPASS
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc"
+#include "iree/compiler/DispatchCreation/Passes.h.inc"
namespace {
-struct CloneProducersIntoDispatchRegionsPass
- : public IREE::Flow::impl::CloneProducersIntoDispatchRegionsPassBase<
+struct CloneProducersIntoDispatchRegionsPass final
+ : public impl::CloneProducersIntoDispatchRegionsPassBase<
CloneProducersIntoDispatchRegionsPass> {
void runOnOperation() override {
mlir::FunctionOpInterface funcOp = getOperation();
IRRewriter rewriter(funcOp->getContext());
- funcOp->walk([&](DispatchRegionOp regionOp) {
+ funcOp->walk([&](IREE::Flow::DispatchRegionOp regionOp) {
if (failed(cloneProducersToRegion(rewriter, regionOp)))
return signalPassFailure();
});
funcOp->walk([&](Operation *op) {
- if (!isNonNullAndOutsideDispatch(op) || !isa<linalg::GenericOp>(op)) {
+ if (!IREE::Flow::isNonNullAndOutsideDispatch(op) ||
+ !isa<linalg::GenericOp>(op)) {
return;
}
- if (failed(wrapOpInDispatchRegion(rewriter, op))) {
+ if (failed(IREE::Flow::wrapOpInDispatchRegion(rewriter, op))) {
return signalPassFailure();
}
});
// Rerun the cloning again to move still clonable operations into
// dispatches.
- funcOp->walk([&](DispatchRegionOp regionOp) {
+ funcOp->walk([&](IREE::Flow::DispatchRegionOp regionOp) {
if (failed(cloneProducersToRegion(rewriter, regionOp)))
return signalPassFailure();
});
@@ -52,4 +54,4 @@
} // namespace
-} // namespace mlir::iree_compiler::IREE::Flow
+} // namespace mlir::iree_compiler::DispatchCreation
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseDimensions.cpp b/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp
similarity index 95%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseDimensions.cpp
rename to compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp
index 086acb5..1db69fc 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseDimensions.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp
@@ -6,8 +6,8 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.h"
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
+#include "iree/compiler/DispatchCreation/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/Debug.h"
@@ -33,18 +33,17 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#define DEBUG_TYPE "iree-flow-collapse-dimensions"
+#define DEBUG_TYPE "iree-dispatch-creation-collapse-dimensions"
-namespace mlir::iree_compiler::IREE::Flow {
+namespace mlir::iree_compiler::DispatchCreation {
#define GEN_PASS_DEF_COLLAPSEDIMENSIONSPASS
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc"
+#include "iree/compiler/DispatchCreation/Passes.h.inc"
namespace {
/// Pass declaration.
-struct CollapseDimensionsPass
- : public IREE::Flow::impl::CollapseDimensionsPassBase<
- CollapseDimensionsPass> {
+struct CollapseDimensionsPass final
+ : public impl::CollapseDimensionsPassBase<CollapseDimensionsPass> {
using Base::Base;
void runOnOperation() override;
};
@@ -271,7 +270,7 @@
CollapseInfo() = default;
CollapseInfo(linalg::GenericOp genericOp) {
- reassociation = Flow::getCollapsibleLoops(genericOp);
+ reassociation = DispatchCreation::getCollapsibleLoops(genericOp);
collapsableLoops = getCollapsedFromReassociation(reassociation);
}
@@ -469,10 +468,10 @@
/// Traverses all the the Ops in DispatchRegionOps and finds a linalg.generic Op
/// which is the sole producer of the flow.return's operand.
static FailureOr<linalg::GenericOp>
-findRootGenericOp(DispatchRegionOp regionOp) {
+findRootGenericOp(IREE::Flow::DispatchRegionOp regionOp) {
// Check the yielded value is from a single `linalg.generic`.
auto returnOp =
- cast<Flow::ReturnOp>(regionOp.getBody().front().getTerminator());
+ cast<IREE::Flow::ReturnOp>(regionOp.getBody().front().getTerminator());
if (!returnOp->getOperands().size()) {
return failure();
}
@@ -497,17 +496,17 @@
/// Hoist `tensor.collapse_shape` ops at the beginning of the `dispatchOp`
/// and `tensor.expand_shape` ops at the end of the `dispatchOp`, out of the
/// dispatch.
-static FailureOr<DispatchRegionOp>
-hoistTensorReshapesOutOfDispatchRegion(RewriterBase &rewriter,
- DispatchRegionOp dispatchOp) {
+static FailureOr<IREE::Flow::DispatchRegionOp>
+hoistTensorReshapesOutOfDispatchRegion(
+ RewriterBase &rewriter, IREE::Flow::DispatchRegionOp dispatchOp) {
Block &body = dispatchOp.getBody().front();
- auto returnOp = cast<Flow::ReturnOp>(body.getTerminator());
+ auto returnOp = cast<IREE::Flow::ReturnOp>(body.getTerminator());
// 1. Get the slice of operations within `dispatchOp` that produce the yielded
// value.
BackwardSliceOptions sliceOptions;
sliceOptions.filter = [&](Operation *op) {
- return op->getParentOfType<DispatchRegionOp>();
+ return op->getParentOfType<IREE::Flow::DispatchRegionOp>();
};
SetVector<Operation *> slice;
getBackwardSlice(returnOp, &slice, sliceOptions);
@@ -603,7 +602,7 @@
}
// 5. Create the new dispatch op.
- auto newDispatchOp = rewriter.create<DispatchRegionOp>(
+ auto newDispatchOp = rewriter.create<IREE::Flow::DispatchRegionOp>(
loc, newReturnTypes, newDynamicDims, dispatchOp.getWorkload());
// 5a. Move the body over, but replace the `flow.return` to use the new yield
@@ -614,7 +613,7 @@
Operation *terminator = newBody.front().getTerminator();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(terminator);
- rewriter.replaceOpWithNewOp<Flow::ReturnOp>(terminator, newYieldVals);
+ rewriter.replaceOpWithNewOp<IREE::Flow::ReturnOp>(terminator, newYieldVals);
}
// 5b. Move the workgroup count region over.
@@ -668,7 +667,7 @@
for (auto operand : consumerOp.getDpsInputOperands()) {
auto definingOp = operand->get().getDefiningOp();
- if (!definingOp || isNonNullAndOutsideDispatch(definingOp)) {
+ if (!definingOp || IREE::Flow::isNonNullAndOutsideDispatch(definingOp)) {
continue;
}
@@ -765,9 +764,10 @@
// Construct a DAG of `linalg.generic` operations with 1 root op. Find
// dimensions that can be collapsed all the way from the root to the leaves,
// ensuring that all `collapse_shape` ops can be hoisted out of the dispatch.
-static bool collapseDimensionsForDispatch(IRRewriter &rewriter,
- DispatchRegionOp ®ionOp,
- int maxIterations) {
+static bool
+collapseDimensionsForDispatch(IRRewriter &rewriter,
+ IREE::Flow::DispatchRegionOp ®ionOp,
+ int maxIterations) {
// Only collapse dispatches with 1 block
if (!llvm::hasSingleElement(regionOp.getBody())) {
return false;
@@ -783,7 +783,7 @@
sliceOptions.omitBlockArguments = true;
sliceOptions.filter = [&](Operation *op) -> bool {
auto genericOp = dyn_cast<linalg::GenericOp>(op);
- auto parentOp = op->getParentOfType<DispatchRegionOp>();
+ auto parentOp = op->getParentOfType<IREE::Flow::DispatchRegionOp>();
return genericOp && isEligibleForCollapse(genericOp) &&
parentOp == regionOp;
};
@@ -880,8 +880,8 @@
MLIRContext *context = funcOp->getContext();
IRRewriter rewriter(context);
- SmallVector<DispatchRegionOp> modifiedDispatchOps;
- funcOp->walk([&](DispatchRegionOp dispatchOp) {
+ SmallVector<IREE::Flow::DispatchRegionOp> modifiedDispatchOps;
+ funcOp->walk([&](IREE::Flow::DispatchRegionOp dispatchOp) {
if (collapseDimensionsForDispatch(rewriter, dispatchOp, maxIterations)) {
modifiedDispatchOps.push_back(dispatchOp);
}
@@ -899,9 +899,9 @@
// Hoist tensor reshape ops out of dispatch region first. Otherwise, the
// reshape(cst) will be folded into a constant living in the dispatch. It
// could introduce big constants inlined in the dispatch.
- FailureOr<DispatchRegionOp> newDispatchOp =
+ FailureOr<IREE::Flow::DispatchRegionOp> newDispatchOp =
hoistTensorReshapesOutOfDispatchRegion(
- rewriter, cast<DispatchRegionOp>(dispatchOp));
+ rewriter, cast<IREE::Flow::DispatchRegionOp>(dispatchOp));
if (failed(newDispatchOp)) {
dispatchOp->emitOpError("failed to hoist reshapes out of dispatch");
return signalPassFailure();
@@ -929,4 +929,4 @@
}
}
-} // namespace mlir::iree_compiler::IREE::Flow
+} // namespace mlir::iree_compiler::DispatchCreation
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseReductionDimensions.cpp b/compiler/src/iree/compiler/DispatchCreation/CollapseReductionDimensions.cpp
similarity index 86%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseReductionDimensions.cpp
rename to compiler/src/iree/compiler/DispatchCreation/CollapseReductionDimensions.cpp
index 0a4e5d6..8b76e7d 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseReductionDimensions.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/CollapseReductionDimensions.cpp
@@ -4,17 +4,17 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-namespace mlir::iree_compiler::IREE::Flow {
+namespace mlir::iree_compiler::DispatchCreation {
#define GEN_PASS_DEF_COLLAPSEREDUCTIONDIMENSIONSPASS
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc"
+#include "iree/compiler/DispatchCreation/Passes.h.inc"
namespace {
@@ -47,7 +47,7 @@
collapseDimensions(linalg::LinalgOp linalgOp) {
SmallVector<ReassociationIndices> collapseIndices;
- if (!isNonNullAndOutsideDispatch(linalgOp)) {
+ if (!IREE::Flow::isNonNullAndOutsideDispatch(linalgOp)) {
return collapseIndices;
}
@@ -68,8 +68,8 @@
return collapseIndices;
}
-struct CollapseReductionDimensionsPass
- : public IREE::Flow::impl::CollapseReductionDimensionsPassBase<
+struct CollapseReductionDimensionsPass final
+ : public impl::CollapseReductionDimensionsPassBase<
CollapseReductionDimensionsPass> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
@@ -83,4 +83,4 @@
} // namespace
-} // namespace mlir::iree_compiler::IREE::Flow
+} // namespace mlir::iree_compiler::DispatchCreation
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertDispatchRegionsToWorkgroups.cpp b/compiler/src/iree/compiler/DispatchCreation/ConvertDispatchRegionsToWorkgroups.cpp
similarity index 74%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertDispatchRegionsToWorkgroups.cpp
rename to compiler/src/iree/compiler/DispatchCreation/ConvertDispatchRegionsToWorkgroups.cpp
index 078c8c4..7ea310b 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertDispatchRegionsToWorkgroups.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/ConvertDispatchRegionsToWorkgroups.cpp
@@ -7,8 +7,8 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.h"
#include "iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.h"
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
+#include "iree/compiler/DispatchCreation/Passes.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -19,20 +19,19 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
-#define DEBUG_TYPE "iree-flow-convert-dispatch-regions-to-workgroups"
+#define DEBUG_TYPE \
+ "iree-dispatch-creation-convert-dispatch-regions-to-workgroups"
-namespace mlir::iree_compiler::IREE::Flow {
+namespace mlir::iree_compiler::DispatchCreation {
#define GEN_PASS_DEF_CONVERTDISPATCHREGIONSTOWORKGROUPSPASS
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc"
+#include "iree/compiler/DispatchCreation/Passes.h.inc"
namespace {
struct ConvertDispatchRegionsToWorkgroupsPass
- : public IREE::Flow::impl::ConvertDispatchRegionsToWorkgroupsPassBase<
+ : public impl::ConvertDispatchRegionsToWorkgroupsPassBase<
ConvertDispatchRegionsToWorkgroupsPass> {
- using IREE::Flow::impl::ConvertDispatchRegionsToWorkgroupsPassBase<
- ConvertDispatchRegionsToWorkgroupsPass>::
- ConvertDispatchRegionsToWorkgroupsPassBase;
+ using Base::Base;
void runOnOperation() override;
};
} // namespace
@@ -43,7 +42,8 @@
TensorDimTrackingRewriter rewriter(funcOp);
SmallVector<IREE::Flow::DispatchRegionOp> regionOps;
- funcOp.walk([&](Flow::DispatchRegionOp op) { regionOps.push_back(op); });
+ funcOp.walk(
+ [&](IREE::Flow::DispatchRegionOp op) { regionOps.push_back(op); });
numDispatches += regionOps.size();
@@ -58,4 +58,4 @@
}
}
}
-} // namespace mlir::iree_compiler::IREE::Flow
+} // namespace mlir::iree_compiler::DispatchCreation
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertTensorToFlow.cpp b/compiler/src/iree/compiler/DispatchCreation/ConvertTensorToFlow.cpp
similarity index 90%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertTensorToFlow.cpp
rename to compiler/src/iree/compiler/DispatchCreation/ConvertTensorToFlow.cpp
index 9ab55f8..b840924 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertTensorToFlow.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/ConvertTensorToFlow.cpp
@@ -10,8 +10,8 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.h"
#include "iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.h"
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
+#include "iree/compiler/DispatchCreation/Passes.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -24,12 +24,12 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#define DEBUG_TYPE "iree-flow-convert-tensor-to-flow"
+#define DEBUG_TYPE "iree-dispatch-creation-convert-tensor-to-flow"
-namespace mlir::iree_compiler::IREE::Flow {
+namespace mlir::iree_compiler::DispatchCreation {
#define GEN_PASS_DEF_CONVERTTENSORTOFLOWPASS
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc"
+#include "iree/compiler/DispatchCreation/Passes.h.inc"
/// Return `true` if the given op is contained in DispatchWorkgroupsOp or in a
/// DispatchRegionOp.
@@ -43,8 +43,7 @@
wrapInWorkgroupsOp(mlir::TensorDimTrackingRewriter &rewriter, Operation *op) {
SmallVector<tensor::DimOp> dimOps = rewriter.getTensorDimOps();
- if (failed(iree_compiler::IREE::Flow::simplifyDimOps(
- rewriter, rewriter.getTensorDimOps())))
+ if (failed(IREE::Flow::simplifyDimOps(rewriter, rewriter.getTensorDimOps())))
return failure();
// Wrap operation.
@@ -91,7 +90,8 @@
// Rewrite InsertSliceOps to FlowUpdateOps.
SmallVector<Operation *> remainingInsertSliceOps;
for (tensor::InsertSliceOp insertSliceOp : insertSliceOps) {
- if (failed(convertInsertSliceOpToFlowUpdateOp(rewriter, insertSliceOp))) {
+ if (failed(IREE::Flow::convertInsertSliceOpToFlowUpdateOp(rewriter,
+ insertSliceOp))) {
remainingInsertSliceOps.push_back(insertSliceOp);
}
}
@@ -124,7 +124,8 @@
// Rewrite ExtractSliceOps to FlowSliceOps.
SmallVector<Operation *> remainingExtractSliceOps;
for (tensor::ExtractSliceOp extractSliceOp : extractSliceOps) {
- if (failed(convertExtractSliceOpToFlowSliceOp(rewriter, extractSliceOp))) {
+ if (failed(IREE::Flow::convertExtractSliceOpToFlowSliceOp(
+ rewriter, extractSliceOp))) {
remainingExtractSliceOps.push_back(extractSliceOp);
}
}
@@ -144,10 +145,8 @@
namespace {
struct ConvertTensorToFlowPass
- : public IREE::Flow::impl::ConvertTensorToFlowPassBase<
- ConvertTensorToFlowPass> {
- using IREE::Flow::impl::ConvertTensorToFlowPassBase<
- ConvertTensorToFlowPass>::ConvertTensorToFlowPassBase;
+ : public impl::ConvertTensorToFlowPassBase<ConvertTensorToFlowPass> {
+ using Base::Base;
void runOnOperation() override;
};
} // namespace
@@ -158,7 +157,7 @@
mlir::MLIRContext *context = &getContext();
auto workgroupsOps = SmallVector<IREE::Flow::DispatchWorkgroupsOp>();
- funcOp->walk([&](Flow::DispatchWorkgroupsOp workgroupsOp) {
+ funcOp->walk([&](IREE::Flow::DispatchWorkgroupsOp workgroupsOp) {
workgroupsOps.push_back(workgroupsOp);
});
@@ -211,4 +210,4 @@
}
}
-} // namespace mlir::iree_compiler::IREE::Flow
+} // namespace mlir::iree_compiler::DispatchCreation
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DispatchWithTransformDialect.cpp b/compiler/src/iree/compiler/DispatchCreation/DispatchWithTransformDialect.cpp
similarity index 83%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/DispatchWithTransformDialect.cpp
rename to compiler/src/iree/compiler/DispatchCreation/DispatchWithTransformDialect.cpp
index fae34dc..e5803bd 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DispatchWithTransformDialect.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/DispatchWithTransformDialect.cpp
@@ -5,8 +5,8 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "iree/compiler/DispatchCreation/Passes.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -20,22 +20,20 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
-namespace mlir::iree_compiler::IREE::Flow {
+namespace mlir::iree_compiler::DispatchCreation {
#define GEN_PASS_DEF_DISPATCHWITHTRANSFORMDIALECTPASS
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc"
+#include "iree/compiler/DispatchCreation/Passes.h.inc"
/// Pass declaration.
/// Interpreter pass that applies transform dialect ops for dispatch region
/// formation. This needs to be its own pass because the registration mechanism
/// and ops available are different than for other interpreters.
namespace {
-class DispatchWithTransformDialectPass
- : public IREE::Flow::impl::DispatchWithTransformDialectPassBase<
+struct DispatchWithTransformDialectPass final
+ : public impl::DispatchWithTransformDialectPassBase<
DispatchWithTransformDialectPass> {
-public:
- using IREE::Flow::impl::DispatchWithTransformDialectPassBase<
- DispatchWithTransformDialectPass>::DispatchWithTransformDialectPassBase;
+ using Base::Base;
void runOnOperation() override {
MLIRContext *context = &getContext();
// Load the module from the spec path. The module will be unloaded once the
@@ -67,4 +65,4 @@
} // namespace
-} // namespace mlir::iree_compiler::IREE::Flow
+} // namespace mlir::iree_compiler::DispatchCreation
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ElementwiseOpFusion.cpp b/compiler/src/iree/compiler/DispatchCreation/ElementwiseOpFusion.cpp
similarity index 86%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/ElementwiseOpFusion.cpp
rename to compiler/src/iree/compiler/DispatchCreation/ElementwiseOpFusion.cpp
index 306268d..2841c98 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ElementwiseOpFusion.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/ElementwiseOpFusion.cpp
@@ -11,29 +11,26 @@
//
//===----------------------------------------------------------------------===//
-#include "iree/compiler/Dialect/Flow/Transforms/FusionUtils.h"
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
+#include "iree/compiler/DispatchCreation/FusionUtils.h"
+#include "iree/compiler/DispatchCreation/Passes.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#define DEBUG_TYPE "iree-flow-elementwise-op-fusion"
+#define DEBUG_TYPE "iree-dispatch-creation-elementwise-op-fusion"
-namespace mlir::iree_compiler::IREE::Flow {
+namespace mlir::iree_compiler::DispatchCreation {
#define GEN_PASS_DEF_ELEMENTWISEOPFUSIONPASS
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc"
+#include "iree/compiler/DispatchCreation/Passes.h.inc"
namespace {
-class ElementwiseOpFusionPass
+struct ElementwiseOpFusionPass final
: public impl::ElementwiseOpFusionPassBase<ElementwiseOpFusionPass> {
-
-public:
using Base::Base;
-
void runOnOperation() override;
};
@@ -51,7 +48,7 @@
Operation *producer = fusedOperand->get().getDefiningOp();
Operation *consumer = fusedOperand->getOwner();
- if (!isNonNullAndOutsideDispatch({producer, consumer})) {
+ if (!IREE::Flow::isNonNullAndOutsideDispatch({producer, consumer})) {
return false;
}
@@ -84,4 +81,4 @@
}
}
-} // namespace mlir::iree_compiler::IREE::Flow
+} // namespace mlir::iree_compiler::DispatchCreation
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FoldUnitExtentDims.cpp b/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp
similarity index 92%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/FoldUnitExtentDims.cpp
rename to compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp
index 390e7f6..f802f0b 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FoldUnitExtentDims.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp
@@ -11,10 +11,10 @@
//
//===----------------------------------------------------------------------===//
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
#include "iree/compiler/Dialect/Util/Analysis/Explorer.h"
+#include "iree/compiler/DispatchCreation/Passes.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -25,12 +25,12 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#define DEBUG_TYPE "iree-flow-fold-unit-extent-dims"
+#define DEBUG_TYPE "iree-dispatch-creation-fold-unit-extent-dims"
-namespace mlir::iree_compiler::IREE::Flow {
+namespace mlir::iree_compiler::DispatchCreation {
#define GEN_PASS_DEF_FOLDUNITEXTENTDIMSPASS
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc"
+#include "iree/compiler/DispatchCreation/Passes.h.inc"
//===----------------------------------------------------------------------===//
// Pass helpers
@@ -107,9 +107,8 @@
}
namespace {
-struct FoldUnitExtentDimsPass
- : public IREE::Flow::impl::FoldUnitExtentDimsPassBase<
- FoldUnitExtentDimsPass> {
+struct FoldUnitExtentDimsPass final
+ : public impl::FoldUnitExtentDimsPassBase<FoldUnitExtentDimsPass> {
void runOnOperation() override;
};
} // namespace
@@ -149,7 +148,7 @@
auto defaultFn = options.controlFn;
options.controlFn = [&](Operation *op) {
// Ignore operations already in dispatches.
- if (!isNonNullAndOutsideDispatch(op)) {
+ if (!IREE::Flow::isNonNullAndOutsideDispatch(op)) {
return SmallVector<unsigned>{};
}
return defaultFn(op);
@@ -162,4 +161,4 @@
}
}
-} // namespace mlir::iree_compiler::IREE::Flow
+} // namespace mlir::iree_compiler::DispatchCreation
diff --git a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp
new file mode 100644
index 0000000..d272e3b
--- /dev/null
+++ b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp
@@ -0,0 +1,949 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.h"
+#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.h"
+#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
+#include "iree/compiler/DispatchCreation/Passes.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+
+#define DEBUG_TYPE "iree-dispatch-creation-form-dispatch-regions"
+
+static const char kRootOpAttr[] = "__root_op__";
+static const char kFusionGroupsAttr[] = "__fused_op__";
+
+namespace mlir::iree_compiler::DispatchCreation {
+
+#define GEN_PASS_DEF_FORMDISPATCHREGIONSPASS
+#include "iree/compiler/DispatchCreation/Passes.h.inc"
+
+//===----------------------------------------------------------------------===//
+// Root and fusion group attribute handling
+//===----------------------------------------------------------------------===//
+
+/// Returns true if an op has a root operation.
+static bool hasRootOpAttribute(Operation *op) {
+ return static_cast<bool>(op->getAttrOfType<IntegerAttr>(kRootOpAttr));
+}
+/// Removes root attribute. Asserts if root attribute is not present.
+static void removeRootOpAttribute(Operation *op) {
+ op->removeAttr(kRootOpAttr);
+}
+/// Sets the root attribute for an operation. The root attribute needs a number
+/// to identify the root. Asserts if root attribute is already set on an
+/// operation.
+static void setRootAttribute(MLIRContext *context, Operation *op,
+ int64_t rootNumber) {
+ assert(!op->hasAttr(kRootOpAttr) &&
+ "invalid to update root attribute on an op");
+ op->setAttr(kRootOpAttr,
+ IntegerAttr::get(IntegerType::get(context, 64), rootNumber));
+}
+/// Returns the number of the root. Asserts if the operation is not already set
+/// as a root.
+static int64_t getRootNumber(Operation *op) {
+ return op->getAttrOfType<IntegerAttr>(kRootOpAttr).getInt();
+}
+/// Returns true if an op is part of a fusion group.
+static bool hasFusionGroupsAttribute(Operation *op) {
+ return static_cast<bool>(op->getAttrOfType<ArrayAttr>(kFusionGroupsAttr));
+}
+/// Returns the fusion groups for the given `op`.
+static SmallVector<int64_t, 1> getFusionGroups(Operation *op) {
+ SmallVector<int64_t, 1> fusionGroups = {};
+ if (auto fusionGroupsAttr = op->getAttrOfType<ArrayAttr>(kFusionGroupsAttr)) {
+ fusionGroups = llvm::map_to_vector(fusionGroupsAttr, [](Attribute attr) {
+ return llvm::cast<IntegerAttr>(attr).getInt();
+ });
+ }
+ return fusionGroups;
+}
+/// Appends the given `op` to the `newGroups` fusion groups.
+static void appendToFusionGroup(Operation *op, ArrayRef<int64_t> newGroups) {
+ SmallVector<int64_t> fusionGroups = getFusionGroups(op);
+ fusionGroups.append(newGroups.begin(), newGroups.end());
+ op->setAttr(kFusionGroupsAttr, Builder(op).getI64ArrayAttr(fusionGroups));
+}
+/// Removes the fusion groups attribute.
+static void removeFusionGroupsAttribute(Operation *op) {
+ op->removeAttr(kFusionGroupsAttr);
+}
+
+//===----------------------------------------------------------------------===//
+// Op property charecterizations
+//===----------------------------------------------------------------------===//
+
+/// Returns true if the reduced dimensions in the linalgOp of the unpack result
+/// are not unpacked by the producer tensor::UnPackOp. This means the reduced
+/// dimensions of the unpack result are not part of the inner_dims_pos.
+static bool hasNoPackedReductionDimensions(linalg::LinalgOp linalgOp,
+ Operation *producer) {
+ auto unpack = dyn_cast<tensor::UnPackOp>(producer);
+ if (!unpack) {
+ return false;
+ }
+ AffineMap map;
+ for (auto &use : producer->getResult(0).getUses()) {
+ if (use.getOwner() == linalgOp) {
+ map = linalgOp.getMatchingIndexingMap(&use);
+ break;
+ }
+ }
+ if (!map) {
+ return false;
+ }
+ auto iterators = linalgOp.getIteratorTypesArray();
+ auto reduction = utils::IteratorType::reduction;
+ for (auto expr : llvm::enumerate(map.getResults())) {
+ auto dim = dyn_cast<AffineDimExpr>(expr.value());
+ if (!dim) {
+ return false;
+ }
+ unsigned pos = dim.getPosition();
+ if (iterators[pos] == reduction &&
+ llvm::any_of(unpack.getInnerDimsPos(),
+ [expr](int64_t idp) { return expr.index() == idp; })) {
+ return false;
+ }
+ }
+ return true;
+}
+
+/// Returns true if the linalgOp is fusable with an unpack producer
+static bool hasFusableUnpackProducer(linalg::LinalgOp linalgOp) {
+ return llvm::any_of(linalgOp->getOperands(), [&](Value operand) {
+ auto producer = operand.getDefiningOp<tensor::UnPackOp>();
+ return producer && hasNoPackedReductionDimensions(linalgOp, producer);
+ });
+}
+
+/// Operations that are treated as root operations for dispatch region
+/// formation.
+static bool isRootOp(Operation *op) {
+ if (op->getParentOfType<IREE::Flow::DispatchWorkgroupsOp>()) {
+ return false;
+ }
+ // Dequantization-like ops get cloned into dispatches later.
+ if (IREE::LinalgExt::isBitExtendOp(op)) {
+ return false;
+ }
+ // Any Linalg named op or generic op with reduction iterator types is a root
+ // op.
+ if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
+ if (isa<linalg::GenericOp>(op)) {
+ return linalgOp.getNumReductionLoops() != 0 &&
+ !hasFusableUnpackProducer(linalgOp);
+ }
+ return !isa<linalg::FillOp>(op);
+ }
+ if (isa<TilingInterface>(op)) {
+ return !isa<tensor::PadOp, tensor::PackOp>(op);
+ }
+ return isa<IREE::Encoding::UnsetEncodingOp, tensor::UnPackOp>(op);
+}
+
+/// Returns true if the operation is a `pack` op or a `set_encoding` op that
+/// has pack semantics.
+// TODO(ravishankarm): This seems like a use case for an interface.
+static bool isPackLikeOp(Operation *op) {
+ return isa<IREE::Encoding::SetEncodingOp, tensor::PackOp>(op);
+}
+
+/// Returns true if the operation is an `unpack` op or an `unset_encoding` op,
+/// or an `extract_slice` op whose source operand matches those criteria,
+/// recursively.
+/// The idea is that we want to ensure that `extract_slice` ops can't prevent
+/// fusion between a `unset_encoding` producer and some linalg consumer. In
+/// %0 = unset_encoding ...
+/// %1 = extract_slice %0 ...
+/// %2 = linalg.generic ins(%1) ...
+/// we are not content to be fusing %1 into %0, we also want to be fusing %2,
+/// so we want to prevent %1 from acting as a consumer fusion barrier.
+static bool isUnpackLikeOpViaExtractSliceOps(Operation *op) {
+ if (isa<IREE::Encoding::UnsetEncodingOp, tensor::UnPackOp>(op)) {
+ return true;
+ }
+ if (isa<tensor::ExtractSliceOp>(op)) {
+ Value source = op->getOperand(0);
+ Operation *producer = source.getDefiningOp();
+ if (isUnpackLikeOpViaExtractSliceOps(producer)) {
+ return true;
+ }
+ }
+ return false;
+}
+
+/// Since `iree_encoding.set_encoding` doesnt have padding semantics a
+/// `tensor.pad` is introduced to get the shapes of the input and output to
+/// match. The `tensor.pad` -> `set_encoding` can be folded later on into a
+/// single `tensor.pack` operation. But it means the fusion has to try to keep
+/// these in the same dispatch.
+// TODO(ravishankarm): Maybe make `set_encoding` have pad semantics that can be
+// explicitly broken down if needed.
+static bool isPadUsedInSetEncoding(tensor::PadOp padOp) {
+ return llvm::any_of(padOp->getUsers(),
+ llvm::IsaPred<IREE::Encoding::SetEncodingOp>);
+}
+
+//===----------------------------------------------------------------------===//
+// Heuristics for fusing dispatchble ops with root ops using tile + fuse.
+//===----------------------------------------------------------------------===//
+
+/// Returns a bit vector of size number of loops of the `interfaceOp` with
+/// the bits corresponding to outer parallel loops set to `true`.
+static llvm::SmallBitVector getOuterParallelLoops(Operation *op) {
+ if (auto setEncodingOp = dyn_cast<IREE::Encoding::SetEncodingOp>(op)) {
+ return llvm::SmallBitVector(setEncodingOp.getResultType().getRank(), true);
+ }
+ if (auto unsetEncodingOp = dyn_cast<IREE::Encoding::UnsetEncodingOp>(op)) {
+ return llvm::SmallBitVector(unsetEncodingOp.getResultType().getRank(),
+ true);
+ }
+
+ auto interfaceOp = dyn_cast<TilingInterface>(op);
+ if (!interfaceOp) {
+ // For ops that dont implement the `TilingInterface` just return empty.
+ return llvm::SmallBitVector{};
+ }
+ SmallVector<utils::IteratorType> loopIteratorTypes =
+ interfaceOp.getLoopIteratorTypes();
+ llvm::SmallBitVector parallelLoops(loopIteratorTypes.size());
+ for (auto iteratorType : llvm::enumerate(loopIteratorTypes)) {
+ if (iteratorType.value() != utils::IteratorType::parallel)
+ break;
+ parallelLoops.set(iteratorType.index());
+ }
+ return parallelLoops;
+}
+
+/// Returns true if `map` is an identity map with zeros, i.e. if you
+/// drop the result exprs that are constant zeros, the `map` will become an
+/// identity.
+static bool isIdentityMapWithZeros(AffineMap map) {
+ if (map.getNumSymbols() != 0)
+ return false;
+ if (map.isEmpty())
+ return false;
+ unsigned dimsSeen = 0;
+ for (AffineExpr result : map.getResults()) {
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(result)) {
+ if (dimExpr.getPosition() != dimsSeen) {
+ return false;
+ }
+ dimsSeen++;
+ } else if (auto constExpr = dyn_cast<AffineConstantExpr>(result)) {
+ if (constExpr.getValue() != 0) {
+ return false;
+ }
+ } else {
+ return false;
+ }
+ }
+ return dimsSeen == map.getNumDims();
+}
+
+static bool
+matchIteratorTypes(const llvm::SmallBitVector &rootOuterParallelLoop,
+ const llvm::SmallBitVector &candidateOuterParallelLoop) {
+ // If the candidate is not all parallel, then its loop configuration should be
+ // the same as the root.
+ if (candidateOuterParallelLoop.size() != candidateOuterParallelLoop.count()) {
+ return rootOuterParallelLoop == candidateOuterParallelLoop;
+ }
+
+ // If the candidate is all parallel, then it should be at least as parallel as
+ // the root.
+ for (int pos : llvm::seq<int>(0, rootOuterParallelLoop.size())) {
+ // If we reach the end of the outer loops of the root, break out of the
+ // loop.
+ if (!rootOuterParallelLoop.test(pos))
+ break;
+ // If the root loop is parallel, the candidate loop should also be parallel.
+ if (pos >= candidateOuterParallelLoop.size() ||
+ !candidateOuterParallelLoop.test(pos))
+ return false;
+ }
+ return true;
+}
+
+// Method to check if the op with have compatible indexing map on outer-parallel
+// loops. Currently it means the map needs to be identity on the those
+// dimensions, ignoring its reduction dimensions.
+static bool hasCompatibleOuterParallelLoops(
+ TilingInterface tileOp, AffineMap indexingMap,
+ const llvm::SmallBitVector &rootOuterParallelLoops) {
+ if (!indexingMap.isProjectedPermutation()) {
+ return false;
+ }
+
+ llvm::SmallBitVector parallelLoops = getOuterParallelLoops(tileOp);
+ if (!matchIteratorTypes(rootOuterParallelLoops, parallelLoops)) {
+ return false;
+ }
+
+ /// Project out the non-parallel dimensions.
+ llvm::SmallBitVector projectedDims(rootOuterParallelLoops);
+ projectedDims.flip();
+ projectedDims.resize(tileOp.getLoopIteratorTypes().size(), true);
+ auto projectedMap = getProjectedMap(indexingMap, projectedDims);
+ return isIdentityMapWithZeros(projectedMap);
+}
+
+// Method to check if two `linalg.generic` op with producer-consumer
+// relationship through `operand` have compatible outer-parallel loops.
+static bool hasCompatibleOuterParallelLoops(
+ OpOperand &operand, const llvm::SmallBitVector &rootOuterParallelLoops) {
+ auto producer =
+ operand.get().getDefiningOp<IREE::LinalgExt::LinalgFusionOpInterface>();
+ auto consumer =
+ dyn_cast<IREE::LinalgExt::LinalgFusionOpInterface>(operand.getOwner());
+ if (!producer || !consumer)
+ return false;
+
+ auto producerIndexingMap = producer.getIndexingMapMatchingResult(
+ llvm::cast<OpResult>(operand.get()));
+ auto consumerIndexingMap = consumer.getMatchingIndexingMap(&operand);
+
+ if (!producerIndexingMap || !consumerIndexingMap) {
+ return false;
+ }
+
+ return hasCompatibleOuterParallelLoops(
+ cast<TilingInterface>(producer.getOperation()),
+ producerIndexingMap, rootOuterParallelLoops) &&
+ hasCompatibleOuterParallelLoops(
+ cast<TilingInterface>(consumer.getOperation()),
+ consumerIndexingMap, rootOuterParallelLoops);
+}
+
+/// For all uses of an operation, finds the use that dominates all other uses.
+static std::optional<OpOperand *>
+getFusableUse(Operation *op, DominanceInfo const &dominanceInfo,
+ bool aggressiveFusion) {
+ if (!aggressiveFusion && llvm::count_if(op->getUses(), [](OpOperand &use) {
+ return !isa<tensor::DimOp>(use.getOwner());
+ }) != 1) {
+ return std::nullopt;
+ }
+
+ // Collect non-dim users.
+ SmallVector<Operation *> nonDimUsers;
+ for (Operation *user : op->getUsers()) {
+ if (isa<tensor::DimOp>(user))
+ continue;
+ nonDimUsers.push_back(user);
+ }
+
+ // Find the use in a non-dim user that dominates all other non-dim users.
+ for (auto &use : op->getUses()) {
+ Operation *user = use.getOwner();
+ if (isa<tensor::DimOp>(user))
+ continue;
+ if (llvm::all_of(nonDimUsers, [&](Operation *c) {
+ return dominanceInfo.dominates(user, c);
+ })) {
+ return &use;
+ }
+ }
+ return std::nullopt;
+}
+
+/// Returns true if the operands are fusable.
+static bool areOpsFusable(Operation *producer, Operation *consumer,
+ const llvm::SmallBitVector &rootOuterParallelLoops) {
+ // Collect all the uses from producer to consumer.
+ SmallVector<OpOperand *> allUses;
+ for (OpOperand &producerUse : producer->getUses()) {
+ if (producerUse.getOwner() != consumer)
+ continue;
+ allUses.push_back(&producerUse);
+ }
+
+ // Check that the consumer and producer have compatible outer parallel loops.
+ if (!llvm::all_of(allUses, [&](OpOperand *operand) {
+ return hasCompatibleOuterParallelLoops(*operand,
+ rootOuterParallelLoops);
+ })) {
+ return false;
+ }
+ return true;
+}
+
+/// For the fusion of root op -> elementwise operation to be bufferized
+/// in-place without use of extra memory, the result of the root operation
+/// must be able to reuse the buffer for the result of the elementwise
+/// operation. Check if that is possible for the input/init operand pair.
+static bool canUseInOperandAsInitOperand(OpOperand *inOperand,
+ OpOperand *initOperand) {
+ assert(inOperand->getOwner() == initOperand->getOwner() &&
+ "expected in-operand and init-operand to be owned by same operation");
+
+ // Check that the owner is a `generic` op.
+ auto genericOp = dyn_cast<linalg::GenericOp>(inOperand->getOwner());
+ if (!genericOp)
+ return false;
+
+ // All loops to be parallel.
+ if (genericOp.getNumLoops() != genericOp.getNumParallelLoops()) {
+ return false;
+ }
+
+ /// The input operand cannot be an init operand already.
+ if (genericOp.isDpsInit(inOperand))
+ return false;
+
+ // If the init operand value is used it cannot be reused for the input
+ // operand.
+ if (genericOp.payloadUsesValueFromOperand(initOperand))
+ return false;
+
+ // Indexing map used to access the input and init have to match.
+ if (genericOp.getMatchingIndexingMap(inOperand) !=
+ genericOp.getMatchingIndexingMap(initOperand)) {
+ return false;
+ }
+
+ // Types have to match for the input operand to reuse the buffer from the init
+ // operand
+ if (inOperand->get().getType() != initOperand->get().getType())
+ return false;
+
+ return true;
+}
+
+/// All operations in a dispatch should be vectorized, which isnt the case today
+/// This is an explicit list of operations that arent vectorized for now
+/// requiring special handling for now in dispatch region formation to avoid
+/// large stack allocations.
+static bool isVectorizedAlways(Operation *producer) {
+ // TODO(#17155) : This is a black list of operations that are not vectorized
+ // today (under the aggressive fusion flag). Remove this blacklist to return
+ // true always.
+ if (auto convOp = dyn_cast<linalg::Conv2DNhwcHwcfOp>(producer)) {
+ auto strides = convOp.getStrides();
+ return strides.isSplat() && strides.getSplatValue<int64_t>() == 1;
+ }
+ return true;
+}
+
+/// Returns true if this is a fusable use, while fusing a root with its
+/// consumer.
+static bool
+isFusableWithConsumer(OpOperand &fusedOperand,
+ const llvm::SmallBitVector &rootOuterParallelLoops,
+ FormDispatchRegionsPassOptions const &options) {
+ Operation *producer = fusedOperand.get().getDefiningOp();
+ Operation *consumer = fusedOperand.getOwner();
+
+ // If consumer is a dequant operation, dont fuse it. These get cloned
+ // into their consumers.
+ if (IREE::LinalgExt::isBitExtendOp(consumer)) {
+ return false;
+ }
+
+ // Fuse unset_encoding operations with `tensor.extract_slice` and elementwise
+ // generic ops.
+ if (isUnpackLikeOpViaExtractSliceOps(producer)) {
+ // Fuse `unset_encoding` -> `extract_slice` op since they get folded into
+ // `unpack` on materialization.
+ if (isa<tensor::ExtractSliceOp>(consumer)) {
+ auto sliceOp = cast<tensor::ExtractSliceOp>(consumer);
+ return llvm::all_of(
+ sliceOp.getMixedOffsets(),
+ [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) &&
+ llvm::all_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
+ return isConstantIntValue(ofr, 1);
+ });
+ }
+ // Fuse `unset_encoding/unpack` -> elementwise operations. Fuse unpack with
+ // non-overlapping reductions (i.e., the reduction dimension is not packed).
+ if (auto consumerLinalgOp = dyn_cast<linalg::LinalgOp>(consumer)) {
+ if (hasNoPackedReductionDimensions(consumerLinalgOp, producer)) {
+ return true;
+ }
+ return linalg::isElementwise(consumerLinalgOp) &&
+ consumerLinalgOp.getNumLoops() ==
+ llvm::cast<RankedTensorType>(producer->getResult(0).getType())
+ .getRank();
+ }
+ return false;
+ }
+
+ if (isPackLikeOp(consumer)) {
+ return TypeSwitch<Operation *, bool>(producer)
+ .Case<tensor::PadOp>([&](auto padOp) { return true; })
+ .Case<linalg::LinalgOp>([&](auto linalgOp) {
+ auto producerIndexingMap = linalgOp.getIndexingMapMatchingResult(
+ llvm::cast<OpResult>(fusedOperand.get()));
+ // Make sure the producer op has an identitiy result indexing map. As
+ // CPU backend currently can't handle tranpose between fused ops.
+ return hasCompatibleOuterParallelLoops(
+ cast<TilingInterface>(linalgOp.getOperation()),
+ producerIndexingMap, rootOuterParallelLoops);
+ })
+ .Default([](Operation *) { return false; });
+ }
+
+ // By default, padding should be fused with producers. It is hard to square
+ // this with fusion of pad with consumer. So for now split the difference.
+ // Either fuse pad with producer or with consumer.
+ if (auto padOp = dyn_cast<tensor::PadOp>(consumer)) {
+ if (options.fusePadWithProducers || isPadUsedInSetEncoding(padOp)) {
+ return isa<linalg::LinalgOp>(producer);
+ }
+ return false;
+ }
+
+ // Insert slice ops should always be fused with their producers.
+ if (auto insertSliceOp = dyn_cast<tensor::InsertSliceOp>(consumer)) {
+ // TODO: Enable multi-use slice source fusion.
+ Value source = insertSliceOp.getSource();
+ if (!source.hasOneUse() || source.getDefiningOp() != producer) {
+ return false;
+ }
+ // Fuse in `insert_slice` consumer operations if destination is a fill.
+ // TODO: This can be generalized, but destination cannot be a
+ // `arith.constant` or other constant-like objects. `linalg.fill` captures a
+ // common case of pad generalization.
+ return insertSliceOp.getDest().getDefiningOp<linalg::FillOp>();
+ }
+
+ // TODO(#16025): Enable mmt4d fusion. It is disabled because the backends
+ // can not set multi lowering_config properly. See the issue for more details.
+ if (isa<linalg::Mmt4DOp>(producer)) {
+ return false;
+ }
+
+ auto producerFusionOp =
+ dyn_cast<IREE::LinalgExt::LinalgFusionOpInterface>(producer);
+ auto consumerFusionOp =
+ dyn_cast<IREE::LinalgExt::LinalgFusionOpInterface>(consumer);
+ if (!producerFusionOp || !consumerFusionOp)
+ return false;
+
+ // Check that the consumer is all parallel.
+ if (consumerFusionOp.getNumLoops() !=
+ consumerFusionOp.getNumParallelLoops()) {
+ return false;
+ }
+
+ if (!areOpsFusable(producer, consumer, rootOuterParallelLoops)) {
+ return false;
+ }
+
+ // Check if the iteration spaces of the producer and consumer are same.
+ // TODO(#12664): This is unnecessary requirement, but we need a better config
+ // to tile the consumer with a larger iteration space.
+ if (!options.aggressiveFusion) {
+ auto producerIterationSpace = producerFusionOp.getStaticLoopRanges();
+ auto consumerIterationSpace = consumerFusionOp.getStaticLoopRanges();
+ if (producerIterationSpace.size() < consumerIterationSpace.size()) {
+ return false;
+ }
+ }
+
+ // Under aggressive fusion assume that the dispatches are vectorized. In which
+ // case we dont need to account for the subsequent stack allocation condition.
+ if (options.aggressiveFusion) {
+ if (isVectorizedAlways(producer)) {
+ return true;
+ }
+ }
+
+ // While fusing with consumer, the result of the root might not be the final
+ // result of the dispatch. To avoid a stack allocation we have to ensure that
+ // all operations can bufferize without needing additional memory.
+ auto consumerDstOp =
+ dyn_cast<DestinationStyleOpInterface>(consumerFusionOp.getOperation());
+ if (!consumerDstOp) {
+ return true;
+ }
+
+ for (OpOperand *inputOperand : consumerDstOp.getDpsInputOperands()) {
+ if (inputOperand->get().getDefiningOp() != producer)
+ continue;
+ if (isa<linalg::ConvolutionOpInterface>(producer) &&
+ !llvm::any_of(
+ consumerDstOp.getDpsInitsMutable(), [&](OpOperand &initOperand) {
+ return canUseInOperandAsInitOperand(inputOperand, &initOperand);
+ })) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+/// Fuses roots with its consumers. If a root is fused with its consumer, it is
+/// no more tagged as a root to aid with the dispatch region formation.
+static void
+fuseRootsWithConsumers(MLIRContext *context, ArrayRef<Operation *> roots,
+ DominanceInfo const &dominanceInfo,
+ FormDispatchRegionsPassOptions const &options) {
+ // Fuse with consumers where possible.
+ for (Operation *root : roots) {
+ SmallVector<Operation *> workList;
+ llvm::SmallBitVector rootOuterParallelLoops = getOuterParallelLoops(root);
+ workList.push_back(root);
+ while (!workList.empty()) {
+ Operation *currRoot = workList.pop_back_val();
+ assert(hasRootOpAttribute(currRoot) &&
+ "unexpected non-root op in worklist");
+
+ // Helper function to make the consumer the root instead of the producer
+ // when they are to be fused.
+ auto updateRootTo = [&context, &currRoot](Operation *newRoot) {
+ int64_t rootNumber = getRootNumber(currRoot);
+ setRootAttribute(context, newRoot, rootNumber);
+ removeRootOpAttribute(currRoot);
+ appendToFusionGroup(currRoot, rootNumber);
+ };
+
+ std::optional<OpOperand *> fusableUse =
+ getFusableUse(currRoot, dominanceInfo,
+ /*aggressiveFusion=*/options.aggressiveFusion);
+ if (!fusableUse)
+ continue;
+
+ // Analyse the use to see if it is fusable.
+ Operation *consumerOp = fusableUse.value()->getOwner();
+ if (hasRootOpAttribute(consumerOp) ||
+ hasFusionGroupsAttribute(consumerOp)) {
+ continue;
+ }
+
+ if (isFusableWithConsumer(*(fusableUse.value()), rootOuterParallelLoops,
+ options)) {
+ updateRootTo(consumerOp);
+ workList.push_back(consumerOp);
+ }
+ }
+ }
+}
+
+/// Method to check if the consumer of a use can be fused with its producer.
+static bool
+isFusableWithProducer(OpOperand &operand,
+ const llvm::SmallBitVector &rootOuterParallelLoops,
+ FormDispatchRegionsPassOptions const &options) {
+ Operation *producer = operand.get().getDefiningOp();
+ Operation *consumer = operand.getOwner();
+
+ if (auto padOp = dyn_cast<tensor::PadOp>(consumer)) {
+ if (options.fusePadWithProducers || isPadUsedInSetEncoding(padOp)) {
+ return isa<linalg::LinalgOp>(producer);
+ }
+ return false;
+ }
+
+ if (options.fusePadWithConsumers && isa<tensor::PadOp>(producer) &&
+ isa<linalg::ConvolutionOpInterface>(consumer)) {
+ return true;
+ }
+
+ // Don't fuse attention with it's producer
+ if (isa<IREE::LinalgExt::AttentionOp>(consumer)) {
+ return false;
+ }
+
+ if (isPackLikeOp(consumer)) {
+ return TypeSwitch<Operation *, bool>(producer)
+ .Case<tensor::PadOp>([&](auto padOp) { return true; })
+ .Case<linalg::LinalgOp>([&](auto linalgOp) {
+ if (auto packOp = dyn_cast<tensor::PackOp>(consumer)) {
+ // TODO(#12746): fusion of pack with dynamic inner tile size
+ // causes an error in backend. Disable for now.
+ if (!packOp.getInnerTiles().empty()) {
+ return false;
+ }
+ }
+ auto producerIndexingMap = linalgOp.getIndexingMapMatchingResult(
+ llvm::cast<OpResult>(operand.get()));
+ // Make sure the producer op has an identitiy result indexing map. As
+ // CPU backend currently can't handle tranpose between fused ops.
+ return hasCompatibleOuterParallelLoops(
+ cast<TilingInterface>(linalgOp.getOperation()),
+ producerIndexingMap, rootOuterParallelLoops);
+ })
+ .Default([](Operation *) { return false; });
+ }
+
+ if (!isa<IREE::LinalgExt::LinalgFusionOpInterface>(consumer) ||
+ !isa<IREE::LinalgExt::LinalgFusionOpInterface>(producer)) {
+ return false;
+ }
+
+ if (!options.aggressiveFusion) {
+ auto consumerFusionOp = dyn_cast<DestinationStyleOpInterface>(consumer);
+ if (consumerFusionOp && !consumerFusionOp.isDpsInit(&operand)) {
+ return false;
+ }
+ }
+
+ return areOpsFusable(producer, consumer, rootOuterParallelLoops);
+}
+
+/// Starting from the `root` op, traverse the operand use-def chain
+/// in reverse to fuse with producers.
+static void
+fuseRootsWithProducers(MLIRContext *context, Operation *root, unsigned groupNum,
+ DominanceInfo const &dominanceInfo,
+ FormDispatchRegionsPassOptions const &options) {
+ SmallVector<Operation *> worklist;
+ worklist.push_back(root);
+ llvm::SmallBitVector rootOuterParallelLoops = getOuterParallelLoops(root);
+ while (!worklist.empty()) {
+ Operation *candidate = worklist.pop_back_val();
+ for (OpOperand &operand : candidate->getOpOperands()) {
+ Operation *producer = operand.get().getDefiningOp();
+ if (!producer)
+ continue;
+ if (IREE::Flow::isClonableIntoDispatchOp(producer) ||
+ hasFusionGroupsAttribute(producer) || hasRootOpAttribute(producer)) {
+ continue;
+ }
+
+ std::optional<OpOperand *> fusableUse =
+ getFusableUse(producer, dominanceInfo,
+ /*aggressiveFusion=*/options.aggressiveFusion);
+ if (!fusableUse || fusableUse.value()->getOwner() != candidate)
+ continue;
+
+ if (!isFusableWithProducer(operand, rootOuterParallelLoops, options)) {
+ continue;
+ }
+
+ appendToFusionGroup(producer, groupNum);
+ worklist.push_back(producer);
+ }
+ }
+}
+
+/// Some heuristic is needed to fuse a dispatchable op with root operations
+/// using tile + fuse. Using some heuristic, each root operation is tagged with
+/// an ID (using an IntegerAttr with name `kRootOpAttr`) and all dispatchable
+/// ops to be fused with it is tagged with the same ID (using a list of
+/// IntegerAttr with name `kFusionGroupsAttr`). Each dispatchable operation can
+/// be marked to fuse with multiple root operations (i.e. replicated). For now a
+/// very simple heuristic is used below, but the mechanism should be general
+/// enough to capture any heuristic.
+static unsigned
+decideFusableLinalgOps(Region ®ion, DominanceInfo const &dominanceInfo,
+ FormDispatchRegionsPassOptions const &options,
+ unsigned numRootOps = 0) {
+ MLIRContext *context = region.getContext();
+ OpBuilder builder(context);
+ for (Block &block : region) {
+ // Dispatch region formation works by first cloning the root into
+ // the dispatch region and then pulling operations in.
+ // So procedure here is to
+ // - First find the roots
+ // - To fuse with consumers make the consumer the root.
+ SmallVector<Operation *> roots;
+ for (Operation &op : llvm::reverse(block)) {
+ if (isa<scf::SCFDialect>(op.getDialect())) {
+ for (auto ®ion : op.getRegions()) {
+ numRootOps = decideFusableLinalgOps(region, dominanceInfo, options,
+ numRootOps);
+ }
+ continue;
+ }
+
+ // Start with a root operation and fuse its producers.
+ if (hasFusionGroupsAttribute(&op) || !isRootOp(&op))
+ continue;
+ unsigned newGroup = numRootOps++;
+ setRootAttribute(context, &op, newGroup);
+
+ fuseRootsWithProducers(context, &op, newGroup, dominanceInfo, options);
+ roots.push_back(&op);
+ }
+ roots = llvm::to_vector(llvm::reverse(roots));
+ fuseRootsWithConsumers(context, roots, dominanceInfo, options);
+ }
+
+ // Once all root linalg ops have been tagged, put all remaining generic ops
+ // into their own dispatches.
+ for (Block &block : region) {
+ SmallVector<Operation *> roots;
+ for (Operation &op : llvm::reverse(block)) {
+ // If it is part of a fusion group or root op, ignore it.
+ if (hasFusionGroupsAttribute(&op) || hasRootOpAttribute(&op))
+ continue;
+ // Only look for Linalg ops here. Avoid moving `linalg.fill` that aren't
+ // fused with anything else into their own dispatches since it is better
+ // to convert them to splats. Also avoid moving dequantization-like ops
+ // into their own dispatch since it is better to clone these ops and avoid
+ // materializing large tensors between dispatches.
+ if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp,
+ IREE::Encoding::SetEncodingOp>(op) ||
+ isa<linalg::FillOp>(op) || IREE::LinalgExt::isBitExtendOp(&op)) {
+ continue;
+ }
+
+ unsigned newGroup = numRootOps++;
+ setRootAttribute(context, &op, newGroup);
+
+ fuseRootsWithProducers(context, &op, newGroup, dominanceInfo, options);
+ roots.push_back(&op);
+ }
+ roots = llvm::to_vector(llvm::reverse(roots));
+ fuseRootsWithConsumers(context, roots, dominanceInfo, options);
+ }
+
+ return numRootOps;
+}
+
+//===----------------------------------------------------------------------===//
+// Dispatch region formation
+//===----------------------------------------------------------------------===//
+
+/// Create IREE::Flow::DispatchGroupsOps based on a fusion heuristic.
+static LogicalResult
+createFusionGroups(TensorDimTrackingRewriter &rewriter,
+ mlir::FunctionOpInterface funcOp,
+ DominanceInfo const &dominanceInfo,
+ FormDispatchRegionsPassOptions const &options) {
+ // Step 1: Decide fusion groups (heuristic). This marks rootOps with an
+ // attribute
+ unsigned numRoots =
+ decideFusableLinalgOps(funcOp.getFunctionBody(), dominanceInfo, options);
+ SmallVector<Operation *> roots(numRoots, nullptr);
+ DenseMap<unsigned, SmallVector<Operation *>> producers;
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "\n--- After deciding fusion groups ---\n";
+ funcOp->print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+
+ // TODO: Incrementally add ops to an empty DispatchGroupOp instead of
+ // annotating fusion group IDs via attributes.
+ funcOp.walk([&](Operation *op) {
+ if (hasRootOpAttribute(op)) {
+ roots[getRootNumber(op)] = op;
+ removeRootOpAttribute(op);
+ }
+ if (hasFusionGroupsAttribute(op)) {
+ assert(getFusionGroups(op).size() == 1 && "expected exactly one group");
+ producers[getFusionGroups(op).front()].push_back(op);
+ removeFusionGroupsAttribute(op);
+ }
+ });
+
+ // Step 2. Create a DispatchRegionOp for every fusion group.
+ OpBuilder::InsertionGuard g(rewriter);
+ SmallVector<IREE::Flow::DispatchRegionOp> regionOps;
+ for (const auto &it : llvm::enumerate(roots)) {
+ // Simplify tensor::DimOps.
+ {
+ SmallVector<tensor::DimOp> dimOps = rewriter.getTensorDimOps();
+ if (failed(IREE::Flow::simplifyDimOps(rewriter, dimOps))) {
+ return failure();
+ }
+ }
+
+ // Create fusion group.
+ IREE::Flow::DispatchRegionOp regionOp;
+ auto maybeRegionOp =
+ IREE::Flow::wrapOpInDispatchRegion(rewriter, it.value());
+ if (failed(maybeRegionOp))
+ return failure();
+ regionOp = *maybeRegionOp;
+
+ // Sort producers topologically. All producers must be in the same block
+ // as the root.
+ bool sortResult = mlir::computeTopologicalSorting(producers[it.index()]);
+ (void)sortResult;
+ assert(sortResult && "could not compute topological sorting");
+
+ // Move ops into the region.
+ for (Operation *producer : llvm::reverse(producers[it.index()])) {
+ // Simplify tensor::DimOps.
+ {
+ SmallVector<tensor::DimOp> dimOps = rewriter.getTensorDimOps();
+ if (failed(IREE::Flow::simplifyDimOps(rewriter, dimOps))) {
+ return failure();
+ }
+ }
+
+ auto newRegionOp =
+ movePrecedingOpsIntoDispatchRegion(rewriter, producer, regionOp);
+ if (failed(newRegionOp))
+ return failure();
+ regionOp = *newRegionOp;
+ }
+ // Simplify tensor::DimOps.
+ {
+ SmallVector<tensor::DimOp> dimOps = rewriter.getTensorDimOps();
+ if (failed(IREE::Flow::simplifyDimOps(rewriter, dimOps))) {
+ return failure();
+ }
+ }
+ regionOps.push_back(regionOp);
+ }
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "\n--- After creating flow.dispatch.region ---\n";
+ funcOp->print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+
+ return success();
+}
+
+namespace {
+/// Pass declaration.
+struct FormDispatchRegionsPass final
+ : public impl::FormDispatchRegionsPassBase<FormDispatchRegionsPass> {
+ using Base::Base;
+ void runOnOperation() override;
+};
+} // namespace
+
+/// Create dispatch.region Ops based on a fusion heuristic.
+void FormDispatchRegionsPass::runOnOperation() {
+ mlir::FunctionOpInterface funcOp = getOperation();
+ DominanceInfo const &dominanceInfo = getAnalysis<DominanceInfo>();
+ TensorDimTrackingRewriter rewriter(funcOp);
+ FormDispatchRegionsPassOptions options{aggressiveFusion, fusePadWithConsumers,
+ fusePadWithProducers};
+ if (failed(createFusionGroups(rewriter, funcOp, dominanceInfo, options))) {
+ funcOp->emitOpError("failed to create fusion groups");
+ return signalPassFailure();
+ }
+}
+} // namespace mlir::iree_compiler::DispatchCreation
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormScalarDispatches.cpp b/compiler/src/iree/compiler/DispatchCreation/FormScalarDispatches.cpp
similarity index 89%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/FormScalarDispatches.cpp
rename to compiler/src/iree/compiler/DispatchCreation/FormScalarDispatches.cpp
index fcc9d56..16e5a32 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormScalarDispatches.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/FormScalarDispatches.cpp
@@ -4,8 +4,9 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
+#include "iree/compiler/DispatchCreation/Passes.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/SliceAnalysis.h"
@@ -17,19 +18,18 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/RegionUtils.h"
-#define DEBUG_TYPE "iree-flow-form-scalar-dispatches"
+#define DEBUG_TYPE "iree-dispatch-creation-form-scalar-dispatches"
-namespace mlir::iree_compiler::IREE::Flow {
+namespace mlir::iree_compiler::DispatchCreation {
#define GEN_PASS_DEF_FORMSCALARDISPATCHESPASS
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc"
+#include "iree/compiler/DispatchCreation/Passes.h.inc"
namespace {
/// Pass declaration.
-struct FormScalarDispatchesPass
- : public IREE::Flow::impl::FormScalarDispatchesPassBase<
- FormScalarDispatchesPass> {
+struct FormScalarDispatchesPass final
+ : public impl::FormScalarDispatchesPassBase<FormScalarDispatchesPass> {
void runOnOperation() override;
};
} // namespace
@@ -89,7 +89,7 @@
// 3. Do not move operations that are cloned into the dispatch region.
// TODO: This might prevent moving all scalar operations into dispatch
// resulting in artifical splits. Revisit after more examples.
- return !isClonableIntoDispatchOp(op);
+ return !IREE::Flow::isClonableIntoDispatchOp(op);
}
/// Given a `rootOp` return a DAG of the program that represents
@@ -127,22 +127,22 @@
/// Return `true` if the op is to be treated as a root of a scalar dispatch.
static bool isSliceRoot(int workload, Operation *op) {
- return !op->getParentOfType<DispatchRegionOp>() &&
+ return !op->getParentOfType<IREE::Flow::DispatchRegionOp>() &&
isScalarOperation(workload, op);
}
// Form dispatch regions from slice of the operation.
-static FailureOr<DispatchRegionOp>
+static FailureOr<IREE::Flow::DispatchRegionOp>
formDispatchRegionFromSlice(RewriterBase &rewriter, Operation *rootOp,
ArrayRef<Operation *> slice) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(rootOp);
- FailureOr<DispatchRegionOp> dispatchRegionOp =
- wrapOpInDispatchRegion(rewriter, rootOp);
+ FailureOr<IREE::Flow::DispatchRegionOp> dispatchRegionOp =
+ IREE::Flow::wrapOpInDispatchRegion(rewriter, rootOp);
if (failed(dispatchRegionOp)) {
return rootOp->emitOpError("failed to form dispatch region with root op");
}
- FailureOr<DispatchRegionOp> newDispatchOp =
+ FailureOr<IREE::Flow::DispatchRegionOp> newDispatchOp =
movePrecedingOpsIntoDispatchRegion(rewriter, slice,
dispatchRegionOp.value());
if (failed(newDispatchOp)) {
@@ -268,8 +268,9 @@
IRRewriter rewriter(context);
for (auto &currDispatch : dispatches) {
rewriter.setInsertionPoint(currDispatch.rootOp);
- FailureOr<DispatchRegionOp> dispatchRegionOp = formDispatchRegionFromSlice(
- rewriter, currDispatch.rootOp, currDispatch.fusedOps);
+ FailureOr<IREE::Flow::DispatchRegionOp> dispatchRegionOp =
+ formDispatchRegionFromSlice(rewriter, currDispatch.rootOp,
+ currDispatch.fusedOps);
if (failed(dispatchRegionOp)) {
currDispatch.rootOp->emitOpError(
"failed to form scalar dispatch region with operation as root");
@@ -284,9 +285,9 @@
rewriter.setInsertionPointToStart(countBody);
auto one = rewriter.create<arith::ConstantIndexOp>(
dispatchRegionOp.value()->getLoc(), 1);
- rewriter.create<Flow::ReturnOp>(dispatchRegionOp.value()->getLoc(),
- ValueRange{one, one, one});
+ rewriter.create<IREE::Flow::ReturnOp>(dispatchRegionOp.value()->getLoc(),
+ ValueRange{one, one, one});
}
}
-} // namespace mlir::iree_compiler::IREE::Flow
+} // namespace mlir::iree_compiler::DispatchCreation
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FuseHorizontalContractions.cpp b/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
similarity index 98%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/FuseHorizontalContractions.cpp
rename to compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
index 7561627..a78b6b8 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FuseHorizontalContractions.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
@@ -5,9 +5,9 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
+#include "iree/compiler/DispatchCreation/Passes.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -29,16 +29,16 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#define DEBUG_TYPE "iree-flow-fuse-horizontal-contractions"
+#define DEBUG_TYPE "iree-dispatch-creation-fuse-horizontal-contractions"
-namespace mlir::iree_compiler::IREE::Flow {
+namespace mlir::iree_compiler::DispatchCreation {
#define GEN_PASS_DEF_FUSEHORIZONTALCONTRACTIONSPASS
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc"
+#include "iree/compiler/DispatchCreation/Passes.h.inc"
namespace {
-struct FuseHorizontalContractionsPass
+struct FuseHorizontalContractionsPass final
: public impl::FuseHorizontalContractionsPassBase<
FuseHorizontalContractionsPass> {
using Base::Base;
@@ -678,4 +678,4 @@
return signalPassFailure();
}
}
-} // namespace mlir::iree_compiler::IREE::Flow
+} // namespace mlir::iree_compiler::DispatchCreation
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FuseMultiUseElementwiseProducer.cpp b/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp
similarity index 93%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/FuseMultiUseElementwiseProducer.cpp
rename to compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp
index f31cfaa..9d9d477 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FuseMultiUseElementwiseProducer.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp
@@ -12,11 +12,11 @@
//
//===----------------------------------------------------------------------===//
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
+#include "iree/compiler/DispatchCreation/Passes.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
@@ -31,12 +31,12 @@
#include "mlir/IR/Iterators.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#define DEBUG_TYPE "iree-flow-fusion-of-tensor-ops"
+#define DEBUG_TYPE "iree-dispatch-creation-fusion-of-tensor-ops"
-namespace mlir::iree_compiler::IREE::Flow {
+namespace mlir::iree_compiler::DispatchCreation {
#define GEN_PASS_DEF_FUSEMULTIUSEELEMENTWISEPRODUCERPASS
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc"
+#include "iree/compiler/DispatchCreation/Passes.h.inc"
// TODO: Remove this and the backing code once consteval is beyond being
// rolled back.
@@ -147,7 +147,7 @@
DenseMap<Operation *, Operation *> opToRootMap;
funcOp->walk<WalkOrder::PostOrder, ReverseIterator>(
[&](linalg::GenericOp genericOp) {
- if (!isNonNullAndOutsideDispatch(genericOp)) {
+ if (!IREE::Flow::isNonNullAndOutsideDispatch(genericOp)) {
return;
}
@@ -158,7 +158,7 @@
// Dequantization-like operations should be fused with consumers to keep
// the smaller bit width on the dispatch boundary.
- if (LinalgExt::isBitExtendOp(genericOp)) {
+ if (IREE::LinalgExt::isBitExtendOp(genericOp)) {
return;
}
@@ -198,7 +198,7 @@
// 7. Skip dequantization-like `producer` ops as we would rather fuse
// by cloning the producer instead of multi-use fusion.
- if (LinalgExt::isBitExtendOp(producer)) {
+ if (IREE::LinalgExt::isBitExtendOp(producer)) {
return;
}
@@ -249,12 +249,10 @@
/// Pass to fuse linalg on tensor operations as well as fusion of hal.interface*
/// operations with linalg.tensor_reshape operation.
-struct FuseMultiUseElementwiseProducerPass
- : public IREE::Flow::impl::FuseMultiUseElementwiseProducerPassBase<
+struct FuseMultiUseElementwiseProducerPass final
+ : public impl::FuseMultiUseElementwiseProducerPassBase<
FuseMultiUseElementwiseProducerPass> {
- using IREE::Flow::impl::FuseMultiUseElementwiseProducerPassBase<
- FuseMultiUseElementwiseProducerPass>::
- FuseMultiUseElementwiseProducerPassBase;
+ using Base::Base;
void runOnOperation() override;
};
@@ -281,4 +279,4 @@
}
}
-} // namespace mlir::iree_compiler::IREE::Flow
+} // namespace mlir::iree_compiler::DispatchCreation
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionPreprocessing.cpp b/compiler/src/iree/compiler/DispatchCreation/FusionPreprocessing.cpp
similarity index 92%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionPreprocessing.cpp
rename to compiler/src/iree/compiler/DispatchCreation/FusionPreprocessing.cpp
index 28932b3..3fdafd7 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionPreprocessing.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/FusionPreprocessing.cpp
@@ -10,9 +10,9 @@
//
//===----------------------------------------------------------------------===//
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
+#include "iree/compiler/DispatchCreation/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
@@ -31,10 +31,10 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-namespace mlir::iree_compiler::IREE::Flow {
+namespace mlir::iree_compiler::DispatchCreation {
#define GEN_PASS_DEF_FUSIONPREPROCESSINGPASS
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc"
+#include "iree/compiler/DispatchCreation/Passes.h.inc"
namespace {
@@ -43,9 +43,9 @@
//===----------------------------------------------------------------------===//
// If possible, interchange indexing maps to make input maps all identity.
-struct ElementwiseOpInterchangePattern
+struct ElementwiseOpInterchangePattern final
: public OpRewritePattern<linalg::GenericOp> {
- using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
PatternRewriter &rewriter) const override {
if (!linalg::isElementwise(genericOp) || genericOp.getNumResults() != 1 ||
@@ -95,7 +95,7 @@
/// %2 = linalg.fill ins(%cst : )
/// %3 = tensor.insert_slice %a into %2
/// ```
-struct FoldSuccessiveTensorInsertSliceOps
+struct FoldSuccessiveTensorInsertSliceOps final
: public OpRewritePattern<tensor::InsertSliceOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::InsertSliceOp sliceOp,
@@ -157,8 +157,8 @@
// cannot be fused because it there is no producer-consumer
// relationship between the two generics. This is because the indexing
// is not affine (index values come from a tensor).
-struct GatherFusionPattern : public OpRewritePattern<tensor::ExtractOp> {
- using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
+struct GatherFusionPattern final : public OpRewritePattern<tensor::ExtractOp> {
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
PatternRewriter &rewriter) const override {
// Check if extractOp is inside a generic op
@@ -177,7 +177,8 @@
// Check if the producerOp is fusible
if (producerOp.getNumDpsInputs() != 1 || producerOp.getNumResults() != 1 ||
- !isElementwise(producerOp) || !LinalgExt::isBitExtendOp(producerOp)) {
+ !isElementwise(producerOp) ||
+ !IREE::LinalgExt::isBitExtendOp(producerOp)) {
return rewriter.notifyMatchFailure(producerOp,
"producer op is not fusible");
}
@@ -211,9 +212,8 @@
}
};
-struct FusionPreprocessingPass
- : public IREE::Flow::impl::FusionPreprocessingPassBase<
- FusionPreprocessingPass> {
+struct FusionPreprocessingPass final
+ : public impl::FusionPreprocessingPassBase<FusionPreprocessingPass> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns.add<ElementwiseOpInterchangePattern,
@@ -234,4 +234,4 @@
} // namespace
-} // namespace mlir::iree_compiler::IREE::Flow
+} // namespace mlir::iree_compiler::DispatchCreation
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.cpp b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp
similarity index 93%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.cpp
rename to compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp
index 70f4a17..c428091 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp
@@ -7,12 +7,12 @@
//===--- FusionUtils.h --- Implementation of fusion utility functions -----===//
//===----------------------------------------------------------------------===//
-#include "compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.h"
+#include "compiler/src/iree/compiler/DispatchCreation/FusionUtils.h"
#include "compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
-namespace mlir::iree_compiler::IREE::Flow {
+namespace mlir::iree_compiler::DispatchCreation {
bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand,
bool fuseMultiReduction) {
@@ -65,7 +65,7 @@
// (except for bit-extend ops). If the consumer has only one use, then this
// fusion is fine since cloning wont result in redundant computation of the
// producer. (Also note that the producer is always an elementwise operation).
- if (LinalgExt::isBitExtendOp(consumerOp) && !consumerOp->hasOneUse()) {
+ if (IREE::LinalgExt::isBitExtendOp(consumerOp) && !consumerOp->hasOneUse()) {
return false;
}
@@ -97,4 +97,4 @@
return true;
}
-} // namespace mlir::iree_compiler::IREE::Flow
+} // namespace mlir::iree_compiler::DispatchCreation
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.h b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h
similarity index 87%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.h
rename to compiler/src/iree/compiler/DispatchCreation/FusionUtils.h
index dd4f96a..1d9c930 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.h
+++ b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h
@@ -12,11 +12,11 @@
#include "mlir/IR/Operation.h"
-namespace mlir::iree_compiler::IREE::Flow {
+namespace mlir::iree_compiler::DispatchCreation {
/// Return true of the producer and consumer of `operand` are fusable
/// using elementwise op fusion transformation.
bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *operand,
bool fuseMultiReduction);
-} // namespace mlir::iree_compiler::IREE::Flow
+} // namespace mlir::iree_compiler::DispatchCreation
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/HoistEncodingOps.cpp b/compiler/src/iree/compiler/DispatchCreation/HoistEncodingOps.cpp
similarity index 82%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/HoistEncodingOps.cpp
rename to compiler/src/iree/compiler/DispatchCreation/HoistEncodingOps.cpp
index 0068744..ce30ab1 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/HoistEncodingOps.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/HoistEncodingOps.cpp
@@ -8,9 +8,9 @@
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
+#include "iree/compiler/DispatchCreation/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -29,11 +29,11 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#define DEBUG_TYPE "iree-flow-hoist-encoding-ops"
+#define DEBUG_TYPE "iree-dispatch-creation-hoist-encoding-ops"
-namespace mlir::iree_compiler::IREE::Flow {
+namespace mlir::iree_compiler::DispatchCreation {
#define GEN_PASS_DEF_HOISTENCODINGOPSPASS
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc"
+#include "iree/compiler/DispatchCreation/Passes.h.inc"
static AffineMap getBcastMapOrIdentity(RewriterBase &rewriter,
RankedTensorType encodedType) {
@@ -63,7 +63,7 @@
/// be well tested.
static LogicalResult
bubbleUpSetEncodingThroughGenericOp(RewriterBase &rewriter,
- Encoding::SetEncodingOp encodingOp,
+ IREE::Encoding::SetEncodingOp encodingOp,
linalg::GenericOp genericOp) {
if (!genericOp->hasOneUse()) {
return rewriter.notifyMatchFailure(genericOp,
@@ -111,8 +111,8 @@
auto operandType = cast<RankedTensorType>(operand->get().getType());
auto resType = RankedTensorType::get(
operandType.getShape(), operandType.getElementType(), newEncoding);
- Value encodedInput =
- rewriter.create<Encoding::SetEncodingOp>(loc, resType, operand->get());
+ Value encodedInput = rewriter.create<IREE::Encoding::SetEncodingOp>(
+ loc, resType, operand->get());
encodedOperands.push_back(encodedInput);
}
@@ -131,14 +131,14 @@
static LogicalResult bubbleUpSetEncoding(RewriterBase &rewriter,
OpOperand &operand) {
- auto setEncoding = cast<Encoding::SetEncodingOp>(operand.getOwner());
+ auto setEncoding = cast<IREE::Encoding::SetEncodingOp>(operand.getOwner());
auto producer = operand.get().getDefiningOp<linalg::GenericOp>();
if (!producer) {
return failure();
}
// Only bubble through dequantization ops and broadcasting ops for now.
- if (!LinalgExt::isBitExtendOp(producer) &&
- !LinalgExt::isBroadcastingOp(producer)) {
+ if (!IREE::LinalgExt::isBitExtendOp(producer) &&
+ !IREE::LinalgExt::isBroadcastingOp(producer)) {
return failure();
}
return bubbleUpSetEncodingThroughGenericOp(rewriter, setEncoding, producer);
@@ -147,9 +147,8 @@
namespace {
/// Pass declaration.
struct HoistEncodingOpsPass
- : public IREE::Flow::impl::HoistEncodingOpsPassBase<HoistEncodingOpsPass> {
- using IREE::Flow::impl::HoistEncodingOpsPassBase<
- HoistEncodingOpsPass>::HoistEncodingOpsPassBase;
+ : public impl::HoistEncodingOpsPassBase<HoistEncodingOpsPass> {
+ using Base::Base;
void runOnOperation() override;
};
@@ -157,12 +156,12 @@
/// runs until bubbling is not possible, or until the SetEncoding op is outside
/// of a dispatch.
struct BubbleUpSetEncodingOp
- : public OpRewritePattern<Encoding::SetEncodingOp> {
- using OpRewritePattern<Encoding::SetEncodingOp>::OpRewritePattern;
+ : public OpRewritePattern<IREE::Encoding::SetEncodingOp> {
+ using OpRewritePattern<IREE::Encoding::SetEncodingOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(Encoding::SetEncodingOp encodingOp,
+ LogicalResult matchAndRewrite(IREE::Encoding::SetEncodingOp encodingOp,
PatternRewriter &rewriter) const override {
- if (isNonNullAndOutsideDispatch(encodingOp)) {
+ if (IREE::Flow::isNonNullAndOutsideDispatch(encodingOp)) {
return failure();
}
// Fail if the encodingOp is not in the same dispatch as its producer.
@@ -170,9 +169,10 @@
if (!producer) {
return failure();
}
- auto dispatch = producer->getParentOfType<DispatchRegionOp>();
+ auto dispatch = producer->getParentOfType<IREE::Flow::DispatchRegionOp>();
if (!dispatch ||
- dispatch != encodingOp->getParentOfType<DispatchRegionOp>()) {
+ dispatch !=
+ encodingOp->getParentOfType<IREE::Flow::DispatchRegionOp>()) {
return failure();
}
@@ -194,24 +194,24 @@
return signalPassFailure();
}
- SmallVector<Encoding::SetEncodingOp> candidates;
- funcOp->walk([&](Encoding::SetEncodingOp setEncodingOp) {
- if (setEncodingOp->getParentOfType<DispatchRegionOp>()) {
+ SmallVector<IREE::Encoding::SetEncodingOp> candidates;
+ funcOp->walk([&](IREE::Encoding::SetEncodingOp setEncodingOp) {
+ if (setEncodingOp->getParentOfType<IREE::Flow::DispatchRegionOp>()) {
candidates.push_back(setEncodingOp);
}
});
IRRewriter rewriter(ctx);
for (auto setEncodingOp : candidates) {
- if (failed(hoistOutOfDispatch(rewriter, setEncodingOp))) {
+ if (failed(IREE::Flow::hoistOutOfDispatch(rewriter, setEncodingOp))) {
return signalPassFailure();
}
}
RewritePatternSet cleanPatterns(ctx);
memref::populateResolveRankedShapedTypeResultDimsPatterns(cleanPatterns);
- DispatchRegionOp::getCanonicalizationPatterns(cleanPatterns, ctx);
+ IREE::Flow::DispatchRegionOp::getCanonicalizationPatterns(cleanPatterns, ctx);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(cleanPatterns)))) {
return signalPassFailure();
}
}
-} // namespace mlir::iree_compiler::IREE::Flow
+} // namespace mlir::iree_compiler::DispatchCreation
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/MaterializeDefaultWorkgroupCountRegion.cpp b/compiler/src/iree/compiler/DispatchCreation/MaterializeDefaultWorkgroupCountRegion.cpp
similarity index 87%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/MaterializeDefaultWorkgroupCountRegion.cpp
rename to compiler/src/iree/compiler/DispatchCreation/MaterializeDefaultWorkgroupCountRegion.cpp
index 717edbc..a858bc2 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/MaterializeDefaultWorkgroupCountRegion.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/MaterializeDefaultWorkgroupCountRegion.cpp
@@ -8,8 +8,8 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.h"
#include "iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.h"
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
+#include "iree/compiler/DispatchCreation/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
@@ -23,12 +23,13 @@
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Support/LLVM.h"
-#define DEBUG_TYPE "iree-flow-materialize-default-workgroup-count-region"
+#define DEBUG_TYPE \
+ "iree-dispatch-creation-materialize-default-workgroup-count-region"
-namespace mlir::iree_compiler::IREE::Flow {
+namespace mlir::iree_compiler::DispatchCreation {
#define GEN_PASS_DEF_MATERIALIZEDEFAULTWORKGROUPCOUNTREGIONPASS
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc"
+#include "iree/compiler/DispatchCreation/Passes.h.inc"
/// Creates the workgroup count region where the materialized computation
/// is derived as a program slice of the body of the dispatch. This method
@@ -99,11 +100,9 @@
namespace {
struct MaterializeDefaultWorkgroupCountRegionPass
- : public IREE::Flow::impl::MaterializeDefaultWorkgroupCountRegionPassBase<
+ : public impl::MaterializeDefaultWorkgroupCountRegionPassBase<
MaterializeDefaultWorkgroupCountRegionPass> {
- using IREE::Flow::impl::MaterializeDefaultWorkgroupCountRegionPassBase<
- MaterializeDefaultWorkgroupCountRegionPass>::
- MaterializeDefaultWorkgroupCountRegionPassBase;
+ using Base::Base;
void runOnOperation() override;
};
} // namespace
@@ -115,9 +114,9 @@
// Populate the workgroup_count region of flow.dispatch.workgroups operation
// that dont already have a region
- funcOp.walk([&](Flow::DispatchWorkgroupsOp workgroupsOp) {
+ funcOp.walk([&](IREE::Flow::DispatchWorkgroupsOp workgroupsOp) {
createDefaultWorkgroupCountRegion(rewriter, workgroupsOp);
});
}
-} // namespace mlir::iree_compiler::IREE::Flow
+} // namespace mlir::iree_compiler::DispatchCreation
diff --git a/compiler/src/iree/compiler/DispatchCreation/Passes.cpp b/compiler/src/iree/compiler/DispatchCreation/Passes.cpp
new file mode 100644
index 0000000..9f18ceb
--- /dev/null
+++ b/compiler/src/iree/compiler/DispatchCreation/Passes.cpp
@@ -0,0 +1,351 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/DispatchCreation/Passes.h"
+
+#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
+#include "iree/compiler/Utils/PassUtils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Pass/PassOptions.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/Passes.h"
+
+//===----------------------------------------------------------------------===//
+// Command Line Options
+//===----------------------------------------------------------------------===//
+
+static llvm::cl::opt<std::string> clDispatchTransformFileName(
+ "iree-dispatch-creation-dispatch-use-transform-dialect",
+ llvm::cl::desc("MLIR file containing a top-level module that specifies "
+ "the transformations to apply to form dispatch regions."),
+ llvm::cl::init(""));
+
+static llvm::cl::opt<bool> clDetensoring(
+ "iree-dispatch-creation-enable-detensoring",
+ llvm::cl::desc(
+ "Enable changing of tensor operations into scalar operations."),
+ llvm::cl::init(false));
+
+static llvm::cl::opt<bool> clEnableElementWiseFuseMultiReduction(
+ "iree-dispatch-creation-element-wise-fuse-multi-reduction",
+ llvm::cl::desc("Enable element-wise fusion of multi-reduction loop ops."),
+ llvm::cl::init(true));
+
+static llvm::cl::opt<bool> clEnableFusePaddingIntoLinalgConsumerOps(
+ "iree-dispatch-creation-enable-fuse-padding-into-linalg-consumer-ops",
+ llvm::cl::desc("Enable fusing tensor.pad ops into Linalg consumer ops."),
+ llvm::cl::init(false));
+
+static llvm::cl::opt<bool> clEnableFusePaddingIntoLinalgProducerOps(
+ "iree-dispatch-creation-enable-fuse-padding-into-linalg-producer-ops",
+ llvm::cl::desc("Enable fusing tensor.pad ops into Linalg consumer ops."),
+ llvm::cl::init(false));
+
+static llvm::cl::opt<int> clPadFactor(
+ "iree-dispatch-creation-pad-factor",
+ llvm::cl::desc(
+ "Provides padding size hints that will be attached to "
+ "encodings. This only affects the experimental data tiling "
+ "path in Flow with iree-dispatch-creation-experimental-data-tiling."),
+ llvm::cl::init(32));
+
+static llvm::cl::opt<bool> clEnablePadHandling(
+ "iree-flow-enable-pad-handling",
+ llvm::cl::desc("Enable native handling of tensor.pad operations."),
+ llvm::cl::init(false));
+
+static llvm::cl::opt<bool> clEnableFuseHorizontalContractions(
+ "iree-dispatch-creation-enable-fuse-horizontal-contractions",
+ llvm::cl::desc(
+ "Enables horizontal fusion of contractions with one common operand"),
+ llvm::cl::init(false));
+
+static llvm::cl::opt<bool> clCollapseReductionDims(
+ "iree-dispatch-creation-collapse-reduction-dims",
+ llvm::cl::desc("Enable collapsing of reduction dims"),
+ llvm::cl::init(false));
+
+static llvm::cl::opt<bool>
+ clEnableFuseMultiUse("iree-dispatch-creation-fuse-multi-use",
+ llvm::cl::desc("Fuse multi-use ops."),
+ llvm::cl::init(false));
+
+static llvm::cl::opt<bool> clEnableAggressiveFusion(
+ "iree-dispatch-creation-enable-aggressive-fusion",
+ llvm::cl::desc("Aggressive fusion opportunities that are behind a flag "
+ "since all backends dont support it yet"),
+ llvm::cl::init(false));
+
+static llvm::cl::opt<bool> clEnableDataTiling(
+ "iree-dispatch-creation-experimental-data-tiling",
+ llvm::cl::desc("Enable data-tiling at flow level, i.e., it sets encodings "
+ "in dispatch regions, hoist them out of region, and enables "
+ "fusion for the set_encodings. This is still an "
+ "experimental path. The current main data tiling path is "
+ "iree-opt-data-tiling, which is on by default. To use this "
+ "path, --iree-opt-data-tiling=false must be set as wells"),
+ llvm::cl::init(false));
+
+namespace mlir::iree_compiler::DispatchCreation {
+
+//===----------------------------------------------------------------------===//
+// Utilities
+//===----------------------------------------------------------------------===//
+using FunctionLikeNest =
+ MultiOpNest<func::FuncOp, IREE::Util::InitializerOp, IREE::Util::FuncOp>;
+
+static void addCleanupPatterns(OpPassManager &passManager) {
+ FunctionLikeNest(passManager)
+ // Standard MLIR cleanup.
+ .addPass(IREE::Flow::createCanonicalizerPass)
+ .addPass(mlir::createCSEPass)
+
+ // Simplify util.global accesses; this can help with data flow tracking as
+ // redundant store-loads are removed.
+ .addPass(IREE::Util::createSimplifyGlobalAccessesPass);
+
+ // Cleanup and canonicalization of util.global (and other util ops).
+ passManager.addPass(IREE::Util::createApplyPatternsPass());
+ passManager.addPass(IREE::Util::createFoldGlobalsPass());
+ passManager.addPass(IREE::Util::createFuseGlobalsPass());
+
+ // Large IPO pass. Note that this can introduce a significant amount of
+ // duplication/inlined constants and we'll want to ensure we're running
+ // cleanup again after (this entire set of patterns is run in a fixed-point
+ // iteration to do that).
+ passManager.addPass(IREE::Util::createIPOPass());
+}
+
+//===----------------------------------------------------------------------===//
+// Pipelines
+//===----------------------------------------------------------------------===//
+
+void addDispatchRegionCreationPreprocessingPasses(OpPassManager &passManager) {
+ // 1. Do some simple elementwise op fusion. This could be skipped,
+ // but could reduce the surface area of ops to handle later.
+ FunctionLikeNest(passManager)
+ .addPass([]() {
+ return DispatchCreation::createElementwiseOpFusionPass(
+ ElementwiseOpFusionPassOptions{
+ clEnableElementWiseFuseMultiReduction});
+ })
+ .addPass(IREE::Flow::createCanonicalizerPass)
+ .addPass(mlir::createCSEPass)
+
+ // 2. Bubble up expand_shape ops (or sink collapse_shape ops) to get
+ // elementwise operation into higher dimensions for more fusion
+ // opportunities.
+ .addPass(DispatchCreation::createBubbleUpExpandShapesPass)
+ .addPass(IREE::Flow::createCanonicalizerPass)
+ .addPass(mlir::createCSEPass)
+
+ // 3. Perform elementwise operation fusion again (now with higher
+ // dimensionality).
+ .addPass([]() {
+ return DispatchCreation::createElementwiseOpFusionPass(
+ ElementwiseOpFusionPassOptions{
+ clEnableElementWiseFuseMultiReduction});
+ })
+ .addPass(IREE::Flow::createCanonicalizerPass)
+ .addPass(mlir::createCSEPass)
+
+ // 4. After elementwise operation fusion sink reshapes that block
+ // producer-consumer fusion.
+ .addPass(DispatchCreation::createSinkReshapesPass)
+ .addPass(IREE::Flow::createCanonicalizerPass)
+ .addPass(mlir::createCSEPass);
+
+ if (clEnableFuseHorizontalContractions) {
+ FunctionLikeNest(passManager)
+ .addPass(createFuseHorizontalContractionsPass)
+ .addPass(mlir::createCanonicalizerPass)
+ .addPass(mlir::createCSEPass);
+ }
+
+ FunctionLikeNest(passManager)
+ // 5. After all the reshape propagations, fuse elementwise operations
+ // even if the producer has multiple uses.
+ .addPass(DispatchCreation::createFuseMultiUseElementwiseProducerPass)
+
+ // 6. Some more "post elementwise fusion passes".
+ // a. Detensorize.
+ // TODO: This is probably not in the right place.
+ .addPredicatedPass(clDetensoring,
+ [&]() { return mlir::createLinalgDetensorizePass(); })
+ .addPass(IREE::Flow::createCanonicalizerPass)
+ .addPass(mlir::createCSEPass)
+
+ // b. For ops with multiple reduction dimensions, collapse the
+ // reduction dimension.
+ // TODO: This pass is only needed till all backends can handle
+ // multiple reduction dimensions.
+ .addPredicatedPass(
+ clCollapseReductionDims,
+ DispatchCreation::createCollapseReductionDimensionsPass)
+
+ // c. Split reduction operations into parallel and reduction, i.e
+ // .
+ .addPass(DispatchCreation::createSplitReductionPass)
+
+ // d. Transpose generic ops to
+ // - help with dispatch region formation.
+ // - move reduction iterators to be innermost.
+ .addPass(DispatchCreation::createTransposeGenericOpsPass);
+}
+
+// Pipeline to first create `flow.dispatch.region` ops and then lower to
+// `flow.dispatch.workgroup` ops.
+static void addDispatchRegionCreationPasses(OpPassManager &passManager) {
+ FunctionLikeNest(passManager)
+ // Only want use the transform dialect for some dispatch regions and let
+ // the FormDispatchRegions handle the rest. This only moves the root
+ // compute op into the dispatch region, so that we can run additional
+ // transformations afterwards with a simple region and without bothering
+ // producers.
+ .addPredicatedPass(
+ !clDispatchTransformFileName.empty(),
+ [&]() {
+ DispatchWithTransformDialectPassOptions options;
+ options.transformSpecPath = clDispatchTransformFileName;
+ return createDispatchWithTransformDialectPass(options);
+ })
+ // Create dispatches for scalar operations as roots
+ .addPass(DispatchCreation::createFormScalarDispatchesPass)
+ // Create `flow.dispatch.region` centered around a root and fuse with
+ // producers and consumers.
+ .addPass([&]() {
+ return DispatchCreation::createFormDispatchRegionsPass(
+ FormDispatchRegionsPassOptions{
+ clEnableAggressiveFusion,
+ clEnableFusePaddingIntoLinalgConsumerOps,
+ clEnableFusePaddingIntoLinalgProducerOps});
+ })
+ // Clone all producers into the dispatch region to perpare for being
+ // isolated from above. This enables running additional transformations
+ // afterwards that would need the full dispatch content but don't want to
+ // handle explicit captures as materialized as dispatch workgroup operands
+ // and block arguments.
+ .addPass(DispatchCreation::createCloneProducersIntoDispatchRegionsPass);
+
+ // Experimental data tiling path. The intent of this path is to set encodings
+ // after fusion decisions have already been made, so encodings can be
+ // separated from compiler fusion decisions.
+ if (clEnableDataTiling) {
+ SetEncodingPassOptions options{clPadFactor};
+ FunctionLikeNest(passManager)
+ // Set encodings on all eligible ops. All ops should be in compiler
+ // formed dispatch regions, so encodings will be placed inside of the
+ // dispatch regions with the data-tiled op.
+ .addPass([&]() { return createSetEncodingPass(options); })
+ // SetEncodingOps should not be in the same dispatch as the data-tiled
+ // op, so hoist them out of their current dispatch regions. Also, bubble
+ // SetEncodingOps through special operations like bit-extending ops and
+ // broadcasting ops.
+ .addPass(DispatchCreation::createHoistEncodingOpsPass);
+ }
+ FunctionLikeNest(passManager)
+ // Collapse dimensions of linalg Ops.
+ .addPass(DispatchCreation::createCollapseDimensionsPass);
+}
+
+// Apply preprocessing and form dispatch regions
+void buildDispatchCreationPassPipeline(
+ OpPassManager &passManager, const TransformOptions &transformOptions) {
+
+ // Inject tensor tracing early as we need to have the tracers in the IR
+ // prior to dispatch region formation where we may lose access to them.
+ FunctionLikeNest(passManager)
+ .addPass(IREE::Flow::createInjectTensorTracingPass);
+
+ // Transform pad operations into linalg.fill + tensor.insert_slice.
+ // This is a WAR for not having native pad handling.
+ if (!clEnablePadHandling && !clEnableFusePaddingIntoLinalgProducerOps) {
+ passManager.addPass(
+ DispatchCreation::createTensorPadToTensorInsertSlicePass(
+ TensorPadToTensorInsertSlicePassOptions{
+ /*skipSingleLinalgOpUses=*/
+ clEnableFusePaddingIntoLinalgConsumerOps}));
+ }
+
+ {
+ // We run these under a fixed-point iteration such that we can perform
+ // inter-procedural, intra-procedural, and canonicalization as separably
+ // verifiable/reusable passes. IPO will fold duplicate arguments/results
+ // and inline constants to allow the local optimizations to work more
+ // effectively.
+ OpPassManager ipoPipeline(mlir::ModuleOp::getOperationName());
+
+ // IPO and other cleanups.
+ addCleanupPatterns(ipoPipeline);
+
+ // Run fixed-point iteration on the IPO pipeline.
+ passManager.addPass(
+ IREE::Util::createFixedPointIteratorPass(std::move(ipoPipeline)));
+ }
+
+ FunctionLikeNest(passManager)
+ // Preprocess the input to a form more amenable for fusion.
+ .addPass(DispatchCreation::createFusionPreprocessingPass)
+ .addPass(IREE::Flow::createCanonicalizerPass)
+ .addPass(mlir::createCSEPass);
+
+ addDispatchRegionCreationPreprocessingPasses(passManager);
+ addDispatchRegionCreationPasses(passManager);
+
+ FunctionLikeNest(passManager)
+ .addPass(DispatchCreation::createConvertDispatchRegionsToWorkgroupsPass)
+ // Convert tensor operations to flow.tensor ops.
+ // - Convert extract/insert slice to flow update ops when the tensor op
+ // acts as a contiguous view of the tensor
+ // - Apply tensor -> flow patterns
+ .addPass(DispatchCreation::createConvertTensorToFlowPass)
+ .addPass(IREE::Flow::createCanonicalizerPass)
+ /// Creates the workgroup count region where the materialized computation
+ /// is derived as a program slice of the body of the dispatch. This method
+ /// - Computes the `workload` to use for the `workgroupsOp`, which are
+ /// derived from the values captured by the `workgroupsOp`.
+ /// - Populates the workgroup count region for this with the placeholder
+ /// op `flow.dispatch.workgroups_count_from_body_slice`. This op is
+ /// resolved in the backends into the actual workgroup count
+ /// computation.
+ /// - To correlate back to the captured workload,
+ /// `flow.dispatch.workload.ordinal`
+ /// to map the captured operand to the position in the workload list.
+ .addPass(
+ DispatchCreation::createMaterializeDefaultWorkgroupCountRegionPass);
+}
+
+namespace {
+#define GEN_PASS_REGISTRATION
+#include "iree/compiler/DispatchCreation/Passes.h.inc" // IWYU pragma: keep
+} // namespace
+
+void registerDispatchCreationPasses() {
+ // Generated from Passes.td
+ registerPasses();
+}
+
+void registerDispatchCreationPipelines() {
+ PassPipelineRegistration<TransformOptions> flowDispatchRegionCreationPipeline(
+ "iree-dispatch-creation-pipeline",
+ "Flag used to run passes that form dispatch regions",
+ [](OpPassManager &passManager, const TransformOptions &transformOptions) {
+ buildDispatchCreationPassPipeline(passManager, transformOptions);
+ });
+
+ PassPipelineRegistration<> flowDispatchRegionFormationPreprocessingPipeline(
+ "iree-dispatch-creation-preprocessing-pipeline",
+ "Flag used to run preprocessing passes that run passes before dispatch "
+ "region formation. Used only for testing",
+ [](OpPassManager &passManager) {
+ addDispatchRegionCreationPreprocessingPasses(passManager);
+ });
+}
+
+} // namespace mlir::iree_compiler::DispatchCreation
diff --git a/compiler/src/iree/compiler/DispatchCreation/Passes.h b/compiler/src/iree/compiler/DispatchCreation/Passes.h
new file mode 100644
index 0000000..e129fe6
--- /dev/null
+++ b/compiler/src/iree/compiler/DispatchCreation/Passes.h
@@ -0,0 +1,47 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DISPATCHCREATION_PASSES_H_
+#define IREE_COMPILER_DISPATCHCREATION_PASSES_H_
+
+#include <functional>
+
+#include "iree/compiler/Pipelines/Options.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+
+namespace mlir::iree_compiler::DispatchCreation {
+
+//===----------------------------------------------------------------------===//
+// Pipelines
+//===----------------------------------------------------------------------===//
+
+/// This is a placeholder for future. We should pass all the options through the
+/// struct.
+struct TransformOptions : public PassPipelineOptions<TransformOptions> {};
+
+void buildDispatchCreationPassPipeline(
+ OpPassManager &passManager, const TransformOptions &transformOptions);
+
+//===----------------------------------------------------------------------===//
+// Register all Passes
+//===----------------------------------------------------------------------===//
+
+#define GEN_PASS_DECL
+#include "iree/compiler/DispatchCreation/Passes.h.inc" // IWYU pragma: keep
+
+void registerDispatchCreationPasses();
+
+//===----------------------------------------------------------------------===//
+// Register Pipelines
+//===----------------------------------------------------------------------===//
+void registerDispatchCreationPipelines();
+
+} // namespace mlir::iree_compiler::DispatchCreation
+
+#endif
diff --git a/compiler/src/iree/compiler/DispatchCreation/Passes.td b/compiler/src/iree/compiler/DispatchCreation/Passes.td
new file mode 100644
index 0000000..a89579e
--- /dev/null
+++ b/compiler/src/iree/compiler/DispatchCreation/Passes.td
@@ -0,0 +1,322 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DISPATCHCREATION_PASSES
+#define IREE_COMPILER_DISPATCHCREATION_PASSES
+
+
+include "mlir/Pass/PassBase.td"
+
+//===---------------------------------------------------------------------===//
+// Dispatch region creation preprocessing passes :
+// Passes that transform the program before forming dispatches, like
+// - Elementwise operation fusion
+// - Reshape propagation passes
+//===---------------------------------------------------------------------===//
+
+def TensorPadToTensorInsertSlicePass :
+ Pass<"iree-dispatch-creation-tensor-pad-to-tensor-insert-slice"> {
+ let summary = "Convert tensor.pad into linalg.fill + tensor.insert_slice.";
+ let options = [
+ Option<"skipSingleLinalgOpUses", "skip-one-linalg-use-case", "bool",
+ /*default=*/"false",
+ "Skip the op that has only one use which is used"
+ "by a Linalg op">,
+ ];
+ let dependentDialects = [
+ "mlir::arith::ArithDialect",
+ "mlir::linalg::LinalgDialect",
+ "mlir::math::MathDialect",
+ "mlir::memref::MemRefDialect",
+ ];
+}
+
+def BubbleUpExpandShapesPass :
+ Pass<"iree-dispatch-creation-bubble-up-expand-shapes"> {
+ let summary = "Propagate expand_shapes up the program (and collapse_shapes down).";
+ let dependentDialects = [
+ "mlir::affine::AffineDialect",
+ ];
+}
+
+def CollapseReductionDimensionsPass :
+ Pass<"iree-dispatch-creation-collapse-reduction-dimensions", ""> {
+ let summary = "Collapse reduction dimensions when possible.";
+ let dependentDialects = [
+ "mlir::linalg::LinalgDialect",
+ ];
+}
+
+def ElementwiseOpFusionPass :
+ Pass<"iree-dispatch-creation-elementwise-op-fusion", ""> {
+ let summary = "Fuse elementwise operations.";
+ let options = [
+ Option<"fuseMultiReduction", "fuse-multi-reduction", "bool",
+ /*default=*/"true", "Fuse ops that have multiple reduction iterators">
+ ];
+ let dependentDialects = [
+ "mlir::affine::AffineDialect",
+ ];
+}
+
+def FoldUnitExtentDimsPass :
+ Pass<"iree-dispatch-creation-fold-unit-extent-dims", "mlir::ModuleOp"> {
+ let summary = "Fold unit extent dimension of operations.";
+ let description = [{
+ Imports upstream patterns to fold unit extent dims but with IREE control.
+ }];
+ let dependentDialects = [
+ "mlir::affine::AffineDialect",
+ "mlir::arith::ArithDialect",
+ "mlir::linalg::LinalgDialect",
+ "mlir::tensor::TensorDialect",
+ ];
+}
+
+def FuseHorizontalContractionsPass:
+ InterfacePass<"iree-dispatch-creation-fuse-horizontal-contractions", "mlir::FunctionOpInterface"> {
+ let summary = "Fuses horizontal contraction ops without fusions";
+ let dependentDialects = [
+ "mlir::arith::ArithDialect",
+ "mlir::tensor::TensorDialect",
+ ];
+ let options = [
+ Option<"fusionLimit", "fusion-limit", "int",
+ /*default=*/"3", "Maximum number of contractions fused into one">
+ ];
+ let statistics = [
+ Statistic<"numFusionGroups", "num-fusion-groups", "Number of fusion groups found">,
+ Statistic<"numSize2FusionGroups", "num-size-2-groups", "Number of fusion groups of size 2">,
+ Statistic<"numSize3FusionGroups", "num-size-3-groups", "Number of fusion groups of size 3">
+ ];
+}
+
+def FuseMultiUseElementwiseProducerPass :
+ InterfacePass<"iree-dispatch-creation-fuse-multi-use-elementwise-producer",
+ "mlir::FunctionOpInterface"> {
+ let summary = "Fuse elementwise linalg operations on tensors when producers have multiple uses.";
+ let options = [
+ Option<"numIterations", "num-iterations", "unsigned",
+ /*default=*/"2", "Number of iterations to fuse multiuse ops">
+ ];
+ let dependentDialects = [
+ "mlir::affine::AffineDialect",
+ "mlir::arith::ArithDialect",
+ "mlir::linalg::LinalgDialect",
+ "mlir::math::MathDialect",
+ ];
+}
+
+def FusionPreprocessingPass :
+ Pass<"iree-dispatch-creation-fusion-preprocessing", ""> {
+ let summary = "Run useful preprocessing patterns that help with fusion.";
+ let dependentDialects = [
+ "mlir::affine::AffineDialect",
+ ];
+}
+
+def SinkReshapesPass :
+ Pass<"iree-dispatch-creation-sink-reshapes", ""> {
+ let summary = "Sink reshapes to allow for compute op -> consumer fusion.";
+ let dependentDialects = [
+ "mlir::affine::AffineDialect",
+ "mlir::arith::ArithDialect",
+ ];
+}
+
+def SplitReductionPass :
+ Pass<"iree-dispatch-creation-split-reduction-ops", ""> {
+ let summary = "Split reduction dimension to increase parallelism.";
+ let dependentDialects = [
+ "mlir::linalg::LinalgDialect",
+ ];
+}
+
+def TransposeGenericOpsPass :
+ Pass<"iree-dispatch-creation-transpose-generic-ops", ""> {
+ let summary = "Transpose generic op loops.";
+ let dependentDialects = [
+ "mlir::linalg::LinalgDialect",
+ ];
+}
+
+//===---------------------------------------------------------------------===//
+// Dispatch region creation passes.
+//===---------------------------------------------------------------------===//
+
+def CloneProducersIntoDispatchRegionsPass :
+ InterfacePass<"iree-dispatch-creation-clone-producers-into-dispatch-regions", "mlir::FunctionOpInterface"> {
+ let summary = "Clone producers into dispatch regions to be isolated above.";
+ let description = [{
+ Pass to clone into dispatch regions producers of values used in the dispatch
+ regions but defined in the above. This prepares the dispatch regions for
+ converting to dispatch workgroups with explicit captures.
+ }];
+}
+
+def CollapseDimensionsPass :
+ InterfacePass<"iree-dispatch-creation-collapse-dimensions", "mlir::FunctionOpInterface"> {
+ let summary = "Collapse dimensions of Linalg Ops on tensor ops.";
+ let options = [
+ Option<"maxIterations", "max-iterations", "int",
+ /*default=*/"10",
+ "Maximum number of iterations to wait for collapse dimensions to converge"
+ >,
+ ];
+ let description = [{
+ Collapse dimensions of Linalg Ops on tensor ops inside dispatch.region ops
+ and hoist the reshaping operations out of the dispatch.
+ }];
+}
+
+def DispatchWithTransformDialectPass : Pass<"iree-dispatch-creation-dispatch-with-transform-dialect"> {
+ let summary = "Dispatch Linalg operations on tensors by using the transform dialect interpreter.";
+ let description = [{
+ Pass to perform dispatch of Linalg on tensor ops by using the transform
+ dialect. Dispatch regions are created as specified by the transform module
+ that is parsed from `transformSpecPath`.
+
+ TODO: Drop this pass in favor of the one upstream. The one upstream requires
+ separate loading of the module and thus isn't suited for single-use
+ transform scripts.
+ }];
+ let dependentDialects = [
+ "mlir::affine::AffineDialect",
+ "mlir::arith::ArithDialect",
+ "mlir::linalg::LinalgDialect",
+ "mlir::pdl::PDLDialect",
+ "mlir::pdl_interp::PDLInterpDialect",
+ "mlir::scf::SCFDialect",
+ "mlir::tensor::TensorDialect",
+ "mlir::transform::TransformDialect",
+ "IREE::Flow::FlowDialect",
+ "IREE::LinalgExt::IREELinalgExtDialect",
+ ];
+ let options = [
+ Option<"disableExpensiveChecks", "disable-expensive-checks", "bool",
+ "false",
+ "Disable expensive checks in the interpreter for a faster run.">,
+ Option<"transformSpecPath", "transform-spec-path", "std::string",
+ /*default=*/"", "File path to the transform spec to use.">,
+ ];
+}
+
+def FormDispatchRegionsPass :
+ InterfacePass<"iree-dispatch-creation-form-dispatch-regions", "mlir::FunctionOpInterface"> {
+ let summary = "Form Dispatch Region Ops from Linalg operations on tensors to form dispatch.regions.";
+ let options = [
+ Option<"aggressiveFusion", "aggressive-fusion", "bool",
+ /*default=*/"false", "Aggressive mode enabling fusions not ready for all backends">,
+ Option<"fusePadWithConsumers", "fuse-pad-with-consumers", "bool",
+ /*default=*/"false", "Enable fusing pad with consumer">,
+ Option<"fusePadWithProducers", "fuse-pad-with-producers", "bool",
+ /*default=*/"false", "Enable fusion of pad with producers">
+ ];
+ let description = [{
+ Pass to form dispatch.region ops from Linalg on tensor ops. A dispatch region
+ is created for each tiled loop nest. This pass only moves the root compute op
+ into the dispatch region, allowing producers to be outside.
+ }];
+ let dependentDialects = [
+ "mlir::affine::AffineDialect",
+ "mlir::linalg::LinalgDialect",
+ "mlir::scf::SCFDialect",
+ "mlir::tensor::TensorDialect",
+ "IREE::Flow::FlowDialect",
+ "IREE::LinalgExt::IREELinalgExtDialect",
+ ];
+}
+
+def FormScalarDispatchesPass :
+ InterfacePass<"iree-dispatch-creation-form-scalar-dispatches", "mlir::FunctionOpInterface"> {
+ let summary = "Form Dispatch Regions for scalar computations.";
+ let dependentDialects = [
+ "mlir::affine::AffineDialect",
+ "mlir::linalg::LinalgDialect",
+ "mlir::tensor::TensorDialect",
+ "IREE::Flow::FlowDialect",
+ ];
+}
+
+def SetEncodingPass :
+ InterfacePass<"iree-dispatch-creation-set-encoding", "mlir::FunctionOpInterface"> {
+ let summary = "Introduces tensor encoding for flow dispatch regions.";
+ let dependentDialects = [
+ "mlir::linalg::LinalgDialect",
+ "IREE::Flow::FlowDialect",
+ "IREE::Encoding::IREEEncodingDialect",
+ ];
+ let options = [
+ Option<"padFactor", "pad-factor", "int64_t", /*default=*/"32",
+ "provides padding size hints that will be attached to encodings.">,
+ ];
+}
+
+def ConvertDispatchRegionsToWorkgroupsPass :
+ InterfacePass<"iree-dispatch-creation-convert-dispatch-regions-to-workgroups", "mlir::FunctionOpInterface"> {
+ let summary = "Convert dispatch regions to dispatch workgroups.";
+ let description = [{
+ Pass to convert dispatch regions to dispatch workgroups. This pass is
+ intended to be used after dispatch regions have been formed.
+ }];
+ let dependentDialects = [
+ "mlir::affine::AffineDialect",
+ "mlir::linalg::LinalgDialect",
+ "mlir::scf::SCFDialect",
+ "mlir::tensor::TensorDialect",
+ "IREE::Flow::FlowDialect",
+ ];
+ let statistics = [
+ Statistic<"numDispatches", "num-dispatches", "Number of dispatches created">
+ ];
+}
+
+def ConvertTensorToFlowPass :
+ InterfacePass<"iree-dispatch-creation-convert-tensor-to-flow", "mlir::FunctionOpInterface"> {
+ let summary = "Convert tensor operations to flow";
+ let description = [{
+ Pass to convert tensor operations to flow.tensor.* operations.
+ }];
+ let dependentDialects = [
+ "mlir::affine::AffineDialect",
+ "mlir::arith::ArithDialect",
+ "mlir::linalg::LinalgDialect",
+ "mlir::tensor::TensorDialect",
+ "IREE::Flow::FlowDialect",
+ ];
+ let statistics = [
+ Statistic<"numSlowCopyDispatches", "num-slow-copy-dispatches",
+ "Number of slow copy dispatches (for handling slices) created">
+ ];
+}
+
+def MaterializeDefaultWorkgroupCountRegionPass:
+ InterfacePass<"iree-dispatch-creation-materialize-default-workgroup-count-region",
+ "mlir::FunctionOpInterface"> {
+ let summary = "Canonicalize dispatch workgroups ops.";
+ let description = [{
+ Apply dispatch workgroups canonicalization patterns.
+ }];
+ let dependentDialects = [
+ "mlir::affine::AffineDialect",
+ "mlir::arith::ArithDialect",
+ "mlir::linalg::LinalgDialect",
+ "mlir::scf::SCFDialect",
+ "IREE::Flow::FlowDialect",
+ ];
+}
+
+def HoistEncodingOpsPass :
+ InterfacePass<"iree-dispatch-creation-hoist-encoding-ops", "mlir::FunctionOpInterface"> {
+ let summary = "Hoists tensor encoding ops out of flow dispatch regions.";
+ let dependentDialects = [
+ "mlir::linalg::LinalgDialect",
+ "IREE::Flow::FlowDialect",
+ "IREE::Encoding::IREEEncodingDialect",
+ ];
+}
+
+#endif // IREE_COMPILER_DISPATCHCREATION_PASSES
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/SetEncoding.cpp b/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp
similarity index 96%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/SetEncoding.cpp
rename to compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp
index 9df0c40..c7cb30b 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/SetEncoding.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp
@@ -14,8 +14,8 @@
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
+#include "iree/compiler/DispatchCreation/Passes.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
@@ -25,11 +25,11 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#define DEBUG_TYPE "iree-flow-set-encoding"
+#define DEBUG_TYPE "iree-dispatch-creation-set-encoding"
-namespace mlir::iree_compiler::IREE::Flow {
+namespace mlir::iree_compiler::DispatchCreation {
#define GEN_PASS_DEF_SETENCODINGPASS
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc"
+#include "iree/compiler/DispatchCreation/Passes.h.inc"
using IREE::Encoding::EncodingAttr;
@@ -293,7 +293,7 @@
/// Pattern to fold a `linalg.fill` -> `iree_encoding.set_encoding`
/// operation into a `linalg.fill` of the encoded type.
-struct FoldFillWithSetEncoding
+struct FoldFillWithSetEncoding final
: public OpRewritePattern<IREE::Encoding::SetEncodingOp> {
using OpRewritePattern<IREE::Encoding::SetEncodingOp>::OpRewritePattern;
@@ -317,11 +317,9 @@
}
};
-class SetEncodingPass : public impl::SetEncodingPassBase<SetEncodingPass> {
-
-public:
+struct SetEncodingPass final
+ : public impl::SetEncodingPassBase<SetEncodingPass> {
using Base::Base;
-
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
@@ -337,4 +335,4 @@
};
} // namespace
-} // namespace mlir::iree_compiler::IREE::Flow
+} // namespace mlir::iree_compiler::DispatchCreation
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/SinkReshapes.cpp b/compiler/src/iree/compiler/DispatchCreation/SinkReshapes.cpp
similarity index 91%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/SinkReshapes.cpp
rename to compiler/src/iree/compiler/DispatchCreation/SinkReshapes.cpp
index 3705f7b..6e7c707 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/SinkReshapes.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/SinkReshapes.cpp
@@ -14,10 +14,10 @@
//===----------------------------------------------------------------------===//
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
-#include "iree/compiler/Dialect/Flow/Transforms/FusionUtils.h"
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
+#include "iree/compiler/DispatchCreation/FusionUtils.h"
+#include "iree/compiler/DispatchCreation/Passes.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -25,19 +25,18 @@
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#define DEBUG_TYPE "iree-flow-sink-reshapes"
+#define DEBUG_TYPE "iree-dispatch-creation-sink-reshapes"
-namespace mlir::iree_compiler::IREE::Flow {
+namespace mlir::iree_compiler::DispatchCreation {
#define GEN_PASS_DEF_SINKRESHAPESPASS
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc"
+#include "iree/compiler/DispatchCreation/Passes.h.inc"
namespace {
-class SinkReshapesPass : public impl::SinkReshapesPassBase<SinkReshapesPass> {
-public:
+struct SinkReshapesPass final
+ : public impl::SinkReshapesPassBase<SinkReshapesPass> {
using Base::Base;
-
void runOnOperation() override;
};
@@ -49,7 +48,7 @@
static bool isFusableUsingTileAndFuse(Operation *producer,
Operation *consumer) {
return llvm::isa_and_nonnull<linalg::LinalgOp, tensor::UnPackOp,
- Encoding::UnsetEncodingOp>(producer);
+ IREE::Encoding::UnsetEncodingOp>(producer);
}
/// Control function to check if an `tensor.expand_shape` (which is producer of
@@ -62,7 +61,7 @@
return false;
}
Operation *consumer = opOperand->getOwner();
- if (!isNonNullAndOutsideDispatch({reshapeOp, consumer})) {
+ if (!IREE::Flow::isNonNullAndOutsideDispatch({reshapeOp, consumer})) {
return false;
}
auto consumerGenericOp = dyn_cast<linalg::GenericOp>(consumer);
@@ -77,7 +76,7 @@
// Do not sink reshapes across dequantize operations since they are
// cloned into their consumers.
- if (LinalgExt::isBitExtendOp(consumer)) {
+ if (IREE::LinalgExt::isBitExtendOp(consumer)) {
return false;
}
@@ -176,4 +175,4 @@
} // namespace
-} // namespace mlir::iree_compiler::IREE::Flow
+} // namespace mlir::iree_compiler::DispatchCreation
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/SplitReduction.cpp b/compiler/src/iree/compiler/DispatchCreation/SplitReduction.cpp
similarity index 82%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/SplitReduction.cpp
rename to compiler/src/iree/compiler/DispatchCreation/SplitReduction.cpp
index 3322530..a02b5a5 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/SplitReduction.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/SplitReduction.cpp
@@ -10,25 +10,25 @@
//
//===----------------------------------------------------------------------===//
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
+#include "iree/compiler/DispatchCreation/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-namespace mlir::iree_compiler::IREE::Flow {
+namespace mlir::iree_compiler::DispatchCreation {
#define GEN_PASS_DEF_SPLITREDUCTIONPASS
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc"
+#include "iree/compiler/DispatchCreation/Passes.h.inc"
// TODO(thomasraoux): Move to attributes.
static llvm::cl::opt<int64_t>
- splitReductionRatio("iree-flow-split-matmul-reduction",
+ splitReductionRatio("iree-dispatch-creation-split-matmul-reduction",
llvm::cl::desc("split ratio"), llvm::cl::init(1));
static llvm::cl::list<int64_t> topkSplitReductionRatio(
- "iree-flow-topk-split-reduction",
+ "iree-dispatch-creation-topk-split-reduction",
llvm::cl::desc("comma separated list of split ratios"),
llvm::cl::CommaSeparated);
@@ -51,8 +51,8 @@
}
namespace {
-struct SplitReductionPass
- : public IREE::Flow::impl::SplitReductionPassBase<SplitReductionPass> {
+struct SplitReductionPass final
+ : public impl::SplitReductionPassBase<SplitReductionPass> {
void runOnOperation() override {
if (splitReductionRatio.getValue() <= 1 &&
topkSplitReductionRatio.empty()) {
@@ -76,7 +76,7 @@
(void)splitReductionOnMatmul(rewriter, op, matmulSplitReductionControlFn);
}
- LinalgExt::TopkSplitReductionControlFn topkSplitReductionControlFn =
+ IREE::LinalgExt::TopkSplitReductionControlFn topkSplitReductionControlFn =
[&](int64_t splitReductionDepth) -> int64_t {
SmallVector<int64_t> reductionRatios(topkSplitReductionRatio.begin(),
topkSplitReductionRatio.end());
@@ -87,8 +87,9 @@
}
};
- SmallVector<LinalgExt::TopkOp> topkCandidates;
- funcOp->walk([&](LinalgExt::TopkOp op) { topkCandidates.push_back(op); });
+ SmallVector<IREE::LinalgExt::TopkOp> topkCandidates;
+ funcOp->walk(
+ [&](IREE::LinalgExt::TopkOp op) { topkCandidates.push_back(op); });
for (auto op : topkCandidates) {
(void)splitReduction(rewriter, op, topkSplitReductionControlFn);
}
@@ -97,4 +98,4 @@
} // namespace
-} // namespace mlir::iree_compiler::IREE::Flow
+} // namespace mlir::iree_compiler::DispatchCreation
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/TensorPadToTensorInsertSlice.cpp b/compiler/src/iree/compiler/DispatchCreation/TensorPadToTensorInsertSlice.cpp
similarity index 88%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/TensorPadToTensorInsertSlice.cpp
rename to compiler/src/iree/compiler/DispatchCreation/TensorPadToTensorInsertSlice.cpp
index 8eb612e..3e24b37 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/TensorPadToTensorInsertSlice.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/TensorPadToTensorInsertSlice.cpp
@@ -11,7 +11,7 @@
//===----------------------------------------------------------------------===//
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
+#include "iree/compiler/DispatchCreation/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -24,10 +24,10 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-namespace mlir::iree_compiler::IREE::Flow {
+namespace mlir::iree_compiler::DispatchCreation {
#define GEN_PASS_DEF_TENSORPADTOTENSORINSERTSLICEPASS
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc"
+#include "iree/compiler/DispatchCreation/Passes.h.inc"
namespace {
/// Pattern to convert a tensor.tensor operation into a fill +
@@ -84,11 +84,10 @@
bool skipSingleLinalgOpUses = false;
};
-struct TensorPadToTensorInsertSlicePass
- : public IREE::Flow::impl::TensorPadToTensorInsertSlicePassBase<
+struct TensorPadToTensorInsertSlicePass final
+ : public impl::TensorPadToTensorInsertSlicePassBase<
TensorPadToTensorInsertSlicePass> {
- using IREE::Flow::impl::TensorPadToTensorInsertSlicePassBase<
- TensorPadToTensorInsertSlicePass>::TensorPadToTensorInsertSlicePassBase;
+ using Base::Base;
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
@@ -102,4 +101,4 @@
} // namespace
-} // namespace mlir::iree_compiler::IREE::Flow
+} // namespace mlir::iree_compiler::DispatchCreation
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/TransposeGenericOps.cpp b/compiler/src/iree/compiler/DispatchCreation/TransposeGenericOps.cpp
similarity index 90%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/TransposeGenericOps.cpp
rename to compiler/src/iree/compiler/DispatchCreation/TransposeGenericOps.cpp
index ea05324..02abc36 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/TransposeGenericOps.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/TransposeGenericOps.cpp
@@ -11,15 +11,15 @@
//
//===----------------------------------------------------------------------===//
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
+#include "iree/compiler/DispatchCreation/Passes.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-namespace mlir::iree_compiler::IREE::Flow {
+namespace mlir::iree_compiler::DispatchCreation {
#define GEN_PASS_DEF_TRANSPOSEGENERICOPSPASS
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc"
+#include "iree/compiler/DispatchCreation/Passes.h.inc"
namespace {
@@ -29,7 +29,7 @@
/// For generic ops that are reduction, make the reduction the innermost
/// dimension.
-struct MakeReductionInnermostPattern
+struct MakeReductionInnermostPattern final
: public OpRewritePattern<linalg::GenericOp> {
using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
@@ -62,7 +62,8 @@
/// ops), the dispatch region fusion logic requires the indexing maps to be
/// identity (or projections that are not transposing as well). This pattern
/// fixes up elementwise operations for which that is not the case.
-struct TransposeGenericOpPattern : public OpRewritePattern<linalg::GenericOp> {
+struct TransposeGenericOpPattern final
+ : public OpRewritePattern<linalg::GenericOp> {
using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
PatternRewriter &rewriter) const override {
@@ -112,9 +113,8 @@
}
};
-struct TransposeGenericOpsPass
- : public IREE::Flow::impl::TransposeGenericOpsPassBase<
- TransposeGenericOpsPass> {
+struct TransposeGenericOpsPass final
+ : public impl::TransposeGenericOpsPassBase<TransposeGenericOpsPass> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns.add<MakeReductionInnermostPattern, TransposeGenericOpPattern>(
@@ -128,4 +128,4 @@
} // namespace
-} // namespace mlir::iree_compiler::IREE::Flow
+} // namespace mlir::iree_compiler::DispatchCreation
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel b/compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel
new file mode 100644
index 0000000..8ca4ad2
--- /dev/null
+++ b/compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel
@@ -0,0 +1,59 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
+load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite")
+
+package(
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_lit_test_suite(
+ name = "lit",
+ srcs = enforce_glob(
+ [
+ "clone_producers_into_dispatch_regions.mlir",
+ "collapse_dimensions.mlir",
+ "collapse_linalg_generic_on_tensors.mlir",
+ "collapse_reduction.mlir",
+ "attention_fuse_by_expansion.mlir",
+ "dispatch_linalg_transform_dialect.mlir",
+ "dispatch_region_formation_preprocessing.mlir",
+ "fold_unit_dims.mlir",
+ "form_dispatch_regions.mlir",
+ "dispatch_linalg_on_tensors.mlir",
+ "convert_region_to_workgroups.mlir",
+ "form_dispatch_workgroups.mlir",
+ "dispatch_linalg_ext_fusion.mlir",
+ "hoist_encoding_ops.mlir",
+ "dispatch_linalg_on_tensors_default.mlir",
+ "dispatch_linalg_on_tensors_fusion_with_transpose.mlir",
+ "form_scalar_dispatches.mlir",
+ "fuse_horizontal_contractions.mlir",
+ "fuse_multiuse_elementwise_producer.mlir",
+ "fusion_preprocessing.mlir",
+ "pad_fusion_with_consumer.mlir",
+ "pad_fusion_with_producer.mlir",
+ "set_encoding.mlir",
+ "sink_reshapes.mlir",
+ "split_reduction.mlir",
+ "tensor_pad_to_tensor_insert_slice.mlir",
+ "transform_dispatch_region_formation.mlir",
+ "transpose_generic_ops.mlir",
+ ],
+ include = ["*.mlir"],
+ exclude = [
+ "transform_dialect_dispatch_spec.mlir",
+ ],
+ ),
+ cfg = "//compiler:lit.cfg.py",
+ data = ["transform_dialect_dispatch_spec.mlir"],
+ tools = [
+ "//tools:iree-opt",
+ "@llvm-project//llvm:FileCheck",
+ ],
+)
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/CMakeLists.txt b/compiler/src/iree/compiler/DispatchCreation/test/CMakeLists.txt
new file mode 100644
index 0000000..152c1ac
--- /dev/null
+++ b/compiler/src/iree/compiler/DispatchCreation/test/CMakeLists.txt
@@ -0,0 +1,52 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "attention_fuse_by_expansion.mlir"
+ "clone_producers_into_dispatch_regions.mlir"
+ "collapse_dimensions.mlir"
+ "collapse_linalg_generic_on_tensors.mlir"
+ "collapse_reduction.mlir"
+ "convert_region_to_workgroups.mlir"
+ "dispatch_linalg_ext_fusion.mlir"
+ "dispatch_linalg_on_tensors.mlir"
+ "dispatch_linalg_on_tensors_default.mlir"
+ "dispatch_linalg_on_tensors_fusion_with_transpose.mlir"
+ "dispatch_linalg_transform_dialect.mlir"
+ "dispatch_region_formation_preprocessing.mlir"
+ "fold_unit_dims.mlir"
+ "form_dispatch_regions.mlir"
+ "form_dispatch_workgroups.mlir"
+ "form_scalar_dispatches.mlir"
+ "fuse_horizontal_contractions.mlir"
+ "fuse_multiuse_elementwise_producer.mlir"
+ "fusion_preprocessing.mlir"
+ "hoist_encoding_ops.mlir"
+ "pad_fusion_with_consumer.mlir"
+ "pad_fusion_with_producer.mlir"
+ "set_encoding.mlir"
+ "sink_reshapes.mlir"
+ "split_reduction.mlir"
+ "tensor_pad_to_tensor_insert_slice.mlir"
+ "transform_dispatch_region_formation.mlir"
+ "transpose_generic_ops.mlir"
+ TOOLS
+ FileCheck
+ iree-opt
+ DATA
+ transform_dialect_dispatch_spec.mlir
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/attention_fuse_by_expansion.mlir b/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir
similarity index 98%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/test/attention_fuse_by_expansion.mlir
rename to compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir
index b243b59..d7d41b0 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/attention_fuse_by_expansion.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-flow-bubble-up-expand-shapes, canonicalize, cse, canonicalize))" %s | FileCheck %s
+// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-bubble-up-expand-shapes, canonicalize, cse, canonicalize))" %s | FileCheck %s
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/clone_producers_into_dispatch_regions.mlir b/compiler/src/iree/compiler/DispatchCreation/test/clone_producers_into_dispatch_regions.mlir
similarity index 99%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/test/clone_producers_into_dispatch_regions.mlir
rename to compiler/src/iree/compiler/DispatchCreation/test/clone_producers_into_dispatch_regions.mlir
index 844a581..3c65e7a 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/clone_producers_into_dispatch_regions.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/clone_producers_into_dispatch_regions.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-flow-clone-producers-into-dispatch-regions))" %s | FileCheck %s
+// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-clone-producers-into-dispatch-regions))" %s | FileCheck %s
util.func public @complex_element_type(%input: tensor<4xi32>, %table: tensor<8x2xcomplex<f32>>) -> tensor<4x2xcomplex<f32>> {
%c4095 = arith.constant 4095 : i32
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_dimensions.mlir b/compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir
similarity index 99%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_dimensions.mlir
rename to compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir
index d5e9079..5e21f86 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_dimensions.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-flow-collapse-dimensions))" %s | FileCheck %s
+// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-collapse-dimensions))" %s | FileCheck %s
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
util.func public @do_not_collapse_cst_in_place(%arg0: tensor<1x1x2304xf32>) {
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_linalg_generic_on_tensors.mlir b/compiler/src/iree/compiler/DispatchCreation/test/collapse_linalg_generic_on_tensors.mlir
similarity index 98%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_linalg_generic_on_tensors.mlir
rename to compiler/src/iree/compiler/DispatchCreation/test/collapse_linalg_generic_on_tensors.mlir
index 376f1d8..182af1c 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_linalg_generic_on_tensors.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/collapse_linalg_generic_on_tensors.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-flow-form-dispatch-regions{aggressive-fusion=true}, iree-flow-clone-producers-into-dispatch-regions, iree-flow-collapse-dimensions, cse))" %s | FileCheck %s
+// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-form-dispatch-regions{aggressive-fusion=true}, iree-dispatch-creation-clone-producers-into-dispatch-regions, iree-dispatch-creation-collapse-dimensions, cse))" %s | FileCheck %s
!type = tensor<2x4x8x16x32x64xf32>
util.global private @"__transpose_10_input" {inlining_policy = #util.inline.never} = dense<1.0> : !type
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_reduction.mlir b/compiler/src/iree/compiler/DispatchCreation/test/collapse_reduction.mlir
similarity index 96%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_reduction.mlir
rename to compiler/src/iree/compiler/DispatchCreation/test/collapse_reduction.mlir
index 013df03..b58f9be 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_reduction.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/collapse_reduction.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file -iree-flow-collapse-reduction-dimensions %s | FileCheck %s
+// RUN: iree-opt --split-input-file -iree-dispatch-creation-collapse-reduction-dimensions %s | FileCheck %s
util.func public @multi_reduce_dim(%arg0: tensor<2x32x10x4096xf32>) -> tensor<2x32x1x1xf32> {
%cst = arith.constant -0.000000e+00 : f32
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/convert_region_to_workgroups.mlir b/compiler/src/iree/compiler/DispatchCreation/test/convert_region_to_workgroups.mlir
similarity index 96%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/test/convert_region_to_workgroups.mlir
rename to compiler/src/iree/compiler/DispatchCreation/test/convert_region_to_workgroups.mlir
index 92a0208..e77cf35 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/convert_region_to_workgroups.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/convert_region_to_workgroups.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt %s --pass-pipeline="builtin.module(util.func(iree-flow-convert-dispatch-regions-to-workgroups, iree-flow-canonicalize, cse))" -split-input-file | FileCheck %s
+// RUN: iree-opt %s --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-convert-dispatch-regions-to-workgroups, iree-flow-canonicalize, cse))" -split-input-file | FileCheck %s
util.global private @device : !hal.device
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_ext_fusion.mlir b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_ext_fusion.mlir
similarity index 96%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_ext_fusion.mlir
rename to compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_ext_fusion.mlir
index 25385da..993b690 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_ext_fusion.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_ext_fusion.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="builtin.module(util.func(iree-flow-form-dispatch-regions{aggressive-fusion=true}, iree-flow-clone-producers-into-dispatch-regions), cse, canonicalize, cse)" %s | FileCheck %s
+// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-form-dispatch-regions{aggressive-fusion=true}, iree-dispatch-creation-clone-producers-into-dispatch-regions), cse, canonicalize, cse)" %s | FileCheck %s
#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors.mlir
similarity index 99%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
rename to compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors.mlir
index ea22db1..eb78651 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors.mlir
@@ -1,4 +1,5 @@
-// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="builtin.module(util.func(iree-flow-form-dispatch-regions{aggressive-fusion=true}, iree-flow-clone-producers-into-dispatch-regions, iree-flow-convert-dispatch-regions-to-workgroups, iree-flow-convert-tensor-to-flow, canonicalize, iree-flow-materialize-default-workgroup-count-region), cse, iree-flow-canonicalize, cse)" %s | FileCheck %s
+// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-form-dispatch-regions{aggressive-fusion=true}, iree-dispatch-creation-clone-producers-into-dispatch-regions, iree-dispatch-creation-convert-dispatch-regions-to-workgroups, iree-dispatch-creation-convert-tensor-to-flow, canonicalize, iree-dispatch-creation-materialize-default-workgroup-count-region), cse, iree-flow-canonicalize, cse)" %s | FileCheck %s
+
util.func public @tile_matmul_alone(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%1 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_default.mlir b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors_default.mlir
similarity index 95%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_default.mlir
rename to compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors_default.mlir
index 897fbcc..6deeda3 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_default.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors_default.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="builtin.module(util.func(iree-flow-form-dispatch-regions, iree-flow-clone-producers-into-dispatch-regions,iree-flow-convert-dispatch-regions-to-workgroups), cse, iree-flow-canonicalize, cse)" %s | FileCheck %s
+// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-form-dispatch-regions, iree-dispatch-creation-clone-producers-into-dispatch-regions,iree-dispatch-creation-convert-dispatch-regions-to-workgroups), cse, iree-flow-canonicalize, cse)" %s | FileCheck %s
util.func public @no_fuse_quantized(%arg0 : tensor<?x113x113x64xi8>, %arg1 : tensor<3x3x64xi8>,
%arg2 : i32, %arg3 : i32) -> tensor<?x56x56x64xi8> {
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_fusion_with_transpose.mlir b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors_fusion_with_transpose.mlir
similarity index 86%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_fusion_with_transpose.mlir
rename to compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors_fusion_with_transpose.mlir
index a091b90..68e4e9e 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_fusion_with_transpose.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors_fusion_with_transpose.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="builtin.module(util.func(iree-flow-transpose-generic-ops,iree-flow-form-dispatch-regions{aggressive-fusion=true}, iree-flow-convert-dispatch-regions-to-workgroups, iree-flow-canonicalize, cse))" --mlir-print-local-scope %s | FileCheck %s
+// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-transpose-generic-ops,iree-dispatch-creation-form-dispatch-regions{aggressive-fusion=true}, iree-dispatch-creation-convert-dispatch-regions-to-workgroups, canonicalize, cse))" --mlir-print-local-scope %s | FileCheck %s
util.func @fuse_conv(%arg0 : tensor<2x130x130x16xf32>, %arg1 : tensor<3x3x16x320xf32>) -> tensor<2x320x128x128xf32> {
%empty = tensor.empty() : tensor<2x128x128x320xf32>
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_transform_dialect.mlir b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_transform_dialect.mlir
similarity index 92%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_transform_dialect.mlir
rename to compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_transform_dialect.mlir
index bcb65c8..e2e0ef8 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_transform_dialect.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_transform_dialect.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-flow-dispatch-with-transform-dialect{transform-spec-path=%p/transform_dialect_dispatch_spec.mlir}))" %s | \
+// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-dispatch-with-transform-dialect{transform-spec-path=%p/transform_dialect_dispatch_spec.mlir}))" %s | \
// RUN: FileCheck %s
util.func public @tile_matmul_alone(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_region_formation_preprocessing.mlir b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_region_formation_preprocessing.mlir
similarity index 99%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_region_formation_preprocessing.mlir
rename to compiler/src/iree/compiler/DispatchCreation/test/dispatch_region_formation_preprocessing.mlir
index 1e855d1..30e78fc 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_region_formation_preprocessing.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_region_formation_preprocessing.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(iree-flow-dispatch-region-formation-preprocessing-pipeline)" %s | FileCheck %s
+// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(iree-dispatch-creation-preprocessing-pipeline)" %s | FileCheck %s
util.func public @softmax(%arg0 : tensor<12x128x128xf32>) -> tensor<12x128x128xf32> {
%cst = arith.constant 1.000000e+00 : f32
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fold_unit_dims.mlir b/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir
similarity index 97%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fold_unit_dims.mlir
rename to compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir
index e652c5e..249a8b1 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fold_unit_dims.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(iree-flow-fold-unit-extent-dims)" %s --split-input-file | FileCheck %s
+// RUN: iree-opt --pass-pipeline="builtin.module(iree-dispatch-creation-fold-unit-extent-dims)" %s --split-input-file | FileCheck %s
util.func public @no_fold_unit_dims_in_dispatches(%arg0 : tensor<1x1x10xf32>) -> tensor<1x1x10xf32> {
%0 = tensor.empty() : tensor<1x1x10xf32>
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_dispatch_regions.mlir b/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir
similarity index 99%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_dispatch_regions.mlir
rename to compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir
index b59e4a2..bfceeda 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_dispatch_regions.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-flow-form-dispatch-regions{aggressive-fusion=true}))" --split-input-file %s | FileCheck %s
+// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-form-dispatch-regions{aggressive-fusion=true}))" --split-input-file %s | FileCheck %s
util.func public @pack_elementwise_fusion(%arg0 : tensor<?xf32>,
%arg1 : tensor<?x?xf32>) -> tensor<?x?x8x32xf32> {
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_dispatch_workgroups.mlir b/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_workgroups.mlir
similarity index 94%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_dispatch_workgroups.mlir
rename to compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_workgroups.mlir
index ea8fa00..f3a687b 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_dispatch_workgroups.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_workgroups.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-flow-convert-dispatch-regions-to-workgroups))" --split-input-file %s | FileCheck %s
+// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-convert-dispatch-regions-to-workgroups))" --split-input-file %s | FileCheck %s
util.func public @existing_count_region(%arg0 : index, %arg1 : index) -> tensor<?x?xf32> {
%c1 = arith.constant 1 : index
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_scalar_dispatches.mlir b/compiler/src/iree/compiler/DispatchCreation/test/form_scalar_dispatches.mlir
similarity index 98%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_scalar_dispatches.mlir
rename to compiler/src/iree/compiler/DispatchCreation/test/form_scalar_dispatches.mlir
index f91d7c5..930477a 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_scalar_dispatches.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/form_scalar_dispatches.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-flow-form-scalar-dispatches))" --split-input-file %s | FileCheck %s
+// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-form-scalar-dispatches))" --split-input-file %s | FileCheck %s
#map = affine_map<() -> ()>
util.func public @simpleDAG(
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fuse_horizontal_contractions.mlir b/compiler/src/iree/compiler/DispatchCreation/test/fuse_horizontal_contractions.mlir
similarity index 99%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fuse_horizontal_contractions.mlir
rename to compiler/src/iree/compiler/DispatchCreation/test/fuse_horizontal_contractions.mlir
index d43fdab..a872654 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fuse_horizontal_contractions.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/fuse_horizontal_contractions.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-flow-fuse-horizontal-contractions))" --split-input-file %s | FileCheck %s
+// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-fuse-horizontal-contractions))" --split-input-file %s | FileCheck %s
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fuse_multiuse_elementwise_producer.mlir b/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir
similarity index 98%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fuse_multiuse_elementwise_producer.mlir
rename to compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir
index 5c4e8f9..cc3e159 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fuse_multiuse_elementwise_producer.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-flow-fuse-multi-use-elementwise-producer))" --split-input-file %s | FileCheck %s
+// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-fuse-multi-use-elementwise-producer))" --split-input-file %s | FileCheck %s
util.func public @batchnorm_training(%10 : tensor<12xf32>, %11 : tensor<12x12x12x12x12xf32>, %12 : tensor<12xf32>) -> (tensor<12xf32>, tensor<12xf32>, tensor<12xf32>) {
%cst = arith.constant 1.42 : f32
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fusion_preprocessing.mlir b/compiler/src/iree/compiler/DispatchCreation/test/fusion_preprocessing.mlir
similarity index 98%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fusion_preprocessing.mlir
rename to compiler/src/iree/compiler/DispatchCreation/test/fusion_preprocessing.mlir
index eacef6a..209785d 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fusion_preprocessing.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/fusion_preprocessing.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --iree-flow-fusion-preprocessing --split-input-file %s | FileCheck %s
+// RUN: iree-opt --iree-dispatch-creation-fusion-preprocessing --split-input-file %s | FileCheck %s
util.func public @fold_insert_slices(%source : tensor<?x?xf32>,
%dest0 : tensor<?x?xf32>, %dest1 : tensor<?x?xf32>, %val: f32,
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/hoist_encoding_ops.mlir b/compiler/src/iree/compiler/DispatchCreation/test/hoist_encoding_ops.mlir
similarity index 99%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/test/hoist_encoding_ops.mlir
rename to compiler/src/iree/compiler/DispatchCreation/test/hoist_encoding_ops.mlir
index ca97bb0..ad5935e 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/hoist_encoding_ops.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/hoist_encoding_ops.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-flow-hoist-encoding-ops))" --split-input-file %s | FileCheck %s
+// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-hoist-encoding-ops))" --split-input-file %s | FileCheck %s
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pad_fusion_with_consumer.mlir b/compiler/src/iree/compiler/DispatchCreation/test/pad_fusion_with_consumer.mlir
similarity index 92%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pad_fusion_with_consumer.mlir
rename to compiler/src/iree/compiler/DispatchCreation/test/pad_fusion_with_consumer.mlir
index 2945fca..23e01ee 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pad_fusion_with_consumer.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/pad_fusion_with_consumer.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-flow-form-dispatch-regions{fuse-pad-with-consumers}))" --split-input-file %s | FileCheck %s
+// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-form-dispatch-regions{fuse-pad-with-consumers}))" --split-input-file %s | FileCheck %s
util.func public @fuse_with_consumer(%arg0 : tensor<?x?x?x?xf32>, %arg1 : index,
%arg2 : index, %arg3 : index, %arg4 : index,
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pad_fusion_with_producer.mlir b/compiler/src/iree/compiler/DispatchCreation/test/pad_fusion_with_producer.mlir
similarity index 95%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pad_fusion_with_producer.mlir
rename to compiler/src/iree/compiler/DispatchCreation/test/pad_fusion_with_producer.mlir
index 541a142..f3345dd 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pad_fusion_with_producer.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/pad_fusion_with_producer.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-flow-form-dispatch-regions{fuse-pad-with-producers}))" --split-input-file %s | FileCheck %s
+// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-form-dispatch-regions{fuse-pad-with-producers}))" --split-input-file %s | FileCheck %s
util.func public @fuse_pad_with_producer(%arg0 : tensor<?x?x?x?xf32>,
%arg1 : tensor<?x?x?x?xf32>, %arg2 : tensor<?x?x?x?xf32>,
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/set_encoding.mlir b/compiler/src/iree/compiler/DispatchCreation/test/set_encoding.mlir
similarity index 99%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/test/set_encoding.mlir
rename to compiler/src/iree/compiler/DispatchCreation/test/set_encoding.mlir
index 7fde2d7..800dd04 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/set_encoding.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/set_encoding.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-flow-set-encoding))" %s | FileCheck %s
+// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-set-encoding))" %s | FileCheck %s
util.func public @matmul_f32f32f32(%arg0 : tensor<100x250xf32>, %arg1 : tensor<250x500xf32>,
%arg2 : tensor<100x500xf32>) -> tensor<100x500xf32> {
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/sink_reshapes.mlir b/compiler/src/iree/compiler/DispatchCreation/test/sink_reshapes.mlir
similarity index 98%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/test/sink_reshapes.mlir
rename to compiler/src/iree/compiler/DispatchCreation/test/sink_reshapes.mlir
index 92587fd..15a7e39 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/sink_reshapes.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/sink_reshapes.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-flow-sink-reshapes))" --split-input-file %s | FileCheck %s
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-dispatch-creation-sink-reshapes))" --split-input-file %s | FileCheck %s
/// If for a `tensor.expand_shape` -> consumer pair if the consumer
/// can already be fused with an op by tile and fuse, do nothing. In
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/split_reduction.mlir b/compiler/src/iree/compiler/DispatchCreation/test/split_reduction.mlir
similarity index 90%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/test/split_reduction.mlir
rename to compiler/src/iree/compiler/DispatchCreation/test/split_reduction.mlir
index b9c2742..2ce033a 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/split_reduction.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/split_reduction.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline='builtin.module(util.func(iree-flow-split-reduction-ops))' --iree-flow-split-matmul-reduction=4 %s | FileCheck %s
+// RUN: iree-opt --pass-pipeline='builtin.module(util.func(iree-dispatch-creation-split-reduction-ops))' --iree-dispatch-creation-split-matmul-reduction=4 %s | FileCheck %s
#compilation = #iree_codegen.compilation_info<
lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64, 0]]>,
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/tensor_pad_to_tensor_insert_slice.mlir b/compiler/src/iree/compiler/DispatchCreation/test/tensor_pad_to_tensor_insert_slice.mlir
similarity index 94%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/test/tensor_pad_to_tensor_insert_slice.mlir
rename to compiler/src/iree/compiler/DispatchCreation/test/tensor_pad_to_tensor_insert_slice.mlir
index 357a2b5..62ed8cc 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/tensor_pad_to_tensor_insert_slice.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/tensor_pad_to_tensor_insert_slice.mlir
@@ -1,5 +1,5 @@
-// RUN: iree-opt --split-input-file --iree-flow-tensor-pad-to-tensor-insert-slice --iree-flow-canonicalize %s | FileCheck %s
-// RUN: iree-opt --split-input-file --iree-flow-tensor-pad-to-tensor-insert-slice=skip-one-linalg-use-case --iree-flow-canonicalize %s | FileCheck %s --check-prefix=SKIP
+// RUN: iree-opt --split-input-file --iree-dispatch-creation-tensor-pad-to-tensor-insert-slice --iree-flow-canonicalize %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-dispatch-creation-tensor-pad-to-tensor-insert-slice=skip-one-linalg-use-case --iree-flow-canonicalize %s | FileCheck %s --check-prefix=SKIP
util.func public @tensor_pad(%arg0 : tensor<?x?xf32>, %arg1 : tensor<f32>, %arg2 : index, %arg3 : index) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transform_dialect_dispatch_spec.mlir b/compiler/src/iree/compiler/DispatchCreation/test/transform_dialect_dispatch_spec.mlir
similarity index 100%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transform_dialect_dispatch_spec.mlir
rename to compiler/src/iree/compiler/DispatchCreation/test/transform_dialect_dispatch_spec.mlir
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transform_dispatch_region_formation.mlir b/compiler/src/iree/compiler/DispatchCreation/test/transform_dispatch_region_formation.mlir
similarity index 100%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transform_dispatch_region_formation.mlir
rename to compiler/src/iree/compiler/DispatchCreation/test/transform_dispatch_region_formation.mlir
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transpose_generic_ops.mlir b/compiler/src/iree/compiler/DispatchCreation/test/transpose_generic_ops.mlir
similarity index 97%
rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transpose_generic_ops.mlir
rename to compiler/src/iree/compiler/DispatchCreation/test/transpose_generic_ops.mlir
index e857cdb..e278ef9 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transpose_generic_ops.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/transpose_generic_ops.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --verify-diagnostics --iree-flow-transpose-generic-ops --iree-flow-canonicalize -cse --mlir-print-local-scope %s | FileCheck %s
+// RUN: iree-opt --split-input-file --verify-diagnostics --iree-dispatch-creation-transpose-generic-ops -canonicalize -cse --mlir-print-local-scope %s | FileCheck %s
util.func @supported_conv(%arg0 : tensor<2x130x130x16xf16>, %arg1 : tensor<3x3x16x320xf16>) -> tensor<2x320x128x128xf16> {
%empty = tensor.empty() : tensor<2x128x128x320xf32>
diff --git a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel
index d03b723..5575a05 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel
+++ b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel
@@ -92,6 +92,7 @@
"//compiler/src/iree/compiler/Dialect/Util/Analysis/DFX",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"//compiler/src/iree/compiler/Dialect/Util/Transforms",
+ "//compiler/src/iree/compiler/DispatchCreation",
"//compiler/src/iree/compiler/Modules/IO/Parameters/Transforms",
"//compiler/src/iree/compiler/Pipelines:Options",
"//compiler/src/iree/compiler/Utils",
diff --git a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt
index a2f6b63..19e2d0e 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt
+++ b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt
@@ -107,6 +107,7 @@
iree::compiler::Dialect::Util::Analysis::DFX
iree::compiler::Dialect::Util::IR
iree::compiler::Dialect::Util::Transforms
+ iree::compiler::DispatchCreation
iree::compiler::Modules::IO::Parameters::Transforms
iree::compiler::Pipelines::Options
iree::compiler::Utils
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
index 50197eb..ec684f1 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
@@ -7,6 +7,7 @@
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
+#include "iree/compiler/DispatchCreation/Passes.h"
#include "iree/compiler/Modules/IO/Parameters/Transforms/Passes.h"
#include "iree/compiler/Utils/PassUtils.h"
#include "mlir/Dialect/Linalg/Passes.h"
@@ -127,7 +128,7 @@
// specialized raising and the op names are no longer useful.
.addPass(createGeneralizeLinalgNamedOpsPass);
- mainPassManager.addPass(IREE::Flow::createFoldUnitExtentDimsPass());
+ mainPassManager.addPass(DispatchCreation::createFoldUnitExtentDimsPass());
FunctionLikeNest(mainPassManager)
.addPredicatedPass(clEnableFuseSiluHorizontalMatmul,
createFuseSiluHorizontalMatmulPass)
@@ -157,8 +158,8 @@
// Enable data tiling after they are in a canonical form.
if (transformOptions.options.dataTiling) {
FunctionLikeNest(mainPassManager).addPass([&]() {
- return IREE::Flow::createSetEncodingPass(
- IREE::Flow::SetEncodingPassOptions{clPadFactor});
+ return DispatchCreation::createSetEncodingPass(
+ DispatchCreation::SetEncodingPassOptions{clPadFactor});
});
// TODO(hanchung): Make data-tiling passes be FunctionOpInterface pass, so
// we can use `FunctionLikNest` here.
diff --git a/compiler/src/iree/compiler/Pipelines/BUILD.bazel b/compiler/src/iree/compiler/Pipelines/BUILD.bazel
index d51bdf1..9407077 100644
--- a/compiler/src/iree/compiler/Pipelines/BUILD.bazel
+++ b/compiler/src/iree/compiler/Pipelines/BUILD.bazel
@@ -42,6 +42,7 @@
"//compiler/src/iree/compiler/Dialect/VM/Conversion",
"//compiler/src/iree/compiler/Dialect/VM/Target/Bytecode",
"//compiler/src/iree/compiler/Dialect/VM/Transforms",
+ "//compiler/src/iree/compiler/DispatchCreation",
"//compiler/src/iree/compiler/GlobalOptimization",
"//compiler/src/iree/compiler/InputConversion/Common",
"//compiler/src/iree/compiler/InputConversion/Common:AutoInputConversionPipeline",
diff --git a/compiler/src/iree/compiler/Pipelines/CMakeLists.txt b/compiler/src/iree/compiler/Pipelines/CMakeLists.txt
index 9bb62f8..530c63e 100644
--- a/compiler/src/iree/compiler/Pipelines/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Pipelines/CMakeLists.txt
@@ -46,6 +46,7 @@
iree::compiler::Dialect::VM::Conversion
iree::compiler::Dialect::VM::Target::Bytecode
iree::compiler::Dialect::VM::Transforms
+ iree::compiler::DispatchCreation
iree::compiler::GlobalOptimization
iree::compiler::InputConversion::Common
iree::compiler::InputConversion::Common::AutoInputConversionPipeline
diff --git a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp
index bd56dc8..77319e2 100644
--- a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp
+++ b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp
@@ -13,6 +13,7 @@
#include "iree/compiler/Dialect/Stream/Transforms/Passes.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
#include "iree/compiler/Dialect/VM/Transforms/Passes.h"
+#include "iree/compiler/DispatchCreation/Passes.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "iree/compiler/InputConversion/Common/Passes.h"
#include "iree/compiler/Modules/HAL/Inline/Transforms/Passes.h"
@@ -271,6 +272,20 @@
// No flow/stream processing (implies no tensors).
break;
default:
+ DispatchCreation::TransformOptions dispatchCreationOptions;
+ if (compileFrom < IREEVMPipelinePhase::DispatchCreation) { // late-entry
+ IREE_TRACE_ADD_BEGIN_FRAME_PASS(passManager, "DispatchCreation");
+ if (hooks.beforePhase)
+ hooks.beforePhase(IREEVMPipelinePhase::DispatchCreation, passManager);
+ DispatchCreation::buildDispatchCreationPassPipeline(
+ passManager, dispatchCreationOptions);
+ if (hooks.afterPhase)
+ hooks.afterPhase(IREEVMPipelinePhase::DispatchCreation, passManager);
+ IREE_TRACE_ADD_END_FRAME_PASS(passManager, "DispatchCreation");
+ }
+ if (compileTo == IREEVMPipelinePhase::DispatchCreation)
+ return; // early-exit
+
IREE::Flow::TransformOptions flowOptions;
if (compileFrom < IREEVMPipelinePhase::Flow) { // late-entry
IREE_TRACE_ADD_BEGIN_FRAME_PASS(passManager, "Flow");
diff --git a/compiler/src/iree/compiler/Pipelines/Pipelines.h b/compiler/src/iree/compiler/Pipelines/Pipelines.h
index cdc754f..a7c3af1 100644
--- a/compiler/src/iree/compiler/Pipelines/Pipelines.h
+++ b/compiler/src/iree/compiler/Pipelines/Pipelines.h
@@ -28,6 +28,7 @@
ABI,
Preprocessing,
GlobalOptimization,
+ DispatchCreation,
Flow,
Stream,
ExecutableSources,
@@ -53,6 +54,8 @@
"Compiles up to the `preprocessing` specified");
callback(IREEVMPipelinePhase::GlobalOptimization, "global-optimization",
"Compiles up to global optimization.");
+ callback(IREEVMPipelinePhase::DispatchCreation, "dispatch-creation",
+ "Compiles up to dispatch creation.");
callback(IREEVMPipelinePhase::Flow, "flow",
"Compiles up to the `flow` dialect.");
callback(IREEVMPipelinePhase::Stream, "stream",
diff --git a/compiler/src/iree/compiler/Preprocessing/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/BUILD.bazel
index 399f8f2..8187291 100644
--- a/compiler/src/iree/compiler/Preprocessing/BUILD.bazel
+++ b/compiler/src/iree/compiler/Preprocessing/BUILD.bazel
@@ -21,7 +21,8 @@
"Passes.h",
],
deps = [
- "//compiler/src/iree/compiler/Dialect/Flow/Transforms",
+ "//compiler/src/iree/compiler/Dialect/Util/IR",
+ "//compiler/src/iree/compiler/DispatchCreation",
"//compiler/src/iree/compiler/GlobalOptimization",
"//compiler/src/iree/compiler/Pipelines:Options",
"//compiler/src/iree/compiler/PluginAPI",
diff --git a/compiler/src/iree/compiler/Preprocessing/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/CMakeLists.txt
index 9a70837..b32fc8f 100644
--- a/compiler/src/iree/compiler/Preprocessing/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Preprocessing/CMakeLists.txt
@@ -23,7 +23,8 @@
MLIRLinalgTransforms
MLIRPass
MLIRTransforms
- iree::compiler::Dialect::Flow::Transforms
+ iree::compiler::Dialect::Util::IR
+ iree::compiler::DispatchCreation
iree::compiler::GlobalOptimization
iree::compiler::Pipelines::Options
iree::compiler::PluginAPI
diff --git a/compiler/src/iree/compiler/Preprocessing/Passes.cpp b/compiler/src/iree/compiler/Preprocessing/Passes.cpp
index c149342..af2c5bc 100644
--- a/compiler/src/iree/compiler/Preprocessing/Passes.cpp
+++ b/compiler/src/iree/compiler/Preprocessing/Passes.cpp
@@ -5,7 +5,8 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Preprocessing/Passes.h"
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "iree/compiler/DispatchCreation/Passes.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "iree/compiler/Preprocessing/Common/Passes.h"
#include "iree/compiler/Utils/PassUtils.h"
@@ -103,7 +104,7 @@
.addPass(mlir::createLinalgNamedOpConversionPass)
.addPass(GlobalOptimization::createConvert1X1FilterConv2DToMatmulPass)
.addPass(createConvertConvToChannelsLastPass);
- passManager.addPass(IREE::Flow::createFoldUnitExtentDimsPass());
+ passManager.addPass(DispatchCreation::createFoldUnitExtentDimsPass());
passManager.addPass(createCanonicalizerPass());
passManager.addPass(createCSEPass());
}
diff --git a/compiler/src/iree/compiler/Tools/BUILD.bazel b/compiler/src/iree/compiler/Tools/BUILD.bazel
index 94d6554..ad3bf87 100644
--- a/compiler/src/iree/compiler/Tools/BUILD.bazel
+++ b/compiler/src/iree/compiler/Tools/BUILD.bazel
@@ -57,6 +57,7 @@
"//compiler/src/iree/compiler/Dialect/VM/Transforms",
"//compiler/src/iree/compiler/Dialect/VMVX/IR:VMVXDialect",
"//compiler/src/iree/compiler/Dialect/VMVX/Transforms",
+ "//compiler/src/iree/compiler/DispatchCreation",
"//compiler/src/iree/compiler/ExternalInterfaces:ExternalModels",
"//compiler/src/iree/compiler/GlobalOptimization/Interfaces",
"//compiler/src/iree/compiler/InputConversion/Common",
diff --git a/compiler/src/iree/compiler/Tools/init_iree_passes.h b/compiler/src/iree/compiler/Tools/init_iree_passes.h
index 69f519e..1bd8b3d 100644
--- a/compiler/src/iree/compiler/Tools/init_iree_passes.h
+++ b/compiler/src/iree/compiler/Tools/init_iree_passes.h
@@ -25,6 +25,7 @@
#include "iree/compiler/Dialect/VM/Analysis/TestPasses.h"
#include "iree/compiler/Dialect/VM/Transforms/Passes.h"
#include "iree/compiler/Dialect/VMVX/Transforms/Passes.h"
+#include "iree/compiler/DispatchCreation/Passes.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "iree/compiler/InputConversion/Common/Passes.h"
#include "iree/compiler/Modules/HAL/Inline/Transforms/Passes.h"
@@ -51,7 +52,9 @@
InputConversion::registerCommonInputConversionPasses();
ConstEval::registerConstEvalPasses();
GlobalOptimization::registerGlobalOptimizationPipeline();
+ DispatchCreation::registerDispatchCreationPipelines();
Preprocessing::registerPreprocessingPasses();
+ DispatchCreation::registerDispatchCreationPasses();
IREE::Flow::registerFlowPasses();
IREE::HAL::registerHALPasses();
IREE::HAL::Inline::registerHALInlinePasses();
diff --git a/experimental/regression_suite/shark-test-suite-models/sd3/test_clip.py b/experimental/regression_suite/shark-test-suite-models/sd3/test_clip.py
index 997dc3f..6136876 100644
--- a/experimental/regression_suite/shark-test-suite-models/sd3/test_clip.py
+++ b/experimental/regression_suite/shark-test-suite-models/sd3/test_clip.py
@@ -110,8 +110,8 @@
"--iree-opt-outer-dim-concat=true",
"--iree-hip-waves-per-eu=2",
"--iree-llvmgpu-enable-prefetch",
- "--iree-flow-enable-aggressive-fusion",
- "--iree-flow-enable-fuse-horizontal-contractions=true",
+ "--iree-dispatch-creation-enable-aggressive-fusion",
+ "--iree-dispatch-creation-enable-fuse-horizontal-contractions=true",
"--iree-opt-aggressively-propagate-transposes=true",
"--iree-codegen-llvmgpu-use-vector-distribution=true",
"--iree-execution-model=async-external",
diff --git a/experimental/regression_suite/shark-test-suite-models/sd3/test_mmdit.py b/experimental/regression_suite/shark-test-suite-models/sd3/test_mmdit.py
index 5fc96c7..8486d9b 100644
--- a/experimental/regression_suite/shark-test-suite-models/sd3/test_mmdit.py
+++ b/experimental/regression_suite/shark-test-suite-models/sd3/test_mmdit.py
@@ -87,8 +87,8 @@
"--iree-opt-const-eval=false",
f"--iree-codegen-transform-dialect-library={iree_test_path_extension}/attention_and_matmul_spec.mlir",
"--iree-global-opt-propagate-transposes=true",
- "--iree-flow-enable-fuse-horizontal-contractions=true",
- "--iree-flow-enable-aggressive-fusion=true",
+ "--iree-dispatch-creation-enable-fuse-horizontal-contractions=true",
+ "--iree-dispatch-creation-enable-aggressive-fusion=true",
"--iree-opt-aggressively-propagate-transposes=true",
"--iree-opt-outer-dim-concat=true",
"--iree-vm-target-truncate-unsupported-floats",
diff --git a/experimental/regression_suite/shark-test-suite-models/sd3/test_vae.py b/experimental/regression_suite/shark-test-suite-models/sd3/test_vae.py
index 8037ef0..b41566c 100644
--- a/experimental/regression_suite/shark-test-suite-models/sd3/test_vae.py
+++ b/experimental/regression_suite/shark-test-suite-models/sd3/test_vae.py
@@ -67,7 +67,7 @@
"--iree-opt-outer-dim-concat=true",
"--iree-llvmgpu-enable-prefetch=true",
"--iree-hip-waves-per-eu=2",
- "--iree-flow-enable-aggressive-fusion=true",
+ "--iree-dispatch-creation-enable-aggressive-fusion=true",
"--iree-codegen-llvmgpu-use-vector-distribution=true",
"--iree-execution-model=async-external",
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics)",
diff --git a/experimental/regression_suite/shark-test-suite-models/sdxl/test_clip.py b/experimental/regression_suite/shark-test-suite-models/sdxl/test_clip.py
index 1eb686d..4e5c48a 100644
--- a/experimental/regression_suite/shark-test-suite-models/sdxl/test_clip.py
+++ b/experimental/regression_suite/shark-test-suite-models/sdxl/test_clip.py
@@ -96,8 +96,8 @@
"--iree-opt-outer-dim-concat=true",
"--iree-hip-waves-per-eu=2",
"--iree-llvmgpu-enable-prefetch",
- "--iree-flow-enable-aggressive-fusion",
- "--iree-flow-enable-fuse-horizontal-contractions=true",
+ "--iree-dispatch-creation-enable-aggressive-fusion",
+ "--iree-dispatch-creation-enable-fuse-horizontal-contractions=true",
"--iree-opt-aggressively-propagate-transposes=true",
"--iree-codegen-llvmgpu-use-vector-distribution=true",
"--iree-execution-model=async-external",
diff --git a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py
index 0a5af53..5602251 100644
--- a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py
+++ b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py
@@ -93,8 +93,8 @@
"--iree-opt-const-eval=false",
f"--iree-codegen-transform-dialect-library={iree_test_path_extension}/attention_and_matmul_spec.mlir",
"--iree-global-opt-propagate-transposes=true",
- "--iree-flow-enable-fuse-horizontal-contractions=true",
- "--iree-flow-enable-aggressive-fusion=true",
+ "--iree-dispatch-creation-enable-fuse-horizontal-contractions=true",
+ "--iree-dispatch-creation-enable-aggressive-fusion=true",
"--iree-opt-aggressively-propagate-transposes=true",
"--iree-opt-outer-dim-concat=true",
"--iree-vm-target-truncate-unsupported-floats",
diff --git a/experimental/regression_suite/shark-test-suite-models/sdxl/test_vae.py b/experimental/regression_suite/shark-test-suite-models/sdxl/test_vae.py
index 515dd99..1267f62 100644
--- a/experimental/regression_suite/shark-test-suite-models/sdxl/test_vae.py
+++ b/experimental/regression_suite/shark-test-suite-models/sdxl/test_vae.py
@@ -67,7 +67,7 @@
"--iree-opt-outer-dim-concat=true",
"--iree-llvmgpu-enable-prefetch=true",
"--iree-hip-waves-per-eu=2",
- "--iree-flow-enable-aggressive-fusion=true",
+ "--iree-dispatch-creation-enable-aggressive-fusion=true",
"--iree-codegen-llvmgpu-use-vector-distribution=true",
"--iree-execution-model=async-external",
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics)",
diff --git a/tests/e2e/linalg_ext_ops/BUILD.bazel b/tests/e2e/linalg_ext_ops/BUILD.bazel
index f655d9b..80af04d 100644
--- a/tests/e2e/linalg_ext_ops/BUILD.bazel
+++ b/tests/e2e/linalg_ext_ops/BUILD.bazel
@@ -148,7 +148,7 @@
srcs = [
"top-k.mlir",
],
- compiler_flags = ["--iree-flow-topk-split-reduction=2"],
+ compiler_flags = ["--iree-dispatch-creation-topk-split-reduction=2"],
driver = "cuda",
tags = [
# CUDA cuInit fails with sanitizer on.
@@ -166,7 +166,7 @@
srcs = [
"top-k.mlir",
],
- compiler_flags = ["--iree-flow-topk-split-reduction=3,2"],
+ compiler_flags = ["--iree-dispatch-creation-topk-split-reduction=3,2"],
driver = "cuda",
tags = [
# CUDA cuInit fails with sanitizer on.
diff --git a/tests/e2e/linalg_ext_ops/CMakeLists.txt b/tests/e2e/linalg_ext_ops/CMakeLists.txt
index 97dc732..9865284 100644
--- a/tests/e2e/linalg_ext_ops/CMakeLists.txt
+++ b/tests/e2e/linalg_ext_ops/CMakeLists.txt
@@ -123,7 +123,7 @@
DRIVER
"cuda"
COMPILER_FLAGS
- "--iree-flow-topk-split-reduction=2"
+ "--iree-dispatch-creation-topk-split-reduction=2"
LABELS
"noasan"
"nomsan"
@@ -142,7 +142,7 @@
DRIVER
"cuda"
COMPILER_FLAGS
- "--iree-flow-topk-split-reduction=3,2"
+ "--iree-dispatch-creation-topk-split-reduction=3,2"
LABELS
"noasan"
"nomsan"
diff --git a/tests/e2e/matmul/BUILD.bazel b/tests/e2e/matmul/BUILD.bazel
index d73f468..e712f9f 100644
--- a/tests/e2e/matmul/BUILD.bazel
+++ b/tests/e2e/matmul/BUILD.bazel
@@ -592,7 +592,7 @@
[iree_generated_e2e_runner_test(
name = "e2e_matmul_cuda_%s_large_splitk" % lhs_rhs_type,
compiler_flags = [
- "--iree-flow-split-matmul-reduction=4",
+ "--iree-dispatch-creation-split-matmul-reduction=4",
],
generator = ":generate_e2e_matmul_tests",
generator_args = [
diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt
index 5651d17..f117ba1 100644
--- a/tests/e2e/matmul/CMakeLists.txt
+++ b/tests/e2e/matmul/CMakeLists.txt
@@ -2070,7 +2070,7 @@
DRIVERS
"cuda"
COMPILER_FLAGS
- "--iree-flow-split-matmul-reduction=4"
+ "--iree-dispatch-creation-split-matmul-reduction=4"
LABELS
"noasan"
"nomsan"
diff --git a/tests/e2e/regression/BUILD.bazel b/tests/e2e/regression/BUILD.bazel
index 4879530..87fdccb 100644
--- a/tests/e2e/regression/BUILD.bazel
+++ b/tests/e2e/regression/BUILD.bazel
@@ -141,7 +141,7 @@
"softmax_large.mlir",
],
compiler_flags = [
- "--iree-flow-fuse-multi-use",
+ "--iree-dispatch-creation-fuse-multi-use",
],
driver = "cuda",
tags = [
@@ -171,7 +171,7 @@
"softmax.mlir",
],
compiler_flags = [
- "--iree-flow-fuse-multi-use",
+ "--iree-dispatch-creation-fuse-multi-use",
],
driver = "local-task",
target_backend = "llvm-cpu",
diff --git a/tests/e2e/regression/CMakeLists.txt b/tests/e2e/regression/CMakeLists.txt
index 9334955..06597e6 100644
--- a/tests/e2e/regression/CMakeLists.txt
+++ b/tests/e2e/regression/CMakeLists.txt
@@ -190,7 +190,7 @@
DRIVER
"cuda"
COMPILER_FLAGS
- "--iree-flow-fuse-multi-use"
+ "--iree-dispatch-creation-fuse-multi-use"
LABELS
"noasan"
"nomsan"
@@ -222,7 +222,7 @@
DRIVER
"local-task"
COMPILER_FLAGS
- "--iree-flow-fuse-multi-use"
+ "--iree-dispatch-creation-fuse-multi-use"
)
iree_check_single_backend_test_suite(
diff --git a/tools/test/compile_pipelines.mlir b/tools/test/compile_pipelines.mlir
index fb6dbbe..8546056 100644
--- a/tools/test/compile_pipelines.mlir
+++ b/tools/test/compile_pipelines.mlir
@@ -2,6 +2,7 @@
// RUN: iree-opt --iree-abi-transformation-pipeline - | \
// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-device-assignment-pipeline{target-devices=local})' --iree-hal-local-target-device-backends=vmvx - | \
// RUN: iree-opt --iree-global-optimization-transformation-pipeline - | \
+// RUN: iree-opt --iree-dispatch-creation-pipeline - | \
// RUN: iree-opt --iree-flow-transformation-pipeline - | \
// RUN: iree-opt --iree-stream-transformation-pipeline - | \
// RUN: iree-opt --iree-hal-transformation-pipeline - | \