Reapply "[Codegen] Enable DMA by default for F16/BF16 Gemm on gfx950 (#24117)" (#24235) (#24373)

This reverts commit 75ffbc37144de79cc9428f97827251b2242b230f.

The previously reported numerical issues have now been resolved through
the following changes:
- https://github.com/iree-org/iree/pull/24241
- https://github.com/iree-org/iree/pull/24242
- https://github.com/iree-org/iree/pull/24254

---------

Signed-off-by: Yu-Zhewen <zhewenyu@amd.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUConvertToCoalescedDMA.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUConvertToCoalescedDMA.cpp
index 888f295..08c4d77 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUConvertToCoalescedDMA.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUConvertToCoalescedDMA.cpp
@@ -173,21 +173,6 @@
   return hasAMDGPUFatRawBufferAddressSpace(memrefType);
 }
 
-/// Check if the target architecture supports global load DMA.
-/// Returns true only for CDNA4+ (gfx950+) architectures.
-static bool targetSupportsGlobalLoadDMA(IREE::GPU::TargetAttr target) {
-  if (!target) {
-    return false;
-  }
-  FailureOr<amdgpu::Chipset> chipset = amdgpu::Chipset::parse(target.getArch());
-  if (failed(chipset)) {
-    return false;
-  }
-  // CDNA4 is gfx950+ (major=9, minor>=5). Other major versions (RDNA, etc.)
-  // do not support global load DMA.
-  return chipset->majorVersion == 9 && chipset->minorVersion >= 5;
-}
-
 /// Returns the subgroup size if the available elements are aligned to DMA
 /// transfer sizes, std::nullopt otherwise.
 static std::optional<int64_t>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
index 9ab5d4e..626c210 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
@@ -557,6 +557,42 @@
   return splitReductionTripCnt;
 }
 
+/// Returns true if direct load DMA should be rejected, and fall back to stream
+/// copies.
+///
+/// Rejection cases:
+///   1. Target does not support DMA (requires gfx950+ / CDNA4+).
+///   2. Not a GEMM. TODO(#23907): support convolution.
+///   3. Data types are not f16 or bf16. TODO(#22119): support MXFP4.
+///   4. LHS transposed, RHS not transposed shows regressions. TODO (#24117).
+static bool shouldRejectDirectLoadDMA(IREE::GPU::TargetAttr target, bool isGemm,
+                                      Type lhsElemType, Type rhsElemType,
+                                      bool transposedLhs, bool transposedRhs) {
+  auto isF16OrBF16 = [](Type t) { return t.isF16() || t.isBF16(); };
+
+  // Case 1: DMA requires hardware support (gfx950+ / CDNA4+).
+  if (!targetSupportsGlobalLoadDMA(target)) {
+    return true;
+  }
+
+  // Case 2: Only GEMM are supported currently.
+  if (!isGemm) {
+    return true;
+  }
+
+  // Case 3: Only f16/bf16 are supported currently.
+  if (!isF16OrBF16(lhsElemType) || !isF16OrBF16(rhsElemType)) {
+    return true;
+  }
+
+  // Case 4: LHS transposed, RHS not transposed show regressions with DMA.
+  if (transposedLhs && !transposedRhs) {
+    return true;
+  }
+
+  return false;
+}
+
 /// Create a lowering config for matmul or IGEMM convolution based on iteration
 /// bounds and indexing maps for a given target. This function computes
 /// contraction dimensions and deduces an MMA intrinsic schedule to choose tile
@@ -572,7 +608,7 @@
 getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
     ArrayRef<int64_t> bounds, ArrayRef<AffineMap> maps,
     ArrayRef<Value> operands, IREE::GPU::TargetAttr target, bool isGemm,
