Merge remote-tracking branch 'origin' into google-to-main
diff --git a/SUBMODULE_VERSIONS.txt b/SUBMODULE_VERSIONS.txt
index 9680b9d..2055f64 100644
--- a/SUBMODULE_VERSIONS.txt
+++ b/SUBMODULE_VERSIONS.txt
@@ -3,15 +3,15 @@
 32aabef14ccc0e3ac15602c99476d25dcd230d19 third_party/flatcc
 aa533abfd4232b01f9e57041d70114d5a77e6de0 third_party/googletest
 88b845dee001723c4a0db1fe5477de735b6d3bb0 third_party/liburing
-39c6fa385c29b90a97511d4aff7269e40ad210eb third_party/llvm-bazel
-4c4f1ae93ea7477ccb4772007fc78313f5a0644f third_party/llvm-project
+30724fe230e9af8e408560f23435d820b5c99924 third_party/llvm-bazel
+36111f28edb1182273c6409c3fb7808e0e9cbd60 third_party/llvm-project
 1a4dea1387e34538ba159a56204a6982e728e337 third_party/mlir-emitc
-82ecad259ef0566a4500fa59192457ff69f83fe6 third_party/mlir-hlo
+b6a8145dafcd927c0600af35d811aa5ef8297d6c third_party/mlir-hlo
 4c7697dbe973ed01ae6fbec37d186ebd05982e1f third_party/pybind11
 2e1b5fb39ebc2ef4cb77005f8267e4f3a6241ba1 third_party/spirv_cross
 f5417a4b6633c3217c9a1bc2f0c70b1454975ba7 third_party/spirv_headers
 b42009b3b9d4ca35bc703f5310eedc74f584be58 third_party/stblib
-ff964e92797b4d7e5c23a3143bfc74d1572e6c7b third_party/tensorflow
+3ef6fbfd02716f774024afae711383fdb0d8a30c third_party/tensorflow
 f03b677ffa0fd96fcf859c32e79b740fac7dd59e third_party/tracy
 9d10a96f2d57c3c37e167f2e73c9a31ac2e51fa5 third_party/vulkan_headers
 8d4a9e9174a9c6ad6a3a3ae981b915ef13fc12c4 third_party/vulkan_memory_allocator
diff --git a/integrations/tensorflow/e2e/mobile_bert_squad_test.py b/integrations/tensorflow/e2e/mobile_bert_squad_test.py
index bc11a57..99abbbb 100644
--- a/integrations/tensorflow/e2e/mobile_bert_squad_test.py
+++ b/integrations/tensorflow/e2e/mobile_bert_squad_test.py
@@ -26,10 +26,10 @@
 
 MAX_SEQ_LENGTH = 384  # Max input sequence length used in mobilebert_squad.
 
-FILE_NAME = 'mobilebert_squad_savedmodels.tar.gz'
+FILE_NAME = 'mobilebert_squad_savedmodels'
 MODEL_URL = posixpath.join(
-    'https://storage.googleapis.com/cloud-tpu-checkpoints/mobilebert/',
-    FILE_NAME)
+    f'https://storage.googleapis.com/cloud-tpu-checkpoints/mobilebert/{FILE_NAME}.tar.gz'
+)
 
 
 class MobileBertSquadTest(tf_test_utils.TracedModuleTestCase):
diff --git a/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp b/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp
index cf7381b..d8e0aff 100644
--- a/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp
+++ b/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp
@@ -53,6 +53,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Matchers.h"
@@ -195,8 +196,8 @@
       .Case<ConstantOp>([&](ConstantOp constantOp) { return true; })
       .Case<linalg::TensorCollapseShapeOp, linalg::TensorExpandShapeOp>(
           [&](auto op) { return isFromReadOnlyTensor(op.src(), plan); })
