Switching to use tensor.pack/unpack ops for data-tiling (#12247)

benchmarks: comp-stats
diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp
index 2d80fbc..affad4d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp
@@ -471,6 +471,7 @@
 /// of tensor.unpack op for more details.
 static LogicalResult replaceUnpackEmptyWithAllocTensor(OpBuilder &b,
                                                        func::FuncOp funcOp) {
+  // TODO(hanchung): retire IREE::LinalgExt version.
   funcOp.walk([&](IREE::LinalgExt::UnPackOp unpackOp) {
     if (!unpackOp->hasOneUse() ||
         !isa<tensor::ExtractSliceOp>(*(unpackOp->user_begin()))) {
@@ -485,6 +486,22 @@
         emptyOp.getLoc(), emptyOp.getType(), emptyOp.getDynamicSizes());
     emptyOp.replaceAllUsesWith(allocTensor.getResult());
   });
+
+  funcOp.walk([&](tensor::UnPackOp unpackOp) {
+    if (!unpackOp->hasOneUse() ||
+        !isa<tensor::ExtractSliceOp>(*(unpackOp->user_begin()))) {
+      return;
+    }
+    auto emptyOp = unpackOp.getDest().getDefiningOp<tensor::EmptyOp>();
+    if (!emptyOp) return;
+
+    OpBuilder::InsertionGuard g(b);
+    b.setInsertionPointAfter(emptyOp);
+    auto allocTensor = b.create<bufferization::AllocTensorOp>(
+        emptyOp.getLoc(), emptyOp.getType(), emptyOp.getDynamicSizes());
+    emptyOp.replaceAllUsesWith(allocTensor.getResult());
+  });
+
   return success();
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
index 4ec5a6e..289f7a5 100644
--- a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
@@ -56,9 +56,10 @@
   auto innerTileSizes = getInnerTileSizesOfr(
       builder, loc, boundTensorType, *encodingInfo, materializeEncodingValueFn);
   if (failed(innerTileSizes)) return failure();
-  SmallVector<OpFoldResult> convertedTargetShape = PackOp::getResultShape(
-      builder, loc, targetShape, *innerTileSizes, encodingInfo->innerDimsPos,
-      encodingInfo->outerDimsPerm);
+  SmallVector<OpFoldResult> convertedTargetShape =
+      tensor::PackOp::getResultShape(builder, loc, targetShape, *innerTileSizes,
+                                     encodingInfo->innerDimsPos,
+                                     encodingInfo->outerDimsPerm);
   return convertedTargetShape;
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
index 47df80d..a410923 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
@@ -112,9 +112,9 @@
   return success();
 }
 
-/// Get the materialization information from a `iree_linalg_ext.pack` operation.
+/// Get the materialization information from a `tensor.pack` operation.
 static FailureOr<IREE::LinalgExt::MaterializeEncodingInfo>
-getMaterializationInfo(IREE::LinalgExt::PackOp packOp) {
+getMaterializationInfo(tensor::PackOp packOp) {
   IREE::LinalgExt::MaterializeEncodingInfo encodingInfo;
   SmallVector<OpFoldResult> mixedTileSizes = packOp.getMixedTiles();
   encodingInfo.innerTileSizes.reserve(mixedTileSizes.size());
@@ -128,7 +128,7 @@
   }
   encodingInfo.innerDimsPos = llvm::to_vector(packOp.getInnerDimsPos());
   encodingInfo.outerDimsPerm = llvm::to_vector(packOp.getOuterDimsPerm());
-  encodingInfo.srcRank = packOp.getInputRank();
+  encodingInfo.srcRank = packOp.getSourceRank();
   return encodingInfo;
 }
 
@@ -277,11 +277,10 @@
     auto innerTileSizes = getInnerTileSizesOfr(rewriter, loc, inputType,
                                                materializeEncodingInfo, {});
     if (failed(innerTileSizes)) return failure();
-    SmallVector<OpFoldResult> resultShape =
-        IREE::LinalgExt::PackOp::getResultShape(
-            rewriter, loc, getAsOpFoldResult(workload), *innerTileSizes,
-            materializeEncodingInfo.innerDimsPos,
-            materializeEncodingInfo.outerDimsPerm);
+    SmallVector<OpFoldResult> resultShape = tensor::PackOp::getResultShape(
+        rewriter, loc, getAsOpFoldResult(workload), *innerTileSizes,
+        materializeEncodingInfo.innerDimsPos,
+        materializeEncodingInfo.outerDimsPerm);
     resultShape.resize(materializeEncodingInfo.srcRank);
 
     rewriter
@@ -351,14 +350,13 @@
       patterns.insert<LowerDispatchWorkgroupCountForDagRootOp>(
           context, tileSizes, staticLoopRanges, interchange,
           partitionableLoops);
-      if (auto packRootOp =
-              dyn_cast_or_null<IREE::LinalgExt::PackOp>(dispatchRootOp)) {
+      if (auto packRootOp = dyn_cast_or_null<tensor::PackOp>(dispatchRootOp)) {
         FailureOr<IREE::LinalgExt::MaterializeEncodingInfo> encodingInfo =
             getMaterializationInfo(packRootOp);
         if (failed(encodingInfo)) {
           return signalPassFailure();
         }
-        auto tensorType = packRootOp.getInputType().cast<RankedTensorType>();
+        auto tensorType = packRootOp.getSourceType();
         // The LowerDispatchWorkgroupCountFromSetEncodingOp pattern is going to
         // call materializeEncodingValueFn, passing it a tensor type, expecting
         // that tensor type to have a TensorEncodingAttr. The problem is that
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir b/compiler/src/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir
index c78f23b..2ac2eb1 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir
@@ -821,7 +821,7 @@
       %16 = affine.apply affine_map<(d0)[s0] -> (d0 floordiv s0)>(%arg1)[%8#1]
       %17 = flow.dispatch.tensor.load %3, offsets = [%15, %16, 0, 0], sizes = [%c1, %c1, %8#0, %8#1], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?x?x?xi32>>{%6, %7, %5#0, %5#1} -> tensor<?x?x?x?xi32>
       %18 = tensor.empty(%8#0, %8#1) : tensor<?x?xi32>
-      %19 = iree_linalg_ext.unpack %17 inner_dims_pos = [0, 1] inner_tiles = [%8#0, %8#1] into %18 : (tensor<?x?x?x?xi32> tensor<?x?xi32>) -> tensor<?x?xi32>
+      %19 = tensor.unpack %17 inner_dims_pos = [0, 1] inner_tiles = [%8#0, %8#1] into %18 : tensor<?x?x?x?xi32> -> tensor<?x?xi32>
       %extracted_slice = tensor.extract_slice %19[%13, %14] [1, 1] [1, 1] : tensor<?x?xi32> to tensor<1x1xi32>
       %cast = tensor.cast %extracted_slice : tensor<1x1xi32> to tensor<?x?xi32>
       flow.dispatch.tensor.store %cast, %4, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [1, 1] : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:tensor<1x1xi32>>
@@ -831,7 +831,7 @@
 }
 // CHECK-LABEL: func.func @non_perfect_tiling_unpack
 // CHECK:         %[[ALLOC:.+]] = bufferization.alloc_tensor
-// CHECK:         %[[UNPACK:.+]] = iree_linalg_ext.unpack
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack
 // CHECK-SAME:      into %[[ALLOC]]
 // CHECK:         %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir
index 68117fd..dee7671 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir
@@ -1915,9 +1915,9 @@
         %2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [100, 250], strides = [1, 1]
             : !flow.dispatch.tensor<readonly:tensor<100x250xf32>> -> tensor<100x250xf32>
         %3 = tensor.empty() : tensor<14x64x8x4xf32>
-        %4 = iree_linalg_ext.pack {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64]]>} %2
-            padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %3
-            : (tensor<100x250xf32> tensor<14x64x8x4xf32>) -> tensor<14x64x8x4xf32>
+        %4 = tensor.pack %2 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %3
+            {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64]]>}
+            : tensor<100x250xf32> -> tensor<14x64x8x4xf32>
         flow.dispatch.tensor.store %4, %1, offsets = [0, 0, 0, 0], sizes = [14, 64, 8, 4], strides = [1, 1, 1, 1]
             : tensor<14x64x8x4xf32> -> !flow.dispatch.tensor<writeonly:tensor<14x64x8x4xf32>>
         return
@@ -1960,9 +1960,9 @@
         %2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [250, 500], strides = [1, 1]
             : !flow.dispatch.tensor<readonly:tensor<250x500xf32>> -> tensor<250x500xf32>
         %3 = tensor.empty() : tensor<64x64x8x4xf32>
-        %4 = iree_linalg_ext.pack {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64]]>} %2
-            padding_value(%cst : f32) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 4] into %3
-            : (tensor<250x500xf32> tensor<64x64x8x4xf32>) -> tensor<64x64x8x4xf32>
+        %4 = tensor.pack %2 padding_value(%cst : f32) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 4] into %3
+            {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64]]>}
+            : tensor<250x500xf32> -> tensor<64x64x8x4xf32>
         flow.dispatch.tensor.store %4, %1, offsets = [0, 0, 0, 0], sizes = [64, 64, 8, 4], strides = [1, 1, 1, 1]
             : tensor<64x64x8x4xf32> -> !flow.dispatch.tensor<writeonly:tensor<64x64x8x4xf32>>
         return
