[Flow] Allow element-wise fusion of multi-reduction ops (#16503)
This enables element-wise fusion of ops with multiple reduction dimensions in FusionOfTensorOps.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
index ea4f55a..060439e 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
@@ -19,6 +19,7 @@
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
@@ -61,7 +62,8 @@
}
/// Check if the producer generic op is fusable with the consumer generic op.
-static bool areFusableOps(MLIRContext *context, OpOperand *fusedOperand) {
+static bool areFusableOps(MLIRContext *context, OpOperand *fusedOperand,
+ bool fuseMultiReduction) {
Operation *producerOp = fusedOperand->get().getDefiningOp();
Operation *consumerOp = fusedOperand->getOwner();
if (!producerOp)
@@ -115,11 +117,17 @@
linalgConsumerOp.getNumLoops()) {
return true;
}
- if (linalgConsumerOp.getNumReductionLoops() != 1 ||
- !linalgConsumerOp.getMatchingIndexingMap(fusedOperand)
+ if (!linalgConsumerOp.getMatchingIndexingMap(fusedOperand)
.isPermutation()) {
return false;
}
+ if (!fuseMultiReduction && linalgConsumerOp.getNumReductionLoops() != 1) {
+ return false;
+ }
+ if (linalg::isaContractionOpInterface(linalgConsumerOp) ||
+ linalg::isaConvolutionOpInterface(linalgConsumerOp)) {
+ return false;
+ }
return true;
}
@@ -304,13 +312,15 @@
registry.insert<affine::AffineDialect, arith::ArithDialect,
linalg::LinalgDialect, math::MathDialect>();
}
- FusionOfTensorOpsPass(bool fuseMultiUse, unsigned multiUseFusionIteration) {
+ FusionOfTensorOpsPass(bool fuseMultiUse, bool fuseMultiReduction,
+ unsigned multiUseFusionIteration) {
this->fuseMultiUse = fuseMultiUse;
+ this->fuseMultiReduction = fuseMultiReduction;
this->multiUseFusionIteration = multiUseFusionIteration;
}
FusionOfTensorOpsPass(const FusionOfTensorOpsPass &pass)
- : FusionOfTensorOpsPass(pass.fuseMultiUse, pass.multiUseFusionIteration) {
- }
+ : FusionOfTensorOpsPass(pass.fuseMultiUse, pass.fuseMultiReduction,
+ pass.multiUseFusionIteration) {}
void runOnOperation() override {
Operation *funcOp = getOperation();
@@ -345,7 +355,7 @@
if (operands.size() >= kIreeMaxOperandCount)
return false;
- return areFusableOps(context, fusedOperand);
+ return areFusableOps(context, fusedOperand, fuseMultiReduction);
};
linalg::populateElementwiseOpsFusionPatterns(fusionPatterns,
fuseElementwiseOpsControlFn);
@@ -501,10 +511,10 @@
} // namespace
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createFusionOfTensorOpsPass(bool fuseMultiUse,
+createFusionOfTensorOpsPass(bool fuseMultiUse, bool fuseMultiReduction,
unsigned multiUseFusionIteration) {
- return std::make_unique<FusionOfTensorOpsPass>(fuseMultiUse,
- multiUseFusionIteration);
+ return std::make_unique<FusionOfTensorOpsPass>(
+ fuseMultiUse, fuseMultiReduction, multiUseFusionIteration);
}
} // namespace mlir::iree_compiler::IREE::Flow
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index 2cdc605..56edc9e 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -80,6 +80,11 @@
llvm::cl::desc("Fuse multi-use ops."),
llvm::cl::init(false));
+static llvm::cl::opt<bool> clEnableElementWiseFuseMultiReduction(
+ "iree-flow-element-wise-fuse-multi-reduction",
+ llvm::cl::desc("Enable element-wise fusion of multi-reduction loop ops."),
+ llvm::cl::init(true));
+
static llvm::cl::opt<bool> clDispatchGenerateWorkloadRegion(
"iree-flow-dispatch-generate-workload-region",
llvm::cl::desc("Generate the workload region."), llvm::cl::init(true));
@@ -135,8 +140,10 @@
.addPass(mlir::createCanonicalizerPass)
.addPass(mlir::createCSEPass)
// Elementwise fusion.
- .addPass(
- []() { return createFusionOfTensorOpsPass(clEnableFuseMultiUse); })
+ .addPass([]() {
+ return createFusionOfTensorOpsPass(
+ clEnableFuseMultiUse, clEnableElementWiseFuseMultiReduction);
+ })
.addPredicatedPass(clDetensoring,
[&]() { return mlir::createLinalgDetensorizePass(); })
.addPass(mlir::createCanonicalizerPass)
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
index 2bfc4bc..60cf678 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -65,6 +65,7 @@
// Creates a pass to fuse Linalg operations on tensors.
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createFusionOfTensorOpsPass(bool fuseMultiUse = false,
+ bool fuseMultiReduction = true,
unsigned multiUseFusionIteration = 2);
// Create a pass to initialize all empty tensors after dispatch formation to
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
index 55041e1..743b19e 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
@@ -168,6 +168,8 @@
let options = [
Option<"fuseMultiUse", "fuse-multi-use", "bool",
/*default=*/"false", "Fuse ops with multiuse">,
+ Option<"fuseMultiReduction", "fuse-multi-reduction", "bool",
+ /*default=*/"true", "Fuse ops that have multiple reduction iterators">,
Option<"multiUseFusionIteration", "multi-use-fusion-iteration", "unsigned",
/*default=*/"2", "Number of iterations to fuse multiuse ops">
];
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fusion_of_tensor_ops.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fusion_of_tensor_ops.mlir
index f3741b6..da68ca1 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fusion_of_tensor_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fusion_of_tensor_ops.mlir
@@ -366,10 +366,9 @@
// -----
-util.func public @nofuse_by_expand_dequant(%arg0 : tensor<11008x4096xi4>, %arg1 : tensor<11008x32x1xf16>, %arg2 : tensor<11008x32x1xf16>) -> (tensor<11008xf16>) {
+util.func public @nofuse_by_expand_dequant(%arg0 : tensor<11008x4096xi4>, %arg1 : tensor<11008x32x1xf16>, %arg2 : tensor<11008x32x1xf16>, %arg3 : tensor<1x1x32x128xf16>) -> (tensor<11008xf16>) {
%cst_1 = arith.constant 0.000000e+00 : f16
%0 = tensor.empty() : tensor<11008x32x128xf16>
- %1 = arith.constant dense<0.000000e+00> : tensor<1x1x32x128xf16>
%expanded = tensor.expand_shape %arg0 [[0], [1, 2]] : tensor<11008x4096xi4> into tensor<11008x32x128xi4>
%collapsed = tensor.collapse_shape %arg2 [[0], [1, 2]] : tensor<11008x32x1xf16> into tensor<11008x32xf16>
%collapsed_2 = tensor.collapse_shape %arg1 [[0], [1, 2]] : tensor<11008x32x1xf16> into tensor<11008x32xf16>
@@ -381,7 +380,7 @@
%15 = arith.mulf %14, %in_8 : f16
linalg.yield %15 : f16
} -> tensor<11008x32x128xf16>
- %collapsed_3 = tensor.collapse_shape %1 [[0, 1, 2], [3]] : tensor<1x1x32x128xf16> into tensor<32x128xf16>
+ %collapsed_3 = tensor.collapse_shape %arg3 [[0, 1, 2], [3]] : tensor<1x1x32x128xf16> into tensor<32x128xf16>
%3 = tensor.empty() : tensor<11008xf16>
%4 = linalg.fill ins(%cst_1 : f16) outs(%3 : tensor<11008xf16>) -> tensor<11008xf16>
%5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"]} ins(%collapsed_3, %2 : tensor<32x128xf16>, tensor<11008x32x128xf16>) outs(%4 : tensor<11008xf16>) {
@@ -399,7 +398,7 @@
// CHECK-NOT: tensor.collapse_shape %[[DEQUANT]]
// CHECK: %[[MATVEC:.+]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "reduction", "reduction"]
-// CHECK-SAME: ins(%[[DEQUANT]] : tensor<11008x32x128xf16>)
+// CHECK-SAME: ins({{.+}}%[[DEQUANT]]
// -----
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pipeline_tests.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pipeline_tests.mlir
index 5485f59..290d2d7 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pipeline_tests.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pipeline_tests.mlir
@@ -44,16 +44,11 @@
// CHECK: flow.executable private @[[EXECUTABLE0:[a-zA-Z0-9_]+]]
// CHECK: func.func @[[FUNC0:[a-zA-Z0-9_x]+]]
// CHECK: linalg.generic
-// CHECK-SAME: ["parallel", "parallel"]
-// CHECK: flow.executable private @[[EXECUTABLE1:[a-zA-Z0-9_]+]]
-// CHECK: func.func @[[FUNC1:[a-zA-Z0-9_x]+]]
-// CHECK: linalg.generic
-// CHECK-SAME: ["reduction"]
+// CHECK-SAME: ["reduction", "reduction"]
+// CHECK-NOT: linalg.generic
// CHECK: util.func public @main(
// CHECK: %[[T0:.+]] = flow.dispatch @[[EXECUTABLE0]]::@[[FUNC0]]
-// CHECK: %[[T1:.+]] = flow.tensor.reshape %[[T0]] : tensor<833x833xf32> -> tensor<693889xf32>
-// CHECK: %[[T2:.+]] = flow.dispatch @[[EXECUTABLE1]]::@[[FUNC1]](%[[T1]])
-// CHECK: util.return %[[T2]]
+// CHECK: util.return %[[T0]]
// -----