[codegen] Add max_workgroup_counts to TargetWgpAttr (#17771)

This commit adds a max_workgroup_counts to the workgroup processor
information attribute and sets values for the known targets. Some of
these values may be underestimates as I was not able to locate
information on their values.

This field is added so that we can annotate calls to workgroup.id and
workgroup.count with upper bounds, neabling range inference and strength
reduction.

Note that in some cases (for instance, AMD) we give a
max_workgroup_counts value lower than what is actually supported because
a grid dimension greater than int32_max would be sign-extended to a
negative number to meet the 64-bit nature of `index`.

(This PR is split out of #17707)

Signed-off-by: Krzysztof Drewniak <Krzysztof.Drewniak@amd.com>
diff --git a/compiler/plugins/target/MetalSPIRV/test/smoketest.mlir b/compiler/plugins/target/MetalSPIRV/test/smoketest.mlir
index d32ac8e..7a7af7c 100644
--- a/compiler/plugins/target/MetalSPIRV/test/smoketest.mlir
+++ b/compiler/plugins/target/MetalSPIRV/test/smoketest.mlir
@@ -6,7 +6,8 @@
       #hal.executable.target<"metal-spirv", "metal-msl-fb", {
         iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.3,cap:Shader", wgp = <
           compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [32],
-          max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+          max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+          max_workgroup_counts = [65535, 65535, 65535]>>
       }>
     ]> : !hal.device
   ]
diff --git a/compiler/plugins/target/ROCM/test/target_device_features.mlir b/compiler/plugins/target/ROCM/test/target_device_features.mlir
index d21ca0a..8f6a88f 100644
--- a/compiler/plugins/target/ROCM/test/target_device_features.mlir
+++ b/compiler/plugins/target/ROCM/test/target_device_features.mlir
@@ -8,7 +8,8 @@
 // GFX942-SAME:         subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32,
 // GFX942-SAME:         mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],
 // GFX942-SAME:         subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
-// GFX942-SAME:         max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>,
+// GFX942-SAME:         max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
+// GFX942-SAME:         max_workgroup_counts = [2147483647, 2147483647, 2147483647]>,
 // GFX942-SAME: chip = <wgp_count = 304>>
 
 // GFX940: target = #iree_gpu.target<arch = "gfx940",
diff --git a/compiler/plugins/target/VulkanSPIRV/test/smoketest.mlir b/compiler/plugins/target/VulkanSPIRV/test/smoketest.mlir
index 6ef88a8..b7aeebb 100644
--- a/compiler/plugins/target/VulkanSPIRV/test/smoketest.mlir
+++ b/compiler/plugins/target/VulkanSPIRV/test/smoketest.mlir
@@ -6,7 +6,8 @@
       #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
         iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.3,cap:Shader", wgp = <
           compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [32, 32],
-          max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+          max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+          max_workgroup_counts = [65535, 65535, 65535]>>
       }>
     ]> : !hal.device
   ]
diff --git a/compiler/plugins/target/WebGPUSPIRV/test/smoketest.mlir b/compiler/plugins/target/WebGPUSPIRV/test/smoketest.mlir
index 69c5ceb..e098580 100644
--- a/compiler/plugins/target/WebGPUSPIRV/test/smoketest.mlir
+++ b/compiler/plugins/target/WebGPUSPIRV/test/smoketest.mlir
@@ -7,7 +7,8 @@
       #hal.executable.target<"webgpu-spirv", "webgpu-wgsl-fb", {
         iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.0,cap:Shader,ext:SPV_KHR_storage_buffer_storage_class", wgp = <
           compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [32],
-          max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+          max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+          max_workgroup_counts = [65535, 65535, 65535]>>
       }>
     ]> : !hal.device
   ]
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
index 6809bc6..4c453a4 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
@@ -290,6 +290,8 @@
     "uint32_t":$max_thread_count_per_workgroup,
     // The maximal number of shared memory bytes we can allocate per workgroup.
     "uint32_t":$max_workgroup_memory_bytes,
+    // Tthe maximum number of workgroups per X/Y/Z dimension in a dispatch.
+    "DenseI32ArrayAttr":$max_workgroup_counts,
 
     // An optional extra dict
     // This field allows to inject more features/limits not supported in the
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/target_attrs.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/target_attrs.mlir
index 47b44b0..abeecbd 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/target_attrs.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/target_attrs.mlir
@@ -11,7 +11,8 @@
   // CHECK-SAME: subgroup_size_choices = [32, 64],
   // CHECK-SAME: max_workgroup_sizes = [1024, 1024, 1024],
   // CHECK-SAME: max_thread_count_per_workgroup = 1024,
