[Codegen] Remove WarpReduction from ROCDL pipeline (#21795)

WarpReduction is being completely removed

Signed-off-by: James Newling <james.newling@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLKernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLKernelConfig.cpp
index c93076b..32ad469 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLKernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLKernelConfig.cpp
@@ -10,11 +10,9 @@
 #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
 #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
 #include "iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h"
-#include "iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.h"
 #include "iree/compiler/Codegen/Utils/GPUUtils.h"
 #include "iree/compiler/Codegen/Utils/Utils.h"
 #include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "llvm/ADT/APInt.h"
 #include "llvm/ADT/STLExtras.h"
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -62,210 +60,6 @@
   return true;
 }
 
-static LogicalResult
-setWarpReductionConfig(IREE::GPU::TargetAttr target,
-                       mlir::FunctionOpInterface entryPoint,
-                       linalg::LinalgOp op) {
-  if (!target.supportsSubgroupShuffle())
-    return failure();
-
-  SmallVector<unsigned> parallelDims;
-  SmallVector<unsigned> reductionDims;
-  op.getParallelDims(parallelDims);
-  op.getReductionDims(reductionDims);
-
-  SmallVector<int64_t> bounds = op.getStaticLoopRanges();
-  int64_t numParallelDims = op.getNumParallelLoops();
-
-  if (reductionDims.empty())
-    return failure();
-
-  // Make sure reduction dimensions are static and innermost ones.
-  int64_t numDynamicReductionDims = 0;
-  for (unsigned dim : reductionDims) {
-    if (ShapedType::isDynamic(bounds[dim])) {
-      numDynamicReductionDims++;
-    }
-    if (dim < numParallelDims) {
-      return failure();
-    }
-  }
-
-  // Distribution of multi-dim masked writes currently aren't fully supported.
-  if (numDynamicReductionDims > 1) {
-    return failure();
-  }
-
-  if (op.getRegionOutputArgs().size() != 1)
-    return failure();
-
-  // Only support projected permutation, this could be extended to projected
-  // permutated with broadcast.
-  if (llvm::any_of(op.getDpsInputOperands(), [&](OpOperand *input) {
-        return !op.getMatchingIndexingMap(input).isProjectedPermutation();
-      }))
-    return failure();
-
-  bool foundSingleReductionOutput = false;
-  for (auto [index, initOpOperand] : llvm::enumerate(op.getDpsInitsMutable())) {
-    // Only single combiner operations are supported for now.
-    SmallVector<Operation *> combinerOps;
-    if (matchReduction(op.getRegionOutputArgs(), index, combinerOps) &&
-        combinerOps.size() == 1) {
-      if (foundSingleReductionOutput)
-        return failure();
-      foundSingleReductionOutput = true;
-      continue;
-    }
-    if (!op.getMatchingIndexingMap(&initOpOperand).isIdentity())
-      return failure();
-  }
-  if (!foundSingleReductionOutput)
-    return failure();
-
-  // Tile all the parallel dimension to 1.
-  SmallVector<unsigned> partitionedLoops =
-      cast<PartitionableLoopsInterface>(op.getOperation())
-          .getPartitionableLoops(kNumMaxParallelDims);
-  size_t numLoops = partitionedLoops.empty() ? 0 : partitionedLoops.back() + 1;
-  SmallVector<int64_t> workgroupTileSizes(numLoops, 1);
-
-  // Without any bounds on dynamic reduction dims, we need specialization to
-  // get peak performance. For now, just use the warp size.
-  if (numDynamicReductionDims) {
-    SmallVector<int64_t> reductionTileSizes(op.getNumLoops(), 0);
-    int64_t preferredSubgroupSize = target.getPreferredSubgroupSize();
-    reductionTileSizes[reductionDims[0]] = preferredSubgroupSize;
-    TileSizesListType tileSizes;
-    tileSizes.emplace_back(std::move(workgroupTileSizes)); // Workgroup level
-    tileSizes.emplace_back(std::move(reductionTileSizes)); // Reduction level
-    std::array<int64_t, 3> workgroupSize = {preferredSubgroupSize, 1, 1};
-    if (failed(setOpConfigAndEntryPointFnTranslation(
-            entryPoint, op, tileSizes, CodeGenPipeline::LLVMGPUWarpReduction,
-            workgroupSize))) {
-      return failure();
-    }
-    return success();
-  }
-
-  int64_t reductionSize = 1;
-  for (int64_t dim : reductionDims)
-    reductionSize *= bounds[dim];
-
-  int64_t subgroupSize = 0;
-  for (int s : target.getWgp().getSubgroupSizeChoices().asArrayRef()) {
-    if (reductionSize % s == 0) {
-      subgroupSize = s;
-      break;
-    }
-  }
-  if (subgroupSize == 0)
-    return failure();
-
-  const Type elementType =
-      cast<ShapedType>(op.getDpsInitOperand(0)->get().getType())
-          .getElementType();
-  if (!elementType.isIntOrFloat())
-    return failure();
-  unsigned bitWidth = elementType.getIntOrFloatBitWidth();
-  // Reduction distribution only supports 8/16/32 bit types now.
-  if (bitWidth != 32 && bitWidth != 16 && bitWidth != 8)
-    return failure();
-
-  const unsigned largestLoadSizeInBits = 128;
-  unsigned vectorSize = largestLoadSizeInBits / bitWidth;
-  while ((reductionSize / vectorSize) % subgroupSize != 0)
-    vectorSize /= 2;
-
-  // Deduce the workgroup size we should use for reduction. Currently a
-  // workgroup processes all elements in reduction dimensions. Need to make sure
-  // the workgroup size we use can divide the total reduction size, and it's
-  // also within hardware limitations.
-  const int64_t maxWorkgroupSize = 1024;
-  int64_t groupSize = reductionSize / vectorSize;
-  if (groupSize > maxWorkgroupSize) {
-    groupSize = llvm::APIntOps::GreatestCommonDivisor(
-                    {64, uint64_t(groupSize)}, {64, uint64_t(maxWorkgroupSize)})
-                    .getZExtValue();
-  }
-
-  // Then we need to strike a balance--
-  // 1) parallel dimensions are distributed to workgroups. If there are many
-  //    workgroups dispatched, we'd want to have each GPU core hosting multiple
-  //    of them for occupancy.
-  // 2) we want each thread to read quite a few 128-bit vectors for better
-  //    memory cache behavior.
-  // Both means we cannot use a too large workgroup size.
-
-  std::optional<int64_t> parallelSize = 1;
-  for (int64_t dim : parallelDims) {
-    if (ShapedType::isDynamic(bounds[dim])) {
-      parallelSize = std::nullopt;
-      break;
-    }
-    *parallelSize *= bounds[dim];
-  }
-  // Total parallel size that can fill the GPU with enough workgorups.
-  // TODO: query from the target device; roughly 2x hardware compute unit.
-  const int parallelThreshold = 256;
-  // How many 128-bit vectors each thread should at least read.
-  const int targetVectorCount = 8;
-  while (parallelSize && *parallelSize > parallelThreshold &&
-         (groupSize / 2) % subgroupSize == 0 &&
-         reductionSize / (groupSize * vectorSize) < targetVectorCount) {
-    // Use less subgroups per workgroup..
-    groupSize /= 2;
-    // in order to host more workgroups per hardware compute unit.
-    *parallelSize /= 2;
-  }
-
-  // Current warp reduction pattern is a two step butterfly warp reduce.
-  // First, do warp reductions along multiple subgroups.
-  // Second, reduce results from multiple subgroups using single warp reduce.
-  // The final warp reduce requires subgroup count <= subgroup size to work.
-  if ((groupSize / subgroupSize) > subgroupSize)
-    return failure();
-
-  // With just one subgroup per workgroup, make each subgroup do more work and
-  // process a few reductions (rows) along the last parallel dimension.
-  if (llvm::none_of(bounds, ShapedType::isDynamic) && isMatvecLike(op)) {
-    int64_t lastParallelBound = bounds[parallelDims.back()];
-    int64_t numParallelReductions = 1;
-    const int64_t maxParallelFactor = groupSize / 4;
-    for (int64_t parallelFactor = 2;
-         (parallelFactor < maxParallelFactor) &&
-         (lastParallelBound % parallelFactor == 0) &&
-         (lastParallelBound > parallelFactor);
-         parallelFactor *= 2) {
-      numParallelReductions = parallelFactor;
-    }
-    workgroupTileSizes.back() = numParallelReductions;
-  }
-
-  std::array<int64_t, 3> workgroupSize = {groupSize, 1, 1};
-  SmallVector<int64_t> reductionTileSizes(op.getNumLoops(), 0);
-  int64_t remainingGroupSize = groupSize;
-  for (int i = reductionDims.size() - 1; i >= 0; --i) {
-    int64_t dim = reductionDims[i];
-    int64_t bound = bounds[dim];
-    if (i == reductionDims.size() - 1)
-      bound /= vectorSize;
-    APInt size = llvm::APIntOps::GreatestCommonDivisor(
-        {64, uint64_t(remainingGroupSize)}, {64, uint64_t(bound)});
-    reductionTileSizes[dim] = size.getSExtValue();
-    if (i == reductionDims.size() - 1)
-      reductionTileSizes[dim] *= vectorSize;
-    remainingGroupSize /= size.getSExtValue();
-  }
-  TileSizesListType tileSizes;
-  tileSizes.emplace_back(std::move(workgroupTileSizes)); // Workgroup level
-  tileSizes.emplace_back(std::move(reductionTileSizes)); // Reduction level
-  return setOpConfigAndEntryPointFnTranslation(
-      entryPoint, op, tileSizes, CodeGenPipeline::LLVMGPUWarpReduction,
-      workgroupSize, subgroupSize);
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // Root Configuration
 //===----------------------------------------------------------------------===//
@@ -287,9 +81,6 @@
             target, entryPointFn, computeOp))) {
       return success();
     }
