[Codegen] Avoid distributing unit-extent dimensions. (#18271)

Existing distribution to workgroups has logic to avoid distributing
unit-trip dimensions. This is easily handled by using `scf.forall`
since a pattern can be used to drop the loop dimensions that are unit.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp
index 3eaa7b3..66cfa40 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp
@@ -22,6 +22,8 @@
 
 namespace mlir::iree_compiler {
 
+#define CEILDIV(a, b) ((a + b - 1) / b)
+
 #define GEN_PASS_DEF_TILEANDDISTRIBUTETOWORKGROUPSUSINGFORALLOPPASS
 #include "iree/compiler/Codegen/Common/Passes.h.inc"
 
@@ -119,6 +121,140 @@
   return llvm::to_vector(llvm::reverse(mapping));
 }
 
+//===---------------------------------------------------------------------===//
+// Post tiling cleanup patterns
+//===---------------------------------------------------------------------===//
+
+/// Prune the values corresponding to the dropped loops.
+static SmallVector<OpFoldResult>
+pruneDroppedLoops(ArrayRef<OpFoldResult> inputs,
+                  const llvm::SmallDenseSet<int> &droppedLoops) {
+  SmallVector<OpFoldResult> prunedInputs;
+  for (auto [index, input] : llvm::enumerate(inputs)) {
+    if (droppedLoops.contains(index)) {
+      continue;
+    }
+    prunedInputs.push_back(input);
+  }
+  return prunedInputs;
+}
+
+/// Prune the mapping attributes corresponding to the dropped loops.
+/// Note that we cant just drop them. We need to rebalance the
+/// attributes so that the workgroup attributes are perfectly ordered.
+/// For example, if the attribute list is
+///
+/// ```
+/// [workgroup_mapping<x>, workgroup_mapping<z:1>,
+///  workgroup_mapping<z>, workgroup_mapping<y>,
+///  workgroup_mapping<z:3>, workgroup_mapping<z:2>]
+/// ```
+///
+/// and the droppedloops are `{1, 3}`, then the new mapping should be
+///
+/// ```
+/// [workgroup_mapping<x>, workgroup_mapping<y>,
+///  workgroup_mapping<z:1>, workgroup_mapping<z>]
+/// ```
+SmallVector<Attribute>
+pruneDroppedLoops(ArrayRef<Attribute> inputs,
+                  const llvm::SmallDenseSet<int> &droppedLoops) {
+  SmallVector<IREE::Codegen::WorkgroupMappingAttr> droppedMappings;
+  SmallVector<Attribute> prunedAttrs;
+  for (auto [index, input] : llvm::enumerate(inputs)) {
+    if (droppedLoops.contains(index)) {
+      droppedMappings.push_back(
+          cast<IREE::Codegen::WorkgroupMappingAttr>(input));
+    } else {
+      prunedAttrs.push_back(input);
+    }
+  }
+  for (auto droppedMapping : droppedMappings) {
+    for (auto [index, prunedAttr] : llvm::enumerate(prunedAttrs)) {
+      auto prunedMappingAttr =
+          cast<IREE::Codegen::WorkgroupMappingAttr>(prunedAttr);
+      if (droppedMapping < prunedMappingAttr) {
+        prunedAttrs[index] =
+            IREE::Codegen::WorkgroupMappingAttr::getAttributeFromMappingId(
+                prunedAttr.getContext(), prunedMappingAttr.getMappingId() - 1);
+      }
+    }
+  }
+  return prunedAttrs;
+}
+
+/// Find dimensions of the loop that are unit-trip count and drop them from the
+/// distributed dimensions.
+static LogicalResult dropUnitDistributedDims(RewriterBase &rewriter,
+                                             scf::ForallOp forallOp) {
+  SmallVector<OpFoldResult> mixedLbs = forallOp.getMixedLowerBound();
+  SmallVector<OpFoldResult> mixedUbs = forallOp.getMixedUpperBound();
+  SmallVector<OpFoldResult> mixedSteps = forallOp.getMixedStep();
+
+  // Find the index of loops to be dropped.
+  llvm::SmallDenseSet<int> droppedLoops;
+  for (auto [index, lb, ub, step] :
+       llvm::enumerate(mixedLbs, mixedUbs, mixedSteps)) {
+    if (!isa<Attribute>(lb) || !isa<Attribute>(ub) || !isa<Attribute>(step)) {
+      continue;
+    }
+    int64_t lbVal = getConstantIntValue(lb).value();
+    int64_t ubVal = getConstantIntValue(ub).value();
+    int64_t stepVal = getConstantIntValue(step).value();
+    if (CEILDIV(ubVal - lbVal, stepVal) == 1) {
+      droppedLoops.insert(index);
+    }
+  }
+  if (droppedLoops.empty()) {
+    return success();
+  }
+
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(forallOp);
+  SmallVector<OpFoldResult> newLbs =
+      pruneDroppedLoops(ArrayRef<OpFoldResult>(mixedLbs), droppedLoops);
+  SmallVector<OpFoldResult> newUbs =
+      pruneDroppedLoops(ArrayRef<OpFoldResult>(mixedUbs), droppedLoops);
+  SmallVector<OpFoldResult> newSteps =
+      pruneDroppedLoops(ArrayRef<OpFoldResult>(mixedSteps), droppedLoops);
+  std::optional<ArrayAttr> newMapping;
+  if (auto currMapping = forallOp.getMapping()) {
+    SmallVector<Attribute> newMappingAttrs =
+        pruneDroppedLoops(currMapping.value().getValue(), droppedLoops);
+    newMapping = rewriter.getArrayAttr(newMappingAttrs);
+  }
+
+  Value zero = rewriter.create<arith::ConstantIndexOp>(forallOp.getLoc(), 0);
+  auto newForallOp = rewriter.create<scf::ForallOp>(
+      forallOp.getLoc(), newLbs, newUbs, newSteps, forallOp.getInits(),
+      newMapping, [](OpBuilder &, Location, ValueRange) {});
+
+  SmallVector<Value> argReplacements;
+  int newLoopBlockArgNum = 0;
+  auto newLoopBodyArgs = newForallOp.getInductionVars();
+  for (auto [index, oldBlockArg] :
+       llvm::enumerate(forallOp.getInductionVars())) {
+    if (droppedLoops.contains(index)) {
+      argReplacements.push_back(zero);
+    } else {
+      argReplacements.push_back(newLoopBodyArgs[newLoopBlockArgNum++]);
+    }
+  }
+  argReplacements.append(newForallOp.getRegionIterArgs().begin(),
+                         newForallOp.getRegionIterArgs().end());
+
+  Block *oldLoopBody = forallOp.getBody();
+  Block *newLoopBody = newForallOp.getBody();
+  rewriter.mergeBlocks(oldLoopBody, newLoopBody, argReplacements);
+
+  rewriter.replaceOp(forallOp, newForallOp.getResults());
+  return success();
+}
+
+//===---------------------------------------------------------------------===//
+// Pass implementation.
+//===---------------------------------------------------------------------===//
+
 void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
   auto funcOp = getOperation();
   auto *context = &getContext();
@@ -180,6 +316,19 @@
     }
     std::swap(tileAndFuseResult->loops, tilingLoops);
   }
