| // Copyright 2023 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h" |
| #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" |
| #include "iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h" |
| #include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h" |
| #include "iree-dialects/Dialect/LinalgExt/Utils/Utils.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| #include "mlir/Dialect/Math/IR/Math.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/Dialect/Tensor/Utils/Utils.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/Pass/Pass.h" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| namespace IREE { |
| namespace LinalgExt { |
| |
| namespace { |
| |
| std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>> |
| computeIteratorTypesAndIndexingMaps(int64_t inputRank, int64_t dim, |
| OpBuilder &builder, |
| bool allParallel = false) { |
| SmallVector<utils::IteratorType> iteratorTypes(inputRank, |
| utils::IteratorType::parallel); |
| if (!allParallel) |
| iteratorTypes[dim] = utils::IteratorType::reduction; |
| auto identityMap = |
| AffineMap::getMultiDimIdentityMap(inputRank, builder.getContext()); |
| SmallVector<AffineExpr, 2> affineExprs; |
| for (int i = 0; i < inputRank; i++) { |
| if (i != dim) |
| affineExprs.push_back(mlir::getAffineDimExpr(i, builder.getContext())); |
| } |
| auto reductionMap = |
| AffineMap::get(inputRank, 0, affineExprs, builder.getContext()); |
| SmallVector<AffineMap> indexingMaps{identityMap, reductionMap}; |
| return std::make_tuple(iteratorTypes, indexingMaps); |
| } |
| |
| template <typename T> |
| static Value reduce(Value input, Value output, int64_t dim, Location loc, |
| OpBuilder &builder) { |
| auto inputType = input.getType().cast<ShapedType>(); |
| ArrayRef<int64_t> inputShape = inputType.getShape(); |
| int64_t inputRank = inputShape.size(); |
| auto [iteratorTypes, indexingMaps] = |
| computeIteratorTypesAndIndexingMaps(inputRank, dim, builder); |
| auto genericOp = builder.create<linalg::GenericOp>( |
| loc, output.getType(), input, output, indexingMaps, iteratorTypes, |
| [&](OpBuilder &b, Location loc, ValueRange args) { |
| Value result = b.create<T>(loc, args[0], args[1]); |
| b.create<linalg::YieldOp>(loc, result); |
| }); |
| return genericOp.getResult(0); |
| } |
| |
| static Value subtractAndExp(Value input, Value max, Value output, int64_t dim, |
| Location loc, OpBuilder &builder) { |
| auto inputType = input.getType().cast<ShapedType>(); |
| ArrayRef<int64_t> inputShape = inputType.getShape(); |
| int64_t inputRank = inputShape.size(); |
| auto [iteratorTypes, indexingMaps] = |
| computeIteratorTypesAndIndexingMaps(inputRank, dim, builder, true); |
| indexingMaps.push_back(indexingMaps[0]); |
| auto genericOp = builder.create<linalg::GenericOp>( |
| loc, input.getType(), ValueRange{input, max}, output, indexingMaps, |
| iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { |
| Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]); |
| Value result = b.create<math::ExpOp>(loc, diff); |
| b.create<linalg::YieldOp>(loc, result); |
| }); |
| return genericOp.getResult(0); |
| } |
| |
| static Value computeSoftmax(Value numerator, Value denominator, Value output, |
| int64_t dim, Location loc, OpBuilder &builder) { |
| auto inputType = numerator.getType().cast<ShapedType>(); |
| ArrayRef<int64_t> inputShape = inputType.getShape(); |
| int64_t inputRank = inputShape.size(); |
| auto [iteratorTypes, indexingMaps] = |
| computeIteratorTypesAndIndexingMaps(inputRank, dim, builder, true); |
| indexingMaps.push_back(indexingMaps[0]); |
| auto genericOp = builder.create<linalg::GenericOp>( |
| loc, numerator.getType(), ValueRange{numerator, denominator}, output, |
| indexingMaps, iteratorTypes, |
| [&](OpBuilder &b, Location loc, ValueRange args) { |
| Value result = b.create<arith::DivFOp>(loc, args[0], args[1]); |
| b.create<linalg::YieldOp>(loc, result); |
| }); |
| return genericOp.getResult(0); |
| } |
| |
| /// Given an N-dimensional tensor x, this op converts |
| /// softmax(x) to the following sequence of operations: |
| /// |
| /// 1. Compute the max of x along dimension d. This results |
| /// in a N-1 dimensional tensor m. |
| /// m = max(x, dim = d) |
| /// |
| /// 2. Subtract m from x and exponentiate. This results in |
| /// a N dimensional tensor z. |
| /// z = exp(x - m) |
| /// |
| /// 3. Compute the sum of z along dimension d. This results in |
| /// a N-1 dimensional tensor l. |
| /// l = sum(z, dim = d) |
| /// |
| /// 4. Divide z and l. This gives the N-dimensional softmax. |
| /// softmax = z / l |
| /// |
| LogicalResult convertSoftmaxToGenerics(func::FuncOp funcOp) { |
| IRRewriter rewriter(funcOp.getContext()); |
| SmallVector<Operation *> toDelete; |
| funcOp.walk([&](IREE::LinalgExt::SoftmaxOp softmaxOp) { |
| OpBuilder::InsertionGuard guard(rewriter); |
| rewriter.setInsertionPoint(softmaxOp); |
| Location loc = softmaxOp.getLoc(); |
| Value input = softmaxOp.input(); |
| ShapedType inputType = input.getType().cast<ShapedType>(); |
| Type elementType = inputType.getElementType(); |
| int64_t reductionDim = softmaxOp.getDimension(); |
| SmallVector<OpFoldResult> dims = |
| tensor::createDimValues(rewriter, loc, input); |
| Value outputNd = rewriter.create<tensor::EmptyOp>(loc, dims, elementType); |
| dims.erase(dims.begin() + reductionDim); |
| // Compute max along dim |
| Value output = rewriter.create<tensor::EmptyOp>(loc, dims, elementType); |
| Value largeNegative = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getFloatAttr(elementType, -1.0e30)); |
| Value negativeInit = |
| rewriter.create<linalg::FillOp>(loc, Value{largeNegative}, output) |
| .result(); |
| Value max = |
| reduce<arith::MaxFOp>(input, negativeInit, reductionDim, loc, rewriter); |
| // Subtract max from input and exponentiate |
| Value numerator = |
| subtractAndExp(input, max, outputNd, reductionDim, loc, rewriter); |
| // Compute sum along dim |
| Value zero = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getZeroAttr(elementType)); |
| Value zeroInit = |
| rewriter.create<linalg::FillOp>(loc, Value{zero}, output).result(); |
| Value denominator = |
| reduce<arith::AddFOp>(numerator, zeroInit, reductionDim, loc, rewriter); |
| // Compute softmax |
| Value result = computeSoftmax(numerator, denominator, outputNd, |
| reductionDim, loc, rewriter); |
| softmaxOp.getResult()[0].replaceAllUsesWith(result); |
| // Delete the op after the walk. |
| toDelete.push_back(softmaxOp.getOperation()); |
| return WalkResult::advance(); |
| }); |
| for (Operation *op : toDelete) { |
| rewriter.eraseOp(op); |
| } |
| return success(); |
| } |
| |
| struct DecomposeSoftmaxPass : DecomposeSoftmaxBase<DecomposeSoftmaxPass> { |
| void getDependentDialects(DialectRegistry ®istry) const override { |
| registry |
| .insert<linalg::LinalgDialect, IREE::LinalgExt::IREELinalgExtDialect>(); |
| } |
| void runOnOperation() override { |
| MLIRContext *context = &getContext(); |
| IRRewriter rewriter(context); |
| if (failed(convertSoftmaxToGenerics(getOperation()))) |
| return signalPassFailure(); |
| } |
| }; |
| |
| } // namespace |
| |
| std::unique_ptr<Pass> createDecomposeSoftmaxPass() { |
| return std::make_unique<DecomposeSoftmaxPass>(); |
| } |
| |
| } // namespace LinalgExt |
| } // namespace IREE |
| } // namespace iree_compiler |
| } // namespace mlir |