-  // CHECK-SAME: max_workgroup_memory_bytes = 65536>
+  // CHECK-SAME: max_workgroup_memory_bytes = 65536,
+  // CHECK-SAME: max_workgroup_counts = [2147483647, 2147483647, 2147483647]>
   wgp = #iree_gpu.target_wgp<
     compute = fp16|fp32|int8, storage = b16|b32,
     subgroup = shuffle|arithmetic, dot = dp4xi8toi32,
@@ -19,7 +20,8 @@
     subgroup_size_choices = [32, 64],
     max_workgroup_sizes = [1024, 1024, 1024],
     max_thread_count_per_workgroup = 1024,
-    max_workgroup_memory_bytes = 65536
+    max_workgroup_memory_bytes = 65536,
+    max_workgroup_counts = [2147483647, 2147483647, 2147483647]
   >
 } { return }
 
@@ -37,7 +39,8 @@
     subgroup_size_choices = [32],
     max_workgroup_sizes = [1024, 1024, 1024],
     max_thread_count_per_workgroup = 1024,
-    max_workgroup_memory_bytes = 65536
+    max_workgroup_memory_bytes = 65536,
+    max_workgroup_counts = [2147483647, 2147483647, 2147483647]
   >
 } { return }
 
@@ -67,7 +70,8 @@
       subgroup_size_choices = [32, 64],
       max_workgroup_sizes = [1024, 1024, 1024],
       max_thread_count_per_workgroup = 1024,
-      max_workgroup_memory_bytes = 65536>,
+      max_workgroup_memory_bytes = 65536,
+      max_workgroup_counts = [2147483647, 2147483647, 2147483647]>,
     chip = <wgp_count = 304>
   >
 } { return }
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
index 6f32054..d6f3867 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
@@ -45,6 +45,7 @@
   std::array<int32_t, 3> maxWorkgroupSizes;
   uint32_t maxThreadSize;
   uint32_t maxWorkgroupMemoryBytes;
+  std::array<int32_t, 3> maxWorkgroupCounts;
 };
 
 // Chip level feature/limit details
@@ -106,7 +107,9 @@
       MMAOpsArrayAttr::get(context, mmaAttrs),
       DenseI32ArrayAttr::get(context, subgroupSizes),
       DenseI32ArrayAttr::get(context, wgp->maxWorkgroupSizes),
-      wgp->maxThreadSize, wgp->maxWorkgroupMemoryBytes, DictionaryAttr{});
+      wgp->maxThreadSize, wgp->maxWorkgroupMemoryBytes,
+      DenseI32ArrayAttr::get(context, wgp->maxWorkgroupCounts),
+      DictionaryAttr{});
 
   TargetChipAttr targetChip;
   if (details.chip)
@@ -118,6 +121,10 @@
 
 //===----------------------------------------------------------------------===//
 // Known AMD target details
+//
+// Note: the max workgroup size is given as signed int32 max because MLIR's
+// `index` is signed and the workgroup ID is sign-extended, not zero-extended,
+// to 64-bits.
 //===----------------------------------------------------------------------===//
 
 const WgpDetails *getCDNA3WgpDetails() {
@@ -129,11 +136,17 @@
       MMAIntrinsic::MFMA_I32_16x16x32_I8,
       MMAIntrinsic::MFMA_I32_32x32x16_I8,
   };
-  static const WgpDetails cdna3Wgp = {
-      allComputeBits,   allStorageBits,          allSubgroupOps,
-      allDotProductOps, ARRAY_SIZE(cdna3MMAOps), cdna3MMAOps,
-      {64, 64},         {1024, 1024, 1024},      1024,
-      64 * 1024};
+  static const WgpDetails cdna3Wgp = {allComputeBits,
+                                      allStorageBits,
+                                      allSubgroupOps,
+                                      allDotProductOps,
+                                      ARRAY_SIZE(cdna3MMAOps),
+                                      cdna3MMAOps,
+                                      {64, 64},
+                                      {1024, 1024, 1024},
+                                      1024,
+                                      64 * 1024,
+                                      {0x7fffffff, 0x7fffffff, 0x7fffffff}};
   return &cdna3Wgp;
 }
 
@@ -142,11 +155,17 @@
       MMAIntrinsic::MFMA_F32_16x16x16_F16,
       MMAIntrinsic::MFMA_F32_32x32x8_F16,
   };
