[Codegen] Lower `hal.interface.workgroup.size` in GPU codegen (#18145)

Cleanup related to https://github.com/iree-org/iree/issues/16554

---------

Signed-off-by: nithinsubbiah <nithinsubbiah@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp
index 8319790..821623d 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp
@@ -533,10 +533,12 @@
 }
 
 void populateLowerHALInterfaceOp(RewritePatternSet &patterns) {
-  patterns.insert<HALInterfaceWorkgroupOpsConverter<
-                      IREE::HAL::InterfaceWorkgroupIDOp, gpu::BlockIdOp>,
-                  HALInterfaceWorkgroupOpsConverter<
-                      IREE::HAL::InterfaceWorkgroupCountOp, gpu::GridDimOp>>(
+  patterns.add<HALInterfaceWorkgroupOpsConverter<
+                   IREE::HAL::InterfaceWorkgroupIDOp, gpu::BlockIdOp>,
+               HALInterfaceWorkgroupOpsConverter<
+                   IREE::HAL::InterfaceWorkgroupSizeOp, gpu::BlockDimOp>,
+               HALInterfaceWorkgroupOpsConverter<
+                   IREE::HAL::InterfaceWorkgroupCountOp, gpu::GridDimOp>>(
       patterns.getContext());
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_nvvm.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_nvvm.mlir
index 0bd40d0..ea5f33f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_nvvm.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_nvvm.mlir
@@ -426,3 +426,33 @@
 //       CHECK:   %[[MASK_BIT:.+]] = llvm.icmp "sgt" {{.*}} : vector<1xi64>
 //       CHECK:   llvm.intr.masked.load %{{.*}}, %[[MASK_BIT]]
 //       CHECK:   llvm.intr.masked.store %{{.*}}, %[[MASK_BIT]]
+
+// -----
+// Test workgroup size lowering
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>,
+    #hal.descriptor_set.binding<2, storage_buffer>
+  ]>
+]>
+hal.executable private @interface_wg_size {
+  hal.executable.variant @rocm target(<"cuda", "cuda-nvptx-fb">) {
+    hal.executable.export @interface_wg_size layout(#pipeline_layout) attributes {
+      workgroup_size = [32: index, 1: index, 1: index]
+    }
+    builtin.module attributes {} {
+      func.func @interface_wg_size() {
+        %c0 = arith.constant 0.0 : f32
+        %workgroup_size_x = hal.interface.workgroup.size[0] : index
+        %workgroup_size_y = hal.interface.workgroup.size[1] : index
+        %subspan = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) : memref<64x64xf32>
+        memref.store %c0, %subspan[%workgroup_size_x, %workgroup_size_y] : memref<64x64xf32>
+        return
+      }
+    }
+  }
+}
+// CHECK-LABEL: llvm.func @interface_wg_size
+//       CHECK:   %[[WGDIMX:.+]] = nvvm.read.ptx.sreg.ntid.x
+//       CHECK:   %[[WGDIMY:.+]] = nvvm.read.ptx.sreg.ntid.y
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_rocdl.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_rocdl.mlir
index b4dd4ca..d5c9931 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_rocdl.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_rocdl.mlir
@@ -129,3 +129,33 @@
 //       CHECK:   %[[MASK_BIT:.+]] = llvm.icmp "sgt" {{.*}} : vector<1xi64>
 //       CHECK:   llvm.intr.masked.load %{{.*}}, %[[MASK_BIT]]
 //       CHECK:   llvm.intr.masked.store %{{.*}}, %[[MASK_BIT]]
+
+// -----
+// Test workgroup size lowering
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>,
+    #hal.descriptor_set.binding<2, storage_buffer>
+  ]>
+]>
+hal.executable private @interface_wg_size {
+  hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
+    hal.executable.export @interface_wg_size layout(#pipeline_layout) attributes {
+      workgroup_size = [32: index, 1: index, 1: index]
+    }
+    builtin.module attributes {} {
+      func.func @interface_wg_size() {
+        %c0 = arith.constant 0.0 : f32
+        %workgroup_size_x = hal.interface.workgroup.size[0] : index
+        %workgroup_size_y = hal.interface.workgroup.size[1] : index
+        %subspan = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) : memref<64x64xf32>
+        memref.store %c0, %subspan[%workgroup_size_x, %workgroup_size_y] : memref<64x64xf32>
+        return
+      }
+    }
+  }
+}
+// CHECK-LABEL: llvm.func @interface_wg_size
+//       CHECK:   %[[WGDIMX:.+]] = rocdl.workgroup.dim.x
+//       CHECK:   %[[WGDIMY:.+]] = rocdl.workgroup.dim.y
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
index 442d530..0aad84d 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
@@ -302,10 +302,10 @@
   }
 };
 
