Added a TF2XLA Lowering that includes XLA folders

Without Tensorflow folders we were not able to progressively lower TF to XLA in cases where constant evaluation is required (e.g. constant axis for tf.Concatenate). Including all TF folders is not an option due to kernel build times, so by creating a custom legalize-tf pass for XLA that includes the XLA folders and alternates folding/lowering, we can legalize these cases without the TF folders.

Includes re-enabling linspace for VMLA.
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py
index ba966db..14fb63f 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py
@@ -81,7 +81,7 @@
     "canonicalize",
 
     # Legalize to XLA
-    "xla-legalize-tf{allow-partial-conversion=true}",
+    "iree-xla-legalize-tf",
     "canonicalize",
 
     # Now that the IR is starting to look nice, optimize global tensors.
diff --git a/integrations/tensorflow/compiler/BUILD b/integrations/tensorflow/compiler/BUILD
index adba6af..838e202 100644
--- a/integrations/tensorflow/compiler/BUILD
+++ b/integrations/tensorflow/compiler/BUILD
@@ -21,6 +21,7 @@
 cc_library(
     name = "tensorflow",
     srcs = [
+        "LegalizeTF.cpp",
         "Passes.cpp",
         "PropagateResourceCasts.cpp",
         "TFSavedModelLowerExportedFunctions.cpp",
@@ -46,10 +47,15 @@
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:Shape",
         "@llvm-project//mlir:ShapeTransforms",
+        "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:Support",
         "@llvm-project//mlir:TransformUtils",
+        "@org_tensorflow//tensorflow/compiler/mlir/hlo",
+        "@org_tensorflow//tensorflow/compiler/mlir/hlo:chlo_legalize_to_hlo",
         "@org_tensorflow//tensorflow/compiler/mlir/tensorflow",
+        "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:lower_tf_lib",
         "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
+        "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_legalize_tf",
     ],
     alwayslink = 1,
 )