-  static const WgpDetails cdna2Wgp = {
-      allComputeBits,   allStorageBits,          allSubgroupOps,
-      allDotProductOps, ARRAY_SIZE(cdna2MMAOps), cdna2MMAOps,
-      {64, 64},         {1024, 1024, 1024},      1024,
-      64 * 1024};
+  static const WgpDetails cdna2Wgp = {allComputeBits,
+                                      allStorageBits,
+                                      allSubgroupOps,
+                                      allDotProductOps,
+                                      ARRAY_SIZE(cdna2MMAOps),
+                                      cdna2MMAOps,
+                                      {64, 64},
+                                      {1024, 1024, 1024},
+                                      1024,
+                                      64 * 1024,
+                                      {0x7fffffff, 0x7fffffff, 0x7fffffff}};
   return &cdna2Wgp;
 }
 
@@ -155,11 +174,17 @@
       MMAIntrinsic::MFMA_F32_16x16x16_F16,
       MMAIntrinsic::MFMA_F32_32x32x8_F16,
   };
-  static const WgpDetails cdna1Wgp = {
-      allComputeBits,   allStorageBits,          allSubgroupOps,
-      allDotProductOps, ARRAY_SIZE(cdna1MMAOps), cdna1MMAOps,
-      {64, 64},         {1024, 1024, 1024},      1024,
-      64 * 1024};
+  static const WgpDetails cdna1Wgp = {allComputeBits,
+                                      allStorageBits,
+                                      allSubgroupOps,
+                                      allDotProductOps,
+                                      ARRAY_SIZE(cdna1MMAOps),
+                                      cdna1MMAOps,
+                                      {64, 64},
+                                      {1024, 1024, 1024},
+                                      1024,
+                                      64 * 1024,
+                                      {0x7fffffff, 0x7fffffff, 0x7fffffff}};
   return &cdna1Wgp;
 }
 
@@ -168,27 +193,39 @@
       MMAIntrinsic::WMMA_F32_16x16x16_F16,
       MMAIntrinsic::WMMA_F16_16x16x16_F16,
   };
-  static const WgpDetails rdna3Wgp = {
-      allComputeBits,   allStorageBits,          allSubgroupOps,
-      allDotProductOps, ARRAY_SIZE(rdna3MMAOps), rdna3MMAOps,
-      {32, 64},         {1024, 1024, 1024},      1024,
-      64 * 1024};
+  static const WgpDetails rdna3Wgp = {allComputeBits,
+                                      allStorageBits,
+                                      allSubgroupOps,
+                                      allDotProductOps,
+                                      ARRAY_SIZE(rdna3MMAOps),
+                                      rdna3MMAOps,
+                                      {32, 64},
+                                      {1024, 1024, 1024},
+                                      1024,
+                                      64 * 1024,
+                                      {0x7fffffff, 0x7fffffff, 0x7fffffff}};
   return &rdna3Wgp;
 }
 
 const WgpDetails *getRDNA2WgpDetails() {
   static const WgpDetails rdna2Wgp = {
-      allComputeBits, allStorageBits,     allSubgroupOps, allDotProductOps,
-      /*mmaCount=*/0, /*mmaOps=*/nullptr, {32, 64},       {1024, 1024, 1024},
-      1024,           64 * 1024};
+      allComputeBits,     allStorageBits,
+      allSubgroupOps,     allDotProductOps,
+      /*mmaCount=*/0,
+      /*mmaOps=*/nullptr, {32, 64},
+      {1024, 1024, 1024}, 1024,
+      64 * 1024,          {0x7fffffff, 0x7fffffff, 0x7fffffff}};
   return &rdna2Wgp;
 }
 
 const WgpDetails *getRDNA1WgpDetails() {
   static const WgpDetails rdna1Wgp = {
-      allComputeBits, allStorageBits,     allSubgroupOps, DotProductOps::None,
-      /*mmaCount=*/0, /*mmaOps=*/nullptr, {32, 64},       {1024, 1024, 1024},
-      1024,           64 * 1024};
+      allComputeBits,     allStorageBits,
+      allSubgroupOps,     DotProductOps::None,
+      /*mmaCount=*/0,
+      /*mmaOps=*/nullptr, {32, 64},
+      {1024, 1024, 1024}, 1024,
+      64 * 1024,          {0x7fffffff, 0x7fffffff, 0x7fffffff}};
   return &rdna1Wgp;
 }
 
@@ -281,7 +318,9 @@
   static const WgpDetails wgp = {
       computeBitwdiths,   allStorageBits,     allSubgroupOps,  allDotProductOps,
       /*mmaCount=*/0,     /*mmaOps=*/nullptr, {32, 32},
-      {1024, 1024, 1024}, 1024,               32 * 1024};
+      {1024, 1024, 1024}, 1024,               32 * 1024,
+      // Note: These values have not been checked and may be higher
+      {0xffff, 0xffff, 0xffff}};
   // clang-format on
 
   return TargetDetails{&wgp, nullptr};
