Merge pull request #6114 from rsuderman:main-to-google

PiperOrigin-RevId: 377579548
diff --git a/SUBMODULE_VERSIONS.txt b/SUBMODULE_VERSIONS.txt
index 0920faf..0e47f9b 100644
--- a/SUBMODULE_VERSIONS.txt
+++ b/SUBMODULE_VERSIONS.txt
@@ -4,15 +4,15 @@
 4fb0ff7069bd88ee85902f4d0bb62794e5f6d021 third_party/flatcc
 b1fbd33c06cdb0024c67733c6fdec2009d17b384 third_party/googletest
 88b845dee001723c4a0db1fe5477de735b6d3bb0 third_party/liburing
-528af9c01dc4080331184ece3ced3be6beaad47f third_party/llvm-bazel
-c89dff5855bb32d47751cce087537c2b12a90f1b third_party/llvm-project
+12a96e87d2eb377b7884aae56773bca50f12e7c7 third_party/llvm-bazel
+b109172d993edacd9853a8bbb8128a94da014399 third_party/llvm-project
 108a78da82049553b41c7a0f5987c67d5006af8d third_party/mlir-emitc
-fe42a08fc93830f06150b93bcb764e2386ca3110 third_party/mlir-hlo
+8b3a75ea25ceca9070938ced4b5909042765ea0c third_party/mlir-hlo
 d8c7ee00a687ac369e62e2032514a93a9b413502 third_party/pybind11
 685f86471e9d26b3eb7676695a2e2cefb4551ae9 third_party/spirv_cross
 f8bf11a0253a32375c32cad92c841237b96696c0 third_party/spirv_headers
 b42009b3b9d4ca35bc703f5310eedc74f584be58 third_party/stblib
-8be2640d44c8b4012e7863c93eea3ccb14f75151 third_party/tensorflow
+1d497e6419ca142acd188632c228c11d7c074a34 third_party/tensorflow
 f03b677ffa0fd96fcf859c32e79b740fac7dd59e third_party/tracy
 9bd3f561bcee3f01d22912de10bb07ce4e23d378 third_party/vulkan_headers
 3528e2aed3e8808f33e1e7d63eeb1560456a605a third_party/vulkan_memory_allocator
diff --git a/integrations/tensorflow/e2e/keras/applications/BUILD b/integrations/tensorflow/e2e/keras/applications/BUILD
index 9fce0f0..cd360b9 100644
--- a/integrations/tensorflow/e2e/keras/applications/BUILD
+++ b/integrations/tensorflow/e2e/keras/applications/BUILD
@@ -103,8 +103,15 @@
     failing_configurations = [
         # Frequently OOMs
         {
-            "target_backends": "tflite",
-            "model": "VGG19",
+            "target_backends": [
+                "tflite",
+                "iree_llvmaot",
+                "iree_vulkan",
+            ],
+            "model": [
+                "VGG16",
+                "VGG19",
+            ],
         },
     ],
     matrix = {
diff --git a/integrations/tensorflow/e2e/keras/applications/CMakeLists.txt b/integrations/tensorflow/e2e/keras/applications/CMakeLists.txt
index 7760534..448e0e4 100644
--- a/integrations/tensorflow/e2e/keras/applications/CMakeLists.txt
+++ b/integrations/tensorflow/e2e/keras/applications/CMakeLists.txt
@@ -27,7 +27,12 @@
     "MobileNet;MobileNetV2;ResNet50;VGG16;VGG19"
     "tf;tflite;iree_llvmaot;iree_vulkan"
   FAILING_CONFIGURATIONS
+    ",,,VGG16,tflite"
     ",,,VGG19,tflite"
+    ",,,VGG16,iree_llvmaot"
+    ",,,VGG19,iree_llvmaot"
+    ",,,VGG16,iree_vulkan"
+    ",,,VGG19,iree_vulkan"
   LABELS
     "manual"
 )
diff --git a/integrations/tensorflow/iree_tf_compiler/BUILD b/integrations/tensorflow/iree_tf_compiler/BUILD
index 03f6647..793693e 100644
--- a/integrations/tensorflow/iree_tf_compiler/BUILD
+++ b/integrations/tensorflow/iree_tf_compiler/BUILD
@@ -103,6 +103,7 @@
         "//iree_tf_compiler/MHLO",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Parser",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:Support",
diff --git a/integrations/tensorflow/iree_tf_compiler/MHLO/BUILD b/integrations/tensorflow/iree_tf_compiler/MHLO/BUILD
index b9bf77a..986b490 100644
--- a/integrations/tensorflow/iree_tf_compiler/MHLO/BUILD
+++ b/integrations/tensorflow/iree_tf_compiler/MHLO/BUILD
@@ -37,9 +37,12 @@
         "@iree//iree/compiler/Dialect/Flow/IR",
         "@iree//iree/compiler/Dialect/Flow/Transforms",
         "@iree//iree/compiler/Dialect/IREE/IR",
+        "@iree//iree/compiler/Dialect/Shape/Conversion",
+        "@iree//iree/compiler/Dialect/Shape/Transforms",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:SCFToStandard",
         "@llvm-project//mlir:SCFTransforms",
         "@llvm-project//mlir:Shape",
         "@llvm-project//mlir:ShapeTransforms",
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/BUILD b/integrations/tensorflow/iree_tf_compiler/TF/BUILD
index a745d9d..fc63547 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/BUILD
+++ b/integrations/tensorflow/iree_tf_compiler/TF/BUILD
@@ -46,6 +46,8 @@
         "@iree//iree/compiler/Dialect/Shape/Transforms",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:LinalgOps",
+        "@llvm-project//mlir:MemRefDialect",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:Shape",
         "@llvm-project//mlir:ShapeTransforms",
diff --git a/iree/compiler/Conversion/Common/BufferAllocViewCleanUpPass.cpp b/iree/compiler/Conversion/Common/BufferAllocViewCleanUpPass.cpp
index 6615a66..f5e75a5 100644
--- a/iree/compiler/Conversion/Common/BufferAllocViewCleanUpPass.cpp
+++ b/iree/compiler/Conversion/Common/BufferAllocViewCleanUpPass.cpp
@@ -44,14 +44,15 @@
 ///       !flow.dispatch.tensor<readonly:864xf32>
 ///   %0 = flow.dispatch.tensor.load %subspan :
 ///       !flow.dispatch.tensor<readonly:864xf32> -> tensor<864xf32>