-      .Case<SubTensorOp>([&](SubTensorOp subTensorOp) {
-        return isFromReadOnlyTensor(subTensorOp.source(), plan);
+      .Case<tensor::ExtractSliceOp>([&](tensor::ExtractSliceOp sliceOp) {
+        return isFromReadOnlyTensor(sliceOp.source(), plan);
       })
       .Case<IREE::Flow::DispatchTensorLoadOp>(
           [&](IREE::Flow::DispatchTensorLoadOp loadOp) {
@@ -425,7 +426,7 @@
   return success();
 }
 
-static LogicalResult analyseSubTensorOp(SubTensorOp subTensorOp,
+static LogicalResult analyseSubTensorOp(tensor::ExtractSliceOp subTensorOp,
                                         BufferizationPlan &plan) {
   if (!canUsersHandleSubviews(subTensorOp)) {
     plan.insert(subTensorOp.source());
@@ -505,7 +506,7 @@
 ///   %result = scf.for %arg0 = ... iter_args(%arg1 = %init) {
 ///     %st = subtensor %arg1[...]
 ///
-///     %yieldVal = subtensor_insert %val, %arg1[...]
+///     %yieldVal = tensor.insert_slice %val, %arg1[...]
 ///     scf.yield %yieldVal
 ///   }
 ///
@@ -520,27 +521,29 @@
     auto isDestructiveUpdateUses = [&](OpOperand &use) -> bool {
       Operation *user = use.getOwner();
       return TypeSwitch<Operation *, bool>(user)
-          .Case<SubTensorOp>([&](SubTensorOp subTensorOp) {
-            return subTensorOp.source() == arg;
+          .Case<tensor::ExtractSliceOp>([&](tensor::ExtractSliceOp sliceOp) {
+            return sliceOp.source() == arg;
           })
-          .Case<SubTensorInsertOp>([&](SubTensorInsertOp subTensorInsertOp) {
-            return subTensorInsertOp.dest() == arg;
-          })
+          .Case<tensor::InsertSliceOp>(
+              [&](tensor::InsertSliceOp subTensorInsertOp) {
+                return subTensorInsertOp.dest() == arg;
+              })
           .Case<memref::DimOp, scf::YieldOp>([&](auto op) { return true; })
           .Default([&](Operation *op) { return false; });
     };
     if (llvm::all_of(arg.getUses(), isDestructiveUpdateUses)) {
       for (Operation *user : arg.getUsers()) {
         TypeSwitch<Operation *>(user)
-            .Case<SubTensorOp>([&](SubTensorOp subTensorOp) {
-              plan.unionSets(subTensorOp.source(), subTensorOp.result());
+            .Case<tensor::ExtractSliceOp>([&](tensor::ExtractSliceOp sliceOp) {
+              plan.unionSets(sliceOp.source(), sliceOp.result());
             })
-            .Case<SubTensorInsertOp>([&](SubTensorInsertOp subTensorInsertOp) {
-              if (!isFromReadOnlyTensor(subTensorInsertOp.source(), plan)) {
-                plan.unionSets(subTensorInsertOp.source(),
-                               subTensorInsertOp.dest());
-              }
-            })
+            .Case<tensor::InsertSliceOp>(
+                [&](tensor::InsertSliceOp subTensorInsertOp) {
+                  if (!isFromReadOnlyTensor(subTensorInsertOp.source(), plan)) {
+                    plan.unionSets(subTensorInsertOp.source(),
+                                   subTensorInsertOp.dest());
+                  }
+                })
             .Default([&](Operation *) {});
       }
     }
@@ -582,14 +585,15 @@
               return analyseSingleOperandResultOp(reshapeOp.src(),
                                                   reshapeOp.result(), plan);
             })
-        .Case<SubTensorOp>([&](SubTensorOp subTensorOp) {
-          return analyseSubTensorOp(subTensorOp, plan);
+        .Case<tensor::ExtractSliceOp>([&](tensor::ExtractSliceOp sliceOp) {
+          return analyseSubTensorOp(sliceOp, plan);
         })
-        .Case<SubTensorInsertOp>([&](SubTensorInsertOp subTensorInsertOp) {
-          return analyseDestructiveUpdateOp(
-              subTensorInsertOp, subTensorInsertOp.source(),
-              subTensorInsertOp.dest(), subTensorInsertOp.result(), plan);
-        })
+        .Case<tensor::InsertSliceOp>(
+            [&](tensor::InsertSliceOp subTensorInsertOp) {
+              return analyseDestructiveUpdateOp(
+                  subTensorInsertOp, subTensorInsertOp.source(),
+                  subTensorInsertOp.dest(), subTensorInsertOp.result(), plan);
+            })
         .Case<tensor::CastOp>([&](tensor::CastOp castOp) {
           return analyseSingleOperandResultOp(castOp.source(), castOp.dest(),
                                               plan);
@@ -681,9 +685,9 @@
     }
   } else if (auto loadOp = dyn_cast<IREE::Flow::DispatchTensorLoadOp>(op)) {
     dynamicDims = llvm::to_vector<4>(loadOp.sizes());
-  } else if (auto subTensorOp = dyn_cast<SubTensorOp>(op)) {
-    dynamicDims = llvm::to_vector<4>(subTensorOp.sizes());
-  } else if (auto subTensorInsertOp = dyn_cast<SubTensorInsertOp>(op)) {
+  } else if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
+    dynamicDims = llvm::to_vector<4>(sliceOp.sizes());
+  } else if (auto subTensorInsertOp = dyn_cast<tensor::InsertSliceOp>(op)) {
     dynamicDims = getDynamicDims(b, loc, subTensorInsertOp.dest());
   } else if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op)) {
     dynamicDims = getDynamicDims(b, loc, transferWriteOp.source());
@@ -791,12 +795,12 @@
                                           const BlockAndValueMapping &bvm) {
   SmallVector<Value, 4> operandsOfSubviewOp;
   auto op = cast<OffsetSizeAndStrideOpInterface>(storeOp);
-  Value target =
-      TypeSwitch<Operation *, Value>(op)
-          .Case<IREE::Flow::DispatchTensorStoreOp>(
-              [&](auto storeOp) { return storeOp.target(); })
-          .Case<SubTensorInsertOp>([&](auto storeOp) { return storeOp.dest(); })
-          .Default([](Operation *) { return nullptr; });
+  Value target = TypeSwitch<Operation *, Value>(op)
+                     .Case<IREE::Flow::DispatchTensorStoreOp>(
+                         [&](auto storeOp) { return storeOp.target(); })
+                     .Case<tensor::InsertSliceOp>(
+                         [&](auto storeOp) { return storeOp.dest(); })
+                     .Default([](Operation *) { return nullptr; });
   if (!target) return nullptr;
   operandsOfSubviewOp.push_back(bvm.lookup(target));
   operandsOfSubviewOp.append(op.offsets().begin(), op.offsets().end());
@@ -905,8 +909,8 @@
     Operation *op = it.first->getOwner();
     resultBuffer =
         TypeSwitch<Operation *, Value>(op)
-            .Case<scf::IfOp, scf::ForOp, linalg::LinalgOp, SubTensorInsertOp,
-                  vector::TransferWriteOp>(
+            .Case<scf::IfOp, scf::ForOp, linalg::LinalgOp,
+                  tensor::InsertSliceOp, vector::TransferWriteOp>(
                 [&](auto op) { return resultBuffer; })
             .Case<linalg::TensorCollapseShapeOp, linalg::TensorExpandShapeOp>(
                 [&](auto reshapeOp) {
@@ -980,7 +984,7 @@
 }
 
 /// Converts a `subtensor` operation to a `subview` operation.
-static Value getAliasingBufferForResult(OpBuilder &b, SubTensorOp op,
+static Value getAliasingBufferForResult(OpBuilder &b, tensor::ExtractSliceOp op,
                                         BlockAndValueMapping &bvm) {
   Location loc = op.getLoc();
   Value srcTensor = op.source();
@@ -1020,10 +1024,10 @@
 static SmallVector<Value, 4> getAliasingBuffersForResults(
     OpBuilder &b, Operation *op, BlockAndValueMapping &bvm) {
   return TypeSwitch<Operation *, SmallVector<Value, 4>>(op)
-      .Case<IREE::Flow::DispatchTensorLoadOp, SubTensorOp, tensor::CastOp>(
-          [&](auto singleResultOp) -> SmallVector<Value, 4> {
-            return {getAliasingBufferForResult(b, singleResultOp, bvm)};
-          })
+      .Case<IREE::Flow::DispatchTensorLoadOp, tensor::ExtractSliceOp,
+            tensor::CastOp>([&](auto singleResultOp) -> SmallVector<Value, 4> {
+        return {getAliasingBufferForResult(b, singleResultOp, bvm)};
+      })
       .Case<linalg::TensorCollapseShapeOp, linalg::TensorExpandShapeOp>(
           [&](auto reshapeOp) -> SmallVector<Value, 4> {
             return {getAliasingBufferForReshapeResult(b, reshapeOp, bvm)};
@@ -1238,12 +1242,12 @@
   return success();
 }
 
-/// Converts a `subtensor_insert` operation to buffers by
+/// Converts a `tensor.insert_slice` operation to buffers by
 /// - Allocating a buffer for the result (if needed), and copying the
 ///   destination value into this buffer.
 /// - Copying the source values into a subview of the result buffer.
 static LogicalResult convertSubTensorInsertOp(OpBuilder &b,
-                                              SubTensorInsertOp op,
+                                              tensor::InsertSliceOp op,
                                               BlockAndValueMapping &bvm,
                                               BufferizationPlan &plan) {
   Location loc = op.getLoc();
@@ -1519,20 +1523,19 @@
           return convertScfIfOp(b, ifOp, bvm, plan);
         })
         .Case<IREE::Flow::DispatchTensorLoadOp, linalg::TensorCollapseShapeOp,
-              linalg::TensorExpandShapeOp, SubTensorOp, tensor::CastOp>(
-            [&](auto aliasingOp) {
-              auto aliasingBuffers =
-                  getAliasingBuffersForResults(b, aliasingOp, bvm);
-              if (failed(getOrAllocateResultBuffers(b, aliasingOp,
-                                                    aliasingBuffers, bvm, plan,
-                                                    allocationFn))) {
-                return failure();
-              }
-              copyFromAliasingBufferToResultBuffer(
-                  b, aliasingOp->getLoc(), aliasingOp->getOperand(0),
-                  aliasingOp->getResult(0), aliasingBuffers, bvm, plan);
-              return success();
-            })
+              linalg::TensorExpandShapeOp, tensor::ExtractSliceOp,
+              tensor::CastOp>([&](auto aliasingOp) {
+          auto aliasingBuffers =
+              getAliasingBuffersForResults(b, aliasingOp, bvm);
+          if (failed(getOrAllocateResultBuffers(b, aliasingOp, aliasingBuffers,
+                                                bvm, plan, allocationFn))) {
+            return failure();
+          }
+          copyFromAliasingBufferToResultBuffer(
+              b, aliasingOp->getLoc(), aliasingOp->getOperand(0),
+              aliasingOp->getResult(0), aliasingBuffers, bvm, plan);
+          return success();
+        })
         .Case<linalg::PadTensorOp>([&](linalg::PadTensorOp padTensorOp) {
           if (failed(getOrAllocateResultBuffers(b, padTensorOp, bvm, plan,
                                                 allocationFn))) {
@@ -1547,13 +1550,14 @@
           }
           return convertAnyLinalgOp(b, linalgOp, bvm, plan, allocationFn);
         })
-        .Case<SubTensorInsertOp>([&](SubTensorInsertOp subTensorInsertOp) {
-          if (failed(getOrAllocateResultBuffers(b, subTensorInsertOp, bvm, plan,
-                                                allocationFn))) {
-            return failure();
-          }
-          return convertSubTensorInsertOp(b, subTensorInsertOp, bvm, plan);
-        })
+        .Case<tensor::InsertSliceOp>(
+            [&](tensor::InsertSliceOp subTensorInsertOp) {
+              if (failed(getOrAllocateResultBuffers(b, subTensorInsertOp, bvm,
+                                                    plan, allocationFn))) {
+                return failure();
+              }
+              return convertSubTensorInsertOp(b, subTensorInsertOp, bvm, plan);
+            })
         .Case<tensor::InsertOp>([&](tensor::InsertOp tensorInsertOp) {
           if (failed(getOrAllocateResultBuffers(b, tensorInsertOp, bvm, plan,
                                                 allocationFn))) {
diff --git a/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir b/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir
index 27375ad..99bd30e 100644
--- a/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir
+++ b/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir
@@ -800,7 +800,7 @@
     %8 = muli %workgroup_size_x, %workgroup_count_x : index
     scf.for %arg1 = %7 to %c3 step %8 {
       %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1)>(%arg0)[%workgroup_size_y]
-      %10 = subtensor %4[%arg0, 0] [%9, 2] [1, 1] : tensor<1x2xf32> to tensor<?x2xf32>
+      %10 = tensor.extract_slice %4[%arg0, 0] [%9, 2] [1, 1] : tensor<1x2xf32> to tensor<?x2xf32>
       %11 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 3)>(%arg1)[%workgroup_size_x]
       %12 = flow.dispatch.tensor.load %1, offsets = [%c0, %arg1], sizes = [%c2, %11], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:2x3xf32> -> tensor<2x?xf32>
       %13 = linalg.init_tensor [%9, %11] : tensor<?x?xf32>
@@ -842,7 +842,7 @@
   %4 = hal.interface.load.constant offset = 2 : index
   %5 = hal.interface.load.constant offset = 3 : index
   %6 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:?x?xi32> -> tensor<?x?xi32>
-  %7 = subtensor %6[%2, %3] [%4, %5] [1, 1] : tensor<?x?xi32> to tensor<?x?xi32>
+  %7 = tensor.extract_slice %6[%2, %3] [%4, %5] [1, 1] : tensor<?x?xi32> to tensor<?x?xi32>
   flow.dispatch.tensor.store %7, %1, offsets = [], sizes = [], strides = [] : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:?x?xi32>
   return
 }
@@ -867,7 +867,7 @@
   %4 = hal.interface.load.constant offset = 2 : index
   %5 = hal.interface.load.constant offset = 3 : index
   %6 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:?x?x?xi32> -> tensor<?x?x?xi32>
-  %7 = subtensor %6[%2, %2, %3] [%4, 1, %5] [1, 1, 1] : tensor<?x?x?xi32> to tensor<?x?xi32>
+  %7 = tensor.extract_slice %6[%2, %2, %3] [%4, 1, %5] [1, 1, 1] : tensor<?x?x?xi32> to tensor<?x?xi32>
   flow.dispatch.tensor.store %7, %1, offsets = [], sizes = [], strides = [] : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:?x?xi32>
   return
 }
@@ -895,8 +895,8 @@
   %7 = hal.interface.load.constant offset = 4 : index
   %8 = hal.interface.load.constant offset = 5 : index
   %9 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:?x?x?xi32> -> tensor<?x?x?xi32>
-  %10 = subtensor %9[%3, %4, %5] [%6, %7, %8] [1, 1, 1] : tensor<?x?x?xi32> to tensor<?x?x?xi32>
-  %11 = subtensor %9[%3, %4, %5] [%6, 1, %8] [1, 1, 1] : tensor<?x?x?xi32> to tensor<?x?xi32>
+  %10 = tensor.extract_slice %9[%3, %4, %5] [%6, %7, %8] [1, 1, 1] : tensor<?x?x?xi32> to tensor<?x?x?xi32>
+  %11 = tensor.extract_slice %9[%3, %4, %5] [%6, 1, %8] [1, 1, 1] : tensor<?x?x?xi32> to tensor<?x?xi32>
   flow.dispatch.tensor.store %10, %1, offsets = [], sizes = [], strides = [] : tensor<?x?x?xi32> -> !flow.dispatch.tensor<writeonly:?x?x?xi32>
   flow.dispatch.tensor.store %11, %2, offsets = [%3, %5], sizes = [%6, %8], strides = [1, 1] : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:?x?xi32>
   return
@@ -946,7 +946,7 @@
   %0 = hal.interface.binding.subspan @io::@arg0[%c0] : !flow.dispatch.tensor<readonly:?x?xi32>
   %1 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:?x?xi32>
   %2 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:?x?xi32> -> tensor<?x?xi32>
-  %3 = subtensor %2[1, 0] [1, 4] [1, 1] : tensor<?x?xi32> to tensor<1x4xi32>
+  %3 = tensor.extract_slice %2[1, 0] [1, 4] [1, 1] : tensor<?x?xi32> to tensor<1x4xi32>
   flow.dispatch.tensor.store %3, %1, offsets = [], sizes = [], strides = [] : tensor<1x4xi32> -> !flow.dispatch.tensor<writeonly:?x?xi32>
   return
 }
@@ -972,7 +972,7 @@
   %4 = flow.dispatch.tensor.load %1, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:?x?xi32> -> tensor<?x?xi32>
   %5 = memref.dim %3, %c0 : tensor<?x?xi32>
   %6 = memref.dim %3, %c1 : tensor<?x?xi32>
-  %7 = subtensor_insert %3 into %4[3, 4] [%5, %6] [1, 1] : tensor<?x?xi32> into tensor<?x?xi32>
+  %7 = tensor.insert_slice %3 into %4[3, 4] [%5, %6] [1, 1] : tensor<?x?xi32> into tensor<?x?xi32>
   flow.dispatch.tensor.store %7, %2, offsets = [], sizes = [], strides = [] : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:?x?xi32>
   return
 }
@@ -1077,9 +1077,9 @@
     %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
     scf.for %arg1 = %6 to %c5 step %7 {
       %8 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 5)>(%arg0)[%workgroup_size_y]
-      %9 = subtensor %3[%arg0, 0] [%8, 3] [1, 1] : tensor<5x3xf32> to tensor<?x3xf32>
+      %9 = tensor.extract_slice %3[%arg0, 0] [%8, 3] [1, 1] : tensor<5x3xf32> to tensor<?x3xf32>
       %10 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 5)>(%arg1)[%workgroup_size_x]
-      %11 = subtensor %cst[0, %arg1] [3, %10] [1, 1] : tensor<3x5xf32> to tensor<3x?xf32>
+      %11 = tensor.extract_slice %cst[0, %arg1] [3, %10] [1, 1] : tensor<3x5xf32> to tensor<3x?xf32>
       %12 = linalg.init_tensor [%8, %10] : tensor<?x?xf32>
       %13 = linalg.fill(%12, %cst_0) : tensor<?x?xf32>, f32 -> tensor<?x?xf32>
       %14 = linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%9, %11 : tensor<?x3xf32>, tensor<3x?xf32>) outs(%13 : tensor<?x?xf32>) -> tensor<?x?xf32>
@@ -1211,16 +1211,16 @@
     scf.for %arg1 = %7 to %dim1 step %8 {
       %9 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg0)[%workgroup_size_y, %dim0]
       %10 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg1)[%workgroup_size_x, %dim1]
-      %11 = subtensor %2[%arg0, %arg1] [%9, %10] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+      %11 = tensor.extract_slice %2[%arg0, %arg1] [%9, %10] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
       %12 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg0)[%workgroup_size_y, %dim0]
       %13 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg1)[%workgroup_size_x, %dim1]
-      %14 = subtensor %2[%arg0, %arg1] [%12, %13] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+      %14 = tensor.extract_slice %2[%arg0, %arg1] [%12, %13] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
       %15 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg0)[%workgroup_size_y, %dim0]
       %16 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg1)[%workgroup_size_x, %dim1]
-      %17 = subtensor %4[%arg0, %arg1] [%15, %16] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+      %17 = tensor.extract_slice %4[%arg0, %arg1] [%15, %16] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
       %18 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg0)[%workgroup_size_y, %dim0]
       %19 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg1)[%workgroup_size_x, %dim1]
-      %20 = subtensor %4[%arg0, %arg1] [%18, %19] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+      %20 = tensor.extract_slice %4[%arg0, %arg1] [%18, %19] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
       %21 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg0)[%workgroup_size_y, %dim0]
       %22 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg1)[%workgroup_size_x, %dim1]
       %23 = linalg.init_tensor [%21, %22] : tensor<?x?xf32>
