[gpu] Use clustered gpu.subgroup_reduce for nested layout distribution (#18515)

There is now support in MLIR for expressing a subgroup reduction
operation that operates on several "clusters" in parallel, so it is no
longer necessary to build a series of shuffles.

It has been verified that, at least if the upstream patterns are used,
the resulting sequence of shuffles is the same as the old code.

This commit also adds a new pass, ExpandGPUOps, which uses the upstream
patterns to expand these ops, and adds it to the LLVMGPU pass list.

Resolves #18142.

Signed-off-by: Andrea Faulds <andrea.faulds@amd.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
index 296b316..00b304e 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
@@ -50,6 +50,7 @@
     name = "CommonGPUPasses",
     srcs = [
         "AMDGPUDistributeContract.cpp",
+        "ExpandGPUOps.cpp",
         "GPUApplyTilingLevel.cpp",
         "GPUCheckResourceUsage.cpp",
         "GPUCombineValueBarriers.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
index e078969..e93f404 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
@@ -49,6 +49,7 @@
     "Passes.h"
   SRCS
     "AMDGPUDistributeContract.cpp"
+    "ExpandGPUOps.cpp"
     "GPUApplyTilingLevel.cpp"
     "GPUCheckResourceUsage.cpp"
     "GPUCombineValueBarriers.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/ExpandGPUOps.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/ExpandGPUOps.cpp
new file mode 100644
index 0000000..cd657e9
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/ExpandGPUOps.cpp
@@ -0,0 +1,48 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Common/GPU/Passes.h"
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-codegen-expand-gpu-ops"
+
+namespace mlir::iree_compiler {
+
+#define GEN_PASS_DEF_EXPANDGPUOPSPASS
+#include "iree/compiler/Codegen/Common/GPU/Passes.h.inc"
+
+namespace {
+
+struct ExpandGPUOpsPass final : impl::ExpandGPUOpsPassBase<ExpandGPUOpsPass> {
+  void runOnOperation() override {
+    FunctionOpInterface funcOp = getOperation();
+    MLIRContext *ctx = &getContext();
+
+    std::optional<int> subgroupSize = getGPUSubgroupSize(funcOp);
+    if (!subgroupSize) {
+      funcOp->emitOpError("missing subgroup size");
+      return signalPassFailure();
+    }
+
+    RewritePatternSet patterns(ctx);
+    populateGpuBreakDownSubgroupReducePatterns(
+        patterns, /* maxShuffleBitwidth=*/32, PatternBenefit(2));
+    populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
+        patterns, *subgroupSize, /* shuffleBitwidth=*/32, PatternBenefit(1));
+    if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+      return signalPassFailure();
+    }
+  };
+};
+
+} // namespace
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
index ec86bf2..e36ad99 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/Utils.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/Transforms/Utils.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/AffineExpr.h"
@@ -512,10 +513,17 @@
       Value extracted = rewriter.create<vector::ExtractOp>(loc, flat, i);
 
       // Reduce across all reduction dimensions 1-by-1.
-      for (unsigned i = 0; i < reductionMask.size(); ++i) {
+      for (unsigned i = 0, e = reductionMask.size(); i != e; ++i) {
         if (reductionMask[i]) {
-          extracted = doPackedThreadReductionOnDim(rewriter, layout, extracted,
-                                                   kind, i);
+          int64_t offset = getShuffleOffset(layout, i);
+          int64_t width = getShuffleWidth(layout, i);
+          assert(offset <= std::numeric_limits<uint32_t>::max() &&
+                 width <= std::numeric_limits<uint32_t>::max());
+
+          extracted = rewriter.create<gpu::SubgroupReduceOp>(
+              loc, extracted, combiningKindToAllReduce(kind),
+              /*uniform=*/false, /*cluster_size=*/width,
+              /*cluster_stride=*/offset);
         }
       }
 
@@ -525,25 +533,6 @@
     return res;
   }
 
-  Value doPackedThreadReductionOnDim(RewriterBase &rewriter,
-                                     NestedLayoutAttr layout, Value val,
-                                     vector::CombiningKind kind,
-                                     int64_t dim) const {
-    Location loc = val.getLoc();
-    int64_t offset = getShuffleOffset(layout, dim);
-    int64_t width = getShuffleWidth(layout, dim);
-
-    for (int i = offset; i < offset * width; i <<= 1) {
-      auto shuffleOp = rewriter.create<gpu::ShuffleOp>(
-          loc, val, i, subgroupSize, gpu::ShuffleMode::XOR);
-      val =
-          makeArithReduction(rewriter, loc, kind, shuffleOp.getShuffleResult(),
-                             val, nullptr, nullptr);
-    }
-
-    return val;
-  }
-
   int64_t subgroupSize;
   int64_t maxBitsPerShuffle;
 };
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
index 209d3e2..1b5136d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
@@ -266,4 +266,12 @@
   ];
 }
 
