Merge pull request #6616 from NatashaKnk:main-to-google

PiperOrigin-RevId: 388278366
diff --git a/SUBMODULE_VERSIONS.txt b/SUBMODULE_VERSIONS.txt
index 47dfaa1..0605398 100644
--- a/SUBMODULE_VERSIONS.txt
+++ b/SUBMODULE_VERSIONS.txt
@@ -4,13 +4,13 @@
 aa533abfd4232b01f9e57041d70114d5a77e6de0 third_party/googletest
 88b845dee001723c4a0db1fe5477de735b6d3bb0 third_party/liburing
 acd6f6f014c25e46363e718381e0b35205df2d83 third_party/libyaml
-6a7a2ee8161da84d9a58a88b497b0b47c8df99f3 third_party/llvm-project
-be5715be382c6955aba0533a6dd5e50d79734bd0 third_party/mlir-hlo
+18ec93d9e60c687bfb2b39269f7f81d47b71a179 third_party/llvm-project
+7e00576f3f1e0a935b74a7320de40e56d3d87c11 third_party/mlir-hlo
 4c7697dbe973ed01ae6fbec37d186ebd05982e1f third_party/pybind11
 2e1b5fb39ebc2ef4cb77005f8267e4f3a6241ba1 third_party/spirv_cross
 f5417a4b6633c3217c9a1bc2f0c70b1454975ba7 third_party/spirv_headers
 b42009b3b9d4ca35bc703f5310eedc74f584be58 third_party/stblib
