Plumb e2e support for packing on dynamic inner tiles. (#11487)

diff --git a/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
index a934afa..8d54b6e 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
@@ -118,14 +118,15 @@
   encodingInfo.innerTileSizes.reserve(mixedTileSizes.size());
   for (auto tileSize : mixedTileSizes) {
     if (tileSize.is<Value>()) {
-      return packOp.emitOpError(
-          "unhandled distribution of pack op with dynamic inner tile size");
+      encodingInfo.innerTileSizes.push_back(ShapedType::kDynamic);
+    } else {
+      encodingInfo.innerTileSizes.push_back(
+          tileSize.get<Attribute>().cast<IntegerAttr>().getInt());
     }
-    encodingInfo.innerTileSizes.push_back(
-        tileSize.get<Attribute>().cast<IntegerAttr>().getInt());
   }
   encodingInfo.innerDimsPos = llvm::to_vector(packOp.getInnerDimsPos());
   encodingInfo.outerDimsPerm = llvm::to_vector(packOp.getOuterDimsPerm());
+  encodingInfo.srcRank = packOp.getInputRank();
   return encodingInfo;
 }
 
@@ -268,6 +269,7 @@
             getAsOpFoldResults(materializeEncodingInfo.innerTileSizes),
             materializeEncodingInfo.innerDimsPos,
             materializeEncodingInfo.outerDimsPerm);
+    resultShape.resize(materializeEncodingInfo.srcRank);
 
     rewriter
         .replaceOpWithNewOp<IREE::Flow::DispatchWorkgroupCountFromDagRootOp>(
diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp b/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp
index d4abb96..ed6c253 100644
--- a/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp
+++ b/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp
@@ -280,6 +280,8 @@
 
   // PackOp has other operands besides ins and outs.
   if (auto packOp = dyn_cast<IREE::LinalgExt::PackOp>(op.getOperation())) {
+    newOperands.append(packOp.getInnerTiles().begin(),
+                       packOp.getInnerTiles().end());
     if (auto pad = packOp.getPaddingValue()) newOperands.push_back(pad);
   }
 
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h
index b715a4c..1d16af5 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h
@@ -92,6 +92,7 @@
   SmallVector<int64_t> innerDimsPos;
   SmallVector<int64_t> innerTileSizes;
   SmallVector<int64_t> outerDimsPerm;
+  unsigned srcRank = 0;
 };
 using MaterializeEncodingFn =
     std::function<FailureOr<MaterializeEncodingInfo>(RankedTensorType)>;
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
index b335776..935ddd2 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
@@ -817,6 +817,11 @@
       return rewriter.notifyMatchFailure(
           packOp, "require the outer dimension of the result are all 1s");
     }
+    if (llvm::any_of(packOp.getMixedTiles(),
+                     [](OpFoldResult tile) { return tile.is<Value>(); })) {
+      return rewriter.notifyMatchFailure(
+          packOp, "require inner tile sizes being static");
+    }
 
     Value input = getInputOrPaddedInput(rewriter, packOp);
 
