blob: 4f9393729f684ec6a1919faf3a63403898f0a4c1 [file] [log] [blame]
# Lint-as: python3
"""XLA compiler interface."""
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(#4131) python>=3.7: Use postponed type annotations.
from enum import Enum
import logging
import tempfile
from typing import List, Optional, Sequence, Set, Union
from .tools import find_tool, invoke_immediate, invoke_pipeline
from .core import CompilerOptions, DEFAULT_TESTING_BACKENDS, build_compile_command_line
__all__ = [
"compile_file",
"compile_str",
"is_available",
"DEFAULT_TESTING_BACKENDS",
"ImportOptions",
"ImportFormat",
]
_IMPORT_TOOL = "iree-import-xla"
def is_available():
"""Determine if the XLA frontend is available."""
try:
find_tool(_IMPORT_TOOL)
except ValueError:
logging.warning("Unable to find IREE tool %s", _IMPORT_TOOL)
return False
return True
class ImportFormat(Enum):
"""Import type of the model."""
BINARY_PROTO = "binary_proto"
TEXT_PROTO = "text_proto"
@staticmethod
def parse(spec: Union[str, "ImportFormat"]) -> "ImportFormat":
"""Parses or returns an ImportFormat.
Args:
spec: An ImportFormat instance or the case-insensitive name of one of
the enum values.
Returns:
An ImportFormat instance.
"""
if isinstance(spec, ImportFormat):
return spec
spec = spec.upper()
if spec not in ImportFormat.__members__:
raise ValueError(f"For import_format= argument, expected one of: "
f"{', '.join(ImportFormat.__members__.keys())}")
return ImportFormat[spec]
# TODO(#4131) python>=3.7: Consider using a dataclass.
class ImportOptions(CompilerOptions):
"""Import options layer on top of the backend compiler options."""
def __init__(self,
import_only: bool = False,
import_format: Union[ImportFormat,
str] = ImportFormat.BINARY_PROTO,
import_extra_args: Sequence[str] = (),
save_temp_iree_input: Optional[str] = None,
**kwargs):
"""Initialize options from keywords.
Args:
import_format: Format of the proto (text or binary).
save_temp_iree_input: Optionally save the IR that is the result of the
import (ready to be passed to IREE).
"""
super().__init__(**kwargs)
self.import_only = import_only
self.import_format = ImportFormat.parse(import_format)
self.import_extra_args = import_extra_args
self.save_temp_iree_input = save_temp_iree_input
def build_import_command_line(input_path: str,
options: ImportOptions) -> List[str]:
"""Builds a command line for invoking the import stage.
Args:
input_path: The input path.
options: Import options.
Returns:
List of strings of command line.
"""
import_tool = find_tool(_IMPORT_TOOL)
cl = [
import_tool,
input_path,
f"--xla-format={options.import_format.value}",
]
if options.import_only and options.output_file:
# Import stage directly outputs.
if options.output_file:
cl.append(f"-o={options.output_file}")
# Save temps flags.
if options.save_temp_iree_input:
cl.append(f"--save-temp-iree-input={options.save_temp_iree_input}")
# Extra args.
cl.extend(options.import_extra_args)
return cl
def compile_file(xla_file_path: str, **kwargs):
"""Compiles an on-disk XLA protocol buffer to an IREE binary.
Args:
xla_file_path: Path to the XLA proto file.
**kwargs: Keyword args corresponding to ImportOptions or CompilerOptions.
Returns:
A bytes-like object with the compiled output or None if output_file=
was specified.
"""
options = ImportOptions(**kwargs)
import_cl = build_import_command_line(xla_file_path, options)
if options.import_only:
# One stage tool pipeline.
result = invoke_immediate(import_cl)
if options.output_file:
return None
return result
# Full compilation pipeline.
compile_cl = build_compile_command_line("-", options)
result = invoke_pipeline([import_cl, compile_cl])
if options.output_file:
return None
return result
def compile_str(xla_content: Union[bytes, str], **kwargs):
"""Compiles in-memory XLA content to an IREE binary.
Args:
xla_content: Either bytes or str content (str is only valid for text
formats).
**kwargs: Keyword args corresponding to ImportOptions or CompilerOptions.
Returns:
A bytes-like object with the compiled output or None if output_file=
was specified.
"""
options = ImportOptions(**kwargs)
if isinstance(xla_content, str):
if options.import_format != ImportFormat.TEXT_PROTO:
raise ValueError("If passing a string, ImportFormat must be TEXT_PROTO")
xla_content = xla_content.encode("utf-8")
import_cl = build_import_command_line("-", options)
if options.import_only:
# One stage tool pipeline.
result = invoke_immediate(import_cl, immediate_input=xla_content)
if options.output_file:
return None
return result
# Full compilation pipeline.
compile_cl = build_compile_command_line("-", options)
result = invoke_pipeline([import_cl, compile_cl], immediate_input=xla_content)
if options.output_file:
return None
return result