Fix size calculation in the tensor.empty materialization pattern. (#15359)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
index 1fff730..0dd823b 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
@@ -379,8 +379,8 @@
return rewriter.notifyMatchFailure(
emptyOp, "failed to generate runtime tile size query");
}
- SmallVector<OpFoldResult> sourceDims = getMixedValues(
- resultType.getShape(), emptyOp.getDynamicSizes(), rewriter);
+ SmallVector<OpFoldResult> sourceDims = emptyOp.getMixedSizes();
+ (void)foldDynamicIndexList(sourceDims);
SmallVector<OpFoldResult> newShape =
PackOp::getResultShape(rewriter, loc, sourceDims, *innerTileSizesOfr,
materializeEncodingInfo->innerDimsPos,
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
index 2ca2772..fb40bf3 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
@@ -297,6 +297,45 @@
// -----
+#map = affine_map<()[s0, s1] -> ((s1 ceildiv s0) * s0)>
+func.func @pack_batch_matmul_fill_partial_dynamic(%arg0: tensor<16x?x4096xf32>, %arg1: tensor<16x4096x4096xf32>) -> tensor<16x?x4096xf32> {
+ %c16 = arith.constant 16 : index
+ %c8 = arith.constant 8 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4096 = arith.constant 4096 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %dim = tensor.dim %arg0, %c1 : tensor<16x?x4096xf32>
+ %0 = iree_linalg_ext.set_encoding %arg0 : tensor<16x?x4096xf32> -> tensor<16x?x4096xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = LHS, element_types = [f32, f32, f32], original_type = tensor<16x?x4096xf32>>>
+ %1 = iree_linalg_ext.set_encoding %arg1 : tensor<16x4096x4096xf32> -> tensor<16x4096x4096xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RHS, element_types = [f32, f32, f32], original_type = tensor<16x4096x4096xf32>>>
+ %2 = affine.apply #map()[%c8, %dim]
+ %3 = tensor.empty(%c16, %2, %c4096) : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RESULT, element_types = [f32, f32, f32], original_type = tensor<16x?x4096xf32>>>
+ %4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RESULT, element_types = [f32, f32, f32], original_type = tensor<16x?x4096xf32>>>) -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RESULT, element_types = [f32, f32, f32], original_type = tensor<16x?x4096xf32>>>
+ %5 = linalg.batch_matmul ins(%0, %1 : tensor<16x?x4096xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = LHS, element_types = [f32, f32, f32], original_type = tensor<16x?x4096xf32>>>, tensor<16x4096x4096xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RHS, element_types = [f32, f32, f32], original_type = tensor<16x4096x4096xf32>>>) outs(%4 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RESULT, element_types = [f32, f32, f32], original_type = tensor<16x?x4096xf32>>>) -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RESULT, element_types = [f32, f32, f32], original_type = tensor<16x?x4096xf32>>>
+ %6 = iree_linalg_ext.unset_encoding %5 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RESULT, element_types = [f32, f32, f32], original_type = tensor<16x?x4096xf32>>> -> tensor<?x?x?xf32>
+ %extracted_slice = tensor.extract_slice %6[0, 0, 0] [16, %dim, 4096] [1, 1, 1] : tensor<?x?x?xf32> to tensor<16x?x4096xf32>
+ return %extracted_slice : tensor<16x?x4096xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
+// CHECK: func @pack_batch_matmul_fill_partial_dynamic(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[OUT_D1:.+]] = affine.apply #[[MAP0]]()[%[[D1]]]
+// CHECK-DAG: %[[PACK_LHS:.+]] = tensor.pack %[[ARG0]]
+// CHECK-DAG: %[[PACK_RHS:.+]] = tensor.pack %[[ARG1]]
+// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[OUT_D1]]) : tensor<16x?x512x8x8xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<16x?x512x8x8xf32>)
+// CHECK: %[[BATCH_MMT4D:.+]] = linalg.batch_mmt4d
+// CHECK-SAME: ins(%[[PACK_LHS]], %[[PACK_RHS]] :
+// CHECK-SAME: outs(%[[FILL]] :
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[BATCH_MMT4D]]
+// CHECK: return %[[UNPACK]]
+
+// -----
+
func.func @pack_batch_matmul_fill_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index