-    bool scaled, bool useDirectLoad, int64_t prefetchNumStages,
+    bool scaled, bool &useDirectLoad, int64_t prefetchNumStages,
     int64_t splitReductionTripCnt, bool hasExistingAccumulator = false,
     std::optional<ConvToIgemmInfo> convToIgemmInfo = std::nullopt) {
   if (target.getWgp().getMma().empty()) {
@@ -750,13 +786,11 @@
                              lhsScaleType,
                              rhsScaleType};
 
-  // TODO(#22119): We don't use global load DMA for scaled matmuls, because
-  // compilation doesn't support it. Once this is fixed, we should use global
-  // load DMA here when possible.
   Location loc = operands[0].getLoc();
-  if (scaled && useDirectLoad) {
-    mlir::emitWarning(loc) << "direct load (global load DMA) is not yet "
-                              "supported for scaled matmuls, ignoring";
+  if (useDirectLoad &&
+      shouldRejectDirectLoadDMA(target, isGemm, lhsElemType, rhsElemType,
+                                transposedLhs, transposedRhs)) {
+    LDBG() << "overriding direct load DMA, falling back to stream copies";
     useDirectLoad = false;
   }
 
@@ -882,7 +916,7 @@
     // Apply XOR swizzle for BF16 DMA operands whose reduction dim is
     // innermost (contiguous reads) to avoid LDS bank conflicts.
     // TODO(#24255): Fix untuned swizzle logic for DMA.
-    if (lhsElemType.isBF16() && !transposedLhs) {
+    if (!transposedLhs) {
       FailureOr<Attribute> lhsSwizzleAttr = getXorShuffleAttr(
           context, lhsAttr, target, kind, schedule->kTileSizes, kMMAOperandLhs,
           /*skipUntunedFallback=*/true);
@@ -890,7 +924,7 @@
         lhsAttr = *lhsSwizzleAttr;
       }
     }
-    if (rhsElemType.isBF16() && transposedRhs) {
+    if (transposedRhs) {
       FailureOr<Attribute> rhsSwizzleAttr = getXorShuffleAttr(
           context, rhsAttr, target, kind, schedule->kTileSizes, kMMAOperandRhs,
           /*skipUntunedFallback=*/true);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 8a5f83d..1f28eff 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -106,7 +106,7 @@
 static llvm::cl::opt<bool>
     clUseDirectLoad("iree-llvmgpu-use-direct-load",
                     llvm::cl::desc("Use global load DMA for direct load ops."),
-                    llvm::cl::Hidden, llvm::cl::init(false));
+                    llvm::cl::Hidden, llvm::cl::init(true));
 
 static llvm::cl::opt<bool> clDirectConvolution(
     "iree-codegen-llvmgpu-use-direct-convolution",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir
index dc91220..2673f4c 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir
@@ -1,11 +1,11 @@
 // RUN: iree-opt --mlir-print-local-scope --split-input-file --iree-gpu-test-target=gfx950 \
 // RUN: --iree-codegen-llvmgpu-use-tile-and-fuse-matmul=true --iree-codegen-llvmgpu-test-tile-and-fuse-vectorize=true \
-// RUN: --iree-codegen-llvmgpu-use-igemm=false \
+// RUN: --iree-codegen-llvmgpu-use-igemm=false --iree-llvmgpu-use-direct-load=false \
 // RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s
 
 // RUN: iree-opt --mlir-print-local-scope --split-input-file --iree-gpu-test-target=gfx950 \
 // RUN: --iree-codegen-llvmgpu-use-tile-and-fuse-matmul=true --iree-codegen-llvmgpu-test-tile-and-fuse-vectorize=true \
-// RUN: --iree-codegen-llvmgpu-use-igemm=false \
+// RUN: --iree-codegen-llvmgpu-use-igemm=false --iree-llvmgpu-use-direct-load=false \
 // RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" \
 // RUN: --remarks-filter=".*" %s 2>&1 | FileCheck %s --check-prefix=CHECK-REMARKS
 
@@ -36,11 +36,6 @@
 // RUN: --iree-codegen-llvmgpu-use-igemm=true --iree-codegen-llvmgpu-test-tile-and-fuse-vectorize=true \
 // RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s --check-prefix=IGEMM
 
-// RUN: iree-opt --mlir-print-local-scope --split-input-file --iree-gpu-test-target=gfx950 \
-// RUN: --iree-codegen-llvmgpu-use-igemm=true --iree-codegen-llvmgpu-test-tile-and-fuse-vectorize=true \
-// RUN: --iree-llvmgpu-use-direct-load=true \
-// RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s \
-// RUN: | FileCheck %s --check-prefix=IGEMM-DIRECT-LOAD
 
 #lhs_map = affine_map<(M, N, Ko, Kb) -> (M, Ko, Kb)>
 #rhs_map = affine_map<(M, N, Ko, Kb) -> (N, Ko, Kb)>
@@ -598,21 +593,32 @@
 
 // -----
 
-// BF16 1x1 conv with DMA. The MMA intrinsic (MFMA_F32_32x32x8_BF16) is not in
-// the tuned swizzle table, so no XOR swizzle should be applied -- only plain
-// use_global_load_dma.
-func.func @conv_bf16_no_untuned_swizzle(
+// BF16 1x1 conv (preprocessed to fold unit spatial dims) with DMA. The MMA intrinsic
+// (MFMA_F32_32x32x8_BF16) is not in the tuned swizzle table, so no XOR
+// swizzle should be applied -- only plain use_global_load_dma.
+func.func @conv_1x1_bf16_no_untuned_swizzle(
     %arg0: tensor<16x96x64x40xbf16>,
-    %arg1: tensor<40x1x1x40xbf16>) -> tensor<16x96x64x40xf32> {
+    %arg1: tensor<40x40xbf16>) -> tensor<16x96x64x40xf32> {
   %cst = arith.constant 0.000000e+00 : f32
   %empty = tensor.empty() : tensor<16x96x64x40xf32>
   %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<16x96x64x40xf32>) -> tensor<16x96x64x40xf32>
-  %result = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
-    ins(%arg0, %arg1 : tensor<16x96x64x40xbf16>, tensor<40x1x1x40xbf16>)
-    outs(%fill : tensor<16x96x64x40xf32>) -> tensor<16x96x64x40xf32>
+  %result = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>,
+                       affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>,
+                       affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>],
+      iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]
+  } ins(%arg0, %arg1 : tensor<16x96x64x40xbf16>, tensor<40x40xbf16>)
+    outs(%fill : tensor<16x96x64x40xf32>) {
+  ^bb0(%in: bf16, %in_1: bf16, %out: f32):
+    %0 = arith.extf %in : bf16 to f32
+    %1 = arith.extf %in_1 : bf16 to f32
+    %2 = arith.mulf %0, %1 : f32
+    %3 = arith.addf %out, %2 : f32
+    linalg.yield %3 : f32
+  } -> tensor<16x96x64x40xf32>
   return %result : tensor<16x96x64x40xf32>
 }
 
-// IGEMM-DIRECT-LOAD-LABEL: func.func @conv_bf16_no_untuned_swizzle
-// IGEMM-DIRECT-LOAD:       linalg.conv_2d_nhwc_fhwc {
-// IGEMM-DIRECT-LOAD-SAME:    promotion_types = [#iree_gpu.use_global_load_dma, #iree_gpu.use_global_load_dma]
+// CHECK-DIRECT-LOAD-LABEL: func.func @conv_1x1_bf16_no_untuned_swizzle
+// CHECK-DIRECT-LOAD:       linalg.generic
+// CHECK-DIRECT-LOAD-SAME:    promotion_types = [#iree_gpu.use_global_load_dma, #iree_gpu.use_global_load_dma]
diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
index e1adfa5..dc5e343 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
@@ -756,7 +756,7 @@
                    int operandIndex) {
   SmallVector<Type> elementTypes;
   intrinsic.getElementTypes(elementTypes);
-  assert(operandIndex > 0 && "operand index must be positive");
+  assert(operandIndex >= 0 && "operand index must be non-negative");
   return elementTypes[operandIndex].getIntOrFloatBitWidth();
 }
 
