[Codegen][Tuner] expose python binding to query target info (#21782)

Motivated by issue #2048, this PR exposes the python bindings to query
relevant target info, which will be used to do constraint generation in
the tuner.

---------

Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
diff --git a/compiler/bindings/c/iree/compiler/dialects/iree_gpu.h b/compiler/bindings/c/iree/compiler/dialects/iree_gpu.h
index c6afe11..a8bc342 100644
--- a/compiler/bindings/c/iree/compiler/dialects/iree_gpu.h
+++ b/compiler/bindings/c/iree/compiler/dialects/iree_gpu.h
@@ -144,6 +144,18 @@
 MLIR_CAPI_EXPORTED ireeGPUMMASingleSubgroupLayout
 ireeGPUGetSingleSubgroupLayout(MlirAttribute attr, uint32_t fragment);
 
+struct ireeGPUTargetInfo {
+  MlirIdentifier arch;                // E.g., "gfx942".
+  MlirAttribute subgroupSizeChoices;  // Subgroup size choices.
+  MlirAttribute maxWorkgroupSizes;    // Max threads per X/Y/Z dimension.
+  int64_t maxThreadCountPerWorkgroup; // Max threads per workgroup.
+  int64_t maxWorkgroupMemoryBytes;    // Max workgroup memory.
+};
+
+// Queries GPU target info from the given `ExecutableTargetAttr` attribute.
+MLIR_CAPI_EXPORTED ireeGPUTargetInfo
+ireeHALExecutableTargetAttrGetGPUTargetInfo(MlirAttribute attr);
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/compiler/bindings/python/IREECompilerDialectsModule.cpp b/compiler/bindings/python/IREECompilerDialectsModule.cpp
index f7a4380..27bb8cc 100644
--- a/compiler/bindings/python/IREECompilerDialectsModule.cpp
+++ b/compiler/bindings/python/IREECompilerDialectsModule.cpp
@@ -505,6 +505,38 @@
           });
 
   //===-------------------------------------------------------------------===//
+  // Binding to query target info
+  //===-------------------------------------------------------------------===//
+
+  py::class_<ireeGPUTargetInfo>(iree_gpu_module, "TargetInfo")
+      .def_prop_ro("arch",
+                   [](const ireeGPUTargetInfo &self) -> std::string {
+                     MlirStringRef strRef = mlirIdentifierStr(self.arch);
+                     return std::string(strRef.data, strRef.length);
+                   })
+      .def_prop_ro("subgroup_size_choices",
+                   [](const ireeGPUTargetInfo &self) -> std::vector<int64_t> {
+                     return getIntArrayAttrValues(self.subgroupSizeChoices);
+                   })
+      .def_prop_ro("max_thread_count_per_workgroup",
+                   [](const ireeGPUTargetInfo &self) -> int64_t {
+                     return self.maxThreadCountPerWorkgroup;
+                   })
+      .def_prop_ro("max_workgroup_sizes",
+                   [](const ireeGPUTargetInfo &self) -> std::vector<int64_t> {
+                     return getIntArrayAttrValues(self.maxWorkgroupSizes);
+                   })
+      .def_prop_ro("max_workgroup_memory_bytes",
+                   [](const ireeGPUTargetInfo &self) -> int64_t {
+                     return self.maxWorkgroupMemoryBytes;
+                   });
+
+  iree_gpu_module.def(
+      "get_gpu_target_info", &ireeHALExecutableTargetAttrGetGPUTargetInfo,
+      "Extracts GPU target information from an executable target attribute.",
+      py::arg("executable_target_attr"));
+
+  //===-------------------------------------------------------------------===//
   // Binding to utility function getSingleSubgroupLayout
   //===-------------------------------------------------------------------===//
   py::class_<ireeGPUMMASingleSubgroupLayout>(iree_gpu_module,
@@ -592,12 +624,7 @@
       });
 
   iree_codegen_module.def(
-      "get_attention_op_detail",
-      [](MlirAffineMap q, MlirAffineMap k, MlirAffineMap v, MlirAffineMap o) {
-        ireeCodegenAttentionOpDetail result =
-            ireeCodegenGetAttentionOpDetail(q, k, v, o);
-        return result;
-      },
+      "get_attention_op_detail", &ireeCodegenGetAttentionOpDetail,
       "Infers the structure of an attention operation from affine indexing "
       "maps.",
       py::arg("q"), py::arg("k"), py::arg("v"), py::arg("o"));
diff --git a/compiler/bindings/python/test/ir/dialects_test.py b/compiler/bindings/python/test/ir/dialects_test.py
index 86ad071..4f5060e 100644
--- a/compiler/bindings/python/test/ir/dialects_test.py
+++ b/compiler/bindings/python/test/ir/dialects_test.py
@@ -391,3 +391,68 @@
     assert compilation_info is not None
     assert compilation_info.lowering_config == lowering_config
     assert compilation_info.translation_info == translation_info
