Use coalesce loops (#17314)

diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp
index b09e576..4d54ddd 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp
@@ -37,10 +37,11 @@
 public:
   TileConsumerAndFuseInputProducer(MLIRContext *context,
                                    LinalgTransformationFilter filter,
-                                   bool fuseInputProducer,
+                                   bool fuseInputProducer, bool coalesceLoops,
                                    PatternBenefit benefit = 1)
       : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
-        filter(std::move(filter)), fuseInputProducer(fuseInputProducer) {}
+        filter(std::move(filter)), fuseInputProducer(fuseInputProducer),
+        coalesceLoops(coalesceLoops) {}
 
   LogicalResult matchAndRewrite(TilingInterface op,
                                 PatternRewriter &rewriter) const override {
@@ -69,7 +70,7 @@
     // producer linalg.fill op. It implicitly assumes that the leading
     // dimensions of different linalg ops match, which is the current status;
     // but may not hold true in the long term.
-    tileSizes.resize(op.getLoopIteratorTypes().size());
+    tileSizes.resize(op.getLoopIteratorTypes().size(), 0);
 
     if (llvm::all_of(tileSizes, [](int64_t s) { return s == 0; })) {
       return failure();
@@ -88,6 +89,17 @@
     rewriter.replaceOp(op, tilingResult->replacements);
     filter.replaceLinalgTransformationFilter(rewriter,
                                              tilingResult->tiledOps.front());
+
+    if (coalesceLoops && tilingResult->loops.size() > 1) {
+      SmallVector<scf::ForOp> loops = llvm::map_to_vector(
+          tilingResult->loops, [](LoopLikeOpInterface loop) {
+            return cast<scf::ForOp>(loop.getOperation());
+          });
+      if (failed(mlir::coalesceLoops(rewriter, loops))) {
+        return failure();
+      }
+    }
+
     return success();
   }
 
@@ -161,13 +173,14 @@
 
   LinalgTransformationFilter filter;
   bool fuseInputProducer;
+  bool coalesceLoops;
 };
 
 /// Patterns for workgroup level tiling. Workgroup tiling is done at the flow
 /// level but we may have extra tiling for the reduction dimension. Therefore we
 /// tile again without distributing.
 static void populateTilingPatterns(RewritePatternSet &patterns,
-                                   bool fuseInputProducer) {
+                                   bool fuseInputProducer, bool coalesceLoops) {
   MLIRContext *context = patterns.getContext();
 
   LinalgTransformationFilter filter(
@@ -176,20 +189,21 @@
       StringAttr::get(context, getWorkgroupKTiledMarker()));
   filter.setMatchByDefault();
 
-  patterns.add<TileConsumerAndFuseInputProducer>(context, filter,
-                                                 fuseInputProducer);
+  patterns.add<TileConsumerAndFuseInputProducer>(
+      context, filter, fuseInputProducer, coalesceLoops);
 }
 
 } // namespace
 
 LogicalResult tileReductionToSerialLoops(mlir::FunctionOpInterface funcOp,
-                                         bool fuseInputProducer) {
+                                         bool fuseInputProducer,
+                                         bool coalesceLoops) {
   {
     // Tile again at the workgroup level since redution dimension were
     // ignored. Dimensions already tiled will be ignore since we tile to the
     // same size.
     RewritePatternSet wgTilingPatterns(funcOp.getContext());
-    populateTilingPatterns(wgTilingPatterns, fuseInputProducer);
+    populateTilingPatterns(wgTilingPatterns, fuseInputProducer, coalesceLoops);
     if (failed(applyPatternsAndFoldGreedily(funcOp,
                                             std::move(wgTilingPatterns)))) {
       return failure();
@@ -348,7 +362,8 @@
 
     // Tile to serial loops to the wg tile size to handle reductions and other
     // dimension that have not been distributed.
-    if (failed(tileReductionToSerialLoops(funcOp)))
+    if (failed(tileReductionToSerialLoops(funcOp, /*fuseInputProducer=*/false,
+                                          /*coalesceLoops=*/false)))
       return signalPassFailure();
 
     LLVM_DEBUG({
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTileToSerialLoops.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTileToSerialLoops.cpp
index 871ea08..e89cbf2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTileToSerialLoops.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTileToSerialLoops.cpp
@@ -17,13 +17,18 @@
 namespace {
 struct GPUTensorTileToSerialLoopsPass final
     : impl::GPUTensorTileToSerialLoopsPassBase<GPUTensorTileToSerialLoopsPass> {
+  using Base::Base;
+
   void runOnOperation() override {
     // Tile reductions based on the annotated tiling configuration.
     if (failed(tileReductionToSerialLoops(getOperation(),
-                                          /*fuseInputProducer=*/true))) {
+                                          /*fuseInputProducer=*/true,
+                                          coalesceLoops))) {
       return signalPassFailure();
     }
   }
 };
+
 } // namespace
+
 } // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.h b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.h
index c220c32..267ad45 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.h
@@ -50,7 +50,8 @@
 /// loops without distribution. If `fuseInputProducer` is true, input producers
 /// will be fused into the serial loop.
 LogicalResult tileReductionToSerialLoops(mlir::FunctionOpInterface funcOp,
-                                         bool fuseInputProducer = false);
+                                         bool fuseInputProducer = false,
+                                         bool coalesceLoops = false);
 
 /// Swizzles the workgroup order in `funcOp` according to the `swizzleLogTile`
 /// size. `swizzleLogTile` of 0 disables any swizzling.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
index 85fd630..079d26e 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
@@ -128,6 +128,10 @@
     InterfacePass<"iree-codegen-gpu-tensor-tile-to-serial-loops", "mlir::FunctionOpInterface"> {
   let summary = "Pass to tile reduction dimensions for certain GPU ops";
   let dependentDialects = ["::mlir::scf::SCFDialect"];
+  let options = [
+    Option<"coalesceLoops", "coalesce-loops", "bool", /*default=*/"false",
+           "Collapse the loops that are generated to a single loops">,
+  ];
 }
 
 def GPUTilePass : InterfacePass<"iree-codegen-gpu-tile", "mlir::FunctionOpInterface"> {
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tensor_alloc.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tensor_alloc.mlir
index 57c73a3..b6d5d8f 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tensor_alloc.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tensor_alloc.mlir
@@ -1,4 +1,5 @@
 // RUN: iree-opt %s --allow-unregistered-dialect --split-input-file --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-tensor-tile-to-serial-loops,iree-codegen-gpu-tensor-alloc))" | FileCheck %s
+// RUN: iree-opt %s --allow-unregistered-dialect --split-input-file --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-tensor-tile-to-serial-loops{coalesce-loops},iree-codegen-gpu-tensor-alloc))" | FileCheck %s --check-prefix=COALESCE_LOOPS
 
 func.func @matmul_2048x512x1024() {
   %c0 = arith.constant 0 : index
@@ -194,3 +195,29 @@
 //  CHECK-SAME:         outs(%[[ARG2]] : tensor<32x128xf32>)
 //       CHECK:       scf.yield
 //       CHECK:     scf.yield
+
+// -----
+
+func.func @conv() attributes {translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, subgroup_m_count = 1, subgroup_n_count = 4>}>} {
+  %cst = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x34x34x1280xf16>>
+  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<3x3x1280x1280xf16>>
+  %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x32x32x1280xf32>>
+  %workgroup_id_z = hal.interface.workgroup.id[2] : index
+  %workgroup_id_y = hal.interface.workgroup.id[1] : index
+  %workgroup_id_x = hal.interface.workgroup.id[0] : index
+  %3 = affine.apply affine_map<()[s0] -> (s0 * 256)>()[%workgroup_id_x]
+  %4 = flow.dispatch.tensor.load %2, offsets = [%workgroup_id_z, %workgroup_id_y, 0, %3], sizes = [1, 1, 32, 256], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<writeonly:tensor<2x32x32x1280xf32>> -> tensor<1x1x32x256xf32>
+  %5 = flow.dispatch.tensor.load %0, offsets = [%workgroup_id_z, %workgroup_id_y, 0, 0], sizes = [1, 3, 34, 1280], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x34x34x1280xf16>> -> tensor<1x3x34x1280xf16>
+  %6 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, %3], sizes = [3, 3, 1280, 256], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<3x3x1280x1280xf16>> -> tensor<3x3x1280x256xf16>
+  %7 = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 32, 256, 1, 1, 32]]>} ins(%cst : f32) outs(%4 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>
+  %8 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 32, 256, 1, 1, 32]]>, strides = dense<1> : vector<2xi64>} ins(%5, %6 : tensor<1x3x34x1280xf16>, tensor<3x3x1280x256xf16>) outs(%7 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>
+  flow.dispatch.tensor.store %8, %2, offsets = [%workgroup_id_z, %workgroup_id_y, 0, %3], sizes = [1, 1, 32, 256], strides = [1, 1, 1, 1] : tensor<1x1x32x256xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x32x32x1280xf32>>
+  return
+}
+// Check loops coalescing works
+//     COALESCE_LOOPS: func.func @conv()
+//     COALESCE_LOOPS:   scf.for
+// COALESCE_LOOPS-NOT:   scf.for
+//     COALESCE_LOOPS:   return
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index 6aed3fc..5e70de6 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -570,7 +570,12 @@
   }
 
   // Problem specific (reduction) tiling.
