Add matcher to detect softmax op and add pass to raise op (#12084)

diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD
index 01caff8..4dd7d58 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD
@@ -57,6 +57,7 @@
         "OutlineDispatchRegions.cpp",
         "PassDetail.h",
         "Passes.cpp",
+        "RaiseSpecialOps.cpp",
         "RegionOpUtils.cpp",
         "SetEncoding.cpp",
         "SplitReduction.cpp",
@@ -83,6 +84,7 @@
         "//compiler/src/iree/compiler/Dialect/Util/IR",
         "//compiler/src/iree/compiler/Dialect/Util/Transforms",
         "//compiler/src/iree/compiler/Utils",
+        "//llvm-external-projects/iree-dialects:IREEDialectsTransforms",
         "//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
         "//llvm-external-projects/iree-dialects:IREELinalgExtPasses",
         "//llvm-external-projects/iree-dialects:IREELinalgExtTransformOps",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index cfc5586..d00acc8 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -55,6 +55,7 @@
     "OutlineDispatchRegions.cpp"
     "PassDetail.h"
     "Passes.cpp"
+    "RaiseSpecialOps.cpp"
     "RegionOpUtils.cpp"
     "SetEncoding.cpp"
     "SplitReduction.cpp"
@@ -64,6 +65,7 @@
     "VerifyInputLegality.cpp"
   DEPS
     ::PassesIncGen
+    IREEDialectsTransforms
     IREELinalgExtDialect
     IREELinalgExtPasses
     IREELinalgExtTransformOps
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
index 52dfd51..c750b1b 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -186,6 +186,10 @@
 std::unique_ptr<OperationPass<mlir::ModuleOp>>
 createDeduplicateExecutablesPass();
 
+// Create a pass to raise sequence of ops to higher level linalg.ext
+// representation.
+std::unique_ptr<Pass> createRaiseSpecialOps();
+
 // Create a pass to split reduction dimension.
 std::unique_ptr<Pass> createSplitReductionPass();
 
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
index 4efbfcc..926bb37 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
@@ -229,6 +229,12 @@
   let constructor = "mlir::iree_compiler::IREE::Flow::createDumpDispatchGraphPass()";
 }
 
+def RaiseSpecialOps :
+    Pass<"iree-flow-raise-special-ops", ""> {
+  let summary = "raise special ops like softmax to the high level linalg.ext representation";
+  let constructor = "mlir::iree_compiler::IREE::Flow::createRaiseSpecialOps()";
+}
+
 def SplitReduction :
     Pass<"iree-flow-split-reduction-ops", ""> {
   let summary = "Split reduction dimension to increase parallelism.";
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp
new file mode 100644
index 0000000..a71b92b
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp
@@ -0,0 +1,74 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h"
+#include "iree-dialects/Transforms/TransformMatchers.h"
+#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
+#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using transform_ext::StructuredOpMatcher;
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Flow {
+
+namespace {
+
+struct RaiseSpecialOpsPass : public RaiseSpecialOpsBase<RaiseSpecialOpsPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<IREE::LinalgExt::IREELinalgExtDialect>();
+  }
+
+  void runOnOperation() override {
+    SmallVector<std::pair<linalg::LinalgOp, Value>> softmaxRoots;
+    getOperation()->walk([&](linalg::LinalgOp op) {
+      StructuredOpMatcher reduction, fill, leading, trailing;
+      transform_ext::StructuredOpMatcher fillMinusInf;
+      transform_ext::StructuredOpMatcher maxReduction;
+      transform_ext::StructuredOpMatcher sub;
+      transform_ext::StructuredOpMatcher expOperand;
+      transform_ext::StructuredOpMatcher fillzero;
+      transform_ext::StructuredOpMatcher sum;
+      transform_ext::StructuredOpMatcher divOperand;
+      transform_ext::StructuredOpMatcher softmaxroot;
+      makeSoftmaxMatcher(fillMinusInf, maxReduction, sub, expOperand, fillzero,
+                         sum, divOperand, softmaxroot);
+      if (matchPattern(op, softmaxroot)) {
+        Value src = maxReduction.getCaptured()->getOperand(0);
+        softmaxRoots.push_back(std::make_pair(op, src));
+      }
+    });
+    for (std::pair<linalg::LinalgOp, Value> softmax : softmaxRoots) {
+      linalg::LinalgOp op = softmax.first;
+      Value src = softmax.second;
+      IRRewriter rewriter(op.getContext());
+      rewriter.setInsertionPoint(softmax.first);
+      rewriter.replaceOpWithNewOp<IREE::LinalgExt::SoftmaxOp>(
+          op, src, op.getDpsInitOperand(0)->get(), op.getNumLoops() - 1);
+    }
+  }
+};
+
+}  // namespace
+
+std::unique_ptr<Pass> createRaiseSpecialOps() {
+  return std::make_unique<RaiseSpecialOpsPass>();
+}
+
+}  // namespace Flow
+}  // namespace IREE
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD
index 3b873ed..4b23af8 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD
@@ -39,6 +39,7 @@
             "interchange_transpose_generic_ops.mlir",
             "optimize_numerics.mlir",
             "outline_dispatch_regions.mlir",
