|  | // Copyright 2020 The IREE Authors | 
|  | // | 
|  | // Licensed under the Apache License v2.0 with LLVM Exceptions. | 
|  | // See https://llvm.org/LICENSE.txt for license information. | 
|  | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
|  |  | 
|  | #include "iree_tf_compiler/TF/Passes.h" | 
|  | #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" | 
|  | #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" | 
|  | #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" | 
|  | #include "mlir/Dialect/Func/IR/FuncOps.h" | 
|  | #include "mlir/Dialect/Linalg/IR/Linalg.h" | 
|  | #include "mlir/Dialect/Shape/IR/Shape.h" | 
|  | #include "mlir/Dialect/Tensor/IR/Tensor.h" | 
|  | #include "mlir/IR/BuiltinOps.h" | 
|  | #include "mlir/Pass/Pass.h" | 
|  | #include "mlir/Support/LLVM.h" | 
|  | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" | 
|  | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" | 
|  | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" | 
|  | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" | 
|  | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" | 
|  | #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" | 
|  | #include "tensorflow/compiler/mlir/xla/transforms/passes.h" | 
|  |  | 
|  | namespace mlir { | 
|  | namespace iree_integrations { | 
|  | namespace TF { | 
|  |  | 
|  | // This is a customized version of the TF to XLA lowering in: | 
|  | //    tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc | 
|  | // It does not require the same number of options as we can hardcode as the pass | 
|  | // the IREE requires. | 
|  | class ConvertToMHLOPass | 
|  | : public PassWrapper<ConvertToMHLOPass, OperationPass<FuncOp>> { | 
|  | void getDependentDialects(DialectRegistry ®istry) const override { | 
|  | registry | 
|  | .insert<mlir::linalg::LinalgDialect, mlir::TF::TensorFlowDialect, | 
|  | mlir::tf_executor::TensorFlowExecutorDialect, | 
|  | mlir::tf_device::TensorFlowDeviceDialect, | 
|  | mlir::tf_saved_model::TensorFlowSavedModelDialect, | 
|  | chlo::HloClientDialect, mhlo::MhloDialect, shape::ShapeDialect, | 
|  | mlir::arith::ArithmeticDialect, func::FuncDialect>(); | 
|  | } | 
|  |  | 
|  | StringRef getArgument() const override { return "iree-tf-convert-to-mhlo"; } | 
|  |  | 
|  | StringRef getDescription() const override { | 
|  | return "Converts from TensorFlow to the XLA MHLO dialect"; | 
|  | } | 
|  |  | 
|  | public: | 
|  | ConvertToMHLOPass() = default; | 
|  | ConvertToMHLOPass(const ConvertToMHLOPass &) {} | 
|  |  | 
|  | void runOnOperation() override { | 
|  | auto op = getOperation(); | 
|  | MLIRContext *context = op.getContext(); | 
|  |  | 
|  | // Lower TF Patterns must be separate from canonocalization patterns as | 
|  | // they are sometimes inversions of eachother. | 
|  | RewritePatternSet lowerTfPatterns(&getContext()); | 
|  | mlir::TF::PopulateTFLoweringBeforeHLOPatterns(context, &lowerTfPatterns); | 
|  |  | 
|  | RewritePatternSet canonicalizePatterns(&getContext()); | 
|  | for (auto op : context->getRegisteredOperations()) { | 
|  | op.getCanonicalizationPatterns(canonicalizePatterns, context); | 
|  | } | 
|  |  | 
|  | RewritePatternSet patterns(&getContext()); | 
|  | // Note that the `OperationConverter` orders patterns lexicographically by: | 
|  | // 1) Ascending legalization depth (i.e., minimum number of patterns | 
|  | // necessary to arrive at conversion target). | 
|  | // 2) Descending pattern benefit. | 
|  | // 3) Order of patterns in `RewritePatternSet`. | 
|  |  | 
|  | // Add TF->HLO legalization patterns. | 
|  | mhlo::PopulateLegalizeTfPatterns(context, &patterns); | 
|  |  | 
|  | // IREE Direct TF lowerings. | 
|  | populateDirectLoweringPatterns(context, patterns); | 
|  |  | 
|  | // TF::PopulateLoweringTFPatterns(context, &patterns); | 
|  |  | 
|  | // ConstantLike op is convenient to create splat constants, but is | 
|  | // canonicalized to plain HLO constant if statically shaped. Add the | 
|  | // canonicalization pattern to pattern list to enable multi-hop lowering. | 
|  | chlo::ConstantLikeOp::getCanonicalizationPatterns(patterns, context); | 
|  |  | 
|  | ConversionTarget target(*context); | 
|  | target.addLegalDialect<chlo::HloClientDialect>(); | 
|  | target.addLegalDialect<linalg::LinalgDialect>(); | 
|  | target.addLegalDialect<mhlo::MhloDialect>(); | 
|  | target.addLegalDialect<mlir::func::FuncDialect, | 
|  | mlir::arith::ArithmeticDialect>(); | 
|  | target.addLegalDialect<shape::ShapeDialect>(); | 
|  | target.addLegalDialect<tensor::TensorDialect>(); | 
|  | target.addLegalOp<mlir::func::CallOp>(); | 
|  | target.addLegalOp<mlir::tensor::CastOp>(); | 
|  | target.addLegalOp<mlir::tensor::DimOp>(); | 
|  |  | 
|  | // TODO(suderman): Enable logicistic op for lowering once the op is | 
|  | // supported in IREE. Also, remove the numerically unstable ConvertSigmoidOp | 
|  | // pattern in the legalize-tf pass. | 
|  | target.addIllegalOp<mhlo::LogisticOp>(); | 
|  |  | 
|  | // In general, IREE does not support DynamicBroadcastInDim ops that do not | 
|  | // resolve to a static form. This excludes any TF2XLA expansions which | 
|  | // we ultimately lack a linalg lowering for. Matches the corresponding | 
|  | // condition in legalize_to_linalg.cc for this op. | 
|  | target.addDynamicallyLegalOp<mhlo::DynamicBroadcastInDimOp>( | 
|  | [](mhlo::DynamicBroadcastInDimOp op) { | 
|  | if (auto t = op.operand() | 
|  | .getType() | 
|  | .template dyn_cast<RankedTensorType>()) { | 
|  | if (t.hasStaticShape()) { | 
|  | return true; | 
|  | } | 
|  | } | 
|  | return false; | 
|  | }); | 
|  |  | 
|  | DenseSet<Operation *> prevUnconvertedOps; | 
|  | DenseSet<Operation *> unconvertedOps; | 
|  |  | 
|  | FrozenRewritePatternSet frozenPatterns(std::move(patterns)); | 
|  | FrozenRewritePatternSet frozenCanonicalizePatterns( | 
|  | std::move(canonicalizePatterns)); | 
|  | FrozenRewritePatternSet frozenTfPatterns(std::move(lowerTfPatterns)); | 
|  | while (true) { | 
|  | if (failed( | 
|  | applyPatternsAndFoldGreedily(op, frozenCanonicalizePatterns))) { | 
|  | return signalPassFailure(); | 
|  | } | 
|  |  | 
|  | if (failed(applyPatternsAndFoldGreedily(op, frozenTfPatterns))) { | 
|  | return signalPassFailure(); | 
|  | } | 
|  |  | 
|  | if (failed(applyPartialConversion(op, target, frozenPatterns, | 
|  | &unconvertedOps))) { | 
|  | return signalPassFailure(); | 
|  | } | 
|  |  | 
|  | if (prevUnconvertedOps == unconvertedOps) break; | 
|  | prevUnconvertedOps = std::move(unconvertedOps); | 
|  | } | 
|  | } | 
|  |  | 
|  | private: | 
|  | Option<bool> allow_partial_conversion_{ | 
|  | *this, "allow-partial-conversion", | 
|  | llvm::cl::desc("Allow operations that can't be legalized."), | 
|  | llvm::cl::init(false)}; | 
|  | Option<bool> legalize_chlo_{ | 
|  | *this, "legalize-chlo", | 
|  | llvm::cl::desc( | 
|  | "Also legalizes intermediate chlo ops to hlo (default true)"), | 
|  | llvm::cl::init(false)}; | 
|  | Option<bool> use_tf2xla_fallback_{ | 
|  | *this, "use-tf2xla-fallback", | 
|  | llvm::cl::desc( | 
|  | "Also use TF2XLA fallback for legalization (default false)"), | 
|  | llvm::cl::init(false)}; | 
|  | Option<std::string> device_type_{ | 
|  | *this, "device-type", | 
|  | llvm::cl::desc( | 
|  | "The device type used by TF2XLA fallback. Must be specified if " | 
|  | "use-tf2xla-fallback is true, otherwise not used."), | 
|  | llvm::cl::init("INVALID_DEVICE_TYPE")}; | 
|  | }; | 
|  |  | 
|  | std::unique_ptr<Pass> createConvertToMHLOPass() { | 
|  | return std::make_unique<ConvertToMHLOPass>(); | 
|  | } | 
|  |  | 
|  | static PassRegistration<ConvertToMHLOPass> pass; | 
|  |  | 
|  | }  // namespace TF | 
|  | }  // namespace iree_integrations | 
|  | }  // namespace mlir |