@@ -302,7 +341,9 @@
   static const WgpDetails valhallWgp = {
       computeBitwdiths,   allStorageBits,     allSubgroupOps,  allDotProductOps,
       /*mmaCount=*/0,     /*mmaOps=*/nullptr, {16, 16},        {512, 512, 512},
-      512,                32 * 1024};
+      512,                32 * 1024,
+      // Note: These values have not been checked and may be higher
+      {0xffff, 0xffff, 0xffff}};
   // clang-format on
   return &valhallWgp;
 }
@@ -358,11 +399,17 @@
       MMAIntrinsic::WMMA_F32_16x16x16_F16,
       MMAIntrinsic::WMMA_F16_16x16x16_F16,
   };
-  static const WgpDetails ampereWgp = {
-      allComputeBits,   allStorageBits,     allSubgroupOps,
-      allDotProductOps, ARRAY_SIZE(mmaOps), mmaOps,
-      {32, 32},         {1024, 1024, 1024}, 1024,
-      163 * 1024};
+  static const WgpDetails ampereWgp = {allComputeBits,
+                                       allStorageBits,
+                                       allSubgroupOps,
+                                       allDotProductOps,
+                                       ARRAY_SIZE(mmaOps),
+                                       mmaOps,
+                                       {32, 32},
+                                       {1024, 1024, 1024},
+                                       1024,
+                                       163 * 1024,
+                                       {0x7fffffff, 0xffff, 0xffff}};
   return &ampereWgp;
 }
 
@@ -371,11 +418,17 @@
       MMAIntrinsic::WMMA_F32_16x16x16_F16,
       MMAIntrinsic::WMMA_F16_16x16x16_F16,
   };
-  static const WgpDetails turingWgp = {
-      allComputeBits,   allStorageBits,     allSubgroupOps,
-      allDotProductOps, ARRAY_SIZE(mmaOps), mmaOps,
-      {32, 32},         {1024, 1024, 1024}, 1024,
-      64 * 1024};
+  static const WgpDetails turingWgp = {allComputeBits,
+                                       allStorageBits,
+                                       allSubgroupOps,
+                                       allDotProductOps,
+                                       ARRAY_SIZE(mmaOps),
+                                       mmaOps,
+                                       {32, 32},
+                                       {1024, 1024, 1024},
+                                       1024,
+                                       64 * 1024,
+                                       {0x7fffffff, 0xffff, 0xffff}};
   return &turingWgp;
 }
 
@@ -388,7 +441,8 @@
   static const WgpDetails voltaWgp = {
       allComputeBits,     allStorageBits, allSubgroupOps, DotProductOps::None,
       ARRAY_SIZE(mmaOps), mmaOps,         {32, 32},       {1024, 1024, 1024},
-      1024,               96 * 1024};
+      1024,               96 * 1024,
+      {0x7fffffff, 0xffff, 0xffff}};
   // clang-format on
   return &voltaWgp;
 }
@@ -398,7 +452,8 @@
   static const WgpDetails pascalWgp = {
       allComputeBits, allStorageBits, allSubgroupOps, DotProductOps::None,
       0, nullptr, // Pascal does not have tensor core support.
-      {32, 32}, {1024, 1024, 1024}, 1024, 48 * 1024};
+      {32, 32}, {1024, 1024, 1024}, 1024, 48 * 1024,
+      {0x7fffffff, 0xffff, 0xffff}};
   // clang-format on
   return &pascalWgp;
 }
@@ -479,7 +534,9 @@
       computeBitwdiths,   storageBitwidths,   allSubgroupOps,
       allDotProductOps,   /*mmaCount=*/0,     /*mmaOps=*/nullptr,
       {64, 64},           {1024, 1024, 1024}, 1024,
-      32 * 1024};
+      32 * 1024,
+      // Note: These values have not been checked and may be higher
+      {0xffff, 0xffff, 0xffff}};
   // clang-format on
   return &adrenoWgp;
 }
@@ -545,7 +602,8 @@
       computeBitwdiths,    storageBitwidths,   SubgroupOps::None,
       DotProductOps::None, /*mmaCount=*/0,     /*mmaOps=*/nullptr,
       {64, 64},            {128, 128, 64},     128,
-      16 * 1024};
+      16 * 1024,
+      {0xffff, 0xffff, 0xffff}};
   // clang-format on
   return &androidWgp;
 }
