[python] Add an iree.build package with API/tooling for program building. (#18630)
This is the first step of providing a unified user-oriented build tool
for IREE export and compilation. This initial version supports the CLI
environment, network/fetch actions, and ONNX import+upgrade. A follow-on
step will build out the compiler invocation rules, specifically focusing
on getting option, device, and parameter handling and manifest/metadata
propagated properly.
The intent is that downstream users of the IREE/Turbine tools can have
one stop modules to efficiently export and compile arbitrarily
complicated pipelines without reinventing the wheel or needing to do the
advanced compiler juggling that comes with some of the more complicated
(and performant) workflows.
As a by-product this will standardize a directory layout and
manifest/metadata to accompany it for pipeline construction. The intent
is that this will simplify serving tooling, which currently has to get a
number of out of band parameters and files put together in the right way
and matching how they were compiled.
---------
Signed-off-by: Stella Laurenzo <stellaraccident@gmail.com>
diff --git a/.gitignore b/.gitignore
index b61afb4..58f7831 100644
--- a/.gitignore
+++ b/.gitignore
@@ -22,6 +22,8 @@
build/
build-*/
Testing/
+# Include iree.build package
+!compiler/bindings/python/iree/compiler/build/
# Bazel artifacts
**/bazel-*
diff --git a/compiler/bindings/python/CMakeLists.txt b/compiler/bindings/python/CMakeLists.txt
index 0a0b364..bc8119f 100644
--- a/compiler/bindings/python/CMakeLists.txt
+++ b/compiler/bindings/python/CMakeLists.txt
@@ -235,6 +235,37 @@
################################################################################
+# iree.build package
+# This is a pure Python part of the namespace, not rooted under iree.compiler
+# like the above. It is only using the same build support for compatibility
+# with the existing development flow.
+# If the build system for Python code is ever redone, this can just be
+# source namespace in the project definition.
+################################################################################
+
+# The iree.build package.
+declare_mlir_python_sources(IREECompilerBuildPythonPackage
+ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/iree/build"
+SOURCES
+ __init__.py
+ __main__.py
+ executor.py
+ lang.py
+ main.py
+ net_actions.py
+ onnx_actions.py
+)
+
+add_mlir_python_modules(IREECompilerBuildPythonModules
+ ROOT_PREFIX "${_PYTHON_BUILD_PREFIX}/iree/build"
+ INSTALL_PREFIX "${_PYTHON_INSTALL_PREFIX}/iree/build"
+ DECLARED_SOURCES
+ IREECompilerBuildPythonPackage
+)
+
+add_dependencies(IREECompilerPythonModules IREECompilerBuildPythonModules)
+
+################################################################################
# Tools linked against the shared CAPI library
################################################################################
diff --git a/compiler/bindings/python/iree/build/__init__.py b/compiler/bindings/python/iree/build/__init__.py
new file mode 100644
index 0000000..95ee3f7
--- /dev/null
+++ b/compiler/bindings/python/iree/build/__init__.py
@@ -0,0 +1,12 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import argparse
+
+from iree.build.lang import *
+from iree.build.main import *
+from iree.build.net_actions import *
+from iree.build.onnx_actions import *
diff --git a/compiler/bindings/python/iree/build/__main__.py b/compiler/bindings/python/iree/build/__main__.py
new file mode 100644
index 0000000..bbaa9e1
--- /dev/null
+++ b/compiler/bindings/python/iree/build/__main__.py
@@ -0,0 +1,11 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from .main import CliMain
+
+
+if __name__ == "__main__":
+ CliMain().run()
diff --git a/compiler/bindings/python/iree/build/executor.py b/compiler/bindings/python/iree/build/executor.py
new file mode 100644
index 0000000..35b9173
--- /dev/null
+++ b/compiler/bindings/python/iree/build/executor.py
@@ -0,0 +1,517 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import Callable, Collection, Generator, IO
+
+import abc
+import argparse
+import concurrent.futures
+import enum
+import inspect
+import multiprocessing
+import sys
+import time
+import traceback
+from pathlib import Path
+import threading
+
+_locals = threading.local()
+
+
+class FileNamespace(enum.StrEnum):
+ # Transient generated files go into the GEN namespace. These are typically
+ # not packaged for distribution.
+ GEN = enum.auto()
+
+ # Distributable parameter files.
+ PARAMS = enum.auto()
+
+ # Distributable, platform-neutral binaries.
+ BIN = enum.auto()
+
+ # Distributable, platform specific binaries.
+ PLATFORM_BIN = enum.auto()
+
+
+FileNamespaceToPath = {
+ FileNamespace.GEN: lambda executor: executor.output_dir / "genfiles",
+ FileNamespace.PARAMS: lambda executor: executor.output_dir / "params",
+ FileNamespace.BIN: lambda executor: executor.output_dir / "bin",
+ # TODO: This isn't right. Need to resolve platform dynamically.
+ FileNamespace.PLATFORM_BIN: lambda executor: executor.output_dir / "platform",
+}
+
+
+def join_namespace(prefix: str, suffix: str) -> str:
+ """Joins two namespace components, taking care of the root namespace (empty)."""
+ if not prefix:
+ return suffix
+ return f"{prefix}/{suffix}"
+
+
+class ClArg:
+ def __init__(self, name, dest: str, **add_argument_kw):
+ self.name = name
+ self.dest = dest
+ self.add_argument_kw = add_argument_kw
+
+ def define_arg(self, parser: argparse.ArgumentParser):
+ parser.add_argument(f"--{self.name}", dest=self.dest, **self.add_argument_kw)
+
+ def resolve(self, arg_namespace: argparse.Namespace):
+ try:
+ return getattr(arg_namespace, self.dest)
+ except AttributeError as e:
+ raise RuntimeError(
+ f"Unable to resolve command line argument '{self.dest}' in namespace"
+ ) from e
+
+
+class Entrypoint:
+ def __init__(
+ self,
+ name: str,
+ wrapped: Callable,
+ description: str | None = None,
+ ):
+ self.name = name
+ self.description = description
+ self._wrapped = wrapped
+
+ def cl_args(self) -> Generator[ClArg, None, None]:
+ sig = inspect.signature(self._wrapped)
+ for p in sig.parameters.values():
+ def_value = p.default
+ if isinstance(def_value, ClArg):
+ yield def_value
+
+ def __call__(self, *args, **kwargs):
+ parent_context = BuildContext.current()
+ bep = BuildEntrypoint(
+ join_namespace(parent_context.path, self.name),
+ parent_context.executor,
+ self,
+ )
+ parent_context.executor.entrypoints.append(bep)
+ with bep:
+ sig = inspect.signature(self._wrapped)
+ bound = sig.bind(*args, **kwargs)
+ bound.apply_defaults()
+
+ def filter(arg):
+ if isinstance(arg, ClArg):
+ return arg.resolve(parent_context.executor.args_namespace)
+ return arg
+
+ new_args = [filter(arg) for arg in bound.args]
+ new_kwargs = {k: filter(v) for k, v in bound.kwargs.items()}
+ results = self._wrapped(*new_args, **new_kwargs)
+ if results is not None:
+ files = bep.files(results)
+ bep.deps.update(files)
+ bep.outputs.extend(files)
+ return files
+
+
+class Executor:
+ """Executor that all build contexts share."""
+
+ def __init__(
+ self, output_dir: Path, args_namespace: argparse.Namespace, stderr: IO
+ ):
+ self.output_dir = output_dir
+ self.verbose_level = 0
+ # Keyed by path
+ self.all: dict[str, "BuildContext" | "BuildFile"] = {}
+ self.entrypoints: list["BuildEntrypoint"] = []
+ self.args_namespace = args_namespace
+ self.stderr = stderr
+ BuildContext("", self)
+
+ def check_path_not_exists(self, path: str, for_entity):
+ existing = self.all.get(path)
+ if existing is not None:
+ formatted_stack = "".join(traceback.format_list(existing.def_stack))
+ raise RuntimeError(
+ f"Cannot add {for_entity} because an entity with that name was "
+ f"already defined at:\n{formatted_stack}"
+ )
+
+ def get_context(self, path: str) -> "BuildContext":
+ existing = self.all.get(path)
+ if existing is None:
+ raise RuntimeError(f"Context at path {path} not found")
+ if not isinstance(existing, BuildContext):
+ raise RuntimeError(
+ f"Entity at path {path} is not a context. It is: {existing}"
+ )
+ return existing
+
+ def get_file(self, path: str) -> "BuildFile":
+ existing = self.all.get(path)
+ if existing is None:
+ raise RuntimeError(f"File at path {path} not found")
+ if not isinstance(existing, BuildFile):
+ raise RuntimeError(
+ f"Entity at path {path} is not a file. It is: {existing}"
+ )
+ return existing
+
+ def write_status(self, message: str):
+ print(message, file=self.stderr)
+
+ def get_root(self, namespace: FileNamespace) -> Path:
+ return FileNamespaceToPath[namespace](self)
+
+ def analyze(self, *entrypoints: Entrypoint):
+ """Analyzes all entrypoints building the graph."""
+ for entrypoint in entrypoints:
+ if self.verbose_level > 1:
+ self.write_status(f"Analyzing entrypoint {entrypoint.name}")
+ with self.get_context("") as context:
+ entrypoint()
+
+ def build(self, *initial_deps: "BuildDependency"):
+ """Transitively builds the given deps."""
+ scheduler = Scheduler(stderr=self.stderr)
+ success = False
+ try:
+ for d in initial_deps:
+ scheduler.add_initial_dep(d)
+ scheduler.build()
+ success = True
+ finally:
+ if not success:
+ print("Waiting for background tasks to complete...", file=self.stderr)
+ scheduler.shutdown()
+
+
+class BuildDependency:
+ """Base class of entities that can act as a build dependency."""
+
+ def __init__(
+ self, *, executor: Executor, deps: set["BuildDependency"] | None = None
+ ):
+ self.executor = executor
+ self.deps: set[BuildDependency] = set()
+ if deps:
+ self.deps.update(deps)
+
+ # Scheduling state.
+ self.future: concurrent.futures.Future | None = None
+ self.start_time: float | None = None
+ self.finish_time: float | None = None
+
+ @property
+ def is_scheduled(self) -> bool:
+ return self.future is not None
+
+ @property
+ def execution_time(self) -> float:
+ if self.start_time is None:
+ return 0.0
+ if self.finish_time is None:
+ return time.time() - self.start_time
+ return self.finish_time - self.start_time
+
+ def start(self, future: concurrent.futures.Future):
+ assert not self.is_scheduled, f"Cannot start an already scheduled dep: {self}"
+ self.future = future
+ self.start_time = time.time()
+
+ def finish(self):
+ assert self.is_scheduled, "Cannot finish an unstarted dep"
+ self.finish_time = time.time()
+ self.future.set_result(self)
+
+
+class BuildFile(BuildDependency):
+ """Generated file in the build tree."""
+
+ def __init__(
+ self,
+ *,
+ executor: Executor,
+ path: str,
+ namespace: FileNamespace = FileNamespace.GEN,
+ deps: set[BuildDependency] | None = None,
+ ):
+ super().__init__(executor=executor, deps=deps)
+ self.def_stack = traceback.extract_stack()[0:-2]
+ self.executor = executor
+ self.path = path
+ self.namespace = namespace
+ # Set of build files that must be made available to any transitive user
+ # of this build file at runtime.
+ self.runfiles: set["BuildFile"] = set()
+
+ executor.check_path_not_exists(path, self)
+ executor.all[path] = self
+
+ def get_fs_path(self) -> Path:
+ path = self.executor.get_root(self.namespace) / self.path
+ path.parent.mkdir(parents=True, exist_ok=True)
+ return path
+
+ def __repr__(self):
+ return f"BuildFile[{self.namespace}]({self.path})"
+
+
+class ActionConcurrency(enum.StrEnum):
+ THREAD = enum.auto()
+ PROCESS = enum.auto()
+ NONE = enum.auto()
+
+
+class BuildAction(BuildDependency, abc.ABC):
+ """An action that must be carried out."""
+
+ def __init__(
+ self,
+ *,
+ desc: str,
+ executor: Executor,
+ concurrency: ActionConcurrency = ActionConcurrency.THREAD,
+ deps: set[BuildDependency] | None = None,
+ ):
+ super().__init__(executor=executor, deps=deps)
+ self.desc = desc
+ self.concurrnecy = concurrency
+
+ def __str__(self):
+ return self.desc
+
+ def __repr__(self):
+ return f"Action[{type(self).__name__}]('{self.desc}')"
+
+ @abc.abstractmethod
+ def invoke(self):
+ ...
+
+
+class BuildContext(BuildDependency):
+ """Manages a build graph under construction."""
+
+ def __init__(self, path: str, executor: Executor):
+ super().__init__(executor=executor)
+ self.def_stack = traceback.extract_stack()[0:-2]
+ self.executor = executor
+ self.path = path
+ executor.check_path_not_exists(path, self)
+ executor.all[path] = self
+ self.analyzed = False
+
+ def __repr__(self):
+ return f"{type(self).__name__}(path='{self.path}')"
+
+ def allocate_file(
+ self, path: str, namespace: FileNamespace = FileNamespace.GEN
+ ) -> BuildFile:
+ """Allocates a file in the build tree with local path |path|.
+
+ If |path| is absoluate (starts with '/'), then it is used as-is. Otherwise,
+ it is joined with the path of this context.
+ """
+ if not path.startswith("/"):
+ path = join_namespace(self.path, path)
+ return BuildFile(executor=self.executor, path=path, namespace=namespace)
+
+ def file(self, file: str | BuildFile) -> BuildFile:
+ """Accesses a BuildFile by either string (path) or BuildFile.
+
+ It must already exist.
+ """
+ if isinstance(file, BuildFile):
+ return file
+ path = file
+ if not path.startswith("/"):
+ path = join_namespace(self.path, path)
+ existing = self.executor.all.get(path)
+ if not isinstance(existing, BuildFile):
+ all_files = [
+ f.path for f in self.executor.all.values() if isinstance(f, BuildFile)
+ ]
+ all_files_lines = "\n ".join(all_files)
+ raise RuntimeError(
+ f"File with path '{path}' is not known in the build graph. Available:\n"
+ f" {all_files_lines}"
+ )
+ return existing
+
+ def files(
+ self, files: str | BuildFile | Collection[str | BuildFile]
+ ) -> list[BuildFile]:
+ """Accesses a collection of files (or single) as a list of BuildFiles."""
+ if isinstance(files, (str, BuildFile)):
+ return [self.file(files)]
+ return [self.file(f) for f in files]
+
+ @staticmethod
+ def current() -> "BuildContext":
+ try:
+ return _locals.context_stack[-1]
+ except (AttributeError, IndexError):
+ raise RuntimeError(
+ "The current code can only be evaluated within an active BuildContext"
+ )
+
+ def __enter__(self) -> "BuildContext":
+ try:
+ stack = _locals.context_stack
+ except AttributeError:
+ stack = _locals.context_stack = []
+ stack.append(self)
+ return self
+
+ def __exit__(self, *args):
+ try:
+ stack = _locals.context_stack
+ except AttributeError:
+ raise AssertionError("BuildContext exit without enter")
+ existing = stack.pop()
+ assert existing is self, "Unbalanced BuildContext enter/exit"
+
+ def populate_arg_parser(self, parser: argparse.ArgumentParser):
+ ...
+
+
+class BuildEntrypoint(BuildContext):
+ def __init__(self, path: str, executor: Executor, entrypoint: Entrypoint):
+ super().__init__(path, executor)
+ self.entrypoint = entrypoint
+ self.outputs: list[BuildFile] = []
+
+
+class Scheduler:
+ """Holds resources related to scheduling."""
+
+ def __init__(self, stderr: IO):
+ self.stderr = stderr
+
+ # Inverted producer-consumer graph nodes mapping a producer dep to
+ # all deps which directly depend on it and will be unblocked by it
+ # beins satisfied.
+ self.producer_graph: dict[BuildDependency, list[BuildDependency]] = {}
+
+ # Set of build dependencies that have been scheduled. These will all
+ # have a future set on them prior to adding to the set.
+ self.in_flight_deps: set[BuildDependency] = set()
+
+ self.thread_pool_executor = concurrent.futures.ThreadPoolExecutor(
+ max_workers=10, thread_name_prefix="iree.build"
+ )
+ self.process_pool_executor = concurrent.futures.ProcessPoolExecutor(
+ max_workers=10, mp_context=multiprocessing.get_context("spawn")
+ )
+
+ def shutdown(self):
+ self.thread_pool_executor.shutdown(cancel_futures=True)
+ self.process_pool_executor.shutdown(cancel_futures=True)
+
+ def add_initial_dep(self, initial_dep: BuildDependency):
+ assert isinstance(initial_dep, BuildDependency)
+ if initial_dep in self.producer_graph:
+ # Already in the graph.
+ return
+
+ # At this point nothing depends on this initial dep, so just note it
+ # as producing nothing.
+ self.producer_graph[initial_dep] = []
+
+ # Adds a dep requested by some top-level caller.
+ stack: set[BuildDependency] = set()
+ stack.add(initial_dep)
+ for producer_dep in initial_dep.deps:
+ self._add_dep(producer_dep, initial_dep, stack)
+
+ def _add_dep(
+ self,
+ producer_dep: BuildDependency,
+ consumer_dep: BuildDependency,
+ stack: set[BuildDependency],
+ ):
+ if producer_dep in stack:
+ raise RuntimeError(
+ f"Circular dependency: '{producer_dep}' depends on itself: {stack}"
+ )
+ plist = self.producer_graph.get(producer_dep)
+ if plist is None:
+ plist = []
+ self.producer_graph[producer_dep] = plist
+ plist.append(consumer_dep)
+ next_stack = set(stack)
+ next_stack.add(producer_dep)
+ if producer_dep.deps:
+ # Intermediate dep.
+ for next_dep in producer_dep.deps:
+ self._add_dep(next_dep, producer_dep, next_stack)
+
+ def build(self):
+ # Build all deps until the graph is satisfied.
+ # Schedule any deps that have no dependencies to start things off.
+ for eligible_dep in self.producer_graph.keys():
+ if len(eligible_dep.deps) == 0:
+ self._schedule_action(eligible_dep)
+ self.in_flight_deps.add(eligible_dep)
+
+ while self.producer_graph:
+ print(
+ f"Servicing {len(self.producer_graph)} outstanding tasks",
+ file=self.stderr,
+ )
+ self._service_graph()
+
+ def _service_graph(self):
+ completed_deps: set[BuildDependency] = set()
+ try:
+ for completed_fut in concurrent.futures.as_completed(
+ (d.future for d in self.in_flight_deps), 0
+ ):
+ completed_dep = completed_fut.result()
+ assert isinstance(completed_dep, BuildDependency)
+ print(f"Completed {completed_dep}", file=self.stderr)
+ completed_deps.add(completed_dep)
+ except TimeoutError:
+ pass
+
+ # Purge done from in-flight list.
+ self.in_flight_deps.difference_update(completed_deps)
+
+ # Schedule any available.
+ for completed_dep in completed_deps:
+ ready_list = self.producer_graph.get(completed_dep)
+ if ready_list is None:
+ continue
+ del self.producer_graph[completed_dep]
+ for ready_dep in ready_list:
+ self._schedule_action(ready_dep)
+ self.in_flight_deps.add(ready_dep)
+
+ # Do a blocking wait for at least one ready.
+ concurrent.futures.wait(
+ (d.future for d in self.in_flight_deps),
+ return_when=concurrent.futures.FIRST_COMPLETED,
+ )
+
+ def _schedule_action(self, dep: BuildDependency):
+ if dep.is_scheduled:
+ return
+ if isinstance(dep, BuildAction):
+
+ def invoke():
+ dep.invoke()
+ return dep
+
+ print(f"Scheduling action: {dep}", file=self.stderr)
+ dep.start(self.thread_pool_executor.submit(invoke))
+ else:
+ # Not schedulable. Just mark it as done.
+ dep.start(concurrent.futures.Future())
+ dep.finish()
+
+
+# Type aliases.
+BuildFileLike = BuildFile | str
diff --git a/compiler/bindings/python/iree/build/lang.py b/compiler/bindings/python/iree/build/lang.py
new file mode 100644
index 0000000..5cb8779
--- /dev/null
+++ b/compiler/bindings/python/iree/build/lang.py
@@ -0,0 +1,52 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import Callable
+
+import argparse
+import functools
+
+from iree.build.executor import ClArg, Entrypoint
+
+__all__ = [
+ "cl_arg",
+ "entrypoint",
+]
+
+
+def entrypoint(
+ f=None,
+ *,
+ description: str | None = None,
+):
+ """Function decorator to turn it into a build entrypoint."""
+ if f is None:
+ return functools.partial(entrypoint, description=description)
+ target = Entrypoint(f.__name__, f, description=description)
+ functools.wraps(target, f)
+ return target
+
+
+def cl_arg(name: str, *, action=None, default=None, type=None, help=None):
+ """Used to define or reference a command-line argument from within actions
+ and entry-points.
+
+ Keywords have the same interpretation as `ArgumentParser.add_argument()`.
+
+ Any ClArg set as a default value for an argument to an `entrypoint` will be
+ added to the global argument parser. Any particular argument name can only be
+ registered once and must not conflict with a built-in command line option.
+ The implication of this is that for single-use arguments, the `=cl_arg(...)`
+ can just be added as a default argument. Otherwise, for shared arguments,
+ it should be created at the module level and referenced.
+
+ When called, any entrypoint arguments that do not have an explicit keyword
+ set will get their value from the command line environment.
+ """
+ if name.startswith("-"):
+ raise ValueError("cl_arg name must not be prefixed with dashes")
+ dest = name.replace("-", "_")
+ return ClArg(name, action=action, default=default, type=type, dest=dest, help=help)
diff --git a/compiler/bindings/python/iree/build/main.py b/compiler/bindings/python/iree/build/main.py
new file mode 100644
index 0000000..272533c
--- /dev/null
+++ b/compiler/bindings/python/iree/build/main.py
@@ -0,0 +1,244 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import Any, IO
+
+import argparse
+import importlib
+import importlib.util
+from pathlib import Path
+import sys
+
+from iree.build.executor import BuildEntrypoint, Entrypoint, Executor
+
+__all__ = [
+ "iree_build_main",
+ "load_build_module",
+]
+
+
+def iree_build_main(
+ module="__main__",
+ args: list[str] | None = None,
+ stdout: IO | None = None,
+ stderr: IO | None = None,
+):
+ """Make a build module invoke iree.build on itself when run.
+
+ Typically, if you have a module that declares build entrypoints, you will
+ add a stanza at the end:
+
+ .. code-block:: python
+ from iree.build import *
+
+ if __name__ == "__main__":
+ iree_build_main()
+ """
+ main = CliMain(module=module, args=args, stdout=stdout, stderr=stderr)
+ main.run()
+
+
+def load_build_module(module_path: Path | str):
+ """Loads a build module by path, evaling and returning it."""
+ spec = importlib.util.spec_from_file_location("__iree_build__", module_path)
+ mod = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(mod)
+ return mod
+
+
+class CliMain:
+ """Composes command line programs."""
+
+ def __init__(
+ self,
+ *,
+ args: list[str] | None = None,
+ module=None,
+ stdout: IO | None = None,
+ stderr: IO | None = None,
+ ):
+ self.stdout = stdout if stdout is not None else sys.stdout
+ self.stderr = stderr if stderr is not None else sys.stderr
+ if args is None:
+ args = sys.argv[1:]
+ if module is not None and isinstance(module, str):
+ module = __import__(module)
+ module = module
+
+ p = self.arg_parser = argparse.ArgumentParser(
+ description="IREE program build driver"
+ )
+ if module is None:
+ args, self.top_module = self._resolve_module_arguments(args)
+ else:
+ self.top_module = ModuleWrapper(module)
+
+ p.add_argument(
+ "--output-dir",
+ type=Path,
+ default=Path.cwd(),
+ help="Output directory for the build tree (defaults to current directory)",
+ )
+
+ cmd_group_desc = p.add_argument_group(
+ title="Build command",
+ description="Selects a build sub-command to invoke (default '--build')",
+ )
+ cmd_group = cmd_group_desc.add_mutually_exclusive_group()
+ cmd_group.add_argument(
+ "--build",
+ dest="command",
+ action="store_const",
+ const=self.build_command,
+ help="Executes build actions",
+ )
+ cmd_group.add_argument(
+ "--list",
+ dest="command",
+ action="store_const",
+ const=self.list_command,
+ help="Lists top level build actions",
+ )
+
+ cmd_group.add_argument(
+ "--list-all",
+ dest="command",
+ action="store_const",
+ const=self.list_all_command,
+ help="Lists all build actions",
+ )
+
+ p.add_argument(
+ "action_path",
+ nargs="*",
+ help="Paths of actions to build (default to top-level actions)",
+ )
+
+ self._define_action_arguments(p)
+ self.args = self.arg_parser.parse_args(args)
+
+ def abort(self):
+ sys.exit(1)
+
+ def _define_action_arguments(self, p: argparse.ArgumentParser):
+ user_group = p.add_argument_group("Action defined options")
+ for ep in self.top_module.entrypoints.values():
+ for cl_arg in ep.cl_args():
+ cl_arg.define_arg(user_group)
+
+ def _resolve_module_arguments(
+ self, args: list[str]
+ ) -> tuple[list[str], "ModuleWrapper"]:
+ p = argparse.ArgumentParser(
+ add_help=False,
+ usage="python -m iree.build [-m] build_module [... additional module specific options ...]",
+ prog="python -m iree.build",
+ )
+ # Invoked as a standalone tool: need the user to specify the
+ # module.
+ p.add_argument(
+ "-m",
+ dest="parse_as_module",
+ action="store_true",
+ help="Interpret the build definitions argument as a module (vs a file)",
+ )
+ p.add_argument(
+ "build_module",
+ help="The Python file or module from which to load build definitions",
+ )
+
+ bootstrap_args, rem_args = p.parse_known_args(args)
+ # Resolve from arguments.
+ is_module = bootstrap_args.parse_as_module or _is_module_like_str(
+ bootstrap_args.build_module
+ )
+ if is_module:
+ try:
+ top_module = ModuleWrapper.load_module(bootstrap_args.build_module)
+ except ModuleNotFoundError as e:
+ print(
+ f"ERROR: Module '{bootstrap_args.build_module}' not found: {e}",
+ file=self.stderr,
+ )
+ self.abort()
+ else:
+ top_module = ModuleWrapper.load_py_file(bootstrap_args.build_module)
+ return rem_args, top_module
+
+ def _create_executor(self) -> Executor:
+ executor = Executor(self.args.output_dir, self.args, stderr=self.stderr)
+ executor.analyze(*self.top_module.entrypoints.values())
+ return executor
+
+ def run(self):
+ command = self.args.command
+ if command is None:
+ command = self.build_command
+ command()
+
+ def build_command(self):
+ executor = self._create_executor()
+
+ if not self.args.action_path:
+ # Default to all.
+ build_actions = list(executor.entrypoints)
+ else:
+ # Look up each requested and add it.
+ build_actions = []
+ for action_path in self.args.action_path:
+ try:
+ build_actions.append(executor.all[action_path])
+ except KeyError:
+ all_paths = "\n".join(executor.all.keys())
+ print(
+ f"ERROR: Action '{action_path}' not found. Available: \n{all_paths}",
+ file=self.stderr,
+ )
+ self.abort()
+ executor.build(*build_actions)
+
+ for build_action in build_actions:
+ if isinstance(build_action, BuildEntrypoint):
+ for output in build_action.outputs:
+ print(f"{output.get_fs_path()}", file=self.stdout)
+
+ def list_command(self):
+ executor = self._create_executor()
+ for ep in executor.entrypoints:
+ print(ep.path, file=self.stdout)
+
+ def list_all_command(self):
+ executor = self._create_executor()
+ for name in executor.all.keys():
+ if name:
+ print(name, file=self.stdout)
+
+
+class ModuleWrapper:
+ """Wraps a raw, loaded module with access to discovered details."""
+
+ def __init__(self, mod):
+ self.mod = mod
+ self.entrypoints = self._collect_entrypoints()
+
+ @staticmethod
+ def load_module(module_name: str) -> "ModuleWrapper":
+ return ModuleWrapper(importlib.import_module(module_name))
+
+ @staticmethod
+ def load_py_file(module_path: Path | str) -> "ModuleWrapper":
+ return ModuleWrapper(load_build_module(str(module_path)))
+
+ def _collect_entrypoints(self) -> dict[str, Entrypoint]:
+ results: dict[str, Entrypoint] = {}
+ for attr_name, attr_value in self.mod.__dict__.items():
+ if isinstance(attr_value, Entrypoint):
+ results[attr_name] = attr_value
+ return results
+
+
+def _is_module_like_str(s: str) -> bool:
+ return "/" not in s and "\\" not in s and not s.endswith(".py")
diff --git a/compiler/bindings/python/iree/build/net_actions.py b/compiler/bindings/python/iree/build/net_actions.py
new file mode 100644
index 0000000..1d2e158
--- /dev/null
+++ b/compiler/bindings/python/iree/build/net_actions.py
@@ -0,0 +1,39 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import urllib.error
+import urllib.request
+
+from iree.build.executor import BuildAction, BuildContext, BuildFile
+
+__all__ = [
+ "fetch_http",
+]
+
+
+def fetch_http(*, name: str, url: str) -> BuildFile:
+ context = BuildContext.current()
+ output_file = context.allocate_file(name)
+ action = FetchHttpAction(
+ url=url, output_file=output_file, desc=f"Fetch {url}", executor=context.executor
+ )
+ output_file.deps.add(action)
+ return output_file
+
+
+class FetchHttpAction(BuildAction):
+ def __init__(self, url: str, output_file: BuildFile, **kwargs):
+ super().__init__(**kwargs)
+ self.url = url
+ self.output_file = output_file
+
+ def invoke(self):
+ path = self.output_file.get_fs_path()
+ self.executor.write_status(f"Fetching URL: {self.url} -> {path}")
+ try:
+ urllib.request.urlretrieve(self.url, str(path))
+ except urllib.error.HTTPError as e:
+ raise IOError(f"Failed to fetch URL '{self.url}': {e}") from None
diff --git a/compiler/bindings/python/iree/build/onnx_actions.py b/compiler/bindings/python/iree/build/onnx_actions.py
new file mode 100644
index 0000000..5220890
--- /dev/null
+++ b/compiler/bindings/python/iree/build/onnx_actions.py
@@ -0,0 +1,90 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from iree.build.executor import BuildAction, BuildContext, BuildFile, BuildFileLike
+
+__all__ = [
+ "onnx_import",
+]
+
+
+def onnx_import(
+ *,
+ # Name of the rule and output of the final artifact.
+ name: str,
+ # Source onnx file.
+ source: BuildFileLike,
+ upgrade: bool = True,
+) -> BuildFile:
+ context = BuildContext.current()
+ input_file = context.file(source)
+ output_file = context.allocate_file(name)
+
+ # Chain through an upgrade if requested.
+ if upgrade:
+ processed_file = context.allocate_file(f"{name}__upgrade.onnx")
+ UpgradeOnnxAction(
+ input_file=input_file,
+ output_file=processed_file,
+ executor=context.executor,
+ desc=f"Upgrading ONNX {name}",
+ deps=[
+ input_file,
+ ],
+ )
+ input_file = processed_file
+
+ # Import.
+ ImportOnnxAction(
+ input_file=input_file,
+ output_file=output_file,
+ desc=f"Importing ONNX {name}",
+ executor=context.executor,
+ deps=[
+ input_file,
+ ],
+ )
+
+ output_file.deps.add(processed_file)
+ return output_file
+
+
+class UpgradeOnnxAction(BuildAction):
+ def __init__(self, input_file: BuildFile, output_file: BuildFile, **kwargs):
+ super().__init__(**kwargs)
+ self.input_file = input_file
+ self.output_file = output_file
+ output_file.deps.add(self)
+
+ def invoke(self):
+ import onnx
+
+ input_path = self.input_file.get_fs_path()
+ output_path = self.output_file.get_fs_path()
+
+ original_model = onnx.load_model(str(input_path))
+ converted_model = onnx.version_converter.convert_version(original_model, 17)
+ onnx.save(converted_model, str(output_path))
+
+
+class ImportOnnxAction(BuildAction):
+ def __init__(self, input_file: BuildFile, output_file: BuildFile, **kwargs):
+ super().__init__(**kwargs)
+ self.input_file = input_file
+ self.output_file = output_file
+ output_file.deps.add(self)
+
+ def invoke(self):
+ import iree.compiler.tools.import_onnx.__main__ as m
+
+ args = m.parse_arguments(
+ [
+ str(self.input_file.get_fs_path()),
+ "-o",
+ str(self.output_file.get_fs_path()),
+ ]
+ )
+ m.main(args)
diff --git a/compiler/bindings/python/test/CMakeLists.txt b/compiler/bindings/python/test/CMakeLists.txt
index 6f6cdb9..ca809a8 100644
--- a/compiler/bindings/python/test/CMakeLists.txt
+++ b/compiler/bindings/python/test/CMakeLists.txt
@@ -5,6 +5,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
add_subdirectory(api)
+add_subdirectory(build_api)
add_subdirectory(extras)
add_subdirectory(ir)
add_subdirectory(tools)
diff --git a/compiler/bindings/python/test/build_api/CMakeLists.txt b/compiler/bindings/python/test/build_api/CMakeLists.txt
new file mode 100644
index 0000000..b8bd817
--- /dev/null
+++ b/compiler/bindings/python/test/build_api/CMakeLists.txt
@@ -0,0 +1,15 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# The mnist builder depends on onnx, which needs the torch input support.
+if(IREE_INPUT_TORCH)
+ iree_py_test(
+ NAME
+ mnist_builder_test
+ SRCS
+ "mnist_builder_test.py"
+ )
+endif()
diff --git a/compiler/bindings/python/test/build_api/mnist_builder.py b/compiler/bindings/python/test/build_api/mnist_builder.py
new file mode 100644
index 0000000..54d6e30
--- /dev/null
+++ b/compiler/bindings/python/test/build_api/mnist_builder.py
@@ -0,0 +1,30 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from iree.build import *
+
+
+@entrypoint(description="Compiles an mnist model")
+def mnist(
+ url=cl_arg(
+ "mnist-onnx-url",
+ default="https://github.com/onnx/models/raw/main/validated/vision/classification/mnist/model/mnist-12.onnx",
+ help="URL from which to download mnist",
+ ),
+):
+ fetch_http(
+ name="mnist.onnx",
+ url=url,
+ )
+ onnx_import(
+ name="mnist.mlir",
+ source="mnist.onnx",
+ )
+ return "mnist.mlir"
+
+
+if __name__ == "__main__":
+ iree_build_main()
diff --git a/compiler/bindings/python/test/build_api/mnist_builder_test.py b/compiler/bindings/python/test/build_api/mnist_builder_test.py
new file mode 100644
index 0000000..02a9140
--- /dev/null
+++ b/compiler/bindings/python/test/build_api/mnist_builder_test.py
@@ -0,0 +1,108 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import io
+from pathlib import Path
+import re
+import subprocess
+import unittest
+import tempfile
+import sys
+
+from iree.build import *
+
+THIS_DIR = Path(__file__).resolve().parent
+
+
+class MnistBuilderTest(unittest.TestCase):
+ def setUp(self):
+ self._temp_dir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True)
+ self._temp_dir.__enter__()
+ self.output_path = Path(self._temp_dir.name)
+
+ def tearDown(self) -> None:
+ self._temp_dir.__exit__(None, None, None)
+
+ # Tests that invoking via the tool works:
+ # python -m iree.build {path to py file}
+ # We execute this out of process in order to verify the full flow.
+ def testBuildEntrypoint(self):
+ output = subprocess.check_output(
+ [
+ sys.executable,
+ "-m",
+ "iree.build",
+ str(THIS_DIR / "mnist_builder.py"),
+ "--output-dir",
+ str(self.output_path),
+ ]
+ ).decode()
+ print("OUTPUT:", output)
+ output_paths = output.splitlines()
+ self.assertEqual(len(output_paths), 1)
+ output_path = Path(output_paths[0])
+ self.assertTrue(output_path.is_relative_to(self.output_path))
+ contents = output_path.read_text()
+ self.assertIn("module", contents)
+
+ # Tests that invoking via the build module itself works
+ # python {path to py file}
+ # We execute this out of process in order to verify the full flow.
+ def testTargetModuleEntrypoint(self):
+ output = subprocess.check_output(
+ [
+ sys.executable,
+ str(THIS_DIR / "mnist_builder.py"),
+ "--output-dir",
+ str(self.output_path),
+ ]
+ ).decode()
+ print("OUTPUT:", output)
+ output_paths = output.splitlines()
+ self.assertEqual(len(output_paths), 1)
+
+ def testListCommand(self):
+ mod = load_build_module(THIS_DIR / "mnist_builder.py")
+ out_file = io.StringIO()
+ iree_build_main(mod, args=["--list"], stdout=out_file)
+ output = out_file.getvalue().strip()
+ self.assertEqual(output, "mnist")
+
+ def testListAllCommand(self):
+ mod = load_build_module(THIS_DIR / "mnist_builder.py")
+ out_file = io.StringIO()
+ iree_build_main(mod, args=["--list-all"], stdout=out_file)
+ output = out_file.getvalue().splitlines()
+ self.assertIn("mnist", output)
+ self.assertIn("mnist/mnist.onnx", output)
+
+ def testActionCLArg(self):
+ mod = load_build_module(THIS_DIR / "mnist_builder.py")
+ out_file = io.StringIO()
+ err_file = io.StringIO()
+ with self.assertRaisesRegex(
+ IOError,
+ re.escape("Failed to fetch URL 'https://github.com/iree-org/doesnotexist'"),
+ ):
+ iree_build_main(
+ mod,
+ args=[
+ "--mnist-onnx-url",
+ "https://github.com/iree-org/doesnotexist",
+ ],
+ stdout=out_file,
+ stderr=err_file,
+ )
+
+
+if __name__ == "__main__":
+ try:
+ import onnx
+ except ModuleNotFoundError:
+ print(f"Skipping test {__file__} because Python dependency `onnx` is not found")
+ sys.exit(0)
+
+ unittest.main()
diff --git a/compiler/setup.py b/compiler/setup.py
index 8204e82..bf5d734 100644
--- a/compiler/setup.py
+++ b/compiler/setup.py
@@ -454,6 +454,7 @@
packages=packages,
entry_points={
"console_scripts": [
+ "iree-build = iree.build.__main__:main",
"iree-compile = iree.compiler.tools.scripts.iree_compile.__main__:main",
"iree-import-onnx = iree.compiler.tools.import_onnx.__main__:_cli_main",
"iree-ir-tool = iree.compiler.tools.ir_tool.__main__:_cli_main",