Reapply "[Codegen][GPU] Add range information to GPU dispatch IDs" (#19361) (#19372)

This reverts commit cb5be1dbd3560f692578c137eadbb413b41e44c7.

Compaled to the previous revision, this one works around a correctness
bug in dataflow analysis that's being fixed by removing the analysis
after SCF->CF.

---

First, this patch implements InferIntRangeInterface for
hal.interface.workgroup.{size,id,count} using a local upper_bound
attribute.

Then, it adds a -iree-codegen-gpu-propagate-dispatch-size-bounds pass
that adds these upper_bounds identifiers to the interface.workgroup
operations and to gpu.thread_id based on static information available
late in the codegen pipeline.

Then, it uses -optimize-int-arithmetic to optimize indexing after
-lower-affine, getting rid of a bunch of "if the input's negative" logic
that isn't actually needed in many of our kernels.

It also ensures that these upper_bound values propagate to LLVM.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
index c9e2363..128ffa9 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
@@ -73,6 +73,7 @@
         "GPUPatterns.cpp",
         "GPUPipelining.cpp",
         "GPUPromoteMatmulOperands.cpp",
+        "GPUPropagateDispatchSizeBounds.cpp",
         "GPUReduceBankConflicts.cpp",
         "GPUReuseSharedMemoryAllocs.cpp",
         "GPUTensorAlloc.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
index 2aeb9ad..97d3240 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
@@ -71,6 +71,7 @@
     "GPUPatterns.cpp"
     "GPUPipelining.cpp"
     "GPUPromoteMatmulOperands.cpp"
+    "GPUPropagateDispatchSizeBounds.cpp"
     "GPUReduceBankConflicts.cpp"
     "GPUReuseSharedMemoryAllocs.cpp"
     "GPUTensorAlloc.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPropagateDispatchSizeBounds.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPropagateDispatchSizeBounds.cpp
new file mode 100644
index 0000000..43aa70b
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPropagateDispatchSizeBounds.cpp
@@ -0,0 +1,103 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Common/GPU/Passes.h"
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
+#include "iree/compiler/Codegen/Utils/Utils.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Transforms/Passes.h"
+
+namespace mlir::iree_compiler {
+
+#define GEN_PASS_DEF_GPUPROPAGATEDISPATCHSIZEBOUNDSPASS
+#include "iree/compiler/Codegen/Common/GPU/Passes.h.inc"
+
+namespace {
+
+static void applyBounds(FunctionOpInterface funcOp,
+                        ArrayRef<int32_t> workgroupSizes,
+                        ArrayRef<int32_t> workgroupCounts) {
+  Builder b(funcOp->getContext());
+  funcOp->walk([&](Operation *op) {
+    TypeSwitch<Operation *>(op)
+        .Case([&](gpu::ThreadIdOp tidOp) {
+          tidOp.setUpperBoundAttr(b.getIndexAttr(
+              workgroupSizes[static_cast<uint32_t>(tidOp.getDimension())]));
+        })
+        .Case([&](IREE::HAL::InterfaceWorkgroupSizeOp wgSizeOp) {
+          wgSizeOp.setUpperBoundAttr(b.getIndexAttr(
+              workgroupSizes[wgSizeOp.getDimension().getZExtValue()]));
+        })
+        .Case([&](IREE::HAL::InterfaceWorkgroupIDOp wgIdOp) {
+          wgIdOp.setUpperBoundAttr(b.getIndexAttr(
+              workgroupCounts[wgIdOp.getDimension().getZExtValue()]));
+        })
+        .Case([&](IREE::HAL::InterfaceWorkgroupCountOp wgCountOp) {
+          wgCountOp.setUpperBoundAttr(b.getIndexAttr(
+              workgroupCounts[wgCountOp.getDimension().getZExtValue()]));
+        })
+        .Default([](Operation *) {});
+  });
+}
+
+struct GPUPropagateDispatchSizeBoundsPass final
+    : impl::GPUPropagateDispatchSizeBoundsPassBase<
+          GPUPropagateDispatchSizeBoundsPass> {
+  using Base::Base;
+
+  void runOnOperation() override {
+    FunctionOpInterface funcOp = getOperation();
+    IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp);
+    if (!target) {
+      funcOp.emitWarning("no known target attribute late in GPU codegen");
+      return;
+    }
+    SmallVector<int32_t, 3> workgroupSizes(
+        target.getWgp().getMaxWorkgroupSizes().asArrayRef());
+    SmallVector<int32_t, 3> workgroupCounts(
+        target.getWgp().getMaxWorkgroupCounts().asArrayRef());
+
+    std::optional<SmallVector<int64_t>> staticWorkgroupSize =
+        getWorkgroupSize(funcOp);
+
+    // Late in codegen, we've reconciled the workgroup size onto the export op.
+    if (std::optional<IREE::HAL::ExecutableExportOp> exportOp =
+            getEntryPoint(funcOp)) {
+      if (std::optional<ArrayAttr> exportWorkgroupSize =
+              exportOp->getWorkgroupSize()) {
+        staticWorkgroupSize =
+            llvm::map_to_vector(exportWorkgroupSize->getAsRange<IntegerAttr>(),
+                                [](IntegerAttr a) { return a.getInt(); });
+      }
+    }
+
+    if (staticWorkgroupSize) {
+      // Target info with no workgroup sizes gives a 0-length array, hence no
+      // zip_equal.
+      for (auto [size, staticSize] :
+           llvm::zip(workgroupSizes, *staticWorkgroupSize)) {
+        size = staticSize;
+      }
+    }
+    SmallVector<int64_t> staticWorkgroupCounts = getStaticNumWorkgroups(funcOp);
+    assert(staticWorkgroupCounts.size() <= 3 &&
+           "workgroup counts are 3D at most");
+    for (auto [count, staticCount] :
+         llvm::zip(workgroupCounts, staticWorkgroupCounts)) {
+      if (staticCount != ShapedType::kDynamic) {
+        count = staticCount;
+      }
+    }
+
+    applyBounds(funcOp, workgroupSizes, workgroupCounts);
+  }
+};
+} // namespace
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
index 7891309..b3fdd50 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
@@ -178,6 +178,11 @@
   ];
 }
 