@@ -2018,9 +2018,9 @@
         %15 = affine.apply affine_map<()[s0] -> (s0 ceildiv 8)>()[%13]
         %16 = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%14]
         %17 = tensor.empty(%15, %16) : tensor<?x?x8x4xf32>
-        %18 = iree_linalg_ext.pack {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64]]>} %12
-            padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %17
-            : (tensor<?x?xf32> tensor<?x?x8x4xf32>) -> tensor<?x?x8x4xf32>
+        %18 = tensor.pack %12 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %17
+            {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64]]>}
+            : tensor<?x?xf32> -> tensor<?x?x8x4xf32>
         %19 = affine.apply affine_map<()[s0] -> (s0 ceildiv 8)>()[%6]
         %20 = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%7]
         flow.dispatch.tensor.store %18, %11, offsets = [0, 0, 0, 0], sizes = [%19, %20, 8, 4], strides = [1, 1, 1, 1]
@@ -2070,8 +2070,9 @@
         %9 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c131072) : !flow.dispatch.tensor<writeonly:tensor<?x?xi32>>{%6, %7}
         %10 = flow.dispatch.tensor.load %8, offsets = [0, 0, 0, 0], sizes = [%4, %5, 32, 16], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?x32x16xi32>>{%4, %5} -> tensor<?x?x32x16xi32>
         %11 = tensor.empty(%6, %7) : tensor<?x?xi32>