@@ -954,9 +959,17 @@
       SimpleRewriter rewriter(ctx);
       auto packOptions = scf::SCFTileAndFuseOptions().setTilingOptions(
           scf::SCFTilingOptions().setTileSizeComputationFunction(
-              [](OpBuilder &builder, Operation *op) {
+              [](OpBuilder &builder, Operation *op) -> SmallVector<Value> {
                 Location loc = op->getLoc();
-                int inputRank = cast<PackOp>(op).getInputRank();
+                auto packOp = cast<PackOp>(op);
+
+                // Do nothing if any of inner tile sizes is dynamic.
+                if (llvm::any_of(packOp.getMixedTiles(), [](OpFoldResult tile) {
+                      return tile.is<Value>();
+                    }))
+                  return {};
+
+                int inputRank = packOp.getInputRank();
                 SmallVector<Value> tileSizes(
                     inputRank, builder.create<arith::ConstantIndexOp>(loc, 1));
                 return tileSizes;
diff --git a/tests/e2e/linalg_ext_ops/pack.mlir b/tests/e2e/linalg_ext_ops/pack.mlir
index 2fb47df..2031db6 100644
--- a/tests/e2e/linalg_ext_ops/pack.mlir
+++ b/tests/e2e/linalg_ext_ops/pack.mlir
@@ -437,3 +437,60 @@
   check.expect_eq(%cast_pack, %transpose) : tensor<16x4x16x32xi32>
   return
 }
+
+func.func @fully_dynamic_pack_simple() {
+  %iree_input = flow.tensor.constant dense<[
+    [0, 1, 2, 3],
+    [4, 5, 6, 7],
+    [8, 9, 10, 11],
+    [12, 13, 14, 15]]> : tensor<4x4xi32> -> tensor<?x?xi32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = util.unfoldable_constant 2 : index
+  %in_d0 = tensor.dim %iree_input, %c0 : tensor<?x?xi32>
+  %in_d1 = tensor.dim %iree_input, %c1 : tensor<?x?xi32>
+  %out_d0 = arith.ceildivui %in_d0, %c2 : index
+  %out_d1 = arith.ceildivui %in_d1, %c2 : index
+  %init = tensor.empty(%out_d0, %out_d1, %c2, %c2) : tensor<?x?x?x?xi32>
+  %pack = iree_linalg_ext.pack %iree_input inner_dims_pos = [0, 1] inner_tiles = [%c2, %c2] into %init
+      : (tensor<?x?xi32> tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
+  %cast = tensor.cast %pack : tensor<?x?x?x?xi32> to tensor<2x2x2x2xi32>
+  check.expect_eq_const(%cast, dense<[[[[0, 1], [4, 5]], [[2, 3], [6, 7]]], [[[8, 9], [12, 13]], [[10 ,11], [14, 15]]]]> : tensor<2x2x2x2xi32>) : tensor<2x2x2x2xi32>
+  return
+}
+
+func.func @fully_dynamic_pack_pad_transpose_inner_and_outer_dims_large() {
+  %d0 = util.unfoldable_constant 100 : index
+  %d1 = util.unfoldable_constant 250 : index
+  %source = call @generate_2D_source(%d0, %d1) : (index, index) -> tensor<?x?xi32>
+  %padding_value = arith.constant 42 : i32
+
+  %c16 = util.unfoldable_constant 16 : index
+  %c32 = util.unfoldable_constant 32 : index
+  %tiled_d0 = arith.ceildivui %d0, %c32 : index
+  %tiled_d1 = arith.ceildivui %d1, %c16 : index
+  %init_pack = tensor.empty(%tiled_d1, %tiled_d0, %c16, %c32) : tensor<?x?x?x?xi32>
+  %pack = iree_linalg_ext.pack %source padding_value(%padding_value : i32)
+      outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [%c16, %c32] into %init_pack
+      : (tensor<?x?xi32> tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
+  %cast_pack = tensor.cast %pack : tensor<?x?x?x?xi32> to tensor<16x4x16x32xi32>
+
+  %c100 = arith.constant 100 : index
+  %c250 = arith.constant 250 : index
+  %source2 = call @generate_2D_source(%c100, %c250) : (index, index) -> tensor<?x?xi32>
+  %static_source = tensor.cast %source2 : tensor<?x?xi32> to tensor<100x250xi32>
+
+  %pad = tensor.pad %static_source low[0, 0] high[28, 6] {
+    ^bb0(%b0 : index, %b1 : index):
+      tensor.yield %padding_value : i32
+  } : tensor<100x250xi32> to tensor<128x256xi32>
+  %reshape = tensor.expand_shape %pad [[0, 1], [2, 3]] : tensor<128x256xi32> into tensor<4x32x16x16xi32>
+  %init_transpose = tensor.empty() : tensor<16x4x16x32xi32>
+  %transpose = linalg.transpose
+    ins(%reshape : tensor<4x32x16x16xi32>)
+    outs(%init_transpose : tensor<16x4x16x32xi32>)
+    permutation = [2, 0, 3, 1]
+
+  check.expect_eq(%cast_pack, %transpose) : tensor<16x4x16x32xi32>
+  return
+}