[LLVMGPU] Pad to intrinsic shape in LLVMGPUPadAndVectorDistribute pipeline (#18632)

This patch makes LLVMGPUPromoteToFitMMA pass pad to a multiple of
intrinsic shape, instead of padding to 1.

Fixes https://github.com/iree-org/iree/issues/18602
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUPromoteMatmulToFitMMA.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUPromoteMatmulToFitMMA.cpp
index 26b96d7..b11573b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUPromoteMatmulToFitMMA.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUPromoteMatmulToFitMMA.cpp
@@ -36,23 +36,17 @@
   }
 
   void padWithZeroValue(RewriterBase &rewriter, linalg::LinalgOp op,
-                        utils::IteratorType targetIterType, bool nofold) const {
+                        ArrayRef<int64_t> paddingDims,
+                        ArrayRef<int64_t> padToMultipleOf, bool noFold) const {
+    assert(paddingDims.size() == padToMultipleOf.size() &&
+           "invalid pad multiples for padding dimensions");
+
     LLVM_DEBUG(llvm::dbgs() << "candidate: " << op << "\n");
     OpBuilder::InsertionGuard guard(rewriter);
     rewriter.setInsertionPointAfter(op);
 
-    SmallVector<int64_t> paddingDims;
-    for (auto [index, iterType] : llvm::enumerate(op.getIteratorTypesArray())) {
-      if (iterType == targetIterType) {
-        paddingDims.push_back(index);
-      }
-    }
+    SmallVector<bool> packPaddings(op.getNumDpsInputs(), noFold);
 
-    SmallVector<bool> packPaddings(op.getNumDpsInputs(), nofold);
-
-    // One is enough because they will essentially be padded to corresponding
-    // tile sizes, which should be multiple of MMA shapes.
-    SmallVector<int64_t> padToMultipleOf(paddingDims.size(), 1);
     SmallVector<Attribute> paddingValueAttributes;
     for (auto &operand : op->getOpOperands()) {
       auto elemType = getElementTypeOrSelf(operand.get().getType());
@@ -80,18 +74,18 @@
 
     // Preserve the innermost tensor.pad ops (i.e., pad for reduction dims), so
     // we can kick canonicalization patterns to fold outer tensor.pad ops away.
-    bool nofold = false;
+    bool noFold = false;
     utils::IteratorType targetIterType = utils::IteratorType::parallel;
     switch (targetDimensions) {
     case LLVMGPUMatmulPadOption::ParallelDims:
       LLVM_DEBUG(llvm::dbgs() << "padding parallel dims\n");
       targetIterType = utils::IteratorType::parallel;
-      nofold = false;
+      noFold = false;
       break;
     case LLVMGPUMatmulPadOption::ReductionDims:
       LLVM_DEBUG(llvm::dbgs() << "padding reduction dims\n");
       targetIterType = utils::IteratorType::reduction;
-      nofold = true;
+      noFold = true;
       break;
     default: // Unreachable.
       assert(false);
@@ -106,8 +100,47 @@
     });
 
     IRRewriter rewriter(ctx);
-    for (auto op : candidates) {
-      padWithZeroValue(rewriter, op, targetIterType, nofold);
+    for (linalg::LinalgOp op : candidates) {
+      SmallVector<int64_t> padMultiples(op.getNumLoops(), 1);
+      auto config = dyn_cast_or_null<IREE::GPU::LoweringConfigAttr>(
+          getLoweringConfig(op));
+      if (config) {
+        switch (targetDimensions) {
+        case LLVMGPUMatmulPadOption::ParallelDims:
+          padMultiples = config.getStaticTilingLevelSizes(
+              static_cast<unsigned>(IREE::GPU::TilingLevel::Workgroup), op);
+          break;
+        case LLVMGPUMatmulPadOption::ReductionDims:
+          padMultiples = config.getStaticTilingLevelSizes(
+              static_cast<unsigned>(IREE::GPU::TilingLevel::Reduction), op);
+          break;
+        default:
+          assert(false && "Unexpected target dimensions");
+          break;
+        }
+      }
+
+      // Populate padding dimensions.
+      SmallVector<int64_t> paddingDimensions;
+      for (auto [idx, iter] : llvm::enumerate(op.getIteratorTypesArray())) {
+        if (iter == targetIterType) {
+          paddingDimensions.push_back(idx);
+        }
+      }
+
+      // Populate tile sizes. We pad to multiples of workgroup/reduction
+      // tile sizes based on the selected target tiling dimensions.
+      // This pass is ran after the select target tiling is done to pad
+      // all dimensions to the select tile sizes.
+      SmallVector<int64_t> padToMultipleOf;
+      for (int64_t dim : paddingDimensions) {
+        if (padMultiples[dim] != 0) {
+          padToMultipleOf.push_back(padMultiples[dim]);
+        }
+      }
+
+      padWithZeroValue(rewriter, op, paddingDimensions, padToMultipleOf,
+                       noFold);
     }
 
     {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
index 6ba9ba8..49e20f4 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
@@ -484,6 +484,67 @@
 
 // -----
 
+#config = #iree_gpu.lowering_config<{workgroup = [1, 16, 32, 0], reduction = [0, 0, 0, 8]}>
+#translation = #iree_codegen.translation_info<LLVMGPUPadAndVectorDistribute workgroup_size = [128, 1, 1] subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>, subgroup_m_count = 1, subgroup_n_count = 2>}>
+
+#pipeline_layout = #hal.pipeline.layout<bindings = [
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>
+]>
+
+hal.executable public @pad_batch_matmul {
+  hal.executable.variant public @rocm_hsaco_fb target(#hal.executable.target<"rocm", "rocm-hsaco-fb">) {
+    hal.executable.export public @pad_batch_matmul ordinal(0) layout(#pipeline_layout) {
+    ^bb0(%arg0: !hal.device):
+      %x, %y, %z = flow.dispatch.workgroup_count_from_slice
+      hal.return %x, %y, %z : index, index, index
+    }
+    builtin.module {
+      func.func @pad_batch_matmul() attributes {translation_info = #translation} {
+        %cst = arith.constant 0.000000e+00 : f32
+        %c0 = arith.constant 0 : index
+        %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<196x16x24xf32>>
+        %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<196x24x24xf32>>
+        %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<196x16x24xf32>>
+        %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [196, 16, 24], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<196x16x24xf32>> -> tensor<196x16x24xf32>
+        %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [196, 24, 24], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<196x24x24xf32>> -> tensor<196x24x24xf32>
+        %5 = tensor.empty() : tensor<196x16x24xf32>
+        %6 = linalg.fill {lowering_config = #config} ins(%cst : f32) outs(%5 : tensor<196x16x24xf32>) -> tensor<196x16x24xf32>
+        %7 = linalg.batch_matmul {lowering_config = #config} ins(%3, %4 : tensor<196x16x24xf32>, tensor<196x24x24xf32>) outs(%6 : tensor<196x16x24xf32>) -> tensor<196x16x24xf32>
+        flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0], sizes = [196, 16, 24], strides = [1, 1, 1] : tensor<196x16x24xf32> -> !flow.dispatch.tensor<writeonly:tensor<196x16x24xf32>>
+        return
+      }
+    }
+  }
+}
+
+// This test checks if we can handle an unaligned batch matmul which has sizes
+// smaller than the chosen tile sizes. We just want to make sure we can compile
+// this example. We also check if the correct transfer_read/transfer_write are
+// produced with in_bounds attrs for the padded dimensions.
+
+// CHECK-LABEL:   @pad_batch_matmul
+// CHECK:           scf.for
+// LHS
+// CHECK:             vector.transfer_read
+// CHECK-SAME:        in_bounds = [true, true, true]
+// CHECK-SAME:        memref<196x16x24xf32
+// CHECK-SAME:        vector<1x1x1xf32>
+// RHS
+// CHECK:             vector.transfer_read
+// CHECK-SAME:        in_bounds = [true, true, false]
+// CHECK-SAME:        memref<1x8x24xf32
+// CHECK-SAME:        vector<1x1x2xf32>
+// CHECK:           scf.yield
+// OUTPUT
+// CHECK:           vector.transfer_write
+// CHECK-SAME:      in_bounds = [true, true, false]
+// CHECK-SAME:      vector<1x4x1xf32>
+// CHECK-SAME:      memref<1x16x24xf32
+
+// -----
+
 // This test ensures that we are generating contraction schedules does not only work on contraction,
 // but also will be compatible with transfer_read layouts anchors.
 // Currently the transfer_read layout anchors expects WorkgroupSize % (WgTileSize / numelPerThread) == 0.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/promote_matmul_to_fit_mma.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/promote_matmul_to_fit_mma.mlir
index e2ef4ee..bda4836 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/promote_matmul_to_fit_mma.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/promote_matmul_to_fit_mma.mlir
@@ -12,6 +12,7 @@
 #map3 = affine_map<()[s0] -> (s0 * -128 + 1281, 128)>
 #map4 = affine_map<()[s0] -> (-s0 + 64)>
 #map5 = affine_map<()[s0] -> (-s0 + 128)>
+#config = #iree_gpu.lowering_config<{workgroup = [1, 16, 16, 0], reduction = [0, 0, 0, 16]}>
 func.func @batch_matmul_f16() {
   %cst = arith.constant 0.000000e+00 : f16
   %c0 = arith.constant 0 : index
@@ -29,7 +30,7 @@
   %8 = flow.dispatch.tensor.load %0, offsets = [%workgroup_id_z, %3, 0], sizes = [1, %5, 1281], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<64x968x1281xf16>> -> tensor<1x?x1281xf16>
   %9 = flow.dispatch.tensor.load %1, offsets = [%workgroup_id_z, 0, %4], sizes = [1, 1281, %6], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<64x1281x1281xf16>> -> tensor<1x1281x?xf16>
   %10 = linalg.fill ins(%cst : f16) outs(%7 : tensor<1x?x?xf16>) -> tensor<1x?x?xf16>
-  %11 = linalg.batch_matmul ins(%8, %9 : tensor<1x?x1281xf16>, tensor<1x1281x?xf16>) outs(%10 : tensor<1x?x?xf16>) -> tensor<1x?x?xf16>
+  %11 = linalg.batch_matmul {lowering_config = #config} ins(%8, %9 : tensor<1x?x1281xf16>, tensor<1x1281x?xf16>) outs(%10 : tensor<1x?x?xf16>) -> tensor<1x?x?xf16>
   flow.dispatch.tensor.store %11, %2, offsets = [%workgroup_id_z, %3, %4], sizes = [1, %5, %6], strides = [1, 1, 1] : tensor<1x?x?xf16> -> !flow.dispatch.tensor<writeonly:tensor<64x968x1281xf16>>
   return
 }
@@ -48,14 +49,14 @@
 // PARALLEL-SAME:     ins(%[[PADDED_LHS]], %[[PADDED_RHS]]
 // PARALLEL-SAME:     outs(%[[FILL]]
 
-// The reduction dim is not tiled in the test case, so it pads it to the same
-// shape.
+// The reduction dim is not tiled in the test case, so it pads it to the
+// matmul intrinsic k.
 // REDUCTION-DAG:   %[[FILL_DEST:.+]] = flow.dispatch.tensor.load %[[OUT_HANDLE]]
 // REDUCTION:       %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[FILL_DEST]]
 // REDUCTION:       %[[PADDED_LHS:.+]] = tensor.pad %[[LHS]]
-// REDUCTION:       } : tensor<1x?x1281xf16> to tensor<1x?x1281xf16>
+// REDUCTION:       } : tensor<1x?x1281xf16> to tensor<1x?x1296xf16>
 // REDUCTION:       %[[PADDED_RHS:.+]] = tensor.pad %[[RHS]]
-// REDUCTION:       } : tensor<1x1281x?xf16> to tensor<1x1281x?xf16>
+// REDUCTION:       } : tensor<1x1281x?xf16> to tensor<1x1296x?xf16>
 // REDUCTION:       %[[GEMM:.+]] = linalg.batch_matmul
 // REDUCTION-SAME:    ins(%[[PADDED_LHS]], %[[PADDED_RHS]]
 // REDUCTION-SAME:    outs(%[[FILL]]