Attach Fusion interface to `linalg.softmax` (#18550)
diff --git a/.github/workflows/pkgci_regression_test.yml b/.github/workflows/pkgci_regression_test.yml
index 402158e..3cec71e 100644
--- a/.github/workflows/pkgci_regression_test.yml
+++ b/.github/workflows/pkgci_regression_test.yml
@@ -221,7 +221,7 @@
--goldentime-rocm-clip-ms 18.5 \
--goldentime-rocm-vae-ms 337.0 \
--goldendispatch-rocm-unet 1551 \
- --goldendispatch-rocm-clip 1225 \
+ --goldendispatch-rocm-clip 1139 \
--goldendispatch-rocm-vae 248 \
--goldensize-rocm-unet-bytes 2280000 \
--goldensize-rocm-clip-bytes 860000 \
@@ -242,7 +242,7 @@
--goldentime-rocm-clip-ms 15.5 \
--goldentime-rocm-vae-ms 80.0 \
--goldendispatch-rocm-unet 1551 \
- --goldendispatch-rocm-clip 1225 \
+ --goldendispatch-rocm-clip 1139 \
--goldendispatch-rocm-vae 248 \
--goldensize-rocm-unet-bytes 2270000 \
--goldensize-rocm-clip-bytes 860000 \
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.cpp
index 5996310..5fbe89c 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.cpp
@@ -8,12 +8,14 @@
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/SourceMgr.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
@@ -52,7 +54,6 @@
};
// Used to register the LinalgFusionOpInterface with the linalg ops.
-namespace {
template <typename ConcreteType>
struct LinalgFusionOpInterfaceAdapter
: public LinalgFusionOpInterface::ExternalModel<
@@ -103,6 +104,48 @@
return inputMaps;
}
};
+
+namespace {
+struct SoftmaxFusionOpInterfaceAdapter
+ : public LinalgFusionOpInterface::ExternalModel<
+ SoftmaxFusionOpInterfaceAdapter, linalg::SoftmaxOp> {
+public:
+ SmallVector<AffineMap> getIndexingMapsForOperands(mlir::Operation *op) const {
+ Builder b(op->getContext());
+ return llvm::to_vector(llvm::map_range(
+ llvm::cast<linalg::SoftmaxOp>(op).getDpsInputs(),
+ [&b](Value operand) -> AffineMap {
+ auto rank = cast<ShapedType>(operand.getType()).getRank();
+ return b.getMultiDimIdentityMap(rank);
+ }));
+ }
+
+ SmallVector<AffineMap> getIndexingMapsForResults(mlir::Operation *op) const {
+ Builder b(op->getContext());
+ return llvm::to_vector(llvm::map_range(
+ llvm::cast<linalg::SoftmaxOp>(op).getDpsInits(),
+ [&b](Value operand) -> AffineMap {
+ auto rank = cast<ShapedType>(operand.getType()).getRank();
+ return b.getMultiDimIdentityMap(rank);
+ }));
+ }
+
+ AffineMap getIndexingMapMatchingResult(mlir::Operation *op,
+ OpResult result) const {
+ return getIndexingMapsForResults(op)[result.getResultNumber()];
+ }
+
+ AffineMap getMatchingIndexingMap(mlir::Operation *op,
+ OpOperand *operand) const {
+ return getIndexingMapsForOperands(op)[operand->getOperandNumber()];
+ }
+
+ SmallVector<AffineMap> getIndexingMapsArray(mlir::Operation *op) const {
+ auto inputMaps = getIndexingMapsForOperands(op);
+ llvm::append_range(inputMaps, getIndexingMapsForResults(op));
+ return inputMaps;
+ }
+};
} // namespace
template <typename... Args>
@@ -125,6 +168,8 @@
registerOpsWithLinalgExtOpInterface<
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>(context);
+ linalg::SoftmaxOp::attachInterface<SoftmaxFusionOpInterfaceAdapter>(*context);
+
addInterfaces<IREELinalgExtInlinerInterface>();
addAttributes<
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors_default.mlir b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors_default.mlir
index 6deeda3..dae758c 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors_default.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors_default.mlir
@@ -122,3 +122,26 @@
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK: flow.dispatch.tensor.store %[[GENERIC]]
// CHECK: util.return %[[DISPATCH1]]
+
+util.func @softmax(%arg0: tensor<2x16x32xf32>) -> tensor<2x16x32xf16> {
+ %empty0 = tensor.empty() : tensor<2x16x32xf32>
+ %empty1 = tensor.empty() : tensor<2x16x32xf16>
+ %1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%empty0 : tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
+ %2 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%1 : tensor<2x16x32xf32>) outs(%empty1 : tensor<2x16x32xf16>){
+ ^bb0(%in : f32, %out : f16):
+ %3 = arith.truncf %in : f32 to f16
+ linalg.yield %3 : f16
+ } -> tensor<2x16x32xf16>
+ util.return %2 : tensor<2x16x32xf16>
+}
+
+// CHECK-LABEL: util.func public @softmax
+// CHECK: %[[DISPATCH1:.+]] = flow.dispatch.workgroups
+// CHECK: %[[SOFTMAX:.+]] = linalg.softmax
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[SOFTMAX]]
+// CHECK: flow.dispatch.tensor.store %[[GENERIC]]
+// CHECK: util.return %[[DISPATCH1]]