Plumb e2e tensor.(un)pack support through VMVX and microkernels.  (#12133)

It adds support for bufferizing tensor.pack/unpack into
iree_linalg_ext.pack/unpack ops. The VMVX and ukernels need memref
version of pack/unpack op at a stage of their pipeline. This allows VMVX
backend to use upstream ops and transforms until bufferization; reuse
the existing LinalgExt ops definition and LowerToLoops transform to
handle the rest of pipeline.

It also add the support of tensor.pack op for dynamic inner tiling sizes
for llvm-cpu backend. Note that it's not vectorized in this case.
diff --git a/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp b/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
index 3f651a3..9d5b3cc 100644
--- a/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
@@ -9,6 +9,8 @@
 // Wrapper pass to use MLIR's One-Shot Bufferize pass.
 //
 //===----------------------------------------------------------------------===//
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
 #include "iree/compiler/Codegen/Interfaces/BufferizationInterfaces.h"
 #include "iree/compiler/Codegen/PassDetail.h"
 #include "iree/compiler/Codegen/Passes.h"
@@ -74,6 +76,7 @@
                 bufferization::BufferizationDialect,
                 func::FuncDialect,
                 IREE::Flow::FlowDialect,
+                IREE::LinalgExt::IREELinalgExtDialect,
                 IREE::Util::UtilDialect,
                 linalg::LinalgDialect,
                 memref::MemRefDialect,
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir b/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
index 62e0e10..e578c1d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
@@ -2334,7 +2334,7 @@
 
 // -----
 
-func.func @pack() {
+func.func @iree_linalg_ext_pack() {
   %c0 = arith.constant 0 : index
   %c0_i32 = arith.constant 0 : i32
   %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<4x4xi32>>
@@ -2345,7 +2345,7 @@
   flow.dispatch.tensor.store %4, %1, offsets = [0, 0, 0, 0], sizes = [2, 2, 3, 3], strides = [1, 1, 1, 1] : tensor<2x2x3x3xi32> -> !flow.dispatch.tensor<writeonly:tensor<2x2x3x3xi32>>
   return
 }
-// CHECK: func.func @pack
+// CHECK: func.func @iree_linalg_ext_pack
 // CHECK-DAG:  %[[PAD:.+]] = arith.constant 0 : i32
 // CHECK-DAG:  %[[IN:.+]] = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : memref<4x4xi32, #hal.descriptor_type<storage_buffer>>
 // CHECK-DAG:  %[[OUT:.+]] = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<2x2x3x3xi32, #hal.descriptor_type<storage_buffer>>
@@ -2355,7 +2355,7 @@
 
 // -----
 
-func.func @unpack() {
+func.func @iree_linalg_ext_unpack() {
   %c0 = arith.constant 0 : index
   %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<2x2x2x2xi32>>
   %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<4x4xi32>>
@@ -2365,7 +2365,7 @@
   flow.dispatch.tensor.store %4, %1, offsets = [0, 0], sizes = [4, 4], strides = [1, 1] : tensor<4x4xi32> -> !flow.dispatch.tensor<writeonly:tensor<4x4xi32>>
   return
 }
-// CHECK: func.func @unpack
+// CHECK: func.func @iree_linalg_ext_unpack
 // CHECK-DAG:  %[[IN:.+]] = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : memref<2x2x2x2xi32, #hal.descriptor_type<storage_buffer>>
 // CHECK-DAG:  %[[OUT:.+]] = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<4x4xi32, #hal.descriptor_type<storage_buffer>>
 // CHECK:      iree_linalg_ext.unpack %[[IN]]
@@ -2373,7 +2373,7 @@
 
 // -----
 
-func.func @unpack_fully_dynamic() {
+func.func @iree_linalg_ext_unpack_fully_dynamic() {
   %c0 = arith.constant 0 : index
   %inner_d0 = util.unfoldable_constant 2 : index
   %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<2x2x2x2xi32>>
@@ -2385,12 +2385,71 @@
   return
 }
 
-// CHECK:      func.func @unpack_fully_dynamic
+// CHECK:      func.func @iree_linalg_ext_unpack_fully_dynamic
 // CHECK-DAG:  %[[D:.+]] = util.optimization_barrier %c2 : index
 // CHECK:      iree_linalg_ext.unpack
 // CHECK-SAME:   inner_dims_pos = [0, 1] inner_tiles = [%[[D]], %[[D]]]
 
 // -----
+
+func.func @tensor_pack() {
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<4x4xi32>>
+  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x2x3x3xi32>>
+  %2 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [2, 2, 3, 3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<writeonly:tensor<2x2x3x3xi32>> -> tensor<2x2x3x3xi32>
+  %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [4, 4], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4x4xi32>> -> tensor<4x4xi32>
+  %4 = tensor.pack %3 padding_value(%c0_i32 : i32) inner_dims_pos = [0, 1] inner_tiles = [3, 3] into %2 : tensor<4x4xi32> -> tensor<2x2x3x3xi32>
+  flow.dispatch.tensor.store %4, %1, offsets = [0, 0, 0, 0], sizes = [2, 2, 3, 3], strides = [1, 1, 1, 1] : tensor<2x2x3x3xi32> -> !flow.dispatch.tensor<writeonly:tensor<2x2x3x3xi32>>
+  return
+}
+// CHECK: func.func @tensor_pack
+// CHECK-DAG:  %[[PAD:.+]] = arith.constant 0 : i32
+// CHECK-DAG:  %[[IN:.+]] = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : memref<4x4xi32, #hal.descriptor_type<storage_buffer>>
+// CHECK-DAG:  %[[OUT:.+]] = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<2x2x3x3xi32, #hal.descriptor_type<storage_buffer>>
+// CHECK:      iree_linalg_ext.pack %[[IN]]
+// CHECK-SAME:   padding_value(%[[PAD]] : i32)
+// CHECK-SAME:   inner_dims_pos = [0, 1] inner_tiles = [3, 3] into %[[OUT]]
+
+// -----
+
+func.func @tensor_unpack() {
+  %c0 = arith.constant 0 : index
+  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<2x2x2x2xi32>>
+  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<4x4xi32>>
+  %2 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [4, 4], strides = [1, 1] : !flow.dispatch.tensor<writeonly:tensor<4x4xi32>> -> tensor<4x4xi32>
+  %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [2, 2, 2, 2], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x2x2x2xi32>> -> tensor<2x2x2x2xi32>
+  %4 = tensor.unpack %3 inner_dims_pos = [0, 1] inner_tiles = [2, 2] into %2 : tensor<2x2x2x2xi32> -> tensor<4x4xi32>
+  flow.dispatch.tensor.store %4, %1, offsets = [0, 0], sizes = [4, 4], strides = [1, 1] : tensor<4x4xi32> -> !flow.dispatch.tensor<writeonly:tensor<4x4xi32>>
+  return
+}
+// CHECK: func.func @tensor_unpack
+// CHECK-DAG:  %[[IN:.+]] = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : memref<2x2x2x2xi32, #hal.descriptor_type<storage_buffer>>
+// CHECK-DAG:  %[[OUT:.+]] = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<4x4xi32, #hal.descriptor_type<storage_buffer>>
+// CHECK:      iree_linalg_ext.unpack %[[IN]]
+// CHECK-SAME:   inner_dims_pos = [0, 1] inner_tiles = [2, 2] into %[[OUT]]
+
+// -----
+
+func.func @tensor_unpack_fully_dynamic() {
+  %c0 = arith.constant 0 : index
+  %inner_d0 = util.unfoldable_constant 2 : index
+  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<2x2x2x2xi32>>
+  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<4x4xi32>>
+  %2 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [4, 4], strides = [1, 1] : !flow.dispatch.tensor<writeonly:tensor<4x4xi32>> -> tensor<4x4xi32>
+  %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [2, 2, %inner_d0, %inner_d0], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x2x2x2xi32>> -> tensor<2x2x?x?xi32>
+  %4 = tensor.unpack %3 inner_dims_pos = [0, 1] inner_tiles = [%inner_d0, %inner_d0] into %2 : tensor<2x2x?x?xi32> -> tensor<4x4xi32>
+  flow.dispatch.tensor.store %4, %1, offsets = [0, 0], sizes = [4, 4], strides = [1, 1] : tensor<4x4xi32> -> !flow.dispatch.tensor<writeonly:tensor<4x4xi32>>
+  return
+}
+
+// CHECK:      func.func @tensor_unpack_fully_dynamic
+// CHECK-DAG:  %[[D:.+]] = util.optimization_barrier %c2 : index
+// CHECK:      iree_linalg_ext.unpack
+// CHECK-SAME:   inner_dims_pos = [0, 1] inner_tiles = [%[[D]], %[[D]]]
+
+// -----
+
 module {
   func.func @reduction_ew() {
     %c5120 = arith.constant 5120 : index
diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/BUILD b/compiler/src/iree/compiler/Codegen/Interfaces/BUILD
index 5bec368..77331e5 100644
--- a/compiler/src/iree/compiler/Codegen/Interfaces/BUILD
+++ b/compiler/src/iree/compiler/Codegen/Interfaces/BUILD
@@ -89,6 +89,7 @@
         "@llvm-project//mlir:MemRefDialect",
         "@llvm-project//mlir:SCFTransforms",
         "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TensorDialect",
         "@llvm-project//mlir:TensorTransforms",
         "@llvm-project//mlir:VectorTransforms",
     ],
diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp b/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp
index c9d7060..9416cb1 100644
--- a/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp
+++ b/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Support/LLVM.h"
@@ -352,6 +353,133 @@
   }
 };
 
+/// Returns the buffers of the source and destination for pack and unpack ops.
+/// Returns a failure if the buffers can not be found.
+template <typename OpTy>
+static FailureOr<std::pair<Value, Value>> getSourceAndDestFromPackUnPackOp(
+    RewriterBase &rewriter, OpTy op, const BufferizationOptions &options) {
+  static_assert(llvm::is_one_of<OpTy, tensor::PackOp, tensor::UnPackOp>::value);
+  Value source;
+  auto maybeBuffer = getBuffer(rewriter, op.getSource(), options);
+  if (failed(maybeBuffer)) return failure();
+  source = *maybeBuffer;
+
+  Value dest;
+  AnalysisState analysisState(options);
+  SmallVector<OpOperand *> aliasingOpOperands =
+      analysisState.getAliasingOpOperands(op->getOpResult(0));
+  assert(aliasingOpOperands.size() == 1 && "expected 1 OpOperand");
+  FailureOr<Value> resultBuffer =
+      getBuffer(rewriter, aliasingOpOperands.front()->get(), options);
+  if (failed(resultBuffer)) return failure();
+  dest = *resultBuffer;
+  return std::make_pair(source, dest);
+}
+
+static LogicalResult bufferizePackOp(RewriterBase &rewriter, tensor::PackOp op,
+                                     const BufferizationOptions &options) {
+  // Take a guard before anything else.
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(op);
+
+  auto maybeSrcAndDest =
+      getSourceAndDestFromPackUnPackOp(rewriter, op, options);
+  if (failed(maybeSrcAndDest)) return failure();
+  auto [source, dest] = *maybeSrcAndDest;
+
+  // Set insertion point now that potential alloc/dealloc are introduced.
+  rewriter.setInsertionPoint(op);
+  rewriter.create<IREE::LinalgExt::PackOp>(
+      op.getLoc(), source, dest, op.getInnerDimsPos(), op.getMixedTiles(),
+      op.getPaddingValue(), op.getOuterDimsPerm());
+
+  // Replace the results of the old op with the new output buffers.
+  bufferization::replaceOpWithBufferizedValues(rewriter, op, dest);
+
+  return success();
+}
+
+static LogicalResult bufferizeUnPackOp(RewriterBase &rewriter,
+                                       tensor::UnPackOp op,
+                                       const BufferizationOptions &options) {
+  // Take a guard before anything else.
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(op);
+
+  auto maybeSrcAndDest =
+      getSourceAndDestFromPackUnPackOp(rewriter, op, options);
+  if (failed(maybeSrcAndDest)) return failure();
+  auto [source, dest] = *maybeSrcAndDest;
+
+  // Set insertion point now that potential alloc/dealloc are introduced.
+  rewriter.setInsertionPoint(op);
+  rewriter.create<IREE::LinalgExt::UnPackOp>(
+      op.getLoc(), source, dest, op.getInnerDimsPos(), op.getMixedTiles(),
+      op.getOuterDimsPerm());
+
+  // Replace the results of the old op with the new output buffers.
+  bufferization::replaceOpWithBufferizedValues(rewriter, op, dest);
+
+  return success();
+}
+
+template <typename OpTy>
+struct PackUnPackOpInterface
+    : public BufferizableOpInterface::ExternalModel<PackUnPackOpInterface<OpTy>,
+                                                    OpTy> {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              const AnalysisState &state) const {
+    return true;
+  }
+
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               const AnalysisState &state) const {
+    // Operand is written to if it has an aliasing OpResult.
+    auto dpsOp = cast<DestinationStyleOpInterface>(op);
+    return dpsOp.isDpsInit(&opOperand);
+  }
+
+  SmallVector<OpOperand *> getAliasingOpOperand(
+      Operation *op, OpResult opResult, const AnalysisState &state) const {
+    auto dpsOp = cast<DestinationStyleOpInterface>(op);
+    return {dpsOp.getDpsInitOperand(opResult.getResultNumber())};
+  }
+
+  SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                                            const AnalysisState &state) const {
+    auto dspOp = cast<DestinationStyleOpInterface>(op);
+
+    // The i-th "out" tensor may alias with the i-th OpResult.
+    if (dspOp.isDpsInit(&opOperand)) return {dspOp.getTiedOpResult(&opOperand)};
+    return {};
+  }
+
+  bufferization::AliasingOpResultList getAliasingOpResults(
+      Operation *op, OpOperand &opOperand, const AnalysisState &state) const {
+    auto dspOp = cast<DestinationStyleOpInterface>(op);
+
+    // The i-th "out" tensor may alias with the i-th OpResult.
+    if (dspOp.isDpsInit(&opOperand)) return {dspOp.getTiedOpResult(&opOperand)};
+    return {};
+  }
+
+  bufferization::BufferRelation bufferRelation(
+      Operation *op, OpResult opResult, const AnalysisState &state) const {
+    return bufferization::BufferRelation::Equivalent;
+  }
+
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+                          const BufferizationOptions &options) const {
+    return TypeSwitch<Operation *, LogicalResult>(op)
+        .template Case<tensor::PackOp>(
+            [&](auto pack) { return bufferizePackOp(rewriter, pack, options); })
+        .template Case<tensor::UnPackOp>([&](auto unpack) {
+          return bufferizeUnPackOp(rewriter, unpack, options);
+        })
+        .Default([](auto) { return failure(); });
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // IREE specific post analysis transformations.
 //===----------------------------------------------------------------------===//
@@ -461,6 +589,12 @@
     IREE::LinalgExt::AttentionOp::attachInterface<
         LinalgExtOpInterface<IREE::LinalgExt::AttentionOp>>(*ctx);
   });
