Attach Fusion interface to `linalg.softmax` (#18550)

diff --git a/.github/workflows/pkgci_regression_test.yml b/.github/workflows/pkgci_regression_test.yml
index 402158e..3cec71e 100644
--- a/.github/workflows/pkgci_regression_test.yml
+++ b/.github/workflows/pkgci_regression_test.yml
@@ -221,7 +221,7 @@
             --goldentime-rocm-clip-ms 18.5 \
             --goldentime-rocm-vae-ms 337.0 \
             --goldendispatch-rocm-unet 1551 \
-            --goldendispatch-rocm-clip 1225 \
+            --goldendispatch-rocm-clip 1139 \
             --goldendispatch-rocm-vae 248 \
             --goldensize-rocm-unet-bytes 2280000  \
             --goldensize-rocm-clip-bytes 860000 \
@@ -242,7 +242,7 @@
             --goldentime-rocm-clip-ms 15.5 \
             --goldentime-rocm-vae-ms 80.0 \
             --goldendispatch-rocm-unet 1551 \
-            --goldendispatch-rocm-clip 1225 \
+            --goldendispatch-rocm-clip 1139 \
             --goldendispatch-rocm-vae 248 \
             --goldensize-rocm-unet-bytes 2270000 \
             --goldensize-rocm-clip-bytes 860000  \
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.cpp
index 5996310..5fbe89c 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.cpp
@@ -8,12 +8,14 @@
 
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/SourceMgr.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/OpDefinition.h"
@@ -52,7 +54,6 @@
 };
 
 // Used to register the LinalgFusionOpInterface with the linalg ops.
-namespace {
 template <typename ConcreteType>
 struct LinalgFusionOpInterfaceAdapter
     : public LinalgFusionOpInterface::ExternalModel<
@@ -103,6 +104,48 @@
     return inputMaps;
   }
 };
+
+namespace {
+struct SoftmaxFusionOpInterfaceAdapter
+    : public LinalgFusionOpInterface::ExternalModel<
+          SoftmaxFusionOpInterfaceAdapter, linalg::SoftmaxOp> {
+public:
+  SmallVector<AffineMap> getIndexingMapsForOperands(mlir::Operation *op) const {
+    Builder b(op->getContext());
+    return llvm::to_vector(llvm::map_range(
+        llvm::cast<linalg::SoftmaxOp>(op).getDpsInputs(),
+        [&b](Value operand) -> AffineMap {
+          auto rank = cast<ShapedType>(operand.getType()).getRank();
+          return b.getMultiDimIdentityMap(rank);
+        }));
+  }
+
+  SmallVector<AffineMap> getIndexingMapsForResults(mlir::Operation *op) const {
+    Builder b(op->getContext());
+    return llvm::to_vector(llvm::map_range(
+        llvm::cast<linalg::SoftmaxOp>(op).getDpsInits(),
+        [&b](Value operand) -> AffineMap {
+          auto rank = cast<ShapedType>(operand.getType()).getRank();
+          return b.getMultiDimIdentityMap(rank);
+        }));
+  }
+
+  AffineMap getIndexingMapMatchingResult(mlir::Operation *op,
+                                         OpResult result) const {
+    return getIndexingMapsForResults(op)[result.getResultNumber()];
+  }
+
+  AffineMap getMatchingIndexingMap(mlir::Operation *op,
+                                   OpOperand *operand) const {
+    return getIndexingMapsForOperands(op)[operand->getOperandNumber()];
+  }
+
+  SmallVector<AffineMap> getIndexingMapsArray(mlir::Operation *op) const {
+    auto inputMaps = getIndexingMapsForOperands(op);
+    llvm::append_range(inputMaps, getIndexingMapsForResults(op));
+    return inputMaps;
+  }
+};
 } // namespace
 
 template <typename... Args>
@@ -125,6 +168,8 @@
   registerOpsWithLinalgExtOpInterface<
 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
       >(context);
+  linalg::SoftmaxOp::attachInterface<SoftmaxFusionOpInterfaceAdapter>(*context);
+
   addInterfaces<IREELinalgExtInlinerInterface>();
 
   addAttributes<
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors_default.mlir b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors_default.mlir
index 6deeda3..dae758c 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors_default.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors_default.mlir
@@ -122,3 +122,26 @@
 //       CHECK:     %[[GENERIC:.+]] = linalg.generic
 //       CHECK:     flow.dispatch.tensor.store %[[GENERIC]]
 //       CHECK:   util.return %[[DISPATCH1]]
+
+util.func @softmax(%arg0: tensor<2x16x32xf32>) -> tensor<2x16x32xf16> {
+  %empty0 = tensor.empty() : tensor<2x16x32xf32>
+  %empty1 = tensor.empty() : tensor<2x16x32xf16>
+  %1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%empty0 : tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
+  %2 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+    iterator_types = ["parallel", "parallel", "parallel"]}
+    ins(%1 : tensor<2x16x32xf32>) outs(%empty1 : tensor<2x16x32xf16>){
+    ^bb0(%in : f32, %out : f16):
+      %3 = arith.truncf %in : f32 to f16
+      linalg.yield %3 : f16
+    } -> tensor<2x16x32xf16>
+    util.return %2 : tensor<2x16x32xf16>
+}
+
+// CHECK-LABEL: util.func public @softmax
+//       CHECK:   %[[DISPATCH1:.+]] = flow.dispatch.workgroups
+//       CHECK:     %[[SOFTMAX:.+]] = linalg.softmax
+//       CHECK:     %[[GENERIC:.+]] = linalg.generic
+//  CHECK-SAME:       ins(%[[SOFTMAX]]
+//       CHECK:     flow.dispatch.tensor.store %[[GENERIC]]
+//       CHECK:   util.return %[[DISPATCH1]]