[LLVMGPU] Combine parallel and reduction padding in LLVMGPUPadAndVectorDistribute (#18771)

Since https://github.com/iree-org/iree/pull/18748 tensor.pad can be
fused in with tiling. This patch combines the parallel and reduction
padding passes into a single pass that pads at once, and the pads are
later fused during tiling.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUPromoteMatmulToFitMMA.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUPromoteMatmulToFitMMA.cpp
index dbcc5b1..2421494 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUPromoteMatmulToFitMMA.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUPromoteMatmulToFitMMA.cpp
@@ -27,25 +27,18 @@
 public:
   using impl::LLVMGPUPromoteMatmulToFitMMAPassBase<
       LLVMGPUPromoteMatmulToFitMMAPass>::LLVMGPUPromoteMatmulToFitMMAPassBase;
-  explicit LLVMGPUPromoteMatmulToFitMMAPass(
-      const LLVMGPUMatmulPadOption &option) {
-    this->targetDimensions.setValue(option);
-  }
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<tensor::TensorDialect, linalg::LinalgDialect>();
   }
 
   void padWithZeroValue(RewriterBase &rewriter, linalg::LinalgOp op,
-                        ArrayRef<int64_t> paddingDims,
-                        ArrayRef<int64_t> padToMultipleOf, bool noFold) const {
-    assert(paddingDims.size() == padToMultipleOf.size() &&
-           "invalid pad multiples for padding dimensions");
-
+                        ArrayRef<int64_t> padToMultipleOf) const {
     LLVM_DEBUG(llvm::dbgs() << "candidate: " << op << "\n");
     OpBuilder::InsertionGuard guard(rewriter);
     rewriter.setInsertionPointAfter(op);
 
-    SmallVector<bool> nofoldFlags(op.getNumDpsInputs(), noFold);
+    SmallVector<int64_t> paddingDims =
+        llvm::to_vector(llvm::seq<int64_t>(padToMultipleOf.size()));
 
     SmallVector<Attribute> paddingValueAttributes;
     for (auto &operand : op->getOpOperands()) {
@@ -58,7 +51,6 @@
             .setPaddingDimensions(paddingDims)
             .setPaddingValues(paddingValueAttributes)
             .setPadToMultipleOf(padToMultipleOf)
-            .setNofoldFlags(nofoldFlags)
             .setCopyBackOp(linalg::LinalgPaddingOptions::CopyBackOp::None);
 
     FailureOr<linalg::LinalgOp> result =
@@ -72,26 +64,6 @@
     MLIRContext *ctx = &getContext();
     auto funcOp = getOperation();
 
-    // Preserve the innermost tensor.pad ops (i.e., pad for reduction dims), so
-    // we can kick canonicalization patterns to fold outer tensor.pad ops away.
-    bool noFold = false;
-    utils::IteratorType targetIterType = utils::IteratorType::parallel;
-    switch (targetDimensions) {
-    case LLVMGPUMatmulPadOption::ParallelDims:
-      LLVM_DEBUG(llvm::dbgs() << "padding parallel dims\n");
-      targetIterType = utils::IteratorType::parallel;
-      noFold = false;
-      break;
-    case LLVMGPUMatmulPadOption::ReductionDims:
-      LLVM_DEBUG(llvm::dbgs() << "padding reduction dims\n");
-      targetIterType = utils::IteratorType::reduction;
-      noFold = true;
-      break;
-    default: // Unreachable.
-      assert(false);
-      break;
-    };
-
     SmallVector<linalg::LinalgOp> candidates;
     funcOp->walk([&](linalg::LinalgOp op) {
       if (linalg::isaContractionOpInterface(op)) {
@@ -101,46 +73,27 @@
 
     IRRewriter rewriter(ctx);
     for (linalg::LinalgOp op : candidates) {
-      SmallVector<int64_t> padMultiples(op.getNumLoops(), 1);
       auto config = dyn_cast_or_null<IREE::GPU::LoweringConfigAttr>(
           getLoweringConfig(op));
-      if (config) {
-        switch (targetDimensions) {
-        case LLVMGPUMatmulPadOption::ParallelDims:
-          padMultiples = config.getStaticTilingLevelSizes(
-              static_cast<unsigned>(IREE::GPU::TilingLevel::Workgroup), op);
-          break;
-        case LLVMGPUMatmulPadOption::ReductionDims:
-          padMultiples = config.getStaticTilingLevelSizes(
-              static_cast<unsigned>(IREE::GPU::TilingLevel::Reduction), op);
-          break;
-        default:
-          assert(false && "Unexpected target dimensions");
-          break;
-        }
+      if (!config) {
+        continue;
       }
 
-      // Populate padding dimensions.
-      SmallVector<int64_t> paddingDimensions;
-      for (auto [idx, iter] : llvm::enumerate(op.getIteratorTypesArray())) {
-        if (iter == targetIterType) {
-          paddingDimensions.push_back(idx);
-        }
-      }
+      SmallVector<int64_t> wgTiles = config.getStaticTilingLevelSizes(
+          static_cast<unsigned>(IREE::GPU::TilingLevel::Workgroup), op);
+      SmallVector<int64_t> redTiles = config.getStaticTilingLevelSizes(
+          static_cast<unsigned>(IREE::GPU::TilingLevel::Reduction), op);
 
-      // Populate tile sizes. We pad to multiples of workgroup/reduction
-      // tile sizes based on the selected target tiling dimensions.
-      // This pass is ran after the select target tiling is done to pad
-      // all dimensions to the select tile sizes.
-      SmallVector<int64_t> padToMultipleOf;
-      for (int64_t dim : paddingDimensions) {
-        if (padMultiples[dim] != 0) {
-          padToMultipleOf.push_back(padMultiples[dim]);
-        }
+      // Populate padding dimensions to maximum of possible tile sizes.
+      SmallVector<int64_t> padToMultipleOf(op.getNumLoops(), 1);
+      for (auto [wgTile, redTile, padMultiple] :
+           llvm::zip_equal(wgTiles, redTiles, padToMultipleOf)) {
+        padMultiple = std::max({wgTile, redTile, padMultiple});
       }
+      SmallVector<int64_t> paddingDimensions =
+          llvm::to_vector(llvm::seq<int64_t>(op.getNumLoops()));
 
-      padWithZeroValue(rewriter, op, paddingDimensions, padToMultipleOf,
-                       noFold);
+      padWithZeroValue(rewriter, op, padToMultipleOf);
     }
 
     {
@@ -156,58 +109,8 @@
         return signalPassFailure();
       }
     }
-
-    // XXX(hanchung): This is needed for pad op fusion, which will remove
-    // outer pad ops. I.e., it mainly wants to remove first pad op in the
-    // pad->extract_slice->pad chain, while the canonicalization pattern can
-    // only recognize slice->pad->slice->pad.
-    {
-      SmallVector<tensor::PadOp> padOps;
-      funcOp.walk([&](tensor::PadOp op) { padOps.push_back(op); });
-      for (auto op : padOps) {
-        auto srcExtractSliceOp =
-            op.getSource().getDefiningOp<tensor::ExtractSliceOp>();
-        if (!srcExtractSliceOp) {
-          continue;
-        }
-        auto producerPadOp =
-            srcExtractSliceOp.getSource().getDefiningOp<tensor::PadOp>();
-        if (!producerPadOp) {
-          continue;
-        }
-        auto src = producerPadOp.getSource()
-                       .getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
-        if (!src) {
-          continue;
-        }
-
-        rewriter.setInsertionPointAfter(src);
-        SmallVector<OpFoldResult> sizes =
-            tensor::getMixedSizes(rewriter, op.getLoc(), src);
-        SmallVector<OpFoldResult> offsets(sizes.size(),
-                                          rewriter.getIndexAttr(0));
-        SmallVector<OpFoldResult> strides(sizes.size(),
-                                          rewriter.getIndexAttr(1));
-        auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
-            op.getLoc(), src.getResult(), offsets, sizes, strides);
-        rewriter.startOpModification(op);
-        producerPadOp.getSourceMutable().assign(extractSliceOp.getResult());
-        rewriter.finalizeOpModification(op);
-      }
-
-      RewritePatternSet patterns(ctx);
-      tensor::PadOp::getCanonicalizationPatterns(patterns, ctx);
-      if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
-        return signalPassFailure();
-      }
-    }
   }
 };
 } // namespace
 
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createLLVMGPUPromoteMatmulToFitMMAPass(LLVMGPUMatmulPadOption option) {
-  return std::make_unique<LLVMGPUPromoteMatmulToFitMMAPass>(option);
-}
-
 } // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index 76b1af3..51fcc6b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -858,25 +858,20 @@
   funcPassManager.addPass(createCSEPass());
 
   if (usePadToModelSharedMemcpy) {
-    LLVMGPUMatmulPadOption option = LLVMGPUMatmulPadOption::ParallelDims;
-    funcPassManager.addPass(createLLVMGPUPromoteMatmulToFitMMAPass(option));
+    funcPassManager.addPass(createLLVMGPUPromoteMatmulToFitMMAPass());
   }
 
   // Tile to reduction loops.
   {
     GPUApplyTilingLevelPassOptions options;
     options.tilingLevel = IREE::GPU::TilingLevel::Reduction;
+    options.allowZeroSlices = true;
     funcPassManager.addPass(createGPUApplyTilingLevelPass(options));
     funcPassManager.addPass(affine::createLoopCoalescingPass());
     funcPassManager.addPass(createCanonicalizerPass());
     funcPassManager.addPass(createCSEPass());
   }
 