+            "raise_special_ops.mlir",
             "set_encoding.mlir",
             "strip_and_splat_constant_variables.mlir",
             "strip_signedness.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
index 589d684..7178f26 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
@@ -37,6 +37,7 @@
     "interchange_transpose_generic_ops.mlir"
     "optimize_numerics.mlir"
     "outline_dispatch_regions.mlir"
+    "raise_special_ops.mlir"
     "set_encoding.mlir"
     "strip_and_splat_constant_variables.mlir"
     "strip_signedness.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir
new file mode 100644
index 0000000..b93e9e8
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir
@@ -0,0 +1,54 @@
+// RUN: iree-opt --iree-flow-raise-special-ops -canonicalize %s | FileCheck %s
+
+// CHECK-LABEL: @softmax
+//  CHECK-SAME: %[[ARG:.+]]: tensor<?x?x?xf32>
+//       CHECK:   %[[E:.+]] = tensor.empty(%{{.*}}, %{{.*}}, %{{.*}}) : tensor<?x?x?xf32>
+//       CHECK:   %[[S:.+]] = iree_linalg_ext.softmax dimension(2) ins(%[[ARG]] : tensor<?x?x?xf32>) outs(%[[E]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+//       CHECK:   return %[[S]] : tensor<?x?x?xf32>
+
+func.func @softmax(%src : tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>) {
+  %cst = arith.constant 1.000000e+00 : f32
+  %cst_0 = arith.constant 0.000000e+00 : f32
+  %cst_1 = arith.constant -3.40282347E+38 : f32
+  %c_0_index = arith.constant 0 : index
+  %c_1_index = arith.constant 1 : index
+  %c_2_index = arith.constant 2 : index
+  %dim_0 = tensor.dim %src, %c_0_index : tensor<?x?x?xf32>
+  %dim_1 = tensor.dim %src, %c_1_index : tensor<?x?x?xf32>
+  %dim_2 = tensor.dim %src, %c_2_index : tensor<?x?x?xf32>
+  %1 = tensor.empty(%dim_0, %dim_1) : tensor<?x?xf32>
+  %2 = linalg.fill ins(%cst_1 : f32) outs(%1 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%src : tensor<?x?x?xf32>) outs(%2 : tensor<?x?xf32>) {
+  ^bb0(%arg0: f32, %arg1: f32):
+    %11 = arith.maxf %arg0, %arg1 : f32
+    linalg.yield %11 : f32
+  } -> tensor<?x?xf32>
+  %4 = tensor.empty(%dim_0, %dim_1, %dim_2) : tensor<?x?x?xf32>
+  %5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%src, %3 : tensor<?x?x?xf32>, tensor<?x?xf32>) outs(%4 : tensor<?x?x?xf32>) {
+  ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
+    %11 = arith.subf %arg0, %arg1 : f32
+    linalg.yield %11 : f32
+  } -> tensor<?x?x?xf32>
+  %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%5 : tensor<?x?x?xf32>) outs(%4 : tensor<?x?x?xf32>) {
+  ^bb0(%arg0: f32, %arg1: f32):
+    %11 = math.exp %arg0 : f32
+    linalg.yield %11 : f32
+  } -> tensor<?x?x?xf32>
+  %7 = linalg.fill ins(%cst_0 : f32) outs(%1 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%6 : tensor<?x?x?xf32>) outs(%7 : tensor<?x?xf32>) {
+  ^bb0(%arg0: f32, %arg1: f32):
+    %11 = arith.addf %arg0, %arg1 : f32
+    linalg.yield %11 : f32
+  } -> tensor<?x?xf32>
+  %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%8 : tensor<?x?xf32>) outs(%1 : tensor<?x?xf32>) {
+  ^bb0(%arg0: f32, %arg1: f32):
+    %11 = arith.divf %cst, %arg0 : f32
+    linalg.yield %11 : f32
+  } -> tensor<?x?xf32>
+  %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%6, %9 : tensor<?x?x?xf32>, tensor<?x?xf32>) outs(%4 : tensor<?x?x?xf32>) {
+  ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
+    %11 = arith.mulf %arg0, %arg1 : f32
+    linalg.yield %11 : f32
+  } -> tensor<?x?x?xf32>
+  return %10 : tensor<?x?x?xf32>
+}
diff --git a/llvm-external-projects/iree-dialects/BUILD b/llvm-external-projects/iree-dialects/BUILD
index 276494a..2c9b105 100644
--- a/llvm-external-projects/iree-dialects/BUILD
+++ b/llvm-external-projects/iree-dialects/BUILD
@@ -146,10 +146,12 @@
     deps = [
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:Analysis",
+        "@llvm-project//mlir:ArithDialect",
         "@llvm-project//mlir:DialectUtils",
         "@llvm-project//mlir:FuncDialect",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:LinalgDialect",
+        "@llvm-project//mlir:MathDialect",
         "@llvm-project//mlir:Rewrite",
         "@llvm-project//mlir:SCFDialect",
         "@llvm-project//mlir:TensorDialect",
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td
index 4768ba1..fe5dbb9 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -1135,7 +1135,7 @@
   );
 
   let builders = [
-    OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
+    OpBuilder<(ins "Value":$inputs, "Value":$outputs,
       CArg<"int64_t", "0">:$dimension)>
   ];
 
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
index 21f4060..61870c4 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
@@ -126,6 +126,25 @@
 /// Predicate tag indicating that the affine map is a projected permutation.
 struct IsProjectedPermutation {};
 
