Add a second code sequence to match softmax (#12104)

diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp
index a71b92b..8181f2b 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp
@@ -35,20 +35,27 @@
   void runOnOperation() override {
     SmallVector<std::pair<linalg::LinalgOp, Value>> softmaxRoots;
     getOperation()->walk([&](linalg::LinalgOp op) {
-      StructuredOpMatcher reduction, fill, leading, trailing;
-      transform_ext::StructuredOpMatcher fillMinusInf;
-      transform_ext::StructuredOpMatcher maxReduction;
-      transform_ext::StructuredOpMatcher sub;
-      transform_ext::StructuredOpMatcher expOperand;
-      transform_ext::StructuredOpMatcher fillzero;
-      transform_ext::StructuredOpMatcher sum;
-      transform_ext::StructuredOpMatcher divOperand;
-      transform_ext::StructuredOpMatcher softmaxroot;
-      makeSoftmaxMatcher(fillMinusInf, maxReduction, sub, expOperand, fillzero,
-                         sum, divOperand, softmaxroot);
-      if (matchPattern(op, softmaxroot)) {
-        Value src = maxReduction.getCaptured()->getOperand(0);
-        softmaxRoots.push_back(std::make_pair(op, src));
+      {
+        transform_ext::StructuredOpMatcher fillMinusInf;
+        transform_ext::StructuredOpMatcher maxReduction;
+        transform_ext::StructuredOpMatcher maybeBroadcastMax;
+        transform_ext::StructuredOpMatcher sub;
+        transform_ext::StructuredOpMatcher broadcastedSub;
+        transform_ext::StructuredOpMatcher expOperand;
+        transform_ext::StructuredOpMatcher fillzero;
+        transform_ext::StructuredOpMatcher sum;
+        transform_ext::StructuredOpMatcher maybeBroadcastSum;
+        transform_ext::StructuredOpMatcher rcpOperand;
+        transform_ext::StructuredOpMatcher matmulOperand;
+        transform_ext::StructuredOpMatcher divOperand;
+        transform_ext::StructuredOpMatcher softmaxroot;
+        makeSoftmaxMatcher(fillMinusInf, maxReduction, sub, expOperand,
+                           fillzero, sum, rcpOperand, matmulOperand, divOperand,
+                           softmaxroot);
+        if (matchPattern(op, softmaxroot)) {
+          Value src = maxReduction.getCaptured()->getOperand(0);
+          softmaxRoots.push_back(std::make_pair(op, src));
+        }
       }
     });
     for (std::pair<linalg::LinalgOp, Value> softmax : softmaxRoots) {
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir
index b93e9e8..5300cd7 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir
@@ -52,3 +52,59 @@
   } -> tensor<?x?x?xf32>
   return %10 : tensor<?x?x?xf32>
 }
+
+// CHECK-LABEL: @softmax_no_rcp
+//  CHECK-SAME: %[[ARG:.+]]: tensor<10x4096x4096xf16>
+//       CHECK:   %[[E:.+]] = tensor.empty() : tensor<10x4096x4096xf16>
+//       CHECK:   %[[S:.+]] = iree_linalg_ext.softmax dimension(2) ins(%[[ARG]] : tensor<10x4096x4096xf16>) outs(%[[E]] : tensor<10x4096x4096xf16>) -> tensor<10x4096x4096xf16>
+//       CHECK:   return %[[S]] : tensor<10x4096x4096xf16>
+func.func @softmax_no_rcp(%src : tensor<10x4096x4096xf16>) -> (tensor<10x4096x4096xf16>) {
+  %cst_158 = arith.constant -6.550400e+04 : f16
+  %cst_121 = arith.constant 0.000000e+00 : f16
+  %224 = tensor.empty() : tensor<10x4096xf16>
+  %216 = tensor.empty() : tensor<10x4096x4096xf16>
+  %225 = linalg.fill ins(%cst_158 : f16) outs(%224 : tensor<10x4096xf16>) -> tensor<10x4096xf16>
+  %226 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%src : tensor<10x4096x4096xf16>) outs(%225 : tensor<10x4096xf16>) {
+  ^bb0(%in: f16, %out: f16):
+    %5290 = arith.maxf %in, %out : f16
+    linalg.yield %5290 : f16
+  } -> tensor<10x4096xf16>
+  %227 = linalg.generic
+  {indexing_maps = [
+    affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+    affine_map<(d0, d1, d2) -> (d0, d1)>,
+    affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+    iterator_types = ["parallel", "parallel", "parallel"]}
+    ins(%src, %226 : tensor<10x4096x4096xf16>, tensor<10x4096xf16>) outs(%216 : tensor<10x4096x4096xf16>) {
+  ^bb0(%in: f16, %in_1572: f16, %out: f16):
+    %5290 = arith.subf %in, %in_1572 : f16
+    linalg.yield %5290 : f16
+  } -> tensor<10x4096x4096xf16>
+  %228 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%227 : tensor<10x4096x4096xf16>) outs(%216 : tensor<10x4096x4096xf16>) {
+  ^bb0(%in: f16, %out: f16):
+    %5290 = math.exp %in : f16
+    linalg.yield %5290 : f16
+  } -> tensor<10x4096x4096xf16>
+  %229 = tensor.empty() : tensor<10x4096xf16>
+  %230 = linalg.fill ins(%cst_121 : f16) outs(%229 : tensor<10x4096xf16>) -> tensor<10x4096xf16>
+  %231 = linalg.generic 
+  {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+  iterator_types = ["parallel", "parallel", "reduction"]}
+  ins(%228 : tensor<10x4096x4096xf16>) outs(%230 : tensor<10x4096xf16>) {
+  ^bb0(%in: f16, %out: f16):
+    %5290 = arith.addf %in, %out : f16
+    linalg.yield %5290 : f16
+  } -> tensor<10x4096xf16>
+  %232 = linalg.generic 
+  {indexing_maps = [
+    affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+    affine_map<(d0, d1, d2) -> (d0, d1)>,
+    affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+    iterator_types = ["parallel", "parallel", "parallel"]}
+    ins(%228, %231 : tensor<10x4096x4096xf16>, tensor<10x4096xf16>) outs(%216 : tensor<10x4096x4096xf16>) {
+  ^bb0(%in: f16, %in_1572: f16, %out: f16):
+    %5290 = arith.divf %in, %in_1572 : f16
+    linalg.yield %5290 : f16
+  } -> tensor<10x4096x4096xf16>
+  return %232 : tensor<10x4096x4096xf16>
+}
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
index 61870c4..9afc0f6 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
@@ -237,6 +237,9 @@
     });
   }
 
