[LLVMGPU] Introduce a pass that pad matmul to fit mma shapes. (#17225)

There are two modes in the pass, one is padding the parallel dimensions
and the other is padding the reduction dimensions. The padding value is
inferred from the producers that implements ValueBoundsOpInterface,
i.e., they will be padded to the last tiling sizes.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
index 66e1e82..5eac141 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
@@ -98,6 +98,7 @@
         "LLVMGPULowerExecutableTarget.cpp",
         "LLVMGPUPackSharedMemoryAlloc.cpp",
         "LLVMGPUPrefetching.cpp",
+        "LLVMGPUPromoteMatmulToFitMMA.cpp",
         "LLVMGPUSelectLoweringStrategy.cpp",
         "LLVMGPUTensorCoreVectorization.cpp",
         "LLVMGPUTensorPad.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
index f9634d6..87c0e9f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
@@ -83,6 +83,7 @@
     "LLVMGPULowerExecutableTarget.cpp"
     "LLVMGPUPackSharedMemoryAlloc.cpp"
     "LLVMGPUPrefetching.cpp"
+    "LLVMGPUPromoteMatmulToFitMMA.cpp"
     "LLVMGPUSelectLoweringStrategy.cpp"
     "LLVMGPUTensorCoreVectorization.cpp"
     "LLVMGPUTensorPad.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUPromoteMatmulToFitMMA.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUPromoteMatmulToFitMMA.cpp
new file mode 100644
index 0000000..01b6722
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUPromoteMatmulToFitMMA.cpp
@@ -0,0 +1,177 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/LLVMGPU/PassDetail.h"
+#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-llvmgpu-promote-matmul-to-fit-mma"
+
+namespace mlir::iree_compiler {
+#define GEN_PASS_DECL_LLVMGPUPROMOTEMATMULTOFITMMA
+#include "iree/compiler/Codegen/LLVMGPU/Passes.h.inc"
+namespace {
+
+class LLVMGPUPromoteMatmulToFitMMAPass
+    : public LLVMGPUPromoteMatmulToFitMMABase<
+          LLVMGPUPromoteMatmulToFitMMAPass> {
+public:
+  explicit LLVMGPUPromoteMatmulToFitMMAPass(
+      const LLVMGPUMatmulPadOption &option) {
+    this->targetDimensions.setValue(option);
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<tensor::TensorDialect, linalg::LinalgDialect>();
+  }
+
+  void padWithZeroValue(RewriterBase &rewriter, linalg::LinalgOp op,
+                        utils::IteratorType targetIterType, bool nofold) const {
+    LLVM_DEBUG(llvm::dbgs() << "candidate: " << op << "\n");
+    OpBuilder::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPointAfter(op);
+
+    SmallVector<int64_t> paddingDims;
+    for (auto [index, iterType] : llvm::enumerate(op.getIteratorTypesArray())) {
+      if (iterType == targetIterType) {
+        paddingDims.push_back(index);
+      }
+    }
+
+    SmallVector<bool> packPaddings(op.getNumDpsInputs(), nofold);
+
+    // One is enough because they will essentially be padded to corresponding
+    // tile sizes, which should be multiple of MMA shapes.
+    SmallVector<int64_t> padToMultipleOf(paddingDims.size(), 1);
+    SmallVector<Attribute> paddingValueAttributes;
+    for (auto &operand : op->getOpOperands()) {
+      auto elemType = getElementTypeOrSelf(operand.get().getType());
+      paddingValueAttributes.push_back(rewriter.getZeroAttr(elemType));
+    }
+
+    auto options =
+        linalg::LinalgPaddingOptions()
+            .setPaddingDimensions(paddingDims)
+            .setPaddingValues(paddingValueAttributes)
+            .setPadToMultipleOf(padToMultipleOf)
+            .setPackPaddings(packPaddings)
+            .setCopyBackOp(linalg::LinalgPaddingOptions::CopyBackOp::None);
+
+    FailureOr<linalg::LinalgOp> result =
+        linalg::padAndHoistLinalgOp(rewriter, op, options);
+    if (failed(result)) {
+      LLVM_DEBUG(llvm::dbgs() << "failed to pad op " << op << "\n");
+    }
+  }
+
+  void runOnOperation() override {
+    MLIRContext *ctx = &getContext();
+    auto funcOp = getOperation();
+
+    // Preserve the innermost tensor.pad ops (i.e., pad for reduction dims), so
+    // we can kick canonicalization patterns to fold outer tensor.pad ops away.
+    bool nofold = false;
+    utils::IteratorType targetIterType = utils::IteratorType::parallel;
+    switch (targetDimensions) {
+    case LLVMGPUMatmulPadOption::ParallelDims:
+      LLVM_DEBUG(llvm::dbgs() << "padding parallel dims\n");
+      targetIterType = utils::IteratorType::parallel;
+      nofold = false;
+      break;
+    case LLVMGPUMatmulPadOption::ReductionDims:
+      LLVM_DEBUG(llvm::dbgs() << "padding reduction dims\n");
+      targetIterType = utils::IteratorType::reduction;
+      nofold = true;
+      break;
+    default: // Unreachable.
+      assert(false);
+      break;
+    };
+
+    SmallVector<linalg::LinalgOp> candidates;
+    funcOp->walk([&](linalg::LinalgOp op) {
+      if (linalg::isaContractionOpInterface(op)) {
+        candidates.push_back(op);
+      }
+    });
+
+    IRRewriter rewriter(ctx);
+    for (auto op : candidates) {
+      padWithZeroValue(rewriter, op, targetIterType, nofold);
+    }
+
+    {
+      RewritePatternSet patterns(ctx);
+      linalg::populateSwapExtractSliceWithFillPatterns(patterns);
+      linalg::FillOp::getCanonicalizationPatterns(patterns, ctx);
+      memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
+      ctx->getLoadedDialect<tensor::TensorDialect>()
+          ->getCanonicalizationPatterns(patterns);
+      tensor::PadOp::getCanonicalizationPatterns(patterns, ctx);
+      if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+        LLVM_DEBUG(llvm::dbgs() << "----- cleanup failed -----\n");
+        return signalPassFailure();
+      }
+    }
+
+    // XXX(hanchung): This is needed for pad op fusion, which will remove
+    // outer pad ops. I.e., it mainly wants to remove first pad op in the
+    // pad->extract_slice->pad chain, while the canonicalization pattern can
+    // only recognize slice->pad->slice->pad.
+    {
+      SmallVector<tensor::PadOp> padOps;
+      funcOp.walk([&](tensor::PadOp op) { padOps.push_back(op); });
+      for (auto op : padOps) {
+        auto srcExtractSliceOp =
+            op.getSource().getDefiningOp<tensor::ExtractSliceOp>();
+        if (!srcExtractSliceOp) {
+          continue;
+        }
+        auto producerPadOp =
+            srcExtractSliceOp.getSource().getDefiningOp<tensor::PadOp>();
+        if (!producerPadOp) {
+          continue;
+        }
+        auto src = producerPadOp.getSource()
+                       .getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
+        if (!src) {
+          continue;
+        }
+
+        rewriter.setInsertionPointAfter(src);
+        SmallVector<OpFoldResult> sizes =
+            tensor::getMixedSizes(rewriter, op.getLoc(), src);
+        SmallVector<OpFoldResult> offsets(sizes.size(),
+                                          rewriter.getIndexAttr(0));
+        SmallVector<OpFoldResult> strides(sizes.size(),
+                                          rewriter.getIndexAttr(1));
+        auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
+            op.getLoc(), src.getResult(), offsets, sizes, strides);
+        rewriter.startOpModification(op);
+        op.getSourceMutable().assign(extractSliceOp.getResult());
+        rewriter.finalizeOpModification(op);
+      }
+
+      RewritePatternSet patterns(ctx);
+      tensor::PadOp::getCanonicalizationPatterns(patterns, ctx);
+      if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+        return signalPassFailure();
+      }
+    }
+  }
+};
+} // namespace
+
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createLLVMGPUPromoteMatmulToFitMMAPass(LLVMGPUMatmulPadOption option) {
+  return std::make_unique<LLVMGPUPromoteMatmulToFitMMAPass>(option);
+}
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/PassDetail.h b/compiler/src/iree/compiler/Codegen/LLVMGPU/PassDetail.h
index 1d830a2..f8cb91b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/PassDetail.h
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/PassDetail.h
@@ -7,6 +7,7 @@
 #ifndef IREE_COMPILER_CODEGEN_LLVMGPU_PASS_DETAIL_H_
 #define IREE_COMPILER_CODEGEN_LLVMGPU_PASS_DETAIL_H_
 
