[DT] Teach encoding about padding. (#17077)

The revision has four major commits. They are expected to be landed
together because we need all the piece to make it work.

## Make encodings be able to carry padding semantics

https://github.com/openxla/iree/pull/17077/commits/7967faf4d2b442f7f1229d3801dbc95678ebe051
introduces a `round_dims_to` integer array on encodings. It represents
the values for padding M,N,K dimensions. This provides the hints for
both host and device the values that every
dimension should be aligned with. Eventually we should have a better way
for propagating the information between host side and device side. The
revision is a step towards to the goal.

The commit adds an option to SetEncoding pass. If the `padFactor` is
set, the `round_dims_to` will be filled with the values; it only
generates set_encoding ops, but not
`iree_linalg_ext.upper_bound_tile_size` and `tensor.pad` ops.

## Teach Pack/UnPack Materialization Patterns about the new field

https://github.com/openxla/iree/pull/17077/commits/2652c028b586b984dca237cfc6cd53a1ffa5235e
teaches the materialization patterns to handle the new field. If the
field is set, the inner tile sizes can't be greater than corresponding
`round_dims_to` values. Otherwise, the actual buffer size could
mismatch.

## Teach stream.tensor.sizeof to take encoding into accounts

https://github.com/openxla/iree/pull/17077/commits/365dc413250675ae8e49cfcaccc3e4341d1d3432
teaches stream.tensor.sizeof to calculate proper sizes based on
encodings. The encodings have `role`, `indexing_maps`, and
`round_dims_to`. So it is able to look at `role` and `indexing_map` to
infer contraction dimensions; pads the dimension to be aligned with
values in `round_dims_to`. E.g.,

```mlir
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
util.func public @sizeof_lhs_encoding_dynamic(%arg0: index, %arg1: index) -> index {
  %0 = stream.tensor.sizeof tensor<?x?xf32, #iree_linalg_ext.encoding<
    role = LHS,
    element_types = [f32, f32, f32],
    original_type = tensor<?x?xf32>,
    user_indexing_maps = [#map, #map1, #map2],
    round_dims_to = 4, 8, 16>>{%arg0, %arg1} : index
  util.return %0 : index
}
// CHECK-LABEL: @sizeof_lhs_encoding_dynamic
// CHECK-DAG:     %[[C4:.+]] = arith.constant 4 : index
// CHECK-DAG:     %[[C16:.+]] = arith.constant 16 : index
// CHECK:         %[[CEIL_DIV_D0:.+]] = arith.ceildivui %arg0, %[[C4]]
// CHECK:         %[[PAD_D0:.+]] = arith.muli %[[CEIL_DIV_D0]], %[[C4]]
// CHECK:         %[[CEIL_DIV_D1:.+]] = arith.ceildivui %arg1, %[[C16]]
// CHECK:         %[[PAD_D1:.+]] = arith.muli %[[CEIL_DIV_D1]], %[[C16]]
// CHECK:         %[[T0:.+]] = arith.muli %[[PAD_D0]], %[[C4]]
// CHECK:         %[[T1:.+]] = arith.muli %[[T0]], %[[PAD_D1]]
// CHECK:         return %[[T1]]
```

## Add e2e tests

https://github.com/openxla/iree/pull/17077/commits/5ca4a4b7a5a672e25d257cc9443c030735c305e5
adds a new test suite with
`--iree-global-opt-enable-early-materialization=false`, so we have
enough e2e test coverage for the new path.
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp
index b6d9c59..0499976 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp
@@ -26,6 +26,8 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
+#define DEBUG_TYPE "cpu-materialize-encoding"
+
 namespace mlir::iree_compiler {
 
 using namespace IREE::LinalgExt;
@@ -305,13 +307,23 @@
   return {};
 }
 
-static TileMxNxK chooseMatmulTile(ArrayRef<TileMxNxK> enumeratedTiles,
-                                  int64_t matmulNarrowM,
-                                  int64_t matmulNarrowN) {
+/// Returns the best TileMxNxK from `enumeratedTiles` pool. If the
+/// `hostDefinedUpperBound` is not empty, the chosen tile sizes can not be
+/// greater than the values.
+/// TODO(#16933): Remove `hostDefinedUpperBound` once we can propagate such
+/// information to host. For now, they are defined by host.
+static TileMxNxK
+chooseMatmulTile(ArrayRef<TileMxNxK> enumeratedTiles, int64_t matmulNarrowM,
+                 int64_t matmulNarrowN,
+                 ArrayRef<int64_t> hostDefinedUpperBound = {}) {
+  assert((hostDefinedUpperBound.empty() || hostDefinedUpperBound.size() == 3) &&
+         "expected hostDefinedUpperBound is empty or has upper bound for {M, "
+         "N, K}");
   // Handle narrow-N by transposing to reduce to narrow-M. Note: the
   // enumeratedTiles currently only enumerate narrow-M cases.
   if (matmulNarrowN && (!matmulNarrowM || matmulNarrowN < matmulNarrowM)) {
-    TileMxNxK tile = chooseMatmulTile(enumeratedTiles, matmulNarrowN, 0);
+    TileMxNxK tile = chooseMatmulTile(enumeratedTiles, matmulNarrowN, 0,
+                                      hostDefinedUpperBound);
     std::swap(tile.M, tile.N);
     return tile;
   }
@@ -342,7 +354,26 @@
   SmallVector<RatedTileMxNxK> ratedTiles;
   ratedTiles.reserve(enumeratedTiles.size());
   int64_t bestPaddingPenalty = INT64_MAX;
+  int64_t mUB = INT64_MAX;
+  int64_t nUB = INT64_MAX;
+  int64_t kUB = INT64_MAX;
+  if (!hostDefinedUpperBound.empty()) {
+    mUB = hostDefinedUpperBound[0];
+    nUB = hostDefinedUpperBound[1];
+    kUB = hostDefinedUpperBound[2];
+  }
   for (auto tile : enumeratedTiles) {
+    if (tile.M > mUB || tile.N > nUB || tile.K > kUB) {
+      LLVM_DEBUG(llvm::dbgs() << "[" << DEBUG_TYPE << "]: tile (";
+                 llvm::interleaveComma(
+                     ArrayRef<int64_t>{tile.M, tile.N, tile.K}, llvm::dbgs());
+                 llvm::dbgs()
+                 << ") is skipped because it is not valid for upper_bound (";
+                 llvm::interleaveComma(ArrayRef<int64_t>{mUB, nUB, kUB},
+                                       llvm::dbgs());
+                 llvm::dbgs() << ")\n");
+      continue;
+    }
     RatedTileMxNxK ratedTile(tile);
     ratedTile.paddingPenalty = 0;
     // If we are choosing a tile for a narrow-M case, we want to minimize
@@ -468,7 +499,9 @@
   // Choose a final matmul TileMxNxK from the above-enumarated tile shapes,
   // taking narrow dimensions into account.
   TileMxNxK chosenTileMxNxK =
-      chooseMatmulTile(enumeratedTileMxNxK, matmulNarrowM, matmulNarrowN);
+      chooseMatmulTile(enumeratedTileMxNxK, matmulNarrowM, matmulNarrowN,
+                       encoding.getRoundDimsToArray());
+
   // Map the matmul TileMxNxK to an actual tile shape for the tensor at hand,
   // based on its role in the matmul.
   auto rank = tensorType.getRank();
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir b/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir
index ae13563..31bbc0c 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir
@@ -1,5 +1,36 @@
 // RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-cpu-materialize-encoding),canonicalize,cse)" --split-input-file %s | FileCheck %s
 
+
+func.func @set_encoding_with_padding_semantics_bf16_x86_64_avx512f() attributes {
+  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
+}{
+  %c0 = arith.constant 0 : index
+  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1x1000xbf16>>
+  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<1x1000xbf16, #iree_linalg_ext.encoding<role =  LHS, element_types = [bf16, bf16, bf16], original_type = tensor<1x1000xbf16>, matmul_narrow_M = 1 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], round_dims_to = array<i64: 16, 16, 16>>>>
+  %2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1, 1000], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<1x1000xbf16>> -> tensor<1x1000xbf16>
+  %3 = iree_linalg_ext.set_encoding %2 : tensor<1x1000xbf16> -> tensor<1x1000xbf16, #iree_linalg_ext.encoding<role =  LHS, element_types = [bf16, bf16, bf16], original_type = tensor<1x1000xbf16>, matmul_narrow_M = 1 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], round_dims_to = array<i64: 16, 16, 16>>>
+  flow.dispatch.tensor.store %3, %1, offsets = [0, 0], sizes = [1, 1000], strides = [1, 1] : tensor<1x1000xbf16, #iree_linalg_ext.encoding<role =  LHS, element_types = [bf16, bf16, bf16], original_type = tensor<1x1000xbf16>, matmul_narrow_M = 1  : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], round_dims_to = array<i64: 16, 16, 16>>> -> !flow.dispatch.tensor<writeonly:tensor<1x1000xbf16,  #iree_linalg_ext.encoding<role =  LHS, element_types = [bf16, bf16, bf16], original_type = tensor<1x1000xbf16>, matmul_narrow_M = 1 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], round_dims_to = array<i64: 16, 16, 16>>>>
+  return
+}
+// This tests that
+//   1. The padding value is created for tensor.pack ops.
+//   2. The inner tile sizes are less than or equal to values in round_dims_to.
+//      We could choose 128 when it is a narrow matrix.
+// CHECK-LABEL: func.func @set_encoding_with_padding_semantics_bf16_x86_64_avx512f
+// CHECK-DAG:     %[[PAD:.+]] = arith.constant 0.000000e+00 : bf16
+// CHECK-DAG:     %[[IN_BINDING:.+]] = hal.interface.binding.subspan {{.+}} : !flow.dispatch.tensor<readonly:tensor<1x1000xbf16>>
+// CHECK-DAG:     %[[OUT_BINDING:.+]] = hal.interface.binding.subspan {{.+}} : !flow.dispatch.tensor<writeonly:tensor<1x1000x16x1xbf16>>
+// CHECK:         %[[SRC:.+]] = flow.dispatch.tensor.load %[[IN_BINDING]]
+// CHECK-DAG:     %[[INIT:.+]] = tensor.empty() : tensor<1x1000x16x1xbf16>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[SRC]] padding_value(%[[PAD]] : bf16)
+// CHECK-SAME:      outer_dims_perm = [0, 1]
+// CHECK-SAME:      inner_dims_pos = [0, 1]
+// CHECK-SAME:      inner_tiles = [16, 1]
+// CHECK-SAME:      into %[[INIT]] : tensor<1x1000xbf16> -> tensor<1x1000x16x1xbf16>
+// CHECK:         flow.dispatch.tensor.store %[[PACK]], %[[OUT_BINDING]]
+
+// -----
+
 #map = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
 #map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
index 250152e..177fb22 100644
--- a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
@@ -220,7 +220,13 @@
   if (!encoding) {
     return failure();
   }
-  std::optional<Value> paddingValue = getPaddingValue(source);
+  std::optional<Value> paddingValue;
+  if (encoding.getRoundDimsToArray().empty()) {
+    paddingValue = getPaddingValue(source);
+  } else {
+    paddingValue = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getZeroAttr(resultType.getElementType()));
+  }
   SmallVector<OpFoldResult> sourceDims =
       tensor::getMixedSizes(rewriter, loc, source);
   SmallVector<OpFoldResult> resultDims = tensor::PackOp::getResultShape(
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtBase.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtBase.td
index 3d3dd36..2b34906 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtBase.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtBase.td
@@ -93,7 +93,9 @@
     // TODO(#15466): generalize matmul_narrow_{M,N} into a list?
     OptionalParameter<"IntegerAttr", "optional M narrow dimension size (only for contraction op user_indexing_maps)">:$matmul_narrow_M,
     OptionalParameter<"IntegerAttr", "optional N narrow dimension size (only for contraction op user_indexing_maps)">:$matmul_narrow_N,
-    OptionalParameter<"ArrayAttr", "Indexing maps of the operation using this tensor">:$user_indexing_maps
+    OptionalParameter<"ArrayAttr", "Indexing maps of the operation using this tensor">:$user_indexing_maps,
+    // TODO(hanchung): The round_dims_to parameter can be revisited. We explicitly map them to M,N,K dimension for now.
+    OptionalParameter<"DenseArrayAttr", "Values for padding M,N,K dimensions">:$round_dims_to
   );
 
   let builders = [
@@ -101,7 +103,8 @@
         "ArrayRef<Type>":$elemTypes, "Type":$origType,
         CArg<"std::optional<int64_t>", "{}">:$matmulNarrowM,
         CArg<"std::optional<int64_t>", "{}">:$matmulNarrowN,
-        CArg<"ArrayRef<AffineMap>", "{}">:$maps)>
+        CArg<"ArrayRef<AffineMap>", "{}">:$maps,
+        CArg<"ArrayRef<int64_t>", "{}">:$roundDimsTo)>
   ];
 
   let extraClassDeclaration = [{
@@ -111,6 +114,9 @@
     /// Given the dim position of the encoding `user_indexing_maps`, returns the
     /// matching index of the given encoding's tensor.
     unsigned mapDimToRoleIndex(int64_t dimPos);
+
+    /// Returns an integer array with values in `round_dims_to`.
+    ArrayRef<int64_t> getRoundDimsToArray();
   }];
 
   let genVerifyDecl = 0;
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index 55e4f3b..044a276 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -2894,16 +2894,20 @@
                                ArrayRef<Type> elemTypes, Type origType,
                                std::optional<int64_t> matmulNarrowM,
                                std::optional<int64_t> matmulNarrowN,
