Merge google -> main (#3050)

* 2730f436 Merge pull request #3048 from GMNGeoffrey:main-to-google
* 7f8f2c51 Opt-in to the global dialect registry
* 051a9e2f Synchronize submodules
* 11d565ee Integrate LLVM at llvm/llvm-project@ffd0b31c7cba
* 833983bc Synchronize submodules
* 5ae0d624 Integrate LLVM at llvm/llvm-project@1d3d9b9cd808
* 80c32bdf Synchronize submodules
* 731139cf Integrate LLVM at llvm/llvm-project@646f19bb9dc8
* de995a98 Synchronize submodules
* 703e3782 Integrate LLVM at llvm/llvm-project@bc3d4d9ed783
* 92689b8d Adds support for invoking a function through the java api
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index 0c2dbe3..e7e7454 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -71,6 +71,7 @@
     "dynamic_mlp_relu_test.py",
     "dynamic_mlp_test.py",
     "fill_test.py",  # TODO(jennik): Get this test working on IREE.
+    "finite_test.py",
     "mandelbrot_test.py",  # TODO(silvasean): Get this working on IREE.
     "matrix_ops_test.py",
     "range_test.py",
@@ -87,6 +88,7 @@
     "dynamic_mlp_relu_test.py",
     "dynamic_mlp_test.py",
     "fill_test.py",  # TODO(jennik): Get this test working on IREE.
+    "finite_test.py",
     "mandelbrot_test.py",  # TODO(silvasean): Get this working on IREE.
     "matrix_ops_test.py",
     "range_test.py",
diff --git a/integrations/tensorflow/e2e/finite_test.py b/integrations/tensorflow/e2e/finite_test.py
new file mode 100644
index 0000000..ff62f3a
--- /dev/null
+++ b/integrations/tensorflow/e2e/finite_test.py
@@ -0,0 +1,41 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+from pyiree.tf.support import tf_test_utils
+import tensorflow.compat.v2 as tf
+
+
+class FiniteModule(tf.Module):
+
+  @tf.function(input_signature=[tf.TensorSpec([4], tf.float32)])
+  def finite(self, x):
+    return tf.math.is_finite(x)
+
+
+@tf_test_utils.compile_module(FiniteModule)
+class FiniteTest(tf_test_utils.TracedModuleTestCase):
+
+  def test_finite(self):
+
+    def finite(module):
+      module.finite(np.array([0.0, 1.2, -5.0, np.inf], dtype=np.float32))
+
+    self.compare_backends(finite)
+
+
+if __name__ == "__main__":
+  if hasattr(tf, "enable_v2_behavior"):
+    tf.enable_v2_behavior()
+  tf.test.main()
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
index 3a69791..6aa0836 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
@@ -276,7 +276,7 @@
                          linalg::LinalgTilingOptions options,
                          ArrayRef<int64_t> workgroupSize,
                          PatternBenefit benefit = 1)
-      : Base(context, options,
+      : Base(context, options.setDistributionOptions(matmulDistributionOptions),
              linalg::LinalgMarker(
                  ArrayRef<Identifier>(),
                  Identifier::get(getWorkgroupMarker(), context)),
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
index 48c6fb7..6e3fdd0 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
@@ -684,6 +684,26 @@
   TypeConverter &typeConverter;
 };
 
+struct FiniteOpConversion : public OpConversionPattern<mhlo::IsFiniteOp> {
+  FiniteOpConversion(MLIRContext *context, TypeConverter &typeConverter)
+      : OpConversionPattern(context), typeConverter(typeConverter) {}
+
+  LogicalResult matchAndRewrite(
+      mhlo::IsFiniteOp srcOp, ArrayRef<Value> rawOperands,
+      ConversionPatternRewriter &rewriter) const override {
+    auto input_type =
+        srcOp.getOperand().getType().cast<ShapedType>().getElementType();
+    auto dst = VMLAConversionTarget::allocateOutputBuffer(
+        srcOp.getLoc(), srcOp.getResult(), typeConverter, rewriter);
+    rewriter.createOrFold<IREE::VMLA::FiniteOp>(
+        srcOp.getLoc(), srcOp.getOperand(), dst, TypeAttr::get(input_type));
+    rewriter.replaceOp(srcOp, {dst});
+    return success();
+  }
+
+  TypeConverter &typeConverter;
+};
+
 struct ConvertOpConversion : public OpConversionPattern<mhlo::ConvertOp> {
   ConvertOpConversion(MLIRContext *context, TypeConverter &typeConverter)
       : OpConversionPattern(context), typeConverter(typeConverter) {}
@@ -810,6 +830,7 @@
       context, typeConverter);
 
   patterns.insert<CompareOpConversion>(context, typeConverter);
+  patterns.insert<FiniteOpConversion>(context, typeConverter);
 
   // Ops that are only used for type information that we erase. We can elide
   // these entirely by just passing on their input values.
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/math_ops.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/math_ops.mlir
index a30898d..ec2a349 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/math_ops.mlir
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/math_ops.mlir
@@ -33,3 +33,15 @@
   // CHECK-NEXT: return %[[BUF]]
   return %0 : tensor<4xf32>
 }