@@ -645,7 +703,8 @@
       computeBitwdiths,    storageBitwidths,   SubgroupOps::None,
       DotProductOps::None, /*mmaCount=*/0,     /*mmaOps=*/nullptr,
       {32, 32},            {128, 128, 64},     128,
-      16 * 1024};
+      16 * 1024,
+      {0xffff, 0xffff, 0xffff}};
   // clang-format on
 
   return createTargetAttr(
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir
index 9f0e6cc..3450851 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir
@@ -93,7 +93,8 @@
   subgroup = shuffle|arithmetic, dot = dp4xi8toi32,
   mma = [],
   subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
-  max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
+  max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
+  max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>
 #executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #target}>
 func.func @matmul_256x256x256() attributes {hal.executable.target = #executable_target_rocm_hsaco_fb} {
   %cst = arith.constant 0.000000e+00 : f32
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ops.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ops.mlir
index e02c07d..13f86b7 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ops.mlir
@@ -10,7 +10,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [16], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #map = affine_map<(d0, d1) -> (d0, d1)>
 func.func @copy_as_generic() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
@@ -44,7 +45,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [64], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 func.func @copy() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
@@ -81,7 +83,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 func.func @avg_pool() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
   %c0 = arith.constant 0 : index
@@ -118,7 +121,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [4], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 func.func @avg_pool() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
@@ -162,7 +166,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 func.func @max_pool() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
   %cst = arith.constant 0xFF800000 : f32
@@ -203,7 +208,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #map = affine_map<(d0, d1) -> (d0, d1)>
 #map1 = affine_map<(d0, d1) -> (d1)>
@@ -244,7 +250,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
 func.func @dwconv_elementwise() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
@@ -292,7 +299,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #map = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
 #map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -332,7 +340,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #map = affine_map<(d0, d1) -> (d0, d1)>
 #map1 = affine_map<(d0, d1) -> (d0)>
@@ -381,7 +390,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [16], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #map = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1, d2)>
 #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
@@ -418,7 +428,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #map = affine_map<(d0, d1) -> (d0, d1)>
 #map1 = affine_map<(d0, d1) -> (d0)>
@@ -465,7 +476,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #map = affine_map<(d0, d1) -> (d0, d1)>
 #map1 = affine_map<(d0, d1) -> (d0)>
@@ -512,7 +524,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #map = affine_map<(d0, d1, d2, d3) -> ()>
 #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matmul.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matmul.mlir
index 9ff2c67..9370b6c 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matmul.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matmul.mlir
@@ -13,7 +13,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 func.func @batch_matmul_1x3x32() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
   %c0 = arith.constant 0 : index
@@ -55,7 +56,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [64], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 func.func @matmul_64x16xi8() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
   %c0 = arith.constant 0 : index
@@ -96,7 +98,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int64|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [64], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 func.func @matmul_64x16xi64() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
   %c0 = arith.constant 0 : index
@@ -137,7 +140,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [64], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #map = affine_map<(d0, d1) -> (d1)>
 #map1 = affine_map<(d0, d1) -> (d0, d1)>
@@ -189,7 +193,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [64], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #map = affine_map<(d0, d1) -> (d1)>
 #map1 = affine_map<(d0, d1) -> (d0, d1)>
@@ -243,7 +248,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #map = affine_map<(d0, d1) -> (d0, d1)>
 func.func @matmul_pointwise_256x1024() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_reduction.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_reduction.mlir
index 09c4c36..0f31e2a 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_reduction.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_reduction.mlir
@@ -10,7 +10,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = shuffle, dot = none, mma = [],
     subgroup_size_choices = [16], max_workgroup_sizes = [512, 512, 512],
-    max_thread_count_per_workgroup = 512, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 512, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #map = affine_map<(d0, d1) -> (d0, d1)>
 #map1 = affine_map<(d0, d1) -> (d0)>
@@ -50,7 +51,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = shuffle, dot = none, mma = [],
     subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
-    max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 #map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -97,7 +99,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = shuffle, dot = none, mma = [],
     subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
-    max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #map = affine_map<(d0, d1) -> (d0, d1)>
 #map1 = affine_map<(d0, d1) -> (d0)>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_gpu_target.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_gpu_target.mlir
index b1da56d..eeaa8fe 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_gpu_target.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_gpu_target.mlir
@@ -4,7 +4,8 @@
 hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {
     iree.gpu.target = #iree_gpu.target<arch = "rdna3", features = "spirv:v1.6,cap:Shader",
       wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [<WMMA_F32_16x16x16_F16>, <WMMA_F16_16x16x16_F16>],
-      subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>}>) {
+      subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
+      max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>}>) {
   hal.executable.export public @dispatch ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer>]>]>) {
   ^bb0(%arg0: !hal.device):
     %x, %y, %z = flow.dispatch.workgroup_count_from_slice
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/emulate_i64.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/emulate_i64.mlir
index a7a2f9d..eb1c281 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/emulate_i64.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/emulate_i64.mlir
@@ -13,7 +13,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [32], max_workgroup_sizes = [1024, 1024, 1024],
-    max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
+    max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 func.func @buffer_types() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
   %c0 = arith.constant 0 : index
@@ -50,7 +51,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [32], max_workgroup_sizes = [1024, 1024, 1024],
-    max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
+    max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 func.func @emulate_1d_vector() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
   %c95232 = arith.constant 95232 : index
@@ -103,7 +105,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int64|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [32], max_workgroup_sizes = [1024, 1024, 1024],
-    max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
+    max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 func.func @no_emulation() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
   %c0 = arith.constant 0 : index
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/illegal_configuration.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/illegal_configuration.mlir
index ba502da..681cf99 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/illegal_configuration.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/illegal_configuration.mlir
@@ -14,7 +14,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #translation = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [16, 8, 1], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
 #compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
@@ -42,7 +43,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #translation = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize, {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
 #compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
@@ -70,7 +72,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #translation = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [16, 8, 128], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
 #compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
@@ -98,7 +101,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #translation = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [32, 8, 1], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
 #compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
@@ -126,7 +130,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #translation = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [8, 2, 1], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
 #compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
@@ -154,7 +159,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #translation = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [15, 8, 1], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
 #compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
@@ -182,7 +188,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #translation = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [16, 8, 1], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
 #compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
@@ -210,7 +217,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #translation = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [16, 8, 1], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
 #compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
@@ -239,7 +247,8 @@
     compute = fp32|fp16|int32, storage = b32|b16, subgroup = none, dot = none,
     mma = [<WMMA_F32_16x16x16_F16>, <WMMA_F16_16x16x16_F16>],
     subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024],
