Add new aggressive fusion heuristics to DispatchLinalgOnTensorsPass (#10472)
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
index 4a8f212..2ab384c 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
@@ -31,6 +31,7 @@
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
@@ -60,12 +61,6 @@
"dispatch region"),
llvm::cl::init(256));
-static llvm::cl::opt<bool> clEnableMultiResultDispatches(
- "iree-flow-enable-multi-result-dispatches",
- llvm::cl::desc(
- "Enable dispatch region formation to enable multi-result dispatches"),
- llvm::cl::init(false));
-
static const char kRootOpAttr[] = "__root_op__";
static const char kFusionGroupsAttr[] = "__fused_op__";
@@ -381,22 +376,22 @@
int64_t groupNum = getRootNumber(rootOp);
std::deque<Operation *> worklist;
worklist.push_back(rootOp);
- llvm::SmallDenseSet<Operation *, 2> movedOps;
- movedOps.insert(rootOp);
+ llvm::SmallDenseSet<Operation *, 2> visitedOps;
+ visitedOps.insert(rootOp);
while (!worklist.empty()) {
Operation *currRoot = worklist.front();
worklist.pop_front();
for (auto operand : currRoot->getOperands()) {
auto producer = operand.getDefiningOp();
- if (movedOps.count(producer)) continue;
- if (!producer || !isInFusionGroup(producer, groupNum)) continue;
- movedOps.insert(producer);
+ if (!producer || visitedOps.count(producer)) continue;
+ visitedOps.insert(producer);
+ if (!isInFusionGroup(producer, groupNum)) continue;
worklist.push_back(producer);
dispatchOps.push_back(producer);
}
}
- return dispatchOps;
+ return llvm::to_vector(llvm::reverse(orderOperations(dispatchOps)));
}
//===---------------------------------------------------------------------===//
@@ -721,27 +716,91 @@
// Heuristics for fusing dispatchble ops with root ops using tile + fuse.
//===----------------------------------------------------------------------===//
-/// Checks if the producer and consumer LinalgOps can be fused.
-static bool areFusableLinalgOps(OpOperand &use) {
- return areLinalgOpsFusableUsingTileAndFuse(use);
+/// Returns a bit vector of size number of loops of the `interfaceOp` with
+/// the bits corresponding to outer parallel loops set to `true`.
+static llvm::SmallBitVector getOuterParallelLoops(TilingInterface interfaceOp) {
+ SmallVector<StringRef> loopIteratorTypes = interfaceOp.getLoopIteratorTypes();
+ llvm::SmallBitVector parallelLoops(loopIteratorTypes.size());
+ for (auto iteratorType : llvm::enumerate(loopIteratorTypes)) {
+ if (iteratorType.value() != getParallelIteratorTypeName()) break;
+ parallelLoops.set(iteratorType.index());
+ }
+ return parallelLoops;
}
-/// Returns true if this is a fusable use.
-static bool isFusableWithConsumer(OpOperand &use) {
- // Check for linalg producer -> consumer fusion with tile + fuse.
- return areFusableLinalgOps(use);
+/// Returns true if `map` is an identity map with zeros, i.e. if you
+/// drop the result exprs that are constant zeros, the `map` will become an
+/// identity.
+static bool isIdentityMapWithZeros(AffineMap map) {
+ if (map.getNumSymbols() != 0) return false;
+ unsigned dimsSeen = 0;
+ for (auto result : map.getResults()) {
+ bool isValidExpr = TypeSwitch<AffineExpr, bool>(result)
+ .Case<AffineDimExpr>([&dimsSeen](auto dimExpr) {
+ if (dimExpr.getPosition() != dimsSeen)
+ return false;
+ dimsSeen++;
+ return true;
+ })
+ .Case<AffineConstantExpr>([](auto constExpr) {
+ return constExpr.getValue() == 0;
+ })
+ .Default([](AffineExpr) { return false; });
+ if (!isValidExpr) return false;
+ }
+ return dimsSeen == map.getNumDims();
+}
+
+/// Method to check if two `linalg.generic` op with producer-consumer
+/// relationship through `operand` have compatible outer-parallel loops.
+static bool hasCompatibleOuterParallelLoops(
+ OpOperand &operand, bool allowConsumerParallelismPessimization) {
+ auto producer = operand.get().getDefiningOp<linalg::LinalgOp>();
+ auto consumer = dyn_cast<linalg::LinalgOp>(operand.getOwner());
+ if (!producer || !consumer) return false;
+
+ llvm::SmallBitVector producerParallelLoops =
+ getOuterParallelLoops(cast<TilingInterface>(producer.getOperation()));
+ llvm::SmallBitVector consumerParallelLoops =
+ getOuterParallelLoops(cast<TilingInterface>(consumer.getOperation()));
+
+ if (allowConsumerParallelismPessimization) {
+ if (producerParallelLoops.count() > consumerParallelLoops.count())
+ return false;
+ } else if (producerParallelLoops.count() != consumerParallelLoops.count()) {
+ return false;
+ }
+
+ auto producerIndexingMap =
+ producer.getTiedIndexingMapForResult(operand.get().cast<OpResult>());
+ auto consumerIndexingMap = consumer.getTiedIndexingMap(&operand);
+ if (!producerIndexingMap.isProjectedPermutation() ||
+ !consumerIndexingMap.isProjectedPermutation()) {
+ return false;
+ }
+
+ /// Project out the non-parallel dimensions.
+ llvm::SmallBitVector producerProjectedDims(producerParallelLoops);
+ producerProjectedDims.flip();
+ auto projectedProducerMap =
+ getProjectedMap(producerIndexingMap, producerProjectedDims);
+
+ llvm::SmallBitVector consumerProjectedDims(producerParallelLoops);
+ consumerProjectedDims.flip();
+ consumerProjectedDims.resize(consumer.getNumLoops(), true);
+ auto projectedConsumerMap =
+ getProjectedMap(consumerIndexingMap, consumerProjectedDims);
+
+ return isIdentityMapWithZeros(projectedProducerMap) &&
+ isIdentityMapWithZeros(projectedConsumerMap);
}
/// For all uses of an operation, finds the use that dominates all other uses.
static Optional<OpOperand *> getFusableUse(Operation *op,
- DominanceInfo const &dominanceInfo) {
- if (!clEnableMultiResultDispatches) {
- if (op->hasOneUse()) {
- OpOperand &use = *(op->use_begin());
- return &use;
- }
- return llvm::None;
- }
+ DominanceInfo const &dominanceInfo,
+ bool fuseMultiUse) {
+ if (!fuseMultiUse && !op->hasOneUse()) return llvm::None;
+
for (auto &use : op->getUses()) {
Operation *user = use.getOwner();
if (llvm::all_of(op->getUsers(), [&](Operation *c) {
@@ -753,11 +812,65 @@
return llvm::None;
}
+/// Returns true if the operands are fusable under the aggressive fusion
+/// heuristics.
+static bool areOpsAggresiveFusable(Operation *producer, Operation *consumer,
+ bool allowConsumerParallelismPessimization) {
+ // Collect all the uses from producer to consumer.
+ SmallVector<OpOperand *> allUses;
+ for (OpOperand &producerUse : producer->getUses()) {
+ if (producerUse.getOwner() != consumer) continue;
+ allUses.push_back(&producerUse);
+ }
+
+ // Check that the consumer and producer have compatible outer parallel loops.
+ if (!llvm::all_of(allUses, [&](OpOperand *operand) {
+ return hasCompatibleOuterParallelLoops(
+ *operand, allowConsumerParallelismPessimization);
+ })) {
+ return false;
+ }
+
+ // Finally only fuse if the `ins` operand can be properly bufferized.
+ // TODO(#10498): Handle the multi-result case.
+ return llvm::all_of(allUses, [](OpOperand *operand) {
+ return isInsOperandBufferizable(operand, /*aggressiveFusion=*/true);
+ });
+}
+
+/// Returns true if this is a fusable use, while fusing a root with its
+/// consumer.
+static bool isFusableWithConsumer(OpOperand &fusedOperand,
+ bool aggressiveFusion) {
+ // Use the original fusion heuristics if aggressive fusion isn't enabled.
+ if (!aggressiveFusion)
+ return areLinalgOpsFusableUsingTileAndFuse(fusedOperand);
+
+ // Logics with aggressive fusion heuristics.
+ Operation *producer = fusedOperand.get().getDefiningOp();
+ Operation *consumer = fusedOperand.getOwner();
+
+ if (!isa<linalg::LinalgOp>(producer) || !isa<linalg::LinalgOp>(consumer))
+ return false;
+
+ auto consumerLinalgOp = cast<linalg::LinalgOp>(consumer);
+
+ // Check that the consumer is all parallel.
+ if (consumerLinalgOp.getNumLoops() !=
+ consumerLinalgOp.getNumParallelLoops()) {
+ return false;
+ }
+
+ return areOpsAggresiveFusable(producer, consumer,
+ /*allowConsumerParallelismPessimization=*/true);
+}
+
/// Fuses roots with its consumers. If a root is fused with its consumer, it is
/// no more tagged as a root to aid with the dispatch region formation.
static void fuseRootsWithConsumers(MLIRContext *context,
ArrayRef<Operation *> roots,
- DominanceInfo const &dominanceInfo) {
+ DominanceInfo const &dominanceInfo,
+ bool aggressiveFusion) {
SmallVector<Operation *> workList(roots.begin(), roots.end());
// Fuse with consumers where possible.
while (!workList.empty()) {
@@ -774,7 +887,8 @@
appendToFusionGroup(currRoot, rootNumber);
};
- Optional<OpOperand *> fusableUse = getFusableUse(currRoot, dominanceInfo);
+ Optional<OpOperand *> fusableUse = getFusableUse(
+ currRoot, dominanceInfo, /*fuseMultiUse=*/aggressiveFusion);
if (!fusableUse) continue;
// Analyse the use to see if it is fusable.
@@ -784,7 +898,7 @@
continue;
}
- if (isFusableWithConsumer(*(fusableUse.value()))) {
+ if (isFusableWithConsumer(*(fusableUse.value()), aggressiveFusion)) {
updateRootTo(consumerOp);
workList.push_back(consumerOp);
}
@@ -792,19 +906,29 @@
}
/// Method to check if the consumer of a use can be fused with its producer.
-static bool isFusableWithProducer(OpOperand &operand) {
+static bool isFusableWithProducer(OpOperand &operand, bool aggressiveFusion) {
Operation *producer = operand.get().getDefiningOp();
Operation *consumer = operand.getOwner();
- if (isa<linalg::LinalgOp>(consumer) && isa<linalg::LinalgOp>(producer)) {
- auto consumerLinalgOp = cast<linalg::LinalgOp>(consumer);
- auto producerLinalgOp = cast<linalg::LinalgOp>(producer);
- if (consumerLinalgOp.isOutputTensor(&operand) &&
- producerLinalgOp.getNumLoops() ==
- producerLinalgOp.getNumParallelLoops()) {
- return true;
- }
+ if (!isa<linalg::LinalgOp>(consumer) || !isa<linalg::LinalgOp>(producer))
+ return false;
+
+ auto consumerLinalgOp = cast<linalg::LinalgOp>(consumer);
+ auto producerLinalgOp = cast<linalg::LinalgOp>(producer);
+ if (consumerLinalgOp.isOutputTensor(&operand) &&
+ producerLinalgOp.getNumLoops() ==
+ producerLinalgOp.getNumParallelLoops()) {
+ return true;
}
+
+ // Only fuse on inputs if both are generic ops.
+ if (aggressiveFusion && consumerLinalgOp.isInputTensor(&operand) &&
+ isa<linalg::GenericOp>(consumer) && isa<linalg::GenericOp>(producer)) {
+ return areOpsAggresiveFusable(
+ producer, consumer,
+ /*allowConsumerParallelismPessimization=*/false);
+ }
+
return false;
}
@@ -812,21 +936,28 @@
/// in reverse to fuse with producers.
static void fuseRootsWithProducers(MLIRContext *context, Operation *root,
unsigned groupNum,
- DominanceInfo const &dominanceInfo) {
- // We probably want a worklist algorithm here, but for now just look at
- // immediate producers.
- for (OpOperand &operand : root->getOpOperands()) {
- Operation *producer = operand.get().getDefiningOp();
- if (!producer) continue;
- if (hasFusionGroupsAttribute(producer) || hasRootOpAttribute(producer)) {
- continue;
- }
+ DominanceInfo const &dominanceInfo,
+ bool aggressiveFusion) {
+ SmallVector<Operation *> worklist;
+ worklist.push_back(root);
- Optional<OpOperand *> fusableUse = getFusableUse(producer, dominanceInfo);
- if (!fusableUse || fusableUse.value()->getOwner() != root) continue;
+ while (!worklist.empty()) {
+ Operation *candidate = worklist.pop_back_val();
+ for (OpOperand &operand : candidate->getOpOperands()) {
+ Operation *producer = operand.get().getDefiningOp();
+ if (!producer) continue;
+ if (hasFusionGroupsAttribute(producer) || hasRootOpAttribute(producer)) {
+ continue;
+ }
- if (isFusableWithProducer(operand)) {
+ Optional<OpOperand *> fusableUse = getFusableUse(
+ producer, dominanceInfo, /*fuseMultiUse=*/aggressiveFusion);
+ if (!fusableUse || fusableUse.value()->getOwner() != candidate) continue;
+
+ if (!isFusableWithProducer(operand, aggressiveFusion)) continue;
+
appendToFusionGroup(producer, groupNum);
+ worklist.push_back(producer);
}
}
}
@@ -840,7 +971,8 @@
/// very simple heuristic is used below, but the mechanism should be general
/// enough to capture any heuristic.
static unsigned decideFusableLinalgOps(FunctionOpInterface funcOp,
- DominanceInfo const &dominanceInfo) {
+ DominanceInfo const &dominanceInfo,
+ bool aggressiveFusion) {
unsigned numRootOps = 0;
MLIRContext *context = funcOp->getContext();
OpBuilder builder(context);
@@ -857,11 +989,12 @@
unsigned newGroup = numRootOps++;
setRootAttribute(context, &op, newGroup);
- fuseRootsWithProducers(context, &op, newGroup, dominanceInfo);
+ fuseRootsWithProducers(context, &op, newGroup, dominanceInfo,
+ aggressiveFusion);
roots.push_back(&op);
}
roots = llvm::to_vector(llvm::reverse(roots));
- fuseRootsWithConsumers(context, roots, dominanceInfo);
+ fuseRootsWithConsumers(context, roots, dominanceInfo, aggressiveFusion);
}
// Once all root linalg ops have been tagged, put all remaining generic ops
@@ -881,7 +1014,7 @@
roots.push_back(&op);
}
roots = llvm::to_vector(llvm::reverse(roots));
- fuseRootsWithConsumers(context, roots, dominanceInfo);
+ fuseRootsWithConsumers(context, roots, dominanceInfo, aggressiveFusion);
}
return numRootOps;
@@ -896,8 +1029,11 @@
.insert<AffineDialect, IREE::Flow::FlowDialect, linalg::LinalgDialect,
scf::SCFDialect, tensor::TensorDialect>();
}
- DispatchLinalgOnTensorsPass() = default;
- DispatchLinalgOnTensorsPass(const DispatchLinalgOnTensorsPass &pass) {}
+ DispatchLinalgOnTensorsPass(bool aggressiveFusion) {
+ this->aggressiveFusion = aggressiveFusion;
+ }
+ DispatchLinalgOnTensorsPass(const DispatchLinalgOnTensorsPass &pass)
+ : DispatchLinalgOnTensorsPass(pass.aggressiveFusion) {}
void runOnOperation() override;
private:
@@ -980,7 +1116,7 @@
auto funcOp = getOperation();
MLIRContext *context = &getContext();
DominanceInfo const &dominanceInfo = getAnalysis<DominanceInfo>();
- decideFusableLinalgOps(funcOp, dominanceInfo);
+ decideFusableLinalgOps(funcOp, dominanceInfo, aggressiveFusion);
LLVM_DEBUG({
llvm::dbgs() << "\n--- After annotating linalg op fusion scheme ---\n";
@@ -1063,8 +1199,8 @@
}
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createDispatchLinalgOnTensorsPass() {
- return std::make_unique<DispatchLinalgOnTensorsPass>();
+createDispatchLinalgOnTensorsPass(bool aggressiveFusion) {
+ return std::make_unique<DispatchLinalgOnTensorsPass>(aggressiveFusion);
}
} // namespace Flow
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.cpp
index 616b964..a888110 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.cpp
@@ -38,16 +38,25 @@
/// indexing map.
// TODO: This restriction can go away if we can vectorize always, but that has
// a long tail of tasks.
-static bool canInsOperandTieWithOutsOperand(OpOperand *insOperand) {
+bool isInsOperandBufferizable(OpOperand *insOperand, bool aggressiveFusion) {
+ // Ignore the check if in-place bufferization is not required.
+ if (!clEnsureInplaceableConsumer) return true;
+
auto linalgOp = dyn_cast<linalg::LinalgOp>(insOperand->getOwner());
if (!linalgOp) return false;
AffineMap insOperandIndexingMap = linalgOp.getTiedIndexingMap(insOperand);
auto canTieWithOutsOperand = [&](OpOperand *outsOperand) {
- if (linalgOp.getTiedIndexingMap(outsOperand) != insOperandIndexingMap) {
- return false;
+ AffineMap outsOperandIndexingMap = linalgOp.getTiedIndexingMap(outsOperand);
+
+ if (outsOperandIndexingMap != insOperandIndexingMap) {
+ if (!aggressiveFusion) return false;
+ // If the operand is a projected permutation a small stack might be
+ // fine.
+ if (!insOperandIndexingMap.isProjectedPermutation()) return false;
}
+
// TODO(#8411): Until ops are vectorized (always), we need
// to check that the elementtype matches for the operands to be tied.
// For now just doing this check for convolution ops since we expect
@@ -219,7 +228,7 @@
// 4. In-place bufferization requirements (for now) require that the use in
// the consumer can re-use the buffer for a result.
- return !clEnsureInplaceableConsumer || canInsOperandTieWithOutsOperand(&use);
+ return isInsOperandBufferizable(&use, /*aggressiveFusion=*/false);
}
} // namespace Flow
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.h
index 245ad75..2f1b5f4 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.h
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.h
@@ -18,6 +18,10 @@
namespace IREE {
namespace Flow {
+/// Returns true if the `ins` operand can be properly bufferized after the
+/// fusion.
+bool isInsOperandBufferizable(OpOperand *insOperand, bool aggressiveFusion);
+
/// Returns true if the `use` is from a producer linalg op that can be fused
/// with the consumer linalg op using tile + fuse.
bool areLinalgOpsFusableUsingTileAndFuse(OpOperand &use);
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index a5017fd..11ffd5e 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -273,8 +273,11 @@
})
// Only want use the transform dialect for some dispatch regions and let
// the DispatchLinalgOnTensorsPass handle the rest.
- .addPredicatedPass(!clDispatchViaRegionOps,
- createDispatchLinalgOnTensorsPass)
+ .addPredicatedPass(
+ !clDispatchViaRegionOps,
+ []() {
+ return createDispatchLinalgOnTensorsPass(clEnableAggressiveFusion);
+ })
// DispatchLinalgOnTensorsViaRegionsPass is a variant of
// DispatchLinalgOnTensorsPass that lowers via DispatchRegionOps. This is
// on an opt-in basis until the pass is stable enough to replace
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
index adf7ab7..71a1226 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -143,7 +143,7 @@
// Pass to perform dispatch of Linalg on tensor ops by tiling and distribution.
// A dispatch region is created for each tiled loop nest.
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createDispatchLinalgOnTensorsPass();
+createDispatchLinalgOnTensorsPass(bool aggressiveFusion = false);
// Pass to perform dispatch of Linalg on tensor ops by tiling and distribution.
// A dispatch region is created for each tiled loop nest. (First create
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
index d68df0d..3c4f6f8 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
@@ -66,6 +66,10 @@
InterfacePass<"iree-flow-dispatch-linalg-on-tensors-pass", "mlir::FunctionOpInterface"> {
let summary = "Dispatch Linalg operations on tensors by using tile and distribute";
let constructor = "mlir::iree_compiler::IREE::Flow::createDispatchLinalgOnTensorsPass()";
+ let options = [
+ Option<"aggressiveFusion", "aggressive-fusion", "bool",
+ /*default=*/"false", "Fuse with aggressive heuristics">,
+ ];
}
def DispatchLinalgOnTensorsViaRegionOps :
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD
index e01d0d7..abbe416 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD
@@ -24,6 +24,7 @@
"deduplicate_executables.mlir",
"detach_elementwise_from_named_ops.mlir",
"dispatch_linalg_on_tensors.mlir",
+ "dispatch_linalg_on_tensors_aggressive_fusion.mlir",
"dispatch_linalg_on_tensors_fusion.mlir",
"dispatch_linalg_on_tensors_fusion_reduction_broadcast_elementwise.mlir",
"dispatch_linalg_on_tensors_fusion_with_transpose.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
index d033ea3..e5cce09 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
@@ -22,6 +22,7 @@
"deduplicate_executables.mlir"
"detach_elementwise_from_named_ops.mlir"
"dispatch_linalg_on_tensors.mlir"
+ "dispatch_linalg_on_tensors_aggressive_fusion.mlir"
"dispatch_linalg_on_tensors_fusion.mlir"
"dispatch_linalg_on_tensors_fusion_reduction_broadcast_elementwise.mlir"
"dispatch_linalg_on_tensors_fusion_with_transpose.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
index 43651b4..89cd23e 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --verify-diagnostics --iree-flow-enable-multi-result-dispatches --pass-pipeline="func.func(iree-flow-dispatch-linalg-on-tensors-pass), cse, canonicalize, cse" %s | FileCheck %s
+// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="func.func(iree-flow-dispatch-linalg-on-tensors-pass{aggressive-fusion=true}), cse, canonicalize, cse" %s | FileCheck %s
func.func @tile_matmul_alone(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_aggressive_fusion.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_aggressive_fusion.mlir
new file mode 100644
index 0000000..5ffa94c
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_aggressive_fusion.mlir
@@ -0,0 +1,107 @@
+// RUN: iree-opt --split-input-file --pass-pipeline="func.func(iree-flow-dispatch-linalg-on-tensors-pass{aggressive-fusion=true})" %s | FileCheck %s
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
+module {
+ func.func @softmax(%arg0: tensor<12x128x128xf32>) -> tensor<12x128x128xf32> {
+ %cst = arith.constant 1.000000e+00 : f32
+ %cst_0 = arith.constant 0.000000e+00 : f32
+ %cst_1 = arith.constant -3.40282347E+38 : f32
+ %0 = linalg.init_tensor [12, 128] : tensor<12x128xf32>
+ %1 = linalg.fill ins(%cst_1 : f32) outs(%0 : tensor<12x128xf32>) -> tensor<12x128xf32>
+ %2 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<12x128x128xf32>) outs(%1 : tensor<12x128xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %7 = arith.maxf %arg1, %arg2 : f32
+ linalg.yield %7 : f32
+ } -> tensor<12x128xf32>
+ %3 = linalg.init_tensor [12, 128, 128] : tensor<12x128x128xf32>
+ %4 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<12x128xf32>) -> tensor<12x128xf32>
+ %5:2 = linalg.generic {indexing_maps = [#map0, #map1, #map0, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %2 : tensor<12x128x128xf32>, tensor<12x128xf32>) outs(%3, %4 : tensor<12x128x128xf32>, tensor<12x128xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32):
+ %7 = arith.subf %arg1, %arg2 : f32
+ %8 = math.exp %7 : f32
+ %9 = arith.addf %8, %arg4 : f32
+ linalg.yield %8, %9 : f32, f32
+ } -> (tensor<12x128x128xf32>, tensor<12x128xf32>)
+ %6 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel", "parallel"]} ins(%5#0, %5#1 : tensor<12x128x128xf32>, tensor<12x128xf32>) outs(%3 : tensor<12x128x128xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
+ %7 = arith.divf %cst, %arg2 : f32
+ %8 = arith.mulf %arg1, %7 : f32
+ linalg.yield %8 : f32
+ } -> tensor<12x128x128xf32>
+ return %6 : tensor<12x128x128xf32>
+ }
+}
+// CHECK-LABEL: func @softmax(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<12x128x128xf32>
+// CHECK: %[[DISPATCH:.+]] = flow.dispatch.workgroups
+// CHECK-SAME: (%[[ARG0]])
+// CHECK-NEXT: %[[ARG1:.+]]: !flow.dispatch.tensor<readonly:12x128x128xf32>
+// CHECK: %[[LOAD0:.+]] = flow.dispatch.tensor.load %[[ARG1]]
+// CHECK: %[[FILL0:.+]] = linalg.fill
+// CHECK: %[[FILL1:.+]] = linalg.fill
+// CHECK: %[[GENERIC0:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[LOAD0]] :
+// CHECK: %[[GENERIC1:.+]]:2 = linalg.generic
+// CHECK-SAME: ins(%[[LOAD0]], %[[GENERIC0]] :
+// CHECK: %[[GENERIC2:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[GENERIC1]]#0, %[[GENERIC1]]#1 :
+// CHECK: flow.dispatch.tensor.store %[[GENERIC2]]
+// CHECK: flow.return
+// CHECK: return %[[DISPATCH]]
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d3, d4, d0)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0)>
+#map2 = affine_map<(d0) -> (d0)>
+module {
+ func.func @batchnorm_training(%arg0: tensor<12xf32>, %arg1: tensor<12x12x12x12x12xf32>, %arg2: tensor<12xf32>) -> (tensor<12xf32>, tensor<12xf32>, tensor<12xf32>) {
+ %cst = arith.constant 1.420000e+00 : f32
+ %cst_0 = arith.constant 1.450000e+00 : f32
+ %cst_1 = arith.constant 1.300000e+00 : f32
+ %cst_2 = arith.constant 0.000000e+00 : f32
+ %0 = linalg.init_tensor [12] : tensor<12xf32>
+ %1 = linalg.fill ins(%cst_2 : f32) outs(%0 : tensor<12xf32>) -> tensor<12xf32>
+ %2 = linalg.generic {indexing_maps = [#map0, #map1, #map1], iterator_types = ["parallel", "reduction", "reduction", "reduction", "reduction"]} ins(%arg1, %arg2 : tensor<12x12x12x12x12xf32>, tensor<12xf32>) outs(%1 : tensor<12xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+ %4 = arith.subf %arg3, %arg4 : f32
+ %5 = arith.mulf %4, %4 : f32
+ %6 = arith.addf %arg5, %5 : f32
+ linalg.yield %6 : f32
+ } -> tensor<12xf32>
+ %3:3 = linalg.generic {indexing_maps = [#map2, #map2, #map2, #map2, #map2], iterator_types = ["parallel"]} ins(%arg0, %2 : tensor<12xf32>, tensor<12xf32>) outs(%0, %0, %0 : tensor<12xf32>, tensor<12xf32>, tensor<12xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32, %arg7: f32):
+ %4 = arith.divf %arg4, %cst_0 : f32
+ %5 = arith.addf %4, %cst_1 : f32
+ %6 = math.sqrt %5 : f32
+ %7 = arith.subf %arg3, %6 : f32
+ %8 = arith.mulf %7, %cst : f32
+ %9 = arith.subf %arg3, %8 : f32
+ linalg.yield %5, %6, %9 : f32, f32, f32
+ } -> (tensor<12xf32>, tensor<12xf32>, tensor<12xf32>)
+ return %3#0, %3#1, %3#2 : tensor<12xf32>, tensor<12xf32>, tensor<12xf32>
+ }
+}
+// CHECK-LABEL: func @batchnorm_training(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<12xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<12x12x12x12x12xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<12xf32>
+// CHECK: %[[DISPATCH:.+]]:3 = flow.dispatch.workgroups
+// CHECK-SAME: (%[[ARG1]], %[[ARG2]], %[[ARG0]])
+// CHECK-NEXT: %[[ARG3:.+]]: !flow.dispatch.tensor<readonly:12x12x12x12x12xf32>
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:12xf32>
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:12xf32>
+// CHECK-DAG: %[[LOAD0:.+]] = flow.dispatch.tensor.load %[[ARG3]]
+// CHECK-DAG: %[[LOAD1:.+]] = flow.dispatch.tensor.load %[[ARG4]]
+// CHECK-DAG: %[[LOAD2:.+]] = flow.dispatch.tensor.load %[[ARG5]]
+// CHECK: %[[FILL:.+]] = linalg.fill
+// CHECK: %[[GENERIC0:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[LOAD0]], %[[LOAD1]] :
+// CHECK: %[[GENERIC1:.+]]:3 = linalg.generic
+// CHECK-SAME: ins(%[[LOAD2]], %[[GENERIC0]] :
+// CHECK-DAG: flow.dispatch.tensor.store %[[GENERIC1]]#0
+// CHECK-DAG: flow.dispatch.tensor.store %[[GENERIC1]]#1
+// CHECK-DAG: flow.dispatch.tensor.store %[[GENERIC1]]#2
+// CHECK: flow.return
+// CHECK: return %[[DISPATCH]]#0, %[[DISPATCH]]#1, %[[DISPATCH]]#2
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_fusion.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_fusion.mlir
index ef08d5b..e17cd08 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_fusion.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_fusion.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --verify-diagnostics --iree-flow-enable-multi-result-dispatches --pass-pipeline="func.func(iree-flow-dispatch-linalg-on-tensors-pass)" --canonicalize -cse %s | FileCheck %s
+// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="func.func(iree-flow-dispatch-linalg-on-tensors-pass{aggressive-fusion=true})" --canonicalize -cse %s | FileCheck %s
func.func @fuse_conv2d_elementwise(%input: tensor<1x225x225x16xf32>, %filter: tensor<3x3x16x32xf32>, %offset: tensor<32xf32>) -> tensor<1x112x112x32xf32> {
%cst = arith.constant 0.000000e+00 : f32
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_fusion_reduction_broadcast_elementwise.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_fusion_reduction_broadcast_elementwise.mlir
index a552edd..2095fe4 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_fusion_reduction_broadcast_elementwise.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_fusion_reduction_broadcast_elementwise.mlir
@@ -106,7 +106,7 @@
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-func.func @reduction_broadcast_elementwise_type_mismatch(%a: tensor<12x16x16xf32>, %b: tensor<12x16x32xf32>) -> tensor<12x16x32xf32> {
+func.func @reduction_broadcast_elementwise_type_mismatch(%a: tensor<12x16x16xf32>, %b: tensor<12x16x32xf32>) -> tensor<12x16x32xi32> {
%cst_47 = arith.constant 0.000000e+00 : f32
%37 = linalg.init_tensor [12, 16] : tensor<12x16xf32>
%38 = linalg.fill ins(%cst_47 : f32) outs(%37 : tensor<12x16xf32>) -> tensor<12x16xf32>
@@ -115,13 +115,14 @@
%780 = arith.maxf %arg3, %arg4 : f32
linalg.yield %780 : f32
} -> tensor<12x16xf32>
- %40 = linalg.init_tensor [12, 16, 32] : tensor<12x16x32xf32>
- %42 = linalg.generic {indexing_maps = [#map2, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%b, %39 : tensor<12x16x32xf32>, tensor<12x16xf32>) outs(%40 : tensor<12x16x32xf32>) {
- ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+ %40 = linalg.init_tensor [12, 16, 32] : tensor<12x16x32xi32>
+ %42 = linalg.generic {indexing_maps = [#map2, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%b, %39 : tensor<12x16x32xf32>, tensor<12x16xf32>) outs(%40 : tensor<12x16x32xi32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: i32):
%780 = arith.subf %arg3, %arg4 : f32
- linalg.yield %780 : f32
- } -> tensor<12x16x32xf32>
- return %42 : tensor<12x16x32xf32>
+ %781 = arith.fptosi %780 : f32 to i32
+ linalg.yield %781 : i32
+ } -> tensor<12x16x32xi32>
+ return %42 : tensor<12x16x32xi32>
}
// Check that two generic ops are NOT dispatched together since the input type
@@ -163,4 +164,3 @@
// CHECK-LABEL: func.func @reduction_broadcast_elementwise_dynamic
// CHECK: flow.dispatch.workgroups
// CHECK: flow.dispatch.workgroups
-