Added a pass to validate TF has been lowered away (#3572)

Added a pass that verifies that all TF has been lowered away. This guarantees
we terminate early when a TF operation should no longer exist.
diff --git a/integrations/tensorflow/compiler/BUILD b/integrations/tensorflow/compiler/BUILD
index 838e202..fba7d6d 100644
--- a/integrations/tensorflow/compiler/BUILD
+++ b/integrations/tensorflow/compiler/BUILD
@@ -21,6 +21,7 @@
 cc_library(
     name = "tensorflow",
     srcs = [
+        "CheckNoTF.cpp",
         "LegalizeTF.cpp",
         "Passes.cpp",
         "PropagateResourceCasts.cpp",
diff --git a/integrations/tensorflow/compiler/CheckNoTF.cpp b/integrations/tensorflow/compiler/CheckNoTF.cpp
new file mode 100644
index 0000000..246f28a
--- /dev/null
+++ b/integrations/tensorflow/compiler/CheckNoTF.cpp
@@ -0,0 +1,93 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "llvm/Support/FormatVariadic.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
+#include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
+#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace {
+
+class CheckNoTensorflow : public PassWrapper<CheckNoTensorflow, FunctionPass> {
+ public:
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<chlo::HloClientDialect, mhlo::MhloDialect,
+                    shape::ShapeDialect, StandardOpsDialect>();
+  }
+
+  CheckNoTensorflow() = default;
+  CheckNoTensorflow(const CheckNoTensorflow &) {}
+
+  /// Validates that no TensorFlow frontends ops are in the function.
+  void runOnFunction() override {
+    auto op = getFunction();
+    auto context = op.getContext();
+
+    Dialect *dialect = context->getLoadedDialect("tf");
+    DenseSet<Operation *> illegalOps;
+    op.walk([&](Operation *op) {
+      if (op->getDialect() == dialect) {
+        illegalOps.insert(op);
+      }
+    });
+
+    if (!illegalOps.empty()) {
+      emitLegalizationErrors(op, illegalOps);
+      return signalPassFailure();
+    }
+  }
+
+  // Emits debug information which includes the number of ops of each type which
+  // failed to legalize.
+  void emitLegalizationErrors(Operation *op,
+                              const DenseSet<Operation *> &nonlegalizedOps) {
+    // Print op errors for each of the TensorFlow ops that still remain.
+    std::map<StringRef, int> opNameCounts;
+    for (Operation *nonlegalizedOp : nonlegalizedOps) {
+      StringRef opName = nonlegalizedOp->getName().getStringRef();
+      opNameCounts[opName]++;
+      nonlegalizedOp->emitOpError()
+          << ": unlegalized TensorFlow op still exists";
+    }
+
+    std::vector<std::string> errorMessages;
+    errorMessages.reserve(opNameCounts.size());
+    for (const auto &opInfo : opNameCounts) {
+      errorMessages.push_back(
+          llvm::formatv("\t{0} (count: {1})", opInfo.first, opInfo.second));
+    }
+    Location loc = op->getLoc();
+    emitError(loc) << "The following Tensorflow operations still remain: \n"
+                   << llvm::join(errorMessages, "\n") << "\n";
+  }
+};
+
+static PassRegistration<CheckNoTensorflow> pass(
+    "iree-check-no-tf", "Check that no TensorFlow frontend ops remain");
+}  // namespace
+
+std::unique_ptr<OperationPass<FuncOp>> createCheckNoTF() {
+  return std::make_unique<CheckNoTensorflow>();
+}
+
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/integrations/tensorflow/compiler/Passes.cpp b/integrations/tensorflow/compiler/Passes.cpp
index b4b6c47..db2e9d9 100644
--- a/integrations/tensorflow/compiler/Passes.cpp
+++ b/integrations/tensorflow/compiler/Passes.cpp
@@ -64,6 +64,11 @@
   // - It removes tf_saved_model.semantics from the module, which we can only
   //   do at the very end.
   pm.addPass(createTFSavedModelLowerExportedFunctions());
+
+  ////////////////////////////////////////////////////////////////////////////
+  // Validate that all Tensorflow has been legalized away.
+  ////////////////////////////////////////////////////////////////////////////
+  pm.addPass(createCheckNoTF());
 }
 
 static mlir::PassPipelineRegistration<> pipeline(
diff --git a/integrations/tensorflow/compiler/Passes.h b/integrations/tensorflow/compiler/Passes.h
index e3016a3..8b3b219 100644
--- a/integrations/tensorflow/compiler/Passes.h
+++ b/integrations/tensorflow/compiler/Passes.h
@@ -40,6 +40,9 @@
 // Push resource casts forward to better propagate resource related shapes.
 std::unique_ptr<OperationPass<ModuleOp>> createPropagateResourceCasts();
 
+// Validates whether any Tensorflow operations remain.
+std::unique_ptr<OperationPass<FuncOp>> createCheckNoTF();
+
 // Create a single pipeline that will run all the needed IREE-specific TF import
 // passes in the right order.
 void createIreeTfImportPipeline(OpPassManager &pm);
diff --git a/integrations/tensorflow/compiler/test/check-no-tf.mlir b/integrations/tensorflow/compiler/test/check-no-tf.mlir
new file mode 100644
index 0000000..a7c49e6
--- /dev/null
+++ b/integrations/tensorflow/compiler/test/check-no-tf.mlir
@@ -0,0 +1,28 @@
+// RUN: iree-tf-opt %s -iree-check-no-tf -split-input-file -verify-diagnostics
+
+// CHECK-LABEL: func @f
+func @f() -> (tensor<i32>) {
+  // CHECK: [[VAL0:%.+]] = mhlo.constant dense<3>
+  %0 = mhlo.constant dense<3> : tensor<i32>
+  return %0 : tensor<i32>
+}
+
+// -----
+
+// expected-error@+3 {{'tf.Const' op : unlegalized TensorFlow op still exists}}
+// expected-error@below {{The following Tensorflow operations still remain}}
+func @f() -> (tensor<i32>) {
+  %0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
+  return %0 : tensor<i32>
+}
+
+// -----
+
+// expected-error@+4 {{'tf.Const' op : unlegalized TensorFlow op still exists}}
+// expected-error@+4 {{'tf.Add' op : unlegalized TensorFlow op still exists}}
+// expected-error@below {{The following Tensorflow operations still remain}}
+func @f(%arg0 : tensor<i32>) -> (tensor<i32>) {
+  %0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
+  %1 = "tf.Add"(%arg0, %0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+  return %1 : tensor<i32>
+}
diff --git a/integrations/tensorflow/compiler/test/saved_model_adopt_exports.py b/integrations/tensorflow/compiler/test/saved_model_adopt_exports.py
index 256ce89..aff5dd8 100644
--- a/integrations/tensorflow/compiler/test/saved_model_adopt_exports.py
+++ b/integrations/tensorflow/compiler/test/saved_model_adopt_exports.py
@@ -23,6 +23,7 @@
 SAVED_MODEL_IMPORT_PASSES = [
     "tf-executor-graph-pruning",
     "tf-standard-pipeline",
+    "iree-xla-legalize-tf",
     "iree-tf-import-pipeline",
     "canonicalize",
 ]
@@ -114,8 +115,8 @@
 # CHECK: attributes
 # CHECK-SAME: iree.module.export
 # CHECK-SAME: iree.reflection = {abi = "sip", abiv = 1 : i32, sip = "I1!R1!"}
-# CHECK-DAG:   [[CONST_2xf32:%.+]] = "tf.Const"() {value = dense<[0.000000e+00, 1.000000e+00]> : tensor<2xf32>} : () -> tensor<2xf32>
-# CHECK-DAG:   [[CONST_3xf32:%.+]] = "tf.Const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>} : () -> tensor<3xf32>
+# CHECK-DAG:   [[CONST_2xf32:%.+]] = mhlo.constant dense<[0.000000e+00, 1.000000e+00]>
+# CHECK-DAG:   [[CONST_3xf32:%.+]] = mhlo.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]>
 # CHECK-DAG:   flow.variable.store [[CONST_2xf32]], @v : tensor<2xf32>
 # CHECK-DAG:   flow.variable.store [[CONST_3xf32]], @v : tensor<3xf32>
 # CHECK: FINISH_TEST