Add matcher to detect softmax op and add pass to raise op (#12084)
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD index 01caff8..4dd7d58 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD
@@ -57,6 +57,7 @@ "OutlineDispatchRegions.cpp", "PassDetail.h", "Passes.cpp", + "RaiseSpecialOps.cpp", "RegionOpUtils.cpp", "SetEncoding.cpp", "SplitReduction.cpp", @@ -83,6 +84,7 @@ "//compiler/src/iree/compiler/Dialect/Util/IR", "//compiler/src/iree/compiler/Dialect/Util/Transforms", "//compiler/src/iree/compiler/Utils", + "//llvm-external-projects/iree-dialects:IREEDialectsTransforms", "//llvm-external-projects/iree-dialects:IREELinalgExtDialect", "//llvm-external-projects/iree-dialects:IREELinalgExtPasses", "//llvm-external-projects/iree-dialects:IREELinalgExtTransformOps",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt index cfc5586..d00acc8 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -55,6 +55,7 @@ "OutlineDispatchRegions.cpp" "PassDetail.h" "Passes.cpp" + "RaiseSpecialOps.cpp" "RegionOpUtils.cpp" "SetEncoding.cpp" "SplitReduction.cpp" @@ -64,6 +65,7 @@ "VerifyInputLegality.cpp" DEPS ::PassesIncGen + IREEDialectsTransforms IREELinalgExtDialect IREELinalgExtPasses IREELinalgExtTransformOps
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h index 52dfd51..c750b1b 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -186,6 +186,10 @@ std::unique_ptr<OperationPass<mlir::ModuleOp>> createDeduplicateExecutablesPass(); +// Create a pass to raise sequence of ops to higher level linalg.ext +// representation. +std::unique_ptr<Pass> createRaiseSpecialOps(); + // Create a pass to split reduction dimension. std::unique_ptr<Pass> createSplitReductionPass();
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td index 4efbfcc..926bb37 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
@@ -229,6 +229,12 @@ let constructor = "mlir::iree_compiler::IREE::Flow::createDumpDispatchGraphPass()"; } +def RaiseSpecialOps : + Pass<"iree-flow-raise-special-ops", ""> { + let summary = "raise special ops like softmax to the high level linalg.ext representation"; + let constructor = "mlir::iree_compiler::IREE::Flow::createRaiseSpecialOps()"; +} + def SplitReduction : Pass<"iree-flow-split-reduction-ops", ""> { let summary = "Split reduction dimension to increase parallelism.";
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp new file mode 100644 index 0000000..a71b92b --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp
@@ -0,0 +1,74 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h" +#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.h" +#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" +#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h" +#include "iree-dialects/Transforms/TransformMatchers.h" +#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h" +#include "iree/compiler/Dialect/Flow/Transforms/Passes.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using transform_ext::StructuredOpMatcher; + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Flow { + +namespace { + +struct RaiseSpecialOpsPass : public RaiseSpecialOpsBase<RaiseSpecialOpsPass> { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<IREE::LinalgExt::IREELinalgExtDialect>(); + } + + void runOnOperation() override { + SmallVector<std::pair<linalg::LinalgOp, Value>> softmaxRoots; + getOperation()->walk([&](linalg::LinalgOp op) { + StructuredOpMatcher reduction, fill, leading, trailing; + transform_ext::StructuredOpMatcher fillMinusInf; + transform_ext::StructuredOpMatcher maxReduction; + transform_ext::StructuredOpMatcher sub; + transform_ext::StructuredOpMatcher expOperand; + transform_ext::StructuredOpMatcher fillzero; + transform_ext::StructuredOpMatcher sum; + transform_ext::StructuredOpMatcher divOperand; + transform_ext::StructuredOpMatcher softmaxroot; + makeSoftmaxMatcher(fillMinusInf, maxReduction, sub, expOperand, fillzero, + sum, divOperand, softmaxroot); + if (matchPattern(op, softmaxroot)) { + Value src = maxReduction.getCaptured()->getOperand(0); + softmaxRoots.push_back(std::make_pair(op, src)); + } + }); + for (std::pair<linalg::LinalgOp, Value> softmax : softmaxRoots) { + linalg::LinalgOp op = softmax.first; + Value src = softmax.second; + IRRewriter rewriter(op.getContext()); + rewriter.setInsertionPoint(softmax.first); + rewriter.replaceOpWithNewOp<IREE::LinalgExt::SoftmaxOp>( + op, src, op.getDpsInitOperand(0)->get(), op.getNumLoops() - 1); + } + } +}; + +} // namespace + +std::unique_ptr<Pass> createRaiseSpecialOps() { + return std::make_unique<RaiseSpecialOpsPass>(); +} + +} // namespace Flow +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD index 3b873ed..4b23af8 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD
@@ -39,6 +39,7 @@ "interchange_transpose_generic_ops.mlir", "optimize_numerics.mlir", "outline_dispatch_regions.mlir", + "raise_special_ops.mlir", "set_encoding.mlir", "strip_and_splat_constant_variables.mlir", "strip_signedness.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt index 589d684..7178f26 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
@@ -37,6 +37,7 @@ "interchange_transpose_generic_ops.mlir" "optimize_numerics.mlir" "outline_dispatch_regions.mlir" + "raise_special_ops.mlir" "set_encoding.mlir" "strip_and_splat_constant_variables.mlir" "strip_signedness.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir new file mode 100644 index 0000000..b93e9e8 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir
@@ -0,0 +1,54 @@ +// RUN: iree-opt --iree-flow-raise-special-ops -canonicalize %s | FileCheck %s + +// CHECK-LABEL: @softmax +// CHECK-SAME: %[[ARG:.+]]: tensor<?x?x?xf32> +// CHECK: %[[E:.+]] = tensor.empty(%{{.*}}, %{{.*}}, %{{.*}}) : tensor<?x?x?xf32> +// CHECK: %[[S:.+]] = iree_linalg_ext.softmax dimension(2) ins(%[[ARG]] : tensor<?x?x?xf32>) outs(%[[E]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> +// CHECK: return %[[S]] : tensor<?x?x?xf32> + +func.func @softmax(%src : tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>) { + %cst = arith.constant 1.000000e+00 : f32 + %cst_0 = arith.constant 0.000000e+00 : f32 + %cst_1 = arith.constant -3.40282347E+38 : f32 + %c_0_index = arith.constant 0 : index + %c_1_index = arith.constant 1 : index + %c_2_index = arith.constant 2 : index + %dim_0 = tensor.dim %src, %c_0_index : tensor<?x?x?xf32> + %dim_1 = tensor.dim %src, %c_1_index : tensor<?x?x?xf32> + %dim_2 = tensor.dim %src, %c_2_index : tensor<?x?x?xf32> + %1 = tensor.empty(%dim_0, %dim_1) : tensor<?x?xf32> + %2 = linalg.fill ins(%cst_1 : f32) outs(%1 : tensor<?x?xf32>) -> tensor<?x?xf32> + %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%src : tensor<?x?x?xf32>) outs(%2 : tensor<?x?xf32>) { + ^bb0(%arg0: f32, %arg1: f32): + %11 = arith.maxf %arg0, %arg1 : f32 + linalg.yield %11 : f32 + } -> tensor<?x?xf32> + %4 = tensor.empty(%dim_0, %dim_1, %dim_2) : tensor<?x?x?xf32> + %5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%src, %3 : tensor<?x?x?xf32>, tensor<?x?xf32>) outs(%4 : tensor<?x?x?xf32>) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): + %11 = arith.subf %arg0, %arg1 : f32 + linalg.yield %11 : f32 + } -> tensor<?x?x?xf32> + %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%5 : tensor<?x?x?xf32>) outs(%4 : tensor<?x?x?xf32>) { + ^bb0(%arg0: f32, %arg1: f32): + %11 = math.exp %arg0 : f32 + linalg.yield %11 : f32 + } -> tensor<?x?x?xf32> + %7 = linalg.fill ins(%cst_0 : f32) outs(%1 : tensor<?x?xf32>) -> tensor<?x?xf32> + %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%6 : tensor<?x?x?xf32>) outs(%7 : tensor<?x?xf32>) { + ^bb0(%arg0: f32, %arg1: f32): + %11 = arith.addf %arg0, %arg1 : f32 + linalg.yield %11 : f32 + } -> tensor<?x?xf32> + %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%8 : tensor<?x?xf32>) outs(%1 : tensor<?x?xf32>) { + ^bb0(%arg0: f32, %arg1: f32): + %11 = arith.divf %cst, %arg0 : f32 + linalg.yield %11 : f32 + } -> tensor<?x?xf32> + %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%6, %9 : tensor<?x?x?xf32>, tensor<?x?xf32>) outs(%4 : tensor<?x?x?xf32>) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): + %11 = arith.mulf %arg0, %arg1 : f32 + linalg.yield %11 : f32 + } -> tensor<?x?x?xf32> + return %10 : tensor<?x?x?xf32> +}
diff --git a/llvm-external-projects/iree-dialects/BUILD b/llvm-external-projects/iree-dialects/BUILD index 276494a..2c9b105 100644 --- a/llvm-external-projects/iree-dialects/BUILD +++ b/llvm-external-projects/iree-dialects/BUILD
@@ -146,10 +146,12 @@ deps = [ "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:Rewrite", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:TensorDialect",
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td index 4768ba1..fe5dbb9 100644 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -1135,7 +1135,7 @@ ); let builders = [ - OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs, + OpBuilder<(ins "Value":$inputs, "Value":$outputs, CArg<"int64_t", "0">:$dimension)> ];
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h index 21f4060..61870c4 100644 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
@@ -126,6 +126,25 @@ /// Predicate tag indicating that the affine map is a projected permutation. struct IsProjectedPermutation {}; +/// Predicate tag indicating that the affine map is a projection of given +/// dimension. +struct IsProjected : public SingleValuePredicateParam<int64_t> { + using Base::Base; +}; +/// Predicate tag indicating that the affine map is an identity. +struct IsIdentity {}; + +/// Predicate tag indicating that the operand is a special float constant. +struct ConstantFloatMin {}; +struct ConstantFloatZero {}; + +/// Predicate indicating that the operand is the same value as its producer's +/// operand. +struct SameOperandAsProducer + : public SingleValuePredicateParam<std::pair<int64_t, int64_t>> { + using Base::Base; +}; + /// Indicates that the match optional. The matcher is still expected to run and /// capture if successful. The parameter can be set to false struct OptionalMatch : public SingleValuePredicateParam<bool> { @@ -298,6 +317,9 @@ return *this; } + StructuredOpMatcher &input(int64_t position, + SameOperandAsProducer parentPosition); + /// Adds a predicate checking that all input operands of the structured op /// have a permutation indexing map. StructuredOpMatcher &input(AllOperands tag, IsPermutation); @@ -306,6 +328,21 @@ /// have a projected permutation indexing map. StructuredOpMatcher &input(AllOperands tag, IsProjectedPermutation); + /// Adds a predicate checking that all input operands of the structured op + /// are projected along the given dimension. + StructuredOpMatcher &input(SmallVector<int64_t> &&positions, IsProjected dim); + StructuredOpMatcher &input(int64_t position, IsProjected dim) { + return input(SmallVector<int64_t>{position}, dim); + } + + /// Adds a predicate checking that all input operands of the structured op + /// have identity indexing map. + StructuredOpMatcher &input(AllOperands tag, IsIdentity); + StructuredOpMatcher &input(SmallVector<int64_t> &&positions, IsIdentity); + StructuredOpMatcher &input(int64_t position, IsIdentity) { + return input(SmallVector<int64_t>{position}, IsIdentity()); + } + /// Adds a predicate checking that the bit width of the elemental type of the /// structured op input at the given position is equal to the given value. StructuredOpMatcher &input(int64_t position, ElementTypeBitWidth width); @@ -314,6 +351,11 @@ StructuredOpMatcher &input(int64_t position, CaptureElementTypeBitWidth width); + /// Check if input is equal to a known constant. + // TODO: Support matching for constant ops. + StructuredOpMatcher &input(int64_t position, ConstantFloatMin); + StructuredOpMatcher &input(int64_t position, ConstantFloatZero); + //===-------------------------------------------------------------------===// // Constraints on adjacent ops. //===-------------------------------------------------------------------===// @@ -351,6 +393,14 @@ /// have a projected permutation indexing map. StructuredOpMatcher &output(AllOperands tag, IsProjectedPermutation); + /// Adds a predicate checking that all output operands of the structured op + /// have a + StructuredOpMatcher &output(AllOperands tag, IsProjected dim); + + /// Adds a predicate checking that all output operands of the structured op + /// have identity indexing map. + StructuredOpMatcher &output(AllOperands tag, IsIdentity); + /// Adds a predicate checking that the bit width of the elemental type of the /// structured op output at the given position is equal to the given value. StructuredOpMatcher &output(int64_t position, ElementTypeBitWidth width); @@ -410,6 +460,28 @@ return *this; } + //===-------------------------------------------------------------------===// + // Constraints on op region. + //===-------------------------------------------------------------------===// + + /// Return true if the linalg op only contains a single ops and the arguments + /// of the operation match the order of the linalg operand. + /// Example: + /// linalg.generic + /// ins(%0, %1 : tensor<?x?x?xf32>, tensor<?x?xf32>) + /// outs(%2 : tensor<?x?x?xf32>) { + /// ^bb0(%arg0: f32, %arg1: f32): + /// %3 = arith.maxf %arg0, %arg1 : f32 + /// linalg.yield %3 : f32 + /// } -> tensor<?x?xf32> + template <typename OpType> + StructuredOpMatcher &singleOpWithCanonicaleArgs() { + return singleOpWithCanonicaleArgs(OpType::getOperationName()); + } + StructuredOpMatcher &singleOpWithCanonicaleArgs(StringRef opname); + /// Check if the op is a linalg of with a single float reciprocal op. + StructuredOpMatcher &isFloatReciprocal(); + private: /// Checks that `matchers` captured all tilable ops nested in `parent` except /// for `linalgOp`. This is an implementation detail of allTilableOpsCaptured. @@ -432,6 +504,10 @@ std::function<bool(Operation *)> matcher, OptionalMatch optional); + // Common util for constant matcher. + StructuredOpMatcher &input(int64_t position, + std::function<bool(llvm::APFloat)> floatValueFn); + /// Additional predicates to be checked on the structured op. SmallVector<PredicateFn> predicates; }; @@ -557,6 +633,24 @@ StructuredOpMatcher &trailing, MatchedReductionCaptures &captures); +/// Create a group of matchers for a sequence of operations matching exactly a +/// softmax operation. +/// +/// %red = reduce_max(%0) +/// %sub = sub(%0, %red) +/// %exp = exp(%sub) +/// %sum = reduce_sum(%exp) +/// %rec = reciprocal(%sum) +/// %mul = mul(%exp, %rec) +void makeSoftmaxMatcher(transform_ext::StructuredOpMatcher &fillMinusInf, + transform_ext::StructuredOpMatcher &maxReduction, + transform_ext::StructuredOpMatcher &sub, + transform_ext::StructuredOpMatcher &expOperand, + transform_ext::StructuredOpMatcher &fillzero, + transform_ext::StructuredOpMatcher &sum, + transform_ext::StructuredOpMatcher &divOperand, + transform_ext::StructuredOpMatcher &softmaxroot); + } // namespace transform_ext } // namespace mlir
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp index c088685..dfd53c5 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -2844,6 +2844,12 @@ .reifyResultShapes(b, reifiedReturnShapes); } +void SoftmaxOp::build(OpBuilder &builder, OperationState &state, Value source, + Value output, int64_t dimension) { + build(builder, state, TypeRange({output.getType()}), ValueRange(source), + ValueRange(output), dimension); +} + //===----------------------------------------------------------------------===// // AttentionOp //===----------------------------------------------------------------------===//
diff --git a/llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt index eb77eea..5b233cf 100644 --- a/llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt +++ b/llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt
@@ -9,9 +9,11 @@ # TODO: break dialect dependency by implementing the transformation separately # and registering it. MLIRAsyncDialect + MLIRArithDialect MLIRFuncDialect MLIRLinalgDialect MLIRLinalgTransforms + MLIRMathDialect DEPENDS mlir-headers
diff --git a/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp b/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp index 711d1a7..19a8893 100644 --- a/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp +++ b/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
@@ -7,8 +7,12 @@ #include "iree-dialects/Transforms/TransformMatchers.h" #include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" @@ -318,6 +322,76 @@ return *this; } +/// Helper to check if the map is an identity map with a projected dim. +static bool isProjectedMap(AffineMap map, int64_t projectedDim) { + if (!map.isProjectedPermutation()) + return false; + int64_t dimCounter = 0; + for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { + // Skip the project dim. + if (dimCounter == projectedDim) + dimCounter++; + if (map.getDimPosition(i) != dimCounter++) { + return false; + } + } + return true; +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::input(SmallVector<int64_t> &&positions, + IsProjected dim) { + predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "operands "; + llvm::interleaveComma(positions, llvm::dbgs()); + llvm::dbgs() << " have a permutation maps with " << dim.value + << " projected\n"); + int64_t updatedDim = + dim.value >= 0 ? dim.value : linalgOp.getNumLoops() + dim.value; + for (int64_t position : positions) { + OpOperand *operand = linalgOp.getDpsInputOperand(position); + if (!isProjectedMap(linalgOp.getMatchingIndexingMap(operand), updatedDim)) + return false; + } + return true; + }); + return *this; +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::input(AllOperands tag, IsIdentity) { + predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "all input operands have identity maps"); + // all_of with a lambda requires const-casting dance, so using a loop. + for (OpOperand *operand : linalgOp.getDpsInputOperands()) { + if (!linalgOp.getMatchingIndexingMap(operand).isIdentity()) + return false; + } + return true; + }); + return *this; +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::input(SmallVector<int64_t> &&positions, + IsIdentity) { + predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "input operands "; + llvm::interleaveComma(positions, llvm::dbgs()); + llvm::dbgs() << " have identity maps"); + // all_of with a lambda requires const-casting dance, so using a loop. + for (int64_t position : positions) { + int64_t updatedPosition = + position >= 0 ? position : linalgOp.getNumDpsInputs() + position; + OpOperand *operand = linalgOp.getDpsInputOperand(updatedPosition); + if (!linalgOp.getMatchingIndexingMap(operand).isIdentity()) + return false; + } + return true; + }); + return *this; +} + transform_ext::StructuredOpMatcher & transform_ext::StructuredOpMatcher::input(int64_t position, ElementTypeBitWidth width) { @@ -368,6 +442,72 @@ return *this; } +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::input(int64_t position, ConstantFloatMin) { + return input(position, + [](llvm::APFloat f) { return f.isLargest() && f.isNegative(); }); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::input(int64_t position, ConstantFloatZero) { + return input(position, [](llvm::APFloat f) { return f.isZero(); }); +} + +transform_ext::StructuredOpMatcher &transform_ext::StructuredOpMatcher::input( + int64_t position, std::function<bool(llvm::APFloat)> floatValueFn) { + predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "input operands " << position + << " is a special floating point constant"); + int64_t updatedPosition = + position >= 0 ? position : linalgOp.getNumDpsInputs() + position; + if (0 > updatedPosition || updatedPosition >= linalgOp.getNumDpsInputs()) + return false; + auto cstOp = linalgOp.getDpsInputOperand(updatedPosition) + ->get() + .getDefiningOp<arith::ConstantFloatOp>(); + if (!cstOp) + return false; + return floatValueFn(cstOp.value()); + }); + return *this; +} + +transform_ext::StructuredOpMatcher &transform_ext::StructuredOpMatcher::input( + int64_t position, SameOperandAsProducer producerPosition) { + predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "input operand " << position + << " is the same value as operand " + << producerPosition.value.second << " of the input " + << producerPosition.value.first << " producer"); + int64_t updatedPosition = + position >= 0 ? position : linalgOp.getNumDpsInputs() + position; + if (0 > updatedPosition || updatedPosition >= linalgOp.getNumDpsInputs()) + return false; + int64_t updatedProducerPosition = + producerPosition.value.first >= 0 + ? producerPosition.value.first + : linalgOp.getNumDpsInputs() + producerPosition.value.first; + if (0 > updatedProducerPosition || + updatedProducerPosition >= linalgOp.getNumDpsInputs()) + return false; + Operation *producer = linalgOp.getDpsInputOperand(updatedProducerPosition) + ->get() + .getDefiningOp(); + if (!producer) + return false; + int64_t producerOperandPos = + producerPosition.value.second >= 0 + ? producerPosition.value.second + : producer->getNumOperands() + producerPosition.value.second; + if (0 > producerOperandPos || + updatedProducerPosition >= producer->getNumOperands()) + return false; + return producer->getOperand(producerOperandPos) == + linalgOp.getDpsInputOperand(position)->get(); + }); + return *this; +} + //===---------------------------------------------------------------------===// // Constraints on output operands. //===---------------------------------------------------------------------===// @@ -428,6 +568,35 @@ } transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::output(AllOperands tag, IsProjected dim) { + predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "all output operands have a maps with projected"); + int64_t updatedDim = + dim.value >= 0 ? dim.value : linalgOp.getNumLoops() + dim.value; + // all_of with a lambda requires const-casting dance, so using a loop. + for (OpOperand *operand : linalgOp.getDpsInitOperands()) { + if (!isProjectedMap(linalgOp.getMatchingIndexingMap(operand), updatedDim)) + return false; + } + return true; + }); + return *this; +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::output(AllOperands tag, IsIdentity) { + predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "all output operands have identity permutation maps"); + for (OpOperand *operand : linalgOp.getDpsInitOperands()) { + if (!linalgOp.getMatchingIndexingMap(operand).isIdentity()) + return false; + } + return true; + }); + return *this; +} + +transform_ext::StructuredOpMatcher & transform_ext::StructuredOpMatcher::output(int64_t position, ElementTypeBitWidth width) { predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { @@ -546,6 +715,62 @@ return numTilableOps == matched.size(); } +//===-------------------------------------------------------------------===// +// Constraints on op region. +//===-------------------------------------------------------------------===// + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::singleOpWithCanonicaleArgs( + StringRef opcode) { + predicates.push_back([=](linalg::LinalgOp linalgOp) { + if (linalgOp.getBlock()->getOperations().size() != 2) + return false; + Operation *innerOp = &(*linalgOp.getBlock()->getOperations().begin()); + if (innerOp->getName().getStringRef() != opcode || + innerOp->getNumResults() != 1) + return false; + Operation *yieldOp = linalgOp.getBlock()->getTerminator(); + if (yieldOp->getNumOperands() != 1) + return false; + if (yieldOp->getOperand(0).getDefiningOp() != innerOp) + return false; + for (auto [index, operand] : llvm::enumerate(innerOp->getOperands())) { + auto arg = dyn_cast<BlockArgument>(operand); + if (!arg || arg.getParentBlock() != linalgOp.getBlock() || + arg.getArgNumber() != index) + return false; + } + return true; + }); + return *this; +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::isFloatReciprocal() { + predicates.push_back([=](linalg::LinalgOp linalgOp) { + LLVM_DEBUG(DBGS() << "op region represents a reciprocal operation"); + if (linalgOp.getBlock()->getOperations().size() != 2) + return false; + Operation *innerOp = &(*linalgOp.getBlock()->getOperations().begin()); + if (!isa<arith::DivFOp>(innerOp) || innerOp->getNumResults() != 1) + return false; + Operation *yieldOp = linalgOp.getBlock()->getTerminator(); + if (yieldOp->getNumOperands() != 1) + return false; + if (yieldOp->getOperand(0).getDefiningOp() != innerOp) + return false; + auto cst = innerOp->getOperand(0).getDefiningOp<arith::ConstantFloatOp>(); + if (!cst || cst.value().convertToDouble() != 1.0) + return false; + auto arg = dyn_cast<BlockArgument>(innerOp->getOperand(1)); + if (!arg || arg.getParentBlock() != linalgOp.getBlock() || + arg.getArgNumber() != 0) + return false; + return true; + }); + return *this; +} + //===---------------------------------------------------------------------===// // MatchCallbackResult. //===---------------------------------------------------------------------===// @@ -672,3 +897,81 @@ reduction = reduction.result(0, HasAnyUse(), trailing, OptionalMatch()) .allTilableOpsCaptured<func::FuncOp>(); } + +void transform_ext::makeSoftmaxMatcher( + transform_ext::StructuredOpMatcher &fillMinusInf, + transform_ext::StructuredOpMatcher &maxReduction, + transform_ext::StructuredOpMatcher &sub, + transform_ext::StructuredOpMatcher &expOperand, + transform_ext::StructuredOpMatcher &fillzero, + transform_ext::StructuredOpMatcher &sum, + transform_ext::StructuredOpMatcher &divOperand, + transform_ext::StructuredOpMatcher &softmaxroot) { + + fillMinusInf = m_StructuredOp<linalg::FillOp>().input(0, ConstantFloatMin()); + maxReduction = transform_ext::m_StructuredOp<linalg::GenericOp>() + .singleOpWithCanonicaleArgs<arith::MaxFOp>() + // Only handle most inner reduction for now. + .dim(-1, utils::IteratorType::reduction) + .dim(AllDimsExcept({-1}), utils::IteratorType::parallel) + .input(NumEqualsTo(1)) + .input(AllOperands(), IsIdentity()) + .output(NumEqualsTo(1)) + .output(AllOperands(), IsProjected(-1)); + maxReduction = maxReduction.output(0, fillMinusInf); + + sub = transform_ext::m_StructuredOp<linalg::GenericOp>() + .singleOpWithCanonicaleArgs<arith::SubFOp>() + .dim(AllDims(), utils::IteratorType::parallel) + .input(NumEqualsTo(2)) + .input(0, IsIdentity()) + .input(1, IsProjected(-1)) + .output(NumEqualsTo(1)) + .output(AllOperands(), IsIdentity()); + sub = + sub.input(0, SameOperandAsProducer(std::pair<int64_t, int64_t>({1, 0}))); + sub = sub.input(1, maxReduction); + + expOperand = m_StructuredOp<linalg::GenericOp>() + .singleOpWithCanonicaleArgs<math::ExpOp>() + .dim(AllDims(), utils::IteratorType::parallel) + .input(NumEqualsTo(1)) + .input(AllOperands(), IsIdentity()) + .output(AllOperands(), IsIdentity()) + .output(NumEqualsTo(1)); + expOperand = expOperand.input(0, sub); + + fillzero = m_StructuredOp<linalg::FillOp>().input(0, ConstantFloatZero()); + sum = m_StructuredOp<linalg::GenericOp>() + .singleOpWithCanonicaleArgs<arith::AddFOp>() + // Only handle most inner reduction for now. + .dim(-1, utils::IteratorType::reduction) + .dim(AllDimsExcept({-1}), utils::IteratorType::parallel) + .input(NumEqualsTo(1)) + .input(AllOperands(), IsIdentity()) + .output(AllOperands(), IsProjected(-1)) + .output(NumEqualsTo(1)); + sum = sum.input(0, expOperand); + sum = sum.output(0, fillzero); + + divOperand = m_StructuredOp<linalg::GenericOp>() + .isFloatReciprocal() + .dim(AllDims(), utils::IteratorType::parallel) + .input(NumEqualsTo(1)) + .input(AllOperands(), IsIdentity()) + .output(AllOperands(), IsIdentity()) + .output(NumEqualsTo(1)); + divOperand = divOperand.input(0, sum); + + softmaxroot = transform_ext::m_StructuredOp<linalg::GenericOp>() + .singleOpWithCanonicaleArgs<arith::MulFOp>() + .dim(AllDims(), utils::IteratorType::parallel) + .input(NumEqualsTo(2)) + .input(0, IsIdentity()) + .input(1, IsProjected(-1)) + .output(NumEqualsTo(1)) + .output(AllOperands(), IsIdentity()); + + softmaxroot = softmaxroot.input(0, expOperand); + softmaxroot = softmaxroot.input(1, divOperand); +}