Integrate LLVM at llvm/llvm-project@259cd6f89377

Updates LLVM usage to match
[259cd6f89377](https://github.com/llvm/llvm-project/commit/259cd6f89377)

PiperOrigin-RevId: 412425117
diff --git a/SUBMODULE_VERSIONS.txt b/SUBMODULE_VERSIONS.txt
index 2512ea9..78c52ff 100644
--- a/SUBMODULE_VERSIONS.txt
+++ b/SUBMODULE_VERSIONS.txt
@@ -4,7 +4,7 @@
 aa533abfd4232b01f9e57041d70114d5a77e6de0 third_party/googletest
 88b845dee001723c4a0db1fe5477de735b6d3bb0 third_party/liburing
 acd6f6f014c25e46363e718381e0b35205df2d83 third_party/libyaml
-43dc6d5d57d7e24d6d965ceac9fa9d292322d922 third_party/llvm-project
+259cd6f89377fdc17aabd381204c5bfe2ce15209 third_party/llvm-project
 5087ffb61af04fa8e35aa045a60b3356a8f69ca5 third_party/mlir-hlo
 3f701faace7addc75d16dea8a6cd769fa5b3f260 third_party/musl
 4c7697dbe973ed01ae6fbec37d186ebd05982e1f third_party/pybind11
diff --git a/iree/compiler/Codegen/Common/BUILD b/iree/compiler/Codegen/Common/BUILD
index 59f0b99..c4efb22 100644
--- a/iree/compiler/Codegen/Common/BUILD
+++ b/iree/compiler/Codegen/Common/BUILD
@@ -23,6 +23,7 @@
     tblgen = "@llvm-project//mlir:mlir-tblgen",
     td_file = "FoldTensorExtractOp.td",
     deps = [
+        "@llvm-project//mlir:BufferizationOpsTdFiles",
         "@llvm-project//mlir:MemRefOpsTdFiles",
         "@llvm-project//mlir:OpBaseTdFiles",
         "@llvm-project//mlir:TensorOpsTdFiles",
@@ -65,11 +66,13 @@
         "//llvm-external-projects/iree-dialects:IREELinalgExtTransforms",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:Affine",
+        "@llvm-project//mlir:AffineBufferizableOpInterfaceImpl",
         "@llvm-project//mlir:AffineUtils",
         "@llvm-project//mlir:Analysis",
         "@llvm-project//mlir:ArithBufferizableOpInterfaceImpl",
         "@llvm-project//mlir:ArithmeticDialect",
         "@llvm-project//mlir:BufferizableOpInterface",
+        "@llvm-project//mlir:BufferizationDialect",
         "@llvm-project//mlir:CFGTransforms",
         "@llvm-project//mlir:ComprehensiveBufferize",
         "@llvm-project//mlir:DialectUtils",
diff --git a/iree/compiler/Codegen/Common/CMakeLists.txt b/iree/compiler/Codegen/Common/CMakeLists.txt
index 5980b33..65786c0 100644
--- a/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -45,6 +45,7 @@
     IREELinalgExtPasses
     LLVMSupport
     MLIRAffine
+    MLIRAffineBufferizableOpInterfaceImpl
     MLIRAffineUtils
     MLIRAnalysis
     MLIRArithBufferizableOpInterfaceImpl
diff --git a/iree/compiler/Codegen/Common/FoldTensorExtractOp.td b/iree/compiler/Codegen/Common/FoldTensorExtractOp.td
index 84d8be1..98e86f6 100644
--- a/iree/compiler/Codegen/Common/FoldTensorExtractOp.td
+++ b/iree/compiler/Codegen/Common/FoldTensorExtractOp.td
@@ -7,12 +7,13 @@
 #ifndef IREE_COMPILER_CODEGEN_COMMON_FOLDTENSOREXTRACTOP
 #define IREE_COMPILER_CODEGEN_COMMON_FOLDTENSOREXTRACTOP
 
+include "mlir/Dialect/Bufferization/IR/BufferizationOps.td"
 include "mlir/Dialect/MemRef/IR/MemRefOps.td"
 include "mlir/Dialect/Tensor/IR/TensorOps.td"
 
 // Canonicalize unnecessary tensor_load when the load is used just for
 // an extract
-def : Pat<(Tensor_ExtractOp (TensorLoadOp $value), $indices),
+def : Pat<(Tensor_ExtractOp (Bufferization_ToTensorOp $value), $indices),
           (LoadOp $value, $indices)>;
 
 #endif // IREE_COMPILER_CODEGEN_COMMON_FOLDTENSOREXTRACTOP
diff --git a/iree/compiler/Codegen/Common/FoldTensorExtractOpPass.cpp b/iree/compiler/Codegen/Common/FoldTensorExtractOpPass.cpp
index ac4b3e7..959b32b 100644
--- a/iree/compiler/Codegen/Common/FoldTensorExtractOpPass.cpp
+++ b/iree/compiler/Codegen/Common/FoldTensorExtractOpPass.cpp
@@ -5,6 +5,7 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 #include "iree/compiler/Codegen/PassDetail.h"
 #include "iree/compiler/Codegen/Passes.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -35,13 +36,13 @@
 ///
 /// On LLVM side, the `std.constant` is handled by the
 /// `TensorConstantBufferizePass`, which creates a global object of `memref`
-/// type. To get the tensor back you get a tensor.load. If the above
-/// canonicalization pattern didnt exist, then a tensor.load would not be
+/// type. To get the tensor back you get a to_tensor. If the above
+/// canonicalization pattern didnt exist, then a to_tensor would not be
 /// needed.
 ///
 /// This pass is specifically undoing the canonicalization by folding
 ///
-/// (tensor_extract (tensor_load (get_global_memref:$value), $indices) to
+/// (tensor_extract (to_tensor (get_global_memref:$value), $indices) to
 ///
 /// (load $value, $indices)
 ///
diff --git a/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp b/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
index c2bb7d6..08a6c78 100644
--- a/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
+++ b/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
@@ -30,10 +30,12 @@
 #include "llvm/Support/ErrorHandling.h"
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
@@ -206,11 +208,10 @@
 ///   DispatchTensorStoreOp to the InitTensorOp must have bufferized in-place.
 struct StoreTensorOpAnchoredInitTensorEliminationStep
     : public InitTensorEliminationStep {
-  LogicalResult run(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
-                    DominanceInfo &domInfo,
+  LogicalResult run(FuncOp funcOp, BufferizationState &state,
                     SmallVector<Operation *> &newOps) override {
     return eliminateInitTensors(
-        funcOp, aliasInfo, domInfo,
+        funcOp, state,
         /*anchorMatchFunc=*/
         [&](OpOperand &operand) {
           return isa<IREE::Flow::DispatchTensorStoreOp>(operand.getOwner());
@@ -251,7 +252,7 @@
 
     // TODO: Find a better place to register external models.
     // Registers operations of other dialects.
-    linalg::comprehensive_bufferize::
+    linalg::comprehensive_bufferize::affine_ext::
         registerBufferizableOpInterfaceExternalModels(registry);
     linalg::comprehensive_bufferize::arith_ext::
         registerBufferizableOpInterfaceExternalModels(registry);
diff --git a/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp b/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp
index 3238c69..684bcff 100644
--- a/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp
+++ b/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp
@@ -52,6 +52,7 @@
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
@@ -632,8 +633,8 @@
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPointAfter(constantOp);
   auto memrefType = getMemrefTypeForTensor(tensorType);
-  Value memref =
-      b.create<memref::BufferCastOp>(constantOp.getLoc(), memrefType, result);
+  Value memref = b.create<bufferization::ToMemrefOp>(constantOp.getLoc(),
+                                                     memrefType, result);
   bvm.map(result, memref);
   return success();
 }
@@ -883,7 +884,8 @@
  public:
   LinalgBufferizePass(WorkgroupMemoryAllocationFn fn) : allocationFn(fn) {}
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<IREE::Util::UtilDialect, linalg::LinalgDialect,
+    registry.insert<mlir::bufferization::BufferizationDialect,
+                    IREE::Util::UtilDialect, linalg::LinalgDialect,
                     memref::MemRefDialect, scf::SCFDialect, StandardOpsDialect,
                     mlir::math::MathDialect, mlir::arith::ArithmeticDialect>();
   }
diff --git a/iree/compiler/Codegen/Common/test/fold_tensor_extract_op.mlir b/iree/compiler/Codegen/Common/test/fold_tensor_extract_op.mlir
index 92f2e96..798f5f7 100644
--- a/iree/compiler/Codegen/Common/test/fold_tensor_extract_op.mlir
+++ b/iree/compiler/Codegen/Common/test/fold_tensor_extract_op.mlir
@@ -4,7 +4,7 @@
 {
   %c1 = arith.constant 1 : index
   %c2 = arith.constant 2 : index
-  %0 = memref.tensor_load %arg0 : memref<2x3xi32>
+  %0 = bufferization.to_tensor %arg0 : memref<2x3xi32>
   %1 = tensor.extract %0[%c1, %c2] : tensor<2x3xi32>
   return %1 : i32
 }
diff --git a/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir b/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir
index 1f7e67b..4f6c87e 100644
--- a/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir
+++ b/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir
@@ -1057,7 +1057,7 @@
 }
 // CHECK-LABEL: func @constant()
 //       CHECK:   %[[CST:.+]] = arith.constant {{.+}} : tensor<2x2x3xi32>
-//       CHECK:   %[[MEMREF:.+]] = memref.buffer_cast %[[CST]] : memref<2x2x3xi32>
+//       CHECK:   %[[MEMREF:.+]] = bufferization.to_memref %[[CST]] : memref<2x2x3xi32>
 //       CHECK:   %[[RESULT:.+]] = hal.interface.binding.subspan @io::@ret0
 //       CHECK:   linalg.copy(%[[MEMREF]], %[[RESULT]])
 
@@ -1103,7 +1103,7 @@
 }
 // CHECK-LABEL: func @rhs_non_splat_constant
 //   CHECK-DAG:   %[[CONSTANT:.+]] = arith.constant {{.+}} : tensor<3x5xf32>
-//   CHECK-DAG:   %[[RHS:.+]] = memref.buffer_cast %[[CONSTANT]]
+//   CHECK-DAG:   %[[RHS:.+]] = bufferization.to_memref %[[CONSTANT]]
 //   CHECK-DAG:   %[[LHS_INPUT:.+]] = hal.interface.binding.subspan @io::@arg0[%{{.+}}] : memref<1x5x3x1xf32>
 //   CHECK-DAG:   %[[RETURN:.+]] = hal.interface.binding.subspan @io::@ret0[%{{.+}}] : memref<5x5xf32>
 //       CHECK:   %[[LHS:.+]] = memref.collapse_shape %[[LHS_INPUT]]
@@ -1451,8 +1451,8 @@
 
 //       CHECK-DAG: %[[CST1:.+]] = arith.constant dense<-2147483648> : tensor<i32>
 //       CHECK-DAG: %[[CST5:.+]] = arith.constant dense<[1, 2, 3, 4, 5]> : tensor<5xi32>
-//       CHECK: %[[CAST1:.+]] = memref.buffer_cast %[[CST1]] : memref<i32>
-//       CHECK: %[[CAST5:.+]] = memref.buffer_cast %[[CST5]] : memref<5xi32>
+//       CHECK: %[[CAST1:.+]] = bufferization.to_memref %[[CST1]] : memref<i32>
+//       CHECK: %[[CAST5:.+]] = bufferization.to_memref %[[CST5]] : memref<5xi32>
 //       CHECK: %[[INPUT:.+]] = hal.interface.binding.subspan @io::@ro0[%c0] : memref<5xf32>
 //       CHECK: %[[OUTPUT:.+]] = hal.interface.binding.subspan @io::@wo1[%c0] : memref<i32>
 //       CHECK: linalg.copy(%[[CAST1]], %[[OUTPUT]])
diff --git a/iree/compiler/Codegen/LLVMCPU/test/materialize_launch_configuration.mlir b/iree/compiler/Codegen/LLVMCPU/test/materialize_launch_configuration.mlir
index fa20497..e08abdf 100644
--- a/iree/compiler/Codegen/LLVMCPU/test/materialize_launch_configuration.mlir
+++ b/iree/compiler/Codegen/LLVMCPU/test/materialize_launch_configuration.mlir
@@ -591,8 +591,8 @@
         %c32 = arith.constant 32 : index
         %cst = arith.constant dense<[1.000000e+00, 0.707106769, 6.12323426E-17, -0.707106769]> : tensor<4xf32>
         %cst_0 = arith.constant dense<[-0.000000e+00, -0.707106769, -1.000000e+00, -0.707106769]> : tensor<4xf32>
-        %0 = memref.buffer_cast %cst_0 : memref<4xf32>
-        %1 = memref.buffer_cast %cst : memref<4xf32>
+        %0 = bufferization.to_memref %cst_0 : memref<4xf32>
+        %1 = bufferization.to_memref %cst : memref<4xf32>
         %2 = hal.interface.binding.subspan @io::@s0b0_rw_external[%c0] : memref<64x128x32xf32>
         %3 = hal.interface.binding.subspan @io::@s0b1_rw_external[%c0] : memref<64x128x32xf32>
         %workgroup_id_x = hal.interface.workgroup.id[0] : index
@@ -1029,7 +1029,7 @@
               %15 = affine.min affine_map<(d0) -> (2, -d0 + 7)>(%arg0)
               %16 = affine.min affine_map<(d0) -> (-d0 + 7, 2)>(%arg0)
               %17 = linalg.init_tensor [1, %16, %c7, %c64] : tensor<1x?x?x?xf32>
-              %18 = linalg.fill(%cst, %17) {__internal_linalg_transform__ = "workgroup"} : f32, tensor<1x?x?x?xf32> -> tensor<1x?x?x?xf32> 
+              %18 = linalg.fill(%cst, %17) {__internal_linalg_transform__ = "workgroup"} : f32, tensor<1x?x?x?xf32> -> tensor<1x?x?x?xf32>
               %19 = linalg.depthwise_conv_2d_nhwc_hwc {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%12, %14 : tensor<1x?x?x?xf32>, tensor<5x5x?xf32>) outs(%18 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
               flow.dispatch.tensor.store %19, %2, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %15, %c7, %c64], strides = [1, 1, 1, 1] : tensor<1x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:1x7x7x576xf32>
             }
@@ -1336,7 +1336,7 @@
             %11 = affine.min affine_map<(d0)[s0] -> (-d0 + 384, s0)>(%arg0)[%workgroup_size_y]
             %12 = affine.min affine_map<(d0)[s0] -> (-d0 + 128, s0)>(%arg1)[%workgroup_size_x]
             %13 = linalg.init_tensor [%11, %12] : tensor<?x?xf32>
-            %14 = linalg.fill(%cst, %13) : f32, tensor<?x?xf32> -> tensor<?x?xf32> 
+            %14 = linalg.fill(%cst, %13) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
             %15 = linalg.matmul ins(%8, %10 : tensor<?x512xf32>, tensor<512x?xf32>) outs(%14 : tensor<?x?xf32>) -> tensor<?x?xf32>
             flow.dispatch.tensor.store %15, %2, offsets = [%arg0, %arg1], sizes = [%7, %9], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:384x128xf32>
           }
diff --git a/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir b/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir
index 03bc55c..9124f75 100644
--- a/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir
+++ b/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir
@@ -239,7 +239,7 @@
             %9 = affine.apply affine_map<(d0) -> (d0 + 4)>(%arg0)
             %10 = affine.apply affine_map<(d0) -> (d0 + 3)>(%arg1)
             %11 = memref.subview %1[%9, %10] [%4, %7] [1, 1] : memref<?x?xi32> to memref<?x?xi32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>
-            linalg.copy(%8, %11) : memref<?x?xi32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>, memref<?x?xi32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>> 
+            linalg.copy(%8, %11) : memref<?x?xi32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>, memref<?x?xi32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>
           }
         }
         return
