[python][tuner] Add bindings for `iree_codegen.compilation_info` (#19129)
diff --git a/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h b/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h
index 357ac87..029e672 100644
--- a/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h
+++ b/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h
@@ -50,6 +50,22 @@
MLIR_CAPI_EXPORTED ireeCodegenTranslationInfoParameters
ireeCodegenTranslationInfoAttrGetParameters(MlirAttribute attr);
+MLIR_CAPI_EXPORTED bool
+ireeAttributeIsACodegenCompilationInfoAttr(MlirAttribute attr);
+
+MLIR_CAPI_EXPORTED MlirTypeID ireeCodegenCompilationInfoAttrGetTypeID(void);
+
+struct ireeCodegenCompilationInfoParameters {
+ MlirAttribute loweringConfig;
+ MlirAttribute translationInfo;
+};
+
+MLIR_CAPI_EXPORTED MlirAttribute ireeCodegenCompilationInfoAttrGet(
+ MlirContext mlirCtx, ireeCodegenCompilationInfoParameters parameters);
+
+MLIR_CAPI_EXPORTED ireeCodegenCompilationInfoParameters
+ireeCodegenCompilationInfoAttrGetParameters(MlirAttribute attr);
+
#ifdef __cplusplus
}
#endif
diff --git a/compiler/bindings/python/IREECompilerDialectsModule.cpp b/compiler/bindings/python/IREECompilerDialectsModule.cpp
index 1f8be3f..7ece224 100644
--- a/compiler/bindings/python/IREECompilerDialectsModule.cpp
+++ b/compiler/bindings/python/IREECompilerDialectsModule.cpp
@@ -83,8 +83,7 @@
"cls"_a, "pass_pipeline"_a, "codegen_spec"_a = py::none(),
"workgroup_size"_a = py::none(), "subgroup_size"_a = py::none(),
"configuration"_a = py::none(), py::kw_only(), "ctx"_a = py::none(),
- "Gets an #iree_codegen.translation_info from "
- "parameters.")
+ "Gets an #iree_codegen.translation_info from parameters.")
.def_property_readonly(
"pass_pipeline",
[](MlirAttribute self) -> MlirAttribute {
@@ -124,6 +123,37 @@
return parameters.configuration;
});
+ //===-------------------------------------------------------------------===//
+ // CodegenCompilationInfoAttr
+ //===-------------------------------------------------------------------===//
+
+ mlir_attribute_subclass(iree_codegen_module, "CompilationInfoAttr",
+ ireeAttributeIsACodegenCompilationInfoAttr,
+ ireeCodegenCompilationInfoAttrGetTypeID)
+ .def_classmethod(
+ "get",
+ [](const py::object &, MlirAttribute loweringConfig,
+ MlirAttribute translationInfo, MlirContext ctx) {
+ ireeCodegenCompilationInfoParameters parameters = {};
+ parameters.loweringConfig = loweringConfig;
+ parameters.translationInfo = translationInfo;
+ return ireeCodegenCompilationInfoAttrGet(ctx, parameters);
+ },
+ "cls"_a, "lowering_config"_a, "translation_info"_a,
+ "ctx"_a = py::none(),
+ "Gets an #iree_codegen.compilation_info from parameters.")
+ .def_property_readonly(
+ "lowering_config",
+ [](MlirAttribute self) -> MlirAttribute {
+ auto parameters = ireeCodegenCompilationInfoAttrGetParameters(self);
+ return parameters.loweringConfig;
+ })
+ .def_property_readonly(
+ "translation_info", [](MlirAttribute self) -> MlirAttribute {
+ auto parameters = ireeCodegenCompilationInfoAttrGetParameters(self);
+ return parameters.translationInfo;
+ });
+
//===--------------------------------------------------------------------===//
auto iree_gpu_module =
diff --git a/compiler/bindings/python/test/ir/dialects_test.py b/compiler/bindings/python/test/ir/dialects_test.py
index 378d6ca..381ecea 100644
--- a/compiler/bindings/python/test/ir/dialects_test.py
+++ b/compiler/bindings/python/test/ir/dialects_test.py
@@ -215,3 +215,20 @@
assert lowering_config is not None
assert lowering_config.attributes == attributes
+
+
+@run
+def compilation_info():
+ attributes = ir.DictAttr.get({"reduction": ir.ArrayAttr.get([])})
+ lowering_config = iree_gpu.LoweringConfigAttr.get(attributes)
+ pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get(
+ iree_codegen.DispatchLoweringPassPipeline.None_
+ )
+ translation_info = iree_codegen.TranslationInfoAttr.get(pipeline_attr)
+
+ compilation_info = iree_codegen.CompilationInfoAttr.get(
+ lowering_config, translation_info
+ )
+ assert compilation_info is not None
+ assert compilation_info.lowering_config == lowering_config
+ assert compilation_info.translation_info == translation_info
diff --git a/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp b/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp
index 13f2257..c295d48 100644
--- a/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp
+++ b/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp
@@ -9,15 +9,20 @@
#include <optional>
#include <type_traits>
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
#include "iree/compiler/dialects/iree_codegen.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/IR.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/MLIRContext.h"
+using mlir::iree_compiler::IREE::Codegen::CompilationInfoAttr;
using mlir::iree_compiler::IREE::Codegen::DispatchLoweringPassPipeline;
using mlir::iree_compiler::IREE::Codegen::DispatchLoweringPassPipelineAttr;
+using mlir::iree_compiler::IREE::Codegen::LoweringConfigAttrInterface;
using mlir::iree_compiler::IREE::Codegen::TranslationInfoAttr;
bool ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr(
@@ -109,3 +114,38 @@
return parameters;
}
+
+bool ireeAttributeIsACodegenCompilationInfoAttr(MlirAttribute attr) {
+ return llvm::isa<CompilationInfoAttr>(unwrap(attr));
+}
+
+MlirTypeID ireeCodegenCompilationInfoAttrGetTypeID() {
+ return wrap(CompilationInfoAttr::getTypeID());
+}
+
+MlirAttribute ireeCodegenCompilationInfoAttrGet(
+ MlirContext mlirCtx, ireeCodegenCompilationInfoParameters parameters) {
+ assert(!mlirAttributeIsNull(parameters.loweringConfig) &&
+ "Invalid lowering config attr");
+ assert(
+ !mlirAttributeIsNull(parameters.translationInfo) &&
+ ireeAttributeIsACodegenTranslationInfoAttr(parameters.translationInfo) &&
+ "Invalid translation info attr");
+
+ auto loweringConfig = llvm::cast<LoweringConfigAttrInterface>(
+ unwrap(parameters.loweringConfig));
+ auto translationInfo =
+ llvm::cast<TranslationInfoAttr>(unwrap(parameters.translationInfo));
+
+ mlir::MLIRContext *ctx = unwrap(mlirCtx);
+ return wrap(CompilationInfoAttr::get(ctx, loweringConfig, translationInfo));
+}
+
+ireeCodegenCompilationInfoParameters
+ireeCodegenCompilationInfoAttrGetParameters(MlirAttribute attr) {
+ auto compilationInfo = llvm::cast<CompilationInfoAttr>(unwrap(attr));
+ ireeCodegenCompilationInfoParameters parameters = {};
+ parameters.loweringConfig = wrap(compilationInfo.getLoweringConfig());
+ parameters.translationInfo = wrap(compilationInfo.getTranslationInfo());
+ return parameters;
+}
diff --git a/compiler/src/iree/compiler/API/api_exports.c b/compiler/src/iree/compiler/API/api_exports.c
index 9105596..ffb8f08 100644
--- a/compiler/src/iree/compiler/API/api_exports.c
+++ b/compiler/src/iree/compiler/API/api_exports.c
@@ -10,6 +10,7 @@
#include <stdint.h>
+extern void ireeAttributeIsACodegenCompilationInfoAttr();
extern void ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr();
extern void ireeAttributeIsACodegenTranslationInfoAttr();
extern void ireeAttributeIsAGPULoweringConfigAttr();
@@ -17,6 +18,9 @@
extern void ireeAttributeIsAGPUMMAIntrinsicAttr();
extern void ireeAttributeIsAGPUPipelineOptionsAttr();
extern void ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr();
+extern void ireeCodegenCompilationInfoAttrGet();
+extern void ireeCodegenCompilationInfoAttrGetParameters();
+extern void ireeCodegenCompilationInfoAttrGetTypeID();
extern void ireeCodegenDispatchLoweringPassPipelineAttrGet();
extern void ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID();
extern void ireeCodegenDispatchLoweringPassPipelineAttrGetValue();
@@ -868,6 +872,7 @@
uintptr_t __iree_compiler_hidden_force_extern() {
uintptr_t x = 0;
+ x += (uintptr_t)&ireeAttributeIsACodegenCompilationInfoAttr;
x += (uintptr_t)&ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr;
x += (uintptr_t)&ireeAttributeIsACodegenTranslationInfoAttr;
x += (uintptr_t)&ireeAttributeIsAGPULoweringConfigAttr;
@@ -875,6 +880,9 @@
x += (uintptr_t)&ireeAttributeIsAGPUMMAIntrinsicAttr;
x += (uintptr_t)&ireeAttributeIsAGPUPipelineOptionsAttr;
x += (uintptr_t)&ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr;
+ x += (uintptr_t)&ireeCodegenCompilationInfoAttrGet;
+ x += (uintptr_t)&ireeCodegenCompilationInfoAttrGetParameters;
+ x += (uintptr_t)&ireeCodegenCompilationInfoAttrGetTypeID;
x += (uintptr_t)&ireeCodegenDispatchLoweringPassPipelineAttrGet;
x += (uintptr_t)&ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID;
x += (uintptr_t)&ireeCodegenDispatchLoweringPassPipelineAttrGetValue;
diff --git a/compiler/src/iree/compiler/API/api_exports.def b/compiler/src/iree/compiler/API/api_exports.def
index 9844a9f..ed5e12c 100644
--- a/compiler/src/iree/compiler/API/api_exports.def
+++ b/compiler/src/iree/compiler/API/api_exports.def
@@ -1,5 +1,6 @@
; Generated by generate_exports.py: Do not edit.
EXPORTS
+ ireeAttributeIsACodegenCompilationInfoAttr
ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr
ireeAttributeIsACodegenTranslationInfoAttr
ireeAttributeIsAGPULoweringConfigAttr
@@ -7,6 +8,9 @@
ireeAttributeIsAGPUMMAIntrinsicAttr
ireeAttributeIsAGPUPipelineOptionsAttr
ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr
+ ireeCodegenCompilationInfoAttrGet
+ ireeCodegenCompilationInfoAttrGetParameters
+ ireeCodegenCompilationInfoAttrGetTypeID
ireeCodegenDispatchLoweringPassPipelineAttrGet
ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID
ireeCodegenDispatchLoweringPassPipelineAttrGetValue
diff --git a/compiler/src/iree/compiler/API/api_exports.ld b/compiler/src/iree/compiler/API/api_exports.ld
index 0c36027..0808927 100644
--- a/compiler/src/iree/compiler/API/api_exports.ld
+++ b/compiler/src/iree/compiler/API/api_exports.ld
@@ -1,6 +1,7 @@
# Generated by generate_exports.py: Do not edit.
VER_0 {
global:
+ ireeAttributeIsACodegenCompilationInfoAttr;
ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr;
ireeAttributeIsACodegenTranslationInfoAttr;
ireeAttributeIsAGPULoweringConfigAttr;
@@ -8,6 +9,9 @@
ireeAttributeIsAGPUMMAIntrinsicAttr;
ireeAttributeIsAGPUPipelineOptionsAttr;
ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr;
+ ireeCodegenCompilationInfoAttrGet;
+ ireeCodegenCompilationInfoAttrGetParameters;
+ ireeCodegenCompilationInfoAttrGetTypeID;
ireeCodegenDispatchLoweringPassPipelineAttrGet;
ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID;
ireeCodegenDispatchLoweringPassPipelineAttrGetValue;
diff --git a/compiler/src/iree/compiler/API/api_exports.macos.lst b/compiler/src/iree/compiler/API/api_exports.macos.lst
index d0683e8..11169bf 100644
--- a/compiler/src/iree/compiler/API/api_exports.macos.lst
+++ b/compiler/src/iree/compiler/API/api_exports.macos.lst
@@ -1,4 +1,5 @@
# Generated by generate_exports.py: Do not edit.
+_ireeAttributeIsACodegenCompilationInfoAttr
_ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr
_ireeAttributeIsACodegenTranslationInfoAttr
_ireeAttributeIsAGPULoweringConfigAttr
@@ -6,6 +7,9 @@
_ireeAttributeIsAGPUMMAIntrinsicAttr
_ireeAttributeIsAGPUPipelineOptionsAttr
_ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr
+_ireeCodegenCompilationInfoAttrGet
+_ireeCodegenCompilationInfoAttrGetParameters
+_ireeCodegenCompilationInfoAttrGetTypeID
_ireeCodegenDispatchLoweringPassPipelineAttrGet
_ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID
_ireeCodegenDispatchLoweringPassPipelineAttrGetValue