+#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
 #include "iree/compiler/Dialect/HAL/IR/HALOps.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
index f99cfec..b16382a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
@@ -107,6 +107,12 @@
 std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
 createLLVMGPUPrefetchSharedMemoryPass();
 
+/// Pass to pad operations on tensors in top-down order.
+enum class LLVMGPUMatmulPadOption { ParallelDims, ReductionDims };
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createLLVMGPUPromoteMatmulToFitMMAPass(
+    LLVMGPUMatmulPadOption option = LLVMGPUMatmulPadOption::ParallelDims);
+
 enum class GPUTensorCoreType {
   WMMA = 0,
   MMA_SYNC = 1,
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
index cb3349e..b4176a5 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
@@ -78,6 +78,25 @@
   let constructor = "mlir::iree_compiler::createLLVMGPUPrefetchSharedMemoryPass()";
 }
 
+def LLVMGPUPromoteMatmulToFitMMA :
+    InterfacePass<"iree-llvmgpu-promote-matmul-to-fit-mma", "mlir::FunctionOpInterface"> {
+  let summary = "Pass to promote contraction ops to fit mma shapes";
+  let constructor = "mlir::iree_compiler::createLLVMGPUPromoteMatmulToFitMMAPass()";
+  let options = [
+    Option<"targetDimensions", "target-dimensions", "mlir::iree_compiler::LLVMGPUMatmulPadOption",
+           /*default=*/"mlir::iree_compiler::LLVMGPUMatmulPadOption::ParallelDims",
+           "Select the strategy to control how multi_reduction is lowered.",
+           [{::llvm::cl::values(
+            clEnumValN(mlir::iree_compiler::LLVMGPUMatmulPadOption::ParallelDims,
+                       "parallel",
+                       "Pad all the parallel dims for contraction ops."),
+            clEnumValN(mlir::iree_compiler::LLVMGPUMatmulPadOption::ReductionDims,
+                       "reduction",
+                       "Pad all the reduction dims for contraction ops.")
+        )}]>
+  ];
+}
+
 def LLVMGPUSelectLoweringStrategy :
     Pass<"iree-llvmgpu-select-lowering-strategy", "ModuleOp"> {
   let summary = "Select a IREE::HAL::DispatchLoweringPassPipeline for lowering the target variant";
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
index a7b682b..a91c931 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
@@ -55,6 +55,7 @@
             "pack_pipeline_test.mlir",
             "pack_shared_memory_alloc.mlir",
             "prefetch_shared_memory.mlir",
+            "promote_matmul_to_fit_mma.mlir",
             "tensor_pad.mlir",
             "tensorcore_vectorization.mlir",
             "transform_dialect_bufferize.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index 567d359..2ddaeea 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -42,6 +42,7 @@
     "pack_pipeline_test.mlir"
     "pack_shared_memory_alloc.mlir"
     "prefetch_shared_memory.mlir"
+    "promote_matmul_to_fit_mma.mlir"
     "reduction_pipeline_cuda.mlir"
     "reduction_pipeline_rocm.mlir"
     "reduction_pipeline_transform_cuda.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/promote_matmul_to_fit_mma.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/promote_matmul_to_fit_mma.mlir
new file mode 100644
index 0000000..888602b
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/promote_matmul_to_fit_mma.mlir
@@ -0,0 +1,138 @@
+// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-llvmgpu-promote-matmul-to-fit-mma{target-dimensions=parallel}))"  %s | FileCheck %s --check-prefixes=ALL,PARALLEL
+// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-llvmgpu-promote-matmul-to-fit-mma{target-dimensions=reduction}))" %s | FileCheck %s --check-prefixes=ALL,REDUCTION
+
+#map = affine_map<()[s0] -> (s0 * 64)>
+#map1 = affine_map<()[s0] -> (s0 * 128)>
+#map2 = affine_map<()[s0] -> (s0 * -64 + 968, 64)>
+#map3 = affine_map<()[s0] -> (s0 * -128 + 1281, 128)>
+#map4 = affine_map<()[s0] -> (-s0 + 64)>
+#map5 = affine_map<()[s0] -> (-s0 + 128)>
+func.func @batch_matmul_f16() {
+  %cst = arith.constant 0.000000e+00 : f16
+  %c0 = arith.constant 0 : index
+  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x968x1281xf16>>
+  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x1281x1281xf16>>
+  %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<64x968x1281xf16>>
+  %workgroup_id_z = hal.interface.workgroup.id[2] : index
+  %workgroup_id_y = hal.interface.workgroup.id[1] : index
+  %3 = affine.apply #map()[%workgroup_id_y]
+  %workgroup_id_x = hal.interface.workgroup.id[0] : index
+  %4 = affine.apply #map1()[%workgroup_id_x]
+  %5 = affine.min #map2()[%workgroup_id_y]
+  %6 = affine.min #map3()[%workgroup_id_x]
+  %7 = flow.dispatch.tensor.load %2, offsets = [%workgroup_id_z, %3, %4], sizes = [1, %5, %6], strides = [1, 1, 1] : !flow.dispatch.tensor<writeonly:tensor<64x968x1281xf16>> -> tensor<1x?x?xf16>
+  %8 = flow.dispatch.tensor.load %0, offsets = [%workgroup_id_z, %3, 0], sizes = [1, %5, 1281], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<64x968x1281xf16>> -> tensor<1x?x1281xf16>
+  %9 = flow.dispatch.tensor.load %1, offsets = [%workgroup_id_z, 0, %4], sizes = [1, 1281, %6], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<64x1281x1281xf16>> -> tensor<1x1281x?xf16>
+  %10 = linalg.fill ins(%cst : f16) outs(%7 : tensor<1x?x?xf16>) -> tensor<1x?x?xf16>
+  %11 = linalg.batch_matmul ins(%8, %9 : tensor<1x?x1281xf16>, tensor<1x1281x?xf16>) outs(%10 : tensor<1x?x?xf16>) -> tensor<1x?x?xf16>
+  flow.dispatch.tensor.store %11, %2, offsets = [%workgroup_id_z, %3, %4], sizes = [1, %5, %6], strides = [1, 1, 1] : tensor<1x?x?xf16> -> !flow.dispatch.tensor<writeonly:tensor<64x968x1281xf16>>
+  return
+}
+// ALL-LABEL:     func.func @batch_matmul_f16
+// ALL:             %[[LHS_HANDLE:.+]] = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x968x1281xf16>>
+// ALL:             %[[RHS_HANDLE:.+]] = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x1281x1281xf16>>
+// ALL:             %[[OUT_HANDLE:.+]] = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<64x968x1281xf16>>
+// ALL-DAG:         %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_HANDLE]]
+// ALL-DAG:         %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_HANDLE]]
+// PARALLEL:        %[[PADDED_LHS:.+]] = tensor.pad %[[LHS]]
+// PARALLEL:        } : tensor<1x?x1281xf16> to tensor<1x64x1281xf16>
+// PARALLEL:        %[[PADDED_RHS:.+]] = tensor.pad %[[RHS]]
+// PARALLEL:        } : tensor<1x1281x?xf16> to tensor<1x1281x128xf16>
+// PARALLEL:        %[[FILL:.+]] = linalg.fill
+// PARALLEL:        %[[GEMM:.+]] = linalg.batch_matmul
+// PARALLEL-SAME:     ins(%[[PADDED_LHS]], %[[PADDED_RHS]]
+// PARALLEL-SAME:     outs(%[[FILL]]
+
+// The reduction dim is not tiled in the test case, so it pads it to the same
+// shape.
+// REDUCTION-DAG:   %[[FILL_DEST:.+]] = flow.dispatch.tensor.load %[[OUT_HANDLE]]
+// REDUCTION:       %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[FILL_DEST]]
+// REDUCTION:       %[[PADDED_LHS:.+]] = tensor.pad %[[LHS]]
+// REDUCTION:       } : tensor<1x?x1281xf16> to tensor<1x?x1281xf16>
+// REDUCTION:       %[[PADDED_RHS:.+]] = tensor.pad %[[RHS]]
+// REDUCTION:       } : tensor<1x1281x?xf16> to tensor<1x1281x?xf16>
+// REDUCTION:       %[[GEMM:.+]] = linalg.batch_matmul
+// REDUCTION-SAME:    ins(%[[PADDED_LHS]], %[[PADDED_RHS]]
+// REDUCTION-SAME:    outs(%[[FILL]]
+
+// ALL:             %[[OUT_SLICE:.+]] = tensor.extract_slice %[[GEMM]]
+// ALL:             flow.dispatch.tensor.store %[[OUT_SLICE]], %[[OUT_HANDLE]]
+
+// -----
+
+#map = affine_map<()[s0] -> (s0 * 64)>
+#map1 = affine_map<()[s0] -> (s0 * 128)>
+#map2 = affine_map<()[s0] -> (s0 * -64 + 968, 64)>
+#map3 = affine_map<()[s0] -> (s0 * -128 + 1281, 128)>
+#map4 = affine_map<()[s0] -> (-s0 + 64)>
+#map5 = affine_map<()[s0] -> (-s0 + 128)>
+#map6 = affine_map<(d0) -> (-d0 + 1281, 64)>
+func.func @batch_matmul_pad_reduction_after_tiling() {
+  %c64 = arith.constant 64 : index
+  %c1281 = arith.constant 1281 : index
+  %c2 = arith.constant 2 : index
+  %c1 = arith.constant 1 : index
+  %cst = arith.constant 0.000000e+00 : f16
+  %c0 = arith.constant 0 : index
+  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x968x1281xf16>>
+  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x1281x1281xf16>>
+  %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<64x968x1281xf16>>
+  %workgroup_id_z = hal.interface.workgroup.id[2] : index
+  %workgroup_id_y = hal.interface.workgroup.id[1] : index
+  %3 = affine.apply #map()[%workgroup_id_y]
+  %workgroup_id_x = hal.interface.workgroup.id[0] : index
+  %4 = affine.apply #map1()[%workgroup_id_x]
+  %5 = affine.min #map2()[%workgroup_id_y]
+  %6 = affine.min #map3()[%workgroup_id_x]
+  %7 = flow.dispatch.tensor.load %0, offsets = [%workgroup_id_z, %3, 0], sizes = [1, %5, 1281], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<64x968x1281xf16>> -> tensor<1x?x1281xf16>
+  %dim = tensor.dim %7, %c1 : tensor<1x?x1281xf16>
+  %extracted_slice = tensor.extract_slice %7[0, 0, 0] [1, %dim, 1281] [1, 1, 1] : tensor<1x?x1281xf16> to tensor<1x?x1281xf16>
+  %8 = flow.dispatch.tensor.load %1, offsets = [%workgroup_id_z, 0, %4], sizes = [1, 1281, %6], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<64x1281x1281xf16>> -> tensor<1x1281x?xf16>
+  %dim_0 = tensor.dim %8, %c2 : tensor<1x1281x?xf16>
+  %extracted_slice_1 = tensor.extract_slice %8[0, 0, 0] [1, 1281, %dim_0] [1, 1, 1] : tensor<1x1281x?xf16> to tensor<1x1281x?xf16>
+  %9 = affine.apply #map4()[%5]
+  %padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %9, 0] {
+  ^bb0(%arg0: index, %arg1: index, %arg2: index):
+    tensor.yield %cst : f16
+  } : tensor<1x?x1281xf16> to tensor<1x64x1281xf16>
+  %10 = affine.apply #map5()[%6]
+  %padded_2 = tensor.pad %extracted_slice_1 low[0, 0, 0] high[0, 0, %10] {
+  ^bb0(%arg0: index, %arg1: index, %arg2: index):
+    tensor.yield %cst : f16
+  } : tensor<1x1281x?xf16> to tensor<1x1281x128xf16>
+  %11 = tensor.empty() : tensor<1x64x128xf16>
+  %12 = linalg.fill ins(%cst : f16) outs(%11 : tensor<1x64x128xf16>) -> tensor<1x64x128xf16>
+  %13 = scf.for %arg0 = %c0 to %c1281 step %c64 iter_args(%arg1 = %12) -> (tensor<1x64x128xf16>) {
+    %14 = affine.min #map6(%arg0)
+    %extracted_slice_4 = tensor.extract_slice %padded[0, 0, %arg0] [1, 64, %14] [1, 1, 1] : tensor<1x64x1281xf16> to tensor<1x64x?xf16>
+    %extracted_slice_5 = tensor.extract_slice %padded_2[0, %arg0, 0] [1, %14, 128] [1, 1, 1] : tensor<1x1281x128xf16> to tensor<1x?x128xf16>
+    %15 = linalg.batch_matmul ins(%extracted_slice_4, %extracted_slice_5 : tensor<1x64x?xf16>, tensor<1x?x128xf16>) outs(%arg1 : tensor<1x64x128xf16>) -> tensor<1x64x128xf16>
+    scf.yield %15 : tensor<1x64x128xf16>
+  }
+  %extracted_slice_3 = tensor.extract_slice %13[0, 0, 0] [1, %5, %6] [1, 1, 1] : tensor<1x64x128xf16> to tensor<1x?x?xf16>
+  flow.dispatch.tensor.store %extracted_slice_3, %2, offsets = [%workgroup_id_z, %3, %4], sizes = [1, %5, %6], strides = [1, 1, 1] : tensor<1x?x?xf16> -> !flow.dispatch.tensor<writeonly:tensor<64x968x1281xf16>>
+  return
+}
+// The padding on parallel dims is a nop because they are already padded. Skip
+// the check for the testcase.
+// ALL-LABEL:     func.func @batch_matmul_pad_reduction_after_tiling
+// ALL:             %[[LHS_HANDLE:.+]] = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x968x1281xf16>>
+// ALL:             %[[RHS_HANDLE:.+]] = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x1281x1281xf16>>
+// ALL:             %[[OUT_HANDLE:.+]] = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<64x968x1281xf16>>
+// ALL-DAG:         %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_HANDLE]]
+// ALL-DAG:         %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_HANDLE]]
+// REDUCTION:       %[[INIT:.+]] = tensor.empty() : tensor<1x64x128xf16>
+// REDUCTION:       %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[INIT]]
+// REDUCTION:       %[[RES:.+]] = scf.for {{.+}} iter_args(%[[ITER:.+]] = %[[FILL]])
+// REDUCTION:         %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]]
+// REDUCTION:         %[[PADDED_LHS:.+]] = tensor.pad %[[LHS_SLICE]]
+// REDUCTION:         } : tensor<1x?x?xf16> to tensor<1x64x64xf16>
+// REDUCTION:         %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]]
+// REDUCTION:         %[[PADDED_RHS:.+]] = tensor.pad %[[RHS_SLICE]]
+// REDUCTION:         } : tensor<1x?x?xf16> to tensor<1x64x128xf16>
+// REDUCTION:         %[[GEMM:.+]] = linalg.batch_matmul
+// REDUCTION-SAME:      ins(%[[PADDED_LHS]], %[[PADDED_RHS]]
+// REDUCTION-SAME:      outs(%[[ITER]]
+// REDUCTION:         scf.yield %[[GEMM]]
+// REDUCTION:       %[[OUT_SLICE:.+]] = tensor.extract_slice %[[RES]]
+// REDUCTION:       flow.dispatch.tensor.store %[[OUT_SLICE]], %[[OUT_HANDLE]]