-  if (usePadToModelSharedMemcpy) {
-    LLVMGPUMatmulPadOption option = LLVMGPUMatmulPadOption::ReductionDims;
-    funcPassManager.addPass(createLLVMGPUPromoteMatmulToFitMMAPass(option));
-  }
-
   funcPassManager.addPass(IREE::LinalgExt::createDecomposeAttentionPass());
   funcPassManager.addPass(createCanonicalizerPass());
   funcPassManager.addPass(createCSEPass());
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
index c118177..d932564 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
@@ -103,10 +103,6 @@
 // Wrappers that not use tablegen options.
 //------------------------------------------------------------------------------
 
-enum class LLVMGPUMatmulPadOption { ParallelDims, ReductionDims };
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createLLVMGPUPromoteMatmulToFitMMAPass(LLVMGPUMatmulPadOption option);
-
 enum class GPUTensorCoreType {
   WMMA = 0,
   MMA_SYNC = 1,
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
index ef51a6a..815a82f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
@@ -105,19 +105,6 @@
 def LLVMGPUPromoteMatmulToFitMMAPass :
     InterfacePass<"iree-llvmgpu-promote-matmul-to-fit-mma", "mlir::FunctionOpInterface"> {
   let summary = "Pass to promote contraction ops to fit mma shapes";
-  let options = [
-    Option<"targetDimensions", "target-dimensions", "mlir::iree_compiler::LLVMGPUMatmulPadOption",
-           /*default=*/"mlir::iree_compiler::LLVMGPUMatmulPadOption::ParallelDims",
-           "Select the strategy to control how multi_reduction is lowered.",
-           [{::llvm::cl::values(
-            clEnumValN(mlir::iree_compiler::LLVMGPUMatmulPadOption::ParallelDims,
-                       "parallel",
-                       "Pad all the parallel dims for contraction ops."),
-            clEnumValN(mlir::iree_compiler::LLVMGPUMatmulPadOption::ReductionDims,
-                       "reduction",
-                       "Pad all the reduction dims for contraction ops.")
-        )}]>
-  ];
 }
 
 def LLVMGPUSelectLoweringStrategyPass :
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
index 610e114..d21faf8 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
@@ -511,7 +511,7 @@
 // CHECK:         %[[RHS_LOAD:.+]] = vector.transfer_read %[[RHS_GLOBAL_SUB]]{{.+}} {in_bounds = [true, false, false]}
 // CHECK:         vector.transfer_write %[[LHS_LOAD]], %[[LHS_SHARED]]
 // CHECK:         vector.transfer_write %[[RHS_LOAD]], %[[RHS_SHARED]]
-// CHECK:         %[[RES:.+]] scf.for {{.*}} = %c0 to %c1265 step %c16 iter_args({{.*}}) -> (vector<1x1x1x1x1x1x1x4x1xf16>)
+// CHECK:         %[[RES:.+]] scf.for {{.*}} = %c0 to %c1280 step %c16 iter_args({{.*}}) -> (vector<1x1x1x1x1x1x1x4x1xf16>)
 // CHECK-DAG:       %[[LHS_GLOBAL_SUB:.+]] = memref.subview %[[LHS_GLOBAL]]
 // CHECK-DAG:       %[[RHS_GLOBAL_SUB:.+]] = memref.subview %[[RHS_GLOBAL]]
 // CHECK:           %[[LHS_LOAD:.+]] = vector.transfer_read %[[LHS_GLOBAL_SUB]]
@@ -581,9 +581,11 @@
 // CHECK-SAME:        memref<196x16x24xf32
 // CHECK-SAME:        vector<1x1x1xf32>
 // RHS
+// The dynamic dimension should be removed after:
+// https://github.com/llvm/llvm-project/pull/112236
 // CHECK:             vector.transfer_read
-// CHECK-SAME:        in_bounds = [true, true, false]
-// CHECK-SAME:        memref<1x8x24xf32
+// CHECK-SAME:        in_bounds = [true, false, false]
+// CHECK-SAME:        memref<1x?x24xf32
 // CHECK-SAME:        vector<1x1x2xf32>
 // CHECK:           scf.yield
 // OUTPUT
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/promote_matmul_to_fit_mma.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/promote_matmul_to_fit_mma.mlir
index bda4836..21bc2fc 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/promote_matmul_to_fit_mma.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/promote_matmul_to_fit_mma.mlir
@@ -1,5 +1,4 @@
-// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-llvmgpu-promote-matmul-to-fit-mma{target-dimensions=parallel}))"  %s | FileCheck %s --check-prefixes=ALL,PARALLEL
-// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-llvmgpu-promote-matmul-to-fit-mma{target-dimensions=reduction}))" %s | FileCheck %s --check-prefixes=ALL,REDUCTION
+// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-llvmgpu-promote-matmul-to-fit-mma))"  %s | FileCheck %s
 
 #pipeline_layout = #hal.pipeline.layout<bindings = [
   #hal.pipeline.binding<storage_buffer>,
@@ -34,114 +33,20 @@
   flow.dispatch.tensor.store %11, %2, offsets = [%workgroup_id_z, %3, %4], sizes = [1, %5, %6], strides = [1, 1, 1] : tensor<1x?x?xf16> -> !flow.dispatch.tensor<writeonly:tensor<64x968x1281xf16>>
   return
 }
