# Lint-as: python3
"""Utilities for locating and invoking compiler tools."""

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
import io
import os
import subprocess
import sys
import textwrap
import threading

from typing import List, Optional

__all__ = [
    "find_tool",
    "invoke_immediate",
    "invoke_pipeline",
    "get_tool_path",
    "CompilerToolError",
]

# In normal distribution circumstances, each named tool is associated with
# a python module that provides a `get_tool` function for getting its absolute
# path. This dictionary maps the tool name to the module.
_TOOL_MODULE_MAP = {
    "iree-tf-import": "pyiree.tools.tf",
    "iree-translate": "pyiree.tools.core",
}

# Map of tool module to package name as distributed to archives (used for
# error messages).
_TOOL_MODULE_PACKAGES = {
    "pyiree.tools.tf": "pyiree_compiler_tf",
    "pyiree.tools.core": "pyiree_compiler_core",
}

# Environment variable holding directories to be searched for named tools.
# Delimitted by os.pathsep.
_TOOL_PATH_ENVVAR = "IREE_TOOL_PATH"


class CompilerToolError(Exception):
  """Compiler exception that preserves the command line and error output."""

  def __init__(self, process: subprocess.CompletedProcess):
    try:
      errs = process.stderr.decode("utf-8")
    except:
      errs = str(process.stderr)  # Decode error or other: best we can do.

    tool_name = os.path.basename(process.args[0])
    super().__init__(f"Error invoking IREE compiler tool {tool_name}\n"
                     f"Diagnostics:\n{errs}\n\n"
                     f"Invoked with:\n  {' '.join(process.args)}")


def get_tool_path() -> List[str]:
  """Returns list of paths to search for tools."""
  list_str = os.environ.get(_TOOL_PATH_ENVVAR)
  if not list_str:
    return []
  return list_str.split(os.pathsep)


def find_tool(exe_name: str) -> str:
  """Finds a tool by its (extension-less) executable name.

  Args:
    exe_name: The name of the executable (extension-less).
  Returns:
    An absolute path to the tool.
  Raises:
    ValueError: If the tool is not known or not found.
  """
  if exe_name not in _TOOL_MODULE_MAP:
    raise ValueError(f"IREE compiler tool '{exe_name}' is not a known tool")
  # First search an explicit tool path.
  tool_path = get_tool_path()
  for path_entry in tool_path:
    if not path_entry:
      continue
    candidate_exe = os.path.join(path_entry, exe_name)
    if os.path.isfile(candidate_exe) and os.access(candidate_exe, os.X_OK):
      return candidate_exe

  # Attempt to load the tool module.
  tool_module_name = _TOOL_MODULE_MAP[exe_name]
  tool_module_package = _TOOL_MODULE_PACKAGES[tool_module_name]
  try:
    tool_module = importlib.import_module(tool_module_name)
  except ModuleNotFoundError:
    raise ValueError(
        f"IREE compiler tool '{exe_name}' is not installed (it should have been "
        f"found in the python module '{tool_module_name}', typically installed "
        f"via the package {tool_module_package}).\n\n"
        f"Either install the package or set the {_TOOL_PATH_ENVVAR} environment "
        f"variable to contain the path of the tool executable "
        f"(current {_TOOL_PATH_ENVVAR} = {repr(tool_path)})") from None

  # Ask the module for its tool.
  candidate_exe = tool_module.get_tool(exe_name)
  if (not candidate_exe or not os.path.isfile(candidate_exe) or
      not os.access(candidate_exe, os.X_OK)):
    raise ValueError(
        f"IREE compiler tool '{exe_name}' was located in module "
        f"'{tool_module_name}' but the file was not found or not executable: "
        f"{candidate_exe}")
  return candidate_exe


