Add a second code sequence to match softmax (#12104)
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp
index a71b92b..8181f2b 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp
@@ -35,20 +35,27 @@
void runOnOperation() override {
SmallVector<std::pair<linalg::LinalgOp, Value>> softmaxRoots;
getOperation()->walk([&](linalg::LinalgOp op) {
- StructuredOpMatcher reduction, fill, leading, trailing;
- transform_ext::StructuredOpMatcher fillMinusInf;
- transform_ext::StructuredOpMatcher maxReduction;
- transform_ext::StructuredOpMatcher sub;
- transform_ext::StructuredOpMatcher expOperand;
- transform_ext::StructuredOpMatcher fillzero;
- transform_ext::StructuredOpMatcher sum;
- transform_ext::StructuredOpMatcher divOperand;
- transform_ext::StructuredOpMatcher softmaxroot;
- makeSoftmaxMatcher(fillMinusInf, maxReduction, sub, expOperand, fillzero,
- sum, divOperand, softmaxroot);
- if (matchPattern(op, softmaxroot)) {
- Value src = maxReduction.getCaptured()->getOperand(0);
- softmaxRoots.push_back(std::make_pair(op, src));
+ {
+ transform_ext::StructuredOpMatcher fillMinusInf;
+ transform_ext::StructuredOpMatcher maxReduction;
+ transform_ext::StructuredOpMatcher maybeBroadcastMax;
+ transform_ext::StructuredOpMatcher sub;
+ transform_ext::StructuredOpMatcher broadcastedSub;
+ transform_ext::StructuredOpMatcher expOperand;
+ transform_ext::StructuredOpMatcher fillzero;
+ transform_ext::StructuredOpMatcher sum;
+ transform_ext::StructuredOpMatcher maybeBroadcastSum;
+ transform_ext::StructuredOpMatcher rcpOperand;
+ transform_ext::StructuredOpMatcher matmulOperand;
+ transform_ext::StructuredOpMatcher divOperand;
+ transform_ext::StructuredOpMatcher softmaxroot;
+ makeSoftmaxMatcher(fillMinusInf, maxReduction, sub, expOperand,
+ fillzero, sum, rcpOperand, matmulOperand, divOperand,
+ softmaxroot);
+ if (matchPattern(op, softmaxroot)) {
+ Value src = maxReduction.getCaptured()->getOperand(0);
+ softmaxRoots.push_back(std::make_pair(op, src));
+ }
}
});
for (std::pair<linalg::LinalgOp, Value> softmax : softmaxRoots) {
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir
index b93e9e8..5300cd7 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir
@@ -52,3 +52,59 @@
} -> tensor<?x?x?xf32>
return %10 : tensor<?x?x?xf32>
}
+
+// CHECK-LABEL: @softmax_no_rcp
+// CHECK-SAME: %[[ARG:.+]]: tensor<10x4096x4096xf16>
+// CHECK: %[[E:.+]] = tensor.empty() : tensor<10x4096x4096xf16>
+// CHECK: %[[S:.+]] = iree_linalg_ext.softmax dimension(2) ins(%[[ARG]] : tensor<10x4096x4096xf16>) outs(%[[E]] : tensor<10x4096x4096xf16>) -> tensor<10x4096x4096xf16>
+// CHECK: return %[[S]] : tensor<10x4096x4096xf16>
+func.func @softmax_no_rcp(%src : tensor<10x4096x4096xf16>) -> (tensor<10x4096x4096xf16>) {
+ %cst_158 = arith.constant -6.550400e+04 : f16
+ %cst_121 = arith.constant 0.000000e+00 : f16
+ %224 = tensor.empty() : tensor<10x4096xf16>
+ %216 = tensor.empty() : tensor<10x4096x4096xf16>
+ %225 = linalg.fill ins(%cst_158 : f16) outs(%224 : tensor<10x4096xf16>) -> tensor<10x4096xf16>
+ %226 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%src : tensor<10x4096x4096xf16>) outs(%225 : tensor<10x4096xf16>) {
+ ^bb0(%in: f16, %out: f16):
+ %5290 = arith.maxf %in, %out : f16
+ linalg.yield %5290 : f16
+ } -> tensor<10x4096xf16>
+ %227 = linalg.generic
+ {indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%src, %226 : tensor<10x4096x4096xf16>, tensor<10x4096xf16>) outs(%216 : tensor<10x4096x4096xf16>) {
+ ^bb0(%in: f16, %in_1572: f16, %out: f16):
+ %5290 = arith.subf %in, %in_1572 : f16
+ linalg.yield %5290 : f16
+ } -> tensor<10x4096x4096xf16>
+ %228 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%227 : tensor<10x4096x4096xf16>) outs(%216 : tensor<10x4096x4096xf16>) {
+ ^bb0(%in: f16, %out: f16):
+ %5290 = math.exp %in : f16
+ linalg.yield %5290 : f16
+ } -> tensor<10x4096x4096xf16>
+ %229 = tensor.empty() : tensor<10x4096xf16>
+ %230 = linalg.fill ins(%cst_121 : f16) outs(%229 : tensor<10x4096xf16>) -> tensor<10x4096xf16>
+ %231 = linalg.generic
+ {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%228 : tensor<10x4096x4096xf16>) outs(%230 : tensor<10x4096xf16>) {
+ ^bb0(%in: f16, %out: f16):
+ %5290 = arith.addf %in, %out : f16
+ linalg.yield %5290 : f16
+ } -> tensor<10x4096xf16>
+ %232 = linalg.generic
+ {indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%228, %231 : tensor<10x4096x4096xf16>, tensor<10x4096xf16>) outs(%216 : tensor<10x4096x4096xf16>) {
+ ^bb0(%in: f16, %in_1572: f16, %out: f16):
+ %5290 = arith.divf %in, %in_1572 : f16
+ linalg.yield %5290 : f16
+ } -> tensor<10x4096x4096xf16>
+ return %232 : tensor<10x4096x4096xf16>
+}
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
index 61870c4..9afc0f6 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
@@ -237,6 +237,9 @@
});
}
+ /// Matches a structured operation if either patterns A or B match.
+ StructuredOpMatcher(StructuredOpMatcher &A, StructuredOpMatcher &B);
+
/// Matches the given operation, hook for `matchPattern`.
bool match(Operation *op);
@@ -515,6 +518,11 @@
/// Creates a matcher of an arbitrary structured op.
inline StructuredOpMatcher m_StructuredOp() { return StructuredOpMatcher(); }
+inline StructuredOpMatcher m_StructuredOp_Or(StructuredOpMatcher &A,
+ StructuredOpMatcher &B) {
+ return StructuredOpMatcher(A, B);
+}
+
/// Creates a matcher of a structured op with kinds provided as template
/// arguments.
template <typename... OpType>
@@ -633,21 +641,22 @@
StructuredOpMatcher &trailing,
MatchedReductionCaptures &captures);
-/// Create a group of matchers for a sequence of operations matching exactly a
-/// softmax operation.
+/// Create a group of matchers for a different code sequence of operations
+/// matching exactly a softmax operation.
///
/// %red = reduce_max(%0)
/// %sub = sub(%0, %red)
/// %exp = exp(%sub)
/// %sum = reduce_sum(%exp)
-/// %rec = reciprocal(%sum)
-/// %mul = mul(%exp, %rec)
+/// %mul = div(%exp, %%sum)
void makeSoftmaxMatcher(transform_ext::StructuredOpMatcher &fillMinusInf,
transform_ext::StructuredOpMatcher &maxReduction,
transform_ext::StructuredOpMatcher &sub,
transform_ext::StructuredOpMatcher &expOperand,
- transform_ext::StructuredOpMatcher &fillzero,
+ transform_ext::StructuredOpMatcher &fillZero,
transform_ext::StructuredOpMatcher &sum,
+ transform_ext::StructuredOpMatcher &rcpOperand,
+ transform_ext::StructuredOpMatcher &mulOperand,
transform_ext::StructuredOpMatcher &divOperand,
transform_ext::StructuredOpMatcher &softmaxroot);
diff --git a/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp b/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
index 19a8893..560f83b 100644
--- a/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
@@ -257,6 +257,30 @@
return *this;
}
+transform_ext::StructuredOpMatcher::StructuredOpMatcher(
+ StructuredOpMatcher &A, StructuredOpMatcher &B) {
+
+ predicates.push_back([&A, &B](linalg::LinalgOp linalgOp) -> bool {
+ LLVM_DEBUG(DBGS() << "start recursive lhs OR match {\n");
+ {
+ auto debugRAII = llvm::make_scope_exit(
+ [] { LLVM_DEBUG(DBGS() << "} end recursive match"); });
+ if (A.match(linalgOp))
+ return true;
+ }
+ LLVM_DEBUG(DBGS() << "start recursive rhs OR match {\n");
+ {
+ auto debugRAII = llvm::make_scope_exit(
+ [] { LLVM_DEBUG(DBGS() << "} end recursive match"); });
+ if (B.match(linalgOp))
+ return true;
+ }
+ return false;
+ });
+ recordNestedMatcher(A);
+ recordNestedMatcher(B);
+}
+
//===---------------------------------------------------------------------===//
// Constraints on input operands.
//===---------------------------------------------------------------------===//
@@ -903,11 +927,12 @@
transform_ext::StructuredOpMatcher &maxReduction,
transform_ext::StructuredOpMatcher &sub,
transform_ext::StructuredOpMatcher &expOperand,
- transform_ext::StructuredOpMatcher &fillzero,
+ transform_ext::StructuredOpMatcher &fillZero,
transform_ext::StructuredOpMatcher &sum,
+ transform_ext::StructuredOpMatcher &rcpOperand,
+ transform_ext::StructuredOpMatcher &mulOperand,
transform_ext::StructuredOpMatcher &divOperand,
transform_ext::StructuredOpMatcher &softmaxroot) {
-
fillMinusInf = m_StructuredOp<linalg::FillOp>().input(0, ConstantFloatMin());
maxReduction = transform_ext::m_StructuredOp<linalg::GenericOp>()
.singleOpWithCanonicaleArgs<arith::MaxFOp>()
@@ -941,7 +966,7 @@
.output(NumEqualsTo(1));
expOperand = expOperand.input(0, sub);
- fillzero = m_StructuredOp<linalg::FillOp>().input(0, ConstantFloatZero());
+ fillZero = m_StructuredOp<linalg::FillOp>().input(0, ConstantFloatZero());
sum = m_StructuredOp<linalg::GenericOp>()
.singleOpWithCanonicaleArgs<arith::AddFOp>()
// Only handle most inner reduction for now.
@@ -952,26 +977,40 @@
.output(AllOperands(), IsProjected(-1))
.output(NumEqualsTo(1));
sum = sum.input(0, expOperand);
- sum = sum.output(0, fillzero);
+ sum = sum.output(0, fillZero);
- divOperand = m_StructuredOp<linalg::GenericOp>()
+ rcpOperand = m_StructuredOp<linalg::GenericOp>()
.isFloatReciprocal()
.dim(AllDims(), utils::IteratorType::parallel)
.input(NumEqualsTo(1))
.input(AllOperands(), IsIdentity())
.output(AllOperands(), IsIdentity())
.output(NumEqualsTo(1));
- divOperand = divOperand.input(0, sum);
+ rcpOperand = rcpOperand.input(0, sum);
- softmaxroot = transform_ext::m_StructuredOp<linalg::GenericOp>()
- .singleOpWithCanonicaleArgs<arith::MulFOp>()
- .dim(AllDims(), utils::IteratorType::parallel)
- .input(NumEqualsTo(2))
- .input(0, IsIdentity())
- .input(1, IsProjected(-1))
- .output(NumEqualsTo(1))
- .output(AllOperands(), IsIdentity());
+ mulOperand = transform_ext::m_StructuredOp<linalg::GenericOp>()
+ .singleOpWithCanonicaleArgs<arith::MulFOp>()
+ .dim(AllDims(), utils::IteratorType::parallel)
+ .input(NumEqualsTo(2))
+ .input(0, IsIdentity())
+ .input(1, IsProjected(-1))
+ .output(NumEqualsTo(1))
+ .output(AllOperands(), IsIdentity());
- softmaxroot = softmaxroot.input(0, expOperand);
- softmaxroot = softmaxroot.input(1, divOperand);
+ mulOperand = mulOperand.input(0, expOperand);
+ mulOperand = mulOperand.input(1, rcpOperand);
+
+ divOperand = transform_ext::m_StructuredOp<linalg::GenericOp>()
+ .singleOpWithCanonicaleArgs<arith::DivFOp>()
+ .dim(AllDims(), utils::IteratorType::parallel)
+ .input(NumEqualsTo(2))
+ .input(0, IsIdentity())
+ .input(1, IsProjected(-1))
+ .output(NumEqualsTo(1))
+ .output(AllOperands(), IsIdentity());
+
+ divOperand = divOperand.input(0, expOperand);
+ divOperand = divOperand.input(1, sum);
+
+ softmaxroot = transform_ext::m_StructuredOp_Or(mulOperand, divOperand);
}