| // Copyright 2020 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include <numeric> |
| #include <random> |
| |
| #include "iree/compiler/InputConversion/MHLO/PassDetail.h" |
| #include "iree/compiler/InputConversion/MHLO/Passes.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/Support/Casting.h" |
| #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" |
| #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" |
| #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" |
| #include "mlir/Dialect/Shape/IR/Shape.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/ImplicitLocOpBuilder.h" |
| #include "mlir/IR/TypeUtilities.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Support/LogicalResult.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| |
| namespace { |
| |
| /// Returns true if the given `attr` is a splat of the given `value`. |
| static bool isSplatValue(DenseIntElementsAttr attr, uint64_t value) { |
| return attr.isSplat() && attr.getSplatValue<uint64_t>() == value; |
| } |
| |
| static bool isAllZero(DenseIntElementsAttr attr) { |
| return isSplatValue(attr, 0); |
| } |
| |
| static bool isIota(ArrayRef<int64_t> array) { |
| for (auto it : llvm::enumerate(array)) { |
| if (it.index() != it.value()) { |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| /// Returns true if the conv op has padding attribute, and that it has |
| /// non-zero entries. |
| static bool hasPadding(mhlo::ConvOp op) { |
| Optional<DenseIntElementsAttr> padding = op.padding(); |
| if (!padding) return false; |
| return llvm::any_of(padding.getValue(), |
| [](APInt v) -> bool { return !v.isNullValue(); }); |
| } |
| |
| static DenseIntElementsAttr make1DElementsAttr(OpBuilder &b, |
| ArrayRef<int64_t> integers) { |
| auto type = RankedTensorType::get({static_cast<int64_t>(integers.size())}, |
| b.getIntegerType(64)); |
| return DenseIntElementsAttr::get(type, integers); |
| } |
| |
| static DenseIntElementsAttr make1DElementsAttr(OpBuilder &b, int64_t start, |
| int64_t num) { |
| return make1DElementsAttr( |
| b, llvm::to_vector<4>(llvm::seq<int64_t>(start, start + num))); |
| } |
| |
| static Value getF32Const(ImplicitLocOpBuilder b, ArrayRef<int64_t> shapes, |
| ArrayRef<float> values) { |
| RankedTensorType ty = RankedTensorType::get(shapes, b.getF32Type()); |
| return b.create<mhlo::ConstOp>(DenseFPElementsAttr::get(ty, values)) |
| .getResult(); |
| } |
| |
| static Value getF32SplatConst(ImplicitLocOpBuilder b, ArrayRef<int64_t> shapes, |
| float value) { |
| return getF32Const(b, shapes, {value}); |
| } |
| |
| class DecomposeLog1PPattern : public OpRewritePattern<mhlo::Log1pOp> { |
| public: |
| using OpRewritePattern<mhlo::Log1pOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(mhlo::Log1pOp op, |
| PatternRewriter &rewriter) const override { |
| Location loc = op.getLoc(); |
| auto type = op.operand().getType().cast<TensorType>(); |
| DenseElementsAttr attr = |
| DenseElementsAttr::get(type, rewriter.getF32FloatAttr(1.0)); |
| auto one = rewriter.create<ConstantOp>(loc, attr); |
| auto x = rewriter.create<mhlo::AddOp>(loc, op.operand(), one); |
| rewriter.replaceOpWithNewOp<mhlo::LogOp>(op, x); |
| return success(); |
| } |
| }; |
| |
| class DecomposeExpM1Pattern : public OpRewritePattern<mhlo::Expm1Op> { |
| public: |
| using OpRewritePattern<mhlo::Expm1Op>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(mhlo::Expm1Op op, |
| PatternRewriter &rewriter) const override { |
| Location loc = op.getLoc(); |
| auto type = op.operand().getType().cast<TensorType>(); |
| DenseElementsAttr attr = |
| DenseElementsAttr::get(type, rewriter.getF32FloatAttr(1.0)); |
| auto one = rewriter.create<ConstantOp>(loc, attr); |
| auto x = rewriter.create<mhlo::ExpOp>(loc, op.operand()); |
| rewriter.replaceOpWithNewOp<mhlo::SubOp>(op, x, one); |
| return success(); |
| } |
| }; |
| |
| class ExtractConvOpPaddingAttributes : public OpRewritePattern<mhlo::ConvOp> { |
| public: |
| using OpRewritePattern<mhlo::ConvOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(mhlo::ConvOp op, |
| PatternRewriter &rewriter) const override { |
| if (!hasPadding(op)) return failure(); |
| auto inputType = op.lhs().getType().cast<ShapedType>(); |
| int rank = inputType.getRank(); |
| |
| // TODO(suderman): Add proper support for padding + dilation for codegen. |
| // We can't extract padding if the left hand side has dilation. |
| if (op.lhs_dilation().hasValue()) { |
| for (auto val : op.lhs_dilation().getValue().getValues<APInt>()) { |
| if (val != 1) { |
| return failure(); |
| } |
| } |
| } |
| |
| SmallVector<int64_t, 4> paddingLow, paddingHigh, interiorPadding, shape; |
| paddingLow.append(rank, 0); |
| paddingHigh.append(rank, 0); |
| interiorPadding.append(rank, 0); |
| for (auto iter : |
| llvm::enumerate(op.dimension_numbers().getInputSpatialDimensions())) { |
| unsigned idx = iter.index(); |
| unsigned dim = iter.value(); |
| paddingLow[dim] = op.paddingAttr().getValue<int64_t>({idx, 0}); |
| paddingHigh[dim] = op.paddingAttr().getValue<int64_t>({idx, 1}); |
| } |
| for (unsigned i = 0; i < rank; ++i) { |
| // mhlo.pad doesn't support dynamic shape. |
| if (inputType.isDynamicDim(i)) return failure(); |
| int size = inputType.getShape()[i]; |
| shape.push_back(size + paddingLow[i] + paddingHigh[i]); |
| } |
| |
| auto toDenseAttr = [&rewriter](ArrayRef<int64_t> elements) { |
| return DenseIntElementsAttr::get( |
| RankedTensorType::get(elements.size(), rewriter.getIntegerType(64)), |
| elements); |
| }; |
| |
| auto loc = op.getLoc(); |
| auto padResultType = |
| RankedTensorType::get(shape, inputType.getElementType()); |
| Attribute zeroAttr = rewriter.getZeroAttr( |
| RankedTensorType::get({}, inputType.getElementType())); |
| auto zero = rewriter.create<ConstantOp>(loc, zeroAttr); |
| auto padOp = rewriter.create<mhlo::PadOp>( |
| loc, padResultType, op.lhs(), zero, toDenseAttr(paddingLow), |
| toDenseAttr(paddingHigh), toDenseAttr(interiorPadding)); |
| auto resultType = op.getResult().getType(); |
| auto newOp = rewriter.create<mhlo::ConvOp>( |
| op.getLoc(), resultType, padOp.getResult(), op.rhs(), |
| op.window_stridesAttr(), /*padding=*/nullptr, op.lhs_dilationAttr(), |
| op.rhs_dilationAttr(), /*window_reversal=*/nullptr, |
| op.dimension_numbersAttr(), op.feature_group_countAttr(), |
| op.batch_group_countAttr(), op.precision_configAttr()); |
| rewriter.replaceOp(op, newOp.getResult()); |
| return success(); |
| } |
| }; |
| |
| // Guarantee that the input dimensions are ordered batch, spatial_dims, feature |
| // dim. |
| class ReorderConvOpInputDimensions : public OpRewritePattern<mhlo::ConvOp> { |
| public: |
| using OpRewritePattern<mhlo::ConvOp>::OpRewritePattern; |
| LogicalResult matchAndRewrite(mhlo::ConvOp op, |
| PatternRewriter &rewriter) const override { |
| auto lhsType = op.lhs().getType().cast<ShapedType>(); |
| auto lhsShape = lhsType.getShape(); |
| if (!lhsType.hasRank()) { |
| return failure(); |
| } |
| |
| auto dimensionNumbers = op.dimension_numbers(); |
| auto spatialDims = dimensionNumbers.getInputSpatialDimensions(); |
| |
| // Compute the permutation required to create a standard order. |
| llvm::SmallVector<int64_t, 4> permutations; |
| permutations.push_back(dimensionNumbers.getInputBatchDimension()); |
| permutations.append(spatialDims.begin(), spatialDims.end()); |
| permutations.push_back(dimensionNumbers.getInputFeatureDimension()); |
| |
| // If the permutation is iota then no reordering is required. |
| if (isIota(permutations)) { |
| return failure(); |
| } |
| |
| llvm::SmallVector<int64_t, 4> transposeShape; |
| for (auto p : permutations) { |
| transposeShape.push_back(lhsShape[p]); |
| } |
| |
| auto transposed = rewriter.create<mhlo::TransposeOp>( |
| op.getLoc(), |
| RankedTensorType::get(transposeShape, lhsType.getElementType()), |
| op.lhs(), rewriter.getI64TensorAttr(permutations)); |
| |
| llvm::SmallVector<int64_t, 4> newSpatialDimensions(spatialDims.size()); |
| std::iota(newSpatialDimensions.begin(), newSpatialDimensions.end(), 1); |
| |
| auto newDimensionNumbers = mhlo::ConvDimensionNumbersAttr::get( |
| op.getContext(), |
| /*input_batch_dimension=*/0, |
| /*input_feature_dimension=*/newSpatialDimensions.size() + 1, |
| /*input_spatial_dimensions=*/newSpatialDimensions, |
| dimensionNumbers.getKernelInputFeatureDimension(), |
| dimensionNumbers.getKernelOutputFeatureDimension(), |
| dimensionNumbers.getKernelSpatialDimensions(), |
| dimensionNumbers.getOutputBatchDimension(), |
| dimensionNumbers.getOutputFeatureDimension(), |
| dimensionNumbers.getOutputSpatialDimensions()); |
| |
| SmallVector<Value, 2> operands = {transposed, op.rhs()}; |
| auto newConv = rewriter.create<mhlo::ConvOp>(op.getLoc(), op.getType(), |
| operands, op->getAttrs()); |
| newConv.dimension_numbersAttr(newDimensionNumbers); |
| rewriter.replaceOp(op, newConv.getResult()); |
| |
| return success(); |
| } |
| }; |
| |
| struct ReorderConvOpKernelDimensions : public OpRewritePattern<mhlo::ConvOp> { |
| using OpRewritePattern::OpRewritePattern; |
| LogicalResult matchAndRewrite(mhlo::ConvOp op, |
| PatternRewriter &rewriter) const override { |
| auto kernel = op.rhs(); |
| auto kernelType = kernel.getType().cast<ShapedType>(); |
| if (!kernelType.hasRank()) return failure(); |
| auto kernelShape = kernelType.getShape(); |
| |
| auto dimensionNumbers = op.dimension_numbers(); |
| |
| auto spatialDims = dimensionNumbers.getKernelSpatialDimensions(); |
| |
| auto inputFeatureDimension = |
| dimensionNumbers.getKernelInputFeatureDimension(); |
| auto outputFeatureDimension = |
| dimensionNumbers.getKernelOutputFeatureDimension(); |
| |
| // Compute the permutation for the transpose. |
| llvm::SmallVector<int64_t, 4> permutation(spatialDims.begin(), |
| spatialDims.end()); |
| permutation.push_back(inputFeatureDimension); |
| permutation.push_back(outputFeatureDimension); |
| |
| // If the permutation is iota, then no transpose is required. |
| if (isIota(permutation)) return failure(); |
| |
| llvm::SmallVector<int64_t, 4> transposeShape; |
| for (auto perm : permutation) { |
| transposeShape.push_back(kernelShape[perm]); |
| } |
| |
| llvm::SmallVector<int64_t, 4> newSpatialDimensions(spatialDims.size()); |
| std::iota(newSpatialDimensions.begin(), newSpatialDimensions.end(), 0); |
| |
| auto transposeKernel = rewriter.create<mhlo::TransposeOp>( |
| op.getLoc(), |
| RankedTensorType::get(transposeShape, kernelType.getElementType()), |
| kernel, rewriter.getI64TensorAttr(permutation)); |
| |
| auto newDimensionNumbers = mhlo::ConvDimensionNumbersAttr::get( |
| op.getContext(), dimensionNumbers.getInputBatchDimension(), |
| dimensionNumbers.getInputFeatureDimension(), |
| dimensionNumbers.getInputSpatialDimensions(), |
| /*kernel_input_feature_dimension=*/ |
| newSpatialDimensions.size(), |
| /*kernel_output_feature_dimension=*/ |
| newSpatialDimensions.size() + 1, newSpatialDimensions, |
| dimensionNumbers.getOutputBatchDimension(), |
| dimensionNumbers.getOutputFeatureDimension(), |
| dimensionNumbers.getOutputSpatialDimensions()); |
| |
| SmallVector<Value, 2> operands = {op.lhs(), transposeKernel}; |
| mhlo::ConvOp newConv = rewriter.create<mhlo::ConvOp>( |
| op.getLoc(), op.getType(), operands, op->getAttrs()); |
| newConv.dimension_numbersAttr(newDimensionNumbers); |
| |
| rewriter.replaceOp(op, {newConv.getResult()}); |
| return success(); |
| } |
| }; |
| |
| // Guarantee that the output dimensions are ordered batch, spatial_dims, feature |
| // dim. |
| class ReorderConvOpOutputDimensions : public OpRewritePattern<mhlo::ConvOp> { |
| public: |
| using OpRewritePattern<mhlo::ConvOp>::OpRewritePattern; |
| LogicalResult matchAndRewrite(mhlo::ConvOp op, |
| PatternRewriter &rewriter) const override { |
| auto resultType = op.getType().cast<ShapedType>(); |
| auto resultShape = resultType.getShape(); |
| if (!resultType.hasRank()) { |
| return failure(); |
| } |
| |
| auto dimensionNumbers = op.dimension_numbers(); |
| auto spatialDims = dimensionNumbers.getOutputSpatialDimensions(); |
| |
| // Compute the permutation to transpose to an ordered output. |
| llvm::SmallVector<int64_t, 4> permutation; |
| permutation.push_back(dimensionNumbers.getOutputBatchDimension()); |
| permutation.append(spatialDims.begin(), spatialDims.end()); |
| permutation.push_back(dimensionNumbers.getOutputFeatureDimension()); |
| |
| // If the permutation is iota then no reordering is required. |
| if (isIota(permutation)) { |
| return failure(); |
| } |
| |
| // Compute what the new conv shape should be. |
| llvm::SmallVector<int64_t, 4> convShape; |
| for (auto p : permutation) { |
| convShape.push_back(resultShape[p]); |
| } |
| |
| // Compute the inverse transpose to unordered and ordered output. |
| llvm::SmallVector<int64_t, 4> invertPermutation(permutation.size()); |
| for (auto it : llvm::enumerate(permutation)) { |
| invertPermutation[it.value()] = it.index(); |
| } |
| |
| llvm::SmallVector<int64_t, 4> newSpatialDimensions(spatialDims.size()); |
| std::iota(newSpatialDimensions.begin(), newSpatialDimensions.end(), 1); |
| |
| auto newDimensionNumbers = mhlo::ConvDimensionNumbersAttr::get( |
| op.getContext(), dimensionNumbers.getInputBatchDimension(), |
| dimensionNumbers.getInputFeatureDimension(), |
| dimensionNumbers.getInputSpatialDimensions(), |
| dimensionNumbers.getKernelInputFeatureDimension(), |
| dimensionNumbers.getKernelOutputFeatureDimension(), |
| dimensionNumbers.getKernelSpatialDimensions(), |
| /*output_batch_dimension=*/0, |
| /*output_feature_dimension=*/newSpatialDimensions.size() + 1, |
| /*output_spatial_dimensions=*/newSpatialDimensions); |
| |
| SmallVector<Value, 2> operands = {op.lhs(), op.rhs()}; |
| auto newConv = rewriter.create<mhlo::ConvOp>( |
| op.getLoc(), |
| RankedTensorType::get(convShape, resultType.getElementType()), operands, |
| op->getAttrs()); |
| newConv.dimension_numbersAttr(newDimensionNumbers); |
| |
| auto transposed = rewriter.create<mhlo::TransposeOp>( |
| op.getLoc(), resultType, newConv, |
| rewriter.getI64TensorAttr(invertPermutation)); |
| |
| rewriter.replaceOp(op, transposed.getResult()); |
| return success(); |
| } |
| }; |
| |
| class ExtractReduceWindowOpPaddingAttributes |
| : public OpRewritePattern<mhlo::ReduceWindowOp> { |
| public: |
| using OpRewritePattern<mhlo::ReduceWindowOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(mhlo::ReduceWindowOp op, |
| PatternRewriter &rewriter) const override { |
| if (!op.padding()) return failure(); |
| |
| if ((op.base_dilations() && !isSplatValue(*op.base_dilations(), 1)) || |
| (op.window_dilations() && !isSplatValue(*op.window_dilations(), 1))) { |
| return failure(); |
| } |
| if (isAllZero(op.paddingAttr())) return failure(); |
| |
| // All inputs must be of the same static shape, since |
| // mhlo.pad doesn't support dynamic shape. |
| for (Type inputType : op.inputs().getType()) { |
| if (!inputType.cast<ShapedType>().hasStaticShape()) return failure(); |
| } |
| ArrayRef<int64_t> inputShape = |
| op.inputs()[0].getType().cast<ShapedType>().getShape(); |
| |
| int rank = inputShape.size(); |
| SmallVector<int64_t, 4> paddingLow, paddingHigh, interiorPadding, shape; |
| for (unsigned i = 0; i < rank; ++i) { |
| interiorPadding.push_back(0); |
| paddingLow.push_back(op.paddingAttr().getValue<int64_t>({i, 0})); |
| paddingHigh.push_back(op.paddingAttr().getValue<int64_t>({i, 1})); |
| int size = inputShape[i]; |
| shape.push_back(size + paddingLow.back() + paddingHigh.back()); |
| } |
| |
| auto toDenseAttr = [&rewriter](ArrayRef<int64_t> elements) { |
| return DenseIntElementsAttr::get( |
| RankedTensorType::get(elements.size(), rewriter.getIntegerType(64)), |
| elements); |
| }; |
| |
| SmallVector<Value> padOps; |
| padOps.reserve(op.inputs().size()); |
| auto loc = op.getLoc(); |
| for (auto it : llvm::zip(op.inputs(), op.init_values())) { |
| Value input = std::get<0>(it); |
| Value initValue = std::get<1>(it); |
| auto inputType = input.getType().cast<ShapedType>(); |
| auto padResultType = |
| RankedTensorType::get(shape, inputType.getElementType()); |
| auto padOp = rewriter.create<mhlo::PadOp>( |
| loc, padResultType, input, initValue, toDenseAttr(paddingLow), |
| toDenseAttr(paddingHigh), toDenseAttr(interiorPadding)); |
| padOps.push_back(padOp); |
| } |
| auto newOp = rewriter.create<mhlo::ReduceWindowOp>( |
| loc, op.getResultTypes(), padOps, op.init_values(), |
| op.window_dimensions(), op.window_stridesAttr(), |
| op.base_dilationsAttr(), op.window_dilationsAttr(), |
| /*padding=*/nullptr); |
| rewriter.inlineRegionBefore(op.body(), newOp.body(), newOp.body().begin()); |
| rewriter.replaceOp(op, newOp.getResults()); |
| return success(); |
| } |
| }; |
| |
| // Adjust the shape of depthwise_conv filter where is applied by mhlo. |
| class AdjustDepthwiseFilterShape : public OpRewritePattern<mhlo::ConvOp> { |
| public: |
| using OpRewritePattern<mhlo::ConvOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(mhlo::ConvOp op, |
| PatternRewriter &rewriter) const override { |
| int64_t featureInDim = |
| op.dimension_numbers().getKernelInputFeatureDimension(); |
| int64_t featureOutDim = |
| op.dimension_numbers().getKernelOutputFeatureDimension(); |
| const auto &kernelShape = op.rhs().getType().cast<ShapedType>().getShape(); |
| if (kernelShape[featureInDim] != 1) return failure(); |
| |
| const auto groupCount = op.feature_group_count(); |
| if (groupCount == 1) return failure(); |
| if (kernelShape[featureOutDim] % groupCount != 0) return failure(); |
| |
| SmallVector<int64_t, 4> newShape(kernelShape.begin(), kernelShape.end()); |
| newShape[featureInDim] = groupCount; |
| newShape[featureOutDim] /= groupCount; |
| auto loc = op.getLoc(); |
| auto elemType = op.rhs().getType().cast<ShapedType>().getElementType(); |
| auto reshapeOp = rewriter.create<mhlo::ReshapeOp>( |
| loc, RankedTensorType::get(newShape, elemType), op.rhs()); |
| auto resultType = op.getResult().getType(); |
| SmallVector<Value, 2> operands = {op.lhs(), reshapeOp.getResult()}; |
| auto newOp = rewriter.create<mhlo::ConvOp>(op.getLoc(), resultType, |
| operands, op->getAttrs()); |
| rewriter.replaceOp(op, newOp.getResult()); |
| return success(); |
| } |
| }; |
| |
| bool isConsecutive(ArrayRef<int64_t> array) { |
| for (int i = 1; i < array.size(); ++i) { |
| if (array[i] - array[i - 1] != 1) return false; |
| } |
| return true; |
| } |
| |
| SmallVector<int64_t> extract1DVector(DenseIntElementsAttr elements) { |
| SmallVector<int64_t> ret; |
| for (const APInt &element : elements) { |
| ret.push_back(element.getLimitedValue()); |
| } |
| return ret; |
| } |
| |
| // Rewrites mhlo.dot_general so lhs contraction dimensions are innermost and rhs |
| // contraction dimensions are dims right after batch dimension. The pattern |
| // inserts transposes so the dot_general always has the form: |
| // {batch_dims, parallel_dims, contraction_dims}. |
| // {batch_dims, contraction_dims, parallel_dims} |
| class TransposeGenericDotGeneral : public OpRewritePattern<mhlo::DotGeneralOp> { |
| public: |
| using OpRewritePattern<mhlo::DotGeneralOp>::OpRewritePattern; |
| |
| Value TransposeIfNonConsecutive(OpBuilder b, Location loc, Value src, |
| ArrayRef<int64_t> targetOrder) const { |
| if (isConsecutive(targetOrder)) return src; |
| auto type = src.getType().cast<RankedTensorType>(); |
| SmallVector<int64_t, 4> transposeShape; |
| for (auto i : targetOrder) { |
| transposeShape.push_back(type.getDimSize(i)); |
| } |
| return b.create<mhlo::TransposeOp>( |
| loc, RankedTensorType::get(transposeShape, type.getElementType()), src, |
| b.getI64TensorAttr(targetOrder)); |
| } |
| |
| LogicalResult matchAndRewrite(mhlo::DotGeneralOp op, |
| PatternRewriter &rewriter) const override { |
| auto lhsShapeType = op.lhs().getType().dyn_cast<RankedTensorType>(); |
| auto rhsShapeType = op.rhs().getType().dyn_cast<RankedTensorType>(); |
| auto resultType = op.getResult().getType().dyn_cast<RankedTensorType>(); |
| if (!lhsShapeType || !rhsShapeType || !resultType) return failure(); |
| |
| SmallVector<int64_t> lhsTargetOrder, rhsTargetOrder; |
| mhlo::DotDimensionNumbersAttr dimNumbers = op.dot_dimension_numbers(); |
| auto lhsBatchingDims = dimNumbers.getLhsBatchingDimensions(); |
| auto lhsContractingDims = dimNumbers.getLhsContractingDimensions(); |
| SmallVector<bool> isLhsParallel(lhsShapeType.getRank(), true); |
| for (auto i : lhsBatchingDims) { |
| lhsTargetOrder.push_back(i); |
| isLhsParallel[i] = false; |
| } |
| for (auto i : lhsContractingDims) { |
| isLhsParallel[i] = false; |
| } |
| for (int64_t i = 0, e = lhsShapeType.getRank(); i < e; ++i) { |
| if (isLhsParallel[i]) { |
| lhsTargetOrder.push_back(i); |
| } |
| } |
| for (auto i : lhsContractingDims) { |
| lhsTargetOrder.push_back(i); |
| } |
| |
| SmallVector<bool> isRhsParallel(rhsShapeType.getRank(), true); |
| auto rhsBatchingDims = dimNumbers.getRhsBatchingDimensions(); |
| auto rhsContractingDims = dimNumbers.getRhsContractingDimensions(); |
| for (auto i : rhsBatchingDims) { |
| rhsTargetOrder.push_back(i); |
| isRhsParallel[i] = false; |
| } |
| for (auto i : rhsContractingDims) { |
| rhsTargetOrder.push_back(i); |
| isRhsParallel[i] = false; |
| } |
| for (int64_t i = 0, e = rhsShapeType.getRank(); i < e; ++i) { |
| if (isRhsParallel[i]) { |
| rhsTargetOrder.push_back(i); |
| } |
| } |
| |
| Value lhs = TransposeIfNonConsecutive(rewriter, op.getLoc(), op.lhs(), |
| lhsTargetOrder); |
| Value rhs = TransposeIfNonConsecutive(rewriter, op.getLoc(), op.rhs(), |
| rhsTargetOrder); |
| if (lhs == op.lhs() && rhs == op.rhs()) return failure(); |
| |
| int64_t numLhsContractionDims = lhsContractingDims.size(); |
| int64_t lhsContractionBase = lhsShapeType.getRank() - numLhsContractionDims; |
| int64_t rhsContractionBase = rhsBatchingDims.size(); |
| int64_t numRhsContractionDims = |
| rhsContractionBase + rhsContractingDims.size(); |
| auto lhsBatchingDimsAttr = |
| llvm::to_vector<4>(llvm::seq<int64_t>(0, lhsBatchingDims.size())); |
| auto rhsBatchingDimsAttr = |
| llvm::to_vector<4>(llvm::seq<int64_t>(0, rhsBatchingDims.size())); |
| auto lhsContractingDimsAttr = llvm::to_vector<4>( |
| llvm::seq<int64_t>(lhsContractionBase, lhsShapeType.getRank())); |
| auto rhsContractingDimsAttr = llvm::to_vector<4>( |
| llvm::seq<int64_t>(rhsContractionBase, numRhsContractionDims)); |
| auto dimensionNumbers = mhlo::DotDimensionNumbersAttr::get( |
| rewriter.getContext(), lhsBatchingDimsAttr, rhsBatchingDimsAttr, |
| lhsContractingDimsAttr, rhsContractingDimsAttr); |
| |
| Value result = rewriter.create<mhlo::DotGeneralOp>( |
| op.getLoc(), op.getType(), lhs, rhs, dimensionNumbers, |
| op.precision_configAttr()); |
| rewriter.replaceOp(op, result); |
| return success(); |
| } |
| }; |
| |
| // Rewrite mhlo.dot_general to operate on rank-3 tensors when reduction dims are |
| // in consecutive order and not spliting the domain. This pattern inserts |
| // reshapes to collapse consecutive reduction and parallel dims to always |
| // generate a rank-3 dot_general op. |
| class RankReducedDotGeneral : public OpRewritePattern<mhlo::DotGeneralOp> { |
| public: |
| using OpRewritePattern<mhlo::DotGeneralOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(mhlo::DotGeneralOp op, |
| PatternRewriter &rewriter) const override { |
| auto lhsShapeType = op.lhs().getType().dyn_cast<ShapedType>(); |
| auto rhsShapeType = op.rhs().getType().dyn_cast<ShapedType>(); |
| auto resultType = op.getResult().getType().dyn_cast<ShapedType>(); |
| |
| if (!lhsShapeType || !rhsShapeType || !resultType) return failure(); |
| if (!lhsShapeType.hasStaticShape() || !rhsShapeType.hasStaticShape()) |
| return failure(); |
| if (resultType.getRank() <= 3) return failure(); |
| |
| mhlo::DotDimensionNumbersAttr dimNumbers = op.dot_dimension_numbers(); |
| auto lhsBatchingDims = |
| llvm::to_vector<4>(dimNumbers.getLhsBatchingDimensions()); |
| auto rhsBatchingDims = |
| llvm::to_vector<4>(dimNumbers.getRhsBatchingDimensions()); |
| auto lhsContractingDims = |
| llvm::to_vector<4>(dimNumbers.getLhsContractingDimensions()); |
| auto rhsContractingDims = |
| llvm::to_vector<4>(dimNumbers.getRhsContractingDimensions()); |
| |
| if (lhsBatchingDims.empty() || rhsBatchingDims.empty()) return failure(); |
| |
| llvm::sort(lhsBatchingDims); |
| llvm::sort(lhsContractingDims); |
| llvm::sort(rhsBatchingDims); |
| llvm::sort(rhsContractingDims); |
| |
| auto isDomainSplit = [](ArrayRef<int64_t> shape, |
| ArrayRef<int64_t> batchingDims, |
| ArrayRef<int64_t> contractingDims) { |
| // Batching and contracting are contiguous. |
| if ((contractingDims.front() - batchingDims.back()) == 1) return false; |
| // Contracting dims are inner most. |
| if (contractingDims.back() == (shape.size() - 1)) return false; |
| return true; |
| }; |
| |
| if (!isConsecutive(lhsBatchingDims) || !isConsecutive(lhsContractingDims) || |
| !isConsecutive(rhsBatchingDims) || !isConsecutive(rhsContractingDims)) |
| return failure(); |
| |
| if (isDomainSplit(lhsShapeType.getShape(), lhsBatchingDims, |
| lhsContractingDims) || |
| isDomainSplit(rhsShapeType.getShape(), rhsBatchingDims, |
| rhsContractingDims)) |
| return failure(); |
| |
| // Collapsing shape into a rank-3 tensor, returns newCollabsedShape |
| // contraction and parallel dim indices. |
| auto computeCollapsedShape = [](ArrayRef<int64_t> shape, |
| ArrayRef<int64_t> batchingDims, |
| ArrayRef<int64_t> contractingDims) { |
| auto newRank = |
| shape.size() - batchingDims.size() - contractingDims.size() + 2; |
| auto batchingSize = std::accumulate( |
| batchingDims.begin(), batchingDims.end(), 1, |
| [shape](const int64_t accum, const int64_t index) -> int64_t { |
| return accum * shape[index]; |
| }); |
| auto contractingSize = std::accumulate( |
| contractingDims.begin(), contractingDims.end(), 1, |
| [shape](const int64_t accum, const int64_t index) -> int64_t { |
| return accum * shape[index]; |
| }); |
| |
| int parallelDimIndex, contractingDimIndex, parallelDimSize = 1; |
| if (contractingDims.front() - batchingDims.back() > 1) { |
| parallelDimIndex = 1; |
| contractingDimIndex = 2; |
| for (int i = batchingDims.back() + 1; i < contractingDims.front(); |
| ++i) { |
| parallelDimSize *= shape[i]; |
| } |
| } else { |
| contractingDimIndex = 1; |
| parallelDimIndex = 2; |
| for (int i = contractingDims.back() + 1; i < shape.size(); ++i) { |
| parallelDimSize *= shape[i]; |
| } |
| } |
| llvm::SmallVector<int64_t, 4> newShape(newRank); |
| newShape[0] = batchingSize; |
| newShape[contractingDimIndex] = contractingSize; |
| newShape[parallelDimIndex] = parallelDimSize; |
| return std::make_tuple(newShape, contractingDimIndex, parallelDimIndex); |
| }; |
| |
| int lhsContractingDimIndex, rhsContractingDimIndex, lhsParallelDimIndex, |
| rhsParallelDimIndex; |
| SmallVector<int64_t, 4> lhsNewShape, rhsNewShape; |
| std::tie(lhsNewShape, lhsContractingDimIndex, lhsParallelDimIndex) = |
| computeCollapsedShape(lhsShapeType.getShape(), lhsBatchingDims, |
| lhsContractingDims); |
| |
| std::tie(rhsNewShape, rhsContractingDimIndex, rhsParallelDimIndex) = |
| computeCollapsedShape(rhsShapeType.getShape(), rhsBatchingDims, |
| rhsContractingDims); |
| SmallVector<int64_t, 4> resultNewShape = {lhsNewShape[0], |
| lhsNewShape[lhsParallelDimIndex], |
| rhsNewShape[rhsParallelDimIndex]}; |
| Type dotGeneralResultType = |
| RankedTensorType::get(resultNewShape, resultType.getElementType()); |
| |
| auto loc = op.getLoc(); |
| Value reshapedLhs = rewriter.create<mhlo::ReshapeOp>( |
| loc, RankedTensorType::get(lhsNewShape, lhsShapeType.getElementType()), |
| op.lhs()); |
| Value reshapedRhs = rewriter.create<mhlo::ReshapeOp>( |
| loc, RankedTensorType::get(rhsNewShape, rhsShapeType.getElementType()), |
| op.rhs()); |
| auto dimensionNumbers = mhlo::DotDimensionNumbersAttr::get( |
| rewriter.getContext(), |
| /*lhs_batching_dimensions=*/{0}, |
| /*rhs_batching_dimensions=*/{0}, |
| /*lhs_contracting_dimensions=*/{lhsContractingDimIndex}, |
| /*rhs_contracting_dimensions=*/ |
| {rhsContractingDimIndex}); |
| Value dotGeneralResult = rewriter.create<mhlo::DotGeneralOp>( |
| loc, dotGeneralResultType, reshapedLhs, reshapedRhs, dimensionNumbers, |
| op.precision_configAttr()); |
| |
| Value result = |
| rewriter.create<mhlo::ReshapeOp>(loc, resultType, dotGeneralResult); |
| rewriter.replaceOp(op, result); |
| |
| return success(); |
| } |
| }; |
| |
| // Generates Gaussian noise with uniform random generator based on Box-Muller |
| // transform. |
| class ExpandRngNormal : public OpRewritePattern<mhlo::RngNormalOp> { |
| public: |
| using OpRewritePattern<mhlo::RngNormalOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(mhlo::RngNormalOp op, |
| PatternRewriter &rewriter) const override { |
| auto resTy = op.getType().dyn_cast<RankedTensorType>(); |
| // We can support static shapes, but it's easier to implement Box-Muller |
| // transform if we know the number of elements. |
| if (!resTy || !resTy.hasStaticShape()) return failure(); |
| |
| // The algorithm requires even numbers and will generate pairs. |
| auto numElems = resTy.getNumElements(); |
| if (numElems & 1) numElems++; |
| auto halfNumElems = numElems / 2; |
| |
| ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
| |
| // Explicitly set the seed to 0, so we have stateless generator. This is not |
| // a hard limit. Random generator is still a new topic, and we start with |
| // stateless random generator. |
| std::mt19937 rng{0}; |
| std::uniform_real_distribution<> runif(0.0, 1.0); |
| SmallVector<float> sqrtValues(halfNumElems), cosValues(halfNumElems), |
| sinValues(halfNumElems); |
| for (auto i : llvm::seq<unsigned>(0, numElems / 2)) { |
| constexpr float kEpsilon = std::numeric_limits<float>::epsilon(); |
| constexpr float kTwoPi = 2.0 * M_PI; |
| float u1, u2; |
| do { |
| u1 = runif(rng); |
| u2 = runif(rng); |
| } while (u1 <= kEpsilon); |
| sqrtValues[i] = -2.0 * log(u1); |
| cosValues[i] = cos(kTwoPi * u2); |
| sinValues[i] = sin(kTwoPi * u2); |
| } |
| |
| // mag = sigma * sqrt(-2.0 * log(u1)); |
| Value mag = getF32Const(b, /*shapes=*/{halfNumElems}, sqrtValues); |
| Value sigma = b.create<mhlo::BroadcastOp>( |
| mag.getType(), op.sigma(), make1DElementsAttr(b, halfNumElems)); |
| mag = b.create<mhlo::MulOp>(sigma, b.create<mhlo::SqrtOp>(mag)); |
| |
| // z0 = mag * cos(two_pi * u2) + mu; |
| // z1 = mag * sin(two_pi * u2) + mu; |
| Value mu = b.create<mhlo::BroadcastOp>(mag.getType(), op.mu(), |
| make1DElementsAttr(b, halfNumElems)); |
| Value z0 = getF32Const(b, /*shapes=*/{halfNumElems}, cosValues); |
| z0 = b.create<mhlo::MulOp>(mag, z0); |
| z0 = b.create<mhlo::AddOp>(z0, mu); |
| Value z1 = getF32Const(b, /*shapes=*/{halfNumElems}, sinValues); |
| z1 = b.create<mhlo::MulOp>(mag, z1); |
| z1 = b.create<mhlo::AddOp>(z1, mu); |
| |
| Value res = b.create<mhlo::ConcatenateOp>(ValueRange{z0, z1}, |
| b.getI64IntegerAttr(0)); |
| if (numElems != resTy.getNumElements()) { |
| OpFoldResult zero = b.getIndexAttr(0); |
| OpFoldResult one = b.getIndexAttr(1); |
| OpFoldResult size = b.getIndexAttr(resTy.getNumElements()); |
| res = b.create<tensor::ExtractSliceOp>(res, zero, size, one); |
| } |
| if (resTy.getRank() != 1) { |
| res = b.create<mhlo::ReshapeOp>(resTy, res); |
| } |
| rewriter.replaceOp(op, res); |
| return success(); |
| } |
| }; |
| |
| // clang-format off |
| // |
| // Reorder BroadcastInDimOp and N-ary elementwise op. |
| // |
| // Rewrites the following pattern (take binary elementwise op as example) |
| // |
| // %bcastx = "mhlo.broadcast_in_dim"(%x) {broadcast_dimensions = %[[BCAST_DIMS]]} : (%[[SHAPE_BEFORE_BCAST]]) -> %[[SHAPE_AFTER_BCAST]] |
| // %bcasty = "mhlo.broadcast_in_dim"(%y) {broadcast_dimensions = %[[BCAST_DIMS]]} : (%[[SHAPE_BEFORE_BCAST]]) -> %[[SHAPE_AFTER_BCAST]] |
| // %result = "BinaryElementwiseOpT"(%bcastx, %bcasty) : (%[[SHAPE_AFTER_BCAST]], %[[SHAPE_AFTER_BCAST]]) -> %[[SHAPE_AFTER_BCAST]] |
| // |
| // into |
| // |
| // %z = "BinaryElementwiseOpT"(%x, %y) : (%[[SHAPE_BEFORE_BCAST]], %[[SHAPE_BEFORE_BCAST]]) -> %[[SHAPE_BEFORE_BCAST]] |
| // %result = "mhlo.broadcast_in_dim"(%z) {broadcast_dimensions = %[[BCAST_DIMS]]} : (%[[SHAPE_BEFORE_BCAST]]) -> %[[SHAPE_AFTER_BCAST]] |
| // |
| // clang-format on |
| template <typename ElementwiseOpT> |
| class ReorderBroadcastInDimOpAndElementwiseOp |
| : public OpRewritePattern<ElementwiseOpT> { |
| public: |
| using OpRewritePattern<ElementwiseOpT>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(ElementwiseOpT op, |
| PatternRewriter &rewriter) const override { |
| Operation *operation = op.getOperation(); |
| assert(operation->getNumOperands() >= 1 && operation->getNumResults() == 1); |
| |
| // Verify if all operands are from BroadcastInDimOp and its |
| // broadcast_dimensions is the same. |
| llvm::SmallVector<mhlo::BroadcastInDimOp, 2> bcastOps; |
| for (auto operand : operation->getOperands()) { |
| if (auto bcastOp = operand.getDefiningOp<mhlo::BroadcastInDimOp>()) { |
| bcastOps.push_back(bcastOp); |
| } else { |
| return failure(); |
| } |
| } |
| |
| if (llvm::any_of(bcastOps, [&bcastOps](mhlo::BroadcastInDimOp bcastOp) { |
| return bcastOp.broadcast_dimensions() != |
| bcastOps[0].broadcast_dimensions(); |
| })) { |
| return failure(); |
| } |
| |
| // Verify if all operands of BroadcastInDimOp are of same type and have |
| // static shape. |
| auto bcastOperandType = |
| bcastOps[0].operand().getType().template dyn_cast<ShapedType>(); |
| llvm::SmallVector<Value, 2> bcastOperands; |
| for (auto bcastOp : bcastOps) { |
| auto bcastOperand = bcastOp.operand(); |
| auto type = bcastOperand.getType().template dyn_cast<ShapedType>(); |
| if (!type || !type.hasStaticShape() || type != bcastOperandType) { |
| return failure(); |
| } |
| bcastOperands.push_back(bcastOperand); |
| } |
| |
| // Some elementwise ops, mhlo::RealOp for example, do not have |
| // SameOperandsAndResultType trait, so resultType might be different |
| // from bcastOperandType. |
| auto elementType = getElementTypeOrSelf(op.getResult()); |
| auto resultShape = bcastOperandType.getShape(); |
| auto resultType = RankedTensorType::get(resultShape, elementType); |
| |
| Value result = |
| rewriter.create<ElementwiseOpT>(op.getLoc(), resultType, bcastOperands); |
| rewriter.replaceOpWithNewOp<mhlo::BroadcastInDimOp>( |
| op, op.getType(), result, bcastOps[0].broadcast_dimensions()); |
| |
| for (auto bcastOp : bcastOps) { |
| if (bcastOp.getOperation()->use_empty()) { |
| rewriter.eraseOp(bcastOp); |
| } |
| } |
| |
| return success(); |
| } |
| }; |
| |
| struct MHLOToMHLOPreprocessingPass |
| : public MHLOToMHLOPreprocessingBase<MHLOToMHLOPreprocessingPass> { |
| void getDependentDialects(DialectRegistry ®istry) const override { |
| registry.insert<shape::ShapeDialect, mhlo::MhloDialect, |
| tensor::TensorDialect>(); |
| } |
| |
| void runOnOperation() override { |
| MLIRContext *context = &getContext(); |
| ConversionTarget conversionTarget(*context); |
| OwningRewritePatternList conversionPatterns(&getContext()); |
| // Note that various input modalities may do their own legalization of |
| // CHLO. Converting here allows IREE to accept CHLO dialect regardless of |
| // whether it was legalized away at a higher level. |
| // chlo::PopulateLegalizeChloToHloPatterns(context, &conversionPatterns); |
| conversionTarget.addLegalDialect< |
| shape::ShapeDialect, chlo::HloClientDialect, mhlo::MhloDialect, |
| mlir::StandardOpsDialect, mlir::tensor::TensorDialect>(); |
| // conversionTarget.addIllegalDialect<chlo::HloClientDialect>(); |
| if (failed(applyPartialConversion(getOperation(), conversionTarget, |
| std::move(conversionPatterns)))) { |
| return signalPassFailure(); |
| } |
| |
| OwningRewritePatternList patterns(&getContext()); |
| // TODO: Remove once we have a general contraction to matmul pass. |
| mhlo::PopulateEinsumToDotGeneralPatterns(context, &patterns); |
| mhlo::PopulateUnfuseBatchNormPatterns(context, &patterns); |
| mhlo::PopulateComplexLoweringPatterns(context, &patterns); |
| mhlo::PopulateGatherToTorchIndexSelectPatterns(context, &patterns); |
| patterns.insert<ExtractReduceWindowOpPaddingAttributes, |
| AdjustDepthwiseFilterShape, DecomposeLog1PPattern, |
| DecomposeExpM1Pattern, ExpandRngNormal>(context); |
| |
| // dot_general canoncalization patterns. |
| mhlo::PopulateGeneralDotOpLoweringPatterns(&patterns, context); |
| patterns.insert<RankReducedDotGeneral, TransposeGenericDotGeneral>(context); |
| |
| // Unary elementwise op. |
| patterns.insert< |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::AbsOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::CeilOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::ConvertOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::ClzOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::CosOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::ExpOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::Expm1Op>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::FloorOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::ImagOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::IsFiniteOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::LogOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::Log1pOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::LogisticOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::NotOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::NegOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::PopulationCountOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::RealOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::RoundOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::RsqrtOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::SignOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::SinOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::SqrtOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::TanhOp>>(context); |
| // Binary elementwise op. |
| patterns.insert< |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::AddOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::Atan2Op>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::ComplexOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::DivOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::MaxOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::MinOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::MulOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::PowOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::RemOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::ShiftLeftOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::ShiftRightArithmeticOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::ShiftRightLogicalOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::SubOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::AndOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::OrOp>, |
| ReorderBroadcastInDimOpAndElementwiseOp<mhlo::XorOp>>(context); |
| if (extractPadFromConv) { |
| patterns.insert<ExtractConvOpPaddingAttributes>(context); |
| } |
| if (orderConvFeatures) { |
| patterns.insert<ReorderConvOpInputDimensions>(context); |
| patterns.insert<ReorderConvOpKernelDimensions>(context); |
| patterns.insert<ReorderConvOpOutputDimensions>(context); |
| } |
| (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
| } |
| }; |
| |
| } // namespace |
| |
| std::unique_ptr<OperationPass<FuncOp>> createMHLOToMHLOPreprocessingPass() { |
| return std::make_unique<MHLOToMHLOPreprocessingPass>(); |
| } |
| |
| } // namespace iree_compiler |
| } // namespace mlir |