[LLVMCPU] Add option `onlyFuseProducerInputOperands` to tileRootFuseConsumerProducer Pass (#18114)
Previously, we only tilled the reduction tile sizes and did not fuse
them with the producers from the input operands. It led to transfer
read/write with large vector sizes since the dequant operation
materialised its own tensor and wasn't fused inside the reduction loop.
Adds a `onlyFuseProducerInputOperands` option to the
tile-root-and-fuse-consumer-producer-pass.
If the option is set to true, it tiles the reduction dimension and fuses
the operations arising from the input operand of the already tiled
operation. Issue link: https://github.com/iree-org/iree/issues/18005
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileRootAndFuseProducerConsumer.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileRootAndFuseProducerConsumer.cpp
index 3f6c727..78b3b6f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileRootAndFuseProducerConsumer.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileRootAndFuseProducerConsumer.cpp
@@ -32,34 +32,99 @@
namespace {
-/// Implementation of tile root and fuse producers and consumers greedily.
-static LogicalResult tileRootAndFuseProducerConsumerUsingSCF(
- RewriterBase &rewriter, TilingInterface root,
- const scf::SCFTileAndFuseOptions &options) {
+/// Starting from `op` walk all operands backwards to find all
+/// potentially fusable operations, i.e. operations that implement
+/// the `TilingInterface`.
+static void collectTiledAndFusedOps(Operation *rootOp,
+ llvm::SmallDenseSet<Operation *> &result) {
+ SmallVector<Operation *> worklist;
+ worklist.push_back(rootOp);
+ result.insert(rootOp);
+ while (!worklist.empty()) {
+ Operation *current = worklist.pop_back_val();
+ for (OpOperand &operand : current->getOpOperands()) {
+ Operation *producer = operand.get().getDefiningOp();
+ if (!producer || !isa<TilingInterface>(producer) ||
+ result.count(producer))
+ continue;
+ worklist.push_back(producer);
+ result.insert(producer);
+ }
+ }
+}
- // This transformation is only valid for ops that return values (i.e. not
- // valid to use with operations that have memref operands).
- if (!root->getNumResults()) {
- return rewriter.notifyMatchFailure(
- root, "invalid pattern for op with no results");
+/// Tile the root operation and fuse the producers of the root operation.
+/// If `onlyFuseProducerInputOperands` is set, only fuse producer input
+/// operands. Returns the tiled operation to be used for fusing consumers.
+FailureOr<Operation *>
+tileRootAndFuseProducers(IRRewriter &rewriter, TilingInterface rootOp,
+ int64_t tilingLevel,
+ bool onlyFuseProducerInputOperands) {
+ mlir::DominanceInfo dominanceInfo(rootOp);
+ llvm::SmallDenseSet<Operation *> tiledAndFusedOps;
+ collectTiledAndFusedOps(rootOp, tiledAndFusedOps);
+
+ llvm::DenseSet<Operation *> yieldReplacementsFor;
+ for (auto op : tiledAndFusedOps) {
+ if (llvm::any_of(op->getUsers(), [&](Operation *user) {
+ return dominanceInfo.properlyDominates(rootOp, user);
+ })) {
+ yieldReplacementsFor.insert(op);
+ }
}
- // 1. Tile root op and Fuse Producers.
+ SmallVector<OpFoldResult> tileSizes =
+ getLoweringConfig(rootOp).getTilingLevelSizes(rewriter, tilingLevel,
+ rootOp);
+
+ // Pad the tile sizes with zero.
+ auto zero = rewriter.getIndexAttr(0);
+ int64_t numLoops = rootOp.getLoopIteratorTypes().size();
+ if (tileSizes.size() > numLoops) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "tile sizes size " << tileSizes.size()
+ << " exceeds the number of loops " << numLoops << "\n");
+ return failure();
+ }
+ tileSizes.resize(numLoops, zero);
+
+ scf::SCFTilingOptions tilingOptions;
+ tilingOptions.setTileSizes(tileSizes);
+
+ scf::SCFTileAndFuseOptions tileAndFuseOptions;
+ tileAndFuseOptions.setTilingOptions(tilingOptions);
+
+ scf::SCFTileAndFuseOptions::ControlFnTy controlFn =
+ [&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
+ bool isDestinationOperand) {
+ Operation *owner = originalProducer.getOwner();
+ bool yieldProducerReplacement = yieldReplacementsFor.contains(owner);
+ // Do not fuse destination operands if onlyFuseProducerInputOperands is
+ // true.
+ bool shouldFuse =
+ !(onlyFuseProducerInputOperands && isDestinationOperand);
+ return std::make_tuple(shouldFuse, yieldProducerReplacement);
+ };
+ tileAndFuseOptions.setFusionControlFn(controlFn);
+
FailureOr<scf::SCFTileAndFuseResult> tiledResults =
- scf::tileConsumerAndFuseProducersUsingSCF(rewriter, root, options);
-
+ scf::tileConsumerAndFuseProducersUsingSCF(rewriter, rootOp,
+ tileAndFuseOptions);
if (failed(tiledResults)) {
- return rewriter.notifyMatchFailure(
- root, "failed to tile root and fuse producers");
+ return failure();
}
- // 2. Replace the producers with the tiled verison.
- SmallVector<Operation *> opsToReplace = {root};
+ // Perform the replacement of tiled and fused values.
+ SmallVector<Operation *> opsToReplace{rootOp};
llvm::append_range(opsToReplace, tiledResults->fusedProducers);
for (Operation *toReplace : opsToReplace) {
for (OpResult res : toReplace->getResults())
if (auto replacement = tiledResults->replacements.lookup(res)) {
- rewriter.replaceAllUsesWith(res, replacement);
+ Operation *replacementOp = replacement.getDefiningOp();
+ rewriter.replaceUsesWithIf(res, replacement, [&](OpOperand &use) {
+ Operation *user = use.getOwner();
+ return dominanceInfo.properlyDominates(replacementOp, user);
+ });
}
if (toReplace->use_empty()) {
@@ -67,13 +132,18 @@
}
}
- // 3. Typically, the consumers of the tiled operation are slices of the
- // results of the tiled operation. These are expressed in IR using
- // `tensor.insert_slice` operations, whose outputs are the operands of the
- // untiled operation. Create a worklist of these `tensor.insert_siices`
- // operations. If the consumers of the source of the `tensor.insert_slices`
- // can be tiled such that the tiled value is generated in-place, that
- // effectively tiles + fuses the operations.
+ return tiledResults->tiledAndFusedOps.front();
+}
+
+static void fuseConsumers(RewriterBase &rewriter, Operation *tiledOp) {
+
+ // Typically, the consumers of the tiled operation are slices of the
+ // results of the tiled operation. These are expressed in IR using
+ // `tensor.insert_slice` operations, whose outputs are the operands of the
+ // untiled operation. Create a worklist of these `tensor.insert_siices`
+ // operations. If the consumers of the source of the `tensor.insert_slices`
+ // can be tiled such that the tiled value is generated in-place, that
+ // effectively tiles + fuses the operations.
auto addCandidateSlices = [](Operation *fusedOp,
std::queue<tensor::InsertSliceOp> &candidates) {
for (auto *userOp : fusedOp->getResults().getUsers()) {
@@ -86,7 +156,7 @@
// Collect the candidate slices which can be potential consumers that can be
// fused.
std::queue<tensor::InsertSliceOp> candidates;
- addCandidateSlices(tiledResults->tiledAndFusedOps.front(), candidates);
+ addCandidateSlices(tiledOp, candidates);
while (!candidates.empty()) {
@@ -112,42 +182,44 @@
addCandidateSlices(fusedResult->tiledAndFusedConsumerOperand->getOwner(),
candidates);
}
- return success();
}
-static LogicalResult tileRootAndFuseProducerConsumer(IRRewriter &rewriter,
- TilingInterface rootOp,
- int64_t tilingLevel) {
+/// Implementation of tile root and fuse producers and consumers greedily.
+/// If `onlyFuseProducerInputOperands` is set, only fuse producer input operands
+/// and disable consumer fusion.
+static LogicalResult tileRootAndFuse(IRRewriter &rewriter,
+ TilingInterface rootOp,
+ int64_t tilingLevel,
+ bool onlyFuseProducerInputOperands) {
- SmallVector<OpFoldResult> tileSizes =
- getLoweringConfig(rootOp).getTilingLevelSizes(rewriter, tilingLevel,
- rootOp);
- int64_t numLoops = rootOp.getLoopIteratorTypes().size();
- if (tileSizes.size() > numLoops)
+ FailureOr<Operation *> tiledOp = tileRootAndFuseProducers(
+ rewriter, rootOp, tilingLevel, onlyFuseProducerInputOperands);
+
+ if (failed(tiledOp))
return failure();
- scf::SCFTilingOptions tilingOptions;
- tilingOptions.setTileSizes(tileSizes);
+ if (!onlyFuseProducerInputOperands)
+ fuseConsumers(rewriter, tiledOp.value());
- scf::SCFTileAndFuseOptions tileAndFuseOptions;
- tileAndFuseOptions.setTilingOptions(tilingOptions);
-
- return tileRootAndFuseProducerConsumerUsingSCF(rewriter, rootOp,
- tileAndFuseOptions);
+ return success();
}
/// This pass starts with the first TilingInterface operation that has
/// lowering_config attribute, tiles the op and fuses its consumers and
-/// producers recursively. The `tilingLevel` must be specified. It picks the
-/// `tilingLevel`-th list as tiling sizes from lowering_config.
+/// producers recursively. If the `onlyFuseProducerInputOperands` is set, it
+/// only fuses producer input operands and disables consumer fusion. The
+/// `tilingLevel` must be specified. It picks the `tilingLevel`-th list as
+/// tiling sizes from lowering_config.
struct LLVMCPUTileRootAndFuseProducerConsumer
: impl::LLVMCPUTileRootAndFuseProducerConsumerPassBase<
LLVMCPUTileRootAndFuseProducerConsumer> {
using impl::LLVMCPUTileRootAndFuseProducerConsumerPassBase<
LLVMCPUTileRootAndFuseProducerConsumer>::
LLVMCPUTileRootAndFuseProducerConsumerPassBase;
- explicit LLVMCPUTileRootAndFuseProducerConsumer(int64_t tilingLevel) {
+ explicit LLVMCPUTileRootAndFuseProducerConsumer(
+ int64_t tilingLevel, bool onlyFuseProducerInputOperands) {
this->tilingLevel = tilingLevel;
+ this->onlyFuseProducerInputOperands = onlyFuseProducerInputOperands;
}
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<arith::ArithDialect, affine::AffineDialect,
@@ -186,9 +258,9 @@
return signalPassFailure();
}
- if (failed(tileRootAndFuseProducerConsumer(
+ if (failed(tileRootAndFuse(
rewriter, dyn_cast<TilingInterface>(rootOp.value()),
- tilingLevel.getValue()))) {
+ tilingLevel.getValue(), onlyFuseProducerInputOperands.getValue()))) {
funcOp.emitError() << "tiling of level " << tilingLevel.getValue()
<< " failed\n";
return signalPassFailure();
@@ -212,6 +284,12 @@
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLLVMCPUTileRootAndFuseProducerConsumer(int64_t tilingLevel) {
- return std::make_unique<LLVMCPUTileRootAndFuseProducerConsumer>(tilingLevel);
+ return std::make_unique<LLVMCPUTileRootAndFuseProducerConsumer>(
+ tilingLevel, /*onlyFuseProducerInputOperands=*/false);
+}
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createLLVMCPUTileRootAndFuseInputOperands(int64_t tilingLevel) {
+ return std::make_unique<LLVMCPUTileRootAndFuseProducerConsumer>(
+ tilingLevel, /*onlyFuseProducerInputOperands=*/true);
}
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index aeb9d7f..70cbe0d 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -451,8 +451,8 @@
funcPassManager.addPass(createFuseTensorPadWithConsumerPass());
funcPassManager.addPass(createConcretizePadResultShapePass());
- funcPassManager.addPass(
- createLLVMCPUTilePass(tilingConfig.getVectorReductionLevel()));
+ funcPassManager.addPass(createLLVMCPUTileRootAndFuseInputOperands(
+ tilingConfig.getVectorReductionLevel()));
funcPassManager.addPass(
createLLVMCPUTileAndFusePass(tilingConfig.getVectorInnerParallelLevel()));
funcPassManager.addPass(createDecomposeConvolutionToLowerDimOpsPass());
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
index bac30de..a8cb91a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
@@ -42,6 +42,9 @@
createLLVMCPUTileRootAndFuseProducerConsumer(int64_t tilingLevel);
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createLLVMCPUTileRootAndFuseInputOperands(int64_t tilingLevel);
+
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLLVMCPUVerifyVectorSizeLegalityPass(
int64_t maxAllowedNumberOfNativeVectors);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td
index c85666b..81dbdcf 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td
@@ -140,13 +140,20 @@
];
}
-def LLVMCPUTileRootAndFuseProducerConsumerPass :
- InterfacePass<"iree-llvmcpu-tile-root-and-fuse-producer-consumer", "mlir::FunctionOpInterface"> {
- let summary = "Pass to tile root op and fuse with producer and consumer TilingInterface ops.";
- let options = [
- Option<"tilingLevel", "tiling-level", "int64_t", /*default=*/"-1",
- "Use default tiling level used to retrieve the configuration from lowering_config">
- ];
+def LLVMCPUTileRootAndFuseProducerConsumerPass
+ : InterfacePass<"iree-llvmcpu-tile-root-and-fuse-producer-consumer",
+ "mlir::FunctionOpInterface"> {
+ let summary = "Pass to tile root op and fuse with producer and consumer "
+ "TilingInterface ops.";
+ let options =
+ [Option<"tilingLevel", "tiling-level", "int64_t", /*default=*/"-1",
+ "Use default tiling level used to retrieve the configuration "
+ "from lowering_config">,
+ Option<"onlyFuseProducerInputOperands",
+ "only-fuse-producer-input-operands", "bool",
+ /*default=*/"false",
+ "Specifies if we only want to fuse producer's input operands. "
+ "This is helpful to tile&fuse in case of reduction dimensions.">];
}
def LLVMCPUVerifyVectorSizeLegalityPass :
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/tile-root-fuse-consumer-producer.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/tile-root-fuse-consumer-producer.mlir
index 8710da6..4d8805d 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/tile-root-fuse-consumer-producer.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/tile-root-fuse-consumer-producer.mlir
@@ -1,4 +1,6 @@
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmcpu-tile-root-and-fuse-producer-consumer{tiling-level=0}), canonicalize)" --split-input-file %s | FileCheck %s
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmcpu-tile-root-and-fuse-producer-consumer{tiling-level=2 only-fuse-producer-input-operands=true}), canonicalize)" --split-input-file %s | FileCheck %s --check-prefix=CHECK-REDUCTION
+
#config1 = #iree_codegen.lowering_config<tile_sizes = [[1, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 0, 0, 16, 16, 0], [0, 0, 1, 0, 0, 1], [0, 0, 0, 0, 0, 0]]>
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
@@ -30,7 +32,8 @@
// CHECK: }
// -----
-#config = #iree_codegen.lowering_config<tile_sizes = [[1, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 0, 0, 16, 16, 0], [0, 0, 1, 0, 0, 1], [0, 0, 0, 0, 0, 0]]>
+
+#config2 = #iree_codegen.lowering_config<tile_sizes = [[1, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 0, 0, 16, 16, 0], [0, 0, 1, 0, 0, 1], [0, 0, 0, 0, 0, 0]]>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
func.func @quantized_matmul(%arg0: tensor<2x4x128x16x1xi8>, %arg1: tensor<2x4x16xf32>, %arg2: tensor<2x4x16xf32>, %arg3: tensor<2x688x128x16x1xi8>, %arg4: tensor<2x688x16xf32>, %arg5: tensor<2x688x16xf32>) -> tensor<2x11008x64xf32> {
@@ -61,7 +64,7 @@
} -> tensor<2x688x128x16x1xf32>
%4 = tensor.empty() : tensor<2x4x688x16x16xf32>
%5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<2x4x688x16x16xf32>) -> tensor<2x4x688x16x16xf32>
- %6 = linalg.batch_mmt4d {lowering_config = #config} ins(%1, %3 : tensor<2x4x128x16x1xf32>, tensor<2x688x128x16x1xf32>) outs(%5 : tensor<2x4x688x16x16xf32>) -> tensor<2x4x688x16x16xf32>
+ %6 = linalg.batch_mmt4d {lowering_config = #config2} ins(%1, %3 : tensor<2x4x128x16x1xf32>, tensor<2x688x128x16x1xf32>) outs(%5 : tensor<2x4x688x16x16xf32>) -> tensor<2x4x688x16x16xf32>
%7 = tensor.empty() : tensor<2x11008x64xf32>
%unpack = tensor.unpack %6 outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [16, 16] into %7 : tensor<2x4x688x16x16xf32> -> tensor<2x11008x64xf32>
return %unpack : tensor<2x11008x64xf32>
@@ -75,3 +78,42 @@
// CHECK: linalg.batch_mmt4d
// CHECK: tensor.unpack
// CHECK: }
+
+
+// -----
+
+#config3 = #iree_codegen.lowering_config<tile_sizes = [[0, 32, 0, 0, 0, 0], [1, 16, 1, 1, 0, 0], [0, 0, 0, 0, 1, 5], [0, 0, 0, 0, 0, 0]]>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func.func @dequant_avgpool(%arg0: tensor<1x320x65x65xi8>) -> tensor<1x320x1x1xf32> {
+ %cst = arith.constant 1.250000e-01 : f32
+ %cst_0 = arith.constant 0.000000e+00 : f32
+ %c5408000 = arith.constant 5408000 : index
+ %c0 = arith.constant 0 : index
+ %0 = tensor.empty() : tensor<1x320x1x1xf32>
+ %1 = tensor.empty() : tensor<65x65xf32>
+ %2 = linalg.fill ins(%cst_0 : f32) outs(%1 : tensor<65x65xf32>) -> tensor<65x65xf32>
+ %3 = tensor.empty() : tensor<1x320x65x65xf32>
+ %4 = tensor.empty() : tensor<1x320x1x1xf32>
+ %5 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x320x65x65xi8>) outs(%3 : tensor<1x320x65x65xf32>) {
+ ^bb0(%in: i8, %out: f32):
+ %7 = arith.extsi %in : i8 to i32
+ %8 = arith.sitofp %7 : i32 to f32
+ %9 = arith.mulf %8, %cst : f32
+ linalg.yield %9 : f32
+ } -> tensor<1x320x65x65xf32>
+ %6 = linalg.pooling_nchw_sum {lowering_config = #config3} ins(%5, %2 : tensor<1x320x65x65xf32>, tensor<65x65xf32>) outs(%4 : tensor<1x320x1x1xf32>) -> tensor<1x320x1x1xf32>
+ return %6 : tensor<1x320x1x1xf32>
+}
+
+// CHECK-REDUCTION-LABEL: func.func @dequant_avgpool(
+// CHECK-REDUCTION-SAME: {
+// CHECK-REDUCTION: scf.for
+// CHECK-REDUCTION-SAME: {
+// CHECK-REDUCTION: scf.for
+// CHECK-REDUCTION-SAME: {
+// CHECK-REDUCTION: linalg.generic
+// CHECK-REDUCTION: %[[POOL:.+]] = linalg.pooling_nchw_sum
+// CHECK-REDUCTION: scf.yield %[[POOL]]
+// CHECK-REDUCTION: }
+// CHECK-REDUCTION: }
+// CHECK-REDUCTION: }