-    max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
+    max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #translation = #iree_codegen.translation_info<SPIRVCooperativeMatrixVectorize workgroup_size = [128, 2, 1] subgroup_size = 64, {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
 #compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
@@ -274,7 +283,8 @@
     compute = fp32|fp16|int32, storage = b32|b16, subgroup = none, dot = none,
     mma = [<WMMA_F32_16x16x16_F16>, <WMMA_F16_16x16x16_F16>],
     subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024],
-    max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
+    max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #translation = #iree_codegen.translation_info<SPIRVCooperativeMatrixVectorize workgroup_size = [128, 2, 1] subgroup_size = 64, {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
 #compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
@@ -309,7 +319,8 @@
     compute = fp32|fp16|int32, storage = b32|b16, subgroup = none, dot = none,
     mma = [<WMMA_F32_16x16x16_F16>, <WMMA_F16_16x16x16_F16>],
     subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024],
-    max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
+    max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #translation = #iree_codegen.translation_info<SPIRVCooperativeMatrixVectorize workgroup_size = [256, 4, 1] subgroup_size = 64, {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
 #compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
@@ -344,7 +355,8 @@
     compute = fp32|fp16|int32, storage = b32|b16, subgroup = none, dot = none,
     mma = [<WMMA_F32_16x16x16_F16>, <WMMA_F16_16x16x16_F16>],
     subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024],
-    max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
+    max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #translation = #iree_codegen.translation_info<SPIRVCooperativeMatrixVectorize workgroup_size = [64, 2, 1] subgroup_size = 64, {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
 #compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
@@ -379,7 +391,8 @@
     compute = fp32|fp16|int32, storage = b32|b16, subgroup = none, dot = none,
     mma = [<WMMA_F32_16x16x16_F16>, <WMMA_F16_16x16x16_F16>],
     subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024],
-    max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
+    max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #translation = #iree_codegen.translation_info<SPIRVCooperativeMatrixVectorize workgroup_size = [128, 4, 1] subgroup_size = 64, {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
 #compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
@@ -413,7 +426,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #map = affine_map<()[s0] -> (s0 * 4)>
 #map1 = affine_map<()[s0] -> (s0 * 16)>
@@ -472,7 +486,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #map = affine_map<()[s0] -> (s0 * 4)>
 #map1 = affine_map<()[s0] -> (s0 * 16)>
@@ -531,7 +546,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #map = affine_map<()[s0] -> (s0 * 4)>
 #map1 = affine_map<()[s0] -> (s0 * 16)>
