[LLVMGPU] Fix MMA schedule validation for unaligned shapes (#17317)
Schedule validation was too liberal when alignment is not required. This
PR restricts the heuristic validation more for unaligned problem sizes.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
index 7102221..70ff8ed 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
@@ -30,12 +30,24 @@
}
bool isValidSchedule(const GPUMatmulShapeType &problem,
- const GPUMMASchedule &schedule) {
- bool isValidM = (problem.mSize % (schedule.mSize * schedule.mTileCount *
- schedule.mWarpCount)) == 0;
- bool isValidN = (problem.nSize % (schedule.nSize * schedule.nTileCount *
- schedule.nWarpCount)) == 0;
- bool isValidK = (problem.kSize % (schedule.kSize * schedule.kTileCount)) == 0;
+ const GPUMMASchedule &schedule, const bool mustBeAligned) {
+ auto alignedMSize =
+ mustBeAligned
+ ? problem.mSize
+ : llvm::divideCeil(problem.mSize, schedule.mSize) * schedule.mSize;
+ auto alignedNSize =
+ mustBeAligned
+ ? problem.nSize
+ : llvm::divideCeil(problem.nSize, schedule.nSize) * schedule.nSize;
+ auto alignedKSize =
+ mustBeAligned
+ ? problem.kSize
+ : llvm::divideCeil(problem.kSize, schedule.kSize) * schedule.kSize;
+ bool isValidM = (alignedMSize % (schedule.mSize * schedule.mTileCount *
+ schedule.mWarpCount)) == 0;
+ bool isValidN = (alignedNSize % (schedule.nSize * schedule.nTileCount *
+ schedule.nWarpCount)) == 0;
+ bool isValidK = (alignedKSize % (schedule.kSize * schedule.kTileCount)) == 0;
return isValidN && isValidM && isValidK;
}
@@ -49,7 +61,7 @@
int64_t rhsBitwidth =
intrinsics[schedule.index].bType.getIntOrFloatBitWidth();
- while ((!isValidSchedule(problem, schedule) && mustBeAligned) ||
+ while (!isValidSchedule(problem, schedule, mustBeAligned) ||
calculateSharedMemoryUsedInBytes(schedule, lhsBitwidth, rhsBitwidth) >
sharedMemLimitInBytes) {
LLVM_DEBUG({
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir
index d278be2..d72b7da 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir
@@ -229,6 +229,35 @@
// -----
+// CHECK: #[[$TILE_SIZES:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 16, 128, 128]{{\]}}
+// CHECK: #iree_codegen.translation_info<LLVMGPUPadAndVectorDistribute
+// CHECK-SAME: mma_schedule = #iree_gpu.mma_schedule
+// CHECK-SAME: intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
+// CHECK-SAME: subgroup_m_count = 1, subgroup_n_count = 4
+
+#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {mma_intrinsics = [#iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>], target_arch = "gfx940"}>
+module {
+ func.func @unaligned_m_batch_matmul_64x72x1280x1280() attributes {hal.executable.target = #executable_target_rocm_hsaco_fb} {
+ %cst = arith.constant 0.000000e+00 : f16
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x72x1280xf16>>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x1280x1280xf16>>
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<64x72x1280xf16>>
+ %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [64, 72, 1280], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<64x72x1280xf16>> -> tensor<64x72x1280xf16>
+ %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [64, 1280, 1280], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<64x1280x1280xf16>> -> tensor<64x1280x1280xf16>
+ %5 = tensor.empty() : tensor<64x72x1280xf16>
+ %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<64x72x1280xf16>) -> tensor<64x72x1280xf16>
+ %7 = linalg.batch_matmul ins(%3, %4 : tensor<64x72x1280xf16>, tensor<64x1280x1280xf16>) outs(%6 : tensor<64x72x1280xf16>) -> tensor<64x72x1280xf16>
+ flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0], sizes = [64, 72, 1280], strides = [1, 1, 1] : tensor<64x72x1280xf16> -> !flow.dispatch.tensor<writeonly:tensor<64x72x1280xf16>>
+ return
+ }
+}
+// CHECK-LABEL: func.func @unaligned_m_batch_matmul_64x72x1280x1280()
+// CHECK: linalg.batch_matmul
+// CHECK-SAME: lowering_config = #[[$TILE_SIZES]]
+
+// -----
+
#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {mma_intrinsics = [#iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>], target_arch = "gfx940"}>
module {
func.func @narrow_n_batch_matmul_64x968x4x320_f16() attributes {hal.executable.target = #executable_target_rocm_hsaco_fb} {