@@ -1498,7 +1498,7 @@
   %3 = flow.dispatch.tensor.load %1, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readwrite:?x?x?xf32> -> tensor<?x?x?xf32>
   %4 = memref.dim %3, %c1 : tensor<?x?x?xf32>
   %5 = memref.dim %3, %c2 : tensor<?x?x?xf32>
-  %6 = subtensor_insert %2 into %3[0, 0, 0] [1, %4, %5] [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?xf32>
+  %6 = tensor.insert_slice %2 into %3[0, 0, 0] [1, %4, %5] [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?xf32>
   flow.dispatch.tensor.store %6, %1, offsets = [], sizes = [], strides = [] : tensor<?x?x?xf32> -> !flow.dispatch.tensor<readwrite:?x?x?xf32>
   return
 }
@@ -1802,7 +1802,7 @@
       %15 = linalg.init_tensor [4, 4] : tensor<4x4xf32>
       %16 = linalg.fill(%15, %cst) : tensor<4x4xf32>, f32 -> tensor<4x4xf32>
       %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%13, %14 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%16 : tensor<4x4xf32>) -> tensor<4x4xf32>
-      %18 = subtensor %17[0, 0] [%7, %9] [1, 1] : tensor<4x4xf32> to tensor<?x?xf32>
+      %18 = tensor.extract_slice %17[0, 0] [%7, %9] [1, 1] : tensor<4x4xf32> to tensor<?x?xf32>
       flow.dispatch.tensor.store %18, %2, offsets = [%arg0, %arg1], sizes = [%11, %12], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:?x?xf32>
     }
   }
@@ -2032,16 +2032,16 @@
             %15 = scf.for %arg6 = %c0 to %c144 step %c24 iter_args(%arg7 = %arg5) -> (tensor<?x?xf32>) {
               %16 = affine.min #map3(%arg2)
               %17 = affine.min #map4(%arg6)
-              %18 = subtensor %8[%arg2, %arg6] [%16, %17] [1, 1] : tensor<?x144xf32> to tensor<?x?xf32>
+              %18 = tensor.extract_slice %8[%arg2, %arg6] [%16, %17] [1, 1] : tensor<?x144xf32> to tensor<?x?xf32>
               %19 = affine.min #map5(%arg4)
-              %20 = subtensor %10[%arg6, %arg4] [%17, %19] [1, 1] : tensor<144x?xf32> to tensor<?x?xf32>
+              %20 = tensor.extract_slice %10[%arg6, %arg4] [%17, %19] [1, 1] : tensor<144x?xf32> to tensor<?x?xf32>
               %21 = memref.dim %arg7, %c0 : tensor<?x?xf32>
               %22 = affine.min #map6(%21, %arg2)
               %23 = memref.dim %arg7, %c1 : tensor<?x?xf32>
               %24 = affine.min #map6(%23, %arg4)
-              %25 = subtensor %arg7[%arg2, %arg4] [%22, %24] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+              %25 = tensor.extract_slice %arg7[%arg2, %arg4] [%22, %24] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
               %26 = linalg.matmul {__internal_linalg_transform__ = "workgroup_l1_tile", lowering.config = #config1} ins(%18, %20 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%25 : tensor<?x?xf32>) -> tensor<?x?xf32>
-              %27 = subtensor_insert %26 into %arg7[%arg2, %arg4] [%22, %24] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+              %27 = tensor.insert_slice %26 into %arg7[%arg2, %arg4] [%22, %24] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
               scf.yield %27 : tensor<?x?xf32>
             }
             scf.yield %15 : tensor<?x?xf32>
@@ -2137,16 +2137,16 @@
             %15 = scf.for %arg6 = %c0 to %c144 step %c24 iter_args(%arg7 = %arg5) -> (tensor<?x?xf32>) {
               %16 = affine.min #map3(%arg2)
               %17 = affine.min #map4(%arg6)
-              %18 = subtensor %8[%arg2, %arg6] [%16, %17] [1, 1] : tensor<?x144xf32> to tensor<?x?xf32>
+              %18 = tensor.extract_slice %8[%arg2, %arg6] [%16, %17] [1, 1] : tensor<?x144xf32> to tensor<?x?xf32>
               %19 = affine.min #map5(%arg4)
-              %20 = subtensor %10[%arg6, %arg4] [%17, %19] [1, 1] : tensor<144x?xf32> to tensor<?x?xf32>
+              %20 = tensor.extract_slice %10[%arg6, %arg4] [%17, %19] [1, 1] : tensor<144x?xf32> to tensor<?x?xf32>
               %21 = memref.dim %arg7, %c0 : tensor<?x?xf32>
               %22 = affine.min #map6(%21, %arg2)
               %23 = memref.dim %arg7, %c1 : tensor<?x?xf32>
               %24 = affine.min #map6(%23, %arg4)
-              %25 = subtensor %arg7[%arg2, %arg4] [%22, %24] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+              %25 = tensor.extract_slice %arg7[%arg2, %arg4] [%22, %24] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
               %26 = linalg.matmul {__internal_linalg_transform__ = "workgroup_l1_tile", lowering.config = #config1} ins(%18, %20 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%25 : tensor<?x?xf32>) -> tensor<?x?xf32>
-              %27 = subtensor_insert %26 into %arg7[%arg2, %arg4] [%22, %24] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+              %27 = tensor.insert_slice %26 into %arg7[%arg2, %arg4] [%22, %24] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
               scf.yield %27 : tensor<?x?xf32>
             }
             scf.yield %15 : tensor<?x?xf32>
@@ -2238,16 +2238,16 @@
             %15 = scf.for %arg6 = %c0 to %c144 step %c24 iter_args(%arg7 = %arg5) -> (tensor<?x?xf32>) {
               %16 = affine.min #map3(%arg2)
               %17 = affine.min #map4(%arg6)
-              %18 = subtensor %8[%arg2, %arg6] [%16, %17] [1, 1] : tensor<?x144xf32> to tensor<?x?xf32>
+              %18 = tensor.extract_slice %8[%arg2, %arg6] [%16, %17] [1, 1] : tensor<?x144xf32> to tensor<?x?xf32>
               %19 = affine.min #map5(%arg4)
-              %20 = subtensor %10[%arg6, %arg4] [%17, %19] [1, 1] : tensor<144x?xf32> to tensor<?x?xf32>
+              %20 = tensor.extract_slice %10[%arg6, %arg4] [%17, %19] [1, 1] : tensor<144x?xf32> to tensor<?x?xf32>
               %21 = memref.dim %arg7, %c0 : tensor<?x?xf32>
               %22 = affine.min #map6(%21, %arg2)
               %23 = memref.dim %arg7, %c1 : tensor<?x?xf32>
               %24 = affine.min #map6(%23, %arg4)
-              %25 = subtensor %arg7[%arg2, %arg4] [%22, %24] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+              %25 = tensor.extract_slice %arg7[%arg2, %arg4] [%22, %24] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
               %26 = linalg.matmul {__internal_linalg_transform__ = "workgroup_l1_tile", lowering.config = #config1} ins(%18, %20 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%25 : tensor<?x?xf32>) -> tensor<?x?xf32>
-              %27 = subtensor_insert %26 into %arg7[%arg2, %arg4] [%22, %24] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+              %27 = tensor.insert_slice %26 into %arg7[%arg2, %arg4] [%22, %24] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
               scf.yield %27 : tensor<?x?xf32>
             }
             scf.yield %15 : tensor<?x?xf32>
diff --git a/iree/compiler/Conversion/LinalgToLLVM/PadLinalgWorkgroupTiles.cpp b/iree/compiler/Conversion/LinalgToLLVM/PadLinalgWorkgroupTiles.cpp
index 4fa3ba0..03981e2 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/PadLinalgWorkgroupTiles.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/PadLinalgWorkgroupTiles.cpp
@@ -217,7 +217,7 @@
                                      filledStaticResult.result()});
       SmallVector<OpFoldResult> offsets(2, rewriter.getI64IntegerAttr(0));
       SmallVector<OpFoldResult> strides(2, rewriter.getI64IntegerAttr(1));
-      rewriter.replaceOpWithNewOp<SubTensorOp>(
+      rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
           matmulOp, paddedMatmulOp->getResults()[0], offsets, sizes, strides);
     }
     return success();
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/pad_linalg_workgroup_tiles.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/pad_linalg_workgroup_tiles.mlir
index 3b2ea0f..94fc2e7 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/test/pad_linalg_workgroup_tiles.mlir
+++ b/iree/compiler/Conversion/LinalgToLLVM/test/pad_linalg_workgroup_tiles.mlir
@@ -30,14 +30,14 @@
         %13 = affine.min affine_map<(d0) -> (-d0 + 5, 64)>(%arg0)
         %14 = affine.min affine_map<(d0) -> (-d0 + 5, 64)>(%arg1)
         %15 = linalg.init_tensor [%13, %14] : tensor<?x?xf32>
