Use Black to format Python files (#14161)
Switch from yapf to Black to better align with the LLVM and broader
Python community. I decided not to go with Pyink as it seems much less
popular and differs in formatting style beyond indentation.
- Reformat all python files outside of `third_party` with black.
- Update the lint workflow to use black. This only considers files
modified by the PR.
- Delete old dotfiles.
The command used to reformat all files at once:
```shell
fd -e py --exclude third_party | xargs black
```
To learn more about Back, see: https://black.readthedocs.io/en/stable/
and https://github.com/psf/black.
In the next PR, once the commit SHA of this PR is finalized, I plan to
add this commit to `.git-blame-ignore-revs` to keep the blame history
clean.
Issue: https://github.com/openxla/iree/issues/14135
diff --git a/build_tools/scripts/add_license_header.py b/build_tools/scripts/add_license_header.py
index eebc17e..305e3c5 100755
--- a/build_tools/scripts/add_license_header.py
+++ b/build_tools/scripts/add_license_header.py
@@ -27,163 +27,191 @@
"""
-class CommentSyntax(object):
- def __init__(self, start_comment, middle_comment=None, end_comment=""):
- self.start_comment = start_comment
- self.middle_comment = middle_comment if middle_comment else start_comment
- self.end_comment = end_comment
+class CommentSyntax(object):
+ def __init__(self, start_comment, middle_comment=None, end_comment=""):
+ self.start_comment = start_comment
+ self.middle_comment = middle_comment if middle_comment else start_comment
+ self.end_comment = end_comment
def comment_arg_parser(v):
- """Can be used to parse a comment syntax triple."""
- if v is None:
- return None
- if not isinstance(v, str):
- raise argparse.ArgumentTypeError("String expected")
- return CommentSyntax(*v.split(","))
+ """Can be used to parse a comment syntax triple."""
+ if v is None:
+ return None
+ if not isinstance(v, str):
+ raise argparse.ArgumentTypeError("String expected")
+ return CommentSyntax(*v.split(","))
def create_multikey(d):
- # pylint: disable=g-complex-comprehension
- return {k: v for keys, v in d.items() for k in keys}
+ # pylint: disable=g-complex-comprehension
+ return {k: v for keys, v in d.items() for k in keys}
-filename_to_comment = create_multikey({
- ("BUILD", "CMakeLists.txt"): CommentSyntax("#"),
-})
+filename_to_comment = create_multikey(
+ {
+ ("BUILD", "CMakeLists.txt"): CommentSyntax("#"),
+ }
+)
-ext_to_comment = create_multikey({
- (".bzl", ".cfg", ".cmake", ".overlay", ".py", ".sh", ".yml"):
- CommentSyntax("#"),
- (".cc", ".cpp", ".comp", ".fbs", ".h", ".hpp", ".inc", ".td"):
- CommentSyntax("//"),
- (".def",):
- CommentSyntax(";;"),
-})
+ext_to_comment = create_multikey(
+ {
+ (".bzl", ".cfg", ".cmake", ".overlay", ".py", ".sh", ".yml"): CommentSyntax(
+ "#"
+ ),
+ (".cc", ".cpp", ".comp", ".fbs", ".h", ".hpp", ".inc", ".td"): CommentSyntax(
+ "//"
+ ),
+ (".def",): CommentSyntax(";;"),
+ }
+)
def get_comment_syntax(args):
- """Deterime the comment syntax to use."""
- if args.comment:
- return args.comment
- basename = os.path.basename(args.filename)
- from_filename = filename_to_comment.get(basename)
- if from_filename:
- return from_filename
- _, ext = os.path.splitext(args.filename)
- return ext_to_comment.get(ext, args.default_comment)
+ """Deterime the comment syntax to use."""
+ if args.comment:
+ return args.comment
+ basename = os.path.basename(args.filename)
+ from_filename = filename_to_comment.get(basename)
+ if from_filename:
+ return from_filename
+ _, ext = os.path.splitext(args.filename)
+ return ext_to_comment.get(ext, args.default_comment)
def parse_arguments():
- """Parses command line arguments."""
- current_year = datetime.date.today().year
- parser = argparse.ArgumentParser()
- input_group = parser.add_mutually_exclusive_group()
- input_group.add_argument("infile",
- nargs="?",
- type=argparse.FileType("r", encoding="UTF-8"),
- help="Input file to format. Default: stdin",
- default=sys.stdin)
- parser.add_argument(
- "--filename",
- "--assume-filename",
- type=str,
- default=None,
- help=(
- "Filename to use for determining comment syntax. Default: actual name"
- "of input file."))
- parser.add_argument(
- "--year",
- "-y",
- help="Year to add copyright. Default: the current year ({})".format(
- current_year),
- default=current_year)
- parser.add_argument("--holder",
- help="Copyright holder. Default: The IREE Authors",
- default="The IREE Authors")
- parser.add_argument(
- "--quiet",
- help=("Don't raise a runtime error on encountering an unhandled filetype."
- "Useful for running across many files at once. Default: False"),
- action="store_true",
- default=False)
- output_group = parser.add_mutually_exclusive_group()
- output_group.add_argument("-o",
- "--outfile",
- "--output",
- help="File to send output. Default: stdout",
- type=argparse.FileType("w", encoding="UTF-8"),
- default=sys.stdout)
- output_group.add_argument("--in_place",
- "-i",
- action="store_true",
- help="Run formatting in place. Default: False",
- default=False)
- comment_group = parser.add_mutually_exclusive_group()
- comment_group.add_argument("--comment",
- "-c",
- type=comment_arg_parser,
- help="Override comment syntax.",
- default=None)
- comment_group.add_argument(
- "--default_comment",
- type=comment_arg_parser,
- help="Fallback comment syntax if filename is unknown. Default: None",
- default=None)
- args = parser.parse_args()
+ """Parses command line arguments."""
+ current_year = datetime.date.today().year
+ parser = argparse.ArgumentParser()
+ input_group = parser.add_mutually_exclusive_group()
+ input_group.add_argument(
+ "infile",
+ nargs="?",
+ type=argparse.FileType("r", encoding="UTF-8"),
+ help="Input file to format. Default: stdin",
+ default=sys.stdin,
+ )
+ parser.add_argument(
+ "--filename",
+ "--assume-filename",
+ type=str,
+ default=None,
+ help=(
+ "Filename to use for determining comment syntax. Default: actual name"
+ "of input file."
+ ),
+ )
+ parser.add_argument(
+ "--year",
+ "-y",
+ help="Year to add copyright. Default: the current year ({})".format(
+ current_year
+ ),
+ default=current_year,
+ )
+ parser.add_argument(
+ "--holder",
+ help="Copyright holder. Default: The IREE Authors",
+ default="The IREE Authors",
+ )
+ parser.add_argument(
+ "--quiet",
+ help=(
+ "Don't raise a runtime error on encountering an unhandled filetype."
+ "Useful for running across many files at once. Default: False"
+ ),
+ action="store_true",
+ default=False,
+ )
+ output_group = parser.add_mutually_exclusive_group()
+ output_group.add_argument(
+ "-o",
+ "--outfile",
+ "--output",
+ help="File to send output. Default: stdout",
+ type=argparse.FileType("w", encoding="UTF-8"),
+ default=sys.stdout,
+ )
+ output_group.add_argument(
+ "--in_place",
+ "-i",
+ action="store_true",
+ help="Run formatting in place. Default: False",
+ default=False,
+ )
+ comment_group = parser.add_mutually_exclusive_group()
+ comment_group.add_argument(
+ "--comment",
+ "-c",
+ type=comment_arg_parser,
+ help="Override comment syntax.",
+ default=None,
+ )
+ comment_group.add_argument(
+ "--default_comment",
+ type=comment_arg_parser,
+ help="Fallback comment syntax if filename is unknown. Default: None",
+ default=None,
+ )
+ args = parser.parse_args()
- if args.in_place and args.infile == sys.stdin:
- raise parser.error("Cannot format stdin in place")
+ if args.in_place and args.infile == sys.stdin:
+ raise parser.error("Cannot format stdin in place")
- if not args.filename and args.infile != sys.stdin:
- args.filename = args.infile.name
+ if not args.filename and args.infile != sys.stdin:
+ args.filename = args.infile.name
- return args
+ return args
def main(args):
- first_line = args.infile.readline()
- already_has_license = False
- shebang = ""
- content_lines = []
- if first_line.startswith("#!"):
- shebang = first_line
- else:
- content_lines = [first_line]
- content_lines.extend(args.infile.readlines())
- for line in content_lines:
- if COPYRIGHT_PATTERN.search(line):
- already_has_license = True
- break
- if already_has_license:
- header = shebang
- else:
- comment_syntax = get_comment_syntax(args)
- if not comment_syntax:
- if args.quiet:
- header = shebang
- else:
- raise ValueError("Could not determine comment syntax for " +
- args.filename)
+ first_line = args.infile.readline()
+ already_has_license = False
+ shebang = ""
+ content_lines = []
+ if first_line.startswith("#!"):
+ shebang = first_line
else:
- header = LICENSE_HEADER_FORMATTER.format(
- # Add a blank line between shebang and license.
- shebang=(shebang + "\n" if shebang else ""),
- start_comment=comment_syntax.start_comment,
- middle_comment=comment_syntax.middle_comment,
- # Add a blank line before the end comment.
- end_comment=("\n" + comment_syntax.end_comment
- if comment_syntax.end_comment else ""),
- year=args.year,
- holder=args.holder)
+ content_lines = [first_line]
+ content_lines.extend(args.infile.readlines())
+ for line in content_lines:
+ if COPYRIGHT_PATTERN.search(line):
+ already_has_license = True
+ break
+ if already_has_license:
+ header = shebang
+ else:
+ comment_syntax = get_comment_syntax(args)
+ if not comment_syntax:
+ if args.quiet:
+ header = shebang
+ else:
+ raise ValueError(
+ "Could not determine comment syntax for " + args.filename
+ )
+ else:
+ header = LICENSE_HEADER_FORMATTER.format(
+ # Add a blank line between shebang and license.
+ shebang=(shebang + "\n" if shebang else ""),
+ start_comment=comment_syntax.start_comment,
+ middle_comment=comment_syntax.middle_comment,
+ # Add a blank line before the end comment.
+ end_comment=(
+ "\n" + comment_syntax.end_comment
+ if comment_syntax.end_comment
+ else ""
+ ),
+ year=args.year,
+ holder=args.holder,
+ )
- # Have to open for write after we're done reading.
- if args.in_place:
- args.outfile = open(args.filename, "w", encoding="UTF-8")
- args.outfile.write(header)
- args.outfile.writelines(content_lines)
+ # Have to open for write after we're done reading.
+ if args.in_place:
+ args.outfile = open(args.filename, "w", encoding="UTF-8")
+ args.outfile.write(header)
+ args.outfile.writelines(content_lines)
if __name__ == "__main__":
- main(parse_arguments())
+ main(parse_arguments())
diff --git a/build_tools/scripts/check_path_lengths.py b/build_tools/scripts/check_path_lengths.py
index 645ba7d..42d95c2 100755
--- a/build_tools/scripts/check_path_lengths.py
+++ b/build_tools/scripts/check_path_lengths.py
@@ -30,70 +30,71 @@
def parse_arguments():
- parser = argparse.ArgumentParser(description="Path length checker")
- # The default limit was selected based on repository state when this script
- # was added. If the max path length decreases, consider lowering this too.
- parser.add_argument("--limit",
- help="Path length limit (inclusive)",
- type=int,
- default=75)
- parser.add_argument(
- "--include_tests",
- help=
- "Includes /test directories. False by default as these don't usually generate problematic files during the build",
- action="store_true",
- default=False)
- parser.add_argument("--verbose",
- help="Outputs detailed information about path lengths",
- action="store_true",
- default=False)
- args = parser.parse_args()
- return args
+ parser = argparse.ArgumentParser(description="Path length checker")
+ # The default limit was selected based on repository state when this script
+ # was added. If the max path length decreases, consider lowering this too.
+ parser.add_argument(
+ "--limit", help="Path length limit (inclusive)", type=int, default=75
+ )
+ parser.add_argument(
+ "--include_tests",
+ help="Includes /test directories. False by default as these don't usually generate problematic files during the build",
+ action="store_true",
+ default=False,
+ )
+ parser.add_argument(
+ "--verbose",
+ help="Outputs detailed information about path lengths",
+ action="store_true",
+ default=False,
+ )
+ args = parser.parse_args()
+ return args
def main(args):
- repo_root = pathlib.Path(__file__).parent.parent.parent
+ repo_root = pathlib.Path(__file__).parent.parent.parent
- # Just look at the compiler directory for now, since it has historically had
- # by far the longest paths.
- walk_root = os.path.join(repo_root, "compiler")
+ # Just look at the compiler directory for now, since it has historically had
+ # by far the longest paths.
+ walk_root = os.path.join(repo_root, "compiler")
- longest_path_length = -1
- long_paths = []
- short_paths = []
- for dirpath, dirnames, _ in os.walk(walk_root):
- # Don't descend into test directories, since they typically don't generate
- # object files or binaries that could trip up the build system.
- if not args.include_tests and "test" in dirnames:
- dirnames.remove("test")
+ longest_path_length = -1
+ long_paths = []
+ short_paths = []
+ for dirpath, dirnames, _ in os.walk(walk_root):
+ # Don't descend into test directories, since they typically don't generate
+ # object files or binaries that could trip up the build system.
+ if not args.include_tests and "test" in dirnames:
+ dirnames.remove("test")
- path = pathlib.Path(dirpath).relative_to(repo_root).as_posix()
- if len(path) > args.limit:
- long_paths.append(path)
+ path = pathlib.Path(dirpath).relative_to(repo_root).as_posix()
+ if len(path) > args.limit:
+ long_paths.append(path)
+ else:
+ short_paths.append(path)
+ longest_path_length = max(longest_path_length, len(path))
+ long_paths.sort(key=len)
+ short_paths.sort(key=len)
+
+ if args.verbose and short_paths:
+ print(f"These paths are shorter than the limit of {args.limit} characters:")
+ for path in short_paths:
+ print("{:3d}, {}".format(len(path), path))
+
+ if long_paths:
+ print(f"These paths are longer than the limit of {args.limit} characters:")
+ for path in long_paths:
+ print("{:3d}, {}".format(len(path), path))
+ print(
+ f"Error: {len(long_paths)} source paths are longer than {args.limit} characters."
+ )
+ print(" Long paths can be problematic when building on Windows.")
+ print(" Please look at the output above and trim the paths.")
+ sys.exit(1)
else:
- short_paths.append(path)
- longest_path_length = max(longest_path_length, len(path))
- long_paths.sort(key=len)
- short_paths.sort(key=len)
-
- if args.verbose and short_paths:
- print(f"These paths are shorter than the limit of {args.limit} characters:")
- for path in short_paths:
- print("{:3d}, {}".format(len(path), path))
-
- if long_paths:
- print(f"These paths are longer than the limit of {args.limit} characters:")
- for path in long_paths:
- print("{:3d}, {}".format(len(path), path))
- print(
- f"Error: {len(long_paths)} source paths are longer than {args.limit} characters."
- )
- print(" Long paths can be problematic when building on Windows.")
- print(" Please look at the output above and trim the paths.")
- sys.exit(1)
- else:
- print(f"All path lengths are under the limit of {args.limit} characters.")
+ print(f"All path lengths are under the limit of {args.limit} characters.")
if __name__ == "__main__":
- main(parse_arguments())
+ main(parse_arguments())
diff --git a/build_tools/scripts/download_file.py b/build_tools/scripts/download_file.py
index da1a1d3..ffa4220 100755
--- a/build_tools/scripts/download_file.py
+++ b/build_tools/scripts/download_file.py
@@ -25,84 +25,91 @@
def parse_arguments():
- """Parses command line arguments."""
- parser = argparse.ArgumentParser(
- description="Downloads a file from the web "
- "and decompresses it if necessary. NEVER Use this tool to download from "
- "untrusted sources, it doesn't unpack the file safely.")
- parser.add_argument("source_url",
- type=str,
- metavar="<source-url>",
- help="Source URL to download")
- parser.add_argument("-o",
- "--output",
- type=str,
- required=True,
- metavar="<output-file>",
- help="Output file path")
- parser.add_argument("--unpack",
- action='store_true',
- default=False,
- help="Unpack the downloaded file if it's an archive")
- parser.add_argument("--max-tries",
- metavar="<max-tries>",
- type=int,
- default=DEFAULT_MAX_TRIES,
- help="Number of tries before giving up")
- return parser.parse_args()
+ """Parses command line arguments."""
+ parser = argparse.ArgumentParser(
+ description="Downloads a file from the web "
+ "and decompresses it if necessary. NEVER Use this tool to download from "
+ "untrusted sources, it doesn't unpack the file safely."
+ )
+ parser.add_argument(
+ "source_url", type=str, metavar="<source-url>", help="Source URL to download"
+ )
+ parser.add_argument(
+ "-o",
+ "--output",
+ type=str,
+ required=True,
+ metavar="<output-file>",
+ help="Output file path",
+ )
+ parser.add_argument(
+ "--unpack",
+ action="store_true",
+ default=False,
+ help="Unpack the downloaded file if it's an archive",
+ )
+ parser.add_argument(
+ "--max-tries",
+ metavar="<max-tries>",
+ type=int,
+ default=DEFAULT_MAX_TRIES,
+ help="Number of tries before giving up",
+ )
+ return parser.parse_args()
def download_and_extract(source_url: str, output: str, unpack: bool):
- # Open the URL and get the file-like streaming object.
- with urllib.request.urlopen(source_url) as response:
- if response.status != 200:
- raise RuntimeError(
- f"Failed to download file with status {response.status} {response.msg}"
- )
+ # Open the URL and get the file-like streaming object.
+ with urllib.request.urlopen(source_url) as response:
+ if response.status != 200:
+ raise RuntimeError(
+ f"Failed to download file with status {response.status} {response.msg}"
+ )
- if unpack:
- if source_url.endswith(".tar.gz"):
- # Open tar.gz in the streaming mode.
- with tarfile.open(fileobj=response, mode="r|*") as tar_file:
- if os.path.exists(output):
- shutil.rmtree(output)
- os.makedirs(output)
- tar_file.extractall(output)
- return
- elif source_url.endswith(".gz"):
- # Open gzip from a file-like object, which will be in the streaming mode.
- with gzip.open(filename=response, mode="rb") as input_file:
- with open(output, "wb") as output_file:
- shutil.copyfileobj(input_file, output_file)
- return
+ if unpack:
+ if source_url.endswith(".tar.gz"):
+ # Open tar.gz in the streaming mode.
+ with tarfile.open(fileobj=response, mode="r|*") as tar_file:
+ if os.path.exists(output):
+ shutil.rmtree(output)
+ os.makedirs(output)
+ tar_file.extractall(output)
+ return
+ elif source_url.endswith(".gz"):
+ # Open gzip from a file-like object, which will be in the streaming mode.
+ with gzip.open(filename=response, mode="rb") as input_file:
+ with open(output, "wb") as output_file:
+ shutil.copyfileobj(input_file, output_file)
+ return
- # Fallback to download the file only.
- with open(output, "wb") as output_file:
- # Streaming copy.
- shutil.copyfileobj(response, output_file)
+ # Fallback to download the file only.
+ with open(output, "wb") as output_file:
+ # Streaming copy.
+ shutil.copyfileobj(response, output_file)
def main(args):
- output_dir = os.path.dirname(args.output)
+ output_dir = os.path.dirname(args.output)
- if not os.path.isdir(output_dir):
- os.makedirs(output_dir)
+ if not os.path.isdir(output_dir):
+ os.makedirs(output_dir)
- remaining_tries = args.max_tries
- while remaining_tries > 0:
- try:
- download_and_extract(args.source_url, args.output, args.unpack)
- break
- except (ConnectionResetError, ConnectionRefusedError,
- urllib.error.URLError):
- remaining_tries -= 1
- if remaining_tries == 0:
- raise
- else:
- logging.warning(f"Connection error, remaining {remaining_tries} tries",
- exc_info=True)
- time.sleep(RETRY_COOLDOWN_TIME)
+ remaining_tries = args.max_tries
+ while remaining_tries > 0:
+ try:
+ download_and_extract(args.source_url, args.output, args.unpack)
+ break
+ except (ConnectionResetError, ConnectionRefusedError, urllib.error.URLError):
+ remaining_tries -= 1
+ if remaining_tries == 0:
+ raise
+ else:
+ logging.warning(
+ f"Connection error, remaining {remaining_tries} tries",
+ exc_info=True,
+ )
+ time.sleep(RETRY_COOLDOWN_TIME)
if __name__ == "__main__":
- main(parse_arguments())
+ main(parse_arguments())
diff --git a/build_tools/scripts/generate_compilation_flagfile.py b/build_tools/scripts/generate_compilation_flagfile.py
index cf0cb13..adda56e 100755
--- a/build_tools/scripts/generate_compilation_flagfile.py
+++ b/build_tools/scripts/generate_compilation_flagfile.py
@@ -16,23 +16,24 @@
def parse_arguments():
- """Parses command line arguments."""
- parser = argparse.ArgumentParser()
- parser.add_argument("--output",
- type=str,
- required=True,
- help="output file to write to")
- parser.add_argument("compilation_flags",
- metavar="<compilation-flags>",
- nargs="*",
- help="list of compilation flags")
- return parser.parse_args()
+ """Parses command line arguments."""
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--output", type=str, required=True, help="output file to write to"
+ )
+ parser.add_argument(
+ "compilation_flags",
+ metavar="<compilation-flags>",
+ nargs="*",
+ help="list of compilation flags",
+ )
+ return parser.parse_args()
def main(args):
- with open(args.output, "w") as f:
- f.write("\n".join(args.compilation_flags) + "\n")
+ with open(args.output, "w") as f:
+ f.write("\n".join(args.compilation_flags) + "\n")
if __name__ == "__main__":
- main(parse_arguments())
+ main(parse_arguments())
diff --git a/build_tools/scripts/generate_flagfile.py b/build_tools/scripts/generate_flagfile.py
index f0330e0..fb1effd 100755
--- a/build_tools/scripts/generate_flagfile.py
+++ b/build_tools/scripts/generate_flagfile.py
@@ -12,54 +12,67 @@
def parse_arguments():
- """Parses command line arguments."""
- parser = argparse.ArgumentParser()
- parser.add_argument("--module",
- type=str,
- required=True,
- metavar="<module>",
- help="The name of the module file")
- parser.add_argument("--device",
- type=str,
- required=True,
- metavar="<device>",
- help="The name of the HAL device")
- parser.add_argument("--function",
- type=str,
- required=True,
- metavar="<function>",
- help="The name of the entry function")
- parser.add_argument("--inputs",
- type=str,
- required=True,
- metavar="<inputs>",
- help="A list of comma-separated function inputs")
- parser.add_argument("--additional_args",
- type=str,
- required=True,
- metavar="<additional-cl-args>",
- help="Additional command-line arguments")
- parser.add_argument("-o",
- "--output",
- type=str,
- required=True,
- metavar="<output-file>",
- help="Output file to write to")
- return parser.parse_args()
+ """Parses command line arguments."""
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--module",
+ type=str,
+ required=True,
+ metavar="<module>",
+ help="The name of the module file",
+ )
+ parser.add_argument(
+ "--device",
+ type=str,
+ required=True,
+ metavar="<device>",
+ help="The name of the HAL device",
+ )
+ parser.add_argument(
+ "--function",
+ type=str,
+ required=True,
+ metavar="<function>",
+ help="The name of the entry function",
+ )
+ parser.add_argument(
+ "--inputs",
+ type=str,
+ required=True,
+ metavar="<inputs>",
+ help="A list of comma-separated function inputs",
+ )
+ parser.add_argument(
+ "--additional_args",
+ type=str,
+ required=True,
+ metavar="<additional-cl-args>",
+ help="Additional command-line arguments",
+ )
+ parser.add_argument(
+ "-o",
+ "--output",
+ type=str,
+ required=True,
+ metavar="<output-file>",
+ help="Output file to write to",
+ )
+ return parser.parse_args()
def main(args):
- lines = [
- f"--device={args.device}", f"--module={args.module}",
- f"--function={args.function}"
- ]
- lines.extend([("--input=" + e) for e in args.inputs.split(",")])
- lines.extend(args.additional_args.split(";"))
- content = "\n".join(lines) + "\n"
+ lines = [
+ f"--device={args.device}",
+ f"--module={args.module}",
+ f"--function={args.function}",
+ ]
+ lines.extend([("--input=" + e) for e in args.inputs.split(",")])
+ lines.extend(args.additional_args.split(";"))
+ content = "\n".join(lines) + "\n"
- with open(args.output, "w") as f:
- f.writelines(content)
+ with open(args.output, "w") as f:
+ f.writelines(content)
if __name__ == "__main__":
- main(parse_arguments())
+ main(parse_arguments())
diff --git a/build_tools/scripts/generate_release_index.py b/build_tools/scripts/generate_release_index.py
index 0e7ea94..70a4eeb 100755
--- a/build_tools/scripts/generate_release_index.py
+++ b/build_tools/scripts/generate_release_index.py
@@ -19,63 +19,74 @@
def parse_arguments():
- parser = argparse.ArgumentParser()
- parser.add_argument("--repo",
- "--repository",
- default="openxla/iree",
- help="The GitHub repository to fetch releases from.")
- parser.add_argument(
- "--output",
- default="-",
- help="The file to write the HTML to or '-' for stdout (the default)")
- return parser.parse_args()
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--repo",
+ "--repository",
+ default="openxla/iree",
+ help="The GitHub repository to fetch releases from.",
+ )
+ parser.add_argument(
+ "--output",
+ default="-",
+ help="The file to write the HTML to or '-' for stdout (the default)",
+ )
+ return parser.parse_args()
class ReleaseFetcher:
+ def __init__(self, repo, per_page=100):
+ self._session = requests.Session()
+ self._repo = repo
+ self._per_page = per_page
- def __init__(self, repo, per_page=100):
- self._session = requests.Session()
- self._repo = repo
- self._per_page = per_page
+ def get_all(self):
+ url = f"https://api.github.com/repos/{self._repo}/releases"
+ page = 1
- def get_all(self):
- url = f"https://api.github.com/repos/{self._repo}/releases"
- page = 1
-
- while True:
- response = self._session.get(url,
- params={
- "page": page,
- "per_page": self._per_page,
- })
- for release in response.json():
- yield release
- if "next" not in response.links:
- break
- page += 1
+ while True:
+ response = self._session.get(
+ url,
+ params={
+ "page": page,
+ "per_page": self._per_page,
+ },
+ )
+ for release in response.json():
+ yield release
+ if "next" not in response.links:
+ break
+ page += 1
def main(args):
- fetcher = ReleaseFetcher(repo=args.repo)
- with (sys.stdout if args.output == "-" else open(args.output, "w")) as f:
- f.write(
- textwrap.dedent("""\
+ fetcher = ReleaseFetcher(repo=args.repo)
+ with sys.stdout if args.output == "-" else open(args.output, "w") as f:
+ f.write(
+ textwrap.dedent(
+ """\
<!DOCTYPE html>
<html>
<body>
- """))
- for release in fetcher.get_all():
- if release["draft"]:
- continue
- for asset in release["assets"]:
- url = html.escape(asset['browser_download_url'])
- name = html.escape(asset['name'])
- f.write(f" <a href={url}>{name}</a><br />\n")
- f.write(textwrap.dedent("""\
+ """
+ )
+ )
+ for release in fetcher.get_all():
+ if release["draft"]:
+ continue
+ for asset in release["assets"]:
+ url = html.escape(asset["browser_download_url"])
+ name = html.escape(asset["name"])
+ f.write(f" <a href={url}>{name}</a><br />\n")
+ f.write(
+ textwrap.dedent(
+ """\
</body>
</html>
- """))
+ """
+ )
+ )
if __name__ == "__main__":
- main(parse_arguments())
+ main(parse_arguments())
diff --git a/build_tools/scripts/get_e2e_artifacts.py b/build_tools/scripts/get_e2e_artifacts.py
index 634ee31..8875438 100755
--- a/build_tools/scripts/get_e2e_artifacts.py
+++ b/build_tools/scripts/get_e2e_artifacts.py
@@ -29,153 +29,156 @@
from absl import flags
SUITE_NAME_TO_TARGET = {
- 'e2e_tests':
- '//integrations/tensorflow/e2e:e2e_tests',
- 'mobile_bert_squad_tests':
- '//integrations/tensorflow/e2e:mobile_bert_squad_tests',
- 'layers_tests':
- '//integrations/tensorflow/e2e/keras/layers:layers_tests',
- 'layers_dynamic_batch_tests':
- '//integrations/tensorflow/e2e/keras/layers:layers_dynamic_batch_tests',
- 'layers_training_tests':
- '//integrations/tensorflow/e2e/keras/layers:layers_training_tests',
- 'keyword_spotting_tests':
- '//integrations/tensorflow/e2e/keras:keyword_spotting_tests',
- 'keyword_spotting_internal_streaming_tests':
- '//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests',
- 'imagenet_non_hermetic_tests':
- '//integrations/tensorflow/e2e/keras/applications:imagenet_non_hermetic_tests',
- 'slim_vision_tests':
- '//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests',
+ "e2e_tests": "//integrations/tensorflow/e2e:e2e_tests",
+ "mobile_bert_squad_tests": "//integrations/tensorflow/e2e:mobile_bert_squad_tests",
+ "layers_tests": "//integrations/tensorflow/e2e/keras/layers:layers_tests",
+ "layers_dynamic_batch_tests": "//integrations/tensorflow/e2e/keras/layers:layers_dynamic_batch_tests",
+ "layers_training_tests": "//integrations/tensorflow/e2e/keras/layers:layers_training_tests",
+ "keyword_spotting_tests": "//integrations/tensorflow/e2e/keras:keyword_spotting_tests",
+ "keyword_spotting_internal_streaming_tests": "//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests",
+ "imagenet_non_hermetic_tests": "//integrations/tensorflow/e2e/keras/applications:imagenet_non_hermetic_tests",
+ "slim_vision_tests": "//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests",
}
-SUITES_HELP = [f'`{name}`' for name in SUITE_NAME_TO_TARGET]
+SUITES_HELP = [f"`{name}`" for name in SUITE_NAME_TO_TARGET]
SUITES_HELP = f'{", ".join(SUITES_HELP[:-1])} and {SUITES_HELP[-1]}'
FLAGS = flags.FLAGS
flags.DEFINE_bool(
- 'dry_run', False,
- 'Run without extracting files. Useful for quickly checking for artifact '
- 'collisions.')
+ "dry_run",
+ False,
+ "Run without extracting files. Useful for quickly checking for artifact "
+ "collisions.",
+)
flags.DEFINE_string(
- 'artifacts_dir', os.path.join(tempfile.gettempdir(), 'iree', 'modules'),
- 'Directory to transfer the benchmarking artifacts to. Defaults to '
- '/tmp/iree/modules/')
-flags.DEFINE_bool('run_test_suites', True, 'Run any specified test suites.')
-flags.DEFINE_list('test_suites', list(SUITE_NAME_TO_TARGET.keys()),
- f'Any combination of {SUITES_HELP}.')
+ "artifacts_dir",
+ os.path.join(tempfile.gettempdir(), "iree", "modules"),
+ "Directory to transfer the benchmarking artifacts to. Defaults to "
+ "/tmp/iree/modules/",
+)
+flags.DEFINE_bool("run_test_suites", True, "Run any specified test suites.")
+flags.DEFINE_list(
+ "test_suites",
+ list(SUITE_NAME_TO_TARGET.keys()),
+ f"Any combination of {SUITES_HELP}.",
+)
-EXPECTED_COLLISIONS = [
- '/tf_ref/', 'tf_input.mlir', 'iree_input.mlir', '/saved_model/'
-]
+EXPECTED_COLLISIONS = ["/tf_ref/", "tf_input.mlir", "iree_input.mlir", "/saved_model/"]
def _target_to_testlogs_path(target: str) -> str:
- """Convert target into the path where Bazel stores the artifacts we want."""
- return os.path.join('bazel-testlogs',
- target.replace('//', '').replace(':', os.sep))
+ """Convert target into the path where Bazel stores the artifacts we want."""
+ return os.path.join("bazel-testlogs", target.replace("//", "").replace(":", os.sep))
def _target_to_test_name(target: str, test_suite_path: str) -> str:
- """Get test_name from `suite_name_test_name__tf__backend_name`."""
- return target.split('__')[0].replace(f'{test_suite_path}_', '')
+ """Get test_name from `suite_name_test_name__tf__backend_name`."""
+ return target.split("__")[0].replace(f"{test_suite_path}_", "")
def get_test_paths_and_names(test_suite_path: str):
- """Get the paths Bazel stores test outputs in and the matching test names."""
- targets = utils.get_test_targets(test_suite_path)
- test_paths = [_target_to_testlogs_path(target) for target in targets]
- test_names = [
- _target_to_test_name(target, test_suite_path) for target in targets
- ]
- return test_paths, test_names
+ """Get the paths Bazel stores test outputs in and the matching test names."""
+ targets = utils.get_test_targets(test_suite_path)
+ test_paths = [_target_to_testlogs_path(target) for target in targets]
+ test_names = [_target_to_test_name(target, test_suite_path) for target in targets]
+ return test_paths, test_names
-def check_collision(filename: str, test_name: str, written_paths: Set[str],
- paths_to_tests: Dict[str, str]):
- """Check that we aren't overwriting files unless we expect to."""
- # Note: We can't use a check that the files have identical contents because
- # tf_input.mlir can have random numbers appended to its function names.
- # See https://github.com/openxla/iree/issues/3375
+def check_collision(
+ filename: str,
+ test_name: str,
+ written_paths: Set[str],
+ paths_to_tests: Dict[str, str],
+):
+ """Check that we aren't overwriting files unless we expect to."""
+ # Note: We can't use a check that the files have identical contents because
+ # tf_input.mlir can have random numbers appended to its function names.
+ # See https://github.com/openxla/iree/issues/3375
- expected_collision = any([name in filename for name in EXPECTED_COLLISIONS])
- if filename in written_paths and not expected_collision:
- raise ValueError(f'Collision found on {filename} between {test_name}.py '
- f'and {paths_to_tests[filename]}.py')
- else:
- written_paths.add(filename)
- paths_to_tests[filename] = test_name
+ expected_collision = any([name in filename for name in EXPECTED_COLLISIONS])
+ if filename in written_paths and not expected_collision:
+ raise ValueError(
+ f"Collision found on {filename} between {test_name}.py "
+ f"and {paths_to_tests[filename]}.py"
+ )
+ else:
+ written_paths.add(filename)
+ paths_to_tests[filename] = test_name
def update_path(archive_path: str):
- """Update the --module flag with the new location of the compiled.vmfb"""
- backend_path = archive_path.split('traces')[0] # 'ModuleName/backend_name'.
- compiled_path = os.path.join(FLAGS.artifacts_dir, backend_path,
- 'compiled.vmfb')
- flagfile_path = os.path.join(FLAGS.artifacts_dir, archive_path)
- for line in fileinput.input(files=[flagfile_path], inplace=True):
- if line.strip().startswith('--module'):
- print(f'--module={compiled_path}\n', end='')
- else:
- print(line, end='')
+ """Update the --module flag with the new location of the compiled.vmfb"""
+ backend_path = archive_path.split("traces")[0] # 'ModuleName/backend_name'.
+ compiled_path = os.path.join(FLAGS.artifacts_dir, backend_path, "compiled.vmfb")
+ flagfile_path = os.path.join(FLAGS.artifacts_dir, archive_path)
+ for line in fileinput.input(files=[flagfile_path], inplace=True):
+ if line.strip().startswith("--module"):
+ print(f"--module={compiled_path}\n", end="")
+ else:
+ print(line, end="")
-def extract_artifacts(test_path: str, test_name: str, written_paths: Set[str],
- paths_to_tests: Dict[str, str]):
- """Unzips all of the benchmarking artifacts for a given test and backend."""
- outputs = os.path.join(test_path, 'test.outputs', 'outputs.zip')
- if FLAGS.dry_run and not os.path.exists(outputs):
- # The artifacts may or may not be present on disk during a dry run. If they
- # are then we want to collision check them, but if they aren't that's fine.
- return
+def extract_artifacts(
+ test_path: str,
+ test_name: str,
+ written_paths: Set[str],
+ paths_to_tests: Dict[str, str],
+):
+ """Unzips all of the benchmarking artifacts for a given test and backend."""
+ outputs = os.path.join(test_path, "test.outputs", "outputs.zip")
+ if FLAGS.dry_run and not os.path.exists(outputs):
+ # The artifacts may or may not be present on disk during a dry run. If they
+ # are then we want to collision check them, but if they aren't that's fine.
+ return
- archive = zipfile.ZipFile(outputs)
- # Filter out directory names.
- filenames = [name for name in archive.namelist() if name[-1] != os.sep]
+ archive = zipfile.ZipFile(outputs)
+ # Filter out directory names.
+ filenames = [name for name in archive.namelist() if name[-1] != os.sep]
- for filename in filenames:
- # Check for collisions.
- check_collision(filename, test_name, written_paths, paths_to_tests)
+ for filename in filenames:
+ # Check for collisions.
+ check_collision(filename, test_name, written_paths, paths_to_tests)
- # Extract and update flagfile path.
- if not FLAGS.dry_run:
- archive.extract(filename, FLAGS.artifacts_dir)
- if filename.endswith('flagfile'):
- update_path(filename)
+ # Extract and update flagfile path.
+ if not FLAGS.dry_run:
+ archive.extract(filename, FLAGS.artifacts_dir)
+ if filename.endswith("flagfile"):
+ update_path(filename)
def main(argv):
- del argv # Unused.
+ del argv # Unused.
- print(
- "The bazel integrations build and tests are deprecated. This script "
- "may be reworked in the future. For the time being refer to "
- "https://github.com/openxla/iree/blob/main/docs/developers/developing_iree/e2e_benchmarking.md "
- "for information on how to run TensorFlow benchmarks.")
- exit(1)
+ print(
+ "The bazel integrations build and tests are deprecated. This script "
+ "may be reworked in the future. For the time being refer to "
+ "https://github.com/openxla/iree/blob/main/docs/developers/developing_iree/e2e_benchmarking.md "
+ "for information on how to run TensorFlow benchmarks."
+ )
+ exit(1)
- # Convert test suite shorthands to full test suite targets.
- test_suites = [SUITE_NAME_TO_TARGET[suite] for suite in FLAGS.test_suites]
+ # Convert test suite shorthands to full test suite targets.
+ test_suites = [SUITE_NAME_TO_TARGET[suite] for suite in FLAGS.test_suites]
- if FLAGS.run_test_suites:
- # Use bazel test to execute all of the test suites in parallel.
- command = ['bazel', 'test', *test_suites, '--color=yes']
- print(f'Running: `{" ".join(command)}`')
- if not FLAGS.dry_run:
- subprocess.run(command, check=True)
- print()
+ if FLAGS.run_test_suites:
+ # Use bazel test to execute all of the test suites in parallel.
+ command = ["bazel", "test", *test_suites, "--color=yes"]
+ print(f'Running: `{" ".join(command)}`')
+ if not FLAGS.dry_run:
+ subprocess.run(command, check=True)
+ print()
- written_paths = set()
- paths_to_tests = dict()
+ written_paths = set()
+ paths_to_tests = dict()
- for test_suite in test_suites:
- # Extract all of the artifacts for this test suite.
- test_paths, test_names = get_test_paths_and_names(test_suite)
- for i, (test_path, test_name) in enumerate(zip(test_paths, test_names)):
- print(f'\rTransfering {test_suite} {i + 1}/{len(test_paths)}', end='')
- extract_artifacts(test_path, test_name, written_paths, paths_to_tests)
- print('\n')
+ for test_suite in test_suites:
+ # Extract all of the artifacts for this test suite.
+ test_paths, test_names = get_test_paths_and_names(test_suite)
+ for i, (test_path, test_name) in enumerate(zip(test_paths, test_names)):
+ print(f"\rTransfering {test_suite} {i + 1}/{len(test_paths)}", end="")
+ extract_artifacts(test_path, test_name, written_paths, paths_to_tests)
+ print("\n")
-if __name__ == '__main__':
- app.run(main)
+if __name__ == "__main__":
+ app.run(main)
diff --git a/build_tools/scripts/git/check_submodule_init.py b/build_tools/scripts/git/check_submodule_init.py
index 611c32f..b878ef3 100644
--- a/build_tools/scripts/git/check_submodule_init.py
+++ b/build_tools/scripts/git/check_submodule_init.py
@@ -12,37 +12,47 @@
def run():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--runtime_only",
- help=("Only check the initialization of the submodules for the"
- "runtime-dependent submodules. Default: False"),
- action="store_true",
- default=False)
- args = parser.parse_args()
- # No-op if we're not in a git repository.
- try:
- subprocess.check_call(['git', 'rev-parse', '--is-inside-work-tree'],
- stdout=subprocess.DEVNULL,
- stderr=subprocess.DEVNULL)
- except:
- return
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--runtime_only",
+ help=(
+ "Only check the initialization of the submodules for the"
+ "runtime-dependent submodules. Default: False"
+ ),
+ action="store_true",
+ default=False,
+ )
+ args = parser.parse_args()
+ # No-op if we're not in a git repository.
+ try:
+ subprocess.check_call(
+ ["git", "rev-parse", "--is-inside-work-tree"],
+ stdout=subprocess.DEVNULL,
+ stderr=subprocess.DEVNULL,
+ )
+ except:
+ return
- output = os.popen("git submodule status")
- submodules = output.readlines()
+ output = os.popen("git submodule status")
+ submodules = output.readlines()
- runtime_submodules = pathlib.Path(__file__).with_name(
- "runtime_submodules.txt").read_text().split("\n")
+ runtime_submodules = (
+ pathlib.Path(__file__)
+ .with_name("runtime_submodules.txt")
+ .read_text()
+ .split("\n")
+ )
- for submodule in submodules:
- prefix = submodule.strip()[0]
- name = submodule.split()[1]
- if prefix == "-" and (not args.runtime_only or name in runtime_submodules):
- print(
- "The git submodule '%s' is not initialized. Please run `git submodule update --init`"
- % (name))
- sys.exit(1)
+ for submodule in submodules:
+ prefix = submodule.strip()[0]
+ name = submodule.split()[1]
+ if prefix == "-" and (not args.runtime_only or name in runtime_submodules):
+ print(
+ "The git submodule '%s' is not initialized. Please run `git submodule update --init`"
+ % (name)
+ )
+ sys.exit(1)
if __name__ == "__main__":
- run()
+ run()
diff --git a/build_tools/scripts/integrate/bump_llvm.py b/build_tools/scripts/integrate/bump_llvm.py
index 7dd770b..459da10 100755
--- a/build_tools/scripts/integrate/bump_llvm.py
+++ b/build_tools/scripts/integrate/bump_llvm.py
@@ -44,92 +44,103 @@
def main(args):
- if not args.disable_setup_remote:
- iree_utils.git_setup_remote(args.upstream_remote, args.upstream_repository)
+ if not args.disable_setup_remote:
+ iree_utils.git_setup_remote(args.upstream_remote, args.upstream_repository)
- iree_utils.git_check_porcelain()
- print(f"Fetching remote repository: {args.upstream_remote}")
- iree_utils.git_fetch(repository=args.upstream_remote)
+ iree_utils.git_check_porcelain()
+ print(f"Fetching remote repository: {args.upstream_remote}")
+ iree_utils.git_fetch(repository=args.upstream_remote)
- # If re-using a branch, make sure we are not on that branch.
- if args.reuse_branch:
- iree_utils.git_checkout("main")
+ # If re-using a branch, make sure we are not on that branch.
+ if args.reuse_branch:
+ iree_utils.git_checkout("main")
- # Create branch.
- branch_name = args.branch_name
- if not branch_name:
- branch_name = f"bump-llvm-{date.today().strftime('%Y%m%d')}"
- print(f"Creating branch {branch_name} (override with --branch-name=)")
- iree_utils.git_create_branch(branch_name,
- checkout=True,
- ref=f"{args.upstream_remote}/main",
- force=args.reuse_branch)
+ # Create branch.
+ branch_name = args.branch_name
+ if not branch_name:
+ branch_name = f"bump-llvm-{date.today().strftime('%Y%m%d')}"
+ print(f"Creating branch {branch_name} (override with --branch-name=)")
+ iree_utils.git_create_branch(
+ branch_name,
+ checkout=True,
+ ref=f"{args.upstream_remote}/main",
+ force=args.reuse_branch,
+ )
- # Reset the llvm-project submodule to track upstream.
- # This will discard any cherrypicks that may have been committed locally,
- # but the assumption is that if doing a main llvm version bump, the
- # cherrypicks will be incorporated at the new commit. If not, well, ymmv
- # and you will find out.
- iree_utils.git_submodule_set_origin(
- "third_party/llvm-project",
- url="https://github.com/iree-org/iree-llvm-fork.git",
- branch="--default")
+ # Reset the llvm-project submodule to track upstream.
+ # This will discard any cherrypicks that may have been committed locally,
+ # but the assumption is that if doing a main llvm version bump, the
+ # cherrypicks will be incorporated at the new commit. If not, well, ymmv
+ # and you will find out.
+ iree_utils.git_submodule_set_origin(
+ "third_party/llvm-project",
+ url="https://github.com/iree-org/iree-llvm-fork.git",
+ branch="--default",
+ )
- # Remove the branch pin file, reverting us to pure upstream.
- branch_pin_file = os.path.join(
- iree_utils.get_repo_root(),
- iree_modules.MODULE_INFOS["llvm-project"].branch_pin_file)
- if os.path.exists(branch_pin_file):
- os.remove(branch_pin_file)
+ # Remove the branch pin file, reverting us to pure upstream.
+ branch_pin_file = os.path.join(
+ iree_utils.get_repo_root(),
+ iree_modules.MODULE_INFOS["llvm-project"].branch_pin_file,
+ )
+ if os.path.exists(branch_pin_file):
+ os.remove(branch_pin_file)
- # Update the LLVM submodule.
- llvm_commit = args.llvm_commit
- print(f"Updating LLVM submodule to {llvm_commit}")
- llvm_root = iree_utils.get_submodule_root("llvm-project")
- iree_utils.git_fetch(repository="origin",
- ref="refs/heads/main",
- repo_dir=llvm_root)
- if llvm_commit == "HEAD":
- llvm_commit = "origin/main"
- iree_utils.git_reset(llvm_commit, repo_dir=llvm_root)
- llvm_commit, llvm_summary = iree_utils.git_current_commit(repo_dir=llvm_root)
- print(f"LLVM submodule reset to:\n {llvm_summary}\n")
+ # Update the LLVM submodule.
+ llvm_commit = args.llvm_commit
+ print(f"Updating LLVM submodule to {llvm_commit}")
+ llvm_root = iree_utils.get_submodule_root("llvm-project")
+ iree_utils.git_fetch(repository="origin", ref="refs/heads/main", repo_dir=llvm_root)
+ if llvm_commit == "HEAD":
+ llvm_commit = "origin/main"
+ iree_utils.git_reset(llvm_commit, repo_dir=llvm_root)
+ llvm_commit, llvm_summary = iree_utils.git_current_commit(repo_dir=llvm_root)
+ print(f"LLVM submodule reset to:\n {llvm_summary}\n")
- # Create a commit.
- print("Create commit...")
- iree_utils.git_create_commit(
- message=(f"Integrate llvm-project at {llvm_commit}\n\n"
- f"* Reset third_party/llvm-project: {llvm_summary}"),
- add_all=True)
+ # Create a commit.
+ print("Create commit...")
+ iree_utils.git_create_commit(
+ message=(
+ f"Integrate llvm-project at {llvm_commit}\n\n"
+ f"* Reset third_party/llvm-project: {llvm_summary}"
+ ),
+ add_all=True,
+ )
- # Push.
- print("Pushing...")
- iree_utils.git_push_branch(args.upstream_remote, branch_name)
+ # Push.
+ print("Pushing...")
+ iree_utils.git_push_branch(args.upstream_remote, branch_name)
def parse_arguments(argv):
- parser = argparse.ArgumentParser(description="IREE LLVM-bump-inator")
- parser.add_argument("--upstream-remote",
- help="Upstream remote",
- default="UPSTREAM_AUTOMATION")
- parser.add_argument("--upstream-repository",
- help="Upstream repository URL",
- default="git@github.com:openxla/iree.git")
- parser.add_argument("--disable-setup-remote",
- help="Disable remote setup",
- action="store_true",
- default=False)
- parser.add_argument("--llvm-commit", help="LLVM commit sha", default="HEAD")
- parser.add_argument("--branch-name",
- help="Integrate branch to create",
- default=None)
- parser.add_argument("--reuse-branch",
- help="Allow re-use of an existing branch",
- action="store_true",
- default=False)
- args = parser.parse_args(argv)
- return args
+ parser = argparse.ArgumentParser(description="IREE LLVM-bump-inator")
+ parser.add_argument(
+ "--upstream-remote", help="Upstream remote", default="UPSTREAM_AUTOMATION"
+ )
+ parser.add_argument(
+ "--upstream-repository",
+ help="Upstream repository URL",
+ default="git@github.com:openxla/iree.git",
+ )
+ parser.add_argument(
+ "--disable-setup-remote",
+ help="Disable remote setup",
+ action="store_true",
+ default=False,
+ )
+ parser.add_argument("--llvm-commit", help="LLVM commit sha", default="HEAD")
+ parser.add_argument(
+ "--branch-name", help="Integrate branch to create", default=None
+ )
+ parser.add_argument(
+ "--reuse-branch",
+ help="Allow re-use of an existing branch",
+ action="store_true",
+ default=False,
+ )
+ args = parser.parse_args(argv)
+ return args
if __name__ == "__main__":
- main(parse_arguments(sys.argv[1:]))
+ main(parse_arguments(sys.argv[1:]))
diff --git a/build_tools/scripts/integrate/iree_modules.py b/build_tools/scripts/integrate/iree_modules.py
index fec8ff3..5587333 100644
--- a/build_tools/scripts/integrate/iree_modules.py
+++ b/build_tools/scripts/integrate/iree_modules.py
@@ -6,40 +6,43 @@
class ModuleInfo:
-
- def __init__(self, *, name: str, path: str, branch_pin_file: str,
- default_repository_url: str, fork_repository_push: str,
- fork_repository_pull: str, branch_prefix: str):
- self.name = name
- self.path = path
- self.branch_pin_file = branch_pin_file
- self.default_repository_url = default_repository_url
- self.fork_repository_push = fork_repository_push
- self.fork_repository_pull = fork_repository_pull
- self.branch_prefix = branch_prefix
+ def __init__(
+ self,
+ *,
+ name: str,
+ path: str,
+ branch_pin_file: str,
+ default_repository_url: str,
+ fork_repository_push: str,
+ fork_repository_pull: str,
+ branch_prefix: str
+ ):
+ self.name = name
+ self.path = path
+ self.branch_pin_file = branch_pin_file
+ self.default_repository_url = default_repository_url
+ self.fork_repository_push = fork_repository_push
+ self.fork_repository_pull = fork_repository_pull
+ self.branch_prefix = branch_prefix
MODULE_INFOS = {
- "llvm-project":
- ModuleInfo(
- name="llvm-project",
- path="third_party/llvm-project",
- branch_pin_file="third_party/llvm-project.branch-pin",
- default_repository_url=
- "https://github.com/iree-org/iree-llvm-fork.git",
- fork_repository_push="git@github.com:iree-org/iree-llvm-fork.git",
- fork_repository_pull=
- "https://github.com/iree-org/iree-llvm-fork.git",
- branch_prefix="patched-llvm-project-",
- ),
- "stablehlo":
- ModuleInfo(
- name="stablehlo",
- path="third_party/stablehlo",
- branch_pin_file="third_party/stablehlo.branch-pin",
- default_repository_url="https://github.com/iree-org/stablehlo.git",
- fork_repository_push="git@github.com:iree-org/stablehlo.git",
- fork_repository_pull="https://github.com/iree-org/stablehlo.git",
- branch_prefix="patched-stablehlo-",
- )
+ "llvm-project": ModuleInfo(
+ name="llvm-project",
+ path="third_party/llvm-project",
+ branch_pin_file="third_party/llvm-project.branch-pin",
+ default_repository_url="https://github.com/iree-org/iree-llvm-fork.git",
+ fork_repository_push="git@github.com:iree-org/iree-llvm-fork.git",
+ fork_repository_pull="https://github.com/iree-org/iree-llvm-fork.git",
+ branch_prefix="patched-llvm-project-",
+ ),
+ "stablehlo": ModuleInfo(
+ name="stablehlo",
+ path="third_party/stablehlo",
+ branch_pin_file="third_party/stablehlo.branch-pin",
+ default_repository_url="https://github.com/iree-org/stablehlo.git",
+ fork_repository_push="git@github.com:iree-org/stablehlo.git",
+ fork_repository_pull="https://github.com/iree-org/stablehlo.git",
+ branch_prefix="patched-stablehlo-",
+ ),
}
diff --git a/build_tools/scripts/integrate/iree_utils.py b/build_tools/scripts/integrate/iree_utils.py
index 3a81ba8..21e5454 100644
--- a/build_tools/scripts/integrate/iree_utils.py
+++ b/build_tools/scripts/integrate/iree_utils.py
@@ -15,186 +15,200 @@
def get_repo_root() -> str:
- global _repo_root
- if _repo_root is None:
- _repo_root = os.getcwd()
- _validate_repo_root()
- return _repo_root
+ global _repo_root
+ if _repo_root is None:
+ _repo_root = os.getcwd()
+ _validate_repo_root()
+ return _repo_root
def get_submodule_root(submodule) -> str:
- path = os.path.join(get_repo_root(), "third_party", submodule)
- if not os.path.isdir(path):
- raise SystemExit(f"Could not find submodule: {path}")
- return path
+ path = os.path.join(get_repo_root(), "third_party", submodule)
+ if not os.path.isdir(path):
+ raise SystemExit(f"Could not find submodule: {path}")
+ return path
def _validate_repo_root():
- # Look for something we know is there.
- known_dir = os.path.join(_repo_root, "compiler")
- if not os.path.isdir(known_dir):
- raise SystemExit(f"ERROR: Must run from the iree repository root. "
- f"Actually in: {_repo_root}")
+ # Look for something we know is there.
+ known_dir = os.path.join(_repo_root, "compiler")
+ if not os.path.isdir(known_dir):
+ raise SystemExit(
+ f"ERROR: Must run from the iree repository root. "
+ f"Actually in: {_repo_root}"
+ )
def git_setup_remote(remote_alias, url, *, repo_dir=None):
- needs_create = False
- try:
- existing_url = git_exec(["remote", "get-url", remote_alias],
- capture_output=True,
- repo_dir=repo_dir,
- quiet=True)
- existing_url = existing_url.strip()
- if existing_url == url:
- return
- except subprocess.CalledProcessError:
- # Does not exist.
- needs_create = True
+ needs_create = False
+ try:
+ existing_url = git_exec(
+ ["remote", "get-url", remote_alias],
+ capture_output=True,
+ repo_dir=repo_dir,
+ quiet=True,
+ )
+ existing_url = existing_url.strip()
+ if existing_url == url:
+ return
+ except subprocess.CalledProcessError:
+ # Does not exist.
+ needs_create = True
- if needs_create:
- git_exec(["remote", "add", "--no-tags", remote_alias, url],
- repo_dir=repo_dir)
- else:
- git_exec(["remote", "set-url", remote_alias, url], repo_dir=repo_dir)
+ if needs_create:
+ git_exec(["remote", "add", "--no-tags", remote_alias, url], repo_dir=repo_dir)
+ else:
+ git_exec(["remote", "set-url", remote_alias, url], repo_dir=repo_dir)
def git_is_porcelain(*, repo_dir=None):
- output = git_exec(["status", "--porcelain", "--untracked-files=no"],
- capture_output=True,
- quiet=True,
- repo_dir=repo_dir).strip()
- return not bool(output)
+ output = git_exec(
+ ["status", "--porcelain", "--untracked-files=no"],
+ capture_output=True,
+ quiet=True,
+ repo_dir=repo_dir,
+ ).strip()
+ return not bool(output)
def git_check_porcelain(*, repo_dir=None):
- output = git_exec(["status", "--porcelain", "--untracked-files=no"],
- capture_output=True,
- quiet=True,
- repo_dir=repo_dir).strip()
- if output:
- actual_repo_dir = get_repo_root() if repo_dir is None else repo_dir
- raise SystemExit(f"ERROR: git directory {actual_repo_dir} is not clean. "
- f"Please stash changes:\n{output}")
+ output = git_exec(
+ ["status", "--porcelain", "--untracked-files=no"],
+ capture_output=True,
+ quiet=True,
+ repo_dir=repo_dir,
+ ).strip()
+ if output:
+ actual_repo_dir = get_repo_root() if repo_dir is None else repo_dir
+ raise SystemExit(
+ f"ERROR: git directory {actual_repo_dir} is not clean. "
+ f"Please stash changes:\n{output}"
+ )
def git_fetch(*, repository=None, ref=None, repo_dir=None):
- args = ["fetch"]
- if repository:
- args.append(repository)
- if ref is not None:
- args.append(ref)
- git_exec(args, repo_dir=repo_dir)
+ args = ["fetch"]
+ if repository:
+ args.append(repository)
+ if ref is not None:
+ args.append(ref)
+ git_exec(args, repo_dir=repo_dir)
def git_checkout(ref, *, repo_dir=None):
- git_exec(["checkout", ref], repo_dir=repo_dir)
+ git_exec(["checkout", ref], repo_dir=repo_dir)
-def git_create_branch(branch_name,
- *,
- checkout=True,
- ref=None,
- force=False,
- repo_dir=None):
- branch_args = ["branch"]
- if force:
- branch_args.append("-f")
- branch_args.append(branch_name)
- if ref is not None:
- branch_args.append(ref)
- git_exec(branch_args, repo_dir=repo_dir)
+def git_create_branch(
+ branch_name, *, checkout=True, ref=None, force=False, repo_dir=None
+):
+ branch_args = ["branch"]
+ if force:
+ branch_args.append("-f")
+ branch_args.append(branch_name)
+ if ref is not None:
+ branch_args.append(ref)
+ git_exec(branch_args, repo_dir=repo_dir)
- if checkout:
- git_exec(["checkout", branch_name], repo_dir=repo_dir)
+ if checkout:
+ git_exec(["checkout", branch_name], repo_dir=repo_dir)
def git_push_branch(repository, branch_name, *, force=False, repo_dir=None):
- push_args = ["push", "--set-upstream"]
- if force:
- push_args.append("-f")
- push_args.append(repository)
- push_args.append(f"{branch_name}:{branch_name}")
- git_exec(push_args, repo_dir=repo_dir)
+ push_args = ["push", "--set-upstream"]
+ if force:
+ push_args.append("-f")
+ push_args.append(repository)
+ push_args.append(f"{branch_name}:{branch_name}")
+ git_exec(push_args, repo_dir=repo_dir)
def git_branch_exists(branch_name, *, repo_dir=None):
- output = git_exec(["branch", "-l", branch_name],
- repo_dir=repo_dir,
- quiet=True,
- capture_output=True).strip()
- return bool(output)
+ output = git_exec(
+ ["branch", "-l", branch_name],
+ repo_dir=repo_dir,
+ quiet=True,
+ capture_output=True,
+ ).strip()
+ return bool(output)
def git_submodule_set_origin(path, *, url=None, branch=None, repo_dir=None):
- if url is not None:
- git_exec(["submodule", "set-url", "--", path, url], repo_dir=repo_dir)
+ if url is not None:
+ git_exec(["submodule", "set-url", "--", path, url], repo_dir=repo_dir)
- if branch is not None:
- try:
- if branch == "--default":
- git_exec(["submodule", "set-branch", "--default", "--", path],
- repo_dir=repo_dir)
- else:
- git_exec(["submodule", "set-branch", "--branch", branch, "--", path],
- repo_dir=repo_dir)
- except subprocess.CalledProcessError:
- # The set-branch command returns 0 on change and !0 on no change.
- # This is a bit unfortunate.
- ...
+ if branch is not None:
+ try:
+ if branch == "--default":
+ git_exec(
+ ["submodule", "set-branch", "--default", "--", path],
+ repo_dir=repo_dir,
+ )
+ else:
+ git_exec(
+ ["submodule", "set-branch", "--branch", branch, "--", path],
+ repo_dir=repo_dir,
+ )
+ except subprocess.CalledProcessError:
+ # The set-branch command returns 0 on change and !0 on no change.
+ # This is a bit unfortunate.
+ ...
def git_reset(ref, *, hard=True, repo_dir=None):
- args = ["reset"]
- if hard:
- args.append("--hard")
- args.append(ref)
- git_exec(args, repo_dir=repo_dir)
+ args = ["reset"]
+ if hard:
+ args.append("--hard")
+ args.append(ref)
+ git_exec(args, repo_dir=repo_dir)
def git_current_commit(*, repo_dir=None) -> Tuple[str, str]:
- output = git_exec(["log", "-n", "1", "--pretty=format:%H (%ci): %s"],
- capture_output=True,
- repo_dir=repo_dir,
- quiet=True)
- output = output.strip()
- parts = output.split(" ")
- # Return commit, full_summary
- return parts[0], output
+ output = git_exec(
+ ["log", "-n", "1", "--pretty=format:%H (%ci): %s"],
+ capture_output=True,
+ repo_dir=repo_dir,
+ quiet=True,
+ )
+ output = output.strip()
+ parts = output.split(" ")
+ # Return commit, full_summary
+ return parts[0], output
def git_create_commit(*, message, add_all=False, repo_dir=None):
- if add_all:
- git_exec(["add", "-A"], repo_dir=repo_dir)
- git_exec(["commit", "-m", message])
+ if add_all:
+ git_exec(["add", "-A"], repo_dir=repo_dir)
+ git_exec(["commit", "-m", message])
def git_ls_remote_branches(repository_url, *, filter=None, repo_dir=None):
- args = ["ls-remote", "-h", repository_url]
- if filter:
- args.extend(filter)
- output = git_exec(args, quiet=True, capture_output=True)
- lines = output.strip().splitlines(keepends=False)
+ args = ["ls-remote", "-h", repository_url]
+ if filter:
+ args.extend(filter)
+ output = git_exec(args, quiet=True, capture_output=True)
+ lines = output.strip().splitlines(keepends=False)
- # Format is <commit> refs/heads/branch_name
- def extract_branch(line):
- parts = re.split("\\s+", line)
- ref = parts[1]
- prefix = "refs/heads/"
- if ref.startswith(prefix):
- ref = ref[len(prefix):]
- return ref
+ # Format is <commit> refs/heads/branch_name
+ def extract_branch(line):
+ parts = re.split("\\s+", line)
+ ref = parts[1]
+ prefix = "refs/heads/"
+ if ref.startswith(prefix):
+ ref = ref[len(prefix) :]
+ return ref
- return [extract_branch(l) for l in lines]
+ return [extract_branch(l) for l in lines]
def git_exec(args, *, repo_dir=None, quiet=False, capture_output=False):
- full_args = ["git"] + args
- full_args_quoted = [shlex.quote(a) for a in full_args]
- if not repo_dir:
- repo_dir = get_repo_root()
- if not quiet:
- print(f" ++ EXEC: (cd {repo_dir} && {' '.join(full_args_quoted)})")
- if capture_output:
- return subprocess.check_output(full_args, cwd=repo_dir).decode("utf-8")
- else:
- subprocess.check_call(full_args, cwd=repo_dir)
+ full_args = ["git"] + args
+ full_args_quoted = [shlex.quote(a) for a in full_args]
+ if not repo_dir:
+ repo_dir = get_repo_root()
+ if not quiet:
+ print(f" ++ EXEC: (cd {repo_dir} && {' '.join(full_args_quoted)})")
+ if capture_output:
+ return subprocess.check_output(full_args, cwd=repo_dir).decode("utf-8")
+ else:
+ subprocess.check_call(full_args, cwd=repo_dir)
diff --git a/build_tools/scripts/integrate/patch_module.py b/build_tools/scripts/integrate/patch_module.py
index 2184bfe..fbe2230 100755
--- a/build_tools/scripts/integrate/patch_module.py
+++ b/build_tools/scripts/integrate/patch_module.py
@@ -32,78 +32,77 @@
def main(args):
- module_info = iree_modules.MODULE_INFOS.get(args.module)
- if not module_info:
- raise SystemExit(f"ERROR: Bad value for --module. Must be one of: "
- f"{', '.join(iree_modules.MODULE_INFOS.keys())}")
+ module_info = iree_modules.MODULE_INFOS.get(args.module)
+ if not module_info:
+ raise SystemExit(
+ f"ERROR: Bad value for --module. Must be one of: "
+ f"{', '.join(iree_modules.MODULE_INFOS.keys())}"
+ )
- if args.command == "patch":
- main_patch(args, module_info)
- else:
- raise SystemExit(
- f"ERROR: Unrecognized --command. Must be one of: patch, unpatch")
+ if args.command == "patch":
+ main_patch(args, module_info)
+ else:
+ raise SystemExit(
+ f"ERROR: Unrecognized --command. Must be one of: patch, unpatch"
+ )
def main_patch(args, module_info: iree_modules.ModuleInfo):
- module_root = os.path.join(iree_utils.get_repo_root(), module_info.path)
- setup_module_remotes(module_root, module_info)
+ module_root = os.path.join(iree_utils.get_repo_root(), module_info.path)
+ setup_module_remotes(module_root, module_info)
- branch_name = find_unused_branch_name(module_info)
- print(f"Allocated branch: {branch_name}")
- current_commit, summary = iree_utils.git_current_commit(repo_dir=module_root)
- print(f"Module is currently at: {summary}")
- print(
- f"*** Pushing branch {branch_name} to {module_info.fork_repository_push} ***"
- )
- print(f"(Please ignore any messages below about creating a PR)\n")
- iree_utils.git_exec([
- "push", PATCH_REMOTE_ALIAS, f"{current_commit}:refs/heads/{branch_name}"
- ],
- repo_dir=module_root)
- print(f"*** Branch {branch_name} pushed ***")
+ branch_name = find_unused_branch_name(module_info)
+ print(f"Allocated branch: {branch_name}")
+ current_commit, summary = iree_utils.git_current_commit(repo_dir=module_root)
+ print(f"Module is currently at: {summary}")
+ print(f"*** Pushing branch {branch_name} to {module_info.fork_repository_push} ***")
+ print(f"(Please ignore any messages below about creating a PR)\n")
+ iree_utils.git_exec(
+ ["push", PATCH_REMOTE_ALIAS, f"{current_commit}:refs/heads/{branch_name}"],
+ repo_dir=module_root,
+ )
+ print(f"*** Branch {branch_name} pushed ***")
- print(f"******* Congratulations *******")
- print(
- f"You have pushed your commits to {branch_name} on {module_info.fork_repository_push}."
- )
- print(
- f"Your main repository should now show that the submodule has been edited."
- )
- print(f"Make a commit, referencing the above branch cherry-picks and ")
- print(f"land the resulting PR.")
- print(f"You can push more commits to this module's patch branch via:")
- print(
- f" (cd {module_info.path} && git push {PATCH_REMOTE_ALIAS} HEAD:{branch_name})"
- )
+ print(f"******* Congratulations *******")
+ print(
+ f"You have pushed your commits to {branch_name} on {module_info.fork_repository_push}."
+ )
+ print(f"Your main repository should now show that the submodule has been edited.")
+ print(f"Make a commit, referencing the above branch cherry-picks and ")
+ print(f"land the resulting PR.")
+ print(f"You can push more commits to this module's patch branch via:")
+ print(
+ f" (cd {module_info.path} && git push {PATCH_REMOTE_ALIAS} HEAD:{branch_name})"
+ )
-def setup_module_remotes(module_root: str,
- module_info: iree_modules.ModuleInfo):
- iree_utils.git_setup_remote(PATCH_REMOTE_ALIAS,
- url=module_info.fork_repository_push,
- repo_dir=module_root)
+def setup_module_remotes(module_root: str, module_info: iree_modules.ModuleInfo):
+ iree_utils.git_setup_remote(
+ PATCH_REMOTE_ALIAS, url=module_info.fork_repository_push, repo_dir=module_root
+ )
def find_unused_branch_name(module_info: iree_modules.ModuleInfo):
- branch_base = f"{module_info.branch_prefix}{date.today().strftime('%Y%m%d')}"
- branch_name = branch_base
- existing_branches = iree_utils.git_ls_remote_branches(
- module_info.fork_repository_pull,
- filter=[f"refs/heads/{module_info.branch_prefix}*"])
- i = 1
- while branch_name in existing_branches:
- branch_name = f"{branch_base}.{i}"
- i += 1
- return branch_name
+ branch_base = f"{module_info.branch_prefix}{date.today().strftime('%Y%m%d')}"
+ branch_name = branch_base
+ existing_branches = iree_utils.git_ls_remote_branches(
+ module_info.fork_repository_pull,
+ filter=[f"refs/heads/{module_info.branch_prefix}*"],
+ )
+ i = 1
+ while branch_name in existing_branches:
+ branch_name = f"{branch_base}.{i}"
+ i += 1
+ return branch_name
def parse_arguments(argv):
- parser = argparse.ArgumentParser(description="IREE Submodule Patcher")
- parser.add_argument("--module", help="Submodule to operate on", default=None)
- parser.add_argument("--command", help="Command to execute", default="patch")
- args = parser.parse_args(argv)
- return args
+ parser = argparse.ArgumentParser(description="IREE Submodule Patcher")
+ parser.add_argument("--module", help="Submodule to operate on", default=None)
+ parser.add_argument("--command", help="Command to execute", default="patch")
+ args = parser.parse_args(argv)
+ return args
if __name__ == "__main__":
- main(parse_arguments(sys.argv[1:]))
+ main(parse_arguments(sys.argv[1:]))
diff --git a/build_tools/scripts/ir_to_markdown.py b/build_tools/scripts/ir_to_markdown.py
index 2642f42..476dff3 100644
--- a/build_tools/scripts/ir_to_markdown.py
+++ b/build_tools/scripts/ir_to_markdown.py
@@ -34,71 +34,74 @@
def parse_arguments():
- """Parses command line arguments."""
+ """Parses command line arguments."""
- parser = argparse.ArgumentParser()
- parser.add_argument(
- 'input_file_path',
- type=str,
- nargs='?',
- metavar="<input_file_path>",
- help='Input IR dump (.mlir from -mlir-print-ir-after-all)')
- parser.add_argument('-o,',
- '--output',
- type=str,
- required=True,
- metavar="<output>",
- help='Output file path (e.g. translation_ir.md)')
- # TODO(scotttodd): flags for original IR path and compilation command line
- # .md could then show original IR + flags -> output
- # TODO(scotttodd): flag for markdown flavor (mkdocs, github, etc.)
- # TODO(scotttodd): flag for diff view (correlate IR before and IR after)?
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "input_file_path",
+ type=str,
+ nargs="?",
+ metavar="<input_file_path>",
+ help="Input IR dump (.mlir from -mlir-print-ir-after-all)",
+ )
+ parser.add_argument(
+ "-o,",
+ "--output",
+ type=str,
+ required=True,
+ metavar="<output>",
+ help="Output file path (e.g. translation_ir.md)",
+ )
+ # TODO(scotttodd): flags for original IR path and compilation command line
+ # .md could then show original IR + flags -> output
+ # TODO(scotttodd): flag for markdown flavor (mkdocs, github, etc.)
+ # TODO(scotttodd): flag for diff view (correlate IR before and IR after)?
- return parser.parse_args()
+ return parser.parse_args()
def main(args):
- input_file_path = args.input_file_path
- output_file_path = args.output
- print("Converting input file '%s'" % (input_file_path))
- print(" into output file '%s'" % (output_file_path))
+ input_file_path = args.input_file_path
+ output_file_path = args.output
+ print("Converting input file '%s'" % (input_file_path))
+ print(" into output file '%s'" % (output_file_path))
- with open(input_file_path, "r") as input_file:
- with open(output_file_path, "w") as output_file:
+ with open(input_file_path, "r") as input_file:
+ with open(output_file_path, "w") as output_file:
+ # Iterate line by line through the input file, collecting text into
+ # blocks and writing them into the output file with markdown formatting
+ # as we go.
+ #
+ # Note: we could parse through and find/replace within the file using
+ # regex (or sed), but iterating this way is easier to understand and
+ # uses a predictable amount of memory.
- # Iterate line by line through the input file, collecting text into
- # blocks and writing them into the output file with markdown formatting
- # as we go.
- #
- # Note: we could parse through and find/replace within the file using
- # regex (or sed), but iterating this way is easier to understand and
- # uses a predictable amount of memory.
+ current_block_lines = []
+ dump_after_regex = re.compile(
+ MLIR_START_SEQUENCE + "\s(.*)\s" + MLIR_END_SEQUENCE
+ )
- current_block_lines = []
- dump_after_regex = re.compile(MLIR_START_SEQUENCE + "\s(.*)\s" +
- MLIR_END_SEQUENCE)
+ def finish_block():
+ nonlocal current_block_lines
+ if len(current_block_lines) != 0:
+ current_block_lines.append("```\n\n")
+ output_file.writelines(current_block_lines)
+ current_block_lines = []
- def finish_block():
- nonlocal current_block_lines
- if len(current_block_lines) != 0:
- current_block_lines.append("```\n\n")
- output_file.writelines(current_block_lines)
- current_block_lines = []
+ for input_line in input_file:
+ if input_line == "\n":
+ continue
- for input_line in input_file:
- if input_line == "\n":
- continue
+ if input_line.startswith(MLIR_START_SEQUENCE):
+ finish_block()
+ header_text = dump_after_regex.match(input_line).group(1)
+ current_block_lines.append("### " + header_text + "\n\n")
+ current_block_lines.append("```mlir\n")
+ else:
+ current_block_lines.append(input_line)
- if input_line.startswith(MLIR_START_SEQUENCE):
- finish_block()
- header_text = dump_after_regex.match(input_line).group(1)
- current_block_lines.append("### " + header_text + "\n\n")
- current_block_lines.append("```mlir\n")
- else:
- current_block_lines.append(input_line)
-
- finish_block()
+ finish_block()
-if __name__ == '__main__':
- main(parse_arguments())
+if __name__ == "__main__":
+ main(parse_arguments())
diff --git a/build_tools/scripts/local_web_server.py b/build_tools/scripts/local_web_server.py
index 835a760..a073273 100644
--- a/build_tools/scripts/local_web_server.py
+++ b/build_tools/scripts/local_web_server.py
@@ -20,47 +20,53 @@
class CORSHTTPRequestHandler(server.SimpleHTTPRequestHandler):
+ def __init__(self, *args, **kwargs):
+ # Include MIME types for files we expect to be serving.
+ # https://developer.mozilla.org/en-US/docs/Web/HTTP/Basics_of_HTTP/MIME_types/Common_types
+ self.extensions_map.update(
+ {
+ ".js": "application/javascript",
+ ".wasm": "application/wasm",
+ }
+ )
+ super().__init__(*args, **kwargs)
- def __init__(self, *args, **kwargs):
- # Include MIME types for files we expect to be serving.
- # https://developer.mozilla.org/en-US/docs/Web/HTTP/Basics_of_HTTP/MIME_types/Common_types
- self.extensions_map.update({
- ".js": "application/javascript",
- ".wasm": "application/wasm",
- })
- super().__init__(*args, **kwargs)
+ # Inspiration for this hack: https://stackoverflow.com/a/13354482
+ def end_headers(self):
+ self.send_cors_headers()
- # Inspiration for this hack: https://stackoverflow.com/a/13354482
- def end_headers(self):
- self.send_cors_headers()
+ server.SimpleHTTPRequestHandler.end_headers(self)
- server.SimpleHTTPRequestHandler.end_headers(self)
-
- def send_cors_headers(self):
- # Emscripten uses SharedArrayBuffer for its multithreading, which requires
- # Cross Origin Opener Policy and Cross Origin Embedder Policy headers:
- # * https://emscripten.org/docs/porting/pthreads.html
- # * https://developer.chrome.com/blog/enabling-shared-array-buffer/
- self.send_header("Cross-Origin-Embedder-Policy", "require-corp")
- self.send_header("Cross-Origin-Opener-Policy", "same-origin")
+ def send_cors_headers(self):
+ # Emscripten uses SharedArrayBuffer for its multithreading, which requires
+ # Cross Origin Opener Policy and Cross Origin Embedder Policy headers:
+ # * https://emscripten.org/docs/porting/pthreads.html
+ # * https://developer.chrome.com/blog/enabling-shared-array-buffer/
+ self.send_header("Cross-Origin-Embedder-Policy", "require-corp")
+ self.send_header("Cross-Origin-Opener-Policy", "same-origin")
-if __name__ == '__main__':
- import argparse
- parser = argparse.ArgumentParser()
- parser.add_argument('--directory',
- '-d',
- default=os.getcwd(),
- help='Specify alternative directory '
- '[default:current directory]')
- parser.add_argument('port',
- action='store',
- default=8000,
- type=int,
- nargs='?',
- help='Specify alternate port [default: 8000]')
- args = parser.parse_args()
+if __name__ == "__main__":
+ import argparse
- server.test(HandlerClass=partial(CORSHTTPRequestHandler,
- directory=args.directory),
- port=args.port)
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--directory",
+ "-d",
+ default=os.getcwd(),
+ help="Specify alternative directory " "[default:current directory]",
+ )
+ parser.add_argument(
+ "port",
+ action="store",
+ default=8000,
+ type=int,
+ nargs="?",
+ help="Specify alternate port [default: 8000]",
+ )
+ args = parser.parse_args()
+
+ server.test(
+ HandlerClass=partial(CORSHTTPRequestHandler, directory=args.directory),
+ port=args.port,
+ )
diff --git a/build_tools/scripts/update_tflite_models.py b/build_tools/scripts/update_tflite_models.py
index e2ea887..2af25fb 100644
--- a/build_tools/scripts/update_tflite_models.py
+++ b/build_tools/scripts/update_tflite_models.py
@@ -26,45 +26,44 @@
import urllib
FLAGS = flags.FLAGS
-flags.DEFINE_string('file', '', 'file to update')
+flags.DEFINE_string("file", "", "file to update")
-file_dict = dict({
- "mobilenet_v1.tflite":
- "https://tfhub.dev/tensorflow/lite-model/mobilenet_v1_1.0_160/1/default/1?lite-format=tflite",
- "posenet_i8.tflite":
- "https://tfhub.dev/google/lite-model/movenet/singlepose/lightning/tflite/int8/4?lite-format=tflite",
- "posenet_i8_input.jpg":
- "https://github.com/tensorflow/examples/raw/master/lite/examples/pose_estimation/raspberry_pi/test_data/image3.jpeg"
-})
+file_dict = dict(
+ {
+ "mobilenet_v1.tflite": "https://tfhub.dev/tensorflow/lite-model/mobilenet_v1_1.0_160/1/default/1?lite-format=tflite",
+ "posenet_i8.tflite": "https://tfhub.dev/google/lite-model/movenet/singlepose/lightning/tflite/int8/4?lite-format=tflite",
+ "posenet_i8_input.jpg": "https://github.com/tensorflow/examples/raw/master/lite/examples/pose_estimation/raspberry_pi/test_data/image3.jpeg",
+ }
+)
BUCKET_NAME = "iree-model-artifacts"
FOLDER_NAME = "tflite-integration-tests"
def upload_model(source, destination, tmpfile):
- """Uploads a file to the bucket."""
- urllib.request.urlretrieve(source, tmpfile)
+ """Uploads a file to the bucket."""
+ urllib.request.urlretrieve(source, tmpfile)
- storage_client = storage.Client()
- bucket = storage_client.get_bucket(BUCKET_NAME)
- blob = bucket.blob("/".join([FOLDER_NAME, destination]))
- blob.upload_from_filename(tmpfile)
+ storage_client = storage.Client()
+ bucket = storage_client.get_bucket(BUCKET_NAME)
+ blob = bucket.blob("/".join([FOLDER_NAME, destination]))
+ blob.upload_from_filename(tmpfile)
def main(argv):
- tf = tempfile.NamedTemporaryFile()
+ tf = tempfile.NamedTemporaryFile()
- items = file_dict.items()
+ items = file_dict.items()
- if FLAGS.file in file_dict:
- items = [(FLAGS.file, file_dict[FLAGS.file])]
- elif FLAGS.file != "all":
- print('Unknown file to upload: ', "\"" + FLAGS.file + "\"")
- exit()
+ if FLAGS.file in file_dict:
+ items = [(FLAGS.file, file_dict[FLAGS.file])]
+ elif FLAGS.file != "all":
+ print("Unknown file to upload: ", '"' + FLAGS.file + '"')
+ exit()
- for dst, src in items:
- upload_model(src, dst, tf.name)
+ for dst, src in items:
+ upload_model(src, dst, tf.name)
-if __name__ == '__main__':
- app.run(main)
+if __name__ == "__main__":
+ app.run(main)
diff --git a/build_tools/scripts/utils.py b/build_tools/scripts/utils.py
index 7713f5d..36b31ae 100644
--- a/build_tools/scripts/utils.py
+++ b/build_tools/scripts/utils.py
@@ -14,38 +14,46 @@
def create_markdown_table(rows: Sequence[Sequence[str]]):
- """Converts a 2D array to a Markdown table."""
- return '\n'.join([' | '.join(row) for row in rows])
+ """Converts a 2D array to a Markdown table."""
+ return "\n".join([" | ".join(row) for row in rows])
-def check_and_get_output_lines(command: Sequence[str],
- dry_run: bool = False,
- log_stderr: bool = True):
- print(f'Running: `{" ".join(command)}`')
- if dry_run:
- return None, None
- return subprocess.run(command, stdout=subprocess.PIPE, text=true,
- check=True).stdout.splitlines()
+def check_and_get_output_lines(
+ command: Sequence[str], dry_run: bool = False, log_stderr: bool = True
+):
+ print(f'Running: `{" ".join(command)}`')
+ if dry_run:
+ return None, None
+ return subprocess.run(
+ command, stdout=subprocess.PIPE, text=true, check=True
+ ).stdout.splitlines()
def get_test_targets(test_suite_path: str):
- """Returns a list of test targets for the given test suite."""
- # Check if the suite exists (which may not be true for failing suites).
- # We use two queries here because the return code for a failed query is
- # unfortunately the same as the return code for a bazel configuration error.
- target_dir = test_suite_path.split(':')[0]
- query = [
- 'bazel', 'query', '--ui_event_filters=-DEBUG',
- '--noshow_loading_progress', '--noshow_progress', f'{target_dir}/...'
- ]
- targets = check_and_get_output_lines(query)
- if test_suite_path not in targets:
- return []
+ """Returns a list of test targets for the given test suite."""
+ # Check if the suite exists (which may not be true for failing suites).
+ # We use two queries here because the return code for a failed query is
+ # unfortunately the same as the return code for a bazel configuration error.
+ target_dir = test_suite_path.split(":")[0]
+ query = [
+ "bazel",
+ "query",
+ "--ui_event_filters=-DEBUG",
+ "--noshow_loading_progress",
+ "--noshow_progress",
+ f"{target_dir}/...",
+ ]
+ targets = check_and_get_output_lines(query)
+ if test_suite_path not in targets:
+ return []
- query = [
- 'bazel', 'query', '--ui_event_filters=-DEBUG',
- '--noshow_loading_progress', '--noshow_progress',
- f'tests({test_suite_path})'
- ]
- tests = check_and_get_output_lines(query)
- return tests
+ query = [
+ "bazel",
+ "query",
+ "--ui_event_filters=-DEBUG",
+ "--noshow_loading_progress",
+ "--noshow_progress",
+ f"tests({test_suite_path})",
+ ]
+ tests = check_and_get_output_lines(query)
+ return tests