Integrate LLVM at llvm/llvm-project@c6a8bec4c578

Updates LLVM usage to match
[c6a8bec4c578](https://github.com/llvm/llvm-project/commit/c6a8bec4c578)

PiperOrigin-RevId: 415785157
diff --git a/SUBMODULE_VERSIONS.txt b/SUBMODULE_VERSIONS.txt
index 08ffcc2..7b22a43 100644
--- a/SUBMODULE_VERSIONS.txt
+++ b/SUBMODULE_VERSIONS.txt
@@ -4,7 +4,7 @@
 aa533abfd4232b01f9e57041d70114d5a77e6de0 third_party/googletest
 88b845dee001723c4a0db1fe5477de735b6d3bb0 third_party/liburing
 f8f760f7387d2cc56a2fc7b1be313a3bf3f7f58c third_party/libyaml
-def8b952ebc00b0ad0fa4196ae27b11a385087ad third_party/llvm-project
+c6a8bec4c578a7c12e4458b161fce7b1704804a2 third_party/llvm-project
 8c636d9692e2a50eb03d1e0a9809ffde90dbd2c2 third_party/mlir-hlo
 3f701faace7addc75d16dea8a6cd769fa5b3f260 third_party/musl
 59aa99860c60bd171b9565e9920f125fdb749267 third_party/pybind11
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/DirectLoweringPatterns.cpp b/integrations/tensorflow/iree_tf_compiler/TF/DirectLoweringPatterns.cpp
index 87fd143..7cdeb83 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/DirectLoweringPatterns.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/TF/DirectLoweringPatterns.cpp
@@ -47,7 +47,7 @@
           op, "could not compute reassociation indices");
     }
 
-    rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
+    rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
         op, resultType, op.input(), *reassociationIndices);
     return success();
   }
@@ -99,7 +99,7 @@
           op, "could not compute reassociation indices");
     }
 
