Adds CAPI interop to export a module from an invocation. (#14722)

Also adds Python API to use it and get a Python Operation.
diff --git a/compiler/bindings/python/iree/compiler/api/ctypes_dl.py b/compiler/bindings/python/iree/compiler/api/ctypes_dl.py
index 0275095..aeb0ef2 100644
--- a/compiler/bindings/python/iree/compiler/api/ctypes_dl.py
+++ b/compiler/bindings/python/iree/compiler/api/ctypes_dl.py
@@ -103,6 +103,11 @@
         c_bool,
         [c_void_p, c_void_p],
     )
+    _setsig(
+        _dylib.ireeCompilerInvocationExportStealModule,
+        c_void_p,
+        [c_void_p],
+    )
     # Source
     _setsig(_dylib.ireeCompilerSourceDestroy, None, [c_void_p])
     _setsig(
@@ -297,7 +302,7 @@
 
     @staticmethod
     def wrap_buffer(
-        session: Session, buffer, *, buffer_name: Optional[str] = None
+        session: Session, buffer, *, buffer_name: str = "source.mlir"
     ) -> "Source":
         view = memoryview(buffer)
         if not view.c_contiguous:
@@ -352,6 +357,20 @@
     def enable_console_diagnostics(self):
         _dylib.ireeCompilerInvocationEnableConsoleDiagnostics(self._inv_p)
 
+    def export_module(self):
+        """Exports the module."""
+        from .. import ir
+
+        if self._retained_module_op:
+            return self._retained_module_op
+        module_ptr = _dylib.ireeCompilerInvocationExportStealModule(self._inv_p)
+        if not module_ptr:
+            raise RuntimeError("Module is not available to export")
+        capsule = PyCapsule_New(module_ptr, MLIR_PYTHON_CAPSULE_OPERATION, None)
+        operation = ir.Operation._CAPICreate(capsule)
+        self._retained_module_op = operation
+        return operation
+
     def import_module(self, module_op) -> bool:
         self._retained_module_op = module_op
         # Import module.
diff --git a/compiler/bindings/python/test/api/api_test.py b/compiler/bindings/python/test/api/api_test.py
index b4de41f..ab70fcd 100644
--- a/compiler/bindings/python/test/api/api_test.py
+++ b/compiler/bindings/python/test/api/api_test.py
@@ -7,10 +7,12 @@
 # TODO: Upstream this to IREE.
 
 import platform
+
 if platform.system() == "Windows":
     print("WARNING: Test disabled on Windows due to suspected MSVC bug")
 else:
     from contextlib import closing
+    import os
     from pathlib import Path
     import tempfile
     import unittest
@@ -18,7 +20,6 @@
     from iree.compiler.api import *
     from iree.compiler import ir
 
-
     class DlFlagsTest(unittest.TestCase):
         def testDefaultFlags(self):
             session = Session()
@@ -47,12 +48,88 @@
             with self.assertRaises(ValueError):
                 session.set_flags("--does-not-exist=1")
 
-
     class DlInvocationTest(unittest.TestCase):
         def testCreate(self):
             session = Session()
             inv = session.invocation()
 
+        def testInputFile(self):
+            session = Session()
+            inv = session.invocation()
+            with tempfile.NamedTemporaryFile("w", delete=False) as tf:
+                tf.write("module {}")
+                tf.close()
+            try:
+                source = Source.open_file(session, tf.name)
+                inv.parse_source(source)
+            finally:
+                os.unlink(tf.name)
+            out = Output.open_membuffer()
+            inv.output_ir(out)
+            mem = out.map_memory()
+            self.assertIn(b"module", bytes(mem))
+            out.close()
+
+        def testInputBuffer(self):
+            session = Session()
+            inv = session.invocation()
+            source = Source.wrap_buffer(session, b"builtin.module {}")
+            inv.parse_source(source)
+            out = Output.open_membuffer()
+            inv.output_ir(out)
+            mem = out.map_memory()
+            self.assertIn(b"module", bytes(mem))
+            out.close()
+
+        def testOutputBytecode(self):
+            session = Session()
+            inv = session.invocation()
+            source = Source.wrap_buffer(session, b"builtin.module {}")
+            inv.parse_source(source)
+            out = Output.open_membuffer()
+            inv.output_ir_bytecode(out)
+            mem = out.map_memory()
+            self.assertIn(b"module", bytes(mem))
+            out.close()
+
+        def testExecutePassPipeline(self):
+            session = Session()
+            inv = session.invocation()
+            source = Source.wrap_buffer(
+                session,
+                b"""
+                builtin.module {
+                    func.func private @foobar() -> ()
+                }
+                """,
+            )
+            inv.parse_source(source)
+            inv.execute_text_pass_pipeline("symbol-dce")
+            out = Output.open_membuffer()
+            inv.output_ir(out)
+            mem = out.map_memory()
+            self.assertNotIn(b"func", bytes(mem))
+            out.close()
+
+        def testExecuteStdPipeline(self):
+            session = Session()
+            session.set_flags("--iree-hal-target-backends=vmvx")
+            inv = session.invocation()
+            source = Source.wrap_buffer(
+                session,
+                b"""
+                builtin.module {
+                    func.func @main(%arg0: i32) -> (i32) {
+                        return %arg0 : i32
+                    }
+                }
+                """,
+            )
+            inv.parse_source(source)
+            inv.execute()
+            out = Output.open_membuffer()
+            inv.output_vm_bytecode(out)
+            out.close()
 
     class DlOutputTest(unittest.TestCase):
         def testOpenMembuffer(self):
@@ -97,7 +174,6 @@
             finally:
                 Path(file_path).unlink()
 
-
     class DlInteropTest(unittest.TestCase):
         def testContextFromSession(self):
             s = Session()
@@ -123,53 +199,21 @@
                 print(contents)
                 self.assertIn('test.test = "working"', contents)
 
-
-    # TODO: Port these to the current API.
-    # class CompilerAPITest(unittest.TestCase):
-    #     def testCreate(self):
-    #         compiler = Compiler()
-
-    #     def testLoadFromBytes(self):
-    #         compiler = Compiler()
-    #         p = compiler.load_buffer("module {}".encode(), buffer_name="foobar")
-
-    #     def testPipelineClose(self):
-    #         compiler = Compiler()
-    #         p = compiler.load_buffer("module {}".encode(), buffer_name="foobar")
-    #         p.close()
-
-    #     def testLoadFromFile(self):
-    #         compiler = Compiler()
-    #         with tempfile.NamedTemporaryFile("w", delete=False) as tf:
-    #             tf.write("module {}")
-    #             tf.close()
-    #             p = compiler.load_file(tf.name)
-    #             p.close()
-
-    #     def testExecuteIR(self):
-    #         compiler = Compiler()
-    #         p = compiler.load_buffer("module {}".encode(), buffer_name="foobar")
-    #         p.execute()
-    #         with closing(compiler.open_output_membuffer()) as output:
-    #             p.output_ir(output)
-    #             ir_contents = bytes(output.map_memory())
-    #             print(ir_contents)
-    #             self.assertEqual(b"module {\n}", ir_contents)
-
-    #     def testExecuteVMFB(self):
-    #         compiler = Compiler()
-    #         compiler.set_flags("--iree-hal-target-backends=vmvx")
-    #         p = compiler.load_buffer(
-    #             "module {func.func @main(%arg0: i32) -> (i32) {return %arg0 : i32}}".encode(),
-    #             buffer_name="foobar",
-    #         )
-    #         p.execute()
-    #         with closing(compiler.open_output_membuffer()) as output:
-    #             p.output_vm_bytecode(output)
-    #             ir_contents = bytes(output.map_memory())
-    #             print(len(ir_contents))
-    #             self.assertGreater(len(ir_contents), 0)
-
+        def testExportModule(self):
+            s = Session()
+            with ir.Location.unknown(s.context):
+                source = Source.wrap_buffer(s, b"builtin.module {}")
+                inv = s.invocation()
+                self.assertTrue(inv.parse_source(source))
+                module_op = inv.export_module()
+                module_op.attributes["test.test"] = ir.Attribute.parse('"working"')
+                # Round-trip it back through an Output and verify that the attribute
+                # we set is still there.
+                output = Output.open_membuffer()
+                inv.output_ir(output)
+                contents = bytes(output.map_memory()).decode()
+                print(contents)
+                self.assertIn('test.test = "working"', contents)
 
     if __name__ == "__main__":
         unittest.main()
diff --git a/compiler/src/iree/compiler/API/Internal/Embed.cpp b/compiler/src/iree/compiler/API/Internal/Embed.cpp
index 36801a9..ce71aab 100644
--- a/compiler/src/iree/compiler/API/Internal/Embed.cpp
+++ b/compiler/src/iree/compiler/API/Internal/Embed.cpp
@@ -536,6 +536,7 @@
   bool initializeInvocation();
   std::unique_ptr<PassManager> createPassManager();
   bool parseSource(Source &source);
+  Operation *exportModule();
   bool importModule(Operation *inputModule, bool steal);
   bool runPipeline(enum iree_compiler_pipeline_t pipeline);
   bool runTextualPassPipeline(const char *textPassPipeline);
@@ -708,6 +709,13 @@
   return true;
 }
 
