Fix bug in tile_to_foreach_thread mapping computation (#11312)
The permutation of tile sizes based on `mapping` was computed
incorrectly.
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index 174b330..2ecd1a2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -560,6 +560,7 @@
auto workload = workgroupCountOp.getOperands();
SmallVector<OpFoldResult> unpackedTileSizes;
+ int64_t numTiledDims = 0;
for (auto ofr : tileSizes) {
if (ofr.is<Value>() &&
ofr.get<Value>().getType().isa<pdl::OperationType>()) {
@@ -576,6 +577,7 @@
} else {
unpackedTileSizes.push_back(ofr);
}
+ if (!isConstantIntValue(unpackedTileSizes.back(), 0)) ++numTiledDims;
}
if (unpackedTileSizes.size() > workload.size()) {
@@ -584,33 +586,43 @@
"number of tile sizes overflow the dimension from the workload");
}
- SmallVector<OpFoldResult> workgroupCount, permutedWorkgroupCount;
+ // Generate permutation of tiled dims based on the specified mapping.
+ SmallVector<int64_t> mappingPermutation;
+ if (mapping.has_value()) {
+ if (numTiledDims != mapping->size()) {
+ return rewriter.notifyMatchFailure(exportOp,
+ "number of mapping elements must "
+ "match number of non-zero tile sizes");
+ }
+ for (DeviceMappingAttrInterface map : mapping.value())
+ mappingPermutation.push_back(map.getMappingId());
+ } else {
+ // No mapping specified: No permutation.
+ for (int64_t i = 0; i < numTiledDims; ++i) mappingPermutation.push_back(i);
+ }
+
+ // Compute number of workgroups.
+ SmallVector<OpFoldResult> workgroupCount(3, rewriter.getIndexAttr(1));
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(workgroupCountOp);
loc = workgroupCountOp.getLoc();
- for (auto tileSize : llvm::enumerate(unpackedTileSizes)) {
- if (isConstantIntValue(tileSize.value(), 0)) {
- workgroupCount.push_back(workload[tileSize.index()]);
- continue;
- }
+ int64_t nextTiledDim = 0;
+ for (int64_t workgroupsDim : mappingPermutation) {
+ // Skip dims with tile size 0. These are not tiled.
+ while (isConstantIntValue(unpackedTileSizes[nextTiledDim], 0))
+ ++nextTiledDim;
AffineExpr s0, s1;
bindSymbols(rewriter.getContext(), s0, s1);
auto m = AffineMap::get(0, 2, s0.ceilDiv(s1));
- OpFoldResult count = makeComposedFoldedAffineApply(
+ workgroupCount[workgroupsDim] = makeComposedFoldedAffineApply(
rewriter, loc, m,
- ArrayRef<OpFoldResult>{workload[tileSize.index()], tileSize.value()});
- workgroupCount.push_back(count);
+ ArrayRef<OpFoldResult>{workload[nextTiledDim],
+ unpackedTileSizes[nextTiledDim]});
+ ++nextTiledDim;
}
- // Make sure to fill unused dimensions with 1
- workgroupCount.resize(3, rewriter.getIndexAttr(1));
- permutedWorkgroupCount.resize(3, rewriter.getIndexAttr(1));
- int mappingId = 0;
- for (DeviceMappingAttrInterface map : mapping->getValue()) {
- permutedWorkgroupCount[map.getMappingId()] = workgroupCount[mappingId++];
- }
- rewriter.replaceOp(
- workgroupCountOp,
- getValueOrCreateConstantIndexOp(rewriter, loc, permutedWorkgroupCount));
+
+ rewriter.replaceOp(workgroupCountOp, getValueOrCreateConstantIndexOp(
+ rewriter, loc, workgroupCount));
return success();
}
diff --git a/tests/transform_dialect/cuda/vecadd2d.mlir b/tests/transform_dialect/cuda/vecadd2d.mlir
index 9ddd270..41b7c9a 100644
--- a/tests/transform_dialect/cuda/vecadd2d.mlir
+++ b/tests/transform_dialect/cuda/vecadd2d.mlir
@@ -40,6 +40,15 @@
// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/vecadd2d_codegen_spec.mlir | \
// RUN: FileCheck %s --check-prefix=CHECK
+// RUN: iree-opt %s --iree-hal-target-backends=cuda \
+// RUN: --iree-abi-transformation-pipeline \
+// RUN: --iree-flow-transformation-pipeline \
+// RUN: --iree-stream-transformation-pipeline \
+// RUN: --iree-hal-configuration-pipeline | \
+// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target)))' \
+// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/vecadd2d_codegen_spec_partial_tile.mlir | \
+// RUN: FileCheck %s --check-prefix=CHECK-PARTIAL-TILE
+
// RUN: iree-compile %s --iree-hal-target-backends=cuda \
// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/vecadd2d_codegen_spec.mlir | \
// RUN: iree-run-module --entry_function=vecadd2d --device=cuda |\
@@ -56,6 +65,11 @@
// CHECK: %[[BLKX:.*]] = hal.interface.workgroup.id[0] : index
// CHECK: memref.subview %0[%[[BLKZ:.*]], %[[BLKX:.*]]]
+// CHECK-PARTIAL-TILE: hal.executable.export
+// CHECK-PARTIAL-TILE: bb0(%[[DEV:.*]]: !hal.device, %[[A1:.*]]: index, %[[A2:.*]]: index):
+// CHECK-PARTIAL-TILE: %[[c1:.*]] = arith.constant 1 : index
+// CHECK-PARTIAL-TILE: %[[dim:.*]] = affine.apply #map()[%[[A2]]]
+// CHECK-PARTIAL-TILE: hal.return %[[c1]], %[[c1]], %[[dim]] : index, index, index
// EXEC: EXEC @vecadd2d
// EXEC: result[0]: hal.buffer_view
diff --git a/tests/transform_dialect/cuda/vecadd2d_codegen_spec_partial_tile.mlir b/tests/transform_dialect/cuda/vecadd2d_codegen_spec_partial_tile.mlir
new file mode 100644
index 0000000..75e2dad
--- /dev/null
+++ b/tests/transform_dialect/cuda/vecadd2d_codegen_spec_partial_tile.mlir
@@ -0,0 +1,7 @@
+transform.structured.canonicalized_sequence failures(propagate) {
+^bb1(%variant_op: !pdl.operation):
+ %generics = transform.structured.match ops{["linalg.generic"]} in %variant_op
+ // Tile only one dimension, skip the other one.
+ transform.iree.tile_to_foreach_thread_and_workgroup_count_region %generics
+ tile_sizes [0, 3] ( mapping = [#gpu.block<z>])
+}