+def ExpandGPUOpsPass :
+    InterfacePass<"iree-codegen-expand-gpu-ops", "mlir::FunctionOpInterface"> {
+  let summary = "Expands high-level GPU ops, such as clustered gpu.subgroup_reduce.";
+  let dependentDialects = [
+    "::mlir::gpu::GPUDialect"
+  ];
+}
+
 #endif // IREE_CODEGEN_COMMON_GPU_PASSES
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
index 9876b00..3b9b016 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
@@ -963,19 +963,13 @@
 }
 
 // CHECK-LABEL: func @mfma_16x16x16_out_reduced_dim1
-// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : i32
-// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : i32
-// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : i32
 // CHECK-DAG: %[[IDENTITY:.*]] = arith.constant dense<0xFF800000> : vector<2x1x1xf32>
 // CHECK-DAG: %[[DARG0:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<32x32xf32> -> vector<2x2x1x1x1x4xf32>
 // CHECK-DAG: %[[DARG1:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<32xf32> -> vector<2x1x1xf32>
 // Local reduction
 // CHECK: vector.multi_reduction <maximumf>, %[[DARG0]], %[[IDENTITY]] [1, 3, 5] : vector<2x2x1x1x1x4xf32> to vector<2x1x1xf32>
 // Global reduction
-// CHECK: gpu.shuffle  xor %{{.*}}, %[[C16]], %[[C64]] : f32
-// CHECK: gpu.shuffle  xor %{{.*}}, %[[C32]], %[[C64]] : f32
-// CHECK: gpu.shuffle  xor %{{.*}}, %[[C16]], %[[C64]] : f32
-// CHECK: gpu.shuffle  xor %{{.*}}, %[[C32]], %[[C64]] : f32
+// CHECK: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 4, stride = 16) : (f32) -> f32
 // Accumulator reduction
 // CHECK: %[[ACC_REDUC:.+]] = arith.maximumf %{{.*}}, %[[DARG1]] : vector<2x1x1xf32>
 // CHECK: iree_vector_ext.to_simd %[[ACC_REDUC]] : vector<2x1x1xf32> -> vector<32xf32>
@@ -1012,11 +1006,9 @@
 }
 
 // CHECK-LABEL: func @mfma_32x32x8_out_reduced_dim1
-// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : i32
-// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : i32
 // Local reduction
 // CHECK: vector.multi_reduction <maximumf>, %{{.*}}, %{{.*}} [1, 3, 5] : vector<1x4x1x1x1x4xf32> to vector<1x1x1xf32>
 // Global reduction
-// CHECK: gpu.shuffle  xor %{{.*}}, %[[C32]], %[[C64]] : f32
+// CHECK: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 2, stride = 32) : (f32) -> f32
 // Accumulator reduction
 // CHECK: arith.maximumf %{{.*}}, %{{.*}} : vector<1x1x1xf32>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index 15518e9..ec6e710 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -1061,7 +1061,8 @@
 
   FunctionLikeNest funcPassManager(modulePassManager);
   funcPassManager.addPass(createFoldTensorExtractOpPass)
-      .addPass(createLLVMGPUVectorLoweringPass);
+      .addPass(createLLVMGPUVectorLoweringPass)
+      .addPass(createExpandGPUOpsPass);
 
   // This pass needs to run before SCF -> CF.
   addLowerAndOptimizeAddressComputationPasses(funcPassManager);
diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
index 041784a..5eb4519 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
@@ -535,9 +535,8 @@
   return identity;
 }
 
-/// Return a matching GPU reduction operations.
-static gpu::AllReduceOperation
-combiningKindToAllReduce(vector::CombiningKind kind) {
+/// Returns the matching GPU reduction operation.
+gpu::AllReduceOperation combiningKindToAllReduce(vector::CombiningKind kind) {
   switch (kind) {
 #define MAP_CASE(X)                                                            \
   case vector::CombiningKind::X:                                               \
diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
index e089b00..cdbc297 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
@@ -9,6 +9,7 @@
 
 #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
 #include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
@@ -156,6 +157,11 @@
 /// Emit identity constant based on combiningKind and type.
 Value getCombiningIdentityValue(Location loc, OpBuilder &builder,
                                 vector::CombiningKind kind, Type identityType);
+
+/// Returns the matching GPU reduction operation.
+mlir::gpu::AllReduceOperation
+combiningKindToAllReduce(vector::CombiningKind kind);
+
 //===----------------------------------------------------------------------===//
 // GPU CodeGen op filter
 //===----------------------------------------------------------------------===//