@@ -819,6 +819,8 @@
   }
   if (auto mma = dyn_cast<IREE::GPU::MMAAttr>(intrinsic)) {
     switch (mma.getIntrinsic()) {
+    case IREE::GPU::MMAIntrinsic::MFMA_F32_16x16x32_F16:
+    case IREE::GPU::MMAIntrinsic::MFMA_F32_32x32x16_F16:
     case IREE::GPU::MMAIntrinsic::MFMA_F32_16x16x32_BF16:
     case IREE::GPU::MMAIntrinsic::MFMA_F32_32x32x16_BF16:
       return XorShuffleParams({/*rowElems=*/64,
@@ -1309,6 +1311,18 @@
   return getGPUTargetAttr(op->getContext(),
                           IREE::HAL::ExecutableTargetAttr::lookup(op));
 }
+
+bool targetSupportsGlobalLoadDMA(IREE::GPU::TargetAttr target) {
+  if (!target) {
+    return false;
+  }
+  FailureOr<amdgpu::Chipset> chipset = amdgpu::Chipset::parse(target.getArch());
+  if (failed(chipset)) {
+    return false;
+  }
+  return chipset->majorVersion == 9 && chipset->minorVersion >= 5;
+}
+
 void addConfigGPUTarget(MLIRContext *context,
                         IREE::GPU::TargetAttr gpuTargetAttr,
                         SmallVectorImpl<NamedAttribute> &config) {
diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
index 631e020..8037d8e 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
@@ -330,6 +330,10 @@
                                        IREE::HAL::ExecutableTargetAttr attr);
 IREE::GPU::TargetAttr getGPUTargetAttr(Operation *op);
 
+/// Check if the target architecture supports global load DMA.
+/// Returns true only for CDNA4+ (gfx950+) architectures.
+bool targetSupportsGlobalLoadDMA(IREE::GPU::TargetAttr target);
+
 // Methods to retrieve information association with `configuration` field
 // of `hal.executable.target` attribute used commonly in GPU codegen pipelines.
 std::optional<int64_t> getConfigWavesPerEu(DictionaryAttr targetAttr);