-/// A pattern to convert hal.interface.workgroup.id/count into corresponding
-/// SPIR-V Builtin ops.
+/// A pattern to convert hal.interface.workgroup.id/count/size into
+/// corresponding SPIR-V Builtin ops.
 template <typename InterfaceOpTy, spirv::BuiltIn builtin>
-struct HALInterfaceWorkgroupIdAndCountConverter final
+struct HALInterfaceWorkgroupOpsConverter final
     : OpConversionPattern<InterfaceOpTy> {
   using OpConversionPattern<InterfaceOpTy>::OpConversionPattern;
 
@@ -656,10 +656,12 @@
   // Add IREE HAL interface op conversions.
   patterns.add<
       HALInterfaceLoadConstantConverter,
-      HALInterfaceWorkgroupIdAndCountConverter<
-          IREE::HAL::InterfaceWorkgroupIDOp, spirv::BuiltIn::WorkgroupId>,
-      HALInterfaceWorkgroupIdAndCountConverter<
-          IREE::HAL::InterfaceWorkgroupCountOp, spirv::BuiltIn::NumWorkgroups>>(
+      HALInterfaceWorkgroupOpsConverter<IREE::HAL::InterfaceWorkgroupIDOp,
+                                        spirv::BuiltIn::WorkgroupId>,
+      HALInterfaceWorkgroupOpsConverter<IREE::HAL::InterfaceWorkgroupSizeOp,
+                                        spirv::BuiltIn::WorkgroupSize>,
+      HALInterfaceWorkgroupOpsConverter<IREE::HAL::InterfaceWorkgroupCountOp,
+                                        spirv::BuiltIn::NumWorkgroups>>(
       typeConverter, context);
 
   // Performs a prelimiary step to analyze all hal.interface.binding.subspan ops
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir
index b518367..9fcdab1 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir
@@ -247,6 +247,44 @@
     #hal.descriptor_set.binding<2, storage_buffer>
   ]>
 ]>
+hal.executable private @interface_wg_size {
+  hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb">) {
+    hal.executable.export @interface_wg_size layout(#pipeline_layout) attributes {
+      workgroup_size = [32: index, 1: index, 1: index]
+    }
+    builtin.module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Int64, Shader], []>, #spirv.resource_limits<>>} {
+      func.func @interface_wg_size() {
+        %c0 = arith.constant 0.0 : f32
+        %workgroup_size_x = hal.interface.workgroup.size[0] : index
+        %workgroup_size_y = hal.interface.workgroup.size[1] : index
+        %subspan = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) : memref<64x64xf32, #spirv.storage_class<StorageBuffer>>
+        memref.store %c0, %subspan[%workgroup_size_x, %workgroup_size_y] : memref<64x64xf32, #spirv.storage_class<StorageBuffer>>
+        return
+      }
+    }
+  }
+}
+
+// CHECK-LABEL: spirv.module
+//   CHECK-DAG:   spirv.GlobalVariable @[[WGSIZE:.+]] built_in("WorkgroupSize")
+//   CHECK-DAG:   spirv.GlobalVariable @[[BIND:.+]] bind(0, 0)
+//       CHECK:     %[[CST0:.+]] = spirv.Constant 0.000000e+00 : f32
+//       CHECK:     %[[ADDR1:.+]] = spirv.mlir.addressof @[[WGSIZE]]
+//       CHECK:     %[[VAL1:.+]] = spirv.Load "Input" %[[ADDR1:.+]]
+//       CHECK:     %[[WGSIZEX:.+]] = spirv.CompositeExtract %[[VAL1]][0 : i32]
+//       CHECK:     %[[ADDR2:.+]] = spirv.mlir.addressof @[[WGSIZE]]
+//       CHECK:     %[[VAL2:.+]] = spirv.Load "Input" %[[ADDR2:.+]]
+//       CHECK:     %[[WGSIZEY:.+]] = spirv.CompositeExtract %[[VAL2]][1 : i32]
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>,
+    #hal.descriptor_set.binding<2, storage_buffer>
+  ]>
+]>
 hal.executable private @interface_wg_count {
   hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb">) {
     hal.executable.export @interface_wg_count layout(#pipeline_layout) attributes {