def invoke_immediate(command_line: List[str],
                     *,
                     input_file: Optional[str] = None,
                     immediate_input=None):
  """Invokes an immediate command.

  This is separate from invoke_pipeline as it is simpler and supports more
  complex input redirection, using recommended facilities for sub-processes
  (less magic).

  Note that this differs from the usual way of using subprocess.run or
  subprocess.Popen().communicate() because we need to pump all of the error
  streams individually and only pump pipes not connected to a different stage.
  Uses threads to pump everything that is required.
  """
  run_args = {}
  input_file_handle = None
  stderr_handle = sys.stderr
  try:
    # Redirect input.
    if input_file is not None:
      input_file_handle = open(input_file, "rb")
      run_args["stdin"] = input_file_handle
    elif immediate_input is not None:
      run_args["input"] = immediate_input

    # Capture output.
    # Upgrade note: Python >= 3.7 can just use capture_output=True
    run_args["stdout"] = subprocess.PIPE
    run_args["stderr"] = subprocess.PIPE
    process = subprocess.run(command_line, **run_args)
    if process.returncode != 0:
      raise CompilerToolError(process)
    # Emit stderr contents.
    _write_binary_stderr(stderr_handle, process.stderr)
    return process.stdout
  finally:
    if input_file_handle:
      input_file_handle.close()


def invoke_pipeline(command_lines: List[List[str]]):
  """Invoke a pipeline of commands.

  The first stage of the pipeline will have its stdin set to DEVNULL and each
  subsequent stdin will derive from the prior stdout. The final stdout will
  be accumulated and returned. All stderr contents are accumulated and printed
  to stderr on completion or the first failing stage of the pipeline will have
  an exception raised with its stderr output.
  """
  stages = []
  prev_out = subprocess.DEVNULL
  stderr_handle = sys.stderr

  # Create all stages.
  for i in range(len(command_lines)):
    command_line = command_lines[i]
    popen_args = {
        "stdin": prev_out,
        "stdout": subprocess.PIPE,
        "stderr": subprocess.PIPE,
    }
    process = subprocess.Popen(command_line, **popen_args)
    prev_out = process.stdout
    capture_output = (i == (len(command_lines) - 1))
    stages.append(_PipelineStage(process, capture_output))

  # Start stages.
  for stage in stages:
    stage.start()

  # Join.
  for stage in stages:
    stage.join()

  # Check for errors.
  for stage in stages:
    assert stage.completed
    if stage.completed.returncode != 0:
      raise CompilerToolError(stage.completed)

  # Print any stderr output.
  for stage in stages:
    _write_binary_stderr(stderr_handle, stage.errs)
  return stages[-1].outs


class _PipelineStage(threading.Thread):
  """Wraps a process and pumps its handles, waiting for completion."""

  def __init__(self, process, capture_output):
    super().__init__()
    self.process = process
    self.capture_output = capture_output
    self.completed: Optional[subprocess.CompletedProcess] = None
    self.outs = None
    self.errs = None

  def pump_stderr(self):
    self.errs = self.process.stderr.read()

  def pump_stdout(self):
    self.outs = self.process.stdout.read()

  def run(self):
    stderr_thread = threading.Thread(target=self.pump_stderr)
    stderr_thread.start()
    if self.capture_output:
      stdout_thread = threading.Thread(target=self.pump_stdout)
      stdout_thread.start()
    self.process.wait()
    stderr_thread.join()
    if self.capture_output:
      stdout_thread.join()
    self.completed = subprocess.CompletedProcess(self.process.args,
                                                 self.process.returncode,
                                                 self.outs, self.errs)
    self.process.stderr.close()
    self.process.stdout.close()


def _write_binary_stderr(out_handle, contents):
  # Fast-paths buffered text-io (which stderr is by default) while allowing
  # full decode for non buffered and binary io.
  if hasattr(out_handle, "buffer"):
    out_handle.buffer.write(contents)
  elif isinstance(out_handle, io.TextIOBase):
    out_handle.write(contents.decode("utf-8"))
  else:
    out_handle.write(contents)