+
+
+@run
+def gpu_target_info_attribute_parsing():
+    mlir_string = """
+    hal.executable private @main_dispatch_0 {
+        hal.executable.variant public @rocm_hsaco_fb
+            target(<"rocm", "rocm-hsaco-fb",
+                {
+                abi = "hip",
+                iree_codegen.target_info = #iree_gpu.target<
+                    arch = "gfx942",
+                    features = "",
+                    wgp = <
+                    compute = fp64,
+                    storage = b64,
+                    subgroup = none,
+                    dot = none,
+                    mma = [<MFMA_F32_16x16x4_F32>],
+                    subgroup_size_choices = [32, 64],
+                    max_workgroup_sizes = [256, 512, 1024],
+                    max_thread_count_per_workgroup = 1024,
+                    max_workgroup_memory_bytes = 65536,
+                    max_workgroup_counts = [256, 512, 1024]
+                    >
+                >
+                }>
+            ) {
+        }
+    }
+    """
+
+    module = ir.Module.parse(mlir_string)
+    variant_op_list = iree_codegen.get_executable_variant_ops(module)
+    assert len(variant_op_list) == 1, "Expect one executable variant op"
+    variant_op = variant_op_list[0]
+    executable_variant_op = variant_op.opview
+    target = executable_variant_op.target
+    gpu_target_info = iree_gpu.get_gpu_target_info(target)
+
+    arch = gpu_target_info.arch
+    assert arch == "gfx942", f"Expected arch 'gfx942', got '{arch}'"
+
+    subgroup_size_choices = gpu_target_info.subgroup_size_choices
+    assert subgroup_size_choices == [
+        32,
+        64,
+    ], f"Expected subgroup_size_choice [32, 64], got {subgroup_size_choices}"
+
+    max_thread_count = gpu_target_info.max_thread_count_per_workgroup
+    assert (
+        max_thread_count == 1024
+    ), f"Expected max_thread_count_per_workgroup 1024, got {max_thread_count}"
+
+    max_memory_bytes = gpu_target_info.max_workgroup_memory_bytes
+    assert (
+        max_memory_bytes == 65536
+    ), f"Expected max_workgroup_memory_bytes 65536, got {max_memory_bytes}"
+
+    max_workgroup_sizes = gpu_target_info.max_workgroup_sizes
+    assert max_workgroup_sizes == [
+        256,
+        512,
+        1024,
+    ], f"Expected max_workgroup_sizes [256, 512, 1024], got {max_workgroup_sizes}"
diff --git a/compiler/src/iree/compiler/API/Internal/BUILD.bazel b/compiler/src/iree/compiler/API/Internal/BUILD.bazel
index 0daba3f..fcd4afc 100644
--- a/compiler/src/iree/compiler/API/Internal/BUILD.bazel
+++ b/compiler/src/iree/compiler/API/Internal/BUILD.bazel
@@ -156,6 +156,7 @@
     deps = [
         "//compiler/bindings/c:headers",
         "//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
+        "//compiler/src/iree/compiler/Codegen/Utils",
         "@llvm-project//mlir:CAPIIR",
         "@llvm-project//mlir:CAPIIRHeaders",
         "@llvm-project//mlir:IR",
diff --git a/compiler/src/iree/compiler/API/Internal/CMakeLists.txt b/compiler/src/iree/compiler/API/Internal/CMakeLists.txt
index 3e77c4a..7a34949 100644
--- a/compiler/src/iree/compiler/API/Internal/CMakeLists.txt
+++ b/compiler/src/iree/compiler/API/Internal/CMakeLists.txt
@@ -135,6 +135,7 @@
     MLIRCAPIIR
     MLIRIR
     iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
+    iree::compiler::Codegen::Utils
     iree::compiler::bindings::c::headers
   PUBLIC
 )
diff --git a/compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp b/compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp
index 7810202..be04348 100644
--- a/compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp
+++ b/compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp
@@ -10,6 +10,7 @@
 #include "iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h"
 #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
 #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
 #include "iree/compiler/dialects/iree_gpu.h"
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/IR.h"
@@ -387,3 +388,37 @@
   result.element = wrap(builder.getI64ArrayAttr(layout.element));
   return result;
 }