-        %12 = iree_linalg_ext.unpack {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64]]>}
-          %10 inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %11 : (tensor<?x?x32x16xi32> tensor<?x?xi32>) -> tensor<?x?xi32>
+        %12 = tensor.unpack %10 inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %11
+          {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64]]>}
+          : tensor<?x?x32x16xi32> -> tensor<?x?xi32>
         flow.dispatch.tensor.store %12, %9, offsets = [0, 0], sizes = [%6, %7], strides = [1, 1] : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:tensor<?x?xi32>>{%6, %7}
         return
       }
@@ -2081,7 +2082,7 @@
 // CHECK-LABEL: func.func @dynamic_unpack
 // CHECK:         scf.for
 // CHECK:           scf.for
-// CHECK:             iree_linalg_ext.unpack
+// CHECK:             tensor.unpack
 
 // -----
 
@@ -2115,8 +2116,9 @@
         %9 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c131072) : !flow.dispatch.tensor<writeonly:tensor<?x?xi32>>{%6, %7}
         %10 = flow.dispatch.tensor.load %8, offsets = [0, 0, 0, 0], sizes = [%4, %5, %c32, %c16], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?x?x?xi32>>{%4, %5, %c32, %c16} -> tensor<?x?x?x?xi32>
         %11 = tensor.empty(%6, %7) : tensor<?x?xi32>
-        %12 = iree_linalg_ext.unpack {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64]]>}
-          %10 inner_dims_pos = [0, 1] inner_tiles = [%c32, %c16] into %11 : (tensor<?x?x?x?xi32> tensor<?x?xi32>) -> tensor<?x?xi32>
+        %12 = tensor.unpack %10 inner_dims_pos = [0, 1] inner_tiles = [%c32, %c16] into %11
+          {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64]]>}
+          : tensor<?x?x?x?xi32> -> tensor<?x?xi32>
         flow.dispatch.tensor.store %12, %9, offsets = [0, 0], sizes = [%6, %7], strides = [1, 1] : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:tensor<?x?xi32>>{%6, %7}
         return
       }
@@ -2126,7 +2128,7 @@
 // CHECK-LABEL: func.func @dynamic_unpack_dynamic_tile
 // CHECK:         scf.for
 // CHECK:           scf.for
-// CHECK:             iree_linalg_ext.unpack
+// CHECK:             tensor.unpack
 
 // -----
 
@@ -2149,7 +2151,7 @@
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<128x384xf32>>
         %2 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [16, 48, 8, 8], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<16x48x8x8xf32>> -> tensor<16x48x8x8xf32>
         %3 = tensor.empty() : tensor<128x384xf32>