-faa23e33c73721e8b1e90010866e358b922f5d15 third_party/tensorflow
+0c225ef7450aea9b57cf3ece1e89c93f42a247af third_party/tensorflow
 50f7deb1a389bd3785c12fbe0be74128343f11f7 third_party/tracy
 9d10a96f2d57c3c37e167f2e73c9a31ac2e51fa5 third_party/vulkan_headers
 8d4a9e9174a9c6ad6a3a3ae981b915ef13fc12c4 third_party/vulkan_memory_allocator
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/test/strip_asserts.mlir b/integrations/tensorflow/iree_tf_compiler/TF/test/strip_asserts.mlir
index 043c292..66dc284 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/test/strip_asserts.mlir
+++ b/integrations/tensorflow/iree_tf_compiler/TF/test/strip_asserts.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-tf-opt -split-input-file -verify-diagnostics -pass-pipeline='func(iree-tf-strip-asserts)' %s | IreeFileCheck %s
+// RUN: iree-tf-opt -split-input-file -verify-diagnostics -pass-pipeline='builtin.func(iree-tf-strip-asserts)' %s | IreeFileCheck %s
 
 // CHECK-LABEL: @asserts
 // CHECK-NOT: tf.Assert
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/test/strip_metadata.mlir b/integrations/tensorflow/iree_tf_compiler/TF/test/strip_metadata.mlir
index 3d7e2d8..7c1b696 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/test/strip_metadata.mlir
+++ b/integrations/tensorflow/iree_tf_compiler/TF/test/strip_metadata.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-tf-opt -split-input-file -verify-diagnostics -pass-pipeline='iree-tf-strip-module-metadata,func(iree-tf-strip-function-metadata)' %s | IreeFileCheck %s
+// RUN: iree-tf-opt -split-input-file -verify-diagnostics -pass-pipeline='iree-tf-strip-module-metadata,builtin.func(iree-tf-strip-function-metadata)' %s | IreeFileCheck %s
 
 // CHECK-LABEL: @tf_module
 // CHECK-NOT: attributes
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/test/convert_metadata.mlir b/integrations/tensorflow/iree_tf_compiler/TFL/test/convert_metadata.mlir
index 5f49f81..b1e356a 100644
--- a/integrations/tensorflow/iree_tf_compiler/TFL/test/convert_metadata.mlir
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/test/convert_metadata.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt-tflite -split-input-file -pass-pipeline='iree-tflite-convert-module-metadata,func(iree-tflite-convert-function-metadata)' %s | IreeFileCheck %s
+// RUN: iree-opt-tflite -split-input-file -pass-pipeline='iree-tflite-convert-module-metadata,builtin.func(iree-tflite-convert-function-metadata)' %s | IreeFileCheck %s
 
 module attributes {tfl.schema_version = 3 : i32} {
   // CHECK: func @main(
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/test/strip_metadata.mlir b/integrations/tensorflow/iree_tf_compiler/TFL/test/strip_metadata.mlir
index 1c9504b..3ae5455 100644
--- a/integrations/tensorflow/iree_tf_compiler/TFL/test/strip_metadata.mlir
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/test/strip_metadata.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt-tflite -split-input-file -verify-diagnostics -pass-pipeline='iree-tflite-strip-module-metadata,func(iree-tflite-strip-function-metadata)' %s | IreeFileCheck %s
+// RUN: iree-opt-tflite -split-input-file -verify-diagnostics -pass-pipeline='iree-tflite-strip-module-metadata,builtin.func(iree-tflite-strip-function-metadata)' %s | IreeFileCheck %s
 
 // CHECK-LABEL: module {
 // CHECK-NOT: tf.schema_version
diff --git a/iree/compiler/Codegen/LLVMCPU/test/matmul_vectorization.mlir b/iree/compiler/Codegen/LLVMCPU/test/matmul_vectorization.mlir
index bfe3481..bcbd154 100644
--- a/iree/compiler/Codegen/LLVMCPU/test/matmul_vectorization.mlir
+++ b/iree/compiler/Codegen/LLVMCPU/test/matmul_vectorization.mlir
@@ -1,5 +1,5 @@
-// RUN: iree-opt -pass-pipeline="hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target{use-lowering-pipeline='func(iree-llvmcpu-vectorization)'}))" -split-input-file %s | IreeFileCheck %s
-// RUN: iree-opt -pass-pipeline="hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target{use-lowering-pipeline='func(iree-llvmcpu-vectorization{promote-workgroup-to-full-tiles}),cse'}))" -split-input-file %s | IreeFileCheck %s -check-prefix=CHECK-PROMOTED
+// RUN: iree-opt -pass-pipeline="hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target{use-lowering-pipeline='builtin.func(iree-llvmcpu-vectorization)'}))" -split-input-file %s | IreeFileCheck %s
+// RUN: iree-opt -pass-pipeline="hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target{use-lowering-pipeline='builtin.func(iree-llvmcpu-vectorization{promote-workgroup-to-full-tiles}),cse'}))" -split-input-file %s | IreeFileCheck %s -check-prefix=CHECK-PROMOTED
 
 #config = {nativeVectorSize = [4, 4, 4], tileSizes = [[64, 64], [32, 32, 32], [4, 4, 4]]}
 hal.executable @dynamic_matmul attributes {sym_visibility = "private"} {
diff --git a/iree/compiler/Codegen/SPIRV/BUILD b/iree/compiler/Codegen/SPIRV/BUILD
index e77a60c..5ec14cf 100644
--- a/iree/compiler/Codegen/SPIRV/BUILD
+++ b/iree/compiler/Codegen/SPIRV/BUILD
@@ -55,7 +55,9 @@
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:LinalgOps",
         "@llvm-project//mlir:LinalgTransforms",
+        "@llvm-project//mlir:MathToSPIRV",
         "@llvm-project//mlir:MemRefDialect",
+        "@llvm-project//mlir:MemRefToSPIRV",
         "@llvm-project//mlir:MemRefTransforms",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:SCFDialect",
diff --git a/iree/compiler/Codegen/SPIRV/CMakeLists.txt b/iree/compiler/Codegen/SPIRV/CMakeLists.txt
index 7e9070c..31c4888 100644
--- a/iree/compiler/Codegen/SPIRV/CMakeLists.txt
+++ b/iree/compiler/Codegen/SPIRV/CMakeLists.txt
@@ -43,7 +43,9 @@
     MLIRIR
     MLIRLinalg
     MLIRLinalgTransforms
+    MLIRMathToSPIRV
     MLIRMemRef
+    MLIRMemRefToSPIRV
     MLIRMemRefTransforms
     MLIRPass
     MLIRSCF
diff --git a/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp b/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
index 7c3cbb7..c3aa239 100644
--- a/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
+++ b/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
@@ -25,6 +25,8 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h"
+#include "mlir/Conversion/MathToSPIRV/MathToSPIRV.h"
+#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
 #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
 #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h"
 #include "mlir/Conversion/TosaToStandard/TosaToStandard.h"
@@ -315,8 +317,12 @@
   // TODO(antiagainst): Use a lowering that uses specific SPIRV intrinsics.
   tosa::populateTosaRescaleToStandardConversionPatterns(&patterns);
 
-  // Pull in standard patterns to convert arithmetic ops and others.
+  // Pull in MemRef patterns to convert load/store ops.
+  populateMemRefToSPIRVPatterns(typeConverter, patterns);
+
+  // Pull in standard/math patterns to convert arithmetic ops and others.
   populateStandardToSPIRVPatterns(typeConverter, patterns);
+  populateMathToSPIRVPatterns(typeConverter, patterns);
 
   // Pull in standard patterns to convert tensor operations to SPIR-V. These are
   // primarily used to handle tensor-type constants and contain a
diff --git a/iree/compiler/Codegen/SPIRV/test/vector_to_cooperative_matrix.mlir b/iree/compiler/Codegen/SPIRV/test/vector_to_cooperative_matrix.mlir
index 33ff001..53861fa 100644
--- a/iree/compiler/Codegen/SPIRV/test/vector_to_cooperative_matrix.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/vector_to_cooperative_matrix.mlir
@@ -42,7 +42,7 @@
     %c0_i8 = constant 0 : i8
     // CHECK: %[[C:.+]] = spv.CooperativeMatrixLoadNV
     %4 = vector.transfer_read %arg2[%c0, %c0], %c0_i32 {in_bounds = [true, true]} : memref<4096x4096xi32>, vector<16x16xi32>
-    // CHECK: %[[INIT:.+]] = unrealized_conversion_cast %[[C]] : !spv.coopmatrix<16x16xi32, Subgroup> to vector<16x16xi32>
+    // CHECK: %[[INIT:.+]] = builtin.unrealized_conversion_cast %[[C]] : !spv.coopmatrix<16x16xi32, Subgroup> to vector<16x16xi32>
     // CHECK: %[[LOOP:.+]] = scf.for
     // CHECK-SAME: iter_args(%[[ARG:.+]] = %[[INIT]])
     %5 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %4) -> (vector<16x16xi32>) {
@@ -50,14 +50,14 @@
       %6 = vector.transfer_read %arg0[%c0, %arg3], %c0_i8 {in_bounds = [true, true]} : memref<4096x4096xi8>, vector<16x32xi8>
       // CHECK: %[[B:.+]] = spv.CooperativeMatrixLoadNV
       %7 = vector.transfer_read %arg1[%arg3, %c0], %c0_i8 {in_bounds = [true, true]} : memref<4096x4096xi8>, vector<32x16xi8>
-      // CHECK: %[[C1:.+]] = unrealized_conversion_cast %[[ARG]] : vector<16x16xi32> to !spv.coopmatrix<16x16xi32, Subgroup>
+      // CHECK: %[[C1:.+]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<16x16xi32> to !spv.coopmatrix<16x16xi32, Subgroup>
       // CHECK: %[[R:.+]] = spv.CooperativeMatrixMulAddNV %[[A]], %[[B]], %[[C1]]
       %8 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %6, %7, %arg4 : vector<16x32xi8>, vector<32x16xi8> into vector<16x16xi32>
-      // CHECK: %[[YIELD:.+]] = unrealized_conversion_cast %[[R]] : !spv.coopmatrix<16x16xi32, Subgroup> to vector<16x16xi32>
+      // CHECK: %[[YIELD:.+]] = builtin.unrealized_conversion_cast %[[R]] : !spv.coopmatrix<16x16xi32, Subgroup> to vector<16x16xi32>
       // CHECK: scf.yield %[[YIELD]]
       scf.yield %8 : vector<16x16xi32>
     }
-    // CHECK: %[[ACCv:.+]] = unrealized_conversion_cast %[[LOOP]] : vector<16x16xi32> to !spv.coopmatrix<16x16xi32, Subgroup>
+    // CHECK: %[[ACCv:.+]] = builtin.unrealized_conversion_cast %[[LOOP]] : vector<16x16xi32> to !spv.coopmatrix<16x16xi32, Subgroup>
     // CHECK: spv.CooperativeMatrixStoreNV %{{.*}}, %[[ACCv]], %{{.*}}, %{{.*}}
     vector.transfer_write %5, %arg2[%c0, %c0] : vector<16x16xi32>, memref<4096x4096xi32>
     return
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/inject_dispatch_tracing.mlir b/iree/compiler/Dialect/Flow/Transforms/test/inject_dispatch_tracing.mlir
index 855df02..5993201 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/inject_dispatch_tracing.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/inject_dispatch_tracing.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -pass-pipeline='func(iree-flow-inject-dispatch-tracing)' %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline='builtin.func(iree-flow-inject-dispatch-tracing)' %s | IreeFileCheck %s
 
 // CHECK-LABEL: func @singleDispatch
 // CHECK-SAME: (%[[ARG0:.+]]: tensor<4xf32>)
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/insert_constant_clones.mlir b/iree/compiler/Dialect/Flow/Transforms/test/insert_constant_clones.mlir
index 6e399f0..b05fed4 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/insert_constant_clones.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/insert_constant_clones.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -pass-pipeline='func(iree-flow-insert-constant-clones)' %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline='builtin.func(iree-flow-insert-constant-clones)' %s | IreeFileCheck %s
 
 // CHECK-LABEL: @function_return
 func @function_return() -> (tensor<8xf32>, i32) {
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/promote_i1_to_i8.mlir b/iree/compiler/Dialect/Flow/Transforms/test/promote_i1_to_i8.mlir
index 99f962a..b351f8a 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/promote_i1_to_i8.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/promote_i1_to_i8.mlir
@@ -1,5 +1,5 @@
 
-// RUN: iree-opt -split-input-file -pass-pipeline='func(iree-flow-promote-i1-to-i8)' %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline='builtin.func(iree-flow-promote-i1-to-i8)' %s | IreeFileCheck %s
 
 // CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)>
 
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index d6a0731..e172be6 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -1477,8 +1477,11 @@
   return push_constantsAttr() == other.push_constantsAttr() &&
          bindings.size() == otherBindings.size() &&
          llvm::all_of(llvm::zip(bindings, otherBindings), [](auto bindings) {
-           return OperationEquivalence::isEquivalentTo(std::get<0>(bindings),
-                                                       std::get<1>(bindings));
+           return OperationEquivalence::isEquivalentTo(
+               std::get<0>(bindings), std::get<1>(bindings),
+               OperationEquivalence::exactValueMatch,
+               OperationEquivalence::exactValueMatch,
+               OperationEquivalence::Flags::IgnoreLocations);
          });
 }
 
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.cpp
index cd0f42e..128f5d5 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.cpp
@@ -27,7 +27,7 @@
 namespace HAL {
 
 static llvm::CodeGenOpt::Level passBuilderOptLevelToCodeGenOptLevel(
-    const llvm::PassBuilder::OptimizationLevel &level) {
+    const llvm::OptimizationLevel &level) {
   switch (level.getSpeedupLevel()) {
     case 0:
       return llvm::CodeGenOpt::None;
@@ -87,7 +87,7 @@
     case SanitizerKind::kAddress: {
       passBuilder.registerOptimizerLastEPCallback(
           [](llvm::ModulePassManager &modulePassManager,
-             llvm::PassBuilder::OptimizationLevel Level) {
+             llvm::OptimizationLevel Level) {
             bool compileKernel = false;
             bool recover = false;
             bool useAfterScope = true;
@@ -105,7 +105,7 @@
     } break;
   }
 
-  if (options.optLevel != llvm::PassBuilder::OptimizationLevel::O0) {
+  if (options.optLevel != llvm::OptimizationLevel::O0) {
     llvm::ModulePassManager modulePassManager;
     modulePassManager =
         passBuilder.buildPerModuleDefaultPipeline(options.optLevel);
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMTargetOptions.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMTargetOptions.cpp
index 7c9cff3..3272866 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMTargetOptions.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMTargetOptions.cpp
@@ -45,7 +45,7 @@
 
   // LLVM -O3.
   // TODO(benvanik): add an option for this.
-  targetOptions.optLevel = llvm::PassBuilder::OptimizationLevel::O3;
+  targetOptions.optLevel = llvm::OptimizationLevel::O3;
   targetOptions.options.FloatABIType = llvm::FloatABI::Hard;
 
   return targetOptions;
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMTargetOptions.h b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMTargetOptions.h
index fc3d362..0a0124f 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMTargetOptions.h
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMTargetOptions.h
@@ -29,7 +29,7 @@
   std::string targetCPUFeatures;
 
   llvm::PipelineTuningOptions pipelineTuningOptions;
-  llvm::PassBuilder::OptimizationLevel optLevel;
+  llvm::OptimizationLevel optLevel;
   llvm::TargetOptions options;
 
   // Include debug information in output files (PDB, DWARF, etc).
diff --git a/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp b/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp
index 2d0ce2e..f7d0626 100644
--- a/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp
+++ b/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp
@@ -86,7 +86,10 @@
       if (auto targetOp = targetSymbolMap[symbolName]) {
         if (symbolOp.getVisibility() == SymbolTable::Visibility::Private) {
           // Private symbols can be safely folded into duplicates or renamed.
-          if (OperationEquivalence::isEquivalentTo(targetOp, op)) {
+          if (OperationEquivalence::isEquivalentTo(
+                  targetOp, op, OperationEquivalence::exactValueMatch,
+                  OperationEquivalence::exactValueMatch,
+                  OperationEquivalence::Flags::IgnoreLocations)) {
             // Optimization: skip over duplicate private symbols.
             // We could let CSE do this later, but we may as well check here.
             continue;
diff --git a/iree/compiler/Dialect/Modules/VMVX/Conversion/HALToVMVX/test/interface_ops.mlir b/iree/compiler/Dialect/Modules/VMVX/Conversion/HALToVMVX/test/interface_ops.mlir
index 27f1d28..b5b65e4 100644
--- a/iree/compiler/Dialect/Modules/VMVX/Conversion/HALToVMVX/test/interface_ops.mlir
+++ b/iree/compiler/Dialect/Modules/VMVX/Conversion/HALToVMVX/test/interface_ops.mlir
@@ -28,10 +28,10 @@
   %c1 = constant 1 : index
   %0 = memref.get_global @__constant_5xi32 : memref<5xi32>
   //      CHECK: %[[BINDING0_RAW:.+]] = iree.list.get %[[BINDINGS]][%c0] : !iree.list<memref<?xi8>>
-  // CHECK-NEXT: %[[BINDING0:.+]] = unrealized_conversion_cast %[[BINDING0_RAW]] : memref<?xi8> to memref<5xf32>
+  // CHECK-NEXT: %[[BINDING0:.+]] = builtin.unrealized_conversion_cast %[[BINDING0_RAW]] : memref<?xi8> to memref<5xf32>
   %1 = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : memref<5xf32>
   //      CHECK: %[[BINDING1_RAW:.+]] = iree.list.get %[[BINDINGS]][%c1] : !iree.list<memref<?xi8>>
-  // CHECK-NEXT: %[[BINDING1:.+]] = unrealized_conversion_cast %[[BINDING1_RAW]] : memref<?xi8> to memref<5xi32>
+  // CHECK-NEXT: %[[BINDING1:.+]] = builtin.unrealized_conversion_cast %[[BINDING1_RAW]] : memref<?xi8> to memref<5xi32>
   %2 = hal.interface.binding.subspan @io::@s0b1_xw_external[%c0] : memref<5xi32>
   %workgroup_size_x = hal.interface.workgroup.size[0] : index
   %workgroup_id_x = hal.interface.workgroup.id[0] : index
diff --git a/iree/compiler/InputConversion/MHLO/test/broadcasting.mlir b/iree/compiler/InputConversion/MHLO/test/broadcasting.mlir
index c23d9d8..ed0272e 100644
--- a/iree/compiler/InputConversion/MHLO/test/broadcasting.mlir
+++ b/iree/compiler/InputConversion/MHLO/test/broadcasting.mlir
@@ -408,7 +408,7 @@
 // -----
 // CHECK-LABEL: @fallbackDynamicReshape
 func @fallbackDynamicReshape(%arg0 : tensor<4x?x3x?xui32>, %arg1 : tensor<5xindex>) -> tensor<12x?x?x1x?xui32> {
-  // CHECK: %[[INPUT:.*]] = unrealized_conversion_cast %arg0 : tensor<4x?x3x?xui32> to tensor<4x?x3x?xi32>
+  // CHECK: %[[INPUT:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<4x?x3x?xui32> to tensor<4x?x3x?xi32>
   // CHECK-DAG: %[[C1:.*]] = constant 1 : index
   // CHECK-DAG: %[[RESULT_D1:.*]] = tensor.extract %arg1[%[[C1]]] : tensor<5xindex>
   // CHECK-DAG: %[[C2:.*]] = constant 2 : index
@@ -421,7 +421,7 @@
   // CHECK-DAG: %[[ARG_D3:.*]] = tensor.dim %[[INPUT]], %[[INDEX3]] : tensor<4x?x3x?xi32>
   // CHECK-DAG: %[[RESULT:.*]] = flow.tensor.reshape %[[INPUT]] : tensor<4x?x3x?xi32>{%[[ARG_D1]], %[[ARG_D3]]} -> tensor<12x?x?x1x?xi32>{%[[RESULT_D1]], %[[RESULT_D2]], %[[RESULT_D4]]}
   %0 = "mhlo.dynamic_reshape"(%arg0, %arg1) : (tensor<4x?x3x?xui32>, tensor<5xindex>) -> tensor<12x?x?x1x?xui32>
-  // CHECK: %[[UNCONVERTED_RESULT:.*]] = unrealized_conversion_cast %[[RESULT]] : tensor<12x?x?x1x?xi32> to tensor<12x?x?x1x?xui32>
+  // CHECK: %[[UNCONVERTED_RESULT:.*]] = builtin.unrealized_conversion_cast %[[RESULT]] : tensor<12x?x?x1x?xi32> to tensor<12x?x?x1x?xui32>
   // CHECK: return %[[UNCONVERTED_RESULT]]
   return %0 : tensor<12x?x?x1x?xui32>
 }
diff --git a/iree/compiler/InputConversion/MHLO/test/legalize_input_types.mlir b/iree/compiler/InputConversion/MHLO/test/legalize_input_types.mlir
index fdeafe1..f1e06b4 100644
--- a/iree/compiler/InputConversion/MHLO/test/legalize_input_types.mlir
+++ b/iree/compiler/InputConversion/MHLO/test/legalize_input_types.mlir
@@ -88,7 +88,7 @@
 }
 
 // -----
-// expected-error@+1 {{'func' op unable to legalize type of input 0}}
+// expected-error@+1 {{'builtin.func' op unable to legalize type of input 0}}
 func @tensorUnrankedArg(%arg0 : tensor<*xi64>) -> tensor<*xi64> {
   return %arg0 : tensor<*xi64>
 }
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 6a7a2ee..18ec93d 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 6a7a2ee8161da84d9a58a88b497b0b47c8df99f3
+Subproject commit 18ec93d9e60c687bfb2b39269f7f81d47b71a179
diff --git a/third_party/mlir-hlo b/third_party/mlir-hlo
index be5715b..7e00576 160000
--- a/third_party/mlir-hlo
+++ b/third_party/mlir-hlo
@@ -1 +1 @@
-Subproject commit be5715be382c6955aba0533a6dd5e50d79734bd0
+Subproject commit 7e00576f3f1e0a935b74a7320de40e56d3d87c11
diff --git a/third_party/tensorflow b/third_party/tensorflow
index faa23e3..0c225ef 160000
--- a/third_party/tensorflow
+++ b/third_party/tensorflow
@@ -1 +1 @@
-Subproject commit faa23e33c73721e8b1e90010866e358b922f5d15
+Subproject commit 0c225ef7450aea9b57cf3ece1e89c93f42a247af