Work around infinite application of tile-and-distribute on `linalg_ext.unpack` with dynamic tile sizes (#11601)
Dynamic inner tile sizes currently trigger infinite application of
tile-and-distribute on the unpack op, each time calling
`getTiledImplementation` -> `getSlice` creating more and more
`extract_slice`. As a temporary work-around, we annotate unpack ops with
a custom already_tiled attribute to keep track of what's already been
tiled.
For some reason this causes errors in non-dynamic-shape cases, but it's
not needed there anyway, so we simply check for dynamic inner tiles
before applying this tweak.
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/materialize_encoding.mlir b/compiler/src/iree/compiler/Codegen/Common/test/materialize_encoding.mlir
index 013858d..900c17d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/materialize_encoding.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/materialize_encoding.mlir
@@ -41,7 +41,8 @@
// CHECK-SAME: !flow.dispatch.tensor<writeonly:tensor<?x?x8x4xf32>>{%[[TILED_OUTD0]], %[[TILED_OUTD1]]}
// CHECK: %[[INPUT:.+]] = flow.dispatch.tensor.load %[[INPUT_BINDING]]
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[TILED_OUTD0]], %[[TILED_OUTD1]])
-// CHECK: %[[PACK:.+]] = iree_linalg_ext.pack %[[INPUT]] padding_value(%[[CST]] : f32)
+// CHECK: %[[PACK:.+]] = iree_linalg_ext.pack
+// CHECK-SAME: %[[INPUT]] padding_value(%[[CST]] : f32)
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %[[EMPTY]]
// CHECK: flow.dispatch.tensor.store %[[PACK]], %[[OUTPUT_BINDING]]
// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_OUTD0]], %[[TILED_OUTD1]], 8, 4], strides = [1, 1, 1, 1]
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index 8ee5291..9f3698b 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -2217,6 +2217,19 @@
UnPackOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
+ Operation *unpackOp = *this;
+ // Dynamic inner tile sizes currently trigger infinite application of
+ // tile-and-distribute on the unpack op, each tile calling
+ // getTiledImplementation -> getSlice creating more and more extract_slice.
+ // As a temporary work-around, we annotate unpack ops with a custom
+ // already_tiled attribute to keep track of what's already been tiled.
+ // For some reason this causes errors in non-dynamic-shape cases, but it's
+ // not needed there anyway, so we simply check for dynamic inner tiles before
+ // applying this tweak.
+ if (ShapedType::isDynamicShape(getStaticInnerTiles())) {
+ if (unpackOp->hasAttr("already_tiled"))
+ return {unpackOp};
+ }
// TODO(hanchung): Extend it to handle memref version.
// Tiling on buffers needs extra buffer because tiled unpack op could produce
// more data for incomplete tiles. Tiling on tensors satisfies IREE's needs.
@@ -2368,6 +2381,8 @@
Operation *tiledUnpackOp =
mlir::clone(builder, getOperation(), tiledResultTypes, tiledOperands);
+ tiledUnpackOp->setAttr(StringAttr::get(getContext(), "already_tiled"),
+ BoolAttr::get(getContext(), true));
if (isPerfectTilingCase)
return {tiledUnpackOp};
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir
index b48219e..854a6e8 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir
@@ -1153,7 +1153,7 @@
// CHECK-SAME: : tensor<8x8x32x16xf32> to tensor<?x?x32x16xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty
// CHECK: %[[UNPACK:.+]] = iree_linalg_ext.unpack
-// CHECK-SAME: {__internal_linalg_transform__ = "tiling_pack_output"}
+// CHECK-SAME: {__internal_linalg_transform__ = "tiling_pack_output"
// CHECK-SAME: %[[SLICE]] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
// CHECK-SAME: into %[[EMPTY]]
// CHECK: %[[UNPACK_SLICE:.+]] = tensor.extract_slice %[[UNPACK]]