+def GPUPropagateDispatchSizeBoundsPass :
+    InterfacePass<"iree-codegen-gpu-propagate-dispatch-size-bounds", "mlir::FunctionOpInterface"> {
+  let summary = "Pass to annotate workitem and workgroup IDs with known bounds";
+}
+
 def GPUReduceBankConflictsPass :
     InterfacePass<"iree-codegen-gpu-reduce-bank-conflicts", "mlir::FunctionOpInterface"> {
   let summary = "Pass to try to reduce the number of bank conflicts by padding memref.alloc ops.";
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
index 41afbb6..dc8e6a1 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
@@ -41,6 +41,7 @@
             "gpu_pad_operands.mlir",
             "gpu_pipeline.mlir",
             "gpu_promote_matmul_operands.mlir",
+            "gpu_propagate_dispatch_size_bounds.mlir",
             "gpu_reorder_workgroups_static.mlir",
             "gpu_reorder_workgroups.mlir",
             "gpu_reuse_shared_memory_allocs.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
index ad86649..4dc0f28 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
@@ -37,6 +37,7 @@
     "gpu_pad_operands.mlir"
     "gpu_pipeline.mlir"
     "gpu_promote_matmul_operands.mlir"
+    "gpu_propagate_dispatch_size_bounds.mlir"
     "gpu_reorder_workgroups.mlir"
     "gpu_reorder_workgroups_static.mlir"
     "gpu_reuse_shared_memory_allocs.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_propagate_dispatch_size_bounds.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_propagate_dispatch_size_bounds.mlir
new file mode 100644
index 0000000..f26f2c5
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_propagate_dispatch_size_bounds.mlir
@@ -0,0 +1,122 @@
+// RUN: iree-opt %s --split-input-file \
+// RUN:     --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-codegen-gpu-propagate-dispatch-size-bounds)))))" \
+// RUN:  | FileCheck %s
+
+// Note: not the real target definition, missing types
+#executable_target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #iree_gpu.target<arch = "gfx1100", features = "",
+  wgp = <compute =  fp32,
+    storage =  b32,
+    subgroup =  arithmetic,
+    dot =  none, mma = [],
+    subgroup_size_choices = [32, 64],
+    max_workgroup_sizes = [1024, 1024, 1024],
+    max_thread_count_per_workgroup = 1024,
+    max_workgroup_memory_bytes = 65536,
+    max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>}>
+#pipeline_layout = #hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer>]>
+
+hal.executable private @static {
+  hal.executable.variant public @rocm_hsaco_fb target(#executable_target) {
+    hal.executable.export public @static ordinal(0) layout(#pipeline_layout) attributes {workgroup_size = [64 : index, 2 : index, 1 : index]} {
+    ^bb0(%arg0: !hal.device):
+      %c32 = arith.constant 32 : index
+      %c8 = arith.constant 8 : index
+      %c1 = arith.constant 1 : index
+      hal.return %c32, %c8, %c1 : index, index, index
+    }
+    builtin.module {
+// CHECK-LABEL: func.func @static
+      func.func @static() {
+// CHECK: gpu.thread_id x upper_bound 64
+// CHECK: gpu.thread_id y upper_bound 2
+// CHECK: gpu.thread_id z upper_bound 1
+        %thread_id_x = gpu.thread_id x
+        %thread_id_y = gpu.thread_id y
+        %thread_id_z = gpu.thread_id z
+
+// CHECK: hal.interface.workgroup.size[0] upper_bound 64
+// CHECK: hal.interface.workgroup.size[1] upper_bound 2
+// CHECK: hal.interface.workgroup.size[2] upper_bound 1
+        %workgroup_size_x = hal.interface.workgroup.size[0] : index
+        %workgroup_size_y = hal.interface.workgroup.size[1] : index
+        %workgroup_size_z = hal.interface.workgroup.size[2] : index
+
+// CHECK: hal.interface.workgroup.id[0] upper_bound 32
+// CHECK: hal.interface.workgroup.id[1] upper_bound 8
+// CHECK: hal.interface.workgroup.id[2] upper_bound 1
+        %workgroup_id_x = hal.interface.workgroup.id[0] : index
+        %workgroup_id_y = hal.interface.workgroup.id[1] : index
+        %workgroup_id_z = hal.interface.workgroup.id[2] : index
+
+// CHECK: hal.interface.workgroup.count[0] upper_bound 32
+// CHECK: hal.interface.workgroup.count[1] upper_bound 8
+// CHECK: hal.interface.workgroup.count[2] upper_bound 1
+        %workgroup_conut_x = hal.interface.workgroup.count[0] : index
+        %workgroup_count_y = hal.interface.workgroup.count[1] : index
+        %workgroup_count_z = hal.interface.workgroup.count[2] : index
+
+        return
+      }
+    }
+  }
+}
+
+// -----
+
+#executable_target = #hal.executable.target<"rocm", "rocm-hsaco-fb",
+  {iree.gpu.target = #iree_gpu.target<arch = "gfx1100", features = "",
+  wgp = <compute =  fp32,
+    storage =  b32,
+    subgroup = arithmetic,
+    dot =  none, mma = [],
+    subgroup_size_choices = [32, 64],
+    max_workgroup_sizes = [1024, 1024, 1024],
+    max_thread_count_per_workgroup = 1024,
+    max_workgroup_memory_bytes = 65536,
+    max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>}>
+#pipeline_layout = #hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer>]>
+
+hal.executable private @dynamic {
+  hal.executable.variant public @rocm_hsaco_fb target(#executable_target) {
+    hal.executable.export public @dynamic ordinal(0) layout(#pipeline_layout) {
+      ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
+      %count_x = affine.apply affine_map<()[s0] -> (s0 ceildiv 32)>()[%arg1]
+      %count_y = affine.apply affine_map<()[s0] -> (s0 ceildiv 8)>()[%arg2]
+      %count_z = arith.constant 1 : index
+      hal.return %count_x, %count_y, %count_z : index, index, index
+    }
+    builtin.module {
+      func.func @dynamic() {
+// CHECK: gpu.thread_id x upper_bound 1024
+// CHECK: gpu.thread_id y upper_bound 1024
+// CHECK: gpu.thread_id z upper_bound 1024
+        %thread_id_x = gpu.thread_id x
+        %thread_id_y = gpu.thread_id y
+        %thread_id_z = gpu.thread_id z
+
+// CHECK: hal.interface.workgroup.size[0] upper_bound 1024
+// CHECK: hal.interface.workgroup.size[1] upper_bound 1024
+// CHECK: hal.interface.workgroup.size[2] upper_bound 1024
+        %workgroup_size_x = hal.interface.workgroup.size[0] : index
+        %workgroup_size_y = hal.interface.workgroup.size[1] : index
+        %workgroup_size_z = hal.interface.workgroup.size[2] : index
+
+// CHECK: hal.interface.workgroup.id[0] upper_bound 2147483647
+// CHECK: hal.interface.workgroup.id[1] upper_bound 2147483647
+// CHECK: hal.interface.workgroup.id[2] upper_bound 1
+        %workgroup_id_x = hal.interface.workgroup.id[0] : index
+        %workgroup_id_y = hal.interface.workgroup.id[1] : index
+        %workgroup_id_z = hal.interface.workgroup.id[2] : index
+
+// CHECK: hal.interface.workgroup.count[0] upper_bound 2147483647
+// CHECK: hal.interface.workgroup.count[1] upper_bound 2147483647
+// CHECK: hal.interface.workgroup.count[2] upper_bound 1
+        %workgroup_conut_x = hal.interface.workgroup.count[0] : index
+        %workgroup_count_y = hal.interface.workgroup.count[1] : index
+        %workgroup_count_z = hal.interface.workgroup.count[2] : index
+
+        return
+      }
+    }
+  }
+}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp
index c056d44..1441f95 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp
@@ -505,7 +505,10 @@
     int32_t index = static_cast<int32_t>(op.getDimension().getSExtValue());
     std::array<gpu::Dimension, 3> dimAttr{gpu::Dimension::x, gpu::Dimension::y,
                                           gpu::Dimension::z};
