Fix the infinite application of TileAndDistribute on unpack ops. (#12179)

It removes a cleanup pattern that is no longer needed. The fix is using arith::MulIOp instead of affine ops to compute expanded output sizes in non-perfect tiling cases. Because the affine map could be too complicated, which triggers issues in affine ops simplification.

Fixes https://github.com/iree-org/iree/issues/11607
diff --git a/compiler/src/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp b/compiler/src/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp
index b0dacc6..d793646 100644
--- a/compiler/src/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp
@@ -560,7 +560,8 @@
         .Case<scf::ForOp>(
             [&](scf::ForOp forOp) { return analyseScfForOp(forOp, plan); })
         .Case<scf::YieldOp, tensor::EmptyOp, tensor::DimOp, tensor::ExtractOp,
-              tensor::PadOp, bufferization::ToMemrefOp>(
+              tensor::PadOp, bufferization::ToMemrefOp,
+              bufferization::AllocTensorOp>(
             [&](Operation *op) { return success(); })
         .Default([&](Operation *op) -> LogicalResult {
           if (llvm::any_of(op->getOperands(),
diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp
index 55db046..04e36c8 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp
@@ -30,6 +30,7 @@
 #include "llvm/ADT/TypeSwitch.h"
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
@@ -62,7 +63,8 @@
   }
 
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<linalg::LinalgDialect>();
+    registry
+        .insert<linalg::LinalgDialect, bufferization::BufferizationDialect>();
   }
   void runOnOperation() override;
 };
@@ -463,6 +465,29 @@
   return success();
 }
 
+/// Replaces a tensor.empty op with bufferization.alloc_tensor op which is
+/// created by tiling tensor.unpack op. It is intended because tiling unpack ops
+/// with non-perfect sizes needs extra elements. See the tiling implementation
+/// of tensor.unpack op for more details.
+static LogicalResult replaceUnpackEmptyWithAllocTensor(OpBuilder &b,
+                                                       func::FuncOp funcOp) {
+  funcOp.walk([&](IREE::LinalgExt::UnPackOp unpackOp) {
+    if (!unpackOp->hasOneUse() ||
+        !isa<tensor::ExtractSliceOp>(*(unpackOp->user_begin()))) {
+      return;
+    }
+    auto emptyOp = unpackOp.getOutput().getDefiningOp<tensor::EmptyOp>();
+    if (!emptyOp) return;
+
+    OpBuilder::InsertionGuard g(b);
+    b.setInsertionPointAfter(emptyOp);
+    auto allocTensor = b.create<bufferization::AllocTensorOp>(
+        emptyOp.getLoc(), emptyOp.getType(), emptyOp.getDynamicSizes());
+    emptyOp.replaceAllUsesWith(allocTensor.getResult());
+  });
+  return success();
+}
+
 namespace {
 struct RemoveCstOutsDependency
     : public OpInterfaceRewritePattern<linalg::LinalgOp> {
@@ -525,6 +550,10 @@
     }
   }
 
+  if (failed(replaceUnpackEmptyWithAllocTensor(b, funcOp))) {
+    return signalPassFailure();
+  }
+
   if (failed(convertToDestinationPassingStyle(b, funcOp))) {
     return signalPassFailure();
   }
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp
index a840e5d..0aec05e 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp
@@ -557,35 +557,6 @@
 namespace {
 
 //===----------------------------------------------------------------------===//
-// SwapExtractSliceWithTiledProducer
-//===----------------------------------------------------------------------===//
-
-/// Pattern to swap a `tilinginterface op` -> `tensor.extract_slice` with
-/// `tensor.extract_slice` of operands of the op -> tiled `tilinginterface
-/// op`.
-struct SwapExtractSliceWithTiledProducer
-    : public OpRewritePattern<tensor::ExtractSliceOp> {
-  using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
-                                PatternRewriter &rewriter) const override {
-    OpResult producer = sliceOp.getSource().dyn_cast<OpResult>();
-    if (!producer) {
-      return rewriter.notifyMatchFailure(sliceOp, "source uses bb arg");
-    }
-    FailureOr<Value> tiledProducer =
-        tensor::replaceExtractSliceWithTiledProducer(rewriter, sliceOp,
-                                                     producer);
-    if (failed(tiledProducer)) {
-      return failure();
-    }
-    // Replace all uses of the producer within the
-    rewriter.replaceOp(sliceOp, tiledProducer.value());
-    return success();
-  }
-};
-
-//===----------------------------------------------------------------------===//
 // SwapExtractSliceWithDispatchTensorLoad
 //===----------------------------------------------------------------------===//
 
@@ -653,8 +624,7 @@
     RewritePatternSet &patterns, linalg::LinalgTilingOptions options) {
   MLIRContext *context = patterns.getContext();
   patterns.insert<SwapExtractSliceWithDispatchTensorLoad,
-                  SwapExtractSliceWithTensorEmpty,
-                  SwapExtractSliceWithTiledProducer>(context);
+                  SwapExtractSliceWithTensorEmpty>(context);
 }
 
 }  // namespace iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir b/compiler/src/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir
