Add a second code sequence to match softmax (#12104)
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp index a71b92b..8181f2b 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp
@@ -35,20 +35,27 @@ void runOnOperation() override { SmallVector<std::pair<linalg::LinalgOp, Value>> softmaxRoots; getOperation()->walk([&](linalg::LinalgOp op) { - StructuredOpMatcher reduction, fill, leading, trailing; - transform_ext::StructuredOpMatcher fillMinusInf; - transform_ext::StructuredOpMatcher maxReduction; - transform_ext::StructuredOpMatcher sub; - transform_ext::StructuredOpMatcher expOperand; - transform_ext::StructuredOpMatcher fillzero; - transform_ext::StructuredOpMatcher sum; - transform_ext::StructuredOpMatcher divOperand; - transform_ext::StructuredOpMatcher softmaxroot; - makeSoftmaxMatcher(fillMinusInf, maxReduction, sub, expOperand, fillzero, - sum, divOperand, softmaxroot); - if (matchPattern(op, softmaxroot)) { - Value src = maxReduction.getCaptured()->getOperand(0); - softmaxRoots.push_back(std::make_pair(op, src)); + { + transform_ext::StructuredOpMatcher fillMinusInf; + transform_ext::StructuredOpMatcher maxReduction; + transform_ext::StructuredOpMatcher maybeBroadcastMax; + transform_ext::StructuredOpMatcher sub; + transform_ext::StructuredOpMatcher broadcastedSub; + transform_ext::StructuredOpMatcher expOperand; + transform_ext::StructuredOpMatcher fillzero; + transform_ext::StructuredOpMatcher sum; + transform_ext::StructuredOpMatcher maybeBroadcastSum; + transform_ext::StructuredOpMatcher rcpOperand; + transform_ext::StructuredOpMatcher matmulOperand; + transform_ext::StructuredOpMatcher divOperand; + transform_ext::StructuredOpMatcher softmaxroot; + makeSoftmaxMatcher(fillMinusInf, maxReduction, sub, expOperand, + fillzero, sum, rcpOperand, matmulOperand, divOperand, + softmaxroot); + if (matchPattern(op, softmaxroot)) { + Value src = maxReduction.getCaptured()->getOperand(0); + softmaxRoots.push_back(std::make_pair(op, src)); + } } }); for (std::pair<linalg::LinalgOp, Value> softmax : softmaxRoots) {
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir index b93e9e8..5300cd7 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir
@@ -52,3 +52,59 @@ } -> tensor<?x?x?xf32> return %10 : tensor<?x?x?xf32> } + +// CHECK-LABEL: @softmax_no_rcp +// CHECK-SAME: %[[ARG:.+]]: tensor<10x4096x4096xf16> +// CHECK: %[[E:.+]] = tensor.empty() : tensor<10x4096x4096xf16> +// CHECK: %[[S:.+]] = iree_linalg_ext.softmax dimension(2) ins(%[[ARG]] : tensor<10x4096x4096xf16>) outs(%[[E]] : tensor<10x4096x4096xf16>) -> tensor<10x4096x4096xf16> +// CHECK: return %[[S]] : tensor<10x4096x4096xf16> +func.func @softmax_no_rcp(%src : tensor<10x4096x4096xf16>) -> (tensor<10x4096x4096xf16>) { + %cst_158 = arith.constant -6.550400e+04 : f16 + %cst_121 = arith.constant 0.000000e+00 : f16 + %224 = tensor.empty() : tensor<10x4096xf16> + %216 = tensor.empty() : tensor<10x4096x4096xf16> + %225 = linalg.fill ins(%cst_158 : f16) outs(%224 : tensor<10x4096xf16>) -> tensor<10x4096xf16> + %226 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%src : tensor<10x4096x4096xf16>) outs(%225 : tensor<10x4096xf16>) { + ^bb0(%in: f16, %out: f16): + %5290 = arith.maxf %in, %out : f16 + linalg.yield %5290 : f16 + } -> tensor<10x4096xf16> + %227 = linalg.generic + {indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%src, %226 : tensor<10x4096x4096xf16>, tensor<10x4096xf16>) outs(%216 : tensor<10x4096x4096xf16>) { + ^bb0(%in: f16, %in_1572: f16, %out: f16): + %5290 = arith.subf %in, %in_1572 : f16 + linalg.yield %5290 : f16 + } -> tensor<10x4096x4096xf16> + %228 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%227 : tensor<10x4096x4096xf16>) outs(%216 : tensor<10x4096x4096xf16>) { + ^bb0(%in: f16, %out: f16): + %5290 = math.exp %in : f16 + linalg.yield %5290 : f16 + } -> tensor<10x4096x4096xf16> + %229 = tensor.empty() : tensor<10x4096xf16> + %230 = linalg.fill ins(%cst_121 : f16) outs(%229 : tensor<10x4096xf16>) -> tensor<10x4096xf16> + %231 = linalg.generic + {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%228 : tensor<10x4096x4096xf16>) outs(%230 : tensor<10x4096xf16>) { + ^bb0(%in: f16, %out: f16): + %5290 = arith.addf %in, %out : f16 + linalg.yield %5290 : f16 + } -> tensor<10x4096xf16> + %232 = linalg.generic + {indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%228, %231 : tensor<10x4096x4096xf16>, tensor<10x4096xf16>) outs(%216 : tensor<10x4096x4096xf16>) { + ^bb0(%in: f16, %in_1572: f16, %out: f16): + %5290 = arith.divf %in, %in_1572 : f16 + linalg.yield %5290 : f16 + } -> tensor<10x4096x4096xf16> + return %232 : tensor<10x4096x4096xf16> +}
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h index 61870c4..9afc0f6 100644 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
@@ -237,6 +237,9 @@ }); } + /// Matches a structured operation if either patterns A or B match. + StructuredOpMatcher(StructuredOpMatcher &A, StructuredOpMatcher &B); + /// Matches the given operation, hook for `matchPattern`. bool match(Operation *op); @@ -515,6 +518,11 @@ /// Creates a matcher of an arbitrary structured op. inline StructuredOpMatcher m_StructuredOp() { return StructuredOpMatcher(); } +inline StructuredOpMatcher m_StructuredOp_Or(StructuredOpMatcher &A, + StructuredOpMatcher &B) { + return StructuredOpMatcher(A, B); +} + /// Creates a matcher of a structured op with kinds provided as template /// arguments. template <typename... OpType> @@ -633,21 +641,22 @@ StructuredOpMatcher &trailing, MatchedReductionCaptures &captures); -/// Create a group of matchers for a sequence of operations matching exactly a -/// softmax operation. +/// Create a group of matchers for a different code sequence of operations +/// matching exactly a softmax operation. /// /// %red = reduce_max(%0) /// %sub = sub(%0, %red) /// %exp = exp(%sub) /// %sum = reduce_sum(%exp) -/// %rec = reciprocal(%sum) -/// %mul = mul(%exp, %rec) +/// %mul = div(%exp, %%sum) void makeSoftmaxMatcher(transform_ext::StructuredOpMatcher &fillMinusInf, transform_ext::StructuredOpMatcher &maxReduction, transform_ext::StructuredOpMatcher &sub, transform_ext::StructuredOpMatcher &expOperand, - transform_ext::StructuredOpMatcher &fillzero, + transform_ext::StructuredOpMatcher &fillZero, transform_ext::StructuredOpMatcher &sum, + transform_ext::StructuredOpMatcher &rcpOperand, + transform_ext::StructuredOpMatcher &mulOperand, transform_ext::StructuredOpMatcher &divOperand, transform_ext::StructuredOpMatcher &softmaxroot);
diff --git a/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp b/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp index 19a8893..560f83b 100644 --- a/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp +++ b/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
@@ -257,6 +257,30 @@ return *this; } +transform_ext::StructuredOpMatcher::StructuredOpMatcher( + StructuredOpMatcher &A, StructuredOpMatcher &B) { + + predicates.push_back([&A, &B](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "start recursive lhs OR match {\n"); + { + auto debugRAII = llvm::make_scope_exit( + [] { LLVM_DEBUG(DBGS() << "} end recursive match"); }); + if (A.match(linalgOp)) + return true; + } + LLVM_DEBUG(DBGS() << "start recursive rhs OR match {\n"); + { + auto debugRAII = llvm::make_scope_exit( + [] { LLVM_DEBUG(DBGS() << "} end recursive match"); }); + if (B.match(linalgOp)) + return true; + } + return false; + }); + recordNestedMatcher(A); + recordNestedMatcher(B); +} + //===---------------------------------------------------------------------===// // Constraints on input operands. //===---------------------------------------------------------------------===// @@ -903,11 +927,12 @@ transform_ext::StructuredOpMatcher &maxReduction, transform_ext::StructuredOpMatcher &sub, transform_ext::StructuredOpMatcher &expOperand, - transform_ext::StructuredOpMatcher &fillzero, + transform_ext::StructuredOpMatcher &fillZero, transform_ext::StructuredOpMatcher &sum, + transform_ext::StructuredOpMatcher &rcpOperand, + transform_ext::StructuredOpMatcher &mulOperand, transform_ext::StructuredOpMatcher &divOperand, transform_ext::StructuredOpMatcher &softmaxroot) { - fillMinusInf = m_StructuredOp<linalg::FillOp>().input(0, ConstantFloatMin()); maxReduction = transform_ext::m_StructuredOp<linalg::GenericOp>() .singleOpWithCanonicaleArgs<arith::MaxFOp>() @@ -941,7 +966,7 @@ .output(NumEqualsTo(1)); expOperand = expOperand.input(0, sub); - fillzero = m_StructuredOp<linalg::FillOp>().input(0, ConstantFloatZero()); + fillZero = m_StructuredOp<linalg::FillOp>().input(0, ConstantFloatZero()); sum = m_StructuredOp<linalg::GenericOp>() .singleOpWithCanonicaleArgs<arith::AddFOp>() // Only handle most inner reduction for now. @@ -952,26 +977,40 @@ .output(AllOperands(), IsProjected(-1)) .output(NumEqualsTo(1)); sum = sum.input(0, expOperand); - sum = sum.output(0, fillzero); + sum = sum.output(0, fillZero); - divOperand = m_StructuredOp<linalg::GenericOp>() + rcpOperand = m_StructuredOp<linalg::GenericOp>() .isFloatReciprocal() .dim(AllDims(), utils::IteratorType::parallel) .input(NumEqualsTo(1)) .input(AllOperands(), IsIdentity()) .output(AllOperands(), IsIdentity()) .output(NumEqualsTo(1)); - divOperand = divOperand.input(0, sum); + rcpOperand = rcpOperand.input(0, sum); - softmaxroot = transform_ext::m_StructuredOp<linalg::GenericOp>() - .singleOpWithCanonicaleArgs<arith::MulFOp>() - .dim(AllDims(), utils::IteratorType::parallel) - .input(NumEqualsTo(2)) - .input(0, IsIdentity()) - .input(1, IsProjected(-1)) - .output(NumEqualsTo(1)) - .output(AllOperands(), IsIdentity()); + mulOperand = transform_ext::m_StructuredOp<linalg::GenericOp>() + .singleOpWithCanonicaleArgs<arith::MulFOp>() + .dim(AllDims(), utils::IteratorType::parallel) + .input(NumEqualsTo(2)) + .input(0, IsIdentity()) + .input(1, IsProjected(-1)) + .output(NumEqualsTo(1)) + .output(AllOperands(), IsIdentity()); - softmaxroot = softmaxroot.input(0, expOperand); - softmaxroot = softmaxroot.input(1, divOperand); + mulOperand = mulOperand.input(0, expOperand); + mulOperand = mulOperand.input(1, rcpOperand); + + divOperand = transform_ext::m_StructuredOp<linalg::GenericOp>() + .singleOpWithCanonicaleArgs<arith::DivFOp>() + .dim(AllDims(), utils::IteratorType::parallel) + .input(NumEqualsTo(2)) + .input(0, IsIdentity()) + .input(1, IsProjected(-1)) + .output(NumEqualsTo(1)) + .output(AllOperands(), IsIdentity()); + + divOperand = divOperand.input(0, expOperand); + divOperand = divOperand.input(1, sum); + + softmaxroot = transform_ext::m_StructuredOp_Or(mulOperand, divOperand); }