[GPU] Add workgroupMemoryBankCount parameter to TargetWgpAttr (#23273)
This is part of a series of PR's implementing support for XOR swizzles
in IREE. We require the LDS bank count to figure out XOR swizzle
parameters.
See PR: https://github.com/iree-org/iree/pull/23175
---------
Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
diff --git a/compiler/plugins/target/ROCM/test/target_device_features.mlir b/compiler/plugins/target/ROCM/test/target_device_features.mlir
index 052eaf1..8274c85 100644
--- a/compiler/plugins/target/ROCM/test/target_device_features.mlir
+++ b/compiler/plugins/target/ROCM/test/target_device_features.mlir
@@ -72,6 +72,7 @@
// GFX942-SAME: subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
// GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
// GFX942-SAME: max_workgroup_counts = [2147483647, 2147483647, 2147483647],
+// GFX942: workgroup_memory_bank_count = 32
// MI300X: chip = <wgp_count = 304, sku = "mi300x", memory_bandwidth_tbps = 5.300000e+00 : f32, perf_tflops = {fp16 = 1.307400e+03 : f32, fp32 = 1.634000e+02 : f32, fp8 = 2.614900e+03 : f32, int8 = 2.614900e+03 : f32}>>
// MI300A: chip = <wgp_count = 228, sku = "mi300a", memory_bandwidth_tbps = 5.300000e+00 : f32, perf_tflops = {fp16 = 980.599975 : f32, fp32 = 1.226000e+02 : f32, fp8 = 1.961200e+03 : f32, int8 = 1.961200e+03 : f32}>>
// MI308X: chip = <wgp_count = 80, sku = "mi308x", memory_bandwidth_tbps = 5.300000e+00 : f32, perf_tflops = {fp16 = 1.884000e+02 : f32, fp32 = 2.900000e+01 : f32, fp8 = 1.768000e+02 : f32, int8 = 1.768000e+02 : f32}>>
@@ -82,38 +83,46 @@
// GFX950-SAME: scaled_mma = [<intrinsic = MFMA_SCALE_F32_16x16x128_B32, lhs_elem_type = f8E8M0FNU, rhs_elem_type = f8E8M0FNU, acc_elem_type = f32>, <intrinsic = MFMA_SCALE_F32_16x16x128_B32, lhs_elem_type = f8E5M2, rhs_elem_type = f8E5M2, acc_elem_type = f32>, <intrinsic = MFMA_SCALE_F32_16x16x128_B32, lhs_elem_type = f8E5M2FNUZ, rhs_elem_type = f8E5M2FNUZ, acc_elem_type = f32>, <intrinsic = MFMA_SCALE_F32_16x16x128_B32, lhs_elem_type = f8E4M3FN, rhs_elem_type = f8E4M3FN, acc_elem_type = f32>, <intrinsic = MFMA_SCALE_F32_16x16x128_B32, lhs_elem_type = f8E4M3FNUZ, rhs_elem_type = f8E4M3FNUZ, acc_elem_type = f32>, <intrinsic = MFMA_SCALE_F32_16x16x128_B32, lhs_elem_type = f4E2M1FN, rhs_elem_type = f4E2M1FN, acc_elem_type = f32>, <intrinsic = MFMA_SCALE_F32_32x32x64_B32, lhs_elem_type = f8E8M0FNU, rhs_elem_type = f8E8M0FNU, acc_elem_type = f32>, <intrinsic = MFMA_SCALE_F32_32x32x64_B32, lhs_elem_type = f8E5M2, rhs_elem_type = f8E5M2, acc_elem_type = f32>, <intrinsic = MFMA_SCALE_F32_32x32x64_B32, lhs_elem_type = f8E5M2FNUZ, rhs_elem_type = f8E5M2FNUZ, acc_elem_type = f32>, <intrinsic = MFMA_SCALE_F32_32x32x64_B32, lhs_elem_type = f8E4M3FN, rhs_elem_type = f8E4M3FN, acc_elem_type = f32>, <intrinsic = MFMA_SCALE_F32_32x32x64_B32, lhs_elem_type = f8E4M3FNUZ, rhs_elem_type = f8E4M3FNUZ, acc_elem_type = f32>, <intrinsic = MFMA_SCALE_F32_32x32x64_B32, lhs_elem_type = f4E2M1FN, rhs_elem_type = f4E2M1FN, acc_elem_type = f32>],
// GFX950-SAME: subgroup_size_choices = [64],
// GFX950-SAME: max_workgroup_memory_bytes = 163840,
+// GFX950: workgroup_memory_bank_count = 64
// MI350X: chip = <wgp_count = 256, sku = "mi350x", memory_bandwidth_tbps = 8.000000e+00 : f32, perf_tflops = {fp16 = 2.300000e+03 : f32, fp32 = 1.442000e+02 : f32, fp4 = 9.200000e+03 : f32, fp6 = 9.200000e+03 : f32, fp8 = 4.600000e+03 : f32, int8 = 4.600000e+03 : f32}>>
// MI355X: chip = <wgp_count = 256, sku = "mi355x", memory_bandwidth_tbps = 8.000000e+00 : f32, perf_tflops = {fp16 = 2.500000e+03 : f32, fp32 = 1.573000e+02 : f32, fp4 = 1.000000e+04 : f32, fp6 = 1.000000e+04 : f32, fp8 = 5.000000e+03 : f32, int8 = 5.000000e+03 : f32}>>
// GFX1100: target_info = #iree_gpu.target<arch = "gfx1100",
// GFX1100-SAME: mma = [<WMMAR3_F32_16x16x16_F16>, <WMMAR3_F16_16x16x16_F16>, <WMMAR3_F32_16x16x16_BF16>, <WMMAR3_BF16_16x16x16_BF16>, <WMMAR3_I32_16x16x16_I8>]
// GFX1100-SAME: subgroup_size_choices = [32, 64]
+// GFX1100: workgroup_memory_bank_count = 64
// GFX1101: target_info = #iree_gpu.target<arch = "gfx1101",
// GFX1101-SAME: mma = [<WMMAR3_F32_16x16x16_F16>, <WMMAR3_F16_16x16x16_F16>, <WMMAR3_F32_16x16x16_BF16>, <WMMAR3_BF16_16x16x16_BF16>, <WMMAR3_I32_16x16x16_I8>]
// GFX1101-SAME: subgroup_size_choices = [32, 64]
+// GFX1101: workgroup_memory_bank_count = 64
// GFX1103: target_info = #iree_gpu.target<arch = "gfx1103",
// GFX1103-SAME: mma = [<WMMAR3_F32_16x16x16_F16>, <WMMAR3_F16_16x16x16_F16>, <WMMAR3_F32_16x16x16_BF16>, <WMMAR3_BF16_16x16x16_BF16>, <WMMAR3_I32_16x16x16_I8>]
// GFX1103-SAME: subgroup_size_choices = [32, 64]
+// GFX1103: workgroup_memory_bank_count = 64
// GFX1150: target_info = #iree_gpu.target<arch = "gfx1150",
// GFX1150-SAME: mma = [<WMMAR3_F32_16x16x16_F16>, <WMMAR3_F16_16x16x16_F16>, <WMMAR3_F32_16x16x16_BF16>, <WMMAR3_BF16_16x16x16_BF16>, <WMMAR3_I32_16x16x16_I8>]
// GFX1150-SAME: subgroup_size_choices = [32, 64]
+// GFX1150: workgroup_memory_bank_count = 64
// GFX1151: target_info = #iree_gpu.target<arch = "gfx1151",
// GFX1151-SAME: mma = [<WMMAR3_F32_16x16x16_F16>, <WMMAR3_F16_16x16x16_F16>, <WMMAR3_F32_16x16x16_BF16>, <WMMAR3_BF16_16x16x16_BF16>, <WMMAR3_I32_16x16x16_I8>]
// GFX1151-SAME: subgroup_size_choices = [32, 64]
+// GFX1151: workgroup_memory_bank_count = 64
// GFX1200: target_info = #iree_gpu.target<arch = "gfx1200",
// GFX1200-SAME: mma = [<WMMAR4_F32_16x16x16_F16>, <WMMAR4_F16_16x16x16_F16>, <WMMAR4_F32_16x16x16_BF16>, <WMMAR4_BF16_16x16x16_BF16>, <WMMAR4_F32_16x16x16_F8E5M2>, <WMMAR4_F32_16x16x16_F8E5M2_F8E4M3FN>, <WMMAR4_F32_16x16x16_F8E4M3FN>, <WMMAR4_F32_16x16x16_F8E4M3FN_F8E5M2>, <WMMAR4_I32_16x16x16_I8>]
// GFX1200-SAME: subgroup_size_choices = [32, 64]
+// GFX1200: workgroup_memory_bank_count = 64
//
// RX9060XT: chip = <wgp_count = 16, sku = "rx9060xt", memory_bandwidth_tbps = 3.200000e-01 : f32, perf_tflops = {fp16 = 1.030000e+02 : f32, fp32 = 2.560000e+01 : f32, fp8 = 2.050000e+02 : f32, int8 = 2.050000e+02 : f32}>>
// GFX1201: target_info = #iree_gpu.target<arch = "gfx1201",
// GFX1201-SAME: mma = [<WMMAR4_F32_16x16x16_F16>, <WMMAR4_F16_16x16x16_F16>, <WMMAR4_F32_16x16x16_BF16>, <WMMAR4_BF16_16x16x16_BF16>, <WMMAR4_F32_16x16x16_F8E5M2>, <WMMAR4_F32_16x16x16_F8E5M2_F8E4M3FN>, <WMMAR4_F32_16x16x16_F8E4M3FN>, <WMMAR4_F32_16x16x16_F8E4M3FN_F8E5M2>, <WMMAR4_I32_16x16x16_I8>]
// GFX1201-SAME: subgroup_size_choices = [32, 64]
+// GFX1201: workgroup_memory_bank_count = 64
//
// RX9070XT: chip = <wgp_count = 32, sku = "rx9070xt", memory_bandwidth_tbps = 6.400000e-01 : f32, perf_tflops = {fp16 = 1.950000e+02 : f32, fp32 = 4.870000e+01 : f32, fp8 = 3.890000e+02 : f32, int8 = 3.890000e+02 : f32}>>
// RX9070: chip = <wgp_count = 28, sku = "rx9070", memory_bandwidth_tbps = 6.400000e-01 : f32, perf_tflops = {fp16 = 1.450000e+02 : f32, fp32 = 3.610000e+01 : f32, fp8 = 2.890000e+02 : f32, int8 = 2.890000e+02 : f32}>>
@@ -130,6 +139,7 @@
// GFX1250-SAME: <WMMA_F16_16x16x128_F8E5M2>, <WMMA_F16_16x16x128_F8E5M2_F8E4M3FN>, <WMMA_F16_16x16x128_F8E4M3FN>, <WMMA_F16_16x16x128_F8E4M3FN_F8E5M2>]
// GFX1250-SAME: subgroup_size_choices = [32]
// GFX1250-SAME: max_load_instruction_bits = 128, simds_per_wgp = 4
+// GFX1250: workgroup_memory_bank_count = 64
stream.executable public @reduce_dispatch {
stream.executable.export @reduce_dispatch workgroups() -> (index, index, index) {
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
index c1379d8..68130ff 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
@@ -754,6 +754,8 @@
OptionalParameter<"std::optional<int32_t>">:$vgpr_space_bits,
// Available bit-widths for direct load from global to LDS memory.
OptionalParameter<"DenseI64ArrayAttr">:$dma_sizes,
+ // Number of banks in LDS.
+ OptionalParameter<"std::optional<int32_t>">:$workgroup_memory_bank_count,
// An optional extra dict
// This field allows to inject more features/limits not supported in the
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/target_attrs.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/target_attrs.mlir
index b2204f5..cde8c25 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/target_attrs.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/target_attrs.mlir
@@ -13,7 +13,8 @@
// CHECK-SAME: max_thread_count_per_workgroup = 1024,
// CHECK-SAME: max_workgroup_memory_bytes = 65536,
// CHECK-SAME: max_workgroup_counts = [2147483647, 2147483647, 2147483647],
- // CHECK-SAME: dma_sizes = [32, 128]>
+ // CHECK-SAME: dma_sizes = [32, 128],
+ // CHECK-SAME: workgroup_memory_bank_count = 32>
wgp = #iree_gpu.target_wgp<
compute = fp16|fp32|int8, storage = b16|b32,
subgroup = shuffle|arithmetic, dot = dp4xi8toi32,
@@ -23,7 +24,8 @@
max_thread_count_per_workgroup = 1024,
max_workgroup_memory_bytes = 65536,
max_workgroup_counts = [2147483647, 2147483647, 2147483647],
- dma_sizes = [32, 128]
+ dma_sizes = [32, 128],
+ workgroup_memory_bank_count = 32
>
} { return }
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
index 4093391..6746b03 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
@@ -57,6 +57,7 @@
std::optional<int32_t> simdsPerWgp;
std::optional<int32_t> vgprSpaceBits;
std::optional<ArrayRef<int64_t>> dmaSizes;
+ std::optional<int32_t> workgroupMemoryBankCount;
};
// Chip level feature/limit details
@@ -150,7 +151,7 @@
wgp->maxThreadSize, wgp->maxWorkgroupMemoryBytes,
DenseI32ArrayAttr::get(context, wgp->maxWorkgroupCounts),
wgp->maxLoadInstructionBits, wgp->simdsPerWgp, wgp->vgprSpaceBits,
- dmaSizesAttr, DictionaryAttr{});
+ dmaSizesAttr, wgp->workgroupMemoryBankCount, DictionaryAttr{});
TargetChipAttr targetChip;
if (details.chip) {
@@ -252,7 +253,8 @@
/*maxLoadInstructionBits=*/128,
/*simdsPerWgp=*/4,
/*vgprSpaceBits=*/512 * 32,
- /*dmaSizes=*/ArrayRef<int64_t>(cdna4DMASizes)};
+ /*dmaSizes=*/ArrayRef<int64_t>(cdna4DMASizes),
+ /*workgroupMemoryBankCount=*/64};
return &cdna4Wgp;
}
@@ -297,7 +299,8 @@
/*maxLoadInstructionBits=*/128,
/*simdsPerWgp=*/4,
/*vgprSpaceBits=*/512 * 32,
- /*dmaSizes=*/ArrayRef<int64_t>(cdna3DMASizes)};
+ /*dmaSizes=*/ArrayRef<int64_t>(cdna3DMASizes),
+ /*workgroupMemoryBankCount=*/32};
return &cdna3Wgp;
}
@@ -329,7 +332,9 @@
{0x7fffffff, 0x7fffffff, 0x7fffffff},
/*maxLoadInstructionBits=*/128,
/*simdsPerWgp=*/4,
- /*vgprSpaceBits=*/256 * 32};
+ /*vgprSpaceBits=*/256 * 32,
+ /*dmaSizes=*/std::nullopt,
+ /*workgroupMemoryBankCount=*/32};
return &cdna2Wgp;
}
@@ -354,7 +359,9 @@
{0x7fffffff, 0x7fffffff, 0x7fffffff},
/*maxLoadInstructionBits=*/128,
/*simdsPerWgp=*/4,
- /*vgprSpaceBits=*/256 * 32};
+ /*vgprSpaceBits=*/256 * 32,
+ /*dmaSizes=*/std::nullopt,
+ /*workgroupMemoryBankCount=*/32};
return &cdna1Wgp;
}
@@ -385,7 +392,9 @@
{0x7fffffff, 0x7fffffff, 0x7fffffff},
/*maxLoadInstructionBits=*/128,
/*simdsPerWgp=*/4,
- /*vgprSpaceBits=*/256 * 32};
+ /*vgprSpaceBits=*/256 * 32,
+ /*dmaSizes=*/std::nullopt,
+ /*workgroupMemoryBankCount=*/64};
return &rdna4Wgp;
}
@@ -413,7 +422,9 @@
{0x7fffffff, 0x7fffffff, 0x7fffffff},
/*maxLoadInstructionBits=*/128,
/*simdsPerWgp=*/4,
- /*vgprSpaceBits=*/256 * 32};
+ /*vgprSpaceBits=*/256 * 32,
+ /*dmaSizes=*/std::nullopt,
+ /*workgroupMemoryBankCount=*/64};
return &rdna3Wgp;
}
@@ -497,7 +508,9 @@
{0x7fffffff, 0x7fffffff, 0x7fffffff},
/*maxLoadInstructionBits=*/128,
/*simdsPerWgp=*/4,
- /*vgprSpaceBits=*/256 * 32};
+ /*vgprSpaceBits=*/256 * 32,
+ /*dmaSizes=*/std::nullopt,
+ /*workgroupMemoryBankCount=*/64};
return &gfx1250Wgp;
}