Adds CAPI interop to export a module from an invocation. (#14722)
Also adds Python API to use it and get a Python Operation.
diff --git a/compiler/bindings/python/iree/compiler/api/ctypes_dl.py b/compiler/bindings/python/iree/compiler/api/ctypes_dl.py
index 0275095..aeb0ef2 100644
--- a/compiler/bindings/python/iree/compiler/api/ctypes_dl.py
+++ b/compiler/bindings/python/iree/compiler/api/ctypes_dl.py
@@ -103,6 +103,11 @@
c_bool,
[c_void_p, c_void_p],
)
+ _setsig(
+ _dylib.ireeCompilerInvocationExportStealModule,
+ c_void_p,
+ [c_void_p],
+ )
# Source
_setsig(_dylib.ireeCompilerSourceDestroy, None, [c_void_p])
_setsig(
@@ -297,7 +302,7 @@
@staticmethod
def wrap_buffer(
- session: Session, buffer, *, buffer_name: Optional[str] = None
+ session: Session, buffer, *, buffer_name: str = "source.mlir"
) -> "Source":
view = memoryview(buffer)
if not view.c_contiguous:
@@ -352,6 +357,20 @@
def enable_console_diagnostics(self):
_dylib.ireeCompilerInvocationEnableConsoleDiagnostics(self._inv_p)
+ def export_module(self):
+ """Exports the module."""
+ from .. import ir
+
+ if self._retained_module_op:
+ return self._retained_module_op
+ module_ptr = _dylib.ireeCompilerInvocationExportStealModule(self._inv_p)
+ if not module_ptr:
+ raise RuntimeError("Module is not available to export")
+ capsule = PyCapsule_New(module_ptr, MLIR_PYTHON_CAPSULE_OPERATION, None)
+ operation = ir.Operation._CAPICreate(capsule)
+ self._retained_module_op = operation
+ return operation
+
def import_module(self, module_op) -> bool:
self._retained_module_op = module_op
# Import module.
diff --git a/compiler/bindings/python/test/api/api_test.py b/compiler/bindings/python/test/api/api_test.py
index b4de41f..ab70fcd 100644
--- a/compiler/bindings/python/test/api/api_test.py
+++ b/compiler/bindings/python/test/api/api_test.py
@@ -7,10 +7,12 @@
# TODO: Upstream this to IREE.
import platform
+
if platform.system() == "Windows":
print("WARNING: Test disabled on Windows due to suspected MSVC bug")
else:
from contextlib import closing
+ import os
from pathlib import Path
import tempfile
import unittest
@@ -18,7 +20,6 @@
from iree.compiler.api import *
from iree.compiler import ir
-
class DlFlagsTest(unittest.TestCase):
def testDefaultFlags(self):
session = Session()
@@ -47,12 +48,88 @@
with self.assertRaises(ValueError):
session.set_flags("--does-not-exist=1")
-
class DlInvocationTest(unittest.TestCase):
def testCreate(self):
session = Session()
inv = session.invocation()
+ def testInputFile(self):
+ session = Session()
+ inv = session.invocation()
+ with tempfile.NamedTemporaryFile("w", delete=False) as tf:
+ tf.write("module {}")
+ tf.close()
+ try:
+ source = Source.open_file(session, tf.name)
+ inv.parse_source(source)
+ finally:
+ os.unlink(tf.name)
+ out = Output.open_membuffer()
+ inv.output_ir(out)
+ mem = out.map_memory()
+ self.assertIn(b"module", bytes(mem))
+ out.close()
+
+ def testInputBuffer(self):
+ session = Session()
+ inv = session.invocation()
+ source = Source.wrap_buffer(session, b"builtin.module {}")
+ inv.parse_source(source)
+ out = Output.open_membuffer()
+ inv.output_ir(out)
+ mem = out.map_memory()
+ self.assertIn(b"module", bytes(mem))
+ out.close()
+
+ def testOutputBytecode(self):
+ session = Session()
+ inv = session.invocation()
+ source = Source.wrap_buffer(session, b"builtin.module {}")
+ inv.parse_source(source)
+ out = Output.open_membuffer()
+ inv.output_ir_bytecode(out)
+ mem = out.map_memory()
+ self.assertIn(b"module", bytes(mem))
+ out.close()
+
+ def testExecutePassPipeline(self):
+ session = Session()
+ inv = session.invocation()
+ source = Source.wrap_buffer(
+ session,
+ b"""
+ builtin.module {
+ func.func private @foobar() -> ()
+ }
+ """,
+ )
+ inv.parse_source(source)
+ inv.execute_text_pass_pipeline("symbol-dce")
+ out = Output.open_membuffer()
+ inv.output_ir(out)
+ mem = out.map_memory()
+ self.assertNotIn(b"func", bytes(mem))
+ out.close()
+
+ def testExecuteStdPipeline(self):
+ session = Session()
+ session.set_flags("--iree-hal-target-backends=vmvx")
+ inv = session.invocation()
+ source = Source.wrap_buffer(
+ session,
+ b"""
+ builtin.module {
+ func.func @main(%arg0: i32) -> (i32) {
+ return %arg0 : i32
+ }
+ }
+ """,
+ )
+ inv.parse_source(source)
+ inv.execute()
+ out = Output.open_membuffer()
+ inv.output_vm_bytecode(out)
+ out.close()
class DlOutputTest(unittest.TestCase):
def testOpenMembuffer(self):
@@ -97,7 +174,6 @@
finally:
Path(file_path).unlink()
-
class DlInteropTest(unittest.TestCase):
def testContextFromSession(self):
s = Session()
@@ -123,53 +199,21 @@
print(contents)
self.assertIn('test.test = "working"', contents)
-
- # TODO: Port these to the current API.
- # class CompilerAPITest(unittest.TestCase):
- # def testCreate(self):
- # compiler = Compiler()
-
- # def testLoadFromBytes(self):
- # compiler = Compiler()
- # p = compiler.load_buffer("module {}".encode(), buffer_name="foobar")
-
- # def testPipelineClose(self):
- # compiler = Compiler()
- # p = compiler.load_buffer("module {}".encode(), buffer_name="foobar")
- # p.close()
-
- # def testLoadFromFile(self):
- # compiler = Compiler()
- # with tempfile.NamedTemporaryFile("w", delete=False) as tf:
- # tf.write("module {}")
- # tf.close()
- # p = compiler.load_file(tf.name)
- # p.close()
-
- # def testExecuteIR(self):
- # compiler = Compiler()
- # p = compiler.load_buffer("module {}".encode(), buffer_name="foobar")
- # p.execute()
- # with closing(compiler.open_output_membuffer()) as output:
- # p.output_ir(output)
- # ir_contents = bytes(output.map_memory())
- # print(ir_contents)
- # self.assertEqual(b"module {\n}", ir_contents)
-
- # def testExecuteVMFB(self):
- # compiler = Compiler()
- # compiler.set_flags("--iree-hal-target-backends=vmvx")
- # p = compiler.load_buffer(
- # "module {func.func @main(%arg0: i32) -> (i32) {return %arg0 : i32}}".encode(),
- # buffer_name="foobar",
- # )
- # p.execute()
- # with closing(compiler.open_output_membuffer()) as output:
- # p.output_vm_bytecode(output)
- # ir_contents = bytes(output.map_memory())
- # print(len(ir_contents))
- # self.assertGreater(len(ir_contents), 0)
-
+ def testExportModule(self):
+ s = Session()
+ with ir.Location.unknown(s.context):
+ source = Source.wrap_buffer(s, b"builtin.module {}")
+ inv = s.invocation()
+ self.assertTrue(inv.parse_source(source))
+ module_op = inv.export_module()
+ module_op.attributes["test.test"] = ir.Attribute.parse('"working"')
+ # Round-trip it back through an Output and verify that the attribute
+ # we set is still there.
+ output = Output.open_membuffer()
+ inv.output_ir(output)
+ contents = bytes(output.map_memory()).decode()
+ print(contents)
+ self.assertIn('test.test = "working"', contents)
if __name__ == "__main__":
unittest.main()
diff --git a/compiler/src/iree/compiler/API/Internal/Embed.cpp b/compiler/src/iree/compiler/API/Internal/Embed.cpp
index 36801a9..ce71aab 100644
--- a/compiler/src/iree/compiler/API/Internal/Embed.cpp
+++ b/compiler/src/iree/compiler/API/Internal/Embed.cpp
@@ -536,6 +536,7 @@
bool initializeInvocation();
std::unique_ptr<PassManager> createPassManager();
bool parseSource(Source &source);
+ Operation *exportModule();
bool importModule(Operation *inputModule, bool steal);
bool runPipeline(enum iree_compiler_pipeline_t pipeline);
bool runTextualPassPipeline(const char *textPassPipeline);
@@ -708,6 +709,13 @@
return true;
}
+Operation *Invocation::exportModule() {
+ if (!parsedModuleIsOwned)
+ return nullptr;
+ parsedModuleIsOwned = false;
+ return parsedModule;
+}
+
bool Invocation::runPipeline(enum iree_compiler_pipeline_t pipeline) {
auto passManager = createPassManager();
switch (pipeline) {
@@ -1377,3 +1385,8 @@
MlirOperation moduleOp) {
return unwrap(inv)->importModule(unwrap(moduleOp), /*steal=*/true);
}
+
+MlirOperation
+ireeCompilerInvocationExportStealModule(iree_compiler_invocation_t *inv) {
+ return wrap(unwrap(inv)->exportModule());
+}
diff --git a/compiler/src/iree/compiler/API/MLIRInterop.h b/compiler/src/iree/compiler/API/MLIRInterop.h
index 0a83670..a249854 100644
--- a/compiler/src/iree/compiler/API/MLIRInterop.h
+++ b/compiler/src/iree/compiler/API/MLIRInterop.h
@@ -65,6 +65,12 @@
ireeCompilerInvocationImportStealModule(iree_compiler_invocation_t *inv,
MlirOperation moduleOp);
+// Exports the owned module from the invocation, transferring ownership to the
+// caller.
+MLIR_CAPI_EXPORTED
+MlirOperation
+ireeCompilerInvocationExportStealModule(iree_compiler_invocation_t *inv);
+
#ifdef __cplusplus
}
#endif