Plumb e2e support for packing on dynamic inner tiles. (#11487)
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
index a934afa..8d54b6e 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
@@ -118,14 +118,15 @@
encodingInfo.innerTileSizes.reserve(mixedTileSizes.size());
for (auto tileSize : mixedTileSizes) {
if (tileSize.is<Value>()) {
- return packOp.emitOpError(
- "unhandled distribution of pack op with dynamic inner tile size");
+ encodingInfo.innerTileSizes.push_back(ShapedType::kDynamic);
+ } else {
+ encodingInfo.innerTileSizes.push_back(
+ tileSize.get<Attribute>().cast<IntegerAttr>().getInt());
}
- encodingInfo.innerTileSizes.push_back(
- tileSize.get<Attribute>().cast<IntegerAttr>().getInt());
}
encodingInfo.innerDimsPos = llvm::to_vector(packOp.getInnerDimsPos());
encodingInfo.outerDimsPerm = llvm::to_vector(packOp.getOuterDimsPerm());
+ encodingInfo.srcRank = packOp.getInputRank();
return encodingInfo;
}
@@ -268,6 +269,7 @@
getAsOpFoldResults(materializeEncodingInfo.innerTileSizes),
materializeEncodingInfo.innerDimsPos,
materializeEncodingInfo.outerDimsPerm);
+ resultShape.resize(materializeEncodingInfo.srcRank);
rewriter
.replaceOpWithNewOp<IREE::Flow::DispatchWorkgroupCountFromDagRootOp>(
diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp b/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp
index d4abb96..ed6c253 100644
--- a/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp
+++ b/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp
@@ -280,6 +280,8 @@
// PackOp has other operands besides ins and outs.
if (auto packOp = dyn_cast<IREE::LinalgExt::PackOp>(op.getOperation())) {
+ newOperands.append(packOp.getInnerTiles().begin(),
+ packOp.getInnerTiles().end());
if (auto pad = packOp.getPaddingValue()) newOperands.push_back(pad);
}
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h
index b715a4c..1d16af5 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h
@@ -92,6 +92,7 @@
SmallVector<int64_t> innerDimsPos;
SmallVector<int64_t> innerTileSizes;
SmallVector<int64_t> outerDimsPerm;
+ unsigned srcRank = 0;
};
using MaterializeEncodingFn =
std::function<FailureOr<MaterializeEncodingInfo>(RankedTensorType)>;
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
index b335776..935ddd2 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
@@ -817,6 +817,11 @@
return rewriter.notifyMatchFailure(
packOp, "require the outer dimension of the result are all 1s");
}
+ if (llvm::any_of(packOp.getMixedTiles(),
+ [](OpFoldResult tile) { return tile.is<Value>(); })) {
+ return rewriter.notifyMatchFailure(
+ packOp, "require inner tile sizes being static");
+ }
Value input = getInputOrPaddedInput(rewriter, packOp);
@@ -954,9 +959,17 @@
SimpleRewriter rewriter(ctx);
auto packOptions = scf::SCFTileAndFuseOptions().setTilingOptions(
scf::SCFTilingOptions().setTileSizeComputationFunction(
- [](OpBuilder &builder, Operation *op) {
+ [](OpBuilder &builder, Operation *op) -> SmallVector<Value> {
Location loc = op->getLoc();
- int inputRank = cast<PackOp>(op).getInputRank();
+ auto packOp = cast<PackOp>(op);
+
+ // Do nothing if any of inner tile sizes is dynamic.
+ if (llvm::any_of(packOp.getMixedTiles(), [](OpFoldResult tile) {
+ return tile.is<Value>();
+ }))
+ return {};
+
+ int inputRank = packOp.getInputRank();
SmallVector<Value> tileSizes(
inputRank, builder.create<arith::ConstantIndexOp>(loc, 1));
return tileSizes;
diff --git a/tests/e2e/linalg_ext_ops/pack.mlir b/tests/e2e/linalg_ext_ops/pack.mlir
index 2fb47df..2031db6 100644
--- a/tests/e2e/linalg_ext_ops/pack.mlir
+++ b/tests/e2e/linalg_ext_ops/pack.mlir
@@ -437,3 +437,60 @@
check.expect_eq(%cast_pack, %transpose) : tensor<16x4x16x32xi32>
return
}
+
+func.func @fully_dynamic_pack_simple() {
+ %iree_input = flow.tensor.constant dense<[
+ [0, 1, 2, 3],
+ [4, 5, 6, 7],
+ [8, 9, 10, 11],
+ [12, 13, 14, 15]]> : tensor<4x4xi32> -> tensor<?x?xi32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = util.unfoldable_constant 2 : index
+ %in_d0 = tensor.dim %iree_input, %c0 : tensor<?x?xi32>
+ %in_d1 = tensor.dim %iree_input, %c1 : tensor<?x?xi32>
+ %out_d0 = arith.ceildivui %in_d0, %c2 : index
+ %out_d1 = arith.ceildivui %in_d1, %c2 : index
+ %init = tensor.empty(%out_d0, %out_d1, %c2, %c2) : tensor<?x?x?x?xi32>
+ %pack = iree_linalg_ext.pack %iree_input inner_dims_pos = [0, 1] inner_tiles = [%c2, %c2] into %init
+ : (tensor<?x?xi32> tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
+ %cast = tensor.cast %pack : tensor<?x?x?x?xi32> to tensor<2x2x2x2xi32>
+ check.expect_eq_const(%cast, dense<[[[[0, 1], [4, 5]], [[2, 3], [6, 7]]], [[[8, 9], [12, 13]], [[10 ,11], [14, 15]]]]> : tensor<2x2x2x2xi32>) : tensor<2x2x2x2xi32>
+ return
+}
+
+func.func @fully_dynamic_pack_pad_transpose_inner_and_outer_dims_large() {
+ %d0 = util.unfoldable_constant 100 : index
+ %d1 = util.unfoldable_constant 250 : index
+ %source = call @generate_2D_source(%d0, %d1) : (index, index) -> tensor<?x?xi32>
+ %padding_value = arith.constant 42 : i32
+
+ %c16 = util.unfoldable_constant 16 : index
+ %c32 = util.unfoldable_constant 32 : index
+ %tiled_d0 = arith.ceildivui %d0, %c32 : index
+ %tiled_d1 = arith.ceildivui %d1, %c16 : index
+ %init_pack = tensor.empty(%tiled_d1, %tiled_d0, %c16, %c32) : tensor<?x?x?x?xi32>
+ %pack = iree_linalg_ext.pack %source padding_value(%padding_value : i32)
+ outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [%c16, %c32] into %init_pack
+ : (tensor<?x?xi32> tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
+ %cast_pack = tensor.cast %pack : tensor<?x?x?x?xi32> to tensor<16x4x16x32xi32>
+
+ %c100 = arith.constant 100 : index
+ %c250 = arith.constant 250 : index
+ %source2 = call @generate_2D_source(%c100, %c250) : (index, index) -> tensor<?x?xi32>
+ %static_source = tensor.cast %source2 : tensor<?x?xi32> to tensor<100x250xi32>
+
+ %pad = tensor.pad %static_source low[0, 0] high[28, 6] {
+ ^bb0(%b0 : index, %b1 : index):
+ tensor.yield %padding_value : i32
+ } : tensor<100x250xi32> to tensor<128x256xi32>
+ %reshape = tensor.expand_shape %pad [[0, 1], [2, 3]] : tensor<128x256xi32> into tensor<4x32x16x16xi32>
+ %init_transpose = tensor.empty() : tensor<16x4x16x32xi32>
+ %transpose = linalg.transpose
+ ins(%reshape : tensor<4x32x16x16xi32>)
+ outs(%init_transpose : tensor<16x4x16x32xi32>)
+ permutation = [2, 0, 3, 1]
+
+ check.expect_eq(%cast_pack, %transpose) : tensor<16x4x16x32xi32>
+ return
+}