-        %16 = linalg.fill(%15, %cst) {__internal_linalg_transform__ = "workgroup", lowering.config = #config0} : tensor<?x?xf32>, f32 -> tensor<?x?xf32> 
+        %16 = linalg.fill(%15, %cst) {__internal_linalg_transform__ = "workgroup", lowering.config = #config0} : tensor<?x?xf32>, f32 -> tensor<?x?xf32>
         %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup", lowering.config = #config1} ins(%8, %10 : tensor<?x3xf32>, tensor<3x?xf32>) outs(%16 : tensor<?x?xf32>) -> tensor<?x?xf32>
         flow.dispatch.tensor.store %17, %2, offsets = [%arg0, %arg1], sizes = [%11, %12], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:5x5xf32>
       }
     }
     return
   }
-    
+
   hal.interface @io attributes {sym_visibility = "private"} {
     hal.interface.binding @s0b0_ro_external, set=0, binding=0, type="StorageBuffer", access="Read"
     hal.interface.binding @s0b1_ro_external, set=0, binding=1, type="StorageBuffer", access="Read"
@@ -62,5 +62,5 @@
 //       CHECK: %[[PADDED_RESULT:.+]] = linalg.init_tensor [8, 8] : tensor<8x8xf32>
 //       CHECK: %[[PADDED_RESULT_0:.+]] = linalg.fill(%[[PADDED_RESULT]], %[[C0]]) : tensor<8x8xf32>
 //       CHECK: %[[MATMUL_RESULT:.+]] = linalg.matmul {{.*}} ins(%[[PADDED_LHS]], %[[PADDED_RHS]] : tensor<8x4xf32>, tensor<4x8xf32>) outs(%[[PADDED_RESULT_0]] : tensor<8x8xf32>) -> tensor<8x8xf32>
-//       CHECK: %[[CLIPED_RESULT:.+]] = subtensor %[[MATMUL_RESULT]][0, 0] [%[[LHS_TILE_SIZE]], %[[RHS_TILE_SIZE]]] [1, 1] : tensor<8x8xf32> to tensor<?x?xf32>
+//       CHECK: %[[CLIPED_RESULT:.+]] = tensor.extract_slice %[[MATMUL_RESULT]][0, 0] [%[[LHS_TILE_SIZE]], %[[RHS_TILE_SIZE]]] [1, 1] : tensor<8x8xf32> to tensor<?x?xf32>
 //       CHECK:  flow.dispatch.tensor.store %[[CLIPED_RESULT]], %[[RESULT]], offsets = [%{{.*}}, %{{.*}}], sizes = [%[[LHS_TILE_SIZE]], %[[RHS_TILE_SIZE]]], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:5x5xf32>
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/tile_pad_and_vectorize_workgroups.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/tile_pad_and_vectorize_workgroups.mlir
index 1cb8997..c29dc67 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/test/tile_pad_and_vectorize_workgroups.mlir
+++ b/iree/compiler/Conversion/LinalgToLLVM/test/tile_pad_and_vectorize_workgroups.mlir
@@ -33,7 +33,7 @@
         %11 = affine.min #map3(%arg0)
         %12 = affine.min #map4(%arg1)
         %13 = linalg.init_tensor [%11, %12] : tensor<?x?xf32>
-        %14 = linalg.fill(%13, %cst) {__internal_linalg_transform__ = "workgroup", lowering.config = #config0} : tensor<?x?xf32>, f32 -> tensor<?x?xf32> 
+        %14 = linalg.fill(%13, %cst) {__internal_linalg_transform__ = "workgroup", lowering.config = #config0} : tensor<?x?xf32>, f32 -> tensor<?x?xf32>
         %15 = linalg.matmul {__internal_linalg_transform__ = "workgroup", lowering.config = #config1} ins(%8, %10 : tensor<?x383xf32>, tensor<383x?xf32>) outs(%14 : tensor<?x?xf32>) -> tensor<?x?xf32>
         flow.dispatch.tensor.store %15, %2, offsets = [%arg0, %arg1], sizes = [%7, %9], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:383x513xf32>
       }
@@ -58,23 +58,23 @@
 //     CHECK: %[[DST:.+]] = hal.interface.binding.subspan @io::@s0b2_xw_external[%c0] : !flow.dispatch.tensor<writeonly:383x513xf32>
 //     CHECK: %[[LHS_WG_TILE:.+]] = flow.dispatch.tensor.load %[[LHS]]
 //     CHECK: %[[RHS_WG_TILE:.+]] = flow.dispatch.tensor.load %[[RHS]]
-//     CHECK: %[[DST_WG_TILE_INIT:.+]] = linalg.init_tensor 
+//     CHECK: %[[DST_WG_TILE_INIT:.+]] = linalg.init_tensor
 //     CHECK: %[[DST_WG_TILE_INIT_C0:.+]] = linalg.fill(%[[DST_WG_TILE_INIT]], %[[CST]])
 //     CHECK: {{.*}} = scf.for {{.*}} = %[[C0]] to %[[C383]] step %[[C32]] iter_args(%[[DST_WG_TILE_0:.+]] = %[[DST_WG_TILE_INIT_C0]])
 //     CHECK:    {{.*}} = scf.for {{.*}} = %[[C0]] to %[[C513]] step %[[C32]] iter_args(%[[DST_WG_TILE_1:.+]] = %[[DST_WG_TILE_0]])
 //     CHECK:       {{.*}} = scf.for {{.*}} = %[[C0]] to %[[C383]] step %[[C32]] iter_args(%[[DST_WG_TILE_2:.+]] = %[[DST_WG_TILE_1]])
-//     CHECK:           %[[LHS_L1_TILE:.+]] = subtensor %[[LHS_WG_TILE]]
-//     CHECK:           %[[RHS_L1_TILE:.+]] = subtensor %[[RHS_WG_TILE]]
-//     CHECK:           %[[DST_L1_TILE:.+]] = subtensor %[[DST_WG_TILE_2]]
+//     CHECK:           %[[LHS_L1_TILE:.+]] = tensor.extract_slice %[[LHS_WG_TILE]]
+//     CHECK:           %[[RHS_L1_TILE:.+]] = tensor.extract_slice %[[RHS_WG_TILE]]
+//     CHECK:           %[[DST_L1_TILE:.+]] = tensor.extract_slice %[[DST_WG_TILE_2]]
 //     CHECK:           %[[LHS_L1_TILE_PADDED:.+]] = linalg.pad_tensor %[[LHS_L1_TILE]]
 //     CHECK:           %[[RHS_L1_TILE_PADDED:.+]] = linalg.pad_tensor %[[RHS_L1_TILE]]
 //     CHECK:           %[[DST_L1_TILE_PADDED:.+]] = linalg.pad_tensor %[[DST_L1_TILE]]
 //     CHECK:           {{.*}} = scf.for {{.*}} = %[[C0]] to %[[C32]] step %[[C4]] iter_args(%[[DST_VEC_TILE_0:.+]] = %[[DST_L1_TILE_PADDED]])
 //     CHECK:              {{.*}} = scf.for {{.*}} = %[[C0]] to %[[C32]] step %[[C4]] iter_args(%[[DST_VEC_TILE_1:.+]] = %[[DST_VEC_TILE_0]])
 //     CHECK:                {{.*}} = scf.for {{.*}} = %[[C0]] to %[[C32]] step %[[C4]] iter_args(%[[DST_VEC_TILE_2:.+]] = %[[DST_VEC_TILE_1]])
-//     CHECK:                    %[[LHS_VEC_TILE:.+]] = subtensor %[[LHS_L1_TILE_PADDED]]
-//     CHECK:                    %[[RHS_VEC_TILE:.+]] = subtensor %[[RHS_L1_TILE_PADDED]]
-//     CHECK:                    %[[DST_VEC_TILE:.+]] = subtensor %[[DST_VEC_TILE_2]]
+//     CHECK:                    %[[LHS_VEC_TILE:.+]] = tensor.extract_slice %[[LHS_L1_TILE_PADDED]]
+//     CHECK:                    %[[RHS_VEC_TILE:.+]] = tensor.extract_slice %[[RHS_L1_TILE_PADDED]]
+//     CHECK:                    %[[DST_VEC_TILE:.+]] = tensor.extract_slice %[[DST_VEC_TILE_2]]
 //     CHECK:                    %[[LHS_VEC:.+]] = vector.transfer_read %[[LHS_VEC_TILE]]
 //     CHECK:                    %[[RHS_VEC:.+]] = vector.transfer_read %[[RHS_VEC_TILE]]
 //     CHECK:                    %[[DST_VEC:.+]] = vector.transfer_read %[[DST_VEC_TILE]]
diff --git a/iree/compiler/Conversion/Utils/Utils.cpp b/iree/compiler/Conversion/Utils/Utils.cpp
index 8145cfb..e1a1440 100644
--- a/iree/compiler/Conversion/Utils/Utils.cpp
+++ b/iree/compiler/Conversion/Utils/Utils.cpp
@@ -58,7 +58,7 @@
       view = viewOp.getViewSource();
       continue;
     }
-    if (auto subTensorOp = view.getDefiningOp<SubTensorOp>()) {
+    if (auto subTensorOp = view.getDefiningOp<tensor::ExtractSliceOp>()) {
       view = subTensorOp.source();
       continue;
     }
diff --git a/iree/compiler/Dialect/Flow/IR/BUILD b/iree/compiler/Dialect/Flow/IR/BUILD
index 0a263ec..e1cbe54 100644
--- a/iree/compiler/Dialect/Flow/IR/BUILD
+++ b/iree/compiler/Dialect/Flow/IR/BUILD
@@ -86,6 +86,7 @@
         "//iree/compiler/Dialect/Shape/IR:td_files",
         "@llvm-project//mlir:OpBaseTdFiles",
         "@llvm-project//mlir:StdOpsTdFiles",
+        "@llvm-project//mlir:ViewLikeInterfaceTdFiles",
         "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
     ],
 )
@@ -110,6 +111,7 @@
         "//iree/compiler/Dialect/Shape/IR:td_files",
         "@llvm-project//mlir:OpBaseTdFiles",
         "@llvm-project//mlir:StdOpsTdFiles",
+        "@llvm-project//mlir:ViewLikeInterfaceTdFiles",
         "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
     ],
 )