+  if (!tilingLoops.empty()) {
+    if (tilingLoops.size() != 1 || !isa<scf::ForallOp>(tilingLoops[0])) {
+      funcOp.emitOpError(
+          "expected tiling to produce a single `scf.forall` loop");
+      return signalPassFailure();
+    }
+
+    auto forallOp = cast<scf::ForallOp>(tilingLoops[0]);
+    if (failed(dropUnitDistributedDims(rewriter, forallOp))) {
+      forallOp.emitOpError("failed to drop unit dimensions");
+      return signalPassFailure();
+    }
+  }
 
   // Cleanup patterns for tile and distribute
   {
@@ -191,6 +340,7 @@
     context->getOrLoadDialect<IREE::LinalgExt::IREELinalgExtDialect>()
         ->getCanonicalizationPatterns(patterns);
     memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
+    scf::ForallOp::getCanonicalizationPatterns(patterns, context);
     if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
       funcOp.emitOpError("tiling canonicalization failed");
       return signalPassFailure();
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_workgroups_using_forall.mlir b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_workgroups_using_forall.mlir
index cc6d728..fb68210 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_workgroups_using_forall.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_workgroups_using_forall.mlir
@@ -135,22 +135,6 @@
 
 // -----
 
-func.func @gemm_unit_N(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x1xf32>,
-    %arg2 : tensor<?x1xf32>) -> tensor<?x1xf32> {
-  %0 = linalg.matmul {
-      lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64, 64]]>}
-      ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x1xf32>)
-      outs(%arg2 : tensor<?x1xf32>) -> tensor<?x1xf32>
-  return %0 : tensor<?x1xf32>
-}
-// CHECK-LABEL: func.func @gemm_unit_N(
-//       CHECK:   %[[RESULT:.+]] = scf.forall (%[[IV0:[a-zA-Z0-9]+]])
-//       CHECK:     %[[MATMUL:.+]] = linalg.matmul
-//       CHECK:     scf.forall.in_parallel {
-//       CHECK:       tensor.parallel_insert_slice %[[MATMUL]] into %{{.+}}[%[[IV0]], 0] [%{{.+}}, 1]
-
-// -----
-
 func.func @gemm_unit_M_unit_N(%arg0 : tensor<1x1xf32>, %arg1 : tensor<1x1xf32>,
     %arg2 : tensor<1x1xf32>) -> tensor<1x1xf32> {
   %0 = linalg.matmul {
@@ -351,3 +335,74 @@
 //       CHECK:     scf.forall.in_parallel
 //       CHECK:       tensor.parallel_insert_slice %[[MATMUL]]
 //       CHECK:   return %[[RESULT]]
+
+// -----
+
+func.func @avoid_unit_range_distribute(
+    %arg0 : tensor<32x?x?x16x16xf16>, %arg1 : tensor<32x?x8x16x16xf16>,
+    %arg2 : tensor<32x?x16x8x16xf16>) -> tensor<32x?x16x8x16xf16> {
+   %0 = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d5, d2, d6)>,
+                       affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d5, d3, d6, d4)>,
+                       affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>],
+      iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
+      ins(%arg0, %arg1 : tensor<32x?x?x16x16xf16>, tensor<32x?x8x16x16xf16>)
+      outs(%arg2 : tensor<32x?x16x8x16xf16>)
+      attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 16, 1, 16, 1, 16]]>} {
+    ^bb0(%b0: f16, %b1: f16, %b2 : f16):
+      %1 = arith.mulf %b0, %b1 : f16
+      %2 = arith.addf %b2, %1 : f16
+      linalg.yield %2 : f16
+  } -> tensor<32x?x16x8x16xf16>
+  return %0 : tensor<32x?x16x8x16xf16>
+}
+// CHECK-LABEL: func @avoid_unit_range_distribute(
+//       CHECK:   scf.forall (%{{[a-zA-Z0-9]+}}, %{{[a-zA-Z0-9]+}}, %{{[a-zA-Z0-9]+}}) in (32, %{{.+}}, 8)
+//       CHECK:   mapping = [#iree_codegen.workgroup_mapping<z>, #iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]
+
+// -----
+
+// This just verifies that constant dim propagation works as expected after tiling.
+func.func @set_size_to_tilesize_when_divisible(
+    %arg0 : tensor<?x16x32x128xf16>, %arg1 : tensor<4096x32x28xf16>,
+    %arg2 : tensor<?x16x4096xf16>) -> tensor<?x16x4096xf16> {
+   %0 = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>,
+                       affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>,
+                       affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>],
+      iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
+      ins(%arg0, %arg1 : tensor<?x16x32x128xf16>, tensor<4096x32x28xf16>)
+      outs(%arg2 : tensor<?x16x4096xf16>)
+      attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 16, 128, 1, 128]]>} {
+    ^bb0(%b0: f16, %b1: f16, %b2 : f16):
+      %1 = arith.mulf %b0, %b1 : f16
+      %2 = arith.addf %b2, %1 : f16
+      linalg.yield %2 : f16
+  } -> tensor<?x16x4096xf16>
+  return %0 : tensor<?x16x4096xf16>
+}
+// CHECK-LABEL: func @set_size_to_tilesize_when_divisible(
+//       CHECK:   scf.forall
+//       CHECK:     %[[GENERIC:.+]] = linalg.generic
+//       CHECK:     scf.forall.in_parallel
+//       CHECK:       tensor.parallel_insert_slice %[[GENERIC]]
+//  CHECK-SAME:           tensor<1x16x128xf16> into tensor<?x16x4096xf16>
+
+// -----
+
+// This just verifies that constant dim propagation works as expected after tiling.
+func.func @generate_no_distribution(%arg0 : tensor<16xf16>) -> tensor<16xf16> {
+   %empty = tensor.empty() : tensor<16xf16>
+   %0 = linalg.generic {
+      indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+      iterator_types = ["parallel"]}
+      ins(%arg0 : tensor<16xf16>) outs(%empty : tensor<16xf16>)
+      attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[16]]>} {
+    ^bb0(%b0: f16, %b1: f16):
+      %1 = arith.mulf %b0, %b0 : f16
+      linalg.yield %1 : f16
+  } -> tensor<16xf16>
+  return %0 : tensor<16xf16>
+}
+// CHECK-LABEL: func @generate_no_distribution(
+//   CHECK-NOT:   scf.forall
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp
index b818ac8..5f4be34 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp
@@ -382,6 +382,19 @@
   return success();
 }
 