-        %4 = iree_linalg_ext.unpack {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64]]>} %2 inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %3 : (tensor<16x48x8x8xf32> tensor<128x384xf32>) -> tensor<128x384xf32>
+        %4 = tensor.unpack %2 inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %3 {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64]]>} : tensor<16x48x8x8xf32> -> tensor<128x384xf32>
         %5 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%4 : tensor<128x384xf32>) outs(%3 : tensor<128x384xf32>) {
         ^bb0(%in: f32, %out: f32):
           %6 = arith.addf %in, %in : f32
@@ -2165,7 +2167,7 @@
 // CHECK-LABEL: func.func @unpack_elem
 // CHECK:         scf.for
 // CHECK:           scf.for
-// CHECK:             iree_linalg_ext.unpack
+// CHECK:             tensor.unpack
 // CHECK:             linalg.generic
 
 // -----
@@ -2205,7 +2207,7 @@
         %13 = tensor.empty() : tensor<12544x16xi32>
         %14 = tensor.empty() : tensor<12544x16xi32>
         %15:2 = vmvx.query_tile_sizes sizes(%c12544, %c16) flags(1245184) -> index, index
-        %16 = iree_linalg_ext.unpack {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[16, 16]]>} %10 inner_dims_pos = [0, 1] inner_tiles = [%15#0, %15#1] into %14 : (tensor<?x?x?x?xi32> tensor<12544x16xi32>) -> tensor<12544x16xi32>
+        %16 = tensor.unpack %10 inner_dims_pos = [0, 1] inner_tiles = [%15#0, %15#1] into %14 {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[16, 16]]>} : tensor<?x?x?x?xi32> -> tensor<12544x16xi32>
         %17 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst, %16, %11, %12 : tensor<16xi32>, tensor<12544x16xi32>, tensor<12544xi32>, tensor<16xi32>) outs(%13 : tensor<12544x16xi32>) {
         ^bb0(%in: i32, %in_0: i32, %in_1: i32, %in_2: i32, %out: i32):
           %18 = arith.muli %in_1, %c-30_i32 : i32
@@ -2225,7 +2227,7 @@
 // CHECK-LABEL: func.func @dynamic_unpack_fusion
 // CHECK:         scf.for
 // CHECK:           scf.for
-// CHECK:             iree_linalg_ext.unpack
+// CHECK:             tensor.unpack
 // CHECK:             tensor.extract_slice
 // CHECK:             linalg.generic
 
@@ -2273,7 +2275,7 @@
           linalg.yield %23, %25 : f32, f32
         } -> (tensor<384x512xf32>, tensor<384x512xf32>)
         %17 = tensor.empty() : tensor<48x512x8x1xf32>
-        %18 = iree_linalg_ext.pack {encoding = #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>, lowering_config = #iree_codegen.lowering_config<tile_sizes = [[8, 64]]>} %16#0 inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %17 : (tensor<384x512xf32> tensor<48x512x8x1xf32>) -> tensor<48x512x8x1xf32>
+        %18 = tensor.pack %16#0 inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %17 {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[8, 64]]>} : tensor<384x512xf32> -> tensor<48x512x8x1xf32>
         flow.dispatch.tensor.store %18, %6, offsets = [0, 0, 0, 0], sizes = [48, 512, 8, 1], strides = [1, 1, 1, 1] : tensor<48x512x8x1xf32> -> !flow.dispatch.tensor<writeonly:tensor<48x512x8x1xf32>>
         flow.dispatch.tensor.store %16#0, %7, offsets = [0, 0], sizes = [384, 512], strides = [1, 1] : tensor<384x512xf32> -> !flow.dispatch.tensor<writeonly:tensor<384x512xf32>>
         flow.dispatch.tensor.store %16#1, %8, offsets = [0, 0], sizes = [384, 512], strides = [1, 1] : tensor<384x512xf32> -> !flow.dispatch.tensor<writeonly:tensor<384x512xf32>>
@@ -2286,7 +2288,7 @@
 // CHECK:         scf.for
 // CHECK:           scf.for
 // CHECK:             %[[ELEM:.+]]:2 = linalg.generic
-// CHECK:             %[[PACK:.+]] = iree_linalg_ext.pack
+// CHECK:             %[[PACK:.+]] = tensor.pack
 // CHECK-DAG:         flow.dispatch.tensor.store %[[PACK]], {{.*}} sizes = [8, 64, 8, 1]
 // CHECK-DAG:         flow.dispatch.tensor.store %[[ELEM]]#0, {{.*}} sizes = [64, 64]
 // CHECK-DAG:         flow.dispatch.tensor.store %[[ELEM]]#1, {{.*}} sizes = [64, 64]
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD
index 5efb99a..2be2eba 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD
@@ -102,6 +102,7 @@
         "@llvm-project//mlir:SCFToControlFlow",
         "@llvm-project//mlir:SCFTransforms",
         "@llvm-project//mlir:TensorDialect",
+        "@llvm-project//mlir:TensorTransforms",
         "@llvm-project//mlir:TosaDialect",
         "@llvm-project//mlir:TosaToArith",
         "@llvm-project//mlir:TransformDialect",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
index 91a49c1..8375bfa 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
@@ -80,6 +80,7 @@
     MLIRSCFToControlFlow
     MLIRSCFTransforms
     MLIRTensorDialect
+    MLIRTensorTransforms
     MLIRTosaDialect
     MLIRTosaToArith
     MLIRTransformDialect
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMaterializeEncodingPass.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMaterializeEncodingPass.cpp
index 23ca04c..21152ea 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMaterializeEncodingPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMaterializeEncodingPass.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
@@ -144,7 +145,7 @@
   // dims ops.
   {
     RewritePatternSet patterns(context);
-    populateFoldIntoPackAndUnpackOpsPatterns(patterns);
+    tensor::populateFoldIntoPackAndUnpackPatterns(patterns);
     memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
     if (failed(applyPatternsAndFoldGreedily(operation, std::move(patterns)))) {
       operation.emitOpError("folding patterns failed");
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index b5c41f5..341a56a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -670,9 +670,9 @@
   addTileAndDistributePasses(passManager);
   OpPassManager &nestedModulePM = passManager.nest<ModuleOp>();
   nestedModulePM.addNestedPass<func::FuncOp>(
-      IREE::LinalgExt::createLinalgExtVectorizationPass());
-  nestedModulePM.addNestedPass<func::FuncOp>(
       createVectorizePackUnPackOpsPass());
+  nestedModulePM.addNestedPass<func::FuncOp>(
+      IREE::LinalgExt::createLinalgExtVectorizationPass());
   addBufferizePasses(nestedModulePM);
   nestedModulePM.addNestedPass<func::FuncOp>(
       createSplitFullPartialTransferPass("linalg-copy"));
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/data_tiling_pipeline.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/data_tiling_pipeline.mlir
index 38812ee..884790a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/data_tiling_pipeline.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/data_tiling_pipeline.mlir
@@ -22,7 +22,7 @@
           linalg.yield %7 : f32
         } -> tensor<128x384xf32>
         %5 = tensor.empty() : tensor<16x384x8x1xf32>
-        %6 = iree_linalg_ext.pack %4 inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %5 : (tensor<128x384xf32> tensor<16x384x8x1xf32>) -> tensor<16x384x8x1xf32>
+        %6 = tensor.pack %4 inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %5 : tensor<128x384xf32> -> tensor<16x384x8x1xf32>
         flow.dispatch.tensor.store %6, %1, offsets = [0, 0, 0, 0], sizes = [16, 384, 8, 1], strides = [1, 1, 1, 1] : tensor<16x384x8x1xf32> -> !flow.dispatch.tensor<writeonly:tensor<16x384x8x1xf32>>
         return
       }
@@ -32,5 +32,4 @@
 // CHECK: func.func @elem_pack
 // CHECK:   %[[READ:.+]] = vector.transfer_read
 // CHECK:   %[[ADD:.+]] = arith.addf %[[READ]], %[[READ]]
-// CHECK:   %[[BCAST:.+]] = vector.broadcast %[[ADD]]
-// CHECK:   vector.transfer_write %[[BCAST]], %{{.+}}
+// CHECK:   vector.transfer_write %[[ADD]], %{{.+}}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_aarch64_launch_configuration.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_aarch64_launch_configuration.mlir
index 0d931c9..29049c4 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_aarch64_launch_configuration.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_aarch64_launch_configuration.mlir
@@ -355,7 +355,7 @@
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<4x48x8x1xf32>>
         %2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [20, 40], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<20x40xf32>> -> tensor<20x40xf32>
         %3 = tensor.empty() : tensor<4x48x8x1xf32>
-        %4 = iree_linalg_ext.pack %2 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %3 : (tensor<20x40xf32> tensor<4x48x8x1xf32>) -> tensor<4x48x8x1xf32>
+        %4 = tensor.pack %2 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %3 : tensor<20x40xf32> -> tensor<4x48x8x1xf32>
         flow.dispatch.tensor.store %4, %1, offsets = [0, 0, 0, 0], sizes = [4, 48, 8, 1], strides = [1, 1, 1, 1] : tensor<4x48x8x1xf32> -> !flow.dispatch.tensor<writeonly:tensor<4x48x8x1xf32>>
         return
       }
@@ -366,7 +366,7 @@
 //  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDataTiling>
 //      CHECK: hal.executable.export public @pack
 // CHECK-SAME:     translation_info = #[[TRANSLATION]]
-//      CHECK:   iree_linalg_ext.pack
+//      CHECK:   tensor.pack
 // CHECK-SAME:       lowering_config = #[[CONFIG]]
 
 // -----
@@ -400,7 +400,7 @@
         %9 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c131072) : !flow.dispatch.tensor<writeonly:tensor<?x?xi32>>{%6, %7}
         %10 = flow.dispatch.tensor.load %8, offsets = [0, 0, 0, 0], sizes = [%4, %5, 32, 16], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?x32x16xi32>>{%4, %5} -> tensor<?x?x32x16xi32>
         %11 = tensor.empty(%6, %7) : tensor<?x?xi32>
