[Codegen] Avoid distributing unit-extent dimensions. (#18271)
Existing distribution to workgroups has logic to avoid distributing
unit-trip dimensions. This is easily handled by using `scf.forall`
since a pattern can be used to drop the loop dimensions that are unit.
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp
index 3eaa7b3..66cfa40 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp
@@ -22,6 +22,8 @@
namespace mlir::iree_compiler {
+#define CEILDIV(a, b) ((a + b - 1) / b)
+
#define GEN_PASS_DEF_TILEANDDISTRIBUTETOWORKGROUPSUSINGFORALLOPPASS
#include "iree/compiler/Codegen/Common/Passes.h.inc"
@@ -119,6 +121,140 @@
return llvm::to_vector(llvm::reverse(mapping));
}
+//===---------------------------------------------------------------------===//
+// Post tiling cleanup patterns
+//===---------------------------------------------------------------------===//
+
+/// Prune the values corresponding to the dropped loops.
+static SmallVector<OpFoldResult>
+pruneDroppedLoops(ArrayRef<OpFoldResult> inputs,
+ const llvm::SmallDenseSet<int> &droppedLoops) {
+ SmallVector<OpFoldResult> prunedInputs;
+ for (auto [index, input] : llvm::enumerate(inputs)) {
+ if (droppedLoops.contains(index)) {
+ continue;
+ }
+ prunedInputs.push_back(input);
+ }
+ return prunedInputs;
+}
+
+/// Prune the mapping attributes corresponding to the dropped loops.
+/// Note that we cant just drop them. We need to rebalance the
+/// attributes so that the workgroup attributes are perfectly ordered.
+/// For example, if the attribute list is
+///
+/// ```
+/// [workgroup_mapping<x>, workgroup_mapping<z:1>,
+/// workgroup_mapping<z>, workgroup_mapping<y>,
+/// workgroup_mapping<z:3>, workgroup_mapping<z:2>]
+/// ```
+///
+/// and the droppedloops are `{1, 3}`, then the new mapping should be
+///
+/// ```
+/// [workgroup_mapping<x>, workgroup_mapping<y>,
+/// workgroup_mapping<z:1>, workgroup_mapping<z>]
+/// ```
+SmallVector<Attribute>
+pruneDroppedLoops(ArrayRef<Attribute> inputs,
+ const llvm::SmallDenseSet<int> &droppedLoops) {
+ SmallVector<IREE::Codegen::WorkgroupMappingAttr> droppedMappings;
+ SmallVector<Attribute> prunedAttrs;
+ for (auto [index, input] : llvm::enumerate(inputs)) {
+ if (droppedLoops.contains(index)) {
+ droppedMappings.push_back(
+ cast<IREE::Codegen::WorkgroupMappingAttr>(input));
+ } else {
+ prunedAttrs.push_back(input);
+ }
+ }
+ for (auto droppedMapping : droppedMappings) {
+ for (auto [index, prunedAttr] : llvm::enumerate(prunedAttrs)) {
+ auto prunedMappingAttr =
+ cast<IREE::Codegen::WorkgroupMappingAttr>(prunedAttr);
+ if (droppedMapping < prunedMappingAttr) {
+ prunedAttrs[index] =
+ IREE::Codegen::WorkgroupMappingAttr::getAttributeFromMappingId(
+ prunedAttr.getContext(), prunedMappingAttr.getMappingId() - 1);
+ }
+ }
+ }
+ return prunedAttrs;
+}
+
+/// Find dimensions of the loop that are unit-trip count and drop them from the
+/// distributed dimensions.
+static LogicalResult dropUnitDistributedDims(RewriterBase &rewriter,
+ scf::ForallOp forallOp) {
+ SmallVector<OpFoldResult> mixedLbs = forallOp.getMixedLowerBound();
+ SmallVector<OpFoldResult> mixedUbs = forallOp.getMixedUpperBound();
+ SmallVector<OpFoldResult> mixedSteps = forallOp.getMixedStep();
+
+ // Find the index of loops to be dropped.
+ llvm::SmallDenseSet<int> droppedLoops;
+ for (auto [index, lb, ub, step] :
+ llvm::enumerate(mixedLbs, mixedUbs, mixedSteps)) {
+ if (!isa<Attribute>(lb) || !isa<Attribute>(ub) || !isa<Attribute>(step)) {
+ continue;
+ }
+ int64_t lbVal = getConstantIntValue(lb).value();
+ int64_t ubVal = getConstantIntValue(ub).value();
+ int64_t stepVal = getConstantIntValue(step).value();
+ if (CEILDIV(ubVal - lbVal, stepVal) == 1) {
+ droppedLoops.insert(index);
+ }
+ }
+ if (droppedLoops.empty()) {
+ return success();
+ }
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(forallOp);
+ SmallVector<OpFoldResult> newLbs =
+ pruneDroppedLoops(ArrayRef<OpFoldResult>(mixedLbs), droppedLoops);
+ SmallVector<OpFoldResult> newUbs =
+ pruneDroppedLoops(ArrayRef<OpFoldResult>(mixedUbs), droppedLoops);
+ SmallVector<OpFoldResult> newSteps =
+ pruneDroppedLoops(ArrayRef<OpFoldResult>(mixedSteps), droppedLoops);
+ std::optional<ArrayAttr> newMapping;
+ if (auto currMapping = forallOp.getMapping()) {
+ SmallVector<Attribute> newMappingAttrs =
+ pruneDroppedLoops(currMapping.value().getValue(), droppedLoops);
+ newMapping = rewriter.getArrayAttr(newMappingAttrs);
+ }
+
+ Value zero = rewriter.create<arith::ConstantIndexOp>(forallOp.getLoc(), 0);
+ auto newForallOp = rewriter.create<scf::ForallOp>(
+ forallOp.getLoc(), newLbs, newUbs, newSteps, forallOp.getInits(),
+ newMapping, [](OpBuilder &, Location, ValueRange) {});
+
+ SmallVector<Value> argReplacements;
+ int newLoopBlockArgNum = 0;
+ auto newLoopBodyArgs = newForallOp.getInductionVars();
+ for (auto [index, oldBlockArg] :
+ llvm::enumerate(forallOp.getInductionVars())) {
+ if (droppedLoops.contains(index)) {
+ argReplacements.push_back(zero);
+ } else {
+ argReplacements.push_back(newLoopBodyArgs[newLoopBlockArgNum++]);
+ }
+ }
+ argReplacements.append(newForallOp.getRegionIterArgs().begin(),
+ newForallOp.getRegionIterArgs().end());
+
+ Block *oldLoopBody = forallOp.getBody();
+ Block *newLoopBody = newForallOp.getBody();
+ rewriter.mergeBlocks(oldLoopBody, newLoopBody, argReplacements);
+
+ rewriter.replaceOp(forallOp, newForallOp.getResults());
+ return success();
+}
+
+//===---------------------------------------------------------------------===//
+// Pass implementation.
+//===---------------------------------------------------------------------===//
+
void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
auto funcOp = getOperation();
auto *context = &getContext();
@@ -180,6 +316,19 @@
}
std::swap(tileAndFuseResult->loops, tilingLoops);
}
+ if (!tilingLoops.empty()) {
+ if (tilingLoops.size() != 1 || !isa<scf::ForallOp>(tilingLoops[0])) {
+ funcOp.emitOpError(
+ "expected tiling to produce a single `scf.forall` loop");
+ return signalPassFailure();
+ }
+
+ auto forallOp = cast<scf::ForallOp>(tilingLoops[0]);
+ if (failed(dropUnitDistributedDims(rewriter, forallOp))) {
+ forallOp.emitOpError("failed to drop unit dimensions");
+ return signalPassFailure();
+ }
+ }
// Cleanup patterns for tile and distribute
{
@@ -191,6 +340,7 @@
context->getOrLoadDialect<IREE::LinalgExt::IREELinalgExtDialect>()
->getCanonicalizationPatterns(patterns);
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
+ scf::ForallOp::getCanonicalizationPatterns(patterns, context);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
funcOp.emitOpError("tiling canonicalization failed");
return signalPassFailure();
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_workgroups_using_forall.mlir b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_workgroups_using_forall.mlir
index cc6d728..fb68210 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_workgroups_using_forall.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_workgroups_using_forall.mlir
@@ -135,22 +135,6 @@
// -----
-func.func @gemm_unit_N(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x1xf32>,
- %arg2 : tensor<?x1xf32>) -> tensor<?x1xf32> {
- %0 = linalg.matmul {
- lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64, 64]]>}
- ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x1xf32>)
- outs(%arg2 : tensor<?x1xf32>) -> tensor<?x1xf32>
- return %0 : tensor<?x1xf32>
-}
-// CHECK-LABEL: func.func @gemm_unit_N(
-// CHECK: %[[RESULT:.+]] = scf.forall (%[[IV0:[a-zA-Z0-9]+]])
-// CHECK: %[[MATMUL:.+]] = linalg.matmul
-// CHECK: scf.forall.in_parallel {
-// CHECK: tensor.parallel_insert_slice %[[MATMUL]] into %{{.+}}[%[[IV0]], 0] [%{{.+}}, 1]
-
-// -----
-
func.func @gemm_unit_M_unit_N(%arg0 : tensor<1x1xf32>, %arg1 : tensor<1x1xf32>,
%arg2 : tensor<1x1xf32>) -> tensor<1x1xf32> {
%0 = linalg.matmul {
@@ -351,3 +335,74 @@
// CHECK: scf.forall.in_parallel
// CHECK: tensor.parallel_insert_slice %[[MATMUL]]
// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @avoid_unit_range_distribute(
+ %arg0 : tensor<32x?x?x16x16xf16>, %arg1 : tensor<32x?x8x16x16xf16>,
+ %arg2 : tensor<32x?x16x8x16xf16>) -> tensor<32x?x16x8x16xf16> {
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d5, d2, d6)>,
+ affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d5, d3, d6, d4)>,
+ affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
+ ins(%arg0, %arg1 : tensor<32x?x?x16x16xf16>, tensor<32x?x8x16x16xf16>)
+ outs(%arg2 : tensor<32x?x16x8x16xf16>)
+ attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 16, 1, 16, 1, 16]]>} {
+ ^bb0(%b0: f16, %b1: f16, %b2 : f16):
+ %1 = arith.mulf %b0, %b1 : f16
+ %2 = arith.addf %b2, %1 : f16
+ linalg.yield %2 : f16
+ } -> tensor<32x?x16x8x16xf16>
+ return %0 : tensor<32x?x16x8x16xf16>
+}
+// CHECK-LABEL: func @avoid_unit_range_distribute(
+// CHECK: scf.forall (%{{[a-zA-Z0-9]+}}, %{{[a-zA-Z0-9]+}}, %{{[a-zA-Z0-9]+}}) in (32, %{{.+}}, 8)
+// CHECK: mapping = [#iree_codegen.workgroup_mapping<z>, #iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]
+
+// -----
+
+// This just verifies that constant dim propagation works as expected after tiling.
+func.func @set_size_to_tilesize_when_divisible(
+ %arg0 : tensor<?x16x32x128xf16>, %arg1 : tensor<4096x32x28xf16>,
+ %arg2 : tensor<?x16x4096xf16>) -> tensor<?x16x4096xf16> {
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
+ ins(%arg0, %arg1 : tensor<?x16x32x128xf16>, tensor<4096x32x28xf16>)
+ outs(%arg2 : tensor<?x16x4096xf16>)
+ attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 16, 128, 1, 128]]>} {
+ ^bb0(%b0: f16, %b1: f16, %b2 : f16):
+ %1 = arith.mulf %b0, %b1 : f16
+ %2 = arith.addf %b2, %1 : f16
+ linalg.yield %2 : f16
+ } -> tensor<?x16x4096xf16>
+ return %0 : tensor<?x16x4096xf16>
+}
+// CHECK-LABEL: func @set_size_to_tilesize_when_divisible(
+// CHECK: scf.forall
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK: scf.forall.in_parallel
+// CHECK: tensor.parallel_insert_slice %[[GENERIC]]
+// CHECK-SAME: tensor<1x16x128xf16> into tensor<?x16x4096xf16>
+
+// -----
+
+// This just verifies that constant dim propagation works as expected after tiling.
+func.func @generate_no_distribution(%arg0 : tensor<16xf16>) -> tensor<16xf16> {
+ %empty = tensor.empty() : tensor<16xf16>
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]}
+ ins(%arg0 : tensor<16xf16>) outs(%empty : tensor<16xf16>)
+ attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[16]]>} {
+ ^bb0(%b0: f16, %b1: f16):
+ %1 = arith.mulf %b0, %b0 : f16
+ linalg.yield %1 : f16
+ } -> tensor<16xf16>
+ return %0 : tensor<16xf16>
+}
+// CHECK-LABEL: func @generate_no_distribution(
+// CHECK-NOT: scf.forall
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp
index b818ac8..5f4be34 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp
@@ -382,6 +382,19 @@
return success();
}
+WorkgroupMappingAttr
+WorkgroupMappingAttr::getAttributeFromMappingId(MLIRContext *context,
+ int64_t mappingId) {
+ int64_t linearizedDim = std::max<int64_t>(mappingId - 2, 0);
+ WorkgroupId id =
+ symbolizeWorkgroupId(std::min<uint64_t>(mappingId, 2)).value();
+ return WorkgroupMappingAttr::get(context, id, linearizedDim);
+}
+
+bool WorkgroupMappingAttr::operator<(const WorkgroupMappingAttr &rhs) const {
+ return getMappingId() < rhs.getMappingId();
+}
+
LogicalResult WorkgroupMappingAttr::verifyAttrList(
MLIRContext *context, function_ref<InFlightDiagnostic()> emitError,
ArrayRef<Attribute> attrs) {
@@ -404,14 +417,8 @@
mappingAttrs.push_back(typedAttr);
}
- llvm::sort(mappingAttrs, [](const IREE::Codegen::WorkgroupMappingAttr &lhs,
- const IREE::Codegen::WorkgroupMappingAttr &rhs) {
- if (lhs.getId() != rhs.getId()) {
- return lhs.getId() < rhs.getId();
- }
- assert(lhs.getId() == IREE::Codegen::WorkgroupId::IdZ);
- return lhs.getDelinearizedDim() < rhs.getDelinearizedDim();
- });
+ llvm::sort(mappingAttrs);
+
// First element has to be `workgroup_mapping<x>`.
if (mappingAttrs.front().getId() != IREE::Codegen::WorkgroupId::IdX) {
return emitError() << "missing `workgroup_mapping<x>`";
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
index 95ea6e5..6190b6c 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
@@ -190,6 +190,15 @@
static LogicalResult verifyAttrList(::mlir::MLIRContext *context,
::llvm::function_ref<::mlir::InFlightDiagnostic ()> emitError,
ArrayRef<Attribute> attrs);
+
+ // Convert from mapping ID to attribute.
+ static ::mlir::iree_compiler::IREE::Codegen::WorkgroupMappingAttr
+ getAttributeFromMappingId(MLIRContext *context, int64_t mappingId);
+
+ // Less than operator for easy comparison.
+ bool operator<(
+ const ::mlir::iree_compiler::IREE::Codegen::WorkgroupMappingAttr &rhs)
+ const;
}];
let genVerifyDecl = 1;