-    rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
+    rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
         op, expandedType, op.input(), *reassociationIndices);
     return success();
   }
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/test/direct_lowering.mlir b/integrations/tensorflow/iree_tf_compiler/TF/test/direct_lowering.mlir
index 5316387..40c241c 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/test/direct_lowering.mlir
+++ b/integrations/tensorflow/iree_tf_compiler/TF/test/direct_lowering.mlir
@@ -2,7 +2,7 @@
 
 // CHECK-LABEL: @expand_dims
 func @expand_dims(%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?x1xf32> {
-  // CHECK: %[[R:.*]] = linalg.tensor_expand_shape %arg0 {{\[}}[0], [1], [2, 3]] : tensor<?x?x?xf32> into tensor<?x?x?x1xf32>
+  // CHECK: %[[R:.*]] = tensor.expand_shape %arg0 {{\[}}[0], [1], [2, 3]] : tensor<?x?x?xf32> into tensor<?x?x?x1xf32>
   // CHECK: return %[[R]]
   %axis = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> (tensor<i32>)
   %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor<?x?x?xf32>, tensor<i32>) -> (tensor<?x?x?x1xf32>)
@@ -23,7 +23,7 @@
 // -----
 // CHECK-LABEL: @squeeze
 func @squeeze(%arg0 : tensor<?x1x1x1001xf32>) -> tensor<?x1001xf32> {
-  // CHECK: %[[R:.*]] = linalg.tensor_collapse_shape %arg0 {{\[}}[0], [1, 2, 3]] : tensor<?x1x1x1001xf32> into tensor<?x1001xf32>
+  // CHECK: %[[R:.*]] = tensor.collapse_shape %arg0 {{\[}}[0], [1, 2, 3]] : tensor<?x1x1x1001xf32> into tensor<?x1001xf32>
   // CHECK: return %[[R]]
   %0 = "tf.Squeeze"(%arg0) {device = "", squeeze_dims = [1, 2]} : (tensor<?x1x1x1001xf32>) -> tensor<?x1001xf32>
   return %0 : tensor<?x1001xf32>
diff --git a/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp b/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp
index 8cdce87..7664a5f 100644
--- a/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp
+++ b/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp
@@ -56,8 +56,8 @@
   // TODO(ravishankarm): Maybe this is too aggressive, might have to switch this
   // to have a white-list instead of blacklist.
   for (Operation *user : op->getUsers()) {
-    if (isa<IREE::Flow::DispatchTensorStoreOp, linalg::TensorCollapseShapeOp,
-            linalg::TensorExpandShapeOp>(user)) {
+    if (isa<IREE::Flow::DispatchTensorStoreOp, tensor::CollapseShapeOp,
+            tensor::ExpandShapeOp>(user)) {
       return false;
     }
   }
@@ -82,7 +82,7 @@
   return TypeSwitch<Operation *, bool>(definingOp)
       .Case<arith::ConstantOp>(
           [&](arith::ConstantOp constantOp) { return true; })
-      .Case<linalg::TensorCollapseShapeOp, linalg::TensorExpandShapeOp>(
+      .Case<tensor::CollapseShapeOp, tensor::ExpandShapeOp>(
           [&](auto op) { return isFromReadOnlyTensor(op.src(), plan); })
       .Case<tensor::ExtractSliceOp>([&](tensor::ExtractSliceOp sliceOp) {
         return isFromReadOnlyTensor(sliceOp.source(), plan);
@@ -159,7 +159,7 @@
       // reshapes not working well together, but there is no comment about why
       // this was added with the change that added this.
       Operation *op = v.getDefiningOp();
-      if (op && isa<linalg::TensorCollapseShapeOp, linalg::TensorExpandShapeOp>(
+      if (op && isa<tensor::CollapseShapeOp, tensor::ExpandShapeOp>(
                     v.getDefiningOp())) {
         return false;
       }
@@ -519,7 +519,7 @@
             [&](IREE::LinalgExt::LinalgExtOp linalgExtOp) {
               return analyseLinalgExtOps(linalgExtOp, plan);
             })
-        .Case<linalg::TensorCollapseShapeOp, linalg::TensorExpandShapeOp>(
+        .Case<tensor::CollapseShapeOp, tensor::ExpandShapeOp>(
             [&](auto reshapeOp) {
               return analyseSingleOperandResultOp(reshapeOp.src(),
                                                   reshapeOp.result(), plan);
diff --git a/iree/compiler/Codegen/Common/CleanupBufferAllocViewPass.cpp b/iree/compiler/Codegen/Common/CleanupBufferAllocViewPass.cpp
index 5484102..2d8b899 100644
--- a/iree/compiler/Codegen/Common/CleanupBufferAllocViewPass.cpp
+++ b/iree/compiler/Codegen/Common/CleanupBufferAllocViewPass.cpp
@@ -117,10 +117,9 @@
     : public CleanupBufferAllocViewBase<CleanupBufferAllocViewPass> {
   void runOnOperation() override {
     OwningRewritePatternList patterns(&getContext());
-    patterns.insert<
-        FoldReshapeIntoInterfaceTensorLoad<linalg::TensorCollapseShapeOp>,
-        FoldReshapeIntoInterfaceTensorLoad<linalg::TensorExpandShapeOp>,
-        RemoveDeadMemAllocs>(&getContext());
+    patterns.insert<FoldReshapeIntoInterfaceTensorLoad<tensor::CollapseShapeOp>,
+                    FoldReshapeIntoInterfaceTensorLoad<tensor::ExpandShapeOp>,
+                    RemoveDeadMemAllocs>(&getContext());
     if (failed(applyPatternsAndFoldGreedily(getOperation(),
                                             std::move(patterns)))) {
       return signalPassFailure();
diff --git a/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp b/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp
index 1ca48ce..e2d2f40 100644
--- a/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp
+++ b/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp
@@ -118,8 +118,8 @@
 static Value getReverseOfReshapeOp(OpBuilder &b, TensorReshapeOpTy reshapeOp,
                                    Value resultBuffer) {
   using ReverseReshapeOpTy = typename std::conditional<
-      std::is_same<TensorReshapeOpTy, linalg::TensorCollapseShapeOp>::value,
-      linalg::TensorExpandShapeOp, linalg::TensorCollapseShapeOp>::type;
+      std::is_same<TensorReshapeOpTy, tensor::CollapseShapeOp>::value,
+      tensor::ExpandShapeOp, tensor::CollapseShapeOp>::type;
   return b.create<ReverseReshapeOpTy>(reshapeOp.getLoc(),
                                       reshapeOp.getSrcType(), resultBuffer,
                                       reshapeOp.reassociation());
@@ -244,7 +244,7 @@
               }
               return nullptr;
             })
-            .Case<linalg::TensorCollapseShapeOp, linalg::TensorExpandShapeOp>(
+            .Case<tensor::CollapseShapeOp, tensor::ExpandShapeOp>(
                 [&](auto reshapeOp) {
                   return getReverseOfReshapeOp(b, reshapeOp, resultBuffer);
                 })
diff --git a/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp b/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp
index bdd81a9..5f3a666 100644
--- a/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp
+++ b/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp
@@ -280,7 +280,7 @@
       reshapeOp.getSrcType(), {},
       resultBuffer.getType().cast<MemRefType>().getMemorySpace());
   using ReverseReshapeOpTy = typename std::conditional<
-      std::is_same<TensorReshapeOpTy, linalg::TensorCollapseShapeOp>::value,
+      std::is_same<TensorReshapeOpTy, tensor::CollapseShapeOp>::value,
       memref::ExpandShapeOp, memref::CollapseShapeOp>::type;
   return b.create<ReverseReshapeOpTy>(reshapeOp.getLoc(), memrefType,
                                       resultBuffer, reshapeOp.reassociation());
@@ -365,7 +365,7 @@
                   IREE::LinalgExt::LinalgExtOp, tensor::InsertSliceOp,
                   vector::TransferWriteOp>(
                 [&](auto op) { return resultBuffer; })
-            .Case<linalg::TensorCollapseShapeOp, linalg::TensorExpandShapeOp>(
+            .Case<tensor::CollapseShapeOp, tensor::ExpandShapeOp>(
                 [&](auto reshapeOp) {
                   return getReverseOfReshapeOp(b, reshapeOp, resultBuffer);
                 })
@@ -431,7 +431,7 @@
   auto reshapeResultType = getMemrefTypeForTensor(
       resultTensorType, {}, inputBufferType.getMemorySpace());
   using ReshapeOpTy = typename std::conditional<
-      std::is_same<TensorReshapeOpTy, linalg::TensorCollapseShapeOp>::value,
+      std::is_same<TensorReshapeOpTy, tensor::CollapseShapeOp>::value,
       memref::CollapseShapeOp, memref::ExpandShapeOp>::type;
   Value bufferReshape = b.create<ReshapeOpTy>(loc, reshapeResultType,
                                               inputBuffer, op.reassociation());
@@ -472,7 +472,7 @@
             tensor::CastOp>([&](auto singleResultOp) -> SmallVector<Value, 4> {
         return {getAliasingBufferForResult(b, singleResultOp, bvm)};
       })
-      .Case<linalg::TensorCollapseShapeOp, linalg::TensorExpandShapeOp>(
+      .Case<tensor::CollapseShapeOp, tensor::ExpandShapeOp>(
           [&](auto reshapeOp) -> SmallVector<Value, 4> {
             return {getAliasingBufferForReshapeResult(b, reshapeOp, bvm)};
           })
@@ -943,8 +943,8 @@
           }
           return convertScfIfOp(b, ifOp, bvm, plan);
         })
-        .Case<IREE::Flow::DispatchTensorLoadOp, linalg::TensorCollapseShapeOp,
-              linalg::TensorExpandShapeOp, tensor::ExtractSliceOp,
+        .Case<IREE::Flow::DispatchTensorLoadOp, tensor::CollapseShapeOp,
+              tensor::ExpandShapeOp, tensor::ExtractSliceOp,
               tensor::CastOp>([&](auto aliasingOp) {
           auto aliasingBuffers =
               getAliasingBuffersForResults(b, aliasingOp, bvm);
diff --git a/iree/compiler/Codegen/Common/SetNumWorkgroupsPass.cpp b/iree/compiler/Codegen/Common/SetNumWorkgroupsPass.cpp
index bb4dc4b..a1a5507 100644
--- a/iree/compiler/Codegen/Common/SetNumWorkgroupsPass.cpp
+++ b/iree/compiler/Codegen/Common/SetNumWorkgroupsPass.cpp
@@ -129,7 +129,7 @@
         Value one = b.create<arith::ConstantIndexOp>(loc, 1);
         std::array<Value, 3> returnValues = {one, one, one};
         for (auto ts : llvm::enumerate(currWorkloadPerWorkgroup)) {
-          returnValues[ts.index()] = linalg::applyMapToValues(
+          returnValues[ts.index()] = applyMapToValues(
               b, loc,
               AffineMap::get(0, 1,
                              b.getAffineSymbolExpr(0).ceilDiv(ts.value())),
diff --git a/iree/compiler/Codegen/Common/test/canonicalize_interface_load_store.mlir b/iree/compiler/Codegen/Common/test/canonicalize_interface_load_store.mlir
index 6db17a7..f9fe1a8 100644
--- a/iree/compiler/Codegen/Common/test/canonicalize_interface_load_store.mlir
+++ b/iree/compiler/Codegen/Common/test/canonicalize_interface_load_store.mlir
@@ -10,8 +10,8 @@
   %2 = hal.interface.binding.subspan @interface_io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:3x3x96xf32>
   // CHECK: %[[LOAD:.+]] = flow.dispatch.tensor.load %[[ARG]], {{.*}} : !flow.dispatch.tensor<readonly:3x3x96xf32> -> tensor<3x3x96xf32>
   %3 = flow.dispatch.tensor.load %1, offsets=[], sizes =[], strides=[] : !flow.dispatch.tensor<readonly:3x3x1x96xf32> -> tensor<3x3x1x96xf32>
-  %4 = linalg.tensor_collapse_shape %3 [[0, 1, 2, 3]] : tensor<3x3x1x96xf32> into tensor<864xf32>
-  %5 = linalg.tensor_expand_shape %4 [[0, 1, 2]] : tensor<864xf32> into tensor<3x3x96xf32>
+  %4 = tensor.collapse_shape %3 [[0, 1, 2, 3]] : tensor<3x3x1x96xf32> into tensor<864xf32>
+  %5 = tensor.expand_shape %4 [[0, 1, 2]] : tensor<864xf32> into tensor<3x3x96xf32>
   //  CHECK: flow.dispatch.tensor.store %[[LOAD]], {{.*}}
   flow.dispatch.tensor.store %5, %2, offsets = [%c0, %c0, %c0], sizes = [%c1, %c1, %c1], strides = [%c1, %c1, %c1] : tensor<3x3x96xf32> -> !flow.dispatch.tensor<writeonly:3x3x96xf32>
   return
@@ -34,10 +34,10 @@
   %1 = hal.interface.binding.subspan @interface_io::@arg0[%c0] : !flow.dispatch.tensor<readonly:6x3x1x96xf32>
   %2 = hal.interface.binding.subspan @interface_io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:3x3x96xf32>
   %3 = flow.dispatch.tensor.load %1, offsets = [%c3, %c0, %c0, %c0], sizes = [%c3, %c3, %c1, %c96], strides = [%c1, %c1, %c1, %c1] : !flow.dispatch.tensor<readonly:6x3x1x96xf32> -> tensor<3x3x1x96xf32>
-  // CHECK: linalg.tensor_collapse_shape
-  // CHECK: linalg.tensor_expand_shape
-  %4 = linalg.tensor_collapse_shape %3 [[0, 1, 2, 3]] : tensor<3x3x1x96xf32> into tensor<864xf32>
-  %5 = linalg.tensor_expand_shape %4 [[0, 1, 2]] : tensor<864xf32> into tensor<3x3x96xf32>
+  // CHECK: tensor.collapse_shape
+  // CHECK: tensor.expand_shape
+  %4 = tensor.collapse_shape %3 [[0, 1, 2, 3]] : tensor<3x3x1x96xf32> into tensor<864xf32>
+  %5 = tensor.expand_shape %4 [[0, 1, 2]] : tensor<864xf32> into tensor<3x3x96xf32>
   flow.dispatch.tensor.store %5, %2, offsets = [%c0, %c0, %c0], sizes = [%c1, %c1, %c1], strides = [%c1, %c1, %c1] : tensor<3x3x96xf32> -> !flow.dispatch.tensor<writeonly:3x3x96xf32>
   return
 }
@@ -59,10 +59,10 @@
   %1 = hal.interface.binding.subspan @interface_io::@arg0[%c0] : !flow.dispatch.tensor<readonly:?x?x96xf32>{%dim0, %dim1}
   %2 = hal.interface.binding.subspan @interface_io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:?x12x8xf32>{%dim2}
   %3 = flow.dispatch.tensor.load %1, offsets=[], sizes =[], strides=[] : !flow.dispatch.tensor<readonly:?x?x96xf32> -> tensor<?x?x96xf32>
-  // CHECK: linalg.tensor_collapse_shape
-  // CHECK: linalg.tensor_expand_shape
-  %4 = linalg.tensor_collapse_shape %3 [[0, 1], [2]] : tensor<?x?x96xf32> into tensor<?x96xf32>
-  %5 = linalg.tensor_expand_shape %4 [[0], [1, 2]] : tensor<?x96xf32> into tensor<?x12x8xf32>
+  // CHECK: tensor.collapse_shape
+  // CHECK: tensor.expand_shape
+  %4 = tensor.collapse_shape %3 [[0, 1], [2]] : tensor<?x?x96xf32> into tensor<?x96xf32>
+  %5 = tensor.expand_shape %4 [[0], [1, 2]] : tensor<?x96xf32> into tensor<?x12x8xf32>
   flow.dispatch.tensor.store %5, %2, offsets = [%c0, %c0, %c0], sizes = [%c1, %c1, %c1], strides = [%c1, %c1, %c1] : tensor<?x12x8xf32> -> !flow.dispatch.tensor<writeonly:?x12x8xf32>
   return
 }
diff --git a/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir b/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir
index 1712106..595a797 100644
--- a/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir
+++ b/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir
@@ -171,7 +171,7 @@
   %0 = hal.interface.binding.subspan @io::@arg0[%c0] : !flow.dispatch.tensor<readonly:12xi32>
   %1 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:3x4xi32>
   %2 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:12xi32> -> tensor<12xi32>
-  %3 = linalg.tensor_expand_shape %2 [[0, 1]] : tensor<12xi32> into tensor<3x4xi32>
+  %3 = tensor.expand_shape %2 [[0, 1]] : tensor<12xi32> into tensor<3x4xi32>
   flow.dispatch.tensor.store %3, %1, offsets = [], sizes = [], strides = [] : tensor<3x4xi32> -> !flow.dispatch.tensor<writeonly:3x4xi32>
   return
 }
@@ -183,7 +183,7 @@
 //  CHECK-DAG:   %[[ARG0:.+]] = hal.interface.binding.subspan @io::@arg0
 //  CHECK-DAG:   %[[RET0:.+]] = hal.interface.binding.subspan @io::@ret0
 //      CHECK:   %[[SOURCE:.+]] = flow.dispatch.tensor.load %[[ARG0]]
-//      CHECK:   %[[RESHAPE:.+]] = linalg.tensor_expand_shape %[[SOURCE]]
+//      CHECK:   %[[RESHAPE:.+]] = tensor.expand_shape %[[SOURCE]]
 //      CHECK:   flow.dispatch.tensor.store %[[RESHAPE]], %[[RET0]]
 
 // -----
@@ -197,7 +197,7 @@
   %0 = hal.interface.binding.subspan @io::@arg0[%c0] : !flow.dispatch.tensor<readonly:12xi32>
   %1 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:3x4xi32>
   %2 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:12xi32> -> tensor<12xi32>
-  %3 = linalg.tensor_expand_shape %2 [[0, 1]] : tensor<12xi32> into tensor<3x4xi32>
+  %3 = tensor.expand_shape %2 [[0, 1]] : tensor<12xi32> into tensor<3x4xi32>
   %4 = linalg.init_tensor [3, 4] : tensor<3x4xi32>
   %5 = linalg.generic {
     indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
@@ -219,7 +219,7 @@
 //  CHECK-DAG:   %[[RET0:.+]] = hal.interface.binding.subspan @io::@ret0
 //      CHECK:   %[[TARGET:.+]] = flow.dispatch.tensor.load %[[RET0]]
 //      CHECK:   %[[SOURCE:.+]] = flow.dispatch.tensor.load %[[ARG0]]
-//      CHECK:   %[[RESHAPE:.+]] = linalg.tensor_expand_shape %[[SOURCE]]
+//      CHECK:   %[[RESHAPE:.+]] = tensor.expand_shape %[[SOURCE]]
 //      CHECK:   %[[GENERIC:.+]] = linalg.generic
 // CHECK-SAME:       ins(%[[RESHAPE]]
 // CHECK-SAME:       outs(%[[TARGET]]
@@ -237,7 +237,7 @@
   %1 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:3x4xi32>
   %2 = hal.interface.binding.subspan @io::@ret1[%c0] : !flow.dispatch.tensor<writeonly:3x4xi32>
   %3 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:12xi32> -> tensor<12xi32>
-  %4 = linalg.tensor_expand_shape %3 [[0, 1]] : tensor<12xi32> into tensor<3x4xi32>
+  %4 = tensor.expand_shape %3 [[0, 1]] : tensor<12xi32> into tensor<3x4xi32>
   %5 = linalg.init_tensor [3, 4] : tensor<3x4xi32>
   %6 = linalg.generic {
     indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
@@ -262,7 +262,7 @@
 //  CHECK-DAG:   %[[TARGET:.+]] = flow.dispatch.tensor.load %[[RET0]]
 //  CHECK-DAG:   %[[RET1:.+]] = hal.interface.binding.subspan @io::@ret1
 //      CHECK:   %[[SOURCE:.+]] = flow.dispatch.tensor.load %[[ARG0]]
-//      CHECK:   %[[RESHAPE:.+]] = linalg.tensor_expand_shape %[[SOURCE]]
+//      CHECK:   %[[RESHAPE:.+]] = tensor.expand_shape %[[SOURCE]]
 //      CHECK:   %[[GENERIC:.+]] = linalg.generic
 // CHECK-SAME:       ins(%[[RESHAPE]]
 // CHECK-SAME:       outs(%[[TARGET]]
@@ -289,7 +289,7 @@
       %5 = arith.addi %arg0, %arg0 : i32
       linalg.yield %5 : i32
     } -> tensor<3x4xi32>
-  %5 = linalg.tensor_collapse_shape %4 [[0, 1]] : tensor<3x4xi32> into tensor<12xi32>
+  %5 = tensor.collapse_shape %4 [[0, 1]] : tensor<3x4xi32> into tensor<12xi32>
   flow.dispatch.tensor.store %5, %1, offsets = [], sizes = [], strides = [] : tensor<12xi32> -> !flow.dispatch.tensor<writeonly:12xi32>
   return
 }
@@ -302,11 +302,11 @@
 //  CHECK-DAG:   %[[RET0:.+]] = hal.interface.binding.subspan @io::@ret0
 //  CHECK-DAG:   %[[SOURCE:.+]] = flow.dispatch.tensor.load %[[ARG0]]
 //  CHECK-DAG:   %[[TARGET:.+]] = flow.dispatch.tensor.load %[[RET0]]
-//  CHECK-DAG:   %[[RESHAPE_EXPAND:.+]] = linalg.tensor_expand_shape %[[TARGET]] {{\[}}[0, 1]{{\]}}
+//  CHECK-DAG:   %[[RESHAPE_EXPAND:.+]] = tensor.expand_shape %[[TARGET]] {{\[}}[0, 1]{{\]}}
 //      CHECK:   %[[GENERIC:.+]] = linalg.generic
 // CHECK-SAME:       ins(%[[SOURCE]]
 // CHECK-SAME:       outs(%[[RESHAPE_EXPAND]]
-//      CHECK:   %[[RESHAPE_COLLAPSE:.+]] = linalg.tensor_collapse_shape %[[GENERIC]] {{\[}}[0, 1]{{\]}}
+//      CHECK:   %[[RESHAPE_COLLAPSE:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1]{{\]}}
 //      CHECK:   flow.dispatch.tensor.store %[[RESHAPE_COLLAPSE]], %[[RET0]]
 
 // -----
diff --git a/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir b/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir
index 0ee94d6..e2484ab 100644
--- a/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir
+++ b/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir
@@ -644,7 +644,7 @@
   %0 = hal.interface.binding.subspan @io::@arg0[%c0] : !flow.dispatch.tensor<readonly:12xi32>
   %1 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:3x4xi32>
   %2 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:12xi32> -> tensor<12xi32>
-  %3 = linalg.tensor_expand_shape %2 [[0, 1]] : tensor<12xi32> into tensor<3x4xi32>
+  %3 = tensor.expand_shape %2 [[0, 1]] : tensor<12xi32> into tensor<3x4xi32>
   flow.dispatch.tensor.store %3, %1, offsets = [], sizes = [], strides = [] : tensor<3x4xi32> -> !flow.dispatch.tensor<writeonly:3x4xi32>
   return
 }
@@ -669,7 +669,7 @@
   %0 = hal.interface.binding.subspan @io::@arg0[%c0] : !flow.dispatch.tensor<readonly:12xi32>
   %1 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:3x4xi32>
   %2 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:12xi32> -> tensor<12xi32>
-  %3 = linalg.tensor_expand_shape %2 [[0, 1]] : tensor<12xi32> into tensor<3x4xi32>
+  %3 = tensor.expand_shape %2 [[0, 1]] : tensor<12xi32> into tensor<3x4xi32>
   %4 = linalg.init_tensor [3, 4] : tensor<3x4xi32>
   %5 = linalg.generic {
     indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
@@ -707,7 +707,7 @@
   %1 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:3x4xi32>
   %2 = hal.interface.binding.subspan @io::@ret1[%c0] : !flow.dispatch.tensor<writeonly:3x4xi32>
   %3 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:12xi32> -> tensor<12xi32>
-  %4 = linalg.tensor_expand_shape %3 [[0, 1]] : tensor<12xi32> into tensor<3x4xi32>
+  %4 = tensor.expand_shape %3 [[0, 1]] : tensor<12xi32> into tensor<3x4xi32>
   %5 = linalg.init_tensor [3, 4] : tensor<3x4xi32>
   %6 = linalg.generic {
     indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
@@ -757,7 +757,7 @@
       %5 = arith.addi %arg0, %arg0 : i32
       linalg.yield %5 : i32
     } -> tensor<3x4xi32>
-  %5 = linalg.tensor_collapse_shape %4 [[0, 1]] : tensor<3x4xi32> into tensor<12xi32>
+  %5 = tensor.collapse_shape %4 [[0, 1]] : tensor<3x4xi32> into tensor<12xi32>
   flow.dispatch.tensor.store %5, %1, offsets = [], sizes = [], strides = [] : tensor<12xi32> -> !flow.dispatch.tensor<writeonly:12xi32>
   return
 }
@@ -786,7 +786,7 @@
   %1 = hal.interface.binding.subspan @io::@arg1[%c0] : !flow.dispatch.tensor<readonly:2x3xf32>
   %2 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:1x3xf32>
   %3 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:1x1x2xf32> -> tensor<1x1x2xf32>
-  %4 = linalg.tensor_collapse_shape %3 [[0, 1], [2]] : tensor<1x1x2xf32> into tensor<1x2xf32>
+  %4 = tensor.collapse_shape %3 [[0, 1], [2]] : tensor<1x1x2xf32> into tensor<1x2xf32>
   %workgroup_size_x = hal.interface.workgroup.size[0] : index
   %workgroup_size_y = hal.interface.workgroup.size[1] : index
   %workgroup_id_x = hal.interface.workgroup.id[0] : index
@@ -1072,7 +1072,7 @@
   %0 = hal.interface.binding.subspan @io::@arg0[%c0] : !flow.dispatch.tensor<readonly:1x5x3x1xf32>
   %1 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:5x5xf32>
   %2 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:1x5x3x1xf32> -> tensor<1x5x3x1xf32>
-  %3 = linalg.tensor_collapse_shape %2 [[0, 1], [2, 3]] : tensor<1x5x3x1xf32> into tensor<5x3xf32>
+  %3 = tensor.collapse_shape %2 [[0, 1], [2, 3]] : tensor<1x5x3x1xf32> into tensor<5x3xf32>
   %workgroup_size_x = hal.interface.workgroup.size[0] : index
   %workgroup_size_y = hal.interface.workgroup.size[1] : index
   %workgroup_id_x = hal.interface.workgroup.id[0] : index
@@ -1284,7 +1284,7 @@
   %0 = hal.interface.binding.subspan @io::@ro0[%c0] : !flow.dispatch.tensor<readonly:?x?xf32>{%dim0, %dim1}
   %1 = hal.interface.binding.subspan @io::@wo0[%c0] : !flow.dispatch.tensor<writeonly:?xf32>{%dim2}
   %2 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<?x?xf32>
-  %3 = linalg.tensor_collapse_shape %2 [[0, 1]]
+  %3 = tensor.collapse_shape %2 [[0, 1]]
       : tensor<?x?xf32> into tensor<?xf32>
   %4 = tensor.dim %3, %c0 : tensor<?xf32>
   %5 = linalg.init_tensor [%4] : tensor<?xf32>
@@ -1919,11 +1919,11 @@
         ^bb0(%arg3: f32, %arg4: f32):  // no predecessors
           linalg.yield %arg3 : f32
         } -> tensor<1x16x16x3x3x8xf32>
-        %19 = linalg.tensor_collapse_shape %18 [[0, 1, 2], [3, 4, 5]] : tensor<1x16x16x3x3x8xf32> into tensor<256x72xf32>
-        %20 = linalg.tensor_collapse_shape %14 [[0, 1, 2], [3]] : tensor<3x3x8x4xf32> into tensor<72x4xf32>
-        %21 = linalg.tensor_collapse_shape %16 [[0, 1, 2], [3]] : tensor<1x16x16x4xf32> into tensor<256x4xf32>
+        %19 = tensor.collapse_shape %18 [[0, 1, 2], [3, 4, 5]] : tensor<1x16x16x3x3x8xf32> into tensor<256x72xf32>
+        %20 = tensor.collapse_shape %14 [[0, 1, 2], [3]] : tensor<3x3x8x4xf32> into tensor<72x4xf32>
+        %21 = tensor.collapse_shape %16 [[0, 1, 2], [3]] : tensor<1x16x16x4xf32> into tensor<256x4xf32>
         %22 = linalg.matmul ins(%19, %20 : tensor<256x72xf32>, tensor<72x4xf32>) outs(%21 : tensor<256x4xf32>) -> tensor<256x4xf32>
-        %23 = linalg.tensor_expand_shape %22 [[0, 1, 2], [3]] : tensor<256x4xf32> into tensor<1x16x16x4xf32>
+        %23 = tensor.expand_shape %22 [[0, 1, 2], [3]] : tensor<256x4xf32> into tensor<1x16x16x4xf32>
         %24 = tensor.cast %23 : tensor<1x16x16x4xf32> to tensor<1x?x?x?xf32>
         flow.dispatch.tensor.store %24, %2, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %c16, %c16, %c4], strides = [1, 1, 1, 1] : tensor<1x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:1x112x112x32xf32>
       }
diff --git a/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp b/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp
index 8e5795f..9d1b8c1 100644
--- a/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp
@@ -85,7 +85,8 @@
       arith::populateArithmeticToLLVMConversionPatterns(converter,
                                                         llvmPatterns);
       populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
-      populateGpuToROCDLConversionPatterns(converter, llvmPatterns);
+      populateGpuToROCDLConversionPatterns(converter, llvmPatterns,
+                                           gpu::amd::Runtime::Unknown);
       LLVMConversionTarget target(getContext());
       populateStdToLLVMFuncOpConversionPattern(converter, llvmPatterns);
       configureGpuToROCDLConversionLegality(target);
diff --git a/iree/compiler/Dialect/Flow/Transforms/ConvertConv2D1x1ToMatmulPass.cpp b/iree/compiler/Dialect/Flow/Transforms/ConvertConv2D1x1ToMatmulPass.cpp
index ccf1568..74faa10 100644
--- a/iree/compiler/Dialect/Flow/Transforms/ConvertConv2D1x1ToMatmulPass.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/ConvertConv2D1x1ToMatmulPass.cpp
@@ -70,18 +70,18 @@
     Value output = convOp.getOutputOperand(0)->get();
     auto loc = convOp.getLoc();
 
-    Value reshapedInput = rewriter.create<linalg::TensorCollapseShapeOp>(
+    Value reshapedInput = rewriter.create<tensor::CollapseShapeOp>(
         loc, reshapedInputType, input, reassociationIndices);
-    Value reshapedFilter = rewriter.create<linalg::TensorCollapseShapeOp>(
+    Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
         loc, reshapedFilterType, filter, reassociationIndices);
-    Value reshapedOutput = rewriter.create<linalg::TensorCollapseShapeOp>(
+    Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
         loc, reshapedOutputType, output, reassociationIndices);
 
     auto matmulResult = rewriter.create<linalg::MatmulOp>(
         loc, reshapedOutputType, ArrayRef<Value>{reshapedInput, reshapedFilter},
         ArrayRef<Value>{reshapedOutput});
 
-    auto reshapedResult = rewriter.create<linalg::TensorExpandShapeOp>(
+    auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
         loc, outputShapeType, matmulResult.getResults()[0],
         reassociationIndices);
 
diff --git a/iree/compiler/Dialect/Flow/Transforms/ConvertConv2DToImg2ColPass.cpp b/iree/compiler/Dialect/Flow/Transforms/ConvertConv2DToImg2ColPass.cpp
index 353bca7..9e7cff4 100644
--- a/iree/compiler/Dialect/Flow/Transforms/ConvertConv2DToImg2ColPass.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/ConvertConv2DToImg2ColPass.cpp
@@ -139,15 +139,14 @@
         RankedTensorType::get({outputShape[1] * outputShape[2], outputShape[3]},
                               outputShapeType.getElementType());
 
-    Value reshapedImg2ColTensor =
-        rewriter.create<linalg::TensorCollapseShapeOp>(
-            loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
-            img2ColTensorReassociationIndices);
+    Value reshapedImg2ColTensor = rewriter.create<tensor::CollapseShapeOp>(
+        loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
+        img2ColTensorReassociationIndices);
 
-    Value reshapedFilter = rewriter.create<linalg::TensorCollapseShapeOp>(
+    Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
         loc, reshapedFilterType, filter, filterAndOutputReassociationIndices);
 
-    Value reshapedOutput = rewriter.create<linalg::TensorCollapseShapeOp>(
+    Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
         loc, reshapedOutputType, output, filterAndOutputReassociationIndices);
 
     auto matmulResult = rewriter.create<linalg::MatmulOp>(
@@ -155,7 +154,7 @@
         ArrayRef<Value>{reshapedImg2ColTensor, reshapedFilter},
         ArrayRef<Value>{reshapedOutput});
 
-    auto reshapedResult = rewriter.create<linalg::TensorExpandShapeOp>(
+    auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
         loc, outputShapeType, matmulResult.getResults()[0],
         filterAndOutputReassociationIndices);
 
@@ -308,14 +307,13 @@
          transposedOutputTensorShape[2] * transposedOutputTensorShape[3]},
         outputTensorType.getElementType());
 
-    Value reshapedImg2ColTensor =
-        rewriter.create<linalg::TensorCollapseShapeOp>(
-            loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
-            img2ColTensorReassociationIndices);
-    Value reshapedFilterTensor = rewriter.create<linalg::TensorCollapseShapeOp>(
+    Value reshapedImg2ColTensor = rewriter.create<tensor::CollapseShapeOp>(
+        loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
+        img2ColTensorReassociationIndices);
+    Value reshapedFilterTensor = rewriter.create<tensor::CollapseShapeOp>(
         loc, reshapedFilterTensorType, transposedFilter,
         filterReassociationIndice);
-    Value reshapedoutputTensor = rewriter.create<linalg::TensorCollapseShapeOp>(
+    Value reshapedoutputTensor = rewriter.create<tensor::CollapseShapeOp>(
         loc, reshapedOutputTensorType, transposedOutputTensor,
         outputReassociationIndice);
 
@@ -327,10 +325,9 @@
     SmallVector<ReassociationIndices> batchMatVecReassociationIndice = {{0, 1},
                                                                         {2, 3}};
 
-    Value batchMatVecResultReshaped =
-        rewriter.create<linalg::TensorExpandShapeOp>(
-            loc, transposedOutputTensor.getType(),
-            batchMatVecResult.getResult(0), batchMatVecReassociationIndice);
+    Value batchMatVecResultReshaped = rewriter.create<tensor::ExpandShapeOp>(
+        loc, transposedOutputTensor.getType(), batchMatVecResult.getResult(0),
+        batchMatVecReassociationIndice);
 
     auto transposedResult =
         transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1});
diff --git a/iree/compiler/Dialect/Flow/Transforms/ConvertLinalgMatmulToMmt4D.cpp b/iree/compiler/Dialect/Flow/Transforms/ConvertLinalgMatmulToMmt4D.cpp
index bf1e315..729937b 100644
--- a/iree/compiler/Dialect/Flow/Transforms/ConvertLinalgMatmulToMmt4D.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/ConvertLinalgMatmulToMmt4D.cpp
@@ -44,7 +44,7 @@
       RankedTensorType::get(targetShape, inputType.getElementType());
   std::array<ReassociationIndices, 2> expandIndices = {
       ReassociationIndices{0, 1}, ReassociationIndices{2, 3}};
-  Value reshapedOperand = rewriter.create<linalg::TensorExpandShapeOp>(
+  Value reshapedOperand = rewriter.create<tensor::ExpandShapeOp>(
       loc, targetType, input, expandIndices);
   return reshapedOperand;
 }
@@ -103,7 +103,7 @@
       RankedTensorType::get(targetShape, inputType.getElementType());
   std::array<ReassociationIndices, 2> collapseIndices = {
       ReassociationIndices{0, 1}, ReassociationIndices{2, 3}};
-  Value reshapedOperand = rewriter.create<linalg::TensorCollapseShapeOp>(
+  Value reshapedOperand = rewriter.create<tensor::CollapseShapeOp>(
       loc, targetType, input, collapseIndices);
   return reshapedOperand;
 }
@@ -368,8 +368,9 @@
     // Canonicalization.
     {
       OwningRewritePatternList patterns(&getContext());
-      linalg::TensorExpandShapeOp::getCanonicalizationPatterns(patterns,
-                                                               context);
+      tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
+      linalg::InitTensorOp::getCanonicalizationPatterns(patterns, context);
+      linalg::FillOp::getCanonicalizationPatterns(patterns, context);
       patterns.insert<FoldFillGenericOpPattern>(context);
       if (failed(applyPatternsAndFoldGreedily(getOperation(),
                                               std::move(patterns)))) {
diff --git a/iree/compiler/Dialect/Flow/Transforms/ConvertLinalgTensorOps.cpp b/iree/compiler/Dialect/Flow/Transforms/ConvertLinalgTensorOps.cpp
index 3e09e34..2bd32bf 100644
--- a/iree/compiler/Dialect/Flow/Transforms/ConvertLinalgTensorOps.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/ConvertLinalgTensorOps.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Support/LLVM.h"
@@ -51,7 +52,10 @@
       return failure();
     }
     SmallVector<SmallVector<Value>> outputShape;
-    if (failed(reshapeOp.reifyResultShapes(rewriter, outputShape))) {
+    ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
+        cast<ReifyRankedShapedTypeOpInterface>(reshapeOp.getOperation());
+    if (failed(reifyShapedTypeInterface.reifyResultShapes(rewriter,
+                                                          outputShape))) {
       return failure();
     }
     SmallVector<Value> outputDynamicShapes;
diff --git a/iree/compiler/Dialect/Flow/Transforms/ConvertToFlow.cpp b/iree/compiler/Dialect/Flow/Transforms/ConvertToFlow.cpp
index 993871f..156b9ed 100644
--- a/iree/compiler/Dialect/Flow/Transforms/ConvertToFlow.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/ConvertToFlow.cpp
@@ -52,7 +52,10 @@
       return failure();
     }
     SmallVector<SmallVector<Value>> outputShape;
-    if (failed(reshapeOp.reifyResultShapes(rewriter, outputShape))) {
+    ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
+        cast<ReifyRankedShapedTypeOpInterface>(reshapeOp.getOperation());
+    if (failed(reifyShapedTypeInterface.reifyResultShapes(rewriter,
+                                                          outputShape))) {
       return failure();
     }
     SmallVector<Value> outputDynamicShapes;
@@ -105,10 +108,10 @@
     context->allowUnregisteredDialects(true);
     RewritePatternSet patterns(&getContext());
 
-    patterns.insert<
-        LinalgTensorReshapeToFlowTensorReshape<linalg::TensorCollapseShapeOp>,
-        LinalgTensorReshapeToFlowTensorReshape<linalg::TensorExpandShapeOp>>(
-        context);
+    patterns
+        .insert<LinalgTensorReshapeToFlowTensorReshape<tensor::CollapseShapeOp>,
+                LinalgTensorReshapeToFlowTensorReshape<tensor::ExpandShapeOp>>(
+            context);
     populateTensorToFlowPatternsBeforeDispatchFormation(context, patterns);
     IREE::Flow::TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
 
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
index 6cd504c..a2f5564 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
@@ -162,9 +162,9 @@
 /// Operations that are cloned into dispatch regions formed with other
 /// operations as roots.
 static bool isClonableIntoDispatchOp(Operation *op) {
-  if (isa<arith::IndexCastOp, linalg::InitTensorOp,
-          linalg::TensorCollapseShapeOp, linalg::TensorExpandShapeOp,
-          tensor::ExtractOp, tensor::ExtractSliceOp>(op)) {
+  if (isa<arith::IndexCastOp, linalg::InitTensorOp, tensor::CollapseShapeOp,
+          tensor::ExpandShapeOp, tensor::ExtractOp, tensor::ExtractSliceOp>(
+          op)) {
     return true;
   }
   if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
diff --git a/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp b/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
index f6f902b..9ac58a5 100644
--- a/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
@@ -112,11 +112,10 @@
     // to the consumer linalg op.
     linalg::ControlElementwiseOpsFusionFn foldReshapeBetweenLinalgFn =
         [](const OpResult &producer, const OpOperand &consumer) {
-          auto collapseOp =
-              producer.getDefiningOp<linalg::TensorCollapseShapeOp>();
+          auto collapseOp = producer.getDefiningOp<tensor::CollapseShapeOp>();
           if (collapseOp)
             return collapseOp.src().getDefiningOp<LinalgOp>() != nullptr;
-          auto expandOp = producer.getDefiningOp<linalg::TensorExpandShapeOp>();
+          auto expandOp = producer.getDefiningOp<tensor::ExpandShapeOp>();
           if (expandOp)
             return expandOp.src().getDefiningOp<LinalgOp>() != nullptr;
           return false;
@@ -135,10 +134,14 @@
     OwningRewritePatternList reshapeCanonicalizations(&getContext());
     linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
         reshapeCanonicalizations);
-    linalg::TensorCollapseShapeOp::getCanonicalizationPatterns(
+    tensor::CollapseShapeOp::getCanonicalizationPatterns(
         reshapeCanonicalizations, context);
-    linalg::TensorExpandShapeOp::getCanonicalizationPatterns(
-        reshapeCanonicalizations, context);
+    tensor::ExpandShapeOp::getCanonicalizationPatterns(reshapeCanonicalizations,
+                                                       context);
+    linalg::InitTensorOp::getCanonicalizationPatterns(reshapeCanonicalizations,
+                                                      context);
+    linalg::FillOp::getCanonicalizationPatterns(reshapeCanonicalizations,
+                                                context);
     if (failed(applyPatternsAndFoldGreedily(
             op->getRegions(), std::move(reshapeCanonicalizations)))) {
       return signalPassFailure();
@@ -147,10 +150,13 @@
     // Push the remaining reshapes down the graphs.
     OwningRewritePatternList pushReshapePatterns(&getContext());
     linalg::populatePushReshapeOpsPatterns(pushReshapePatterns);
-    linalg::TensorCollapseShapeOp::getCanonicalizationPatterns(
-        pushReshapePatterns, context);
-    linalg::TensorExpandShapeOp::getCanonicalizationPatterns(
-        pushReshapePatterns, context);
+    tensor::CollapseShapeOp::getCanonicalizationPatterns(pushReshapePatterns,
+                                                         context);
+    tensor::ExpandShapeOp::getCanonicalizationPatterns(pushReshapePatterns,
+                                                       context);
+    linalg::InitTensorOp::getCanonicalizationPatterns(pushReshapePatterns,
+                                                      context);
+    linalg::FillOp::getCanonicalizationPatterns(pushReshapePatterns, context);
     if (failed(applyPatternsAndFoldGreedily(op->getRegions(),
                                             std::move(pushReshapePatterns)))) {
       return signalPassFailure();
diff --git a/iree/compiler/Dialect/Flow/Transforms/PadTensorToSubTensorInsert.cpp b/iree/compiler/Dialect/Flow/Transforms/PadTensorToSubTensorInsert.cpp
index 9667be2..7d28c51 100644
--- a/iree/compiler/Dialect/Flow/Transforms/PadTensorToSubTensorInsert.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/PadTensorToSubTensorInsert.cpp
@@ -80,7 +80,7 @@
       };
       expr = addValueOrAttr(expr, lowPad[dim]);
       expr = addValueOrAttr(expr, highPad[dim]);
-      Value v = linalg::applyMapToValues(
+      Value v = applyMapToValues(
           rewriter, loc, AffineMap::get(1, numSymbols, expr), mapValues)[0];
       if (auto cst = v.getDefiningOp<arith::ConstantOp>()) {
         outputShape.push_back(cst.getValue());
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/conv1x1_to_matmul.mlir b/iree/compiler/Dialect/Flow/Transforms/test/conv1x1_to_matmul.mlir
index 72312f2..fc26e2b 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/conv1x1_to_matmul.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/conv1x1_to_matmul.mlir
@@ -12,9 +12,9 @@
 // CHECK: %[[INPUT:.+]]: tensor<1x4x5x2xf32>
 // CHECK: %[[FILTER:.+]]: tensor<1x1x2x7xf32>
 // CHECK: %[[OTUPUT:.+]] = linalg.init_tensor [1, 4, 5, 7] : tensor<1x4x5x7xf32>
-// CHECK: %[[RESHAPED_INPUT:.+]] = linalg.tensor_collapse_shape %[[INPUT]] {{\[}}[0, 1, 2], [3]] : tensor<1x4x5x2xf32> into tensor<20x2xf32>
-// CHECK: %[[RESHAPED_FILTER:.+]] = linalg.tensor_collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x2x7xf32> into tensor<2x7xf32>
-// CHECK: %[[RESHAPED_OUTPUT:.+]] = linalg.tensor_collapse_shape %[[OTUPUT]] {{\[}}[0, 1, 2], [3]] : tensor<1x4x5x7xf32> into tensor<20x7xf32>
+// CHECK: %[[RESHAPED_INPUT:.+]] = tensor.collapse_shape %[[INPUT]] {{\[}}[0, 1, 2], [3]] : tensor<1x4x5x2xf32> into tensor<20x2xf32>
+// CHECK: %[[RESHAPED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x2x7xf32> into tensor<2x7xf32>
+// CHECK: %[[RESHAPED_OUTPUT:.+]] = tensor.collapse_shape %[[OTUPUT]] {{\[}}[0, 1, 2], [3]] : tensor<1x4x5x7xf32> into tensor<20x7xf32>
 // CHECK: %[[MATMUL_RESULT:.+]] = linalg.matmul ins(%[[RESHAPED_INPUT]], %[[RESHAPED_FILTER]] : tensor<20x2xf32>, tensor<2x7xf32>) outs(%[[RESHAPED_OUTPUT]] : tensor<20x7xf32>)
-// CHECK: %[[RESULT:.+]] = linalg.tensor_expand_shape %[[MATMUL_RESULT]] {{\[}}[0, 1, 2], [3]] : tensor<20x7xf32> into tensor<1x4x5x7xf32>
+// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0, 1, 2], [3]] : tensor<20x7xf32> into tensor<1x4x5x7xf32>
 // CHECK: return %[[RESULT]]
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/conv2d_to_img2col.mlir b/iree/compiler/Dialect/Flow/Transforms/test/conv2d_to_img2col.mlir
index 268e87a..3c86ba8 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/conv2d_to_img2col.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/conv2d_to_img2col.mlir
@@ -19,16 +19,16 @@
 //           CHECK-SAME: #[[MAP1]]
 //                CHECK: ^bb0(%[[IN_DATA:.+]]: f32, %[[OUT_DATA:.+]]: f32)
 //                CHECK: linalg.yield %[[IN_DATA]] : f32
-//      CHECK-DAG: %[[RESHAPED_INIT_COL_TENSOR:.+]] = linalg.tensor_collapse_shape %[[COL_TENSOR]]
+//      CHECK-DAG: %[[RESHAPED_INIT_COL_TENSOR:.+]] = tensor.collapse_shape %[[COL_TENSOR]]
 //           CHECK-SAME: [0, 1, 2], [3, 4, 5]
 //           CHECK-SAME: tensor<1x14x14x3x3x4xf32> into tensor<196x36xf32>
-//      CHECK-DAG: %[[RESHAPED_FILTER:.+]] = linalg.tensor_collapse_shape %[[FILTER]]
+//      CHECK-DAG: %[[RESHAPED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]]
 //           CHECK-SAME: [0, 1, 2], [3]
 //           CHECK-SAME: tensor<3x3x4x16xf32> into tensor<36x16xf32>
-//      CHECK-DAG: %[[RESHAPED_OUTPUT:.+]] = linalg.tensor_collapse_shape %[[OUTPUT]]
+//      CHECK-DAG: %[[RESHAPED_OUTPUT:.+]] = tensor.collapse_shape %[[OUTPUT]]
 //           CHECK-SAME: [0, 1, 2], [3]
 //      CHECK: %[[MATMUL_RESULT:.+]] = linalg.matmul ins(%[[RESHAPED_INIT_COL_TENSOR]], %[[RESHAPED_FILTER]] : tensor<196x36xf32>, tensor<36x16xf32>) outs(%[[RESHAPED_OUTPUT]] : tensor<196x16xf32>)
-//      CHECK: %[[RESULT:.+]] = linalg.tensor_expand_shape %[[MATMUL_RESULT]] {{\[}}[0, 1, 2], [3]] : tensor<196x16xf32> into tensor<1x14x14x16xf32>
+//      CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0, 1, 2], [3]] : tensor<196x16xf32> into tensor<1x14x14x16xf32>
 //      CHECK: return %[[RESULT]]
 
 // -----
@@ -84,14 +84,14 @@
 // CHECK-NEXT:      ^bb0(%{{.*}}: f32, %{{.*}}: f32):
 // CHECK-NEXT:         linalg.yield
 // CHECK-NEXT:    } -> tensor<1x16x112x112x3x3xf32>
-//      CHECK: %[[COL_TENSOR_R:.+]] = linalg.tensor_collapse_shape %[[COL_TENSOR]]
+//      CHECK: %[[COL_TENSOR_R:.+]] = tensor.collapse_shape %[[COL_TENSOR]]
 // CHECK-SAME:    tensor<1x16x112x112x3x3xf32> into tensor<16x12544x9xf32>
-//      CHECK: %[[FILTER_T_R:.+]] = linalg.tensor_collapse_shape %[[FILTER_T]]
+//      CHECK: %[[FILTER_T_R:.+]] = tensor.collapse_shape %[[FILTER_T]]
 // CHECK-SAME:    tensor<16x3x3xf32> into tensor<16x9xf32>
-//      CHECK: %[[OUTPUT_T_R:.+]] = linalg.tensor_collapse_shape %[[OUTPUT_T]]
+//      CHECK: %[[OUTPUT_T_R:.+]] = tensor.collapse_shape %[[OUTPUT_T]]
 // CHECK-SAME:    tensor<1x16x112x112xf32> into tensor<16x12544xf32>
 //      CHECK: %[[BMV_RESULT:.+]] = linalg.batch_matvec ins(%[[COL_TENSOR_R]], %[[FILTER_T_R]] : tensor<16x12544x9xf32>, tensor<16x9xf32>) outs(%[[OUTPUT_T_R]] : tensor<16x12544xf32>) -> tensor<16x12544xf32>
-//      CHECK: %[[RESULT_R:.+]] = linalg.tensor_expand_shape %[[BMV_RESULT]]
+//      CHECK: %[[RESULT_R:.+]] = tensor.expand_shape %[[BMV_RESULT]]
 // CHECK-SAME:    tensor<16x12544xf32> into tensor<1x16x112x112xf32>
 //      CHECK: %[[RESULT_INIT:.+]] = linalg.init_tensor [1, 112, 112, 16] : tensor<1x112x112x16xf32>
 //      CHECK: %[[RESULT:.+]] = linalg.generic
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/convert_linalg_tensor_ops_before.mlir b/iree/compiler/Dialect/Flow/Transforms/test/convert_linalg_tensor_ops_before.mlir
index 1fe2725..957bbb8 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/convert_linalg_tensor_ops_before.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/convert_linalg_tensor_ops_before.mlir
@@ -3,9 +3,9 @@
 func @tensor_reshape(%arg0 : tensor<?x4x?x5x?x6xf32>, %arg1 : tensor<20x?x40xf32>)
     -> (tensor<?x5x?xf32>, tensor<5x4x?x4x2x4x5xf32>)
 {
-  %0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2], [3], [4, 5]]
+  %0 = tensor.collapse_shape %arg0 [[0, 1, 2], [3], [4, 5]]
       : tensor<?x4x?x5x?x6xf32> into tensor<?x5x?xf32>
-  %1 = linalg.tensor_expand_shape %arg1 [[0, 1], [2, 3], [4, 5, 6]]
+  %1 = tensor.expand_shape %arg1 [[0, 1], [2, 3], [4, 5, 6]]
       : tensor<20x?x40xf32> into tensor<5x4x?x4x2x4x5xf32>
   return %0, %1 : tensor<?x5x?xf32>, tensor<5x4x?x4x2x4x5xf32>
 }
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
index fa7e415..5ddf2a7 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
@@ -265,7 +265,7 @@
   %cst = arith.constant 0.0 : f32
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
-  %0 = linalg.tensor_expand_shape %lhs [[0, 1]]
+  %0 = tensor.expand_shape %lhs [[0, 1]]
     : tensor<?xf32> into tensor<?x4xf32>
   %m = tensor.dim %0, %c0 : tensor<?x4xf32>
   %n1 = tensor.dim %rhs1, %c1 : tensor<4x?xf32>
@@ -543,14 +543,14 @@
 func @inline_dag_1(
     %arg0: tensor<?xf32>, %arg1 : tensor<1x?xf32>, %arg2 : tensor<i32>,
     %arg3 : index) -> tensor<?xf32> {
-  %0 = linalg.tensor_expand_shape %arg0 [[0, 1]] : tensor<?xf32> into tensor<1x?xf32>
+  %0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<?xf32> into tensor<1x?xf32>
   %1 = tensor.extract_slice %0[0, 20] [1, %arg3] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
-  %2 = linalg.tensor_collapse_shape %1 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
-  %3 = linalg.tensor_collapse_shape %arg1 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
+  %2 = tensor.collapse_shape %1 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
+  %3 = tensor.collapse_shape %arg1 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
   %4 = tensor.extract_slice %0[0, 10] [1, %arg3] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
-  %5 = linalg.tensor_collapse_shape %4 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
+  %5 = tensor.collapse_shape %4 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
   %6 = tensor.extract_slice %0[0, 0] [1, %arg3] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
-  %7 = linalg.tensor_collapse_shape %6 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
+  %7 = tensor.collapse_shape %6 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
   %8 = linalg.init_tensor [%arg3] : tensor<?xf32>
   %9 = linalg.generic {
       indexing_maps = [#map0, #map1, #map1, #map1, #map1, #map1],
@@ -580,14 +580,14 @@
 //       CHECK:     %[[LEAF1:.+]] = flow.dispatch.tensor.load %[[ARG4]], {{.*}}
 //       CHECK:     %[[LEAF2:.+]] = flow.dispatch.tensor.load %[[ARG5]], {{.*}}
 //       CHECK:     %[[LEAF3:.+]] = flow.dispatch.tensor.load %[[ARG6]], {{.*}}
-//       CHECK:     %[[OP1:.+]] = linalg.tensor_expand_shape %[[LEAF2]]
-//       CHECK:     %[[OP2:.+]] = linalg.tensor_collapse_shape %[[LEAF1]]
+//       CHECK:     %[[OP1:.+]] = tensor.expand_shape %[[LEAF2]]
+//       CHECK:     %[[OP2:.+]] = tensor.collapse_shape %[[LEAF1]]
 //       CHECK:     %[[OP3:.+]] = tensor.extract_slice %[[OP1]][0, 0]
 //       CHECK:     %[[OP4:.+]] = tensor.extract_slice %[[OP1]][0, 10]
 //       CHECK:     %[[OP5:.+]] = tensor.extract_slice %[[OP1]][0, 20]
-//       CHECK:     %[[OP6:.+]] = linalg.tensor_collapse_shape %[[OP3]]
-//       CHECK:     %[[OP7:.+]] = linalg.tensor_collapse_shape %[[OP4]]
-//       CHECK:     %[[OP8:.+]] = linalg.tensor_collapse_shape %[[OP5]]
+//       CHECK:     %[[OP6:.+]] = tensor.collapse_shape %[[OP3]]
+//       CHECK:     %[[OP7:.+]] = tensor.collapse_shape %[[OP4]]
+//       CHECK:     %[[OP8:.+]] = tensor.collapse_shape %[[OP5]]
 
 // -----
 
@@ -596,16 +596,16 @@
 func @inline_dag_2(
     %arg0: tensor<?xf32>, %arg1 : tensor<1x?xf32>, %arg2 : tensor<i32>,
     %arg3 : index) -> tensor<?xf32> {
-  %0 = linalg.tensor_expand_shape %arg0 [[0, 1]] : tensor<?xf32> into tensor<1x?xf32>
+  %0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<?xf32> into tensor<1x?xf32>
   %1 = tensor.extract_slice %0[0, 20] [1, %arg3] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
-  %2 = linalg.tensor_collapse_shape %arg1 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
+  %2 = tensor.collapse_shape %arg1 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
   br ^bb1
 ^bb1:
-  %3 = linalg.tensor_collapse_shape %1 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
+  %3 = tensor.collapse_shape %1 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
   %4 = tensor.extract_slice %0[0, 10] [1, %arg3] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
-  %5 = linalg.tensor_collapse_shape %4 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
+  %5 = tensor.collapse_shape %4 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
   %6 = tensor.extract_slice %0[0, 0] [1, %arg3] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
-  %7 = linalg.tensor_collapse_shape %6 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
+  %7 = tensor.collapse_shape %6 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
   %8 = linalg.init_tensor [%arg3] : tensor<?xf32>
   %9 = linalg.generic {
       indexing_maps = [#map0, #map1, #map1, #map1, #map1, #map1],
@@ -635,14 +635,14 @@
 //       CHECK:     %[[LEAF1:.+]] = flow.dispatch.tensor.load %[[ARG4]], {{.*}}
 //       CHECK:     %[[LEAF2:.+]] = flow.dispatch.tensor.load %[[ARG5]], {{.*}}
 //       CHECK:     %[[LEAF3:.+]] = flow.dispatch.tensor.load %[[ARG6]], {{.*}}
-//       CHECK:     %[[OP1:.+]] = linalg.tensor_expand_shape %[[LEAF2]]
-//       CHECK:     %[[OP2:.+]] = linalg.tensor_collapse_shape %[[LEAF1]]
+//       CHECK:     %[[OP1:.+]] = tensor.expand_shape %[[LEAF2]]
+//       CHECK:     %[[OP2:.+]] = tensor.collapse_shape %[[LEAF1]]
 //       CHECK:     %[[OP3:.+]] = tensor.extract_slice %[[OP1]][0, 0]
 //       CHECK:     %[[OP4:.+]] = tensor.extract_slice %[[OP1]][0, 10]
 //       CHECK:     %[[OP5:.+]] = tensor.extract_slice %[[OP1]][0, 20]
-//       CHECK:     %[[OP6:.+]] = linalg.tensor_collapse_shape %[[OP3]]
-//       CHECK:     %[[OP7:.+]] = linalg.tensor_collapse_shape %[[OP4]]
-//       CHECK:     %[[OP8:.+]] = linalg.tensor_collapse_shape %[[OP5]]
+//       CHECK:     %[[OP6:.+]] = tensor.collapse_shape %[[OP3]]
+//       CHECK:     %[[OP7:.+]] = tensor.collapse_shape %[[OP4]]
+//       CHECK:     %[[OP8:.+]] = tensor.collapse_shape %[[OP5]]
 
 // -----
 
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/matmul_to_mmt4d.mlir b/iree/compiler/Dialect/Flow/Transforms/test/matmul_to_mmt4d.mlir
index 1001bb0..19caa34 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/matmul_to_mmt4d.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/matmul_to_mmt4d.mlir
@@ -8,11 +8,11 @@
 // CHECK-DAG:#[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 // CHECK-DAG:#[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3, d0, d2)>
 //      CHECK: @check_mmt4d_f32_static_nopad(%[[LHS:.+]]: tensor<24x8xf32>, %[[RHS:.+]]: tensor<8x32xf32>, %[[DST:.+]]: tensor<24x32xf32>)
-//      CHECK: %[[LHS4D:.+]] = linalg.tensor_expand_shape %[[LHS]]
+//      CHECK: %[[LHS4D:.+]] = tensor.expand_shape %[[LHS]]
 // CHECK-SAME:   tensor<24x8xf32> into tensor<3x8x4x2xf32>
-//      CHECK: %[[RHS4D:.+]] = linalg.tensor_expand_shape %[[RHS]]
+//      CHECK: %[[RHS4D:.+]] = tensor.expand_shape %[[RHS]]
 // CHECK-SAME:   tensor<8x32xf32> into tensor<4x2x8x4xf32>
-//      CHECK: %[[DST4D:.+]] = linalg.tensor_expand_shape %[[DST]]
+//      CHECK: %[[DST4D:.+]] = tensor.expand_shape %[[DST]]
 // CHECK-SAME:   tensor<24x32xf32> into tensor<3x8x8x4xf32>
 //      CHECK: %[[LHS4DT_INIT:.+]] = linalg.init_tensor [3, 4, 8, 2] : tensor<3x4x8x2xf32>
 //      CHECK: %[[LHS4DT:.+]] = linalg.generic
@@ -47,7 +47,7 @@
 // CHECK-NEXT:    ^bb0(%{{.*}}: f32, %{{.*}}: f32):
 // CHECK-NEXT:           linalg.yield %arg3 : f32
 // CHECK-NEXT:    } -> tensor<3x8x8x4xf32>
-//      CHECK: %[[RESULT:.+]] = linalg.tensor_collapse_shape %[[MMT4DT]]
+//      CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[MMT4DT]]
 // CHECK-SAME:    tensor<3x8x8x4xf32> into tensor<24x32xf32>
 //      CHECK: return %[[RESULT]] : tensor<24x32xf32>
 
@@ -64,9 +64,9 @@
 // CHECK-DAG:#[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3, d0, d2)>
 //      CHECK: @check_mmt4d_with_init_tensor_and_fill(%[[LHS:.+]]: tensor<24x8xf32>, %[[RHS:.+]]: tensor<8x32xf32>)
 //      CHECK: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
-//      CHECK: %[[LHS4D:.+]] = linalg.tensor_expand_shape %[[LHS]]
+//      CHECK: %[[LHS4D:.+]] = tensor.expand_shape %[[LHS]]
 // CHECK-SAME:   tensor<24x8xf32> into tensor<3x8x4x2xf32>
-//      CHECK: %[[RHS4D:.+]] = linalg.tensor_expand_shape %[[RHS]]
+//      CHECK: %[[RHS4D:.+]] = tensor.expand_shape %[[RHS]]
 // CHECK-SAME:   tensor<8x32xf32> into tensor<4x2x8x4xf32>
 //      CHECK: %[[DST_INIT:.+]] = linalg.init_tensor [3, 8, 8, 4] : tensor<3x8x8x4xf32>
 //      CHECK: [[DST:.+]] linalg.fill(%[[ZERO:.+]], %[[DST_INIT]])
@@ -84,15 +84,15 @@
 //      CHECK: tensor<5x2xi8> to tensor<6x4xi8>
 //      CHECK: %[[ACCPAD:.+]] = linalg.pad_tensor %[[ACC]] low[0, 0] high[5, 2]
 //      CHECK: tensor<3x2xi32> to tensor<8x4xi32>
-//      CHECK: %[[LHS4D:.+]] = linalg.tensor_expand_shape %[[LHSPAD]]
+//      CHECK: %[[LHS4D:.+]] = tensor.expand_shape %[[LHSPAD]]
 // CHECK-SAME: tensor<8x6xi8> into tensor<1x8x3x2xi8>
-//      CHECK: %[[RHS4D:.+]] = linalg.tensor_expand_shape %[[RHSPAD]]
+//      CHECK: %[[RHS4D:.+]] = tensor.expand_shape %[[RHSPAD]]
 // CHECK-SAME: tensor<6x4xi8> into tensor<3x2x1x4xi8>
-//      CHECK: %[[ACC4D:.+]] = linalg.tensor_expand_shape %[[ACCPAD]]
+//      CHECK: %[[ACC4D:.+]] = tensor.expand_shape %[[ACCPAD]]
 // CHECK-SAME: tensor<8x4xi32> into tensor<1x8x1x4xi32>
 //  ... After the above padding, we are reduced to the same stuff as we have
 //  ... already checked in the above testcases, so we skip checking that again.
-//      CHECK: %[[RESPAD:.+]] = linalg.tensor_collapse_shape
+//      CHECK: %[[RESPAD:.+]] = tensor.collapse_shape
 // CHECK-SAME: tensor<1x8x1x4xi32> into tensor<8x4xi32>
 //      CHECK: %[[RES:.+]] = tensor.extract_slice %[[RESPAD]][0, 0] [3, 2] [1, 1]
 // CHECK-SAME: tensor<8x4xi32> to tensor<3x2xi32>
@@ -113,15 +113,15 @@
 //      CHECK: tensor<?x?xi8> to tensor<?x?xi8>
 //      CHECK: %[[ACCPAD:.+]] = linalg.pad_tensor %[[ACC]] low[0, 0] high[
 //      CHECK: tensor<?x?xi32> to tensor<?x?xi32>
-//      CHECK: %[[LHS4D:.+]] = linalg.tensor_expand_shape %[[LHSPAD]]
+//      CHECK: %[[LHS4D:.+]] = tensor.expand_shape %[[LHSPAD]]
 // CHECK-SAME: tensor<?x?xi8> into tensor<?x8x?x2xi8>
-//      CHECK: %[[RHS4D:.+]] = linalg.tensor_expand_shape %[[RHSPAD]]
+//      CHECK: %[[RHS4D:.+]] = tensor.expand_shape %[[RHSPAD]]
 // CHECK-SAME: tensor<?x?xi8> into tensor<?x2x?x4xi8>
-//      CHECK: %[[ACC4D:.+]] = linalg.tensor_expand_shape %[[ACCPAD]]
+//      CHECK: %[[ACC4D:.+]] = tensor.expand_shape %[[ACCPAD]]
 // CHECK-SAME: tensor<?x?xi32> into tensor<?x8x?x4xi32>
 //  ... After the above padding, we are reduced to the same stuff as we have
 //  ... already checked in the above testcases, so we skip checking that again.
-//      CHECK: %[[RESPAD:.+]] = linalg.tensor_collapse_shape
+//      CHECK: %[[RESPAD:.+]] = tensor.collapse_shape
 // CHECK-SAME: tensor<?x8x?x4xi32> into tensor<?x?xi32>
 //      CHECK: %[[RES:.+]] = tensor.extract_slice %[[RESPAD]][0, 0] [{{.*}}] [1, 1]
 // CHECK-SAME: tensor<?x?xi32> to tensor<?x?xi32>
diff --git a/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp b/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp
index df0ff46..511e3ff 100644
--- a/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp
+++ b/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp
@@ -257,7 +257,7 @@
     map.emplace_back(1, indicesRank - 1);
     auto resultType = RankedTensorType::get({batchSize, shape.back()},
                                             indicesType.getElementType());
-    indices = b.create<linalg::TensorCollapseShapeOp>(resultType, indices, map);
+    indices = b.create<tensor::CollapseShapeOp>(resultType, indices, map);
 
     auto updateShape = updatesType.getShape().drop_front(shape.size() - 1);
     SmallVector<int64_t> collapsedUpdateShape = {batchSize};
@@ -269,7 +269,7 @@
     for (auto i : llvm::seq<int64_t>(indicesRank - 1, updatesType.getRank())) {
       map.emplace_back(1, i);
     }
-    updates = b.create<linalg::TensorCollapseShapeOp>(resultType, updates, map);
+    updates = b.create<tensor::CollapseShapeOp>(resultType, updates, map);
 
     return success();
   }
diff --git a/iree/compiler/InputConversion/MHLO/LegalizeInputTypes.cpp b/iree/compiler/InputConversion/MHLO/LegalizeInputTypes.cpp
index b6bec24..fa18aa4 100644
--- a/iree/compiler/InputConversion/MHLO/LegalizeInputTypes.cpp
+++ b/iree/compiler/InputConversion/MHLO/LegalizeInputTypes.cpp
@@ -81,8 +81,8 @@
                                       BlockAndValueMapping &mapping,
                                       OpBuilder &builder) {
   if (isa<linalg::LinalgDialect>(oldOp->getDialect()) &&
-      !isa<linalg::TensorCollapseShapeOp>(oldOp) &&
-      !isa<linalg::TensorExpandShapeOp>(oldOp)) {
+      !isa<tensor::CollapseShapeOp>(oldOp) &&
+      !isa<tensor::ExpandShapeOp>(oldOp)) {
     // Currently we assume all Linalg structured ops only contain valid types.
     builder.clone(*oldOp, mapping);
     return success();
diff --git a/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir b/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir
index 10115da..bba02f6 100644
--- a/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir
+++ b/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir
@@ -336,9 +336,9 @@
 // CHECK:         %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK:         %[[ARG1:[a-zA-Z0-9]+]]
 // CHECK:         %[[ARG2:[a-zA-Z0-9]+]]
-// CHECK:         %[[COLLAPSED_INDICES:.+]] = linalg.tensor_collapse_shape
+// CHECK:         %[[COLLAPSED_INDICES:.+]] = tensor.collapse_shape
 // CHECK-SAME:        %[[ARG1]] {{\[}}[0, 1], [2]] : tensor<3x4x1xi32> into tensor<12x1xi32>
-// CHECK:         %[[COLLAPSED_UPDATES:.+]] = linalg.tensor_collapse_shape
+// CHECK:         %[[COLLAPSED_UPDATES:.+]] = tensor.collapse_shape
 // CHECK-SAME:        %[[ARG2]] {{\[}}[0, 1]] : tensor<3x4xi32> into tensor<12xi32>
 // CHECK:         %[[SCATTER:.+]] = iree_linalg_ext.scatter
 // CHECK-SAME:       ins(%[[COLLAPSED_UPDATES]], %[[COLLAPSED_INDICES]] : tensor<12xi32>, tensor<12x1xi32>)
@@ -369,9 +369,9 @@
 // CHECK:         %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK:         %[[ARG1:[a-zA-Z0-9]+]]
 // CHECK:         %[[ARG2:[a-zA-Z0-9]+]]
-// CHECK:         %[[COLLAPSED_INDICES:.+]] = linalg.tensor_collapse_shape
+// CHECK:         %[[COLLAPSED_INDICES:.+]] = tensor.collapse_shape
 // CHECK-SAME:        %[[ARG1]] {{\[}}[0, 1], [2]] : tensor<?x3x2xi32> into tensor<?x2xi32>
-// CHECK:         %[[COLLAPSED_UPDATES:.+]] = linalg.tensor_collapse_shape
+// CHECK:         %[[COLLAPSED_UPDATES:.+]] = tensor.collapse_shape
 // CHECK-SAME:        %[[ARG2]] {{\[}}[0, 1], [2]] : tensor<?x3x512xi32> into tensor<?x512xi32>
 // CHECK:         %[[SCATTER:.+]] = iree_linalg_ext.scatter
 // CHECK-SAME:        ins(%[[COLLAPSED_UPDATES]], %[[COLLAPSED_INDICES]] : tensor<?x512xi32>, tensor<?x2xi32>)
diff --git a/iree/compiler/InputConversion/MHLO/test/legalize_input_types.mlir b/iree/compiler/InputConversion/MHLO/test/legalize_input_types.mlir
index 313df38..2467f56 100644
--- a/iree/compiler/InputConversion/MHLO/test/legalize_input_types.mlir
+++ b/iree/compiler/InputConversion/MHLO/test/legalize_input_types.mlir
@@ -149,9 +149,9 @@
 // CHECK-LABEL: func @linalg_non_structured_op
 // CHECK-SAME:    (%arg0: tensor<9xi32>) -> tensor<1x9xi32>
 func @linalg_non_structured_op(%arg0: tensor<9xi64>) -> tensor<1x9xi64> {
-  // CHECK:       %[[RES:.+]] = linalg.tensor_expand_shape %arg0 {{\[}}[0, 1]] : tensor<9xi32> into tensor<1x9xi32>
+  // CHECK:       %[[RES:.+]] = tensor.expand_shape %arg0 {{\[}}[0, 1]] : tensor<9xi32> into tensor<1x9xi32>
   // CHECK:       return %[[RES:.+]] : tensor<1x9xi32>
-  %0 = linalg.tensor_expand_shape %arg0 [[0, 1]] : tensor<9xi64> into tensor<1x9xi64>
+  %0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<9xi64> into tensor<1x9xi64>
   return %0 : tensor<1x9xi64>
 }
 
diff --git a/iree/test/e2e/xla_ops/scatter.mlir b/iree/test/e2e/xla_ops/scatter.mlir
index 259a78e..2614c28 100644
--- a/iree/test/e2e/xla_ops/scatter.mlir
+++ b/iree/test/e2e/xla_ops/scatter.mlir
@@ -108,7 +108,7 @@
      %1 = arith.index_cast %0 : index to i32
      linalg.yield %1 : i32
       } -> tensor<1400xi32>
-  %indices_reshaped = linalg.tensor_expand_shape %indices [[0, 1]] :
+  %indices_reshaped = tensor.expand_shape %indices [[0, 1]] :
       tensor<1400xi32> into tensor<1400x1xi32>
   %result = "mhlo.scatter"(%original, %indices_reshaped, %update)({
     ^bb0(%arg3 : tensor<i32>, %arg4 : tensor<i32>):
@@ -139,7 +139,7 @@
         %1 = arith.index_cast %0 : index to i32
         linalg.yield %1 : i32
       } -> tensor<200xi32>
-  %indices_reshaped = linalg.tensor_expand_shape %indices [[0, 1]] :
+  %indices_reshaped = tensor.expand_shape %indices [[0, 1]] :
       tensor<200xi32> into tensor<200x1xi32>
   %result = "mhlo.scatter"(%original, %indices_reshaped, %update)({
     ^bb0(%arg3 : tensor<i32>, %arg4 : tensor<i32>):
diff --git a/iree/tools/BUILD b/iree/tools/BUILD
index 4be1159..4a29ef1 100644
--- a/iree/tools/BUILD
+++ b/iree/tools/BUILD
@@ -160,6 +160,7 @@
         "@llvm-project//mlir:Shape",
         "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:StandardToSPIRV",
+        "@llvm-project//mlir:TensorInferTypeOpInterfaceImpl",
         "@llvm-project//mlir:Transforms",
         "@llvm-project//mlir:VectorOps",
     ],
diff --git a/iree/tools/init_mlir_dialects.h b/iree/tools/init_mlir_dialects.h
index 2ddcf82..fc8d911 100644
--- a/iree/tools/init_mlir_dialects.h
+++ b/iree/tools/init_mlir_dialects.h
@@ -25,6 +25,7 @@
 #include "mlir/Dialect/Shape/IR/Shape.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/IR/Dialect.h"
@@ -55,6 +56,7 @@
                   tosa::TosaDialect,
                   shape::ShapeDialect>();
   // clang-format on
+  tensor::registerInferTypeOpInterfaceExternalModels(registry);
 
 #ifdef IREE_HAVE_EMITC_DIALECT
   registry.insert<emitc::EmitCDialect>();
diff --git a/third_party/llvm-project b/third_party/llvm-project
index def8b95..c6a8bec 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit def8b952ebc00b0ad0fa4196ae27b11a385087ad
+Subproject commit c6a8bec4c578a7c12e4458b161fce7b1704804a2