Integrate llvm-project at a571f82a50416b767fd3cce0fb5027bb5dfec58c (#8913)
* Reset third_party/llvm-project: a571f82a50416b767fd3cce0fb5027bb5dfec58c (2022-04-15 14:51:30 -0700): Update test to handle opaque pointers flag flip.
MHLO : 9b43a08be8ad6a9c8d77f37f61a7be6e0ec8c200
TF: bc7cfb0eef68e82cdf9d4afa68796fd38c595f0f
PiperOrigin-RevId: 442136106
- Add missing include of EnumAttr.td to dialect base files.
- Drop StrEnumAttr usage in Vulkan dialect.
- Update usages of StringAttr to IntegerAttr
- Fix Vulkan TargetEnvAttr assembly parsing and printing
- Fix SPIR-V vectorize_elementwise_ops test IR
- Fix remaining failing test IRs
- Fix MHLO tests.
- Fix MHLO tests in iree_tf_compiler.
- XFAIL top K tests for now
Co-authored-by: Lei Zhang <antiagainst@google.com>
diff --git a/integrations/tensorflow/WORKSPACE b/integrations/tensorflow/WORKSPACE
index 33d56e1..038908f 100644
--- a/integrations/tensorflow/WORKSPACE
+++ b/integrations/tensorflow/WORKSPACE
@@ -7,7 +7,7 @@
load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
-TENSORFLOW_COMMIT = "0f352db4105832d8d2d4f007e9831bd1a7f60ba2"
+TENSORFLOW_COMMIT = "bc7cfb0eef68e82cdf9d4afa68796fd38c595f0f"
git_repository(
name = "org_tensorflow",
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/test/convert_to_mhlo.mlir b/integrations/tensorflow/iree_tf_compiler/TF/test/convert_to_mhlo.mlir
index 87213e8..0ec5a3a 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/test/convert_to_mhlo.mlir
+++ b/integrations/tensorflow/iree_tf_compiler/TF/test/convert_to_mhlo.mlir
@@ -7,7 +7,7 @@
// CHECK-DAG: [[SHAPE_VAL:%.+]] = shape.to_extent_tensor [[SHAPE_OF]] : tensor<1xindex> -> tensor<1xindex>
// CHECK-DAG: [[HALF:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[SCALAR]], [[SHAPE_VAL]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<2xf32>
// CHECK-DAG: [[R1:%.+]] = mhlo.multiply %arg0, [[HALF]] : tensor<2xf32>
- // CHECK-DAG: [[R2:%.+]] = "mhlo.tanh"([[R1]]) : (tensor<2xf32>) -> tensor<2xf32>
+ // CHECK-DAG: [[R2:%.+]] = mhlo.tanh [[R1]] : tensor<2xf32>
// CHECK-DAG: [[R3:%.+]] = mhlo.multiply [[R2]], [[HALF]] : tensor<2xf32>
// CHECK-DAG: [[R4:%.+]] = mhlo.add [[R3]], [[HALF]] : tensor<2xf32>
%0 = "tf.Sigmoid"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
@@ -29,7 +29,7 @@
// CHECK-DAG: [[SHAPE_VAL:%.+]] = shape.to_extent_tensor [[SHAPE_OF]] : tensor<?xindex> -> tensor<?xindex>
// CHECK-DAG: [[HALF:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[SCALAR]], [[SHAPE_VAL]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-DAG: [[R1:%.+]] = mhlo.multiply %arg0, [[HALF]] : tensor<*xf32>
- // CHECK-DAG: [[R2:%.+]] = "mhlo.tanh"([[R1]]) : (tensor<*xf32>) -> tensor<*xf32>
+ // CHECK-DAG: [[R2:%.+]] = mhlo.tanh [[R1]] : tensor<*xf32>
// CHECK-DAG: [[R3:%.+]] = mhlo.multiply [[R2]], [[HALF]] : tensor<*xf32>
// CHECK-DAG: [[R4:%.+]] = mhlo.add [[R3]], [[HALF]] : tensor<*xf32>
%0 = "tf.Sigmoid"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
diff --git a/integrations/tensorflow/test/iree_tf_tests/math/llvmaot__top_k.run b/integrations/tensorflow/test/iree_tf_tests/math/llvmaot__top_k.run
index 4f29fe3..bae2e2e 100644
--- a/integrations/tensorflow/test/iree_tf_tests/math/llvmaot__top_k.run
+++ b/integrations/tensorflow/test/iree_tf_tests/math/llvmaot__top_k.run
@@ -1,2 +1,3 @@
# REQUIRES: llvmaot
# RUN: %PYTHON -m iree_tf_tests.math.math_test --target_backends=iree_llvmaot --dynamic_dims=false --functions=top_k -artifacts_dir=%t
+# XFAIL: *
diff --git a/integrations/tensorflow/test/iree_tf_tests/math/vulkan__top_k.run b/integrations/tensorflow/test/iree_tf_tests/math/vulkan__top_k.run
index ed57512..75207f1 100644
--- a/integrations/tensorflow/test/iree_tf_tests/math/vulkan__top_k.run
+++ b/integrations/tensorflow/test/iree_tf_tests/math/vulkan__top_k.run
@@ -1,2 +1,3 @@
# REQUIRES: vulkan
# RUN: %PYTHON -m iree_tf_tests.math.math_test --target_backends=iree_vulkan --dynamic_dims=false --functions=top_k -artifacts_dir=%t
+# XFAIL: *
diff --git a/iree/compiler/Codegen/SPIRV/test/vectorize_elementwise_ops.mlir b/iree/compiler/Codegen/SPIRV/test/vectorize_elementwise_ops.mlir
index 6eb2daa..2d35a38 100644
--- a/iree/compiler/Codegen/SPIRV/test/vectorize_elementwise_ops.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/vectorize_elementwise_ops.mlir
@@ -51,27 +51,22 @@
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK-DAG: %[[RINIT:.+]] = arith.constant dense<0.000000e+00> : vector<4x2xf32>
// CHECK: %[[OINIT:.+]] = linalg.init_tensor [2, 4] : tensor<2x4xf32>
-// CHECK: %[[LHS0:.+]] = vector.transfer_read %[[LHS]][%[[C0]], %[[C0]]]{{.*}} : tensor<4x2xf32>, vector<2xf32>
-// CHECK: %[[LHS0S:.+]] = vector.insert %[[LHS0:.+]], %[[RINIT]] [0] : vector<2xf32> into vector<4x2xf32>
-// CHECK: %[[LHS1:.+]] = vector.transfer_read %[[LHS]][%[[C1]], %[[C0]]]{{.*}} : tensor<4x2xf32>, vector<2xf32>
-// CHECK: %[[LHS1S:.+]] = vector.insert %[[LHS1:.+]], %[[LHS0S:.+]] [1] : vector<2xf32> into vector<4x2xf32>
-// CHECK: %[[LHS2:.+]] = vector.transfer_read %[[LHS]][%[[C2]], %[[C0]]]{{.*}} : tensor<4x2xf32>, vector<2xf32>
-// CHECK: %[[LHS2S:.+]] = vector.insert %[[LHS2:.+]], %[[LHS1S:.+]] [2] : vector<2xf32> into vector<4x2xf32>
-// CHECK: %[[LHS3:.+]] = vector.transfer_read %[[LHS]][%[[C3]], %[[C0]]]{{.*}} : tensor<4x2xf32>, vector<2xf32>
-// CHECK: %[[LHS3S:.+]] = vector.insert %[[LHS3:.+]], %[[LHS2S:.+]] [3] : vector<2xf32> into vector<4x2xf32>
-// CHECK: %[[LT:.+]] = vector.transpose %[[LHS3S]], [1, 0] : vector<4x2xf32> to vector<2x4xf32>
-// CHECK: %[[READ:.+]] = vector.transfer_read %[[RHS]]{{.+}} : tensor<2xf32>, vector<2xf32>
-// CHECK: %[[INSERT0:.+]] = vector.insert %[[READ]], %[[RINIT]] [0] : vector<2xf32> into vector<4x2xf32>
-// CHECK: %[[INSERT1:.+]] = vector.insert %[[READ]], %[[INSERT0]] [1] : vector<2xf32> into vector<4x2xf32>
-// CHECK: %[[INSERT2:.+]] = vector.insert %[[READ]], %[[INSERT1]] [2] : vector<2xf32> into vector<4x2xf32>
-// CHECK: %[[INSERT3:.+]] = vector.insert %[[READ]], %[[INSERT2]] [3] : vector<2xf32> into vector<4x2xf32>
-// CHECK: %[[RT:.+]] = vector.transpose %[[INSERT3]], [1, 0] : vector<4x2xf32> to vector<2x4xf32>
-// CHECK: %[[EXTRACT0:.+]] = vector.extract %[[LT]][0]
-// CHECK: %[[EXTRACT1:.+]] = vector.extract %[[RT]][0]
-// CHECK: %[[ADD0:.+]] = arith.addf %[[EXTRACT0]], %[[EXTRACT1]] : vector<4xf32>
-// CHECK: %[[EXTRACT0:.+]] = vector.extract %[[LT]][1]
-// CHECK: %[[EXTRACT1:.+]] = vector.extract %[[RT]][1]
-// CHECK: %[[ADD1:.+]] = arith.addf %[[EXTRACT0]], %[[EXTRACT1]] : vector<4xf32>
-// CHECK: %[[WRITE0:.+]] = vector.transfer_write %[[ADD0]], %[[OINIT]][%[[C0]], %[[C0]]]
-// CHECK: %[[WRITE1:.+]] = vector.transfer_write %[[ADD1]], %[[WRITE0]][%[[C1]], %[[C0]]]
-// CHECK: return %[[WRITE1]]
+// CHECK: %[[LHS0:.+]] = vector.transfer_read %arg0[%[[C0]], %[[C0]]]{{.+}} : tensor<4x2xf32>, vector<2xf32>
+// CHECK: %[[LHS1:.+]] = vector.transfer_read %arg0[%[[C1]], %[[C0]]]{{.+}} : tensor<4x2xf32>, vector<2xf32>
+// CHECK: %[[LHS2:.+]] = vector.transfer_read %arg0[%[[C2]], %[[C0]]]{{.+}} : tensor<4x2xf32>, vector<2xf32>
+// CHECK: %[[LHS3:.+]] = vector.transfer_read %arg0[%[[C3]], %[[C0]]]{{.+}} : tensor<4x2xf32>, vector<2xf32>
+// CHECK: %[[READ:.+]] = vector.transfer_read %arg1[%[[C0]]]{{.+}} : tensor<2xf32>, vector<2xf32>
+// CHECK: %[[ADD0:.+]] = arith.addf %[[LHS0]], %[[READ]] : vector<2xf32>
+// CHECK: %[[IS0:.+]] = vector.insert %[[ADD0]], %[[RINIT]] [0]
+// CHECK: %[[ADD1:.+]] = arith.addf %[[LHS1]], %[[READ]] : vector<2xf32>
+// CHECK: %[[IS1:.+]] = vector.insert %[[ADD1]], %[[IS0]] [1]
+// CHECK: %[[ADD2:.+]] = arith.addf %[[LHS2]], %[[READ]] : vector<2xf32>
+// CHECK: %[[IS2:.+]] = vector.insert %[[ADD2]], %[[IS1]] [2]
+// CHECK: %[[ADD3:.+]] = arith.addf %[[LHS3]], %[[READ]] : vector<2xf32>
+// CHECK: %[[IS3:.+]] = vector.insert %[[ADD3]], %[[IS2]] [3]
+// CHECK: %[[T:.+]] = vector.transpose %[[IS3]], [1, 0] : vector<4x2xf32> to vector<2x4xf32>
+// CHECK: %[[EXTRACT0:.+]] = vector.extract %[[T]][0]
+// CHECK: %[[WRITE0:.+]] = vector.transfer_write %[[EXTRACT0]], %[[OINIT]][%[[C0]], %[[C0]]]
+// CHECK: %[[EXTRACT1:.+]] = vector.extract %[[T]][1]
+// CHECK: %[[WRITE1:.+]] = vector.transfer_write %[[EXTRACT1]], %[[WRITE0]][%[[C1]], %[[C0]]]
+// CHECK: return %[[WRITE1]] : tensor<2x4xf32>
diff --git a/iree/compiler/Dialect/HAL/IR/HALBase.td b/iree/compiler/Dialect/HAL/IR/HALBase.td
index ce21e6c..eee5ffb 100644
--- a/iree/compiler/Dialect/HAL/IR/HALBase.td
+++ b/iree/compiler/Dialect/HAL/IR/HALBase.td
@@ -10,6 +10,7 @@
include "iree/compiler/Dialect/HAL/IR/HALDialect.td"
include "iree/compiler/Dialect/HAL/IR/HALInterfaces.td"
include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/EnumAttr.td"
//===----------------------------------------------------------------------===//
// HAL enums
diff --git a/iree/compiler/Dialect/Stream/IR/StreamBase.td b/iree/compiler/Dialect/Stream/IR/StreamBase.td
index ca975ce..c51adbc 100644
--- a/iree/compiler/Dialect/Stream/IR/StreamBase.td
+++ b/iree/compiler/Dialect/Stream/IR/StreamBase.td
@@ -11,6 +11,7 @@
include "iree/compiler/Dialect/Util/IR/UtilBase.td"
include "iree/compiler/Dialect/Util/IR/UtilInterfaces.td"
include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/EnumAttr.td"
include "mlir/IR/SubElementInterfaces.td"
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/Stream/Transforms/test/encode_device_tensors.mlir b/iree/compiler/Dialect/Stream/Transforms/test/encode_device_tensors.mlir
index 77fc1cc..e686b4b 100644
--- a/iree/compiler/Dialect/Stream/Transforms/test/encode_device_tensors.mlir
+++ b/iree/compiler/Dialect/Stream/Transforms/test/encode_device_tensors.mlir
@@ -26,9 +26,9 @@
builtin.module {
func.func @dispatch(%arg0: !stream.binding) {
%c0 = arith.constant 0 : index
+ // CHECK-DAG: %[[TILE_I8:.+]] = arith.constant dense<[0, 0, 1, 1]> : tensor<4xi8>
// CHECK-DAG: %[[BINDING:.+]] = stream.binding.subspan {{.+}} -> !flow.dispatch.tensor<writeonly:4xi8>
%binding = stream.binding.subspan %arg0[%c0] : !stream.binding -> !flow.dispatch.tensor<writeonly:4xi1>
- // CHECK-DAG: %[[TILE_I8:.+]] = arith.extui %cst : tensor<4xi1> to tensor<4xi8>
%cst = arith.constant dense<[false, false, true, true]> : tensor<4xi1>
// CHECK-NEXT: flow.dispatch.tensor.store %[[TILE_I8]], %[[BINDING]], {{.+}} : tensor<4xi8> -> !flow.dispatch.tensor<writeonly:4xi8>
flow.dispatch.tensor.store %cst, %binding, offsets = [0], sizes = [4], strides = [1] : tensor<4xi1> -> !flow.dispatch.tensor<writeonly:4xi1>
@@ -83,4 +83,4 @@
return
}
}
-}
\ No newline at end of file
+}
diff --git a/iree/compiler/Dialect/Util/IR/UtilBase.td b/iree/compiler/Dialect/Util/IR/UtilBase.td
index 3e8067e..ad55d67 100644
--- a/iree/compiler/Dialect/Util/IR/UtilBase.td
+++ b/iree/compiler/Dialect/Util/IR/UtilBase.td
@@ -7,6 +7,7 @@
#ifndef IREE_DIALECT_UTIL_IR_UTIL_BASE
#define IREE_DIALECT_UTIL_IR_UTIL_BASE
+include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.cpp b/iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.cpp
index cad23d6..31b0326 100644
--- a/iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.cpp
+++ b/iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.cpp
@@ -10,6 +10,7 @@
#include "iree/compiler/Dialect/Vulkan/IR/VulkanTypes.h"
#include "mlir/IR/AttributeSupport.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Location.h"
@@ -64,7 +65,7 @@
} // namespace detail
TargetEnvAttr TargetEnvAttr::get(Vulkan::Version version, uint32_t revision,
- ArrayRef<Vulkan::Extension> extensions,
+ ArrayRef<Extension> extensions,
spirv::Vendor vendorID,
spirv::DeviceType deviceType,
uint32_t deviceID,
@@ -73,7 +74,7 @@
llvm::SmallVector<Attribute, 0> extAttrs;
extAttrs.reserve(extensions.size());
for (auto ext : extensions) {
- extAttrs.push_back(builder.getStringAttr(Vulkan::stringifyExtension(ext)));
+ extAttrs.push_back(ExtensionAttr::get(builder.getContext(), ext));
}
return get(builder.getI32IntegerAttr(static_cast<uint32_t>(version)),
builder.getI32IntegerAttr(revision),
@@ -106,7 +107,7 @@
TargetEnvAttr::ext_iterator::ext_iterator(ArrayAttr::iterator it)
: llvm::mapped_iterator<ArrayAttr::iterator, Extension (*)(Attribute)>(
it, [](Attribute attr) {
- return *symbolizeExtension(attr.cast<StringAttr>().getValue());
+ return *symbolizeExtension(attr.cast<IntegerAttr>().getInt());
}) {}
TargetEnvAttr::ext_range TargetEnvAttr::getExtensions() {
@@ -141,12 +142,17 @@
if (!revision.getType().isInteger(32))
return emitError() << "expected 32-bit integer for revision";
- if (!llvm::all_of(extensions.getValue(), [](Attribute attr) {
- if (auto strAttr = attr.dyn_cast<StringAttr>())
- if (symbolizeExtension(strAttr.getValue())) return true;
- return false;
- }))
- return emitError() << "unknown extension in extension list";
+ for (Attribute attr : extensions.getValue()) {
+ auto intAttr = attr.dyn_cast<IntegerAttr>();
+ if (!intAttr || !intAttr.getType().isSignlessInteger()) {
+ return emitError() << "extension attribute '" << attr
+ << "' should be 32-bit signless integer";
+ }
+ if (!symbolizeExtension(intAttr.getInt())) {
+ return emitError() << "unknown extension '" << attr
+ << "' in extension list";
+ }
+ }
if (!capabilities.isa<CapabilitiesAttr>()) {
return emitError() << "expected vulkan::CapabilitiesAttr for capabilities";
diff --git a/iree/compiler/Dialect/Vulkan/IR/VulkanBase.td b/iree/compiler/Dialect/Vulkan/IR/VulkanBase.td
index 689cb08..a0e7497 100644
--- a/iree/compiler/Dialect/Vulkan/IR/VulkanBase.td
+++ b/iree/compiler/Dialect/Vulkan/IR/VulkanBase.td
@@ -8,6 +8,7 @@
#define IREE_DIALECT_VULKAN_BASE
include "mlir/IR/OpBase.td"
+include "mlir/IR/EnumAttr.td"
//===----------------------------------------------------------------------===//
// Vulkan dialect definition
@@ -45,9 +46,6 @@
class VK_IsKnownIntEnumCaseFor<string name> :
CPred<"::mlir::iree_compiler::IREE::Vulkan::symbolize" # name # "("
"$_self.cast<IntegerAttr>().getValue().getZExtValue()).hasValue()">;
-class VK_IsKnownStrEnumCaseFor<string name> :
- CPred<"::mlir::iree_compiler::IREE::Vulkan::symbolize" # name # "("
- "$_self.cast<StringAttr>().getValue()).hasValue()">;
// Wrapper over base I32BitEnumAttr to set common fields.
class VK_BitEnumAttr<string name, string description,
@@ -65,11 +63,11 @@
let cppNamespace = "::mlir::iree_compiler::IREE::Vulkan";
}
-// Wrapper over base StrEnumAttr to set common fields.
-class VK_StrEnumAttr<string name, string description,
- list<StrEnumAttrCase> cases> :
- StrEnumAttr<name, description, cases> {
- let predicate = And<[StrAttr.predicate, VK_IsKnownStrEnumCaseFor<name>]>;
+// Wrapper over base I32EnumAttr to set common fields for mimicing StrEnumAttr.
+class VK_EnumAttr<string name, string description,
+ list<I32EnumAttrCase> cases> :
+ EnumAttr<VK_Dialect, I32EnumAttr<name, description, cases>, name> {
+ let predicate = And<[StrAttr.predicate, VK_IsKnownIntEnumCaseFor<name>]>;
let cppNamespace = "::mlir::iree_compiler::IREE::Vulkan";
}
@@ -84,16 +82,16 @@
def VK_VersionAttr : VK_I32EnumAttr<"Version", "valid Vulkan version", [
VK_V_1_0, VK_V_1_1, VK_V_1_2]>;
-def VK_KHR_16bit_storage : StrEnumAttrCase<"VK_KHR_16bit_storage">;
-def VK_KHR_8bit_storage : StrEnumAttrCase<"VK_KHR_8bit_storage">;
-def VK_KHR_shader_float16_int8 : StrEnumAttrCase<"VK_KHR_shader_float16_int8">;
-def VK_KHR_spirv_1_4 : StrEnumAttrCase<"VK_KHR_spirv_1_4">;
-def VK_KHR_storage_buffer_storage_class : StrEnumAttrCase<"VK_KHR_storage_buffer_storage_class">;
-def VK_KHR_variable_pointers: StrEnumAttrCase<"VK_KHR_variable_pointers">;
-def VK_NV_cooperative_matrix : StrEnumAttrCase<"VK_NV_cooperative_matrix">;
+def VK_KHR_16bit_storage : I32EnumAttrCase<"VK_KHR_16bit_storage", 0>;
+def VK_KHR_8bit_storage : I32EnumAttrCase<"VK_KHR_8bit_storage", 1>;
+def VK_KHR_shader_float16_int8 : I32EnumAttrCase<"VK_KHR_shader_float16_int8", 2>;
+def VK_KHR_spirv_1_4 : I32EnumAttrCase<"VK_KHR_spirv_1_4", 3>;
+def VK_KHR_storage_buffer_storage_class : I32EnumAttrCase<"VK_KHR_storage_buffer_storage_class", 4>;
+def VK_KHR_variable_pointers: I32EnumAttrCase<"VK_KHR_variable_pointers", 5>;
+def VK_NV_cooperative_matrix : I32EnumAttrCase<"VK_NV_cooperative_matrix", 6>;
def VK_ExtensionAttr :
- VK_StrEnumAttr<"Extension", "supported Vulkan extension", [
+ VK_EnumAttr<"Extension", "supported Vulkan extension", [
VK_KHR_16bit_storage, VK_KHR_8bit_storage, VK_KHR_shader_float16_int8,
VK_KHR_spirv_1_4, VK_KHR_storage_buffer_storage_class,
VK_KHR_variable_pointers, VK_NV_cooperative_matrix
diff --git a/iree/compiler/Dialect/Vulkan/IR/VulkanDialect.cpp b/iree/compiler/Dialect/Vulkan/IR/VulkanDialect.cpp
index 63a72e9..ced154c 100644
--- a/iree/compiler/Dialect/Vulkan/IR/VulkanDialect.cpp
+++ b/iree/compiler/Dialect/Vulkan/IR/VulkanDialect.cpp
@@ -93,8 +93,9 @@
StringRef errorKeyword;
auto processExtension = [&](llvm::SMLoc loc, StringRef extension) {
- if (symbolizeExtension(extension)) {
- extensions.push_back(builder.getStringAttr(extension));
+ if (auto symbol = symbolizeExtension(extension)) {
+ extensions.push_back(builder.getI32IntegerAttr(
+ static_cast<uint32_t>(symbol.getValue())));
return success();
}
return errorloc = loc, errorKeyword = extension, failure();
@@ -190,7 +191,8 @@
<< stringifyVersion(targetEnv.getVersion()) << ", r("
<< targetEnv.getRevision() << "), [";
interleaveComma(targetEnv.getExtensionsAttr(), os, [&](Attribute attr) {
- os << attr.cast<StringAttr>().getValue();
+ os << stringifyExtension(
+ *symbolizeExtension(attr.cast<IntegerAttr>().getInt()));
});
printer << "], " << spirv::stringifyVendor(targetEnv.getVendorID());
printer << ":" << spirv::stringifyDeviceType(targetEnv.getDeviceType());
diff --git a/iree/compiler/Dialect/Vulkan/Utils/TargetTriple.cpp b/iree/compiler/Dialect/Vulkan/Utils/TargetTriple.cpp
index b486c8f..8db70c7 100644
--- a/iree/compiler/Dialect/Vulkan/Utils/TargetTriple.cpp
+++ b/iree/compiler/Dialect/Vulkan/Utils/TargetTriple.cpp
@@ -103,7 +103,7 @@
/// version. The GPU triple is a handy way to specify the target but we cannot
/// encode all the information in the triple.
void getExtensions(const TargetTriple &triple,
- llvm::SmallVectorImpl<Vulkan::Extension> &extensions) {
+ llvm::SmallVectorImpl<Extension> &extensions) {
// Mobile GPUs need to take Android version into consideration.
switch (triple.getArch()) {
case TargetTripleArch::Apple_M1: {
@@ -396,7 +396,7 @@
}
TargetEnvAttr TargetTriple::getTargetEnv(MLIRContext *context) const {
- SmallVector<Vulkan::Extension> extensions;
+ SmallVector<Extension> extensions;
getExtensions(*this, extensions);
return TargetEnvAttr::get(getVersion(*this), /*revision=*/0, extensions,
getVendor(*this), getDeviceType(*this),
diff --git a/iree/compiler/InputConversion/MHLO/test/convert_complex_to_real.mlir b/iree/compiler/InputConversion/MHLO/test/convert_complex_to_real.mlir
index a34b876..2be74e6 100644
--- a/iree/compiler/InputConversion/MHLO/test/convert_complex_to_real.mlir
+++ b/iree/compiler/InputConversion/MHLO/test/convert_complex_to_real.mlir
@@ -8,8 +8,8 @@
// CHECK-DAG: [[VAL0:%.+]] = mhlo.add %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = mhlo.add %arg1, %arg3
%4 = "mhlo.add"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
- %5 = "mhlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
- %6 = "mhlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
+ %5 = mhlo.real(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
+ %6 = mhlo.imag(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return [[VAL0]], [[VAL1]]
return %5, %6 : tensor<2xf32>, tensor<2xf32>
@@ -23,8 +23,8 @@
// CHECK-DAG: [[VAL0:%.+]] = mhlo.subtract %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = mhlo.subtract %arg1, %arg3
%4 = "mhlo.subtract"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
- %5 = "mhlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
- %6 = "mhlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
+ %5 = mhlo.real(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
+ %6 = mhlo.imag(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return [[VAL0]], [[VAL1]]
return %5, %6 : tensor<2xf32>, tensor<2xf32>
@@ -42,8 +42,8 @@
// CHECK-DAG: %[[VAL4:.+]] = chlo.broadcast_multiply %arg1, %arg2
// CHECK-DAG: %[[VAL5:.+]] = mhlo.add %[[VAL3]], %[[VAL4]]
%4 = "mhlo.multiply"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
- %5 = "mhlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
- %6 = "mhlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
+ %5 = mhlo.real(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
+ %6 = mhlo.imag(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return %2, %5 : tensor<2xf32>, tensor<2xf32>
return %5, %6 : tensor<2xf32>, tensor<2xf32>
@@ -54,7 +54,7 @@
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
- // CHECK-DAG: %[[VAL0:.+]] = "mhlo.negate"(%arg3)
+ // CHECK-DAG: %[[VAL0:.+]] = mhlo.negate %arg3
// Compute the numerator's real component:
// numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag
@@ -79,8 +79,8 @@
// CHECK-DAG: %[[VAL11:.+]] = chlo.broadcast_divide %[[VAL9]], %[[VAL6]]
%4 = "mhlo.divide"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
- %5 = "mhlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
- %6 = "mhlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
+ %5 = mhlo.real(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
+ %6 = mhlo.imag(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return %[[VAL10]], %[[VAL11]]
return %5, %6 : tensor<2xf32>, tensor<2xf32>
@@ -93,8 +93,8 @@
// CHECK-DAG: %[[VAL0:.+]] = mhlo.multiply %arg0, %arg0
// CHECK-DAG: %[[VAL1:.+]] = mhlo.multiply %arg1, %arg1
// CHECK-DAG: %[[VAL2:.+]] = mhlo.add %[[VAL0]], %[[VAL1]]
- // CHECK-DAG: %[[VAL3:.+]] = "mhlo.sqrt"(%[[VAL2]])
- %1 = "mhlo.abs"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
+ // CHECK-DAG: %[[VAL3:.+]] = mhlo.sqrt %[[VAL2]]
+ %1 = mhlo.abs(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return %[[VAL3]]
return %1 : tensor<2xf32>
@@ -104,15 +104,15 @@
func.func @exp(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
- // CHECK-DAG: %[[EXP:.+]] = "mhlo.exponential"(%arg0)
- // CHECK-DAG: %[[COS:.+]] = "mhlo.cosine"(%arg1)
- // CHECK-DAG: %[[SIN:.+]] = "mhlo.sine"(%arg1)
+ // CHECK-DAG: %[[EXP:.+]] = mhlo.exponential %arg0
+ // CHECK-DAG: %[[COS:.+]] = mhlo.cosine %arg1
+ // CHECK-DAG: %[[SIN:.+]] = mhlo.sine %arg1
// CHECK-DAG: %[[OUTR:.+]] = mhlo.multiply %[[COS]], %[[EXP]]
// CHECK-DAG: %[[OUTI:.+]] = mhlo.multiply %[[SIN]], %[[EXP]]
- %1 = "mhlo.exponential"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
+ %1 = mhlo.exponential %0 : tensor<2xcomplex<f32>>
- %2 = "mhlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
- %3 = "mhlo.imag"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
+ %2 = mhlo.real(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
+ %3 = mhlo.imag(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: %[[OUTR]], %[[OUTI]]
return %2, %3 : tensor<2xf32>, tensor<2xf32>
diff --git a/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing.mlir b/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing.mlir
index b8b2801..0a030df 100644
--- a/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing.mlir
+++ b/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing.mlir
@@ -12,7 +12,7 @@
-> (tensor<4x256xf32>) {
// CHECK-DAG: %[[EPS_BCAST:.+]] = mhlo.constant dense<1.001000e-05> : tensor<256xf32>
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32>
- // CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<256xf32>) -> tensor<256xf32>
+ // CHECK-DAG: %[[STDDEV:.+]] = mhlo.sqrt %[[VARIANCE_EPS]] : tensor<256xf32>
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[STDDEV]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
// CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[OFFSET]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
@@ -94,21 +94,21 @@
%1 = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<1x8x8x64xf32>
%2 = "mhlo.broadcast_in_dim"(%arg2) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>) -> tensor<1x8x8x64xi32>
%3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>) -> tensor<1x8x8x64xi32>
- %4 = "mhlo.add"(%0, %1) : (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
- %5 = "mhlo.atan2"(%0, %1) : (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
- %6 = "mhlo.divide"(%0, %1) : (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
- %7 = "mhlo.maximum"(%0, %1) : (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
- %8 = "mhlo.minimum"(%0, %1) : (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
- %9 = "mhlo.multiply"(%0, %1) : (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
- %10 = "mhlo.power"(%0, %1) : (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
- %11 = "mhlo.remainder"(%0, %1) : (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
- %12 = "mhlo.shift_left"(%0, %1) : (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
- %13 = "mhlo.shift_right_arithmetic"(%0, %1) : (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
- %14 = "mhlo.shift_right_logical"(%0, %1) : (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
- %15 = "mhlo.subtract"(%0, %1) : (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
- %16 = "mhlo.and"(%2, %3) : (tensor<1x8x8x64xi32>, tensor<1x8x8x64xi32>) -> tensor<1x8x8x64xi32>
- %17 = "mhlo.or"(%2, %3) : (tensor<1x8x8x64xi32>, tensor<1x8x8x64xi32>) -> tensor<1x8x8x64xi32>
- %18 = "mhlo.xor"(%2, %3) : (tensor<1x8x8x64xi32>, tensor<1x8x8x64xi32>) -> tensor<1x8x8x64xi32>
+ %4 = mhlo.add %0, %1 : tensor<1x8x8x64xf32>
+ %5 = mhlo.atan2 %0, %1 : tensor<1x8x8x64xf32>
+ %6 = mhlo.divide %0, %1 : tensor<1x8x8x64xf32>
+ %7 = mhlo.maximum %0, %1 : tensor<1x8x8x64xf32>
+ %8 = mhlo.minimum %0, %1 : tensor<1x8x8x64xf32>
+ %9 = mhlo.multiply %0, %1 : tensor<1x8x8x64xf32>
+ %10 = mhlo.power %0, %1 : tensor<1x8x8x64xf32>
+ %11 = mhlo.remainder %0, %1 : tensor<1x8x8x64xf32>
+ %12 = mhlo.shift_left %0, %1 : tensor<1x8x8x64xf32>
+ %13 = mhlo.shift_right_arithmetic %0, %1 : tensor<1x8x8x64xf32>
+ %14 = mhlo.shift_right_logical %0, %1 : tensor<1x8x8x64xf32>
+ %15 = mhlo.subtract %0, %1 : tensor<1x8x8x64xf32>
+ %16 = mhlo.and %2, %3 : tensor<1x8x8x64xi32>
+ %17 = mhlo.or %2, %3 : tensor<1x8x8x64xi32>
+ %18 = mhlo.xor %2, %3 : tensor<1x8x8x64xi32>
return %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18 : tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xi32>, tensor<1x8x8x64xi32>, tensor<1x8x8x64xi32>
}
@@ -132,7 +132,7 @@
// CHECK: %[[BCAST:.*]] = "mhlo.broadcast_in_dim"(%[[ATAN2]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xf32>) -> tensor<4x3xf32>
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<3xf32>) -> tensor<4x3xf32>
%1 = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<3xf32>) -> tensor<4x3xf32>
- %2 = "mhlo.atan2"(%0, %1) : (tensor<4x3xf32>, tensor<4x3xf32>) -> tensor<4x3xf32>
+ %2 = mhlo.atan2 %0, %1 : tensor<4x3xf32>
// CHECK: return %[[BCAST]]
return %2 : tensor<4x3xf32>
}
@@ -145,7 +145,7 @@
// CHECK: %[[BCAST:.*]] = "mhlo.broadcast_in_dim"(%[[POWER]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<2x4xi32>) -> tensor<3x2x4xi32>
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<2x4xi32>) -> tensor<3x2x4xi32>
%1 = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<2x4xi32>) -> tensor<3x2x4xi32>
- %2 = "mhlo.power"(%0, %1) : (tensor<3x2x4xi32>, tensor<3x2x4xi32>) -> tensor<3x2x4xi32>
+ %2 = mhlo.power %0, %1 : tensor<3x2x4xi32>
// CHECK: return %[[BCAST]]
return %2 : tensor<3x2x4xi32>
}
@@ -154,46 +154,46 @@
// CHECK: @reorder_broadcast_in_dim_scalar_unary(%[[ARG0:.*]]: tensor<f32>)
func.func @reorder_broadcast_in_dim_scalar_unary(%arg0: tensor<f32>) -> (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>) {
- // CHECK: %[[ABS:.*]] = "mhlo.abs"(%[[ARG0]]) : (tensor<f32>) -> tensor<f32>
+ // CHECK: %[[ABS:.*]] = mhlo.abs %[[ARG0]] : tensor<f32>
// CHECK: "mhlo.broadcast_in_dim"(%[[ABS]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<1x8x8x64xf32>
- // CHECK: %[[CEIL:.*]] = "mhlo.ceil"(%[[ARG0]]) : (tensor<f32>) -> tensor<f32>
+ // CHECK: %[[CEIL:.*]] = mhlo.ceil %[[ARG0]] : tensor<f32>
// CHECK: "mhlo.broadcast_in_dim"(%[[CEIL]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<1x8x8x64xf32>
- // CHECK: %[[COSINE:.*]] = "mhlo.cosine"(%[[ARG0]]) : (tensor<f32>) -> tensor<f32>
+ // CHECK: %[[COSINE:.*]] = mhlo.cosine %[[ARG0]] : tensor<f32>
// CHECK: "mhlo.broadcast_in_dim"(%[[COSINE]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<1x8x8x64xf32>
- // CHECK: %[[EXP:.*]] = "mhlo.exponential"(%[[ARG0]]) : (tensor<f32>) -> tensor<f32>
+ // CHECK: %[[EXP:.*]] = mhlo.exponential %[[ARG0]] : tensor<f32>
// CHECK: "mhlo.broadcast_in_dim"(%[[EXP]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<1x8x8x64xf32>
- // CHECK: %[[FLOOR:.*]] = "mhlo.floor"(%[[ARG0]]) : (tensor<f32>) -> tensor<f32>
+ // CHECK: %[[FLOOR:.*]] = mhlo.floor %[[ARG0]] : tensor<f32>
// CHECK: "mhlo.broadcast_in_dim"(%[[FLOOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<1x8x8x64xf32>
- // CHECK: %[[LOG:.*]] = "mhlo.log"(%[[ARG0]]) : (tensor<f32>) -> tensor<f32>
+ // CHECK: %[[LOG:.*]] = mhlo.log %[[ARG0]] : tensor<f32>
// CHECK: "mhlo.broadcast_in_dim"(%[[LOG]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<1x8x8x64xf32>
- // CHECK: %[[NEG:.*]] = "mhlo.negate"(%[[ARG0]]) : (tensor<f32>) -> tensor<f32>
+ // CHECK: %[[NEG:.*]] = mhlo.negate %[[ARG0]] : tensor<f32>
// CHECK: "mhlo.broadcast_in_dim"(%[[NEG]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<1x8x8x64xf32>
- // CHECK: %[[ROUND:.*]] = "mhlo.round_nearest_afz"(%[[ARG0]]) : (tensor<f32>) -> tensor<f32>
+ // CHECK: %[[ROUND:.*]] = mhlo.round_nearest_afz %[[ARG0]] : tensor<f32>
// CHECK: "mhlo.broadcast_in_dim"(%[[ROUND]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<1x8x8x64xf32>
- // CHECK: %[[RSQRT:.*]] = "mhlo.rsqrt"(%[[ARG0]]) : (tensor<f32>) -> tensor<f32>
+ // CHECK: %[[RSQRT:.*]] = mhlo.rsqrt %[[ARG0]] : tensor<f32>
// CHECK: "mhlo.broadcast_in_dim"(%[[RSQRT]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<1x8x8x64xf32>
- // CHECK: %[[SIGN:.*]] = "mhlo.sign"(%[[ARG0]]) : (tensor<f32>) -> tensor<f32>
+ // CHECK: %[[SIGN:.*]] = mhlo.sign %[[ARG0]] : tensor<f32>
// CHECK: "mhlo.broadcast_in_dim"(%[[SIGN]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<1x8x8x64xf32>
- // CHECK: %[[SINE:.*]] = "mhlo.sine"(%[[ARG0]]) : (tensor<f32>) -> tensor<f32>
+ // CHECK: %[[SINE:.*]] = mhlo.sine %[[ARG0]] : tensor<f32>
// CHECK: "mhlo.broadcast_in_dim"(%[[SINE]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<1x8x8x64xf32>
- // CHECK: %[[SQRT:.*]] = "mhlo.sqrt"(%[[ARG0]]) : (tensor<f32>) -> tensor<f32>
+ // CHECK: %[[SQRT:.*]] = mhlo.sqrt %[[ARG0]] : tensor<f32>
// CHECK: "mhlo.broadcast_in_dim"(%[[SQRT]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<1x8x8x64xf32>
- // CHECK: %[[TANH:.*]] = "mhlo.tanh"(%[[ARG0]]) : (tensor<f32>) -> tensor<f32>
+ // CHECK: %[[TANH:.*]] = mhlo.tanh %[[ARG0]] : tensor<f32>
// CHECK: "mhlo.broadcast_in_dim"(%[[TANH]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<1x8x8x64xf32>
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<1x8x8x64xf32>
- %1 = "mhlo.abs"(%0) : (tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
- %2 = "mhlo.ceil"(%0) : (tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
- %3 = "mhlo.cosine"(%0) : (tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
- %4 = "mhlo.exponential"(%0) : (tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
- %5 = "mhlo.floor"(%0) : (tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
- %6 = "mhlo.log"(%0) : (tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
- %7 = "mhlo.negate"(%0) : (tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
- %8 = "mhlo.round_nearest_afz"(%0) : (tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
- %9 = "mhlo.rsqrt"(%0) : (tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
- %10 = "mhlo.sign"(%0) : (tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
- %11 = "mhlo.sine"(%0) : (tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
- %12 = "mhlo.sqrt"(%0) : (tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
- %13 = "mhlo.tanh"(%0) : (tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
+ %1 = mhlo.abs %0 : tensor<1x8x8x64xf32>
+ %2 = mhlo.ceil %0 : tensor<1x8x8x64xf32>
+ %3 = mhlo.cosine %0 : tensor<1x8x8x64xf32>
+ %4 = mhlo.exponential %0 : tensor<1x8x8x64xf32>
+ %5 = mhlo.floor %0 : tensor<1x8x8x64xf32>
+ %6 = mhlo.log %0 : tensor<1x8x8x64xf32>
+ %7 = mhlo.negate %0 : tensor<1x8x8x64xf32>
+ %8 = mhlo.round_nearest_afz %0 : tensor<1x8x8x64xf32>
+ %9 = mhlo.rsqrt %0 : tensor<1x8x8x64xf32>
+ %10 = mhlo.sign %0 : tensor<1x8x8x64xf32>
+ %11 = mhlo.sine %0 : tensor<1x8x8x64xf32>
+ %12 = mhlo.sqrt %0 : tensor<1x8x8x64xf32>
+ %13 = mhlo.tanh %0 : tensor<1x8x8x64xf32>
return %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13: tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>
}
@@ -201,10 +201,10 @@
// CHECK: @reorder_broadcast_in_dim_1d_unary(%[[ARG0:.*]]: tensor<3xf32>) -> tensor<4x3xf32>
func.func @reorder_broadcast_in_dim_1d_unary(%arg0: tensor<3xf32>) -> tensor<4x3xf32> {
- // CHECK: %[[COS:.*]] = "mhlo.cosine"(%[[ARG0]]) : (tensor<3xf32>) -> tensor<3xf32>
+ // CHECK: %[[COS:.*]] = mhlo.cosine %[[ARG0]] : tensor<3xf32>
// CHECK: %[[BCAST:.*]] = "mhlo.broadcast_in_dim"(%[[COS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xf32>) -> tensor<4x3xf32>
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<3xf32>) -> tensor<4x3xf32>
- %1 = "mhlo.cosine"(%0) : (tensor<4x3xf32>) -> tensor<4x3xf32>
+ %1 = mhlo.cosine %0 : tensor<4x3xf32>
// CHECK: return %[[BCAST]]
return %1 : tensor<4x3xf32>
}
@@ -213,10 +213,10 @@
// CHECK: @reorder_in_dim_2d_unary(%[[ARG0:.*]]: tensor<2x4xf32>) -> tensor<3x2x4xf32>
func.func @reorder_in_dim_2d_unary(%arg0: tensor<2x4xf32>) -> tensor<3x2x4xf32> {
- // CHECK: %[[LOG:.*]] = "mhlo.log"(%[[ARG0]]) : (tensor<2x4xf32>) -> tensor<2x4xf32>
+ // CHECK: %[[LOG:.*]] = mhlo.log %[[ARG0]] : tensor<2x4xf32>
// CHECK: %[[BCAST:.*]] = "mhlo.broadcast_in_dim"(%[[LOG]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<2x4xf32>) -> tensor<3x2x4xf32>
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<2x4xf32>) -> tensor<3x2x4xf32>
- %1 = "mhlo.log"(%0) : (tensor<3x2x4xf32>) -> tensor<3x2x4xf32>
+ %1 = mhlo.log %0 : tensor<3x2x4xf32>
// CHECK: return %[[BCAST]]
return %1 : tensor<3x2x4xf32>
}
@@ -225,13 +225,13 @@
// CHECK: @reorder_broadcast_in_dim_scalar_unary_diff_type(%[[ARG0:.*]]: tensor<complex<f32>>) -> (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>)
func.func @reorder_broadcast_in_dim_scalar_unary_diff_type(%arg0: tensor<complex<f32>>) -> (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>) {
- // CHECK: %[[REAL:.*]] = "mhlo.real"(%[[ARG0]]) : (tensor<complex<f32>>) -> tensor<f32>
+ // CHECK: %[[REAL:.*]] = mhlo.real(%[[ARG0]]) : (tensor<complex<f32>>) -> tensor<f32>
// CHECK: "mhlo.broadcast_in_dim"(%[[REAL]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<1x8x8x64xf32>
- // CHECK: %[[IMAG:.*]] = "mhlo.imag"(%[[ARG0]]) : (tensor<complex<f32>>) -> tensor<f32>
+ // CHECK: %[[IMAG:.*]] = mhlo.imag(%[[ARG0]]) : (tensor<complex<f32>>) -> tensor<f32>
// CHECK: "mhlo.broadcast_in_dim"(%[[IMAG]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<1x8x8x64xf32>
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<complex<f32>>) -> tensor<1x8x8x64xcomplex<f32>>
- %1 = "mhlo.real"(%0) : (tensor<1x8x8x64xcomplex<f32>>) -> tensor<1x8x8x64xf32>
- %2 = "mhlo.imag"(%0) : (tensor<1x8x8x64xcomplex<f32>>) -> tensor<1x8x8x64xf32>
+ %1 = mhlo.real(%0) : (tensor<1x8x8x64xcomplex<f32>>) -> tensor<1x8x8x64xf32>
+ %2 = mhlo.imag(%0) : (tensor<1x8x8x64xcomplex<f32>>) -> tensor<1x8x8x64xf32>
return %1, %2: tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>
}
@@ -296,8 +296,8 @@
// CHECK-LABEL: @mul_float_bool_cast
// CHECK: %[[ZERO:.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
-// CHECK: %[[BTOF:.+]] = "mhlo.convert"(%arg0) : (tensor<?xi1>) -> tensor<?xf32>
-// CHECK: %[[FTOB:.+]] = "mhlo.convert"(%[[BTOF]]) : (tensor<?xf32>) -> tensor<?xi1>
+// CHECK: %[[BTOF:.+]] = mhlo.convert(%arg0) : (tensor<?xi1>) -> tensor<?xf32>
+// CHECK: %[[FTOB:.+]] = mhlo.convert(%[[BTOF]]) : (tensor<?xf32>) -> tensor<?xi1>
// CHECK: %[[SHP:.+]] = shape.shape_of %[[BTOF]] : tensor<?xf32> -> tensor<1xindex>
// CHECK: %[[BROADCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ZERO]], %[[SHP]]) {broadcast_dimensions = dense<> : tensor<0xi64>}
// CHECK: %[[SELECT:.+]] = "mhlo.select"(%[[FTOB]], %arg1, %[[BROADCAST]])
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 2f076bf..a571f82 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 2f076bf64e534713888bf528652200136c503cb6
+Subproject commit a571f82a50416b767fd3cce0fb5027bb5dfec58c
diff --git a/third_party/mlir-hlo b/third_party/mlir-hlo
index a41d2b9..9b43a08 160000
--- a/third_party/mlir-hlo
+++ b/third_party/mlir-hlo
@@ -1 +1 @@
-Subproject commit a41d2b935f15dc98b0e0ec356dd57a96f32cf147
+Subproject commit 9b43a08be8ad6a9c8d77f37f61a7be6e0ec8c200