+  /// Matches a structured operation if either patterns A or B match.
+  StructuredOpMatcher(StructuredOpMatcher &A, StructuredOpMatcher &B);
+
   /// Matches the given operation, hook for `matchPattern`.
   bool match(Operation *op);
 
@@ -515,6 +518,11 @@
 /// Creates a matcher of an arbitrary structured op.
 inline StructuredOpMatcher m_StructuredOp() { return StructuredOpMatcher(); }
 
+inline StructuredOpMatcher m_StructuredOp_Or(StructuredOpMatcher &A,
+                                             StructuredOpMatcher &B) {
+  return StructuredOpMatcher(A, B);
+}
+
 /// Creates a matcher of a structured op with kinds provided as template
 /// arguments.
 template <typename... OpType>
@@ -633,21 +641,22 @@
                           StructuredOpMatcher &trailing,
                           MatchedReductionCaptures &captures);
 
-/// Create a group of matchers for a sequence of operations matching exactly a
-/// softmax operation.
+/// Create a group of matchers for a different code sequence of operations
+/// matching exactly a softmax operation.
 ///
 ///  %red = reduce_max(%0)
 ///  %sub = sub(%0, %red)
 ///  %exp = exp(%sub)
 ///  %sum = reduce_sum(%exp)
-///  %rec = reciprocal(%sum)
-///  %mul = mul(%exp, %rec)
+///  %mul = div(%exp, %%sum)
 void makeSoftmaxMatcher(transform_ext::StructuredOpMatcher &fillMinusInf,
                         transform_ext::StructuredOpMatcher &maxReduction,
                         transform_ext::StructuredOpMatcher &sub,
                         transform_ext::StructuredOpMatcher &expOperand,
-                        transform_ext::StructuredOpMatcher &fillzero,
+                        transform_ext::StructuredOpMatcher &fillZero,
                         transform_ext::StructuredOpMatcher &sum,
+                        transform_ext::StructuredOpMatcher &rcpOperand,
+                        transform_ext::StructuredOpMatcher &mulOperand,
                         transform_ext::StructuredOpMatcher &divOperand,
                         transform_ext::StructuredOpMatcher &softmaxroot);
 
diff --git a/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp b/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
index 19a8893..560f83b 100644
--- a/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
@@ -257,6 +257,30 @@
   return *this;
 }
 