-// ALL-LABEL:     func.func @batch_matmul_f16
-// ALL:             %[[LHS_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x968x1281xf16>>
-// ALL:             %[[RHS_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x1281x1281xf16>>
-// ALL:             %[[OUT_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<64x968x1281xf16>>
-// ALL-DAG:         %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_HANDLE]]
-// ALL-DAG:         %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_HANDLE]]
-// PARALLEL:        %[[PADDED_LHS:.+]] = tensor.pad %[[LHS]]
-// PARALLEL:        } : tensor<1x?x1281xf16> to tensor<1x64x1281xf16>
-// PARALLEL:        %[[PADDED_RHS:.+]] = tensor.pad %[[RHS]]
-// PARALLEL:        } : tensor<1x1281x?xf16> to tensor<1x1281x128xf16>
-// PARALLEL:        %[[FILL:.+]] = linalg.fill
-// PARALLEL:        %[[GEMM:.+]] = linalg.batch_matmul
-// PARALLEL-SAME:     ins(%[[PADDED_LHS]], %[[PADDED_RHS]]
-// PARALLEL-SAME:     outs(%[[FILL]]
+// CHECK-LABEL:     func.func @batch_matmul_f16
+// CHECK:             %[[LHS_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x968x1281xf16>>
+// CHECK:             %[[RHS_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x1281x1281xf16>>
+// CHECK:             %[[OUT_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<64x968x1281xf16>>
+// CHECK-DAG:         %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_HANDLE]]
+// CHECK-DAG:         %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_HANDLE]]
+// CHECK:        %[[PADDED_LHS:.+]] = tensor.pad %[[LHS]]
+// CHECK:        } : tensor<1x?x1281xf16> to tensor<1x64x1296xf16>
+// CHECK:        %[[PADDED_RHS:.+]] = tensor.pad %[[RHS]]
+// CHECK:        } : tensor<1x1281x?xf16> to tensor<1x1296x128xf16>
+// CHECK:        %[[FILL:.+]] = linalg.fill
+// CHECK:        %[[GEMM:.+]] = linalg.batch_matmul
+// CHECK-SAME:     ins(%[[PADDED_LHS]], %[[PADDED_RHS]]
+// CHECK-SAME:     outs(%[[FILL]]
 