-    rewriter.replaceOpWithNewOp<NewOpTy>(op, op.getType(), dimAttr[index]);
+    NewOpTy newOp =
+        rewriter.replaceOpWithNewOp<NewOpTy>(op, op.getType(), dimAttr[index]);
+    if (IntegerAttr bound = op.getUpperBoundAttr())
+      newOp.setUpperBoundAttr(bound);
     return success();
   }
 };
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index 53e49ef..f8ebe1c 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -1067,7 +1067,13 @@
       .addPass(createCSEPass)
       // Hoist the resulting decompositions.
       .addPass(createIREELoopInvariantCodeMotionPass)
-      .addPass(createLowerAffinePass);
+      .addPass(affine::createAffineExpandIndexOpsPass)
+      .addPass(createLowerAffinePass)
+      .addPass(IREE::Util::createOptimizeIntArithmeticPass)
+      // Do another round of LICM now that we've lowered and optimized
+      // arithmetic
+      .addPass(createCSEPass)
+      .addPass(createIREELoopInvariantCodeMotionPass);
 }
 
 static void addLowerToLLVMGPUPasses(OpPassManager &modulePassManager,
@@ -1103,7 +1109,9 @@
   FunctionLikeNest funcPassManager(modulePassManager);
   funcPassManager.addPass(createFoldTensorExtractOpPass)
       .addPass(createLLVMGPUVectorLoweringPass)
-      .addPass(createExpandGPUOpsPass);
+      .addPass(createExpandGPUOpsPass)
+      // Expose workitem and workgroup counts to range inference later.
+      .addPass(createGPUPropagateDispatchSizeBoundsPass);
 
   // This pass needs to run before SCF -> CF.
   addLowerAndOptimizeAddressComputationPasses(funcPassManager);
@@ -1130,9 +1138,7 @@
       .addPass(memref::createExpandStridedMetadataPass)
       .addPass(createEmulateNarrowTypePass)
       .addPass(affine::createAffineExpandIndexOpsPass)
-      .addPass(createLowerAffinePass)
-      .addPass(createCanonicalizerPass)
-      .addPass(createCSEPass);
+      .addPass(createLowerAffinePass);
 
   // Strip out the debug info for the kernel.
   modulePassManager.addPass(createStripDebugInfoPass());
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir
index 6c1c5e1..ba6b5da7 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir
@@ -40,7 +40,7 @@
 // CHECK-DAG: %[[C8192:.*]] = llvm.mlir.constant(8192 : index) : i64
 //
 // Match the interesting special registers.
-// CHECK-DAG: %[[TID_Y:.*]] = nvvm.read.ptx.sreg.tid.y : i32
+// CHECK-DAG: %[[TID_Y:.*]] = nvvm.read.ptx.sreg.tid.y range <i32, 0, 2> : i32
 // CHECK-DAG: %[[TID_Y_EXT:.*]] = llvm.sext %[[TID_Y]] : i32 to i64
 // CHECK-DAG: %[[LANEID:.*]] = nvvm.read.ptx.sreg.laneid range <i32, 0, 32> : i32
 // CHECK-DAG: %[[LANEID_EXT:.*]] = llvm.sext %[[LANEID]] : i32 to i64
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
index ea0aa9f..511dbe7 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -227,9 +227,11 @@
 /// Adds passes to perform the final SPIR-V conversion.
 static void addSPIRVLoweringPasses(OpPassManager &modulePassManager) {
   FunctionLikeNest(modulePassManager)
+      .addPass(createGPUPropagateDispatchSizeBoundsPass)
       .addPass(createCanonicalizerPass)
       .addPass(createCSEPass)
       .addPass(createLowerAffinePass)
+      .addPass(IREE::Util::createOptimizeIntArithmeticPass)
 
       // Lower ApplyScale before the i64 Emulation Pass so that new 64-bit ops
       // are also emulated if not supported by the target.
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel
index d9d6a92..3f80245 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel
@@ -35,6 +35,7 @@
         "//compiler/src/iree/compiler/Dialect/Util/IR:td_files",
         "@llvm-project//mlir:BuiltinDialectTdFiles",
         "@llvm-project//mlir:FuncTdFiles",
+        "@llvm-project//mlir:InferIntRangeInterfaceTdFiles",
         "@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
         "@llvm-project//mlir:OpBaseTdFiles",
         "@llvm-project//mlir:ViewLikeInterfaceTdFiles",
@@ -81,6 +82,7 @@
         "@llvm-project//mlir:FuncDialect",
         "@llvm-project//mlir:FunctionInterfaces",
         "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:InferIntRangeInterface",
         "@llvm-project//mlir:InferTypeOpInterface",
         "@llvm-project//mlir:MemRefDialect",
         "@llvm-project//mlir:Parser",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt
index 8378551..846bcf0 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt
@@ -45,6 +45,7 @@
     MLIRFuncDialect
     MLIRFunctionInterfaces
     MLIRIR
+    MLIRInferIntRangeInterface
     MLIRInferTypeOpInterface
     MLIRMemRefDialect
     MLIRParser
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index 7210d40..cb5bb41 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -19,6 +19,7 @@
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/FunctionImplementation.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 
 namespace mlir::iree_compiler::IREE::HAL {
@@ -2084,24 +2085,59 @@
   }
 }
 
+// Minimum is the smallest possible result we could get. It's 0 for ID-like
+// operations and 1 for count-like ones.
+static void setResultRangesForInterfaceWorkgroupOp(
+    Value result, const std::optional<APInt> &upperBound,
+    SetIntRangeFn setResultRanges, int64_t minimum) {
+  unsigned width = ConstantIntRanges::getStorageBitwidth(result.getType());
+  if (!upperBound.has_value()) {
+    setResultRanges(
+        result, ConstantIntRanges::fromSigned(APInt(width, minimum),
+                                              APInt::getSignedMaxValue(width)));
+    return;
+  }
+  setResultRanges(result,
+                  ConstantIntRanges::fromUnsigned(APInt(width, minimum),
+                                                  *upperBound + minimum - 1));
+}
+
 void InterfaceWorkgroupIDOp::getAsmResultNames(
     function_ref<void(Value, StringRef)> setNameFn) {
   getAsmResultNamesForInterfaceWorkgroupOp("workgroup_id_", getDimension(),
                                            getResult(), setNameFn);
 }
 
+void InterfaceWorkgroupIDOp::inferResultRanges(
+    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
+  setResultRangesForInterfaceWorkgroupOp(getResult(), getUpperBound(),
+                                         setResultRanges, /*minimum=*/0);
+}
+
 void InterfaceWorkgroupCountOp::getAsmResultNames(
     function_ref<void(Value, StringRef)> setNameFn) {
   getAsmResultNamesForInterfaceWorkgroupOp("workgroup_count_", getDimension(),
                                            getResult(), setNameFn);
 }
 
+void InterfaceWorkgroupCountOp::inferResultRanges(
+    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
+  setResultRangesForInterfaceWorkgroupOp(getResult(), getUpperBound(),
+                                         setResultRanges, /*minimum=*/1);
+}
+
 void InterfaceWorkgroupSizeOp::getAsmResultNames(
     function_ref<void(Value, StringRef)> setNameFn) {
   getAsmResultNamesForInterfaceWorkgroupOp("workgroup_size_", getDimension(),
                                            getResult(), setNameFn);
 }
 
+void InterfaceWorkgroupSizeOp::inferResultRanges(
+    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
+  setResultRangesForInterfaceWorkgroupOp(getResult(), getUpperBound(),
+                                         setResultRanges, /*minimum=*/1);
+}
+
 //===----------------------------------------------------------------------===//
 // hal.fence.*
 //===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
index 16f1ead..d51e430 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -3029,9 +3029,28 @@
 
 let opDocGroup = OpGroupInterfaceOps in {
 
-def HAL_InterfaceWorkgroupIDOp : HAL_PureOp<"interface.workgroup.id", [
-  DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-]> {
+class HAL_InterfaceWorkgroupOp<string mnemonic, list<Trait> traits = []>
+  : HAL_PureOp<mnemonic, !listconcat(traits, [
+      DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+      DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>])> {
+  let arguments = (ins
+    IndexAttr:$dimension,
+    OptionalAttr<IndexAttr>:$upper_bound);
+  let results = (outs HAL_Dim:$result);
+
+  let builders = [
+    OpBuilder<(ins "unsigned":$dim),
+    [{
+      build($_builder, $_state, $_builder.getIndexType(), $_builder.getIndexAttr(dim), ::mlir::IntegerAttr{});
+    }]>,
+  ];
+
+  let assemblyFormat = [{
+    `[` $dimension `]` (`upper_bound` $upper_bound^)? attr-dict `:` type($result)
+  }];
+}
+
+def HAL_InterfaceWorkgroupIDOp : HAL_InterfaceWorkgroupOp<"interface.workgroup.id"> {
   let summary = [{returns the index of the current workgroup in the grid}];
   let description = [{
     The global workgroup ID of the current tile in the range of
@@ -3046,25 +3065,9 @@
     %z = hal.interface.workgroup.id[2] : index
     ```
   }];
-
-  let arguments = (ins IndexAttr:$dimension);
-  let results = (outs HAL_Dim:$result);
-
-  let builders = [
-    OpBuilder<(ins "unsigned":$dim),
-    [{
-      build($_builder, $_state, $_builder.getIndexType(), $_builder.getIndexAttr(dim));
-    }]>,
-  ];
-
-  let assemblyFormat = [{
-    `[` $dimension `]` attr-dict `:` type($result)
-  }];
 }
 
-def HAL_InterfaceWorkgroupCountOp : HAL_PureOp<"interface.workgroup.count", [
-  DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-]> {
+def HAL_InterfaceWorkgroupCountOp : HAL_InterfaceWorkgroupOp<"interface.workgroup.count"> {
   let summary = [{returns the total workgroup count of the grid}];
   let description = [{
     The total number of workgroups along each dimension in the dispatch grid.
@@ -3081,24 +3084,9 @@
     ```
   }];
 
-  let arguments = (ins IndexAttr:$dimension);
-  let results = (outs HAL_Dim:$result);
-
-  let builders = [
-    OpBuilder<(ins "unsigned":$dim),
-    [{
-      build($_builder, $_state, $_builder.getIndexType(), $_builder.getIndexAttr(dim));
-    }]>,
-  ];
-
-  let assemblyFormat = [{
-    `[` $dimension `]` attr-dict `:` type($result)
-  }];
 }
 
-def HAL_InterfaceWorkgroupSizeOp : HAL_PureOp<"interface.workgroup.size", [
-  DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-]> {
+def HAL_InterfaceWorkgroupSizeOp : HAL_InterfaceWorkgroupOp<"interface.workgroup.size"> {
   let summary = [{returns the size of each workgroup in invocations}];
   let description = [{
     The number of local invocations within the current workgroup along each
@@ -3114,20 +3102,6 @@
     %z = hal.interface.workgroup.size[2] : index
     ```
   }];
-
-  let arguments = (ins IndexAttr:$dimension);
-  let results = (outs HAL_Dim:$result);
-
-  let builders = [
-    OpBuilder<(ins "unsigned":$dim),
-    [{
-      build($_builder, $_state, $_builder.getIndexType(), $_builder.getIndexAttr(dim));
-    }]>,
-  ];
-
-  let assemblyFormat = [{
-    `[` $dimension `]` attr-dict `:` type($result)
-  }];
 }
 
 def HAL_InterfaceConstantLoadOp : HAL_PureOp<"interface.constant.load"> {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
index 9f3bee7..d830c07 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
@@ -514,7 +514,8 @@
   LogicalResult matchAndRewrite(SrcOp op,
                                 PatternRewriter &rewriter) const override {
     rewriter.replaceOpWithNewOp<DstOp>(op, op.getResult().getType(),
-                                       op.getDimensionAttr());
+                                       op.getDimensionAttr(),
+                                       /*upper_bound=*/nullptr);
     return success();
   }
 };