Relax block size constraints: dynamic reduction-v3 @peak (#11547)
This is a cherry-pick of #11530 onto main.
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index 1e1e953..c8bed8d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -294,17 +294,7 @@
if (foreachThreadOp.getNumThreads().size() > 3)
return foreachThreadOp->emitError(
"scf.foreach_thread with rank > 3 does not lower to workgroup");
- if (llvm::any_of(foreachThreadOp.getNumThreads(), [](Value v) {
- return !v.getDefiningOp<arith::ConstantIndexOp>();
- })) {
- return foreachThreadOp->emitError(
- "unsupported dynamic workgroup_count atm --- need to slice out "
- "workgroup_count computation into ExecutableExport::workgroup_count. "
- "This region may require arbitrary computations and cannot magically "
- "match what the `stream.cmd.dispatch` has already imposed on us at a "
- "distance. For now we must specify the number of values properly "
- "when applying the topLevel tile_to_foreach_thread_op");
- }
+
if (!foreachThreadOp.getMapping().has_value())
return foreachThreadOp->emitError("mapping must be present");
SmallVector<Attribute> blockMapping =
@@ -335,9 +325,6 @@
};
SmallVector<Value> gridDimValues = scf::ForeachThreadOp::getValuesSortedByKey(
blockMapping, numBlocks, comparator);
- SmallVector<int64_t> gridDims;
- for (Value v : gridDimValues)
- gridDims.push_back(v.getDefiningOp<arith::ConstantIndexOp>().value());
// Step 3. Outline the compute workload region and set up the workload
// operands, if this has not been done already.
@@ -352,6 +339,18 @@
// the flow level and explicitly match the ops we want to fuse.
// Once fusion is customizable enough in perpetuity, we can retire this.
if (exportOp.getWorkgroupCount().empty()) {
+ if (llvm::any_of(foreachThreadOp.getNumThreads(), [](Value v) {
+ return !v.getDefiningOp<arith::ConstantIndexOp>();
+ })) {
+ return foreachThreadOp->emitError(
+ "unsupported dynamic workgroup_count atm --- need to slice out "
+ "workgroup_count computation into ExecutableExport::workgroup_count."
+ "\nThis region may require arbitrary computations and cannot "
+ "magically match what the `stream.cmd.dispatch` has already imposed "
+ "on us at a distance."
+ "\nFor now we must specify the number of values properly when "
+ "applying the topLevel tile_to_foreach_thread_op");
+ }
if (failed(populateWorkgroupCountComputingRegion(rewriter, foreachThreadOp,
exportOp))) {
return foreachThreadOp->emitOpError(
diff --git a/tests/transform_dialect/cuda/reduction_v3.mlir b/tests/transform_dialect/cuda/reduction_v3.mlir
index 2ede7af..9f58ad3 100644
--- a/tests/transform_dialect/cuda/reduction_v3.mlir
+++ b/tests/transform_dialect/cuda/reduction_v3.mlir
@@ -1,10 +1,12 @@
-!in_tensor_t = tensor<33x?xf32>
-!out_tensor_t = tensor<33xf32>
+!in_tensor_t = tensor<?x?xf32>
+!out_tensor_t = tensor<?xf32>
func.func @reduce(%arg : !in_tensor_t) -> (!out_tensor_t) {
+ %c0 = arith.constant 0 : index
%cst = arith.constant -0.000000e+00 : f32
-
- %0 = tensor.empty() : !out_tensor_t
+
+ %d0 = tensor.dim %arg, %c0 : !in_tensor_t
+ %0 = tensor.empty(%d0) : !out_tensor_t
%1 = linalg.fill ins(%cst : f32) outs(%0 : !out_tensor_t) -> !out_tensor_t
%2 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
@@ -29,7 +31,7 @@
// RUN: iree-compile %s --iree-hal-target-backends=cuda \
// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/reduction_v3_codegen_spec.mlir | \
-// RUN: iree-run-module --entry_function=reduce --device=cuda --function_input="33x1024xf32=1" |\
+// RUN: iree-run-module --entry_function=reduce --device=cuda --function_input="123x4567xf32=1" |\
// RUN: FileCheck %s --check-prefix=EXEC
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
@@ -40,7 +42,7 @@
// CHECK: %[[SHMEM_VIEW_EXPANDED:.*]] = memref.subview %[[SHMEM_ALLOC]][0, %[[TIDX]]]{{.*}}to memref<1x1xf32, strided<[1024, 1], offset: ?>, 3>
// Local per-thread scf.for-based reduction.
// CHECK: scf.for
- // CHECK: vector.transfer_read %{{.*}} {in_bounds = [true]} : memref<33x?xf32>, vector<1xf32>
+ // CHECK: vector.transfer_read %{{.*}} {in_bounds = [true]} : memref<?x?xf32>, vector<1xf32>
// CHECK: arith.addf {{.*}} : vector<1xf32>
// CHECK: scf.yield %{{.*}} : vector<1xf32>
@@ -57,6 +59,6 @@
// CHECK: vector.transfer_write %[[RES_VEC]]
// CHECK: gpu.barrier
-// only checking the first 6 of 33
+// only checking the first 6 of 123
// EXEC: result[0]: hal.buffer_view
-// EXEC-NEXT: 33xf32=1024 1024 1024 1024 1024 1024
+// EXEC-NEXT: 123xf32=4567 4567 4567 4567 4567 4567
diff --git a/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir
index 48421ce..93107bb 100644
--- a/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir
@@ -1,6 +1,6 @@
// RUN: iree-opt %s
-transform.structured.canonicalized_sequence failures(suppress) {
+transform.structured.canonicalized_sequence failures(propagate) {
^bb1(%variant_op: !pdl.operation):
%fill = transform.structured.match ops{["linalg.fill"]} in %variant_op
%reduction = transform.structured.match ops{["linalg.generic"]} in %variant_op
@@ -13,7 +13,7 @@
transform.structured.fuse_into_containing_op %fill into %foreach_thread_grid
// Step 2. Split the reduction to get meatier parallelism.
- // This also parallelizes to threads
+ // This also parallelizes to threads.
// ===========================================================================
%foreach_thread, %block_more_parallel_fill_op_2, %block_more_parallel_op_2, %block_combiner_op_2 =
transform.structured.tile_reduction_using_foreach_thread %grid_reduction
@@ -24,7 +24,7 @@
// block_combiner_op_2 op is [parallel, reduction] of 1x384 that cannot fuse.
// map the 1-dim to threadIdx.y to trigger mapping of the reduction to
// threadIdx.x via predication via `if (x==0)`.
- transform.structured.tile_to_foreach_thread_op %block_combiner_op_2 tile_sizes [1]
+ transform.structured.tile_to_foreach_thread_op %block_combiner_op_2 num_threads [1]
( mapping = [#gpu.thread<y>] )
// Step 3. Rank-reduce and vectorize.