-        %12 = iree_linalg_ext.unpack %10 inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %11 : (tensor<?x?x32x16xi32> tensor<?x?xi32>) -> tensor<?x?xi32>
+        %12 = tensor.unpack %10 inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %11 : tensor<?x?x32x16xi32> -> tensor<?x?xi32>
         flow.dispatch.tensor.store %12, %9, offsets = [0, 0], sizes = [%6, %7], strides = [1, 1] : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:tensor<?x?xi32>>{%6, %7}
         return
       }
@@ -411,5 +411,5 @@
 //  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDataTiling>
 //      CHECK: hal.executable.export public @unpack_outer_dynamic
 // CHECK-SAME:     translation_info = #[[TRANSLATION]]
-//      CHECK:   iree_linalg_ext.unpack
+//      CHECK:   tensor.unpack
 // CHECK-SAME:       lowering_config = #[[CONFIG]]
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_encoding.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_encoding.mlir
index 0ef9a3f..ac523a8 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_encoding.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_encoding.mlir
@@ -41,7 +41,7 @@
 //  CHECK-SAME:       !flow.dispatch.tensor<writeonly:tensor<?x?x8x4xf32>>{%[[TILED_OUTD0]], %[[TILED_OUTD1]]}
 //       CHECK:   %[[INPUT:.+]] = flow.dispatch.tensor.load %[[INPUT_BINDING]]
 //       CHECK:   %[[EMPTY:.+]] = tensor.empty(%[[TILED_OUTD0]], %[[TILED_OUTD1]])
-//       CHECK:   %[[PACK:.+]] = iree_linalg_ext.pack
+//       CHECK:   %[[PACK:.+]] = tensor.pack
 //  CHECK-SAME:       %[[INPUT]] padding_value(%[[CST]] : f32)
 //  CHECK-SAME:       inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %[[EMPTY]]
 //       CHECK:   flow.dispatch.tensor.store %[[PACK]], %[[OUTPUT_BINDING]]
@@ -86,7 +86,7 @@
 //       CHECK:   %[[INPUT:.+]] = flow.dispatch.tensor.load %[[INPUT_BINDING]]
 //  CHECK-SAME:       offsets = [0, 0, 0, 0], sizes = [%[[TILED_D0]], %[[TILED_D1]], 8, 4], strides = [1, 1, 1, 1]
 //       CHECK:   %[[EMPTY:.+]] = tensor.empty(%[[OUTD0]], %[[OUTD1]])
-//       CHECK:   %[[UNPACK:.+]] = iree_linalg_ext.unpack %[[INPUT]]
+//       CHECK:   %[[UNPACK:.+]] = tensor.unpack %[[INPUT]]
 //  CHECK-SAME:       inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %[[EMPTY]]
 //   CHECK-DAG:   flow.dispatch.tensor.store %[[UNPACK]], %[[OUTPUT_BINDING]]
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_vmvx_launch_configuration.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_vmvx_launch_configuration.mlir
index 2c4747d..62e9e3f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_vmvx_launch_configuration.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_vmvx_launch_configuration.mlir
@@ -212,7 +212,7 @@
         %9 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c131072) : !flow.dispatch.tensor<writeonly:tensor<?x?xi32>>{%6, %7}
         %10 = flow.dispatch.tensor.load %8, offsets = [0, 0, 0, 0], sizes = [%4, %5, 32, 16], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?x32x16xi32>>{%4, %5} -> tensor<?x?x32x16xi32>
         %11 = tensor.empty(%6, %7) : tensor<?x?xi32>
-        %12 = iree_linalg_ext.unpack %10 inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %11 : (tensor<?x?x32x16xi32> tensor<?x?xi32>) -> tensor<?x?xi32>
+        %12 = tensor.unpack %10 inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %11 : tensor<?x?x32x16xi32> -> tensor<?x?xi32>
         flow.dispatch.tensor.store %12, %9, offsets = [0, 0], sizes = [%6, %7], strides = [1, 1] : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:tensor<?x?xi32>>{%6, %7}
         return
       }