-                               ArrayRef<AffineMap> maps) {
+                               ArrayRef<AffineMap> maps,
+                               ArrayRef<int64_t> roundDimsTo) {
   Builder b(ctx);
   auto optionalToAttr = [&](std::optional<int64_t> x) {
     return x ? b.getIndexAttr(*x) : IntegerAttr();
   };
   auto roleAttr = EncodingRoleAttr::get(ctx, role);
   auto origTypeAttr = origType ? TypeAttr::get(origType) : TypeAttr();
+  auto roundDimsToAttr = roundDimsTo.empty()
+                             ? DenseI64ArrayAttr()
+                             : b.getDenseI64ArrayAttr(roundDimsTo);
   return get(ctx, roleAttr, b.getTypeArrayAttr(elemTypes), origTypeAttr,
              optionalToAttr(matmulNarrowM), optionalToAttr(matmulNarrowN),
-             b.getAffineMapArrayAttr(maps));
+             b.getAffineMapArrayAttr(maps), roundDimsToAttr);
 }
 
 AffineMap EncodingAttr::getMapForRole() {
@@ -2927,6 +2931,14 @@
   return idx.value();
 }
 
+ArrayRef<int64_t> EncodingAttr::getRoundDimsToArray() {
+  auto roundDimsTo = getRoundDimsTo();
+  if (!roundDimsTo) {
+    return {};
+  }
+  return roundDimsTo.cast<DenseI64ArrayAttr>().asArrayRef();
+}
+
 //===---------------------------------------------------------------------===//
 // LinalgExt Dialect Helpers
 //===---------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/encode_host_tensors.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/encode_host_tensors.mlir
