Add python registration support for IREE dialects. (#13549)
Unlike the previous mechanism, this will automatically register all
known dialects with any context created from python.
Fixes #13477 after #13395 removed an API that was load bearing to a
customer use case.
diff --git a/compiler/bindings/python/CMakeLists.txt b/compiler/bindings/python/CMakeLists.txt
index c7a791b..0c03446 100644
--- a/compiler/bindings/python/CMakeLists.txt
+++ b/compiler/bindings/python/CMakeLists.txt
@@ -31,6 +31,7 @@
# to be fixed to capture the correct include directories in that macro.
include_directories(
"${IREE_SOURCE_DIR}/compiler/src"
+ "${IREE_SOURCE_DIR}/compiler/bindings/c"
"${IREE_SOURCE_DIR}/llvm-external-projects/iree-dialects/include"
"${IREE_SOURCE_DIR}/third_party/llvm-project/mlir/include"
"${IREE_SOURCE_DIR}/third_party/mlir-hlo/include"
@@ -61,12 +62,30 @@
)
################################################################################
+# Extensions
+################################################################################
+
+declare_mlir_python_sources(IREECompilerPythonExtensions)
+
+declare_mlir_python_extension(IREECompilerPythonExtensions.Registration
+ MODULE_NAME _site_initialize_0
+ ADD_TO_PARENT IREECompilerPythonExtensions
+ SOURCES
+ IREECompilerRegistration.cpp
+ EMBED_CAPI_LINK_LIBS
+ iree_compiler_API_SharedImpl
+ PRIVATE_LINK_LIBS
+ LLVMSupport
+)
+
+################################################################################
# Generate packages and shared library
################################################################################
set(_SOURCE_COMPONENTS
# Local sources.
IREECompilerAPIPythonTools
+ IREECompilerPythonExtensions
MLIRPythonSources.Core
diff --git a/compiler/bindings/python/IREECompilerRegistration.cpp b/compiler/bindings/python/IREECompilerRegistration.cpp
new file mode 100644
index 0000000..d06ac13
--- /dev/null
+++ b/compiler/bindings/python/IREECompilerRegistration.cpp
@@ -0,0 +1,20 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/API/MLIRInterop.h"
+#include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/PybindAdaptors.h"
+
+namespace py = pybind11;
+using namespace mlir::python::adaptors;
+
+PYBIND11_MODULE(_site_initialize_0, m) {
+ m.doc() = "iree-compile registration";
+
+ m.def("register_dialects", [](MlirDialectRegistry registry) {
+ ireeCompilerRegisterDialects(registry);
+ });
+}
diff --git a/compiler/bindings/python/test/CMakeLists.txt b/compiler/bindings/python/test/CMakeLists.txt
index e5080d3..d3f6293 100644
--- a/compiler/bindings/python/test/CMakeLists.txt
+++ b/compiler/bindings/python/test/CMakeLists.txt
@@ -4,4 +4,5 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+add_subdirectory(ir)
add_subdirectory(tools)
diff --git a/compiler/bindings/python/test/ir/CMakeLists.txt b/compiler/bindings/python/test/ir/CMakeLists.txt
new file mode 100644
index 0000000..c49e26d
--- /dev/null
+++ b/compiler/bindings/python/test/ir/CMakeLists.txt
@@ -0,0 +1,13 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+iree_py_test(
+ NAME
+ registration_test
+ SRCS
+ "registration_test.py"
+)
+
diff --git a/compiler/bindings/python/test/ir/registration_test.py b/compiler/bindings/python/test/ir/registration_test.py
new file mode 100644
index 0000000..d2bf693
--- /dev/null
+++ b/compiler/bindings/python/test/ir/registration_test.py
@@ -0,0 +1,20 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from iree.compiler import ir
+
+# Just a simple test that dialects have been registered properly on the
+# context.
+with ir.Context() as ctx:
+ input_module = ir.Module.parse(r"""
+ builtin.module {
+ func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
+ %0 = arith.mulf %arg0, %arg1 : tensor<4xf32>
+ return %0 : tensor<4xf32>
+ }
+ }
+ """)
+ print(input_module)
diff --git a/compiler/setup.py b/compiler/setup.py
index f52b20f..df9e1e9 100644
--- a/compiler/setup.py
+++ b/compiler/setup.py
@@ -408,6 +408,7 @@
# it also needs to be enabled on the build side.
# CMakeExtension("iree.compiler._mlir_libs._mlirHlo"),
CMakeExtension("iree.compiler._mlir_libs._mlirLinalgPasses"),
+ CMakeExtension("iree.compiler._mlir_libs._site_initialize_0"),
],
cmdclass={
"build": CustomBuild,
diff --git a/compiler/src/iree/compiler/API/Internal/Embed.cpp b/compiler/src/iree/compiler/API/Internal/Embed.cpp
index f6615e7..086e320 100644
--- a/compiler/src/iree/compiler/API/Internal/Embed.cpp
+++ b/compiler/src/iree/compiler/API/Internal/Embed.cpp
@@ -1155,6 +1155,12 @@
// Unstable MLIRInterop.h helpers
//===----------------------------------------------------------------------===//
+void ireeCompilerRegisterDialects(MlirDialectRegistry registry) {
+ mlir::DialectRegistry *cppRegistry = unwrap(registry);
+ mlir::iree_compiler::registerAllDialects(*cppRegistry);
+ mlir::iree_compiler::registerLLVMIRTranslations(*cppRegistry);
+}
+
MlirContext ireeCompilerSessionGetContext(iree_compiler_session_t *session) {
return wrap(&unwrap(session)->context);
}
diff --git a/compiler/src/iree/compiler/API/MLIRInterop.h b/compiler/src/iree/compiler/API/MLIRInterop.h
index 2a18d04d..121ec5b 100644
--- a/compiler/src/iree/compiler/API/MLIRInterop.h
+++ b/compiler/src/iree/compiler/API/MLIRInterop.h
@@ -14,6 +14,7 @@
#define IREE_COMPILER_API_MLIR_INTEROP_H
#include "iree/compiler/embedding_api.h"
+#include "mlir-c/IR.h"
#include "mlir-c/Pass.h"
#include "mlir-c/Support.h"
@@ -21,6 +22,10 @@
extern "C" {
#endif
+// Registers all dialects and extensions known to the IREE compiler.
+MLIR_CAPI_EXPORTED void ireeCompilerRegisterDialects(
+ MlirDialectRegistry registry);
+
// Gets the MlirContext that the session manages. The context is owned by the
// session and valid until it is destroyed.
MLIR_CAPI_EXPORTED MlirContext
diff --git a/compiler/src/iree/compiler/API/api_exports.c b/compiler/src/iree/compiler/API/api_exports.c
index 0682ea1..c182ce7 100644
--- a/compiler/src/iree/compiler/API/api_exports.c
+++ b/compiler/src/iree/compiler/API/api_exports.c
@@ -37,6 +37,7 @@
extern void ireeCompilerOutputOpenFile();
extern void ireeCompilerOutputOpenMembuffer();
extern void ireeCompilerOutputWrite();
+extern void ireeCompilerRegisterDialects();
extern void ireeCompilerRunLldMain();
extern void ireeCompilerRunMain();
extern void ireeCompilerSessionCreate();
@@ -627,6 +628,7 @@
x += (uintptr_t)&ireeCompilerOutputOpenFile;
x += (uintptr_t)&ireeCompilerOutputOpenMembuffer;
x += (uintptr_t)&ireeCompilerOutputWrite;
+ x += (uintptr_t)&ireeCompilerRegisterDialects;
x += (uintptr_t)&ireeCompilerRunLldMain;
x += (uintptr_t)&ireeCompilerRunMain;
x += (uintptr_t)&ireeCompilerSessionCreate;
diff --git a/compiler/src/iree/compiler/API/api_exports.def b/compiler/src/iree/compiler/API/api_exports.def
index a40a157..f61d4a9 100644
--- a/compiler/src/iree/compiler/API/api_exports.def
+++ b/compiler/src/iree/compiler/API/api_exports.def
@@ -29,6 +29,7 @@
ireeCompilerOutputOpenFile
ireeCompilerOutputOpenMembuffer
ireeCompilerOutputWrite
+ ireeCompilerRegisterDialects
ireeCompilerRunLldMain
ireeCompilerRunMain
ireeCompilerSessionCreate
diff --git a/compiler/src/iree/compiler/API/api_exports.ld b/compiler/src/iree/compiler/API/api_exports.ld
index 3351f3a..6f40c43 100644
--- a/compiler/src/iree/compiler/API/api_exports.ld
+++ b/compiler/src/iree/compiler/API/api_exports.ld
@@ -30,6 +30,7 @@
ireeCompilerOutputOpenFile;
ireeCompilerOutputOpenMembuffer;
ireeCompilerOutputWrite;
+ ireeCompilerRegisterDialects;
ireeCompilerRunLldMain;
ireeCompilerRunMain;
ireeCompilerSessionCreate;
diff --git a/compiler/src/iree/compiler/API/api_exports.macos.lst b/compiler/src/iree/compiler/API/api_exports.macos.lst
index 5456c1d..872ad1c 100644
--- a/compiler/src/iree/compiler/API/api_exports.macos.lst
+++ b/compiler/src/iree/compiler/API/api_exports.macos.lst
@@ -28,6 +28,7 @@
_ireeCompilerOutputOpenFile
_ireeCompilerOutputOpenMembuffer
_ireeCompilerOutputWrite
+_ireeCompilerRegisterDialects
_ireeCompilerRunLldMain
_ireeCompilerRunMain
_ireeCompilerSessionCreate