@@ -223,5 +223,5 @@
 //  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<VMVXDefault>
 //      CHECK: hal.executable.export public @unpack_outer_dynamic
 // CHECK-SAME:     translation_info = #[[TRANSLATION]]
-//      CHECK:   iree_linalg_ext.unpack
+//      CHECK:   tensor.unpack
 // CHECK-SAME:       lowering_config = #[[CONFIG]]
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_x86_64_launch_configuration.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_x86_64_launch_configuration.mlir
index 6f91452..07fae72 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_x86_64_launch_configuration.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_x86_64_launch_configuration.mlir
@@ -1450,7 +1450,7 @@
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x48x16x1xf32>>
         %2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [20, 40], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<20x40xf32>> -> tensor<20x40xf32>
         %3 = tensor.empty() : tensor<2x48x16x1xf32>
-        %4 = iree_linalg_ext.pack %2 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [16, 1] into %3 : (tensor<20x40xf32> tensor<2x48x16x1xf32>) -> tensor<2x48x16x1xf32>
+        %4 = tensor.pack %2 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [16, 1] into %3 : tensor<20x40xf32> -> tensor<2x48x16x1xf32>
         flow.dispatch.tensor.store %4, %1, offsets = [0, 0, 0, 0], sizes = [2, 48, 16, 1], strides = [1, 1, 1, 1] : tensor<2x48x16x1xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x48x16x1xf32>>
         return
       }
@@ -1461,7 +1461,7 @@
 //  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDataTiling>
 //      CHECK: hal.executable.export public @pack
 // CHECK-SAME:     translation_info = #[[TRANSLATION]]
-//      CHECK:   iree_linalg_ext.pack
+//      CHECK:   tensor.pack
 // CHECK-SAME:       lowering_config = #[[CONFIG]]
 
 // -----
diff --git a/compiler/src/iree/compiler/Codegen/VMVX/BUILD b/compiler/src/iree/compiler/Codegen/VMVX/BUILD
index f585bd1..800c519 100644
--- a/compiler/src/iree/compiler/Codegen/VMVX/BUILD
+++ b/compiler/src/iree/compiler/Codegen/VMVX/BUILD
@@ -49,6 +49,7 @@
         "@llvm-project//mlir:MemRefDialect",
         "@llvm-project//mlir:MemRefTransforms",
         "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:TensorTransforms",
         "@llvm-project//mlir:Transforms",
     ],
 )
diff --git a/compiler/src/iree/compiler/Codegen/VMVX/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/VMVX/CMakeLists.txt
index 776cabc..b4faa02 100644
--- a/compiler/src/iree/compiler/Codegen/VMVX/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/VMVX/CMakeLists.txt
@@ -34,6 +34,7 @@
     MLIRMemRefDialect
     MLIRMemRefTransforms
     MLIRPass
+    MLIRTensorTransforms
     MLIRTransforms
     iree::builtins::ukernel::exported_bits
     iree::compiler::Codegen::Common
diff --git a/compiler/src/iree/compiler/Codegen/VMVX/VMVXMaterializeEncodingPass.cpp b/compiler/src/iree/compiler/Codegen/VMVX/VMVXMaterializeEncodingPass.cpp
index 0af5408..2c25454 100644
--- a/compiler/src/iree/compiler/Codegen/VMVX/VMVXMaterializeEncodingPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/VMVX/VMVXMaterializeEncodingPass.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
@@ -53,9 +54,10 @@
 struct VMVXMaterializeEncodingPass
     : public VMVXMaterializeEncodingBase<VMVXMaterializeEncodingPass> {
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<arith::ArithDialect, AffineDialect, IREE::Flow::FlowDialect,
-                    IREE::LinalgExt::IREELinalgExtDialect,
-                    IREE::VMVX::VMVXDialect>();
+    registry
+        .insert<arith::ArithDialect, AffineDialect, tensor::TensorDialect,
+                IREE::Flow::FlowDialect, IREE::LinalgExt::IREELinalgExtDialect,
+                IREE::VMVX::VMVXDialect>();
   }
   void runOnOperation() override;
 };