+/// Predicate tag indicating that the affine map is a projection of given
+/// dimension.
+struct IsProjected : public SingleValuePredicateParam<int64_t> {
+  using Base::Base;
+};
+/// Predicate tag indicating that the affine map is an identity.
+struct IsIdentity {};
+
+/// Predicate tag indicating that the operand is a special float constant.
+struct ConstantFloatMin {};
+struct ConstantFloatZero {};
+
+/// Predicate indicating that the operand is the same value as its producer's
+/// operand.
+struct SameOperandAsProducer
+    : public SingleValuePredicateParam<std::pair<int64_t, int64_t>> {
+  using Base::Base;
+};
+
 /// Indicates that the match optional. The matcher is still expected to run and
 /// capture if successful. The parameter can be set to false
 struct OptionalMatch : public SingleValuePredicateParam<bool> {
@@ -298,6 +317,9 @@
     return *this;
   }
 
+  StructuredOpMatcher &input(int64_t position,
+                             SameOperandAsProducer parentPosition);
+
   /// Adds a predicate checking that all input operands of the structured op
   /// have a permutation indexing map.
   StructuredOpMatcher &input(AllOperands tag, IsPermutation);
@@ -306,6 +328,21 @@
   /// have a projected permutation indexing map.
   StructuredOpMatcher &input(AllOperands tag, IsProjectedPermutation);
 