-struct FoldReshapeIntoInterfaceTensorLoad
-    : OpRewritePattern<linalg::TensorReshapeOp> {
-  using OpRewritePattern::OpRewritePattern;
+template <typename TensorReshapeOp>
+struct FoldReshapeIntoInterfaceTensorLoad : OpRewritePattern<TensorReshapeOp> {
+  using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(linalg::TensorReshapeOp reshapeOp,
+  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
                                 PatternRewriter &rewriter) const override {
     auto loadOp =
-        reshapeOp.src().getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
+        reshapeOp.src()
+            .template getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
     if (!loadOp) return failure();
 
     // Make sure we are loading the full incoming subspan. Otherwise we cannot
@@ -61,11 +62,14 @@
       return failure();
 
     auto subspanOp =
-        loadOp.source().getDefiningOp<IREE::HAL::InterfaceBindingSubspanOp>();
+        loadOp.source()
+            .template getDefiningOp<IREE::HAL::InterfaceBindingSubspanOp>();
     if (!subspanOp) return failure();
 
     auto newSubspanType = IREE::Flow::DispatchTensorType::get(
-        subspanOp.getType().cast<IREE::Flow::DispatchTensorType>().getAccess(),
+        subspanOp.getType()
+            .template cast<IREE::Flow::DispatchTensorType>()
+            .getAccess(),
         reshapeOp.getResultType());
 
     Value newSubspanOp = rewriter.create<IREE::HAL::InterfaceBindingSubspanOp>(
@@ -101,8 +105,10 @@
     : public PassWrapper<BufferAllocViewCleanUpPass, FunctionPass> {
   void runOnFunction() override {
     OwningRewritePatternList patterns(&getContext());
-    patterns.insert<FoldReshapeIntoInterfaceTensorLoad, RemoveDeadMemAllocs>(
-        &getContext());
+    patterns.insert<
+        FoldReshapeIntoInterfaceTensorLoad<linalg::TensorCollapseShapeOp>,
+        FoldReshapeIntoInterfaceTensorLoad<linalg::TensorExpandShapeOp>,
+        RemoveDeadMemAllocs>(&getContext());
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };
diff --git a/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp b/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp
index d0766db..ffbd37d 100644
--- a/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp
+++ b/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp
@@ -85,9 +85,14 @@
   if (!definingOp) return false;
   return TypeSwitch<Operation *, bool>(definingOp)
       .Case<ConstantOp>([&](ConstantOp constantOp) { return true; })
-      .Case<linalg::TensorReshapeOp>([&](linalg::TensorReshapeOp reshapeOp) {
-        return isFromReadOnlyTensor(reshapeOp.src());
-      })
+      .Case<linalg::TensorExpandShapeOp>(
+          [&](linalg::TensorExpandShapeOp reshapeOp) {
+            return isFromReadOnlyTensor(reshapeOp.src());
+          })
+      .Case<linalg::TensorCollapseShapeOp>(
+          [&](linalg::TensorCollapseShapeOp reshapeOp) {
+            return isFromReadOnlyTensor(reshapeOp.src());
+          })
       .Case<SubTensorOp>([&](SubTensorOp subTensorOp) {
         return isFromReadOnlyTensor(subTensorOp.source());
       })
@@ -108,7 +113,10 @@
   // TODO(ravishankarm): Maybe this is too aggressive, might have to switch this
   // to have a white-list instead of blacklist.
   for (Operation *user : op->getUsers()) {
-    if (isa<IREE::Flow::DispatchTensorStoreOp, linalg::TensorReshapeOp>(user))
+    if (isa<IREE::Flow::DispatchTensorStoreOp, linalg::TensorCollapseShapeOp>(
+            user) ||
+        isa<IREE::Flow::DispatchTensorStoreOp, linalg::TensorExpandShapeOp>(
+            user))
       return false;
   }
   return true;
@@ -249,7 +257,8 @@
         storeOp.getMixedStrides().empty())) {
     SmallVector<Value> mappedTensors = plan.getTensorsMappedToSameSet(value);
     for (auto v : mappedTensors) {
-      if (v.getDefiningOp<linalg::TensorReshapeOp>()) return false;
+      if (v.getDefiningOp<linalg::TensorCollapseShapeOp>()) return false;
+      if (v.getDefiningOp<linalg::TensorExpandShapeOp>()) return false;
     }
   }
 
@@ -482,8 +491,13 @@
         .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
           return analyseLinalgOps(linalgOp, plan);
         })
-        .Case<linalg::TensorReshapeOp>(
-            [&](linalg::TensorReshapeOp tensorReshapeOp) {
+        .Case<linalg::TensorCollapseShapeOp>(
+            [&](linalg::TensorCollapseShapeOp tensorReshapeOp) {
+              return analyseSingleOperandResultOp(
+                  tensorReshapeOp.src(), tensorReshapeOp.result(), plan);
+            })
+        .Case<linalg::TensorExpandShapeOp>(
+            [&](linalg::TensorExpandShapeOp tensorReshapeOp) {
               return analyseSingleOperandResultOp(
                   tensorReshapeOp.src(), tensorReshapeOp.result(), plan);
             })
@@ -681,16 +695,31 @@
   return subview;
 }
 
-/// Gets the reverse of a `linalg.tensor_reshape` op to get a memref type that
-/// can be used for in-place computation of the result of a disaptch region.
-static Value getReverseOfReshapeOp(OpBuilder &b,
-                                   linalg::TensorReshapeOp reshapeOp,
-                                   Value resultBuffer) {
+/// Gets the reverse of a `linalg.tensor_expand_shape` op to get a memref type
+/// that can be used for in-place computation of the result of a dispatch
+/// region.
+static Value getReverseOfExpandShapeOp(OpBuilder &b,
+                                       linalg::TensorExpandShapeOp expandOp,
+                                       Value resultBuffer) {
   auto memrefType = getMemrefTypeForTensor(
-      reshapeOp.getSrcType(), {},
+      expandOp.getSrcType(), {},
       resultBuffer.getType().cast<MemRefType>().getMemorySpaceAsInt());
-  return b.create<linalg::ReshapeOp>(reshapeOp.getLoc(), memrefType,
-                                     resultBuffer, reshapeOp.reassociation());
+  return b.create<linalg::CollapseShapeOp>(
+      expandOp.getLoc(), memrefType, resultBuffer, expandOp.reassociation());
+}
+
+/// Gets the reverse of a `linalg.tensor_collapse_shape` op to get a memref type
+/// that can be used for in-place computation of the result of a dispatch
+/// region.
+static Value getReverseOfCollapseShapeOp(
+    OpBuilder &b, linalg::TensorCollapseShapeOp collapseOp,
+    Value resultBuffer) {
+  auto memrefType = getMemrefTypeForTensor(
+      collapseOp.getSrcType(), {},
+      resultBuffer.getType().cast<MemRefType>().getMemorySpaceAsInt());
+  return b.create<linalg::ExpandShapeOp>(collapseOp.getLoc(), memrefType,
+                                         resultBuffer,
+                                         collapseOp.reassociation());
 }
 
 /// Gets the reverse of a `tensor.cast` op to get a memref type that
@@ -763,9 +792,14 @@
             .Case<scf::ForOp, linalg::LinalgOp, SubTensorInsertOp,
                   vector::TransferWriteOp>(
                 [&](auto op) { return resultBuffer; })
-            .Case<linalg::TensorReshapeOp>(
-                [&](linalg::TensorReshapeOp reshapeOp) {
-                  return getReverseOfReshapeOp(b, reshapeOp, resultBuffer);
+            .Case<linalg::TensorExpandShapeOp>(
+                [&](linalg::TensorExpandShapeOp expandOp) {
+                  return getReverseOfExpandShapeOp(b, expandOp, resultBuffer);
+                })
+            .Case<linalg::TensorCollapseShapeOp>(
+                [&](linalg::TensorCollapseShapeOp collapseOp) {
+                  return getReverseOfCollapseShapeOp(b, collapseOp,
+                                                     resultBuffer);
                 })
             .Case<tensor::CastOp>([&](tensor::CastOp castOp) {
               return getReverseOfCastOp(b, castOp, resultBuffer);
@@ -807,10 +841,10 @@
                          loadOp.getMixedSizes(), loadOp.getMixedStrides());
 }
 
-/// Converts a `linalg.tensor_reshape` operation to a `linalg.reshape`
+/// Converts a `linalg.tensor_expand_shape` operation to a `linalg.expand_shape`
 /// operation with the result aliasing the buffer for the operand.
 static Value getAliasingBufferForResult(OpBuilder &b,
-                                        linalg::TensorReshapeOp op,
+                                        linalg::TensorExpandShapeOp op,
                                         BlockAndValueMapping &bvm) {
   Location loc = op.getLoc();
   Value srcTensor = op.src();
@@ -821,7 +855,27 @@
   MemRefType inputBufferType = inputBuffer.getType().cast<MemRefType>();
   auto reshapeResultType = getMemrefTypeForTensor(
       resultTensorType, {}, inputBufferType.getMemorySpaceAsInt());
-  Value bufferReshape = b.create<linalg::ReshapeOp>(
+  Value bufferReshape = b.create<linalg::ExpandShapeOp>(
+      loc, reshapeResultType, inputBuffer, op.reassociation());
+  return bufferReshape;
+}
+
+/// Converts a `linalg.tensor_collapse_shape` operation to a
+/// `linalg.collapse_shape` operation with the result aliasing the buffer for
+/// the operand.
+static Value getAliasingBufferForResult(OpBuilder &b,
+                                        linalg::TensorCollapseShapeOp op,
+                                        BlockAndValueMapping &bvm) {
+  Location loc = op.getLoc();
+  Value srcTensor = op.src();
+  RankedTensorType resultTensorType = op.getResultType();
+  Value inputBuffer = bvm.lookup(srcTensor);
+
+  // Create the reshape op.
+  MemRefType inputBufferType = inputBuffer.getType().cast<MemRefType>();
+  auto reshapeResultType = getMemrefTypeForTensor(
+      resultTensorType, {}, inputBufferType.getMemorySpaceAsInt());
+  Value bufferReshape = b.create<linalg::CollapseShapeOp>(
       loc, reshapeResultType, inputBuffer, op.reassociation());
   return bufferReshape;
 }
@@ -866,7 +920,12 @@
 static SmallVector<Value, 4> getAliasingBuffersForResults(
     OpBuilder &b, Operation *op, BlockAndValueMapping &bvm) {
   return TypeSwitch<Operation *, SmallVector<Value, 4>>(op)
-      .Case<IREE::Flow::DispatchTensorLoadOp, linalg::TensorReshapeOp,
+      .Case<IREE::Flow::DispatchTensorLoadOp, linalg::TensorCollapseShapeOp,
+            SubTensorOp, tensor::CastOp>(
+          [&](auto singleResultOp) -> SmallVector<Value, 4> {
+            return {getAliasingBufferForResult(b, singleResultOp, bvm)};
+          })
+      .Case<IREE::Flow::DispatchTensorLoadOp, linalg::TensorExpandShapeOp,
             SubTensorOp, tensor::CastOp>(
           [&](auto singleResultOp) -> SmallVector<Value, 4> {
             return {getAliasingBufferForResult(b, singleResultOp, bvm)};
@@ -1301,7 +1360,21 @@
           }
           return convertScfForOp(b, forOp, bvm, plan);
         })
-        .Case<IREE::Flow::DispatchTensorLoadOp, linalg::TensorReshapeOp,
+        .Case<IREE::Flow::DispatchTensorLoadOp, linalg::TensorCollapseShapeOp,
+              SubTensorOp, tensor::CastOp>([&](auto aliasingOp) {
+          auto aliasingBuffers =
+              getAliasingBuffersForResults(b, aliasingOp, bvm);
+          if (failed(getOrAllocateResultBuffers(
+                  b, aliasingOp, aliasingOp->getOperand(0), aliasingBuffers,
+                  bvm, plan, allocationFn))) {
+            return failure();
+          }
+          copyFromAliasingBufferToResultBuffer(
+              b, aliasingOp->getLoc(), aliasingOp->getOperand(0),
+              aliasingOp->getResult(0), aliasingBuffers, bvm, plan);
+          return success();
+        })
+        .Case<IREE::Flow::DispatchTensorLoadOp, linalg::TensorExpandShapeOp,
               SubTensorOp, tensor::CastOp>([&](auto aliasingOp) {
           auto aliasingBuffers =
               getAliasingBuffersForResults(b, aliasingOp, bvm);
diff --git a/iree/compiler/Conversion/Common/test/canonicalize_interface_load_store.mlir b/iree/compiler/Conversion/Common/test/canonicalize_interface_load_store.mlir
index ff8df97..2870012 100644
--- a/iree/compiler/Conversion/Common/test/canonicalize_interface_load_store.mlir
+++ b/iree/compiler/Conversion/Common/test/canonicalize_interface_load_store.mlir
@@ -10,8 +10,8 @@
   %2 = hal.interface.binding.subspan @interface_io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:3x3x96xf32>
   // CHECK: %[[LOAD:.+]] = flow.dispatch.tensor.load %[[ARG]], {{.*}} : !flow.dispatch.tensor<readonly:3x3x96xf32> -> tensor<3x3x96xf32>
   %3 = flow.dispatch.tensor.load %1, offsets=[], sizes =[], strides=[] : !flow.dispatch.tensor<readonly:3x3x1x96xf32> -> tensor<3x3x1x96xf32>
-  %4 = linalg.tensor_reshape %3 [[0, 1, 2, 3]] : tensor<3x3x1x96xf32> into tensor<864xf32>
-  %5 = linalg.tensor_reshape %4 [[0, 1, 2]] : tensor<864xf32> into tensor<3x3x96xf32>
+  %4 = linalg.tensor_collapse_shape %3 [[0, 1, 2, 3]] : tensor<3x3x1x96xf32> into tensor<864xf32>
+  %5 = linalg.tensor_expand_shape %4 [[0, 1, 2]] : tensor<864xf32> into tensor<3x3x96xf32>
   //  CHECK: flow.dispatch.tensor.store %[[LOAD]], {{.*}}
   flow.dispatch.tensor.store %5, %2, offsets = [%c0, %c0, %c0], sizes = [%c1, %c1, %c1], strides = [%c1, %c1, %c1] : tensor<3x3x96xf32> -> !flow.dispatch.tensor<writeonly:3x3x96xf32>
   return
@@ -34,9 +34,10 @@
   %1 = hal.interface.binding.subspan @interface_io::@arg0[%c0] : !flow.dispatch.tensor<readonly:6x3x1x96xf32>
   %2 = hal.interface.binding.subspan @interface_io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:3x3x96xf32>
   %3 = flow.dispatch.tensor.load %1, offsets = [%c3, %c0, %c0, %c0], sizes = [%c3, %c3, %c1, %c96], strides = [%c1, %c1, %c1, %c1] : !flow.dispatch.tensor<readonly:6x3x1x96xf32> -> tensor<3x3x1x96xf32>
-  // CHECK-COUNT-2: linalg.tensor_reshape
-  %4 = linalg.tensor_reshape %3 [[0, 1, 2, 3]] : tensor<3x3x1x96xf32> into tensor<864xf32>
-  %5 = linalg.tensor_reshape %4 [[0, 1, 2]] : tensor<864xf32> into tensor<3x3x96xf32>
+  // CHECK: linalg.tensor_collapse_shape
+  // CHECK: linalg.tensor_expand_shape
+  %4 = linalg.tensor_collapse_shape %3 [[0, 1, 2, 3]] : tensor<3x3x1x96xf32> into tensor<864xf32>
+  %5 = linalg.tensor_expand_shape %4 [[0, 1, 2]] : tensor<864xf32> into tensor<3x3x96xf32>
   flow.dispatch.tensor.store %5, %2, offsets = [%c0, %c0, %c0], sizes = [%c1, %c1, %c1], strides = [%c1, %c1, %c1] : tensor<3x3x96xf32> -> !flow.dispatch.tensor<writeonly:3x3x96xf32>
   return
 }
diff --git a/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir b/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir
index 0d88745..0ba2b4f 100644
--- a/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir
+++ b/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir
@@ -644,7 +644,7 @@
   %0 = hal.interface.binding.subspan @io::@arg0[%c0] : !flow.dispatch.tensor<readonly:12xi32>
   %1 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:3x4xi32>
   %2 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:12xi32> -> tensor<12xi32>
-  %3 = linalg.tensor_reshape %2 [[0, 1]] : tensor<12xi32> into tensor<3x4xi32>
+  %3 = linalg.tensor_expand_shape %2 [[0, 1]] : tensor<12xi32> into tensor<3x4xi32>
   flow.dispatch.tensor.store %3, %1, offsets = [], sizes = [], strides = [] : tensor<3x4xi32> -> !flow.dispatch.tensor<writeonly:3x4xi32>
   return
 }
@@ -655,7 +655,7 @@
 //       CHECK: func @reshape_simple()
 //   CHECK-DAG:   %[[ARG0:.+]] = hal.interface.binding.subspan @io::@arg0
 //   CHECK-DAG:   %[[RET0:.+]] = hal.interface.binding.subspan @io::@ret0
-//       CHECK:   %[[RESHAPE:.+]] = linalg.reshape %[[ARG0]] {{\[}}[0, 1]]
+//       CHECK:   %[[RESHAPE:.+]] = linalg.expand_shape %[[ARG0]] {{\[}}[0, 1]]
 //       CHECK:   linalg.copy(%[[RESHAPE]], %[[RET0]])
 
 // -----
@@ -669,7 +669,7 @@
   %0 = hal.interface.binding.subspan @io::@arg0[%c0] : !flow.dispatch.tensor<readonly:12xi32>
   %1 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:3x4xi32>
   %2 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:12xi32> -> tensor<12xi32>
-  %3 = linalg.tensor_reshape %2 [[0, 1]] : tensor<12xi32> into tensor<3x4xi32>
+  %3 = linalg.tensor_expand_shape %2 [[0, 1]] : tensor<12xi32> into tensor<3x4xi32>
   %4 = linalg.init_tensor [3, 4] : tensor<3x4xi32>
   %5 = linalg.generic {
     indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
@@ -690,7 +690,7 @@
 //       CHECK:   %[[C0:.+]] = constant 0
 //   CHECK-DAG:   %[[ARG0:.+]] = hal.interface.binding.subspan @io::@arg0[%[[C0]]] : memref<12xi32>
 //   CHECK-DAG:   %[[RET0:.+]] = hal.interface.binding.subspan @io::@ret0[%[[C0]]] : memref<3x4xi32>
-//       CHECK:   %[[RESHAPE:.+]] = linalg.reshape %[[ARG0]] {{\[}}[0, 1]]
+//       CHECK:   %[[RESHAPE:.+]] = linalg.expand_shape %[[ARG0]] {{\[}}[0, 1]]
 //       CHECK:   linalg.generic
 //  CHECK-SAME:     ins(%[[RESHAPE]] : memref<3x4xi32>)
 //  CHECK-SAME:     outs(%[[RET0]] : memref<3x4xi32>)
@@ -707,7 +707,7 @@
   %1 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:3x4xi32>
   %2 = hal.interface.binding.subspan @io::@ret1[%c0] : !flow.dispatch.tensor<writeonly:3x4xi32>
   %3 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:12xi32> -> tensor<12xi32>
-  %4 = linalg.tensor_reshape %3 [[0, 1]] : tensor<12xi32> into tensor<3x4xi32>
+  %4 = linalg.tensor_expand_shape %3 [[0, 1]] : tensor<12xi32> into tensor<3x4xi32>
   %5 = linalg.init_tensor [3, 4] : tensor<3x4xi32>
   %6 = linalg.generic {
     indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
@@ -731,7 +731,7 @@
 //   CHECK-DAG:   %[[ARG0:.+]] = hal.interface.binding.subspan @io::@arg0[%[[C0]]] : memref<12xi32>
 //   CHECK-DAG:   %[[RET0:.+]] = hal.interface.binding.subspan @io::@ret0[%[[C0]]] : memref<3x4xi32>
 //   CHECK-DAG:   %[[RET1:.+]] = hal.interface.binding.subspan @io::@ret1[%[[C0]]] : memref<3x4xi32>
-//       CHECK:   %[[RESHAPE:.+]] = linalg.reshape %[[ARG0]] {{\[}}[0, 1]]
+//       CHECK:   %[[RESHAPE:.+]] = linalg.expand_shape %[[ARG0]] {{\[}}[0, 1]]
 //       CHECK:   linalg.generic
 //  CHECK-SAME:     ins(%[[RESHAPE]] : memref<3x4xi32>)
 //  CHECK-SAME:     outs(%[[RET0]] : memref<3x4xi32>)
@@ -757,7 +757,7 @@
       %5 = addi %arg0, %arg0 : i32
       linalg.yield %5 : i32
     } -> tensor<3x4xi32>
-  %5 = linalg.tensor_reshape %4 [[0, 1]] : tensor<3x4xi32> into tensor<12xi32>
+  %5 = linalg.tensor_collapse_shape %4 [[0, 1]] : tensor<3x4xi32> into tensor<12xi32>
   flow.dispatch.tensor.store %5, %1, offsets = [], sizes = [], strides = [] : tensor<12xi32> -> !flow.dispatch.tensor<writeonly:12xi32>
   return
 }
@@ -769,7 +769,7 @@
 //       CHECK:   %[[C0:.+]] = constant 0
 //   CHECK-DAG:   %[[ARG0:.+]] = hal.interface.binding.subspan @io::@arg0[%[[C0]]] : memref<3x4xi32>
 //   CHECK-DAG:   %[[RET0:.+]] = hal.interface.binding.subspan @io::@ret0[%[[C0]]] : memref<12xi32>
-//       CHECK:   %[[RESHAPE:.+]] = linalg.reshape %[[RET0]] {{\[}}[0, 1]]
+//       CHECK:   %[[RESHAPE:.+]] = linalg.expand_shape %[[RET0]] {{\[}}[0, 1]]
 //       CHECK:   linalg.generic
 //  CHECK-SAME:     ins(%[[ARG0]] : memref<3x4xi32>)
 //  CHECK-SAME:     outs(%[[RESHAPE]] : memref<3x4xi32>)
@@ -786,7 +786,7 @@
   %1 = hal.interface.binding.subspan @io::@arg1[%c0] : !flow.dispatch.tensor<readonly:2x3xf32>
   %2 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:1x3xf32>
   %3 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:1x1x2xf32> -> tensor<1x1x2xf32>
-  %4 = linalg.tensor_reshape %3 [[0, 1], [2]] : tensor<1x1x2xf32> into tensor<1x2xf32>
+  %4 = linalg.tensor_collapse_shape %3 [[0, 1], [2]] : tensor<1x1x2xf32> into tensor<1x2xf32>
   %workgroup_size_x = hal.interface.workgroup.size[0] : index
   %workgroup_size_y = hal.interface.workgroup.size[1] : index
   %workgroup_id_x = hal.interface.workgroup.id[0] : index
@@ -819,7 +819,7 @@
 // CHECK-LABEL: func @dot_general_lowering()
 //   CHECK-DAG:   %[[LHS:.+]] = hal.interface.binding.subspan @io::@arg0
 //   CHECK-DAG:   %[[RHS:.+]] = hal.interface.binding.subspan @io::@arg1
-//   CHECK-DAG:   %[[RESHAPE_LHS:.+]] = linalg.reshape %[[LHS]]
+//   CHECK-DAG:   %[[RESHAPE_LHS:.+]] = linalg.collapse_shape %[[LHS]]
 //   CHECK-DAG:   %[[RETURN:.+]] = hal.interface.binding.subspan @io::@ret0
 //       CHECK:   scf.for %[[IV0:.+]] = {{.+}} {
 //       CHECK:     scf.for %[[IV1:.+]] = {{.+}} {
@@ -1063,7 +1063,7 @@
   %0 = hal.interface.binding.subspan @io::@arg0[%c0] : !flow.dispatch.tensor<readonly:1x5x3x1xf32>
   %1 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:5x5xf32>
   %2 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:1x5x3x1xf32> -> tensor<1x5x3x1xf32>
-  %3 = linalg.tensor_reshape %2 [[0, 1], [2, 3]] : tensor<1x5x3x1xf32> into tensor<5x3xf32>
+  %3 = linalg.tensor_collapse_shape %2 [[0, 1], [2, 3]] : tensor<1x5x3x1xf32> into tensor<5x3xf32>
   %workgroup_size_x = hal.interface.workgroup.size[0] : index
   %workgroup_size_y = hal.interface.workgroup.size[1] : index
   %workgroup_id_x = hal.interface.workgroup.id[0] : index
@@ -1097,7 +1097,7 @@
 //   CHECK-DAG:   %[[RHS:.+]] = memref.buffer_cast %[[CONSTANT]]
 //   CHECK-DAG:   %[[LHS_INPUT:.+]] = hal.interface.binding.subspan @io::@arg0[%{{.+}}] : memref<1x5x3x1xf32>
 //   CHECK-DAG:   %[[RETURN:.+]] = hal.interface.binding.subspan @io::@ret0[%{{.+}}] : memref<5x5xf32>
-//       CHECK:   %[[LHS:.+]] = linalg.reshape %[[LHS_INPUT]]
+//       CHECK:   %[[LHS:.+]] = linalg.collapse_shape %[[LHS_INPUT]]
 //       CHECK:   scf.for %[[IV0:.+]] =
 //       CHECK:     scf.for %[[IV1:.+]] =
 //   CHECK-DAG:       %[[LHS_SUBVIEW:.+]] = memref.subview %[[LHS]][%[[IV0]], 0]
@@ -1260,7 +1260,7 @@
   %0 = hal.interface.binding.subspan @io::@ro0[%c0] : !flow.dispatch.tensor<readonly:?x?xf32>
   %1 = hal.interface.binding.subspan @io::@wo0[%c0] : !flow.dispatch.tensor<writeonly:?xf32>
   %2 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<?x?xf32>
-  %3 = linalg.tensor_reshape %2 [[0, 1]]
+  %3 = linalg.tensor_collapse_shape %2 [[0, 1]]
       : tensor<?x?xf32> into tensor<?xf32>
   %4 = memref.dim %3, %c0 : tensor<?xf32>
   %5 = linalg.init_tensor [%4] : tensor<?xf32>
@@ -1278,7 +1278,7 @@
 // CHECK-LABEL: func @reshape_read_only
 //   CHECK-DAG:   %[[INPUT:.+]] = hal.interface.binding.subspan @io::@ro0
 //   CHECK-DAG:   %[[OUTPUT:.+]] = hal.interface.binding.subspan @io::@wo0
-//       CHECK:   %[[RESHAPE:.+]] = linalg.reshape %[[INPUT]]
+//       CHECK:   %[[RESHAPE:.+]] = linalg.collapse_shape %[[INPUT]]
 //       CHECK:   linalg.generic
 //  CHECK-SAME:     ins(%[[RESHAPE]] : memref<?xf32>)
 //  CHECK-SAME:     outs(%[[OUTPUT]] : memref<?xf32>)
@@ -1800,7 +1800,7 @@
         linalg.yield %cst : f32
       } : tensor<2x?xf32> to tensor<4x4xf32>
       %15 = linalg.init_tensor [4, 4] : tensor<4x4xf32>
-      %16 = linalg.fill(%15, %cst) : tensor<4x4xf32>, f32 -> tensor<4x4xf32> 
+      %16 = linalg.fill(%15, %cst) : tensor<4x4xf32>, f32 -> tensor<4x4xf32>
       %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%13, %14 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%16 : tensor<4x4xf32>) -> tensor<4x4xf32>
       %18 = subtensor %17[0, 0] [%7, %9] [1, 1] : tensor<4x4xf32> to tensor<?x?xf32>
       flow.dispatch.tensor.store %18, %2, offsets = [%arg0, %arg1], sizes = [%11, %12], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:?x?xf32>
@@ -1877,17 +1877,17 @@
         %13 = flow.dispatch.tensor.load %0, offsets = [0, %9, %11, 0], sizes = [1, %10, %12, 8], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:1x225x225x8xf32> -> tensor<1x?x?x8xf32>
         %14 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, %arg2], sizes = [3, 3, 8, 4], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:3x3x8x32xf32> -> tensor<3x3x8x4xf32>
         %15 = linalg.init_tensor [1, 16, 16, 4] : tensor<1x16x16x4xf32>
-        %16 = linalg.fill(%15, %cst) {__internal_linalg_transform__ = "workgroup"} : tensor<1x16x16x4xf32>, f32 -> tensor<1x16x16x4xf32> 
+        %16 = linalg.fill(%15, %cst) {__internal_linalg_transform__ = "workgroup"} : tensor<1x16x16x4xf32>, f32 -> tensor<1x16x16x4xf32>
         %17 = linalg.init_tensor [1, 16, 16, 3, 3, 8] : tensor<1x16x16x3x3x8xf32>
         %18 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d3, d2 * 2 + d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%13 : tensor<1x?x?x8xf32>) outs(%17 : tensor<1x16x16x3x3x8xf32>) {
         ^bb0(%arg3: f32, %arg4: f32):  // no predecessors
           linalg.yield %arg3 : f32
         } -> tensor<1x16x16x3x3x8xf32>
-        %19 = linalg.tensor_reshape %18 [[0, 1, 2], [3, 4, 5]] : tensor<1x16x16x3x3x8xf32> into tensor<256x72xf32>
-        %20 = linalg.tensor_reshape %14 [[0, 1, 2], [3]] : tensor<3x3x8x4xf32> into tensor<72x4xf32>
-        %21 = linalg.tensor_reshape %16 [[0, 1, 2], [3]] : tensor<1x16x16x4xf32> into tensor<256x4xf32>
+        %19 = linalg.tensor_collapse_shape %18 [[0, 1, 2], [3, 4, 5]] : tensor<1x16x16x3x3x8xf32> into tensor<256x72xf32>
+        %20 = linalg.tensor_collapse_shape %14 [[0, 1, 2], [3]] : tensor<3x3x8x4xf32> into tensor<72x4xf32>
+        %21 = linalg.tensor_collapse_shape %16 [[0, 1, 2], [3]] : tensor<1x16x16x4xf32> into tensor<256x4xf32>
         %22 = linalg.matmul ins(%19, %20 : tensor<256x72xf32>, tensor<72x4xf32>) outs(%21 : tensor<256x4xf32>) -> tensor<256x4xf32>
-        %23 = linalg.tensor_reshape %22 [[0, 1, 2], [3]] : tensor<256x4xf32> into tensor<1x16x16x4xf32>
+        %23 = linalg.tensor_expand_shape %22 [[0, 1, 2], [3]] : tensor<256x4xf32> into tensor<1x16x16x4xf32>
         %24 = tensor.cast %23 : tensor<1x16x16x4xf32> to tensor<1x?x?x?xf32>
         flow.dispatch.tensor.store %24, %2, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %c16, %c16, %c4], strides = [1, 1, 1, 1] : tensor<1x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:1x112x112x32xf32>
       }
