blob: f51fcd66399aa0423900f5a0b6c82a685ede5edc [file] [log] [blame]
// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Implements logic for lowering StableHLO convolution ops to Linalg dialect.
#include "compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.h"
#include "compiler/plugins/input/StableHLO/Conversion/Rewriters.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "stablehlo/dialect/StablehloOps.h"
namespace mlir::iree_compiler::stablehlo {
namespace {
/// Apply dilation and padding to the input of a convolution.
Value applyConvolutionPadding(Location loc, Value input,
DenseIntElementsAttr padding,
std::optional<ArrayRef<int64_t>> lhsDilation,
llvm::ArrayRef<int64_t> dimMappings,
OpBuilder &rewriter) {
SmallVector<int64_t> lhsDilationValues;
if (lhsDilation.has_value())
lhsDilationValues = llvm::to_vector(lhsDilation.value());
bool noPadding = !padding || isSplatValue(padding, 0);
bool noDilation = !lhsDilation || hlo::isSplatArray(lhsDilationValues, 1);
if (noPadding && noDilation)
return input;
auto inputType = cast<ShapedType>(input.getType());
int64_t rank = inputType.getRank();
// Translate window padding into low/high padding.
SmallVector<int64_t, 8> padLow(rank, 0);
SmallVector<int64_t, 8> padHigh(rank, 0);
if (padding) {
// The padding attribute contains two values per dimension, but excludes the
// batch and feature dimensions.
assert(rank * 2 == padding.size() + 4 &&
"There should be 2 padding values per dimension, i.e low and high.");
for (int64_t i : llvm::seq<int64_t>(0, padding.size() / 2)) {
int64_t dim = dimMappings[i];
padLow[dim] = padding.getValues<int64_t>()[i * 2];
padHigh[dim] = padding.getValues<int64_t>()[i * 2 + 1];
}
}
// Translate input dilation into interior padding.
SmallVector<int64_t, 8> padInterior(rank, 0);
if (lhsDilation) {
assert(rank == static_cast<int64_t>(lhsDilationValues.size()) + 2);
for (int64_t i : llvm::seq<int64_t>(0, lhsDilationValues.size())) {
int64_t dim = dimMappings[i];
padInterior[dim] = lhsDilationValues[i] - 1;
}
}
Value zero;
if (auto complexType = dyn_cast<ComplexType>(inputType.getElementType())) {
auto zeroElement = rewriter.getZeroAttr(complexType.getElementType());
auto zeroAttr = rewriter.getArrayAttr({zeroElement, zeroElement});
zero = rewriter.create<complex::ConstantOp>(loc, complexType, zeroAttr);
zero = rewriter.create<tensor::FromElementsOp>(
loc, RankedTensorType::get({}, complexType), zero);
} else {
zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(
RankedTensorType::get({}, inputType.getElementType())));
}
return rewriter.create<mlir::stablehlo::PadOp>(loc, input, zero, padLow,
padHigh, padInterior);
}
/// If the ConvolutionOp has a window reversal, applies it to the filter.
Value applyConvolutionReversal(Location loc, OpBuilder &b,
mlir::stablehlo::ConvolutionOp op,
Value filter) {
std::optional reversals = op.getWindowReversal();
if (!reversals.has_value()) {
return filter;
}
llvm::SmallVector<int64_t> reversedDims;
for (auto [idx, reversed] : llvm::enumerate(reversals.value())) {
if (reversed) {
reversedDims.push_back(
op.getDimensionNumbers().getKernelSpatialDimensions()[idx]);
}
}
return b.create<mlir::stablehlo::ReverseOp>(
loc, filter, b.getDenseI64ArrayAttr(reversedDims));
}
/// Returns true if the given `dimensionNumbers` from a stablehlo.convolution op
/// follows a canonical form:
///
/// * Input dimensions have order: (batch_count, spatial_dims,
/// input_channel_count).
/// * Filter dimensions have order: (spatial_dims, input_channel_count,
/// output_channel_count).
/// * Output dimensions have order: (batch_count, spatial_dims,
/// output_channel_count).
bool hasCanonicalDimensionNumbers(
mlir::stablehlo::ConvDimensionNumbersAttr dimensionNumbers) {
const int64_t inputSpatialRank =
dimensionNumbers.getInputSpatialDimensions().size();
// The dimensions for input should follow the order of
// batch_count, spatial_dims..., input_feature_count.
if (dimensionNumbers.getInputBatchDimension() != 0 ||
dimensionNumbers.getInputFeatureDimension() != (inputSpatialRank + 1)) {
return false;
}
const int64_t kernelSpatialRank =
dimensionNumbers.getKernelSpatialDimensions().size();
// The dimensions for filter should follow the order of
// spatial_dims..., input_feature_count, num_output_feature_count.
if (dimensionNumbers.getKernelInputFeatureDimension() != kernelSpatialRank ||
dimensionNumbers.getKernelOutputFeatureDimension() !=
(kernelSpatialRank + 1)) {
return false;
}
const int64_t outputSpatialRank =
dimensionNumbers.getOutputSpatialDimensions().size();
// The dimensions for output should follow the order of
// batch_count, spatial_dims.., output_feature_count.
if (dimensionNumbers.getOutputBatchDimension() != 0 ||
dimensionNumbers.getOutputFeatureDimension() != (outputSpatialRank + 1)) {
return false;
}
if (inputSpatialRank != outputSpatialRank ||
inputSpatialRank != kernelSpatialRank) {
return false;
}
const int64_t *inputSpatialDim =
dimensionNumbers.getInputSpatialDimensions().data();
const int64_t *kernelSpatialDim =
dimensionNumbers.getKernelSpatialDimensions().data();
const int64_t *outputSpatialDim =
dimensionNumbers.getOutputSpatialDimensions().data();
// Check spatial dims are ordered correctly.
for (int64_t i = 0; i < inputSpatialRank; ++i) {
const int64_t dim = i + 1;
if ((*inputSpatialDim++) != dim || (*outputSpatialDim++) != dim ||
(*kernelSpatialDim++) != i) {
return false;
}
}
return true;
}
/// Converts stablehlo.conv operation to linalg named op. This only covers
/// normal convolution cases. The op must have canonical dimension numbers.
/// Depthwise convolution and pointwise convolution are not handled in the
/// conversion.
struct NormalConvolutionOpConversion final
: OpConversionPattern<mlir::stablehlo::ConvolutionOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(mlir::stablehlo::ConvolutionOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!hasCanonicalDimensionNumbers(op.getDimensionNumbers())) {
return failure();
}
if (op.getFeatureGroupCount() != 1u)
return failure();
if (op.getBatchGroupCount() != 1u)
return failure();
Location loc = op.getLoc();
Value input = adaptor.getLhs();
Value filter = adaptor.getRhs();
filter = applyConvolutionReversal(loc, rewriter, op, filter);
auto resultType = dyn_cast_or_null<ShapedType>(
getTypeConverter()->convertType(op.getResult().getType()));
if (!resultType) {
return rewriter.notifyMatchFailure(op, "type conversion failed");
}
int64_t rank = resultType.getRank();
// Immediately emit an EmptyOp for output tensors with zero dimension.
if (llvm::is_contained(resultType.getShape(), 0)) {
rewriter.replaceOpWithNewOp<tensor::EmptyOp>(op, resultType.getShape(),
resultType.getElementType());
return success();
}
// The output shape is N spatial_dims F.
SmallVector<Value, 8> dynSizes;
if (resultType.isDynamicDim(0)) {
dynSizes.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
}
for (int64_t i = 1, e = rank - 1; i < e; ++i) {
if (resultType.isDynamicDim(i)) {
return rewriter.notifyMatchFailure(
op, "expected output spatial dims to be static shapes");
}
}
if (resultType.isDynamicDim(rank - 1)) {
dynSizes.push_back(rewriter.create<tensor::DimOp>(loc, filter, rank - 1));
}
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
loc, resultType.getShape(), resultType.getElementType(), dynSizes);
Value zeroTensor = fillTensorWithZeros(rewriter, loc, emptyTensor);
linalg::LinalgOp res;
Attribute strides;
if (auto s = op.getWindowStrides())
strides = rewriter.getI64TensorAttr(*s);
Attribute dilations;
if (auto d = op.getRhsDilation())
dilations = rewriter.getI64TensorAttr(*d);
// Apply padding and input dilation.
llvm::SmallVector<int64_t> spatialDimMapping(rank - 2);
std::iota(spatialDimMapping.begin(), spatialDimMapping.end(), 1);
input = applyConvolutionPadding(loc, input, op.getPaddingAttr(),
op.getLhsDilation(), spatialDimMapping,
rewriter);
switch (rank) {
case 2: {
res = rewriter.create<linalg::MatmulOp>(
loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor},
linalg::getPrunedAttributeList(op));
break;
}
case 3: {
res = rewriter.create<linalg::Conv1DNwcWcfOp>(
loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor},
strides, dilations, linalg::getPrunedAttributeList(op));
break;
}
case 4: {
res = rewriter.create<linalg::Conv2DNhwcHwcfOp>(
loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor},
strides, dilations, linalg::getPrunedAttributeList(op));
break;
}
case 5: {
res = rewriter.create<linalg::Conv3DNdhwcDhwcfOp>(
loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor},
strides, dilations, linalg::getPrunedAttributeList(op));
break;
}
default: {
return rewriter.notifyMatchFailure(op, "expected 1/2/3D conv op");
}
}
rewriter.replaceOp(op, res.getOperation()->getResults());
return success();
}
};
/// Handles all possible inputs for the mlir::stablehlo::ConvolutionOp
struct ConvolutionOpGeneralConversion final
: OpConversionPattern<mlir::stablehlo::ConvolutionOp> {
using OpConversionPattern::OpConversionPattern;
/// This lowering proceeds with the following steps:
/// 1. Handle padding and dilation of the input
/// 2. Handle padding and dilation of the window
/// 3. Handle reversal of the window
/// 4. If feature_group_count != 1:
/// - Reshape the input feature dimension, kernel output feature dimension,
/// and output feature dimension.
/// - Create the AffineExpr for the new dimension
/// - Conceptually, this splits the input feature and both output feature
/// dimensions and computes sets of convolutions with these partial views
/// of the values as if they were multiple convolutions combined in a
/// batch.
/// 5: If batch_group_count != 1:
/// - Reshape the input batch dimension, kernel output feature dimension,
/// and output feature dimension.
/// - Create the AffineExpr for the new dimension
/// - Conceptually, this splits the input batch and both output feature
/// dimensions and computes sets of convolutions with these partial views
/// of the values as if they were multiple convolutions combined in a
/// batch.
/// 6. For all dimensions not newly created by a reshape, create the
/// appropriate parallel and reduction dimensions to create a convolution.
/// 7. Create the linalg.generic that computes the multiply-add
/// 8. Reshape the output to the original shape if it was reshaped by the
/// feature or group count attributes.
LogicalResult
matchAndRewrite(mlir::stablehlo::ConvolutionOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
MLIRContext *ctx = op.getContext();
auto resultType = dyn_cast_or_null<ShapedType>(
getTypeConverter()->convertType(op.getResult().getType()));
if (!resultType) {
return rewriter.notifyMatchFailure(op, "type conversion failed");
}
auto reshapedResultShape = resultType.getShape().vec();
if (!resultType.hasStaticShape())
return failure();
// Immediately emit an EmptyOp for output tensors with zero dimension.
if (llvm::is_contained(reshapedResultShape, 0)) {
rewriter.replaceOpWithNewOp<tensor::EmptyOp>(op, reshapedResultShape,
resultType.getElementType());
return success();
}
mlir::stablehlo::ConvDimensionNumbersAttr dimensionNumbers =
op.getDimensionNumbers();
int64_t inputBatchDimension = dimensionNumbers.getInputBatchDimension();
int64_t inputFeatureDimension = dimensionNumbers.getInputFeatureDimension();
ArrayRef<int64_t> inputSpatialDimensions =
dimensionNumbers.getInputSpatialDimensions();
int64_t kernelInputFeatureDimension =
dimensionNumbers.getKernelInputFeatureDimension();
int64_t kernelOutputFeatureDimension =
dimensionNumbers.getKernelOutputFeatureDimension();
ArrayRef<int64_t> kernelSpatialDimensions =
dimensionNumbers.getKernelSpatialDimensions();
int64_t outputBatchDimension = dimensionNumbers.getOutputBatchDimension();
int64_t outputFeatureDimension =
dimensionNumbers.getOutputFeatureDimension();
ArrayRef<int64_t> outputSpatialDimensions =
dimensionNumbers.getOutputSpatialDimensions();
size_t featureGroupCount = op.getFeatureGroupCount();
size_t batchGroupCount = op.getBatchGroupCount();
if (op.getFeatureGroupCount() != 1 && op.getBatchGroupCount() != 1) {
return rewriter.notifyMatchFailure(
op, "only one of feature and batch group counts can be non-one");
}
// Decompose the convolution into an initial padding
Value modifiedLhs = applyConvolutionPadding(
op.getLoc(), adaptor.getLhs(), adaptor.getPaddingAttr(),
adaptor.getLhsDilation(),
op.getDimensionNumbers().getInputSpatialDimensions(), rewriter);
Value modifiedRhs = applyConvolutionPadding(
op.getLoc(), adaptor.getRhs(), nullptr, adaptor.getRhsDilation(),
op.getDimensionNumbers().getKernelSpatialDimensions(), rewriter);
modifiedRhs = applyConvolutionReversal(loc, rewriter, op, modifiedRhs);
// Non-one values for feature or batch group counts will result in reshaped
// inputs and outputs. These mappings are used to keep track of the the new
// index after reshaping has possibly inserted new dimensions.
auto paddedLhsType = cast<ShapedType>(modifiedLhs.getType());
auto paddedRhsType = cast<ShapedType>(modifiedRhs.getType());
SmallVector<int64_t> lhsIndexMapping(paddedLhsType.getRank());
std::iota(lhsIndexMapping.begin(), lhsIndexMapping.end(), 0);
SmallVector<int64_t> rhsIndexMapping(paddedRhsType.getRank());
std::iota(rhsIndexMapping.begin(), rhsIndexMapping.end(), 0);
SmallVector<int64_t> resultIndexMapping(resultType.getRank());
std::iota(resultIndexMapping.begin(), resultIndexMapping.end(), 0);
auto updateDimMappingFromOffset =
[](llvm::SmallVectorImpl<int64_t> &mapping, int64_t offset) {
for (auto &mappingElt : llvm::drop_begin(mapping, offset)) {
mappingElt += 1;
}
};
// The rest of this code prepares the inputs and a single linalg::GenericOp
// to execute the convolution. The final linalg::GenericOp will be iterated
// through based on the following eventual maps.
SmallVector<AffineExpr, 2> srcExprs(paddedLhsType.getRank());
SmallVector<AffineExpr, 2> windowExprs(paddedRhsType.getRank());
SmallVector<AffineExpr, 2> dstExprs(reshapedResultShape.size());
int64_t nextDim = 0;
int64_t rank = resultType.getRank();
auto reshapeShapeVector = [](llvm::ArrayRef<int64_t> oldShape,
llvm::SmallVectorImpl<int64_t> &newShape,
int64_t reshapedDim, int64_t factor) {
newShape.reserve(oldShape.size() + 1);
for (int64_t i : llvm::seq<int64_t>(0, oldShape.size())) {
if (i == reshapedDim) {
newShape.push_back(factor);
newShape.push_back(oldShape[reshapedDim] / factor);
} else {
newShape.push_back(oldShape[i]);
}
}
};
// If batch or feature count groupings exist, represent this through
// reshaping the input to have an additional dimension that these groupings
// exist along, and reduce in that dimension
SmallVector<utils::IteratorType, 3> iterationLoops;
if (featureGroupCount != 1) {
AffineExpr parallelDim = mlir::getAffineDimExpr(nextDim++, ctx);
iterationLoops.push_back(utils::IteratorType::parallel);
// Reshape LHS
{
srcExprs.insert(srcExprs.begin() + inputFeatureDimension, parallelDim);
auto prevDimsRef = paddedLhsType.getShape();
llvm::SmallVector<int64_t> newShape;
reshapeShapeVector(prevDimsRef, newShape, inputFeatureDimension,
featureGroupCount);
updateDimMappingFromOffset(lhsIndexMapping, inputFeatureDimension);
modifiedLhs = rewriter.create<mlir::stablehlo::ReshapeOp>(
loc,
RankedTensorType::get(newShape, paddedLhsType.getElementType()),
modifiedLhs);
}
// Reshape RHS
{
windowExprs.insert(windowExprs.begin() + kernelOutputFeatureDimension,
parallelDim);
auto prevDimsRef = paddedRhsType.getShape();
llvm::SmallVector<int64_t> newShape;
reshapeShapeVector(prevDimsRef, newShape, kernelOutputFeatureDimension,
featureGroupCount);
updateDimMappingFromOffset(rhsIndexMapping,
kernelOutputFeatureDimension);
modifiedRhs = rewriter.create<mlir::stablehlo::ReshapeOp>(
loc,
RankedTensorType::get(newShape, paddedRhsType.getElementType()),
modifiedRhs);
}
// Prepare reshaped output shape
{
dstExprs.insert(dstExprs.begin() + outputFeatureDimension, parallelDim);
updateDimMappingFromOffset(resultIndexMapping, outputFeatureDimension);
reshapedResultShape.insert(reshapedResultShape.begin() +
outputFeatureDimension,
featureGroupCount);
reshapedResultShape[outputFeatureDimension + 1] /= featureGroupCount;
}
}
if (batchGroupCount != 1) {
iterationLoops.push_back(utils::IteratorType::parallel);
AffineExpr parallelDim = mlir::getAffineDimExpr(nextDim++, ctx);
// Reshape LHS
{
srcExprs.insert(srcExprs.begin() + inputBatchDimension, parallelDim);
ArrayRef<int64_t> prevDimsRef = paddedLhsType.getShape();
llvm::SmallVector<int64_t> newShape;
reshapeShapeVector(prevDimsRef, newShape, inputBatchDimension,
batchGroupCount);
updateDimMappingFromOffset(lhsIndexMapping, inputBatchDimension);
modifiedLhs = rewriter.create<mlir::stablehlo::ReshapeOp>(
op.getLoc(),
RankedTensorType::get(newShape, paddedLhsType.getElementType()),
modifiedLhs);
}
// Reshape RHS
{
windowExprs.insert(windowExprs.begin() + kernelOutputFeatureDimension,
parallelDim);
ArrayRef<int64_t> prevDimsRef = paddedRhsType.getShape();
llvm::SmallVector<int64_t> newShape;
reshapeShapeVector(prevDimsRef, newShape, kernelOutputFeatureDimension,
batchGroupCount);
updateDimMappingFromOffset(rhsIndexMapping,
kernelOutputFeatureDimension);
modifiedRhs = rewriter.create<mlir::stablehlo::ReshapeOp>(
op.getLoc(),
RankedTensorType::get(newShape, paddedRhsType.getElementType()),
modifiedRhs);
}
// Prepare reshaped output shape
{
int64_t outputFeatureDim = resultIndexMapping[outputFeatureDimension];
dstExprs.insert(dstExprs.begin() + outputFeatureDim, parallelDim);
updateDimMappingFromOffset(resultIndexMapping, outputFeatureDimension);
reshapedResultShape.insert(
reshapedResultShape.begin() + outputFeatureDim, batchGroupCount);
reshapedResultShape[outputFeatureDim + 1] /= batchGroupCount;
}
}
// Handle input feature dimension
{
iterationLoops.push_back(utils::IteratorType::reduction);
AffineExpr inputFeatureDim = mlir::getAffineDimExpr(nextDim++, ctx);
srcExprs[lhsIndexMapping[inputFeatureDimension]] = inputFeatureDim;
windowExprs[rhsIndexMapping[kernelInputFeatureDimension]] =
inputFeatureDim;
}
// Handle output feature dimension
{
iterationLoops.push_back(utils::IteratorType::parallel);
AffineExpr outputFeatureDim = mlir::getAffineDimExpr(nextDim++, ctx);
dstExprs[resultIndexMapping[outputFeatureDimension]] = outputFeatureDim;
windowExprs[rhsIndexMapping[kernelOutputFeatureDimension]] =
outputFeatureDim;
}
// Handle spatial Dimensions
int64_t numSpatialDims = rank - 2;
for (int64_t i = 0; i < numSpatialDims; ++i) {
iterationLoops.push_back(utils::IteratorType::parallel);
iterationLoops.push_back(utils::IteratorType::reduction);
AffineExpr dim0 = mlir::getAffineDimExpr(nextDim++, ctx);
AffineExpr dim1 = mlir::getAffineDimExpr(nextDim++, ctx);
AffineExpr stride = dim0;
if (op.getWindowStrides().has_value())
stride = stride * op.getWindowStrides().value()[i];
AffineExpr srcExpr = stride + dim1;
srcExprs[lhsIndexMapping[inputSpatialDimensions[i]]] = srcExpr;
dstExprs[resultIndexMapping[outputSpatialDimensions[i]]] = dim0;
windowExprs[rhsIndexMapping[kernelSpatialDimensions[i]]] = dim1;
}
// Handle batch dimension
{
iterationLoops.push_back(utils::IteratorType::parallel);
AffineExpr batchDim = mlir::getAffineDimExpr(nextDim++, ctx);
srcExprs[lhsIndexMapping[inputBatchDimension]] = batchDim;
dstExprs[resultIndexMapping[outputBatchDimension]] = batchDim;
}
// Finally, create the computation
auto inferredMaps =
AffineMap::inferFromExprList({srcExprs, windowExprs, dstExprs}, ctx);
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
loc, reshapedResultShape, resultType.getElementType());
Value zeroTensor = fillTensorWithZeros(rewriter, loc, emptyTensor);
Value convolved =
rewriter
.create<linalg::GenericOp>(
loc,
/*resultTensors=*/
llvm::ArrayRef<Type>(zeroTensor.getType()),
/*inputs=*/
llvm::ArrayRef<Value>({modifiedLhs, modifiedRhs}),
/*outputs=*/llvm::ArrayRef<Value>(zeroTensor), inferredMaps,
iterationLoops,
/*bodyBuild=*/
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange) {
ImplicitLocOpBuilder builder(nestedLoc, nestedBuilder);
linalg::Conv2DOp::regionBuilder(
builder, *builder.getInsertionBlock(), {});
},
linalg::getPrunedAttributeList(op))
.getResult(0);
rewriter.replaceOpWithNewOp<mlir::stablehlo::ReshapeOp>(op, resultType,
convolved);
return success();
}
};
/// Converts stablehlo.convolution operation to
/// linalg.depthwise_conv_2d_input_nhwc_filter_hwcf op or
/// depthwise_conv_2d_input_nhwc_filter_hwc op.
struct DepthwiseConvolutionOpConversion final
: OpConversionPattern<mlir::stablehlo::ConvolutionOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(mlir::stablehlo::ConvolutionOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op.getBatchGroupCount() != 1)
return failure();
// Fall into the normal convolution cases.
if (op.getFeatureGroupCount() == 1)
return failure();
const mlir::stablehlo::ConvDimensionNumbersAttr &dimensionNumbers =
op.getDimensionNumbers();
const int64_t spatialRank =
dimensionNumbers.getInputSpatialDimensions().size();
if (spatialRank == 0 || spatialRank > 3) {
return rewriter.notifyMatchFailure(op, "only support up to 3D for now");
}
// Make sure that this is depthwise convolution.
int64_t inputFeatureDim = dimensionNumbers.getInputFeatureDimension();
int64_t inputFeatureCount =
cast<ShapedType>(op.getLhs().getType()).getDimSize(inputFeatureDim);
if (static_cast<int64_t>(op.getFeatureGroupCount()) != inputFeatureCount) {
return rewriter.notifyMatchFailure(op, "not depth-wise convolution");
}
// Make sure that this convolution has a canonical form.
if (!hasCanonicalDimensionNumbers(dimensionNumbers)) {
return rewriter.notifyMatchFailure(op, "does not have canonical form");
}
Attribute windowStrides;
if (op.getWindowStrides()) {
windowStrides = rewriter.getI64TensorAttr(op.getWindowStrides().value());
} else {
windowStrides = SplatElementsAttr::get(
VectorType::get({spatialRank}, rewriter.getI64Type()),
rewriter.getI64IntegerAttr(1));
}
Attribute rhsDilation;
if (op.getRhsDilation()) {
rhsDilation = rewriter.getI64TensorAttr(op.getRhsDilation().value());
} else {
rhsDilation = SplatElementsAttr::get(
VectorType::get({spatialRank}, rewriter.getI64Type()),
rewriter.getI64IntegerAttr(1));
}
Location loc = op.getLoc();
Value input = adaptor.getLhs();
Value filter = adaptor.getRhs();
auto resultType = dyn_cast_or_null<RankedTensorType>(
getTypeConverter()->convertType(op.getResult().getType()));
if (!resultType) {
return rewriter.notifyMatchFailure(op, "type conversion failed");
}
if (!resultType.hasStaticShape()) {
return rewriter.notifyMatchFailure(op,
"expected output has static shapes");
}
// Immediately emit an EmptyOp for output tensors with zero dimension.
if (llvm::is_contained(resultType.getShape(), 0)) {
rewriter.replaceOpWithNewOp<tensor::EmptyOp>(op, resultType.getShape(),
resultType.getElementType());
return success();
}
// Apply padding and input dilation.
llvm::SmallVector<int64_t> spatialDimMapping(spatialRank);
std::iota(spatialDimMapping.begin(), spatialDimMapping.end(), 1);
input = applyConvolutionPadding(loc, input, op.getPaddingAttr(),
op.getLhsDilation(), spatialDimMapping,
rewriter);
auto filterDims =
llvm::to_vector(cast<ShapedType>(op.getRhs().getType()).getShape());
auto getReassociationIndicesToCollapseLastTwoDims = [](Value v) {
SmallVector<ReassociationIndices> reassociations;
int64_t rank = cast<ShapedType>(v.getType()).getRank();
for (int64_t i = 0; i < rank - 1; ++i)
reassociations.emplace_back(1, i);
reassociations.back().push_back(rank - 1);
return reassociations;
};
int64_t kernelInputFeatureDimension =
dimensionNumbers.getKernelInputFeatureDimension();
int64_t kernelOutputFeatureDimension =
dimensionNumbers.getKernelOutputFeatureDimension();
if (filterDims[kernelInputFeatureDimension] *
filterDims[kernelOutputFeatureDimension] !=
static_cast<int64_t>(op.getFeatureGroupCount())) {
// For cases where channel multiplier != 1
// Reshaping filter shape
// [filter_height, filter_width, 1, kernel-output-feature].
// to
// [filter_height, filter_width, feature_group_count,
// kernel-output-feature/feature_group_count ]
SmallVector<int64_t> reshapedFilterDims;
reshapedFilterDims.assign(filterDims.begin(), filterDims.end());
Value reshapedFilter = filter;
if (filterDims[kernelInputFeatureDimension] == 1) {
reshapedFilterDims[kernelInputFeatureDimension] =
op.getFeatureGroupCount();
reshapedFilterDims[kernelOutputFeatureDimension] /=
op.getFeatureGroupCount();
auto reshapedFilterType = RankedTensorType::get(
reshapedFilterDims,
cast<ShapedType>(op.getRhs().getType()).getElementType());
reshapedFilter = rewriter.create<mlir::stablehlo::ReshapeOp>(
loc, reshapedFilterType, filter);
}
ArrayRef<int64_t> outputDims = resultType.getShape();
int64_t channelMultiplier = reshapedFilterDims.back();
SmallVector<int64_t> reshapedOutputDims;
reshapedOutputDims.assign(outputDims.begin(), outputDims.end());
reshapedOutputDims.push_back(channelMultiplier);
reshapedOutputDims[reshapedOutputDims.size() - 2] /= channelMultiplier;
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
loc, reshapedOutputDims, resultType.getElementType());
Value zeroTensor = fillTensorWithZeros(rewriter, loc, emptyTensor);
auto reshapedOutputType = RankedTensorType::get(
reshapedOutputDims, resultType.getElementType());
Value conv;
switch (spatialRank) {
case 1: {
conv =
rewriter
.create<linalg::DepthwiseConv1DNwcWcmOp>(
loc, reshapedOutputType, ValueRange{input, reshapedFilter},
ValueRange{zeroTensor}, windowStrides, rhsDilation,
linalg::getPrunedAttributeList(op))
.getResult(0);
break;
}
case 2: {
conv =
rewriter
.create<linalg::DepthwiseConv2DNhwcHwcmOp>(
loc, reshapedOutputType, ValueRange{input, reshapedFilter},
ValueRange{zeroTensor}, windowStrides, rhsDilation,
linalg::getPrunedAttributeList(op))
.getResult(0);
break;
}
case 3: {
conv =
rewriter
.create<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
loc, reshapedOutputType, ValueRange{input, reshapedFilter},
ValueRange{zeroTensor}, windowStrides, rhsDilation,
linalg::getPrunedAttributeList(op))
.getResult(0);
break;
}
default:
llvm_unreachable("Unhandled case");
}
// Create a Linalg reshape op that converts the output from 5 dimensions
// into 4 dimensions (by collapsing the last two dimensions). This is
// needed because linalg.depthwise_conv_2d_input_nhwc_filter_hwcf returns
// 5 dimensions for the output.
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
op, resultType, conv,
getReassociationIndicesToCollapseLastTwoDims(conv));
} else {
// For cases where channel multiplier == 1
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
loc, resultType.getShape(), resultType.getElementType());
Value zeroTensor = fillTensorWithZeros(rewriter, loc, emptyTensor);
// Create a Linalg reshape op that converts the filter from 4 dimensions
// into 3 dimensions (by droping the unit dimension). This is needed
// because linalg.depthwise_conv_2d_input_nhwc_filter_hwc expects 3
// dimensions for the filter.
filterDims[filterDims.size() - 2] =
static_cast<int64_t>(op.getFeatureGroupCount());
filterDims.pop_back();
RankedTensorType filterShape =
RankedTensorType::get(filterDims, op.getType().getElementType());
Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
loc, filterShape, filter,
getReassociationIndicesToCollapseLastTwoDims(filter));
switch (spatialRank) {
case 1:
rewriter.replaceOpWithNewOp<linalg::DepthwiseConv1DNwcWcOp>(
op, resultType, ValueRange{input, reshapedFilter},
ValueRange{zeroTensor}, windowStrides, rhsDilation,
linalg::getPrunedAttributeList(op));
break;
case 2:
rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNhwcHwcOp>(
op, resultType, ValueRange{input, reshapedFilter},
ValueRange{zeroTensor}, windowStrides, rhsDilation,
linalg::getPrunedAttributeList(op));
break;
case 3:
rewriter.replaceOpWithNewOp<linalg::DepthwiseConv3DNdhwcDhwcOp>(
op, resultType, ValueRange{input, reshapedFilter},
ValueRange{zeroTensor}, windowStrides, rhsDilation,
linalg::getPrunedAttributeList(op));
break;
}
}
return success();
}
};
} // namespace
namespace detail {
void populateStableHloConvolutionToLinalgConversionPatterns(
MLIRContext *context, TypeConverter &typeConverter,
RewritePatternSet *patterns) {
// Ensure specialized patterns are higher priority than their generic
// versions.
patterns
->add<NormalConvolutionOpConversion, DepthwiseConvolutionOpConversion>(
typeConverter, context, PatternBenefit(2));
patterns->add<ConvolutionOpGeneralConversion>(typeConverter, context);
}
} // namespace detail
} // namespace mlir::iree_compiler::stablehlo