blob: 0b235ee39eb2e6522a93c200ebe7c9ab83f60dd5 [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree-dialects-c/Dialects.h"
#include "iree-dialects-c/Utils.h"
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/Diagnostics.h"
#include "mlir-c/Registration.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
namespace py = pybind11;
using namespace mlir::python::adaptors;
PYBIND11_MODULE(_ireeDialects, m) {
m.doc() = "iree-dialects main python extension";
auto irModule = py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"));
auto typeClass = irModule.attr("Type");
//===--------------------------------------------------------------------===//
// Utils
//===--------------------------------------------------------------------===//
m.def(
"lookup_nearest_symbol_from",
[](MlirOperation fromOp, MlirAttribute symbol) {
if (!mlirAttributeIsASymbolRef(symbol)) {
throw std::invalid_argument("expected a SymbolRefAttr");
}
return ireeLookupNearestSymbolFrom(fromOp, symbol);
},
py::arg("fromOp"), py::arg("symbol"));
// TODO: Upstream this into the main Python bindings.
m.def(
"emit_error",
[](MlirLocation loc, std::string message) {
mlirEmitError(loc, message.c_str());
},
py::arg("loc"), py::arg("message"));
//===--------------------------------------------------------------------===//
// IREEDialect
//===--------------------------------------------------------------------===//
auto iree_m = m.def_submodule("iree");
iree_m.def(
"register_dialect",
[](MlirContext context, bool load) {
MlirDialectHandle handle = mlirGetDialectHandle__iree__();
mlirDialectHandleRegisterDialect(handle, context);
if (load) {
mlirDialectHandleLoadDialect(handle, context);
}
},
py::arg("context") = py::none(), py::arg("load") = true);
//===--------------------------------------------------------------------===//
// IREEPyDMDialect
//===--------------------------------------------------------------------===//
auto iree_pydm_m = m.def_submodule("iree_pydm");
iree_pydm_m.def(
"register_dialect",
[](MlirContext context, bool load) {
MlirDialectHandle handle = mlirGetDialectHandle__iree_pydm__();
mlirDialectHandleRegisterDialect(handle, context);
if (load) {
mlirDialectHandleLoadDialect(handle, context);
}
},
py::arg("context") = py::none(), py::arg("load") = true);
iree_pydm_m.def(
"build_lower_to_iree_pass_pipeline",
[](MlirPassManager passManager) {
MlirOpPassManager opPassManager =
mlirPassManagerGetAsOpPassManager(passManager);
mlirIREEPyDMBuildLowerToIREEPassPipeline(opPassManager);
},
py::arg("pass_manager"));
#define DEFINE_IREEPYDM_NULLARY_TYPE(Name) \
mlir_type_subclass(iree_pydm_m, #Name "Type", mlirTypeIsAIREEPyDM##Name, \
typeClass) \
.def_classmethod( \
"get", \
[](py::object cls, MlirContext context) { \
return cls(mlirIREEPyDM##Name##TypeGet(context)); \
}, \
py::arg("cls"), py::arg("context") = py::none());
DEFINE_IREEPYDM_NULLARY_TYPE(Bool)
DEFINE_IREEPYDM_NULLARY_TYPE(Bytes)
DEFINE_IREEPYDM_NULLARY_TYPE(ExceptionResult)
DEFINE_IREEPYDM_NULLARY_TYPE(FreeVarRef)
DEFINE_IREEPYDM_NULLARY_TYPE(Integer)
DEFINE_IREEPYDM_NULLARY_TYPE(List)
DEFINE_IREEPYDM_NULLARY_TYPE(None)
DEFINE_IREEPYDM_NULLARY_TYPE(Real)
DEFINE_IREEPYDM_NULLARY_TYPE(Str)
DEFINE_IREEPYDM_NULLARY_TYPE(Tuple)
DEFINE_IREEPYDM_NULLARY_TYPE(Type)
mlir_type_subclass(iree_pydm_m, "ObjectType", mlirTypeIsAIREEPyDMObject,
typeClass)
.def_classmethod(
"get",
[](py::object cls, MlirContext context) {
return cls(mlirIREEPyDMObjectTypeGet(context, {nullptr}));
},
py::arg("cls"), py::arg("context") = py::none())
.def_classmethod(
"get_typed",
[](py::object cls, MlirType type) {
if (!mlirTypeIsAIREEPyDMPrimitiveType(type)) {
throw std::invalid_argument(
"expected a primitive type when constructing object");
}
MlirContext context = mlirTypeGetContext(type);
return cls(mlirIREEPyDMObjectTypeGet(context, type));
},
py::arg("cls"), py::arg("type"));
}