[iree.build] Implement iree-compile action. (#18993)

This covers a lot of ground to actually get the full compile pipeline
working:

* Adds dep metadata and uses it to plumb input type to the compiler.
* Fully generalizes arg parsing and management.
* Adds a skeleton TargetMachine (used to drive the compiler) and teaches
it how to handle/pass-through standard IREE flags for simple single
device configurations. (Complex cases will be handled with some kind of
machine spec file and additional compiler entry-points for setting raw
target info).
* Plumbs through both in process and out of process compiler invocation.
* Robustifies the iree.compiler.api so that it can't do out of order
destruction during process cleanup.

Much more to do:

* I've got enough now to add a torch.export action to iree-turbine. That
will be run against the out of process pool.
* Quite a bit of compiler invocation ergonomics needed (reproducers,
error/output handling, etc).
* Need to take a pass through and make pretty console reporting/logging
throughout.

---------

Signed-off-by: Stella Laurenzo <stellaraccident@gmail.com>
diff --git a/compiler/bindings/python/CMakeLists.txt b/compiler/bindings/python/CMakeLists.txt
index bc8119f..76caebb 100644
--- a/compiler/bindings/python/CMakeLists.txt
+++ b/compiler/bindings/python/CMakeLists.txt
@@ -249,11 +249,15 @@
 SOURCES
   __init__.py
   __main__.py
+  args.py
+  compile_actions.py
   executor.py
   lang.py
   main.py
+  metadata.py
   net_actions.py
   onnx_actions.py
+  target_machine.py
 )
 
 add_mlir_python_modules(IREECompilerBuildPythonModules
diff --git a/compiler/bindings/python/iree/build/__init__.py b/compiler/bindings/python/iree/build/__init__.py
index 95ee3f7..3cc054e 100644
--- a/compiler/bindings/python/iree/build/__init__.py
+++ b/compiler/bindings/python/iree/build/__init__.py
@@ -8,5 +8,7 @@
 
 from iree.build.lang import *
 from iree.build.main import *
+
+from iree.build.compile_actions import *
 from iree.build.net_actions import *
 from iree.build.onnx_actions import *
diff --git a/compiler/bindings/python/iree/build/args.py b/compiler/bindings/python/iree/build/args.py
new file mode 100644
index 0000000..9a2482a
--- /dev/null
+++ b/compiler/bindings/python/iree/build/args.py
@@ -0,0 +1,187 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import Callable, Generator, TypeVar
+
+import argparse
+import contextlib
+import functools
+import inspect
+import threading
+
+from typing import Callable
+
+_locals = threading.local()
+_ALL_ARG_REGISTRARS: list[Callable[[argparse.ArgumentParser], None]] = []
+_ALL_ARG_HANDLERS: list[Callable[[argparse.Namespace], None]] = []
+
+
+def register_arg_parser_callback(registrar: Callable[[argparse.ArgumentParser], None]):
+    """Decorator that adds a global argument registration callback.
+
+    This callback will be invoked when a new ArgumentParser is constructed.
+    """
+    _ALL_ARG_REGISTRARS.append(registrar)
+    return registrar
+
+
+def register_arg_handler_callback(handler: Callable[[argparse.Namespace], None]):
+    """Decorator that registers a handler to be run on global arguments at startup."""
+    _ALL_ARG_HANDLERS.append(handler)
+    return handler
+
+
+def configure_arg_parser(p: argparse.ArgumentParser):
+    """Invokes all callbacks from `register_arg_parser_callback` on the parser."""
+    for callback in _ALL_ARG_REGISTRARS:
+        callback(p)
+
+
+def run_global_arg_handlers(ns: argparse.Namespace):
+    """Invokes all global argument handlers."""
+    for h in _ALL_ARG_HANDLERS:
+        h(ns)
+
+
+@contextlib.contextmanager
+def argument_namespace_context(ns: argparse.Namespace):
+    """Establish that given namespace as the current namespace for this thread.
+
+    Note that as a thread local, this does not propagate to child threads or
+    sub-processes. This means that all argument management must be done during
+    action setup and action invocations will not typically have access to args.
+    """
+    if not hasattr(_locals, "arg_ns_stack"):
+        _locals.arg_ns_stack = []
+    _locals.arg_ns_stack.append(ns)
+    try:
+        yield ns
+    finally:
+        _locals.arg_ns_stack.pop()
+
+
+def current_args_namespace() -> argparse.Namespace:
+    try:
+        return _locals.arg_ns_stack[-1]
+    except (AttributeError, IndexError):
+        raise AssertionError(
+            "No current argument namespace: Is it possible you are trying to resolve "
+            "arguments from another thread or process"
+        )
+
+
+_Decorated = TypeVar("_Decorated", bound=Callable)
+
+
+def expand_cl_arg_defaults(wrapped: _Decorated) -> _Decorated:
+    sig = inspect.signature(wrapped)
+
+    def wrapper(*args, **kwargs):
+        args_ns = current_args_namespace()
+        bound = sig.bind(*args, **kwargs)
+        bound.apply_defaults()
+
+        def filter(arg):
+            if isinstance(arg, ClArgRef):
+                return arg.resolve(args_ns)
+            return arg
+
+        new_args = [filter(arg) for arg in bound.args]
+        new_kwargs = {k: filter(v) for k, v in bound.kwargs.items()}
+        return wrapped(*new_args, **new_kwargs)
+
+    functools.update_wrapper(wrapper, wrapped)
+    return wrapper
+
+
+class ClArgRef:
+    """Used in default values of function arguments to indicate that the default should
+    be derived from an argument reference.
+
+    Actually defining the argument must be done elsewhere.
+
+    See `cl_arg_ref()` for canonical use.
+    """
+
+    def __init__(self, dest: str):
+        self.dest = dest
+
+    def resolve(self, arg_namespace: argparse.Namespace):
+        try:
+            return getattr(arg_namespace, self.dest)
+        except AttributeError as e:
+            raise RuntimeError(
+                f"Unable to resolve command line argument '{self.dest}' in namespace"
+            ) from e
+
+
+def cl_arg_ref(dest: str):
+    """Used as a default value for functions wrapped in @expand_cl_defaults to indicate
+    that an argument must come from the command line environment.
+
+    Note that this does not have a typing annotation, allowing the argument to be
+    annotated with a type, assuming that resolution will happen dynamically in some
+    fashion.
+    """
+    return ClArgRef(dest)
+
+
+class ClArg(ClArgRef):
+    """Used in default values of function arguments to indicate that an argument needs
+    to be defined and referenced.
+
+    This is used in user-defined entry points, and the executor has special logic to
+    collect all needed arguments automatically.
+
+    See `cl_arg()` for canonical use.
+    """
+
+    def __init__(self, name, dest: str, **add_argument_kw):
+        super().__init__(dest)
+        self.name = name
+        self.add_argument_kw = add_argument_kw
+
+    def define_arg(self, parser: argparse.ArgumentParser):
+        parser.add_argument(f"--{self.name}", dest=self.dest, **self.add_argument_kw)
+
+
+def cl_arg(name: str, *, action=None, default=None, type=None, help=None):
+    """Used to define or reference a command-line argument from within actions
+    and entry-points.
+
+    Keywords have the same interpretation as `ArgumentParser.add_argument()`.
+
+    Any ClArg set as a default value for an argument to an `entrypoint` will be
+    added to the global argument parser. Any particular argument name can only be
+    registered once and must not conflict with a built-in command line option.
+    The implication of this is that for single-use arguments, the `=cl_arg(...)`
+    can just be added as a default argument. Otherwise, for shared arguments,
+    it should be created at the module level and referenced.
+
+    When called, any entrypoint arguments that do not have an explicit keyword
+    set will get their value from the command line environment.
+
+    Note that this does not have a typing annotation, allowing the argument to be
+    annotated with a type, assuming that resolution will happen dynamically in some
+    fashion.
+    """
+    if name.startswith("-"):
+        raise ValueError("cl_arg name must not be prefixed with dashes")
+    dest = name.replace("-", "_")
+    return ClArg(name, action=action, default=default, type=type, dest=dest, help=help)
+
+
+def extract_cl_arg_defs(callable: Callable) -> Generator[ClArg, None, None]:
+    """Extracts all `ClArg` default values from a callable.
+
+    This is used in order to eagerly register argument definitions for some set
+    of functions.
+    """
+    sig = inspect.signature(callable)
+    for p in sig.parameters.values():
+        def_value = p.default
+        if isinstance(def_value, ClArg):
+            yield def_value
diff --git a/compiler/bindings/python/iree/build/compile_actions.py b/compiler/bindings/python/iree/build/compile_actions.py
new file mode 100644
index 0000000..2c2056c
--- /dev/null
+++ b/compiler/bindings/python/iree/build/compile_actions.py
@@ -0,0 +1,208 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import argparse
+import shlex
+
+import iree.compiler.api as compiler_api
+import iree.compiler.tools as compiler_tools
+
+from iree.build.args import (
+    expand_cl_arg_defaults,
+    register_arg_handler_callback,
+    register_arg_parser_callback,
+    cl_arg_ref,
+)
+
+from iree.build.executor import (
+    BuildAction,
+    BuildContext,
+    BuildFile,
+    BuildFileLike,
+    FileNamespace,
+)
+
+from iree.build.metadata import CompileSourceMeta
+from iree.build.target_machine import compute_target_machines_from_flags
+
+__all__ = [
+    "compile",
+]
+
+
+@register_arg_parser_callback
+def _(p: argparse.ArgumentParser):
+    g = p.add_argument_group(
+        title="IREE Compiler Options",
+        description="Global options controlling invocation of iree-compile",
+    )
+    g.add_argument(
+        "--iree-compile-out-of-process",
+        action=argparse.BooleanOptionalAction,
+        help="Invokes iree-compiler as an out of process executable (the default is to "
+        "invoke it in-process via API bindings). This can make debugging somewhat "
+        "easier and also grants access to global command line options that may not "
+        "otherwise be available.",
+    )
+    g.add_argument(
+        "--iree-compile-extra-args",
+        help="Extra arguments to pass to iree-compile. When running in-process, these "
+        "will be passed as globals to the library and effect all compilation in the "
+        "process. These are split with shlex rules.",
+    )
+
+
+@register_arg_handler_callback
+def _(ns: argparse.Namespace):
+    in_process = not ns.iree_compile_out_of_process
+    extra_args_str = ns.iree_compile_extra_args
+    if in_process and extra_args_str:
+        # TODO: This is very unsafe. If called multiple times (i.e. in a library)
+        # or with illegal arguments, the program will abort. The safe way to do
+        # this is to spawn one child process and route all "in process" compilation
+        # there. This would allow explicit control of startup/shutdown and would
+        # provide isolation in the event of a compiler crash. It is still important
+        # for a single process to handle all compilation activities since this
+        # allows global compiler resources (like threads) to be pooled and not
+        # saturate the machine resources.
+        extra_args_list = shlex.split(extra_args_str)
+        compiler_api._initializeGlobalCL("unused_prog_name", *extra_args_list)
+
+
+class CompilerInvocation:
+    @expand_cl_arg_defaults
+    def __init__(
+        self,
+        *,
+        input_file: BuildFile,
+        output_file: BuildFile,
+        out_of_process: bool = cl_arg_ref("iree_compile_out_of_process"),
+        extra_args_str=cl_arg_ref("iree_compile_extra_args"),
+    ):
+        self.input_file = input_file
+        self.output_file = output_file
+        # We manage most flags as keyword values that can have at most one
+        # setting.
+        self.kw_flags: dict[str, str | None] = {}
+        # Flags can also be set free-form. These are always added to the command
+        # line after the kw_flags.
+        self.extra_flags: list[str] = []
+        self.out_of_process = out_of_process
+
+        if extra_args_str:
+            self.extra_args = shlex.split(extra_args_str)
+        else:
+            self.extra_args = []
+
+    def run(self):
+        raw_flags: list[str] = []
+
+        # Set any defaults derived from the input_file metadata. These are set
+        # first because they can be overriden by explicit flag settings.
+        meta = CompileSourceMeta.get(self.input_file)
+        raw_flags.append(f"--iree-input-type={meta.input_type}")
+
+        # Process kw_flags.
+        for key, value in self.kw_flags.items():
+            if value is None:
+                raw_flags.append(f"--{key}")
+            else:
+                raw_flags.append(f"--{key}={value}")
+
+        # Process extra_flags.
+        for raw in self.extra_flags:
+            raw_flags.append(raw)
+
+        if self.out_of_process:
+            self.run_out_of_process(raw_flags)
+        else:
+            self.run_inprocess(raw_flags)
+
+    def run_inprocess(self, flags: list[str]):
+        with compiler_api.Session() as session:
+            session.set_flags(*flags)
+            with compiler_api.Invocation(session) as inv, compiler_api.Source.open_file(
+                session, str(self.input_file.get_fs_path())
+            ) as source, compiler_api.Output.open_file(
+                str(self.output_file.get_fs_path())
+            ) as output:
+                inv.enable_console_diagnostics()
+                inv.parse_source(source)
+                if not inv.execute():
+                    raise RuntimeError("COMPILE FAILED (TODO)")
+                inv.output_vm_bytecode(output)
+                output.keep()
+
+    def run_out_of_process(self, flags: list[str]):
+        # TODO: This Python executable wrapper is really long in the tooth. We should
+        # just invoke iree-compile directly (which would also let us have a flag for
+        # the path to it).
+        all_extra_args = self.extra_args + flags
+        compiler_tools.compile_file(
+            str(self.input_file.get_fs_path()),
+            output_file=str(self.output_file.get_fs_path()),
+            extra_args=self.extra_args + flags,
+        )
+
+
+def compile(
+    *,
+    name: str,
+    source: BuildFileLike,
+    target_default: bool = True,
+) -> tuple[BuildFile]:
+    """Invokes iree-compile on a source file, producing binaries for one or more target
+    machines.
+
+    Args:
+      name: The logical name of the compilation command. This is used as the stem
+        for multiple kinds of output files.
+      source: Input source file.
+      target_default: Whether to use command line arguments to compute a target
+        machine configuration (default True). This would be set to False to explicitly
+        depend on target information contained in the source file and not require
+        any target flags passed to the build tool.
+    """
+    context = BuildContext.current()
+    input_file = context.file(source)
+    if target_default:
+        # Compute the target machines from flags and create one compilation for each.
+        tms = compute_target_machines_from_flags()
+        output_files: list[BuildFile] = []
+        for tm in tms:
+            output_file = context.allocate_file(
+                f"{name}_{tm.target_spec}.vmfb", namespace=FileNamespace.BIN
+            )
+            inv = CompilerInvocation(input_file=input_file, output_file=output_file)
+            inv.extra_flags.extend(tm.flag_list)
+            CompileAction(
+                inv,
+                desc=f"Compiling {input_file} (for {tm.target_spec})",
+                executor=context.executor,
+            )
+            output_files.append(output_file)
+        return output_files
+    else:
+        # The compilation is self contained, so just directly compile it.
+        output_file = context.allocate_file(f"{name}.vmfb", namespace=FileNamespace.BIN)
+        inv = CompilerInvocation(input_file=input_file, output_file=output_file)
+        CompileAction(
+            inv,
+            desc=f"Compiling {name}",
+            executor=context.executor,
+        )
+        return output_file
+
+
+class CompileAction(BuildAction):
+    def __init__(self, inv: CompilerInvocation, **kwargs):
+        super().__init__(**kwargs)
+        self.inv = inv
+        self.inv.output_file.deps.add(self)
+        self.deps.add(self.inv.input_file)
+
+    def _invoke(self):
+        self.inv.run()
diff --git a/compiler/bindings/python/iree/build/executor.py b/compiler/bindings/python/iree/build/executor.py
index 35b9173..1d580bd 100644
--- a/compiler/bindings/python/iree/build/executor.py
+++ b/compiler/bindings/python/iree/build/executor.py
@@ -4,20 +4,23 @@
 # See https://llvm.org/LICENSE.txt for license information.
 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-from typing import Callable, Collection, Generator, IO
+from typing import Callable, Collection, IO, Type, TypeVar
 
 import abc
-import argparse
 import concurrent.futures
 import enum
-import inspect
 import multiprocessing
-import sys
 import time
 import traceback
 from pathlib import Path
 import threading
 
+from iree.build.args import (
+    current_args_namespace,
+    expand_cl_arg_defaults,
+    extract_cl_arg_defs,
+)
+
 _locals = threading.local()
 
 
@@ -52,24 +55,6 @@
     return f"{prefix}/{suffix}"
 
 
-class ClArg:
-    def __init__(self, name, dest: str, **add_argument_kw):
-        self.name = name
-        self.dest = dest
-        self.add_argument_kw = add_argument_kw
-
-    def define_arg(self, parser: argparse.ArgumentParser):
-        parser.add_argument(f"--{self.name}", dest=self.dest, **self.add_argument_kw)
-
-    def resolve(self, arg_namespace: argparse.Namespace):
-        try:
-            return getattr(arg_namespace, self.dest)
-        except AttributeError as e:
-            raise RuntimeError(
-                f"Unable to resolve command line argument '{self.dest}' in namespace"
-            ) from e
-
-
 class Entrypoint:
     def __init__(
         self,
@@ -79,17 +64,12 @@
     ):
         self.name = name
         self.description = description
-        self._wrapped = wrapped
-
-    def cl_args(self) -> Generator[ClArg, None, None]:
-        sig = inspect.signature(self._wrapped)
-        for p in sig.parameters.values():
-            def_value = p.default
-            if isinstance(def_value, ClArg):
-                yield def_value
+        self.cl_arg_defs = list(extract_cl_arg_defs(wrapped))
+        self._wrapped = expand_cl_arg_defaults(wrapped)
 
     def __call__(self, *args, **kwargs):
         parent_context = BuildContext.current()
+        args_ns = current_args_namespace()
         bep = BuildEntrypoint(
             join_namespace(parent_context.path, self.name),
             parent_context.executor,
@@ -97,18 +77,7 @@
         )
         parent_context.executor.entrypoints.append(bep)
         with bep:
-            sig = inspect.signature(self._wrapped)
-            bound = sig.bind(*args, **kwargs)
-            bound.apply_defaults()
-
-            def filter(arg):
-                if isinstance(arg, ClArg):
-                    return arg.resolve(parent_context.executor.args_namespace)
-                return arg
-
-            new_args = [filter(arg) for arg in bound.args]
-            new_kwargs = {k: filter(v) for k, v in bound.kwargs.items()}
-            results = self._wrapped(*new_args, **new_kwargs)
+            results = self._wrapped(*args, **kwargs)
             if results is not None:
                 files = bep.files(results)
                 bep.deps.update(files)
@@ -119,15 +88,12 @@
 class Executor:
     """Executor that all build contexts share."""
 
-    def __init__(
-        self, output_dir: Path, args_namespace: argparse.Namespace, stderr: IO
-    ):
+    def __init__(self, output_dir: Path, stderr: IO):
         self.output_dir = output_dir
         self.verbose_level = 0
         # Keyed by path
         self.all: dict[str, "BuildContext" | "BuildFile"] = {}
         self.entrypoints: list["BuildEntrypoint"] = []
-        self.args_namespace = args_namespace
         self.stderr = stderr
         BuildContext("", self)
 
@@ -189,6 +155,41 @@
             scheduler.shutdown()
 
 
+BuildMetaType = TypeVar("BuildMetaType", bound="BuildMeta")
+
+
+class BuildMeta:
+    """Base class for typed metadata that can be set on a BuildDependency.
+
+    This is an open namespace where each sub-class must have a unique key as the class
+    level attribute `KEY`.
+    """
+
+    def __init__(self):
+        key = getattr(self, "KEY", None)
+        assert isinstance(key, str), "BuildMeta.KEY must be a str"
+
+    @classmethod
+    def get(cls: Type[BuildMetaType], dep: "BuildDependency") -> BuildMetaType:
+        """Gets a metadata instance of this type from a dependency.
+
+        If it does not yet exist, returns the value of `create_default()`, which
+        by default returns a new instance (which is set on the dep).
+        """
+        key = getattr(cls, "KEY", None)
+        assert isinstance(key, str), f"{cls.__name__}.KEY must be a str"
+        instance = dep._metadata.get(key)
+        if instance is None:
+            instance = cls.create_default()
+            dep._metadata[key] = instance
+        return instance
+
+    @classmethod
+    def create_default(cls) -> "BuildMeta":
+        """Creates a default instance."""
+        return cls()
+
+
 class BuildDependency:
     """Base class of entities that can act as a build dependency."""
 
@@ -205,6 +206,9 @@
         self.start_time: float | None = None
         self.finish_time: float | None = None
 
+        # Metadata.
+        self._metadata: dict[str, BuildMeta] = {}
+
     @property
     def is_scheduled(self) -> bool:
         return self.future is not None
@@ -287,8 +291,11 @@
     def __repr__(self):
         return f"Action[{type(self).__name__}]('{self.desc}')"
 
-    @abc.abstractmethod
     def invoke(self):
+        self._invoke()
+
+    @abc.abstractmethod
+    def _invoke(self):
         ...
 
 
@@ -317,7 +324,8 @@
         """
         if not path.startswith("/"):
             path = join_namespace(self.path, path)
-        return BuildFile(executor=self.executor, path=path, namespace=namespace)
+        build_file = BuildFile(executor=self.executor, path=path, namespace=namespace)
+        return build_file
 
     def file(self, file: str | BuildFile) -> BuildFile:
         """Accesses a BuildFile by either string (path) or BuildFile.
@@ -374,9 +382,6 @@
         existing = stack.pop()
         assert existing is self, "Unbalanced BuildContext enter/exit"
 
-    def populate_arg_parser(self, parser: argparse.ArgumentParser):
-        ...
-
 
 class BuildEntrypoint(BuildContext):
     def __init__(self, path: str, executor: Executor, entrypoint: Entrypoint):
@@ -506,7 +511,16 @@
                 return dep
 
             print(f"Scheduling action: {dep}", file=self.stderr)
-            dep.start(self.thread_pool_executor.submit(invoke))
+            if dep.concurrnecy == ActionConcurrency.NONE:
+                invoke()
+            elif dep.concurrnecy == ActionConcurrency.THREAD:
+                dep.start(self.thread_pool_executor.submit(invoke))
+            elif dep.concurrnecy == ActionConcurrency.PROCESS:
+                dep.start(self.process_pool_executor.submit(invoke))
+            else:
+                raise AssertionError(
+                    f"Unhandled ActionConcurrency value: {dep.concurrnecy}"
+                )
         else:
             # Not schedulable. Just mark it as done.
             dep.start(concurrent.futures.Future())
diff --git a/compiler/bindings/python/iree/build/lang.py b/compiler/bindings/python/iree/build/lang.py
index 5cb8779..60f4c9f 100644
--- a/compiler/bindings/python/iree/build/lang.py
+++ b/compiler/bindings/python/iree/build/lang.py
@@ -9,7 +9,8 @@
 import argparse
 import functools
 
-from iree.build.executor import ClArg, Entrypoint
+from iree.build.args import cl_arg  # Export as part of the public API
+from iree.build.executor import Entrypoint
 
 __all__ = [
     "cl_arg",
@@ -28,25 +29,3 @@
     target = Entrypoint(f.__name__, f, description=description)
     functools.wraps(target, f)
     return target
-
-
-def cl_arg(name: str, *, action=None, default=None, type=None, help=None):
-    """Used to define or reference a command-line argument from within actions
-    and entry-points.
-
-    Keywords have the same interpretation as `ArgumentParser.add_argument()`.
-
-    Any ClArg set as a default value for an argument to an `entrypoint` will be
-    added to the global argument parser. Any particular argument name can only be
-    registered once and must not conflict with a built-in command line option.
-    The implication of this is that for single-use arguments, the `=cl_arg(...)`
-    can just be added as a default argument. Otherwise, for shared arguments,
-    it should be created at the module level and referenced.
-
-    When called, any entrypoint arguments that do not have an explicit keyword
-    set will get their value from the command line environment.
-    """
-    if name.startswith("-"):
-        raise ValueError("cl_arg name must not be prefixed with dashes")
-    dest = name.replace("-", "_")
-    return ClArg(name, action=action, default=default, type=type, dest=dest, help=help)
diff --git a/compiler/bindings/python/iree/build/main.py b/compiler/bindings/python/iree/build/main.py
index 272533c..6d40d18 100644
--- a/compiler/bindings/python/iree/build/main.py
+++ b/compiler/bindings/python/iree/build/main.py
@@ -12,7 +12,12 @@
 from pathlib import Path
 import sys
 
-from iree.build.executor import BuildEntrypoint, Entrypoint, Executor
+from iree.build.args import (
+    argument_namespace_context,
+    configure_arg_parser,
+    run_global_arg_handlers,
+)
+from iree.build.executor import BuildEntrypoint, BuildFile, Entrypoint, Executor
 
 __all__ = [
     "iree_build_main",
@@ -117,6 +122,7 @@
             help="Paths of actions to build (default to top-level actions)",
         )
 
+        configure_arg_parser(p)
         self._define_action_arguments(p)
         self.args = self.arg_parser.parse_args(args)
 
@@ -126,7 +132,7 @@
     def _define_action_arguments(self, p: argparse.ArgumentParser):
         user_group = p.add_argument_group("Action defined options")
         for ep in self.top_module.entrypoints.values():
-            for cl_arg in ep.cl_args():
+            for cl_arg in ep.cl_arg_defs:
                 cl_arg.define_arg(user_group)
 
     def _resolve_module_arguments(
@@ -169,15 +175,17 @@
         return rem_args, top_module
 
     def _create_executor(self) -> Executor:
-        executor = Executor(self.args.output_dir, self.args, stderr=self.stderr)
+        executor = Executor(self.args.output_dir, stderr=self.stderr)
         executor.analyze(*self.top_module.entrypoints.values())
         return executor
 
     def run(self):
-        command = self.args.command
-        if command is None:
-            command = self.build_command
-        command()
+        with argument_namespace_context(self.args):
+            run_global_arg_handlers(self.args)
+            command = self.args.command
+            if command is None:
+                command = self.build_command
+            command()
 
     def build_command(self):
         executor = self._create_executor()
@@ -203,7 +211,9 @@
         for build_action in build_actions:
             if isinstance(build_action, BuildEntrypoint):
                 for output in build_action.outputs:
-                    print(f"{output.get_fs_path()}", file=self.stdout)
+                    print(output.get_fs_path(), file=self.stdout)
+            elif isinstance(build_action, BuildFile):
+                print(build_action.get_fs_path(), file=self.stdout)
 
     def list_command(self):
         executor = self._create_executor()
diff --git a/compiler/bindings/python/iree/build/metadata.py b/compiler/bindings/python/iree/build/metadata.py
new file mode 100644
index 0000000..91767dd
--- /dev/null
+++ b/compiler/bindings/python/iree/build/metadata.py
@@ -0,0 +1,37 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+"""Common `BuildMeta` subclasses for built-in actions.
+
+These are maintained here purely as an aid to avoiding circular dependencies.
+Typically, in out of tree actions, they would just be inlined into the implementation
+file.
+"""
+
+from .executor import BuildMeta
+
+
+class CompileSourceMeta(BuildMeta):
+    """CompileSourceMeta tracks source level properties that can influence compilation.
+
+    This meta can be set on any dependency that ultimately is used as a source to a
+    `compile` action.
+    """
+
+    # Slots in this case simply will catch attempts to set undefined attributes.
+    __slots__ = [
+        "input_type",
+    ]
+    KEY = "iree.compile.source"
+
+    def __init__(self):
+        super().__init__()
+
+        # The value to the --iree-input-type= flag for this source file.
+        self.input_type: str = "auto"
+
+    def __repr__(self):
+        return f"CompileSourceMeta(input_type={self.input_type})"
diff --git a/compiler/bindings/python/iree/build/net_actions.py b/compiler/bindings/python/iree/build/net_actions.py
index 1d2e158..da74d9a 100644
--- a/compiler/bindings/python/iree/build/net_actions.py
+++ b/compiler/bindings/python/iree/build/net_actions.py
@@ -30,7 +30,7 @@
         self.url = url
         self.output_file = output_file
 
-    def invoke(self):
+    def _invoke(self):
         path = self.output_file.get_fs_path()
         self.executor.write_status(f"Fetching URL: {self.url} -> {path}")
         try:
diff --git a/compiler/bindings/python/iree/build/onnx_actions.py b/compiler/bindings/python/iree/build/onnx_actions.py
index 5220890..83ec3a1 100644
--- a/compiler/bindings/python/iree/build/onnx_actions.py
+++ b/compiler/bindings/python/iree/build/onnx_actions.py
@@ -5,6 +5,7 @@
 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from iree.build.executor import BuildAction, BuildContext, BuildFile, BuildFileLike
+from iree.build.metadata import CompileSourceMeta
 
 __all__ = [
     "onnx_import",
@@ -30,7 +31,7 @@
             input_file=input_file,
             output_file=processed_file,
             executor=context.executor,
-            desc=f"Upgrading ONNX {name}",
+            desc=f"Upgrading ONNX {input_file} -> {processed_file}",
             deps=[
                 input_file,
             ],
@@ -41,14 +42,13 @@
     ImportOnnxAction(
         input_file=input_file,
         output_file=output_file,
-        desc=f"Importing ONNX {name}",
+        desc=f"Importing ONNX {name} -> {output_file}",
         executor=context.executor,
         deps=[
             input_file,
         ],
     )
 
-    output_file.deps.add(processed_file)
     return output_file
 
 
@@ -57,9 +57,11 @@
         super().__init__(**kwargs)
         self.input_file = input_file
         self.output_file = output_file
+        self.deps.add(self.input_file)
         output_file.deps.add(self)
+        CompileSourceMeta.get(output_file).input_type = "onnx"
 
-    def invoke(self):
+    def _invoke(self):
         import onnx
 
         input_path = self.input_file.get_fs_path()
@@ -75,9 +77,11 @@
         super().__init__(**kwargs)
         self.input_file = input_file
         self.output_file = output_file
+        self.deps.add(input_file)
         output_file.deps.add(self)
+        CompileSourceMeta.get(output_file).input_type = "onnx"
 
-    def invoke(self):
+    def _invoke(self):
         import iree.compiler.tools.import_onnx.__main__ as m
 
         args = m.parse_arguments(
diff --git a/compiler/bindings/python/iree/build/target_machine.py b/compiler/bindings/python/iree/build/target_machine.py
new file mode 100644
index 0000000..c8eca7d
--- /dev/null
+++ b/compiler/bindings/python/iree/build/target_machine.py
@@ -0,0 +1,190 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+"""Handles the messy affair of deriving options for targeting machines."""
+
+import argparse
+
+from iree.build.args import (
+    expand_cl_arg_defaults,
+    register_arg_parser_callback,
+    cl_arg_ref,
+)
+
+
+class TargetMachine:
+    def __init__(
+        self,
+        target_spec: str,
+        *,
+        iree_compile_device_type: str | None = None,
+        extra_flags: list[str] | None = None,
+    ):
+        self.target_spec = target_spec
+        self.iree_compile_device_type = iree_compile_device_type
+        self.extra_flags = extra_flags
+
+    @property
+    def flag_list(self) -> list[str]:
+        if self.iree_compile_device_type is not None:
+            # This is just a hard-coded machine model using a single IREE device
+            # type alias in the default configuration.
+            return [f"--iree-hal-target-device={self.iree_compile_device_type}"] + (
+                self.extra_flags or []
+            )
+        raise RuntimeError(f"Cannot compute iree-compile flags for: {self}")
+
+    def __repr__(self):
+        r = f"TargetMachine({self.target_spec}, "
+        if self.iree_compile_device_type is not None:
+            r += f"iree_compile_device_type='{self.iree_compile_device_type}', "
+        if self.extra_flags:
+            r += f"extra_flags={self.extra_flags}, "
+        r += ")"
+        return r
+
+
+################################################################################
+# Handling of --iree-hal-target-device from flags
+################################################################################
+
+
+HAL_TARGET_DEVICES_FROM_FLAGS_HANDLERS = {}
+
+
+def handle_hal_target_devices_from_flags(*mnemonics: str):
+    def decorator(f):
+        for mn in mnemonics:
+            HAL_TARGET_DEVICES_FROM_FLAGS_HANDLERS[mn] = f
+        return f
+
+    return decorator
+
+
+def handle_unknown_hal_target_device(mnemonic: str) -> list[TargetMachine]:
+    return [TargetMachine(mnemonic, iree_compile_device_type=mnemonic)]
+
+
+@handle_hal_target_devices_from_flags("amdgpu", "hip")
+@expand_cl_arg_defaults
+def amdgpu_hal_target_from_flags(
+    mnemonic: str, *, amdgpu_target=cl_arg_ref("iree_amdgpu_target")
+) -> list[TargetMachine]:
+    if not amdgpu_target:
+        raise RuntimeError(
+            "No AMDGPU targets specified. Pass a chip to target as "
+            "--iree-amdgpu-target=gfx..."
+        )
+    return [
+        TargetMachine(
+            f"amdgpu-{amdgpu_target}",
+            iree_compile_device_type="amdgpu",
+            extra_flags=[f"--iree-hip-target={amdgpu_target}"],
+        )
+    ]
+
+
+@handle_hal_target_devices_from_flags("llvm-cpu", "cpu")
+@expand_cl_arg_defaults
+def cpu_hal_target_from_flags(
+    mnemonic: str,
+    *,
+    cpu=cl_arg_ref("iree_llvmcpu_target_cpu"),
+    features=cl_arg_ref("iree_llvmcpu_target_cpu_features"),
+) -> list[TargetMachine]:
+    target_spec = "cpu"
+    extra_flags = []
+    if cpu:
+        target_spec += f"-{cpu}"
+        extra_flags.append(f"--iree-llvmcpu-target-cpu={cpu}")
+    if features:
+        target_spec += f":{features}"
+        extra_flags.append(f"--iree-llvmcpu-target-cpu-features={features}")
+
+    return [
+        TargetMachine(
+            f"cpu-{cpu or 'generic'}",
+            iree_compile_device_type="llvm-cpu",
+            extra_flags=extra_flags,
+        )
+    ]
+
+
+################################################################################
+# Flag definition
+################################################################################
+
+
+@register_arg_parser_callback
+def _(p: argparse.ArgumentParser):
+    g = p.add_argument_group(
+        title="IREE Target Machine Options",
+        description="Global options controlling invocation of iree-compile",
+    )
+    g.add_argument(
+        "--iree-hal-target-device",
+        help="Compiles with a single machine model and a single specified device"
+        " (mutually exclusive with other ways to set the machine target). This "
+        "emulates the simple case of device targeting if invoking `iree-compile` "
+        "directly and is mostly a pass-through which also enforces other flags "
+        "depending on the value given. Supported options (or any supported by the "
+        "compiler): "
+        f"{', '.join(HAL_TARGET_DEVICES_FROM_FLAGS_HANDLERS.keys() - 'default')}",
+        nargs="*",
+    )
+
+    hip_g = p.add_argument_group(
+        title="IREE AMDGPU Target Options",
+        description="Options controlling explicit targeting of AMDGPU devices",
+    )
+    hip_g.add_argument(
+        "--iree-amdgpu-target",
+        "--iree-hip-target",
+        help="AMDGPU target selection (i.e. 'gfxYYYY')",
+    )
+
+    cpu_g = p.add_argument_group(
+        title="IREE CPU Target Options",
+        description="These are mostly pass-through. See `iree-compile --help` for "
+        "full information. Advanced usage will require an explicit machine config "
+        "file",
+    )
+    cpu_g.add_argument(
+        "--iree-llvmcpu-target-cpu",
+        help="'generic', 'host', or an explicit CPU name. See iree-compile help.",
+    )
+    cpu_g.add_argument(
+        "--iree-llvmcpu-target-cpu-features",
+        help="Comma separated list of '+' prefixed CPU features. See iree-compile help.",
+    )
+
+
+################################################################################
+# Global flag dispatch
+################################################################################
+
+
+@expand_cl_arg_defaults
+def compute_target_machines_from_flags(
+    *,
+    explicit_hal_target_devices: list[str]
+    | None = cl_arg_ref("iree_hal_target_device"),
+) -> list[TargetMachine]:
+    if explicit_hal_target_devices is not None:
+        # Most basic default case for setting up compilation.
+        machines = []
+        for explicit_hal_target_device in explicit_hal_target_devices:
+            handler = (
+                HAL_TARGET_DEVICES_FROM_FLAGS_HANDLERS.get(explicit_hal_target_device)
+                or handle_unknown_hal_target_device
+            )
+            machines.extend(handler(explicit_hal_target_device))
+        return machines
+
+    raise RuntimeError(
+        "iree-compile target information is required but none was provided. "
+        "See flags: --iree-hal-target-device"
+    )
diff --git a/compiler/bindings/python/iree/compiler/api/ctypes_dl.py b/compiler/bindings/python/iree/compiler/api/ctypes_dl.py
index 784fc61..26f7228 100644
--- a/compiler/bindings/python/iree/compiler/api/ctypes_dl.py
+++ b/compiler/bindings/python/iree/compiler/api/ctypes_dl.py
@@ -189,6 +189,11 @@
     return len(view) >= 4 and view[:4].hex() == "4d4cef52"
 
 
+class SessionObject:
+    def close(self):
+        ...
+
+
 class Session:
     def __init__(self):
         self._global_init = _global_init
@@ -197,12 +202,27 @@
         # its ownership of it, so we must cache the new Python-level MLIRContext
         # so its lifetime extends at least to our own.
         self._owned_context = None
+        self._dependents: set[SessionObject] = set()
 
     def __del__(self):
-        _dylib.ireeCompilerSessionDestroy(self._session_p)
+        self.close()
+
+    def close(self):
+        if self._session_p:
+            for dep in list(self._dependents):
+                dep.close()
+            _dylib.ireeCompilerSessionDestroy(self._session_p)
+            self._session_p = c_void_p()
+
+    def __enter__(self) -> "Session":
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.close()
 
     @property
     def context(self):
+        assert self._session_p, "Session is closed"
         if self._owned_context is None:
             from .. import ir
 
@@ -218,9 +238,11 @@
         return self._owned_context
 
     def invocation(self) -> "Invocation":
+        assert self._session_p, "Session is closed"
         return Invocation(self)
 
     def get_flags(self, non_default_only: bool = False) -> Sequence[str]:
+        assert self._session_p, "Session is closed"
         results = []
 
         @_GET_FLAG_CALLBACK
@@ -235,6 +257,7 @@
         return results
 
     def set_flags(self, *flags: str):
+        assert self._session_p, "Session is closed"
         argv_type = c_char_p * len(flags)
         argv = argv_type(*[flag.encode("UTF-8") for flag in flags])
         _handle_error(
@@ -257,6 +280,12 @@
             self._local_dylib.ireeCompilerOutputDestroy(self._output_p)
             self._output_p = None
 
+    def __enter__(self) -> "Invocation":
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.close()
+
     @staticmethod
     def open_file(file_path: str) -> "Output":
         output_p = c_void_p()
@@ -300,11 +329,12 @@
         return pointer
 
 
-class Source:
+class Source(SessionObject):
     """Wraps an iree_compiler_source_t."""
 
-    def __init__(self, session: c_void_p, source_p: c_void_p, backing_ref):
-        self._session: c_void_p = session  # Keeps ref alive.
+    def __init__(self, session: Session, source_p: c_void_p, backing_ref):
+        self._session: Session | None = session  # Keeps ref alive.
+        self._session._dependents.add(self)
         self._source_p: c_void_p = source_p
         self._backing_ref = backing_ref
         self._local_dylib = _dylib
@@ -318,7 +348,14 @@
             self._source_p = c_void_p()
             self._local_dylib.ireeCompilerSourceDestroy(s)
             self._backing_ref = None
-            self._session = c_void_p()
+            self._session._dependents.remove(self)
+            self._session = None
+
+    def __enter__(self) -> "Invocation":
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.close()
 
     def __repr__(self):
         return f"<Source {self._source_p}>"
@@ -362,9 +399,10 @@
     IREE_COMPILER_PIPELINE_PRECOMPILE = 2
 
 
-class Invocation:
+class Invocation(SessionObject):
     def __init__(self, session: Session):
-        self._session = session
+        self._session: Session | None = session
+        self._session._dependents.add(self)
         self._inv_p = _dylib.ireeCompilerInvocationCreate(self._session._session_p)
         self._sources: list[Source] = []
         self._local_dylib = _dylib
@@ -387,6 +425,14 @@
             for s in self._sources:
                 s.close()
             self._sources.clear()
+            self._session._dependents.remove(self)
+            self._session = None
+
+    def __enter__(self) -> "Invocation":
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.close()
 
     def enable_console_diagnostics(self):
         _dylib.ireeCompilerInvocationEnableConsoleDiagnostics(self._inv_p)
diff --git a/compiler/bindings/python/test/build_api/mnist_builder.py b/compiler/bindings/python/test/build_api/mnist_builder.py
index 54d6e30..31c4845 100644
--- a/compiler/bindings/python/test/build_api/mnist_builder.py
+++ b/compiler/bindings/python/test/build_api/mnist_builder.py
@@ -23,7 +23,10 @@
         name="mnist.mlir",
         source="mnist.onnx",
     )
-    return "mnist.mlir"
+    return compile(
+        name="mnist",
+        source="mnist.mlir",
+    )
 
 
 if __name__ == "__main__":
diff --git a/compiler/bindings/python/test/build_api/mnist_builder_test.py b/compiler/bindings/python/test/build_api/mnist_builder_test.py
index 02a9140..fb8cefb 100644
--- a/compiler/bindings/python/test/build_api/mnist_builder_test.py
+++ b/compiler/bindings/python/test/build_api/mnist_builder_test.py
@@ -16,6 +16,11 @@
 
 THIS_DIR = Path(__file__).resolve().parent
 
+DEFAULT_TARGET_ARGS = [
+    "--iree-hal-target-device=cpu",
+    "--iree-llvmcpu-target-cpu=host",
+]
+
 
 class MnistBuilderTest(unittest.TestCase):
     def setUp(self):
@@ -39,14 +44,14 @@
                 "--output-dir",
                 str(self.output_path),
             ]
+            + DEFAULT_TARGET_ARGS
         ).decode()
         print("OUTPUT:", output)
         output_paths = output.splitlines()
-        self.assertEqual(len(output_paths), 1)
+        self.assertEqual(len(output_paths), 1, msg=f"Found {output_paths}")
         output_path = Path(output_paths[0])
         self.assertTrue(output_path.is_relative_to(self.output_path))
-        contents = output_path.read_text()
-        self.assertIn("module", contents)
+        self.assertIn("mnist_cpu-host.vmfb", output_paths[0])
 
     # Tests that invoking via the build module itself works
     #   python {path to py file}
@@ -59,22 +64,24 @@
                 "--output-dir",
                 str(self.output_path),
             ]
+            + DEFAULT_TARGET_ARGS
         ).decode()
         print("OUTPUT:", output)
         output_paths = output.splitlines()
-        self.assertEqual(len(output_paths), 1)
+        self.assertEqual(len(output_paths), 1, msg=f"Found {output_paths}")
+        self.assertIn("mnist_cpu-host.vmfb", output_paths[0])
 
     def testListCommand(self):
         mod = load_build_module(THIS_DIR / "mnist_builder.py")
         out_file = io.StringIO()
-        iree_build_main(mod, args=["--list"], stdout=out_file)
+        iree_build_main(mod, args=["--list"] + DEFAULT_TARGET_ARGS, stdout=out_file)
         output = out_file.getvalue().strip()
         self.assertEqual(output, "mnist")
 
     def testListAllCommand(self):
         mod = load_build_module(THIS_DIR / "mnist_builder.py")
         out_file = io.StringIO()
-        iree_build_main(mod, args=["--list-all"], stdout=out_file)
+        iree_build_main(mod, args=["--list-all"] + DEFAULT_TARGET_ARGS, stdout=out_file)
         output = out_file.getvalue().splitlines()
         self.assertIn("mnist", output)
         self.assertIn("mnist/mnist.onnx", output)
@@ -92,11 +99,24 @@
                 args=[
                     "--mnist-onnx-url",
                     "https://github.com/iree-org/doesnotexist",
-                ],
+                ]
+                + DEFAULT_TARGET_ARGS,
                 stdout=out_file,
                 stderr=err_file,
             )
 
+    def testBuildNonDefaultSubTarget(self):
+        mod = load_build_module(THIS_DIR / "mnist_builder.py")
+        out_file = io.StringIO()
+        iree_build_main(
+            mod, args=["mnist/mnist.mlir"] + DEFAULT_TARGET_ARGS, stdout=out_file
+        )
+        output = out_file.getvalue().strip()
+        self.assertIn("genfiles/mnist/mnist.mlir", output)
+        output_path = Path(output)
+        contents = output_path.read_text()
+        self.assertIn("module", contents)
+
 
 if __name__ == "__main__":
     try: