[LLVMGPU] Fix MMA schedule validation for unaligned shapes (#17317)

Schedule validation was too liberal when alignment is not required. This
PR restricts the heuristic validation more for unaligned problem sizes.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
index 7102221..70ff8ed 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
@@ -30,12 +30,24 @@
 }
 
 bool isValidSchedule(const GPUMatmulShapeType &problem,
-                     const GPUMMASchedule &schedule) {
-  bool isValidM = (problem.mSize % (schedule.mSize * schedule.mTileCount *
-                                    schedule.mWarpCount)) == 0;
-  bool isValidN = (problem.nSize % (schedule.nSize * schedule.nTileCount *
-                                    schedule.nWarpCount)) == 0;
-  bool isValidK = (problem.kSize % (schedule.kSize * schedule.kTileCount)) == 0;
+                     const GPUMMASchedule &schedule, const bool mustBeAligned) {
+  auto alignedMSize =
+      mustBeAligned
+          ? problem.mSize
+          : llvm::divideCeil(problem.mSize, schedule.mSize) * schedule.mSize;
+  auto alignedNSize =
+      mustBeAligned
+          ? problem.nSize
+          : llvm::divideCeil(problem.nSize, schedule.nSize) * schedule.nSize;
+  auto alignedKSize =
+      mustBeAligned
+          ? problem.kSize
+          : llvm::divideCeil(problem.kSize, schedule.kSize) * schedule.kSize;
+  bool isValidM = (alignedMSize % (schedule.mSize * schedule.mTileCount *
+                                   schedule.mWarpCount)) == 0;
+  bool isValidN = (alignedNSize % (schedule.nSize * schedule.nTileCount *
+                                   schedule.nWarpCount)) == 0;
+  bool isValidK = (alignedKSize % (schedule.kSize * schedule.kTileCount)) == 0;
   return isValidN && isValidM && isValidK;
 }
 
@@ -49,7 +61,7 @@
   int64_t rhsBitwidth =
       intrinsics[schedule.index].bType.getIntOrFloatBitWidth();
 
-  while ((!isValidSchedule(problem, schedule) && mustBeAligned) ||
+  while (!isValidSchedule(problem, schedule, mustBeAligned) ||
          calculateSharedMemoryUsedInBytes(schedule, lhsBitwidth, rhsBitwidth) >
              sharedMemLimitInBytes) {
     LLVM_DEBUG({
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir
index d278be2..d72b7da 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir
@@ -229,6 +229,35 @@
 
 // -----
 
+// CHECK:      #[[$TILE_SIZES:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 16, 128, 128]{{\]}}
+// CHECK:      #iree_codegen.translation_info<LLVMGPUPadAndVectorDistribute
+// CHECK-SAME: mma_schedule = #iree_gpu.mma_schedule
+// CHECK-SAME:   intrinsic =  #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
+// CHECK-SAME:   subgroup_m_count = 1, subgroup_n_count = 4
+
+#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {mma_intrinsics = [#iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>], target_arch = "gfx940"}>
+module {
+  func.func @unaligned_m_batch_matmul_64x72x1280x1280() attributes {hal.executable.target = #executable_target_rocm_hsaco_fb} {
+    %cst = arith.constant 0.000000e+00 : f16
+    %c0 = arith.constant 0 : index
+    %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x72x1280xf16>>
+    %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x1280x1280xf16>>
+    %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<64x72x1280xf16>>
+    %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [64, 72, 1280], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<64x72x1280xf16>> -> tensor<64x72x1280xf16>
+    %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [64, 1280, 1280], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<64x1280x1280xf16>> -> tensor<64x1280x1280xf16>
+    %5 = tensor.empty() : tensor<64x72x1280xf16>
+    %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<64x72x1280xf16>) -> tensor<64x72x1280xf16>
+    %7 = linalg.batch_matmul ins(%3, %4 : tensor<64x72x1280xf16>, tensor<64x1280x1280xf16>) outs(%6 : tensor<64x72x1280xf16>) -> tensor<64x72x1280xf16>
+    flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0], sizes = [64, 72, 1280], strides = [1, 1, 1] : tensor<64x72x1280xf16> -> !flow.dispatch.tensor<writeonly:tensor<64x72x1280xf16>>
+    return
+  }
+}
+// CHECK-LABEL: func.func @unaligned_m_batch_matmul_64x72x1280x1280()
+// CHECK:         linalg.batch_matmul
+// CHECK-SAME:      lowering_config = #[[$TILE_SIZES]]
+
+// -----
+
 #executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {mma_intrinsics = [#iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>], target_arch = "gfx940"}>
 module {
   func.func @narrow_n_batch_matmul_64x968x4x320_f16() attributes {hal.executable.target = #executable_target_rocm_hsaco_fb} {