Materialize batch_matmul to batch_mmt4d (#14731)

Add pattern the materialize `batch_matmul` with data-tiling encoding to
`batch_mmt4d`

Tracking issue: #14431
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h
index 700dce5..889e0a7 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h
@@ -15,6 +15,12 @@
 namespace IREE {
 namespace LinalgExt {
 
+// Check if encoding user is one of matmul encodings.
+bool isMatmulEncodingUser(EncodingUser user);
+
+// Check if encoding user is one of batch matmul encodings.
+bool isBatchMatmulEncodingUser(EncodingUser user);
+
 struct MatmulTileParams {
   int64_t M = 1;
   int64_t K = 1;
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
index 769541b..342a483 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
@@ -248,7 +248,10 @@
   if (!lhsEncoding || !rhsEncoding || !resultEncoding) {
     return failure();
   }
-  if (lhsEncoding.getRole().getValue() !=
+  if (!isMatmulEncodingUser(lhsEncoding.getUser().getValue()) ||
+      !isMatmulEncodingUser(rhsEncoding.getUser().getValue()) ||
+      !isMatmulEncodingUser(resultEncoding.getUser().getValue()) ||
+      lhsEncoding.getRole().getValue() !=
           mlir::iree_compiler::IREE::LinalgExt::EncodingRole::LHS ||
       rhsEncoding.getRole().getValue() !=
           mlir::iree_compiler::IREE::LinalgExt::EncodingRole::RHS ||
@@ -262,8 +265,46 @@
   return mmt4DOp;
 }
 
-/// Utility method to convert from `linalg.fill` on `tensor` type with encoding
-/// to fill of the materialized type
+/// Utility method to convert from `linalg.batch_matmul` with
+/// - lhs encoding with user=BATCH_MATMUL_*, role=LHS
+/// - rhs encoding with user=BATCH_MATMUL_*, role=RHS
+/// - result encoding with user=BATCH_MATMUL_*, role=RESULT
+/// to linalg.batch_mmt4d op.
+static FailureOr<Operation *>
+lowerOpWithEncoding(RewriterBase &rewriter, linalg::BatchMatmulOp batchMatmulOp,
+                    ValueRange convertedInputOperands,
+                    ValueRange convertedOutputOperands, MaterializeEncodingFn,
+                    MaterializeEncodingValueFn) {
+  if (!batchMatmulOp.hasTensorSemantics())
+    return failure();
+  auto inputs = batchMatmulOp.getDpsInputOperands();
+  auto outputs = batchMatmulOp.getDpsInitOperands();
+  auto lhsEncoding =
+      getEncodingAttr(inputs[0]->get().getType().cast<RankedTensorType>());
+  auto rhsEncoding =
+      getEncodingAttr(inputs[1]->get().getType().cast<RankedTensorType>());
+  auto resultEncoding =
+      getEncodingAttr(outputs[0]->get().getType().cast<RankedTensorType>());
+  if (!lhsEncoding || !rhsEncoding || !resultEncoding) {
+    return failure();
+  }
+
+  if (!isBatchMatmulEncodingUser(lhsEncoding.getUser().getValue()) ||
+      !isBatchMatmulEncodingUser(rhsEncoding.getUser().getValue()) ||
+      !isBatchMatmulEncodingUser(resultEncoding.getUser().getValue()) ||
+      lhsEncoding.getRole().getValue() != EncodingRole::LHS ||
+      rhsEncoding.getRole().getValue() != EncodingRole::RHS ||
+      resultEncoding.getRole().getValue() != EncodingRole::RESULT) {
+    return failure();
+  }
+  Operation *batchMmt4DOp = rewriter.create<linalg::BatchMmt4DOp>(
+      batchMatmulOp.getLoc(), convertedOutputOperands[0].getType(),
+      convertedInputOperands, convertedOutputOperands);
+  return batchMmt4DOp;
+}
+
+/// Utility method to convert from `linalg.fill` on `tensor` type with
+/// encoding to fill of the materialized type
 static FailureOr<Operation *>
 lowerOpWithEncoding(RewriterBase &rewriter, linalg::FillOp fillOp,
                     ValueRange convertedInputOperands,
@@ -515,9 +556,11 @@
     MaterializeEncodingTypeConverter &typeConverter,
     MaterializeEncodingValueFn materializeEncodingValueFn) {
 
-  // Add all patterns for converting from encoded type to the materialized type
+  // Add all patterns for converting from encoded type to the materialized
+  // type
   patterns.insert<MaterializeDPSOperation<linalg::FillOp>,
                   MaterializeDPSOperation<linalg::MatmulOp>,
+                  MaterializeDPSOperation<linalg::BatchMatmulOp>,
                   MaterializeOperation<tensor::EmptyOp>,
                   SetEncodingOpToPackOpConversion,
                   UnsetEncodingOpToPackOpConversion>(
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp
index cae238d..73c141d 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp
@@ -11,11 +11,21 @@
 namespace IREE {
 namespace LinalgExt {
 
-MaterializeEncodingInfo
-chooseEncodingInfoForMatmul(EncodingUser user, EncodingRole role,
-                            MatmulTileParams tileParams) {
-  // Start dim of the MxK (LHS), KxN (RHS), or MxN (RESULT) 2D matrix.
-  int64_t matmulDimBase = 0;
+bool isMatmulEncodingUser(EncodingUser user) {
+  switch (user) {
+  case EncodingUser::MATMUL_F32F32F32:
+  case EncodingUser::MATMUL_F16F16F32:
+  case EncodingUser::MATMUL_F16F16F16:
+  case EncodingUser::MATMUL_BF16BF16F32:
+  case EncodingUser::MATMUL_BF16BF16BF16:
+  case EncodingUser::MATMUL_I8I8I32:
+    return true;
+  default:
+    return false;
+  }
+}
+
+bool isBatchMatmulEncodingUser(EncodingUser user) {
   switch (user) {
   case EncodingUser::BATCH_MATMUL_F32F32F32:
   case EncodingUser::BATCH_MATMUL_F16F16F32:
@@ -23,11 +33,17 @@
   case EncodingUser::BATCH_MATMUL_BF16BF16F32:
   case EncodingUser::BATCH_MATMUL_BF16BF16BF16:
   case EncodingUser::BATCH_MATMUL_I8I8I32:
-    matmulDimBase = 1;
-    break;
+    return true;
   default:
-    break;
+    return false;
   }
+}
+
+MaterializeEncodingInfo
+chooseEncodingInfoForMatmul(EncodingUser user, EncodingRole role,
+                            MatmulTileParams tileParams) {
+  // Start dim of the MxK (LHS), KxN (RHS), or MxN (RESULT) 2D matrix.
+  int64_t matmulDimBase = isBatchMatmulEncodingUser(user) ? 1 : 0;
 
   MaterializeEncodingInfo encodingInfo;
   encodingInfo.innerDimsPos = {matmulDimBase, matmulDimBase + 1};
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
index 7f757aa..d600287 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
@@ -240,3 +240,101 @@
 //      CHECK:   %[[UNPACK_DEST:.+]] = tensor.empty(%[[D0]], %[[D1]], %[[D2]]) : tensor<?x?x?xf32>
 //      CHECK:   %[[UNPACK:.+]] = tensor.unpack %[[PACK]] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[UNPACK_DEST]]
 //      CHECK:   return %[[UNPACK]]
+
+// -----
+
+func.func @pack_batch_matmul(%arg0 : tensor<128x80x32xf32>, %arg1 : tensor<128x32x320xf32>, %arg2 : tensor<128x80x320xf32>) -> tensor<128x80x320xf32> {
+  %0 = iree_linalg_ext.set_encoding %arg0 : tensor<128x80x32xf32> -> tensor<128x80x32xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = LHS>>
+  %1 = iree_linalg_ext.set_encoding %arg1 : tensor<128x32x320xf32> -> tensor<128x32x320xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RHS>>
+  %2 = iree_linalg_ext.set_encoding %arg2 : tensor<128x80x320xf32> -> tensor<128x80x320xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>
+  %3 = linalg.batch_matmul ins(%0, %1 : tensor<128x80x32xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = LHS>>, tensor<128x32x320xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RHS>>)
+      outs(%2 : tensor<128x80x320xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>) -> tensor<128x80x320xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>
+  %4 = iree_linalg_ext.unset_encoding %3 : tensor<128x80x320xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>> -> tensor<128x80x320xf32>
+  return %4 : tensor<128x80x320xf32>
+}
+//      CHECK: func @pack_batch_matmul(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x80x32xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<128x32x320xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<128x80x320xf32>
+//      CHECK:   %[[PACK_LHS:.+]] = tensor.pack
+// CHECK-SAME:     %[[ARG0]]
+//      CHECK:   %[[PACK_RHS:.+]] = tensor.pack
+// CHECK-SAME:     %[[ARG1]]
+//      CHECK:   %[[PACK_RESULT:.+]] = tensor.pack
+// CHECK-SAME:     %[[ARG2]]
+//      CHECK:   %[[BATCH_MMT4D:.+]] = linalg.batch_mmt4d
+// CHECK-SAME:       ins(%[[PACK_LHS]], %[[PACK_RHS]] :
+// CHECK-SAME:       outs(%[[PACK_RESULT]] :
+//      CHECK:   %[[UNPACK:.+]] = tensor.unpack %[[BATCH_MMT4D]]
+//      CHECK:   return %[[UNPACK]]
+
+// -----
+
+func.func @pack_batch_matmul_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>, %arg2 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?x?xf32> -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = LHS>>
+  %1 = iree_linalg_ext.set_encoding %arg1 : tensor<?x?x?xf32> -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RHS>>
+  %2 = iree_linalg_ext.set_encoding %arg2 : tensor<?x?x?xf32> -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>
+  %3 = linalg.batch_matmul ins(%0, %1 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = LHS>>, tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RHS>>)
+      outs(%2 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>) -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>
+  %4 = iree_linalg_ext.unset_encoding %3 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>> -> tensor<?x?x?xf32>
+  return %4 : tensor<?x?x?xf32>
+}
+//      CHECK: func @pack_batch_matmul_dynamic(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
+//      CHECK:   %[[PACK_LHS:.+]] = tensor.pack
+// CHECK-SAME:     %[[ARG0]]
+//      CHECK:   %[[PACK_RHS:.+]] = tensor.pack
+// CHECK-SAME:     %[[ARG1]]
+//      CHECK:   %[[PACK_RESULT:.+]] = tensor.pack
+// CHECK-SAME:     %[[ARG2]]
+//      CHECK:   %[[BATCH_MMT4D:.+]] = linalg.batch_mmt4d
+// CHECK-SAME:       ins(%[[PACK_LHS]], %[[PACK_RHS]] :
+// CHECK-SAME:       outs(%[[PACK_RESULT]] :
+//      CHECK:   %[[UNPACK:.+]] = tensor.unpack %[[BATCH_MMT4D]]
+//      CHECK:   return %[[UNPACK]]
+
+// -----
+
+func.func @pack_batch_matmul_fill_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %cst = arith.constant 0.0 : f32
+  %d0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
+  %d1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
+  %d2 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
+  %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?x?xf32> -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = LHS>>
+  %1 = iree_linalg_ext.set_encoding %arg1 : tensor<?x?x?xf32> -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RHS>>
+  %2 = tensor.empty(%d0, %d1, %d2) : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>
+  %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>)
+      -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>
+  %4 = linalg.batch_matmul ins(%0, %1 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = LHS>>, tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RHS>>)
+      outs(%3 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>) -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>
+  %5 = iree_linalg_ext.unset_encoding %4 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>> -> tensor<?x?x?xf32>
+  return %5 : tensor<?x?x?xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
+//      CHECK: func @pack_batch_matmul_fill_dynamic(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//  CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//  CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+//  CHECK-DAG:   %[[D2:.+]] = tensor.dim %[[ARG1]], %[[C2]]
+//  CHECK-DAG:   %[[OUT_D1:.+]] = affine.apply #[[MAP0]]()[%[[D1]]]
+//  CHECK-DAG:   %[[OUT_D2:.+]] = affine.apply #[[MAP0]]()[%[[D2]]]
+//  CHECK-DAG:   %[[PACK_LHS:.+]] = tensor.pack %[[ARG0]]
+//  CHECK-DAG:   %[[PACK_RHS:.+]] = tensor.pack %[[ARG1]]
+//  CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty(%[[D0]], %[[OUT_D1]], %[[OUT_D2]]) : tensor<?x?x?x8x8xf32>
+//      CHECK:   %[[FILL:.+]] = linalg.fill
+// CHECK-SAME:     outs(%[[EMPTY]] : tensor<?x?x?x8x8xf32>)
+//      CHECK:   %[[BATCH_MMT4D:.+]] = linalg.batch_mmt4d
+// CHECK-SAME:       ins(%[[PACK_LHS]], %[[PACK_RHS]] :
+// CHECK-SAME:       outs(%[[FILL]] :
+//      CHECK:   %[[UNPACK:.+]] = tensor.unpack %[[BATCH_MMT4D]]
+//      CHECK:   return %[[UNPACK]]