[Flow] Raise batch_matmul(a, transpose(b)) to batch_matmul_transpose_b (#14847)

Adds a similar raising pattern as that for matmul(a, transpose(b)).
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp
index 8d4c4d9..03fa3f8 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp
@@ -28,23 +28,30 @@
 
 namespace {
 
-// Method to match a transpose operation.
-static bool match2DTranspose(linalg::LinalgOp genericOp) {
+// Method to match a transpose operation on the two most minor dimensions of the
+// specified rank.
+static bool matchInner2DTranspose(linalg::LinalgOp genericOp, unsigned rank) {
+  // Only makes sense for minimum rank 2.
+  if (rank < 2) {
+    return false;
+  }
   if (genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1) {
     return false;
   }
-  // Check only for 2D ops.
-  if (genericOp.getNumLoops() != 2 ||
+  // Check only for ops of the specified rank.
+  if (genericOp.getNumLoops() != rank ||
       genericOp.getNumLoops() != genericOp.getNumParallelLoops()) {
     return false;
   }
   // Check for transpose map.
-  AffineExpr d0, d1;
+  SmallVector<AffineExpr> exprList(rank);
   MLIRContext *context = genericOp.getContext();
-  bindDims(context, d0, d1);
+  bindDimsList(context, MutableArrayRef{exprList});
+  SmallVector<AffineExpr> transposeExprList(exprList);
+  std::swap(transposeExprList[rank - 1], transposeExprList[rank - 2]);
   SmallVector<AffineMap> expectedMaps = {
-      AffineMap::get(2, 0, {d0, d1}, context),
-      AffineMap::get(2, 0, {d1, d0}, context)};
+      AffineMap::get(rank, 0, exprList, context),
+      AffineMap::get(rank, 0, transposeExprList, context)};
   if (genericOp.getIndexingMapsArray() != expectedMaps) {
     return false;
   }
@@ -70,7 +77,21 @@
   }
   auto rhs = matmulOp.getDpsInputOperand(1);
   auto genericOp = rhs->get().getDefiningOp<linalg::GenericOp>();
-  if (genericOp && match2DTranspose(genericOp)) {
+  if (genericOp && matchInner2DTranspose(genericOp, 2)) {
+    return genericOp.getDpsInputOperand(0)->get();
+  }
+  return std::nullopt;
+}
+
+// Method to match a linalg.batch_matmul(a, linalg.transpose(b)). Returns `b` on
+// success.
+std::optional<Value> matchATransposeBBatchMatmul(linalg::LinalgOp bmmOp) {
+  if (!isa<linalg::BatchMatmulOp>(bmmOp.getOperation())) {
+    return std::nullopt;
+  }
+  auto rhs = bmmOp.getDpsInputOperand(1);
+  auto genericOp = rhs->get().getDefiningOp<linalg::GenericOp>();
+  if (genericOp && matchInner2DTranspose(genericOp, 3)) {
     return genericOp.getDpsInputOperand(0)->get();
   }
   return std::nullopt;
@@ -361,6 +382,8 @@
 
     SmallVector<std::pair<linalg::LinalgOp, Value>> softmaxRoots;
     SmallVector<std::pair<linalg::MatmulOp, Value>> transposeMatmulRoots;
+    SmallVector<std::pair<linalg::BatchMatmulOp, Value>>
+        transposeBatchMatmulRoots;
     SmallVector<std::pair<linalg::GenericOp, Value>> genericFills;
     getOperation()->walk([&](linalg::LinalgOp op) {
       {
@@ -376,6 +399,10 @@
           transposeMatmulRoots.push_back(std::make_pair(
               cast<linalg::MatmulOp>(op.getOperation()), newRhs.value()));
         }
+        if (std::optional<Value> newRhs = matchATransposeBBatchMatmul(op)) {
+          transposeBatchMatmulRoots.push_back(std::make_pair(
+              cast<linalg::BatchMatmulOp>(op.getOperation()), newRhs.value()));
+        }
         if (std::optional<Value> fillInput = matchGenericFill(op)) {
           genericFills.push_back(
               std::make_pair(cast<linalg::GenericOp>(op), fillInput.value()));
@@ -402,6 +429,17 @@
       rewriter.replaceOpWithNewOp<linalg::MatmulTransposeBOp>(
           matmulOp, ValueRange{lhs, newRhs}, ValueRange{init}, attrs);
     }
+    for (std::pair<linalg::BatchMatmulOp, Value> aTransposeBBatchMatmul :
+         transposeBatchMatmulRoots) {
+      auto bmmOp = aTransposeBBatchMatmul.first;
+      Value lhs = bmmOp.getDpsInputOperand(0)->get();
+      auto newRhs = aTransposeBBatchMatmul.second;
+      Value init = bmmOp.getDpsInitOperand(0)->get();
+      rewriter.setInsertionPoint(bmmOp);
+      SmallVector<NamedAttribute> attrs = getPrunedAttributeList(bmmOp);
+      rewriter.replaceOpWithNewOp<linalg::BatchMatmulTransposeBOp>(
+          bmmOp, ValueRange{lhs, newRhs}, ValueRange{init}, attrs);
+    }
     for (std::pair<linalg::GenericOp, Value> genericFill : genericFills) {
       auto genericOp = genericFill.first;
       Value fillInput = genericFill.second;
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir
index 54835d5..5e238aa 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir
@@ -187,6 +187,30 @@
 //  CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]] :
 //       CHECK:   return %[[RESULT]]
 
+func.func @aTransposeBBatchMatmul(%arg0 : tensor<5x10x20xf32>,
+    %arg1 : tensor<5x40x20xf32>) -> tensor<5x10x40xf32> {
+  %0 = tensor.empty() : tensor<5x20x40xf32>
+  %1 = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2, d1)>],
+      iterator_types = ["parallel", "parallel", "parallel"]}
+      ins(%arg1 : tensor<5x40x20xf32>) outs(%0 : tensor<5x20x40xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32):
+      linalg.yield %b0 : f32
+  } -> tensor<5x20x40xf32>
+  %2 = tensor.empty() : tensor<5x10x40xf32>
+  %3 = arith.constant 0.0 : f32
+  %4 = linalg.fill ins(%3 : f32) outs(%2 : tensor<5x10x40xf32>) -> tensor<5x10x40xf32>
+  %5 = linalg.batch_matmul ins(%arg0, %1 : tensor<5x10x20xf32>, tensor<5x20x40xf32>)
+      outs(%4 : tensor<5x10x40xf32>) -> tensor<5x10x40xf32>
+  return %5 : tensor<5x10x40xf32>
+}
+// CHECK-LABEL: func @aTransposeBBatchMatmul
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<5x10x20xf32>
+//  CHECK-SAME:     %[[ARG1:.+]]: tensor<5x40x20xf32>
+//       CHECK:   %[[RESULT:.+]] = linalg.batch_matmul_transpose_b
+//  CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]] :
+//       CHECK:   return %[[RESULT]]
+
 func.func @generic_fill(%arg0: tensor<?x?xf32>) -> tensor<1x1x?x?xf32> {
     %cst = arith.constant 0.000000e+00 : f32
     %c0 = arith.constant 0 : index