@@ -590,7 +606,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #translation = #iree_codegen.translation_info<SPIRVBaseVectorize workgroup_size = [32, 1, 1]>
 #compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
@@ -618,7 +635,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
     subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #translation = #iree_codegen.translation_info<SPIRVBaseVectorize workgroup_size = [32, 1, 1]>
 #compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_reduction.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_reduction.mlir
index 243a361..6d4d163 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_reduction.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_reduction.mlir
@@ -12,7 +12,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = shuffle|arithmetic, dot = none, mma = [],
     subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 65536>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 65536,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #map = affine_map<(d0, d1) -> (d0, d1)>
 #map1 = affine_map<(d0, d1) -> (d0)>
@@ -100,7 +101,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = shuffle|arithmetic, dot = none, mma = [],
     subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
-    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 65536>>
+    max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 65536,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 #map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -183,7 +185,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32, storage = b32, subgroup = shuffle|arithmetic, dot = none, mma = [],
     subgroup_size_choices = [32], max_workgroup_sizes = [1024, 1024, 64],
-    max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
+    max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 func.func @softmax() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
   %c0 = arith.constant 0 : index
@@ -290,7 +293,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|fp16|int32, storage = b32|b16, subgroup = shuffle|arithmetic, dot = none, mma = [],
     subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
-    max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
+    max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 }>
 func.func @dynamic_softmax() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
   %c32_i64 = arith.constant 32 : i64
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/map_memref_storage_class.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/map_memref_storage_class.mlir
index 9c622d8..a2507c3 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/map_memref_storage_class.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/map_memref_storage_class.mlir
@@ -6,7 +6,7 @@
       compute = fp32|int32, storage = b32, subgroup = shuffle|arithmetic,
       dot = none, mma = [], subgroup_size_choices = [64],
       max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128,
-      max_workgroup_memory_bytes = 16384>>}>
+      max_workgroup_memory_bytes = 16384, max_workgroup_counts = [65535, 65535, 65535]>>}>
 
 func.func @vulkan_client_api() attributes {hal.executable.target = #target} {
   %0 = "dialect.memref_producer"() : () -> (memref<?x8xf32, #hal.descriptor_type<uniform_buffer>>)
@@ -45,7 +45,7 @@
       compute = fp32|int32, storage = b32, subgroup = shuffle|arithmetic,
       dot = none, mma = [], subgroup_size_choices = [64],
       max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128,
-      max_workgroup_memory_bytes = 16384>>}>
+      max_workgroup_memory_bytes = 16384, max_workgroup_counts = [65535, 65535, 65535]>>}>
 
 func.func @opencl_client_api() attributes {hal.executable.target = #target} {
   %0 = "dialect.memref_producer"() : () -> (memref<?x8xf32, #hal.descriptor_type<uniform_buffer>>)
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_matvec.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_matvec.mlir
index aa5d7cb..aed1a32 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_matvec.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_matvec.mlir
@@ -16,7 +16,8 @@
     iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
       compute = fp32|fp16|int32, storage = b32|b16, subgroup = shuffle|arithmetic, dot = none, mma = [],
       subgroup_size_choices = [32], max_workgroup_sizes = [1024, 1024, 1024],
-      max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
+      max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
+      max_workgroup_counts = [65535, 65535, 65535]>>
     }>) {
     hal.executable.export @i4_dequant_unit_matmul_f16 layout(#pipeline_layout) {
     ^bb0(%arg0: !hal.device):
@@ -125,7 +126,8 @@
     iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
       compute = fp32|fp16|int32, storage = b32|b16, subgroup = shuffle|arithmetic, dot = none, mma = [],
       subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
-      max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
+      max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
+      max_workgroup_counts = [65535, 65535, 65535]>>
   }>) {
     hal.executable.export @i4_dequant_matvec_f16_subgroup_64 layout(#pipeline_layout) {
     ^bb0(%arg0: !hal.device):
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_conv.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_conv.mlir
index 105012f..50aa57e 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_conv.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_conv.mlir
@@ -30,7 +30,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32|int16|int8, storage = b32|b16|b8, subgroup = none, dot = dp4xi8toi32, mma = [],
     subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
-    max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
+    max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 } {
   %c0 = arith.constant 0 : i32
   %i0 = arith.constant 0 : index
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
index c2da6a5..03644e7 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
@@ -271,7 +271,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32|int16|int8, storage = b32|b16|b8, subgroup = none, dot = dp4xi8toi32, mma = [],
     subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
-    max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
+    max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 } {
   %c0 = arith.constant 0 : i32
   %i0 = arith.constant 0 : index
@@ -330,7 +331,8 @@
   iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
     compute = fp32|int32|int16|int8, storage = b32|b16|b8, subgroup = none, dot = dp4xi8toi32, mma = [],
     subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
-    max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
+    max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
+    max_workgroup_counts = [65535, 65535, 65535]>>
 } {
   %c0 = arith.constant 0 : i32
   %i0 = arith.constant 0 : index
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_mfma.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_mfma.mlir
index dc3e66d..84d110a 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_mfma.mlir
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_mfma.mlir
@@ -69,7 +69,8 @@
   subgroup = shuffle|arithmetic, dot = dp4xi8toi32,
   mma = [<MFMA_F32_32x32x8_F16>],
   subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
-  max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
+  max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
+  max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>
 #rocm_executable_target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #target, ukernels = "none"}>
 
 // CHECK-LABEL: func.func @main2(
diff --git a/samples/custom_dispatch/vulkan/shaders/example.mlir b/samples/custom_dispatch/vulkan/shaders/example.mlir
index 69843aa..d9cb5e1 100644
--- a/samples/custom_dispatch/vulkan/shaders/example.mlir
+++ b/samples/custom_dispatch/vulkan/shaders/example.mlir
@@ -19,7 +19,8 @@
       compute = fp32|int32, storage = b32, subgroup = none,
       dot = none, mma = [], subgroup_size_choices = [64, 64],
       max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128,
-      max_workgroup_memory_bytes = 16384>
+      max_workgroup_memory_bytes = 16384,
+      max_workgroup_counts = [65535, 65535, 65535]>
   >
 }>
 
diff --git a/samples/custom_dispatch/vulkan/shaders/example_inline.mlir b/samples/custom_dispatch/vulkan/shaders/example_inline.mlir
index 4157651..2882134 100644
--- a/samples/custom_dispatch/vulkan/shaders/example_inline.mlir
+++ b/samples/custom_dispatch/vulkan/shaders/example_inline.mlir
@@ -19,7 +19,8 @@
       compute = fp32|int32, storage = b32, subgroup = none,
       dot = none, mma = [], subgroup_size_choices = [64, 64],
       max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128,
-      max_workgroup_memory_bytes = 16384>
+      max_workgroup_memory_bytes = 16384,
+      max_workgroup_counts = [65535, 65535, 65535]>
   >
 }>
 