+WorkgroupMappingAttr
+WorkgroupMappingAttr::getAttributeFromMappingId(MLIRContext *context,
+                                                int64_t mappingId) {
+  int64_t linearizedDim = std::max<int64_t>(mappingId - 2, 0);
+  WorkgroupId id =
+      symbolizeWorkgroupId(std::min<uint64_t>(mappingId, 2)).value();
+  return WorkgroupMappingAttr::get(context, id, linearizedDim);
+}
+
+bool WorkgroupMappingAttr::operator<(const WorkgroupMappingAttr &rhs) const {
+  return getMappingId() < rhs.getMappingId();
+}
+
 LogicalResult WorkgroupMappingAttr::verifyAttrList(
     MLIRContext *context, function_ref<InFlightDiagnostic()> emitError,
     ArrayRef<Attribute> attrs) {
@@ -404,14 +417,8 @@
     mappingAttrs.push_back(typedAttr);
   }
 
-  llvm::sort(mappingAttrs, [](const IREE::Codegen::WorkgroupMappingAttr &lhs,
-                              const IREE::Codegen::WorkgroupMappingAttr &rhs) {
-    if (lhs.getId() != rhs.getId()) {
-      return lhs.getId() < rhs.getId();
-    }
-    assert(lhs.getId() == IREE::Codegen::WorkgroupId::IdZ);
-    return lhs.getDelinearizedDim() < rhs.getDelinearizedDim();
-  });
+  llvm::sort(mappingAttrs);
+
   // First element has to be `workgroup_mapping<x>`.
   if (mappingAttrs.front().getId() != IREE::Codegen::WorkgroupId::IdX) {
     return emitError() << "missing `workgroup_mapping<x>`";
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
index 95ea6e5..6190b6c 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
@@ -190,6 +190,15 @@
     static LogicalResult verifyAttrList(::mlir::MLIRContext *context,
       ::llvm::function_ref<::mlir::InFlightDiagnostic ()> emitError,
       ArrayRef<Attribute> attrs);
+
+    // Convert from mapping ID to attribute.
+    static ::mlir::iree_compiler::IREE::Codegen::WorkgroupMappingAttr
+    getAttributeFromMappingId(MLIRContext *context, int64_t mappingId);
+
+    // Less than operator for easy comparison.
+    bool operator<(
+        const ::mlir::iree_compiler::IREE::Codegen::WorkgroupMappingAttr &rhs)
+        const;
   }];
 
   let genVerifyDecl = 1;