@@ -101,7 +103,7 @@
   // dims ops.
   {
     RewritePatternSet patterns(context);
-    populateFoldIntoPackAndUnpackOpsPatterns(patterns);
+    tensor::populateFoldIntoPackAndUnpackPatterns(patterns);
     memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
     if (failed(applyPatternsAndFoldGreedily(operation, std::move(patterns)))) {
       operation.emitOpError("folding patterns failed");
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
index 90977cb..37f2a24 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -47,10 +48,10 @@
   if (failed(materializeEncodingInfo)) {
     return tensorType;
   }
-  return PackOp::getPackedType(tensorType,
-                               materializeEncodingInfo->innerTileSizes,
-                               materializeEncodingInfo->innerDimsPos,
-                               materializeEncodingInfo->outerDimsPerm)
+  return tensor::PackOp::inferPackedType(
+             tensorType, materializeEncodingInfo->innerTileSizes,
+             materializeEncodingInfo->innerDimsPos,
+             materializeEncodingInfo->outerDimsPerm)
       .cast<RankedTensorType>();
 }
 
@@ -116,7 +117,7 @@
 /// Utility method to convert from `set_encoding` op to `pack` operation.
 /// For now this takes a `paddingValue` as input. The source is also taken
 /// as input so that these could be used with `OpConversionPatterns`.
-static FailureOr<PackOp> lowerSetEncodingOpToPackOp(
+static FailureOr<tensor::PackOp> lowerSetEncodingOpToPackOp(
     RewriterBase &rewriter, SetEncodingOp encodingOp, Value source,
     MaterializeEncodingFn materializeEncodingFn,
     MaterializeEncodingValueFn materializeEncodingValueFn) {
@@ -139,14 +140,14 @@
   Optional<TensorEncoding> encoding = getEncoding(resultType);
   if (!encoding)
     return failure();
-  SmallVector<OpFoldResult> resultDims =
-      PackOp::getResultShape(rewriter, loc, sourceDims, *innerTileSizesOfr,
-                             materializeEncodingInfo->innerDimsPos,
-                             materializeEncodingInfo->outerDimsPerm);
+  SmallVector<OpFoldResult> resultDims = tensor::PackOp::getResultShape(
+      rewriter, loc, sourceDims, *innerTileSizesOfr,
+      materializeEncodingInfo->innerDimsPos,
+      materializeEncodingInfo->outerDimsPerm);
   auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, resultDims,
                                                   resultType.getElementType());
   Optional<Value> paddingValue = getPaddingValue(source);
-  auto packOp = rewriter.create<PackOp>(
+  auto packOp = rewriter.create<tensor::PackOp>(
       loc, source, emptyOp, materializeEncodingInfo->innerDimsPos,
       *innerTileSizesOfr, paddingValue, materializeEncodingInfo->outerDimsPerm);
   // As we rewrite the SetEncoding and its old result tensor, which used to hold
@@ -165,7 +166,7 @@
 /// Utility method to convert from `set_encoding` op to `pack` operation.
 /// The source is taken as input so that these could be used with
 /// `OpConversionPatterns`.
-static FailureOr<UnPackOp> lowerUnsetEncodingToUnpackOp(
+static FailureOr<tensor::UnPackOp> lowerUnsetEncodingToUnpackOp(
     RewriterBase &rewriter, UnsetEncodingOp encodingOp, Value packedValue,
     MaterializeEncodingFn materializeEncodingFn,
     MaterializeEncodingValueFn materializeEncodingValueFn) {
@@ -188,7 +189,7 @@
     return rewriter.notifyMatchFailure(
         encodingOp, "failed to generate runtime tile size query");
   }
-  return rewriter.create<UnPackOp>(
+  return rewriter.create<tensor::UnPackOp>(
       loc, packedValue, emptyOp, materializeEncodingInfo->innerDimsPos,
       *innerTileSizesOfr, materializeEncodingInfo->outerDimsPerm);
 }
@@ -303,7 +304,7 @@
     if (failed(packOp))
       return rewriter.notifyMatchFailure(encodingOp,
                                          "failed to convert to pack op");
-    rewriter.replaceOp(encodingOp, packOp->getResults());
+    rewriter.replaceOp(encodingOp, packOp->getResult());
     return success();
   }
 };
@@ -327,7 +328,7 @@
     if (failed(unpackOp))
       return rewriter.notifyMatchFailure(encodingOp,
                                          "failed to convert to unpack op");
-    rewriter.replaceOp(encodingOp, unpackOp->getResults());
+    rewriter.replaceOp(encodingOp, unpackOp->getResult());
     return success();
   }
 };
@@ -402,10 +403,11 @@
       return signalPassFailure();
   }
 