+  /// Adds a predicate checking that all input operands of the structured op
+  /// are projected along the given dimension.
+  StructuredOpMatcher &input(SmallVector<int64_t> &&positions, IsProjected dim);
+  StructuredOpMatcher &input(int64_t position, IsProjected dim) {
+    return input(SmallVector<int64_t>{position}, dim);
+  }
+
+  /// Adds a predicate checking that all input operands of the structured op
+  /// have identity indexing map.
+  StructuredOpMatcher &input(AllOperands tag, IsIdentity);
+  StructuredOpMatcher &input(SmallVector<int64_t> &&positions, IsIdentity);
+  StructuredOpMatcher &input(int64_t position, IsIdentity) {
+    return input(SmallVector<int64_t>{position}, IsIdentity());
+  }
+
   /// Adds a predicate checking that the bit width of the elemental type of the
   /// structured op input at the given position is equal to the given value.
   StructuredOpMatcher &input(int64_t position, ElementTypeBitWidth width);
@@ -314,6 +351,11 @@
   StructuredOpMatcher &input(int64_t position,
                              CaptureElementTypeBitWidth width);
 
+  /// Check if input is equal to a known constant.
+  // TODO: Support matching for constant ops.
+  StructuredOpMatcher &input(int64_t position, ConstantFloatMin);
+  StructuredOpMatcher &input(int64_t position, ConstantFloatZero);
+
   //===-------------------------------------------------------------------===//
   // Constraints on adjacent ops.
   //===-------------------------------------------------------------------===//
@@ -351,6 +393,14 @@
   /// have a projected permutation indexing map.
   StructuredOpMatcher &output(AllOperands tag, IsProjectedPermutation);
 
+  /// Adds a predicate checking that all output operands of the structured op
+  /// have a
+  StructuredOpMatcher &output(AllOperands tag, IsProjected dim);
+
+  /// Adds a predicate checking that all output operands of the structured op
+  /// have identity indexing map.
+  StructuredOpMatcher &output(AllOperands tag, IsIdentity);
+
   /// Adds a predicate checking that the bit width of the elemental type of the
   /// structured op output at the given position is equal to the given value.
   StructuredOpMatcher &output(int64_t position, ElementTypeBitWidth width);
@@ -410,6 +460,28 @@
     return *this;
   }
 
+  //===-------------------------------------------------------------------===//
+  // Constraints on op region.
+  //===-------------------------------------------------------------------===//
+
+  /// Return true if the linalg op only contains a single ops and the arguments
+  /// of the operation match the order of the linalg operand.
+  /// Example:
+  ///   linalg.generic
+  ///     ins(%0, %1 : tensor<?x?x?xf32>, tensor<?x?xf32>)
+  ///     outs(%2 : tensor<?x?x?xf32>) {
+  ///     ^bb0(%arg0: f32, %arg1: f32):
+  ///     %3 = arith.maxf %arg0, %arg1 : f32
+  ///     linalg.yield %3 : f32
+  ///   } -> tensor<?x?xf32>
+  template <typename OpType>
+  StructuredOpMatcher &singleOpWithCanonicaleArgs() {
+    return singleOpWithCanonicaleArgs(OpType::getOperationName());
+  }
+  StructuredOpMatcher &singleOpWithCanonicaleArgs(StringRef opname);
+  /// Check if the op is a linalg of with a single float reciprocal op.
+  StructuredOpMatcher &isFloatReciprocal();
+
 private:
   /// Checks that `matchers` captured all tilable ops nested in `parent` except
   /// for `linalgOp`. This is an implementation detail of allTilableOpsCaptured.
@@ -432,6 +504,10 @@
                         std::function<bool(Operation *)> matcher,
                         OptionalMatch optional);
 
+  // Common util for constant matcher.
+  StructuredOpMatcher &input(int64_t position,
+                             std::function<bool(llvm::APFloat)> floatValueFn);
+
   /// Additional predicates to be checked on the structured op.
   SmallVector<PredicateFn> predicates;
 };