@@ -322,8 +322,8 @@
         %c32 = arith.constant 32 : index
         %cst = arith.constant dense<[1.000000e+00, 0.707106769, 6.12323426E-17, -0.707106769]> : tensor<4xf32>
         %cst_0 = arith.constant dense<[-0.000000e+00, -0.707106769, -1.000000e+00, -0.707106769]> : tensor<4xf32>
-        %0 = memref.buffer_cast %cst_0 : memref<4xf32>
-        %1 = memref.buffer_cast %cst : memref<4xf32>
+        %0 = bufferization.to_memref %cst_0 : memref<4xf32>
+        %1 = bufferization.to_memref %cst : memref<4xf32>
         %2 = hal.interface.binding.subspan @io::@s0b0_rw_external[%c0] : memref<64x128x32xf32>
         %3 = hal.interface.binding.subspan @io::@s0b1_rw_external[%c0] : memref<64x128x32xf32>
         %workgroup_id_x = hal.interface.workgroup.id[0] : index
diff --git a/iree/compiler/Codegen/SPIRV/BUILD b/iree/compiler/Codegen/SPIRV/BUILD
index 4a6e3dc..64a03d1 100644
--- a/iree/compiler/Codegen/SPIRV/BUILD
+++ b/iree/compiler/Codegen/SPIRV/BUILD
@@ -54,6 +54,7 @@
         "@llvm-project//mlir:ArithmeticDialect",
         "@llvm-project//mlir:ArithmeticToSPIRV",
         "@llvm-project//mlir:ArithmeticTransforms",
