[Flow] Raise batch_matmul(a, transpose(b)) to batch_matmul_transpose_b (#14847)
Adds a similar raising pattern as that for matmul(a, transpose(b)).
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp
index 8d4c4d9..03fa3f8 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp
@@ -28,23 +28,30 @@
namespace {
-// Method to match a transpose operation.
-static bool match2DTranspose(linalg::LinalgOp genericOp) {
+// Method to match a transpose operation on the two most minor dimensions of the
+// specified rank.
+static bool matchInner2DTranspose(linalg::LinalgOp genericOp, unsigned rank) {
+ // Only makes sense for minimum rank 2.
+ if (rank < 2) {
+ return false;
+ }
if (genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1) {
return false;
}
- // Check only for 2D ops.
- if (genericOp.getNumLoops() != 2 ||
+ // Check only for ops of the specified rank.
+ if (genericOp.getNumLoops() != rank ||
genericOp.getNumLoops() != genericOp.getNumParallelLoops()) {
return false;
}
// Check for transpose map.
- AffineExpr d0, d1;
+ SmallVector<AffineExpr> exprList(rank);
MLIRContext *context = genericOp.getContext();
- bindDims(context, d0, d1);
+ bindDimsList(context, MutableArrayRef{exprList});
+ SmallVector<AffineExpr> transposeExprList(exprList);
+ std::swap(transposeExprList[rank - 1], transposeExprList[rank - 2]);
SmallVector<AffineMap> expectedMaps = {
- AffineMap::get(2, 0, {d0, d1}, context),
- AffineMap::get(2, 0, {d1, d0}, context)};
+ AffineMap::get(rank, 0, exprList, context),
+ AffineMap::get(rank, 0, transposeExprList, context)};
if (genericOp.getIndexingMapsArray() != expectedMaps) {
return false;
}
@@ -70,7 +77,21 @@
}
auto rhs = matmulOp.getDpsInputOperand(1);
auto genericOp = rhs->get().getDefiningOp<linalg::GenericOp>();
- if (genericOp && match2DTranspose(genericOp)) {
+ if (genericOp && matchInner2DTranspose(genericOp, 2)) {
+ return genericOp.getDpsInputOperand(0)->get();
+ }
+ return std::nullopt;
+}
+
+// Method to match a linalg.batch_matmul(a, linalg.transpose(b)). Returns `b` on
+// success.
+std::optional<Value> matchATransposeBBatchMatmul(linalg::LinalgOp bmmOp) {
+ if (!isa<linalg::BatchMatmulOp>(bmmOp.getOperation())) {
+ return std::nullopt;
+ }
+ auto rhs = bmmOp.getDpsInputOperand(1);
+ auto genericOp = rhs->get().getDefiningOp<linalg::GenericOp>();
+ if (genericOp && matchInner2DTranspose(genericOp, 3)) {
return genericOp.getDpsInputOperand(0)->get();
}
return std::nullopt;
@@ -361,6 +382,8 @@
SmallVector<std::pair<linalg::LinalgOp, Value>> softmaxRoots;
SmallVector<std::pair<linalg::MatmulOp, Value>> transposeMatmulRoots;
+ SmallVector<std::pair<linalg::BatchMatmulOp, Value>>
+ transposeBatchMatmulRoots;
SmallVector<std::pair<linalg::GenericOp, Value>> genericFills;
getOperation()->walk([&](linalg::LinalgOp op) {
{
@@ -376,6 +399,10 @@
transposeMatmulRoots.push_back(std::make_pair(
cast<linalg::MatmulOp>(op.getOperation()), newRhs.value()));
}
+ if (std::optional<Value> newRhs = matchATransposeBBatchMatmul(op)) {
+ transposeBatchMatmulRoots.push_back(std::make_pair(
+ cast<linalg::BatchMatmulOp>(op.getOperation()), newRhs.value()));
+ }
if (std::optional<Value> fillInput = matchGenericFill(op)) {
genericFills.push_back(
std::make_pair(cast<linalg::GenericOp>(op), fillInput.value()));
@@ -402,6 +429,17 @@
rewriter.replaceOpWithNewOp<linalg::MatmulTransposeBOp>(
matmulOp, ValueRange{lhs, newRhs}, ValueRange{init}, attrs);
}
+ for (std::pair<linalg::BatchMatmulOp, Value> aTransposeBBatchMatmul :
+ transposeBatchMatmulRoots) {
+ auto bmmOp = aTransposeBBatchMatmul.first;
+ Value lhs = bmmOp.getDpsInputOperand(0)->get();
+ auto newRhs = aTransposeBBatchMatmul.second;
+ Value init = bmmOp.getDpsInitOperand(0)->get();
+ rewriter.setInsertionPoint(bmmOp);
+ SmallVector<NamedAttribute> attrs = getPrunedAttributeList(bmmOp);
+ rewriter.replaceOpWithNewOp<linalg::BatchMatmulTransposeBOp>(
+ bmmOp, ValueRange{lhs, newRhs}, ValueRange{init}, attrs);
+ }
for (std::pair<linalg::GenericOp, Value> genericFill : genericFills) {
auto genericOp = genericFill.first;
Value fillInput = genericFill.second;
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir
index 54835d5..5e238aa 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir
@@ -187,6 +187,30 @@
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] :
// CHECK: return %[[RESULT]]
+func.func @aTransposeBBatchMatmul(%arg0 : tensor<5x10x20xf32>,
+ %arg1 : tensor<5x40x20xf32>) -> tensor<5x10x40xf32> {
+ %0 = tensor.empty() : tensor<5x20x40xf32>
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2, d1)>],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%arg1 : tensor<5x40x20xf32>) outs(%0 : tensor<5x20x40xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32):
+ linalg.yield %b0 : f32
+ } -> tensor<5x20x40xf32>
+ %2 = tensor.empty() : tensor<5x10x40xf32>
+ %3 = arith.constant 0.0 : f32
+ %4 = linalg.fill ins(%3 : f32) outs(%2 : tensor<5x10x40xf32>) -> tensor<5x10x40xf32>
+ %5 = linalg.batch_matmul ins(%arg0, %1 : tensor<5x10x20xf32>, tensor<5x20x40xf32>)
+ outs(%4 : tensor<5x10x40xf32>) -> tensor<5x10x40xf32>
+ return %5 : tensor<5x10x40xf32>
+}
+// CHECK-LABEL: func @aTransposeBBatchMatmul
+// CHECK-SAME: %[[ARG0:.+]]: tensor<5x10x20xf32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<5x40x20xf32>
+// CHECK: %[[RESULT:.+]] = linalg.batch_matmul_transpose_b
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] :
+// CHECK: return %[[RESULT]]
+
func.func @generic_fill(%arg0: tensor<?x?xf32>) -> tensor<1x1x?x?xf32> {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index