blob: 18260c2da4aa66f03f3ad025416e03b27b5f2505 [file] [log] [blame]
# Lint-as: python3
"""Utilities for locating and invoking compiler tools."""
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import io
import os
import subprocess
import sys
import textwrap
import threading
from typing import List, Optional
__all__ = [
"find_tool",
"invoke_immediate",
"invoke_pipeline",
"get_tool_path",
"CompilerToolError",
]
# In normal distribution circumstances, each named tool is associated with
# a python module that provides a `get_tool` function for getting its absolute
# path. This dictionary maps the tool name to the module.
_TOOL_MODULE_MAP = {
"iree-tf-import": "pyiree.tools.tf",
"iree-translate": "pyiree.tools.core",
}
# Map of tool module to package name as distributed to archives (used for
# error messages).
_TOOL_MODULE_PACKAGES = {
"pyiree.tools.core": "google-iree-tools-core",
"pyiree.tools.tf": "google-iree-tools-tf",
}
# Environment variable holding directories to be searched for named tools.
# Delimitted by os.pathsep.
_TOOL_PATH_ENVVAR = "IREE_TOOL_PATH"
class CompilerToolError(Exception):
"""Compiler exception that preserves the command line and error output."""
def __init__(self, process: subprocess.CompletedProcess):
try:
errs = process.stderr.decode("utf-8")
except:
errs = str(process.stderr) # Decode error or other: best we can do.
tool_name = os.path.basename(process.args[0])
super().__init__(f"Error invoking IREE compiler tool {tool_name}\n"
f"Diagnostics:\n{errs}\n\n"
f"Invoked with:\n {' '.join(process.args)}")
def get_tool_path() -> List[str]:
"""Returns list of paths to search for tools."""
list_str = os.environ.get(_TOOL_PATH_ENVVAR)
if not list_str:
return []
return list_str.split(os.pathsep)
def find_tool(exe_name: str) -> str:
"""Finds a tool by its (extension-less) executable name.
Args:
exe_name: The name of the executable (extension-less).
Returns:
An absolute path to the tool.
Raises:
ValueError: If the tool is not known or not found.
"""
if exe_name not in _TOOL_MODULE_MAP:
raise ValueError(f"IREE compiler tool '{exe_name}' is not a known tool")
# First search an explicit tool path.
tool_path = get_tool_path()
for path_entry in tool_path:
if not path_entry:
continue
candidate_exe = os.path.join(path_entry, exe_name)
if os.path.isfile(candidate_exe) and os.access(candidate_exe, os.X_OK):
return candidate_exe
# Attempt to load the tool module.
tool_module_name = _TOOL_MODULE_MAP[exe_name]
tool_module_package = _TOOL_MODULE_PACKAGES[tool_module_name]
try:
tool_module = importlib.import_module(tool_module_name)
except ModuleNotFoundError:
raise ValueError(
f"IREE compiler tool '{exe_name}' is not installed (it should have been "
f"found in the python module '{tool_module_name}', typically installed "
f"via the package {tool_module_package}).\n\n"
f"Either install the package or set the {_TOOL_PATH_ENVVAR} environment "
f"variable to contain the path of the tool executable "
f"(current {_TOOL_PATH_ENVVAR} = {repr(tool_path)})") from None
# Ask the module for its tool.
candidate_exe = tool_module.get_tool(exe_name)
if (not candidate_exe or not os.path.isfile(candidate_exe) or
not os.access(candidate_exe, os.X_OK)):
raise ValueError(
f"IREE compiler tool '{exe_name}' was located in module "
f"'{tool_module_name}' but the file was not found or not executable: "
f"{candidate_exe}")
return candidate_exe
def invoke_immediate(command_line: List[str],
*,
input_file: Optional[str] = None,
immediate_input=None):
"""Invokes an immediate command.
This is separate from invoke_pipeline as it is simpler and supports more
complex input redirection, using recommended facilities for sub-processes
(less magic).
Note that this differs from the usual way of using subprocess.run or
subprocess.Popen().communicate() because we need to pump all of the error
streams individually and only pump pipes not connected to a different stage.
Uses threads to pump everything that is required.
"""
run_args = {}
input_file_handle = None
stderr_handle = sys.stderr
try:
# Redirect input.
if input_file is not None:
input_file_handle = open(input_file, "rb")
run_args["stdin"] = input_file_handle
elif immediate_input is not None:
run_args["input"] = immediate_input
# Capture output.
# TODO(#4131) python>=3.7: Use capture_output=True.
run_args["stdout"] = subprocess.PIPE
run_args["stderr"] = subprocess.PIPE
process = subprocess.run(command_line, **run_args)
if process.returncode != 0:
raise CompilerToolError(process)
# Emit stderr contents.
_write_binary_stderr(stderr_handle, process.stderr)
return process.stdout
finally:
if input_file_handle:
input_file_handle.close()
def invoke_pipeline(command_lines: List[List[str]]):
"""Invoke a pipeline of commands.
The first stage of the pipeline will have its stdin set to DEVNULL and each
subsequent stdin will derive from the prior stdout. The final stdout will
be accumulated and returned. All stderr contents are accumulated and printed
to stderr on completion or the first failing stage of the pipeline will have
an exception raised with its stderr output.
"""
stages = []
prev_out = subprocess.DEVNULL
stderr_handle = sys.stderr
# Create all stages.
for i in range(len(command_lines)):
command_line = command_lines[i]
popen_args = {
"stdin": prev_out,
"stdout": subprocess.PIPE,
"stderr": subprocess.PIPE,
}
process = subprocess.Popen(command_line, **popen_args)
prev_out = process.stdout
capture_output = (i == (len(command_lines) - 1))
stages.append(_PipelineStage(process, capture_output))
# Start stages.
for stage in stages:
stage.start()
# Join.
for stage in stages:
stage.join()
# Check for errors.
for stage in stages:
assert stage.completed
if stage.completed.returncode != 0:
raise CompilerToolError(stage.completed)
# Print any stderr output.
for stage in stages:
_write_binary_stderr(stderr_handle, stage.errs)
return stages[-1].outs
class _PipelineStage(threading.Thread):
"""Wraps a process and pumps its handles, waiting for completion."""
def __init__(self, process, capture_output):
super().__init__()
self.process = process
self.capture_output = capture_output
self.completed: Optional[subprocess.CompletedProcess] = None
self.outs = None
self.errs = None
def pump_stderr(self):
self.errs = self.process.stderr.read()
def pump_stdout(self):
self.outs = self.process.stdout.read()
def run(self):
stderr_thread = threading.Thread(target=self.pump_stderr)
stderr_thread.start()
if self.capture_output:
stdout_thread = threading.Thread(target=self.pump_stdout)
stdout_thread.start()
self.process.wait()
stderr_thread.join()
if self.capture_output:
stdout_thread.join()
self.completed = subprocess.CompletedProcess(self.process.args,
self.process.returncode,
self.outs, self.errs)
self.process.stderr.close()
self.process.stdout.close()
def _write_binary_stderr(out_handle, contents):
# Fast-paths buffered text-io (which stderr is by default) while allowing
# full decode for non buffered and binary io.
if hasattr(out_handle, "buffer"):
out_handle.buffer.write(contents)
elif isinstance(out_handle, io.TextIOBase):
out_handle.write(contents.decode("utf-8"))
else:
out_handle.write(contents)