+        "@llvm-project//mlir:BufferizationDialect",
         "@llvm-project//mlir:DialectUtils",
         "@llvm-project//mlir:GPUDialect",
         "@llvm-project//mlir:GPUToSPIRV",
diff --git a/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp b/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
index e5f5546..e3f2a48 100644
--- a/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
+++ b/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
@@ -34,6 +34,7 @@
 #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h"
 #include "mlir/Conversion/TosaToStandard/TosaToStandard.h"
 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
@@ -391,7 +392,7 @@
   /// - unrealized_conversion_cast with the same source and target type.
   patterns.insert<
       FoldAsNoOp<memref::CollapseShapeOp>, FoldAsNoOp<memref::ExpandShapeOp>,
-      FoldAsNoOp<memref::BufferCastOp>, RemoveIdentityConversionCast>(
+      FoldAsNoOp<bufferization::ToMemrefOp>, RemoveIdentityConversionCast>(
       typeConverter, context);
 
   std::unique_ptr<ConversionTarget> target =
diff --git a/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ext_ops.mlir b/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ext_ops.mlir
index 179a92f..b92da86 100644
--- a/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ext_ops.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ext_ops.mlir
@@ -188,8 +188,8 @@
         %c32 = arith.constant 32 : index
         %cst = arith.constant dense<[1.000000e+00, 0.707106769, 6.12323426E-17, -0.707106769]> : tensor<4xf32>
         %cst_0 = arith.constant dense<[-0.000000e+00, -0.707106769, -1.000000e+00, -0.707106769]> : tensor<4xf32>
-        %0 = memref.buffer_cast %cst_0 : memref<4xf32>
-        %1 = memref.buffer_cast %cst : memref<4xf32>
+        %0 = bufferization.to_memref %cst_0 : memref<4xf32>
+        %1 = bufferization.to_memref %cst : memref<4xf32>
         %2 = hal.interface.binding.subspan @io::@s0b0_rw_external[%c0] : memref<64x128x32xf32>
         %3 = hal.interface.binding.subspan @io::@s0b1_rw_external[%c0] : memref<64x128x32xf32>
         %workgroup_id_x = hal.interface.workgroup.id[0] : index
diff --git a/iree/compiler/Dialect/HAL/Transforms/BUILD b/iree/compiler/Dialect/HAL/Transforms/BUILD
index 341258a..05e662b 100644
--- a/iree/compiler/Dialect/HAL/Transforms/BUILD
+++ b/iree/compiler/Dialect/HAL/Transforms/BUILD
@@ -49,6 +49,7 @@
         "//iree/compiler/Dialect/Util/Transforms",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:AffineToStandard",
+        "@llvm-project//mlir:BufferizationDialect",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:StandardOps",
diff --git a/iree/compiler/Dialect/HAL/Transforms/TranslateExecutables.cpp b/iree/compiler/Dialect/HAL/Transforms/TranslateExecutables.cpp
index bc9ecc4..76d403f 100644
--- a/iree/compiler/Dialect/HAL/Transforms/TranslateExecutables.cpp
+++ b/iree/compiler/Dialect/HAL/Transforms/TranslateExecutables.cpp
@@ -12,6 +12,7 @@
 #include "iree/compiler/Dialect/HAL/Target/TargetBackend.h"
 #include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
 #include "llvm/ADT/StringSet.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Diagnostics.h"
@@ -44,6 +45,7 @@
 
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<IREE::HAL::HALDialect>();
+    registry.insert<bufferization::BufferizationDialect>();
     auto targetBackend = getTargetBackend(target);
     if (targetBackend) {
       targetBackend->getDependentDialects(registry);
@@ -102,6 +104,7 @@
 
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<IREE::HAL::HALDialect>();
+    registry.insert<bufferization::BufferizationDialect>();
     auto targetBackends = getTargetBackends(getRegisteredTargetBackends());
     for (auto &targetBackend : targetBackends) {
       targetBackend->getDependentDialects(registry);
diff --git a/iree/compiler/Dialect/Shape/IR/BUILD b/iree/compiler/Dialect/Shape/IR/BUILD
index 81a4410..ad79a52 100644
--- a/iree/compiler/Dialect/Shape/IR/BUILD
+++ b/iree/compiler/Dialect/Shape/IR/BUILD
@@ -69,6 +69,7 @@
         "@llvm-project//mlir:SideEffects",
         "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TensorDialect",
         "@llvm-project//mlir:Transforms",
         "@llvm-project//mlir:ViewLikeInterface",
     ],
diff --git a/iree/compiler/Dialect/Shape/IR/Builders.cpp b/iree/compiler/Dialect/Shape/IR/Builders.cpp
index 31cc942..ca4c6bf 100644
--- a/iree/compiler/Dialect/Shape/IR/Builders.cpp
+++ b/iree/compiler/Dialect/Shape/IR/Builders.cpp
@@ -10,6 +10,7 @@
 #include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
 #include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Diagnostics.h"
 
 namespace mlir {
diff --git a/iree/compiler/Dialect/Shape/IR/Folders.cpp b/iree/compiler/Dialect/Shape/IR/Folders.cpp
index 59d3251..5322162 100644
--- a/iree/compiler/Dialect/Shape/IR/Folders.cpp
+++ b/iree/compiler/Dialect/Shape/IR/Folders.cpp
@@ -10,6 +10,7 @@
 #include "llvm/Support/Debug.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LogicalResult.h"
diff --git a/iree/compiler/Dialect/VM/Conversion/MemRefToVM/BUILD b/iree/compiler/Dialect/VM/Conversion/MemRefToVM/BUILD
index 236a670..b6c86ca 100644
--- a/iree/compiler/Dialect/VM/Conversion/MemRefToVM/BUILD
+++ b/iree/compiler/Dialect/VM/Conversion/MemRefToVM/BUILD
@@ -24,6 +24,7 @@
         "//iree/compiler/Dialect/VM/IR",
         "@llvm-project//mlir:Affine",
         "@llvm-project//mlir:ArithmeticDialect",
+        "@llvm-project//mlir:BufferizationDialect",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:MemRefDialect",
         "@llvm-project//mlir:Pass",
diff --git a/iree/compiler/Dialect/VM/Conversion/MemRefToVM/ConvertMemRefToVM.cpp b/iree/compiler/Dialect/VM/Conversion/MemRefToVM/ConvertMemRefToVM.cpp
index 98cedac..defc7c7 100644
--- a/iree/compiler/Dialect/VM/Conversion/MemRefToVM/ConvertMemRefToVM.cpp
+++ b/iree/compiler/Dialect/VM/Conversion/MemRefToVM/ConvertMemRefToVM.cpp
@@ -12,6 +12,7 @@
 #include "iree/compiler/Dialect/VM/IR/VMOps.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Attributes.h"
@@ -233,7 +234,8 @@
     return llvm::None;
   });
 
-  patterns.insert<FoldAsNoOp<memref::BufferCastOp>>(typeConverter, context);
+  patterns.insert<FoldAsNoOp<bufferization::ToMemrefOp>>(typeConverter,
+                                                         context);
   patterns.insert<ConvertMemRefGlobalOp, ConvertMemRefGetGlobalOp,
                   ConvertMemRefLoadOp, ConvertMemRefStoreOp>(typeConverter,
                                                              context);
diff --git a/iree/tools/BUILD b/iree/tools/BUILD
index 87db68a..039fc0c 100644
--- a/iree/tools/BUILD
+++ b/iree/tools/BUILD
@@ -139,6 +139,7 @@
     deps = [
         "@llvm-project//mlir:Affine",
         "@llvm-project//mlir:AffineTransforms",
+        "@llvm-project//mlir:BufferizationDialect",
         "@llvm-project//mlir:ConversionPasses",
         "@llvm-project//mlir:GPUDialect",
         "@llvm-project//mlir:GPUToSPIRV",
diff --git a/iree/tools/init_mlir_dialects.h b/iree/tools/init_mlir_dialects.h
index 2e8148a..2ddcf82 100644
--- a/iree/tools/init_mlir_dialects.h
+++ b/iree/tools/init_mlir_dialects.h
@@ -13,6 +13,7 @@
 #define IREE_TOOLS_INIT_MLIR_DIALECTS_H_
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
@@ -38,6 +39,7 @@
 inline void registerMlirDialects(DialectRegistry &registry) {
   // clang-format off
   registry.insert<AffineDialect,
+                  bufferization::BufferizationDialect,
                   gpu::GPUDialect,
                   LLVM::LLVMDialect,
                   linalg::LinalgDialect,
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 43dc6d5..259cd6f 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 43dc6d5d57d7e24d6d965ceac9fa9d292322d922
+Subproject commit 259cd6f89377fdc17aabd381204c5bfe2ce15209