diff --git a/integrations/tensorflow/compiler/LegalizeTF.cpp b/integrations/tensorflow/compiler/LegalizeTF.cpp
new file mode 100644
index 0000000..9531521
--- /dev/null
+++ b/integrations/tensorflow/compiler/LegalizeTF.cpp
@@ -0,0 +1,127 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
+#include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
+#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
+
+namespace mlir {
+namespace mhlo {
+namespace {
+
+// This is a customizer version of the TF to XLA lowering in:
+//    tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
+// It does not require the same number of options as we can hardcode as the pass
+// the IREE requires.
+class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<chlo::HloClientDialect, mhlo::MhloDialect,
+                    shape::ShapeDialect, StandardOpsDialect>();
+  }
+
+ public:
+  LegalizeTF() = default;
+  LegalizeTF(const LegalizeTF &) {}
+
+  /// Performs the lowering to XLA dialect.
+  void runOnFunction() override {
+    auto op = getFunction();
+    MLIRContext *context = op.getContext();
+    OwningRewritePatternList canonicalizePatterns;
+    for (auto *op : context->getRegisteredOperations())
+      op->getCanonicalizationPatterns(canonicalizePatterns, context);
+
+    OwningRewritePatternList patterns;
+    // Note that the `OperationConverter` orders patterns lexicographically by:
+    // 1) Ascending legalization depth (i.e., minimum number of patterns
+    // necessary to arrive at conversion target).
+    // 2) Descending pattern benefit.
+    // 3) Order of patterns in `OwningRewritePatternList`.
+
+    // Add TF->HLO legalization patterns.
+    PopulateLegalizeTfPatterns(context, &patterns);
+
+    // Add TF->TF lowering patterns.
+    TF::PopulateLoweringTFPatterns(context, &patterns);
+
+    // Populate with CHLO->HLO lowerings to account for TF ops legalized to
+    // CHLO first.
+    chlo::PopulateLegalizeChloToHloPatterns(context, &patterns);
+
+    // ConstantLike op is convenient to create splat constants, but is
+    // canonicalized to plain HLO constant if statically shaped. Add the
+    // canonicalization pattern to pattern list to enable multi-hop lowering.
+    chlo::ConstantLikeOp::getCanonicalizationPatterns(patterns, context);
+
+    ConversionTarget target(*context);
+    target.addIllegalDialect<chlo::HloClientDialect>();
+    target.addLegalDialect<MhloDialect>();
+    target.addLegalDialect<StandardOpsDialect>();
+    target.addLegalDialect<shape::ShapeDialect>();
+    target.addLegalOp<CallOp>();
+    target.addLegalOp<TensorCastOp>();
+
+    DenseSet<Operation *> prevUnconvertedOps;
+    DenseSet<Operation *> unconvertedOps;
+
+    while (true) {
+      if (failed(
+              applyPartialConversion(op, target, patterns, &unconvertedOps))) {
+        return signalPassFailure();
+      }
+
+      if (prevUnconvertedOps == unconvertedOps) break;
+
+      prevUnconvertedOps = std::move(unconvertedOps);
+      if (failed(applyPatternsAndFoldGreedily(op, canonicalizePatterns))) {
+        return signalPassFailure();
+      }
+    }
+  }
+
+ private:
+  Option<bool> allow_partial_conversion_{
+      *this, "allow-partial-conversion",
+      llvm::cl::desc("Allow operations that can't be legalized."),
+      llvm::cl::init(false)};
+  Option<bool> legalize_chlo_{
+      *this, "legalize-chlo",
+      llvm::cl::desc(
+          "Also legalizes intermediate chlo ops to hlo (default true)"),
+      llvm::cl::init(true)};
+  Option<bool> use_tf2xla_fallback_{
+      *this, "use-tf2xla-fallback",
+      llvm::cl::desc(
+          "Also use TF2XLA fallback for legalization (default false)"),
+      llvm::cl::init(false)};
+  Option<std::string> device_type_{
+      *this, "device-type",
+      llvm::cl::desc(
+          "The device type used by TF2XLA fallback. Must be specified if "
+          "use-tf2xla-fallback is true, otherwise not used."),
+      llvm::cl::init("INVALID_DEVICE_TYPE")};
+};
+
+static PassRegistration<LegalizeTF> pass(
+    "iree-xla-legalize-tf", "Legalize from TensorFlow to the XLA dialect");
+
+}  // namespace
+}  // namespace mhlo
+}  // namespace mlir
diff --git a/integrations/tensorflow/compiler/test/legalize-tf.mlir b/integrations/tensorflow/compiler/test/legalize-tf.mlir
new file mode 100644
index 0000000..e10a92a
--- /dev/null
+++ b/integrations/tensorflow/compiler/test/legalize-tf.mlir
@@ -0,0 +1,54 @@
+// RUN: iree-tf-opt -iree-xla-legalize-tf -split-input-file <%s | IreeFileCheck %s
+
+// CHECK-LABEL: func @f
+func @f(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<3xf32>) {
+  // CHECK: [[VAL0:%.+]] = mhlo.constant dense<2.000000e+00>
+  // CHECK: [[VAL1:%.+]] = mhlo.constant dense<1.000000e+00>
+  %0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
+  %1 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
+  %2 = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
+  %3 = "tf.Const"() {value = dense<1> : tensor<i64>} : () -> tensor<i64>
+  %4 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+  %5 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  %6 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+  %7 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  %8 = "tf.GreaterEqual"(%2, %4) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+  %9 = "tf.StridedSlice"(%5, %7, %5, %5) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
+  %10 = "tf.SelectV2"(%0, %4, %9) {device = ""} : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
+  %11 = "tf.Range"(%4, %9, %6) {device = ""} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<1xi32>
+  %12 = "tf.Equal"(%10, %11) {device = "", incompatible_shape_error = true} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi1>
+  %13 = "tf.SelectV2"(%12, %2, %5) {device = ""} : (tensor<1xi1>, tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
+  %14 = "tf.Sub"(%2, %6) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
+  %15 = "tf.Maximum"(%14, %6) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
+  %16 = "tf.Cast"(%15) {Truncate = false, device = ""} : (tensor<i32>) -> tensor<f32>
+  %17 = "tf.SelectV2"(%8, %15, %1) {device = ""} : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
+  %18 = "tf.Cast"(%17) {Truncate = false, device = ""} : (tensor<i32>) -> tensor<i64>
+  %19 = "tf.Range"(%3, %18, %3) {device = ""} : (tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1xi64>
+  %20 = "tf.Cast"(%19) {Truncate = false, device = ""} : (tensor<1xi64>) -> tensor<1xf32>
+
+  // CHECK: [[VAL2:%.+]] = "mhlo.reshape"(%arg0)
+  %21 = "tf.ExpandDims"(%arg0, %4) {device = ""} : (tensor<f32>, tensor<i32>) -> tensor<1xf32>
+
+  // CHECK: [[VAL3:%.+]] = "mhlo.reshape"(%arg1)
+  %22 = "tf.ExpandDims"(%arg1, %4) {device = ""} : (tensor<f32>, tensor<i32>) -> tensor<1xf32>
+
+  // CHECK: [[VAL4:%.+]] = mhlo.subtract [[VAL3]], [[VAL2]]
+  %23 = "tf.Sub"(%22, %21) {device = ""} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+
+  // CHECK: [[VAL5:%.+]] = mhlo.divide [[VAL4]], [[VAL0]]
+  %24 = "tf.RealDiv"(%23, %16) {device = ""} : (tensor<1xf32>, tensor<f32>) -> tensor<1xf32>
+
+  // CHECK: [[VAL6:%.+]] = mhlo.multiply [[VAL5]], [[VAL1]]
+  %25 = "tf.Mul"(%24, %20) {device = ""} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+
+  // CHECK: [[VAL7:%.+]] = mhlo.add [[VAL2]], [[VAL6]]
+  %26 = "tf.AddV2"(%21, %25) {device = ""} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+
+  // CHECK: [[VAL8:%.+]] = "mhlo.concatenate"([[VAL2]], [[VAL7]], [[VAL3]]) {dimension = 0 : i64}
+  %27 = "tf.ConcatV2"(%21, %26, %22, %10) {device = ""} : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<i32>) -> tensor<3xf32>
+  %28 = "tf.Slice"(%27, %7, %13) {device = ""} : (tensor<3xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xf32>
+  %29 = "tf.Identity"(%28) {device = ""} : (tensor<3xf32>) -> tensor<3xf32>
+
+  // CHECK: return [[VAL8]]
+  return %29 : tensor<3xf32>
+}
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index ff88406..bc8a998 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -53,7 +53,6 @@
 # backends.
 # keep sorted
 SPECIAL_CASES = [
-    "linspace_test.py",
     "mobile_bert_squad_test.py",
 ]
 
@@ -102,6 +101,7 @@
     "einsum_vector_test.py",
     "fft_test.py",  # TODO(natashaknk): Get this working after kernel is in.
     "fill_test.py",  # TODO(jennik): Get this test working on IREE.
+    "linspace_test.py",  # TODO(https://github.com/google/iree/issues/1521)
     "logical_ops_test.py",
     "mandelbrot_test.py",  # TODO(silvasean): Get this working on IREE.
     "matrix_ops_dynamic_test.py",
@@ -123,6 +123,7 @@
     "einsum_vector_test.py",
     "fft_test.py",  # TODO(natashaknk): Get this working after kernel is in.
     "fill_test.py",  # TODO(jennik): Get this test working on IREE.
+    "linspace_test.py",  # TODO(https://github.com/google/iree/issues/1521)
     "logical_ops_test.py",
     "mandelbrot_test.py",  # TODO(silvasean): Get this working on IREE.
     "matrix_ops_dynamic_test.py",
@@ -192,47 +193,6 @@
     ],
 )
 
-# Special cases.
-
-# linspace_test passes internally, but fails in the OSS CI, so it needs
-# a "nokokoro" tag.
-iree_e2e_test_suite(
-    # TODO(#2082): `linspace_test.py` fails in the `bazel-tensorflow` image.
-    name = "linspace_tests",
-    backends_to_srcs = {
-        "tf": ["linspace_test.py"],
-        "iree_vmla": ["linspace_test.py"],
-    },
-    reference_backend = "tf",
-    tags = [
-        "manual",
-        "nokokoro",
-    ],
-    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
-        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
-    ],
-)
-
-# TODO(laurenzo): Re-enable iree_vulkan once dynamic-slice is implemented
-# See https://github.com/google/iree/issues/1521
-iree_e2e_test_suite(
-    name = "linspace_tests_failing",
-    backends_to_srcs = {
-        "iree_llvmjit": ["linspace_test.py"],
-        "iree_vulkan": ["linspace_test.py"],
-    },
-    reference_backend = "tf",
-    tags = [
-        "failing",
-        "manual",
-        "nokokoro",
-        "notap",
-    ],
-    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
-        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
-    ],
-)
-
 iree_e2e_test_suite(
     name = "mobile_bert_squad_tests",
     size = "enormous",