Materialize batch_matmul to batch_mmt4d (#14731)
Add pattern the materialize `batch_matmul` with data-tiling encoding to
`batch_mmt4d`
Tracking issue: #14431
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h
index 700dce5..889e0a7 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h
@@ -15,6 +15,12 @@
namespace IREE {
namespace LinalgExt {
+// Check if encoding user is one of matmul encodings.
+bool isMatmulEncodingUser(EncodingUser user);
+
+// Check if encoding user is one of batch matmul encodings.
+bool isBatchMatmulEncodingUser(EncodingUser user);
+
struct MatmulTileParams {
int64_t M = 1;
int64_t K = 1;
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
index 769541b..342a483 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
@@ -248,7 +248,10 @@
if (!lhsEncoding || !rhsEncoding || !resultEncoding) {
return failure();
}
- if (lhsEncoding.getRole().getValue() !=
+ if (!isMatmulEncodingUser(lhsEncoding.getUser().getValue()) ||
+ !isMatmulEncodingUser(rhsEncoding.getUser().getValue()) ||
+ !isMatmulEncodingUser(resultEncoding.getUser().getValue()) ||
+ lhsEncoding.getRole().getValue() !=
mlir::iree_compiler::IREE::LinalgExt::EncodingRole::LHS ||
rhsEncoding.getRole().getValue() !=
mlir::iree_compiler::IREE::LinalgExt::EncodingRole::RHS ||
@@ -262,8 +265,46 @@
return mmt4DOp;
}
-/// Utility method to convert from `linalg.fill` on `tensor` type with encoding
-/// to fill of the materialized type
+/// Utility method to convert from `linalg.batch_matmul` with
+/// - lhs encoding with user=BATCH_MATMUL_*, role=LHS
+/// - rhs encoding with user=BATCH_MATMUL_*, role=RHS
+/// - result encoding with user=BATCH_MATMUL_*, role=RESULT
+/// to linalg.batch_mmt4d op.
+static FailureOr<Operation *>
+lowerOpWithEncoding(RewriterBase &rewriter, linalg::BatchMatmulOp batchMatmulOp,
+ ValueRange convertedInputOperands,
+ ValueRange convertedOutputOperands, MaterializeEncodingFn,
+ MaterializeEncodingValueFn) {
+ if (!batchMatmulOp.hasTensorSemantics())
+ return failure();
+ auto inputs = batchMatmulOp.getDpsInputOperands();
+ auto outputs = batchMatmulOp.getDpsInitOperands();
+ auto lhsEncoding =
+ getEncodingAttr(inputs[0]->get().getType().cast<RankedTensorType>());
+ auto rhsEncoding =
+ getEncodingAttr(inputs[1]->get().getType().cast<RankedTensorType>());
+ auto resultEncoding =
+ getEncodingAttr(outputs[0]->get().getType().cast<RankedTensorType>());
+ if (!lhsEncoding || !rhsEncoding || !resultEncoding) {
+ return failure();
+ }
+
+ if (!isBatchMatmulEncodingUser(lhsEncoding.getUser().getValue()) ||
+ !isBatchMatmulEncodingUser(rhsEncoding.getUser().getValue()) ||
+ !isBatchMatmulEncodingUser(resultEncoding.getUser().getValue()) ||
+ lhsEncoding.getRole().getValue() != EncodingRole::LHS ||
+ rhsEncoding.getRole().getValue() != EncodingRole::RHS ||
+ resultEncoding.getRole().getValue() != EncodingRole::RESULT) {
+ return failure();
+ }
+ Operation *batchMmt4DOp = rewriter.create<linalg::BatchMmt4DOp>(
+ batchMatmulOp.getLoc(), convertedOutputOperands[0].getType(),
+ convertedInputOperands, convertedOutputOperands);
+ return batchMmt4DOp;
+}
+
+/// Utility method to convert from `linalg.fill` on `tensor` type with
+/// encoding to fill of the materialized type
static FailureOr<Operation *>
lowerOpWithEncoding(RewriterBase &rewriter, linalg::FillOp fillOp,
ValueRange convertedInputOperands,
@@ -515,9 +556,11 @@
MaterializeEncodingTypeConverter &typeConverter,
MaterializeEncodingValueFn materializeEncodingValueFn) {
- // Add all patterns for converting from encoded type to the materialized type
+ // Add all patterns for converting from encoded type to the materialized
+ // type
patterns.insert<MaterializeDPSOperation<linalg::FillOp>,
MaterializeDPSOperation<linalg::MatmulOp>,
+ MaterializeDPSOperation<linalg::BatchMatmulOp>,
MaterializeOperation<tensor::EmptyOp>,
SetEncodingOpToPackOpConversion,
UnsetEncodingOpToPackOpConversion>(
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp
index cae238d..73c141d 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp
@@ -11,11 +11,21 @@
namespace IREE {
namespace LinalgExt {
-MaterializeEncodingInfo
-chooseEncodingInfoForMatmul(EncodingUser user, EncodingRole role,
- MatmulTileParams tileParams) {
- // Start dim of the MxK (LHS), KxN (RHS), or MxN (RESULT) 2D matrix.
- int64_t matmulDimBase = 0;
+bool isMatmulEncodingUser(EncodingUser user) {
+ switch (user) {
+ case EncodingUser::MATMUL_F32F32F32:
+ case EncodingUser::MATMUL_F16F16F32:
+ case EncodingUser::MATMUL_F16F16F16:
+ case EncodingUser::MATMUL_BF16BF16F32:
+ case EncodingUser::MATMUL_BF16BF16BF16:
+ case EncodingUser::MATMUL_I8I8I32:
+ return true;
+ default:
+ return false;
+ }
+}
+
+bool isBatchMatmulEncodingUser(EncodingUser user) {
switch (user) {
case EncodingUser::BATCH_MATMUL_F32F32F32:
case EncodingUser::BATCH_MATMUL_F16F16F32:
@@ -23,11 +33,17 @@
case EncodingUser::BATCH_MATMUL_BF16BF16F32:
case EncodingUser::BATCH_MATMUL_BF16BF16BF16:
case EncodingUser::BATCH_MATMUL_I8I8I32:
- matmulDimBase = 1;
- break;
+ return true;
default:
- break;
+ return false;
}
+}
+
+MaterializeEncodingInfo
+chooseEncodingInfoForMatmul(EncodingUser user, EncodingRole role,
+ MatmulTileParams tileParams) {
+ // Start dim of the MxK (LHS), KxN (RHS), or MxN (RESULT) 2D matrix.
+ int64_t matmulDimBase = isBatchMatmulEncodingUser(user) ? 1 : 0;
MaterializeEncodingInfo encodingInfo;
encodingInfo.innerDimsPos = {matmulDimBase, matmulDimBase + 1};
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
index 7f757aa..d600287 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
@@ -240,3 +240,101 @@
// CHECK: %[[UNPACK_DEST:.+]] = tensor.empty(%[[D0]], %[[D1]], %[[D2]]) : tensor<?x?x?xf32>
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PACK]] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[UNPACK_DEST]]
// CHECK: return %[[UNPACK]]
+
+// -----
+
+func.func @pack_batch_matmul(%arg0 : tensor<128x80x32xf32>, %arg1 : tensor<128x32x320xf32>, %arg2 : tensor<128x80x320xf32>) -> tensor<128x80x320xf32> {
+ %0 = iree_linalg_ext.set_encoding %arg0 : tensor<128x80x32xf32> -> tensor<128x80x32xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = LHS>>
+ %1 = iree_linalg_ext.set_encoding %arg1 : tensor<128x32x320xf32> -> tensor<128x32x320xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RHS>>
+ %2 = iree_linalg_ext.set_encoding %arg2 : tensor<128x80x320xf32> -> tensor<128x80x320xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>
+ %3 = linalg.batch_matmul ins(%0, %1 : tensor<128x80x32xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = LHS>>, tensor<128x32x320xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RHS>>)
+ outs(%2 : tensor<128x80x320xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>) -> tensor<128x80x320xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>
+ %4 = iree_linalg_ext.unset_encoding %3 : tensor<128x80x320xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>> -> tensor<128x80x320xf32>
+ return %4 : tensor<128x80x320xf32>
+}
+// CHECK: func @pack_batch_matmul(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x80x32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<128x32x320xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<128x80x320xf32>
+// CHECK: %[[PACK_LHS:.+]] = tensor.pack
+// CHECK-SAME: %[[ARG0]]
+// CHECK: %[[PACK_RHS:.+]] = tensor.pack
+// CHECK-SAME: %[[ARG1]]
+// CHECK: %[[PACK_RESULT:.+]] = tensor.pack
+// CHECK-SAME: %[[ARG2]]
+// CHECK: %[[BATCH_MMT4D:.+]] = linalg.batch_mmt4d
+// CHECK-SAME: ins(%[[PACK_LHS]], %[[PACK_RHS]] :
+// CHECK-SAME: outs(%[[PACK_RESULT]] :
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[BATCH_MMT4D]]
+// CHECK: return %[[UNPACK]]
+
+// -----
+
+func.func @pack_batch_matmul_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>, %arg2 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?x?xf32> -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = LHS>>
+ %1 = iree_linalg_ext.set_encoding %arg1 : tensor<?x?x?xf32> -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RHS>>
+ %2 = iree_linalg_ext.set_encoding %arg2 : tensor<?x?x?xf32> -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>
+ %3 = linalg.batch_matmul ins(%0, %1 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = LHS>>, tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RHS>>)
+ outs(%2 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>) -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>
+ %4 = iree_linalg_ext.unset_encoding %3 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>> -> tensor<?x?x?xf32>
+ return %4 : tensor<?x?x?xf32>
+}
+// CHECK: func @pack_batch_matmul_dynamic(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
+// CHECK: %[[PACK_LHS:.+]] = tensor.pack
+// CHECK-SAME: %[[ARG0]]
+// CHECK: %[[PACK_RHS:.+]] = tensor.pack
+// CHECK-SAME: %[[ARG1]]
+// CHECK: %[[PACK_RESULT:.+]] = tensor.pack
+// CHECK-SAME: %[[ARG2]]
+// CHECK: %[[BATCH_MMT4D:.+]] = linalg.batch_mmt4d
+// CHECK-SAME: ins(%[[PACK_LHS]], %[[PACK_RHS]] :
+// CHECK-SAME: outs(%[[PACK_RESULT]] :
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[BATCH_MMT4D]]
+// CHECK: return %[[UNPACK]]
+
+// -----
+
+func.func @pack_batch_matmul_fill_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %cst = arith.constant 0.0 : f32
+ %d0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
+ %d1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
+ %d2 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
+ %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?x?xf32> -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = LHS>>
+ %1 = iree_linalg_ext.set_encoding %arg1 : tensor<?x?x?xf32> -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RHS>>
+ %2 = tensor.empty(%d0, %d1, %d2) : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>
+ %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>)
+ -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>
+ %4 = linalg.batch_matmul ins(%0, %1 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = LHS>>, tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RHS>>)
+ outs(%3 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>) -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>
+ %5 = iree_linalg_ext.unset_encoding %4 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>> -> tensor<?x?x?xf32>
+ return %5 : tensor<?x?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
+// CHECK: func @pack_batch_matmul_fill_dynamic(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG1]], %[[C2]]
+// CHECK-DAG: %[[OUT_D1:.+]] = affine.apply #[[MAP0]]()[%[[D1]]]
+// CHECK-DAG: %[[OUT_D2:.+]] = affine.apply #[[MAP0]]()[%[[D2]]]
+// CHECK-DAG: %[[PACK_LHS:.+]] = tensor.pack %[[ARG0]]
+// CHECK-DAG: %[[PACK_RHS:.+]] = tensor.pack %[[ARG1]]
+// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[D0]], %[[OUT_D1]], %[[OUT_D2]]) : tensor<?x?x?x8x8xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x?x?x8x8xf32>)
+// CHECK: %[[BATCH_MMT4D:.+]] = linalg.batch_mmt4d
+// CHECK-SAME: ins(%[[PACK_LHS]], %[[PACK_RHS]] :
+// CHECK-SAME: outs(%[[FILL]] :
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[BATCH_MMT4D]]
+// CHECK: return %[[UNPACK]]