+  registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
+    tensor::PackOp::attachInterface<PackUnPackOpInterface<tensor::PackOp>>(
+        *ctx);
+    tensor::UnPackOp::attachInterface<PackUnPackOpInterface<tensor::UnPackOp>>(
+        *ctx);
+  });
 }
 
 }  // namespace iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt
index fa3d7c6..86d3d19 100644
--- a/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt
@@ -58,6 +58,7 @@
     MLIRMemRefDialect
     MLIRSCFTransforms
     MLIRSupport
+    MLIRTensorDialect
     MLIRTensorTransforms
     MLIRVectorTransforms
     iree::compiler::Dialect::Flow::IR
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index dfd53c5..99dd29a 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -1703,7 +1703,11 @@
   SmallVector<int64_t> staticTileSizes;
   SmallVector<Value> dynamicTileSizes;
   dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
-  build(builder, state, output.getType(), source, output,
+  SmallVector<Type> resultType;
+  auto outputType = output.getType();
+  if (outputType.isa<RankedTensorType>())
+    resultType.push_back(outputType);
+  build(builder, state, resultType, source, output,
         outerDimsPerm.empty() ? nullptr
                               : builder.getDenseI64ArrayAttr(outerDimsPerm),
         builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
@@ -2125,7 +2129,11 @@
   SmallVector<int64_t> staticTileSizes;
   SmallVector<Value> dynamicTileSizes;
   dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
-  build(builder, state, output.getType(), source, output,
+  SmallVector<Type> resultType;
+  auto outputType = output.getType();
+  if (outputType.isa<RankedTensorType>())
+    resultType.push_back(outputType);
+  build(builder, state, resultType, source, output,
         outerDimsPerm.empty() ? nullptr
                               : builder.getDenseI64ArrayAttr(outerDimsPerm),
         builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
diff --git a/tests/e2e/tensor_ops/BUILD b/tests/e2e/tensor_ops/BUILD
index 74e1797..4c7b6a8 100644
--- a/tests/e2e/tensor_ops/BUILD
+++ b/tests/e2e/tensor_ops/BUILD
@@ -57,13 +57,13 @@
         # keep sorted
         [
             "extract_slice.mlir",
+            "pack.mlir",
             "tensor_insert_slice.mlir",
+            "unpack.mlir",
         ],
         include = ["*.mlir"],
         exclude = [
-            "pack.mlir",
             "tensor_cast.mlir",
-            "unpack.mlir",
         ],
     ),
     driver = "local-task",
@@ -71,6 +71,26 @@
 )
 
 iree_check_single_backend_test_suite(
+    name = "check_vmvx_ukernel_local-task",
+    srcs = [
+        "pack.mlir",
+        "unpack.mlir",
+    ],
+    compiler_flags = [
+        "--iree-vmvx-enable-microkernels",
+        # Some testcases have linalg.generic ops with multiple ops in the body.
+        # If we don't opt out from it, DecomposeLinalgGenericPass splits those
+        # into smaller linalg.generic ops with only one op in the body. This
+        # results in the creation of temporary buffers between these split
+        # linalg.generic ops, causing:
+        # > error: failed to legalize operation 'memref.alloca' that was explicitly marked illegal
+        "--iree-vmvx-enable-microkernels-decompose-linalg-generic=false",
+    ],
+    driver = "local-task",
+    target_backend = "vmvx",
+)
+
+iree_check_single_backend_test_suite(
     name = "check_cuda",
     srcs = enforce_glob(
         # keep sorted
diff --git a/tests/e2e/tensor_ops/CMakeLists.txt b/tests/e2e/tensor_ops/CMakeLists.txt
index f992180..fed57de 100644
--- a/tests/e2e/tensor_ops/CMakeLists.txt
+++ b/tests/e2e/tensor_ops/CMakeLists.txt
@@ -45,7 +45,9 @@
     check_vmvx_local-task
   SRCS
     "extract_slice.mlir"
+    "pack.mlir"
     "tensor_insert_slice.mlir"
+    "unpack.mlir"
   TARGET_BACKEND
     "vmvx"
   DRIVER
@@ -54,6 +56,21 @@
 
 iree_check_single_backend_test_suite(
   NAME
+    check_vmvx_ukernel_local-task
+  SRCS
+    "pack.mlir"
+    "unpack.mlir"
+  TARGET_BACKEND
+    "vmvx"
+  DRIVER
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-vmvx-enable-microkernels"
+    "--iree-vmvx-enable-microkernels-decompose-linalg-generic=false"
+)
+
+iree_check_single_backend_test_suite(
+  NAME
     check_cuda
   SRCS
     "extract_slice.mlir"
diff --git a/tests/e2e/tensor_ops/pack.mlir b/tests/e2e/tensor_ops/pack.mlir
index 6569471..1adb0e5 100644
--- a/tests/e2e/tensor_ops/pack.mlir
+++ b/tests/e2e/tensor_ops/pack.mlir
@@ -438,61 +438,59 @@
   return
 }
 
-// TODO(hanchung): Enable the test once we have better ideas about supporting
-// dynamic inner tiling sizes. We are not able to vectorized this case today.
-// func.func @fully_dynamic_pack_simple() {
-//   %iree_input = flow.tensor.constant dense<[
-//     [0, 1, 2, 3],
-//     [4, 5, 6, 7],
-//     [8, 9, 10, 11],
-//     [12, 13, 14, 15]]> : tensor<4x4xi32> -> tensor<?x?xi32>
-//   %c0 = arith.constant 0 : index
-//   %c1 = arith.constant 1 : index
-//   %c2 = util.unfoldable_constant 2 : index
-//   %in_d0 = tensor.dim %iree_input, %c0 : tensor<?x?xi32>
-//   %in_d1 = tensor.dim %iree_input, %c1 : tensor<?x?xi32>
-//   %out_d0 = arith.ceildivui %in_d0, %c2 : index
-//   %out_d1 = arith.ceildivui %in_d1, %c2 : index
-//   %init = tensor.empty(%out_d0, %out_d1, %c2, %c2) : tensor<?x?x?x?xi32>
-//   %pack = tensor.pack %iree_input inner_dims_pos = [0, 1] inner_tiles = [%c2, %c2] into %init
-//       : tensor<?x?xi32> -> tensor<?x?x?x?xi32>
-//   %cast = tensor.cast %pack : tensor<?x?x?x?xi32> to tensor<2x2x2x2xi32>
-//   check.expect_eq_const(%cast, dense<[[[[0, 1], [4, 5]], [[2, 3], [6, 7]]], [[[8, 9], [12, 13]], [[10 ,11], [14, 15]]]]> : tensor<2x2x2x2xi32>) : tensor<2x2x2x2xi32>
-//   return
-// }
-//
-// func.func @fully_dynamic_pack_pad_transpose_inner_and_outer_dims_large() {
-//   %d0 = util.unfoldable_constant 100 : index
-//   %d1 = util.unfoldable_constant 250 : index
-//   %source = call @generate_2D_source(%d0, %d1) : (index, index) -> tensor<?x?xi32>
-//   %padding_value = arith.constant 42 : i32
-//
-//   %c16 = util.unfoldable_constant 16 : index
-//   %c32 = util.unfoldable_constant 32 : index
-//   %tiled_d0 = arith.ceildivui %d0, %c32 : index
-//   %tiled_d1 = arith.ceildivui %d1, %c16 : index
-//   %init_pack = tensor.empty(%tiled_d1, %tiled_d0, %c16, %c32) : tensor<?x?x?x?xi32>
-//   %pack = tensor.pack %source padding_value(%padding_value : i32)
-//       outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [%c16, %c32] into %init_pack
-//       : tensor<?x?xi32> -> tensor<?x?x?x?xi32>
-//   %cast_pack = tensor.cast %pack : tensor<?x?x?x?xi32> to tensor<16x4x16x32xi32>
-//
-//   %c100 = arith.constant 100 : index
-//   %c250 = arith.constant 250 : index
-//   %source2 = call @generate_2D_source(%c100, %c250) : (index, index) -> tensor<?x?xi32>
-//   %static_source = tensor.cast %source2 : tensor<?x?xi32> to tensor<100x250xi32>
-//
-//   %pad = tensor.pad %static_source low[0, 0] high[28, 6] {
-//     ^bb0(%b0 : index, %b1 : index):
-//       tensor.yield %padding_value : i32
-//   } : tensor<100x250xi32> to tensor<128x256xi32>
-//   %reshape = tensor.expand_shape %pad [[0, 1], [2, 3]] : tensor<128x256xi32> into tensor<4x32x16x16xi32>
-//   %init_transpose = tensor.empty() : tensor<16x4x16x32xi32>
-//   %transpose = linalg.transpose
-//     ins(%reshape : tensor<4x32x16x16xi32>)
-//     outs(%init_transpose : tensor<16x4x16x32xi32>)
-//     permutation = [2, 0, 3, 1]
-//
-//   check.expect_eq(%cast_pack, %transpose) : tensor<16x4x16x32xi32>
-//   return
-// }
+func.func @fully_dynamic_pack_simple() {
+  %iree_input = flow.tensor.constant dense<[
+    [0, 1, 2, 3],
+    [4, 5, 6, 7],
+    [8, 9, 10, 11],
+    [12, 13, 14, 15]]> : tensor<4x4xi32> -> tensor<?x?xi32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = util.unfoldable_constant 2 : index
+  %in_d0 = tensor.dim %iree_input, %c0 : tensor<?x?xi32>
+  %in_d1 = tensor.dim %iree_input, %c1 : tensor<?x?xi32>
+  %out_d0 = arith.ceildivui %in_d0, %c2 : index
+  %out_d1 = arith.ceildivui %in_d1, %c2 : index
+  %init = tensor.empty(%out_d0, %out_d1, %c2, %c2) : tensor<?x?x?x?xi32>
+  %pack = tensor.pack %iree_input inner_dims_pos = [0, 1] inner_tiles = [%c2, %c2] into %init
+      : tensor<?x?xi32> -> tensor<?x?x?x?xi32>
+  %cast = tensor.cast %pack : tensor<?x?x?x?xi32> to tensor<2x2x2x2xi32>
+  check.expect_eq_const(%cast, dense<[[[[0, 1], [4, 5]], [[2, 3], [6, 7]]], [[[8, 9], [12, 13]], [[10 ,11], [14, 15]]]]> : tensor<2x2x2x2xi32>) : tensor<2x2x2x2xi32>
+  return
+}
+
+func.func @fully_dynamic_pack_pad_transpose_inner_and_outer_dims_large() {
+  %d0 = util.unfoldable_constant 100 : index
+  %d1 = util.unfoldable_constant 250 : index
+  %source = call @generate_2D_source(%d0, %d1) : (index, index) -> tensor<?x?xi32>
+  %padding_value = arith.constant 42 : i32
+
+  %c16 = util.unfoldable_constant 16 : index
+  %c32 = util.unfoldable_constant 32 : index
+  %tiled_d0 = arith.ceildivui %d0, %c32 : index
+  %tiled_d1 = arith.ceildivui %d1, %c16 : index
+  %init_pack = tensor.empty(%tiled_d1, %tiled_d0, %c16, %c32) : tensor<?x?x?x?xi32>
+  %pack = tensor.pack %source padding_value(%padding_value : i32)
+      outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [%c16, %c32] into %init_pack
+      : tensor<?x?xi32> -> tensor<?x?x?x?xi32>
+  %cast_pack = tensor.cast %pack : tensor<?x?x?x?xi32> to tensor<16x4x16x32xi32>
+
+  %c100 = arith.constant 100 : index
+  %c250 = arith.constant 250 : index
+  %source2 = call @generate_2D_source(%c100, %c250) : (index, index) -> tensor<?x?xi32>
+  %static_source = tensor.cast %source2 : tensor<?x?xi32> to tensor<100x250xi32>
+
+  %pad = tensor.pad %static_source low[0, 0] high[28, 6] {
+    ^bb0(%b0 : index, %b1 : index):
+      tensor.yield %padding_value : i32
+  } : tensor<100x250xi32> to tensor<128x256xi32>
+  %reshape = tensor.expand_shape %pad [[0, 1], [2, 3]] : tensor<128x256xi32> into tensor<4x32x16x16xi32>
+  %init_transpose = tensor.empty() : tensor<16x4x16x32xi32>
+  %transpose = linalg.transpose
+    ins(%reshape : tensor<4x32x16x16xi32>)
+    outs(%init_transpose : tensor<16x4x16x32xi32>)
+    permutation = [2, 0, 3, 1]
+
+  check.expect_eq(%cast_pack, %transpose) : tensor<16x4x16x32xi32>
+  return
+}