@@ -557,6 +633,24 @@
                           StructuredOpMatcher &trailing,
                           MatchedReductionCaptures &captures);
 
+/// Create a group of matchers for a sequence of operations matching exactly a
+/// softmax operation.
+///
+///  %red = reduce_max(%0)
+///  %sub = sub(%0, %red)
+///  %exp = exp(%sub)
+///  %sum = reduce_sum(%exp)
+///  %rec = reciprocal(%sum)
+///  %mul = mul(%exp, %rec)
+void makeSoftmaxMatcher(transform_ext::StructuredOpMatcher &fillMinusInf,
+                        transform_ext::StructuredOpMatcher &maxReduction,
+                        transform_ext::StructuredOpMatcher &sub,
+                        transform_ext::StructuredOpMatcher &expOperand,
+                        transform_ext::StructuredOpMatcher &fillzero,
+                        transform_ext::StructuredOpMatcher &sum,
+                        transform_ext::StructuredOpMatcher &divOperand,
+                        transform_ext::StructuredOpMatcher &softmaxroot);
+
 } // namespace transform_ext
 } // namespace mlir
 
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index c088685..dfd53c5 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -2844,6 +2844,12 @@
       .reifyResultShapes(b, reifiedReturnShapes);
 }
 
+void SoftmaxOp::build(OpBuilder &builder, OperationState &state, Value source,
+                      Value output, int64_t dimension) {
+  build(builder, state, TypeRange({output.getType()}), ValueRange(source),
+        ValueRange(output), dimension);
+}
+
 //===----------------------------------------------------------------------===//
 // AttentionOp
 //===----------------------------------------------------------------------===//
diff --git a/llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt
index eb77eea..5b233cf 100644
--- a/llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt
@@ -9,9 +9,11 @@
   # TODO: break dialect dependency by implementing the transformation separately
   # and registering it.
   MLIRAsyncDialect
+  MLIRArithDialect
   MLIRFuncDialect
   MLIRLinalgDialect
   MLIRLinalgTransforms
+  MLIRMathDialect
 
   DEPENDS
   mlir-headers
diff --git a/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp b/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
index 711d1a7..19a8893 100644
--- a/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
@@ -7,8 +7,12 @@
 #include "iree-dialects/Transforms/TransformMatchers.h"
 
 #include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
@@ -318,6 +322,76 @@
   return *this;
 }
 