@@ -1917,9 +1917,9 @@
 //       CHECK:      linalg.generic
 //  CHECK-SAME:        ins(%[[ARG0_SV]]
 //  CHECK-SAME:        outs(%[[ALLOC_ARG0]]
-//   CHECK-DAG:      %[[ALLOC_ARG0_RESHAPE:.+]] = linalg.reshape %[[ALLOC_ARG0]]
-//   CHECK-DAG:      %[[ALLOC_ARG1_RESHAPE:.+]] = linalg.reshape %[[ALLOC_ARG1]]
-//   CHECK-DAG:      %[[ALLOC_RET0_RESHAPE:.+]] = linalg.reshape %[[ALLOC_RET0]]
+//   CHECK-DAG:      %[[ALLOC_ARG0_RESHAPE:.+]] = linalg.collapse_shape %[[ALLOC_ARG0]]
+//   CHECK-DAG:      %[[ALLOC_ARG1_RESHAPE:.+]] = linalg.collapse_shape %[[ALLOC_ARG1]]
+//   CHECK-DAG:      %[[ALLOC_RET0_RESHAPE:.+]] = linalg.collapse_shape %[[ALLOC_RET0]]
 //       CHECK:      linalg.matmul
 //  CHECK-SAME:        ins(%[[ALLOC_ARG0_RESHAPE]], %[[ALLOC_ARG1_RESHAPE]]
 //  CHECK-SAME:        outs(%[[ALLOC_RET0_RESHAPE]]
