blob: 2b74488c9ef329c23b17f42a141d3d2e13e7343e [file] [log] [blame]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Utils for lowering of the StableHLO dialect to the Linalg dialect.
#ifndef IREE_COMPILER_PLUGINS_INPUT_STABLEHLO_CONVERSION_LEGALIZE_TO_LINALG_UTILS_H_
#define IREE_COMPILER_PLUGINS_INPUT_STABLEHLO_CONVERSION_LEGALIZE_TO_LINALG_UTILS_H_
#include <algorithm>
#include <numeric>
#include <optional>
#include <string>
#include <utility>
#include "compiler/plugins/input/StableHLO/Conversion/MapStableHLOToScalarOp.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "stablehlo/dialect/StablehloOps.h"
namespace mlir::iree_compiler::stablehlo {
/// Returns an ArrayAttr that contains `nLoops` attributes. All the attributes
/// are "parallel" except the last `nReduction` elements, where are "reduction"
/// attributes.
SmallVector<utils::IteratorType, 3>
getParallelAndReductionIterators(unsigned nLoops, unsigned nReduction);
/// Returns an ArrayAttr that contains `nParallelLoops` "parallel" attributes.
SmallVector<utils::IteratorType, 3>
getNParallelLoopsAttrs(unsigned nParallelLoops);
/// Generates an init sparse tensor.
Value getEmptySparseTensor(OpBuilder &b, Location loc, ShapedType type,
ArrayRef<Value> dynSizes);
/// Generates a tensor.empty op.
Value getEmptyTensor(OpBuilder &b, Location loc, ShapedType type,
ArrayRef<Value> dynSizes);
/// Generates an empty tensor for the result of the operation, which could be a
/// dense tensor or a sparse tensor.
Value getEmptyTensorFor(OpBuilder &b, Location loc, ShapedType resultType,
Operation *op, ValueRange operands);
/// Ensures a tensor has the same shape (not including the element type) as
/// another.
Value coerceTensorShape(OpBuilder &builder, Location loc,
TypedValue<ShapedType> value, ShapedType targetType);
/// Verifies |op|'s semantics by checking if all operands and results have
/// ranged tensor types.
LogicalResult verifyHloOpBufferOrTensorSemantics(Operation *op);
/// Fills |tensor| with a zero constant of the matching type. Returns the new
/// value.
Value fillTensorWithZeros(OpBuilder &builder, Location loc, Value tensor);
/// Sparsifies a (block of) operation(s) that cannot be handled directly
/// by the sparse compiler but has well-known semi-ring semantics.
///
/// This yields something of the following form:
///
/// %result = sparse_tensor.unary %values[0]
/// present={
/// ^bb1(%val):
/// ... codegen proceeds here using %val ....
/// sparse_tensor.yield
/// }
/// absent={}
/// linalg.yield %result
Value preSparsify(Operation *op, llvm::SmallVector<Value, 2> &values, Type rtp,
OpBuilder *b);
/// Finalizes sparse semi-ring construction.
Value postSparsify(Operation *op, Value semiring, Value result, OpBuilder *b);
/// Returns true if all operands are tensors with rank 0.
bool allOperandsAreScalarTensors(Operation *op);
/// Returns true if parent op is linalg.
bool isInBodyOfLinalgOps(Operation *op);
/// Extracts integer values from the attribute |elements|.
SmallVector<int64_t> extract1DVector(DenseIntElementsAttr elements);
/// Returns true if the given |values| is a splat of the given |queryValue|.
inline bool isSplatValue(const ArrayRef<int64_t> &values, int64_t queryValue) {
for (auto value : values) {
if (value != queryValue) {
return false;
}
}
return true;
}
/// Returns true if the given |attr| is a splat of the given |value|.
inline bool isSplatValue(DenseIntElementsAttr attr, uint64_t value) {
return attr.isSplat() && attr.getSplatValue<uint64_t>() == value;
}
} // namespace mlir::iree_compiler::stablehlo
#endif // IREE_COMPILER_PLUGINS_INPUT_STABLEHLO_CONVERSION_LEGALIZE_TO_LINALG_UTILS_H_