+
+// -----
+
+// CHECK-LABEL: @finite
+func @finite(%arg0 : tensor<4xf32>) -> tensor<4xi1> attributes { sym_visibility = "private" } {
+  // CHECK-NEXT: %[[BUF_SZ:.+]] = constant 4
+  // CHECK-NEXT: %[[BUF:.+]] = vmla.buffer.alloc byte_length = %[[BUF_SZ]] : !vmla.buffer
+  // CHECK-NEXT: vmla.finite %arg0, out %[[BUF]] : f32
+  %0 = "mhlo.is_finite"(%arg0) : (tensor<4xf32>) -> tensor<4xi1>
+  // CHECK-NEXT: return %[[BUF]]
+  return %0 : tensor<4xi1>
+}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
index ee1f2ea..42c080d 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
@@ -277,6 +277,7 @@
 
   VMLA_TYPED_IMPORT_OP(IREE::VMLA::CmpOp, "vmla.cmp");
   VMLA_SIZED_IMPORT_OP(IREE::VMLA::SelectOp, "vmla.select");
+  VMLA_TYPED_IMPORT_OP(IREE::VMLA::FiniteOp, "vmla.finite");
 
   VMLA_SIZED_IMPORT_OP(IREE::VMLA::CopyOp, "vmla.copy");
   VMLA_SIZED_IMPORT_OP(IREE::VMLA::TransposeOp, "vmla.transpose");
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
index 8b5c425..4a10c08 100644
--- a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
+++ b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
@@ -186,6 +186,16 @@
   }];
 }
 
+def VMLA_FiniteOp : VMLA_Op<"finite", [VMLA_OpInterface]> {
+  let arguments = (ins
+    VMLA_Buffer:$src,
+    VMLA_Buffer:$dst,
+    VMLA_FloatTypeAttr:$element_type
+  );
+
+  let assemblyFormat = "$src`,` `out` $dst attr-dict `:` $element_type";
+}
+
 //===----------------------------------------------------------------------===//
 // VMLA Ops: shape/structure
 //===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/VMLA/vmla.imports.mlir b/iree/compiler/Dialect/VMLA/vmla.imports.mlir
index 5f2177d..73f41f5 100644
--- a/iree/compiler/Dialect/VMLA/vmla.imports.mlir
+++ b/iree/compiler/Dialect/VMLA/vmla.imports.mlir
@@ -95,6 +95,8 @@
 vm.import @select.x16(%cond : !vm.ref<!vmla.buffer>, %lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
 vm.import @select.x32(%cond : !vm.ref<!vmla.buffer>, %lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
 
+vm.import @finite.f32(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
+
 //===----------------------------------------------------------------------===//
 // VMLA Ops: shape/structure
 //===----------------------------------------------------------------------===//
diff --git a/iree/hal/vmla/op_kernels.h b/iree/hal/vmla/op_kernels.h
index bdacd77..be14cf3 100644
--- a/iree/hal/vmla/op_kernels.h
+++ b/iree/hal/vmla/op_kernels.h
@@ -120,6 +120,12 @@
                         absl::Span<T> dst_buffer);
 };
 
+struct Finite {
+  template <typename T>
+  static Status Execute(absl::Span<const T> src_buffer,
+                        absl::Span<bool> dst_buffer);
+};
+
 struct Transpose {
   template <typename T>
   static Status Execute(absl::Span<const T> src_buffer,
diff --git a/iree/hal/vmla/op_kernels_generic.h b/iree/hal/vmla/op_kernels_generic.h
index 1d08230..99607f8 100644
--- a/iree/hal/vmla/op_kernels_generic.h
+++ b/iree/hal/vmla/op_kernels_generic.h
@@ -15,6 +15,8 @@
 #ifndef IREE_HAL_VMLA_OP_KERNELS_GENERIC_H_
 #define IREE_HAL_VMLA_OP_KERNELS_GENERIC_H_
 
+#include <cmath>
+
 #include "absl/container/flat_hash_set.h"
 #include "absl/container/inlined_vector.h"
 #include "absl/types/span.h"
@@ -220,6 +222,15 @@
 }
 
 template <typename T>
+Status Finite::Execute(absl::Span<const T> src_buffer,
+                       absl::Span<bool> dst_buffer) {
+  for (size_t i = 0; i < dst_buffer.size(); ++i) {
+    dst_buffer[i] = std::isfinite(src_buffer[i]);
+  }
+  return OkStatus();
+}
+
+template <typename T>
 Status Transpose::Execute(absl::Span<const T> src_buffer,
                           absl::Span<T> dst_buffer, ShapeSpan src_shape,
                           absl::Span<const int32_t> perm) {
diff --git a/iree/hal/vmla/vmla_module.cc b/iree/hal/vmla/vmla_module.cc
index 6354eaf..3f42662 100644
--- a/iree/hal/vmla/vmla_module.cc
+++ b/iree/hal/vmla/vmla_module.cc
@@ -415,6 +415,13 @@
   IREE_VMLA_SELECT_OP(SelectX16, uint16_t);
   IREE_VMLA_SELECT_OP(SelectX32, uint32_t);
 
+#define IREE_VMLA_UNARY_PREDICATE_OP(name, kernel, type)            \
+  Status name(vm::ref<Buffer> src, vm::ref<Buffer> dst) {           \
+    IREE_TRACE_SCOPE0("VMLAModuleState::" #name);                   \
+    return kernel::Execute<type>(src->As<type>(), dst->As<bool>()); \
+  }
+  IREE_VMLA_UNARY_PREDICATE_OP(FiniteF32, kernels::Finite, float);
+
   //===--------------------------------------------------------------------===//
   // VMLA Ops: shape/structure
   //===--------------------------------------------------------------------===//
@@ -931,6 +938,7 @@
     vm::MakeNativeFunction("clamp.f32", &VMLAModuleState::ClampF32),
     vm::MakeNativeFunction("floor.f32", &VMLAModuleState::FloorF32),
     vm::MakeNativeFunction("ceil.f32", &VMLAModuleState::CeilF32),
+    vm::MakeNativeFunction("finite.f32", &VMLAModuleState::FiniteF32),
 
     vm::MakeNativeFunction("convert.i8.i16", &VMLAModuleState::ConvertI8I16),
     vm::MakeNativeFunction("convert.i8.i32", &VMLAModuleState::ConvertI8I32),
diff --git a/iree/test/e2e/xla_ops/dot_general.mlir b/iree/test/e2e/xla_ops/dot_general.mlir
index 16073eb..a3f29a2 100644
--- a/iree/test/e2e/xla_ops/dot_general.mlir
+++ b/iree/test/e2e/xla_ops/dot_general.mlir
@@ -79,3 +79,19 @@
           [15.0, 30.0, 45.0, 60.0]]]> : tensor<2x2x4xf32>) : tensor<2x2x4xf32>
   return
 }
+
+func @large_dot_general() attributes { iree.module.export } {
+  %lhs = iree.unfoldable_constant dense<1.0> : tensor<4x32x1024xf32>
+  %rhs = iree.unfoldable_constant dense<0.4> : tensor<4x1024x64xf32>
+  %res = "mhlo.dot_general"(%lhs, %rhs) {
+           dot_dimension_numbers = {
+             lhs_batching_dimensions = dense<0> : tensor<1xi64>,
+             lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
+             rhs_batching_dimensions = dense<0> : tensor<1xi64>,
+             rhs_contracting_dimensions = dense<1> : tensor<1xi64>
+           },
+           precision_config = ["DEFAULT", "DEFAULT"]
+         } : (tensor<4x32x1024xf32>, tensor<4x1024x64xf32>) -> tensor<4x32x64xf32>
+  check.expect_almost_eq_const(%res, dense<409.596> : tensor<4x32x64xf32>) : tensor<4x32x64xf32>
+  return
+}
diff --git a/iree/test/e2e/xla_ops/finite.mlir b/iree/test/e2e/xla_ops/finite.mlir
new file mode 100644
index 0000000..68d8168
--- /dev/null
+++ b/iree/test/e2e/xla_ops/finite.mlir
@@ -0,0 +1,11 @@
+func @f32() attributes { iree.module.export } {
+  %0 = iree.unfoldable_constant dense<[1.0, 6.0, -6.0, 0.0]> : tensor<4xf32>
+  %1 = iree.unfoldable_constant dense<[0.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
+  %2 = "mhlo.divide"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  %result = "mhlo.is_finite"(%2) : (tensor<4xf32>) -> tensor<4xi1>
+  %c0 = iree.unfoldable_constant dense<0> : tensor<4xi8>
+  %c1 = iree.unfoldable_constant dense<1> : tensor<4xi8>
+  %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+  check.expect_eq_const(%output, dense<[0, 1, 1, 1]> : tensor<4xi8>) : tensor<4xi8>
+  return
+}