[Codegen][IGEMM] Do not pre-pad convs with CHW layout or small input channel size (#21839)

For CHW layout, there is no need to vectorize the channel dimension for
im2col ops. Pre-pad input channel dimension may cause over padding when
filter H/W size is large, since the reduction dimensions will be
collapsed during im2col transform.

For HWC layout, when the input channel size is small (e.g., inputChannel
= 3, and paddingSize = 32), pre-padding may also cause perf regression
because of overpad, remove such cases from pre-padding path.

---------

Signed-off-by: yzhang93 <zhyuhang88@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
index 6909fbc..21e77c1 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
@@ -357,24 +357,32 @@
   return schedule;
 }
 
+struct ConvToIgemmInfo {
+  bool isInputChannelLast;
+  linalg::ConvolutionDimensions convDims;
+  DenseMap<int64_t, AffineExpr> convToIgemmDimMap;
+  DenseMap<int64_t, int64_t> inputChannelDimToSize;
+};
+
 /// Helper function to get convolution padding sizes if possible.
-static std::optional<ArrayAttr> getPaddingConvSizes(
-    Builder &b, const SmallVector<int64_t> &bounds,
-    const SmallVector<int64_t> &paddingSizes,
-    const SmallVector<int64_t> &workgroupTileSizes,
-    const SmallVector<int64_t> &reductionTileSizes,
-    std::optional<DenseMap<int64_t, AffineExpr>> &convToIgemmDimMap,
-    std::optional<mlir::linalg::ConvolutionDimensions> &convDims) {
-  if (!convToIgemmDimMap.has_value() || !convDims.has_value())
+static std::optional<ArrayAttr>
+getPaddingConvSizes(Builder &b, const SmallVector<int64_t> &bounds,
+                    const SmallVector<int64_t> &paddingSizes,
+                    const SmallVector<int64_t> &workgroupTileSizes,
+                    const SmallVector<int64_t> &reductionTileSizes,
+                    std::optional<ConvToIgemmInfo> &convToIgemmInfo) {
+  if (!convToIgemmInfo.has_value())
     return std::nullopt;
 
-  DenseMap<int64_t, AffineExpr> convToIgemmMap = convToIgemmDimMap.value();
+  DenseMap<int64_t, AffineExpr> convToIgemmMap =
+      convToIgemmInfo->convToIgemmDimMap;
   // Padding sizes for parallel dimensions are the same as workgroup tile
   // sizes.
   DenseSet<int64_t> paddedIGEMMDims;
   DenseMap<int64_t, SmallVector<int64_t>> paddedReductionConvDims;
-  SetVector<int64_t> inputChannelDims(convDims->inputChannel.begin(),
-                                      convDims->inputChannel.end());
+  linalg::ConvolutionDimensions convDims = convToIgemmInfo->convDims;
+  SetVector<int64_t> inputChannelDims(convDims.inputChannel.begin(),
+                                      convDims.inputChannel.end());
   SmallVector<int64_t> paddingConvSizes(convToIgemmMap.size(), 0);
   for (auto [convDim, IGEMMExpr] : convToIgemmMap) {
     auto IGEMMDimExpr = cast<AffineDimExpr>(IGEMMExpr);
@@ -391,8 +399,16 @@
       // Only pad input channel dims. If we need to pad filter dims, then we
       // would rather just do padding on the GEMM instead.
       if (inputChannelDims.contains(convDim)) {
-        // Multiple input channel dims for a single IGEMMPos is not supported.
-        if (paddedIGEMMDims.contains(IGEMMPos)) {
+        int64_t inputChannelSize =
+            convToIgemmInfo->inputChannelDimToSize[convDim];
+        bool isInputChannelSizeSmall =
+            (paddingSizes[IGEMMPos] / inputChannelSize > 2);
+        // The following cases are not supported:
+        // 1) Input channel is not the innermost dimension;
+        // 2) Input channel size is too small compared to padding size;
+        // 3) Multiple input channel dims for a single IGEMMPos.
+        if (!convToIgemmInfo->isInputChannelLast || isInputChannelSizeSmall ||
+            paddedIGEMMDims.contains(IGEMMPos)) {
           return std::nullopt;
         }
         paddingConvSizes[convDim] = paddingSizes[IGEMMPos];
@@ -429,9 +445,7 @@
     SmallVector<int64_t> bounds, ArrayRef<AffineMap> maps,
     ArrayRef<Value> operands, IREE::GPU::TargetAttr target, bool useDirectLoad,
     bool isGemm, bool scaled,
-    std::optional<DenseMap<int64_t, AffineExpr>> convToIgemmDimMap =
-        std::nullopt,
-    std::optional<linalg::ConvolutionDimensions> convDims = std::nullopt) {
+    std::optional<ConvToIgemmInfo> convToIgemmInfo = std::nullopt) {
   if (target.getWgp().getMma().empty()) {
     return failure();
   }
@@ -678,9 +692,9 @@
 
     // Create `padding_conv` attribute when padding convolutions before IGEMM is
     // possible, otherwise fallback to pad IGEMM.
-    if (auto attr = getPaddingConvSizes(b, bounds, paddingTileSizes,
-                                        workgroupTileSizes, reductionTileSizes,
-                                        convToIgemmDimMap, convDims)) {
+    if (auto attr =
+            getPaddingConvSizes(b, bounds, paddingTileSizes, workgroupTileSizes,
+                                reductionTileSizes, convToIgemmInfo)) {
       attrs.emplace_back(StringAttr::get(context, "padding_conv"), *attr);
     } else {
       attrs.emplace_back(StringAttr::get(context, "padding"),
@@ -715,24 +729,39 @@
     LDBG() << "Unsupported convolution type";
     return failure();
   }
+
+  ConvToIgemmInfo convToIgemmInfo;
+  if (padConv) {
+    auto inputType = llvm::cast<ShapedType>(op->getOperands()[0].getType());
+    ArrayRef<int64_t> inputShape = inputType.getShape();
+    AffineMap inputMap = linalgOp.getIndexingMapsArray()[0];
+    SmallVector<int64_t> inputChannelPos;
+    for (auto dim : igemmGenericConvDetails->convDims.inputChannel) {
+      for (auto [idx, e] : llvm::enumerate(inputMap.getResults())) {
+        if (e.isFunctionOfDim(dim)) {
+          convToIgemmInfo.inputChannelDimToSize[dim] = inputShape[idx];
+          inputChannelPos.push_back(idx);
+        }
+      }
+    }
+    llvm::sort(inputChannelPos);
+    convToIgemmInfo.isInputChannelLast =
+        inputChannelPos.back() == inputShape.size() - 1;
+    convToIgemmInfo.convDims = igemmGenericConvDetails->convDims;
+    convToIgemmInfo.convToIgemmDimMap =
+        igemmGenericConvDetails->convToIgemmDimMap;
+  }
+
   SmallVector<AffineMap> igemmContractionMaps =
       igemmGenericConvDetails->igemmContractionMaps;
   SmallVector<int64_t> igemmLoopBounds =
       igemmGenericConvDetails->igemmLoopBounds;
   SmallVector<Value> igemmOperands = igemmGenericConvDetails->igemmOperands;
-
-  std::optional<DenseMap<int64_t, AffineExpr>> convToIgemmDimMap;
-  std::optional<linalg::ConvolutionDimensions> convDims;
-  if (padConv) {
-    convDims = igemmGenericConvDetails->convDims;
-    convToIgemmDimMap = igemmGenericConvDetails->convToIgemmDimMap;
-  }
-
-  SmallVector<int64_t> bounds = igemmLoopBounds;
   FailureOr<std::pair<LoweringConfigAttr, int64_t>> configAndWgSize =
       getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
-          bounds, igemmContractionMaps, igemmOperands, target, useDirectLoad,
-          /*isGemm=*/false, /*scaled*/ false, convToIgemmDimMap, convDims);
+          igemmLoopBounds, igemmContractionMaps, igemmOperands, target,
+          useDirectLoad, /*isGemm=*/false,
+          /*scaled*/ false, convToIgemmInfo);
   if (failed(configAndWgSize)) {
     return failure();
   }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir
index db1909e..1f33c99 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir
@@ -195,7 +195,7 @@
 #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1 + d5 * 2, d2 + d6 * 2, d3)>
 #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d0)>
 #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
-func.func @conv_chwn_chwf_unaligned(%arg0: tensor<16x193x129x40xbf16>, %arg1: tensor<16x96x64x40xbf16>, %arg2: tensor<40x3x3x40xf32>) -> tensor<40x3x3x40xf32> {
+func.func @conv_chwn_chwf_unaligned_batch(%arg0: tensor<16x193x129x40xbf16>, %arg1: tensor<16x96x64x40xbf16>, %arg2: tensor<40x3x3x40xf32>) -> tensor<40x3x3x40xf32> {
   %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<16x193x129x40xbf16>, tensor<16x96x64x40xbf16>) outs(%arg2 : tensor<40x3x3x40xf32>) {
   ^bb0(%in: bf16, %in_0: bf16, %out: f32):
     %1 = arith.extf %in : bf16 to f32
@@ -207,7 +207,7 @@
   return %0 : tensor<40x3x3x40xf32>
 }
 
-// CHECK-LABEL: func.func @conv_chwn_chwf_unaligned
+// CHECK-LABEL: func.func @conv_chwn_chwf_unaligned_batch
 //  CHECK-SAME:   #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64
 //  CHECK-SAME:   #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false
 //  CHECK-SAME:   use_igemm_convolution = true
@@ -227,7 +227,7 @@
 #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0 + d4, d1 + d5, d2, d6)>
 #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d3, d4, d5, d6)>
 #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
-func.func @group_conv_unaligned(%arg0: tensor<61x93x16x56xbf16>, %arg1: tensor<16x56x3x3x56xbf16>, %arg2: tensor<59x91x16x56xf32>) -> tensor<59x91x16x56xf32> {
+func.func @group_conv_hwgc_gfhwc_unaligned(%arg0: tensor<61x93x16x56xbf16>, %arg1: tensor<16x56x3x3x56xbf16>, %arg2: tensor<59x91x16x56xf32>) -> tensor<59x91x16x56xf32> {
   %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<61x93x16x56xbf16>, tensor<16x56x3x3x56xbf16>) outs(%arg2 : tensor<59x91x16x56xf32>) {
     ^bb0(%in: bf16, %in_4: bf16, %out: f32):
       %10 = arith.extf %in : bf16 to f32
@@ -239,7 +239,7 @@
   return %0 : tensor<59x91x16x56xf32>
 }
 
-// CHECK-LABEL: func.func @group_conv_unaligned
+// CHECK-LABEL: func.func @group_conv_hwgc_gfhwc_unaligned
 //  CHECK-SAME:   #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [512, 1, 1] subgroup_size = 64
 //  CHECK-SAME:   #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false
 //  CHECK-SAME:   use_igemm_convolution = true
@@ -299,3 +299,45 @@
 // MI300X-SAME:     workgroup = [1, 2, 32, 32, 0]
 
 // PAD-CONV-GFX942:     padding_conv = [2, 2, 32, 64, 0, 0]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1 + d5 * 2, d2 + d6 * 2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d0)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+func.func @conv_chwn_chwf_no_pad_conv(%arg0: tensor<2x192x128x40xbf16>, %arg1: tensor<2x95x63x40xbf16>, %arg2: tensor<40x3x3x40xf32>) -> tensor<40x3x3x40xf32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<2x192x128x40xbf16>, tensor<2x95x63x40xbf16>) outs(%arg2 : tensor<40x3x3x40xf32>) {
+  ^bb0(%in: bf16, %in_0: bf16, %out: f32):
+    %1 = arith.extf %in : bf16 to f32
+    %2 = arith.extf %in_0 : bf16 to f32
+    %3 = arith.mulf %1, %2 : f32
+    %4 = arith.addf %out, %3 : f32
+    linalg.yield %4 : f32
+  } -> tensor<40x3x3x40xf32>
+  return %0 : tensor<40x3x3x40xf32>
+}
+
+//         CHECK-LABEL:  func.func @conv_chwn_chwf_no_pad_conv
+//     PAD-CONV-GFX942:     padding = [16, 1, 1, 16, 16]
+// PAD-CONV-GFX942-NOT:     padding_conv
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+func.func @conv_nhwc_small_channel_no_pad_conv(%arg0: tensor<16x26x19x3xf16>, %arg1: tensor<287x3x3x3xf16>, %arg2: tensor<16x24x17x287xf32>) -> tensor<16x24x17x287xf32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<16x26x19x3xf16>, tensor<287x3x3x3xf16>) outs(%arg2 : tensor<16x24x17x287xf32>) {
+  ^bb0(%in: f16, %in_0: f16, %out: f32):
+    %1 = arith.extf %in : f16 to f32
+    %2 = arith.extf %in_0 : f16 to f32
+    %3 = arith.mulf %1, %2 : f32
+    %4 = arith.addf %out, %3 : f32
+    linalg.yield %4 : f32
+  } -> tensor<16x24x17x287xf32>
+  return %0 : tensor<16x24x17x287xf32>
+}
+
+//         CHECK-LABEL:  func.func @conv_nhwc_small_channel_no_pad_conv
+//     PAD-CONV-GFX942:     padding = [1, 4, 32, 32, 32]
+// PAD-CONV-GFX942-NOT:     padding_conv