[Codegen][LLVMGPU] Remove LLVMGPUWarpReduction pipeline (#21821)
The necessary changes to LLVMGPUVectorDistrubte pipeline and all tests
are hopefully already done and landed now, and this can go in cleanly as
is.
Signed-off-by: James Newling <james.newling@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
index ce58994..17e44e5 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
@@ -42,14 +42,12 @@
: I32EnumAttrCase<"LLVMGPUVectorize", 103>;
def LLVMGPU_TransposeSharedMem
: I32EnumAttrCase<"LLVMGPUTransposeSharedMem", 104>;
-def LLVMGPU_WarpReduction
- : I32EnumAttrCase<"LLVMGPUWarpReduction", 105>;
def LLVMGPU_VectorDistribute
- : I32EnumAttrCase<"LLVMGPUVectorDistribute", 106>;
+ : I32EnumAttrCase<"LLVMGPUVectorDistribute", 105>;
def LLVMGPU_WinogradVectorize
- : I32EnumAttrCase<"LLVMGPUWinogradVectorize", 107>;
+ : I32EnumAttrCase<"LLVMGPUWinogradVectorize", 106>;
def LLVMGPU_TileAndFuse
- : I32EnumAttrCase<"LLVMGPUTileAndFuse", 108>;
+ : I32EnumAttrCase<"LLVMGPUTileAndFuse", 107>;
def SPIRV_BaseLowering
: I32EnumAttrCase<"SPIRVBaseLowering", 200>;
@@ -88,7 +86,7 @@
// LLVMGPU CodeGen pipelines
LLVMGPU_Default, LLVMGPU_BaseLowering, LLVMGPU_SimpleDistribute,
- LLVMGPU_Vectorize, LLVMGPU_TransposeSharedMem, LLVMGPU_WarpReduction,
+ LLVMGPU_Vectorize, LLVMGPU_TransposeSharedMem,
LLVMGPU_VectorDistribute, LLVMGPU_WinogradVectorize, LLVMGPU_TileAndFuse,
// SPIR-V CodeGen pipelines
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index dae9254..7510180 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -2494,230 +2494,6 @@
return true;
}
-//====---------------------------------------------------------------------===//
-// Warp Reduction Pipeline Configuration
-//====---------------------------------------------------------------------===//
-
-/// Set the configuration for reductions that can be mapped to warp reductions.
-static LogicalResult
-setWarpReductionConfig(IREE::GPU::TargetAttr target,
- mlir::FunctionOpInterface entryPoint,
- linalg::LinalgOp op) {
- if (!target.supportsSubgroupShuffle())
- return failure();
-
- SmallVector<unsigned> parallelDims;
- SmallVector<unsigned> reductionDims;
- op.getParallelDims(parallelDims);
- op.getReductionDims(reductionDims);
-
- SmallVector<int64_t> bounds = op.getStaticLoopRanges();
- int64_t numParallelDims = op.getNumParallelLoops();
-
- if (reductionDims.empty())
- return failure();
-
- // Make sure reduction dimensions are static and innermost ones.
- int64_t numDynamicReductionDims = 0;
- for (unsigned dim : reductionDims) {
- if (ShapedType::isDynamic(bounds[dim])) {
- numDynamicReductionDims++;
- }
- if (dim < numParallelDims) {
- return failure();
- }
- }
- int numDynamicDims = llvm::count_if(bounds, ShapedType::isDynamic);
-
- // Distribution of multi-dim masked writes currently aren't fully supported.
- if (numDynamicReductionDims > 1) {
- return failure();
- }
-
- if (op.getRegionOutputArgs().size() != 1)
- return failure();
-
- // Only support projected permutation, this could be extended to projected
- // permutated with broadcast.
- if (llvm::any_of(op.getDpsInputOperands(), [&](OpOperand *input) {
- return !op.getMatchingIndexingMap(input).isProjectedPermutation();
- }))
- return failure();
-
- bool foundSingleReductionOutput = false;
- for (auto [index, initOpOperand] : llvm::enumerate(op.getDpsInitsMutable())) {
- // Only single combiner operations are supported for now.
- SmallVector<Operation *> combinerOps;
- if (matchReduction(op.getRegionOutputArgs(), index, combinerOps) &&
- combinerOps.size() == 1) {
- if (foundSingleReductionOutput)
- return failure();
- foundSingleReductionOutput = true;
- continue;
- }
- if (!op.getMatchingIndexingMap(&initOpOperand).isIdentity())
- return failure();
- }
- if (!foundSingleReductionOutput)
- return failure();
-
- SmallVector<int64_t> workgroupTileSizes(op.getNumParallelLoops(), 1);
-
- int64_t reductionSize = 1;
- for (int64_t dim : reductionDims)
- reductionSize *= bounds[dim];
-
- int64_t subgroupSize = 0;
- for (int s : target.getWgp().getSubgroupSizeChoices().asArrayRef()) {
- if (reductionSize % s == 0) {
- subgroupSize = s;
- break;
- }
- }
- if (subgroupSize == 0)
- return failure();
-
- // Without any bounds on dynamic dims, we need specialization to
- // get peak performance. For now, just use the warp size.
- if (numDynamicDims > 0) {
- SmallVector<int64_t> reductionTileSizes(op.getNumLoops(), 0);
- int64_t preferredSubgroupSize = target.getPreferredSubgroupSize();
- // We should set the subgroup size on:
- // Priority 1: The innermost reduction dimension with static shapes.
- // Priority 2: If there's no reduction dimension with static shapes
- // then the innermost reduction dim.
- unsigned lastNonDynamicReductionDim = reductionDims.back();
- if (reductionDims.size() > 1) {
- for (unsigned dim : reductionDims) {
- if (ShapedType::isDynamic(bounds[dim])) {
- reductionTileSizes[dim] = 1;
- } else {
- lastNonDynamicReductionDim = dim;
- }
- }
- }
- reductionTileSizes[lastNonDynamicReductionDim] = preferredSubgroupSize;
- TileSizesListType tileSizes;
- tileSizes.emplace_back(std::move(workgroupTileSizes)); // Workgroup level
- tileSizes.emplace_back(std::move(reductionTileSizes)); // Reduction level
- std::array<int64_t, 3> workgroupSize = {preferredSubgroupSize, 1, 1};
- if (failed(setOpConfigAndEntryPointFnTranslation(
- entryPoint, op, tileSizes, CodeGenPipeline::LLVMGPUWarpReduction,
- workgroupSize, preferredSubgroupSize))) {
- return failure();
- }
- return success();
- }
-
- const Type elementType =
- llvm::cast<ShapedType>(op.getDpsInitOperand(0)->get().getType())
- .getElementType();
- if (!elementType.isIntOrFloat())
- return failure();
- unsigned bitWidth = elementType.getIntOrFloatBitWidth();
- // Reduction distribution only supports 8/16/32 bit types now.
- if (bitWidth != 32 && bitWidth != 16 && bitWidth != 8)
- return failure();
-
- const unsigned largestLoadSizeInBits = 128;
- unsigned vectorSize = largestLoadSizeInBits / bitWidth;
- while ((reductionSize / vectorSize) % subgroupSize != 0)
- vectorSize /= 2;
-
- // Deduce the workgroup size we should use for reduction. Currently a
- // workgroup processes all elements in reduction dimensions. Need to make sure
- // the workgroup size we use can divide the total reduction size, and it's
- // also within hardware limitations.
- const int64_t maxWorkgroupSize = 1024;
- int64_t groupSize = reductionSize / vectorSize;
- if (groupSize > maxWorkgroupSize) {
- groupSize = llvm::APIntOps::GreatestCommonDivisor(
- {64, uint64_t(groupSize)}, {64, uint64_t(maxWorkgroupSize)})
- .getZExtValue();
- }
-
- // Then we need to strike a balance--
- // 1) parallel dimensions are distributed to workgroups. If there are many
- // workgroups dispatched, we'd want to have each GPU core hosting multiple
- // of them for occupancy.
- // 2) we want each thread to read quite a few 128-bit vectors for better
- // memory cache behavior.
- // Both means we cannot use a too large workgroup size.
-
- std::optional<int64_t> parallelSize = 1;
- for (int64_t dim : parallelDims) {
- if (ShapedType::isDynamic(bounds[dim])) {
- parallelSize = std::nullopt;
- break;
- }
- *parallelSize *= bounds[dim];
- }
- // Total parallel size that can fill the GPU with enough workgorups.
- // TODO: query from the target device; roughly 2x hardware compute unit.
- const int parallelThreshold = 256;
- // How many 128-bit vectors each thread should at least read.
- const int targetVectorCount = 8;
- while (parallelSize && *parallelSize > parallelThreshold &&
- (groupSize / 2) % subgroupSize == 0 &&
- reductionSize / (groupSize * vectorSize) < targetVectorCount) {
- // Use less subgroups per workgroup..
- groupSize /= 2;
- // in order to host more workgroups per hardware compute unit.
- *parallelSize /= 2;
- }
-
- // Current warp reduction pattern is a two step butterfly warp reduce.
- // First, do warp reductions along multiple subgroups.
- // Second, reduce results from multiple subgroups using single warp reduce.
- // The final warp reduce requires subgroup count <= subgroup size to work.
- if ((groupSize / subgroupSize) > subgroupSize)
- return failure();
-
- // With just one subgroup per workgroup, make each subgroup do more work and
- // process a few reductions (rows) along the last parallel dimension.
- //
- // TODO: This is enabled for matvec on ROCm for now. We should
- // validate this strategy and extend to more linalg generics and to CUDA.
- if (isROCmBackend(target) && ShapedType::isStaticShape(bounds) &&
- isMatvecLike(op)) {
- int64_t parallelIdx = *llvm::find_if(
- parallelDims, [&](int64_t currIdx) { return bounds[currIdx] != 1; });
- int64_t parallelBound = bounds[parallelIdx];
- int64_t numParallelReductions = 1;
- const int64_t maxParallelFactor = groupSize / 4;
- for (int64_t parallelFactor = 2; (parallelFactor < maxParallelFactor) &&
- (parallelBound % parallelFactor == 0) &&
- (parallelBound > parallelFactor);
- parallelFactor *= 2) {
- numParallelReductions = parallelFactor;
- }
- workgroupTileSizes[parallelIdx] = numParallelReductions;
- }
-
- std::array<int64_t, 3> workgroupSize = {groupSize, 1, 1};
- SmallVector<int64_t> reductionTileSizes(op.getNumLoops(), 0);
- int64_t remainingGroupSize = groupSize;
- for (int i = reductionDims.size() - 1; i >= 0; --i) {
- int64_t dim = reductionDims[i];
- int64_t bound = bounds[dim];
- if (i == reductionDims.size() - 1)
- bound /= vectorSize;
- APInt size = llvm::APIntOps::GreatestCommonDivisor(
- {64, uint64_t(remainingGroupSize)}, {64, uint64_t(bound)});
- reductionTileSizes[dim] = size.getSExtValue();
- if (i == reductionDims.size() - 1)
- reductionTileSizes[dim] *= vectorSize;
- remainingGroupSize /= size.getSExtValue();
- }
- TileSizesListType tileSizes;
- tileSizes.emplace_back(std::move(workgroupTileSizes)); // Workgroup level
- tileSizes.emplace_back(std::move(reductionTileSizes)); // Reduction level
- return setOpConfigAndEntryPointFnTranslation(
- entryPoint, op, tileSizes, CodeGenPipeline::LLVMGPUWarpReduction,
- workgroupSize, subgroupSize);
- return success();
-}
-
static bool hasTwoOrThreeLoopsInfo(linalg::LinalgOp linalgOp) {
return linalgOp.getNumParallelLoops() >= 2 &&
linalgOp.getNumParallelLoops() <= 3;
@@ -3083,10 +2859,6 @@
LDBG() << "Vector Distribution Subgroup Reduction Config";
return success();
}
- if (succeeded(setWarpReductionConfig(target, entryPointFn, linalgOp))) {
- LDBG() << "Warp Reduction Config";
- return success();
- }
if (succeeded(setConvolutionConfig(target, entryPointFn, linalgOp, 16))) {
LDBG() << "Convolution Config";
return success();
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
index d5d2d1c..7c6f87c 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
@@ -120,9 +120,6 @@
case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUVectorDistribute:
addGPUVectorDistributePassPipeline(pipeline, pipelineOptions, forROCDL);
break;
- case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUWarpReduction:
- addGPUWarpReductionPassPipeline(pipeline, forROCDL);
- break;
case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUTileAndFuse:
addGPUTileAndFusePassPipeline(pipeline, pipelineOptions, forROCDL);
break;
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index 54e4fee..a67547d 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -875,56 +875,6 @@
funcPassManager.addPass(createCSEPass());
}
-void addGPUWarpReductionPassPipeline(OpPassManager &funcPassManager,
- bool forROCDL) {
- tileAndDistributeToWorkgroup(
- funcPassManager, /*useForall=*/clDistributeToWorkgroupsUsingForall);
- funcPassManager.addPass(createRematerializeParallelOpsPass());
- funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
- funcPassManager.addPass(createGPUTileReductionPass());
- funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
- funcPassManager.addPass(createCSEPass());
- funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
-
- // Linalg -> vector
- {
- GenericVectorizationPassOptions options;
- options.enableVectorMasking = true;
- options.useConfiguredVectorSizes = false;
- options.vectorizePadding = true;
- options.vectorizeGatherAccesses = true;
- options.enableCleanup = false;
- options.generateContract = false;
- funcPassManager.addPass(createGenericVectorizationPass(options));
- funcPassManager.addPass(createOptimizeTensorInsertExtractSlicesPass());
- funcPassManager.addPass(createCanonicalizerPass());
- funcPassManager.addPass(createCSEPass());
- }
- funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
- funcPassManager.addPass(createCanonicalizerPass());
- funcPassManager.addPass(createCSEPass());
-
- addBufferizePasses(funcPassManager);
-
- funcPassManager.addPass(memref::createFoldMemRefAliasOpsPass());
- funcPassManager.addPass(createOptimizeVectorTransferPass());
- funcPassManager.addPass(createOptimizeTensorInsertExtractSlicesPass());
- funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
- funcPassManager.addPass(createCanonicalizerPass());
- funcPassManager.addPass(createCSEPass());
- funcPassManager.addPass(createForOpCanonicalizationPass());
- funcPassManager.addPass(createCanonicalizerPass());
-
- // vector -> simt gpu + vector
- VectorReductionToGPUPassOptions options;
- options.expandSubgroupReduction = !forROCDL;
- funcPassManager.addPass(createVectorReductionToGPUPass(options));
- funcPassManager.addPass(createCanonicalizerPass());
- funcPassManager.addPass(createCSEPass());
- funcPassManager.addPass(affine::createLoopCoalescingPass());
- funcPassManager.addPass(createCanonicalizerPass());
-}
-
void addGPUSimpleDistributePassPipeline(OpPassManager &funcPassManager) {
tileAndBufferize(funcPassManager);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
index 1578a9b..682b5b9 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
@@ -60,10 +60,6 @@
const GPUPipelineOptions &options,
bool forROCDL);
-/// Lowering reductions to warp reductions.
-void addGPUWarpReductionPassPipeline(OpPassManager &funcPassManager,
- bool forROCDL = true);
-
/// Default pass pipeline on GPU, currently used only for the ukernel path.
void addGPUDefaultPassPipeline(OpPassManager &funcPassManager,
const GPUPipelineOptions &options);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLLowerExecutableTarget.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLLowerExecutableTarget.cpp
index eb644c1..5fec4e9 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLLowerExecutableTarget.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLLowerExecutableTarget.cpp
@@ -69,9 +69,6 @@
case CodeGenPipeline::LLVMGPUBaseLowering:
addGPUBaseLoweringPassPipeline(pipeline);
break;
- case CodeGenPipeline::LLVMGPUWarpReduction:
- addGPUWarpReductionPassPipeline(pipeline, /*forROCDL=*/true);
- break;
case CodeGenPipeline::LLVMGPUTileAndFuse:
addGPUTileAndFusePassPipeline(pipeline, pipelineOptions,
/*forROCDL=*/true);