[CPU] Retire CPUDoubleTilingPadExpert pipeline. (#15931)

- Remove the expert and related methods.
- Rename `setMatmulNoPadRootConfig` to `setMatmulRootConfig`
- Replace preset CPUDoubleTilingPadExpert compilation_info with
CPUDoubleTilingExpert`
- Delete `e2e_matmul_nondt_f32_small_no_padding` test suite. It was
added when we was enabling pad expert.
- Inline `getNoPadTilingExpert` because it is only used by
`setMatmulRootConfig`.
diff --git a/build_tools/cmake/test_riscv.sh b/build_tools/cmake/test_riscv.sh
index 53fdec4..7d6375e 100755
--- a/build_tools/cmake/test_riscv.sh
+++ b/build_tools/cmake/test_riscv.sh
@@ -97,7 +97,6 @@
   "iree/tests/e2e/stablehlo_ops/check_llvm-cpu_local-task_dot.mlir"
   "iree/tests/e2e/matmul/e2e_matmul_direct_i8_small_llvm-cpu_local-task"
   "iree/tests/e2e/matmul/e2e_matmul_direct_f32_small_llvm-cpu_local-task"
-  "iree/tests/e2e/matmul/e2e_matmul_direct_f32_small_no_padding_llvm-cpu_local-task"
   "iree/tests/e2e/regression/check_regression_llvm-cpu_strided_slice.mlir"
 )
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/materialize_user_configs.mlir b/compiler/src/iree/compiler/Codegen/Common/test/materialize_user_configs.mlir
index 022db7f..c27b743 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/materialize_user_configs.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/materialize_user_configs.mlir
@@ -2,7 +2,7 @@
 
 #compilation = #iree_codegen.compilation_info<
     lowering_config = <tile_sizes = [[64, 64, 0], [32, 32, 0], [0, 0, 32], [0, 0, 0]]>,
-    translation_info  = <CPUDoubleTilingPadExpert>>
+    translation_info  = <CPUDoubleTilingExpert>>
 #pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
   #hal.descriptor_set.layout<0, bindings = [
     #hal.descriptor_set.binding<0, storage_buffer>,
@@ -39,7 +39,7 @@
   }
 }
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64, 0], [32, 32, 0], [0, 0, 32], [0, 0, 0]]>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDoubleTilingPadExpert>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDoubleTilingExpert>
 //      CHECK: hal.executable.export
 // CHECK-SAME:     translation_info = #[[TRANSLATION]]
 //      CHECK: func.func @preset_config
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/IREECodegenAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/IREECodegenAttrs.td
index 40a722d..2666814 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/IREECodegenAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/IREECodegenAttrs.td
@@ -15,18 +15,16 @@
     : I32EnumAttrCase<"CPUDefault", 0>;
 def CPU_DoubleTilingExpert
     : I32EnumAttrCase<"CPUDoubleTilingExpert", 1>;
-def CPU_DoubleTilingPadExpert
-    : I32EnumAttrCase<"CPUDoubleTilingPadExpert", 2>;
 def CPU_DoubleTilingPeelingExpert
-    : I32EnumAttrCase<"CPUDoubleTilingPeelingExpert", 3>;
+    : I32EnumAttrCase<"CPUDoubleTilingPeelingExpert", 2>;
 def CPU_ConvTileAndDecomposeExpert
-    : I32EnumAttrCase<"CPUConvTileAndDecomposeExpert", 4>;
+    : I32EnumAttrCase<"CPUConvTileAndDecomposeExpert", 3>;
 def CPU_Mmt4dTilingExpert
-    : I32EnumAttrCase<"Mmt4dTilingExpert", 5>;
+    : I32EnumAttrCase<"Mmt4dTilingExpert", 4>;
 def CPU_BufferOpsTileAndVectorize
-    : I32EnumAttrCase<"CPUBufferOpsTileAndVectorize", 6>;
+    : I32EnumAttrCase<"CPUBufferOpsTileAndVectorize", 5>;
 def CPU_DataTiling
-    : I32EnumAttrCase<"CPUDataTiling", 7>;
+    : I32EnumAttrCase<"CPUDataTiling", 6>;
 
 def LLVMGPU_Default
     : I32EnumAttrCase<"LLVMGPUDefault", 100>;
@@ -76,7 +74,7 @@
   "DispatchLoweringPassPipeline",
   "identifier for pass pipeline use to lower dispatch region", [
     // CPU CodeGen pipelines
-    CPU_Default, CPU_DoubleTilingExpert, CPU_DoubleTilingPadExpert,
+    CPU_Default, CPU_DoubleTilingExpert,
     CPU_DoubleTilingPeelingExpert, CPU_ConvTileAndDecomposeExpert,
     CPU_Mmt4dTilingExpert, CPU_BufferOpsTileAndVectorize,
     CPU_DataTiling,
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
index 577de33..dfece91 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
@@ -76,14 +76,6 @@
                    "in data-tiled matmuls (mmt4d)."),
     llvm::cl::init(64 * 1024));
 
-// TODO(hanchung): Remove the flag. This is the flag for fastly falling back to
-// the previous snapshot.
-
-static llvm::cl::opt<bool>
-    enableVectorPadding("iree-codegen-enable-vector-padding",
-                        llvm::cl::desc("Enable padding for vectorization"),
-                        llvm::cl::init(true));
-
 static llvm::cl::opt<bool>
     enableVectorPeeling("iree-codegen-enable-vector-peeling",
                         llvm::cl::desc("Enable peeling for vectorization"),
@@ -100,9 +92,6 @@
 // Encodes the pre-processing strategy to be applied on a Linalg operation
 // before vectorization.
 enum class VectorPreProcStrategy {
-  // Pad vector dimensions of tensors so that they are multiple of the vector
-  // length.
-  Padding,
   // Peel iterations from the vector dimensions so that they become multiple of
   // the vector length.
   Peeling,
@@ -118,9 +107,6 @@
 static llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
                                      const VectorPreProcStrategy &strategy) {
   switch (strategy) {
-  case VectorPreProcStrategy::Padding:
-    os << "Padding";
-    break;
   case VectorPreProcStrategy::Peeling:
     os << "Peeling";
     break;
@@ -183,7 +169,7 @@
                       [](int64_t size) { return ShapedType::isDynamic(size); });
 }
 
-/// Returns the vectorization pre-processing strategy (padding, peeling) for the
+/// Returns the vectorization pre-processing strategy (peeling, masking) for the
 /// given LinalgOp, depending on the op traits and the target architecture.
 static VectorPreProcStrategy
 getVectorPreProcStrategy(linalg::LinalgOp linalgOp) {
@@ -597,9 +583,9 @@
 /// hints != 1, it will try to find the tile sizes which are multipliers of the
 /// hints.
 ///
-/// TODO(hanchung): Remove `allowIncompleteTile` option after codegen can handle
-/// padding/peeling for all the kernels. Allowing incomplete tile is critical
-/// for odd shapes (e.g., some dim sizes could be prime number).
+/// TODO(hanchung): Remove `allowIncompleteTile` option after codegen can
+/// vectorize all the ops. Allowing incomplete tile is critical for odd shapes
+/// (e.g., some dim sizes could be prime number).
 static SmallVector<int64_t> getDefaultDistributedLevelTileSizes(
     ArrayRef<unsigned> partitionableLoops, ArrayRef<int64_t> lbs,
     ArrayRef<int64_t> ubs, ArrayRef<int64_t> minTileSizes,
@@ -791,7 +777,7 @@
     parallelTileSizes[index] = std::min(parallelTileSizes[index], size);
   }
 
-  // TODO(hanchung): Make logic more heuristic. Padding hurts performance a lot
+  // TODO(hanchung): Make logic more heuristic. Peeling hurts performance a lot
   // if the dim size is small (e.g., K=24).
   int64_t numTilingDims = vecTileSizes.size();
   SmallVector<int64_t> reductionTileSizes(numTilingDims - 1, 0);
@@ -817,19 +803,12 @@
       DispatchLoweringPassPipeline::CPUDoubleTilingPeelingExpert);
 }
 
-static DispatchLoweringPassPipeline
-getNoPadTilingExpert(VectorPreProcStrategy strategy) {
-  if (strategy == VectorPreProcStrategy::Peeling) {
-    return DispatchLoweringPassPipeline::CPUDoubleTilingPeelingExpert;
-  }
-  return DispatchLoweringPassPipeline::CPUDoubleTilingExpert;
-}
-
-static LogicalResult setMatmulNoPadRootConfig(
-    func::FuncOp entryPointFn, linalg::ContractionOpInterface op,
-    const TileSizesListTypeRef inputTileSizes,
-    const ScalableTileFlagsListTypeRef inputScalableTileFlags, int vectorSize,
-    VectorPreProcStrategy vecPreProcStrategy) {
+static LogicalResult
+setMatmulRootConfig(func::FuncOp entryPointFn,
+                    linalg::ContractionOpInterface op,
+                    const TileSizesListTypeRef inputTileSizes,
+                    const ScalableTileFlagsListTypeRef inputScalableTileFlags,
+                    int vectorSize, VectorPreProcStrategy vecPreProcStrategy) {
   auto linalgOp = cast<linalg::LinalgOp>(op.getOperation());
   SmallVector<int64_t> shape = linalgOp.getStaticLoopRanges();
 
@@ -899,15 +878,17 @@
   // No scalable inner parallel dims.
   newScalableTileFlags.emplace_back(numTilingDims, false);
 
-  LLVM_DEBUG(
-      KD_DBGS() << "Final tile sizes for non-padding contraction: "
-                << newTileSizes << "\n"
-                << "Final tile scalable flags for no-padding contraction: "
-                << newScalableTileFlags << "\n");
+  LLVM_DEBUG(KD_DBGS() << "Final tile sizes for contraction: " << newTileSizes
+                       << "\n"
+                       << "Final tile scalable flags for contraction: "
+                       << newScalableTileFlags << "\n");
 
-  return setOpConfigAndEntryPointFnTranslation(
-      entryPointFn, op, newTileSizes, newScalableTileFlags,
-      getNoPadTilingExpert(vecPreProcStrategy));
+  auto pipeline = DispatchLoweringPassPipeline::CPUDoubleTilingExpert;
+  if (vecPreProcStrategy == VectorPreProcStrategy::Peeling) {
+    pipeline = DispatchLoweringPassPipeline::CPUDoubleTilingPeelingExpert;
+  }
+  return setOpConfigAndEntryPointFnTranslation(entryPointFn, op, newTileSizes,
+                                               newScalableTileFlags, pipeline);
 }
 
 /// Returns default hard-coded vector sizes for a give target. No smartness
@@ -1146,9 +1127,6 @@
     minTileSizes.push_back(minTileSize);
   }
 
-  // There are hard-coded configurations in DoubleTilingPadExpert, so it only
-  // works for linalg.matmul cases. We can relax it once we have better
-  // scheduling, e.g., transform dialect.
   SmallVector<int64_t> distTileSizes;
   auto vecPreProcStrategy = getVectorPreProcStrategy(linalgOp);
   bool usePeelingPipeline =
@@ -1203,9 +1181,8 @@
   TileSizesListType tileSizes = {distTileSizes, vecTileSizes};
   ScalableTileFlagsListType scalableTileFlags = {distScalableTileFlags,
                                                  vecScalableFlags};
-  return setMatmulNoPadRootConfig(entryPointFn, contractionOp, tileSizes,
-                                  scalableTileFlags, vectorSize,
-                                  vecPreProcStrategy);
+  return setMatmulRootConfig(entryPointFn, contractionOp, tileSizes,
+                             scalableTileFlags, vectorSize, vecPreProcStrategy);
 }
 
 static TileSizesListType getMmt4dTileSizes(linalg::LinalgOp op) {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp
index 47b4433..0f4e41e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp
@@ -175,12 +175,6 @@
                                      enableVectorMasking, lowerToAVX2);
     break;
   }
-  case IREE::Codegen::DispatchLoweringPassPipeline::CPUDoubleTilingPadExpert: {
-    TilingConfig tilingConfig = getTilingConfigForPipeline(moduleOp);
-    addDoubleTilingPadExpertPassPipeline(pipeline, tilingConfig,
-                                         enableVectorMasking);
-    break;
-  }
   case IREE::Codegen::DispatchLoweringPassPipeline::
       CPUDoubleTilingPeelingExpert: {
     TilingConfig tilingConfig = getTilingConfigForPipeline(moduleOp);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSelectLoweringStrategy.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSelectLoweringStrategy.cpp
index 2f6c24a..3d85e16 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSelectLoweringStrategy.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSelectLoweringStrategy.cpp
@@ -100,7 +100,6 @@
   LogicalResult verificationStatus = success();
   switch (translationInfo.value().getDispatchLoweringPassPipeline()) {
   case IREE::Codegen::DispatchLoweringPassPipeline::CPUDoubleTilingExpert:
-  case IREE::Codegen::DispatchLoweringPassPipeline::CPUDoubleTilingPadExpert:
     verificationStatus =
         verifyLoweringConfiguration(moduleOp, translationInfo.value(),
                                     verifyDoubleTilingExpertPassPipelineConfig);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index 1934cb9..b0fbf23 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -125,16 +125,10 @@
 
   // Verify that the translation info is using the right pipeline.
   if (translationInfo.getDispatchLoweringPassPipeline() !=
-          IREE::Codegen::DispatchLoweringPassPipeline::CPUDoubleTilingExpert &&
-      translationInfo.getDispatchLoweringPassPipeline() !=
-          IREE::Codegen::DispatchLoweringPassPipeline::
-              CPUDoubleTilingPadExpert) {
+      IREE::Codegen::DispatchLoweringPassPipeline::CPUDoubleTilingExpert) {
     return op->emitOpError("expected pipeline in translation_info to be ")
            << stringifyEnum(IREE::Codegen::DispatchLoweringPassPipeline::
-                                CPUDoubleTilingExpert)
-           << " or "
-           << stringifyEnum(IREE::Codegen::DispatchLoweringPassPipeline::
-                                CPUDoubleTilingPadExpert);
+                                CPUDoubleTilingExpert);
   }
 
   if (tilingConfig.getNumTilingLevels() == 6) {
@@ -324,49 +318,6 @@
             mlir::arm_sme::ArmStreamingMode::StreamingLocally));
 }
 
-void addDoubleTilingPadExpertPassPipeline(OpPassManager &passManager,
-                                          TilingConfig &tilingConfig,
-                                          bool enableVectorMasking) {
-  addTileAndDistributePasses(passManager);
-
-  OpPassManager &nestedModulePM = passManager.nest<ModuleOp>();
-  nestedModulePM.addNestedPass<func::FuncOp>(createLLVMCPUTileAndFusePass(
-      tilingConfig.getVectorCommonParallelLevel()));
-  nestedModulePM.addNestedPass<func::FuncOp>(
-      createLLVMCPUTensorPadPass(LLVMCPUTensorPadOption::ParallelDims));
-
-  nestedModulePM.addNestedPass<func::FuncOp>(
-      createLLVMCPUTilePass(tilingConfig.getVectorReductionLevel()));
-  nestedModulePM.addNestedPass<func::FuncOp>(
-      createLLVMCPUTensorPadPass(LLVMCPUTensorPadOption::ReductionDims));
-
-  {
-    GenericVectorizationPassOptions options;
-    options.enableVectorMasking = enableVectorMasking;
-    options.vectorizePadding = true;
-    options.vectorizeGatherAccesses = true;
-    nestedModulePM.addNestedPass<func::FuncOp>(
-        createGenericVectorizationPass(options));
-    nestedModulePM.addNestedPass<func::FuncOp>(
-        createHoistRedundantVectorTransfersPass());
-    nestedModulePM.addNestedPass<func::FuncOp>(createCanonicalizerPass());
-    nestedModulePM.addNestedPass<func::FuncOp>(createCSEPass());
-  }
-
-  addCPUBufferizePasses(nestedModulePM);
-
-  // Run IREE specific passes before vector lowering expert.
-  nestedModulePM.addNestedPass<func::FuncOp>(
-      createRemoveSingleIterationLoopPass());
-
-  {
-    LLVMCPUVectorLoweringPassOptions options;
-    options.splitVectorTransfersTo = "linalg-copy";
-    nestedModulePM.addNestedPass<func::FuncOp>(
-        createLLVMCPUVectorLoweringPass(options));
-  }
-}
-
 void addMultiTilingExpertPassPipeline(
     OpPassManager &passManager, TilingConfig &tilingConfig, bool enablePeeling,
     bool enableVectorMasking, bool lowerToAVX2, bool enableAArch64SSVE) {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
index 394f8e2..75b82de 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
@@ -135,10 +135,6 @@
                                                bool enableVectorMasking,
                                                bool enableAArch64SSVE = false);
 
-void addDoubleTilingPadExpertPassPipeline(OpPassManager &passManager,
-                                          TilingConfig &tilingConfig,
-                                          bool enableVectorMasking);
-
 /// Populates the passes needed to multi level tile, fuse and vectorize
 /// lowering of linalg ops on tensors to vectors operations.
 void addMmt4dTilingExpertPassPipeline(OpPassManager &passManager,
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/illegal_configuration.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/illegal_configuration.mlir
index ad90f6c..98bd3e1 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/illegal_configuration.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/illegal_configuration.mlir
@@ -145,37 +145,6 @@
 
 // -----
 
-// The constraints of CPUDoubleTilingPadExpert is as same as
-// CPUDoubleTilingExpert, checking one test it enough.
-#config = #iree_codegen.lowering_config<tile_sizes = [[64, 64], [8, 32, 16], [0, 0, 16], [0, 0, 0]]>
-#translation = #iree_codegen.translation_info<CPUDoubleTilingPadExpert>
-#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
-  #hal.descriptor_set.layout<0, bindings = [
-    #hal.descriptor_set.binding<0, storage_buffer>,
-    #hal.descriptor_set.binding<1, storage_buffer>,
-    #hal.descriptor_set.binding<2, storage_buffer>
-  ]>
-]>
-hal.executable private @matmul_tensors {
-  hal.executable.variant @llvm target(#hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {}>) {
-    hal.executable.export @illegal layout(#pipeline_layout) attributes {translation_info = #translation}
-    builtin.module {
-      func.func @illegal() {
-        %c0 = arith.constant 0 : index
-        %lhs = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<4x8xf32>
-        %rhs = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : memref<8x16xf32>
-        %result = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : memref<4x16xf32>
-        // expected-error @+1 {{expected only parallel dims to be set in the second tiling level, got 2-th tile size set}}
-        linalg.matmul {lowering_config = #config} ins(%lhs, %rhs : memref<4x8xf32>, memref<8x16xf32>)
-          outs(%result: memref<4x16xf32>)
-        return
-      }
-    }
-  }
-}
-
-// -----
-
 #config = #iree_codegen.lowering_config<tile_sizes = [[0, 7, 7, 64, 0, 0, 0], [6, 1, 7, 32, 0, 0, 0], [0, 0, 0, 0, 3, 3, 4], [0, 0, 0, 0, 0, 0, 0]]>
 #translation = #iree_codegen.translation_info<CPUConvTileAndDecomposeExpert>
 #pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_tests.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_tests.mlir
index 7a9ab67..cac9fb0 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_tests.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_tests.mlir
@@ -65,108 +65,6 @@
 
 // -----
 
-// Checks that the ops are padded and vectorized. The test sets tiling sizes to
-// be non-divisible by problem sizes. If padding and vectorizing are kicked in,
-// vector ops will be generated.
-#config = #iree_codegen.lowering_config<tile_sizes = [[65, 65], [8, 32, 0], [0, 0, 16], [0, 0, 0]]>
-#translation = #iree_codegen.translation_info<CPUDoubleTilingPadExpert>
-#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
-  #hal.descriptor_set.layout<0, bindings = [
-    #hal.descriptor_set.binding<0, storage_buffer>,
-    #hal.descriptor_set.binding<1, storage_buffer>,
-    #hal.descriptor_set.binding<2, storage_buffer>
-  ]>
-]>
-hal.executable private @preset_pad_config_matmul  {
-  hal.executable.variant @system_elf_x86_64 target(<"llvm-cpu", "system-elf-x86_64">) {
-    hal.executable.export @preset_pad_config_matmul layout(#pipeline_layout) attributes {translation_info = #translation} {
-    ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index, %arg3 : index):
-      %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3
-      hal.return %x, %y, %z : index, index, index
-    }
-    builtin.module {
-      func.func @preset_pad_config_matmul() {
-        %cst = arith.constant 0.000000e+00 : f32
-        %c0 = arith.constant 0 : index
-        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<128x49xf32>>
-        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<49x512xf32>>
-        %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<128x512xf32>>
-        %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 49], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128x49xf32>> -> tensor<128x49xf32>
-        %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [49, 512], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<49x512xf32>> -> tensor<49x512xf32>
-        %5 = tensor.empty() : tensor<128x512xf32>
-        %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<128x512xf32>) -> tensor<128x512xf32>
-        %7 = linalg.matmul {lowering_config = #config}
-          ins(%3, %4 : tensor<128x49xf32>, tensor<49x512xf32>)
-          outs(%6 : tensor<128x512xf32>) -> tensor<128x512xf32>
-        flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [128, 512], strides = [1, 1] : tensor<128x512xf32> -> !flow.dispatch.tensor<writeonly:tensor<128x512xf32>>
-        return
-      }
-    }
-  }
-}
-// CHECK-LABEL: func.func @preset_pad_config_matmul
-//       CHECK:     vector.fma
-
-// -----
-
-// Checks that the ops are padded and vectorized. The test sets tiling sizes to
-// be non-divisible by problem sizes. If padding and vectorizing are kicked in,
-// vector ops will be generated.
-#config = #iree_codegen.lowering_config<tile_sizes = [[192, 128, 0], [8, 32, 0], [0, 0, 16], [0, 0, 0]]>
-#translation = #iree_codegen.translation_info<CPUDoubleTilingPadExpert>
-#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
-  #hal.descriptor_set.layout<0, bindings = [
-    #hal.descriptor_set.binding<0, storage_buffer>,
-    #hal.descriptor_set.binding<1, storage_buffer>,
-    #hal.descriptor_set.binding<2, storage_buffer>
-  ]>
-]>
-hal.executable private @preset_pad_config_dynamic_matmul  {
-  hal.executable.variant @system_elf_x86_64 target(<"llvm-cpu", "system-elf-x86_64">) {
-    hal.executable.export @preset_pad_config_dynamic_matmul layout(#pipeline_layout) attributes {translation_info = #translation} {
-    ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index, %arg4: index):
-      %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg1, %arg2, %arg3, %arg4
-      hal.return %x, %y, %z : index, index, index
-    }
-    builtin.module {
-      func.func @preset_pad_config_dynamic_matmul() {
-        %cst = arith.constant 0.000000e+00 : f32
-        %c0 = arith.constant 0 : index
-        %0 = hal.interface.constant.load[0] : i32
-        %1 = hal.interface.constant.load[1] : i32
-        %2 = hal.interface.constant.load[2] : i32
-        %3 = hal.interface.constant.load[3] : i32
-        %4 = arith.index_castui %0 : i32 to index
-        %5 = arith.index_castui %1 : i32 to index
-        %6 = arith.index_castui %2 : i32 to index
-        %7 = arith.index_castui %3 : i32 to index
-        %8 = flow.dispatch.workload.ordinal %4, 0 : index
-        %9 = flow.dispatch.workload.ordinal %5, 1 : index
-        %10 = flow.dispatch.workload.ordinal %6, 2 : index
-        %11 = flow.dispatch.workload.ordinal %7, 3 : index
-        %12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<?x?xf32>>{%10, %8}
-        %13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<?x?xf32>>{%9, %11}
-        %14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<?x?xf32>>{%10, %11}
-        %15 = flow.dispatch.tensor.load %12, offsets = [0, 0], sizes = [%10, %8], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?xf32>>{%10, %8} -> tensor<?x?xf32>
-        %16 = flow.dispatch.tensor.load %13, offsets = [0, 0], sizes = [%9, %11], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?xf32>>{%9, %11} -> tensor<?x?xf32>
-        %17 = tensor.empty(%10, %11) : tensor<?x?xf32>
-        %18 = linalg.fill ins(%cst : f32) outs(%17 : tensor<?x?xf32>) -> tensor<?x?xf32>
-        %19 = linalg.matmul {lowering_config = #config} ins(%15, %16 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%18 : tensor<?x?xf32>) -> tensor<?x?xf32>
-        flow.dispatch.tensor.store %19, %14, offsets = [0, 0], sizes = [%10, %11], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:tensor<?x?xf32>>{%10, %11}
-        return
-      }
-    }
-  }
-}
-// Checks that the bounded stack allocation are created.
-// CHECK-LABEL: func.func @preset_pad_config_dynamic_matmul
-//   CHECK-DAG:   memref.alloca() {{.+}} memref<8x16xf32>
-//   CHECK-DAG:   memref.alloca() {{.+}} memref<16x32xf32>
-//   CHECK-DAG:   memref.alloca() {{.+}} memref<8x32xf32>
-//       CHECK:     vector.fma
-
-// -----
-
 #executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {
   cpu_features = "",
   data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
diff --git a/tests/e2e/matmul/BUILD.bazel b/tests/e2e/matmul/BUILD.bazel
index e051136..0ebebf8 100644
--- a/tests/e2e/matmul/BUILD.bazel
+++ b/tests/e2e/matmul/BUILD.bazel
@@ -185,23 +185,6 @@
     "large",
 ]]
 
-# Some e2e testing for --iree-codegen-enable-vector-padding=false.
-iree_generated_trace_runner_test(
-    name = "e2e_matmul_nondt_f32_small_no_padding",
-    compiler_flags = [
-        "--iree-codegen-enable-vector-padding=false",
-    ],
-    generator = ":generate_e2e_matmul_tests",
-    generator_args = [
-        "--lhs_rhs_type=f32",
-        "--shapes=small",
-    ],
-    target_backends_and_drivers = [
-        ("llvm-cpu", "local-task"),
-    ],
-    trace_runner = "//tools:iree-e2e-matmul-test",
-)
-
 ###########################################################################
 ##
 ## VMVX backend
diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt
index 15e324e..8ffd49b 100644
--- a/tests/e2e/matmul/CMakeLists.txt
+++ b/tests/e2e/matmul/CMakeLists.txt
@@ -742,24 +742,6 @@
 
 iree_generated_trace_runner_test(
   NAME
-    e2e_matmul_nondt_f32_small_no_padding
-  GENERATOR
-    "generate_e2e_matmul_tests.py"
-  GENERATOR_ARGS
-    "--lhs_rhs_type=f32"
-    "--shapes=small"
-  TRACE_RUNNER
-    iree-e2e-matmul-test
-  TARGET_BACKENDS
-    "llvm-cpu"
-  DRIVERS
-    "local-task"
-  COMPILER_FLAGS
-    "--iree-codegen-enable-vector-padding=false"
-)
-
-iree_generated_trace_runner_test(
-  NAME
     e2e_matmul_dt_uk_i8_small
   GENERATOR
     "generate_e2e_matmul_tests.py"
diff --git a/tests/e2e/regression/lowering_config.mlir b/tests/e2e/regression/lowering_config.mlir
index 42a1329..259eb01 100644
--- a/tests/e2e/regression/lowering_config.mlir
+++ b/tests/e2e/regression/lowering_config.mlir
@@ -1,12 +1,12 @@
 #compilation0 = #iree_codegen.compilation_info<
     lowering_config = <tile_sizes = [[32, 32], [8, 8, 0], [0, 0, 8], [0, 0, 0]]>,
-    translation_info = <CPUDoubleTilingPadExpert>>
+    translation_info = <CPUDoubleTilingExpert>>
 #compilation1 = #iree_codegen.compilation_info<
     lowering_config = <tile_sizes = [[64, 64], [4, 4, 0], [0, 0, 4], [0, 0, 0]]>,
-    translation_info = <CPUDoubleTilingPadExpert>>
+    translation_info = <CPUDoubleTilingExpert>>
 #compilation2 = #iree_codegen.compilation_info<
     lowering_config = <tile_sizes = [{sizes=[32, 64], interchange=[1,0]}, [8, 32, 0], [0, 0, 8], [0, 0, 0]]>,
-    translation_info = <CPUDoubleTilingPadExpert>>
+    translation_info = <CPUDoubleTilingExpert>>
 
 func.func @lowering_config_test() {
   %a = util.unfoldable_constant dense<1.0> : tensor<128x256xf32>