@@ -135,6 +137,7 @@
         "@llvm-project//mlir:OpBaseTdFiles",
         "@llvm-project//mlir:SideEffectTdFiles",
         "@llvm-project//mlir:StdOpsTdFiles",
+        "@llvm-project//mlir:ViewLikeInterfaceTdFiles",
         "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
     ],
 )
@@ -156,6 +159,7 @@
         "@llvm-project//mlir:OpBaseTdFiles",
         "@llvm-project//mlir:SideEffectTdFiles",
         "@llvm-project//mlir:StdOpsTdFiles",
+        "@llvm-project//mlir:ViewLikeInterfaceTdFiles",
         "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
     ],
 )
diff --git a/iree/compiler/Dialect/Flow/IR/FlowDialect.cpp b/iree/compiler/Dialect/Flow/IR/FlowDialect.cpp
index 9d703d1..fbe120f 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowDialect.cpp
+++ b/iree/compiler/Dialect/Flow/IR/FlowDialect.cpp
@@ -10,6 +10,7 @@
 #include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
 #include "llvm/Support/SourceMgr.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/OpImplementation.h"
@@ -69,6 +70,7 @@
 #include "iree/compiler/Dialect/Flow/IR/FlowOps.cpp.inc"
       >();
   context->getOrLoadDialect("shapex");
+  context->getOrLoadDialect<tensor::TensorDialect>();
 }
 
 Operation *FlowDialect::materializeConstant(OpBuilder &builder, Attribute value,
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
index 8456560..125528a 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
+++ b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -446,7 +446,7 @@
         loadOp.strides().empty()) {
       return failure();
     }
-    rewriter.replaceOpWithNewOp<SubTensorOp>(
+    rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
         loadOp, loadOp.source(), loadOp.getMixedOffsets(),
         loadOp.getMixedSizes(), loadOp.getMixedStrides());
     return success();
diff --git a/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp b/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp
index 6654559..24bf115 100644
--- a/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
@@ -160,10 +161,10 @@
 
 /// Convert subtensor insert operation flow.tensor.update where possible.
 struct SubTensorInsertToTensorUpdate