+/// Helper to check if the map is an identity map with a projected dim.
+static bool isProjectedMap(AffineMap map, int64_t projectedDim) {
+  if (!map.isProjectedPermutation())
+    return false;
+  int64_t dimCounter = 0;
+  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
+    // Skip the project dim.
+    if (dimCounter == projectedDim)
+      dimCounter++;
+    if (map.getDimPosition(i) != dimCounter++) {
+      return false;
+    }
+  }
+  return true;
+}
+
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::input(SmallVector<int64_t> &&positions,
+                                          IsProjected dim) {
+  predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
+    LLVM_DEBUG(DBGS() << "operands ";
+               llvm::interleaveComma(positions, llvm::dbgs());
+               llvm::dbgs() << " have a permutation maps with " << dim.value
+                            << " projected\n");
+    int64_t updatedDim =
+        dim.value >= 0 ? dim.value : linalgOp.getNumLoops() + dim.value;
+    for (int64_t position : positions) {
+      OpOperand *operand = linalgOp.getDpsInputOperand(position);
+      if (!isProjectedMap(linalgOp.getMatchingIndexingMap(operand), updatedDim))
+        return false;
+    }
+    return true;
+  });
+  return *this;
+}
+
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::input(AllOperands tag, IsIdentity) {
+  predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
+    LLVM_DEBUG(DBGS() << "all input operands have identity maps");
+    // all_of with a lambda requires const-casting dance, so using a loop.
+    for (OpOperand *operand : linalgOp.getDpsInputOperands()) {
+      if (!linalgOp.getMatchingIndexingMap(operand).isIdentity())
+        return false;
+    }
+    return true;
+  });
+  return *this;
+}
+
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::input(SmallVector<int64_t> &&positions,
+                                          IsIdentity) {
+  predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
+    LLVM_DEBUG(DBGS() << "input operands ";
+               llvm::interleaveComma(positions, llvm::dbgs());
+               llvm::dbgs() << " have identity maps");
+    // all_of with a lambda requires const-casting dance, so using a loop.
+    for (int64_t position : positions) {
+      int64_t updatedPosition =
+          position >= 0 ? position : linalgOp.getNumDpsInputs() + position;
+      OpOperand *operand = linalgOp.getDpsInputOperand(updatedPosition);
+      if (!linalgOp.getMatchingIndexingMap(operand).isIdentity())
+        return false;
+    }
+    return true;
+  });
+  return *this;
+}
+
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::input(int64_t position,
                                           ElementTypeBitWidth width) {
@@ -368,6 +442,72 @@
   return *this;
 }
 
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::input(int64_t position, ConstantFloatMin) {
+  return input(position,
+               [](llvm::APFloat f) { return f.isLargest() && f.isNegative(); });
+}
+
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::input(int64_t position, ConstantFloatZero) {
+  return input(position, [](llvm::APFloat f) { return f.isZero(); });
+}
+
+transform_ext::StructuredOpMatcher &transform_ext::StructuredOpMatcher::input(
+    int64_t position, std::function<bool(llvm::APFloat)> floatValueFn) {
+  predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
+    LLVM_DEBUG(DBGS() << "input operands " << position
+                      << " is a special floating point constant");
+    int64_t updatedPosition =
+        position >= 0 ? position : linalgOp.getNumDpsInputs() + position;
+    if (0 > updatedPosition || updatedPosition >= linalgOp.getNumDpsInputs())
+      return false;
+    auto cstOp = linalgOp.getDpsInputOperand(updatedPosition)
+                     ->get()
+                     .getDefiningOp<arith::ConstantFloatOp>();
+    if (!cstOp)
+      return false;
+    return floatValueFn(cstOp.value());
+  });
+  return *this;
+}
+
+transform_ext::StructuredOpMatcher &transform_ext::StructuredOpMatcher::input(
+    int64_t position, SameOperandAsProducer producerPosition) {
+  predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
+    LLVM_DEBUG(DBGS() << "input operand " << position
+                      << " is the same value as operand "
+                      << producerPosition.value.second << " of the input "
+                      << producerPosition.value.first << " producer");
+    int64_t updatedPosition =
+        position >= 0 ? position : linalgOp.getNumDpsInputs() + position;
+    if (0 > updatedPosition || updatedPosition >= linalgOp.getNumDpsInputs())
+      return false;
+    int64_t updatedProducerPosition =
+        producerPosition.value.first >= 0
+            ? producerPosition.value.first
+            : linalgOp.getNumDpsInputs() + producerPosition.value.first;
+    if (0 > updatedProducerPosition ||
+        updatedProducerPosition >= linalgOp.getNumDpsInputs())
+      return false;
+    Operation *producer = linalgOp.getDpsInputOperand(updatedProducerPosition)
+                              ->get()
+                              .getDefiningOp();
+    if (!producer)
+      return false;
+    int64_t producerOperandPos =
+        producerPosition.value.second >= 0
+            ? producerPosition.value.second
+            : producer->getNumOperands() + producerPosition.value.second;
+    if (0 > producerOperandPos ||
+        updatedProducerPosition >= producer->getNumOperands())
+      return false;
+    return producer->getOperand(producerOperandPos) ==
+           linalgOp.getDpsInputOperand(position)->get();
+  });
+  return *this;
+}
+
 //===---------------------------------------------------------------------===//
 // Constraints on output operands.
 //===---------------------------------------------------------------------===//
