blob: 9a2482a8af0d8e0f783b7a4dec8acc0b7b9760d2 [file] [log] [blame]
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from typing import Callable, Generator, TypeVar
import argparse
import contextlib
import functools
import inspect
import threading
from typing import Callable
_locals = threading.local()
_ALL_ARG_REGISTRARS: list[Callable[[argparse.ArgumentParser], None]] = []
_ALL_ARG_HANDLERS: list[Callable[[argparse.Namespace], None]] = []
def register_arg_parser_callback(registrar: Callable[[argparse.ArgumentParser], None]):
"""Decorator that adds a global argument registration callback.
This callback will be invoked when a new ArgumentParser is constructed.
"""
_ALL_ARG_REGISTRARS.append(registrar)
return registrar
def register_arg_handler_callback(handler: Callable[[argparse.Namespace], None]):
"""Decorator that registers a handler to be run on global arguments at startup."""
_ALL_ARG_HANDLERS.append(handler)
return handler
def configure_arg_parser(p: argparse.ArgumentParser):
"""Invokes all callbacks from `register_arg_parser_callback` on the parser."""
for callback in _ALL_ARG_REGISTRARS:
callback(p)
def run_global_arg_handlers(ns: argparse.Namespace):
"""Invokes all global argument handlers."""
for h in _ALL_ARG_HANDLERS:
h(ns)
@contextlib.contextmanager
def argument_namespace_context(ns: argparse.Namespace):
"""Establish that given namespace as the current namespace for this thread.
Note that as a thread local, this does not propagate to child threads or
sub-processes. This means that all argument management must be done during
action setup and action invocations will not typically have access to args.
"""
if not hasattr(_locals, "arg_ns_stack"):
_locals.arg_ns_stack = []
_locals.arg_ns_stack.append(ns)
try:
yield ns
finally:
_locals.arg_ns_stack.pop()
def current_args_namespace() -> argparse.Namespace:
try:
return _locals.arg_ns_stack[-1]
except (AttributeError, IndexError):
raise AssertionError(
"No current argument namespace: Is it possible you are trying to resolve "
"arguments from another thread or process"
)
_Decorated = TypeVar("_Decorated", bound=Callable)
def expand_cl_arg_defaults(wrapped: _Decorated) -> _Decorated:
sig = inspect.signature(wrapped)
def wrapper(*args, **kwargs):
args_ns = current_args_namespace()
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()
def filter(arg):
if isinstance(arg, ClArgRef):
return arg.resolve(args_ns)
return arg
new_args = [filter(arg) for arg in bound.args]
new_kwargs = {k: filter(v) for k, v in bound.kwargs.items()}
return wrapped(*new_args, **new_kwargs)
functools.update_wrapper(wrapper, wrapped)
return wrapper
class ClArgRef:
"""Used in default values of function arguments to indicate that the default should
be derived from an argument reference.
Actually defining the argument must be done elsewhere.
See `cl_arg_ref()` for canonical use.
"""
def __init__(self, dest: str):
self.dest = dest
def resolve(self, arg_namespace: argparse.Namespace):
try:
return getattr(arg_namespace, self.dest)
except AttributeError as e:
raise RuntimeError(
f"Unable to resolve command line argument '{self.dest}' in namespace"
) from e
def cl_arg_ref(dest: str):
"""Used as a default value for functions wrapped in @expand_cl_defaults to indicate
that an argument must come from the command line environment.
Note that this does not have a typing annotation, allowing the argument to be
annotated with a type, assuming that resolution will happen dynamically in some
fashion.
"""
return ClArgRef(dest)
class ClArg(ClArgRef):
"""Used in default values of function arguments to indicate that an argument needs
to be defined and referenced.
This is used in user-defined entry points, and the executor has special logic to
collect all needed arguments automatically.
See `cl_arg()` for canonical use.
"""
def __init__(self, name, dest: str, **add_argument_kw):
super().__init__(dest)
self.name = name
self.add_argument_kw = add_argument_kw
def define_arg(self, parser: argparse.ArgumentParser):
parser.add_argument(f"--{self.name}", dest=self.dest, **self.add_argument_kw)
def cl_arg(name: str, *, action=None, default=None, type=None, help=None):
"""Used to define or reference a command-line argument from within actions
and entry-points.
Keywords have the same interpretation as `ArgumentParser.add_argument()`.
Any ClArg set as a default value for an argument to an `entrypoint` will be
added to the global argument parser. Any particular argument name can only be
registered once and must not conflict with a built-in command line option.
The implication of this is that for single-use arguments, the `=cl_arg(...)`
can just be added as a default argument. Otherwise, for shared arguments,
it should be created at the module level and referenced.
When called, any entrypoint arguments that do not have an explicit keyword
set will get their value from the command line environment.
Note that this does not have a typing annotation, allowing the argument to be
annotated with a type, assuming that resolution will happen dynamically in some
fashion.
"""
if name.startswith("-"):
raise ValueError("cl_arg name must not be prefixed with dashes")
dest = name.replace("-", "_")
return ClArg(name, action=action, default=default, type=type, dest=dest, help=help)
def extract_cl_arg_defs(callable: Callable) -> Generator[ClArg, None, None]:
"""Extracts all `ClArg` default values from a callable.
This is used in order to eagerly register argument definitions for some set
of functions.
"""
sig = inspect.signature(callable)
for p in sig.parameters.values():
def_value = p.default
if isinstance(def_value, ClArg):
yield def_value