[GPU] Teach GPUApplyTilingLevel PartialReduction tiling (#19682)

diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp
index b6f58fe..62a8ac2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp
@@ -279,6 +279,12 @@
       tilingOptions.setMapping(llvm::to_vector(llvm::reverse(mapping)));
     }
 
+    if (tilingLevel == IREE::GPU::TilingLevel::PartialReduction) {
+      tilingOptions.setReductionTilingStrategy(
+          scf::SCFTilingOptions::ReductionTilingStrategy::
+              PartialReductionOuterReduction);
+    }
+
     scf::SCFTileAndFuseOptions tileAndFuseOptions;
     tileAndFuseOptions.setTilingOptions(tilingOptions);
 
@@ -288,6 +294,7 @@
         -> std::optional<scf::SCFTileAndFuseOptions::ControlFnResult> {
       Operation *owner = originalProducer.getOwner();
       if (tilingLevel == IREE::GPU::TilingLevel::Reduction ||
+          tilingLevel == IREE::GPU::TilingLevel::PartialReduction ||
           tilingLevel == IREE::GPU::TilingLevel::Subgroup) {
         // Do not fuse pad in reduction and subgroup tiling. We instead fuse
         // pad without zero slice guard as a cleanup pattern.
@@ -298,7 +305,8 @@
       bool yieldProducerReplacement = false;
       // We dont want this for reduction tiling as it can lead to large tensors
       // being yielded.
-      if (tilingLevel != IREE::GPU::TilingLevel::Reduction)
+      if (tilingLevel != IREE::GPU::TilingLevel::Reduction &&
+          tilingLevel != IREE::GPU::TilingLevel::PartialReduction)
         yieldProducerReplacement = yieldReplacementsFor.contains(owner);
       bool shouldFuse = false;
       if (auto tilingOwner = dyn_cast<TilingInterface>(owner)) {
@@ -306,7 +314,8 @@
       }
       // Do not fuse destination operands for reduction tiling.
       if (isDestinationOperand &&
-          tilingLevel == IREE::GPU::TilingLevel::Reduction) {
+          (tilingLevel == IREE::GPU::TilingLevel::Reduction ||
+           tilingLevel == IREE::GPU::TilingLevel::PartialReduction)) {
         shouldFuse = false;
       }
       if (shouldFuse) {
@@ -395,7 +404,8 @@
 
   if (tilingLevel != IREE::GPU::TilingLevel::Reduction &&
       tilingLevel != IREE::GPU::TilingLevel::Thread &&
-      tilingLevel != IREE::GPU::TilingLevel::Subgroup) {
+      tilingLevel != IREE::GPU::TilingLevel::Subgroup &&
+      tilingLevel != IREE::GPU::TilingLevel::PartialReduction) {
     funcOp.emitError() << "unsupported tiling level: "
                        << IREE::GPU::stringifyEnum(tilingLevel) << "\n";
     return signalPassFailure();
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
index 3a71759..7b9f96f 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
@@ -233,6 +233,8 @@
            [{llvm::cl::values(
               clEnumValN(IREE::GPU::TilingLevel::Reduction, "reduction",
                          "Tile and fuse all annotated ops to serial loops"),
+              clEnumValN(IREE::GPU::TilingLevel::PartialReduction, "partial_reduction",
+                         "Tile and fuse all annotated ops to partial reduuction loops"),
               clEnumValN(IREE::GPU::TilingLevel::Thread, "thread",
                          "Tile and fuse all annotated ops to threads"),
               clEnumValN(IREE::GPU::TilingLevel::Subgroup, "subgroup",
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir
index f47a1ec..ee0c9a2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir
@@ -2,6 +2,7 @@
 // RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-apply-tiling-level{allow-zero-slices=true}, canonicalize, cse))" %s | FileCheck %s --check-prefix=NOZERO
 // RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-apply-tiling-level{tiling-level=thread}, canonicalize, cse))" %s | FileCheck %s --check-prefix=THREAD
 // RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-apply-tiling-level{tiling-level=subgroup}, canonicalize, cse))" %s | FileCheck %s --check-prefix=SUBGROUP
+// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-apply-tiling-level{tiling-level=partial_reduction}, canonicalize, cse))" %s | FileCheck %s --check-prefix=PARTRED
 
 #config = #iree_gpu.lowering_config<{thread = [2, 16], subgroup = [2, 16]}>
 #map = affine_map<(d0, d1) -> (d0, d1)>
@@ -698,3 +699,42 @@
 //       THREAD:     %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[%[[APPLY0]], %[[APPLY1]]] [20, 4] [1, 1]
 //       THREAD:     %[[EXPAND:.+]] = tensor.expand_shape %[[SLICE]] {{\[\[}}0, 1, 2], [3, 4]] output_shape [1, 2, 10, 1, 4]
 //       THREAD:     linalg.exp {{.*}} ins(%[[EXPAND]]
+
+// -----
+
+// Partial reduction tiling tests
+#config = #iree_gpu.lowering_config<{partial_reduction = [0, 8]}>
+#map1 = affine_map<(d0, d1) -> (d0, d1)>
+#map2 = affine_map<(d0, d1) -> (d0)>
+func.func @partial_reduction(%3: tensor<?x?xf32>) -> tensor<?xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %c0  = arith.constant 0 : index
+  %par_dim = tensor.dim %3, %c0 : tensor<?x?xf32>
+  %empty = tensor.empty(%par_dim) : tensor<?xf32>
+  %4 = linalg.fill ins(%cst : f32) outs(%empty : tensor<?xf32>) -> tensor<?xf32>
+  %5 = linalg.generic {
+    indexing_maps = [#map1, #map2],
+    iterator_types = ["parallel", "reduction"]
+    } ins(%3 : tensor<?x?xf32>) outs(%4 : tensor<?xf32>) attrs =  {lowering_config = #config} {
+  ^bb0(%in: f32, %out: f32):
+    %7 = arith.addf %in, %out : f32
+    linalg.yield %7 : f32
+  } -> tensor<?xf32>
+  return %5 : tensor<?xf32>
+}
+
+// We only check if the correct tiling implementation was used. We do not
+// check if the tiling implementation itself is correct (it should be tested
+// in the partial tiling unit tests).
+// PARTRED-LABEL: func.func @partial_reduction
+//   PARTRED-DAG:  %[[DIM0:.+]]  = tensor.dim %{{.*}}, %c0
+//   PARTRED-DAG:  %[[DIM1:.+]]  = tensor.dim %{{.*}}, %c1
+//   PARTRED-DAG:   %[[FULL:.+]] = linalg.fill {{.*}} tensor<?xf32>
+//   PARTRED-DAG:   %[[PART:.+]] = linalg.fill {{.*}} tensor<?x8xf32>
+//       PARTRED:   %[[OUT:.+]] = scf.for %{{.*}} = %c0 to %[[DIM1]] step %c8 iter_args(%{{.*}} = %[[PART]])
+//       PARTRED:     linalg.generic
+//  PARTRED-SAME:       iterator_types = ["parallel", "parallel"]
+//  PARTRED-SAME:       ins(%{{.*}}  : tensor<?x?xf32>) outs(%{{.*}} : tensor<?x?xf32>)
+//       PARTRED:   scf.yield
+//       PARTRED:   linalg.reduce ins(%[[OUT]] : tensor<?x8xf32>)
+//  PARTRED-SAME:                 outs(%[[FULL]] : tensor<?xf32>)