blob: 422c887211b6137e3941d9907e4174913cc57f6e [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <numeric>
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace Flow {
namespace {
static llvm::cl::opt<bool> extractPadFromConv(
"iree-flow-extract-pad-from-conv",
llvm::cl::desc("Extract padding attributes from conv op"),
llvm::cl::init(true));
static llvm::cl::opt<bool> orderConvFeatures(
"iree-flow-order-conv-features",
llvm::cl::desc("Guarantees input/output features ordered for conv kernel"),
llvm::cl::init(true));
static llvm::cl::opt<bool> conv1x1toDot(
"iree-flow-1x1-conv-to-dot",
llvm::cl::desc("Rewrites mhlo.conv with 1x1 filter into mhlo.dot"),
llvm::cl::init(true));
static bool isAllZero(DenseIntElementsAttr attr) {
if (!attr.isSplat()) return false;
return attr.getSplatValue<IntegerAttr>().getInt() == 0;
}
static bool isIota(ArrayRef<int64_t> array) {
for (auto it : llvm::enumerate(array)) {
if (it.index() != it.value()) {
return false;
}
}
return true;
}
/// Returns true if the conv op has padding attribute, and that it has
/// non-zero entries.
static bool hasPadding(mhlo::ConvOp op) {
Optional<DenseIntElementsAttr> padding = op.padding();
if (!padding) return false;
return llvm::any_of(padding.getValue(),
[](APInt v) -> bool { return !v.isNullValue(); });
}
static DenseIntElementsAttr make1DElementsAttr(PatternRewriter &rewriter,
ArrayRef<int64_t> integers) {
auto type = RankedTensorType::get({static_cast<int64_t>(integers.size())},
rewriter.getIntegerType(64));
return DenseIntElementsAttr::get(type, integers);
}
class DecomposeLog1PPattern : public OpRewritePattern<mhlo::Log1pOp> {
public:
using OpRewritePattern<mhlo::Log1pOp>::OpRewritePattern;
LogicalResult matchAndRewrite(mhlo::Log1pOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto type = op.operand().getType().cast<TensorType>();
DenseElementsAttr attr =
DenseElementsAttr::get(type, rewriter.getF32FloatAttr(1.0));
auto one = rewriter.create<ConstantOp>(loc, attr);
auto x = rewriter.create<mhlo::AddOp>(loc, op.operand(), one);
rewriter.replaceOpWithNewOp<mhlo::LogOp>(op, x);
return success();
}
};
class DecomposeExpM1Pattern : public OpRewritePattern<mhlo::Expm1Op> {
public:
using OpRewritePattern<mhlo::Expm1Op>::OpRewritePattern;
LogicalResult matchAndRewrite(mhlo::Expm1Op op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto type = op.operand().getType().cast<TensorType>();
DenseElementsAttr attr =
DenseElementsAttr::get(type, rewriter.getF32FloatAttr(1.0));
auto one = rewriter.create<ConstantOp>(loc, attr);
auto x = rewriter.create<mhlo::ExpOp>(loc, op.operand());
rewriter.replaceOpWithNewOp<mhlo::SubOp>(op, x, one);
return success();
}
};
class ExtractConvOpPaddingAttributes : public OpRewritePattern<mhlo::ConvOp> {
public:
using OpRewritePattern<mhlo::ConvOp>::OpRewritePattern;
LogicalResult matchAndRewrite(mhlo::ConvOp op,
PatternRewriter &rewriter) const override {
if (!hasPadding(op)) return failure();
auto inputType = op.lhs().getType().cast<ShapedType>();
int rank = inputType.getRank();
// TODO(suderman): Add proper support for padding + dilation for codegen.
// We can't extract padding if the left hand side has dilation.
if (op.lhs_dilation().hasValue()) {
for (auto val : op.lhs_dilation().getValue().getIntValues()) {
if (val != 1) {
return failure();
}
}
}
SmallVector<int64_t, 4> paddingLow, paddingHigh, interiorPadding, shape;
paddingLow.append(rank, 0);
paddingHigh.append(rank, 0);
interiorPadding.append(rank, 0);
for (auto iter :
llvm::enumerate(op.dimension_numbers().input_spatial_dimensions())) {
unsigned idx = iter.index();
unsigned dim = iter.value().getZExtValue();
paddingLow[dim] = op.paddingAttr().getValue<int64_t>({idx, 0});
paddingHigh[dim] = op.paddingAttr().getValue<int64_t>({idx, 1});
}
for (unsigned i = 0; i < rank; ++i) {
// mhlo.pad doesn't support dynamic shape.
if (inputType.isDynamicDim(i)) return failure();
int size = inputType.getShape()[i];
shape.push_back(size + paddingLow[i] + paddingHigh[i]);
}
auto toDenseAttr = [&rewriter](ArrayRef<int64_t> elements) {
return DenseIntElementsAttr::get(
RankedTensorType::get(elements.size(), rewriter.getIntegerType(64)),
elements);
};
auto loc = op.getLoc();
auto padResultType =
RankedTensorType::get(shape, inputType.getElementType());
Attribute zeroAttr = rewriter.getZeroAttr(
RankedTensorType::get({}, inputType.getElementType()));
auto zero = rewriter.create<ConstantOp>(loc, zeroAttr);
auto padOp = rewriter.create<mhlo::PadOp>(
loc, padResultType, op.lhs(), zero, toDenseAttr(paddingLow),
toDenseAttr(paddingHigh), toDenseAttr(interiorPadding));
auto resultType = op.getResult().getType();
auto newOp = rewriter.create<mhlo::ConvOp>(
op.getLoc(), resultType, padOp.getResult(), op.rhs(),
op.window_stridesAttr(), /*padding=*/nullptr, op.lhs_dilationAttr(),
op.rhs_dilationAttr(), /*window_reversal=*/nullptr,
op.dimension_numbersAttr(), op.feature_group_countAttr(),
op.batch_group_countAttr(), op.precision_configAttr());
rewriter.replaceOp(op, newOp.getResult());
return success();
}
};
// Guarantee that the input dimensions are ordered batch, spatial_dims, feature
// dim.
class ReorderConvOpInputDimensions : public OpRewritePattern<mhlo::ConvOp> {
public:
using OpRewritePattern<mhlo::ConvOp>::OpRewritePattern;
LogicalResult matchAndRewrite(mhlo::ConvOp op,
PatternRewriter &rewriter) const override {
auto lhsType = op.lhs().getType().cast<ShapedType>();
auto lhsShape = lhsType.getShape();
if (!lhsType.hasRank()) {
return failure();
}
auto dimensionNumbers = op.dimension_numbers();
auto inputSpatialDimensions = dimensionNumbers.input_spatial_dimensions();
llvm::SmallVector<int64_t, 4> spatialDims;
for (auto dim : inputSpatialDimensions) {
spatialDims.push_back(dim.getSExtValue());
}
// Compute the permutation required to create a standard order.
llvm::SmallVector<int64_t, 4> permutations;
permutations.push_back(
dimensionNumbers.input_batch_dimension().getValue().getSExtValue());
permutations.append(spatialDims.begin(), spatialDims.end());
permutations.push_back(
dimensionNumbers.input_feature_dimension().getValue().getSExtValue());
// If the permutation is iota then no reordering is required.
if (isIota(permutations)) {
return failure();
}
llvm::SmallVector<int64_t, 4> transposeShape;
for (auto p : permutations) {
transposeShape.push_back(lhsShape[p]);
}
auto transposed = rewriter.create<mhlo::TransposeOp>(
op.getLoc(),
RankedTensorType::get(transposeShape, lhsType.getElementType()),
op.lhs(), rewriter.getI64TensorAttr(permutations));
llvm::SmallVector<int64_t, 4> newSpatialDimensions(spatialDims.size());
std::iota(newSpatialDimensions.begin(), newSpatialDimensions.end(), 1);
auto newDimensionNumbers = mhlo::ConvDimensionNumbers::get(
/*input_batch_dimension=*/rewriter.getI64IntegerAttr(0),
/*input_feature_dimension=*/
rewriter.getI64IntegerAttr(newSpatialDimensions.size() + 1),
/*input_spatial_dimensions=*/
rewriter.getI64TensorAttr(newSpatialDimensions),
dimensionNumbers.kernel_input_feature_dimension(),
dimensionNumbers.kernel_output_feature_dimension(),
dimensionNumbers.kernel_spatial_dimensions(),
dimensionNumbers.output_batch_dimension(),
dimensionNumbers.output_feature_dimension(),
dimensionNumbers.output_spatial_dimensions(), op.getContext());
SmallVector<Value, 2> operands = {transposed, op.rhs()};
auto newConv = rewriter.create<mhlo::ConvOp>(op.getLoc(), op.getType(),
operands, op.getAttrs());
newConv.dimension_numbersAttr(newDimensionNumbers);
rewriter.replaceOp(op, newConv.getResult());
return success();
}
};
struct ReorderConvOpKernelDimensions : public OpRewritePattern<mhlo::ConvOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mhlo::ConvOp op,
PatternRewriter &rewriter) const override {
auto kernel = op.rhs();
auto kernelType = kernel.getType().cast<ShapedType>();
if (!kernelType.hasRank()) return failure();
auto kernelShape = kernelType.getShape();
auto dimensionNumbers = op.dimension_numbers();
auto inputSpatialDimensions = dimensionNumbers.kernel_spatial_dimensions();
llvm::SmallVector<int64_t, 4> spatialDims;
for (auto dim : inputSpatialDimensions) {
spatialDims.push_back(dim.getSExtValue());
}
auto inputFeatureDimension =
dimensionNumbers.kernel_input_feature_dimension().getInt();
auto outputFeatureDimension =
dimensionNumbers.kernel_output_feature_dimension().getInt();
// Compute the permutation for the transpose.
llvm::SmallVector<int64_t, 4> permutation(spatialDims);
permutation.push_back(inputFeatureDimension);
permutation.push_back(outputFeatureDimension);
// If the permutation is iota, then no transpose is required.
if (isIota(permutation)) return failure();
llvm::SmallVector<int64_t, 4> transposeShape;
for (auto perm : permutation) {
transposeShape.push_back(kernelShape[perm]);
}
llvm::SmallVector<int64_t, 4> newSpatialDimensions(spatialDims.size());
std::iota(newSpatialDimensions.begin(), newSpatialDimensions.end(), 0);
auto transposeKernel = rewriter.create<mhlo::TransposeOp>(
op.getLoc(),
RankedTensorType::get(transposeShape, kernelType.getElementType()),
kernel, rewriter.getI64TensorAttr(permutation));
auto newDimensionNumbers = mhlo::ConvDimensionNumbers::get(
dimensionNumbers.input_batch_dimension(),
dimensionNumbers.input_feature_dimension(),
dimensionNumbers.input_spatial_dimensions(),
/*kernel_input_feature_dimension=*/
rewriter.getI64IntegerAttr(newSpatialDimensions.size()),
/*kernel_output_feature_dimension=*/
rewriter.getI64IntegerAttr(newSpatialDimensions.size() + 1),
rewriter.getI64TensorAttr(newSpatialDimensions),
dimensionNumbers.output_batch_dimension(),
dimensionNumbers.output_feature_dimension(),
dimensionNumbers.output_spatial_dimensions(), op.getContext());
SmallVector<Value, 2> operands = {op.lhs(), transposeKernel};
mhlo::ConvOp newConv = rewriter.create<mhlo::ConvOp>(
op.getLoc(), op.getType(), operands, op.getAttrs());
newConv.dimension_numbersAttr(newDimensionNumbers);
rewriter.replaceOp(op, {newConv.getResult()});
return success();
}
};
// Guarantee that the output dimensions are ordered batch, spatial_dims, feature
// dim.
class ReorderConvOpOutputDimensions : public OpRewritePattern<mhlo::ConvOp> {
public:
using OpRewritePattern<mhlo::ConvOp>::OpRewritePattern;
LogicalResult matchAndRewrite(mhlo::ConvOp op,
PatternRewriter &rewriter) const override {
auto resultType = op.getType().cast<ShapedType>();
auto resultShape = resultType.getShape();
if (!resultType.hasRank()) {
return failure();
}
auto dimensionNumbers = op.dimension_numbers();
auto outputSpatialDimensions = dimensionNumbers.output_spatial_dimensions();
llvm::SmallVector<int64_t, 4> spatialDims;
for (auto dim : outputSpatialDimensions) {
spatialDims.push_back(dim.getSExtValue());
}
// Compute the permutation to transpose to an ordered output.
llvm::SmallVector<int64_t, 4> permutation;
permutation.push_back(
dimensionNumbers.output_batch_dimension().getValue().getSExtValue());
permutation.append(spatialDims.begin(), spatialDims.end());
permutation.push_back(
dimensionNumbers.output_feature_dimension().getValue().getSExtValue());
// If the permutation is iota then no reordering is required.
if (isIota(permutation)) {
return failure();
}
// Compute what the new conv shape should be.
llvm::SmallVector<int64_t, 4> convShape;
for (auto p : permutation) {
convShape.push_back(resultShape[p]);
}
// Compute the inverse transpose to unordered and ordered output.
llvm::SmallVector<int64_t, 4> invertPermutation(permutation.size());
for (auto it : llvm::enumerate(permutation)) {
invertPermutation[it.value()] = it.index();
}
llvm::SmallVector<int64_t, 4> newSpatialDimensions(spatialDims.size());
std::iota(newSpatialDimensions.begin(), newSpatialDimensions.end(), 1);
auto newDimensionNumbers = mhlo::ConvDimensionNumbers::get(
dimensionNumbers.input_batch_dimension(),
dimensionNumbers.input_feature_dimension(),
dimensionNumbers.input_spatial_dimensions(),
dimensionNumbers.kernel_input_feature_dimension(),
dimensionNumbers.kernel_output_feature_dimension(),
dimensionNumbers.kernel_spatial_dimensions(),
/*output_batch_dimension=*/rewriter.getI64IntegerAttr(0),
/*output_feature_dimension=*/
rewriter.getI64IntegerAttr(newSpatialDimensions.size() + 1),
/*output_spatial_dimensions=*/
rewriter.getI64TensorAttr(newSpatialDimensions), op.getContext());
SmallVector<Value, 2> operands = {op.lhs(), op.rhs()};
auto newConv = rewriter.create<mhlo::ConvOp>(
op.getLoc(),
RankedTensorType::get(convShape, resultType.getElementType()), operands,
op.getAttrs());
newConv.dimension_numbersAttr(newDimensionNumbers);
auto transposed = rewriter.create<mhlo::TransposeOp>(
op.getLoc(), resultType, newConv,
rewriter.getI64TensorAttr(invertPermutation));
rewriter.replaceOp(op, transposed.getResult());
return success();
}
};
class ExtractReduceWindowOpPaddingAttributes
: public OpRewritePattern<mhlo::ReduceWindowOp> {
public:
using OpRewritePattern<mhlo::ReduceWindowOp>::OpRewritePattern;
LogicalResult matchAndRewrite(mhlo::ReduceWindowOp op,
PatternRewriter &rewriter) const override {
if (!op.padding()) return failure();
if (op.base_dilations() || op.window_dilations()) return failure();
if (isAllZero(op.paddingAttr())) return failure();
auto inputType = op.operand().getType().cast<ShapedType>();
int rank = inputType.getRank();
SmallVector<int64_t, 4> paddingLow, paddingHigh, interiorPadding, shape;
for (unsigned i = 0; i < rank; ++i) {
// mhlo.pad doesn't support dynamic shape.
if (inputType.isDynamicDim(i)) return failure();
interiorPadding.push_back(0);
paddingLow.push_back(op.paddingAttr().getValue<int64_t>({i, 0}));
paddingHigh.push_back(op.paddingAttr().getValue<int64_t>({i, 1}));
int size = inputType.getShape()[i];
shape.push_back(size + paddingLow.back() + paddingHigh.back());
}
auto toDenseAttr = [&rewriter](ArrayRef<int64_t> elements) {
return DenseIntElementsAttr::get(
RankedTensorType::get(elements.size(), rewriter.getIntegerType(64)),
elements);
};
auto loc = op.getLoc();
auto padResultType =
RankedTensorType::get(shape, inputType.getElementType());
auto padOp = rewriter.create<mhlo::PadOp>(
loc, padResultType, op.operand(), op.init_value(),
toDenseAttr(paddingLow), toDenseAttr(paddingHigh),
toDenseAttr(interiorPadding));
auto newOp = rewriter.create<mhlo::ReduceWindowOp>(
loc, op.getResult().getType(), padOp, op.init_value(),
op.window_dimensions(), op.window_stridesAttr(),
op.base_dilationsAttr(), op.window_dilationsAttr(),
/*padding=*/nullptr);
rewriter.inlineRegionBefore(op.body(), newOp.body(), newOp.body().begin());
rewriter.replaceOp(op, newOp.getResult());
return success();
}
};
// Rewrites an n-d (n, d1, d2, d3, ..., ci) * (1, 1, 1, ..., ci, co)
// as (n * d1 * d2 * d3, ..., ci) . (ci, co)
class Lower1x1ConvolutionToDotOp : public OpRewritePattern<mhlo::ConvOp> {
public:
using OpRewritePattern<mhlo::ConvOp>::OpRewritePattern;
LogicalResult matchAndRewrite(mhlo::ConvOp op,
PatternRewriter &rewriter) const override {
// Only 1x1 convolution no groups will match.
if (op.feature_group_count() != 1) return failure();
Value input = op.lhs();
Value filter = op.rhs();
Value output = op.getResult();
auto inputShapeType = input.getType().dyn_cast_or_null<RankedTensorType>();
auto filterShapeType =
filter.getType().dyn_cast_or_null<RankedTensorType>();
auto outputShapeType =
output.getType().dyn_cast_or_null<RankedTensorType>();
if (!inputShapeType || !filterShapeType || !outputShapeType) {
return failure();
}
auto inputShape = inputShapeType.getShape();
auto filterShape = filterShapeType.getShape();
auto inputBatchDim =
op.dimension_numbers().input_batch_dimension().getInt();
auto inputFeatureDim =
op.dimension_numbers().input_feature_dimension().getInt();
auto kernelInputFeatureDim =
op.dimension_numbers().kernel_input_feature_dimension().getInt();
auto kernelOutputFeatureDim =
op.dimension_numbers().kernel_output_feature_dimension().getInt();
// Match input (n, d1, d2, ..., ci) format
if (inputFeatureDim != (inputShape.size() - 1) || inputBatchDim != 0) {
return failure();
}
// Match filter (k1, k2, ..., ci, co) format
if (kernelInputFeatureDim != (filterShape.size() - 2) ||
kernelOutputFeatureDim != (filterShape.size() - 1)) {
return failure();
}
// Check 1x1x... kernel spatial size.
for (auto dim : op.dimension_numbers().kernel_spatial_dimensions()) {
if (filterShape[dim.getZExtValue()] != 1) return failure();
}
// Check dilation & strides are ones.
if (op.window_strides()) {
for (auto stride : op.window_strides()->getValues<int64_t>()) {
if (stride != 1) return failure();
}
}
if (op.rhs_dilation()) {
for (auto dilation : op.rhs_dilation()->getValues<int64_t>()) {
if (dilation != 1) return failure();
}
}
int64_t spatialSize = inputShape[0];
for (auto dim : op.dimension_numbers().input_spatial_dimensions()) {
spatialSize *= inputShape[dim.getZExtValue()];
}
Type reshapedInputType =
RankedTensorType::get({spatialSize, inputShape[inputFeatureDim]},
inputShapeType.getElementType());
Type reshapedFilterTYpe =
RankedTensorType::get({filterShape[kernelInputFeatureDim],
filterShape[kernelOutputFeatureDim]},
filterShapeType.getElementType());
Type dotResultType = RankedTensorType::get(
{spatialSize, filterShape[kernelOutputFeatureDim]},
outputShapeType.getElementType());
Value reshapedInput =
rewriter.create<mhlo::ReshapeOp>(op.getLoc(), reshapedInputType, input);
Value reshapedFilter = rewriter.create<mhlo::ReshapeOp>(
op.getLoc(), reshapedFilterTYpe, filter);
Value dotResult = rewriter.create<mhlo::DotOp>(
op.getLoc(), dotResultType, reshapedInput, reshapedFilter,
rewriter.getStrArrayAttr({"HIGHEST", "HIGHEST"}));
Value reshapedResult = rewriter.create<mhlo::ReshapeOp>(
op.getLoc(), outputShapeType, dotResult);
rewriter.replaceOp(op, reshapedResult);
return success();
}
};
// Adjust the shape of depthwise_conv filter where is applied by mhlo.
class AdjustDepthwiseFilterShape : public OpRewritePattern<mhlo::ConvOp> {
public:
using OpRewritePattern<mhlo::ConvOp>::OpRewritePattern;
LogicalResult matchAndRewrite(mhlo::ConvOp op,
PatternRewriter &rewriter) const override {
const auto featureInDim =
op.dimension_numbers().kernel_input_feature_dimension().getInt();
const auto featureOutDim =
op.dimension_numbers().kernel_output_feature_dimension().getInt();
const auto &kernelShape = op.rhs().getType().cast<ShapedType>().getShape();
if (kernelShape[featureInDim] != 1) return failure();
const auto groupCount = op.feature_group_count();
if (groupCount == 1) return failure();
if (kernelShape[featureOutDim] % groupCount != 0) return failure();
SmallVector<int64_t, 4> newShape(kernelShape.begin(), kernelShape.end());
newShape[featureInDim] = groupCount;
newShape[featureOutDim] /= groupCount;
auto loc = op.getLoc();
auto elemType = op.rhs().getType().cast<ShapedType>().getElementType();
auto reshapeOp = rewriter.create<mhlo::ReshapeOp>(
loc, RankedTensorType::get(newShape, elemType), op.rhs());
auto resultType = op.getResult().getType();
SmallVector<Value, 2> operands = {op.lhs(), reshapeOp.getResult()};
auto newOp = rewriter.create<mhlo::ConvOp>(op.getLoc(), resultType,
operands, op.getAttrs());
rewriter.replaceOp(op, newOp.getResult());
return success();
}
};
// Rewrites rank-3 mhlo.dot_general so lhs contraction dimension is
// inner most (2) and rhs contraction dimension is dim right after batch
// dimension. The pattern inserts transposes so the dot_general always has the
// form: {batch_dim, parallel, contraction}.{batch_dim, contraction, parallel}
class TransposeRank3GenericDotGeneral
: public OpRewritePattern<mhlo::DotGeneralOp> {
public:
using OpRewritePattern<mhlo::DotGeneralOp>::OpRewritePattern;
LogicalResult matchAndRewrite(mhlo::DotGeneralOp op,
PatternRewriter &rewriter) const override {
auto lhsShapeType = op.lhs().getType().dyn_cast<RankedTensorType>();
auto rhsShapeType = op.rhs().getType().dyn_cast<RankedTensorType>();
auto resultType = op.getResult().getType().dyn_cast<RankedTensorType>();
if (!lhsShapeType || !rhsShapeType || !resultType) return failure();
if (resultType.getRank() != 3) return failure();
if (op.dot_dimension_numbers().lhs_contracting_dimensions().size() != 1 ||
op.dot_dimension_numbers().rhs_contracting_dimensions().size() != 1)
return failure();
int64_t lhsBatchDim = (*op.dot_dimension_numbers()
.lhs_batching_dimensions()
.int_value_begin())
.getSExtValue();
int64_t rhsBatchDim = (*op.dot_dimension_numbers()
.rhs_batching_dimensions()
.int_value_begin())
.getSExtValue();
int64_t lhsContractionDim = (*op.dot_dimension_numbers()
.lhs_contracting_dimensions()
.int_value_begin())
.getSExtValue();
int64_t rhsContractionDim = (*op.dot_dimension_numbers()
.rhs_contracting_dimensions()
.int_value_begin())
.getSExtValue();
// Only accept rank-3 tensors with dim order when dims are :
// lhs : {batch_dim, contraction, parallel}
// rhs : {batch_dim, parallel, contraction}
if (lhsBatchDim != 0 || rhsBatchDim != 0) return failure();
// No transposes are needed.
if (lhsContractionDim == 2 && rhsContractionDim == 1) return failure();
Value lhs = op.lhs(), rhs = op.rhs();
// transpose {batch_dim, contraction, parallel} case.
if (lhsContractionDim == 1) {
Type transposedType = RankedTensorType::get(
{lhsShapeType.getDimSize(0), lhsShapeType.getDimSize(2),
lhsShapeType.getDimSize(1)},
resultType.getElementType());
lhs = rewriter.create<mhlo::TransposeOp>(
op.getLoc(), transposedType, lhs,
make1DElementsAttr(rewriter, {0, 2, 1}));
}
// transpose {batch_dim, contraction, parallel} case.
if (rhsContractionDim == 2) {
Type transposedType = RankedTensorType::get(
{rhsShapeType.getDimSize(0), rhsShapeType.getDimSize(2),
rhsShapeType.getDimSize(1)},
resultType.getElementType());
rhs = rewriter.create<mhlo::TransposeOp>(
op.getLoc(), transposedType, rhs,
make1DElementsAttr(rewriter, {0, 2, 1}));
}
auto dimensionNumbers = mhlo::DotDimensionNumbers::get(
/*lhs_batching_dimensions=*/make1DElementsAttr(rewriter, {0}),
/*rhs_batching_dimensions=*/make1DElementsAttr(rewriter, {0}),
/*lhs_contracting_dimensions=*/make1DElementsAttr(rewriter, {2}),
/*rhs_contracting_dimensions=*/
make1DElementsAttr(rewriter, {1}), rewriter.getContext());
Value result = rewriter.create<mhlo::DotGeneralOp>(
op.getLoc(), op.getType(), lhs, rhs, dimensionNumbers,
op.precision_configAttr());
rewriter.replaceOp(op, result);
return success();
}
};
// Rewrite mhlo.dot_general to operate on rank-3 tensors when reduction dims are
// in consecutive order and not spliting the domain. This pattern inserts
// reshapes to collapse consecutive reduction and parallel dims to always
// generate a rank-3 dot_general op.
class RankReducedDotGeneral : public OpRewritePattern<mhlo::DotGeneralOp> {
public:
using OpRewritePattern<mhlo::DotGeneralOp>::OpRewritePattern;
LogicalResult matchAndRewrite(mhlo::DotGeneralOp op,
PatternRewriter &rewriter) const override {
auto lhsShapeType = op.lhs().getType().dyn_cast<ShapedType>();
auto rhsShapeType = op.rhs().getType().dyn_cast<ShapedType>();
auto resultType = op.getResult().getType().dyn_cast<ShapedType>();
if (!lhsShapeType || !rhsShapeType || !resultType) return failure();
if (!lhsShapeType.hasStaticShape() || !rhsShapeType.hasStaticShape())
return failure();
if (resultType.getRank() <= 3) return failure();
mhlo::DotDimensionNumbers dimNumbers = op.dot_dimension_numbers();
auto lhsBatchingDims = llvm::to_vector<4>(
llvm::map_range(dimNumbers.lhs_batching_dimensions(),
[](APInt v) { return v.getSExtValue(); }));
auto rhsBatchingDims = llvm::to_vector<4>(
llvm::map_range(dimNumbers.rhs_batching_dimensions(),
[](APInt v) { return v.getSExtValue(); }));
auto lhsContractingDims = llvm::to_vector<4>(
llvm::map_range(dimNumbers.lhs_contracting_dimensions(),
[](APInt v) { return v.getSExtValue(); }));
auto rhsContractingDims = llvm::to_vector<4>(
llvm::map_range(dimNumbers.rhs_contracting_dimensions(),
[](APInt v) { return v.getSExtValue(); }));
if (lhsBatchingDims.empty() || rhsBatchingDims.empty()) return failure();
llvm::sort(lhsBatchingDims);
llvm::sort(lhsContractingDims);
llvm::sort(rhsBatchingDims);
llvm::sort(rhsContractingDims);
auto isConsecutive = [](ArrayRef<int64_t> array) {
for (int i = 1; i < array.size(); ++i) {
if (array[i] - array[i - 1] != 1) return false;
}
return true;
};
auto isDomainSplit = [](ArrayRef<int64_t> shape,
ArrayRef<int64_t> batchingDims,
ArrayRef<int64_t> contractingDims) {
// Batching and contracting are contiguous.
if ((contractingDims.front() - batchingDims.back()) == 1) return false;
// Contracting dims are inner most.
if (contractingDims.back() == (shape.size() - 1)) return false;
return true;
};
if (!isConsecutive(lhsBatchingDims) || !isConsecutive(lhsContractingDims) ||
!isConsecutive(rhsBatchingDims) || !isConsecutive(rhsContractingDims))
return failure();
if (isDomainSplit(lhsShapeType.getShape(), lhsBatchingDims,
lhsContractingDims) ||
isDomainSplit(rhsShapeType.getShape(), rhsBatchingDims,
rhsContractingDims))
return failure();
// Collapsing shape into a rank-3 tensor, returns newCollabsedShape
// contraction and parallel dim indices.
auto computeCollapsedShape = [](ArrayRef<int64_t> shape,
ArrayRef<int64_t> batchingDims,
ArrayRef<int64_t> contractingDims) {
auto newRank =
shape.size() - batchingDims.size() - contractingDims.size() + 2;
auto batchingSize = std::accumulate(
batchingDims.begin(), batchingDims.end(), 1,
[shape](const int64_t accum, const int64_t index) -> int64_t {
return accum * shape[index];
});
auto contractingSize = std::accumulate(
contractingDims.begin(), contractingDims.end(), 1,
[shape](const int64_t accum, const int64_t index) -> int64_t {
return accum * shape[index];
});
int parallelDimIndex, contractingDimIndex, parallelDimSize = 1;
if (contractingDims.front() - batchingDims.back() > 1) {
parallelDimIndex = 1;
contractingDimIndex = 2;
for (int i = batchingDims.back() + 1; i < contractingDims.front();
++i) {
parallelDimSize *= shape[i];
}
} else {
contractingDimIndex = 1;
parallelDimIndex = 2;
for (int i = contractingDims.back() + 1; i < shape.size(); ++i) {
parallelDimSize *= shape[i];
}
}
llvm::SmallVector<int64_t, 4> newShape(newRank);
newShape[0] = batchingSize;
newShape[contractingDimIndex] = contractingSize;
newShape[parallelDimIndex] = parallelDimSize;
return std::make_tuple(newShape, contractingDimIndex, parallelDimIndex);
};
int lhsContractingDimIndex, rhsContractingDimIndex, lhsParallelDimIndex,
rhsParallelDimIndex;
SmallVector<int64_t, 4> lhsNewShape, rhsNewShape;
std::tie(lhsNewShape, lhsContractingDimIndex, lhsParallelDimIndex) =
computeCollapsedShape(lhsShapeType.getShape(), lhsBatchingDims,
lhsContractingDims);
std::tie(rhsNewShape, rhsContractingDimIndex, rhsParallelDimIndex) =
computeCollapsedShape(rhsShapeType.getShape(), rhsBatchingDims,
rhsContractingDims);
SmallVector<int64_t, 4> resultNewShape = {lhsNewShape[0],
lhsNewShape[lhsParallelDimIndex],
rhsNewShape[rhsParallelDimIndex]};
Type dotGeneralResultType =
RankedTensorType::get(resultNewShape, resultType.getElementType());
auto loc = op.getLoc();
Value reshapedLhs = rewriter.create<mhlo::ReshapeOp>(
loc, RankedTensorType::get(lhsNewShape, lhsShapeType.getElementType()),
op.lhs());
Value reshapedRhs = rewriter.create<mhlo::ReshapeOp>(
loc, RankedTensorType::get(rhsNewShape, rhsShapeType.getElementType()),
op.rhs());
auto dimensionNumbers = mhlo::DotDimensionNumbers::get(
/*lhs_batching_dimensions=*/make1DElementsAttr(rewriter, {0}),
/*rhs_batching_dimensions=*/make1DElementsAttr(rewriter, {0}),
/*lhs_contracting_dimensions=*/
make1DElementsAttr(rewriter, {lhsContractingDimIndex}),
/*rhs_contracting_dimensions=*/
make1DElementsAttr(rewriter, {rhsContractingDimIndex}),
rewriter.getContext());
Value dotGeneralResult = rewriter.create<mhlo::DotGeneralOp>(
loc, dotGeneralResultType, reshapedLhs, reshapedRhs, dimensionNumbers,
op.precision_configAttr());
Value result =
rewriter.create<mhlo::ReshapeOp>(loc, resultType, dotGeneralResult);
rewriter.replaceOp(op, result);
return success();
}
}; // namespace
// clang-format off
//
// Reorder BroadcastInDimOp and N-ary elementwise op.
//
// Rewrites the following pattern (take binary elementwise op as example)
//
// %bcastx = "mhlo.broadcast_in_dim"(%x) {broadcast_dimensions = %[[BCAST_DIMS]]} : (%[[SHAPE_BEFORE_BCAST]]) -> %[[SHAPE_AFTER_BCAST]]
// %bcasty = "mhlo.broadcast_in_dim"(%y) {broadcast_dimensions = %[[BCAST_DIMS]]} : (%[[SHAPE_BEFORE_BCAST]]) -> %[[SHAPE_AFTER_BCAST]]
// %result = "BinaryElementwiseOpT"(%bcastx, %bcasty) : (%[[SHAPE_AFTER_BCAST]], %[[SHAPE_AFTER_BCAST]]) -> %[[SHAPE_AFTER_BCAST]]
//
// into
//
// %z = "BinaryElementwiseOpT"(%x, %y) : (%[[SHAPE_BEFORE_BCAST]], %[[SHAPE_BEFORE_BCAST]]) -> %[[SHAPE_BEFORE_BCAST]]
// %result = "mhlo.broadcast_in_dim"(%z) {broadcast_dimensions = %[[BCAST_DIMS]]} : (%[[SHAPE_BEFORE_BCAST]]) -> %[[SHAPE_AFTER_BCAST]]
//
// clang-format on
template <typename ElementwiseOpT>
class ReorderBroadcastInDimOpAndElementwiseOp
: public OpRewritePattern<ElementwiseOpT> {
public:
using OpRewritePattern<ElementwiseOpT>::OpRewritePattern;
LogicalResult matchAndRewrite(ElementwiseOpT op,
PatternRewriter &rewriter) const override {
Operation *operation = op.getOperation();
assert(operation->getNumOperands() >= 1 && operation->getNumResults() == 1);
// Verify if all operands are from BroadcastInDimOp and its
// broadcast_dimensions is the same.
llvm::SmallVector<mhlo::BroadcastInDimOp, 2> bcastOps;
for (auto operand : operation->getOperands()) {
if (auto bcastOp = operand.getDefiningOp<mhlo::BroadcastInDimOp>()) {
bcastOps.push_back(bcastOp);
} else {
return failure();
}
}
if (llvm::any_of(bcastOps, [&bcastOps](mhlo::BroadcastInDimOp bcastOp) {
return bcastOp.broadcast_dimensions() !=
bcastOps[0].broadcast_dimensions();
})) {
return failure();
}
// Verify if all operands of BroadcastInDimOp are of same type and have
// static shape.
auto bcastOperandType =
bcastOps[0].operand().getType().template dyn_cast<ShapedType>();
llvm::SmallVector<Value, 2> bcastOperands;
for (auto bcastOp : bcastOps) {
auto bcastOperand = bcastOp.operand();
auto type = bcastOperand.getType().template dyn_cast<ShapedType>();
if (!type || !type.hasStaticShape() || type != bcastOperandType) {
return failure();
}
bcastOperands.push_back(bcastOperand);
}
// Some elementwise ops, mhlo::RealOp for example, do not have
// SameOperandsAndResultType trait, so resultType might be different
// from bcastOperandType.
auto elementType = getElementTypeOrSelf(op.getResult());
auto resultShape = bcastOperandType.getShape();
auto resultType = RankedTensorType::get(resultShape, elementType);
Value result =
rewriter.create<ElementwiseOpT>(op.getLoc(), resultType, bcastOperands);
rewriter.replaceOpWithNewOp<mhlo::BroadcastInDimOp>(
op, op.getType(), result, bcastOps[0].broadcast_dimensions());
for (auto bcastOp : bcastOps) {
if (bcastOp.getOperation()->use_empty()) {
rewriter.eraseOp(bcastOp);
}
}
return success();
}
};
struct HLOToHLOPreprocessing
: public PassWrapper<HLOToHLOPreprocessing, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<shape::ShapeDialect, mhlo::MhloDialect,
tensor::TensorDialect>();
}
void runOnFunction() override {
MLIRContext *context = &getContext();
ConversionTarget conversionTarget(*context);
OwningRewritePatternList conversionPatterns;
// Note that various input modalities may do their own legalization of
// CHLO. Converting here allows IREE to accept CHLO dialect regardless of
// whether it was legalized away at a higher level.
chlo::PopulateLegalizeChloToHloPatterns(context, &conversionPatterns);
conversionTarget.addLegalDialect<shape::ShapeDialect, mhlo::MhloDialect,
mlir::StandardOpsDialect,
mlir::tensor::TensorDialect>();
conversionTarget.addIllegalDialect<chlo::HloClientDialect>();
if (failed(applyPartialConversion(getFunction(), conversionTarget,
std::move(conversionPatterns)))) {
return signalPassFailure();
}
OwningRewritePatternList patterns;
mhlo::PopulateUnfuseBatchNormPatterns(context, &patterns);
mhlo::PopulateComplexLoweringPatterns(context, &patterns);
mhlo::PopulateGatherToTorchIndexSelectPatterns(context, &patterns);
patterns.insert<ExtractReduceWindowOpPaddingAttributes,
AdjustDepthwiseFilterShape, DecomposeLog1PPattern,
DecomposeExpM1Pattern>(context);
// dot_general canoncalization patterns.
mhlo::PopulateGeneralDotOpLoweringPatterns(&patterns, context);
patterns.insert<RankReducedDotGeneral, TransposeRank3GenericDotGeneral>(
context);
// Unary elementwise op.
patterns.insert<
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::AbsOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::CeilOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::ConvertOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::ClzOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::CosOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::ExpOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::Expm1Op>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::FloorOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::ImagOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::IsFiniteOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::LogOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::Log1pOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::LogisticOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::NotOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::NegOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::PopulationCountOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::RealOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::RoundOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::RsqrtOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::SignOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::SinOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::SqrtOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::TanhOp>>(context);
// Binary elementwise op.
patterns.insert<
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::AddOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::Atan2Op>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::ComplexOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::DivOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::MaxOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::MinOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::MulOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::PowOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::RemOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::ShiftLeftOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::ShiftRightArithmeticOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::ShiftRightLogicalOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::SubOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::AndOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::OrOp>,
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::XorOp>>(context);
if (extractPadFromConv) {
patterns.insert<ExtractConvOpPaddingAttributes>(context);
}
if (orderConvFeatures) {
patterns.insert<ReorderConvOpInputDimensions>(context);
patterns.insert<ReorderConvOpKernelDimensions>(context);
patterns.insert<ReorderConvOpOutputDimensions>(context);
}
if (conv1x1toDot) {
patterns.insert<Lower1x1ConvolutionToDotOp>(context);
}
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>> createHLOPreprocessingPass() {
return std::make_unique<HLOToHLOPreprocessing>();
}
static PassRegistration<HLOToHLOPreprocessing> legalize_pass(
"iree-flow-hlo-to-hlo-preprocessing",
"Apply hlo to hlo transformations for some hlo ops");
} // namespace Flow
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir