Merge google -> main (#3050)
* 2730f436 Merge pull request #3048 from GMNGeoffrey:main-to-google
* 7f8f2c51 Opt-in to the global dialect registry
* 051a9e2f Synchronize submodules
* 11d565ee Integrate LLVM at llvm/llvm-project@ffd0b31c7cba
* 833983bc Synchronize submodules
* 5ae0d624 Integrate LLVM at llvm/llvm-project@1d3d9b9cd808
* 80c32bdf Synchronize submodules
* 731139cf Integrate LLVM at llvm/llvm-project@646f19bb9dc8
* de995a98 Synchronize submodules
* 703e3782 Integrate LLVM at llvm/llvm-project@bc3d4d9ed783
* 92689b8d Adds support for invoking a function through the java api
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index 0c2dbe3..e7e7454 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -71,6 +71,7 @@
"dynamic_mlp_relu_test.py",
"dynamic_mlp_test.py",
"fill_test.py", # TODO(jennik): Get this test working on IREE.
+ "finite_test.py",
"mandelbrot_test.py", # TODO(silvasean): Get this working on IREE.
"matrix_ops_test.py",
"range_test.py",
@@ -87,6 +88,7 @@
"dynamic_mlp_relu_test.py",
"dynamic_mlp_test.py",
"fill_test.py", # TODO(jennik): Get this test working on IREE.
+ "finite_test.py",
"mandelbrot_test.py", # TODO(silvasean): Get this working on IREE.
"matrix_ops_test.py",
"range_test.py",
diff --git a/integrations/tensorflow/e2e/finite_test.py b/integrations/tensorflow/e2e/finite_test.py
new file mode 100644
index 0000000..ff62f3a
--- /dev/null
+++ b/integrations/tensorflow/e2e/finite_test.py
@@ -0,0 +1,41 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+from pyiree.tf.support import tf_test_utils
+import tensorflow.compat.v2 as tf
+
+
+class FiniteModule(tf.Module):
+
+ @tf.function(input_signature=[tf.TensorSpec([4], tf.float32)])
+ def finite(self, x):
+ return tf.math.is_finite(x)
+
+
+@tf_test_utils.compile_module(FiniteModule)
+class FiniteTest(tf_test_utils.TracedModuleTestCase):
+
+ def test_finite(self):
+
+ def finite(module):
+ module.finite(np.array([0.0, 1.2, -5.0, np.inf], dtype=np.float32))
+
+ self.compare_backends(finite)
+
+
+if __name__ == "__main__":
+ if hasattr(tf, "enable_v2_behavior"):
+ tf.enable_v2_behavior()
+ tf.test.main()
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
index 3a69791..6aa0836 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
@@ -276,7 +276,7 @@
linalg::LinalgTilingOptions options,
ArrayRef<int64_t> workgroupSize,
PatternBenefit benefit = 1)
- : Base(context, options,
+ : Base(context, options.setDistributionOptions(matmulDistributionOptions),
linalg::LinalgMarker(
ArrayRef<Identifier>(),
Identifier::get(getWorkgroupMarker(), context)),
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
index 48c6fb7..6e3fdd0 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
@@ -684,6 +684,26 @@
TypeConverter &typeConverter;
};
+struct FiniteOpConversion : public OpConversionPattern<mhlo::IsFiniteOp> {
+ FiniteOpConversion(MLIRContext *context, TypeConverter &typeConverter)
+ : OpConversionPattern(context), typeConverter(typeConverter) {}
+
+ LogicalResult matchAndRewrite(
+ mhlo::IsFiniteOp srcOp, ArrayRef<Value> rawOperands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto input_type =
+ srcOp.getOperand().getType().cast<ShapedType>().getElementType();
+ auto dst = VMLAConversionTarget::allocateOutputBuffer(
+ srcOp.getLoc(), srcOp.getResult(), typeConverter, rewriter);
+ rewriter.createOrFold<IREE::VMLA::FiniteOp>(
+ srcOp.getLoc(), srcOp.getOperand(), dst, TypeAttr::get(input_type));
+ rewriter.replaceOp(srcOp, {dst});
+ return success();
+ }
+
+ TypeConverter &typeConverter;
+};
+
struct ConvertOpConversion : public OpConversionPattern<mhlo::ConvertOp> {
ConvertOpConversion(MLIRContext *context, TypeConverter &typeConverter)
: OpConversionPattern(context), typeConverter(typeConverter) {}
@@ -810,6 +830,7 @@
context, typeConverter);
patterns.insert<CompareOpConversion>(context, typeConverter);
+ patterns.insert<FiniteOpConversion>(context, typeConverter);
// Ops that are only used for type information that we erase. We can elide
// these entirely by just passing on their input values.
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/math_ops.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/math_ops.mlir
index a30898d..ec2a349 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/math_ops.mlir
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/math_ops.mlir
@@ -33,3 +33,15 @@
// CHECK-NEXT: return %[[BUF]]
return %0 : tensor<4xf32>
}
+
+// -----
+
+// CHECK-LABEL: @finite
+func @finite(%arg0 : tensor<4xf32>) -> tensor<4xi1> attributes { sym_visibility = "private" } {
+ // CHECK-NEXT: %[[BUF_SZ:.+]] = constant 4
+ // CHECK-NEXT: %[[BUF:.+]] = vmla.buffer.alloc byte_length = %[[BUF_SZ]] : !vmla.buffer
+ // CHECK-NEXT: vmla.finite %arg0, out %[[BUF]] : f32
+ %0 = "mhlo.is_finite"(%arg0) : (tensor<4xf32>) -> tensor<4xi1>
+ // CHECK-NEXT: return %[[BUF]]
+ return %0 : tensor<4xi1>
+}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
index ee1f2ea..42c080d 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
@@ -277,6 +277,7 @@
VMLA_TYPED_IMPORT_OP(IREE::VMLA::CmpOp, "vmla.cmp");
VMLA_SIZED_IMPORT_OP(IREE::VMLA::SelectOp, "vmla.select");
+ VMLA_TYPED_IMPORT_OP(IREE::VMLA::FiniteOp, "vmla.finite");
VMLA_SIZED_IMPORT_OP(IREE::VMLA::CopyOp, "vmla.copy");
VMLA_SIZED_IMPORT_OP(IREE::VMLA::TransposeOp, "vmla.transpose");
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
index 8b5c425..4a10c08 100644
--- a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
+++ b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
@@ -186,6 +186,16 @@
}];
}
+def VMLA_FiniteOp : VMLA_Op<"finite", [VMLA_OpInterface]> {
+ let arguments = (ins
+ VMLA_Buffer:$src,
+ VMLA_Buffer:$dst,
+ VMLA_FloatTypeAttr:$element_type
+ );
+
+ let assemblyFormat = "$src`,` `out` $dst attr-dict `:` $element_type";
+}
+
//===----------------------------------------------------------------------===//
// VMLA Ops: shape/structure
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/VMLA/vmla.imports.mlir b/iree/compiler/Dialect/VMLA/vmla.imports.mlir
index 5f2177d..73f41f5 100644
--- a/iree/compiler/Dialect/VMLA/vmla.imports.mlir
+++ b/iree/compiler/Dialect/VMLA/vmla.imports.mlir
@@ -95,6 +95,8 @@
vm.import @select.x16(%cond : !vm.ref<!vmla.buffer>, %lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
vm.import @select.x32(%cond : !vm.ref<!vmla.buffer>, %lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
+vm.import @finite.f32(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
+
//===----------------------------------------------------------------------===//
// VMLA Ops: shape/structure
//===----------------------------------------------------------------------===//
diff --git a/iree/hal/vmla/op_kernels.h b/iree/hal/vmla/op_kernels.h
index bdacd77..be14cf3 100644
--- a/iree/hal/vmla/op_kernels.h
+++ b/iree/hal/vmla/op_kernels.h
@@ -120,6 +120,12 @@
absl::Span<T> dst_buffer);
};
+struct Finite {
+ template <typename T>
+ static Status Execute(absl::Span<const T> src_buffer,
+ absl::Span<bool> dst_buffer);
+};
+
struct Transpose {
template <typename T>
static Status Execute(absl::Span<const T> src_buffer,
diff --git a/iree/hal/vmla/op_kernels_generic.h b/iree/hal/vmla/op_kernels_generic.h
index 1d08230..99607f8 100644
--- a/iree/hal/vmla/op_kernels_generic.h
+++ b/iree/hal/vmla/op_kernels_generic.h
@@ -15,6 +15,8 @@
#ifndef IREE_HAL_VMLA_OP_KERNELS_GENERIC_H_
#define IREE_HAL_VMLA_OP_KERNELS_GENERIC_H_
+#include <cmath>
+
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/types/span.h"
@@ -220,6 +222,15 @@
}
template <typename T>
+Status Finite::Execute(absl::Span<const T> src_buffer,
+ absl::Span<bool> dst_buffer) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = std::isfinite(src_buffer[i]);
+ }
+ return OkStatus();
+}
+
+template <typename T>
Status Transpose::Execute(absl::Span<const T> src_buffer,
absl::Span<T> dst_buffer, ShapeSpan src_shape,
absl::Span<const int32_t> perm) {
diff --git a/iree/hal/vmla/vmla_module.cc b/iree/hal/vmla/vmla_module.cc
index 6354eaf..3f42662 100644
--- a/iree/hal/vmla/vmla_module.cc
+++ b/iree/hal/vmla/vmla_module.cc
@@ -415,6 +415,13 @@
IREE_VMLA_SELECT_OP(SelectX16, uint16_t);
IREE_VMLA_SELECT_OP(SelectX32, uint32_t);
+#define IREE_VMLA_UNARY_PREDICATE_OP(name, kernel, type) \
+ Status name(vm::ref<Buffer> src, vm::ref<Buffer> dst) { \
+ IREE_TRACE_SCOPE0("VMLAModuleState::" #name); \
+ return kernel::Execute<type>(src->As<type>(), dst->As<bool>()); \
+ }
+ IREE_VMLA_UNARY_PREDICATE_OP(FiniteF32, kernels::Finite, float);
+
//===--------------------------------------------------------------------===//
// VMLA Ops: shape/structure
//===--------------------------------------------------------------------===//
@@ -931,6 +938,7 @@
vm::MakeNativeFunction("clamp.f32", &VMLAModuleState::ClampF32),
vm::MakeNativeFunction("floor.f32", &VMLAModuleState::FloorF32),
vm::MakeNativeFunction("ceil.f32", &VMLAModuleState::CeilF32),
+ vm::MakeNativeFunction("finite.f32", &VMLAModuleState::FiniteF32),
vm::MakeNativeFunction("convert.i8.i16", &VMLAModuleState::ConvertI8I16),
vm::MakeNativeFunction("convert.i8.i32", &VMLAModuleState::ConvertI8I32),
diff --git a/iree/test/e2e/xla_ops/dot_general.mlir b/iree/test/e2e/xla_ops/dot_general.mlir
index 16073eb..a3f29a2 100644
--- a/iree/test/e2e/xla_ops/dot_general.mlir
+++ b/iree/test/e2e/xla_ops/dot_general.mlir
@@ -79,3 +79,19 @@
[15.0, 30.0, 45.0, 60.0]]]> : tensor<2x2x4xf32>) : tensor<2x2x4xf32>
return
}
+
+func @large_dot_general() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1.0> : tensor<4x32x1024xf32>
+ %rhs = iree.unfoldable_constant dense<0.4> : tensor<4x1024x64xf32>
+ %res = "mhlo.dot_general"(%lhs, %rhs) {
+ dot_dimension_numbers = {
+ lhs_batching_dimensions = dense<0> : tensor<1xi64>,
+ lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
+ rhs_batching_dimensions = dense<0> : tensor<1xi64>,
+ rhs_contracting_dimensions = dense<1> : tensor<1xi64>
+ },
+ precision_config = ["DEFAULT", "DEFAULT"]
+ } : (tensor<4x32x1024xf32>, tensor<4x1024x64xf32>) -> tensor<4x32x64xf32>
+ check.expect_almost_eq_const(%res, dense<409.596> : tensor<4x32x64xf32>) : tensor<4x32x64xf32>
+ return
+}
diff --git a/iree/test/e2e/xla_ops/finite.mlir b/iree/test/e2e/xla_ops/finite.mlir
new file mode 100644
index 0000000..68d8168
--- /dev/null
+++ b/iree/test/e2e/xla_ops/finite.mlir
@@ -0,0 +1,11 @@
+func @f32() attributes { iree.module.export } {
+ %0 = iree.unfoldable_constant dense<[1.0, 6.0, -6.0, 0.0]> : tensor<4xf32>
+ %1 = iree.unfoldable_constant dense<[0.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
+ %2 = "mhlo.divide"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ %result = "mhlo.is_finite"(%2) : (tensor<4xf32>) -> tensor<4xi1>
+ %c0 = iree.unfoldable_constant dense<0> : tensor<4xi8>
+ %c1 = iree.unfoldable_constant dense<1> : tensor<4xi8>
+ %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+ check.expect_eq_const(%output, dense<[0, 1, 1, 1]> : tensor<4xi8>) : tensor<4xi8>
+ return
+}