@@ -1948,8 +1948,8 @@
     %7 = flow.dispatch.tensor.load %0, offsets = [0, %arg0], sizes = [%d0, %6], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x?xi32> -> tensor<?x?xi32>
     %9 = flow.dispatch.tensor.load %1, offsets = [0, %arg0], sizes = [%d0, %6], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x?xi32> -> tensor<?x?xi32>
     %13 = linalg.init_tensor [%6] : tensor<?xi32>
-    %14 = linalg.fill(%13, %c-2147483648_i32) {__internal_linalg_transform__ = "workgroup", lowering.config = {tileSizes = [[128]]}} : tensor<?xi32>, i32 -> tensor<?xi32> 
-    %17 = linalg.fill(%13, %c0_i32) {__internal_linalg_transform__ = "workgroup", lowering.config = {tileSizes = [[128]]}} : tensor<?xi32>, i32 -> tensor<?xi32> 
+    %14 = linalg.fill(%13, %c-2147483648_i32) {__internal_linalg_transform__ = "workgroup", lowering.config = {tileSizes = [[128]]}} : tensor<?xi32>, i32 -> tensor<?xi32>
+    %17 = linalg.fill(%13, %c0_i32) {__internal_linalg_transform__ = "workgroup", lowering.config = {tileSizes = [[128]]}} : tensor<?xi32>, i32 -> tensor<?xi32>
     %18:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%7, %9 : tensor<?x?xi32>, tensor<?x?xi32>) outs(%14, %17 : tensor<?xi32>, tensor<?xi32>) attrs =  {__internal_linalg_transform__ = "workgroup", lowering.config = {tileSizes = [[128]]}} {
     ^bb0(%arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32):  // no predecessors
       %19 = cmpi sge, %arg1, %arg3 : i32
@@ -2028,7 +2028,7 @@
         %9 = affine.min #map2(%arg1)
         %10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [144, %9], strides = [1, 1] : !flow.dispatch.tensor<readonly:144x370xf32> -> tensor<144x?xf32>
         %11 = linalg.init_tensor [%7, %9] : tensor<?x?xf32>
-        %12 = linalg.fill(%11, %cst) {__internal_linalg_transform__ = "workgroup", lowering.config = #config0} : tensor<?x?xf32>, f32 -> tensor<?x?xf32> 
+        %12 = linalg.fill(%11, %cst) {__internal_linalg_transform__ = "workgroup", lowering.config = #config0} : tensor<?x?xf32>, f32 -> tensor<?x?xf32>
         %13 = scf.for %arg2 = %c0 to %c250 step %c32 iter_args(%arg3 = %12) -> (tensor<?x?xf32>) {
           %14 = scf.for %arg4 = %c0 to %c370 step %c32 iter_args(%arg5 = %arg3) -> (tensor<?x?xf32>) {
             %15 = scf.for %arg6 = %c0 to %c144 step %c24 iter_args(%arg7 = %arg5) -> (tensor<?x?xf32>) {
diff --git a/iree/compiler/Conversion/HLOToLinalg/FusionOfTensorOps.cpp b/iree/compiler/Conversion/HLOToLinalg/FusionOfTensorOps.cpp
index d907171..e7a97ec 100644
--- a/iree/compiler/Conversion/HLOToLinalg/FusionOfTensorOps.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/FusionOfTensorOps.cpp
@@ -28,9 +28,10 @@
 
 namespace mlir {
 namespace iree_compiler {
-
 namespace {
 
+using linalg::LinalgOp;
+
 /// Pass to fuse linalg on tensor operations as well as fusion of hal.interface*
 /// operations with linalg.tensor_reshape operation.
 struct FusionOfTensorOpsPass
@@ -58,7 +59,7 @@
           if (!clEnableFusionWithReductionOps) {
             auto consumerOp = consumer.getOwner();
             if (isa<linalg::GenericOp, linalg::IndexedGenericOp>(consumerOp) &&
-                dyn_cast<linalg::LinalgOp>(consumerOp).getNumReductionLoops()) {
+                dyn_cast<LinalgOp>(consumerOp).getNumReductionLoops()) {
               return false;
             }
           }
@@ -77,8 +78,14 @@
     // to the consumer linalg op.
     linalg::ControlElementwiseOpsFusionFn foldReshapeBetweenLinalgFn =
         [](const OpResult &producer, const OpOperand &consumer) {
-          auto reshapeOp = producer.getDefiningOp<linalg::TensorReshapeOp>();
-          return reshapeOp.src().getDefiningOp<linalg::LinalgOp>() != nullptr;
+          auto collapseOp =
+              producer.getDefiningOp<linalg::TensorCollapseShapeOp>();
+          if (collapseOp)
+            return collapseOp.src().getDefiningOp<LinalgOp>() != nullptr;
+          auto expandOp = producer.getDefiningOp<linalg::TensorExpandShapeOp>();
+          if (expandOp)
+            return expandOp.src().getDefiningOp<LinalgOp>() != nullptr;
+          return false;
         };
     linalg::populateElementwiseOpsFusionPatterns(
         fusionPatterns,
@@ -92,7 +99,9 @@
     OwningRewritePatternList reshapeCanonicalizations(&getContext());
     linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
         reshapeCanonicalizations);
-    linalg::TensorReshapeOp::getCanonicalizationPatterns(
+    linalg::TensorCollapseShapeOp::getCanonicalizationPatterns(
+        reshapeCanonicalizations, context);
+    linalg::TensorExpandShapeOp::getCanonicalizationPatterns(
         reshapeCanonicalizations, context);
     (void)applyPatternsAndFoldGreedily(op->getRegions(),
                                        std::move(reshapeCanonicalizations));
@@ -100,8 +109,10 @@
     // Push the remaining reshapes down the graphs.
     OwningRewritePatternList pushReshapePatterns(&getContext());
     linalg::populatePushReshapeOpsPatterns(pushReshapePatterns);
-    linalg::TensorReshapeOp::getCanonicalizationPatterns(pushReshapePatterns,
-                                                         context);
+    linalg::TensorCollapseShapeOp::getCanonicalizationPatterns(
+        pushReshapePatterns, context);
+    linalg::TensorExpandShapeOp::getCanonicalizationPatterns(
+        pushReshapePatterns, context);
     (void)applyPatternsAndFoldGreedily(op->getRegions(),
                                        std::move(pushReshapePatterns));
   }
diff --git a/iree/compiler/Conversion/LinalgToLinalg/Conv2D1x1ToMatmul.cpp b/iree/compiler/Conversion/LinalgToLinalg/Conv2D1x1ToMatmul.cpp
index a2cdc1a..bb50042 100644
--- a/iree/compiler/Conversion/LinalgToLinalg/Conv2D1x1ToMatmul.cpp
+++ b/iree/compiler/Conversion/LinalgToLinalg/Conv2D1x1ToMatmul.cpp
@@ -62,18 +62,18 @@
     Value output = convOp.getOutput(0);
     auto loc = convOp.getLoc();
 
-    Value reshapedInput = rewriter.create<linalg::TensorReshapeOp>(
+    Value reshapedInput = rewriter.create<linalg::TensorCollapseShapeOp>(
         loc, reshapedInputType, input, reassociationIndices);
-    Value reshapedFilter = rewriter.create<linalg::TensorReshapeOp>(
+    Value reshapedFilter = rewriter.create<linalg::TensorCollapseShapeOp>(
         loc, reshapedFilterType, filter, reassociationIndices);
-    Value reshapedOutput = rewriter.create<linalg::TensorReshapeOp>(
+    Value reshapedOutput = rewriter.create<linalg::TensorCollapseShapeOp>(
         loc, reshapedOutputType, output, reassociationIndices);
 
     auto matmulResult = rewriter.create<linalg::MatmulOp>(
         loc, reshapedOutputType, ArrayRef<Value>{reshapedInput, reshapedFilter},
         ArrayRef<Value>{reshapedOutput});
 
-    auto reshapedResult = rewriter.create<linalg::TensorReshapeOp>(
+    auto reshapedResult = rewriter.create<linalg::TensorExpandShapeOp>(
         loc, outputShapeType, matmulResult.getResults()[0],
         reassociationIndices);
 
diff --git a/iree/compiler/Conversion/LinalgToLinalg/Conv2DToImg2Col.cpp b/iree/compiler/Conversion/LinalgToLinalg/Conv2DToImg2Col.cpp
index 3ee6078..4862acc 100644
--- a/iree/compiler/Conversion/LinalgToLinalg/Conv2DToImg2Col.cpp
+++ b/iree/compiler/Conversion/LinalgToLinalg/Conv2DToImg2Col.cpp
@@ -131,14 +131,15 @@
         RankedTensorType::get({outputShape[1] * outputShape[2], outputShape[3]},
                               outputShapeType.getElementType());
 
-    Value reshapedImg2ColTensor = rewriter.create<linalg::TensorReshapeOp>(
-        loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
-        img2ColTensorReassociationIndices);
+    Value reshapedImg2ColTensor =
+        rewriter.create<linalg::TensorCollapseShapeOp>(
+            loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
+            img2ColTensorReassociationIndices);
 
-    Value reshapedFilter = rewriter.create<linalg::TensorReshapeOp>(
+    Value reshapedFilter = rewriter.create<linalg::TensorCollapseShapeOp>(
         loc, reshapedFilterType, filter, filterAndOutputReassociationIndices);
 
-    Value reshapedOutput = rewriter.create<linalg::TensorReshapeOp>(
+    Value reshapedOutput = rewriter.create<linalg::TensorCollapseShapeOp>(
         loc, reshapedOutputType, output, filterAndOutputReassociationIndices);
 
     auto matmulResult = rewriter.create<linalg::MatmulOp>(
@@ -146,7 +147,7 @@
         ArrayRef<Value>{reshapedImg2ColTensor, reshapedFilter},
         ArrayRef<Value>{reshapedOutput});
 
-    auto reshapedResult = rewriter.create<linalg::TensorReshapeOp>(
+    auto reshapedResult = rewriter.create<linalg::TensorExpandShapeOp>(
         loc, outputShapeType, matmulResult.getResults()[0],
         filterAndOutputReassociationIndices);
 
diff --git a/iree/compiler/Conversion/LinalgToLinalg/test/conv1x1_to_matmul.mlir b/iree/compiler/Conversion/LinalgToLinalg/test/conv1x1_to_matmul.mlir
index 067c66d..dfdd0c1 100644
--- a/iree/compiler/Conversion/LinalgToLinalg/test/conv1x1_to_matmul.mlir
+++ b/iree/compiler/Conversion/LinalgToLinalg/test/conv1x1_to_matmul.mlir
@@ -12,9 +12,9 @@
 // CHECK: %[[INPUT:.+]]: tensor<1x4x5x2xf32>
 // CHECK: %[[FILTER:.+]]: tensor<1x1x2x7xf32>
 // CHECK: %[[OTUPUT:.+]] = linalg.init_tensor [1, 4, 5, 7] : tensor<1x4x5x7xf32>
-// CHECK: %[[RESHAPED_INPUT:.+]] = linalg.tensor_reshape %[[INPUT]] {{\[}}[0, 1, 2], [3]] : tensor<1x4x5x2xf32> into tensor<20x2xf32>
-// CHECK: %[[RESHAPED_FILTER:.+]] = linalg.tensor_reshape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x2x7xf32> into tensor<2x7xf32>
-// CHECK: %[[RESHAPED_OUTPUT:.+]] = linalg.tensor_reshape %[[OTUPUT]] {{\[}}[0, 1, 2], [3]] : tensor<1x4x5x7xf32> into tensor<20x7xf32>
+// CHECK: %[[RESHAPED_INPUT:.+]] = linalg.tensor_collapse_shape %[[INPUT]] {{\[}}[0, 1, 2], [3]] : tensor<1x4x5x2xf32> into tensor<20x2xf32>
+// CHECK: %[[RESHAPED_FILTER:.+]] = linalg.tensor_collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x2x7xf32> into tensor<2x7xf32>
+// CHECK: %[[RESHAPED_OUTPUT:.+]] = linalg.tensor_collapse_shape %[[OTUPUT]] {{\[}}[0, 1, 2], [3]] : tensor<1x4x5x7xf32> into tensor<20x7xf32>
 // CHECK: %[[MATMUL_RESULT:.+]] = linalg.matmul ins(%[[RESHAPED_INPUT]], %[[RESHAPED_FILTER]] : tensor<20x2xf32>, tensor<2x7xf32>) outs(%[[RESHAPED_OUTPUT]] : tensor<20x7xf32>)
-// CHECK: %[[RESULT:.+]] = linalg.tensor_reshape %[[MATMUL_RESULT]] {{\[}}[0, 1, 2], [3]] : tensor<20x7xf32> into tensor<1x4x5x7xf32>
+// CHECK: %[[RESULT:.+]] = linalg.tensor_expand_shape %[[MATMUL_RESULT]] {{\[}}[0, 1, 2], [3]] : tensor<20x7xf32> into tensor<1x4x5x7xf32>
 // CHECK: return %[[RESULT]]
diff --git a/iree/compiler/Conversion/LinalgToLinalg/test/conv2d_to_img2col.mlir b/iree/compiler/Conversion/LinalgToLinalg/test/conv2d_to_img2col.mlir
index 7100ec7..e406535 100644
--- a/iree/compiler/Conversion/LinalgToLinalg/test/conv2d_to_img2col.mlir
+++ b/iree/compiler/Conversion/LinalgToLinalg/test/conv2d_to_img2col.mlir
@@ -20,14 +20,14 @@
 //           CHECK-SAME: #[[MAP1]]
 //                CHECK: ^bb0(%[[IN_DATA:.+]]: f32, %[[OUT_DATA:.+]]: f32)
 //                CHECK: linalg.yield %[[IN_DATA]] : f32
-//      CHECK-DAG: %[[RESHAPED_INIT_COL_TENSOR:.+]] = linalg.tensor_reshape %[[COL_TENSOR]]
+//      CHECK-DAG: %[[RESHAPED_INIT_COL_TENSOR:.+]] = linalg.tensor_collapse_shape %[[COL_TENSOR]]
 //           CHECK-SAME: [0, 1, 2], [3, 4, 5]
 //           CHECK-SAME: tensor<1x14x14x3x3x4xf32> into tensor<196x36xf32>
-//      CHECK-DAG: %[[RESHAPED_FILTER:.+]] = linalg.tensor_reshape %[[FILTER]]
+//      CHECK-DAG: %[[RESHAPED_FILTER:.+]] = linalg.tensor_collapse_shape %[[FILTER]]
 //           CHECK-SAME: [0, 1, 2], [3]
 //           CHECK-SAME: tensor<3x3x4x16xf32> into tensor<36x16xf32>
-//      CHECK-DAG: %[[RESHAPED_OUTPUT:.+]] = linalg.tensor_reshape %[[OUTPUT]]
+//      CHECK-DAG: %[[RESHAPED_OUTPUT:.+]] = linalg.tensor_collapse_shape %[[OUTPUT]]
 //           CHECK-SAME: [0, 1, 2], [3]
 //      CHECK: %[[MATMUL_RESULT:.+]] = linalg.matmul ins(%[[RESHAPED_INIT_COL_TENSOR]], %[[RESHAPED_FILTER]] : tensor<196x36xf32>, tensor<36x16xf32>) outs(%[[RESHAPED_OUTPUT]] : tensor<196x16xf32>)
-//      CHECK: %[[RESULT:.+]] = linalg.tensor_reshape %[[MATMUL_RESULT]] {{\[}}[0, 1, 2], [3]] : tensor<196x16xf32> into tensor<1x14x14x16xf32>
+//      CHECK: %[[RESULT:.+]] = linalg.tensor_expand_shape %[[MATMUL_RESULT]] {{\[}}[0, 1, 2], [3]] : tensor<196x16xf32> into tensor<1x14x14x16xf32>
 //      CHECK: return %[[RESULT]]
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
index d5d1837..36c496a 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
@@ -653,7 +653,7 @@
   // Reshape ops are treated legal since they just change the way the underlying
   // buffer is viewed. These are legalized downstream. They become no ops when
   // lowering to SPIR-V since the SPIR-V code uses linearized arrays.
-  target.addLegalOp<linalg::ReshapeOp>();
+  target.addLegalOp<linalg::CollapseShapeOp, linalg::ExpandShapeOp>();
   // Let the rest fall through.
   target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
 
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
index b351349..130e1c6 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
@@ -313,9 +313,10 @@
   /// - tensor_to_memref can become a no-op since tensors are lowered to
   ///   !spv.array.
   /// - unrealized_conversion_cast with the same source and target type.
-  patterns
-      .insert<FoldAsNoOp<linalg::ReshapeOp>, FoldAsNoOp<memref::BufferCastOp>,
-              RemoveIdentityConversionCast>(typeConverter, context);
+  patterns.insert<
+      FoldAsNoOp<linalg::CollapseShapeOp>, FoldAsNoOp<linalg::ExpandShapeOp>,
+      FoldAsNoOp<memref::BufferCastOp>, RemoveIdentityConversionCast>(
+      typeConverter, context);
 
   std::unique_ptr<ConversionTarget> target =
       SPIRVConversionTarget::get(targetAttr);
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_conv.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_conv.mlir
index 6d21f9c..d660072 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_conv.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_conv.mlir
@@ -109,8 +109,8 @@
         %0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<1x113x113x96xf32>
         %1 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<3x3x1x96xf32>
         %2 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<1x56x56x96xf32>
-        %3 = linalg.reshape %1 [[0, 1, 2, 3]] : memref<3x3x1x96xf32> into memref<864xf32>
-        %4 = linalg.reshape %3 [[0, 1, 2]] : memref<864xf32> into memref<3x3x96xf32>
+        %3 = linalg.collapse_shape %1 [[0, 1, 2, 3]] : memref<3x3x1x96xf32> into memref<864xf32>
+        %4 = linalg.expand_shape %3 [[0, 1, 2]] : memref<864xf32> into memref<3x3x96xf32>
         %workgroup_size_x = hal.interface.workgroup.size[0] : index
         %workgroup_size_y = hal.interface.workgroup.size[1] : index
         %workgroup_size_z = hal.interface.workgroup.size[2] : index
diff --git a/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp b/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp
index 3eb7143..6654559 100644
--- a/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp
@@ -130,13 +130,14 @@
 
 /// Converts linalg.tensor_reshape operations into flow.tensor.reshape
 /// operations.
+template <typename TensorReshapeOp>
 struct LinalgTensorReshapeToFlowTensorReshape
-    : public OpRewritePattern<linalg::TensorReshapeOp> {
-  using OpRewritePattern<linalg::TensorReshapeOp>::OpRewritePattern;
+    : public OpRewritePattern<TensorReshapeOp> {
+  using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(linalg::TensorReshapeOp reshapeOp,
+  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
                                 PatternRewriter &rewriter) const override {
-    if (reshapeOp->getParentOfType<Flow::DispatchWorkgroupsOp>()) {
+    if (reshapeOp->template getParentOfType<Flow::DispatchWorkgroupsOp>()) {
       return failure();
     }
     SmallVector<SmallVector<Value>> outputShape;
@@ -263,9 +264,10 @@
     MLIRContext *context = funcOp->getContext();
     context->allowUnregisteredDialects(true);
     RewritePatternSet patterns(&getContext());
-    patterns.insert<LinalgTensorReshapeToFlowTensorReshape,
-                    SubTensorInsertToTensorUpdate, SubTensorToTensorSlice>(
-        context);
+    patterns.insert<
+        LinalgTensorReshapeToFlowTensorReshape<linalg::TensorCollapseShapeOp>,
+        LinalgTensorReshapeToFlowTensorReshape<linalg::TensorExpandShapeOp>,
+        SubTensorInsertToTensorUpdate, SubTensorToTensorSlice>(context);
     IREE::Flow::TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
     if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
       return signalPassFailure();
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
index dea0b8b..f1a8875 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
@@ -261,7 +261,9 @@
 }
 
 static bool isAlwaysFusedIntoDispatchOp(Operation *op) {
-  return isDispatchableOp(op) && isa<linalg::TensorReshapeOp, SubTensorOp>(op);
+  return isDispatchableOp(op) &&
+         (isa<linalg::TensorCollapseShapeOp, SubTensorOp>(op) ||
+          isa<linalg::TensorExpandShapeOp, SubTensorOp>(op));
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/convert_to_flow_tensor_ops.mlir b/iree/compiler/Dialect/Flow/Transforms/test/convert_to_flow_tensor_ops.mlir
index 03b5ec5..3360b10 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/convert_to_flow_tensor_ops.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/convert_to_flow_tensor_ops.mlir
@@ -117,9 +117,9 @@
 func @tensor_reshape(%arg0 : tensor<?x4x?x5x?x6xf32>, %arg1 : tensor<20x?x40xf32>)
     -> (tensor<?x5x?xf32>, tensor<5x4x?x4x2x4x5xf32>)
 {
-  %0 = linalg.tensor_reshape %arg0 [[0, 1, 2], [3], [4, 5]]
+  %0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2], [3], [4, 5]]
       : tensor<?x4x?x5x?x6xf32> into tensor<?x5x?xf32>
-  %1 = linalg.tensor_reshape %arg1 [[0, 1], [2, 3], [4, 5, 6]]
+  %1 = linalg.tensor_expand_shape %arg1 [[0, 1], [2, 3], [4, 5, 6]]
       : tensor<20x?x40xf32> into tensor<5x4x?x4x2x4x5xf32>
   return %0, %1 : tensor<?x5x?xf32>, tensor<5x4x?x4x2x4x5xf32>
 }
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
index 6266c07..ee0cdcd 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
@@ -212,7 +212,7 @@
 
 func @fuse_reshape_op(%arg0: tensor<?x?xf32>) -> tensor<?xf32>
 {
-  %0 = linalg.tensor_reshape %arg0 [[0, 1]] : tensor<?x?xf32> into tensor<?xf32>
+  %0 = linalg.tensor_collapse_shape %arg0 [[0, 1]] : tensor<?x?xf32> into tensor<?xf32>
   return %0 : tensor<?xf32>
 }
 //  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
@@ -228,7 +228,7 @@
 // CHECK-NEXT:     %[[ARG1:.+]]: !flow.dispatch.tensor<readonly:?x?xf32>
 // CHECK-SAME:     %[[ARG2:.+]]: !flow.dispatch.tensor<writeonly:?xf32>
 //      CHECK:       %[[LOAD:.+]] = flow.dispatch.tensor.load %[[ARG1]], {{.*}}
-//      CHECK:       %[[RESHAPE:.+]] = linalg.tensor_reshape %[[LOAD]] {{\[}}[0, 1]]
+//      CHECK:       %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[LOAD]] {{\[}}[0, 1]]
 //      CHECK:       flow.dispatch.tensor.store %[[RESHAPE]], %[[ARG2]], {{.*}}
 
 // -----
@@ -284,7 +284,7 @@
   %cst = constant 0.0 : f32
   %c0 = constant 0 : index
   %c1 = constant 1 : index
-  %0 = linalg.tensor_reshape %lhs [[0, 1]]
+  %0 = linalg.tensor_expand_shape %lhs [[0, 1]]
     : tensor<?xf32> into tensor<?x4xf32>
   %m = memref.dim %0, %c0 : tensor<?x4xf32>
   %n1 = memref.dim %rhs1, %c1 : tensor<4x?xf32>
@@ -554,14 +554,14 @@
 func @inline_dag_1(
     %arg0: tensor<?xf32>, %arg1 : tensor<1x?xf32>, %arg2 : tensor<i32>,
     %arg3 : index) -> tensor<?xf32> {
-  %0 = linalg.tensor_reshape %arg0 [[0, 1]] : tensor<?xf32> into tensor<1x?xf32>
+  %0 = linalg.tensor_expand_shape %arg0 [[0, 1]] : tensor<?xf32> into tensor<1x?xf32>
   %1 = subtensor %0[0, 20] [1, %arg3] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
-  %2 = linalg.tensor_reshape %1 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
-  %3 = linalg.tensor_reshape %arg1 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
+  %2 = linalg.tensor_collapse_shape %1 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
+  %3 = linalg.tensor_collapse_shape %arg1 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
   %4 = subtensor %0[0, 10] [1, %arg3] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
-  %5 = linalg.tensor_reshape %4 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
+  %5 = linalg.tensor_collapse_shape %4 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
   %6 = subtensor %0[0, 0] [1, %arg3] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
-  %7 = linalg.tensor_reshape %6 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
+  %7 = linalg.tensor_collapse_shape %6 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
   %8 = linalg.init_tensor [%arg3] : tensor<?xf32>
   %9 = linalg.generic {
       indexing_maps = [#map0, #map1, #map1, #map1, #map1, #map1],
@@ -591,14 +591,14 @@
 //       CHECK:     %[[LEAF1:.+]] = flow.dispatch.tensor.load %[[ARG4]], {{.*}}
 //       CHECK:     %[[LEAF2:.+]] = flow.dispatch.tensor.load %[[ARG5]], {{.*}}
 //       CHECK:     %[[LEAF3:.+]] = flow.dispatch.tensor.load %[[ARG6]], {{.*}}
-//       CHECK:     %[[OP1:.+]] = linalg.tensor_reshape %[[LEAF2]]
-//       CHECK:     %[[OP2:.+]] = linalg.tensor_reshape %[[LEAF1]]
+//       CHECK:     %[[OP1:.+]] = linalg.tensor_expand_shape %[[LEAF2]]
+//       CHECK:     %[[OP2:.+]] = linalg.tensor_collapse_shape %[[LEAF1]]
 //       CHECK:     %[[OP3:.+]] = subtensor %[[OP1]][0, 0]
 //       CHECK:     %[[OP4:.+]] = subtensor %[[OP1]][0, 10]
 //       CHECK:     %[[OP5:.+]] = subtensor %[[OP1]][0, 20]
-//       CHECK:     %[[OP6:.+]] = linalg.tensor_reshape %[[OP3]]
-//       CHECK:     %[[OP7:.+]] = linalg.tensor_reshape %[[OP4]]
-//       CHECK:     %[[OP8:.+]] = linalg.tensor_reshape %[[OP5]]
+//       CHECK:     %[[OP6:.+]] = linalg.tensor_collapse_shape %[[OP3]]
+//       CHECK:     %[[OP7:.+]] = linalg.tensor_collapse_shape %[[OP4]]
+//       CHECK:     %[[OP8:.+]] = linalg.tensor_collapse_shape %[[OP5]]
 
 // -----
 
@@ -607,16 +607,16 @@
 func @inline_dag_2(
     %arg0: tensor<?xf32>, %arg1 : tensor<1x?xf32>, %arg2 : tensor<i32>,
     %arg3 : index) -> tensor<?xf32> {
-  %0 = linalg.tensor_reshape %arg0 [[0, 1]] : tensor<?xf32> into tensor<1x?xf32>
+  %0 = linalg.tensor_expand_shape %arg0 [[0, 1]] : tensor<?xf32> into tensor<1x?xf32>
   %1 = subtensor %0[0, 20] [1, %arg3] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
-  %2 = linalg.tensor_reshape %arg1 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
+  %2 = linalg.tensor_collapse_shape %arg1 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
   br ^bb1
 ^bb1:
-  %3 = linalg.tensor_reshape %1 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
+  %3 = linalg.tensor_collapse_shape %1 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
   %4 = subtensor %0[0, 10] [1, %arg3] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
-  %5 = linalg.tensor_reshape %4 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
+  %5 = linalg.tensor_collapse_shape %4 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
   %6 = subtensor %0[0, 0] [1, %arg3] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
-  %7 = linalg.tensor_reshape %6 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
+  %7 = linalg.tensor_collapse_shape %6 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
   %8 = linalg.init_tensor [%arg3] : tensor<?xf32>
   %9 = linalg.generic {
       indexing_maps = [#map0, #map1, #map1, #map1, #map1, #map1],
@@ -646,14 +646,14 @@
 //       CHECK:     %[[LEAF1:.+]] = flow.dispatch.tensor.load %[[ARG4]], {{.*}}
 //       CHECK:     %[[LEAF2:.+]] = flow.dispatch.tensor.load %[[ARG5]], {{.*}}
 //       CHECK:     %[[LEAF3:.+]] = flow.dispatch.tensor.load %[[ARG6]], {{.*}}
-//       CHECK:     %[[OP1:.+]] = linalg.tensor_reshape %[[LEAF2]]
-//       CHECK:     %[[OP2:.+]] = linalg.tensor_reshape %[[LEAF1]]
+//       CHECK:     %[[OP1:.+]] = linalg.tensor_expand_shape %[[LEAF2]]
+//       CHECK:     %[[OP2:.+]] = linalg.tensor_collapse_shape %[[LEAF1]]
 //       CHECK:     %[[OP3:.+]] = subtensor %[[OP1]][0, 0]
 //       CHECK:     %[[OP4:.+]] = subtensor %[[OP1]][0, 10]
 //       CHECK:     %[[OP5:.+]] = subtensor %[[OP1]][0, 20]
-//       CHECK:     %[[OP6:.+]] = linalg.tensor_reshape %[[OP3]]
-//       CHECK:     %[[OP7:.+]] = linalg.tensor_reshape %[[OP4]]
-//       CHECK:     %[[OP8:.+]] = linalg.tensor_reshape %[[OP5]]
+//       CHECK:     %[[OP6:.+]] = linalg.tensor_collapse_shape %[[OP3]]
+//       CHECK:     %[[OP7:.+]] = linalg.tensor_collapse_shape %[[OP4]]
+//       CHECK:     %[[OP8:.+]] = linalg.tensor_collapse_shape %[[OP5]]
 
 // -----
 
diff --git a/iree/compiler/Dialect/Modules/VMVX/Conversion/StandardToVMVX/ConvertStandardToVMVX.cpp b/iree/compiler/Dialect/Modules/VMVX/Conversion/StandardToVMVX/ConvertStandardToVMVX.cpp
index 5ec9871..96e8fe1 100644
--- a/iree/compiler/Dialect/Modules/VMVX/Conversion/StandardToVMVX/ConvertStandardToVMVX.cpp
+++ b/iree/compiler/Dialect/Modules/VMVX/Conversion/StandardToVMVX/ConvertStandardToVMVX.cpp
@@ -68,7 +68,8 @@
                                     OwningRewritePatternList &patterns,
                                     TypeConverter &typeConverter) {
   // We type/shape erase memrefs as we lower so there is no need for reshapes.
-  patterns.insert<FoldAsNoOp<linalg::ReshapeOp>>(typeConverter, context);
+  patterns.insert<FoldAsNoOp<linalg::CollapseShapeOp>>(typeConverter, context);
+  patterns.insert<FoldAsNoOp<linalg::ExpandShapeOp>>(typeConverter, context);
 
   patterns.insert<RemoveIdentityConversionCast>(typeConverter, context);
 }
diff --git a/third_party/llvm-bazel b/third_party/llvm-bazel
index 528af9c..12a96e8 160000
--- a/third_party/llvm-bazel
+++ b/third_party/llvm-bazel
@@ -1 +1 @@
-Subproject commit 528af9c01dc4080331184ece3ced3be6beaad47f
+Subproject commit 12a96e87d2eb377b7884aae56773bca50f12e7c7
diff --git a/third_party/llvm-project b/third_party/llvm-project
index c89dff5..b109172 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit c89dff5855bb32d47751cce087537c2b12a90f1b
+Subproject commit b109172d993edacd9853a8bbb8128a94da014399
diff --git a/third_party/mlir-hlo b/third_party/mlir-hlo
index fe42a08..8b3a75e 160000
--- a/third_party/mlir-hlo
+++ b/third_party/mlir-hlo
@@ -1 +1 @@
-Subproject commit fe42a08fc93830f06150b93bcb764e2386ca3110
+Subproject commit 8b3a75ea25ceca9070938ced4b5909042765ea0c
diff --git a/third_party/tensorflow b/third_party/tensorflow
index 8be2640..1d497e6 160000
--- a/third_party/tensorflow
+++ b/third_party/tensorflow
@@ -1 +1 @@
-Subproject commit 8be2640d44c8b4012e7863c93eea3ccb14f75151
+Subproject commit 1d497e6419ca142acd188632c228c11d7c074a34