@@ -428,6 +568,35 @@
 }
 
 transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::output(AllOperands tag, IsProjected dim) {
+  predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
+    LLVM_DEBUG(DBGS() << "all output operands have a maps with projected");
+    int64_t updatedDim =
+        dim.value >= 0 ? dim.value : linalgOp.getNumLoops() + dim.value;
+    // all_of with a lambda requires const-casting dance, so using a loop.
+    for (OpOperand *operand : linalgOp.getDpsInitOperands()) {
+      if (!isProjectedMap(linalgOp.getMatchingIndexingMap(operand), updatedDim))
+        return false;
+    }
+    return true;
+  });
+  return *this;
+}
+
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::output(AllOperands tag, IsIdentity) {
+  predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
+    LLVM_DEBUG(DBGS() << "all output operands have identity permutation maps");
+    for (OpOperand *operand : linalgOp.getDpsInitOperands()) {
+      if (!linalgOp.getMatchingIndexingMap(operand).isIdentity())
+        return false;
+    }
+    return true;
+  });
+  return *this;
+}
+
+transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::output(int64_t position,
                                            ElementTypeBitWidth width) {
   predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
@@ -546,6 +715,62 @@
   return numTilableOps == matched.size();
 }
 
+//===-------------------------------------------------------------------===//
+// Constraints on op region.
+//===-------------------------------------------------------------------===//
+
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::singleOpWithCanonicaleArgs(
+    StringRef opcode) {
+  predicates.push_back([=](linalg::LinalgOp linalgOp) {
+    if (linalgOp.getBlock()->getOperations().size() != 2)
+      return false;
+    Operation *innerOp = &(*linalgOp.getBlock()->getOperations().begin());
+    if (innerOp->getName().getStringRef() != opcode ||
+        innerOp->getNumResults() != 1)
+      return false;
+    Operation *yieldOp = linalgOp.getBlock()->getTerminator();
+    if (yieldOp->getNumOperands() != 1)
+      return false;
+    if (yieldOp->getOperand(0).getDefiningOp() != innerOp)
+      return false;
+    for (auto [index, operand] : llvm::enumerate(innerOp->getOperands())) {
+      auto arg = dyn_cast<BlockArgument>(operand);
+      if (!arg || arg.getParentBlock() != linalgOp.getBlock() ||
+          arg.getArgNumber() != index)
+        return false;
+    }
+    return true;
+  });
+  return *this;
+}
+
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::isFloatReciprocal() {
+  predicates.push_back([=](linalg::LinalgOp linalgOp) {
+    LLVM_DEBUG(DBGS() << "op region represents a reciprocal operation");
+    if (linalgOp.getBlock()->getOperations().size() != 2)
+      return false;
+    Operation *innerOp = &(*linalgOp.getBlock()->getOperations().begin());
+    if (!isa<arith::DivFOp>(innerOp) || innerOp->getNumResults() != 1)
+      return false;
+    Operation *yieldOp = linalgOp.getBlock()->getTerminator();
+    if (yieldOp->getNumOperands() != 1)
+      return false;
+    if (yieldOp->getOperand(0).getDefiningOp() != innerOp)
+      return false;
+    auto cst = innerOp->getOperand(0).getDefiningOp<arith::ConstantFloatOp>();
+    if (!cst || cst.value().convertToDouble() != 1.0)
+      return false;
+    auto arg = dyn_cast<BlockArgument>(innerOp->getOperand(1));
+    if (!arg || arg.getParentBlock() != linalgOp.getBlock() ||
+        arg.getArgNumber() != 0)
+      return false;
+    return true;
+  });
+  return *this;
+}
+
 //===---------------------------------------------------------------------===//
 // MatchCallbackResult.
 //===---------------------------------------------------------------------===//
@@ -672,3 +897,81 @@
   reduction = reduction.result(0, HasAnyUse(), trailing, OptionalMatch())
                   .allTilableOpsCaptured<func::FuncOp>();
 }