-  funcPassManager.addPass(createGPUTensorTileToSerialLoopsPass());
+  {
+    GPUTensorTileToSerialLoopsPassOptions tensorTileToSerialLoopsPassOptions;
+    tensorTileToSerialLoopsPassOptions.coalesceLoops = true;
+    funcPassManager.addPass(createGPUTensorTileToSerialLoopsPass(
+        tensorTileToSerialLoopsPassOptions));
+  }
 
   if (usePadToModelSharedMemcpy) {
     LLVMGPUMatmulPadOption option = LLVMGPUMatmulPadOption::ReductionDims;
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
index aaab32f..444e13d 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
@@ -170,7 +170,7 @@
 //     CHECK-SAME:     translation_info = #[[TRANSLATION]]
 // This has more than 2 iteartions. So we have prefetching enabled for this case. Due to
 // prefetching, we have one iteration peeled of so upper bound is 2048 - 128 = 1920.
-//          CHECK:   scf.for {{.*}} = %c0 to %c1920 step %c128 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<4x1x1x1x4x1xf16>)
+//          CHECK:   scf.for {{.*}} = %c0 to %c15 step %c1 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<4x1x1x1x4x1xf16>)
 //          CHECK:     arith.extf %[[ARG]] : vector<4x1x1x1x4x1xf16> to vector<4x1x1x1x4x1xf32>
 // CHECK-COUNT-32:     amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
 //          CHECK:     %[[TRUNC:.+]] = arith.truncf %{{.*}} : vector<4x1x1x1x4x1xf32> to vector<4x1x1x1x4x1xf16>
@@ -217,16 +217,10 @@
 }
 
 //    CHECK-LABEL: func.func @conv_nhwc
-//          CHECK:   scf.for {{.*}} = %c0 to %c3
-//          CHECK:     scf.for {{.*}} = %c0 to %c3
-// This has more than 2 iteartions. So we have prefetching enabled for this case. Due to
-// prefetching, we have one iteration peeled of so upper bound is 768 - 32 = 736.
-//          CHECK:       scf.for {{.*}} = %c0 to %c736 step %c32 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<2x4x1x1x4x1xf32>)
-// CHECK-COUNT-16:         amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
-//          CHECK:         scf.yield
-// CHECK-COUNT-16:       amdgpu.mfma
+//          CHECK:   scf.for {{.*}} = %c0 to %c215 step %c1 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<2x4x1x1x4x1xf32>)
+// CHECK-COUNT-16:     amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
 //          CHECK:     scf.yield
-//          CHECK:   scf.yield
+// CHECK-COUNT-16:   amdgpu.mfma
 //  CHECK-COUNT-8:   vector.transfer_write {{.+}} : vector<4x1xf32>, memref<2x256x512x256xf32, #hal.descriptor_type<storage_buffer>>
 
 // -----
@@ -300,7 +294,7 @@
 //    CHECK-LABEL: func.func @generic_2x1024x20x64x1280_f16
 // This has more than 2 iteartions. So we have prefetching enabled for this case. Due to
 // prefetching, we have one iteration peeled of so upper bound is 1280 - 128 = 1152.
-//          CHECK:   scf.for {{.*}} = %c0 to %c1152 step %c128 iter_args({{.*}}) -> (vector<2x2x1x1x4x1xf16>)
+//          CHECK:   scf.for {{.*}} = %c0 to %c9 step %c1 iter_args({{.*}}) -> (vector<2x2x1x1x4x1xf16>)
 // Each subgroup handles 2 * 2 tiles, and for each tile we accumulate 8 times
 // along the K dimension. So in total 32 mfma ops.
 // CHECK-COUNT-32:     amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
@@ -420,7 +414,7 @@
 // CHECK:         %[[RHS_LOAD:.+]] = vector.transfer_read %[[RHS_GLOBAL_SUB]]{{.+}} {in_bounds = [true, false, false]}
 // CHECK:         vector.transfer_write %[[LHS_LOAD]], %[[LHS_SHARED]]
 // CHECK:         vector.transfer_write %[[RHS_LOAD]], %[[RHS_SHARED]]
-// CHECK:         %[[RES:.+]] scf.for {{.*}} = %c0 to %c1265 step %c16 iter_args({{.*}}) -> (vector<1x1x1x1x1x1x1x4x1xf16>)
+// CHECK:         %[[RES:.+]] scf.for {{.*}} = %c0 to %c80 step %c1 iter_args({{.*}}) -> (vector<1x1x1x1x1x1x1x4x1xf16>)
 // CHECK-DAG:       %[[LHS_GLOBAL_SUB:.+]] = memref.subview %[[LHS_GLOBAL]]
 // CHECK-DAG:       %[[RHS_GLOBAL_SUB:.+]] = memref.subview %[[RHS_GLOBAL]]
 // CHECK:           %[[LHS_LOAD:.+]] = vector.transfer_read %[[LHS_GLOBAL_SUB]]