[GPU] Add workgroupMemoryBankCount parameter to TargetWgpAttr (#23273)

This is part of a series of PR's implementing support for XOR swizzles
in IREE. We require the LDS bank count to figure out XOR swizzle
parameters.

See PR: https://github.com/iree-org/iree/pull/23175

---------

Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
diff --git a/compiler/plugins/target/ROCM/test/target_device_features.mlir b/compiler/plugins/target/ROCM/test/target_device_features.mlir
index 052eaf1..8274c85 100644
--- a/compiler/plugins/target/ROCM/test/target_device_features.mlir
+++ b/compiler/plugins/target/ROCM/test/target_device_features.mlir
@@ -72,6 +72,7 @@
 // GFX942-SAME:         subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
 // GFX942-SAME:         max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
 // GFX942-SAME:         max_workgroup_counts = [2147483647, 2147483647, 2147483647],
+// GFX942:              workgroup_memory_bank_count = 32
 // MI300X: chip = <wgp_count = 304, sku = "mi300x", memory_bandwidth_tbps = 5.300000e+00 : f32, perf_tflops = {fp16 = 1.307400e+03 : f32, fp32 = 1.634000e+02 : f32, fp8 = 2.614900e+03 : f32, int8 = 2.614900e+03 : f32}>>
 // MI300A: chip = <wgp_count = 228, sku = "mi300a", memory_bandwidth_tbps = 5.300000e+00 : f32, perf_tflops = {fp16 = 980.599975 : f32, fp32 = 1.226000e+02 : f32, fp8 = 1.961200e+03 : f32, int8 = 1.961200e+03 : f32}>>
 // MI308X: chip = <wgp_count = 80, sku = "mi308x", memory_bandwidth_tbps = 5.300000e+00 : f32, perf_tflops = {fp16 = 1.884000e+02 : f32, fp32 = 2.900000e+01 : f32, fp8 = 1.768000e+02 : f32, int8 = 1.768000e+02 : f32}>>
@@ -82,38 +83,46 @@
 // GFX950-SAME:         scaled_mma = [<intrinsic = MFMA_SCALE_F32_16x16x128_B32, lhs_elem_type = f8E8M0FNU, rhs_elem_type = f8E8M0FNU, acc_elem_type = f32>, <intrinsic = MFMA_SCALE_F32_16x16x128_B32, lhs_elem_type = f8E5M2, rhs_elem_type = f8E5M2, acc_elem_type = f32>, <intrinsic = MFMA_SCALE_F32_16x16x128_B32, lhs_elem_type = f8E5M2FNUZ, rhs_elem_type = f8E5M2FNUZ, acc_elem_type = f32>, <intrinsic = MFMA_SCALE_F32_16x16x128_B32, lhs_elem_type = f8E4M3FN, rhs_elem_type = f8E4M3FN, acc_elem_type = f32>, <intrinsic = MFMA_SCALE_F32_16x16x128_B32, lhs_elem_type = f8E4M3FNUZ, rhs_elem_type = f8E4M3FNUZ, acc_elem_type = f32>, <intrinsic = MFMA_SCALE_F32_16x16x128_B32, lhs_elem_type = f4E2M1FN, rhs_elem_type = f4E2M1FN, acc_elem_type = f32>, <intrinsic = MFMA_SCALE_F32_32x32x64_B32, lhs_elem_type = f8E8M0FNU, rhs_elem_type = f8E8M0FNU, acc_elem_type = f32>, <intrinsic = MFMA_SCALE_F32_32x32x64_B32, lhs_elem_type = f8E5M2, rhs_elem_type = f8E5M2, acc_elem_type = f32>, <intrinsic = MFMA_SCALE_F32_32x32x64_B32, lhs_elem_type = f8E5M2FNUZ, rhs_elem_type = f8E5M2FNUZ, acc_elem_type = f32>, <intrinsic = MFMA_SCALE_F32_32x32x64_B32, lhs_elem_type = f8E4M3FN, rhs_elem_type = f8E4M3FN, acc_elem_type = f32>, <intrinsic = MFMA_SCALE_F32_32x32x64_B32, lhs_elem_type = f8E4M3FNUZ, rhs_elem_type = f8E4M3FNUZ, acc_elem_type = f32>, <intrinsic = MFMA_SCALE_F32_32x32x64_B32, lhs_elem_type = f4E2M1FN, rhs_elem_type = f4E2M1FN, acc_elem_type = f32>],
 // GFX950-SAME:         subgroup_size_choices = [64],
 // GFX950-SAME:         max_workgroup_memory_bytes = 163840,
+// GFX950:              workgroup_memory_bank_count = 64
 // MI350X: chip = <wgp_count = 256, sku = "mi350x", memory_bandwidth_tbps = 8.000000e+00 : f32, perf_tflops = {fp16 = 2.300000e+03 : f32, fp32 = 1.442000e+02 : f32, fp4 = 9.200000e+03 : f32, fp6 = 9.200000e+03 : f32, fp8 = 4.600000e+03 : f32, int8 = 4.600000e+03 : f32}>>
 // MI355X: chip = <wgp_count = 256, sku = "mi355x", memory_bandwidth_tbps = 8.000000e+00 : f32, perf_tflops = {fp16 = 2.500000e+03 : f32, fp32 = 1.573000e+02 : f32, fp4 = 1.000000e+04 : f32, fp6 = 1.000000e+04 : f32, fp8 = 5.000000e+03 : f32, int8 = 5.000000e+03 : f32}>>
 
 // GFX1100: target_info = #iree_gpu.target<arch = "gfx1100",
 // GFX1100-SAME:        mma = [<WMMAR3_F32_16x16x16_F16>, <WMMAR3_F16_16x16x16_F16>, <WMMAR3_F32_16x16x16_BF16>, <WMMAR3_BF16_16x16x16_BF16>, <WMMAR3_I32_16x16x16_I8>]
 // GFX1100-SAME:        subgroup_size_choices = [32, 64]
+// GFX1100:             workgroup_memory_bank_count = 64
 
 // GFX1101: target_info = #iree_gpu.target<arch = "gfx1101",
 // GFX1101-SAME:        mma = [<WMMAR3_F32_16x16x16_F16>, <WMMAR3_F16_16x16x16_F16>, <WMMAR3_F32_16x16x16_BF16>, <WMMAR3_BF16_16x16x16_BF16>, <WMMAR3_I32_16x16x16_I8>]
 // GFX1101-SAME:        subgroup_size_choices = [32, 64]
+// GFX1101:             workgroup_memory_bank_count = 64
 
 // GFX1103: target_info = #iree_gpu.target<arch = "gfx1103",
 // GFX1103-SAME:        mma = [<WMMAR3_F32_16x16x16_F16>, <WMMAR3_F16_16x16x16_F16>, <WMMAR3_F32_16x16x16_BF16>, <WMMAR3_BF16_16x16x16_BF16>, <WMMAR3_I32_16x16x16_I8>]
 // GFX1103-SAME:        subgroup_size_choices = [32, 64]
+// GFX1103:             workgroup_memory_bank_count = 64
 
 // GFX1150: target_info = #iree_gpu.target<arch = "gfx1150",
 // GFX1150-SAME:        mma = [<WMMAR3_F32_16x16x16_F16>, <WMMAR3_F16_16x16x16_F16>, <WMMAR3_F32_16x16x16_BF16>, <WMMAR3_BF16_16x16x16_BF16>, <WMMAR3_I32_16x16x16_I8>]
 // GFX1150-SAME:        subgroup_size_choices = [32, 64]
+// GFX1150:             workgroup_memory_bank_count = 64
 
 // GFX1151: target_info = #iree_gpu.target<arch = "gfx1151",
 // GFX1151-SAME:        mma = [<WMMAR3_F32_16x16x16_F16>, <WMMAR3_F16_16x16x16_F16>, <WMMAR3_F32_16x16x16_BF16>, <WMMAR3_BF16_16x16x16_BF16>, <WMMAR3_I32_16x16x16_I8>]
 // GFX1151-SAME:        subgroup_size_choices = [32, 64]
+// GFX1151:             workgroup_memory_bank_count = 64
 
 // GFX1200: target_info = #iree_gpu.target<arch = "gfx1200",
 // GFX1200-SAME:        mma = [<WMMAR4_F32_16x16x16_F16>, <WMMAR4_F16_16x16x16_F16>, <WMMAR4_F32_16x16x16_BF16>, <WMMAR4_BF16_16x16x16_BF16>, <WMMAR4_F32_16x16x16_F8E5M2>, <WMMAR4_F32_16x16x16_F8E5M2_F8E4M3FN>, <WMMAR4_F32_16x16x16_F8E4M3FN>, <WMMAR4_F32_16x16x16_F8E4M3FN_F8E5M2>,  <WMMAR4_I32_16x16x16_I8>]
 // GFX1200-SAME:        subgroup_size_choices = [32, 64]
+// GFX1200:             workgroup_memory_bank_count = 64
 //
 // RX9060XT: chip = <wgp_count = 16, sku = "rx9060xt", memory_bandwidth_tbps = 3.200000e-01 : f32, perf_tflops = {fp16 = 1.030000e+02 : f32, fp32 = 2.560000e+01 : f32, fp8 = 2.050000e+02 : f32, int8 = 2.050000e+02 : f32}>>
 
 // GFX1201: target_info = #iree_gpu.target<arch = "gfx1201",
 // GFX1201-SAME:        mma = [<WMMAR4_F32_16x16x16_F16>, <WMMAR4_F16_16x16x16_F16>, <WMMAR4_F32_16x16x16_BF16>, <WMMAR4_BF16_16x16x16_BF16>, <WMMAR4_F32_16x16x16_F8E5M2>, <WMMAR4_F32_16x16x16_F8E5M2_F8E4M3FN>, <WMMAR4_F32_16x16x16_F8E4M3FN>, <WMMAR4_F32_16x16x16_F8E4M3FN_F8E5M2>,  <WMMAR4_I32_16x16x16_I8>]
 // GFX1201-SAME:        subgroup_size_choices = [32, 64]
+// GFX1201:             workgroup_memory_bank_count = 64
 //
 // RX9070XT: chip = <wgp_count = 32, sku = "rx9070xt", memory_bandwidth_tbps = 6.400000e-01 : f32, perf_tflops = {fp16 = 1.950000e+02 : f32, fp32 = 4.870000e+01 : f32, fp8 = 3.890000e+02 : f32, int8 = 3.890000e+02 : f32}>>
 // RX9070:   chip = <wgp_count = 28, sku = "rx9070", memory_bandwidth_tbps = 6.400000e-01 : f32, perf_tflops = {fp16 = 1.450000e+02 : f32, fp32 = 3.610000e+01 : f32, fp8 = 2.890000e+02 : f32, int8 = 2.890000e+02 : f32}>>
@@ -130,6 +139,7 @@
 // GFX1250-SAME:               <WMMA_F16_16x16x128_F8E5M2>, <WMMA_F16_16x16x128_F8E5M2_F8E4M3FN>, <WMMA_F16_16x16x128_F8E4M3FN>, <WMMA_F16_16x16x128_F8E4M3FN_F8E5M2>]
 // GFX1250-SAME:        subgroup_size_choices = [32]
 // GFX1250-SAME:        max_load_instruction_bits = 128, simds_per_wgp = 4
+// GFX1250:             workgroup_memory_bank_count = 64
 
 stream.executable public @reduce_dispatch {
   stream.executable.export @reduce_dispatch workgroups() -> (index, index, index) {
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
index c1379d8..68130ff 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
@@ -754,6 +754,8 @@
     OptionalParameter<"std::optional<int32_t>">:$vgpr_space_bits,
     // Available bit-widths for direct load from global to LDS memory.
     OptionalParameter<"DenseI64ArrayAttr">:$dma_sizes,
+    // Number of banks in LDS.
+    OptionalParameter<"std::optional<int32_t>">:$workgroup_memory_bank_count,
 
     // An optional extra dict
     // This field allows to inject more features/limits not supported in the
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/target_attrs.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/target_attrs.mlir
index b2204f5..cde8c25 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/target_attrs.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/target_attrs.mlir
@@ -13,7 +13,8 @@
   // CHECK-SAME: max_thread_count_per_workgroup = 1024,
   // CHECK-SAME: max_workgroup_memory_bytes = 65536,
   // CHECK-SAME: max_workgroup_counts = [2147483647, 2147483647, 2147483647],
-  // CHECK-SAME: dma_sizes = [32, 128]>
+  // CHECK-SAME: dma_sizes = [32, 128],
+  // CHECK-SAME: workgroup_memory_bank_count = 32>
   wgp = #iree_gpu.target_wgp<
     compute = fp16|fp32|int8, storage = b16|b32,
     subgroup = shuffle|arithmetic, dot = dp4xi8toi32,
@@ -23,7 +24,8 @@
     max_thread_count_per_workgroup = 1024,
     max_workgroup_memory_bytes = 65536,
     max_workgroup_counts = [2147483647, 2147483647, 2147483647],
-    dma_sizes = [32, 128]
+    dma_sizes = [32, 128],
+    workgroup_memory_bank_count = 32
   >
 } { return }
 
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
index 4093391..6746b03 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
@@ -57,6 +57,7 @@
   std::optional<int32_t> simdsPerWgp;
   std::optional<int32_t> vgprSpaceBits;
   std::optional<ArrayRef<int64_t>> dmaSizes;
+  std::optional<int32_t> workgroupMemoryBankCount;
 };
 
 // Chip level feature/limit details
@@ -150,7 +151,7 @@
       wgp->maxThreadSize, wgp->maxWorkgroupMemoryBytes,
       DenseI32ArrayAttr::get(context, wgp->maxWorkgroupCounts),
       wgp->maxLoadInstructionBits, wgp->simdsPerWgp, wgp->vgprSpaceBits,
-      dmaSizesAttr, DictionaryAttr{});
+      dmaSizesAttr, wgp->workgroupMemoryBankCount, DictionaryAttr{});
 
   TargetChipAttr targetChip;
   if (details.chip) {
@@ -252,7 +253,8 @@
       /*maxLoadInstructionBits=*/128,
       /*simdsPerWgp=*/4,
       /*vgprSpaceBits=*/512 * 32,
-      /*dmaSizes=*/ArrayRef<int64_t>(cdna4DMASizes)};
+      /*dmaSizes=*/ArrayRef<int64_t>(cdna4DMASizes),
+      /*workgroupMemoryBankCount=*/64};
   return &cdna4Wgp;
 }
 