+
+void transform_ext::makeSoftmaxMatcher(
+    transform_ext::StructuredOpMatcher &fillMinusInf,
+    transform_ext::StructuredOpMatcher &maxReduction,
+    transform_ext::StructuredOpMatcher &sub,
+    transform_ext::StructuredOpMatcher &expOperand,
+    transform_ext::StructuredOpMatcher &fillzero,
+    transform_ext::StructuredOpMatcher &sum,
+    transform_ext::StructuredOpMatcher &divOperand,
+    transform_ext::StructuredOpMatcher &softmaxroot) {
+
+  fillMinusInf = m_StructuredOp<linalg::FillOp>().input(0, ConstantFloatMin());
+  maxReduction = transform_ext::m_StructuredOp<linalg::GenericOp>()
+                     .singleOpWithCanonicaleArgs<arith::MaxFOp>()
+                     // Only handle most inner reduction for now.
+                     .dim(-1, utils::IteratorType::reduction)
+                     .dim(AllDimsExcept({-1}), utils::IteratorType::parallel)
+                     .input(NumEqualsTo(1))
+                     .input(AllOperands(), IsIdentity())
+                     .output(NumEqualsTo(1))
+                     .output(AllOperands(), IsProjected(-1));
+  maxReduction = maxReduction.output(0, fillMinusInf);
+
+  sub = transform_ext::m_StructuredOp<linalg::GenericOp>()
+            .singleOpWithCanonicaleArgs<arith::SubFOp>()
+            .dim(AllDims(), utils::IteratorType::parallel)
+            .input(NumEqualsTo(2))
+            .input(0, IsIdentity())
+            .input(1, IsProjected(-1))
+            .output(NumEqualsTo(1))
+            .output(AllOperands(), IsIdentity());
+  sub =
+      sub.input(0, SameOperandAsProducer(std::pair<int64_t, int64_t>({1, 0})));
+  sub = sub.input(1, maxReduction);
+
+  expOperand = m_StructuredOp<linalg::GenericOp>()
+                   .singleOpWithCanonicaleArgs<math::ExpOp>()
+                   .dim(AllDims(), utils::IteratorType::parallel)
+                   .input(NumEqualsTo(1))
+                   .input(AllOperands(), IsIdentity())
+                   .output(AllOperands(), IsIdentity())
+                   .output(NumEqualsTo(1));
+  expOperand = expOperand.input(0, sub);
+
+  fillzero = m_StructuredOp<linalg::FillOp>().input(0, ConstantFloatZero());
+  sum = m_StructuredOp<linalg::GenericOp>()
+            .singleOpWithCanonicaleArgs<arith::AddFOp>()
+            // Only handle most inner reduction for now.
+            .dim(-1, utils::IteratorType::reduction)
+            .dim(AllDimsExcept({-1}), utils::IteratorType::parallel)
+            .input(NumEqualsTo(1))
+            .input(AllOperands(), IsIdentity())
+            .output(AllOperands(), IsProjected(-1))
+            .output(NumEqualsTo(1));
+  sum = sum.input(0, expOperand);
+  sum = sum.output(0, fillzero);
+
+  divOperand = m_StructuredOp<linalg::GenericOp>()
+                   .isFloatReciprocal()
+                   .dim(AllDims(), utils::IteratorType::parallel)
+                   .input(NumEqualsTo(1))
+                   .input(AllOperands(), IsIdentity())
+                   .output(AllOperands(), IsIdentity())
+                   .output(NumEqualsTo(1));
+  divOperand = divOperand.input(0, sum);
+
+  softmaxroot = transform_ext::m_StructuredOp<linalg::GenericOp>()
+                    .singleOpWithCanonicaleArgs<arith::MulFOp>()
+                    .dim(AllDims(), utils::IteratorType::parallel)
+                    .input(NumEqualsTo(2))
+                    .input(0, IsIdentity())
+                    .input(1, IsProjected(-1))
+                    .output(NumEqualsTo(1))
+                    .output(AllOperands(), IsIdentity());
+
+  softmaxroot = softmaxroot.input(0, expOperand);
+  softmaxroot = softmaxroot.input(1, divOperand);
+}