[GPU] Teach GPUApplyTilingLevel PartialReduction tiling (#19682)
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp
index b6f58fe..62a8ac2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp
@@ -279,6 +279,12 @@
tilingOptions.setMapping(llvm::to_vector(llvm::reverse(mapping)));
}
+ if (tilingLevel == IREE::GPU::TilingLevel::PartialReduction) {
+ tilingOptions.setReductionTilingStrategy(
+ scf::SCFTilingOptions::ReductionTilingStrategy::
+ PartialReductionOuterReduction);
+ }
+
scf::SCFTileAndFuseOptions tileAndFuseOptions;
tileAndFuseOptions.setTilingOptions(tilingOptions);
@@ -288,6 +294,7 @@
-> std::optional<scf::SCFTileAndFuseOptions::ControlFnResult> {
Operation *owner = originalProducer.getOwner();
if (tilingLevel == IREE::GPU::TilingLevel::Reduction ||
+ tilingLevel == IREE::GPU::TilingLevel::PartialReduction ||
tilingLevel == IREE::GPU::TilingLevel::Subgroup) {
// Do not fuse pad in reduction and subgroup tiling. We instead fuse
// pad without zero slice guard as a cleanup pattern.
@@ -298,7 +305,8 @@
bool yieldProducerReplacement = false;
// We dont want this for reduction tiling as it can lead to large tensors
// being yielded.
- if (tilingLevel != IREE::GPU::TilingLevel::Reduction)
+ if (tilingLevel != IREE::GPU::TilingLevel::Reduction &&
+ tilingLevel != IREE::GPU::TilingLevel::PartialReduction)
yieldProducerReplacement = yieldReplacementsFor.contains(owner);
bool shouldFuse = false;
if (auto tilingOwner = dyn_cast<TilingInterface>(owner)) {
@@ -306,7 +314,8 @@
}
// Do not fuse destination operands for reduction tiling.
if (isDestinationOperand &&
- tilingLevel == IREE::GPU::TilingLevel::Reduction) {
+ (tilingLevel == IREE::GPU::TilingLevel::Reduction ||
+ tilingLevel == IREE::GPU::TilingLevel::PartialReduction)) {
shouldFuse = false;
}
if (shouldFuse) {
@@ -395,7 +404,8 @@
if (tilingLevel != IREE::GPU::TilingLevel::Reduction &&
tilingLevel != IREE::GPU::TilingLevel::Thread &&
- tilingLevel != IREE::GPU::TilingLevel::Subgroup) {
+ tilingLevel != IREE::GPU::TilingLevel::Subgroup &&
+ tilingLevel != IREE::GPU::TilingLevel::PartialReduction) {
funcOp.emitError() << "unsupported tiling level: "
<< IREE::GPU::stringifyEnum(tilingLevel) << "\n";
return signalPassFailure();
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
index 3a71759..7b9f96f 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
@@ -233,6 +233,8 @@
[{llvm::cl::values(
clEnumValN(IREE::GPU::TilingLevel::Reduction, "reduction",
"Tile and fuse all annotated ops to serial loops"),
+ clEnumValN(IREE::GPU::TilingLevel::PartialReduction, "partial_reduction",
+ "Tile and fuse all annotated ops to partial reduuction loops"),
clEnumValN(IREE::GPU::TilingLevel::Thread, "thread",
"Tile and fuse all annotated ops to threads"),
clEnumValN(IREE::GPU::TilingLevel::Subgroup, "subgroup",
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir
index f47a1ec..ee0c9a2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir
@@ -2,6 +2,7 @@
// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-apply-tiling-level{allow-zero-slices=true}, canonicalize, cse))" %s | FileCheck %s --check-prefix=NOZERO
// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-apply-tiling-level{tiling-level=thread}, canonicalize, cse))" %s | FileCheck %s --check-prefix=THREAD
// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-apply-tiling-level{tiling-level=subgroup}, canonicalize, cse))" %s | FileCheck %s --check-prefix=SUBGROUP
+// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-apply-tiling-level{tiling-level=partial_reduction}, canonicalize, cse))" %s | FileCheck %s --check-prefix=PARTRED
#config = #iree_gpu.lowering_config<{thread = [2, 16], subgroup = [2, 16]}>
#map = affine_map<(d0, d1) -> (d0, d1)>
@@ -698,3 +699,42 @@
// THREAD: %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[%[[APPLY0]], %[[APPLY1]]] [20, 4] [1, 1]
// THREAD: %[[EXPAND:.+]] = tensor.expand_shape %[[SLICE]] {{\[\[}}0, 1, 2], [3, 4]] output_shape [1, 2, 10, 1, 4]
// THREAD: linalg.exp {{.*}} ins(%[[EXPAND]]
+
+// -----
+
+// Partial reduction tiling tests
+#config = #iree_gpu.lowering_config<{partial_reduction = [0, 8]}>
+#map1 = affine_map<(d0, d1) -> (d0, d1)>
+#map2 = affine_map<(d0, d1) -> (d0)>
+func.func @partial_reduction(%3: tensor<?x?xf32>) -> tensor<?xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %par_dim = tensor.dim %3, %c0 : tensor<?x?xf32>
+ %empty = tensor.empty(%par_dim) : tensor<?xf32>
+ %4 = linalg.fill ins(%cst : f32) outs(%empty : tensor<?xf32>) -> tensor<?xf32>
+ %5 = linalg.generic {
+ indexing_maps = [#map1, #map2],
+ iterator_types = ["parallel", "reduction"]
+ } ins(%3 : tensor<?x?xf32>) outs(%4 : tensor<?xf32>) attrs = {lowering_config = #config} {
+ ^bb0(%in: f32, %out: f32):
+ %7 = arith.addf %in, %out : f32
+ linalg.yield %7 : f32
+ } -> tensor<?xf32>
+ return %5 : tensor<?xf32>
+}
+
+// We only check if the correct tiling implementation was used. We do not
+// check if the tiling implementation itself is correct (it should be tested
+// in the partial tiling unit tests).
+// PARTRED-LABEL: func.func @partial_reduction
+// PARTRED-DAG: %[[DIM0:.+]] = tensor.dim %{{.*}}, %c0
+// PARTRED-DAG: %[[DIM1:.+]] = tensor.dim %{{.*}}, %c1
+// PARTRED-DAG: %[[FULL:.+]] = linalg.fill {{.*}} tensor<?xf32>
+// PARTRED-DAG: %[[PART:.+]] = linalg.fill {{.*}} tensor<?x8xf32>
+// PARTRED: %[[OUT:.+]] = scf.for %{{.*}} = %c0 to %[[DIM1]] step %c8 iter_args(%{{.*}} = %[[PART]])
+// PARTRED: linalg.generic
+// PARTRED-SAME: iterator_types = ["parallel", "parallel"]
+// PARTRED-SAME: ins(%{{.*}} : tensor<?x?xf32>) outs(%{{.*}} : tensor<?x?xf32>)
+// PARTRED: scf.yield
+// PARTRED: linalg.reduce ins(%[[OUT]] : tensor<?x8xf32>)
+// PARTRED-SAME: outs(%[[FULL]] : tensor<?xf32>)