[iree.build] Implement iree-compile action. (#18993)
This covers a lot of ground to actually get the full compile pipeline
working:
* Adds dep metadata and uses it to plumb input type to the compiler.
* Fully generalizes arg parsing and management.
* Adds a skeleton TargetMachine (used to drive the compiler) and teaches
it how to handle/pass-through standard IREE flags for simple single
device configurations. (Complex cases will be handled with some kind of
machine spec file and additional compiler entry-points for setting raw
target info).
* Plumbs through both in process and out of process compiler invocation.
* Robustifies the iree.compiler.api so that it can't do out of order
destruction during process cleanup.
Much more to do:
* I've got enough now to add a torch.export action to iree-turbine. That
will be run against the out of process pool.
* Quite a bit of compiler invocation ergonomics needed (reproducers,
error/output handling, etc).
* Need to take a pass through and make pretty console reporting/logging
throughout.
---------
Signed-off-by: Stella Laurenzo <stellaraccident@gmail.com>
diff --git a/compiler/bindings/python/CMakeLists.txt b/compiler/bindings/python/CMakeLists.txt
index bc8119f..76caebb 100644
--- a/compiler/bindings/python/CMakeLists.txt
+++ b/compiler/bindings/python/CMakeLists.txt
@@ -249,11 +249,15 @@
SOURCES
__init__.py
__main__.py
+ args.py
+ compile_actions.py
executor.py
lang.py
main.py
+ metadata.py
net_actions.py
onnx_actions.py
+ target_machine.py
)
add_mlir_python_modules(IREECompilerBuildPythonModules
diff --git a/compiler/bindings/python/iree/build/__init__.py b/compiler/bindings/python/iree/build/__init__.py
index 95ee3f7..3cc054e 100644
--- a/compiler/bindings/python/iree/build/__init__.py
+++ b/compiler/bindings/python/iree/build/__init__.py
@@ -8,5 +8,7 @@
from iree.build.lang import *
from iree.build.main import *
+
+from iree.build.compile_actions import *
from iree.build.net_actions import *
from iree.build.onnx_actions import *
diff --git a/compiler/bindings/python/iree/build/args.py b/compiler/bindings/python/iree/build/args.py
new file mode 100644
index 0000000..9a2482a
--- /dev/null
+++ b/compiler/bindings/python/iree/build/args.py
@@ -0,0 +1,187 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import Callable, Generator, TypeVar
+
+import argparse
+import contextlib
+import functools
+import inspect
+import threading
+
+from typing import Callable
+
+_locals = threading.local()
+_ALL_ARG_REGISTRARS: list[Callable[[argparse.ArgumentParser], None]] = []
+_ALL_ARG_HANDLERS: list[Callable[[argparse.Namespace], None]] = []
+
+
+def register_arg_parser_callback(registrar: Callable[[argparse.ArgumentParser], None]):
+ """Decorator that adds a global argument registration callback.
+
+ This callback will be invoked when a new ArgumentParser is constructed.
+ """
+ _ALL_ARG_REGISTRARS.append(registrar)
+ return registrar
+
+
+def register_arg_handler_callback(handler: Callable[[argparse.Namespace], None]):
+ """Decorator that registers a handler to be run on global arguments at startup."""
+ _ALL_ARG_HANDLERS.append(handler)
+ return handler
+
+
+def configure_arg_parser(p: argparse.ArgumentParser):
+ """Invokes all callbacks from `register_arg_parser_callback` on the parser."""
+ for callback in _ALL_ARG_REGISTRARS:
+ callback(p)
+
+
+def run_global_arg_handlers(ns: argparse.Namespace):
+ """Invokes all global argument handlers."""
+ for h in _ALL_ARG_HANDLERS:
+ h(ns)
+
+
+@contextlib.contextmanager
+def argument_namespace_context(ns: argparse.Namespace):
+ """Establish that given namespace as the current namespace for this thread.
+
+ Note that as a thread local, this does not propagate to child threads or
+ sub-processes. This means that all argument management must be done during
+ action setup and action invocations will not typically have access to args.
+ """
+ if not hasattr(_locals, "arg_ns_stack"):
+ _locals.arg_ns_stack = []
+ _locals.arg_ns_stack.append(ns)
+ try:
+ yield ns
+ finally:
+ _locals.arg_ns_stack.pop()
+
+
+def current_args_namespace() -> argparse.Namespace:
+ try:
+ return _locals.arg_ns_stack[-1]
+ except (AttributeError, IndexError):
+ raise AssertionError(
+ "No current argument namespace: Is it possible you are trying to resolve "
+ "arguments from another thread or process"
+ )
+
+
+_Decorated = TypeVar("_Decorated", bound=Callable)
+
+
+def expand_cl_arg_defaults(wrapped: _Decorated) -> _Decorated:
+ sig = inspect.signature(wrapped)
+
+ def wrapper(*args, **kwargs):
+ args_ns = current_args_namespace()
+ bound = sig.bind(*args, **kwargs)
+ bound.apply_defaults()
+
+ def filter(arg):
+ if isinstance(arg, ClArgRef):
+ return arg.resolve(args_ns)
+ return arg
+
+ new_args = [filter(arg) for arg in bound.args]
+ new_kwargs = {k: filter(v) for k, v in bound.kwargs.items()}
+ return wrapped(*new_args, **new_kwargs)
+
+ functools.update_wrapper(wrapper, wrapped)
+ return wrapper
+
+
+class ClArgRef:
+ """Used in default values of function arguments to indicate that the default should
+ be derived from an argument reference.
+
+ Actually defining the argument must be done elsewhere.
+
+ See `cl_arg_ref()` for canonical use.
+ """
+
+ def __init__(self, dest: str):
+ self.dest = dest
+
+ def resolve(self, arg_namespace: argparse.Namespace):
+ try:
+ return getattr(arg_namespace, self.dest)
+ except AttributeError as e:
+ raise RuntimeError(
+ f"Unable to resolve command line argument '{self.dest}' in namespace"
+ ) from e
+
+
+def cl_arg_ref(dest: str):
+ """Used as a default value for functions wrapped in @expand_cl_defaults to indicate
+ that an argument must come from the command line environment.
+
+ Note that this does not have a typing annotation, allowing the argument to be
+ annotated with a type, assuming that resolution will happen dynamically in some
+ fashion.
+ """
+ return ClArgRef(dest)
+
+
+class ClArg(ClArgRef):
+ """Used in default values of function arguments to indicate that an argument needs
+ to be defined and referenced.
+
+ This is used in user-defined entry points, and the executor has special logic to
+ collect all needed arguments automatically.
+
+ See `cl_arg()` for canonical use.
+ """
+
+ def __init__(self, name, dest: str, **add_argument_kw):
+ super().__init__(dest)
+ self.name = name
+ self.add_argument_kw = add_argument_kw
+
+ def define_arg(self, parser: argparse.ArgumentParser):
+ parser.add_argument(f"--{self.name}", dest=self.dest, **self.add_argument_kw)
+
+
+def cl_arg(name: str, *, action=None, default=None, type=None, help=None):
+ """Used to define or reference a command-line argument from within actions
+ and entry-points.
+
+ Keywords have the same interpretation as `ArgumentParser.add_argument()`.
+
+ Any ClArg set as a default value for an argument to an `entrypoint` will be
+ added to the global argument parser. Any particular argument name can only be
+ registered once and must not conflict with a built-in command line option.
+ The implication of this is that for single-use arguments, the `=cl_arg(...)`
+ can just be added as a default argument. Otherwise, for shared arguments,
+ it should be created at the module level and referenced.
+
+ When called, any entrypoint arguments that do not have an explicit keyword
+ set will get their value from the command line environment.
+
+ Note that this does not have a typing annotation, allowing the argument to be
+ annotated with a type, assuming that resolution will happen dynamically in some
+ fashion.
+ """
+ if name.startswith("-"):
+ raise ValueError("cl_arg name must not be prefixed with dashes")
+ dest = name.replace("-", "_")
+ return ClArg(name, action=action, default=default, type=type, dest=dest, help=help)
+
+
+def extract_cl_arg_defs(callable: Callable) -> Generator[ClArg, None, None]:
+ """Extracts all `ClArg` default values from a callable.
+
+ This is used in order to eagerly register argument definitions for some set
+ of functions.
+ """
+ sig = inspect.signature(callable)
+ for p in sig.parameters.values():
+ def_value = p.default
+ if isinstance(def_value, ClArg):
+ yield def_value
diff --git a/compiler/bindings/python/iree/build/compile_actions.py b/compiler/bindings/python/iree/build/compile_actions.py
new file mode 100644
index 0000000..2c2056c
--- /dev/null
+++ b/compiler/bindings/python/iree/build/compile_actions.py
@@ -0,0 +1,208 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import argparse
+import shlex
+
+import iree.compiler.api as compiler_api
+import iree.compiler.tools as compiler_tools
+
+from iree.build.args import (
+ expand_cl_arg_defaults,
+ register_arg_handler_callback,
+ register_arg_parser_callback,
+ cl_arg_ref,
+)
+
+from iree.build.executor import (
+ BuildAction,
+ BuildContext,
+ BuildFile,
+ BuildFileLike,
+ FileNamespace,
+)
+
+from iree.build.metadata import CompileSourceMeta
+from iree.build.target_machine import compute_target_machines_from_flags
+
+__all__ = [
+ "compile",
+]
+
+
+@register_arg_parser_callback
+def _(p: argparse.ArgumentParser):
+ g = p.add_argument_group(
+ title="IREE Compiler Options",
+ description="Global options controlling invocation of iree-compile",
+ )
+ g.add_argument(
+ "--iree-compile-out-of-process",
+ action=argparse.BooleanOptionalAction,
+ help="Invokes iree-compiler as an out of process executable (the default is to "
+ "invoke it in-process via API bindings). This can make debugging somewhat "
+ "easier and also grants access to global command line options that may not "
+ "otherwise be available.",
+ )
+ g.add_argument(
+ "--iree-compile-extra-args",
+ help="Extra arguments to pass to iree-compile. When running in-process, these "
+ "will be passed as globals to the library and effect all compilation in the "
+ "process. These are split with shlex rules.",
+ )
+
+
+@register_arg_handler_callback
+def _(ns: argparse.Namespace):
+ in_process = not ns.iree_compile_out_of_process
+ extra_args_str = ns.iree_compile_extra_args
+ if in_process and extra_args_str:
+ # TODO: This is very unsafe. If called multiple times (i.e. in a library)
+ # or with illegal arguments, the program will abort. The safe way to do
+ # this is to spawn one child process and route all "in process" compilation
+ # there. This would allow explicit control of startup/shutdown and would
+ # provide isolation in the event of a compiler crash. It is still important
+ # for a single process to handle all compilation activities since this
+ # allows global compiler resources (like threads) to be pooled and not
+ # saturate the machine resources.
+ extra_args_list = shlex.split(extra_args_str)
+ compiler_api._initializeGlobalCL("unused_prog_name", *extra_args_list)
+
+
+class CompilerInvocation:
+ @expand_cl_arg_defaults
+ def __init__(
+ self,
+ *,
+ input_file: BuildFile,
+ output_file: BuildFile,
+ out_of_process: bool = cl_arg_ref("iree_compile_out_of_process"),
+ extra_args_str=cl_arg_ref("iree_compile_extra_args"),
+ ):
+ self.input_file = input_file
+ self.output_file = output_file
+ # We manage most flags as keyword values that can have at most one
+ # setting.
+ self.kw_flags: dict[str, str | None] = {}
+ # Flags can also be set free-form. These are always added to the command
+ # line after the kw_flags.
+ self.extra_flags: list[str] = []
+ self.out_of_process = out_of_process
+
+ if extra_args_str:
+ self.extra_args = shlex.split(extra_args_str)
+ else:
+ self.extra_args = []
+
+ def run(self):
+ raw_flags: list[str] = []
+
+ # Set any defaults derived from the input_file metadata. These are set
+ # first because they can be overriden by explicit flag settings.
+ meta = CompileSourceMeta.get(self.input_file)
+ raw_flags.append(f"--iree-input-type={meta.input_type}")
+
+ # Process kw_flags.
+ for key, value in self.kw_flags.items():
+ if value is None:
+ raw_flags.append(f"--{key}")
+ else:
+ raw_flags.append(f"--{key}={value}")
+
+ # Process extra_flags.
+ for raw in self.extra_flags:
+ raw_flags.append(raw)
+
+ if self.out_of_process:
+ self.run_out_of_process(raw_flags)
+ else:
+ self.run_inprocess(raw_flags)
+
+ def run_inprocess(self, flags: list[str]):
+ with compiler_api.Session() as session:
+ session.set_flags(*flags)
+ with compiler_api.Invocation(session) as inv, compiler_api.Source.open_file(
+ session, str(self.input_file.get_fs_path())
+ ) as source, compiler_api.Output.open_file(
+ str(self.output_file.get_fs_path())
+ ) as output:
+ inv.enable_console_diagnostics()
+ inv.parse_source(source)
+ if not inv.execute():
+ raise RuntimeError("COMPILE FAILED (TODO)")
+ inv.output_vm_bytecode(output)
+ output.keep()
+
+ def run_out_of_process(self, flags: list[str]):
+ # TODO: This Python executable wrapper is really long in the tooth. We should
+ # just invoke iree-compile directly (which would also let us have a flag for
+ # the path to it).
+ all_extra_args = self.extra_args + flags
+ compiler_tools.compile_file(
+ str(self.input_file.get_fs_path()),
+ output_file=str(self.output_file.get_fs_path()),
+ extra_args=self.extra_args + flags,
+ )
+
+
+def compile(
+ *,
+ name: str,
+ source: BuildFileLike,
+ target_default: bool = True,
+) -> tuple[BuildFile]:
+ """Invokes iree-compile on a source file, producing binaries for one or more target
+ machines.
+
+ Args:
+ name: The logical name of the compilation command. This is used as the stem
+ for multiple kinds of output files.
+ source: Input source file.
+ target_default: Whether to use command line arguments to compute a target
+ machine configuration (default True). This would be set to False to explicitly
+ depend on target information contained in the source file and not require
+ any target flags passed to the build tool.
+ """
+ context = BuildContext.current()
+ input_file = context.file(source)
+ if target_default:
+ # Compute the target machines from flags and create one compilation for each.
+ tms = compute_target_machines_from_flags()
+ output_files: list[BuildFile] = []
+ for tm in tms:
+ output_file = context.allocate_file(
+ f"{name}_{tm.target_spec}.vmfb", namespace=FileNamespace.BIN
+ )
+ inv = CompilerInvocation(input_file=input_file, output_file=output_file)
+ inv.extra_flags.extend(tm.flag_list)
+ CompileAction(
+ inv,
+ desc=f"Compiling {input_file} (for {tm.target_spec})",
+ executor=context.executor,
+ )
+ output_files.append(output_file)
+ return output_files
+ else:
+ # The compilation is self contained, so just directly compile it.
+ output_file = context.allocate_file(f"{name}.vmfb", namespace=FileNamespace.BIN)
+ inv = CompilerInvocation(input_file=input_file, output_file=output_file)
+ CompileAction(
+ inv,
+ desc=f"Compiling {name}",
+ executor=context.executor,
+ )
+ return output_file
+
+
+class CompileAction(BuildAction):
+ def __init__(self, inv: CompilerInvocation, **kwargs):
+ super().__init__(**kwargs)
+ self.inv = inv
+ self.inv.output_file.deps.add(self)
+ self.deps.add(self.inv.input_file)
+
+ def _invoke(self):
+ self.inv.run()
diff --git a/compiler/bindings/python/iree/build/executor.py b/compiler/bindings/python/iree/build/executor.py
index 35b9173..1d580bd 100644
--- a/compiler/bindings/python/iree/build/executor.py
+++ b/compiler/bindings/python/iree/build/executor.py
@@ -4,20 +4,23 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from typing import Callable, Collection, Generator, IO
+from typing import Callable, Collection, IO, Type, TypeVar
import abc
-import argparse
import concurrent.futures
import enum
-import inspect
import multiprocessing
-import sys
import time
import traceback
from pathlib import Path
import threading
+from iree.build.args import (
+ current_args_namespace,
+ expand_cl_arg_defaults,
+ extract_cl_arg_defs,
+)
+
_locals = threading.local()
@@ -52,24 +55,6 @@
return f"{prefix}/{suffix}"
-class ClArg:
- def __init__(self, name, dest: str, **add_argument_kw):
- self.name = name
- self.dest = dest
- self.add_argument_kw = add_argument_kw
-
- def define_arg(self, parser: argparse.ArgumentParser):
- parser.add_argument(f"--{self.name}", dest=self.dest, **self.add_argument_kw)
-
- def resolve(self, arg_namespace: argparse.Namespace):
- try:
- return getattr(arg_namespace, self.dest)
- except AttributeError as e:
- raise RuntimeError(
- f"Unable to resolve command line argument '{self.dest}' in namespace"
- ) from e
-
-
class Entrypoint:
def __init__(
self,
@@ -79,17 +64,12 @@
):
self.name = name
self.description = description
- self._wrapped = wrapped
-
- def cl_args(self) -> Generator[ClArg, None, None]:
- sig = inspect.signature(self._wrapped)
- for p in sig.parameters.values():
- def_value = p.default
- if isinstance(def_value, ClArg):
- yield def_value
+ self.cl_arg_defs = list(extract_cl_arg_defs(wrapped))
+ self._wrapped = expand_cl_arg_defaults(wrapped)
def __call__(self, *args, **kwargs):
parent_context = BuildContext.current()
+ args_ns = current_args_namespace()
bep = BuildEntrypoint(
join_namespace(parent_context.path, self.name),
parent_context.executor,
@@ -97,18 +77,7 @@
)
parent_context.executor.entrypoints.append(bep)
with bep:
- sig = inspect.signature(self._wrapped)
- bound = sig.bind(*args, **kwargs)
- bound.apply_defaults()
-
- def filter(arg):
- if isinstance(arg, ClArg):
- return arg.resolve(parent_context.executor.args_namespace)
- return arg
-
- new_args = [filter(arg) for arg in bound.args]
- new_kwargs = {k: filter(v) for k, v in bound.kwargs.items()}
- results = self._wrapped(*new_args, **new_kwargs)
+ results = self._wrapped(*args, **kwargs)
if results is not None:
files = bep.files(results)
bep.deps.update(files)
@@ -119,15 +88,12 @@
class Executor:
"""Executor that all build contexts share."""
- def __init__(
- self, output_dir: Path, args_namespace: argparse.Namespace, stderr: IO
- ):
+ def __init__(self, output_dir: Path, stderr: IO):
self.output_dir = output_dir
self.verbose_level = 0
# Keyed by path
self.all: dict[str, "BuildContext" | "BuildFile"] = {}
self.entrypoints: list["BuildEntrypoint"] = []
- self.args_namespace = args_namespace
self.stderr = stderr
BuildContext("", self)
@@ -189,6 +155,41 @@
scheduler.shutdown()
+BuildMetaType = TypeVar("BuildMetaType", bound="BuildMeta")
+
+
+class BuildMeta:
+ """Base class for typed metadata that can be set on a BuildDependency.
+
+ This is an open namespace where each sub-class must have a unique key as the class
+ level attribute `KEY`.
+ """
+
+ def __init__(self):
+ key = getattr(self, "KEY", None)
+ assert isinstance(key, str), "BuildMeta.KEY must be a str"
+
+ @classmethod
+ def get(cls: Type[BuildMetaType], dep: "BuildDependency") -> BuildMetaType:
+ """Gets a metadata instance of this type from a dependency.
+
+ If it does not yet exist, returns the value of `create_default()`, which
+ by default returns a new instance (which is set on the dep).
+ """
+ key = getattr(cls, "KEY", None)
+ assert isinstance(key, str), f"{cls.__name__}.KEY must be a str"
+ instance = dep._metadata.get(key)
+ if instance is None:
+ instance = cls.create_default()
+ dep._metadata[key] = instance
+ return instance
+
+ @classmethod
+ def create_default(cls) -> "BuildMeta":
+ """Creates a default instance."""
+ return cls()
+
+
class BuildDependency:
"""Base class of entities that can act as a build dependency."""
@@ -205,6 +206,9 @@
self.start_time: float | None = None
self.finish_time: float | None = None
+ # Metadata.
+ self._metadata: dict[str, BuildMeta] = {}
+
@property
def is_scheduled(self) -> bool:
return self.future is not None
@@ -287,8 +291,11 @@
def __repr__(self):
return f"Action[{type(self).__name__}]('{self.desc}')"
- @abc.abstractmethod
def invoke(self):
+ self._invoke()
+
+ @abc.abstractmethod
+ def _invoke(self):
...
@@ -317,7 +324,8 @@
"""
if not path.startswith("/"):
path = join_namespace(self.path, path)
- return BuildFile(executor=self.executor, path=path, namespace=namespace)
+ build_file = BuildFile(executor=self.executor, path=path, namespace=namespace)
+ return build_file
def file(self, file: str | BuildFile) -> BuildFile:
"""Accesses a BuildFile by either string (path) or BuildFile.
@@ -374,9 +382,6 @@
existing = stack.pop()
assert existing is self, "Unbalanced BuildContext enter/exit"
- def populate_arg_parser(self, parser: argparse.ArgumentParser):
- ...
-
class BuildEntrypoint(BuildContext):
def __init__(self, path: str, executor: Executor, entrypoint: Entrypoint):
@@ -506,7 +511,16 @@
return dep
print(f"Scheduling action: {dep}", file=self.stderr)
- dep.start(self.thread_pool_executor.submit(invoke))
+ if dep.concurrnecy == ActionConcurrency.NONE:
+ invoke()
+ elif dep.concurrnecy == ActionConcurrency.THREAD:
+ dep.start(self.thread_pool_executor.submit(invoke))
+ elif dep.concurrnecy == ActionConcurrency.PROCESS:
+ dep.start(self.process_pool_executor.submit(invoke))
+ else:
+ raise AssertionError(
+ f"Unhandled ActionConcurrency value: {dep.concurrnecy}"
+ )
else:
# Not schedulable. Just mark it as done.
dep.start(concurrent.futures.Future())
diff --git a/compiler/bindings/python/iree/build/lang.py b/compiler/bindings/python/iree/build/lang.py
index 5cb8779..60f4c9f 100644
--- a/compiler/bindings/python/iree/build/lang.py
+++ b/compiler/bindings/python/iree/build/lang.py
@@ -9,7 +9,8 @@
import argparse
import functools
-from iree.build.executor import ClArg, Entrypoint
+from iree.build.args import cl_arg # Export as part of the public API
+from iree.build.executor import Entrypoint
__all__ = [
"cl_arg",
@@ -28,25 +29,3 @@
target = Entrypoint(f.__name__, f, description=description)
functools.wraps(target, f)
return target
-
-
-def cl_arg(name: str, *, action=None, default=None, type=None, help=None):
- """Used to define or reference a command-line argument from within actions
- and entry-points.
-
- Keywords have the same interpretation as `ArgumentParser.add_argument()`.
-
- Any ClArg set as a default value for an argument to an `entrypoint` will be
- added to the global argument parser. Any particular argument name can only be
- registered once and must not conflict with a built-in command line option.
- The implication of this is that for single-use arguments, the `=cl_arg(...)`
- can just be added as a default argument. Otherwise, for shared arguments,
- it should be created at the module level and referenced.
-
- When called, any entrypoint arguments that do not have an explicit keyword
- set will get their value from the command line environment.
- """
- if name.startswith("-"):
- raise ValueError("cl_arg name must not be prefixed with dashes")
- dest = name.replace("-", "_")
- return ClArg(name, action=action, default=default, type=type, dest=dest, help=help)
diff --git a/compiler/bindings/python/iree/build/main.py b/compiler/bindings/python/iree/build/main.py
index 272533c..6d40d18 100644
--- a/compiler/bindings/python/iree/build/main.py
+++ b/compiler/bindings/python/iree/build/main.py
@@ -12,7 +12,12 @@
from pathlib import Path
import sys
-from iree.build.executor import BuildEntrypoint, Entrypoint, Executor
+from iree.build.args import (
+ argument_namespace_context,
+ configure_arg_parser,
+ run_global_arg_handlers,
+)
+from iree.build.executor import BuildEntrypoint, BuildFile, Entrypoint, Executor
__all__ = [
"iree_build_main",
@@ -117,6 +122,7 @@
help="Paths of actions to build (default to top-level actions)",
)
+ configure_arg_parser(p)
self._define_action_arguments(p)
self.args = self.arg_parser.parse_args(args)
@@ -126,7 +132,7 @@
def _define_action_arguments(self, p: argparse.ArgumentParser):
user_group = p.add_argument_group("Action defined options")
for ep in self.top_module.entrypoints.values():
- for cl_arg in ep.cl_args():
+ for cl_arg in ep.cl_arg_defs:
cl_arg.define_arg(user_group)
def _resolve_module_arguments(
@@ -169,15 +175,17 @@
return rem_args, top_module
def _create_executor(self) -> Executor:
- executor = Executor(self.args.output_dir, self.args, stderr=self.stderr)
+ executor = Executor(self.args.output_dir, stderr=self.stderr)
executor.analyze(*self.top_module.entrypoints.values())
return executor
def run(self):
- command = self.args.command
- if command is None:
- command = self.build_command
- command()
+ with argument_namespace_context(self.args):
+ run_global_arg_handlers(self.args)
+ command = self.args.command
+ if command is None:
+ command = self.build_command
+ command()
def build_command(self):
executor = self._create_executor()
@@ -203,7 +211,9 @@
for build_action in build_actions:
if isinstance(build_action, BuildEntrypoint):
for output in build_action.outputs:
- print(f"{output.get_fs_path()}", file=self.stdout)
+ print(output.get_fs_path(), file=self.stdout)
+ elif isinstance(build_action, BuildFile):
+ print(build_action.get_fs_path(), file=self.stdout)
def list_command(self):
executor = self._create_executor()
diff --git a/compiler/bindings/python/iree/build/metadata.py b/compiler/bindings/python/iree/build/metadata.py
new file mode 100644
index 0000000..91767dd
--- /dev/null
+++ b/compiler/bindings/python/iree/build/metadata.py
@@ -0,0 +1,37 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+"""Common `BuildMeta` subclasses for built-in actions.
+
+These are maintained here purely as an aid to avoiding circular dependencies.
+Typically, in out of tree actions, they would just be inlined into the implementation
+file.
+"""
+
+from .executor import BuildMeta
+
+
+class CompileSourceMeta(BuildMeta):
+ """CompileSourceMeta tracks source level properties that can influence compilation.
+
+ This meta can be set on any dependency that ultimately is used as a source to a
+ `compile` action.
+ """
+
+ # Slots in this case simply will catch attempts to set undefined attributes.
+ __slots__ = [
+ "input_type",
+ ]
+ KEY = "iree.compile.source"
+
+ def __init__(self):
+ super().__init__()
+
+ # The value to the --iree-input-type= flag for this source file.
+ self.input_type: str = "auto"
+
+ def __repr__(self):
+ return f"CompileSourceMeta(input_type={self.input_type})"
diff --git a/compiler/bindings/python/iree/build/net_actions.py b/compiler/bindings/python/iree/build/net_actions.py
index 1d2e158..da74d9a 100644
--- a/compiler/bindings/python/iree/build/net_actions.py
+++ b/compiler/bindings/python/iree/build/net_actions.py
@@ -30,7 +30,7 @@
self.url = url
self.output_file = output_file
- def invoke(self):
+ def _invoke(self):
path = self.output_file.get_fs_path()
self.executor.write_status(f"Fetching URL: {self.url} -> {path}")
try:
diff --git a/compiler/bindings/python/iree/build/onnx_actions.py b/compiler/bindings/python/iree/build/onnx_actions.py
index 5220890..83ec3a1 100644
--- a/compiler/bindings/python/iree/build/onnx_actions.py
+++ b/compiler/bindings/python/iree/build/onnx_actions.py
@@ -5,6 +5,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from iree.build.executor import BuildAction, BuildContext, BuildFile, BuildFileLike
+from iree.build.metadata import CompileSourceMeta
__all__ = [
"onnx_import",
@@ -30,7 +31,7 @@
input_file=input_file,
output_file=processed_file,
executor=context.executor,
- desc=f"Upgrading ONNX {name}",
+ desc=f"Upgrading ONNX {input_file} -> {processed_file}",
deps=[
input_file,
],
@@ -41,14 +42,13 @@
ImportOnnxAction(
input_file=input_file,
output_file=output_file,
- desc=f"Importing ONNX {name}",
+ desc=f"Importing ONNX {name} -> {output_file}",
executor=context.executor,
deps=[
input_file,
],
)
- output_file.deps.add(processed_file)
return output_file
@@ -57,9 +57,11 @@
super().__init__(**kwargs)
self.input_file = input_file
self.output_file = output_file
+ self.deps.add(self.input_file)
output_file.deps.add(self)
+ CompileSourceMeta.get(output_file).input_type = "onnx"
- def invoke(self):
+ def _invoke(self):
import onnx
input_path = self.input_file.get_fs_path()
@@ -75,9 +77,11 @@
super().__init__(**kwargs)
self.input_file = input_file
self.output_file = output_file
+ self.deps.add(input_file)
output_file.deps.add(self)
+ CompileSourceMeta.get(output_file).input_type = "onnx"
- def invoke(self):
+ def _invoke(self):
import iree.compiler.tools.import_onnx.__main__ as m
args = m.parse_arguments(
diff --git a/compiler/bindings/python/iree/build/target_machine.py b/compiler/bindings/python/iree/build/target_machine.py
new file mode 100644
index 0000000..c8eca7d
--- /dev/null
+++ b/compiler/bindings/python/iree/build/target_machine.py
@@ -0,0 +1,190 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+"""Handles the messy affair of deriving options for targeting machines."""
+
+import argparse
+
+from iree.build.args import (
+ expand_cl_arg_defaults,
+ register_arg_parser_callback,
+ cl_arg_ref,
+)
+
+
+class TargetMachine:
+ def __init__(
+ self,
+ target_spec: str,
+ *,
+ iree_compile_device_type: str | None = None,
+ extra_flags: list[str] | None = None,
+ ):
+ self.target_spec = target_spec
+ self.iree_compile_device_type = iree_compile_device_type
+ self.extra_flags = extra_flags
+
+ @property
+ def flag_list(self) -> list[str]:
+ if self.iree_compile_device_type is not None:
+ # This is just a hard-coded machine model using a single IREE device
+ # type alias in the default configuration.
+ return [f"--iree-hal-target-device={self.iree_compile_device_type}"] + (
+ self.extra_flags or []
+ )
+ raise RuntimeError(f"Cannot compute iree-compile flags for: {self}")
+
+ def __repr__(self):
+ r = f"TargetMachine({self.target_spec}, "
+ if self.iree_compile_device_type is not None:
+ r += f"iree_compile_device_type='{self.iree_compile_device_type}', "
+ if self.extra_flags:
+ r += f"extra_flags={self.extra_flags}, "
+ r += ")"
+ return r
+
+
+################################################################################
+# Handling of --iree-hal-target-device from flags
+################################################################################
+
+
+HAL_TARGET_DEVICES_FROM_FLAGS_HANDLERS = {}
+
+
+def handle_hal_target_devices_from_flags(*mnemonics: str):
+ def decorator(f):
+ for mn in mnemonics:
+ HAL_TARGET_DEVICES_FROM_FLAGS_HANDLERS[mn] = f
+ return f
+
+ return decorator
+
+
+def handle_unknown_hal_target_device(mnemonic: str) -> list[TargetMachine]:
+ return [TargetMachine(mnemonic, iree_compile_device_type=mnemonic)]
+
+
+@handle_hal_target_devices_from_flags("amdgpu", "hip")
+@expand_cl_arg_defaults
+def amdgpu_hal_target_from_flags(
+ mnemonic: str, *, amdgpu_target=cl_arg_ref("iree_amdgpu_target")
+) -> list[TargetMachine]:
+ if not amdgpu_target:
+ raise RuntimeError(
+ "No AMDGPU targets specified. Pass a chip to target as "
+ "--iree-amdgpu-target=gfx..."
+ )
+ return [
+ TargetMachine(
+ f"amdgpu-{amdgpu_target}",
+ iree_compile_device_type="amdgpu",
+ extra_flags=[f"--iree-hip-target={amdgpu_target}"],
+ )
+ ]
+
+
+@handle_hal_target_devices_from_flags("llvm-cpu", "cpu")
+@expand_cl_arg_defaults
+def cpu_hal_target_from_flags(
+ mnemonic: str,
+ *,
+ cpu=cl_arg_ref("iree_llvmcpu_target_cpu"),
+ features=cl_arg_ref("iree_llvmcpu_target_cpu_features"),
+) -> list[TargetMachine]:
+ target_spec = "cpu"
+ extra_flags = []
+ if cpu:
+ target_spec += f"-{cpu}"
+ extra_flags.append(f"--iree-llvmcpu-target-cpu={cpu}")
+ if features:
+ target_spec += f":{features}"
+ extra_flags.append(f"--iree-llvmcpu-target-cpu-features={features}")
+
+ return [
+ TargetMachine(
+ f"cpu-{cpu or 'generic'}",
+ iree_compile_device_type="llvm-cpu",
+ extra_flags=extra_flags,
+ )
+ ]
+
+
+################################################################################
+# Flag definition
+################################################################################
+
+
+@register_arg_parser_callback
+def _(p: argparse.ArgumentParser):
+ g = p.add_argument_group(
+ title="IREE Target Machine Options",
+ description="Global options controlling invocation of iree-compile",
+ )
+ g.add_argument(
+ "--iree-hal-target-device",
+ help="Compiles with a single machine model and a single specified device"
+ " (mutually exclusive with other ways to set the machine target). This "
+ "emulates the simple case of device targeting if invoking `iree-compile` "
+ "directly and is mostly a pass-through which also enforces other flags "
+ "depending on the value given. Supported options (or any supported by the "
+ "compiler): "
+ f"{', '.join(HAL_TARGET_DEVICES_FROM_FLAGS_HANDLERS.keys() - 'default')}",
+ nargs="*",
+ )
+
+ hip_g = p.add_argument_group(
+ title="IREE AMDGPU Target Options",
+ description="Options controlling explicit targeting of AMDGPU devices",
+ )
+ hip_g.add_argument(
+ "--iree-amdgpu-target",
+ "--iree-hip-target",
+ help="AMDGPU target selection (i.e. 'gfxYYYY')",
+ )
+
+ cpu_g = p.add_argument_group(
+ title="IREE CPU Target Options",
+ description="These are mostly pass-through. See `iree-compile --help` for "
+ "full information. Advanced usage will require an explicit machine config "
+ "file",
+ )
+ cpu_g.add_argument(
+ "--iree-llvmcpu-target-cpu",
+ help="'generic', 'host', or an explicit CPU name. See iree-compile help.",
+ )
+ cpu_g.add_argument(
+ "--iree-llvmcpu-target-cpu-features",
+ help="Comma separated list of '+' prefixed CPU features. See iree-compile help.",
+ )
+
+
+################################################################################
+# Global flag dispatch
+################################################################################
+
+
+@expand_cl_arg_defaults
+def compute_target_machines_from_flags(
+ *,
+ explicit_hal_target_devices: list[str]
+ | None = cl_arg_ref("iree_hal_target_device"),
+) -> list[TargetMachine]:
+ if explicit_hal_target_devices is not None:
+ # Most basic default case for setting up compilation.
+ machines = []
+ for explicit_hal_target_device in explicit_hal_target_devices:
+ handler = (
+ HAL_TARGET_DEVICES_FROM_FLAGS_HANDLERS.get(explicit_hal_target_device)
+ or handle_unknown_hal_target_device
+ )
+ machines.extend(handler(explicit_hal_target_device))
+ return machines
+
+ raise RuntimeError(
+ "iree-compile target information is required but none was provided. "
+ "See flags: --iree-hal-target-device"
+ )
diff --git a/compiler/bindings/python/iree/compiler/api/ctypes_dl.py b/compiler/bindings/python/iree/compiler/api/ctypes_dl.py
index 784fc61..26f7228 100644
--- a/compiler/bindings/python/iree/compiler/api/ctypes_dl.py
+++ b/compiler/bindings/python/iree/compiler/api/ctypes_dl.py
@@ -189,6 +189,11 @@
return len(view) >= 4 and view[:4].hex() == "4d4cef52"
+class SessionObject:
+ def close(self):
+ ...
+
+
class Session:
def __init__(self):
self._global_init = _global_init
@@ -197,12 +202,27 @@
# its ownership of it, so we must cache the new Python-level MLIRContext
# so its lifetime extends at least to our own.
self._owned_context = None
+ self._dependents: set[SessionObject] = set()
def __del__(self):
- _dylib.ireeCompilerSessionDestroy(self._session_p)
+ self.close()
+
+ def close(self):
+ if self._session_p:
+ for dep in list(self._dependents):
+ dep.close()
+ _dylib.ireeCompilerSessionDestroy(self._session_p)
+ self._session_p = c_void_p()
+
+ def __enter__(self) -> "Session":
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.close()
@property
def context(self):
+ assert self._session_p, "Session is closed"
if self._owned_context is None:
from .. import ir
@@ -218,9 +238,11 @@
return self._owned_context
def invocation(self) -> "Invocation":
+ assert self._session_p, "Session is closed"
return Invocation(self)
def get_flags(self, non_default_only: bool = False) -> Sequence[str]:
+ assert self._session_p, "Session is closed"
results = []
@_GET_FLAG_CALLBACK
@@ -235,6 +257,7 @@
return results
def set_flags(self, *flags: str):
+ assert self._session_p, "Session is closed"
argv_type = c_char_p * len(flags)
argv = argv_type(*[flag.encode("UTF-8") for flag in flags])
_handle_error(
@@ -257,6 +280,12 @@
self._local_dylib.ireeCompilerOutputDestroy(self._output_p)
self._output_p = None
+ def __enter__(self) -> "Invocation":
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.close()
+
@staticmethod
def open_file(file_path: str) -> "Output":
output_p = c_void_p()
@@ -300,11 +329,12 @@
return pointer
-class Source:
+class Source(SessionObject):
"""Wraps an iree_compiler_source_t."""
- def __init__(self, session: c_void_p, source_p: c_void_p, backing_ref):
- self._session: c_void_p = session # Keeps ref alive.
+ def __init__(self, session: Session, source_p: c_void_p, backing_ref):
+ self._session: Session | None = session # Keeps ref alive.
+ self._session._dependents.add(self)
self._source_p: c_void_p = source_p
self._backing_ref = backing_ref
self._local_dylib = _dylib
@@ -318,7 +348,14 @@
self._source_p = c_void_p()
self._local_dylib.ireeCompilerSourceDestroy(s)
self._backing_ref = None
- self._session = c_void_p()
+ self._session._dependents.remove(self)
+ self._session = None
+
+ def __enter__(self) -> "Invocation":
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.close()
def __repr__(self):
return f"<Source {self._source_p}>"
@@ -362,9 +399,10 @@
IREE_COMPILER_PIPELINE_PRECOMPILE = 2
-class Invocation:
+class Invocation(SessionObject):
def __init__(self, session: Session):
- self._session = session
+ self._session: Session | None = session
+ self._session._dependents.add(self)
self._inv_p = _dylib.ireeCompilerInvocationCreate(self._session._session_p)
self._sources: list[Source] = []
self._local_dylib = _dylib
@@ -387,6 +425,14 @@
for s in self._sources:
s.close()
self._sources.clear()
+ self._session._dependents.remove(self)
+ self._session = None
+
+ def __enter__(self) -> "Invocation":
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.close()
def enable_console_diagnostics(self):
_dylib.ireeCompilerInvocationEnableConsoleDiagnostics(self._inv_p)
diff --git a/compiler/bindings/python/test/build_api/mnist_builder.py b/compiler/bindings/python/test/build_api/mnist_builder.py
index 54d6e30..31c4845 100644
--- a/compiler/bindings/python/test/build_api/mnist_builder.py
+++ b/compiler/bindings/python/test/build_api/mnist_builder.py
@@ -23,7 +23,10 @@
name="mnist.mlir",
source="mnist.onnx",
)
- return "mnist.mlir"
+ return compile(
+ name="mnist",
+ source="mnist.mlir",
+ )
if __name__ == "__main__":
diff --git a/compiler/bindings/python/test/build_api/mnist_builder_test.py b/compiler/bindings/python/test/build_api/mnist_builder_test.py
index 02a9140..fb8cefb 100644
--- a/compiler/bindings/python/test/build_api/mnist_builder_test.py
+++ b/compiler/bindings/python/test/build_api/mnist_builder_test.py
@@ -16,6 +16,11 @@
THIS_DIR = Path(__file__).resolve().parent
+DEFAULT_TARGET_ARGS = [
+ "--iree-hal-target-device=cpu",
+ "--iree-llvmcpu-target-cpu=host",
+]
+
class MnistBuilderTest(unittest.TestCase):
def setUp(self):
@@ -39,14 +44,14 @@
"--output-dir",
str(self.output_path),
]
+ + DEFAULT_TARGET_ARGS
).decode()
print("OUTPUT:", output)
output_paths = output.splitlines()
- self.assertEqual(len(output_paths), 1)
+ self.assertEqual(len(output_paths), 1, msg=f"Found {output_paths}")
output_path = Path(output_paths[0])
self.assertTrue(output_path.is_relative_to(self.output_path))
- contents = output_path.read_text()
- self.assertIn("module", contents)
+ self.assertIn("mnist_cpu-host.vmfb", output_paths[0])
# Tests that invoking via the build module itself works
# python {path to py file}
@@ -59,22 +64,24 @@
"--output-dir",
str(self.output_path),
]
+ + DEFAULT_TARGET_ARGS
).decode()
print("OUTPUT:", output)
output_paths = output.splitlines()
- self.assertEqual(len(output_paths), 1)
+ self.assertEqual(len(output_paths), 1, msg=f"Found {output_paths}")
+ self.assertIn("mnist_cpu-host.vmfb", output_paths[0])
def testListCommand(self):
mod = load_build_module(THIS_DIR / "mnist_builder.py")
out_file = io.StringIO()
- iree_build_main(mod, args=["--list"], stdout=out_file)
+ iree_build_main(mod, args=["--list"] + DEFAULT_TARGET_ARGS, stdout=out_file)
output = out_file.getvalue().strip()
self.assertEqual(output, "mnist")
def testListAllCommand(self):
mod = load_build_module(THIS_DIR / "mnist_builder.py")
out_file = io.StringIO()
- iree_build_main(mod, args=["--list-all"], stdout=out_file)
+ iree_build_main(mod, args=["--list-all"] + DEFAULT_TARGET_ARGS, stdout=out_file)
output = out_file.getvalue().splitlines()
self.assertIn("mnist", output)
self.assertIn("mnist/mnist.onnx", output)
@@ -92,11 +99,24 @@
args=[
"--mnist-onnx-url",
"https://github.com/iree-org/doesnotexist",
- ],
+ ]
+ + DEFAULT_TARGET_ARGS,
stdout=out_file,
stderr=err_file,
)
+ def testBuildNonDefaultSubTarget(self):
+ mod = load_build_module(THIS_DIR / "mnist_builder.py")
+ out_file = io.StringIO()
+ iree_build_main(
+ mod, args=["mnist/mnist.mlir"] + DEFAULT_TARGET_ARGS, stdout=out_file
+ )
+ output = out_file.getvalue().strip()
+ self.assertIn("genfiles/mnist/mnist.mlir", output)
+ output_path = Path(output)
+ contents = output_path.read_text()
+ self.assertIn("module", contents)
+
if __name__ == "__main__":
try: