[Codegen][Nearly NFC] Move PropagateDispatchSizeBounds to Common/ (#19650)
As part of refactoring the single-iteration loop remover to use the
ValueBoundsOpInterface, I'll want to annotate workgroup IDs and counts
with theri `upper_bound`s before removing those loops, instead of having
the pass look that information up from context using a custom function.
Since other code like the CPU backend also uses the single-iteration
loop remover and workgroup counts (but not, I think, workitems) I've
generalized the annotation pass to not require a GPU target (and thus to
fall back to not adding a bound if one can't be inferred) and I've moved
it into Codegen/Common (along with renaming it) so it can be called from
non-GPU flows.
diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
index 8582cf9..940477d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
@@ -138,6 +138,7 @@
"PassUtils.cpp",
"Passes.cpp",
"PolynomialApproximationPass.cpp",
+ "PropagateDispatchSizeBounds.cpp",
"PropagateReshapesByExpansion.cpp",
"ReconcileTranslationInfo.cpp",
"RematerializeParallelOps.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
index 1dd9f91..d6e2528 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -130,6 +130,7 @@
"PassUtils.cpp"
"Passes.cpp"
"PolynomialApproximationPass.cpp"
+ "PropagateDispatchSizeBounds.cpp"
"PropagateReshapesByExpansion.cpp"
"ReconcileTranslationInfo.cpp"
"RematerializeParallelOps.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
index f4f8c41..c35b511 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
@@ -72,7 +72,6 @@
"GPUPatterns.cpp",
"GPUPipelining.cpp",
"GPUPromoteMatmulOperands.cpp",
- "GPUPropagateDispatchSizeBounds.cpp",
"GPUReduceBankConflicts.cpp",
"GPUReuseSharedMemoryAllocs.cpp",
"GPUTensorAlloc.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
index 4c1422e..51576eb 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
@@ -70,7 +70,6 @@
"GPUPatterns.cpp"
"GPUPipelining.cpp"
"GPUPromoteMatmulOperands.cpp"
- "GPUPropagateDispatchSizeBounds.cpp"
"GPUReduceBankConflicts.cpp"
"GPUReuseSharedMemoryAllocs.cpp"
"GPUTensorAlloc.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPropagateDispatchSizeBounds.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPropagateDispatchSizeBounds.cpp
deleted file mode 100644
index 43aa70b..0000000
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPropagateDispatchSizeBounds.cpp
+++ /dev/null
@@ -1,103 +0,0 @@
-// Copyright 2024 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/compiler/Codegen/Common/GPU/Passes.h"
-#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
-#include "iree/compiler/Codegen/Utils/GPUUtils.h"
-#include "iree/compiler/Codegen/Utils/Utils.h"
-#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Interfaces/FunctionInterfaces.h"
-#include "mlir/Transforms/Passes.h"
-
-namespace mlir::iree_compiler {
-
-#define GEN_PASS_DEF_GPUPROPAGATEDISPATCHSIZEBOUNDSPASS
-#include "iree/compiler/Codegen/Common/GPU/Passes.h.inc"
-
-namespace {
-
-static void applyBounds(FunctionOpInterface funcOp,
- ArrayRef<int32_t> workgroupSizes,
- ArrayRef<int32_t> workgroupCounts) {
- Builder b(funcOp->getContext());
- funcOp->walk([&](Operation *op) {
- TypeSwitch<Operation *>(op)
- .Case([&](gpu::ThreadIdOp tidOp) {
- tidOp.setUpperBoundAttr(b.getIndexAttr(
- workgroupSizes[static_cast<uint32_t>(tidOp.getDimension())]));
- })
- .Case([&](IREE::HAL::InterfaceWorkgroupSizeOp wgSizeOp) {
- wgSizeOp.setUpperBoundAttr(b.getIndexAttr(
- workgroupSizes[wgSizeOp.getDimension().getZExtValue()]));
- })
- .Case([&](IREE::HAL::InterfaceWorkgroupIDOp wgIdOp) {
- wgIdOp.setUpperBoundAttr(b.getIndexAttr(
- workgroupCounts[wgIdOp.getDimension().getZExtValue()]));
- })
- .Case([&](IREE::HAL::InterfaceWorkgroupCountOp wgCountOp) {
- wgCountOp.setUpperBoundAttr(b.getIndexAttr(
- workgroupCounts[wgCountOp.getDimension().getZExtValue()]));
- })
- .Default([](Operation *) {});
- });
-}
-
-struct GPUPropagateDispatchSizeBoundsPass final
- : impl::GPUPropagateDispatchSizeBoundsPassBase<
- GPUPropagateDispatchSizeBoundsPass> {
- using Base::Base;
-
- void runOnOperation() override {
- FunctionOpInterface funcOp = getOperation();
- IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp);
- if (!target) {
- funcOp.emitWarning("no known target attribute late in GPU codegen");
- return;
- }
- SmallVector<int32_t, 3> workgroupSizes(
- target.getWgp().getMaxWorkgroupSizes().asArrayRef());
- SmallVector<int32_t, 3> workgroupCounts(
- target.getWgp().getMaxWorkgroupCounts().asArrayRef());
-
- std::optional<SmallVector<int64_t>> staticWorkgroupSize =
- getWorkgroupSize(funcOp);
-
- // Late in codegen, we've reconciled the workgroup size onto the export op.
- if (std::optional<IREE::HAL::ExecutableExportOp> exportOp =
- getEntryPoint(funcOp)) {
- if (std::optional<ArrayAttr> exportWorkgroupSize =
- exportOp->getWorkgroupSize()) {
- staticWorkgroupSize =
- llvm::map_to_vector(exportWorkgroupSize->getAsRange<IntegerAttr>(),
- [](IntegerAttr a) { return a.getInt(); });
- }
- }
-
- if (staticWorkgroupSize) {
- // Target info with no workgroup sizes gives a 0-length array, hence no
- // zip_equal.
- for (auto [size, staticSize] :
- llvm::zip(workgroupSizes, *staticWorkgroupSize)) {
- size = staticSize;
- }
- }
- SmallVector<int64_t> staticWorkgroupCounts = getStaticNumWorkgroups(funcOp);
- assert(staticWorkgroupCounts.size() <= 3 &&
- "workgroup counts are 3D at most");
- for (auto [count, staticCount] :
- llvm::zip(workgroupCounts, staticWorkgroupCounts)) {
- if (staticCount != ShapedType::kDynamic) {
- count = staticCount;
- }
- }
-
- applyBounds(funcOp, workgroupSizes, workgroupCounts);
- }
-};
-} // namespace
-
-} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
index 24552cb..3a71759 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
@@ -180,11 +180,6 @@
];
}
-def GPUPropagateDispatchSizeBoundsPass :
- InterfacePass<"iree-codegen-gpu-propagate-dispatch-size-bounds", "mlir::FunctionOpInterface"> {
- let summary = "Pass to annotate workitem and workgroup IDs with known bounds";
-}
-
def GPUReduceBankConflictsPass :
InterfacePass<"iree-codegen-gpu-reduce-bank-conflicts", "mlir::FunctionOpInterface"> {
let summary = "Pass to try to reduce the number of bank conflicts by padding memref.alloc ops.";
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
index 2f3b092..a43b450 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
@@ -38,7 +38,6 @@
"gpu_pad_operands.mlir",
"gpu_pipeline.mlir",
"gpu_promote_matmul_operands.mlir",
- "gpu_propagate_dispatch_size_bounds.mlir",
"gpu_reorder_workgroups_static.mlir",
"gpu_reorder_workgroups.mlir",
"gpu_reuse_shared_memory_allocs.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
index 50be391..8efb734 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
@@ -34,7 +34,6 @@
"gpu_pad_operands.mlir"
"gpu_pipeline.mlir"
"gpu_promote_matmul_operands.mlir"
- "gpu_propagate_dispatch_size_bounds.mlir"
"gpu_reorder_workgroups.mlir"
"gpu_reorder_workgroups_static.mlir"
"gpu_reuse_shared_memory_allocs.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_propagate_dispatch_size_bounds.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_propagate_dispatch_size_bounds.mlir
deleted file mode 100644
index f26f2c5..0000000
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_propagate_dispatch_size_bounds.mlir
+++ /dev/null
@@ -1,122 +0,0 @@
-// RUN: iree-opt %s --split-input-file \
-// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-codegen-gpu-propagate-dispatch-size-bounds)))))" \
-// RUN: | FileCheck %s
-
-// Note: not the real target definition, missing types
-#executable_target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #iree_gpu.target<arch = "gfx1100", features = "",
- wgp = <compute = fp32,
- storage = b32,
- subgroup = arithmetic,
- dot = none, mma = [],
- subgroup_size_choices = [32, 64],
- max_workgroup_sizes = [1024, 1024, 1024],
- max_thread_count_per_workgroup = 1024,
- max_workgroup_memory_bytes = 65536,
- max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>}>
-#pipeline_layout = #hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer>]>
-
-hal.executable private @static {
- hal.executable.variant public @rocm_hsaco_fb target(#executable_target) {
- hal.executable.export public @static ordinal(0) layout(#pipeline_layout) attributes {workgroup_size = [64 : index, 2 : index, 1 : index]} {
- ^bb0(%arg0: !hal.device):
- %c32 = arith.constant 32 : index
- %c8 = arith.constant 8 : index
- %c1 = arith.constant 1 : index
- hal.return %c32, %c8, %c1 : index, index, index
- }
- builtin.module {
-// CHECK-LABEL: func.func @static
- func.func @static() {
-// CHECK: gpu.thread_id x upper_bound 64
-// CHECK: gpu.thread_id y upper_bound 2
-// CHECK: gpu.thread_id z upper_bound 1
- %thread_id_x = gpu.thread_id x
- %thread_id_y = gpu.thread_id y
- %thread_id_z = gpu.thread_id z
-
-// CHECK: hal.interface.workgroup.size[0] upper_bound 64
-// CHECK: hal.interface.workgroup.size[1] upper_bound 2
-// CHECK: hal.interface.workgroup.size[2] upper_bound 1
- %workgroup_size_x = hal.interface.workgroup.size[0] : index
- %workgroup_size_y = hal.interface.workgroup.size[1] : index
- %workgroup_size_z = hal.interface.workgroup.size[2] : index
-
-// CHECK: hal.interface.workgroup.id[0] upper_bound 32
-// CHECK: hal.interface.workgroup.id[1] upper_bound 8
-// CHECK: hal.interface.workgroup.id[2] upper_bound 1
- %workgroup_id_x = hal.interface.workgroup.id[0] : index
- %workgroup_id_y = hal.interface.workgroup.id[1] : index
- %workgroup_id_z = hal.interface.workgroup.id[2] : index
-
-// CHECK: hal.interface.workgroup.count[0] upper_bound 32
-// CHECK: hal.interface.workgroup.count[1] upper_bound 8
-// CHECK: hal.interface.workgroup.count[2] upper_bound 1
- %workgroup_conut_x = hal.interface.workgroup.count[0] : index
- %workgroup_count_y = hal.interface.workgroup.count[1] : index
- %workgroup_count_z = hal.interface.workgroup.count[2] : index
-
- return
- }
- }
- }
-}
-
-// -----
-
-#executable_target = #hal.executable.target<"rocm", "rocm-hsaco-fb",
- {iree.gpu.target = #iree_gpu.target<arch = "gfx1100", features = "",
- wgp = <compute = fp32,
- storage = b32,
- subgroup = arithmetic,
- dot = none, mma = [],
- subgroup_size_choices = [32, 64],
- max_workgroup_sizes = [1024, 1024, 1024],
- max_thread_count_per_workgroup = 1024,
- max_workgroup_memory_bytes = 65536,
- max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>}>
-#pipeline_layout = #hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer>]>
-
-hal.executable private @dynamic {
- hal.executable.variant public @rocm_hsaco_fb target(#executable_target) {
- hal.executable.export public @dynamic ordinal(0) layout(#pipeline_layout) {
- ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
- %count_x = affine.apply affine_map<()[s0] -> (s0 ceildiv 32)>()[%arg1]
- %count_y = affine.apply affine_map<()[s0] -> (s0 ceildiv 8)>()[%arg2]
- %count_z = arith.constant 1 : index
- hal.return %count_x, %count_y, %count_z : index, index, index
- }
- builtin.module {
- func.func @dynamic() {
-// CHECK: gpu.thread_id x upper_bound 1024
-// CHECK: gpu.thread_id y upper_bound 1024
-// CHECK: gpu.thread_id z upper_bound 1024
- %thread_id_x = gpu.thread_id x
- %thread_id_y = gpu.thread_id y
- %thread_id_z = gpu.thread_id z
-
-// CHECK: hal.interface.workgroup.size[0] upper_bound 1024
-// CHECK: hal.interface.workgroup.size[1] upper_bound 1024
-// CHECK: hal.interface.workgroup.size[2] upper_bound 1024
- %workgroup_size_x = hal.interface.workgroup.size[0] : index
- %workgroup_size_y = hal.interface.workgroup.size[1] : index
- %workgroup_size_z = hal.interface.workgroup.size[2] : index
-
-// CHECK: hal.interface.workgroup.id[0] upper_bound 2147483647
-// CHECK: hal.interface.workgroup.id[1] upper_bound 2147483647
-// CHECK: hal.interface.workgroup.id[2] upper_bound 1
- %workgroup_id_x = hal.interface.workgroup.id[0] : index
- %workgroup_id_y = hal.interface.workgroup.id[1] : index
- %workgroup_id_z = hal.interface.workgroup.id[2] : index
-
-// CHECK: hal.interface.workgroup.count[0] upper_bound 2147483647
-// CHECK: hal.interface.workgroup.count[1] upper_bound 2147483647
-// CHECK: hal.interface.workgroup.count[2] upper_bound 1
- %workgroup_conut_x = hal.interface.workgroup.count[0] : index
- %workgroup_count_y = hal.interface.workgroup.count[1] : index
- %workgroup_count_z = hal.interface.workgroup.count[2] : index
-
- return
- }
- }
- }
-}
diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td
index 1854279..7188de2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td
@@ -531,6 +531,11 @@
let summary = "Convert math operations to their polynomial approximation";
}
+def PropagateDispatchSizeBoundsPass :
+ InterfacePass<"iree-codegen-propagate-dispatch-size-bounds", "mlir::FunctionOpInterface"> {
+ let summary = "Pass to annotate workitem and workgroup IDs with known bounds";
+}
+
def PropagateReshapesByExpansionPass :
Pass<"iree-codegen-propagate-reshapes-by-expansion", ""> {
let summary = "Propagates reshaping operations by expansion.";
diff --git a/compiler/src/iree/compiler/Codegen/Common/PropagateDispatchSizeBounds.cpp b/compiler/src/iree/compiler/Codegen/Common/PropagateDispatchSizeBounds.cpp
new file mode 100644
index 0000000..1ca6ecd
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/PropagateDispatchSizeBounds.cpp
@@ -0,0 +1,127 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Common/Passes.h"
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
+#include "iree/compiler/Codegen/Utils/Utils.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Transforms/Passes.h"
+
+namespace mlir::iree_compiler {
+
+#define GEN_PASS_DEF_PROPAGATEDISPATCHSIZEBOUNDSPASS
+#include "iree/compiler/Codegen/Common/Passes.h.inc"
+
+namespace {
+
+static void applyBounds(FunctionOpInterface funcOp,
+ ArrayRef<std::optional<int64_t>> workgroupSizes,
+ ArrayRef<std::optional<int64_t>> workgroupCounts) {
+ Builder b(funcOp->getContext());
+ funcOp->walk([&](Operation *op) {
+ TypeSwitch<Operation *>(op)
+ .Case([&](gpu::ThreadIdOp tidOp) {
+ std::optional<int64_t> bound =
+ workgroupSizes[static_cast<uint32_t>(tidOp.getDimension())];
+ if (bound) {
+ tidOp.setUpperBoundAttr(b.getIndexAttr(*bound));
+ }
+ })
+ .Case([&](gpu::BlockDimOp blockDimOp) {
+ std::optional<int64_t> bound =
+ workgroupSizes[static_cast<int32_t>(blockDimOp.getDimension())];
+ if (bound) {
+ blockDimOp.setUpperBoundAttr(b.getIndexAttr(*bound));
+ }
+ })
+ .Case([&](IREE::HAL::InterfaceWorkgroupSizeOp wgSizeOp) {
+ std::optional<int64_t> bound =
+ workgroupSizes[wgSizeOp.getDimension().getZExtValue()];
+ if (bound) {
+ wgSizeOp.setUpperBoundAttr(b.getIndexAttr(*bound));
+ }
+ })
+ .Case([&](IREE::HAL::InterfaceWorkgroupIDOp wgIdOp) {
+ std::optional<int64_t> bound =
+ workgroupCounts[wgIdOp.getDimension().getZExtValue()];
+ if (bound) {
+ wgIdOp.setUpperBoundAttr(b.getIndexAttr(*bound));
+ }
+ })
+ .Case([&](IREE::HAL::InterfaceWorkgroupCountOp wgCountOp) {
+ std::optional<int64_t> bound =
+ workgroupCounts[wgCountOp.getDimension().getZExtValue()];
+ if (bound) {
+ wgCountOp.setUpperBoundAttr(b.getIndexAttr(*bound));
+ }
+ })
+ .Default([](Operation *) {});
+ });
+}
+
+struct PropagateDispatchSizeBoundsPass final
+ : impl::PropagateDispatchSizeBoundsPassBase<
+ PropagateDispatchSizeBoundsPass> {
+ using Base::Base;
+
+ void runOnOperation() override {
+ FunctionOpInterface funcOp = getOperation();
+ SmallVector<std::optional<int64_t>, 3> workgroupSizes(3, std::nullopt);
+ SmallVector<std::optional<int64_t>, 3> workgroupCounts(3, std::nullopt);
+
+ IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp);
+ if (target) {
+ ArrayRef<int32_t> targetWorkgroupSizes =
+ target.getWgp().getMaxWorkgroupSizes().asArrayRef();
+ ArrayRef<int32_t> targetWorkgroupCounts =
+ target.getWgp().getMaxWorkgroupCounts().asArrayRef();
+ llvm::transform(targetWorkgroupSizes, workgroupSizes.begin(),
+ [](int32_t x) { return std::optional<int64_t>{x}; });
+ llvm::transform(targetWorkgroupCounts, workgroupCounts.begin(),
+ [](int32_t x) { return std::optional<int64_t>{x}; });
+ }
+
+ std::optional<SmallVector<int64_t>> staticWorkgroupSize =
+ getWorkgroupSize(funcOp);
+
+ // Late in codegen, we've reconciled the workgroup size onto the export op.
+ if (std::optional<IREE::HAL::ExecutableExportOp> exportOp =
+ getEntryPoint(funcOp)) {
+ if (std::optional<ArrayAttr> exportWorkgroupSize =
+ exportOp->getWorkgroupSize()) {
+ staticWorkgroupSize =
+ llvm::map_to_vector(exportWorkgroupSize->getAsRange<IntegerAttr>(),
+ [](IntegerAttr a) { return a.getInt(); });
+ }
+ }
+
+ if (staticWorkgroupSize) {
+ // Target info with no workgroup sizes gives a 0-length array, hence no
+ // zip_equal.
+ for (auto [size, staticSize] :
+ llvm::zip(workgroupSizes, *staticWorkgroupSize)) {
+ size = staticSize;
+ }
+ }
+ SmallVector<int64_t> staticWorkgroupCounts = getStaticNumWorkgroups(funcOp);
+ assert(staticWorkgroupCounts.size() <= 3 &&
+ "workgroup counts are 3D at most");
+ for (auto [count, staticCount] :
+ llvm::zip(workgroupCounts, staticWorkgroupCounts)) {
+ if (staticCount != ShapedType::kDynamic) {
+ count = staticCount;
+ }
+ }
+
+ applyBounds(funcOp, workgroupSizes, workgroupCounts);
+ }
+};
+} // namespace
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
index edbb5d8..43a4079 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
@@ -70,6 +70,7 @@
"optimize_tensor_insert_extract_slices.mlir",
"pad_dynamic_alloc.mlir",
"polynomial_approximation.mlir",
+ "propagate_dispatch_size_bounds.mlir",
"propagate_reshapes_by_expansion.mlir",
"reconcile_translation_info.mlir",
"reductions.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
index e240940..cfeef07 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
@@ -66,6 +66,7 @@
"optimize_tensor_insert_extract_slices.mlir"
"pad_dynamic_alloc.mlir"
"polynomial_approximation.mlir"
+ "propagate_dispatch_size_bounds.mlir"
"propagate_reshapes_by_expansion.mlir"
"reconcile_translation_info.mlir"
"reductions.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/propagate_dispatch_size_bounds.mlir b/compiler/src/iree/compiler/Codegen/Common/test/propagate_dispatch_size_bounds.mlir
new file mode 100644
index 0000000..eb85adc
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/test/propagate_dispatch_size_bounds.mlir
@@ -0,0 +1,204 @@
+// RUN: iree-opt %s --split-input-file \
+// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-codegen-propagate-dispatch-size-bounds)))))" \
+// RUN: | FileCheck %s
+
+// Note: not the real target definition, missing types
+#executable_target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #iree_gpu.target<arch = "gfx1100", features = "",
+ wgp = <compute = fp32,
+ storage = b32,
+ subgroup = arithmetic,
+ dot = none, mma = [],
+ subgroup_size_choices = [32, 64],
+ max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024,
+ max_workgroup_memory_bytes = 65536,
+ max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>}>
+#pipeline_layout = #hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer>]>
+
+hal.executable private @static {
+ hal.executable.variant public @rocm_hsaco_fb target(#executable_target) {
+ hal.executable.export public @static ordinal(0) layout(#pipeline_layout) attributes {workgroup_size = [64 : index, 2 : index, 1 : index]} {
+ ^bb0(%arg0: !hal.device):
+ %c32 = arith.constant 32 : index
+ %c8 = arith.constant 8 : index
+ %c1 = arith.constant 1 : index
+ hal.return %c32, %c8, %c1 : index, index, index
+ }
+ builtin.module {
+// CHECK-LABEL: func.func @static()
+ func.func @static() {
+// CHECK: gpu.thread_id x upper_bound 64
+// CHECK: gpu.thread_id y upper_bound 2
+// CHECK: gpu.thread_id z upper_bound 1
+ %thread_id_x = gpu.thread_id x
+ %thread_id_y = gpu.thread_id y
+ %thread_id_z = gpu.thread_id z
+
+// CHECK: gpu.block_dim x upper_bound 64
+// CHECK: gpu.block_dim y upper_bound 2
+// CHECK: gpu.block_dim z upper_bound 1
+ %block_dim_x = gpu.block_dim x
+ %block_dim_y = gpu.block_dim y
+ %block_dim_z = gpu.block_dim z
+
+// CHECK: hal.interface.workgroup.size[0] upper_bound 64
+// CHECK: hal.interface.workgroup.size[1] upper_bound 2
+// CHECK: hal.interface.workgroup.size[2] upper_bound 1
+ %workgroup_size_x = hal.interface.workgroup.size[0] : index
+ %workgroup_size_y = hal.interface.workgroup.size[1] : index
+ %workgroup_size_z = hal.interface.workgroup.size[2] : index
+
+// CHECK: hal.interface.workgroup.id[0] upper_bound 32
+// CHECK: hal.interface.workgroup.id[1] upper_bound 8
+// CHECK: hal.interface.workgroup.id[2] upper_bound 1
+ %workgroup_id_x = hal.interface.workgroup.id[0] : index
+ %workgroup_id_y = hal.interface.workgroup.id[1] : index
+ %workgroup_id_z = hal.interface.workgroup.id[2] : index
+
+// CHECK: hal.interface.workgroup.count[0] upper_bound 32
+// CHECK: hal.interface.workgroup.count[1] upper_bound 8
+// CHECK: hal.interface.workgroup.count[2] upper_bound 1
+ %workgroup_conut_x = hal.interface.workgroup.count[0] : index
+ %workgroup_count_y = hal.interface.workgroup.count[1] : index
+ %workgroup_count_z = hal.interface.workgroup.count[2] : index
+
+ return
+ }
+ }
+ }
+}
+
+// -----
+
+#executable_target = #hal.executable.target<"rocm", "rocm-hsaco-fb",
+ {iree.gpu.target = #iree_gpu.target<arch = "gfx1100", features = "",
+ wgp = <compute = fp32,
+ storage = b32,
+ subgroup = arithmetic,
+ dot = none, mma = [],
+ subgroup_size_choices = [32, 64],
+ max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024,
+ max_workgroup_memory_bytes = 65536,
+ max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>}>
+#pipeline_layout = #hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer>]>
+
+hal.executable private @dynamic {
+ hal.executable.variant public @rocm_hsaco_fb target(#executable_target) {
+ hal.executable.export public @dynamic ordinal(0) layout(#pipeline_layout) {
+ ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
+ %count_x = affine.apply affine_map<()[s0] -> (s0 ceildiv 32)>()[%arg1]
+ %count_y = affine.apply affine_map<()[s0] -> (s0 ceildiv 8)>()[%arg2]
+ %count_z = arith.constant 1 : index
+ hal.return %count_x, %count_y, %count_z : index, index, index
+ }
+ builtin.module {
+// CHECK-LABEL: func.func @dynamic()
+ func.func @dynamic() {
+// CHECK: gpu.thread_id x upper_bound 1024
+// CHECK: gpu.thread_id y upper_bound 1024
+// CHECK: gpu.thread_id z upper_bound 1024
+ %thread_id_x = gpu.thread_id x
+ %thread_id_y = gpu.thread_id y
+ %thread_id_z = gpu.thread_id z
+
+// CHECK: hal.interface.workgroup.size[0] upper_bound 1024
+// CHECK: hal.interface.workgroup.size[1] upper_bound 1024
+// CHECK: hal.interface.workgroup.size[2] upper_bound 1024
+ %workgroup_size_x = hal.interface.workgroup.size[0] : index
+ %workgroup_size_y = hal.interface.workgroup.size[1] : index
+ %workgroup_size_z = hal.interface.workgroup.size[2] : index
+
+// CHECK: hal.interface.workgroup.id[0] upper_bound 2147483647
+// CHECK: hal.interface.workgroup.id[1] upper_bound 2147483647
+// CHECK: hal.interface.workgroup.id[2] upper_bound 1
+ %workgroup_id_x = hal.interface.workgroup.id[0] : index
+ %workgroup_id_y = hal.interface.workgroup.id[1] : index
+ %workgroup_id_z = hal.interface.workgroup.id[2] : index
+
+// CHECK: hal.interface.workgroup.count[0] upper_bound 2147483647
+// CHECK: hal.interface.workgroup.count[1] upper_bound 2147483647
+// CHECK: hal.interface.workgroup.count[2] upper_bound 1
+ %workgroup_conut_x = hal.interface.workgroup.count[0] : index
+ %workgroup_count_y = hal.interface.workgroup.count[1] : index
+ %workgroup_count_z = hal.interface.workgroup.count[2] : index
+
+ return
+ }
+ }
+ }
+}
+
+// -----
+
+#executable_target = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu_features = "+avx512f", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 16 : index, target_triple = "x86_64-unknown-linux-gnu"}>
+#pipeline_layout = #hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer>]>
+
+hal.executable private @static_cpu {
+ hal.executable.variant public @embedded_elf_x86_64 target(#executable_target) {
+ hal.executable.export public @static_cpu ordinal(0) layout(#pipeline_layout) attributes {workgroup_size = [64 : index, 2 : index, 1 : index]} {
+ ^bb0(%arg0: !hal.device):
+ %c32 = arith.constant 32 : index
+ %c8 = arith.constant 8 : index
+ %c1 = arith.constant 1 : index
+ hal.return %c32, %c8, %c1 : index, index, index
+ }
+ builtin.module {
+// CHECK-LABEL: func.func @static_cpu()
+ func.func @static_cpu() {
+// CHECK: hal.interface.workgroup.id[0] upper_bound 32
+// CHECK: hal.interface.workgroup.id[1] upper_bound 8
+// CHECK: hal.interface.workgroup.id[2] upper_bound 1
+ %workgroup_id_x = hal.interface.workgroup.id[0] : index
+ %workgroup_id_y = hal.interface.workgroup.id[1] : index
+ %workgroup_id_z = hal.interface.workgroup.id[2] : index
+
+// CHECK: hal.interface.workgroup.count[0] upper_bound 32
+// CHECK: hal.interface.workgroup.count[1] upper_bound 8
+// CHECK: hal.interface.workgroup.count[2] upper_bound 1
+ %workgroup_conut_x = hal.interface.workgroup.count[0] : index
+ %workgroup_count_y = hal.interface.workgroup.count[1] : index
+ %workgroup_count_z = hal.interface.workgroup.count[2] : index
+
+ return
+ }
+ }
+ }
+}
+
+// -----
+
+#executable_target = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu_features = "+avx512f", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 16 : index, target_triple = "x86_64-unknown-linux-gnu"}>
+#pipeline_layout = #hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer>]>
+
+hal.executable private @dynamic_cpu {
+ hal.executable.variant public @embedded_elf_x86_64 target(#executable_target) {
+ hal.executable.export public @dynamic_cpu ordinal(0) layout(#pipeline_layout) {
+ ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
+ %count_x = affine.apply affine_map<()[s0] -> (s0 ceildiv 32)>()[%arg1]
+ %count_y = affine.apply affine_map<()[s0] -> (s0 ceildiv 8)>()[%arg2]
+ %count_z = arith.constant 1 : index
+ hal.return %count_x, %count_y, %count_z : index, index, index
+ }
+ builtin.module {
+// CHECK-LABEL: @dynamic_cpu()
+ func.func @dynamic_cpu() {
+// CHECK: hal.interface.workgroup.id[0] : index
+// CHECK: hal.interface.workgroup.id[1] : index
+// CHECK: hal.interface.workgroup.id[2] upper_bound 1 : index
+ %workgroup_id_x = hal.interface.workgroup.id[0] : index
+ %workgroup_id_y = hal.interface.workgroup.id[1] : index
+ %workgroup_id_z = hal.interface.workgroup.id[2] : index
+
+// CHECK: hal.interface.workgroup.count[0] : index
+// CHECK: hal.interface.workgroup.count[1] : index
+// CHECK: hal.interface.workgroup.count[2] upper_bound 1 : index
+ %workgroup_conut_x = hal.interface.workgroup.count[0] : index
+ %workgroup_count_y = hal.interface.workgroup.count[1] : index
+ %workgroup_count_z = hal.interface.workgroup.count[2] : index
+
+ return
+ }
+ }
+ }
+}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index d0a269e..c3dda57 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -1053,7 +1053,7 @@
.addPass(createLLVMGPUVectorLoweringPass)
.addPass(createExpandGPUOpsPass)
// Expose workitem and workgroup counts to range inference later.
- .addPass(createGPUPropagateDispatchSizeBoundsPass);
+ .addPass(createPropagateDispatchSizeBoundsPass);
// This pass needs to run before SCF -> CF.
addLowerAndOptimizeAddressComputationPasses(funcPassManager);
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
index 511dbe7..2a13558 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -227,7 +227,7 @@
/// Adds passes to perform the final SPIR-V conversion.
static void addSPIRVLoweringPasses(OpPassManager &modulePassManager) {
FunctionLikeNest(modulePassManager)
- .addPass(createGPUPropagateDispatchSizeBoundsPass)
+ .addPass(createPropagateDispatchSizeBoundsPass)
.addPass(createCanonicalizerPass)
.addPass(createCSEPass)
.addPass(createLowerAffinePass)