+Operation *Invocation::exportModule() {
+  if (!parsedModuleIsOwned)
+    return nullptr;
+  parsedModuleIsOwned = false;
+  return parsedModule;
+}
+
 bool Invocation::runPipeline(enum iree_compiler_pipeline_t pipeline) {
   auto passManager = createPassManager();
   switch (pipeline) {
@@ -1377,3 +1385,8 @@
                                              MlirOperation moduleOp) {
   return unwrap(inv)->importModule(unwrap(moduleOp), /*steal=*/true);
 }
+
+MlirOperation
+ireeCompilerInvocationExportStealModule(iree_compiler_invocation_t *inv) {
+  return wrap(unwrap(inv)->exportModule());
+}
diff --git a/compiler/src/iree/compiler/API/MLIRInterop.h b/compiler/src/iree/compiler/API/MLIRInterop.h
index 0a83670..a249854 100644
--- a/compiler/src/iree/compiler/API/MLIRInterop.h
+++ b/compiler/src/iree/compiler/API/MLIRInterop.h
@@ -65,6 +65,12 @@
 ireeCompilerInvocationImportStealModule(iree_compiler_invocation_t *inv,
                                         MlirOperation moduleOp);
 
+// Exports the owned module from the invocation, transferring ownership to the
+// caller.
+MLIR_CAPI_EXPORTED
+MlirOperation
+ireeCompilerInvocationExportStealModule(iree_compiler_invocation_t *inv);
+
 #ifdef __cplusplus
 }
 #endif