index af46133..c78f23b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir
@@ -791,6 +791,52 @@
 
 // -----
 
+func.func @non_perfect_tiling_unpack() {
+  %c1 = arith.constant 1 : index
+  %c512 = arith.constant 512 : index
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %0:2 = vmvx.query_tile_sizes sizes(%c16, %c16) flags(1245184) -> index, index
+  %1 = affine.apply affine_map<()[s0] -> (16 ceildiv s0)>()[%0#0]
+  %2 = affine.apply affine_map<()[s0] -> (16 ceildiv s0)>()[%0#1]
+  %3 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c512) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<?x?x?x?xi32>>{%1, %2, %0#0, %0#1}
+  %4 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<1x1xi32>>
+  %5:2 = vmvx.query_tile_sizes sizes(%c16, %c16) flags(1245184) -> index, index
+  %6 = affine.apply affine_map<()[s0] -> (16 ceildiv s0)>()[%5#0]
+  %7 = affine.apply affine_map<()[s0] -> (16 ceildiv s0)>()[%5#1]
+  %8:2 = vmvx.query_tile_sizes sizes(%c16, %c16) flags(1245184) -> index, index
+  %workgroup_id_x = hal.interface.workgroup.id[0] : index
+  %workgroup_count_x = hal.interface.workgroup.count[0] : index
+  %workgroup_id_y = hal.interface.workgroup.id[1] : index
+  %workgroup_count_y = hal.interface.workgroup.count[1] : index
+  %9 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_id_y]
+  %10 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_count_y]
+  scf.for %arg0 = %9 to %c1 step %10 {
+    %11 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_id_x]
+    %12 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_count_x]
+    scf.for %arg1 = %11 to %c1 step %12 {
+      %13 = affine.apply affine_map<(d0)[s0] -> (d0 mod s0)>(%arg0)[%8#0]
+      %14 = affine.apply affine_map<(d0)[s0] -> (d0 mod s0)>(%arg1)[%8#1]
+      %15 = affine.apply affine_map<(d0)[s0] -> (d0 floordiv s0)>(%arg0)[%8#0]
+      %16 = affine.apply affine_map<(d0)[s0] -> (d0 floordiv s0)>(%arg1)[%8#1]
+      %17 = flow.dispatch.tensor.load %3, offsets = [%15, %16, 0, 0], sizes = [%c1, %c1, %8#0, %8#1], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?x?x?xi32>>{%6, %7, %5#0, %5#1} -> tensor<?x?x?x?xi32>
+      %18 = tensor.empty(%8#0, %8#1) : tensor<?x?xi32>
+      %19 = iree_linalg_ext.unpack %17 inner_dims_pos = [0, 1] inner_tiles = [%8#0, %8#1] into %18 : (tensor<?x?x?x?xi32> tensor<?x?xi32>) -> tensor<?x?xi32>
+      %extracted_slice = tensor.extract_slice %19[%13, %14] [1, 1] [1, 1] : tensor<?x?xi32> to tensor<1x1xi32>
+      %cast = tensor.cast %extracted_slice : tensor<1x1xi32> to tensor<?x?xi32>
+      flow.dispatch.tensor.store %cast, %4, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [1, 1] : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:tensor<1x1xi32>>
+    }
+  }
+  return
+}
+// CHECK-LABEL: func.func @non_perfect_tiling_unpack
+// CHECK:         %[[ALLOC:.+]] = bufferization.alloc_tensor
+// CHECK:         %[[UNPACK:.+]] = iree_linalg_ext.unpack
+// CHECK-SAME:      into %[[ALLOC]]
+// CHECK:         %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
+
+// -----
+
 func.func @multi_result_dispatches() {
   %c0 = arith.constant 0 : index
   %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index 99dd29a..b6d691e 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -2224,19 +2224,6 @@
 UnPackOp::getTiledImplementation(OpBuilder &builder,
                                  ArrayRef<OpFoldResult> offsets,
                                  ArrayRef<OpFoldResult> sizes) {
-  Operation *unpackOp = *this;
-  // Dynamic inner tile sizes currently trigger infinite application of
-  // tile-and-distribute on the unpack op, each tile calling
-  // getTiledImplementation -> getSlice creating more and more extract_slice.
-  // As a temporary work-around, we annotate unpack ops with a custom
-  // already_tiled attribute to keep track of what's already been tiled.
-  // For some reason this causes errors in non-dynamic-shape cases, but it's
-  // not needed there anyway, so we simply check for dynamic inner tiles before
-  // applying this tweak.
-  if (ShapedType::isDynamicShape(getStaticInnerTiles())) {
-    if (unpackOp->hasAttr("already_tiled"))
-      return {unpackOp};
-  }
   // TODO(hanchung): Extend it to handle memref version.
   // Tiling on buffers needs extra buffer because tiled unpack op could produce
   // more data for incomplete tiles. Tiling on tensors satisfies IREE's needs.
@@ -2339,9 +2326,13 @@
       AffineExpr i, tile;
       bindDims(builder.getContext(), i);
       bindSymbols(builder.getContext(), tile);
-      OpFoldResult size = makeComposedFoldedAffineApply(
-          builder, loc, i * tile,
-          ArrayRef<OpFoldResult>{inputSizes.back(), dimAndTileMapping[dim]});
+      // Do not create an Affine ops for output size because the affine op is
+      // too complicated which would trigger an issue in affine ops
+      // simplification.
+      OpFoldResult size = builder.createOrFold<arith::MulIOp>(
+          loc, getValueOrCreateConstantIndexOp(builder, loc, inputSizes.back()),
+          getValueOrCreateConstantIndexOp(builder, loc,
+                                          dimAndTileMapping[dim]));
       outputExpandedSizes.push_back(size);
     }
   }
@@ -2388,8 +2379,6 @@
 
   Operation *tiledUnpackOp =
       mlir::clone(builder, getOperation(), tiledResultTypes, tiledOperands);
-  tiledUnpackOp->setAttr(StringAttr::get(getContext(), "already_tiled"),
-                         BoolAttr::get(getContext(), true));
 
   if (isPerfectTilingCase)
     return {tiledUnpackOp};
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir
index efba85a..fc8ce7e 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir
@@ -1127,11 +1127,9 @@
 // CHECK-DAG:   #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 32)>
 // CHECK-DAG:   #[[MAP1:.+]] = affine_map<(d0) -> (d0 mod 32)>
 // CHECK-DAG:   #[[MAP2:.+]] = affine_map<(d0) -> ((d0 + 1) floordiv 32 - d0 floordiv 32 + 1)>
-// CHECK-DAG:   #[[MAP3:.+]] = affine_map<(d0) -> (((d0 + 1) floordiv 32) * 32 - (d0 floordiv 32) * 32 + 32)>
 // CHECK-DAG:   #[[MAP4:.+]] = affine_map<(d0) -> (d0 floordiv 16)>
 // CHECK-DAG:   #[[MAP5:.+]] = affine_map<(d0) -> (d0 mod 16)>
 // CHECK-DAG:   #[[MAP6:.+]] = affine_map<(d0) -> ((d0 + 3) floordiv 16 - d0 floordiv 16 + 1)>
-// CHECK-DAG:   #[[MAP7:.+]] = affine_map<(d0) -> (((d0 + 3) floordiv 16) * 16 - (d0 floordiv 16) * 16 + 16)>
 // CHECK-LABEL: func.func @NCnc_to_NC
 // CHECK-SAME:    %[[IN:[A-Za-z0-9]+]]:
 // CHECK-SAME:    %[[OUT:[A-Za-z0-9]+]]:
@@ -1171,11 +1169,9 @@
 // CHECK-DAG:   #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 32)>
 // CHECK-DAG:   #[[MAP1:.+]] = affine_map<(d0) -> (d0 mod 32)>
 // CHECK-DAG:   #[[MAP2:.+]] = affine_map<(d0) -> ((d0 + 1) floordiv 32 - d0 floordiv 32 + 1)>
-// CHECK-DAG:   #[[MAP3:.+]] = affine_map<(d0) -> (((d0 + 1) floordiv 32) * 32 - (d0 floordiv 32) * 32 + 32)>
 // CHECK-DAG:   #[[MAP4:.+]] = affine_map<(d0) -> (d0 floordiv 8)>
 // CHECK-DAG:   #[[MAP5:.+]] = affine_map<(d0) -> (d0 mod 8)>
 // CHECK-DAG:   #[[MAP6:.+]] = affine_map<(d0) -> ((d0 + 3) floordiv 8 - d0 floordiv 8 + 1)>
-// CHECK-DAG:   #[[MAP7:.+]] = affine_map<(d0) -> (((d0 + 3) floordiv 8) * 8 - (d0 floordiv 8) * 8 + 8)>
 // CHECK-LABEL: func.func @CKkc_to_KC
 // CHECK-SAME:    %[[IN:[A-Za-z0-9]+]]:
 // CHECK-SAME:    %[[OUT:[A-Za-z0-9]+]]:
@@ -1675,5 +1671,3 @@
 // CHECK:        }
 // CHECK:        return
 // CHECK:      }
-
-// -----