| // Copyright 2020 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // https://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/Support/Casting.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/StandardTypes.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Support/LogicalResult.h" |
| #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" |
| #include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| namespace IREE { |
| namespace Flow { |
| |
| namespace { |
| |
| static llvm::cl::opt<bool> extractPadFromConv( |
| "iree-extract-pad-from-conv", |
| llvm::cl::desc("Extract padding attributes from conv op"), |
| llvm::cl::init(false)); |
| |
| static bool isAllZero(DenseIntElementsAttr attr) { |
| if (!attr.isSplat()) return false; |
| return attr.getSplatValue<IntegerAttr>().getInt() == 0; |
| } |
| |
| class ExtractConvOpPaddingAttributes |
| : public OpRewritePattern<xla_hlo::ConvOp> { |
| public: |
| using OpRewritePattern<xla_hlo::ConvOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(xla_hlo::ConvOp op, |
| PatternRewriter &rewriter) const override { |
| if (!op.padding()) return failure(); |
| auto inputType = op.lhs().getType().cast<ShapedType>(); |
| int rank = inputType.getRank(); |
| SmallVector<int64_t, 4> paddingLow, paddingHigh, interiorPadding, shape; |
| paddingLow.append(rank, 0); |
| paddingHigh.append(rank, 0); |
| interiorPadding.append(rank, 0); |
| for (auto iter : |
| llvm::enumerate(op.dimension_numbers().input_spatial_dimensions())) { |
| unsigned idx = iter.index(); |
| unsigned dim = iter.value().getZExtValue(); |
| paddingLow[dim] = op.paddingAttr().getValue<int64_t>({idx, 0}); |
| paddingHigh[dim] = op.paddingAttr().getValue<int64_t>({idx, 1}); |
| } |
| for (unsigned i = 0; i < rank; ++i) { |
| // xla_hlo.pad doesn't support dynamic shape. |
| if (inputType.isDynamicDim(i)) return failure(); |
| int size = inputType.getShape()[i]; |
| shape.push_back(size + paddingLow[i] + paddingHigh[i]); |
| } |
| |
| auto toDenseAttr = [&rewriter](ArrayRef<int64_t> elements) { |
| return DenseIntElementsAttr::get( |
| RankedTensorType::get(elements.size(), rewriter.getIntegerType(64)), |
| elements); |
| }; |
| |
| auto loc = op.getLoc(); |
| auto padResultType = |
| RankedTensorType::get(shape, inputType.getElementType()); |
| Attribute zeroAttr = rewriter.getZeroAttr( |
| RankedTensorType::get({}, inputType.getElementType())); |
| auto zero = rewriter.create<ConstantOp>(loc, zeroAttr); |
| auto padOp = rewriter.create<xla_hlo::PadOp>( |
| loc, padResultType, op.lhs(), zero, toDenseAttr(paddingLow), |
| toDenseAttr(paddingHigh), toDenseAttr(interiorPadding)); |
| auto resultType = op.getResult().getType(); |
| auto newOp = rewriter.create<xla_hlo::ConvOp>( |
| op.getLoc(), resultType, padOp.getResult(), op.rhs(), |
| op.window_stridesAttr(), /*padding=*/nullptr, op.lhs_dilationAttr(), |
| op.rhs_dilationAttr(), op.dimension_numbersAttr(), |
| op.feature_group_countAttr(), op.batch_group_countAttr(), |
| op.precision_configAttr()); |
| rewriter.replaceOp(op, newOp.getResult()); |
| return success(); |
| } |
| }; |
| |
| class ExtractReduceWindowOpPaddingAttributes |
| : public OpRewritePattern<xla_hlo::ReduceWindowOp> { |
| public: |
| using OpRewritePattern<xla_hlo::ReduceWindowOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(xla_hlo::ReduceWindowOp op, |
| PatternRewriter &rewriter) const override { |
| if (!op.padding()) return failure(); |
| if (op.base_dilations() || op.window_dilations()) return failure(); |
| if (isAllZero(op.paddingAttr())) return failure(); |
| |
| auto inputType = op.operand().getType().cast<ShapedType>(); |
| int rank = inputType.getRank(); |
| SmallVector<int64_t, 4> paddingLow, paddingHigh, interiorPadding, shape; |
| for (unsigned i = 0; i < rank; ++i) { |
| // xla_hlo.pad doesn't support dynamic shape. |
| if (inputType.isDynamicDim(i)) return failure(); |
| interiorPadding.push_back(0); |
| paddingLow.push_back(op.paddingAttr().getValue<int64_t>({i, 0})); |
| paddingHigh.push_back(op.paddingAttr().getValue<int64_t>({i, 1})); |
| int size = inputType.getShape()[i]; |
| shape.push_back(size + paddingLow.back() + paddingHigh.back()); |
| } |
| |
| auto toDenseAttr = [&rewriter](ArrayRef<int64_t> elements) { |
| return DenseIntElementsAttr::get( |
| RankedTensorType::get(elements.size(), rewriter.getIntegerType(64)), |
| elements); |
| }; |
| |
| auto loc = op.getLoc(); |
| auto padResultType = |
| RankedTensorType::get(shape, inputType.getElementType()); |
| auto padOp = rewriter.create<xla_hlo::PadOp>( |
| loc, padResultType, op.operand(), op.init_value(), |
| toDenseAttr(paddingLow), toDenseAttr(paddingHigh), |
| toDenseAttr(interiorPadding)); |
| auto newOp = rewriter.create<xla_hlo::ReduceWindowOp>( |
| loc, op.getResult().getType(), padOp, op.init_value(), |
| op.window_dimensions(), op.window_stridesAttr(), |
| op.base_dilationsAttr(), op.window_dilationsAttr(), |
| /*padding=*/nullptr); |
| rewriter.inlineRegionBefore(op.body(), newOp.body(), newOp.body().begin()); |
| rewriter.replaceOp(op, newOp.getResult()); |
| return success(); |
| } |
| }; |
| |
| // Adjust the shape of depthwise_conv filter where is applied by xla_hlo. |
| class AdjustDepthwiseFilterShape : public OpRewritePattern<xla_hlo::ConvOp> { |
| public: |
| using OpRewritePattern<xla_hlo::ConvOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(xla_hlo::ConvOp op, |
| PatternRewriter &rewriter) const override { |
| const auto featureInDim = |
| op.dimension_numbers().kernel_input_feature_dimension().getInt(); |
| const auto featureOutDim = |
| op.dimension_numbers().kernel_output_feature_dimension().getInt(); |
| const auto &kernelShape = op.rhs().getType().cast<ShapedType>().getShape(); |
| if (kernelShape[featureInDim] != 1) return failure(); |
| |
| const auto groupCount = op.feature_group_count().getZExtValue(); |
| if (groupCount == 1) return failure(); |
| if (kernelShape[featureOutDim] % groupCount != 0) return failure(); |
| |
| SmallVector<int64_t, 4> newShape(kernelShape.begin(), kernelShape.end()); |
| newShape[featureInDim] = groupCount; |
| newShape[featureOutDim] /= groupCount; |
| auto loc = op.getLoc(); |
| auto elemType = op.rhs().getType().cast<ShapedType>().getElementType(); |
| auto reshapeOp = rewriter.create<xla_hlo::ReshapeOp>( |
| loc, RankedTensorType::get(newShape, elemType), op.rhs()); |
| auto resultType = op.getResult().getType(); |
| SmallVector<Value, 2> operands = {op.lhs(), reshapeOp.getResult()}; |
| auto newOp = rewriter.create<xla_hlo::ConvOp>(op.getLoc(), resultType, |
| operands, op.getAttrs()); |
| rewriter.replaceOp(op, newOp.getResult()); |
| return success(); |
| } |
| }; |
| |
| struct HLOToHLOPreprocessing |
| : public PassWrapper<HLOToHLOPreprocessing, FunctionPass> { |
| void runOnFunction() override { |
| MLIRContext *context = &getContext(); |
| OwningRewritePatternList patterns; |
| xla_hlo::PopulateUnfuseBatchNormPatterns(context, &patterns); |
| // Note that various input modalities may do their own legalization of |
| // CHLO. Converting here allows IREE to accept CHLO dialect regardless of |
| // whether it was legalized away at a higher level. |
| xla_chlo::PopulateLegalizeChloToHloPatterns(context, &patterns); |
| patterns.insert<ExtractReduceWindowOpPaddingAttributes, |
| AdjustDepthwiseFilterShape>(context); |
| if (extractPadFromConv) { |
| patterns.insert<ExtractConvOpPaddingAttributes>(context); |
| } |
| applyPatternsAndFoldGreedily(getOperation(), patterns); |
| } |
| }; |
| |
| } // namespace |
| |
| std::unique_ptr<OperationPass<FuncOp>> createHLOPreprocessingPass() { |
| return std::make_unique<HLOToHLOPreprocessing>(); |
| } |
| |
| static PassRegistration<HLOToHLOPreprocessing> legalize_pass( |
| "iree-flow-hlo-to-hlo-preprocessing", |
| "Apply hlo to hlo transformations for some hlo ops"); |
| |
| } // namespace Flow |
| } // namespace IREE |
| } // namespace iree_compiler |
| } // namespace mlir |