-    if (succeeded(setWarpReductionConfig(target, entryPointFn, linalgOp))) {
-      return success();
-    }
     // TODO: Add configurations for matmul here too.
     if (succeeded(IREE::GPU::setTileAndFuseLoweringConfig(target, entryPointFn,
                                                           computeOp))) {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel
index 4743106..b3529db 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel
@@ -39,7 +39,6 @@
             "pipeline_vector_distribute_reduction_gfx942.mlir",
             "pipeline_vector_distribute_gfx950.mlir",
             "pipeline_vector_distribute_gfx1100.mlir",
-            "pipeline_warp_reduction.mlir",
         ],
         include = ["*.mlir"],
     ),
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt
index 9e5c557..0655424 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt
@@ -35,7 +35,6 @@
     "pipeline_vector_distribute_gfx942.mlir"
     "pipeline_vector_distribute_gfx950.mlir"
     "pipeline_vector_distribute_reduction_gfx942.mlir"
-    "pipeline_warp_reduction.mlir"
   TOOLS
     FileCheck
     iree-opt
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_warp_reduction.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_warp_reduction.mlir
deleted file mode 100644
index 6e82966..0000000
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_warp_reduction.mlir
+++ /dev/null
@@ -1,87 +0,0 @@
-// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx942 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-rocdl-configuration-pipeline), iree-codegen-linalg-to-rocdl-pipeline2)))" %s | FileCheck %s
-
-#pipeline_layout = #hal.pipeline.layout<bindings = [
-  #hal.pipeline.binding<storage_buffer>,
-  #hal.pipeline.binding<storage_buffer>
-]>
-hal.executable private @warp_reduction {
-  hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
-    hal.executable.export public @warp_reduction ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device, %arg1: index, %arg2: index) -> (index, index, index) {
-      %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_dag_root(%arg1, %arg2)
-      hal.return %x, %y, %z : index, index, index
-    }
-    builtin.module {
-      func.func @warp_reduction() {
-        %c0 = arith.constant 0 : index
-        %cst = arith.constant 0.000000e+00 : f32
-        %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x512xf32>>
-        %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2xf32>>
-        %2 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0, 0], sizes = [2, 512], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x512xf32>> -> tensor<2x512xf32>
-        %3 = tensor.empty() : tensor<2xf32>
-        %4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<2xf32>) -> tensor<2xf32>
-        %5 = linalg.generic {
-          indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
-          iterator_types = ["parallel", "reduction"]
-        } ins(%2 : tensor<2x512xf32>) outs(%4 : tensor<2xf32>) {
-        ^bb0(%arg0: f32, %arg1: f32):
-          %6 = arith.addf %arg1, %arg0 : f32
-          linalg.yield %6 : f32
-        } -> tensor<2xf32>
-        iree_tensor_ext.dispatch.tensor.store %5, %1, offsets = [0], sizes = [2], strides = [1] : tensor<2xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2xf32>>
-        return
-      }
-    }
-  }
-}
-
-//   CHECK-LABEL: llvm.func @warp_reduction
-// CHECK-COUNT-6:   rocdl.update.dpp
-
-// -----
-
-#pipeline_layout = #hal.pipeline.layout<bindings = [
-  #hal.pipeline.binding<storage_buffer>,
-  #hal.pipeline.binding<storage_buffer>,
-  #hal.pipeline.binding<storage_buffer>
-]>
-hal.executable public @main_dispatch_517 {
-  hal.executable.variant public @rocm target(<"rocm", "rocm-hsaco-fb">) {
-    hal.executable.export public @warp_reduction_large_vector ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device) -> (index, index, index) {
-      %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice()
-      hal.return %x, %y, %z : index, index, index
-    }
-    builtin.module {
-      func.func @warp_reduction_large_vector() {
-        %cst = arith.constant 0.000000e+00 : f32
-        %c128 = arith.constant 128 : index
-        %c0 = arith.constant 0 : index
-        %c394240 = arith.constant 394240 : index
-        %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c128) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<1x1280xf32>>
-        %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<1280x1280xf32>>
-        %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c394240) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<1x1280xf32>>
-        %3 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1, 1280], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<1x1280xf32>> -> tensor<1x1280xf32>
-        %4 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0], sizes = [1280, 1280], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<1280x1280xf32>> -> tensor<1280x1280xf32>
-        %5 = tensor.empty() : tensor<1x1280xf32>
-        %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<1x1280xf32>) -> tensor<1x1280xf32>
-        %7 = linalg.matmul
-          indexing_maps = [
-            affine_map<(d0, d1, d2) -> (d0, d2)>,
-            affine_map<(d0, d1, d2) -> (d1, d2)>,
-            affine_map<(d0, d1, d2) -> (d0, d1)>
-          ]
-          ins(%3, %4 : tensor<1x1280xf32>, tensor<1280x1280xf32>) outs(%6 : tensor<1x1280xf32>) -> tensor<1x1280xf32>
-        iree_tensor_ext.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [1, 1280], strides = [1, 1] : tensor<1x1280xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<1x1280xf32>>
-        return
-      }
-    }
-  }
-}
-
-// Each workgroup (5x64 threads) handles a shape of 64x1280 (parallel x reduction).
-// So we are seeing:
-// 6 dpp ops to reduce within warps
-// => 64 * (6) = 384 dpp operations.
-// TODO: we probably need to revisit the configuration heuristics here.
-
-//     CHECK-LABEL: llvm.func @warp_reduction_large_vector
-// CHECK-COUNT-384:   rocdl.update.dpp