|  | // Copyright 2020 Google LLC | 
|  | // | 
|  | // Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | // you may not use this file except in compliance with the License. | 
|  | // You may obtain a copy of the License at | 
|  | // | 
|  | //      https://www.apache.org/licenses/LICENSE-2.0 | 
|  | // | 
|  | // Unless required by applicable law or agreed to in writing, software | 
|  | // distributed under the License is distributed on an "AS IS" BASIS, | 
|  | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | // See the License for the specific language governing permissions and | 
|  | // limitations under the License. | 
|  |  | 
|  | #include "llvm/Support/FormatVariadic.h" | 
|  | #include "mlir/Dialect/Shape/IR/Shape.h" | 
|  | #include "mlir/Dialect/StandardOps/IR/Ops.h" | 
|  | #include "mlir/Pass/Pass.h" | 
|  | #include "mlir/Support/LLVM.h" | 
|  | #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" | 
|  | #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" | 
|  | #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" | 
|  | #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" | 
|  | #include "tensorflow/compiler/mlir/xla/transforms/passes.h" | 
|  |  | 
|  | namespace mlir { | 
|  | namespace iree_compiler { | 
|  | namespace { | 
|  |  | 
|  | class CheckNoTensorflow : public PassWrapper<CheckNoTensorflow, FunctionPass> { | 
|  | public: | 
|  | void getDependentDialects(DialectRegistry ®istry) const override { | 
|  | registry.insert<chlo::HloClientDialect, mhlo::MhloDialect, | 
|  | shape::ShapeDialect, StandardOpsDialect>(); | 
|  | } | 
|  |  | 
|  | CheckNoTensorflow() = default; | 
|  | CheckNoTensorflow(const CheckNoTensorflow &) {} | 
|  |  | 
|  | /// Validates that no TensorFlow frontends ops are in the function. | 
|  | void runOnFunction() override { | 
|  | auto op = getFunction(); | 
|  | auto context = op.getContext(); | 
|  |  | 
|  | Dialect *dialect = context->getLoadedDialect("tf"); | 
|  | DenseSet<Operation *> illegalOps; | 
|  | op.walk([&](Operation *op) { | 
|  | if (op->getDialect() == dialect) { | 
|  | illegalOps.insert(op); | 
|  | } | 
|  | }); | 
|  |  | 
|  | if (!illegalOps.empty()) { | 
|  | emitLegalizationErrors(op, illegalOps); | 
|  | return signalPassFailure(); | 
|  | } | 
|  | } | 
|  |  | 
|  | // Emits debug information which includes the number of ops of each type which | 
|  | // failed to legalize. | 
|  | void emitLegalizationErrors(Operation *op, | 
|  | const DenseSet<Operation *> &nonlegalizedOps) { | 
|  | // Print op errors for each of the TensorFlow ops that still remain. | 
|  | std::map<StringRef, int> opNameCounts; | 
|  | for (Operation *nonlegalizedOp : nonlegalizedOps) { | 
|  | StringRef opName = nonlegalizedOp->getName().getStringRef(); | 
|  | opNameCounts[opName]++; | 
|  | nonlegalizedOp->emitOpError() | 
|  | << ": unlegalized TensorFlow op still exists"; | 
|  | } | 
|  |  | 
|  | std::vector<std::string> errorMessages; | 
|  | errorMessages.reserve(opNameCounts.size()); | 
|  | for (const auto &opInfo : opNameCounts) { | 
|  | errorMessages.push_back( | 
|  | llvm::formatv("\t{0} (count: {1})", opInfo.first, opInfo.second)); | 
|  | } | 
|  | Location loc = op->getLoc(); | 
|  | emitError(loc) << "The following Tensorflow operations still remain: \n" | 
|  | << llvm::join(errorMessages, "\n") << "\n"; | 
|  | } | 
|  | }; | 
|  |  | 
|  | static PassRegistration<CheckNoTensorflow> pass( | 
|  | "iree-check-no-tf", "Check that no TensorFlow frontend ops remain"); | 
|  | }  // namespace | 
|  |  | 
|  | std::unique_ptr<OperationPass<FuncOp>> createCheckNoTF() { | 
|  | return std::make_unique<CheckNoTensorflow>(); | 
|  | } | 
|  |  | 
|  | }  // namespace iree_compiler | 
|  | }  // namespace mlir |