+
+ireeGPUTargetInfo
+ireeHALExecutableTargetAttrGetGPUTargetInfo(MlirAttribute attr) {
+  assert(!mlirAttributeIsNull(attr) && "attr cannot be null");
+  auto executableTargetAttr =
+      llvm::cast<mlir::iree_compiler::IREE::HAL::ExecutableTargetAttr>(
+          unwrap(attr));
+
+  ireeGPUTargetInfo targetInfo = {};
+  mlir::MLIRContext *context = executableTargetAttr.getContext();
+  mlir::iree_compiler::IREE::GPU::TargetAttr gpuTargetAttr =
+      mlir::iree_compiler::getGPUTargetAttr(context, executableTargetAttr);
+
+  if (!gpuTargetAttr) {
+    return targetInfo;
+  }
+
+  targetInfo.arch =
+      wrap(mlir::StringAttr::get(context, gpuTargetAttr.getArch()));
+  mlir::iree_compiler::IREE::GPU::TargetWgpAttr wgpAttr =
+      gpuTargetAttr.getWgp();
+  mlir::Builder builder = mlir::OpBuilder(context);
+
+  targetInfo.subgroupSizeChoices =
+      wrap(builder.getI32ArrayAttr(wgpAttr.getSubgroupSizeChoices()));
+  targetInfo.maxWorkgroupSizes =
+      wrap(builder.getI32ArrayAttr(wgpAttr.getMaxWorkgroupSizes()));
+
+  targetInfo.maxThreadCountPerWorkgroup =
+      wgpAttr.getMaxThreadCountPerWorkgroup();
+  targetInfo.maxWorkgroupMemoryBytes = wgpAttr.getMaxWorkgroupMemoryBytes();
+
+  return targetInfo;
+}
diff --git a/compiler/src/iree/compiler/API/api_exports.c b/compiler/src/iree/compiler/API/api_exports.c
index fd10f4a..1ac9ab3 100644
--- a/compiler/src/iree/compiler/API/api_exports.c
+++ b/compiler/src/iree/compiler/API/api_exports.c
@@ -113,6 +113,7 @@
 extern void ireeGPUReorderWorkgroupsStrategyAttrGet();
 extern void ireeGPUReorderWorkgroupsStrategyAttrGetTypeID();
 extern void ireeGPUReorderWorkgroupsStrategyAttrGetValue();
+extern void ireeHALExecutableTargetAttrGetGPUTargetInfo();
 extern void ireeMlirLspServerRunMain();
 extern void ireeOptRunMain();
 extern void ireeReduceRunMain();
@@ -1027,6 +1028,7 @@
   x += (uintptr_t)&ireeGPUReorderWorkgroupsStrategyAttrGet;
   x += (uintptr_t)&ireeGPUReorderWorkgroupsStrategyAttrGetTypeID;
   x += (uintptr_t)&ireeGPUReorderWorkgroupsStrategyAttrGetValue;
+  x += (uintptr_t)&ireeHALExecutableTargetAttrGetGPUTargetInfo;
   x += (uintptr_t)&ireeMlirLspServerRunMain;
   x += (uintptr_t)&ireeOptRunMain;
   x += (uintptr_t)&ireeReduceRunMain;
diff --git a/compiler/src/iree/compiler/API/api_exports.def b/compiler/src/iree/compiler/API/api_exports.def
index fe42d24..807184f 100644
--- a/compiler/src/iree/compiler/API/api_exports.def
+++ b/compiler/src/iree/compiler/API/api_exports.def
@@ -103,6 +103,7 @@
   ireeGPUReorderWorkgroupsStrategyAttrGet
   ireeGPUReorderWorkgroupsStrategyAttrGetTypeID
   ireeGPUReorderWorkgroupsStrategyAttrGetValue
+  ireeHALExecutableTargetAttrGetGPUTargetInfo
   ireeMlirLspServerRunMain
   ireeOptRunMain
   ireeReduceRunMain
diff --git a/compiler/src/iree/compiler/API/api_exports.ld b/compiler/src/iree/compiler/API/api_exports.ld
index c930996..c4d1d7f 100644
--- a/compiler/src/iree/compiler/API/api_exports.ld
+++ b/compiler/src/iree/compiler/API/api_exports.ld
@@ -104,6 +104,7 @@
     ireeGPUReorderWorkgroupsStrategyAttrGet;
     ireeGPUReorderWorkgroupsStrategyAttrGetTypeID;
     ireeGPUReorderWorkgroupsStrategyAttrGetValue;
+    ireeHALExecutableTargetAttrGetGPUTargetInfo;
     ireeMlirLspServerRunMain;
     ireeOptRunMain;
     ireeReduceRunMain;
diff --git a/compiler/src/iree/compiler/API/api_exports.macos.lst b/compiler/src/iree/compiler/API/api_exports.macos.lst
index e07c70c..3bfa0e3 100644
--- a/compiler/src/iree/compiler/API/api_exports.macos.lst
+++ b/compiler/src/iree/compiler/API/api_exports.macos.lst
@@ -102,6 +102,7 @@
 _ireeGPUReorderWorkgroupsStrategyAttrGet
 _ireeGPUReorderWorkgroupsStrategyAttrGetTypeID
 _ireeGPUReorderWorkgroupsStrategyAttrGetValue
+_ireeHALExecutableTargetAttrGetGPUTargetInfo
 _ireeMlirLspServerRunMain
 _ireeOptRunMain
 _ireeReduceRunMain