[python][tuner] Add bindings for lowering config (#19096)
Keeping it simple for now -- no attribute methods are exposed beyond
access to the attribute dictionary.
diff --git a/compiler/bindings/c/iree/compiler/dialects/iree_gpu.h b/compiler/bindings/c/iree/compiler/dialects/iree_gpu.h
index 7a596fe..53e2fbc 100644
--- a/compiler/bindings/c/iree/compiler/dialects/iree_gpu.h
+++ b/compiler/bindings/c/iree/compiler/dialects/iree_gpu.h
@@ -82,6 +82,17 @@
MLIR_CAPI_EXPORTED ireeGPUMMAInfo ireeGPUMMAAttrGetInfo(MlirAttribute attr);
+MLIR_CAPI_EXPORTED bool
+ireeAttributeIsAGPULoweringConfigAttr(MlirAttribute attr);
+
+MLIR_CAPI_EXPORTED MlirTypeID ireeGPULoweringConfigAttrGetTypeID(void);
+
+MLIR_CAPI_EXPORTED MlirAttribute ireeGPULoweringConfigAttrGet(
+ MlirContext mlirCtx, MlirAttribute attributesDictionary);
+
+MLIR_CAPI_EXPORTED MlirAttribute
+ireeGPULoweringConfigAttrGetAttributes(MlirAttribute attr);
+
#ifdef __cplusplus
}
#endif
diff --git a/compiler/bindings/python/IREECompilerDialectsModule.cpp b/compiler/bindings/python/IREECompilerDialectsModule.cpp
index d24075a..33f4681 100644
--- a/compiler/bindings/python/IREECompilerDialectsModule.cpp
+++ b/compiler/bindings/python/IREECompilerDialectsModule.cpp
@@ -115,6 +115,7 @@
//===-------------------------------------------------------------------===//
// GPUMMAIntrinsicAttr
//===-------------------------------------------------------------------===//
+
mlir_attribute_subclass(iree_gpu_module, "MMAIntrinsicAttr",
ireeAttributeIsAGPUMMAIntrinsicAttr,
ireeGPUMMAIntrinsicAttrGetTypeID)
@@ -138,6 +139,10 @@
return ireeGPUMMAAttrGet(mlirAttributeGetContext(self), value);
});
+ //===-------------------------------------------------------------------===//
+ // GPUMMAAttr
+ //===-------------------------------------------------------------------===//
+
mlir_attribute_subclass(iree_gpu_module, "MMAAttr",
ireeAttributeIsAGPUMMAAttr, ireeGPUMMAAttrGetTypeID)
.def_classmethod(
@@ -165,4 +170,22 @@
ireeGPUMMAInfo info = ireeGPUMMAAttrGetInfo(self);
return py::make_tuple(info.mElements, info.nElements, info.kElements);
});
+
+ //===-------------------------------------------------------------------===//
+ // GPULoweringConfigAttr
+ //===-------------------------------------------------------------------===//
+
+ mlir_attribute_subclass(iree_gpu_module, "LoweringConfigAttr",
+ ireeAttributeIsAGPULoweringConfigAttr,
+ ireeGPULoweringConfigAttrGetTypeID)
+ .def_classmethod(
+ "get",
+ [](const py::object &, MlirAttribute attributeDictionary,
+ MlirContext ctx) {
+ return ireeGPULoweringConfigAttrGet(ctx, attributeDictionary);
+ },
+ "cls"_a, "value"_a, "ctx"_a = py::none(),
+ "Gets a gpu.lowering_config from parameters.")
+ .def_property_readonly("attributes",
+ ireeGPULoweringConfigAttrGetAttributes);
}
diff --git a/compiler/bindings/python/test/ir/dialects_test.py b/compiler/bindings/python/test/ir/dialects_test.py
index e1b4121..249b397 100644
--- a/compiler/bindings/python/test/ir/dialects_test.py
+++ b/compiler/bindings/python/test/ir/dialects_test.py
@@ -131,3 +131,15 @@
assert K == 8
assert mma_intrinsic_attr.mma == mma_attr
+
+
+@lambda _: _()
+def lowering_config_attr():
+ with ir.Context() as ctx, ir.Location.unknown():
+ module = ir.Module.create()
+ with ir.InsertionPoint(module.body):
+ attributes = ir.DictAttr.get({"reduction": ir.ArrayAttr.get([])}, ctx)
+ lowering_config = iree_gpu.LoweringConfigAttr.get(attributes, ctx)
+ assert lowering_config is not None
+
+ assert lowering_config.attributes == attributes
diff --git a/compiler/src/iree/compiler/API/Internal/BUILD.bazel b/compiler/src/iree/compiler/API/Internal/BUILD.bazel
index 883f7fb..f26921e 100644
--- a/compiler/src/iree/compiler/API/Internal/BUILD.bazel
+++ b/compiler/src/iree/compiler/API/Internal/BUILD.bazel
@@ -139,5 +139,6 @@
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:CAPIIRHeaders",
+ "@llvm-project//mlir:IR",
],
)
diff --git a/compiler/src/iree/compiler/API/Internal/CMakeLists.txt b/compiler/src/iree/compiler/API/Internal/CMakeLists.txt
index 3ee76d9..80fbef4 100644
--- a/compiler/src/iree/compiler/API/Internal/CMakeLists.txt
+++ b/compiler/src/iree/compiler/API/Internal/CMakeLists.txt
@@ -114,6 +114,7 @@
DEPS
IREELLVMIncludeSetup
MLIRCAPIIR
+ MLIRIR
iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
iree::compiler::bindings::c::headers
PUBLIC
diff --git a/compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp b/compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp
index 15f517e..3badb3f 100644
--- a/compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp
+++ b/compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp
@@ -4,14 +4,17 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include <cassert>
#include <cstdint>
#include <type_traits>
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
#include "iree/compiler/dialects/iree_gpu.h"
+#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/IR.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
+#include "mlir/IR/BuiltinAttributes.h"
bool ireeAttributeIsAGPUPipelineOptionsAttr(MlirAttribute attr) {
return llvm::isa<mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr>(
@@ -184,3 +187,29 @@
std::tie(info.mElements, info.nElements, info.kElements) = mma.getMNKShape();
return info;
}
+
+bool ireeAttributeIsAGPULoweringConfigAttr(MlirAttribute attr) {
+ return llvm::isa<mlir::iree_compiler::IREE::GPU::LoweringConfigAttr>(
+ unwrap(attr));
+}
+
+MlirTypeID ireeGPULoweringConfigAttrGetTypeID() {
+ return wrap(mlir::iree_compiler::IREE::GPU::LoweringConfigAttr::getTypeID());
+}
+
+MlirAttribute ireeGPULoweringConfigAttrGet(MlirContext mlirCtx,
+ MlirAttribute attributesDictionary) {
+ assert(mlirAttributeIsADictionary(attributesDictionary));
+ auto attributes =
+ llvm::cast<mlir::DictionaryAttr>(unwrap(attributesDictionary));
+ mlir::MLIRContext *ctx = unwrap(mlirCtx);
+ return wrap(
+ mlir::iree_compiler::IREE::GPU::LoweringConfigAttr::get(ctx, attributes));
+}
+
+MlirAttribute ireeGPULoweringConfigAttrGetAttributes(MlirAttribute attr) {
+ assert(ireeAttributeIsAGPULoweringConfigAttr(attr));
+ return wrap(llvm::cast<mlir::iree_compiler::IREE::GPU::LoweringConfigAttr>(
+ unwrap(attr))
+ .getAttributes());
+}