diff --git a/samples/custom_dispatch/vulkan/shaders/example_transform.mlir b/samples/custom_dispatch/vulkan/shaders/example_transform.mlir
index 4bea02d..08c40a1 100644
--- a/samples/custom_dispatch/vulkan/shaders/example_transform.mlir
+++ b/samples/custom_dispatch/vulkan/shaders/example_transform.mlir
@@ -23,7 +23,8 @@
       compute = fp32|int32, storage = b32, subgroup = shuffle|arithmetic,
       dot = none, mma = [], subgroup_size_choices = [64, 64],
       max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128,
-      max_workgroup_memory_bytes = 16384>
+      max_workgroup_memory_bytes = 16384,
+      max_workgroup_counts = [65535, 65535, 65535]>
   >
 }>
 
diff --git a/samples/custom_dispatch/vulkan/shaders/example_transform_spec.mlir b/samples/custom_dispatch/vulkan/shaders/example_transform_spec.mlir
index 5bcdafe..8e23206 100644
--- a/samples/custom_dispatch/vulkan/shaders/example_transform_spec.mlir
+++ b/samples/custom_dispatch/vulkan/shaders/example_transform_spec.mlir
@@ -12,7 +12,8 @@
       compute = fp32|int32, storage = b32, subgroup = shuffle|arithmetic,
       dot = none, mma = [], subgroup_size_choices = [64, 64],
       max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128,
-      max_workgroup_memory_bytes = 16384>
+      max_workgroup_memory_bytes = 16384,
+      max_workgroup_counts = [65535, 65535, 65535]>
   >
 }>
 
diff --git a/samples/transform_dialect/example_module.mlir b/samples/transform_dialect/example_module.mlir
index 723bbbf..2fb3498 100644
--- a/samples/transform_dialect/example_module.mlir
+++ b/samples/transform_dialect/example_module.mlir
@@ -27,7 +27,7 @@
 
 #target = #iree_gpu.target<arch = "", features = "spirv:v1.3,cap:Shader", wgp = <
   compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [64, 64],
-  max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+  max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384, max_workgroup_counts = [65535, 65535, 65535]>>
 
 #pipeline_layout_0 = #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>
 #pipeline_layout_1 = #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>