[Codegen] Lower `hal.interface.workgroup.size` in GPU codegen (#18145)
Cleanup related to https://github.com/iree-org/iree/issues/16554
---------
Signed-off-by: nithinsubbiah <nithinsubbiah@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp
index 8319790..821623d 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp
@@ -533,10 +533,12 @@
}
void populateLowerHALInterfaceOp(RewritePatternSet &patterns) {
- patterns.insert<HALInterfaceWorkgroupOpsConverter<
- IREE::HAL::InterfaceWorkgroupIDOp, gpu::BlockIdOp>,
- HALInterfaceWorkgroupOpsConverter<
- IREE::HAL::InterfaceWorkgroupCountOp, gpu::GridDimOp>>(
+ patterns.add<HALInterfaceWorkgroupOpsConverter<
+ IREE::HAL::InterfaceWorkgroupIDOp, gpu::BlockIdOp>,
+ HALInterfaceWorkgroupOpsConverter<
+ IREE::HAL::InterfaceWorkgroupSizeOp, gpu::BlockDimOp>,
+ HALInterfaceWorkgroupOpsConverter<
+ IREE::HAL::InterfaceWorkgroupCountOp, gpu::GridDimOp>>(
patterns.getContext());
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_nvvm.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_nvvm.mlir
index 0bd40d0..ea5f33f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_nvvm.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_nvvm.mlir
@@ -426,3 +426,33 @@
// CHECK: %[[MASK_BIT:.+]] = llvm.icmp "sgt" {{.*}} : vector<1xi64>
// CHECK: llvm.intr.masked.load %{{.*}}, %[[MASK_BIT]]
// CHECK: llvm.intr.masked.store %{{.*}}, %[[MASK_BIT]]
+
+// -----
+// Test workgroup size lowering
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>,
+ #hal.descriptor_set.binding<2, storage_buffer>
+ ]>
+]>
+hal.executable private @interface_wg_size {
+ hal.executable.variant @rocm target(<"cuda", "cuda-nvptx-fb">) {
+ hal.executable.export @interface_wg_size layout(#pipeline_layout) attributes {
+ workgroup_size = [32: index, 1: index, 1: index]
+ }
+ builtin.module attributes {} {
+ func.func @interface_wg_size() {
+ %c0 = arith.constant 0.0 : f32
+ %workgroup_size_x = hal.interface.workgroup.size[0] : index
+ %workgroup_size_y = hal.interface.workgroup.size[1] : index
+ %subspan = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) : memref<64x64xf32>
+ memref.store %c0, %subspan[%workgroup_size_x, %workgroup_size_y] : memref<64x64xf32>
+ return
+ }
+ }
+ }
+}
+// CHECK-LABEL: llvm.func @interface_wg_size
+// CHECK: %[[WGDIMX:.+]] = nvvm.read.ptx.sreg.ntid.x
+// CHECK: %[[WGDIMY:.+]] = nvvm.read.ptx.sreg.ntid.y
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_rocdl.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_rocdl.mlir
index b4dd4ca..d5c9931 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_rocdl.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_rocdl.mlir
@@ -129,3 +129,33 @@
// CHECK: %[[MASK_BIT:.+]] = llvm.icmp "sgt" {{.*}} : vector<1xi64>
// CHECK: llvm.intr.masked.load %{{.*}}, %[[MASK_BIT]]
// CHECK: llvm.intr.masked.store %{{.*}}, %[[MASK_BIT]]
+
+// -----
+// Test workgroup size lowering
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>,
+ #hal.descriptor_set.binding<2, storage_buffer>
+ ]>
+]>
+hal.executable private @interface_wg_size {
+ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
+ hal.executable.export @interface_wg_size layout(#pipeline_layout) attributes {
+ workgroup_size = [32: index, 1: index, 1: index]
+ }
+ builtin.module attributes {} {
+ func.func @interface_wg_size() {
+ %c0 = arith.constant 0.0 : f32
+ %workgroup_size_x = hal.interface.workgroup.size[0] : index
+ %workgroup_size_y = hal.interface.workgroup.size[1] : index
+ %subspan = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) : memref<64x64xf32>
+ memref.store %c0, %subspan[%workgroup_size_x, %workgroup_size_y] : memref<64x64xf32>
+ return
+ }
+ }
+ }
+}
+// CHECK-LABEL: llvm.func @interface_wg_size
+// CHECK: %[[WGDIMX:.+]] = rocdl.workgroup.dim.x
+// CHECK: %[[WGDIMY:.+]] = rocdl.workgroup.dim.y
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
index 442d530..0aad84d 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
@@ -302,10 +302,10 @@
}
};
-/// A pattern to convert hal.interface.workgroup.id/count into corresponding
-/// SPIR-V Builtin ops.
+/// A pattern to convert hal.interface.workgroup.id/count/size into
+/// corresponding SPIR-V Builtin ops.
template <typename InterfaceOpTy, spirv::BuiltIn builtin>
-struct HALInterfaceWorkgroupIdAndCountConverter final
+struct HALInterfaceWorkgroupOpsConverter final
: OpConversionPattern<InterfaceOpTy> {
using OpConversionPattern<InterfaceOpTy>::OpConversionPattern;
@@ -656,10 +656,12 @@
// Add IREE HAL interface op conversions.
patterns.add<
HALInterfaceLoadConstantConverter,
- HALInterfaceWorkgroupIdAndCountConverter<
- IREE::HAL::InterfaceWorkgroupIDOp, spirv::BuiltIn::WorkgroupId>,
- HALInterfaceWorkgroupIdAndCountConverter<
- IREE::HAL::InterfaceWorkgroupCountOp, spirv::BuiltIn::NumWorkgroups>>(
+ HALInterfaceWorkgroupOpsConverter<IREE::HAL::InterfaceWorkgroupIDOp,
+ spirv::BuiltIn::WorkgroupId>,
+ HALInterfaceWorkgroupOpsConverter<IREE::HAL::InterfaceWorkgroupSizeOp,
+ spirv::BuiltIn::WorkgroupSize>,
+ HALInterfaceWorkgroupOpsConverter<IREE::HAL::InterfaceWorkgroupCountOp,
+ spirv::BuiltIn::NumWorkgroups>>(
typeConverter, context);
// Performs a prelimiary step to analyze all hal.interface.binding.subspan ops
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir
index b518367..9fcdab1 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir
@@ -247,6 +247,44 @@
#hal.descriptor_set.binding<2, storage_buffer>
]>
]>
+hal.executable private @interface_wg_size {
+ hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb">) {
+ hal.executable.export @interface_wg_size layout(#pipeline_layout) attributes {
+ workgroup_size = [32: index, 1: index, 1: index]
+ }
+ builtin.module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Int64, Shader], []>, #spirv.resource_limits<>>} {
+ func.func @interface_wg_size() {
+ %c0 = arith.constant 0.0 : f32
+ %workgroup_size_x = hal.interface.workgroup.size[0] : index
+ %workgroup_size_y = hal.interface.workgroup.size[1] : index
+ %subspan = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) : memref<64x64xf32, #spirv.storage_class<StorageBuffer>>
+ memref.store %c0, %subspan[%workgroup_size_x, %workgroup_size_y] : memref<64x64xf32, #spirv.storage_class<StorageBuffer>>
+ return
+ }
+ }
+ }
+}
+
+// CHECK-LABEL: spirv.module
+// CHECK-DAG: spirv.GlobalVariable @[[WGSIZE:.+]] built_in("WorkgroupSize")
+// CHECK-DAG: spirv.GlobalVariable @[[BIND:.+]] bind(0, 0)
+// CHECK: %[[CST0:.+]] = spirv.Constant 0.000000e+00 : f32
+// CHECK: %[[ADDR1:.+]] = spirv.mlir.addressof @[[WGSIZE]]
+// CHECK: %[[VAL1:.+]] = spirv.Load "Input" %[[ADDR1:.+]]
+// CHECK: %[[WGSIZEX:.+]] = spirv.CompositeExtract %[[VAL1]][0 : i32]
+// CHECK: %[[ADDR2:.+]] = spirv.mlir.addressof @[[WGSIZE]]
+// CHECK: %[[VAL2:.+]] = spirv.Load "Input" %[[ADDR2:.+]]
+// CHECK: %[[WGSIZEY:.+]] = spirv.CompositeExtract %[[VAL2]][1 : i32]
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>,
+ #hal.descriptor_set.binding<2, storage_buffer>
+ ]>
+]>
hal.executable private @interface_wg_count {
hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb">) {
hal.executable.export @interface_wg_count layout(#pipeline_layout) attributes {