[Codegen][IGEMM] Fix and preserve padding dim order for convs (#21772)

Fix for the first case in https://github.com/iree-org/iree/issues/21660.

The previous logic was trying to determine the position for each
convolution dimensions and assign padding values accordingly which is
fragile given the fact there could be different kinds of convolutions in
generic forms. This PR adds a mapping from dims of a convolution to the
corresponding dims in the GEMM space which is more robust for figuring
the dims for padding on convolutions.

---------

Signed-off-by: yzhang93 <zhyuhang88@gmail.com>
Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
Co-authored-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
index 8324471..3ac1751 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
@@ -206,75 +206,60 @@
 
 /// Helper function to get convolution padding sizes if possible.
 static std::optional<ArrayAttr> getPaddingConvSizes(
-    Builder &b, int64_t kSize, int64_t kPaddingSize,
+    Builder &b, const SmallVector<int64_t> &bounds,
+    const SmallVector<int64_t> &paddingSizes,
     const SmallVector<int64_t> &workgroupTileSizes,
-    const SmallVector<int64_t> &mDims, const SmallVector<int64_t> &nDims,
-    const SmallVector<int64_t> &batchDims,
-    std::optional<mlir::linalg::ConvolutionDimensions> &padConvDims) {
-  if (!padConvDims.has_value())
+    const SmallVector<int64_t> &reductionTileSizes,
+    std::optional<DenseMap<int64_t, AffineExpr>> &convToIgemmDimMap,
+    std::optional<mlir::linalg::ConvolutionDimensions> &convDims) {
+  if (!convToIgemmDimMap.has_value() || !convDims.has_value())
     return std::nullopt;
 
-  SmallVector<unsigned> batchAndImageDims;
-  mlir::linalg::ConvolutionDimensions convDims = padConvDims.value();
-  bool isBatchLast = !convDims.batch.empty() &&
-                     convDims.outputImage.back() < convDims.batch.front();
-  if (isBatchLast) {
-    batchAndImageDims.append(convDims.outputImage.begin(),
-                             convDims.outputImage.end());
-    batchAndImageDims.append(convDims.batch.begin(), convDims.batch.end());
-  } else {
-    batchAndImageDims.append(convDims.batch.begin(), convDims.batch.end());
-    batchAndImageDims.append(convDims.outputImage.begin(),
-                             convDims.outputImage.end());
-  }
-
-  SmallVector<unsigned> concatMDims, concatNDims;
-  bool isOutputChannelFirst =
-      convDims.outputChannel.back() < convDims.outputImage.front();
-  if (isOutputChannelFirst) {
-    concatMDims.append(convDims.outputChannel.begin(),
-                       convDims.outputChannel.end());
-    concatNDims = batchAndImageDims;
-  } else {
-    concatMDims = batchAndImageDims;
-    concatNDims.append(convDims.outputChannel.begin(),
-                       convDims.outputChannel.end());
-  }
-
-  // Verify that the number of M, N dimensions from IGEMM match the
-  // corresponding number of convolution dimensions.
-  if (concatMDims.size() != mDims.size() ||
-      concatNDims.size() != nDims.size() ||
-      convDims.depth.size() != batchDims.size()) {
-    return std::nullopt;
-  }
-
+  DenseMap<int64_t, AffineExpr> convToIgemmMap = convToIgemmDimMap.value();
   // Padding sizes for parallel dimensions are the same as workgroup tile
   // sizes.
-  int64_t totalNumDims = convDims.batch.size() + convDims.outputImage.size() +
-                         convDims.outputChannel.size() +
-                         convDims.filterLoop.size() +
-                         convDims.inputChannel.size() + convDims.depth.size();
-  SmallVector<int64_t> paddingConvSizes(totalNumDims, 0);
-  if (batchDims.size() != 0) {
-    for (auto [dim, bDim] : llvm::zip(convDims.depth, batchDims)) {
-      paddingConvSizes[dim] = workgroupTileSizes[bDim];
+  DenseSet<int64_t> paddedIGEMMDims;
+  DenseMap<int64_t, SmallVector<int64_t>> paddedReductionConvDims;
+  SetVector<int64_t> inputChannelDims(convDims->inputChannel.begin(),
+                                      convDims->inputChannel.end());
+  SmallVector<int64_t> paddingConvSizes(convToIgemmMap.size(), 0);
+  for (auto [convDim, IGEMMExpr] : convToIgemmMap) {
+    auto IGEMMDimExpr = cast<AffineDimExpr>(IGEMMExpr);
+    unsigned IGEMMPos = IGEMMDimExpr.getPosition();
+    if (reductionTileSizes[IGEMMPos] != 0) {
+      // For reduction dimensions, avoid setting padding on the convolution
+      // if the product of the corresponding conv sizes are already divisible
+      // by the padding size.
+      if (paddingSizes[IGEMMPos] &&
+          bounds[IGEMMPos] % paddingSizes[IGEMMPos] == 0) {
+        paddedIGEMMDims.insert(IGEMMPos);
+        continue;
+      }
+      // Only pad input channel dims. If we need to pad filter dims, then we
+      // would rather just do padding on the GEMM instead.
+      if (inputChannelDims.contains(convDim)) {
+        // Multiple input channel dims for a single IGEMMPos is not supported.
+        if (paddedIGEMMDims.contains(IGEMMPos)) {
+          return std::nullopt;
+        }
+        paddingConvSizes[convDim] = paddingSizes[IGEMMPos];
+        paddedIGEMMDims.insert(IGEMMPos);
+      }
+      continue;
     }
+    // Multiple padded parallel dims mapping to the same IGEMM dim is not
+    // supported.
+    if (workgroupTileSizes[IGEMMPos] != 0 &&
+        paddedIGEMMDims.contains(IGEMMPos)) {
+      return std::nullopt;
+    }
+    paddingConvSizes[convDim] = paddingSizes[IGEMMPos];
+    paddedIGEMMDims.insert(IGEMMPos);
   }
-  for (auto [dim, mDim] : llvm::zip(concatMDims, mDims))
-    paddingConvSizes[dim] = workgroupTileSizes[mDim];
-  for (auto [dim, nDim] : llvm::zip(concatNDims, nDims))
-    paddingConvSizes[dim] = workgroupTileSizes[nDim];
 
-  // To avoid over-padding, no padding for channel dimensions is needed if
-  // the product of reduction sizes is already multiples of k padding
-  // size. Otherwise, pad the innermost channel dimension.
-  // TODO (vivian): Padding the innermost channel dimension to a multiple
-  // of vector size may still be needed even if the K-dim is aligned, and
-  // this should be validated based on performance.
-  if (kSize % kPaddingSize != 0) {
-    int64_t innerChannelDim = convDims.inputChannel.back();
-    paddingConvSizes[innerChannelDim] = kPaddingSize;
+  // Ensure that all dimensions have been padded.
+  if (paddedIGEMMDims.size() != paddingSizes.size()) {
+    return std::nullopt;
   }
   return b.getI64ArrayAttr(paddingConvSizes);
 }
@@ -291,7 +276,9 @@
     SmallVector<int64_t> bounds, ArrayRef<AffineMap> maps,
     ArrayRef<Value> operands, IREE::GPU::TargetAttr target, bool useDirectLoad,
     bool scaled,
-    std::optional<mlir::linalg::ConvolutionDimensions> padConvDims = {}) {
+    std::optional<DenseMap<int64_t, AffineExpr>> convToIgemmDimMap =
+        std::nullopt,
+    std::optional<linalg::ConvolutionDimensions> convDims = std::nullopt) {
   if (target.getWgp().getMma().empty()) {
     return failure();
   }
@@ -537,9 +524,9 @@
 
     // Create `padding_conv` attribute when padding convolutions before IGEMM is
     // possible, otherwise fallback to pad IGEMM.
-    if (auto attr = getPaddingConvSizes(
-            b, bounds[innerKDim], paddingTileSizes[innerKDim],
-            workgroupTileSizes, mDims, nDims, batchDims, padConvDims)) {
+    if (auto attr = getPaddingConvSizes(b, bounds, paddingTileSizes,
+                                        workgroupTileSizes, reductionTileSizes,
+                                        convToIgemmDimMap, convDims)) {
       attrs.emplace_back(StringAttr::get(context, "padding_conv"), *attr);
     } else {
       attrs.emplace_back(StringAttr::get(context, "padding"),
@@ -580,15 +567,18 @@
       igemmGenericConvDetails->igemmLoopBounds;
   SmallVector<Value> igemmOperands = igemmGenericConvDetails->igemmOperands;
 
-  std::optional<mlir::linalg::ConvolutionDimensions> padConvDims;
-  if (padConv)
-    padConvDims = igemmGenericConvDetails->convDims;
+  std::optional<DenseMap<int64_t, AffineExpr>> convToIgemmDimMap;
+  std::optional<linalg::ConvolutionDimensions> convDims;
+  if (padConv) {
+    convDims = igemmGenericConvDetails->convDims;
+    convToIgemmDimMap = igemmGenericConvDetails->convToIgemmDimMap;
+  }
 
   SmallVector<int64_t> bounds = igemmLoopBounds;
   FailureOr<std::pair<LoweringConfigAttr, int64_t>> configAndWgSize =
       getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
           bounds, igemmContractionMaps, igemmOperands, target, useDirectLoad,
-          /*scaled*/ false, padConvDims);
+          /*scaled*/ false, convToIgemmDimMap, convDims);
   if (failed(configAndWgSize)) {
     return failure();
   }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir
index 52a35a8..35eb91b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir
@@ -5,7 +5,7 @@
 // RUN: --iree-codegen-llvmgpu-use-igemm=true --iree-codegen-llvmgpu-igemm-pad-convolution=false --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s --check-prefixes=CHECK,MI300X
 
 // RUN: iree-opt --mlir-print-local-scope --split-input-file --iree-gpu-test-target=gfx942 \
-// RUN: --iree-codegen-llvmgpu-use-igemm=true --iree-codegen-llvmgpu-igemm-pad-convolution=true --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s --check-prefix=PAD-CONV
+// RUN: --iree-codegen-llvmgpu-use-igemm=true --iree-codegen-llvmgpu-igemm-pad-convolution=true --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s --check-prefix=PAD-CONV-GFX942
 
 func.func @nhwc_conv_mfma() {
   %cst = arith.constant 0.000000e+00 : f32
@@ -110,7 +110,7 @@
 // MI300X-SAME:     subgroup = [1, 1, 1, 1, 0]
 // MI300X-SAME:     workgroup = [1, 1, 16, 64, 0]
 
-//    PAD-CONV:     padding_conv = [2, 1, 32, 64, 0, 0, 0]
+// PAD-CONV-GFX942:     padding_conv = [2, 1, 32, 64, 0, 0, 0]
 
 // -----
 
@@ -149,7 +149,7 @@
 // MI300X-SAME:     subgroup = [1, 1, 1, 1, 0]
 // MI300X-SAME:     workgroup = [1, 32, 1, 32, 0]
 
-//    PAD-CONV:     padding_conv = [1, 64, 2, 32, 0, 0, 0]
+// PAD-CONV-GFX942:     padding_conv = [1, 64, 2, 32, 0, 0, 0]
 
 // -----
 
@@ -188,7 +188,7 @@
 // MI300X-SAME:     subgroup = [1, 4, 1, 1, 0]
 // MI300X-SAME:     workgroup = [1, 4, 32, 32, 0]
 
-//    PAD-CONV:     padding_conv = [1, 8, 32, 32, 0, 0, 32]
+// PAD-CONV-GFX942:     padding_conv = [1, 8, 32, 32, 0, 0, 32]
 
 // -----
 
@@ -220,7 +220,7 @@
 //  CHECK-SAME:     subgroup = [1, 1, 1, 1, 0]
 //  CHECK-SAME:     workgroup = [16, 1, 1, 16, 0]
 
-//    PAD-CONV:     padding_conv = [16, 1, 1, 16, 0, 0, 0]
+// PAD-CONV-GFX942:     padding_conv = [16, 1, 1, 16, 0, 0, 0]
 
 // -----
 
@@ -258,4 +258,44 @@
 // MI300X-SAME:     subgroup = [1, 1, 0, 1, 0]
 // MI300X-SAME:     workgroup = [1, 32, 1, 32, 0]
 
-//    PAD-CONV:     padding_conv = [1, 32, 1, 64, 0, 0, 32]
+// PAD-CONV-GFX942:     padding_conv = [1, 32, 1, 64, 0, 0, 32]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2, d5)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+module {
+  func.func @conv_nhwc_filter_5x1_unaligned(%arg0: tensor<16x42x19x64xbf16>, %arg1: tensor<64x5x64xbf16>, %arg2: tensor<16x38x19x64xf32>) -> tensor<16x38x19x64xf32> {
+    %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<16x42x19x64xbf16>, tensor<64x5x64xbf16>) outs(%arg2 : tensor<16x38x19x64xf32>) {
+    ^bb0(%in: bf16, %in_0: bf16, %out: f32):
+      %1 = arith.extf %in : bf16 to f32
+      %2 = arith.extf %in_0 : bf16 to f32
+      %3 = arith.mulf %1, %2 : f32
+      %4 = arith.addf %out, %3 : f32
+      linalg.yield %4 : f32
+    } -> tensor<16x38x19x64xf32>
+    return %0 : tensor<16x38x19x64xf32>
+  }
+}
+
+// CHECK-LABEL: func.func @conv_nhwc_filter_5x1_unaligned
+//  CHECK-SAME:   #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
+//  CHECK-SAME:   #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false
+//  CHECK-SAME:   use_igemm_convolution = true
+
+//       CHECK:   linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
+// GFX942-SAME:     mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16>
+// GFX942-SAME:     padding = [2, 2, 32, 64, 32]
+// GFX942-SAME:     promote_operands = [0, 1, 2]
+// GFX942-SAME:     reduction = [0, 0, 0, 0, 2]
+// GFX942-SAME:     subgroup = [2, 2, 2, 1, 0]
+// GFX942-SAME:     workgroup = [2, 2, 32, 64, 0]
+
+// MI300X-SAME:     padding = [1, 1, 32, 64, 32]
+// MI300X-SAME:     promote_operands = [0, 1, 2]
+// MI300X-SAME:     reduction = [0, 0, 0, 0, 2]
+// MI300X-SAME:     subgroup = [1, 1, 2, 1, 0]
+// MI300X-SAME:     workgroup = [1, 1, 32, 64, 0]
+
+// PAD-CONV-GFX942:     padding_conv = [2, 2, 32, 64, 0, 0]
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
index 797726b..5d6aeaa 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
@@ -584,12 +584,17 @@
   auto inputMapGEMM =
       AffineMap::get(numParallelDims + numKDims, 0, inputDims, ctx);
 
-  // Prepare filter map.
+  // Prepare filter map and add mapping for reduction dimensions.
   int64_t currKPos = numParallelDims;
   SmallVector<AffineExpr> filterDims;
   for (const auto &[iter, indices] :
        llvm::zip_equal(filterIterators, filterReassocIndices)) {
     if (iter == reduction) {
+      for (int64_t reInd : indices) {
+        int64_t convDimIdx =
+            cast<AffineDimExpr>(filterMap.getResult(reInd)).getPosition();
+        convToIgemmDimMap[convDimIdx] = dims[currKPos];
+      }
       filterDims.push_back(dims[currKPos++]);
     } else {
       assert(iter == parallel && "expected a parallel dim");