-    : public OpRewritePattern<SubTensorInsertOp> {
-  using OpRewritePattern<SubTensorInsertOp>::OpRewritePattern;
+    : public OpRewritePattern<tensor::InsertSliceOp> {
+  using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(SubTensorInsertOp insertOp,
+  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
                                 PatternRewriter &rewriter) const override {
     if (insertOp->getParentOfType<Flow::DispatchWorkgroupsOp>()) {
       return failure();
@@ -203,26 +204,27 @@
 };
 
 /// Convert subtensor operation to flow.tensor.slice where possible.
-struct SubTensorToTensorSlice : public OpRewritePattern<SubTensorOp> {
-  using OpRewritePattern<SubTensorOp>::OpRewritePattern;
+struct SubTensorToTensorSlice
+    : public OpRewritePattern<tensor::ExtractSliceOp> {
+  using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(SubTensorOp subTensorOp,
+  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
                                 PatternRewriter &rewriter) const override {
-    if (subTensorOp->getParentOfType<Flow::DispatchWorkgroupsOp>()) {
+    if (sliceOp->getParentOfType<Flow::DispatchWorkgroupsOp>()) {
       return failure();
     }
-    SmallVector<OpFoldResult, 4> offsets = subTensorOp.getMixedOffsets();
-    SmallVector<OpFoldResult, 4> sizes = subTensorOp.getMixedSizes();
-    SmallVector<OpFoldResult, 4> strides = subTensorOp.getMixedStrides();
-    ArrayRef<int64_t> srcShape = subTensorOp.getSourceType().getShape();
+    SmallVector<OpFoldResult, 4> offsets = sliceOp.getMixedOffsets();
+    SmallVector<OpFoldResult, 4> sizes = sliceOp.getMixedSizes();
+    SmallVector<OpFoldResult, 4> strides = sliceOp.getMixedStrides();
+    ArrayRef<int64_t> srcShape = sliceOp.getSourceType().getShape();
     if (!isOffsetSizeAndStrideMappableToFlow(offsets, sizes, strides,
                                              srcShape)) {
       return failure();
     }
-    Location loc = subTensorOp.getLoc();
+    Location loc = sliceOp.getLoc();
 
-    ShapedType sourceType = subTensorOp.getSourceType();
-    ShapedType resultType = subTensorOp.getType();
+    ShapedType sourceType = sliceOp.getSourceType();
+    ShapedType resultType = sliceOp.getType();
 
     // Handle rank reduced version.
     if (resultType.getRank() < sourceType.getRank()) {
@@ -235,17 +237,17 @@
     auto offsetVals = getAsValues(rewriter, loc, offsets);
     auto sizeVals = getAsValues(rewriter, loc, sizes);
     auto sourceDynamicDims =
-        getDynamicDimValues(rewriter, loc, subTensorOp.source());
+        getDynamicDimValues(rewriter, loc, sliceOp.source());
     auto resultDynamicDims = getDynamicValues(sizes);
     Value replacement = rewriter.create<TensorSliceOp>(
-        loc, resultType, subTensorOp.source(), sourceDynamicDims, offsetVals,
+        loc, resultType, sliceOp.source(), sourceDynamicDims, offsetVals,
         sizeVals, resultDynamicDims);
-    if (resultType.getRank() > subTensorOp.getType().getRank()) {
+    if (resultType.getRank() > sliceOp.getType().getRank()) {
       replacement = rewriter.create<IREE::Flow::TensorReshapeOp>(
-          loc, subTensorOp.getType(), replacement, resultDynamicDims,
+          loc, sliceOp.getType(), replacement, resultDynamicDims,
           resultDynamicDims);
     }
-    rewriter.replaceOp(subTensorOp, replacement);
+    rewriter.replaceOp(sliceOp, replacement);
     return success();
   }
 };
diff --git a/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp b/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp
index a415d25..a90fbea 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Dominance.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
@@ -56,14 +57,15 @@
 // TODO(nicolasvasilache): Use some interface instead of op names directly.
 static bool hasDestructiveUpdateSubTensorUses(
     Value v, SpecialTerminatorOpCapture &capture) {
-  SmallVector<SubTensorOp, 4> reads;
-  SmallVector<SubTensorInsertOp, 4> writes;
+  SmallVector<tensor::ExtractSliceOp, 4> reads;
+  SmallVector<tensor::InsertSliceOp, 4> writes;
   for (auto &u : v.getUses()) {
-    if (auto subTensorOp = dyn_cast<SubTensorOp>(u.getOwner())) {
+    if (auto subTensorOp = dyn_cast<tensor::ExtractSliceOp>(u.getOwner())) {
       reads.push_back(subTensorOp);
       continue;
     }
-    if (auto subTensorInsertOp = dyn_cast<SubTensorInsertOp>(u.getOwner())) {
+    if (auto subTensorInsertOp =
+            dyn_cast<tensor::InsertSliceOp>(u.getOwner())) {
       writes.push_back(subTensorInsertOp);
       continue;
     }
@@ -195,7 +197,8 @@
 
 /// Convert `subtensor %t [offsets][sizes][strides] -> %st` to a
 /// flow.dispatch.tensor.load.
-static LogicalResult propagateSubTensorOp(OpBuilder &b, SubTensorOp op) {
+static LogicalResult propagateSubTensorOp(OpBuilder &b,
+                                          tensor::ExtractSliceOp op) {
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPoint(op);
   auto loadOp = op.source().getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
@@ -224,7 +227,7 @@
 }
 
 static LogicalResult rewriteSubTensorInsertInPlace(OpBuilder &b,
-                                                   SubTensorInsertOp op,
+                                                   tensor::InsertSliceOp op,
                                                    Value target) {
   LLVM_DEBUG(llvm::dbgs() << "RewriteSubTensorInsertInPlace: "
                           << *(op.getOperation()) << "\n");
@@ -374,12 +377,14 @@
 
   // Try to rewrite inplace.
   if (failed(rewriteSubTensorInsertInPlace(
-          b, cast<SubTensorInsertOp>(capture.rootDestructiveUpdate), target))) {
+          b, cast<tensor::InsertSliceOp>(capture.rootDestructiveUpdate),
+          target))) {
     return failure();
   }
 
   if (scf::ForOp loopOp = dyn_cast<scf::ForOp>(outermostProducingOp))
-    loopOp.walk([&](SubTensorOp op) { (void)propagateSubTensorOp(b, op); });
+    loopOp.walk(
+        [&](tensor::ExtractSliceOp op) { (void)propagateSubTensorOp(b, op); });
 
   return success();
 }
@@ -420,7 +425,7 @@
                     capture.initValue = op.value();
                     Value sourceValue =
                         isADestructiveUpdatePattern(capture.initValue, capture);
-                    if (!sourceValue || !isa_and_nonnull<SubTensorInsertOp>(
+                    if (!sourceValue || !isa_and_nonnull<tensor::InsertSliceOp>(
                                             capture.rootDestructiveUpdate))
                       return WalkResult::advance();
                     if (failed(rewriteDestructiveUpdateInPlace(b, capture,
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
index b74fff5..ff98366 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Block.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Builders.h"
@@ -253,7 +254,7 @@
   // Linalg ops are marked dispatchable.
   if ((op->getDialect() !=
        op->getContext()->getLoadedDialect<linalg::LinalgDialect>()) &&
-      !isa<SubTensorOp, SubTensorInsertOp>(op)) {
+      !isa<tensor::ExtractSliceOp, tensor::InsertSliceOp>(op)) {
     return false;
   }
   return !isAlwaysClonedIntoDispatchOp(op);
@@ -261,8 +262,8 @@
 
 static bool isAlwaysFusedIntoDispatchOp(Operation *op) {
   return isDispatchableOp(op) &&
-         (isa<linalg::TensorCollapseShapeOp, SubTensorOp>(op) ||
-          isa<linalg::TensorExpandShapeOp, SubTensorOp>(op));
+         (isa<linalg::TensorCollapseShapeOp, tensor::ExtractSliceOp>(op) ||
+          isa<linalg::TensorExpandShapeOp, tensor::ExtractSliceOp>(op));
 }
 
 //===----------------------------------------------------------------------===//
@@ -535,7 +536,7 @@
 
     // TODO(antiagainst): use TiedOpInterface here instead of hardcoding ops
     // when it's available in MLIR core in some form.
-    if (auto insertOp = dyn_cast_or_null<SubTensorInsertOp>(tieOp)) {
+    if (auto insertOp = dyn_cast_or_null<tensor::InsertSliceOp>(tieOp)) {
       auto loadOp =
           insertOp.dest().getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
       if (!loadOp) return nullptr;
@@ -1224,7 +1225,7 @@
     // to guarantee type match during transformation. Later in destructive
     // update subtensor_insert ops will be turned into flow dispatch output
     // store ops.
-    SubTensorInsertOp::getCanonicalizationPatterns(patterns, context);
+    tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, context);
     (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
   }
 
diff --git a/iree/compiler/Dialect/Flow/Transforms/PadLinalgOps.cpp b/iree/compiler/Dialect/Flow/Transforms/PadLinalgOps.cpp
index f35380c..40c76b2 100644
--- a/iree/compiler/Dialect/Flow/Transforms/PadLinalgOps.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/PadLinalgOps.cpp
@@ -118,7 +118,7 @@
       SmallVector<OpFoldResult> strides(2, rewriter.getI64IntegerAttr(1));
       SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(M),
                                          rewriter.getIndexAttr(N)};
-      rewriter.replaceOpWithNewOp<SubTensorOp>(
+      rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
           matmulOp, paddedMatmulOp->getResults()[0], offsets, sizes, strides);
     }
 
diff --git a/iree/compiler/Dialect/Flow/Transforms/PadTensorToSubTensorInsert.cpp b/iree/compiler/Dialect/Flow/Transforms/PadTensorToSubTensorInsert.cpp
index 8ded13c..b64621d 100644
--- a/iree/compiler/Dialect/Flow/Transforms/PadTensorToSubTensorInsert.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/PadTensorToSubTensorInsert.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -89,7 +90,7 @@
     Value fill =
         rewriter.create<linalg::FillOp>(loc, initTensor, yieldVal).getResult(0);
     SmallVector<OpFoldResult> strides(rank, rewriter.getI64IntegerAttr(1));
-    rewriter.replaceOpWithNewOp<SubTensorInsertOp>(
+    rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
         padTensorOp, source, fill, lowPad, sourceShape, strides);
     return success();
   }
@@ -98,7 +99,8 @@
 struct PadTensorToSubTensorInsertPass
     : public PadTensorToSubTensorInsertBase<PadTensorToSubTensorInsertPass> {
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<linalg::LinalgDialect, StandardOpsDialect>();
+    registry.insert<linalg::LinalgDialect, memref::MemRefDialect,
+                    StandardOpsDialect>();
   }
 
   void runOnOperation() override {
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/convert_to_flow_tensor_ops.mlir b/iree/compiler/Dialect/Flow/Transforms/test/convert_to_flow_tensor_ops.mlir
index 3360b10..53c4a24 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/convert_to_flow_tensor_ops.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/convert_to_flow_tensor_ops.mlir
@@ -1,7 +1,7 @@
 // RUN: iree-opt -iree-flow-convert-to-flow-tensor-ops-pass -canonicalize -cse -split-input-file %s | IreeFileCheck %s
 
 func @subtensor1(%arg0 : tensor<5x24x48xf32>) -> tensor<4xf32> {
-  %0 = subtensor %arg0[2, 3, 4] [1, 1, 4] [1, 1, 1]
+  %0 = tensor.extract_slice %arg0[2, 3, 4] [1, 1, 4] [1, 1, 1]
       : tensor<5x24x48xf32> to tensor<4xf32>
   return %0 : tensor<4xf32>
 }
@@ -18,7 +18,7 @@
 // -----
 
 func @subtensor2(%arg0 : tensor<5x24x48xf32>) -> tensor<2x48xf32> {
-  %0 = subtensor %arg0[2, 3, 0] [1, 2, 48] [1, 1, 1]
+  %0 = tensor.extract_slice %arg0[2, 3, 0] [1, 2, 48] [1, 1, 1]
       : tensor<5x24x48xf32> to tensor<2x48xf32>
   return %0 : tensor<2x48xf32>
 }
@@ -36,47 +36,47 @@
 // -----
 
 func @subtensor3(%arg0 : tensor<5x24x48xf32>) -> tensor<2x24xf32> {
-  %0 = subtensor %arg0[2, 3, 0] [1, 2, 24] [1, 1, 1]
+  %0 = tensor.extract_slice %arg0[2, 3, 0] [1, 2, 24] [1, 1, 1]
       : tensor<5x24x48xf32> to tensor<2x24xf32>
   return %0 : tensor<2x24xf32>
 }
 // CHECK-LABEL: func @subtensor3
-//       CHECK:   subtensor
+//       CHECK:   tensor.extract_slice
 
 // -----
 
 func @subtensor4(%arg0 : tensor<5x24x48xf32>, %arg1 : index) -> tensor<2x24xf32> {
-  %0 = subtensor %arg0[2, 3, 0] [1, 2, 24] [1, %arg1, 1]
+  %0 = tensor.extract_slice %arg0[2, 3, 0] [1, 2, 24] [1, %arg1, 1]
       : tensor<5x24x48xf32> to tensor<2x24xf32>
   return %0 : tensor<2x24xf32>
 }
 // CHECK-LABEL: func @subtensor4
-//       CHECK:   subtensor
+//       CHECK:   tensor.extract_slice
 
 // -----
 
 func @subtensor5(%arg0 : tensor<5x24x48xf32>, %arg1 : index) -> tensor<2x48xf32> {
-  %0 = subtensor %arg0[2, %arg1, 0] [1, 2, 48] [1, 1, 1]
+  %0 = tensor.extract_slice %arg0[2, %arg1, 0] [1, 2, 48] [1, 1, 1]
       : tensor<5x24x48xf32> to tensor<2x48xf32>
   return %0 : tensor<2x48xf32>
 }
 // CHECK-LABEL: func @subtensor5
-//       CHECK:   subtensor
+//       CHECK:   tensor.extract_slice
 
 // -----
 
 func @subtensor6(%arg0 : tensor<5x24x48xf32>, %arg1 : index) -> tensor<?x48xf32> {
-  %0 = subtensor %arg0[2, 3, 0] [1, %arg1, 48] [1, 1, 1]
+  %0 = tensor.extract_slice %arg0[2, 3, 0] [1, %arg1, 48] [1, 1, 1]
       : tensor<5x24x48xf32> to tensor<?x48xf32>
   return %0 : tensor<?x48xf32>
 }
 // CHECK-LABEL: func @subtensor6
-//       CHECK:   subtensor
+//       CHECK:   tensor.extract_slice
 
 // -----
 
 func @subtensor7(%arg0 : tensor<5x?x48xf32>, %arg1 : index) -> tensor<2x48xf32> {
-  %0 = subtensor %arg0[2, 3, 0] [1, 2, 48] [1, 1, 1]
+  %0 = tensor.extract_slice %arg0[2, 3, 0] [1, 2, 48] [1, 1, 1]
       : tensor<5x?x48xf32> to tensor<2x48xf32>
   return %0 : tensor<2x48xf32>
 }
@@ -96,7 +96,7 @@
 // -----
 
 func @rank_reducing_subtensor(%arg0: tensor<?x513xi32>) -> tensor<513xi32> {
-  %0 = subtensor %arg0[4, 0] [1, 513] [1, 1] : tensor<?x513xi32> to tensor<513xi32>
+  %0 = tensor.extract_slice %arg0[4, 0] [1, 513] [1, 1] : tensor<?x513xi32> to tensor<513xi32>
   return %0 : tensor<513xi32>
 }
 // CHECK-LABEL: func @rank_reducing_subtensor
@@ -136,7 +136,7 @@
     (%arg0 : tensor<?x24x48xf32>, %arg1 : tensor<1x4x48xf32>) ->
     tensor<?x24x48xf32> {
   %c0 = constant 0 : index
-  %0 = subtensor_insert %arg1 into %arg0[4, 2, 0] [1, 4, 48] [1, 1, 1] :
+  %0 = tensor.insert_slice %arg1 into %arg0[4, 2, 0] [1, 4, 48] [1, 1, 1] :
       tensor<1x4x48xf32> into tensor<?x24x48xf32>
   return %0 : tensor<?x24x48xf32>
 }
@@ -156,7 +156,7 @@
     (%arg0 : tensor<?x24x48xf32>, %arg1 : tensor<4x48xf32>) ->
     tensor<?x24x48xf32> {
   %c0 = constant 0 : index
-  %0 = subtensor_insert %arg1 into %arg0[4, 2, 0] [1, 4, 48] [1, 1, 1] :
+  %0 = tensor.insert_slice %arg1 into %arg0[4, 2, 0] [1, 4, 48] [1, 1, 1] :
       tensor<4x48xf32> into tensor<?x24x48xf32>
   return %0 : tensor<?x24x48xf32>
 }
@@ -175,7 +175,7 @@
 
 func @rank_reducing_subtensor_insert_trailing_unit_dims
    (%arg0 : tensor<49x20xf32>, %arg1 : tensor<1x50x20x1xf32>) -> tensor<1x50x20x1xf32> {
-  %0 = subtensor_insert %arg0 into %arg1[0, 1, 0, 0] [1, 49, 20, 1] [1, 1, 1, 1] : tensor<49x20xf32> into tensor<1x50x20x1xf32>
+  %0 = tensor.insert_slice %arg0 into %arg1[0, 1, 0, 0] [1, 49, 20, 1] [1, 1, 1, 1] : tensor<49x20xf32> into tensor<1x50x20x1xf32>
   return %0 : tensor<1x50x20x1xf32>
 }
 // CHECK-LABEL: func @rank_reducing_subtensor_insert_trailing_unit_dims
@@ -188,7 +188,7 @@
 
 func @rank_reducing_subtensor_trailing_unit_dims
    (%arg0 : tensor<1x50x20x1xf32>) -> tensor<49x20xf32> {
-  %0 = subtensor %arg0[0, 1, 0, 0] [1, 49, 20, 1] [1, 1, 1, 1] : tensor<1x50x20x1xf32> to tensor<49x20xf32>
+  %0 = tensor.extract_slice %arg0[0, 1, 0, 0] [1, 49, 20, 1] [1, 1, 1, 1] : tensor<1x50x20x1xf32> to tensor<49x20xf32>
   return %0 : tensor<49x20xf32>
 }
 // CHECK-LABEL: func @rank_reducing_subtensor_trailing_unit_dims
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
index 33bc432..d23bfde 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
@@ -483,7 +483,7 @@
   %cst = constant 0.000000e+00 : f32
   %0 = linalg.init_tensor [1, 225, 225, 3] : tensor<1x225x225x3xf32>
   %1 = linalg.fill(%0, %cst) : tensor<1x225x225x3xf32>, f32 -> tensor<1x225x225x3xf32>
-  %2 = subtensor_insert %arg0 into %1[0, 0, 0, 0] [1, 224, 224, 3] [1, 1, 1, 1] : tensor<1x224x224x3xf32> into tensor<1x225x225x3xf32>
+  %2 = tensor.insert_slice %arg0 into %1[0, 0, 0, 0] [1, 224, 224, 3] [1, 1, 1, 1] : tensor<1x224x224x3xf32> into tensor<1x225x225x3xf32>
   return %2 : tensor<1x225x225x3xf32>
 }
 
@@ -501,7 +501,7 @@
 // CHECK-NEXT:       (%[[SRC:.+]]: !flow.dispatch.tensor<readonly:1x224x224x3xf32>, %[[DST:.+]]: !flow.dispatch.tensor<readwrite:1x225x225x3xf32>) {
 // CHECK-NEXT:     %[[SRC_TENSOR:.+]] = flow.dispatch.tensor.load %[[SRC]], {{.*}} : !flow.dispatch.tensor<readonly:1x224x224x3xf32> -> tensor<1x224x224x3xf32>
 // CHECK-NEXT:     %[[DST_TENSOR:.+]] = flow.dispatch.tensor.load %[[DST]], {{.*}} : !flow.dispatch.tensor<readwrite:1x225x225x3xf32> -> tensor<1x225x225x3xf32>
-// CHECK-NEXT:     %[[INSERT:.+]] = subtensor_insert %[[SRC_TENSOR]] into %[[DST_TENSOR]][0, 0, 0, 0] [1, 224, 224, 3] [1, 1, 1, 1]
+// CHECK-NEXT:     %[[INSERT:.+]] = tensor.insert_slice %[[SRC_TENSOR]] into %[[DST_TENSOR]][0, 0, 0, 0] [1, 224, 224, 3] [1, 1, 1, 1]
 // CHECK-NEXT:     flow.dispatch.tensor.store %[[INSERT]], %[[DST]], {{.*}} : tensor<1x225x225x3xf32> -> !flow.dispatch.tensor<readwrite:1x225x225x3xf32>
 // CHECK-NEXT:     flow.return
 //
@@ -553,12 +553,12 @@
     %arg0: tensor<?xf32>, %arg1 : tensor<1x?xf32>, %arg2 : tensor<i32>,
     %arg3 : index) -> tensor<?xf32> {
   %0 = linalg.tensor_expand_shape %arg0 [[0, 1]] : tensor<?xf32> into tensor<1x?xf32>
-  %1 = subtensor %0[0, 20] [1, %arg3] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
+  %1 = tensor.extract_slice %0[0, 20] [1, %arg3] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
   %2 = linalg.tensor_collapse_shape %1 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
   %3 = linalg.tensor_collapse_shape %arg1 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
-  %4 = subtensor %0[0, 10] [1, %arg3] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
+  %4 = tensor.extract_slice %0[0, 10] [1, %arg3] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
   %5 = linalg.tensor_collapse_shape %4 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
-  %6 = subtensor %0[0, 0] [1, %arg3] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
+  %6 = tensor.extract_slice %0[0, 0] [1, %arg3] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
   %7 = linalg.tensor_collapse_shape %6 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
   %8 = linalg.init_tensor [%arg3] : tensor<?xf32>
   %9 = linalg.generic {
@@ -579,7 +579,7 @@
 }
 // CHECK-LABEL: func @inline_dag_1
 //   CHECK-NOT:   linalg.
-//   CHECK-NOT:   subtensor
+//   CHECK-NOT:   tensor.extract_slice
 //       CHECK:   flow.dispatch.workgroups
 //  CHECK-NEXT:     %[[ARG4:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:1x?xf32>
 //  CHECK-SAME:     %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:?xf32>
@@ -591,9 +591,9 @@
 //       CHECK:     %[[LEAF3:.+]] = flow.dispatch.tensor.load %[[ARG6]], {{.*}}
 //       CHECK:     %[[OP1:.+]] = linalg.tensor_expand_shape %[[LEAF2]]
 //       CHECK:     %[[OP2:.+]] = linalg.tensor_collapse_shape %[[LEAF1]]
-//       CHECK:     %[[OP3:.+]] = subtensor %[[OP1]][0, 0]
-//       CHECK:     %[[OP4:.+]] = subtensor %[[OP1]][0, 10]
-//       CHECK:     %[[OP5:.+]] = subtensor %[[OP1]][0, 20]
+//       CHECK:     %[[OP3:.+]] = tensor.extract_slice %[[OP1]][0, 0]
+//       CHECK:     %[[OP4:.+]] = tensor.extract_slice %[[OP1]][0, 10]
+//       CHECK:     %[[OP5:.+]] = tensor.extract_slice %[[OP1]][0, 20]
 //       CHECK:     %[[OP6:.+]] = linalg.tensor_collapse_shape %[[OP3]]
 //       CHECK:     %[[OP7:.+]] = linalg.tensor_collapse_shape %[[OP4]]
 //       CHECK:     %[[OP8:.+]] = linalg.tensor_collapse_shape %[[OP5]]
@@ -606,14 +606,14 @@
     %arg0: tensor<?xf32>, %arg1 : tensor<1x?xf32>, %arg2 : tensor<i32>,
     %arg3 : index) -> tensor<?xf32> {
   %0 = linalg.tensor_expand_shape %arg0 [[0, 1]] : tensor<?xf32> into tensor<1x?xf32>
-  %1 = subtensor %0[0, 20] [1, %arg3] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
+  %1 = tensor.extract_slice %0[0, 20] [1, %arg3] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
   %2 = linalg.tensor_collapse_shape %arg1 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
   br ^bb1
 ^bb1:
   %3 = linalg.tensor_collapse_shape %1 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
-  %4 = subtensor %0[0, 10] [1, %arg3] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
+  %4 = tensor.extract_slice %0[0, 10] [1, %arg3] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
   %5 = linalg.tensor_collapse_shape %4 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
-  %6 = subtensor %0[0, 0] [1, %arg3] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
+  %6 = tensor.extract_slice %0[0, 0] [1, %arg3] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
   %7 = linalg.tensor_collapse_shape %6 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
   %8 = linalg.init_tensor [%arg3] : tensor<?xf32>
   %9 = linalg.generic {
@@ -646,9 +646,9 @@
 //       CHECK:     %[[LEAF3:.+]] = flow.dispatch.tensor.load %[[ARG6]], {{.*}}
 //       CHECK:     %[[OP1:.+]] = linalg.tensor_expand_shape %[[LEAF2]]
 //       CHECK:     %[[OP2:.+]] = linalg.tensor_collapse_shape %[[LEAF1]]
-//       CHECK:     %[[OP3:.+]] = subtensor %[[OP1]][0, 0]
-//       CHECK:     %[[OP4:.+]] = subtensor %[[OP1]][0, 10]
-//       CHECK:     %[[OP5:.+]] = subtensor %[[OP1]][0, 20]
+//       CHECK:     %[[OP3:.+]] = tensor.extract_slice %[[OP1]][0, 0]
+//       CHECK:     %[[OP4:.+]] = tensor.extract_slice %[[OP1]][0, 10]
+//       CHECK:     %[[OP5:.+]] = tensor.extract_slice %[[OP1]][0, 20]
 //       CHECK:     %[[OP6:.+]] = linalg.tensor_collapse_shape %[[OP3]]
 //       CHECK:     %[[OP7:.+]] = linalg.tensor_collapse_shape %[[OP4]]
 //       CHECK:     %[[OP8:.+]] = linalg.tensor_collapse_shape %[[OP5]]
@@ -667,7 +667,7 @@
   %251 = cmpi sgt, %250, %c0_i32 : i32
   %252 = select %251, %250, %c0_i32 : i32
   %253 = index_cast %252 : i32 to index
-  %254 = subtensor %245[%253] [9] [1] : tensor<18xi32> to tensor<9xi32>
+  %254 = tensor.extract_slice %245[%253] [9] [1] : tensor<18xi32> to tensor<9xi32>
   %255 = linalg.init_tensor [9] : tensor<9xi1>
   %256 = linalg.generic {
       indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
@@ -699,7 +699,7 @@
 //   CHECK-DAG:     %[[CMP2:.+]] = cmpi sgt, %[[SELECT1]], %[[C0]]
 //   CHECK-DAG:     %[[SELECT2:.+]] = select %[[CMP2]], %[[SELECT1]], %[[C0]]
 //   CHECK-DAG:     %[[INDEX_CAST:.+]] = index_cast %[[SELECT2]]
-//   CHECK-DAG:     subtensor %[[ARG3V]][%[[INDEX_CAST]]]
+//   CHECK-DAG:     tensor.extract_slice %[[ARG3V]][%[[INDEX_CAST]]]
 //       CHECK:     flow.return
 
 // -----
@@ -714,7 +714,7 @@
   %3 = cmpi sgt, %2, %c0_i32 : i32
   %4 = select %3, %2, %c0_i32 : i32
   %5 = index_cast %4 : i32 to index
-  %6 = subtensor %arg0[%5] [1] [1] : tensor<4xi32> to tensor<i32>
+  %6 = tensor.extract_slice %arg0[%5] [1] [1] : tensor<4xi32> to tensor<i32>
   br ^bb1
 ^bb1:  // pred: ^bb0
   %7 = linalg.init_tensor [] : tensor<i16>
@@ -744,7 +744,7 @@
 //       CHECK:     %[[OP4:.+]] = cmpi sgt, %[[OP3]], %[[C0]] : i32
 //       CHECK:     %[[OP5:.+]] = select %[[OP4]], %[[OP3]], %[[C0]] : i32
 //       CHECK:     %[[OP6:.+]] = index_cast %[[OP5]] : i32 to index
-//       CHECK:     %[[OP7:.+]] = subtensor %[[LEAF1]][%[[OP6]]] [1] [1] : tensor<4xi32> to tensor<i32>
+//       CHECK:     %[[OP7:.+]] = tensor.extract_slice %[[LEAF1]][%[[OP6]]] [1] [1] : tensor<4xi32> to tensor<i32>
 //       CHECK:     %[[RES:.+]] = linalg.generi
 //  CHECK-SAME:       ins(%[[OP7]] : tensor<i32>)
 //  CHECK-SAME:       outs(%[[INIT]] : tensor<i16>) {
@@ -842,7 +842,7 @@
   %9 = cmpi sgt, %8, %c0_i32 : i32
   %10 = select %9, %8, %c0_i32 : i32
   %11 = index_cast %10 : i32 to index
-  %12 = subtensor %arg0[%5, %11] [1, %arg3] [1, 1] : tensor<?x?xi32> to tensor<1x?xi32>
+  %12 = tensor.extract_slice %arg0[%5, %11] [1, %arg3] [1, 1] : tensor<?x?xi32> to tensor<1x?xi32>
   return %12 : tensor<1x?xi32>
 }
 // CHECK-LABEL: func @dynamic_slice(
@@ -864,7 +864,7 @@
 //   CHECK-DAG:     select
 //   CHECK-DAG:     index_cast
 //   CHECK-DAG:     index_cast
-//       CHECK:     subtensor
+//       CHECK:     tensor.extract_slice
 //       CHECK:     flow.return
 //       CHECK:   return %[[RESULT]]
 
@@ -895,4 +895,4 @@
 //       CHECK:        linalg.matmul
 //   CHECK-NOT:    linalg.fill
 //   CHECK-NOT:    linalg.matmul
-//       CHECK:    return
\ No newline at end of file
+//       CHECK:    return
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/pad_linalg_ops.mlir b/iree/compiler/Dialect/Flow/Transforms/test/pad_linalg_ops.mlir
index cfb4b7a..96bba4e 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/pad_linalg_ops.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/pad_linalg_ops.mlir
@@ -14,7 +14,7 @@
 //       CHECK:      %[[PADDED_RESULT:.+]] = linalg.matmul
 //  CHECK-SAME:         ins(%[[PADDED_LHS]], %[[PADDED_RHS]] : tensor<12x20xf32>, tensor<20x16xf32>)
 //  CHECK-SAME:         outs(%[[PADDED_DST]] : tensor<12x16xf32>)
-//       CHECK:      %[[RESULT:.+]] = subtensor %[[PADDED_RESULT]][0, 0] [11, 13] [1, 1] : tensor<12x16xf32> to tensor<11x13xf32>
+//       CHECK:      %[[RESULT:.+]] = tensor.extract_slice %[[PADDED_RESULT]][0, 0] [11, 13] [1, 1] : tensor<12x16xf32> to tensor<11x13xf32>
 //       CHECK:      return %[[RESULT]] : tensor<11x13xf32>
 
 // -----
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/pad_tensor_to_tensor.mlir b/iree/compiler/Dialect/Flow/Transforms/test/pad_tensor_to_tensor.mlir
index 840b982..e6ad623 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/pad_tensor_to_tensor.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/pad_tensor_to_tensor.mlir
@@ -29,7 +29,7 @@
 //   CHECK-DAG:   %[[RD1:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[D1]]]
 //       CHECK:   %[[INIT:.+]] = linalg.init_tensor [%[[RD0]], %[[RD1]]]
 //       CHECK:   %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[VAL]])
-//       CHECK:   %[[RESULT:.+]] = subtensor_insert %[[ARG0]] into %[[FILL]][4, %[[ARG2]]] [%[[D0]], %[[D1]]] [1, 1]
+//       CHECK:   %[[RESULT:.+]] = tensor.insert_slice %[[ARG0]] into %[[FILL]][4, %[[ARG2]]] [%[[D0]], %[[D1]]] [1, 1]
 //       CHECK:   return %[[RESULT]]
 
 // -----
@@ -54,5 +54,5 @@
 //   CHECK-DAG:   %[[VAL:.+]] = tensor.extract %[[ARG1]]
 //       CHECK:   %[[INIT:.+]] = linalg.init_tensor [18, 12]
 //       CHECK:   %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[VAL]])
-//       CHECK:   %[[RESULT:.+]] = subtensor_insert %[[ARG0]] into %[[FILL]][4, 5] [12, 4] [1, 1]
+//       CHECK:   %[[RESULT:.+]] = tensor.insert_slice %[[ARG0]] into %[[FILL]][4, 5] [12, 4] [1, 1]
 //       CHECK:   return %[[RESULT]]
diff --git a/iree/compiler/Dialect/HAL/IR/BUILD b/iree/compiler/Dialect/HAL/IR/BUILD
index dbffc40..72e3926 100644
--- a/iree/compiler/Dialect/HAL/IR/BUILD
+++ b/iree/compiler/Dialect/HAL/IR/BUILD
@@ -132,6 +132,7 @@
         "//iree/compiler/Dialect/IREE/IR:td_files",
         "@llvm-project//mlir:OpBaseTdFiles",
         "@llvm-project//mlir:StdOpsTdFiles",
+        "@llvm-project//mlir:ViewLikeInterfaceTdFiles",
     ],
 )
 
@@ -156,6 +157,7 @@
         "@llvm-project//mlir:include/mlir/IR/SymbolInterfaces.td",
         "@llvm-project//mlir:OpBaseTdFiles",
         "@llvm-project//mlir:StdOpsTdFiles",
+        "@llvm-project//mlir:ViewLikeInterfaceTdFiles",
     ],
 )
 
@@ -218,6 +220,7 @@
         "//iree/compiler/Dialect/Shape/IR:td_files",
         "@llvm-project//mlir:include/mlir/IR/SymbolInterfaces.td",
         "@llvm-project//mlir:StdOpsTdFiles",
+        "@llvm-project//mlir:ViewLikeInterfaceTdFiles",
     ],
 )
 
diff --git a/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp b/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp
index 46a5d86..9f1a476 100644
--- a/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp
+++ b/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp
@@ -52,7 +52,7 @@
 //===----------------------------------------------------------------------===//
 
 namespace {
-/// Converts mhlo.concatenate operation to subtensor ops + subtensor_insert ops.
+/// Converts mhlo.concatenate operation to extract_slice ops + insert_slice ops.
 struct ConcatenateOpConversion
     : public OpConversionPattern<mhlo::ConcatenateOp> {
   using OpConversionPattern<mhlo::ConcatenateOp>::OpConversionPattern;
@@ -93,8 +93,8 @@
     for (auto arg : args) {
       offsets[dim] = accBound;
       sizes[dim] = rewriter.create<memref::DimOp>(loc, arg, dim);
-      result = rewriter.create<SubTensorInsertOp>(loc, arg, result, offsets,
-                                                  sizes, strides);
+      result = rewriter.create<tensor::InsertSliceOp>(loc, arg, result, offsets,
+                                                      sizes, strides);
       accBound = rewriter.create<AddIOp>(loc, accBound, sizes[dim]);
     }
     rewriter.replaceOp(op, result);
diff --git a/third_party/llvm-bazel b/third_party/llvm-bazel
index 39c6fa3..30724fe 160000
--- a/third_party/llvm-bazel
+++ b/third_party/llvm-bazel
@@ -1 +1 @@
-Subproject commit 39c6fa385c29b90a97511d4aff7269e40ad210eb
+Subproject commit 30724fe230e9af8e408560f23435d820b5c99924
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 4c4f1ae..36111f2 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 4c4f1ae93ea7477ccb4772007fc78313f5a0644f
+Subproject commit 36111f28edb1182273c6409c3fb7808e0e9cbd60
diff --git a/third_party/mlir-hlo b/third_party/mlir-hlo
index 82ecad2..b6a8145 160000
--- a/third_party/mlir-hlo
+++ b/third_party/mlir-hlo
@@ -1 +1 @@
-Subproject commit 82ecad259ef0566a4500fa59192457ff69f83fe6
+Subproject commit b6a8145dafcd927c0600af35d811aa5ef8297d6c
diff --git a/third_party/tensorflow b/third_party/tensorflow
index ff964e9..3ef6fbf 160000
--- a/third_party/tensorflow
+++ b/third_party/tensorflow
@@ -1 +1 @@
-Subproject commit ff964e92797b4d7e5c23a3143bfc74d1572e6c7b
+Subproject commit 3ef6fbfd02716f774024afae711383fdb0d8a30c