-  // Add patterns to fold pack/unpack ops with pad/extract_slice ops.
+  // Add patterns to fold tensor.pack/unpack ops with tensor.pad/extract_slice
+  // ops.
   {
     RewritePatternSet patterns(context);
-    populateFoldIntoPackAndUnpackOpsPatterns(patterns);
+    tensor::populateFoldIntoPackAndUnpackPatterns(patterns);
     if (failed(
             applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
       return signalPassFailure();
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
index 0a069d4..9dec445 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
@@ -16,10 +16,10 @@
 //  CHECK-DAG:   %[[OUTER_D0:.+]] = affine.apply #[[MAP0]]()[%[[D0]]]
 //  CHECK-DAG:   %[[OUTER_D1:.+]] = affine.apply #[[MAP1]]()[%[[D1]]]
 //      CHECK:   %[[PACK_DEST:.+]] = tensor.empty(%[[OUTER_D0]], %[[OUTER_D1]]) : tensor<?x?x8x4xf32>
-//      CHECK:   %[[PACK:.+]] = iree_linalg_ext.pack
+//      CHECK:   %[[PACK:.+]] = tensor.pack
 // CHECK-SAME:     %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %[[PACK_DEST]]
 //      CHECK:   %[[UNPACK_DEST:.+]] = tensor.empty(%[[D0]], %[[D1]]) : tensor<?x?xf32>
-//      CHECK:   %[[UNPACK:.+]] = iree_linalg_ext.unpack %[[PACK]] inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %[[UNPACK_DEST]]
+//      CHECK:   %[[UNPACK:.+]] = tensor.unpack %[[PACK]] inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %[[UNPACK_DEST]]
 //      CHECK:   return %[[UNPACK]]
 
 // -----
@@ -30,9 +30,9 @@
   return %1 : tensor<?x?xf32>
 }
 // CHECK-LABEL: func @pack_unpack_gemm_rhs(
-//       CHECK:   linalg_ext.pack
+//       CHECK:   tensor.pack
 //  CHECK-SAME:     outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 4]
-//       CHECK:   linalg_ext.unpack %{{.+}} outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 4]
+//       CHECK:   tensor.unpack %{{.+}} outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 4]
 
 // -----
 
@@ -42,9 +42,9 @@
   return %1 : tensor<?x?xf32>
 }
 // CHECK-LABEL: func @pack_unpack_gemm_result(
-//       CHECK:   linalg_ext.pack
+//       CHECK:   tensor.pack
 //  CHECK-SAME:     inner_dims_pos = [0, 1] inner_tiles = [8, 8]
-//       CHECK:   linalg_ext.unpack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [8, 8]
+//       CHECK:   tensor.unpack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [8, 8]
 
 // -----
 
@@ -77,21 +77,21 @@
 // CHECK-SAME:     %[[ARG2:.+]]: tensor<100x500xf32>
 //      CHECK:   %[[CST:.+]] = arith.constant 0.0
 //      CHECK:   %[[INIT_LHS:.+]] = tensor.empty() : tensor<13x63x8x4xf32>
-//      CHECK:   %[[PACK_LHS:.+]] = iree_linalg_ext.pack
+//      CHECK:   %[[PACK_LHS:.+]] = tensor.pack
 // CHECK-SAME:     %[[ARG0]] padding_value(%[[CST]] : f32)
 // CHECK-SAME:       into %[[INIT_LHS]]
 //      CHECK:   %[[INIT_RHS:.+]] = tensor.empty() : tensor<63x63x8x4xf32>
-//      CHECK:   %[[PACK_RHS:.+]] = iree_linalg_ext.pack
+//      CHECK:   %[[PACK_RHS:.+]] = tensor.pack
 // CHECK-SAME:     %[[ARG1]] padding_value(%[[CST]] : f32)
 // CHECK-SAME:       into %[[INIT_RHS]]
 //      CHECK:   %[[INIT_RESULT:.+]] = tensor.empty() : tensor<13x63x8x8xf32>
-//      CHECK:   %[[PACK_RESULT:.+]] = iree_linalg_ext.pack
+//      CHECK:   %[[PACK_RESULT:.+]] = tensor.pack
 // CHECK-SAME:     %[[ARG2]] padding_value(%[[CST]] : f32)
 // CHECK-SAME:       into %[[INIT_RESULT]]
 //      CHECK:   %[[MMT4D:.+]] = linalg.mmt4d
 // CHECK-SAME:       ins(%[[PACK_LHS]], %[[PACK_RHS]] :
 // CHECK-SAME:       outs(%[[PACK_RESULT]] :
-//      CHECK:   %[[UNPACK:.+]] = iree_linalg_ext.unpack %[[MMT4D]]
+//      CHECK:   %[[UNPACK:.+]] = tensor.unpack %[[MMT4D]]
 //      CHECK:   return %[[UNPACK]]
 
 // -----
@@ -111,16 +111,16 @@
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
-//      CHECK:   %[[PACK_LHS:.+]] = iree_linalg_ext.pack
+//      CHECK:   %[[PACK_LHS:.+]] = tensor.pack
 // CHECK-SAME:     %[[ARG0]]
-//      CHECK:   %[[PACK_RHS:.+]] = iree_linalg_ext.pack
+//      CHECK:   %[[PACK_RHS:.+]] = tensor.pack
 // CHECK-SAME:     %[[ARG1]]
-//      CHECK:   %[[PACK_RESULT:.+]] = iree_linalg_ext.pack
+//      CHECK:   %[[PACK_RESULT:.+]] = tensor.pack
 // CHECK-SAME:     %[[ARG2]]
 //      CHECK:   %[[MMT4D:.+]] = linalg.mmt4d
 // CHECK-SAME:       ins(%[[PACK_LHS]], %[[PACK_RHS]] :
 // CHECK-SAME:       outs(%[[PACK_RESULT]] :
-//      CHECK:   %[[UNPACK:.+]] = iree_linalg_ext.unpack %[[MMT4D]]
+//      CHECK:   %[[UNPACK:.+]] = tensor.unpack %[[MMT4D]]
 //      CHECK:   return %[[UNPACK]]
 
 // -----
@@ -152,8 +152,8 @@
 //  CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
 //  CHECK-DAG:   %[[OUT_D0:.+]] = affine.apply #[[MAP0]]()[%[[D0]]]
 //  CHECK-DAG:   %[[OUT_D1:.+]] = affine.apply #[[MAP0]]()[%[[D1]]]
-//  CHECK-DAG:   %[[PACK_LHS:.+]] = iree_linalg_ext.pack {{.*}}%[[ARG0]]
-//      CHECK:   %[[PACK_RHS:.+]] = iree_linalg_ext.pack
+//  CHECK-DAG:   %[[PACK_LHS:.+]] = tensor.pack {{.*}}%[[ARG0]]
+//      CHECK:   %[[PACK_RHS:.+]] = tensor.pack
 // CHECK-SAME:     %[[ARG1]]
 //  CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty(%[[OUT_D0]], %[[OUT_D1]]) : tensor<?x?x8x8xf32>
 //      CHECK:   %[[FILL:.+]] = linalg.fill
@@ -161,5 +161,5 @@
 //      CHECK:   %[[MMT4D:.+]] = linalg.mmt4d
 // CHECK-SAME:       ins(%[[PACK_LHS]], %[[PACK_RHS]] :
 // CHECK-SAME:       outs(%[[FILL]] :
-//      CHECK:   %[[UNPACK:.+]] = iree_linalg_ext.unpack %[[MMT4D]]
+//      CHECK:   %[[UNPACK:.+]] = tensor.unpack %[[MMT4D]]
 //      CHECK:   return %[[UNPACK]]