Fix the infinite application of TileAndDistribute on unpack ops. (#12179)
It removes a cleanup pattern that is no longer needed. The fix is using arith::MulIOp instead of affine ops to compute expanded output sizes in non-perfect tiling cases. Because the affine map could be too complicated, which triggers issues in affine ops simplification.
Fixes https://github.com/iree-org/iree/issues/11607
diff --git a/compiler/src/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp b/compiler/src/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp
index b0dacc6..d793646 100644
--- a/compiler/src/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp
@@ -560,7 +560,8 @@
.Case<scf::ForOp>(
[&](scf::ForOp forOp) { return analyseScfForOp(forOp, plan); })
.Case<scf::YieldOp, tensor::EmptyOp, tensor::DimOp, tensor::ExtractOp,
- tensor::PadOp, bufferization::ToMemrefOp>(
+ tensor::PadOp, bufferization::ToMemrefOp,
+ bufferization::AllocTensorOp>(
[&](Operation *op) { return success(); })
.Default([&](Operation *op) -> LogicalResult {
if (llvm::any_of(op->getOperands(),
diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp
index 55db046..04e36c8 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp
@@ -30,6 +30,7 @@
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
@@ -62,7 +63,8 @@
}
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<linalg::LinalgDialect>();
+ registry
+ .insert<linalg::LinalgDialect, bufferization::BufferizationDialect>();
}
void runOnOperation() override;
};
@@ -463,6 +465,29 @@
return success();
}
+/// Replaces a tensor.empty op with bufferization.alloc_tensor op which is
+/// created by tiling tensor.unpack op. It is intended because tiling unpack ops
+/// with non-perfect sizes needs extra elements. See the tiling implementation
+/// of tensor.unpack op for more details.
+static LogicalResult replaceUnpackEmptyWithAllocTensor(OpBuilder &b,
+ func::FuncOp funcOp) {
+ funcOp.walk([&](IREE::LinalgExt::UnPackOp unpackOp) {
+ if (!unpackOp->hasOneUse() ||
+ !isa<tensor::ExtractSliceOp>(*(unpackOp->user_begin()))) {
+ return;
+ }
+ auto emptyOp = unpackOp.getOutput().getDefiningOp<tensor::EmptyOp>();
+ if (!emptyOp) return;
+
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPointAfter(emptyOp);
+ auto allocTensor = b.create<bufferization::AllocTensorOp>(
+ emptyOp.getLoc(), emptyOp.getType(), emptyOp.getDynamicSizes());
+ emptyOp.replaceAllUsesWith(allocTensor.getResult());
+ });
+ return success();
+}
+
namespace {
struct RemoveCstOutsDependency
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
@@ -525,6 +550,10 @@
}
}
+ if (failed(replaceUnpackEmptyWithAllocTensor(b, funcOp))) {
+ return signalPassFailure();
+ }
+
if (failed(convertToDestinationPassingStyle(b, funcOp))) {
return signalPassFailure();
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp
index a840e5d..0aec05e 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp
@@ -557,35 +557,6 @@
namespace {
//===----------------------------------------------------------------------===//
-// SwapExtractSliceWithTiledProducer
-//===----------------------------------------------------------------------===//
-
-/// Pattern to swap a `tilinginterface op` -> `tensor.extract_slice` with
-/// `tensor.extract_slice` of operands of the op -> tiled `tilinginterface
-/// op`.
-struct SwapExtractSliceWithTiledProducer
- : public OpRewritePattern<tensor::ExtractSliceOp> {
- using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
- PatternRewriter &rewriter) const override {
- OpResult producer = sliceOp.getSource().dyn_cast<OpResult>();
- if (!producer) {
- return rewriter.notifyMatchFailure(sliceOp, "source uses bb arg");
- }
- FailureOr<Value> tiledProducer =
- tensor::replaceExtractSliceWithTiledProducer(rewriter, sliceOp,
- producer);
- if (failed(tiledProducer)) {
- return failure();
- }
- // Replace all uses of the producer within the
- rewriter.replaceOp(sliceOp, tiledProducer.value());
- return success();
- }
-};
-
-//===----------------------------------------------------------------------===//
// SwapExtractSliceWithDispatchTensorLoad
//===----------------------------------------------------------------------===//
@@ -653,8 +624,7 @@
RewritePatternSet &patterns, linalg::LinalgTilingOptions options) {
MLIRContext *context = patterns.getContext();
patterns.insert<SwapExtractSliceWithDispatchTensorLoad,
- SwapExtractSliceWithTensorEmpty,
- SwapExtractSliceWithTiledProducer>(context);
+ SwapExtractSliceWithTensorEmpty>(context);
}
} // namespace iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir b/compiler/src/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir
index af46133..c78f23b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir
@@ -791,6 +791,52 @@
// -----
+func.func @non_perfect_tiling_unpack() {
+ %c1 = arith.constant 1 : index
+ %c512 = arith.constant 512 : index
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %0:2 = vmvx.query_tile_sizes sizes(%c16, %c16) flags(1245184) -> index, index
+ %1 = affine.apply affine_map<()[s0] -> (16 ceildiv s0)>()[%0#0]
+ %2 = affine.apply affine_map<()[s0] -> (16 ceildiv s0)>()[%0#1]
+ %3 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c512) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<?x?x?x?xi32>>{%1, %2, %0#0, %0#1}
+ %4 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<1x1xi32>>
+ %5:2 = vmvx.query_tile_sizes sizes(%c16, %c16) flags(1245184) -> index, index
+ %6 = affine.apply affine_map<()[s0] -> (16 ceildiv s0)>()[%5#0]
+ %7 = affine.apply affine_map<()[s0] -> (16 ceildiv s0)>()[%5#1]
+ %8:2 = vmvx.query_tile_sizes sizes(%c16, %c16) flags(1245184) -> index, index
+ %workgroup_id_x = hal.interface.workgroup.id[0] : index
+ %workgroup_count_x = hal.interface.workgroup.count[0] : index
+ %workgroup_id_y = hal.interface.workgroup.id[1] : index
+ %workgroup_count_y = hal.interface.workgroup.count[1] : index
+ %9 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_id_y]
+ %10 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_count_y]
+ scf.for %arg0 = %9 to %c1 step %10 {
+ %11 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_id_x]
+ %12 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_count_x]
+ scf.for %arg1 = %11 to %c1 step %12 {
+ %13 = affine.apply affine_map<(d0)[s0] -> (d0 mod s0)>(%arg0)[%8#0]
+ %14 = affine.apply affine_map<(d0)[s0] -> (d0 mod s0)>(%arg1)[%8#1]
+ %15 = affine.apply affine_map<(d0)[s0] -> (d0 floordiv s0)>(%arg0)[%8#0]
+ %16 = affine.apply affine_map<(d0)[s0] -> (d0 floordiv s0)>(%arg1)[%8#1]
+ %17 = flow.dispatch.tensor.load %3, offsets = [%15, %16, 0, 0], sizes = [%c1, %c1, %8#0, %8#1], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?x?x?xi32>>{%6, %7, %5#0, %5#1} -> tensor<?x?x?x?xi32>
+ %18 = tensor.empty(%8#0, %8#1) : tensor<?x?xi32>
+ %19 = iree_linalg_ext.unpack %17 inner_dims_pos = [0, 1] inner_tiles = [%8#0, %8#1] into %18 : (tensor<?x?x?x?xi32> tensor<?x?xi32>) -> tensor<?x?xi32>
+ %extracted_slice = tensor.extract_slice %19[%13, %14] [1, 1] [1, 1] : tensor<?x?xi32> to tensor<1x1xi32>
+ %cast = tensor.cast %extracted_slice : tensor<1x1xi32> to tensor<?x?xi32>
+ flow.dispatch.tensor.store %cast, %4, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [1, 1] : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:tensor<1x1xi32>>
+ }
+ }
+ return
+}
+// CHECK-LABEL: func.func @non_perfect_tiling_unpack
+// CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor
+// CHECK: %[[UNPACK:.+]] = iree_linalg_ext.unpack
+// CHECK-SAME: into %[[ALLOC]]
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
+
+// -----
+
func.func @multi_result_dispatches() {
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index 99dd29a..b6d691e 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -2224,19 +2224,6 @@
UnPackOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
- Operation *unpackOp = *this;
- // Dynamic inner tile sizes currently trigger infinite application of
- // tile-and-distribute on the unpack op, each tile calling
- // getTiledImplementation -> getSlice creating more and more extract_slice.
- // As a temporary work-around, we annotate unpack ops with a custom
- // already_tiled attribute to keep track of what's already been tiled.
- // For some reason this causes errors in non-dynamic-shape cases, but it's
- // not needed there anyway, so we simply check for dynamic inner tiles before
- // applying this tweak.
- if (ShapedType::isDynamicShape(getStaticInnerTiles())) {
- if (unpackOp->hasAttr("already_tiled"))
- return {unpackOp};
- }
// TODO(hanchung): Extend it to handle memref version.
// Tiling on buffers needs extra buffer because tiled unpack op could produce
// more data for incomplete tiles. Tiling on tensors satisfies IREE's needs.
@@ -2339,9 +2326,13 @@
AffineExpr i, tile;
bindDims(builder.getContext(), i);
bindSymbols(builder.getContext(), tile);
- OpFoldResult size = makeComposedFoldedAffineApply(
- builder, loc, i * tile,
- ArrayRef<OpFoldResult>{inputSizes.back(), dimAndTileMapping[dim]});
+ // Do not create an Affine ops for output size because the affine op is
+ // too complicated which would trigger an issue in affine ops
+ // simplification.
+ OpFoldResult size = builder.createOrFold<arith::MulIOp>(
+ loc, getValueOrCreateConstantIndexOp(builder, loc, inputSizes.back()),
+ getValueOrCreateConstantIndexOp(builder, loc,
+ dimAndTileMapping[dim]));
outputExpandedSizes.push_back(size);
}
}
@@ -2388,8 +2379,6 @@
Operation *tiledUnpackOp =
mlir::clone(builder, getOperation(), tiledResultTypes, tiledOperands);
- tiledUnpackOp->setAttr(StringAttr::get(getContext(), "already_tiled"),
- BoolAttr::get(getContext(), true));
if (isPerfectTilingCase)
return {tiledUnpackOp};
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir
index efba85a..fc8ce7e 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir
@@ -1127,11 +1127,9 @@
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 32)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 mod 32)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> ((d0 + 1) floordiv 32 - d0 floordiv 32 + 1)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (((d0 + 1) floordiv 32) * 32 - (d0 floordiv 32) * 32 + 32)>
// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0) -> (d0 floordiv 16)>
// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0) -> (d0 mod 16)>
// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0) -> ((d0 + 3) floordiv 16 - d0 floordiv 16 + 1)>
-// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0) -> (((d0 + 3) floordiv 16) * 16 - (d0 floordiv 16) * 16 + 16)>
// CHECK-LABEL: func.func @NCnc_to_NC
// CHECK-SAME: %[[IN:[A-Za-z0-9]+]]:
// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]:
@@ -1171,11 +1169,9 @@
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 32)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 mod 32)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> ((d0 + 1) floordiv 32 - d0 floordiv 32 + 1)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (((d0 + 1) floordiv 32) * 32 - (d0 floordiv 32) * 32 + 32)>
// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0) -> (d0 floordiv 8)>
// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0) -> (d0 mod 8)>
// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0) -> ((d0 + 3) floordiv 8 - d0 floordiv 8 + 1)>
-// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0) -> (((d0 + 3) floordiv 8) * 8 - (d0 floordiv 8) * 8 + 8)>
// CHECK-LABEL: func.func @CKkc_to_KC
// CHECK-SAME: %[[IN:[A-Za-z0-9]+]]:
// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]:
@@ -1675,5 +1671,3 @@
// CHECK: }
// CHECK: return
// CHECK: }
-
-// -----