+transform_ext::StructuredOpMatcher::StructuredOpMatcher(
+    StructuredOpMatcher &A, StructuredOpMatcher &B) {
+
+  predicates.push_back([&A, &B](linalg::LinalgOp linalgOp) -> bool {
+    LLVM_DEBUG(DBGS() << "start recursive lhs OR match {\n");
+    {
+      auto debugRAII = llvm::make_scope_exit(
+          [] { LLVM_DEBUG(DBGS() << "} end recursive match"); });
+      if (A.match(linalgOp))
+        return true;
+    }
+    LLVM_DEBUG(DBGS() << "start recursive rhs OR match {\n");
+    {
+      auto debugRAII = llvm::make_scope_exit(
+          [] { LLVM_DEBUG(DBGS() << "} end recursive match"); });
+      if (B.match(linalgOp))
+        return true;
+    }
+    return false;
+  });
+  recordNestedMatcher(A);
+  recordNestedMatcher(B);
+}
+
 //===---------------------------------------------------------------------===//
 // Constraints on input operands.
 //===---------------------------------------------------------------------===//
@@ -903,11 +927,12 @@
     transform_ext::StructuredOpMatcher &maxReduction,
     transform_ext::StructuredOpMatcher &sub,
     transform_ext::StructuredOpMatcher &expOperand,
-    transform_ext::StructuredOpMatcher &fillzero,
+    transform_ext::StructuredOpMatcher &fillZero,
     transform_ext::StructuredOpMatcher &sum,
+    transform_ext::StructuredOpMatcher &rcpOperand,
+    transform_ext::StructuredOpMatcher &mulOperand,
     transform_ext::StructuredOpMatcher &divOperand,
     transform_ext::StructuredOpMatcher &softmaxroot) {
-
   fillMinusInf = m_StructuredOp<linalg::FillOp>().input(0, ConstantFloatMin());
   maxReduction = transform_ext::m_StructuredOp<linalg::GenericOp>()
                      .singleOpWithCanonicaleArgs<arith::MaxFOp>()
@@ -941,7 +966,7 @@
                    .output(NumEqualsTo(1));
   expOperand = expOperand.input(0, sub);
 
-  fillzero = m_StructuredOp<linalg::FillOp>().input(0, ConstantFloatZero());
+  fillZero = m_StructuredOp<linalg::FillOp>().input(0, ConstantFloatZero());
   sum = m_StructuredOp<linalg::GenericOp>()
             .singleOpWithCanonicaleArgs<arith::AddFOp>()
             // Only handle most inner reduction for now.
@@ -952,26 +977,40 @@
             .output(AllOperands(), IsProjected(-1))
             .output(NumEqualsTo(1));
   sum = sum.input(0, expOperand);
-  sum = sum.output(0, fillzero);
+  sum = sum.output(0, fillZero);
 
-  divOperand = m_StructuredOp<linalg::GenericOp>()
+  rcpOperand = m_StructuredOp<linalg::GenericOp>()
                    .isFloatReciprocal()
                    .dim(AllDims(), utils::IteratorType::parallel)
                    .input(NumEqualsTo(1))
                    .input(AllOperands(), IsIdentity())
                    .output(AllOperands(), IsIdentity())
                    .output(NumEqualsTo(1));
-  divOperand = divOperand.input(0, sum);
+  rcpOperand = rcpOperand.input(0, sum);
 
-  softmaxroot = transform_ext::m_StructuredOp<linalg::GenericOp>()
-                    .singleOpWithCanonicaleArgs<arith::MulFOp>()
-                    .dim(AllDims(), utils::IteratorType::parallel)
-                    .input(NumEqualsTo(2))
-                    .input(0, IsIdentity())
-                    .input(1, IsProjected(-1))
-                    .output(NumEqualsTo(1))
-                    .output(AllOperands(), IsIdentity());
+  mulOperand = transform_ext::m_StructuredOp<linalg::GenericOp>()
+                   .singleOpWithCanonicaleArgs<arith::MulFOp>()
+                   .dim(AllDims(), utils::IteratorType::parallel)
+                   .input(NumEqualsTo(2))
+                   .input(0, IsIdentity())
+                   .input(1, IsProjected(-1))
+                   .output(NumEqualsTo(1))
+                   .output(AllOperands(), IsIdentity());
 
-  softmaxroot = softmaxroot.input(0, expOperand);
-  softmaxroot = softmaxroot.input(1, divOperand);
+  mulOperand = mulOperand.input(0, expOperand);
+  mulOperand = mulOperand.input(1, rcpOperand);
+
+  divOperand = transform_ext::m_StructuredOp<linalg::GenericOp>()
+                   .singleOpWithCanonicaleArgs<arith::DivFOp>()
+                   .dim(AllDims(), utils::IteratorType::parallel)
+                   .input(NumEqualsTo(2))
+                   .input(0, IsIdentity())
+                   .input(1, IsProjected(-1))
+                   .output(NumEqualsTo(1))
+                   .output(AllOperands(), IsIdentity());
+
+  divOperand = divOperand.input(0, expOperand);
+  divOperand = divOperand.input(1, sum);
+
+  softmaxroot = transform_ext::m_StructuredOp_Or(mulOperand, divOperand);
 }