-// The reduction dim is not tiled in the test case, so it pads it to the
-// matmul intrinsic k.
-// REDUCTION-DAG:   %[[FILL_DEST:.+]] = flow.dispatch.tensor.load %[[OUT_HANDLE]]
-// REDUCTION:       %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[FILL_DEST]]
-// REDUCTION:       %[[PADDED_LHS:.+]] = tensor.pad %[[LHS]]
-// REDUCTION:       } : tensor<1x?x1281xf16> to tensor<1x?x1296xf16>
-// REDUCTION:       %[[PADDED_RHS:.+]] = tensor.pad %[[RHS]]
-// REDUCTION:       } : tensor<1x1281x?xf16> to tensor<1x1296x?xf16>
-// REDUCTION:       %[[GEMM:.+]] = linalg.batch_matmul
-// REDUCTION-SAME:    ins(%[[PADDED_LHS]], %[[PADDED_RHS]]
-// REDUCTION-SAME:    outs(%[[FILL]]
-
-// ALL:             %[[OUT_SLICE:.+]] = tensor.extract_slice %[[GEMM]]
-// ALL:             flow.dispatch.tensor.store %[[OUT_SLICE]], %[[OUT_HANDLE]]
-
-// -----
-
-#pipeline_layout = #hal.pipeline.layout<bindings = [
-  #hal.pipeline.binding<storage_buffer>,
-  #hal.pipeline.binding<storage_buffer>,
-  #hal.pipeline.binding<storage_buffer>
-]>
-#map = affine_map<()[s0] -> (s0 * 64)>
-#map1 = affine_map<()[s0] -> (s0 * 128)>
-#map2 = affine_map<()[s0] -> (s0 * -64 + 968, 64)>
-#map3 = affine_map<()[s0] -> (s0 * -128 + 1281, 128)>
-#map4 = affine_map<()[s0] -> (-s0 + 64)>
-#map5 = affine_map<()[s0] -> (-s0 + 128)>
-#map6 = affine_map<(d0) -> (-d0 + 1281, 64)>
-func.func @batch_matmul_pad_reduction_after_tiling() {
-  %c64 = arith.constant 64 : index
-  %c1281 = arith.constant 1281 : index
-  %c2 = arith.constant 2 : index
-  %c1 = arith.constant 1 : index
-  %cst = arith.constant 0.000000e+00 : f16
-  %c0 = arith.constant 0 : index
-  %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x968x1281xf16>>
-  %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x1281x1281xf16>>
-  %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<64x968x1281xf16>>
-  %workgroup_id_z = hal.interface.workgroup.id[2] : index
-  %workgroup_id_y = hal.interface.workgroup.id[1] : index
-  %3 = affine.apply #map()[%workgroup_id_y]
-  %workgroup_id_x = hal.interface.workgroup.id[0] : index
-  %4 = affine.apply #map1()[%workgroup_id_x]
-  %5 = affine.min #map2()[%workgroup_id_y]
-  %6 = affine.min #map3()[%workgroup_id_x]
-  %7 = flow.dispatch.tensor.load %0, offsets = [%workgroup_id_z, %3, 0], sizes = [1, %5, 1281], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<64x968x1281xf16>> -> tensor<1x?x1281xf16>
-  %dim = tensor.dim %7, %c1 : tensor<1x?x1281xf16>
-  %8 = flow.dispatch.tensor.load %1, offsets = [%workgroup_id_z, 0, %4], sizes = [1, 1281, %6], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<64x1281x1281xf16>> -> tensor<1x1281x?xf16>
-  %dim_0 = tensor.dim %8, %c2 : tensor<1x1281x?xf16>
-  %9 = affine.apply #map4()[%5]
-  %padded = tensor.pad %7 low[0, 0, 0] high[0, %9, 0] {
-  ^bb0(%arg0: index, %arg1: index, %arg2: index):
-    tensor.yield %cst : f16
-  } : tensor<1x?x1281xf16> to tensor<1x64x1281xf16>
-  %10 = affine.apply #map5()[%6]
-  %padded_2 = tensor.pad %8 low[0, 0, 0] high[0, 0, %10] {
-  ^bb0(%arg0: index, %arg1: index, %arg2: index):
-    tensor.yield %cst : f16
-  } : tensor<1x1281x?xf16> to tensor<1x1281x128xf16>
-  %11 = tensor.empty() : tensor<1x64x128xf16>
-  %12 = linalg.fill ins(%cst : f16) outs(%11 : tensor<1x64x128xf16>) -> tensor<1x64x128xf16>
-  %13 = scf.for %arg0 = %c0 to %c1281 step %c64 iter_args(%arg1 = %12) -> (tensor<1x64x128xf16>) {
-    %14 = affine.min #map6(%arg0)
-    %extracted_slice_4 = tensor.extract_slice %padded[0, 0, %arg0] [1, 64, %14] [1, 1, 1] : tensor<1x64x1281xf16> to tensor<1x64x?xf16>
-    %extracted_slice_5 = tensor.extract_slice %padded_2[0, %arg0, 0] [1, %14, 128] [1, 1, 1] : tensor<1x1281x128xf16> to tensor<1x?x128xf16>
-    %15 = linalg.batch_matmul ins(%extracted_slice_4, %extracted_slice_5 : tensor<1x64x?xf16>, tensor<1x?x128xf16>) outs(%arg1 : tensor<1x64x128xf16>) -> tensor<1x64x128xf16>
-    scf.yield %15 : tensor<1x64x128xf16>
-  }
-  %extracted_slice_3 = tensor.extract_slice %13[0, 0, 0] [1, %5, %6] [1, 1, 1] : tensor<1x64x128xf16> to tensor<1x?x?xf16>
-  flow.dispatch.tensor.store %extracted_slice_3, %2, offsets = [%workgroup_id_z, %3, %4], sizes = [1, %5, %6], strides = [1, 1, 1] : tensor<1x?x?xf16> -> !flow.dispatch.tensor<writeonly:tensor<64x968x1281xf16>>
-  return
-}
-// The padding on parallel dims is a nop because they are already padded. Skip
-// the check for the testcase.
-// ALL-LABEL:     func.func @batch_matmul_pad_reduction_after_tiling
-// ALL:             %[[LHS_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x968x1281xf16>>
-// ALL:             %[[RHS_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x1281x1281xf16>>
-// ALL:             %[[OUT_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<64x968x1281xf16>>
-// ALL-DAG:         %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_HANDLE]]
-// ALL-DAG:         %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_HANDLE]]
-// REDUCTION:       %[[INIT:.+]] = tensor.empty() : tensor<1x64x128xf16>
-// REDUCTION:       %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[INIT]]
-// REDUCTION:       %[[RES:.+]] = scf.for {{.+}} iter_args(%[[ITER:.+]] = %[[FILL]])
-// REDUCTION:         %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]]
-// REDUCTION:         %[[PADDED_LHS:.+]] = tensor.pad %[[LHS_SLICE]]
-// REDUCTION:         } : tensor<1x?x?xf16> to tensor<1x64x64xf16>
-// REDUCTION:         %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]]
-// REDUCTION:         %[[PADDED_RHS:.+]] = tensor.pad %[[RHS_SLICE]]
-// REDUCTION:         } : tensor<1x?x?xf16> to tensor<1x64x128xf16>
-// REDUCTION:         %[[GEMM:.+]] = linalg.batch_matmul
-// REDUCTION-SAME:      ins(%[[PADDED_LHS]], %[[PADDED_RHS]]
-// REDUCTION-SAME:      outs(%[[ITER]]
-// REDUCTION:         scf.yield %[[GEMM]]
-// REDUCTION:       %[[OUT_SLICE:.+]] = tensor.extract_slice %[[RES]]
-// REDUCTION:       flow.dispatch.tensor.store %[[OUT_SLICE]], %[[OUT_HANDLE]]
+// CHECK:             %[[OUT_SLICE:.+]] = tensor.extract_slice %[[GEMM]]
+// CHECK:             flow.dispatch.tensor.store %[[OUT_SLICE]], %[[OUT_HANDLE]]