@@ -297,7 +299,8 @@
       /*maxLoadInstructionBits=*/128,
       /*simdsPerWgp=*/4,
       /*vgprSpaceBits=*/512 * 32,
-      /*dmaSizes=*/ArrayRef<int64_t>(cdna3DMASizes)};
+      /*dmaSizes=*/ArrayRef<int64_t>(cdna3DMASizes),
+      /*workgroupMemoryBankCount=*/32};
   return &cdna3Wgp;
 }
 
@@ -329,7 +332,9 @@
                                       {0x7fffffff, 0x7fffffff, 0x7fffffff},
                                       /*maxLoadInstructionBits=*/128,
                                       /*simdsPerWgp=*/4,
-                                      /*vgprSpaceBits=*/256 * 32};
+                                      /*vgprSpaceBits=*/256 * 32,
+                                      /*dmaSizes=*/std::nullopt,
+                                      /*workgroupMemoryBankCount=*/32};
   return &cdna2Wgp;
 }
 
@@ -354,7 +359,9 @@
                                       {0x7fffffff, 0x7fffffff, 0x7fffffff},
                                       /*maxLoadInstructionBits=*/128,
                                       /*simdsPerWgp=*/4,
-                                      /*vgprSpaceBits=*/256 * 32};
+                                      /*vgprSpaceBits=*/256 * 32,
+                                      /*dmaSizes=*/std::nullopt,
+                                      /*workgroupMemoryBankCount=*/32};
   return &cdna1Wgp;
 }
 