index cb1aca7..8416e8a 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/encode_host_tensors.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/encode_host_tensors.mlir
@@ -21,6 +21,67 @@
 
 // -----
 
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+util.func public @sizeof_lhs_encoding_dynamic(%arg0: index, %arg1: index) -> index {
+  %0 = stream.tensor.sizeof tensor<?x?xf32, #iree_linalg_ext.encoding<role = LHS, element_types = [f32, f32, f32], original_type = tensor<?x?xf32>, user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 4, 8, 16>>>{%arg0, %arg1} : index
+  util.return %0 : index
+}
+// CHECK-LABEL: @sizeof_lhs_encoding_dynamic
+// CHECK-DAG:     %[[C4:.+]] = arith.constant 4 : index
+// CHECK-DAG:     %[[C16:.+]] = arith.constant 16 : index
+// CHECK:         %[[CEIL_DIV_D0:.+]] = arith.ceildivui %arg0, %[[C4]]
+// CHECK:         %[[PAD_D0:.+]] = arith.muli %[[CEIL_DIV_D0]], %[[C4]]
+// CHECK:         %[[CEIL_DIV_D1:.+]] = arith.ceildivui %arg1, %[[C16]]
+// CHECK:         %[[PAD_D1:.+]] = arith.muli %[[CEIL_DIV_D1]], %[[C16]]
+// CHECK:         %[[T0:.+]] = arith.muli %[[PAD_D0]], %[[C4]]
+// CHECK:         %[[T1:.+]] = arith.muli %[[T0]], %[[PAD_D1]]
+// CHECK:         return %[[T1]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+util.func public @sizeof_rhs_encoding_dynamic(%arg0: index, %arg1: index) -> index {
+  %0 = stream.tensor.sizeof tensor<?x?xf32, #iree_linalg_ext.encoding<role = RHS, element_types = [f32, f32, f32], original_type = tensor<?x?xf32>, user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 4, 8, 16>>>{%arg0, %arg1} : index
+  util.return %0 : index
+}
+// CHECK-LABEL: @sizeof_rhs_encoding_dynamic
+// CHECK-DAG:     %[[C4:.+]] = arith.constant 4 : index
+// CHECK-DAG:     %[[C8:.+]] = arith.constant 8 : index
+// CHECK-DAG:     %[[C16:.+]] = arith.constant 16 : index
+// CHECK:         %[[CEIL_DIV_D1:.+]] = arith.ceildivui %arg1, %[[C8]]
+// CHECK:         %[[PAD_D1:.+]] = arith.muli %[[CEIL_DIV_D1]], %[[C8]]
+// CHECK:         %[[CEIL_DIV_D0:.+]] = arith.ceildivui %arg0, %[[C16]]
+// CHECK:         %[[PAD_D0:.+]] = arith.muli %[[CEIL_DIV_D0]], %[[C16]]
+// CHECK:         %[[T0:.+]] = arith.muli %[[PAD_D0]], %[[C4]]
+// CHECK:         %[[T1:.+]] = arith.muli %[[T0]], %[[PAD_D1]]
+// CHECK:         return %[[T1]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+util.func public @sizeof_result_encoding_dynamic(%arg0: index, %arg1: index) -> index {
+  %0 = stream.tensor.sizeof tensor<?x?xf32, #iree_linalg_ext.encoding<role = RESULT, element_types = [f32, f32, f32], original_type = tensor<?x?xf32>, user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 4, 8, 16>>>{%arg0, %arg1} : index
+  util.return %0 : index
+}
+// CHECK-LABEL: @sizeof_result_encoding_dynamic
+// CHECK-DAG:     %[[C4:.+]] = arith.constant 4 : index
+// CHECK-DAG:     %[[C8:.+]] = arith.constant 8 : index
+// CHECK:         %[[CEIL_DIV_D0:.+]] = arith.ceildivui %arg0, %[[C4]]
+// CHECK:         %[[PAD_D0:.+]] = arith.muli %[[CEIL_DIV_D0]], %[[C4]]
+// CHECK:         %[[CEIL_DIV_D1:.+]] = arith.ceildivui %arg1, %[[C8]]
+// CHECK:         %[[PAD_D1:.+]] = arith.muli %[[CEIL_DIV_D1]], %[[C8]]
+// CHECK:         %[[T0:.+]] = arith.muli %[[PAD_D0]], %[[C4]]
+// CHECK:         %[[T1:.+]] = arith.muli %[[T0]], %[[PAD_D1]]
+// CHECK:         return %[[T1]]
+
+// -----
+
 // CHECK-LABEL: @denseTensorEmpty
 util.func public @denseTensorEmpty(%arg0: index, %arg1: index) -> !stream.resource<*> {
   // CHECK: %[[RET:.+]] = stream.async.alloca : !stream.resource<*>{%arg1}
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
index a6dc62b..94dae26 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
@@ -163,7 +163,10 @@
   if (transformOptions.options.dataTiling) {
     // TODO(hanchung): Make data-tiling passes be FunctionOpInterface pass, so
     // we can use `FunctionLikNest` here.
-    mainPassManager.addPass(createSetEncodingPass());
+    // TODO(hanchung): Make it controlable through flags. It is fine for now
+    // because it is an experimental path.
+    const int64_t kPadFactor = clEnableEarlyMaterialization ? 0 : 16;
+    mainPassManager.addPass(createSetEncodingPass(kPadFactor));
     if (clEnableEarlyMaterialization) {
       mainPassManager.addPass(createMaterializeHomogeneousEncodingsPass());
     }
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.h b/compiler/src/iree/compiler/GlobalOptimization/Passes.h
index 856a35e..e2f79ba 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Passes.h
+++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.h
@@ -108,8 +108,10 @@
 std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
 createRemoveZeroExtentTensorsPass();
 
-/// Sets encoding for tensors to allow tiled execution of operations.
-std::unique_ptr<Pass> createSetEncodingPass();
+/// Sets encoding for tensors to allow tiled execution of operations. If
+/// `padFactor` is set to non-zero, the padding sizes hint will be attached to
+/// encodings. It makes the host and device agree with the same padding sizes.
+std::unique_ptr<Pass> createSetEncodingPass(int64_t padFactor = 0);
 
 /// Simplifies tensor pack/unpack ops to reshape ops.
 std::unique_ptr<Pass> createSimplifyPackUnpackPass();
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.td b/compiler/src/iree/compiler/GlobalOptimization/Passes.td
index 5b4c01a..cf5e85b 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Passes.td
+++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.td
@@ -127,6 +127,13 @@
 def SetEncoding : Pass<"iree-global-opt-set-encoding", ""> {
   let summary = "Introduces tensor encoding for compute operations";
   let constructor = "mlir::iree_compiler::GlobalOptimization::createSetEncodingPass()";
+  let options = [
+    Option<"padFactor", "pad-factor", "int64_t", /*default=*/"0",
+           "The padding sizes hint will be attached to encodings if is it set"
+           "to non-zero. Otherwise, it creates"
+           "iree_linalg_ext.upper_bound_tile_size and rely on backends to"
+           "resolve them.">,
+  ];
 }
 
 def SimplifyPackUnpack : Pass<"iree-global-opt-simplify-pack-unpack", ""> {
diff --git a/compiler/src/iree/compiler/GlobalOptimization/SetEncoding.cpp b/compiler/src/iree/compiler/GlobalOptimization/SetEncoding.cpp
index 7b3380d..5e3760e 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/SetEncoding.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/SetEncoding.cpp
@@ -272,9 +272,13 @@
 
 namespace {
 
-struct setContractionOpEncoding
+class setContractionOpEncoding
     : public OpInterfaceRewritePattern<linalg::LinalgOp> {
+public:
   using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
+  explicit setContractionOpEncoding(MLIRContext *ctx, int64_t factor)
+      : OpInterfaceRewritePattern<linalg::LinalgOp>(ctx), padFactor(factor) {}
+
   LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
                                 PatternRewriter &rewriter) const override {
     if (!linalgOp.hasPureTensorSemantics()) {
@@ -320,12 +324,27 @@
 
     Location loc = linalgOp.getLoc();
     SmallVector<AffineMap> maps = linalgOp.getIndexingMapsArray();
-    Value encodedLhs = padAndSetEncoding(rewriter, loc, lhs, EncodingRole::LHS,
-                                         elemTypes, narrowSizes, maps);
-    Value encodedRhs = padAndSetEncoding(rewriter, loc, rhs, EncodingRole::RHS,
-                                         elemTypes, narrowSizes, maps);
-    Value encodedOut = padAndSetEncoding(
-        rewriter, loc, out, EncodingRole::RESULT, elemTypes, narrowSizes, maps);
+    Value encodedLhs, encodedRhs, encodedOut;
+
+    if (!padFactor) {
+      encodedLhs = padAndSetEncoding(rewriter, loc, lhs, EncodingRole::LHS,
+                                     elemTypes, narrowSizes, maps);
+      encodedRhs = padAndSetEncoding(rewriter, loc, rhs, EncodingRole::RHS,
+                                     elemTypes, narrowSizes, maps);
+      encodedOut = padAndSetEncoding(rewriter, loc, out, EncodingRole::RESULT,
+                                     elemTypes, narrowSizes, maps);
+    } else {
+      auto setEncodingWrapper = [&](Value src, EncodingRole role) -> Value {
+        SmallVector<int64_t> roundDimsTo(linalgOp.getNumLoops(), padFactor);
+        auto encoding = EncodingAttr::get(
+            linalgOp.getContext(), role, elemTypes, src.getType(),
+            narrowSizes.M, narrowSizes.N, maps, roundDimsTo);
+        return setEncoding(rewriter, loc, src, encoding);
+      };
+      encodedLhs = setEncodingWrapper(lhs, EncodingRole::LHS);
+      encodedRhs = setEncodingWrapper(rhs, EncodingRole::RHS);
+      encodedOut = setEncodingWrapper(out, EncodingRole::RESULT);
+    }
     Value opTiled = clone(rewriter, linalgOp, encodedOut.getType(),
                           ValueRange{encodedLhs, encodedRhs, encodedOut})
                         ->getResult(0);
@@ -344,6 +363,9 @@
     rewriter.replaceOp(linalgOp, result);
     return success();
   }
+
+private:
+  int64_t padFactor = 0;
 };
 
 /// Pattern to fold a `linalg.fill` -> `iree_linalg_ext.set_encoding`
@@ -376,6 +398,7 @@
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<IREE::LinalgExt::IREELinalgExtDialect>();
   }
+  explicit SetEncodingPass(int64_t factor) { this->padFactor.setValue(factor); }
 
   void runOnOperation() override;
 };
@@ -385,7 +408,7 @@
   MLIRContext *context = &getContext();
   {
     RewritePatternSet patterns(context);
-    patterns.insert<setContractionOpEncoding>(context);
+    patterns.insert<setContractionOpEncoding>(context, padFactor);
     linalg::FillOp::getCanonicalizationPatterns(patterns, context);
     patterns.insert<FoldFillWithSetEncoding>(context);
     memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
@@ -396,8 +419,8 @@
   }
 }
 
-std::unique_ptr<Pass> createSetEncodingPass() {
-  return std::make_unique<SetEncodingPass>();
+std::unique_ptr<Pass> createSetEncodingPass(int64_t padFactor) {
+  return std::make_unique<SetEncodingPass>(padFactor);
 }
 
 } // namespace mlir::iree_compiler::GlobalOptimization
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/set_encoding.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/set_encoding.mlir
index 2d95dcd..2abc02f 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/set_encoding.mlir
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/set_encoding.mlir
@@ -1,4 +1,5 @@
 // RUN: iree-opt --iree-global-opt-set-encoding --cse --split-input-file %s | FileCheck %s
+// RUN: iree-opt --iree-global-opt-set-encoding="pad-factor=16" --cse --split-input-file %s | FileCheck %s --check-prefix=PAD-WITHIN-ENCODING
 
 util.func public @matmul_f32f32f32(%arg0 : tensor<100x250xf32>, %arg1 : tensor<250x500xf32>,
     %arg2 : tensor<100x500xf32>) -> tensor<100x500xf32> {
@@ -44,6 +45,21 @@
 //      CHECK:   %[[RESULT_PADDED:.+]] = iree_linalg_ext.unset_encoding %[[MATMUL]]
 //      CHECK:   %[[RESULT:.+]] = tensor.extract_slice %[[RESULT_PADDED]][0, 0] [100, 500] [1, 1]
 //      CHECK:   util.return %[[RESULT]]
+// The only difference with `pad-factor` being set is creating pad ops or not.
+// Having a single test for now is okay, others are covered in the other path.
+// PAD-WITHIN-ENCODING-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// PAD-WITHIN-ENCODING-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// PAD-WITHIN-ENCODING-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// PAD-WITHIN-ENCODING:      util.func public @matmul_f32f32f32(
+// PAD-WITHIN-ENCODING-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
+// PAD-WITHIN-ENCODING-SAME:   %[[ARG1:[a-zA-Z0-9]+]]
+// PAD-WITHIN-ENCODING-SAME:   %[[ARG2:[a-zA-Z0-9]+]]
+// PAD-WITHIN-ENCODING:        %[[LHS:.+]] = iree_linalg_ext.set_encoding %[[ARG0]]
+// PAD-WITHIN-ENCODING-SAME:     tensor<100x250xf32, #iree_linalg_ext.encoding<role = LHS, element_types = [f32, f32, f32], original_type = tensor<100x250xf32>, user_indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]], round_dims_to = array<i64: 16, 16, 16>>>
+// PAD-WITHIN-ENCODING:        %[[RHS:.+]] = iree_linalg_ext.set_encoding %[[ARG1]]
+// PAD-WITHIN-ENCODING-SAME:     tensor<250x500xf32, #iree_linalg_ext.encoding<role = RHS, element_types = [f32, f32, f32], original_type = tensor<250x500xf32>, user_indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]], round_dims_to = array<i64: 16, 16, 16>>>
+// PAD-WITHIN-ENCODING:        %[[LHS:.+]] = iree_linalg_ext.set_encoding %[[ARG2]]
+// PAD-WITHIN-ENCODING-SAME:     tensor<100x500xf32, #iree_linalg_ext.encoding<role = RESULT, element_types = [f32, f32, f32], original_type = tensor<100x500xf32>, user_indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]], round_dims_to = array<i64: 16, 16, 16>>>
 
 // -----
 
diff --git a/compiler/src/iree/compiler/Utils/BUILD.bazel b/compiler/src/iree/compiler/Utils/BUILD.bazel
index c094aa1..a3e0869 100644
--- a/compiler/src/iree/compiler/Utils/BUILD.bazel
+++ b/compiler/src/iree/compiler/Utils/BUILD.bazel
@@ -49,6 +49,8 @@
     ],
     deps = [
         "//compiler/src/iree/compiler/Dialect/Util/IR",
+        #TODO(hanchung): Move encodings to Util/, so it does not need the dep.
+        "//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
         "//runtime/src/iree/base",
         "//runtime/src/iree/base/internal/flatcc:building",
         "//runtime/src/iree/base/internal/flatcc:debugging",
diff --git a/compiler/src/iree/compiler/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Utils/CMakeLists.txt
index 042552a..74c0d84 100644
--- a/compiler/src/iree/compiler/Utils/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Utils/CMakeLists.txt
@@ -57,6 +57,7 @@
     iree::base::internal::flatcc::building
     iree::base::internal::flatcc::debugging
     iree::base::internal::flatcc::parsing
+    iree::compiler::Dialect::LinalgExt::IR
     iree::compiler::Dialect::Util::IR
   PUBLIC
 )
diff --git a/compiler/src/iree/compiler/Utils/ElementPackingUtils.cpp b/compiler/src/iree/compiler/Utils/ElementPackingUtils.cpp
index 2d24d0b..07d2f40 100644
--- a/compiler/src/iree/compiler/Utils/ElementPackingUtils.cpp
+++ b/compiler/src/iree/compiler/Utils/ElementPackingUtils.cpp
@@ -6,6 +6,7 @@
 
 #include "iree/compiler/Utils/ElementPackingUtils.h"
 
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/MathExtras.h"
@@ -60,15 +61,54 @@
   if (!needToPackSubByteElementBitWidth(elementBits)) {
     staticCount *= IREE::Util::getRoundedElementByteWidth(alignedElementType);
   }
+
+  // TODO: Do we use makeComposedFoldedAffineApply here, so the index
+  // computation an be much simpler.
+  SmallVector<int64_t> paddedShape(shapedType.getShape());
+  SmallVector<Value> paddedDynamicDims(dynamicDims.begin(), dynamicDims.end());
+  auto encoding = IREE::LinalgExt::getEncodingAttr(shapedType);
+  if (encoding && !encoding.getRoundDimsToArray().empty()) {
+    auto roundDimsTo = encoding.getRoundDimsToArray();
+    FailureOr<linalg::ContractionDimensions> cDims =
+        IREE::LinalgExt::getEncodingContractionDims(encoding);
+    auto indexingMap = encoding.getMapForRole();
+    auto pad = [&](int dim, int value) {
+      std::optional<unsigned> maybeMappedDim =
+          indexingMap.getResultPosition(builder.getAffineDimExpr(dim));
+      if (!maybeMappedDim) {
+        return;
+      }
+      unsigned mappedDim = maybeMappedDim.value();
+      if (shapedType.isDynamicDim(mappedDim)) {
+        auto alignment = builder.create<arith::ConstantIndexOp>(loc, value);
+        paddedDynamicDims[mappedDim] = builder.create<arith::CeilDivUIOp>(
+            loc, paddedDynamicDims[mappedDim], alignment);
+        paddedDynamicDims[mappedDim] = builder.create<arith::MulIOp>(
+            loc, paddedDynamicDims[mappedDim], alignment);
+      } else {
+        paddedShape[mappedDim] = llvm::alignTo(paddedShape[mappedDim], value);
+      }
+    };
+    for (auto m : cDims->m) {
+      pad(m, roundDimsTo[0]);
+    }
+    for (auto n : cDims->n) {
+      pad(n, roundDimsTo[1]);
+    }
+    for (auto k : cDims->k) {
+      pad(k, roundDimsTo[2]);
+    }
+  }
+
   for (unsigned i = 0; i < shapedType.getRank(); ++i) {
     if (!shapedType.isDynamicDim(i))
-      staticCount *= shapedType.getDimSize(i);
+      staticCount *= paddedShape[i];
   }
 
   // Scale by dynamic dims, if present.
   auto value =
       builder.create<arith::ConstantIndexOp>(loc, staticCount).getResult();
-  for (auto dim : dynamicDims) {
+  for (auto dim : paddedDynamicDims) {
     value = builder.createOrFold<arith::MulIOp>(loc, value, dim);
   }
   // Sub-byte packing requires putting multiple elements in the same byte.
@@ -77,7 +117,7 @@
     unsigned byteElements = 8 / elementBits;
     // Perform some basic sanity check to make sure the total count is byte
     // aligned for fully static shapes.
-    if (dynamicDims.empty() && (staticCount * elementBits) % 8 != 0) {
+    if (paddedDynamicDims.empty() && (staticCount * elementBits) % 8 != 0) {
       return nullptr;
     }
     auto divisor = builder.create<arith::ConstantIndexOp>(loc, byteElements);
diff --git a/tests/e2e/matmul/BUILD.bazel b/tests/e2e/matmul/BUILD.bazel
index 9a5d57f..c8511d1 100644
--- a/tests/e2e/matmul/BUILD.bazel
+++ b/tests/e2e/matmul/BUILD.bazel
@@ -193,6 +193,83 @@
     "large",
 ]]
 
+# LLVMCPU, data-tiling, data-tiling + ukernels + late materialization.
+[iree_generated_e2e_runner_test(
+    name = "e2e_matmul_cpu_experimental_dt%s_%s_%s_%s" % (
+        ("_uk" if use_uk else ""),
+        lhs_rhs_type,
+        acc_type,
+        size,
+    ),
+    compiler_flags = [
+        "--iree-opt-data-tiling",
+        "--iree-global-opt-enable-early-materialization=false",
+    ] + ["--iree-llvmcpu-enable-ukernels=%s" % ("all" if use_uk else "none")],
+    generator = ":generate_e2e_matmul_tests",
+    generator_args = [
+        "--lhs_rhs_type=%s" % lhs_rhs_type,
+        "--acc_type=%s" % acc_type,
+        "--shapes=%s" % size,
+    ],
+    tags = ([
+        # "--shapes=large" can cause timeouts on sanitizers.
+        "noasan",
+        "notsan",
+    ] if size == "large" else []) + ([
+        # "--shapes=large" can cause timeouts on RISC-V emulator.
+        # f16/bf16 trigger internal LLVM assertion errors on riscv and wasm.
+        "noriscv",
+        "nowasm",
+    ] if (lhs_rhs_type == "f16" or lhs_rhs_type == "bf16") else []),
+    target_backends_and_drivers = [
+        ("llvm-cpu", "local-task"),
+    ],
+    target_cpu_features_variants = ["default"] +
+                                   ([
+                                       "arm_64:dotprod:+dotprod",
+                                       "arm_64:i8mm:+i8mm",
+                                       "x86_64:avx512vnni:" + ",".join(X86_64_AVX512_VNNI),
+                                   ] if lhs_rhs_type == "i8" and acc_type == "i32" else [
+                                       "x86_64:avx2:" + ",".join(X86_64_AVX2),
+                                       "x86_64:avx512:" + ",".join(X86_64_AVX512),
+                                   ] if lhs_rhs_type == "f32" and acc_type == "f32" else [
+                                       "x86_64:avx2:" + ",".join(X86_64_AVX2),
+                                       "x86_64:avx512:" + ",".join(X86_64_AVX512),
+                                       "arm_64:fullfp16:+fullfp16",
+                                   ] if lhs_rhs_type == "f16" and acc_type == "f16" else [
+                                       "x86_64:avx2:" + ",".join(X86_64_AVX2),
+                                       "x86_64:avx512:" + ",".join(X86_64_AVX512),
+                                       "arm_64:fp16fml:+fp16fml",
+                                   ] if lhs_rhs_type == "f16" and acc_type == "f32" else [
+                                       "x86_64:avx2:" + ",".join(X86_64_AVX2),
+                                       "x86_64:avx512:" + ",".join(X86_64_AVX512),
+                                       "x86_64:avx512bf16:" + ",".join(X86_64_AVX512_BF16),
+                                       "arm_64:bf16:+bf16",
+                                   ] if lhs_rhs_type == "bf16" and acc_type == "bf16" else [
+                                       "x86_64:avx2:" + ",".join(X86_64_AVX2),
+                                       "x86_64:avx512:" + ",".join(X86_64_AVX512),
+                                       "x86_64:avx512bf16:" + ",".join(X86_64_AVX512_BF16),
+                                       "arm_64:bf16:+bf16",
+                                   ] if lhs_rhs_type == "bf16" and acc_type == "f32" else []),
+    test_runner = "//tools/testing/e2e:iree-e2e-matmul-test",
+    test_type = "matmul",
+) for use_uk in [
+    False,
+    True,
+] for (lhs_rhs_type, acc_type) in (
+    [
+        ("i8", "i32"),
+        ("f32", "f32"),
+        ("f16", "f16"),
+        ("f16", "f32"),
+        ("bf16", "bf16"),
+        ("bf16", "f32"),
+    ]
+) for size in [
+    "small",
+    "large",
+]]
+
 ###########################################################################
 ##
 ## VMVX backend
diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt
index df94499..f5fa735 100644
--- a/tests/e2e/matmul/CMakeLists.txt
+++ b/tests/e2e/matmul/CMakeLists.txt
@@ -800,6 +800,766 @@
 
 iree_generated_e2e_runner_test(
   NAME
+    e2e_matmul_cpu_experimental_dt_i8_i32_small
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=i8"
+    "--acc_type=i32"
+    "--shapes=small"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling"
+    "--iree-global-opt-enable-early-materialization=false"
+    "--iree-llvmcpu-enable-ukernels=none"
+  LABELS
+
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "arm_64:dotprod:+dotprod"
+    "arm_64:i8mm:+i8mm"
+    "x86_64:avx512vnni:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq,+avx512vnni"
+)
+
+iree_generated_e2e_runner_test(
+  NAME
+    e2e_matmul_cpu_experimental_dt_i8_i32_large
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=i8"
+    "--acc_type=i32"
+    "--shapes=large"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling"
+    "--iree-global-opt-enable-early-materialization=false"
+    "--iree-llvmcpu-enable-ukernels=none"
+  LABELS
+    "noasan"
+    "notsan"
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "arm_64:dotprod:+dotprod"
+    "arm_64:i8mm:+i8mm"
+    "x86_64:avx512vnni:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq,+avx512vnni"
+)
+
+iree_generated_e2e_runner_test(
+  NAME
+    e2e_matmul_cpu_experimental_dt_f32_f32_small
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=f32"
+    "--acc_type=f32"
+    "--shapes=small"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling"
+    "--iree-global-opt-enable-early-materialization=false"
+    "--iree-llvmcpu-enable-ukernels=none"
+  LABELS
+
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+    "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+)
+
+iree_generated_e2e_runner_test(
+  NAME
+    e2e_matmul_cpu_experimental_dt_f32_f32_large
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=f32"
+    "--acc_type=f32"
+    "--shapes=large"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling"
+    "--iree-global-opt-enable-early-materialization=false"
+    "--iree-llvmcpu-enable-ukernels=none"
+  LABELS
+    "noasan"
+    "notsan"
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+    "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+)
+
+iree_generated_e2e_runner_test(
+  NAME
+    e2e_matmul_cpu_experimental_dt_f16_f16_small
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=f16"
+    "--acc_type=f16"
+    "--shapes=small"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling"
+    "--iree-global-opt-enable-early-materialization=false"
+    "--iree-llvmcpu-enable-ukernels=none"
+  LABELS
+    "noriscv"
+    "nowasm"
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+    "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+    "arm_64:fullfp16:+fullfp16"
+)
+
+iree_generated_e2e_runner_test(
+  NAME
+    e2e_matmul_cpu_experimental_dt_f16_f16_large
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=f16"
+    "--acc_type=f16"
+    "--shapes=large"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling"
+    "--iree-global-opt-enable-early-materialization=false"
+    "--iree-llvmcpu-enable-ukernels=none"
+  LABELS
+    "noasan"
+    "notsan"
+    "noriscv"
+    "nowasm"
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+    "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+    "arm_64:fullfp16:+fullfp16"
+)
+
+iree_generated_e2e_runner_test(
+  NAME
+    e2e_matmul_cpu_experimental_dt_f16_f32_small
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=f16"
+    "--acc_type=f32"
+    "--shapes=small"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling"
+    "--iree-global-opt-enable-early-materialization=false"
+    "--iree-llvmcpu-enable-ukernels=none"
+  LABELS
+    "noriscv"
+    "nowasm"
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+    "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+    "arm_64:fp16fml:+fp16fml"
+)
+
+iree_generated_e2e_runner_test(
+  NAME
+    e2e_matmul_cpu_experimental_dt_f16_f32_large
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=f16"
+    "--acc_type=f32"
+    "--shapes=large"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling"
+    "--iree-global-opt-enable-early-materialization=false"
+    "--iree-llvmcpu-enable-ukernels=none"
+  LABELS
+    "noasan"
+    "notsan"
+    "noriscv"
+    "nowasm"
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+    "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+    "arm_64:fp16fml:+fp16fml"
+)
+
+iree_generated_e2e_runner_test(
+  NAME
+    e2e_matmul_cpu_experimental_dt_bf16_bf16_small
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=bf16"
+    "--acc_type=bf16"
+    "--shapes=small"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling"
+    "--iree-global-opt-enable-early-materialization=false"
+    "--iree-llvmcpu-enable-ukernels=none"
+  LABELS
+    "noriscv"
+    "nowasm"
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+    "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+    "x86_64:avx512bf16:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq,+avx512bf16"
+    "arm_64:bf16:+bf16"
+)
+
+iree_generated_e2e_runner_test(
+  NAME
+    e2e_matmul_cpu_experimental_dt_bf16_bf16_large
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=bf16"
+    "--acc_type=bf16"
+    "--shapes=large"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling"
+    "--iree-global-opt-enable-early-materialization=false"
+    "--iree-llvmcpu-enable-ukernels=none"
+  LABELS
+    "noasan"
+    "notsan"
+    "noriscv"
+    "nowasm"
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+    "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+    "x86_64:avx512bf16:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq,+avx512bf16"
+    "arm_64:bf16:+bf16"
+)
+
+iree_generated_e2e_runner_test(
+  NAME
+    e2e_matmul_cpu_experimental_dt_bf16_f32_small
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=bf16"
+    "--acc_type=f32"
+    "--shapes=small"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling"
+    "--iree-global-opt-enable-early-materialization=false"
+    "--iree-llvmcpu-enable-ukernels=none"
+  LABELS
+    "noriscv"
+    "nowasm"
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+    "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+    "x86_64:avx512bf16:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq,+avx512bf16"
+    "arm_64:bf16:+bf16"
+)
+
+iree_generated_e2e_runner_test(
+  NAME
+    e2e_matmul_cpu_experimental_dt_bf16_f32_large
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=bf16"
+    "--acc_type=f32"
+    "--shapes=large"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling"
+    "--iree-global-opt-enable-early-materialization=false"
+    "--iree-llvmcpu-enable-ukernels=none"
+  LABELS
+    "noasan"
+    "notsan"
+    "noriscv"
+    "nowasm"
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+    "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+    "x86_64:avx512bf16:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq,+avx512bf16"
+    "arm_64:bf16:+bf16"
+)
+
+iree_generated_e2e_runner_test(
+  NAME
+    e2e_matmul_cpu_experimental_dt_uk_i8_i32_small
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=i8"
+    "--acc_type=i32"
+    "--shapes=small"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling"
+    "--iree-global-opt-enable-early-materialization=false"
+    "--iree-llvmcpu-enable-ukernels=all"
+  LABELS
+
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "arm_64:dotprod:+dotprod"
+    "arm_64:i8mm:+i8mm"
+    "x86_64:avx512vnni:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq,+avx512vnni"
+)
+
+iree_generated_e2e_runner_test(
+  NAME
+    e2e_matmul_cpu_experimental_dt_uk_i8_i32_large
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=i8"
+    "--acc_type=i32"
+    "--shapes=large"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling"
+    "--iree-global-opt-enable-early-materialization=false"
+    "--iree-llvmcpu-enable-ukernels=all"
+  LABELS
+    "noasan"
+    "notsan"
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "arm_64:dotprod:+dotprod"
+    "arm_64:i8mm:+i8mm"
+    "x86_64:avx512vnni:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq,+avx512vnni"
+)
+
+iree_generated_e2e_runner_test(
+  NAME
+    e2e_matmul_cpu_experimental_dt_uk_f32_f32_small
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=f32"
+    "--acc_type=f32"
+    "--shapes=small"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling"
+    "--iree-global-opt-enable-early-materialization=false"
+    "--iree-llvmcpu-enable-ukernels=all"
+  LABELS
+
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+    "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+)
+
+iree_generated_e2e_runner_test(
+  NAME
+    e2e_matmul_cpu_experimental_dt_uk_f32_f32_large
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=f32"
+    "--acc_type=f32"
+    "--shapes=large"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling"
+    "--iree-global-opt-enable-early-materialization=false"
+    "--iree-llvmcpu-enable-ukernels=all"
+  LABELS
+    "noasan"
+    "notsan"
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+    "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+)
+
+iree_generated_e2e_runner_test(
+  NAME
+    e2e_matmul_cpu_experimental_dt_uk_f16_f16_small
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=f16"
+    "--acc_type=f16"
+    "--shapes=small"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling"
+    "--iree-global-opt-enable-early-materialization=false"
+    "--iree-llvmcpu-enable-ukernels=all"
+  LABELS
+    "noriscv"
+    "nowasm"
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+    "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+    "arm_64:fullfp16:+fullfp16"
+)
+
+iree_generated_e2e_runner_test(
+  NAME
+    e2e_matmul_cpu_experimental_dt_uk_f16_f16_large
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=f16"
+    "--acc_type=f16"
+    "--shapes=large"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling"
+    "--iree-global-opt-enable-early-materialization=false"
+    "--iree-llvmcpu-enable-ukernels=all"
+  LABELS
+    "noasan"
+    "notsan"
+    "noriscv"
+    "nowasm"
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+    "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+    "arm_64:fullfp16:+fullfp16"
+)
+
+iree_generated_e2e_runner_test(
+  NAME
+    e2e_matmul_cpu_experimental_dt_uk_f16_f32_small
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=f16"
+    "--acc_type=f32"
+    "--shapes=small"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling"
+    "--iree-global-opt-enable-early-materialization=false"
+    "--iree-llvmcpu-enable-ukernels=all"
+  LABELS
+    "noriscv"
+    "nowasm"
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+    "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+    "arm_64:fp16fml:+fp16fml"
+)
+
+iree_generated_e2e_runner_test(
+  NAME
+    e2e_matmul_cpu_experimental_dt_uk_f16_f32_large
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=f16"
+    "--acc_type=f32"
+    "--shapes=large"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling"
+    "--iree-global-opt-enable-early-materialization=false"
+    "--iree-llvmcpu-enable-ukernels=all"
+  LABELS
+    "noasan"
+    "notsan"
+    "noriscv"
+    "nowasm"
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+    "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+    "arm_64:fp16fml:+fp16fml"
+)
+
+iree_generated_e2e_runner_test(
+  NAME
+    e2e_matmul_cpu_experimental_dt_uk_bf16_bf16_small
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=bf16"
+    "--acc_type=bf16"
+    "--shapes=small"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling"
+    "--iree-global-opt-enable-early-materialization=false"
+    "--iree-llvmcpu-enable-ukernels=all"
+  LABELS
+    "noriscv"
+    "nowasm"
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+    "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+    "x86_64:avx512bf16:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq,+avx512bf16"
+    "arm_64:bf16:+bf16"
+)
+
+iree_generated_e2e_runner_test(
+  NAME
+    e2e_matmul_cpu_experimental_dt_uk_bf16_bf16_large
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=bf16"
+    "--acc_type=bf16"
+    "--shapes=large"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling"
+    "--iree-global-opt-enable-early-materialization=false"
+    "--iree-llvmcpu-enable-ukernels=all"
+  LABELS
+    "noasan"
+    "notsan"
+    "noriscv"
+    "nowasm"
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+    "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+    "x86_64:avx512bf16:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq,+avx512bf16"
+    "arm_64:bf16:+bf16"
+)
+
+iree_generated_e2e_runner_test(
+  NAME
+    e2e_matmul_cpu_experimental_dt_uk_bf16_f32_small
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=bf16"
+    "--acc_type=f32"
+    "--shapes=small"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling"
+    "--iree-global-opt-enable-early-materialization=false"
+    "--iree-llvmcpu-enable-ukernels=all"
+  LABELS
+    "noriscv"
+    "nowasm"
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+    "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+    "x86_64:avx512bf16:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq,+avx512bf16"
+    "arm_64:bf16:+bf16"
+)
+
+iree_generated_e2e_runner_test(
+  NAME
+    e2e_matmul_cpu_experimental_dt_uk_bf16_f32_large
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=bf16"
+    "--acc_type=f32"
+    "--shapes=large"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling"
+    "--iree-global-opt-enable-early-materialization=false"
+    "--iree-llvmcpu-enable-ukernels=all"
+  LABELS
+    "noasan"
+    "notsan"
+    "noriscv"
+    "nowasm"
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+    "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+    "x86_64:avx512bf16:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq,+avx512bf16"
+    "arm_64:bf16:+bf16"
+)
+
+iree_generated_e2e_runner_test(
+  NAME
     e2e_matmul_vmvx_dt_uk_i8_small
   TEST_TYPE
     matmul