[Codegen][Tuner] expose python binding to query target info (#21782)
Motivated by issue #2048, this PR exposes the python bindings to query
relevant target info, which will be used to do constraint generation in
the tuner.
---------
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
diff --git a/compiler/bindings/c/iree/compiler/dialects/iree_gpu.h b/compiler/bindings/c/iree/compiler/dialects/iree_gpu.h
index c6afe11..a8bc342 100644
--- a/compiler/bindings/c/iree/compiler/dialects/iree_gpu.h
+++ b/compiler/bindings/c/iree/compiler/dialects/iree_gpu.h
@@ -144,6 +144,18 @@
MLIR_CAPI_EXPORTED ireeGPUMMASingleSubgroupLayout
ireeGPUGetSingleSubgroupLayout(MlirAttribute attr, uint32_t fragment);
+struct ireeGPUTargetInfo {
+ MlirIdentifier arch; // E.g., "gfx942".
+ MlirAttribute subgroupSizeChoices; // Subgroup size choices.
+ MlirAttribute maxWorkgroupSizes; // Max threads per X/Y/Z dimension.
+ int64_t maxThreadCountPerWorkgroup; // Max threads per workgroup.
+ int64_t maxWorkgroupMemoryBytes; // Max workgroup memory.
+};
+
+// Queries GPU target info from the given `ExecutableTargetAttr` attribute.
+MLIR_CAPI_EXPORTED ireeGPUTargetInfo
+ireeHALExecutableTargetAttrGetGPUTargetInfo(MlirAttribute attr);
+
#ifdef __cplusplus
}
#endif
diff --git a/compiler/bindings/python/IREECompilerDialectsModule.cpp b/compiler/bindings/python/IREECompilerDialectsModule.cpp
index f7a4380..27bb8cc 100644
--- a/compiler/bindings/python/IREECompilerDialectsModule.cpp
+++ b/compiler/bindings/python/IREECompilerDialectsModule.cpp
@@ -505,6 +505,38 @@
});
//===-------------------------------------------------------------------===//
+ // Binding to query target info
+ //===-------------------------------------------------------------------===//
+
+ py::class_<ireeGPUTargetInfo>(iree_gpu_module, "TargetInfo")
+ .def_prop_ro("arch",
+ [](const ireeGPUTargetInfo &self) -> std::string {
+ MlirStringRef strRef = mlirIdentifierStr(self.arch);
+ return std::string(strRef.data, strRef.length);
+ })
+ .def_prop_ro("subgroup_size_choices",
+ [](const ireeGPUTargetInfo &self) -> std::vector<int64_t> {
+ return getIntArrayAttrValues(self.subgroupSizeChoices);
+ })
+ .def_prop_ro("max_thread_count_per_workgroup",
+ [](const ireeGPUTargetInfo &self) -> int64_t {
+ return self.maxThreadCountPerWorkgroup;
+ })
+ .def_prop_ro("max_workgroup_sizes",
+ [](const ireeGPUTargetInfo &self) -> std::vector<int64_t> {
+ return getIntArrayAttrValues(self.maxWorkgroupSizes);
+ })
+ .def_prop_ro("max_workgroup_memory_bytes",
+ [](const ireeGPUTargetInfo &self) -> int64_t {
+ return self.maxWorkgroupMemoryBytes;
+ });
+
+ iree_gpu_module.def(
+ "get_gpu_target_info", &ireeHALExecutableTargetAttrGetGPUTargetInfo,
+ "Extracts GPU target information from an executable target attribute.",
+ py::arg("executable_target_attr"));
+
+ //===-------------------------------------------------------------------===//
// Binding to utility function getSingleSubgroupLayout
//===-------------------------------------------------------------------===//
py::class_<ireeGPUMMASingleSubgroupLayout>(iree_gpu_module,
@@ -592,12 +624,7 @@
});
iree_codegen_module.def(
- "get_attention_op_detail",
- [](MlirAffineMap q, MlirAffineMap k, MlirAffineMap v, MlirAffineMap o) {
- ireeCodegenAttentionOpDetail result =
- ireeCodegenGetAttentionOpDetail(q, k, v, o);
- return result;
- },
+ "get_attention_op_detail", &ireeCodegenGetAttentionOpDetail,
"Infers the structure of an attention operation from affine indexing "
"maps.",
py::arg("q"), py::arg("k"), py::arg("v"), py::arg("o"));
diff --git a/compiler/bindings/python/test/ir/dialects_test.py b/compiler/bindings/python/test/ir/dialects_test.py
index 86ad071..4f5060e 100644
--- a/compiler/bindings/python/test/ir/dialects_test.py
+++ b/compiler/bindings/python/test/ir/dialects_test.py
@@ -391,3 +391,68 @@
assert compilation_info is not None
assert compilation_info.lowering_config == lowering_config
assert compilation_info.translation_info == translation_info
+
+
+@run
+def gpu_target_info_attribute_parsing():
+ mlir_string = """
+ hal.executable private @main_dispatch_0 {
+ hal.executable.variant public @rocm_hsaco_fb
+ target(<"rocm", "rocm-hsaco-fb",
+ {
+ abi = "hip",
+ iree_codegen.target_info = #iree_gpu.target<
+ arch = "gfx942",
+ features = "",
+ wgp = <
+ compute = fp64,
+ storage = b64,
+ subgroup = none,
+ dot = none,
+ mma = [<MFMA_F32_16x16x4_F32>],
+ subgroup_size_choices = [32, 64],
+ max_workgroup_sizes = [256, 512, 1024],
+ max_thread_count_per_workgroup = 1024,
+ max_workgroup_memory_bytes = 65536,
+ max_workgroup_counts = [256, 512, 1024]
+ >
+ >
+ }>
+ ) {
+ }
+ }
+ """
+
+ module = ir.Module.parse(mlir_string)
+ variant_op_list = iree_codegen.get_executable_variant_ops(module)
+ assert len(variant_op_list) == 1, "Expect one executable variant op"
+ variant_op = variant_op_list[0]
+ executable_variant_op = variant_op.opview
+ target = executable_variant_op.target
+ gpu_target_info = iree_gpu.get_gpu_target_info(target)
+
+ arch = gpu_target_info.arch
+ assert arch == "gfx942", f"Expected arch 'gfx942', got '{arch}'"
+
+ subgroup_size_choices = gpu_target_info.subgroup_size_choices
+ assert subgroup_size_choices == [
+ 32,
+ 64,
+ ], f"Expected subgroup_size_choice [32, 64], got {subgroup_size_choices}"
+
+ max_thread_count = gpu_target_info.max_thread_count_per_workgroup
+ assert (
+ max_thread_count == 1024
+ ), f"Expected max_thread_count_per_workgroup 1024, got {max_thread_count}"
+
+ max_memory_bytes = gpu_target_info.max_workgroup_memory_bytes
+ assert (
+ max_memory_bytes == 65536
+ ), f"Expected max_workgroup_memory_bytes 65536, got {max_memory_bytes}"
+
+ max_workgroup_sizes = gpu_target_info.max_workgroup_sizes
+ assert max_workgroup_sizes == [
+ 256,
+ 512,
+ 1024,
+ ], f"Expected max_workgroup_sizes [256, 512, 1024], got {max_workgroup_sizes}"
diff --git a/compiler/src/iree/compiler/API/Internal/BUILD.bazel b/compiler/src/iree/compiler/API/Internal/BUILD.bazel
index 0daba3f..fcd4afc 100644
--- a/compiler/src/iree/compiler/API/Internal/BUILD.bazel
+++ b/compiler/src/iree/compiler/API/Internal/BUILD.bazel
@@ -156,6 +156,7 @@
deps = [
"//compiler/bindings/c:headers",
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
+ "//compiler/src/iree/compiler/Codegen/Utils",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:IR",
diff --git a/compiler/src/iree/compiler/API/Internal/CMakeLists.txt b/compiler/src/iree/compiler/API/Internal/CMakeLists.txt
index 3e77c4a..7a34949 100644
--- a/compiler/src/iree/compiler/API/Internal/CMakeLists.txt
+++ b/compiler/src/iree/compiler/API/Internal/CMakeLists.txt
@@ -135,6 +135,7 @@
MLIRCAPIIR
MLIRIR
iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
+ iree::compiler::Codegen::Utils
iree::compiler::bindings::c::headers
PUBLIC
)
diff --git a/compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp b/compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp
index 7810202..be04348 100644
--- a/compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp
+++ b/compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp
@@ -10,6 +10,7 @@
#include "iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/dialects/iree_gpu.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/IR.h"
@@ -387,3 +388,37 @@
result.element = wrap(builder.getI64ArrayAttr(layout.element));
return result;
}
+
+ireeGPUTargetInfo
+ireeHALExecutableTargetAttrGetGPUTargetInfo(MlirAttribute attr) {
+ assert(!mlirAttributeIsNull(attr) && "attr cannot be null");
+ auto executableTargetAttr =
+ llvm::cast<mlir::iree_compiler::IREE::HAL::ExecutableTargetAttr>(
+ unwrap(attr));
+
+ ireeGPUTargetInfo targetInfo = {};
+ mlir::MLIRContext *context = executableTargetAttr.getContext();
+ mlir::iree_compiler::IREE::GPU::TargetAttr gpuTargetAttr =
+ mlir::iree_compiler::getGPUTargetAttr(context, executableTargetAttr);
+
+ if (!gpuTargetAttr) {
+ return targetInfo;
+ }
+
+ targetInfo.arch =
+ wrap(mlir::StringAttr::get(context, gpuTargetAttr.getArch()));
+ mlir::iree_compiler::IREE::GPU::TargetWgpAttr wgpAttr =
+ gpuTargetAttr.getWgp();
+ mlir::Builder builder = mlir::OpBuilder(context);
+
+ targetInfo.subgroupSizeChoices =
+ wrap(builder.getI32ArrayAttr(wgpAttr.getSubgroupSizeChoices()));
+ targetInfo.maxWorkgroupSizes =
+ wrap(builder.getI32ArrayAttr(wgpAttr.getMaxWorkgroupSizes()));
+
+ targetInfo.maxThreadCountPerWorkgroup =
+ wgpAttr.getMaxThreadCountPerWorkgroup();
+ targetInfo.maxWorkgroupMemoryBytes = wgpAttr.getMaxWorkgroupMemoryBytes();
+
+ return targetInfo;
+}
diff --git a/compiler/src/iree/compiler/API/api_exports.c b/compiler/src/iree/compiler/API/api_exports.c
index fd10f4a..1ac9ab3 100644
--- a/compiler/src/iree/compiler/API/api_exports.c
+++ b/compiler/src/iree/compiler/API/api_exports.c
@@ -113,6 +113,7 @@
extern void ireeGPUReorderWorkgroupsStrategyAttrGet();
extern void ireeGPUReorderWorkgroupsStrategyAttrGetTypeID();
extern void ireeGPUReorderWorkgroupsStrategyAttrGetValue();
+extern void ireeHALExecutableTargetAttrGetGPUTargetInfo();
extern void ireeMlirLspServerRunMain();
extern void ireeOptRunMain();
extern void ireeReduceRunMain();
@@ -1027,6 +1028,7 @@
x += (uintptr_t)&ireeGPUReorderWorkgroupsStrategyAttrGet;
x += (uintptr_t)&ireeGPUReorderWorkgroupsStrategyAttrGetTypeID;
x += (uintptr_t)&ireeGPUReorderWorkgroupsStrategyAttrGetValue;
+ x += (uintptr_t)&ireeHALExecutableTargetAttrGetGPUTargetInfo;
x += (uintptr_t)&ireeMlirLspServerRunMain;
x += (uintptr_t)&ireeOptRunMain;
x += (uintptr_t)&ireeReduceRunMain;
diff --git a/compiler/src/iree/compiler/API/api_exports.def b/compiler/src/iree/compiler/API/api_exports.def
index fe42d24..807184f 100644
--- a/compiler/src/iree/compiler/API/api_exports.def
+++ b/compiler/src/iree/compiler/API/api_exports.def
@@ -103,6 +103,7 @@
ireeGPUReorderWorkgroupsStrategyAttrGet
ireeGPUReorderWorkgroupsStrategyAttrGetTypeID
ireeGPUReorderWorkgroupsStrategyAttrGetValue
+ ireeHALExecutableTargetAttrGetGPUTargetInfo
ireeMlirLspServerRunMain
ireeOptRunMain
ireeReduceRunMain
diff --git a/compiler/src/iree/compiler/API/api_exports.ld b/compiler/src/iree/compiler/API/api_exports.ld
index c930996..c4d1d7f 100644
--- a/compiler/src/iree/compiler/API/api_exports.ld
+++ b/compiler/src/iree/compiler/API/api_exports.ld
@@ -104,6 +104,7 @@
ireeGPUReorderWorkgroupsStrategyAttrGet;
ireeGPUReorderWorkgroupsStrategyAttrGetTypeID;
ireeGPUReorderWorkgroupsStrategyAttrGetValue;
+ ireeHALExecutableTargetAttrGetGPUTargetInfo;
ireeMlirLspServerRunMain;
ireeOptRunMain;
ireeReduceRunMain;
diff --git a/compiler/src/iree/compiler/API/api_exports.macos.lst b/compiler/src/iree/compiler/API/api_exports.macos.lst
index e07c70c..3bfa0e3 100644
--- a/compiler/src/iree/compiler/API/api_exports.macos.lst
+++ b/compiler/src/iree/compiler/API/api_exports.macos.lst
@@ -102,6 +102,7 @@
_ireeGPUReorderWorkgroupsStrategyAttrGet
_ireeGPUReorderWorkgroupsStrategyAttrGetTypeID
_ireeGPUReorderWorkgroupsStrategyAttrGetValue
+_ireeHALExecutableTargetAttrGetGPUTargetInfo
_ireeMlirLspServerRunMain
_ireeOptRunMain
_ireeReduceRunMain