@@ -385,7 +392,9 @@
                                       {0x7fffffff, 0x7fffffff, 0x7fffffff},
                                       /*maxLoadInstructionBits=*/128,
                                       /*simdsPerWgp=*/4,
-                                      /*vgprSpaceBits=*/256 * 32};
+                                      /*vgprSpaceBits=*/256 * 32,
+                                      /*dmaSizes=*/std::nullopt,
+                                      /*workgroupMemoryBankCount=*/64};
   return &rdna4Wgp;
 }
 
@@ -413,7 +422,9 @@
                                       {0x7fffffff, 0x7fffffff, 0x7fffffff},
                                       /*maxLoadInstructionBits=*/128,
                                       /*simdsPerWgp=*/4,
-                                      /*vgprSpaceBits=*/256 * 32};
+                                      /*vgprSpaceBits=*/256 * 32,
+                                      /*dmaSizes=*/std::nullopt,
+                                      /*workgroupMemoryBankCount=*/64};
   return &rdna3Wgp;
 }
 
@@ -497,7 +508,9 @@
                                         {0x7fffffff, 0x7fffffff, 0x7fffffff},
                                         /*maxLoadInstructionBits=*/128,
                                         /*simdsPerWgp=*/4,
-                                        /*vgprSpaceBits=*/256 * 32};
+                                        /*vgprSpaceBits=*/256 * 32,
+                                        /*dmaSizes=*/std::nullopt,
+                                        /*workgroupMemoryBankCount=*/64};
   return &gfx1250Wgp;
 }