Fix distribution logic when number of parallel loops is greater than 3 (#18714)
Make the distribution logic for handling distribution of more than 3
loops more robust by avoiding use of tile sizes to figure out which
loops are distribute, but instead pass only loop ranges that are
gauranteed to be distributed. This also requires making the range passed
to these loops be the ranges of the tiled loops.
Fixes #18708
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
index 85670e5..a20aaf8 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
@@ -372,7 +372,7 @@
auto linalgTilingOptions =
linalg::LinalgTilingOptions()
.setDistributionOptions(getIREELinalgLoopDistributionOptions(
- tileSizes, distributionMethodValue, maxWorkgroupParallelDims))
+ distributionMethodValue, maxWorkgroupParallelDims))
.setInterchange(llvm::map_to_vector(
interchange,
[](int64_t v) -> unsigned { return static_cast<unsigned>(v); }))
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp
index 5430127..9bb4192 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp
@@ -291,10 +291,28 @@
IREETilingResult tilingResult;
tilingResult.tiledLoops.resize(numLoops, false);
- for (auto [index, tileSize] : llvm::enumerate(tileSizes)) {
- if (!isConstantIntValue(tileSize, 0)) {
- tilingResult.tiledLoops.set(index);
+ AffineExpr s0, s1, s2, s3; // lb, ub, step, tileSize
+ bindSymbols(rewriter.getContext(), s0, s1, s2, s3);
+ AffineExpr numTilesExprs = (s1 - s0).ceilDiv(s2 * s3);
+ for (auto [index, iteratorType, range, tileSize] :
+ llvm::enumerate(op.getLoopIteratorTypes(), iterationDomain, tileSizes)) {
+ // If distribution is specified, only parallel loops are tiled.
+ if (options.distribution && iteratorType != utils::IteratorType::parallel) {
+ continue;
}
+ // If tile size is 0, it isnt tiled.
+ if (isConstantIntValue(tileSize, 0)) {
+ continue;
+ }
+ // If number of tiles is statically know to be 1, the loop isnt tiled.
+ OpFoldResult numTiles = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, numTilesExprs,
+ {range.offset, range.size, range.stride, tileSize});
+ if (isConstantIntValue(numTiles, 1)) {
+ continue;
+ }
+
+ tilingResult.tiledLoops.set(index);
}
if (!tilingResult.tiledLoops.any()) {
@@ -328,40 +346,30 @@
iterationDomain.size(), linalg::DistributionMethod::None);
SmallVector<linalg::ProcInfo> procInfo;
if (options.distribution) {
- SmallVector<utils::IteratorType> iteratorTypes =
- op.getLoopIteratorTypes();
-
- // The parallel loops that are tiled are partitionable loops.
SmallVector<Range> parallelLoopRanges;
- SmallVector<unsigned> partitionedLoopIds;
-
- AffineExpr s0, s1, s2, s3; // lb, ub, step, tileSize
- bindSymbols(rewriter.getContext(), s0, s1, s2, s3);
- AffineExpr numTilesExprs = (s1 - s0).ceilDiv(s2 * s3);
- for (auto [index, iteratorType] : llvm::enumerate(iteratorTypes)) {
- if (iteratorType != utils::IteratorType::parallel ||
- isConstantIntValue(tileSizes[index], 0)) {
- continue;
+ for (auto loopIdx : llvm::seq<unsigned>(0, numLoops)) {
+ if (tilingResult.tiledLoops.test(loopIdx)) {
+ AffineExpr s0, s1;
+ bindSymbols(rewriter.getContext(), s0, s1);
+ OpFoldResult parallelLoopStep = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, s0 * s1,
+ {iterationDomain[loopIdx].stride, tileSizes[loopIdx]});
+ Range r = {iterationDomain[loopIdx].offset,
+ iterationDomain[loopIdx].size, parallelLoopStep};
+ parallelLoopRanges.emplace_back(std::move(r));
}
-
- OpFoldResult numTiles = affine::makeComposedFoldedAffineApply(
- rewriter, loc, numTilesExprs,
- {iterationDomain[index].offset, iterationDomain[index].size,
- iterationDomain[index].stride, tileSizes[index]});
- if (isConstantIntValue(numTiles, 1)) {
- continue;
- }
-
- parallelLoopRanges.push_back(iterationDomain[index]);
- partitionedLoopIds.push_back(index);
}
- // Query the callback to get the {procId, nprocs} to use.
procInfo =
options.distribution->procInfo(rewriter, loc, parallelLoopRanges);
- for (auto [index, loopIdx] : llvm::enumerate(partitionedLoopIds)) {
- distributionMethods[loopIdx] = procInfo[index].distributionMethod;
+ unsigned partitionedLoopIdx = 0;
+ for (auto loopIdx : llvm::seq<unsigned>(0, numLoops)) {
+ if (!tilingResult.tiledLoops.test(loopIdx)) {
+ continue;
+ }
+ distributionMethods[loopIdx] =
+ procInfo[partitionedLoopIdx++].distributionMethod;
}
}
@@ -443,7 +451,8 @@
worklist.pop_front();
for (OpOperand &operand : currOp->getOpOperands()) {
Operation *definingOp = operand.get().getDefiningOp();
- auto tilingInterfaceProducer = dyn_cast<TilingInterface>(definingOp);
+ auto tilingInterfaceProducer =
+ dyn_cast_or_null<TilingInterface>(definingOp);
if (!tilingInterfaceProducer || isa<tensor::PadOp>(definingOp) ||
producers.count(tilingInterfaceProducer)) {
continue;
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
index f0ce080..284e6cf 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
@@ -65,6 +65,7 @@
"repeated_matcher_use.mlir",
"replace_slow_min_max_ops.mlir",
"test_partitionable_loops_interface.mlir",
+ "tile_and_distribute_to_workgroups_func_scope.mlir",
"tile_and_distribute_to_workgroups.mlir",
"tile_and_distribute_workgroups_using_forall.mlir",
"transform_buffer_opt.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
index ba2b67b..f75729f 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
@@ -62,6 +62,7 @@
"replace_slow_min_max_ops.mlir"
"test_partitionable_loops_interface.mlir"
"tile_and_distribute_to_workgroups.mlir"
+ "tile_and_distribute_to_workgroups_func_scope.mlir"
"tile_and_distribute_workgroups_using_forall.mlir"
"transform_buffer_opt.mlir"
"transform_copy_operand.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups_func_scope.mlir b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups_func_scope.mlir
new file mode 100644
index 0000000..27ff4a7
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups_func_scope.mlir
@@ -0,0 +1,45 @@
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-tile-and-distribute-to-workgroups{distribution-method=2}, canonicalize, cse))" --mlir-print-local-scope --split-input-file %s | FileCheck %s
+
+func.func @multiple_dim_distribute(%s0 : index, %s1 : index, %s2 : index, %s3 : index,
+ %arg0 : tensor<2x3x4x5xf32>) attributes {
+ translation_info = #iree_codegen.translation_info<LLVMGPUTileAndFuse workgroup_size = [32, 1, 1] subgroup_size = 32>} {
+ %c0 = arith.constant 0 : index
+ %result = hal.interface.binding.subspan layout(
+ <bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">,
+ #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>)
+ binding(0) alignment(64) offset(%c0) flags(Indirect)
+ : !flow.dispatch.tensor<writeonly:tensor<?x2x?x3x?x4x?x5xf32>>{%s0, %s1, %s2, %s3}
+ %35 = tensor.empty(%s0, %s1, %s2, %s3) : tensor<?x2x?x3x?x4x?x5xf32>
+ %36 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d3, d5, d7)>,
+ affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
+ ins(%arg0 : tensor<2x3x4x5xf32>) outs(%35 : tensor<?x2x?x3x?x4x?x5xf32>)
+ attrs = {lowering_config = #iree_gpu.lowering_config<{thread = [1, 1, 1, 1, 1, 1, 1, 1], workgroup = [1, 2, 1, 4, 1, 4, 1, 1]}>} {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ } -> tensor<?x2x?x3x?x4x?x5xf32>
+ flow.dispatch.tensor.store %36, %result, offsets = [0, 0, 0, 0, 0, 0, 0, 0], sizes = [%s0, 2, %s1, 3, %s2, 4, %s3, 5], strides = [1, 1, 1, 1, 1, 1, 1, 1]
+ : tensor<?x2x?x3x?x4x?x5xf32> -> !flow.dispatch.tensor<writeonly:tensor<?x2x?x3x?x4x?x5xf32>>{%s0, %s1, %s2, %s3}
+ return
+}
+// CHECK-LABEL: func @multiple_dim_distribute(
+// CHECK-SAME: %[[S0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[S1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[S2:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[S3:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[INPUT:.+]]: tensor<2x3x4x5xf32>)
+// CHECK-DAG: %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0]
+// CHECK-DAG: %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1]
+// CHECK-DAG: %[[WG_ID_Z:.+]] = hal.interface.workgroup.id[2]
+// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<1x2x1x3x1x4x1x1xf32>
+// CHECK-DAG: %[[IN_SLICE:.+]] = tensor.extract_slice %[[INPUT]][0, 0, 0, %[[WG_ID_X]]] [2, 3, 4, 1]
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[IN_SLICE]] :
+// CHECK-SAME: outs(%[[EMPTY]] :
+// CHECK-DAG: %[[WG_ID_Z_0:.+]] = affine.apply affine_map<()[s0, s1, s2] -> ((s1 floordiv s2) floordiv s0)>()[%[[S1]], %[[WG_ID_Z]], %[[S2]]]
+// CHECK-DAG: %[[WG_ID_Z_1:.+]] = affine.apply affine_map<()[s0, s1, s2] -> ((s1 floordiv s2) mod s0)>()[%[[S1]], %[[WG_ID_Z]], %[[S2]]]
+// CHECK-DAG: %[[WG_ID_Z_2:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 mod s1)>()[%[[WG_ID_Z]], %[[S2]]]
+// CHECK: flow.dispatch.tensor.store %[[GENERIC]],
+// CHECK-SAME: offsets = [%[[WG_ID_Z_0]], 0, %[[WG_ID_Z_1]], 0, %[[WG_ID_Z_2]], 0, %[[WG_ID_Y]], %[[WG_ID_X]]]
+// CHECK-SAME: sizes = [1, 2, 1, 3, 1, 4, 1, 1]
diff --git a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
index 4ce88ae..0903573 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
@@ -704,17 +704,11 @@
}
linalg::LinalgLoopDistributionOptions getIREELinalgLoopDistributionOptions(
- const SmallVector<int64_t> &tileSizes,
linalg::DistributionMethod distributionMethod,
int32_t maxWorkgroupParallelDims) {
- return {[&tileSizes, distributionMethod,
+ return {[distributionMethod,
maxWorkgroupParallelDims](OpBuilder &builder, Location loc,
ArrayRef<Range> parallelLoopRanges) {
- SmallVector<int64_t> nonZeroTileSizes;
- for (int64_t size : tileSizes) {
- if (size != 0)
- nonZeroTileSizes.push_back(size);
- }
auto numParallelDims = parallelLoopRanges.size();
SmallVector<linalg::ProcInfo, 3> procInfo(numParallelDims);
@@ -729,11 +723,12 @@
OpFoldResult size = parallelLoopRanges[numParallelDims - dim - 1].size;
OpFoldResult offset =
parallelLoopRanges[numParallelDims - dim - 1].offset;
- AffineExpr d0, d1;
- int64_t tileSize = nonZeroTileSizes[numParallelDims - dim - 1];
- bindSymbols(builder.getContext(), d0, d1);
+ OpFoldResult step =
+ parallelLoopRanges[numParallelDims - dim - 1].stride;
+ AffineExpr d0, d1, d2;
+ bindSymbols(builder.getContext(), d0, d1, d2);
OpFoldResult numTiles = affine::makeComposedFoldedAffineApply(
- builder, loc, (d0 - d1).ceilDiv(tileSize), {size, offset});
+ builder, loc, (d1 - d0).ceilDiv(d2), {offset, size, step});
OpFoldResult dimValue;
if (dim == numParallelDims - 1)
dimValue = splitDim.value();
diff --git a/compiler/src/iree/compiler/Codegen/Utils/Utils.h b/compiler/src/iree/compiler/Codegen/Utils/Utils.h
index 1125957..3c19995 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/Utils.h
+++ b/compiler/src/iree/compiler/Codegen/Utils/Utils.h
@@ -183,7 +183,6 @@
/// Returns the option that distributes the ops using the flow workgroup
/// ID/Count operations.
linalg::LinalgLoopDistributionOptions getIREELinalgLoopDistributionOptions(
- const SmallVector<int64_t> &tileSizes,
linalg::DistributionMethod distributionMethod,
int32_t maxWorkgroupParallelDims = kNumMaxParallelDims);