| // Copyright 2022 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| // Implements utilities for lowering StableHLO dialect to Linalg dialect. |
| |
| #include "compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.h" |
| |
| #include <algorithm> |
| #include <numeric> |
| #include <string> |
| #include <utility> |
| |
| #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "stablehlo/dialect/ChloOps.h" |
| #include "stablehlo/dialect/StablehloOps.h" |
| |
| namespace mlir::iree_compiler::stablehlo { |
| namespace { |
| bool hasIntegralShapeType(Operation *op) { |
| auto stp = llvm::dyn_cast<ShapedType>(op->getOperand(0).getType()); |
| return stp && stp.getElementType().isIntOrIndex(); |
| } |
| |
| } // namespace |
| |
| SmallVector<utils::IteratorType, 3> |
| getParallelAndReductionIterators(unsigned nLoops, unsigned nReduction) { |
| SmallVector<utils::IteratorType, 3> res(nLoops - nReduction, |
| utils::IteratorType::parallel); |
| res.append(nReduction, utils::IteratorType::reduction); |
| return res; |
| } |
| |
| SmallVector<utils::IteratorType, 3> |
| getNParallelLoopsAttrs(unsigned nParallelLoops) { |
| return getParallelAndReductionIterators(nParallelLoops, 0); |
| } |
| |
| Value getEmptySparseTensor(OpBuilder &b, Location loc, ShapedType type, |
| ArrayRef<Value> dynSizes) { |
| return b.create<bufferization::AllocTensorOp>( |
| loc, llvm::cast<TensorType>(type), dynSizes, |
| /*copy=*/Value(), |
| /*memory_space=*/IntegerAttr()); |
| } |
| |
| Value getEmptyTensor(OpBuilder &b, Location loc, ShapedType type, |
| ArrayRef<Value> dynSizes) { |
| return b.create<tensor::EmptyOp>( |
| loc, type.getShape(), type.getElementType(), dynSizes, |
| llvm::cast<RankedTensorType>(type).getEncoding()); |
| } |
| |
| Value getEmptyTensorFor(OpBuilder &b, Location loc, ShapedType resultType, |
| Operation *op, ValueRange operands) { |
| bool isSparse = sparse_tensor::getSparseTensorEncoding(resultType) != nullptr; |
| // Collect the sizes for a ranked tensor to be passed as parameter to a |
| // new tensor initialization operation. This operation only needs the |
| // dynamic sizes. |
| SmallVector<Value> sizes; |
| if (resultType.hasRank() && !resultType.hasStaticShape()) { |
| // Ask the op for its output shape. |
| auto shapeSource = cast<InferShapedTypeOpInterface>(op); |
| SmallVector<Value, 1> reifiedShapes; |
| (void)shapeSource.reifyReturnTypeShapes(b, operands, reifiedShapes); |
| assert(reifiedShapes.size() == 1 && "Expected one reified result"); |
| // Construct sizes for the required dimensions. |
| for (const auto &en : llvm::enumerate(resultType.getShape())) { |
| if (!ShapedType::isDynamic(en.value())) |
| continue; |
| sizes.push_back(b.create<tensor::ExtractOp>( |
| loc, reifiedShapes[0], |
| ValueRange{b.create<arith::ConstantIndexOp>(loc, en.index())})); |
| } |
| } |
| return isSparse ? getEmptySparseTensor(b, loc, resultType, sizes) |
| : getEmptyTensor(b, loc, resultType, sizes); |
| } |
| |
| Value coerceTensorShape(OpBuilder &builder, Location loc, |
| TypedValue<ShapedType> value, ShapedType targetType) { |
| return builder.createOrFold<tensor::CastOp>( |
| loc, targetType.cloneWith(std::nullopt, value.getType().getElementType()), |
| value); |
| } |
| |
| LogicalResult verifyHloOpBufferOrTensorSemantics(Operation *op) { |
| auto isRankedTensor = [](Value val) { |
| return isa<RankedTensorType>(val.getType()); |
| }; |
| if (!llvm::all_of(op->getOperands(), isRankedTensor)) |
| return failure(); |
| return success(llvm::all_of(op->getResults(), isRankedTensor)); |
| } |
| |
| Value fillTensorWithZeros(OpBuilder &builder, Location loc, Value tensor) { |
| auto type = cast<ShapedType>(tensor.getType()); |
| Value zero; |
| // Complex numbers are a special case. |
| if (auto complexType = llvm::dyn_cast<ComplexType>(type.getElementType())) { |
| auto zeroElement = builder.getZeroAttr(complexType.getElementType()); |
| auto zeroAttr = builder.getArrayAttr({zeroElement, zeroElement}); |
| zero = builder.create<complex::ConstantOp>(loc, complexType, zeroAttr); |
| } else { |
| auto zeroAttr = builder.getZeroAttr(type.getElementType()); |
| zero = builder.create<arith::ConstantOp>(loc, zeroAttr); |
| } |
| return builder.create<linalg::FillOp>(loc, zero, tensor).result(); |
| } |
| |
| Value preSparsify(Operation *op, llvm::SmallVector<Value, 2> &values, Type rtp, |
| OpBuilder *b) { |
| // Apply for semi-ring operations that lower to elaborate code |
| // (any sign-op, or an integral abs-op). |
| // TODO(peiming, ajcbik): these all can potentially be optimized by applying |
| // value transform on sparse_tenosr.value memref |
| if (isa<mlir::stablehlo::SignOp, mlir::stablehlo::NegOp>(op) || |
| (isa<mlir::stablehlo::AbsOp>(op) && hasIntegralShapeType(op)) || |
| isa<chlo::AsinOp, chlo::AsinhOp, chlo::AtanOp, chlo::AtanhOp, |
| chlo::BesselI1eOp, chlo::SinhOp, chlo::TanOp>(op)) { |
| if (!sparse_tensor::getSparseTensorEncoding(op->getResult(0).getType()) && |
| !sparse_tensor::getSparseTensorEncoding(op->getOperand(0).getType())) |
| return Value(); |
| Location loc = op->getLoc(); |
| auto semiring = b->create<sparse_tensor::UnaryOp>(loc, rtp, values[0]); |
| Type itp = values[0].getType(); |
| Block *present = b->createBlock(&semiring.getPresentRegion(), {}, itp, loc); |
| b->setInsertionPointToStart(&semiring.getPresentRegion().front()); |
| values[0] = present->getArgument(0); |
| return semiring; |
| } |
| return Value(); |
| } |
| |
| Value postSparsify(Operation *op, Value semiring, Value result, OpBuilder *b) { |
| if (semiring) { |
| b->create<sparse_tensor::YieldOp>(op->getLoc(), result); |
| b->setInsertionPointAfter(semiring.getDefiningOp()); |
| return semiring; |
| } |
| return result; |
| } |
| |
| bool allOperandsAreScalarTensors(Operation *op) { |
| return llvm::all_of(op->getOperands(), [](Value operand) { |
| auto operandTy = llvm::dyn_cast<ShapedType>(operand.getType()); |
| return operandTy && operandTy.getRank() == 0; |
| }); |
| } |
| |
| bool isInBodyOfLinalgOps(Operation *op) { |
| auto *parentOp = op->getParentRegion()->getParentOp(); |
| return parentOp->getDialect() == |
| parentOp->getContext()->getLoadedDialect<linalg::LinalgDialect>(); |
| } |
| |
| SmallVector<int64_t> extract1DVector(DenseIntElementsAttr elements) { |
| SmallVector<int64_t> ret; |
| for (const APInt &element : elements) { |
| ret.push_back(element.getLimitedValue()); |
| } |
| return ret; |
| } |
| |
| } // namespace mlir::iree_compiler::stablehlo |