Reapply "[Codegen][GPU] Add range information to GPU dispatch IDs" (#19361) (#19372)
This reverts commit cb5be1dbd3560f692578c137eadbb413b41e44c7.
Compaled to the previous revision, this one works around a correctness
bug in dataflow analysis that's being fixed by removing the analysis
after SCF->CF.
---
First, this patch implements InferIntRangeInterface for
hal.interface.workgroup.{size,id,count} using a local upper_bound
attribute.
Then, it adds a -iree-codegen-gpu-propagate-dispatch-size-bounds pass
that adds these upper_bounds identifiers to the interface.workgroup
operations and to gpu.thread_id based on static information available
late in the codegen pipeline.
Then, it uses -optimize-int-arithmetic to optimize indexing after
-lower-affine, getting rid of a bunch of "if the input's negative" logic
that isn't actually needed in many of our kernels.
It also ensures that these upper_bound values propagate to LLVM.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
index c9e2363..128ffa9 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
@@ -73,6 +73,7 @@
"GPUPatterns.cpp",
"GPUPipelining.cpp",
"GPUPromoteMatmulOperands.cpp",
+ "GPUPropagateDispatchSizeBounds.cpp",
"GPUReduceBankConflicts.cpp",
"GPUReuseSharedMemoryAllocs.cpp",
"GPUTensorAlloc.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
index 2aeb9ad..97d3240 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
@@ -71,6 +71,7 @@
"GPUPatterns.cpp"
"GPUPipelining.cpp"
"GPUPromoteMatmulOperands.cpp"
+ "GPUPropagateDispatchSizeBounds.cpp"
"GPUReduceBankConflicts.cpp"
"GPUReuseSharedMemoryAllocs.cpp"
"GPUTensorAlloc.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPropagateDispatchSizeBounds.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPropagateDispatchSizeBounds.cpp
new file mode 100644
index 0000000..43aa70b
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPropagateDispatchSizeBounds.cpp
@@ -0,0 +1,103 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Common/GPU/Passes.h"
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
+#include "iree/compiler/Codegen/Utils/Utils.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Transforms/Passes.h"
+
+namespace mlir::iree_compiler {
+
+#define GEN_PASS_DEF_GPUPROPAGATEDISPATCHSIZEBOUNDSPASS
+#include "iree/compiler/Codegen/Common/GPU/Passes.h.inc"
+
+namespace {
+
+static void applyBounds(FunctionOpInterface funcOp,
+ ArrayRef<int32_t> workgroupSizes,
+ ArrayRef<int32_t> workgroupCounts) {
+ Builder b(funcOp->getContext());
+ funcOp->walk([&](Operation *op) {
+ TypeSwitch<Operation *>(op)
+ .Case([&](gpu::ThreadIdOp tidOp) {
+ tidOp.setUpperBoundAttr(b.getIndexAttr(
+ workgroupSizes[static_cast<uint32_t>(tidOp.getDimension())]));
+ })
+ .Case([&](IREE::HAL::InterfaceWorkgroupSizeOp wgSizeOp) {
+ wgSizeOp.setUpperBoundAttr(b.getIndexAttr(
+ workgroupSizes[wgSizeOp.getDimension().getZExtValue()]));
+ })
+ .Case([&](IREE::HAL::InterfaceWorkgroupIDOp wgIdOp) {
+ wgIdOp.setUpperBoundAttr(b.getIndexAttr(
+ workgroupCounts[wgIdOp.getDimension().getZExtValue()]));
+ })
+ .Case([&](IREE::HAL::InterfaceWorkgroupCountOp wgCountOp) {
+ wgCountOp.setUpperBoundAttr(b.getIndexAttr(
+ workgroupCounts[wgCountOp.getDimension().getZExtValue()]));
+ })
+ .Default([](Operation *) {});
+ });
+}
+
+struct GPUPropagateDispatchSizeBoundsPass final
+ : impl::GPUPropagateDispatchSizeBoundsPassBase<
+ GPUPropagateDispatchSizeBoundsPass> {
+ using Base::Base;
+
+ void runOnOperation() override {
+ FunctionOpInterface funcOp = getOperation();
+ IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp);
+ if (!target) {
+ funcOp.emitWarning("no known target attribute late in GPU codegen");
+ return;
+ }
+ SmallVector<int32_t, 3> workgroupSizes(
+ target.getWgp().getMaxWorkgroupSizes().asArrayRef());
+ SmallVector<int32_t, 3> workgroupCounts(
+ target.getWgp().getMaxWorkgroupCounts().asArrayRef());
+
+ std::optional<SmallVector<int64_t>> staticWorkgroupSize =
+ getWorkgroupSize(funcOp);
+
+ // Late in codegen, we've reconciled the workgroup size onto the export op.
+ if (std::optional<IREE::HAL::ExecutableExportOp> exportOp =
+ getEntryPoint(funcOp)) {
+ if (std::optional<ArrayAttr> exportWorkgroupSize =
+ exportOp->getWorkgroupSize()) {
+ staticWorkgroupSize =
+ llvm::map_to_vector(exportWorkgroupSize->getAsRange<IntegerAttr>(),
+ [](IntegerAttr a) { return a.getInt(); });
+ }
+ }
+
+ if (staticWorkgroupSize) {
+ // Target info with no workgroup sizes gives a 0-length array, hence no
+ // zip_equal.
+ for (auto [size, staticSize] :
+ llvm::zip(workgroupSizes, *staticWorkgroupSize)) {
+ size = staticSize;
+ }
+ }
+ SmallVector<int64_t> staticWorkgroupCounts = getStaticNumWorkgroups(funcOp);
+ assert(staticWorkgroupCounts.size() <= 3 &&
+ "workgroup counts are 3D at most");
+ for (auto [count, staticCount] :
+ llvm::zip(workgroupCounts, staticWorkgroupCounts)) {
+ if (staticCount != ShapedType::kDynamic) {
+ count = staticCount;
+ }
+ }
+
+ applyBounds(funcOp, workgroupSizes, workgroupCounts);
+ }
+};
+} // namespace
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
index 7891309..b3fdd50 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
@@ -178,6 +178,11 @@
];
}
+def GPUPropagateDispatchSizeBoundsPass :
+ InterfacePass<"iree-codegen-gpu-propagate-dispatch-size-bounds", "mlir::FunctionOpInterface"> {
+ let summary = "Pass to annotate workitem and workgroup IDs with known bounds";
+}
+
def GPUReduceBankConflictsPass :
InterfacePass<"iree-codegen-gpu-reduce-bank-conflicts", "mlir::FunctionOpInterface"> {
let summary = "Pass to try to reduce the number of bank conflicts by padding memref.alloc ops.";
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
index 41afbb6..dc8e6a1 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
@@ -41,6 +41,7 @@
"gpu_pad_operands.mlir",
"gpu_pipeline.mlir",
"gpu_promote_matmul_operands.mlir",
+ "gpu_propagate_dispatch_size_bounds.mlir",
"gpu_reorder_workgroups_static.mlir",
"gpu_reorder_workgroups.mlir",
"gpu_reuse_shared_memory_allocs.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
index ad86649..4dc0f28 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
@@ -37,6 +37,7 @@
"gpu_pad_operands.mlir"
"gpu_pipeline.mlir"
"gpu_promote_matmul_operands.mlir"
+ "gpu_propagate_dispatch_size_bounds.mlir"
"gpu_reorder_workgroups.mlir"
"gpu_reorder_workgroups_static.mlir"
"gpu_reuse_shared_memory_allocs.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_propagate_dispatch_size_bounds.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_propagate_dispatch_size_bounds.mlir
new file mode 100644
index 0000000..f26f2c5
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_propagate_dispatch_size_bounds.mlir
@@ -0,0 +1,122 @@
+// RUN: iree-opt %s --split-input-file \
+// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-codegen-gpu-propagate-dispatch-size-bounds)))))" \
+// RUN: | FileCheck %s
+
+// Note: not the real target definition, missing types
+#executable_target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #iree_gpu.target<arch = "gfx1100", features = "",
+ wgp = <compute = fp32,
+ storage = b32,
+ subgroup = arithmetic,
+ dot = none, mma = [],
+ subgroup_size_choices = [32, 64],
+ max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024,
+ max_workgroup_memory_bytes = 65536,
+ max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>}>
+#pipeline_layout = #hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer>]>
+
+hal.executable private @static {
+ hal.executable.variant public @rocm_hsaco_fb target(#executable_target) {
+ hal.executable.export public @static ordinal(0) layout(#pipeline_layout) attributes {workgroup_size = [64 : index, 2 : index, 1 : index]} {
+ ^bb0(%arg0: !hal.device):
+ %c32 = arith.constant 32 : index
+ %c8 = arith.constant 8 : index
+ %c1 = arith.constant 1 : index
+ hal.return %c32, %c8, %c1 : index, index, index
+ }
+ builtin.module {
+// CHECK-LABEL: func.func @static
+ func.func @static() {
+// CHECK: gpu.thread_id x upper_bound 64
+// CHECK: gpu.thread_id y upper_bound 2
+// CHECK: gpu.thread_id z upper_bound 1
+ %thread_id_x = gpu.thread_id x
+ %thread_id_y = gpu.thread_id y
+ %thread_id_z = gpu.thread_id z
+
+// CHECK: hal.interface.workgroup.size[0] upper_bound 64
+// CHECK: hal.interface.workgroup.size[1] upper_bound 2
+// CHECK: hal.interface.workgroup.size[2] upper_bound 1
+ %workgroup_size_x = hal.interface.workgroup.size[0] : index
+ %workgroup_size_y = hal.interface.workgroup.size[1] : index
+ %workgroup_size_z = hal.interface.workgroup.size[2] : index
+
+// CHECK: hal.interface.workgroup.id[0] upper_bound 32
+// CHECK: hal.interface.workgroup.id[1] upper_bound 8
+// CHECK: hal.interface.workgroup.id[2] upper_bound 1
+ %workgroup_id_x = hal.interface.workgroup.id[0] : index
+ %workgroup_id_y = hal.interface.workgroup.id[1] : index
+ %workgroup_id_z = hal.interface.workgroup.id[2] : index
+
+// CHECK: hal.interface.workgroup.count[0] upper_bound 32
+// CHECK: hal.interface.workgroup.count[1] upper_bound 8
+// CHECK: hal.interface.workgroup.count[2] upper_bound 1
+ %workgroup_conut_x = hal.interface.workgroup.count[0] : index
+ %workgroup_count_y = hal.interface.workgroup.count[1] : index
+ %workgroup_count_z = hal.interface.workgroup.count[2] : index
+
+ return
+ }
+ }
+ }
+}
+
+// -----
+
+#executable_target = #hal.executable.target<"rocm", "rocm-hsaco-fb",
+ {iree.gpu.target = #iree_gpu.target<arch = "gfx1100", features = "",
+ wgp = <compute = fp32,
+ storage = b32,
+ subgroup = arithmetic,
+ dot = none, mma = [],
+ subgroup_size_choices = [32, 64],
+ max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024,
+ max_workgroup_memory_bytes = 65536,
+ max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>}>
+#pipeline_layout = #hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer>]>
+
+hal.executable private @dynamic {
+ hal.executable.variant public @rocm_hsaco_fb target(#executable_target) {
+ hal.executable.export public @dynamic ordinal(0) layout(#pipeline_layout) {
+ ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
+ %count_x = affine.apply affine_map<()[s0] -> (s0 ceildiv 32)>()[%arg1]
+ %count_y = affine.apply affine_map<()[s0] -> (s0 ceildiv 8)>()[%arg2]
+ %count_z = arith.constant 1 : index
+ hal.return %count_x, %count_y, %count_z : index, index, index
+ }
+ builtin.module {
+ func.func @dynamic() {
+// CHECK: gpu.thread_id x upper_bound 1024
+// CHECK: gpu.thread_id y upper_bound 1024
+// CHECK: gpu.thread_id z upper_bound 1024
+ %thread_id_x = gpu.thread_id x
+ %thread_id_y = gpu.thread_id y
+ %thread_id_z = gpu.thread_id z
+
+// CHECK: hal.interface.workgroup.size[0] upper_bound 1024
+// CHECK: hal.interface.workgroup.size[1] upper_bound 1024
+// CHECK: hal.interface.workgroup.size[2] upper_bound 1024
+ %workgroup_size_x = hal.interface.workgroup.size[0] : index
+ %workgroup_size_y = hal.interface.workgroup.size[1] : index
+ %workgroup_size_z = hal.interface.workgroup.size[2] : index
+
+// CHECK: hal.interface.workgroup.id[0] upper_bound 2147483647
+// CHECK: hal.interface.workgroup.id[1] upper_bound 2147483647
+// CHECK: hal.interface.workgroup.id[2] upper_bound 1
+ %workgroup_id_x = hal.interface.workgroup.id[0] : index
+ %workgroup_id_y = hal.interface.workgroup.id[1] : index
+ %workgroup_id_z = hal.interface.workgroup.id[2] : index
+
+// CHECK: hal.interface.workgroup.count[0] upper_bound 2147483647
+// CHECK: hal.interface.workgroup.count[1] upper_bound 2147483647
+// CHECK: hal.interface.workgroup.count[2] upper_bound 1
+ %workgroup_conut_x = hal.interface.workgroup.count[0] : index
+ %workgroup_count_y = hal.interface.workgroup.count[1] : index
+ %workgroup_count_z = hal.interface.workgroup.count[2] : index
+
+ return
+ }
+ }
+ }
+}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp
index c056d44..1441f95 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp
@@ -505,7 +505,10 @@
int32_t index = static_cast<int32_t>(op.getDimension().getSExtValue());
std::array<gpu::Dimension, 3> dimAttr{gpu::Dimension::x, gpu::Dimension::y,
gpu::Dimension::z};
- rewriter.replaceOpWithNewOp<NewOpTy>(op, op.getType(), dimAttr[index]);
+ NewOpTy newOp =
+ rewriter.replaceOpWithNewOp<NewOpTy>(op, op.getType(), dimAttr[index]);
+ if (IntegerAttr bound = op.getUpperBoundAttr())
+ newOp.setUpperBoundAttr(bound);
return success();
}
};
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index 53e49ef..f8ebe1c 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -1067,7 +1067,13 @@
.addPass(createCSEPass)
// Hoist the resulting decompositions.
.addPass(createIREELoopInvariantCodeMotionPass)
- .addPass(createLowerAffinePass);
+ .addPass(affine::createAffineExpandIndexOpsPass)
+ .addPass(createLowerAffinePass)
+ .addPass(IREE::Util::createOptimizeIntArithmeticPass)
+ // Do another round of LICM now that we've lowered and optimized
+ // arithmetic
+ .addPass(createCSEPass)
+ .addPass(createIREELoopInvariantCodeMotionPass);
}
static void addLowerToLLVMGPUPasses(OpPassManager &modulePassManager,
@@ -1103,7 +1109,9 @@
FunctionLikeNest funcPassManager(modulePassManager);
funcPassManager.addPass(createFoldTensorExtractOpPass)
.addPass(createLLVMGPUVectorLoweringPass)
- .addPass(createExpandGPUOpsPass);
+ .addPass(createExpandGPUOpsPass)
+ // Expose workitem and workgroup counts to range inference later.
+ .addPass(createGPUPropagateDispatchSizeBoundsPass);
// This pass needs to run before SCF -> CF.
addLowerAndOptimizeAddressComputationPasses(funcPassManager);
@@ -1130,9 +1138,7 @@
.addPass(memref::createExpandStridedMetadataPass)
.addPass(createEmulateNarrowTypePass)
.addPass(affine::createAffineExpandIndexOpsPass)
- .addPass(createLowerAffinePass)
- .addPass(createCanonicalizerPass)
- .addPass(createCSEPass);
+ .addPass(createLowerAffinePass);
// Strip out the debug info for the kernel.
modulePassManager.addPass(createStripDebugInfoPass());
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir
index 6c1c5e1..ba6b5da7 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir
@@ -40,7 +40,7 @@
// CHECK-DAG: %[[C8192:.*]] = llvm.mlir.constant(8192 : index) : i64
//
// Match the interesting special registers.
-// CHECK-DAG: %[[TID_Y:.*]] = nvvm.read.ptx.sreg.tid.y : i32
+// CHECK-DAG: %[[TID_Y:.*]] = nvvm.read.ptx.sreg.tid.y range <i32, 0, 2> : i32
// CHECK-DAG: %[[TID_Y_EXT:.*]] = llvm.sext %[[TID_Y]] : i32 to i64
// CHECK-DAG: %[[LANEID:.*]] = nvvm.read.ptx.sreg.laneid range <i32, 0, 32> : i32
// CHECK-DAG: %[[LANEID_EXT:.*]] = llvm.sext %[[LANEID]] : i32 to i64
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
index ea0aa9f..511dbe7 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -227,9 +227,11 @@
/// Adds passes to perform the final SPIR-V conversion.
static void addSPIRVLoweringPasses(OpPassManager &modulePassManager) {
FunctionLikeNest(modulePassManager)
+ .addPass(createGPUPropagateDispatchSizeBoundsPass)
.addPass(createCanonicalizerPass)
.addPass(createCSEPass)
.addPass(createLowerAffinePass)
+ .addPass(IREE::Util::createOptimizeIntArithmeticPass)
// Lower ApplyScale before the i64 Emulation Pass so that new 64-bit ops
// are also emulated if not supported by the target.
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel
index d9d6a92..3f80245 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel
@@ -35,6 +35,7 @@
"//compiler/src/iree/compiler/Dialect/Util/IR:td_files",
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:FuncTdFiles",
+ "@llvm-project//mlir:InferIntRangeInterfaceTdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:ViewLikeInterfaceTdFiles",
@@ -81,6 +82,7 @@
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FunctionInterfaces",
"@llvm-project//mlir:IR",
+ "@llvm-project//mlir:InferIntRangeInterface",
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Parser",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt
index 8378551..846bcf0 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt
@@ -45,6 +45,7 @@
MLIRFuncDialect
MLIRFunctionInterfaces
MLIRIR
+ MLIRInferIntRangeInterface
MLIRInferTypeOpInterface
MLIRMemRefDialect
MLIRParser
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index 7210d40..cb5bb41 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -19,6 +19,7 @@
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/FunctionImplementation.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
namespace mlir::iree_compiler::IREE::HAL {
@@ -2084,24 +2085,59 @@
}
}
+// Minimum is the smallest possible result we could get. It's 0 for ID-like
+// operations and 1 for count-like ones.
+static void setResultRangesForInterfaceWorkgroupOp(
+ Value result, const std::optional<APInt> &upperBound,
+ SetIntRangeFn setResultRanges, int64_t minimum) {
+ unsigned width = ConstantIntRanges::getStorageBitwidth(result.getType());
+ if (!upperBound.has_value()) {
+ setResultRanges(
+ result, ConstantIntRanges::fromSigned(APInt(width, minimum),
+ APInt::getSignedMaxValue(width)));
+ return;
+ }
+ setResultRanges(result,
+ ConstantIntRanges::fromUnsigned(APInt(width, minimum),
+ *upperBound + minimum - 1));
+}
+
void InterfaceWorkgroupIDOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
getAsmResultNamesForInterfaceWorkgroupOp("workgroup_id_", getDimension(),
getResult(), setNameFn);
}
+void InterfaceWorkgroupIDOp::inferResultRanges(
+ ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
+ setResultRangesForInterfaceWorkgroupOp(getResult(), getUpperBound(),
+ setResultRanges, /*minimum=*/0);
+}
+
void InterfaceWorkgroupCountOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
getAsmResultNamesForInterfaceWorkgroupOp("workgroup_count_", getDimension(),
getResult(), setNameFn);
}
+void InterfaceWorkgroupCountOp::inferResultRanges(
+ ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
+ setResultRangesForInterfaceWorkgroupOp(getResult(), getUpperBound(),
+ setResultRanges, /*minimum=*/1);
+}
+
void InterfaceWorkgroupSizeOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
getAsmResultNamesForInterfaceWorkgroupOp("workgroup_size_", getDimension(),
getResult(), setNameFn);
}
+void InterfaceWorkgroupSizeOp::inferResultRanges(
+ ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
+ setResultRangesForInterfaceWorkgroupOp(getResult(), getUpperBound(),
+ setResultRanges, /*minimum=*/1);
+}
+
//===----------------------------------------------------------------------===//
// hal.fence.*
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
index 16f1ead..d51e430 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -3029,9 +3029,28 @@
let opDocGroup = OpGroupInterfaceOps in {
-def HAL_InterfaceWorkgroupIDOp : HAL_PureOp<"interface.workgroup.id", [
- DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-]> {
+class HAL_InterfaceWorkgroupOp<string mnemonic, list<Trait> traits = []>
+ : HAL_PureOp<mnemonic, !listconcat(traits, [
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>])> {
+ let arguments = (ins
+ IndexAttr:$dimension,
+ OptionalAttr<IndexAttr>:$upper_bound);
+ let results = (outs HAL_Dim:$result);
+
+ let builders = [
+ OpBuilder<(ins "unsigned":$dim),
+ [{
+ build($_builder, $_state, $_builder.getIndexType(), $_builder.getIndexAttr(dim), ::mlir::IntegerAttr{});
+ }]>,
+ ];
+
+ let assemblyFormat = [{
+ `[` $dimension `]` (`upper_bound` $upper_bound^)? attr-dict `:` type($result)
+ }];
+}
+
+def HAL_InterfaceWorkgroupIDOp : HAL_InterfaceWorkgroupOp<"interface.workgroup.id"> {
let summary = [{returns the index of the current workgroup in the grid}];
let description = [{
The global workgroup ID of the current tile in the range of
@@ -3046,25 +3065,9 @@
%z = hal.interface.workgroup.id[2] : index
```
}];
-
- let arguments = (ins IndexAttr:$dimension);
- let results = (outs HAL_Dim:$result);
-
- let builders = [
- OpBuilder<(ins "unsigned":$dim),
- [{
- build($_builder, $_state, $_builder.getIndexType(), $_builder.getIndexAttr(dim));
- }]>,
- ];
-
- let assemblyFormat = [{
- `[` $dimension `]` attr-dict `:` type($result)
- }];
}
-def HAL_InterfaceWorkgroupCountOp : HAL_PureOp<"interface.workgroup.count", [
- DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-]> {
+def HAL_InterfaceWorkgroupCountOp : HAL_InterfaceWorkgroupOp<"interface.workgroup.count"> {
let summary = [{returns the total workgroup count of the grid}];
let description = [{
The total number of workgroups along each dimension in the dispatch grid.
@@ -3081,24 +3084,9 @@
```
}];
- let arguments = (ins IndexAttr:$dimension);
- let results = (outs HAL_Dim:$result);
-
- let builders = [
- OpBuilder<(ins "unsigned":$dim),
- [{
- build($_builder, $_state, $_builder.getIndexType(), $_builder.getIndexAttr(dim));
- }]>,
- ];
-
- let assemblyFormat = [{
- `[` $dimension `]` attr-dict `:` type($result)
- }];
}
-def HAL_InterfaceWorkgroupSizeOp : HAL_PureOp<"interface.workgroup.size", [
- DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-]> {
+def HAL_InterfaceWorkgroupSizeOp : HAL_InterfaceWorkgroupOp<"interface.workgroup.size"> {
let summary = [{returns the size of each workgroup in invocations}];
let description = [{
The number of local invocations within the current workgroup along each
@@ -3114,20 +3102,6 @@
%z = hal.interface.workgroup.size[2] : index
```
}];
-
- let arguments = (ins IndexAttr:$dimension);
- let results = (outs HAL_Dim:$result);
-
- let builders = [
- OpBuilder<(ins "unsigned":$dim),
- [{
- build($_builder, $_state, $_builder.getIndexType(), $_builder.getIndexAttr(dim));
- }]>,
- ];
-
- let assemblyFormat = [{
- `[` $dimension `]` attr-dict `:` type($result)
- }];
}
def HAL_InterfaceConstantLoadOp : HAL_PureOp<"interface.constant.load"> {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
index 9f3bee7..d830c07 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
@@ -514,7 +514,8 @@
LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<DstOp>(op, op.getResult().getType(),
- op.getDimensionAttr());
+ op.getDimensionAttr(),
+ /*upper_bound=*/nullptr);
return success();
}
};