| // Copyright 2020 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree_tf_compiler/TF/Passes.h" |
| #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" |
| #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" |
| #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" |
| #include "mlir/Dialect/Func/IR/FuncOps.h" |
| #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| #include "mlir/Dialect/Shape/IR/Shape.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/IR/BuiltinOps.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Support/LLVM.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" |
| #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" |
| #include "tensorflow/compiler/mlir/xla/transforms/passes.h" |
| |
| namespace mlir { |
| namespace iree_integrations { |
| namespace TF { |
| |
| // This is a customized version of the TF to XLA lowering in: |
| // tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc |
| // It does not require the same number of options as we can hardcode as the pass |
| // the IREE requires. |
| class ConvertToMHLOPass |
| : public PassWrapper<ConvertToMHLOPass, OperationPass<func::FuncOp>> { |
| void getDependentDialects(DialectRegistry ®istry) const override { |
| registry.insert<mlir::linalg::LinalgDialect, mlir::TF::TensorFlowDialect, |
| mlir::tf_executor::TensorFlowExecutorDialect, |
| mlir::tf_device::TensorFlowDeviceDialect, |
| mlir::tf_saved_model::TensorFlowSavedModelDialect, |
| chlo::ChloDialect, mhlo::MhloDialect, shape::ShapeDialect, |
| mlir::arith::ArithmeticDialect, func::FuncDialect>(); |
| } |
| |
| StringRef getArgument() const override { return "iree-tf-convert-to-mhlo"; } |
| |
| StringRef getDescription() const override { |
| return "Converts from TensorFlow to the XLA MHLO dialect"; |
| } |
| |
| public: |
| ConvertToMHLOPass() = default; |
| ConvertToMHLOPass(const ConvertToMHLOPass &) {} |
| |
| void runOnOperation() override { |
| auto op = getOperation(); |
| MLIRContext *context = op.getContext(); |
| |
| // Lower TF Patterns must be separate from canonocalization patterns as |
| // they are sometimes inversions of eachother. |
| RewritePatternSet lowerTfPatterns(&getContext()); |
| mlir::TF::PopulateTFLoweringBeforeHLOPatterns(context, &lowerTfPatterns); |
| |
| RewritePatternSet canonicalizePatterns(&getContext()); |
| for (auto op : context->getRegisteredOperations()) { |
| op.getCanonicalizationPatterns(canonicalizePatterns, context); |
| } |
| |
| RewritePatternSet patterns(&getContext()); |
| // Note that the `OperationConverter` orders patterns lexicographically by: |
| // 1) Ascending legalization depth (i.e., minimum number of patterns |
| // necessary to arrive at conversion target). |
| // 2) Descending pattern benefit. |
| // 3) Order of patterns in `RewritePatternSet`. |
| |
| // Add TF->HLO legalization patterns. |
| mhlo::PopulateLegalizeTfPatterns(context, &patterns); |
| |
| // IREE Direct TF lowerings. |
| populateDirectLoweringPatterns(context, patterns); |
| |
| // TF::PopulateLoweringTFPatterns(context, &patterns); |
| |
| // ConstantLike op is convenient to create splat constants, but is |
| // canonicalized to plain HLO constant if statically shaped. Add the |
| // canonicalization pattern to pattern list to enable multi-hop lowering. |
| chlo::ConstantLikeOp::getCanonicalizationPatterns(patterns, context); |
| |
| ConversionTarget target(*context); |
| target.addLegalDialect<chlo::ChloDialect>(); |
| target.addLegalDialect<linalg::LinalgDialect>(); |
| target.addLegalDialect<mhlo::MhloDialect>(); |
| target.addLegalDialect<mlir::func::FuncDialect, |
| mlir::arith::ArithmeticDialect>(); |
| target.addLegalDialect<shape::ShapeDialect>(); |
| target.addLegalDialect<tensor::TensorDialect>(); |
| target.addLegalOp<mlir::func::CallOp>(); |
| target.addLegalOp<mlir::tensor::CastOp>(); |
| target.addLegalOp<mlir::tensor::DimOp>(); |
| |
| // TODO(suderman): Enable logicistic op for lowering once the op is |
| // supported in IREE. Also, remove the numerically unstable ConvertSigmoidOp |
| // pattern in the legalize-tf pass. |
| target.addIllegalOp<mhlo::LogisticOp>(); |
| |
| // In general, IREE does not support DynamicBroadcastInDim ops that do not |
| // resolve to a static form. This excludes any TF2XLA expansions which |
| // we ultimately lack a linalg lowering for. Matches the corresponding |
| // condition in legalize_to_linalg.cc for this op. |
| target.addDynamicallyLegalOp<mhlo::DynamicBroadcastInDimOp>( |
| [](mhlo::DynamicBroadcastInDimOp op) { |
| if (auto t = op.operand() |
| .getType() |
| .template dyn_cast<RankedTensorType>()) { |
| if (t.hasStaticShape()) { |
| return true; |
| } |
| } |
| return false; |
| }); |
| |
| DenseSet<Operation *> prevUnconvertedOps; |
| DenseSet<Operation *> unconvertedOps; |
| |
| FrozenRewritePatternSet frozenPatterns(std::move(patterns)); |
| FrozenRewritePatternSet frozenCanonicalizePatterns( |
| std::move(canonicalizePatterns)); |
| FrozenRewritePatternSet frozenTfPatterns(std::move(lowerTfPatterns)); |
| while (true) { |
| if (failed( |
| applyPatternsAndFoldGreedily(op, frozenCanonicalizePatterns))) { |
| return signalPassFailure(); |
| } |
| |
| if (failed(applyPatternsAndFoldGreedily(op, frozenTfPatterns))) { |
| return signalPassFailure(); |
| } |
| |
| if (failed(applyPartialConversion(op, target, frozenPatterns, |
| &unconvertedOps))) { |
| return signalPassFailure(); |
| } |
| |
| if (prevUnconvertedOps == unconvertedOps) break; |
| prevUnconvertedOps = std::move(unconvertedOps); |
| } |
| } |
| |
| private: |
| Option<bool> allow_partial_conversion_{ |
| *this, "allow-partial-conversion", |
| llvm::cl::desc("Allow operations that can't be legalized."), |
| llvm::cl::init(false)}; |
| Option<bool> legalize_chlo_{ |
| *this, "legalize-chlo", |
| llvm::cl::desc( |
| "Also legalizes intermediate chlo ops to hlo (default true)"), |
| llvm::cl::init(false)}; |
| Option<bool> use_tf2xla_fallback_{ |
| *this, "use-tf2xla-fallback", |
| llvm::cl::desc( |
| "Also use TF2XLA fallback for legalization (default false)"), |
| llvm::cl::init(false)}; |
| Option<std::string> device_type_{ |
| *this, "device-type", |
| llvm::cl::desc( |
| "The device type used by TF2XLA fallback. Must be specified if " |
| "use-tf2xla-fallback is true, otherwise not used."), |
| llvm::cl::init("INVALID_DEVICE_TYPE")}; |
| }; |
| |
| std::unique_ptr<Pass> createConvertToMHLOPass() { |
| return std::make_unique<ConvertToMHLOPass>(); |
| } |
| |
| static PassRegistration<ConvertToMHLOPass> pass; |
| |
| } // namespace TF |
| } // namespace iree_integrations |
| } // namespace mlir |