[Common][TransformDialect] Fix the usage of num_threads for TileToFor… (#12438)
…allAndWorkgroupCountRegionOp
Prior to this patch, it was impossible to use
`TileToForallAndWorkgroupCountRegionOp` without specifying the tile
sizes. The lowering of the workgroup_count op would be simply skipped
leading to compiler errors down the line.
This patch fixes that by teaching `lowerWorkgroupCountComputingRegion`
(the function that lowers workgroup_count ops) how to get the workgroup
sizes from `num_threads`.
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index 65e2bcc..b28cb17 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -741,8 +741,8 @@
/// pdl::OperationType handles on the fly.
static LogicalResult lowerWorkgroupCountComputingRegion(
transform::TransformState &state, RewriterBase &rewriter, Location loc,
- HAL::ExecutableExportOp exportOp, ArrayRef<OpFoldResult> tileSizes,
- Optional<ArrayAttr> mapping) {
+ HAL::ExecutableExportOp exportOp, ArrayRef<OpFoldResult> numThreads,
+ ArrayRef<OpFoldResult> tileSizes, Optional<ArrayAttr> mapping) {
Region &r = exportOp.getWorkgroupCount();
if (!r.hasOneBlock()) {
return rewriter.notifyMatchFailure(exportOp,
@@ -760,31 +760,37 @@
auto workgroupCountOp = *workgroupCountOps.begin();
auto workload = workgroupCountOp.getOperands();
- SmallVector<OpFoldResult> unpackedTileSizes;
+ bool useNumThreads = !numThreads.empty();
+ ArrayRef<OpFoldResult> tileSizesOrNumThreads =
+ useNumThreads ? numThreads : tileSizes;
+ StringRef kindStr = useNumThreads ? "num thread" : "tile size";
+
+ SmallVector<OpFoldResult> unpackedTileSizesOrNumThreads;
int64_t numTiledDims = 0;
- for (auto ofr : tileSizes) {
+ for (auto ofr : tileSizesOrNumThreads) {
if (ofr.is<Value>() &&
ofr.get<Value>().getType().isa<pdl::OperationType>()) {
for (Operation *sizeProducer : state.getPayloadOps(ofr.get<Value>())) {
if (sizeProducer->getNumResults() != 1) {
- auto diag =
- mlir::emitDefiniteFailure(sizeProducer)
- << "the operation producing tile size must have one result";
+ auto diag = mlir::emitDefiniteFailure(sizeProducer)
+ << "the operation producing " << kindStr
+ << " must have one result";
diag.attachNote(loc) << "when applying this transform";
return diag;
}
- unpackedTileSizes.push_back(sizeProducer->getResult(0));
+ unpackedTileSizesOrNumThreads.push_back(sizeProducer->getResult(0));
}
} else {
- unpackedTileSizes.push_back(ofr);
+ unpackedTileSizesOrNumThreads.push_back(ofr);
}
- if (!isConstantIntValue(unpackedTileSizes.back(), 0)) ++numTiledDims;
+ if (!isConstantIntValue(unpackedTileSizesOrNumThreads.back(), 0))
+ ++numTiledDims;
}
- if (unpackedTileSizes.size() > workload.size()) {
+ if (unpackedTileSizesOrNumThreads.size() > workload.size()) {
return rewriter.notifyMatchFailure(
exportOp,
- "number of tile sizes overflow the dimension from the workload");
+ "number of " + kindStr + "s overflow the dimension from the workload");
}
// Generate permutation of tiled dims based on the specified mapping.
@@ -793,7 +799,8 @@
if (numTiledDims != mapping->size()) {
return rewriter.notifyMatchFailure(exportOp,
"number of mapping elements must "
- "match number of non-zero tile sizes");
+ "match number of non-zero " +
+ kindStr + "s");
}
for (DeviceMappingAttrInterface map : mapping.value())
mappingPermutation.push_back(map.getMappingId());
@@ -810,15 +817,20 @@
int64_t nextTiledDim = 0;
for (int64_t workgroupsDim : mappingPermutation) {
// Skip dims with tile size 0. These are not tiled.
- while (isConstantIntValue(unpackedTileSizes[nextTiledDim], 0))
+ while (isConstantIntValue(unpackedTileSizesOrNumThreads[nextTiledDim], 0))
++nextTiledDim;
- AffineExpr s0, s1;
- bindSymbols(rewriter.getContext(), s0, s1);
- auto m = AffineMap::get(0, 2, s0.ceilDiv(s1));
- workgroupCount[workgroupsDim] = makeComposedFoldedAffineApply(
- rewriter, loc, m,
- ArrayRef<OpFoldResult>{workload[nextTiledDim],
- unpackedTileSizes[nextTiledDim]});
+ if (useNumThreads) {
+ workgroupCount[workgroupsDim] =
+ unpackedTileSizesOrNumThreads[nextTiledDim];
+ } else {
+ AffineExpr s0, s1;
+ bindSymbols(rewriter.getContext(), s0, s1);
+ auto m = AffineMap::get(0, 2, s0.ceilDiv(s1));
+ workgroupCount[workgroupsDim] = makeComposedFoldedAffineApply(
+ rewriter, loc, m,
+ ArrayRef<OpFoldResult>{workload[nextTiledDim],
+ unpackedTileSizesOrNumThreads[nextTiledDim]});
+ }
++nextTiledDim;
}
@@ -882,8 +894,8 @@
/// regions are created by default in IREEs compilation flow.
IRRewriter rewriter(getContext());
if (failed(lowerWorkgroupCountComputingRegion(
- state, rewriter, getLoc(), exportOp.value(), getMixedTileSizes(),
- getMapping()))) {
+ state, rewriter, getLoc(), exportOp.value(), getMixedNumThreads(),
+ getMixedTileSizes(), getMapping()))) {
return mlir::emitDefiniteFailure(exportOp.value(),
"failed to lower workgroup count region");
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD
index b52c7ec..19d07e9 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD
@@ -46,6 +46,7 @@
"synchronize_symbol_visibility.mlir",
"test_config_mmt4d.mlir",
"transform_dialect_bufferize.mlir",
+ "transform_dialect_iree_tile_to_forall.mlir",
"transpose_avx2_lowering.mlir",
"triple_tiling_expert_pipeline.mlir",
"unfused_fma.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt
index c8dfb63..29b3167 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt
@@ -41,6 +41,7 @@
"synchronize_symbol_visibility.mlir"
"test_config_mmt4d.mlir"
"transform_dialect_bufferize.mlir"
+ "transform_dialect_iree_tile_to_forall.mlir"
"transpose_avx2_lowering.mlir"
"triple_tiling_expert_pipeline.mlir"
"unfused_fma.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/transform_dialect_iree_tile_to_forall.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/transform_dialect_iree_tile_to_forall.mlir
new file mode 100644
index 0000000..94e984e
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/transform_dialect_iree_tile_to_forall.mlir
@@ -0,0 +1,59 @@
+// RUN: iree-opt %s -iree-transform-dialect-interpreter -transform-dialect-drop-schedule --split-input-file | FileCheck %s
+
+// Check that we can specify `num_threads` when lowering
+// `workgroup_count_from_dag_root` using
+// `transform.iree.tile_to_forall_and_workgroup_count_region`
+
+
+#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu = "generic", cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 16 : index, target_triple = "x86_64-unknown-unknown-eabi-elf"}>
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>
+
+// Check that num_threads (32) is reflected in the map.
+// CHECK: #[[$NUM_THREADS_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
+
+hal.executable private @matmul_static_dispatch_0 {
+ hal.executable.variant public @embedded_elf_x86_64, target = #executable_target_embedded_elf_x86_64_ {
+
+ hal.executable.export public @matmul_static_dispatch_0_matmul_1024x4096x12345 ordinal(0) layout(#pipeline_layout) {
+ // Check that num_threads is reflected in the workgroup size.
+ // CHECK-LABEL: hal.executable.export public @matmul_static_dispatch_0_matmul_1024x4096x12345
+ // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: hal.return %[[C32]], %[[C1]], %[[C1]] : index, index, index
+ ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index):
+ %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3
+ hal.return %x, %y, %z : index, index, index
+ }
+
+ builtin.module {
+ func.func @matmul_static_dispatch_0_matmul_1024x4096x12345() {
+ // Check that the tiling matches num_threads.
+ // CHECK-LABEL: func.func @matmul_static_dispatch_0_matmul_1024x4096x12345
+ // CHECK: = scf.forall (%[[IV:.*]]) in (32) shared_outs(%{{.*}}) -> (tensor<1024x4096xf32>) {
+ // CHECK: %[[OFFSET:.*]] = affine.apply #[[$NUM_THREADS_MAP]](%[[IV]])
+ // CHECK: %extracted_slice = tensor.extract_slice %{{.*}}[%[[OFFSET]], 0] [32, 12345] [1, 1] : tensor<1024x12345xf32> to tensor<32x12345xf32>
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1024x12345xf32>>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<12345x4096xf32>>
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readwrite:tensor<1024x4096xf32>>
+ %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1024, 12345], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<1024x12345xf32>> -> tensor<1024x12345xf32>
+ %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [12345, 4096], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<12345x4096xf32>> -> tensor<12345x4096xf32>
+ %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [1024, 4096], strides = [1, 1] : !flow.dispatch.tensor<readwrite:tensor<1024x4096xf32>> -> tensor<1024x4096xf32>
+ %6 = linalg.matmul ins(%3, %4 : tensor<1024x12345xf32>, tensor<12345x4096xf32>) outs(%5 : tensor<1024x4096xf32>) -> tensor<1024x4096xf32>
+ flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [1024, 4096], strides = [1, 1] : tensor<1024x4096xf32> -> !flow.dispatch.tensor<readwrite:tensor<1024x4096xf32>>
+ return
+ }
+ }
+ }
+}
+
+transform.structured.canonicalized_sequence failures(propagate) {
+^bb1(%variant_op: !pdl.operation):
+ %original_matmul = transform.structured.match ops{["linalg.matmul"]} in %variant_op
+ : (!pdl.operation) -> !pdl.operation
+
+ %forall, %matmul =
+ transform.iree.tile_to_forall_and_workgroup_count_region %original_matmul
+ num_threads [32]
+ ( mapping = [#gpu.block<x>] )
+}