Merge pull request #4120 from google:main-to-google
PiperOrigin-RevId: 346216570
diff --git a/CMakeLists.txt b/CMakeLists.txt
index b37f0e0..9f08720 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -46,6 +46,7 @@
option(IREE_BUILD_PYTHON_BINDINGS "Builds the IREE python bindings" OFF)
option(IREE_BUILD_JAVA_BINDINGS "Builds the IREE java bindings." OFF)
option(IREE_BUILD_EXPERIMENTAL "Builds experimental projects." OFF)
+option(IREE_BUILD_TENSORFLOW_COMPILER "Builds TensorFlow compiler." OFF)
set(IREE_HAL_DRIVERS_TO_BUILD "all"
CACHE STRING "Semicolon-separated list of HAL drivers to build, or \"all\".")
@@ -268,6 +269,18 @@
include(iree_setup_toolchain)
#-------------------------------------------------------------------------------
+# Configure python early if there are any features that need it.
+# Note that doing this early ensures that dependencies that make incidental
+# use of Python (such as LLVM) resolve the same version.
+#-------------------------------------------------------------------------------
+
+if(${IREE_BUILD_COMPILER} OR
+ ${IREE_BUILD_PYTHON_BINDINGS} OR
+ ${IREE_BUILD_TENSORFLOW_COMPILER})
+ find_package(Python3 COMPONENTS Interpreter REQUIRED)
+endif()
+
+#-------------------------------------------------------------------------------
# MLIR/LLVM Dependency
# We treat the LLVM dependency specially because we support several different
# ways to use it:
@@ -371,22 +384,28 @@
endif()
#-------------------------------------------------------------------------------
-# Non-LLVM Dependencies
+# Python bindings.
#-------------------------------------------------------------------------------
-# Use the FindPython functions before any of our dependencies do. See
-# https://pybind11.readthedocs.io/en/stable/faq.html#inconsistent-detection-of-python-version-in-cmake-and-pybind11
-# If one dependency finds Python 2 (the default),
-# any others that try to find Python 3 will fail.
-# (Also come on, it's $CURRENT_YEAR - please just use Python 3 already.)
-if(${IREE_BUILD_COMPILER} OR ${IREE_BUILD_PYTHON_BINDINGS})
- find_package(Python3 COMPONENTS Interpreter REQUIRED)
-endif()
if(${IREE_BUILD_PYTHON_BINDINGS})
# Note: Optional because python libs can be manually specified.
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
endif()
+#-------------------------------------------------------------------------------
+# Bazel setup (conditional on whether features need it)
+# Depends on python configuration.
+#-------------------------------------------------------------------------------
+
+if(${IREE_BUILD_TENSORFLOW_COMPILER})
+ include(configure_bazel)
+ iree_configure_bazel()
+endif()
+
+#-------------------------------------------------------------------------------
+# Other dependencies.
+#-------------------------------------------------------------------------------
+
include(external_cc_library)
include(flatbuffer_c_library)
@@ -516,6 +535,10 @@
add_subdirectory(experimental)
endif()
+if(${IREE_BUILD_TENSORFLOW_COMPILER})
+ add_subdirectory(integrations/tensorflow)
+endif()
+
if(${IREE_BUILD_PYTHON_BINDINGS})
iree_complete_py_extension_link_options()
endif()
diff --git a/bindings/python/CMakeLists.txt b/bindings/python/CMakeLists.txt
index 3febc3b..90a844c 100644
--- a/bindings/python/CMakeLists.txt
+++ b/bindings/python/CMakeLists.txt
@@ -17,4 +17,24 @@
set(PYBIND_COPTS "-fexceptions")
set(PYBIND_EXTENSION_COPTS "-fvisibility=hidden")
-add_subdirectory(pyiree)
+# Generated setup scripts.
+# TODO: Make the version configurable.
+set(IREE_PYTHON_VERSION "0.1a1")
+configure_file(setup.py setup.py COPYONLY)
+configure_file(setup_compiler.py.in setup_compiler.py)
+configure_file(setup_runtime.py.in setup_runtime.py)
+configure_file(setup_tools_core.py.in setup_tools_core.py)
+configure_file(setup_tools_tf.py.in setup_tools_tf.py)
+
+# Namespace packages.
+add_subdirectory(pyiree/common) # Deprecated
+add_subdirectory(pyiree/compiler2)
+add_subdirectory(pyiree/rt)
+
+if(${IREE_BUILD_COMPILER})
+add_subdirectory(pyiree/compiler) # Deprecated
+add_subdirectory(pyiree/tools/core)
+endif()
+
+# Tests.
+add_subdirectory(tests)
diff --git a/bindings/python/README.md b/bindings/python/README.md
index 4307803..18b7b44 100644
--- a/bindings/python/README.md
+++ b/bindings/python/README.md
@@ -1,19 +1,43 @@
-# IREE Python Sandbox
+# IREE Python API
-This directory contains various integration-oriented Python utilities that are
-not intended to be a public API. They are, however, useful for lower level
-compiler interop work. And of course, they are useful since we presently lack a
-real API :)
+Top-level packages:
-We're still untangling build support, jupyter integration, etc for OSS builds.
-Stand by.
+* `pyiree.compiler2` : Main compiler API (soon to be renamed to 'compiler').
+* `pyiree.rt` : Runtime components for executing binaries.
+* `pyiree.tools.core` : Core tools for executing the compiler.
+* `pyiree.tools.tf` : TensorFlow compiler tools (if enabled).
-## Issues:
+Deprecated packages:
-* This is called `pyiree` vs `iree` to avoid pythonpath collisions that tend
- to arise when an iree directory is inside of an iree directory.
-* The above could be solved in the bazel build by making iree/bindings/python
- its own sub-workspace.
-* However, doing so presently breaks both flatbuffer and tablegen generation
- because of fixes needed to those build rules so that they are sub-worksapce
- aware.
+* `pyiree.compiler`
+* `pyiree.common`
+* `pyiree.tf.compiler`
+
+## Installing
+
+First perform a normal CMake build with the following options:
+
+* `-DIREE_BUILD_PYTHON_BINDINGS=ON` : Enables Python Bindings
+* `-DIREE_BUILD_TENSORFLOW_COMPILER=ON` (optional) : Enables building the
+ TensorFlow compilers (note: requires additional dependencies. see overall
+ build docs).
+
+Then from the build directory, run:
+
+```shell
+# Install into a local installation or virtualenv.
+python bindings/python/setup.py install
+
+# Build wheels.
+python bindings/python/setup.py bdist_wheel
+```
+
+## Development mode
+
+For development, just set your `PYTHONPATH` environment variable to the
+`bindings/python` directory in your CMake build dir.
+
+## Run tests
+
+Tests under `bindings/python/tests` can be run directly once installed.
+Additional tests under `integrations/tensorflow/e2e` will be runnable soon.
diff --git a/bindings/python/pyiree/compiler2/.skip_bazel_to_cmake b/bindings/python/pyiree/compiler2/.skip_bazel_to_cmake
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/bindings/python/pyiree/compiler2/.skip_bazel_to_cmake
diff --git a/bindings/python/pyiree/CMakeLists.txt b/bindings/python/pyiree/compiler2/CMakeLists.txt
similarity index 76%
copy from bindings/python/pyiree/CMakeLists.txt
copy to bindings/python/pyiree/compiler2/CMakeLists.txt
index 55ac43d..c9a4ef8 100644
--- a/bindings/python/pyiree/CMakeLists.txt
+++ b/bindings/python/pyiree/compiler2/CMakeLists.txt
@@ -12,9 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-add_subdirectory(common)
-add_subdirectory(rt)
+# Static and generated files.
+configure_file(README.md README.md COPYONLY)
-if(${IREE_BUILD_COMPILER})
- add_subdirectory(compiler)
-endif()
+iree_py_library(
+ NAME
+ compiler2
+ SRCS
+ "__init__.py"
+ "core.py"
+ "tf.py"
+ "tools.py"
+)
diff --git a/bindings/python/pyiree/compiler2/README.md b/bindings/python/pyiree/compiler2/README.md
new file mode 100644
index 0000000..f9d0e5b
--- /dev/null
+++ b/bindings/python/pyiree/compiler2/README.md
@@ -0,0 +1,46 @@
+# IREE Compiler Python Bindings
+
+Transitional note: These bindings are not complete yet and will ultimately
+replace the `pyiree.compiler` and `pyiree.tf.compiler` packages.
+
+## Core compiler
+
+```py
+from pyiree.compiler2 import *
+
+SIMPLE_MUL_ASM = """
+func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
+ attributes { iree.module.export } {
+ %0 = "mhlo.multiply"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ return %0 : tensor<4xf32>
+}
+"""
+
+# Also see compile_file()
+# There are many keyword options available.
+# See pyiree.compiler2.CompilerOptions
+binary = compile_str(SIMPLE_MUL_ASM, target_backends=["vulkan-spirv"])
+```
+
+
+## TensorFlow compiler
+
+```py
+import tensorflow as tf
+from pyiree.compiler2.tf import *
+
+class SimpleArithmeticModule(tf.Module):
+
+ @tf.function(input_signature=[
+ tf.TensorSpec([4], tf.float32),
+ tf.TensorSpec([4], tf.float32)
+ ])
+ def simple_mul(self, a, b):
+ return a * b
+
+# Also see compile_saved_model to directly compile an on-disk saved model.
+# There are many keyword options available.
+# See: pyiree.compiler2.tf.ImportOptions
+binary = compile_module(
+ SimpleArithmeticModule(), target_backends=["vulkan-spirv"])
+```
diff --git a/packaging/python/dummy_exclude_from_package.py b/bindings/python/pyiree/compiler2/__init__.py
similarity index 87%
rename from packaging/python/dummy_exclude_from_package.py
rename to bindings/python/pyiree/compiler2/__init__.py
index 0ca47f8..7ba6cb7 100644
--- a/packaging/python/dummy_exclude_from_package.py
+++ b/bindings/python/pyiree/compiler2/__init__.py
@@ -1,3 +1,4 @@
+# Lint-as: python3
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -11,3 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+from .core import *
+from .tools import CompilerToolError
diff --git a/bindings/python/pyiree/compiler2/core.py b/bindings/python/pyiree/compiler2/core.py
new file mode 100644
index 0000000..cdc7900
--- /dev/null
+++ b/bindings/python/pyiree/compiler2/core.py
@@ -0,0 +1,191 @@
+# Lint-as: python3
+"""Core compiler interface."""
+
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from enum import Enum
+import subprocess
+from typing import Any, Dict, List, Optional, Sequence, Union
+
+from .tools import *
+
+__all__ = [
+ "DEFAULT_TESTING_BACKENDS",
+ "compile_file",
+ "compile_str",
+ "CompilerOptions",
+ "OutputFormat",
+]
+
+# Default testing backend for invoking the compiler.
+DEFAULT_TESTING_BACKENDS = ["vmla"]
+
+
+class OutputFormat(Enum):
+ """The output format of the compiler."""
+ FLATBUFFER_BINARY = "flatbuffer-binary"
+ FLATBUFFER_TEXT = "flatbuffer-text"
+ MLIR_TEXT = "mlir-text"
+
+ @staticmethod
+ def parse(spec: Union[str, "OutputFormat"]) -> "OutputFormat":
+ """Parses or returns an OutputFormat.
+
+ Args:
+ spec: An OutputFormat instance or the case-insensitive name of one of
+ the enum values.
+ Returns:
+ An OutputFormat instance.
+ """
+ if isinstance(spec, OutputFormat):
+ return spec
+ spec = spec.upper()
+ if spec not in OutputFormat.__members__:
+ raise ValueError(f"For output_format= argument, expected one of: "
+ f"{', '.join(OutputFormat.__members__.keys())}")
+ return OutputFormat[spec]
+
+
+class CompilerOptions:
+ """Options to the compiler backend.
+
+ Keyword options:
+ output_file: Optionally save the compiled binary to a file instead of
+ returning it.
+ target_backends: List of str names of target backends to compile into
+ the binary. The resulting binary will run on targets that match one
+ or more of the compiled backends.
+ output_format: Override the output format. See the OutputFormat enum.
+ Values can either be an enum value or a case-insensitive name of
+ the option. Typically used for debugging
+ extra_args: Optional list of additional arguments to pass to the compiler.
+ Example: ["--print-ir-after-all"]
+ optimize: Whether to apply some default high level optimizations (default
+ True).
+ strip_debug_ops: Whether to strip high level operations used to aid
+ debugging.
+ strip_source_map: Whether to strip source map information (used to generate
+ better errors).
+ strip_symbols: Whether to strip extra symbols not needed for execution
+ (but which may aid debugging).
+ crash_reproducer_path: File name to output an MLIR crash dump to if there
+ is a compiler failure.
+ enable_benchmark: Whether to generate instrumented binaries suitable
+ for benchmarking.
+ """
+
+ def __init__(self,
+ *,
+ output_file: Optional[str] = None,
+ target_backends: Sequence[str] = (),
+ output_format: Union[OutputFormat,
+ str] = OutputFormat.FLATBUFFER_BINARY,
+ extra_args: Sequence[str] = (),
+ optimize: bool = True,
+ strip_debug_ops: bool = False,
+ strip_source_map: bool = False,
+ strip_symbols: bool = False,
+ crash_reproducer_path: Optional[str] = None,
+ enable_benchmark: bool = False):
+ self.output_file = output_file
+ self.target_backends = target_backends
+ self.output_format = OutputFormat.parse(output_format)
+ self.extra_args = extra_args
+ self.optimize = optimize
+ self.strip_debug_ops = strip_debug_ops
+ self.strip_source_map = strip_source_map
+ self.strip_symbols = strip_symbols
+ self.crash_reproducer_path = crash_reproducer_path
+ self.enable_benchmark = enable_benchmark
+
+
+def build_compile_command_line(input_file: str,
+ options: CompilerOptions) -> List[str]:
+ """Builds a command line for invoking the compiler.
+
+ Args:
+ input_file: The input file name.
+ options: Compiler options.
+ Returns:
+ List of strings of command line.
+ """
+ iree_translate = find_tool("iree-translate")
+ if not options.target_backends:
+ raise ValueError("Expected a non-empty list for 'target_backends'")
+
+ cl = [
+ iree_translate,
+ input_file,
+ f"--iree-vm-bytecode-module-output-format={options.output_format.value}",
+ f"--iree-hal-target-backends={','.join(options.target_backends)}",
+ ]
+
+ # Output file.
+ if options.output_file:
+ cl.append(f"-o={options.output_file}")
+
+ # Translation to perform.
+ cl.append("--iree-mlir-to-vm-bytecode-module" if not options.enable_benchmark
+ else "--iree-mlir-to-executable-benchmark-vm-module")
+
+ # Other options to set if specified.
+ if options.strip_debug_ops:
+ cl.append("--iree-vm-bytecode-module-strip-debug-ops")
+ if options.strip_source_map:
+ cl.append("--iree-vm-bytecode-module-strip-source-map")
+ if options.strip_symbols:
+ cl.append("--iree-vm-bytecode-module-strip-symbols")
+ if options.crash_reproducer_path:
+ cl.append(
+ f"--pass-pipeline-crash-reproducer={options.crash_reproducer_path}")
+
+ cl.extend(options.extra_args)
+ return cl
+
+
+def compile_file(input_file: str, **kwargs):
+ """Invokes the IREE compiler on an input file.
+
+ Args:
+ input_file: File containing MLIR assembly to compile.
+ **kwargs: Keyword arguments corresponding to CompilerOptions.
+ Returns:
+ Either a byte buffer of the compiled content or None if output_file
+ was specified in the options.
+ """
+ options = CompilerOptions(**kwargs)
+ cl = build_compile_command_line(input_file, options)
+ result = invoke_immediate(cl)
+ if options.output_file:
+ return None
+ return result
+
+
+def compile_str(input_str: str, **kwargs):
+ """Invokes the IREE compiler with an input string.
+
+ Args:
+ input_str: MLIR assembly to parse/compile.
+ **kwargs: Keyword arguments corresponding to CompilerOptions.
+ Returns:
+ Either a byte buffer of the compiled content or None if output_file
+ was specified in the options.
+ """
+ options = CompilerOptions(**kwargs)
+ cl = build_compile_command_line("-", options)
+ result = invoke_immediate(cl, immediate_input=input_str.encode("utf-8"))
+ if options.output_file:
+ return None
+ return result
diff --git a/bindings/python/pyiree/compiler2/setup.py b/bindings/python/pyiree/compiler2/setup.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/bindings/python/pyiree/compiler2/setup.py
diff --git a/bindings/python/pyiree/compiler2/tf.py b/bindings/python/pyiree/compiler2/tf.py
new file mode 100644
index 0000000..c9608af
--- /dev/null
+++ b/bindings/python/pyiree/compiler2/tf.py
@@ -0,0 +1,188 @@
+# Lint-as: python3
+"""TensorFlow compiler interface."""
+
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from enum import Enum
+import logging
+import tempfile
+from typing import List, Optional, Sequence, Set, Union
+
+from .tools import find_tool, invoke_immediate, invoke_pipeline
+from .core import CompilerOptions, DEFAULT_TESTING_BACKENDS, build_compile_command_line
+
+__all__ = [
+ "compile_saved_model",
+ "compile_module",
+ "is_available",
+ "DEFAULT_TESTING_BACKENDS",
+ "ImportOptions",
+ "ImportType",
+]
+
+_TF_IMPORT_TOOL = "iree-tf-import"
+
+
+def is_available():
+ """Determine if TensorFlow and the compiler are available."""
+ try:
+ import tensorflow as tf
+ except ModuleNotFoundError:
+ logging.warn("Unable to import tensorflow")
+ return False
+ try:
+ find_tool(_TF_IMPORT_TOOL)
+ except ValueError:
+ logging.warning("Unable to find IREE tool %s", _TF_IMPORT_TOOL)
+ return False
+ return True
+
+
+class ImportType(Enum):
+ """Import type of the model."""
+ OBJECT_GRAPH = "savedmodel_v2"
+ V2 = "savedmodel_v2"
+ SIGNATURE_DEF = "savedmodel_v1"
+ V1 = "savedmodel_v1"
+
+ @staticmethod
+ def parse(spec: Union[str, "ImportType"]) -> "ImportType":
+ """Parses or returns an ImportType.
+
+ Args:
+ spec: An ImportType instance or the case-insensitive name of one of
+ the enum values.
+ Returns:
+ An ImportType instance.
+ """
+ if isinstance(spec, ImportType):
+ return spec
+ spec = spec.upper()
+ if spec not in ImportType.__members__:
+ raise ValueError(f"For import_type= argument, expected one of: "
+ f"{', '.join(ImportType.__members__.keys())}")
+ return ImportType[spec]
+
+
+class ImportOptions(CompilerOptions):
+ """Import options layer on top of the backend compiler options."""
+
+ def __init__(self,
+ exported_names: Sequence[str] = (),
+ import_only: bool = False,
+ import_type: Union[ImportType, str] = ImportType.OBJECT_GRAPH,
+ saved_model_tags: Set[str] = set(),
+ import_extra_args: Sequence[str] = (),
+ **kwargs):
+ """Initialize options from keywords.
+
+ Args:
+ exported_names: Optional sequence representing the exported names to
+ keep (object graph/v2 models only).
+ import_only: Only import the module. If True, the result will be textual
+ MLIR that can be further fed to the IREE compiler. If False (default),
+ the result will be the fully compiled IREE binary. In both cases,
+ bytes-like output is returned. Note that if the output_file= is
+ specified and import_only=True, then the MLIR form will be written to
+ the output file.
+ import_type: Type of import to perform. See ImportType enum.
+ saved_model_tags: Set of tags to export (signature def/v1 saved models
+ only).
+ import_extra_args: Extra arguments to pass to the iree-tf-import tool.
+ """
+ super().__init__(**kwargs)
+ self.exported_names = exported_names
+ self.import_only = import_only
+ self.import_type = ImportType.parse(import_type)
+ self.saved_model_tags = saved_model_tags
+ self.import_extra_args = import_extra_args
+
+
+def build_import_command_line(input_path: str,
+ options: ImportOptions) -> List[str]:
+ """Builds a command line for invoking the import stage.
+
+ Args:
+ input_path: The input path.
+ options: Import options.
+ Returns:
+ List of strings of command line.
+ """
+ tf_import = find_tool(_TF_IMPORT_TOOL)
+ cl = [
+ tf_import,
+ input_path,
+ f"--tf-import-type={options.import_type.value}",
+ f"--tf-savedmodel-exported-names={','.join(options.exported_names)}",
+ f"--tf-savedmodel-tags={','.join(options.saved_model_tags)}",
+ ]
+ if options.import_only and options.output_file:
+ # Import stage directly outputs.
+ if options.output_file:
+ cl.append(f"-o={options.output_file}")
+ cl.extend(options.import_extra_args)
+ return cl
+
+
+def compile_saved_model(saved_model_dir: str, **kwargs):
+ """Compiles an on-disk saved model to an IREE binary.
+
+ Args:
+ saved_model_dir: Path to directory where the model was saved.
+ **kwargs: Keyword args corresponding to ImportOptions or CompilerOptions.
+ Returns:
+ A bytes-like object with the compiled output or None if output_file=
+ was specified.
+ """
+ options = ImportOptions(**kwargs)
+ import_cl = build_import_command_line(saved_model_dir, options)
+ if options.import_only:
+ # One stage tool pipeline.
+ result = invoke_immediate(import_cl)
+ if options.output_file:
+ return None
+ return result
+
+ # Full compilation pipeline.
+ compile_cl = build_compile_command_line("-", options)
+ result = invoke_pipeline([import_cl, compile_cl])
+ if options.output_file:
+ return None
+ return result
+
+
+def compile_module(module, saved_model_dir: Optional[str] = None, **kwargs):
+ """Compiles a tf.Module to an IREE binary (by saving to disk).
+
+ Args:
+ module: The tf.Module instance to convert to MLIR
+ saved_model_dir: Optional path to save the tf.Module to. The module will not
+ be persisted on disk outside of this call if this is not provided.
+ **kwargs: Keyword args corresponding to ImportOptions or CompilerOptions.
+ Returns:
+ Same as compile_saved_model().
+ """
+
+ def do_it(saved_model_dir):
+ import tensorflow as tf
+ options = tf.saved_model.SaveOptions(save_debug_info=True)
+ tf.saved_model.save(module, saved_model_dir, options=options)
+ return compile_saved_model(saved_model_dir, **kwargs)
+
+ if saved_model_dir:
+ return do_it(saved_model_dir)
+ else:
+ with tempfile.TemporaryDirectory(suffix=".sm") as td:
+ return do_it(td)
diff --git a/bindings/python/pyiree/compiler2/tools.py b/bindings/python/pyiree/compiler2/tools.py
new file mode 100644
index 0000000..480600b
--- /dev/null
+++ b/bindings/python/pyiree/compiler2/tools.py
@@ -0,0 +1,254 @@
+# Lint-as: python3
+"""Utilities for locating and invoking compiler tools."""
+
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+import io
+import os
+import subprocess
+import sys
+import textwrap
+import threading
+
+from typing import List, Optional
+
+__all__ = [
+ "find_tool",
+ "invoke_immediate",
+ "invoke_pipeline",
+ "get_tool_path",
+ "CompilerToolError",
+]
+
+# In normal distribution circumstances, each named tool is associated with
+# a python module that provides a `get_tool` function for getting its absolute
+# path. This dictionary maps the tool name to the module.
+_TOOL_MODULE_MAP = {
+ "iree-tf-import": "pyiree.tools.tf",
+ "iree-translate": "pyiree.tools.core",
+}
+
+# Map of tool module to package name as distributed to archives (used for
+# error messages).
+_TOOL_MODULE_PACKAGES = {
+ "pyiree.tools.core": "google-iree-tools-core",
+ "pyiree.tools.tf": "google-iree-tools-tf",
+}
+
+# Environment variable holding directories to be searched for named tools.
+# Delimitted by os.pathsep.
+_TOOL_PATH_ENVVAR = "IREE_TOOL_PATH"
+
+
+class CompilerToolError(Exception):
+ """Compiler exception that preserves the command line and error output."""
+
+ def __init__(self, process: subprocess.CompletedProcess):
+ try:
+ errs = process.stderr.decode("utf-8")
+ except:
+ errs = str(process.stderr) # Decode error or other: best we can do.
+
+ tool_name = os.path.basename(process.args[0])
+ super().__init__(f"Error invoking IREE compiler tool {tool_name}\n"
+ f"Diagnostics:\n{errs}\n\n"
+ f"Invoked with:\n {' '.join(process.args)}")
+
+
+def get_tool_path() -> List[str]:
+ """Returns list of paths to search for tools."""
+ list_str = os.environ.get(_TOOL_PATH_ENVVAR)
+ if not list_str:
+ return []
+ return list_str.split(os.pathsep)
+
+
+def find_tool(exe_name: str) -> str:
+ """Finds a tool by its (extension-less) executable name.
+
+ Args:
+ exe_name: The name of the executable (extension-less).
+ Returns:
+ An absolute path to the tool.
+ Raises:
+ ValueError: If the tool is not known or not found.
+ """
+ if exe_name not in _TOOL_MODULE_MAP:
+ raise ValueError(f"IREE compiler tool '{exe_name}' is not a known tool")
+ # First search an explicit tool path.
+ tool_path = get_tool_path()
+ for path_entry in tool_path:
+ if not path_entry:
+ continue
+ candidate_exe = os.path.join(path_entry, exe_name)
+ if os.path.isfile(candidate_exe) and os.access(candidate_exe, os.X_OK):
+ return candidate_exe
+
+ # Attempt to load the tool module.
+ tool_module_name = _TOOL_MODULE_MAP[exe_name]
+ tool_module_package = _TOOL_MODULE_PACKAGES[tool_module_name]
+ try:
+ tool_module = importlib.import_module(tool_module_name)
+ except ModuleNotFoundError:
+ raise ValueError(
+ f"IREE compiler tool '{exe_name}' is not installed (it should have been "
+ f"found in the python module '{tool_module_name}', typically installed "
+ f"via the package {tool_module_package}).\n\n"
+ f"Either install the package or set the {_TOOL_PATH_ENVVAR} environment "
+ f"variable to contain the path of the tool executable "
+ f"(current {_TOOL_PATH_ENVVAR} = {repr(tool_path)})") from None
+
+ # Ask the module for its tool.
+ candidate_exe = tool_module.get_tool(exe_name)
+ if (not candidate_exe or not os.path.isfile(candidate_exe) or
+ not os.access(candidate_exe, os.X_OK)):
+ raise ValueError(
+ f"IREE compiler tool '{exe_name}' was located in module "
+ f"'{tool_module_name}' but the file was not found or not executable: "
+ f"{candidate_exe}")
+ return candidate_exe
+
+
+def invoke_immediate(command_line: List[str],
+ *,
+ input_file: Optional[str] = None,
+ immediate_input=None):
+ """Invokes an immediate command.
+
+ This is separate from invoke_pipeline as it is simpler and supports more
+ complex input redirection, using recommended facilities for sub-processes
+ (less magic).
+
+ Note that this differs from the usual way of using subprocess.run or
+ subprocess.Popen().communicate() because we need to pump all of the error
+ streams individually and only pump pipes not connected to a different stage.
+ Uses threads to pump everything that is required.
+ """
+ run_args = {}
+ input_file_handle = None
+ stderr_handle = sys.stderr
+ try:
+ # Redirect input.
+ if input_file is not None:
+ input_file_handle = open(input_file, "rb")
+ run_args["stdin"] = input_file_handle
+ elif immediate_input is not None:
+ run_args["input"] = immediate_input
+
+ # Capture output.
+ # Upgrade note: Python >= 3.7 can just use capture_output=True
+ run_args["stdout"] = subprocess.PIPE
+ run_args["stderr"] = subprocess.PIPE
+ process = subprocess.run(command_line, **run_args)
+ if process.returncode != 0:
+ raise CompilerToolError(process)
+ # Emit stderr contents.
+ _write_binary_stderr(stderr_handle, process.stderr)
+ return process.stdout
+ finally:
+ if input_file_handle:
+ input_file_handle.close()
+
+
+def invoke_pipeline(command_lines: List[List[str]]):
+ """Invoke a pipeline of commands.
+
+ The first stage of the pipeline will have its stdin set to DEVNULL and each
+ subsequent stdin will derive from the prior stdout. The final stdout will
+ be accumulated and returned. All stderr contents are accumulated and printed
+ to stderr on completion or the first failing stage of the pipeline will have
+ an exception raised with its stderr output.
+ """
+ stages = []
+ prev_out = subprocess.DEVNULL
+ stderr_handle = sys.stderr
+
+ # Create all stages.
+ for i in range(len(command_lines)):
+ command_line = command_lines[i]
+ popen_args = {
+ "stdin": prev_out,
+ "stdout": subprocess.PIPE,
+ "stderr": subprocess.PIPE,
+ }
+ process = subprocess.Popen(command_line, **popen_args)
+ prev_out = process.stdout
+ capture_output = (i == (len(command_lines) - 1))
+ stages.append(_PipelineStage(process, capture_output))
+
+ # Start stages.
+ for stage in stages:
+ stage.start()
+
+ # Join.
+ for stage in stages:
+ stage.join()
+
+ # Check for errors.
+ for stage in stages:
+ assert stage.completed
+ if stage.completed.returncode != 0:
+ raise CompilerToolError(stage.completed)
+
+ # Print any stderr output.
+ for stage in stages:
+ _write_binary_stderr(stderr_handle, stage.errs)
+ return stages[-1].outs
+
+
+class _PipelineStage(threading.Thread):
+ """Wraps a process and pumps its handles, waiting for completion."""
+
+ def __init__(self, process, capture_output):
+ super().__init__()
+ self.process = process
+ self.capture_output = capture_output
+ self.completed: Optional[subprocess.CompletedProcess] = None
+ self.outs = None
+ self.errs = None
+
+ def pump_stderr(self):
+ self.errs = self.process.stderr.read()
+
+ def pump_stdout(self):
+ self.outs = self.process.stdout.read()
+
+ def run(self):
+ stderr_thread = threading.Thread(target=self.pump_stderr)
+ stderr_thread.start()
+ if self.capture_output:
+ stdout_thread = threading.Thread(target=self.pump_stdout)
+ stdout_thread.start()
+ self.process.wait()
+ stderr_thread.join()
+ if self.capture_output:
+ stdout_thread.join()
+ self.completed = subprocess.CompletedProcess(self.process.args,
+ self.process.returncode,
+ self.outs, self.errs)
+ self.process.stderr.close()
+ self.process.stdout.close()
+
+
+def _write_binary_stderr(out_handle, contents):
+ # Fast-paths buffered text-io (which stderr is by default) while allowing
+ # full decode for non buffered and binary io.
+ if hasattr(out_handle, "buffer"):
+ out_handle.buffer.write(contents)
+ elif isinstance(out_handle, io.TextIOBase):
+ out_handle.write(contents.decode("utf-8"))
+ else:
+ out_handle.write(contents)
diff --git a/bindings/python/pyiree/rt/CMakeLists.txt b/bindings/python/pyiree/rt/CMakeLists.txt
index 7828073..ef0f5bb 100644
--- a/bindings/python/pyiree/rt/CMakeLists.txt
+++ b/bindings/python/pyiree/rt/CMakeLists.txt
@@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+# Static and generated files.
+configure_file(README.md README.md COPYONLY)
+
iree_pyext_library(
NAME
PyExtRtLib
diff --git a/bindings/python/pyiree/rt/README.md b/bindings/python/pyiree/rt/README.md
new file mode 100644
index 0000000..44b94e1
--- /dev/null
+++ b/bindings/python/pyiree/rt/README.md
@@ -0,0 +1,4 @@
+# IREE Python Runtime Components
+
+This package provides an API for running compiled IREE binaries and interfacing
+with the hardware-abstraction-layer.
diff --git a/bindings/python/pyiree/tools/core/.skip_bazel_to_cmake b/bindings/python/pyiree/tools/core/.skip_bazel_to_cmake
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/bindings/python/pyiree/tools/core/.skip_bazel_to_cmake
diff --git a/bindings/python/pyiree/CMakeLists.txt b/bindings/python/pyiree/tools/core/CMakeLists.txt
similarity index 77%
copy from bindings/python/pyiree/CMakeLists.txt
copy to bindings/python/pyiree/tools/core/CMakeLists.txt
index 55ac43d..9e4364c 100644
--- a/bindings/python/pyiree/CMakeLists.txt
+++ b/bindings/python/pyiree/tools/core/CMakeLists.txt
@@ -12,9 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-add_subdirectory(common)
-add_subdirectory(rt)
+iree_py_library(
+ NAME
+ core
+ SRCS
+ "__init__.py"
+)
if(${IREE_BUILD_COMPILER})
- add_subdirectory(compiler)
+ iree_symlink_tool(
+ TARGET core
+ FROM_TOOL_TARGET iree_tools_iree-translate
+ TO_EXE_NAME iree-translate
+ )
endif()
diff --git a/bindings/python/pyiree/CMakeLists.txt b/bindings/python/pyiree/tools/core/__init__.py
similarity index 64%
copy from bindings/python/pyiree/CMakeLists.txt
copy to bindings/python/pyiree/tools/core/__init__.py
index 55ac43d..4ce9c14 100644
--- a/bindings/python/pyiree/CMakeLists.txt
+++ b/bindings/python/pyiree/tools/core/__init__.py
@@ -1,3 +1,6 @@
+# Lint-as: python3
+"""Core tools."""
+
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,9 +15,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-add_subdirectory(common)
-add_subdirectory(rt)
+from typing import Optional
-if(${IREE_BUILD_COMPILER})
- add_subdirectory(compiler)
-endif()
+import os
+import platform
+
+
+def get_tool(exe_name: str) -> Optional[str]:
+ if platform.system() == "Windows":
+ exe_name = exe_name + ".exe"
+ this_path = os.path.dirname(__file__)
+ tool_path = os.path.join(this_path, exe_name)
+ return tool_path
diff --git a/bindings/python/setup.py b/bindings/python/setup.py
new file mode 100644
index 0000000..3a669cc
--- /dev/null
+++ b/bindings/python/setup.py
@@ -0,0 +1,41 @@
+#!/usr/bin/python3
+
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Build platform specific wheel files for the pyiree.rt package.
+# Built artifacts are per-platform and build out of the build tree.
+
+import os
+import subprocess
+import sys
+
+# Make this setup position independent and make it not conflict with
+# parallel scripts.
+this_dir = os.path.abspath(os.path.dirname(__file__))
+
+
+def run_sub_setup(name):
+ sub_path = os.path.join(this_dir, f"{name}.py")
+ args = [sys.executable, sub_path] + sys.argv[1:]
+ print(f"##### Running sub setup: {' '.join(args)}")
+ subprocess.check_call(args)
+ print("")
+
+
+run_sub_setup("setup_compiler")
+run_sub_setup("setup_runtime")
+run_sub_setup("setup_tools_core")
+if os.path.exists(os.path.join(this_dir, "pyiree/tools/tf")):
+ run_sub_setup("setup_tools_tf")
diff --git a/bindings/python/setup_compiler.py.in b/bindings/python/setup_compiler.py.in
new file mode 100644
index 0000000..f8d637d
--- /dev/null
+++ b/bindings/python/setup_compiler.py.in
@@ -0,0 +1,57 @@
+#!/usr/bin/python3
+
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Build platform specific wheel files for the pyiree.rt package.
+# Built artifacts are per-platform and build out of the build tree.
+
+import os
+from setuptools import setup, find_namespace_packages
+
+# Make this setup position independent and make it not conflict with
+# parallel scripts.
+this_dir = os.path.abspath(os.path.dirname(__file__))
+setup_dir = os.path.join(this_dir, "setupbuild", "compiler")
+os.makedirs(setup_dir, exist_ok=True)
+os.chdir(setup_dir)
+
+def read(fname):
+ return open(os.path.join(this_dir, fname), "rt").read()
+
+
+setup(
+ name="google-iree-compiler",
+ version="@IREE_PYTHON_VERSION@",
+ author="The IREE Team",
+ author_email="iree-discuss@googlegroups.com",
+ license="Apache",
+ description="IREE Python Compiler API",
+ long_description=read("pyiree/compiler2/README.md"),
+ long_description_content_type="text/markdown",
+ url="https://github.com/google/iree",
+ classifiers=[
+ "Programming Language :: Python :: 3",
+ "License :: OSI Approved :: Apache License",
+ "Operating System :: OS Independent",
+ "Development Status :: 3 - Alpha",
+ ],
+ python_requires=">=3.6",
+ package_dir={"": this_dir},
+ packages=find_namespace_packages(
+ where=this_dir,
+ include=["pyiree.compiler2", "pyiree.compiler2.*"],
+ exclude=["*.CMakeFiles"]),
+ zip_safe=False, # This package is fine but not zipping is more versatile.
+)
diff --git a/bindings/python/setup_runtime.py.in b/bindings/python/setup_runtime.py.in
new file mode 100644
index 0000000..5893017
--- /dev/null
+++ b/bindings/python/setup_runtime.py.in
@@ -0,0 +1,66 @@
+#!/usr/bin/python3
+
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Build platform specific wheel files for the pyiree.rt package.
+# Built artifacts are per-platform and build out of the build tree.
+
+import os
+from setuptools import setup, find_namespace_packages, Extension
+import sysconfig
+
+# Make this setup position independent and make it not conflict with
+# parallel scripts.
+this_dir = os.path.abspath(os.path.dirname(__file__))
+setup_dir = os.path.join(this_dir, "setupbuild", "runtime")
+os.makedirs(setup_dir, exist_ok=True)
+os.chdir(setup_dir)
+
+
+def read(fname):
+ return open(os.path.join(this_dir, fname), "rt").read()
+
+
+setup(
+ name="google-iree-runtime",
+ version="@IREE_PYTHON_VERSION@",
+ author="The IREE Team",
+ author_email="iree-discuss@googlegroups.com",
+ license="Apache",
+ description="IREE Python Runtime Components",
+ long_description=read("pyiree/rt/README.md"),
+ long_description_content_type="text/markdown",
+ url="https://github.com/google/iree",
+ classifiers=[
+ "Programming Language :: Python :: 3",
+ "License :: OSI Approved :: Apache License",
+ "Operating System :: OS Independent",
+ "Development Status :: 3 - Alpha",
+ ],
+ python_requires=">=3.6",
+ package_dir={"": this_dir},
+ packages=find_namespace_packages(where=this_dir,
+ include=["pyiree.rt", "pyiree.rt.*"],
+ exclude=["*.CMakeFiles"]),
+ ext_modules=[
+ Extension(name="pyiree.rt.binding", sources=[]),
+ ],
+ # Matching the native extension as a data file keeps setuptools from
+ # "building" it (i.e. turning it into a static binary).
+ package_data={
+ "": [f"*{sysconfig.get_config_var('EXT_SUFFIX')}"],
+ },
+ zip_safe=False,
+)
diff --git a/bindings/python/setup_tools_core.py.in b/bindings/python/setup_tools_core.py.in
new file mode 100644
index 0000000..55104f5
--- /dev/null
+++ b/bindings/python/setup_tools_core.py.in
@@ -0,0 +1,80 @@
+#!/usr/bin/python3
+
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Build platform specific wheel files for the pyiree.rt package.
+# Built artifacts are per-platform and build out of the build tree.
+
+import os
+import platform
+from setuptools import setup, find_namespace_packages
+
+# Make this setup position independent and make it not conflict with
+# parallel scripts.
+this_dir = os.path.abspath(os.path.dirname(__file__))
+setup_dir = os.path.join(this_dir, "setupbuild", "tools_core")
+os.makedirs(setup_dir, exist_ok=True)
+os.chdir(setup_dir)
+
+exe_suffix = ".exe" if platform.system() == "Windows" else ""
+
+
+def read(fname):
+ return open(os.path.join(this_dir, fname), "rt").read()
+
+
+# Force platform specific wheel.
+# https://stackoverflow.com/questions/45150304
+try:
+ from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
+
+ class bdist_wheel(_bdist_wheel):
+
+ def finalize_options(self):
+ _bdist_wheel.finalize_options(self)
+ self.root_is_pure = False
+except ImportError:
+ bdist_wheel = None
+
+setup(
+ name="google-iree-tools-core",
+ version="@IREE_PYTHON_VERSION@",
+ author="The IREE Team",
+ author_email="iree-discuss@googlegroups.com",
+ license="Apache",
+ description="IREE Python Core Tools Binaries",
+ long_description=
+ "Package containing platform-specific binaries for core compiler tools",
+ long_description_content_type="text/plain",
+ url="https://github.com/google/iree",
+ classifiers=[
+ "Programming Language :: Python :: 3",
+ "License :: OSI Approved :: Apache License",
+ "Operating System :: OS Independent",
+ "Development Status :: 3 - Alpha",
+ ],
+ python_requires=">=3.6",
+ package_dir={"": this_dir},
+ packages=find_namespace_packages(where=this_dir,
+ include=["pyiree.tools.core"],
+ exclude=["*.CMakeFiles"]),
+ # Matching the native extension as a data file keeps setuptools from
+ # "building" it (i.e. turning it into a static binary).
+ package_data={
+ "pyiree.tools.core": [f"iree-translate{exe_suffix}",],
+ },
+ cmdclass={'bdist_wheel': bdist_wheel},
+ zip_safe=False,
+)
diff --git a/bindings/python/setup_tools_tf.py.in b/bindings/python/setup_tools_tf.py.in
new file mode 100644
index 0000000..a1b292e
--- /dev/null
+++ b/bindings/python/setup_tools_tf.py.in
@@ -0,0 +1,81 @@
+#!/usr/bin/python3
+
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Build platform specific wheel files for the pyiree.rt package.
+# Built artifacts are per-platform and build out of the build tree.
+
+import os
+import platform
+from setuptools import setup, find_namespace_packages
+
+# Make this setup position independent and make it not conflict with
+# parallel scripts.
+this_dir = os.path.abspath(os.path.dirname(__file__))
+setup_dir = os.path.join(this_dir, "setupbuild", "tools_tf")
+os.makedirs(setup_dir, exist_ok=True)
+os.chdir(setup_dir)
+
+exe_suffix = ".exe" if platform.system() == "Windows" else ""
+
+
+def read(fname):
+ return open(os.path.join(this_dir, fname), "rt").read()
+
+
+# Force platform specific wheel.
+# https://stackoverflow.com/questions/45150304
+try:
+ from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
+
+ class bdist_wheel(_bdist_wheel):
+
+ def finalize_options(self):
+ _bdist_wheel.finalize_options(self)
+ self.root_is_pure = False
+except ImportError:
+ bdist_wheel = None
+
+setup(
+ name="google-iree-tools-tf",
+ version="@IREE_PYTHON_VERSION@",
+ author="The IREE Team",
+ author_email="iree-discuss@googlegroups.com",
+ license="Apache",
+ description="IREE Python TensorFlow Tools Binaries",
+ long_description=
+ "Package containing platform-specific binaries for TensorFlow "
+ "compiler tools",
+ long_description_content_type="text/plain",
+ url="https://github.com/google/iree",
+ classifiers=[
+ "Programming Language :: Python :: 3",
+ "License :: OSI Approved :: Apache License",
+ "Operating System :: OS Independent",
+ "Development Status :: 3 - Alpha",
+ ],
+ python_requires=">=3.6",
+ package_dir={"": this_dir},
+ packages=find_namespace_packages(where=this_dir,
+ include=["pyiree.tools.tf"],
+ exclude=["*.CMakeFiles"]),
+ # Matching the native extension as a data file keeps setuptools from
+ # "building" it (i.e. turning it into a static binary).
+ package_data={
+ "pyiree.tools.tf": [f"iree-tf-import{exe_suffix}",],
+ },
+ cmdclass={'bdist_wheel': bdist_wheel},
+ zip_safe=False,
+)
diff --git a/bindings/python/tests/.skip_bazel_to_cmake b/bindings/python/tests/.skip_bazel_to_cmake
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/bindings/python/tests/.skip_bazel_to_cmake
diff --git a/bindings/python/pyiree/CMakeLists.txt b/bindings/python/tests/CMakeLists.txt
similarity index 78%
rename from bindings/python/pyiree/CMakeLists.txt
rename to bindings/python/tests/CMakeLists.txt
index 55ac43d..701bb72 100644
--- a/bindings/python/pyiree/CMakeLists.txt
+++ b/bindings/python/tests/CMakeLists.txt
@@ -12,9 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-add_subdirectory(common)
-add_subdirectory(rt)
+iree_py_test(
+ NAME
+ compiler_core_test
+ SRCS
+ "compiler_core_test.py"
+)
-if(${IREE_BUILD_COMPILER})
- add_subdirectory(compiler)
-endif()
+iree_py_test(
+ NAME
+ compiler_tf_test
+ SRCS
+ "compiler_tf_test.py"
+)
diff --git a/bindings/python/tests/README.md b/bindings/python/tests/README.md
new file mode 100644
index 0000000..a0a288e
--- /dev/null
+++ b/bindings/python/tests/README.md
@@ -0,0 +1,8 @@
+# Python API Tests
+
+These tests are run in an environment where all available Python bindings
+are setup on the `PYTHONPATH`. Each will internally skip itself if optional
+components are not available.
+
+Note that IREE compiler tool locations can be overriden by specifying the
+`IREE_TOOL_PATH` environment variable.
diff --git a/bindings/python/tests/compiler_core_test.py b/bindings/python/tests/compiler_core_test.py
new file mode 100644
index 0000000..54c15bd
--- /dev/null
+++ b/bindings/python/tests/compiler_core_test.py
@@ -0,0 +1,138 @@
+# Lint as: python3
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import contextlib
+import logging
+import os
+import io
+import tempfile
+import unittest
+
+from pyiree import compiler2 as compiler
+
+SIMPLE_MUL_ASM = """
+func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
+ attributes { iree.module.export } {
+ %0 = "mhlo.multiply"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ return %0 : tensor<4xf32>
+}
+"""
+
+
+class CompilerTest(unittest.TestCase):
+
+ def testNoTargetBackends(self):
+ with self.assertRaisesRegex(
+ ValueError, "Expected a non-empty list for 'target_backends'"):
+ binary = compiler.compile_str(SIMPLE_MUL_ASM)
+
+ def testCompileStr(self):
+ binary = compiler.compile_str(
+ SIMPLE_MUL_ASM, target_backends=compiler.DEFAULT_TESTING_BACKENDS)
+ logging.info("Flatbuffer size = %d", len(binary))
+ self.assertTrue(binary)
+
+ def testCompileInputFile(self):
+ with tempfile.NamedTemporaryFile("wt", delete=False) as f:
+ try:
+ f.write(SIMPLE_MUL_ASM)
+ f.close()
+ binary = compiler.compile_file(
+ f.name, target_backends=compiler.DEFAULT_TESTING_BACKENDS)
+ finally:
+ os.remove(f.name)
+ logging.info("Flatbuffer size = %d", len(binary))
+ self.assertIn(b"simple_mul", binary)
+
+ def testCompileOutputFile(self):
+ with tempfile.NamedTemporaryFile("wt", delete=False) as f:
+ try:
+ f.close()
+ output = compiler.compile_str(
+ SIMPLE_MUL_ASM,
+ output_file=f.name,
+ target_backends=compiler.DEFAULT_TESTING_BACKENDS)
+ self.assertIsNone(output)
+
+ with open(f.name, "rb") as f_read:
+ binary = f_read.read()
+ finally:
+ os.remove(f.name)
+ logging.info("Flatbuffer size = %d", len(binary))
+ self.assertIn(b"simple_mul", binary)
+
+ def testOutputFbText(self):
+ text = compiler.compile_str(
+ SIMPLE_MUL_ASM,
+ output_format=compiler.OutputFormat.FLATBUFFER_TEXT,
+ target_backends=compiler.DEFAULT_TESTING_BACKENDS).decode("utf-8")
+ # Just check for an arbitrary JSON-tag.
+ self.assertIn('"exported_functions"', text)
+
+ def testBadOutputFormat(self):
+ with self.assertRaisesRegex(
+ ValueError, "For output_format= argument, expected one of: "
+ "FLATBUFFER_BINARY, FLATBUFFER_TEXT, MLIR_TEXT"):
+ _ = compiler.compile_str(
+ SIMPLE_MUL_ASM,
+ output_format="foobar",
+ target_backends=compiler.DEFAULT_TESTING_BACKENDS)
+
+ def testOutputFbTextParsed(self):
+ text = compiler.compile_str(
+ SIMPLE_MUL_ASM,
+ output_format='flatbuffer_text',
+ target_backends=compiler.DEFAULT_TESTING_BACKENDS).decode("utf-8")
+ # Just check for an arbitrary JSON-tag.
+ self.assertIn('"exported_functions"', text)
+
+ def testOutputMlirText(self):
+ text = compiler.compile_str(
+ SIMPLE_MUL_ASM,
+ output_format=compiler.OutputFormat.MLIR_TEXT,
+ target_backends=compiler.DEFAULT_TESTING_BACKENDS).decode("utf-8")
+ # Just check for a textual op name.
+ self.assertIn("vm.module", text)
+
+ def testExtraArgsStderr(self):
+ # pass-timing is not special: it just does something and emits to stderr.
+ with io.StringIO() as buf, contextlib.redirect_stderr(buf):
+ compiler.compile_str(SIMPLE_MUL_ASM,
+ extra_args=["--pass-timing"],
+ target_backends=compiler.DEFAULT_TESTING_BACKENDS)
+ stderr = buf.getvalue()
+ self.assertIn("Pass execution timing report", stderr)
+
+ def testAllOptions(self):
+ binary = compiler.compile_str(
+ SIMPLE_MUL_ASM,
+ optimize=False,
+ strip_debug_ops=True,
+ strip_source_map=True,
+ strip_symbols=True,
+ crash_reproducer_path="foobar.txt",
+ enable_benchmark=True,
+ target_backends=compiler.DEFAULT_TESTING_BACKENDS)
+
+ def testException(self):
+ with self.assertRaisesRegex(compiler.CompilerToolError, "Invoked with"):
+ _ = compiler.compile_str(
+ "I'm a little teapot but not a valid program",
+ target_backends=compiler.DEFAULT_TESTING_BACKENDS)
+
+
+if __name__ == "__main__":
+ logging.basicConfig(level=logging.DEBUG)
+ unittest.main()
diff --git a/bindings/python/tests/compiler_tf_test.py b/bindings/python/tests/compiler_tf_test.py
new file mode 100644
index 0000000..41fc94e
--- /dev/null
+++ b/bindings/python/tests/compiler_tf_test.py
@@ -0,0 +1,89 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+import unittest
+
+# TODO: No idea why pytype cannot find names from this module.
+# pytype: disable=name-error
+from pyiree.compiler2.tf import *
+
+if not is_available():
+ print(f"Skipping test {__file__} because the IREE TensorFlow compiler "
+ f"is not installed")
+ sys.exit(0)
+
+import tensorflow as tf
+
+
+class SimpleArithmeticModule(tf.Module):
+
+ @tf.function(input_signature=[
+ tf.TensorSpec([4], tf.float32),
+ tf.TensorSpec([4], tf.float32)
+ ])
+ def simple_mul(self, a, b):
+ return a * b
+
+ @tf.function(input_signature=[
+ tf.TensorSpec([128, 3072], tf.float32),
+ tf.TensorSpec([3072, 256], tf.float32),
+ ])
+ def simple_matmul(self, a, b):
+ return tf.matmul(a, b)
+
+
+# TODO(laurenzo): More test cases needed (may need additional files).
+# Specifically, figure out how to test v1 models.
+class TfCompilerTest(unittest.TestCase):
+
+ def testImportSavedModel(self):
+ import_mlir = compile_saved_model(self.smdir,
+ import_only=True).decode("utf-8")
+ self.assertIn("func @simple_matmul", import_mlir)
+
+ def testCompileSavedModel(self):
+ binary = compile_saved_model(self.smdir,
+ target_backends=DEFAULT_TESTING_BACKENDS)
+ logging.info("Compiled len: %d", len(binary))
+ self.assertIn(b"simple_matmul", binary)
+ self.assertIn(b"simple_mul", binary)
+
+ def testCompileModule(self):
+ binary = compile_module(self.m, target_backends=DEFAULT_TESTING_BACKENDS)
+ logging.info("Compiled len: %d", len(binary))
+ self.assertIn(b"simple_matmul", binary)
+ self.assertIn(b"simple_mul", binary)
+
+ @classmethod
+ def setUpClass(cls):
+ cls.m = SimpleArithmeticModule()
+ cls.tempdir = tempfile.TemporaryDirectory()
+ cls.smdir = os.path.join(cls.tempdir.name, "arith.sm")
+ tf.saved_model.save(
+ cls.m,
+ cls.smdir,
+ options=tf.saved_model.SaveOptions(save_debug_info=True))
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.tempdir.cleanup()
+
+
+if __name__ == "__main__":
+ logging.basicConfig(level=logging.DEBUG)
+ unittest.main()
diff --git a/build_tools/bazel/build_tensorflow.sh b/build_tools/bazel/build_tensorflow.sh
index cc9a407..055e586 100755
--- a/build_tools/bazel/build_tensorflow.sh
+++ b/build_tools/bazel/build_tensorflow.sh
@@ -84,7 +84,7 @@
--nosystem_rc --nohome_rc --noworkspace_rc \
--bazelrc=build_tools/bazel/iree.bazelrc \
query \
- //integrations/... + //colab/... + //packaging/... | \
+ //integrations/... + //colab/... | \
xargs bazel \
--nosystem_rc --nohome_rc --noworkspace_rc \
--bazelrc=build_tools/bazel/iree.bazelrc \
diff --git a/build_tools/cmake/bazel.bat.in b/build_tools/cmake/bazel.bat.in
new file mode 100644
index 0000000..93da790
--- /dev/null
+++ b/build_tools/cmake/bazel.bat.in
@@ -0,0 +1,17 @@
+@echo off
+REM Copyright 2020 Google LLC
+REM
+REM Licensed under the Apache License, Version 2.0 (the "License");
+REM you may not use this file except in compliance with the License.
+REM You may obtain a copy of the License at
+REM
+REM https://www.apache.org/licenses/LICENSE-2.0
+REM
+REM Unless required by applicable law or agreed to in writing, software
+REM distributed under the License is distributed on an "AS IS" BASIS,
+REM WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+REM See the License for the specific language governing permissions and
+REM limitations under the License.
+
+cd /d "@_bazel_src_root@"
+@IREE_BAZEL_EXECUTABLE@ @_bazel_startup_options@ %* || exit /b
diff --git a/packaging/python/dummy_exclude_from_package.py b/build_tools/cmake/bazel.sh.in
similarity index 84%
copy from packaging/python/dummy_exclude_from_package.py
copy to build_tools/cmake/bazel.sh.in
index 0ca47f8..ed2466f 100644
--- a/packaging/python/dummy_exclude_from_package.py
+++ b/build_tools/cmake/bazel.sh.in
@@ -1,3 +1,4 @@
+#!/bin/bash
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -11,3 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+cd "@_bazel_src_root@"
+exec '@IREE_BAZEL_EXECUTABLE@' @_bazel_startup_options@ "$@"
diff --git a/build_tools/cmake/configure_bazel.cmake b/build_tools/cmake/configure_bazel.cmake
new file mode 100644
index 0000000..528d244
--- /dev/null
+++ b/build_tools/cmake/configure_bazel.cmake
@@ -0,0 +1,158 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set(IREE_BAZEL_EXECUTABLE "bazel"
+ CACHE STRING "Bazel executable to use for bazel builds")
+
+# iree_configure_bazel
+#
+# Configures the CMake binary directory to also contain a bazel build root.
+# The following files will be created:
+# bazel (shell script): Shell wrapper to invoke bazel
+# bazel.bat: Windows batch file to invoke bazel
+# bazelrc: The bazelrc to use for the build
+# bazel-out/: Bazel output directory
+# bazel-bin/: Symlink to the bin directory appropriate for the build mode
+#
+# Variables will be set in the parent scope:
+# IREE_BAZEL_WRAPPER: Executable wrapper to invoke to run bazel
+# IREE_BAZEL_BIN: Path to the bazel-bin directory
+function(iree_configure_bazel)
+ set(_bazel_output_base "${CMAKE_BINARY_DIR}/bazel-out")
+ set(_bazel_src_root "${CMAKE_SOURCE_DIR}")
+
+ # Use the utility to emit _bazelrc_file configuration options.
+ set(_bazelrc_file "${CMAKE_BINARY_DIR}/bazelrc")
+ execute_process(
+ RESULT_VARIABLE RC
+ COMMAND
+ "${Python3_EXECUTABLE}"
+ "${_bazel_src_root}/configure_bazel.py"
+ "${_bazelrc_file}"
+ )
+ if(NOT RC EQUAL 0)
+ message(FATAL_ERROR "Error running ${_bazel_src_root}/configure_bazel.py script")
+ endif()
+
+ # Now add an import to the configured.bazelrc to load the project-wide
+ # bazelrc file.
+ file(APPEND "${_bazelrc_file}" "
+import ${_bazel_src_root}/build_tools/bazel/iree.bazelrc
+")
+
+ # Note that we do allow a .bazelrc in the home directory (otherwise we
+ # would have --nohome_rc). This is mainly about disabling interference from
+ # interactive builds in the workspace.
+ set(_bazel_startup_options "--nosystem_rc --noworkspace_rc '--bazelrc=${_bazelrc_file}' '--output_base=${_bazel_output_base}'")
+
+ # And emit scripts to delegate to bazel.
+ set(IREE_BAZEL_WRAPPER "${CMAKE_BINARY_DIR}/bazel")
+ configure_file(
+ "${CMAKE_CURRENT_SOURCE_DIR}/build_tools/cmake/bazel.sh.in"
+ "${IREE_BAZEL_WRAPPER}"
+ )
+ configure_file(
+ "${CMAKE_CURRENT_SOURCE_DIR}/build_tools/cmake/bazel.bat.in"
+ "${IREE_BAZEL_WRAPPER}.bat"
+ )
+ if(NOT WIN32)
+ execute_process(
+ COMMAND chmod a+x "${IREE_BAZEL_WRAPPER}"
+ )
+ endif()
+
+ # Now ready to start bazel and ask it things.
+ message(STATUS "Detecting bazel version...")
+ execute_process(
+ RESULT_VARIABLE RC
+ OUTPUT_VARIABLE BAZEL_RELEASE
+ OUTPUT_STRIP_TRAILING_WHITESPACE
+ COMMAND
+ "${IREE_BAZEL_WRAPPER}" info release
+ )
+ if(NOT RC EQUAL 0)
+ message(FATAL_ERROR "Failed to launch bazel using wrapper ${IREE_BAZEL_WRAPPER}. Inspect that script and ensure bazel is installed properly.")
+ endif()
+ execute_process(
+ RESULT_VARIABLE RC
+ OUTPUT_VARIABLE IREE_BAZEL_BIN
+ OUTPUT_STRIP_TRAILING_WHITESPACE
+ COMMAND
+ "${IREE_BAZEL_WRAPPER}" info bazel-bin
+ )
+ if(NOT RC EQUAL 0)
+ message(FATAL_ERROR "Failed to run 'info bazel-bin' via ${IREE_BAZEL_WRAPPER}. Inspect that script and ensure bazel is installed properly.")
+ endif()
+ message(STATUS "Found bazel ${BAZEL_RELEASE}, bin directory: ${IREE_BAZEL_BIN}")
+ message(STATUS "Bazel wrapper script generated at: ${IREE_BAZEL_WRAPPER}")
+
+ # Build automation will use the IREE_BAZEL_BIN variable, but also drop a
+ # convenience symlink, since that is what people expect.
+ # And since bazel isn't nice enough to create it...
+ if(NOT WIN32)
+ execute_process(
+ RESULT_VARIABLE RC
+ COMMAND
+ ln -sf "${IREE_BAZEL_BIN}" "${CMAKE_CURRENT_BINARY_DIR}/bazel-bin"
+ )
+ if(NOT RC EQUAL 0)
+ message(WARNING "Failed to create convenience bazel-bin symlink")
+ endif()
+ endif()
+
+ set(IREE_BAZEL_WRAPPER "${IREE_BAZEL_WRAPPER}" PARENT_SCOPE)
+ set(IREE_BAZEL_BIN "${IREE_BAZEL_BIN}" PARENT_SCOPE)
+endfunction()
+
+# iree_add_bazel_invocation
+#
+# Adds a target to perform a bazel invocation, building a list of targets
+# and exporting pseudo targets for some results of the build.
+#
+# Parameters:
+# INVOCATION_TARGET: The target name for the custom invocation target.
+# BAZEL_TARGETS: List of bazel targets to build.
+# EXECUTABLE_PATHS: Paths under bazel-bin for executables. An equivalent
+# CMake imported executable target will be created for each by replacing
+# the "/" with "_".
+function(iree_add_bazel_invocation)
+ cmake_parse_arguments(ARG
+ ""
+ "INVOCATION_TARGET"
+ "BAZEL_TARGETS;EXECUTABLE_PATHS"
+ ${ARGN}
+ )
+
+ add_custom_target(${ARG_INVOCATION_TARGET}
+ USES_TERMINAL
+ COMMAND ${CMAKE_COMMAND} -E echo
+ "Starting bazel build of targets ${ARG_BAZEL_TARGETS}"
+ COMMAND "${IREE_BAZEL_WRAPPER}" build ${ARG_BAZEL_TARGETS}
+ COMMAND ${CMAKE_COMMAND} -E echo "Bazel build complete."
+ )
+
+ # Create an imported executable target for each binary path.
+ # Since the bazel directory namespace lines up with the cmake namespace,
+ # generate a cmake target name for each.
+ foreach(_executable_path ${ARG_EXECUTABLE_PATHS})
+ string(REPLACE "/" "_" _executable_target "${_executable_path}")
+ message(STATUS "Add bazel executable target ${_executable_target}")
+ add_executable(${_executable_target} IMPORTED GLOBAL)
+ set_target_properties(${_executable_target}
+ PROPERTIES IMPORTED_LOCATION
+ "${IREE_BAZEL_BIN}/${_executable_path}${CMAKE_EXECUTABLE_SUFFIX}"
+ )
+ add_dependencies(${_executable_target} ${ARG_INVOCATION_TARGET})
+ endforeach()
+endfunction()
diff --git a/build_tools/cmake/iree_cross_compile.cmake b/build_tools/cmake/iree_cross_compile.cmake
index f749f90..581f3d2 100644
--- a/build_tools/cmake/iree_cross_compile.cmake
+++ b/build_tools/cmake/iree_cross_compile.cmake
@@ -87,6 +87,7 @@
iree_to_bool(_CONFIG_BUILD_PYTHON_BINDINGS "${IREE_${CONFIG_NAME}_BUILD_PYTHON_BINDINGS}")
iree_to_bool(_CONFIG_BUILD_JAVA_BINDINGS "${IREE_${CONFIG_NAME}_BUILD_JAVA_BINDINGS}")
iree_to_bool(_CONFIG_BUILD_EXPERIMENTAL "${IREE_${CONFIG_NAME}_BUILD_EXPERIMENTAL}")
+ iree_to_bool(_CONFIG_BUILD_TENSORFLOW_COMPILER "${IREE_${CONFIG_NAME}_BUILD_TENSORFLOW_COMPILER}")
# Escape semicolons in the targets list so that CMake doesn't expand them to
# spaces.
@@ -119,6 +120,7 @@
-DIREE_BUILD_PYTHON_BINDINGS=${_CONFIG_BUILD_PYTHON_BINDINGS}
-DIREE_BUILD_JAVA_BINDINGS=${_CONFIG_BUILD_JAVA_BINDINGS}
-DIREE_BUILD_EXPERIMENTAL=${_CONFIG_BUILD_EXPERIMENTAL}
+ -DIREE_BUILD_TENSORFLOW_COMPILER=${_CONFIG_BUILD_TENSORFLOW_COMPILER}
# LINT.ThenChange(
# https://github.com/google/iree/tree/main/CMakeLists.txt:iree_options,
# https://github.com/google/iree/tree/main/build_tools/cmake/iree_cross_compile.cmake:iree_cross_compile_options,
diff --git a/build_tools/cmake/iree_macros.cmake b/build_tools/cmake/iree_macros.cmake
index 04466b5..8f4e5d2 100644
--- a/build_tools/cmake/iree_macros.cmake
+++ b/build_tools/cmake/iree_macros.cmake
@@ -265,6 +265,49 @@
endfunction()
#-------------------------------------------------------------------------------
+# Tool symlinks
+#-------------------------------------------------------------------------------
+
+# iree_symlink_tool
+#
+# Adds a command to TARGET which symlinks a tool from elsewhere
+# (FROM_TOOL_TARGET_NAME) to a local file name (TO_EXE_NAME) in the current
+# binary directory.
+#
+# Parameters:
+# TARGET: Local target to which to add the symlink command (i.e. an
+# iree_py_library, etc).
+# FROM_TOOL_TARGET: Target of the tool executable that is the source of the
+# link.
+# TO_EXE_NAME: The executable name to output in the current binary dir.
+function(iree_symlink_tool)
+ cmake_parse_arguments(
+ ARG
+ ""
+ "TARGET;FROM_TOOL_TARGET;TO_EXE_NAME"
+ ""
+ ${ARGN}
+ )
+
+ # Transform TARGET
+ iree_package_ns(_PACKAGE_NS)
+ iree_package_name(_PACKAGE_NAME)
+ set(_TARGET "${_PACKAGE_NAME}_${ARG_TARGET}")
+ set(_FROM_TOOL_TARGET ${ARG_FROM_TOOL_TARGET})
+
+ add_custom_command(
+ TARGET "${_TARGET}"
+ BYPRODUCTS
+ "${CMAKE_CURRENT_BINARY_DIR}/${ARG_TO_EXE_NAME}${CMAKE_EXECUTABLE_SUFFIX}"
+ COMMAND
+ ${CMAKE_COMMAND} -E create_symlink
+ "$<TARGET_FILE:${_FROM_TOOL_TARGET}>"
+ "${CMAKE_CURRENT_BINARY_DIR}/${ARG_TO_EXE_NAME}${CMAKE_EXECUTABLE_SUFFIX}"
+ )
+endfunction()
+
+
+#-------------------------------------------------------------------------------
# Tests
#-------------------------------------------------------------------------------
diff --git a/build_tools/cmake/iree_multipy.cmake b/build_tools/cmake/iree_multipy.cmake
index 1ac1028..3456fad 100644
--- a/build_tools/cmake/iree_multipy.cmake
+++ b/build_tools/cmake/iree_multipy.cmake
@@ -23,12 +23,20 @@
# Note that this is using the pybind11 configuration vars, which creates
# a fragile dependency. It would be better to derive these locally.
if(Python3_FOUND)
- set(IREE_MULTIPY_DEFAULT_EXECUTABLE "${PYTHON_EXECUTABLE}" CACHE INTERNAL "Python executable" )
- set(IREE_MULTIPY_DEFAULT_INCLUDE_DIRS "${PYTHON_INCLUDE_DIRS}" CACHE INTERNAL "Python include dirs" )
- set(IREE_MULTIPY_DEFAULT_LIBRARIES "${PYTHON_LIBRARIES}" CACHE INTERNAL "Python libraries")
- set(IREE_MULTIPY_DEFAULT_PREFIX "${PYTHON_MODULE_PREFIX}" CACHE INTERNAL "Python module prefix")
- set(IREE_MULTIPY_DEFAULT_SUFFIX "${PYTHON_MODULE_SUFFIX}" CACHE INTERNAL "Python module suffix")
- set(IREE_MULTIPY_DEFAULT_EXTENSION "${PYTHON_MODULE_EXTENSION}" CACHE INTERNAL "Python module extension")
+ set(IREE_MULTIPY_DEFAULT_EXECUTABLE "${Python3_EXECUTABLE}" CACHE INTERNAL "Python executable" )
+ set(IREE_MULTIPY_DEFAULT_INCLUDE_DIRS "${Python3_INCLUDE_DIRS}" CACHE INTERNAL "Python include dirs" )
+ set(IREE_MULTIPY_DEFAULT_LIBRARIES "${Python3_LIBRARIES}" CACHE INTERNAL "Python libraries")
+ set(IREE_MULTIPY_DEFAULT_PREFIX "${Python3_MODULE_PREFIX}" CACHE INTERNAL "Python module prefix")
+ set(IREE_MULTIPY_DEFAULT_SUFFIX "${Python3_MODULE_SUFFIX}" CACHE INTERNAL "Python module suffix")
+ # CMake 3.19 and there-abouts does define Python3_SOABI, but get it
+ # ourselves for compatibility.
+ execute_process(
+ OUTPUT_VARIABLE _FOUND_DEFAULT_EXTENSION
+ OUTPUT_STRIP_TRAILING_WHITESPACE
+ COMMAND
+ "${Python3_EXECUTABLE}" -c "import sysconfig;print(sysconfig.get_config_var('EXT_SUFFIX'))"
+ )
+ set(IREE_MULTIPY_DEFAULT_EXTENSION "${_FOUND_DEFAULT_EXTENSION}" CACHE INTERNAL "Python module extension")
endif()
if(IREE_MULTIPY_VERSIONS)
@@ -51,13 +59,13 @@
# Check for required settings.
if(NOT IREE_MULTIPY_${V}_INCLUDE_DIRS)
- message(FATAL " MULTIPY version ${V}: No IREE_MULTIPY_${VER}_EXECUTABLE var")
+ message(FATAL_ERROR " MULTIPY version ${V}: No IREE_MULTIPY_${VER}_EXECUTABLE var")
endif()
if(NOT IREE_MULTIPY_${V}_INCLUDE_DIRS)
- message(FATAL " MULTIPY version ${V}: No IREE_MULTIPY_${VER}_INCLUDE_DIRS var")
+ message(FATAL_ERROR " MULTIPY version ${V}: No IREE_MULTIPY_${VER}_INCLUDE_DIRS var")
endif()
if(NOT IREE_MULTIPY_${V}_EXTENSION)
- message(FATAL " MULTIPY version ${V}: No IREE_MULTIPY_${VER}_EXTENSION var")
+ message(FATAL_ERROR " MULTIPY version ${V}: No IREE_MULTIPY_${VER}_EXTENSION var")
endif()
endforeach()
endfunction()
@@ -237,14 +245,20 @@
iree_package_name(_PACKAGE_NAME)
set(_NAME "${_PACKAGE_NAME}_${ARG_NAME}")
- # Add path to each source file
- list(TRANSFORM ARG_SRCS PREPEND "${CMAKE_CURRENT_SOURCE_DIR}/")
-
add_custom_target(${_NAME} ALL
- COMMAND ${CMAKE_COMMAND} -E copy ${ARG_SRCS} "${CMAKE_CURRENT_BINARY_DIR}/"
DEPENDS ${ARG_DEPS}
)
+ # Symlink each file as its own target.
+ foreach(SRC_FILE ${ARG_SRCS})
+ add_custom_command(
+ TARGET ${_NAME}
+ COMMAND ${CMAKE_COMMAND} -E create_symlink
+ "${CMAKE_CURRENT_SOURCE_DIR}/${SRC_FILE}" "${CMAKE_CURRENT_BINARY_DIR}/${SRC_FILE}"
+ BYPRODUCTS "${CMAKE_CURRENT_BINARY_DIR}/${SRC_FILE}"
+ )
+ endforeach()
+
# Add PYEXT_DEPS.
if(${ARG_PYEXT_DEPS})
foreach(V ${IREE_MULTIPY_VERSIONS_EFFECTIVE})
diff --git a/configure_bazel.py b/configure_bazel.py
index f0bac93..18be4a4 100644
--- a/configure_bazel.py
+++ b/configure_bazel.py
@@ -49,7 +49,7 @@
user_site = subprocess.check_output(
[sys.executable, "-m", "site", "--user-site"]).decode("utf-8").strip()
print("Found user site directory:", user_site)
- except OSError:
+ except subprocess.CalledProcessError:
print("Could not resolve user site directory")
return
print("build --action_env PYTHONPATH=\"{}\"".format(
@@ -57,7 +57,10 @@
file=bazelrc)
-local_bazelrc = os.path.join(os.path.dirname(__file__), "configured.bazelrc")
+if len(sys.argv) > 1:
+ local_bazelrc = sys.argv[1]
+else:
+ local_bazelrc = os.path.join(os.path.dirname(__file__), "configured.bazelrc")
with open(local_bazelrc, "wt") as bazelrc:
write_platform(bazelrc)
write_python_bin(bazelrc)
diff --git a/docs/get_started/cmake_options_and_variables.md b/docs/get_started/cmake_options_and_variables.md
index 18fff58..ad523a8 100644
--- a/docs/get_started/cmake_options_and_variables.md
+++ b/docs/get_started/cmake_options_and_variables.md
@@ -105,6 +105,15 @@
`MLIR_DIR` needs to be passed and that LLVM needs to be compiled with
`LLVM_ENABLE_RTTI` set to `ON`.
+#### `IREE_BUILD_TENSORFLOW_COMPILER`:BOOL
+
+Enables building of the TensorFlow to IREE compiler under
+`integrations/tensorflow`, including some native binaries and Python packages.
+Note that TensorFlow's build system is bazel-based and this will require a
+functioning `bazel` installation. A `bazel` wrapper script will be emitted
+in your build directory and `bazel-bin` link will point to artifacts. This
+can be used to manually invoke additional bazel actions if desired.
+
## MLIR-specific CMake Options and Variables
#### `MLIR_DIR`:STRING
diff --git a/docs/get_started/getting_started_python.md b/docs/get_started/getting_started_python.md
index af8a25b..955a179 100644
--- a/docs/get_started/getting_started_python.md
+++ b/docs/get_started/getting_started_python.md
@@ -1,11 +1,14 @@
# Getting Started with Python
-IREE has Python bindings geared towards lower level compiler interop that are
-not intended to be a public API, and integration with Python frontends such as
-TensorFlow.
+ NOTE: Iree's Python API is currently being reworked. Some of these
+ instructions may be in a state of flux as they document the end state.
-We do not yet provide a pip package for easy installation, so to use IREE's
-Python bindings you must build from source.
+IREE has two primary Python APIs:
+
+* Compiler API: `pyiree.compiler2`, `pyiree.compiler2.tf`
+* Runtime API: `pyiree.tf`
+
+There are additional ancillary modules that are not part of the public API.
## Prerequisites
@@ -13,22 +16,41 @@
[getting started guides](../get-started) for instructions.
> Note:<br>
-> Support is best with Bazel.
-> For CMake (excluding TensorFlow), set the `IREE_BUILD_PYTHON_BINDINGS` option.
+> Support is only complete with CMake.
+
+Minimally, the following CMake flags must be specified:
+
+* `-DIREE_BUILD_PYTHON_BINDINGS=ON`
+* `-DIREE_BUILD_TENSORFLOW_COMPILER=ON` : Optional. Also builds the
+ TensorFlow compiler integration.
+
+If building any parts of TensorFlow, you must have a working `bazel` command
+on your path. See the `.bazelversion` file at the root of the project for the
+recommended version.
## Python Setup
Install a recent version of [Python 3](https://www.python.org/downloads/) and
[pip](https://pip.pypa.io/en/stable/installing/), if needed.
+(Recommended) Setup a virtual environment (use your preferred mechanism):
+
+```shell
+# Note that venv is only available in python3 and is therefore a good check
+# that you are in fact running a python3 binary.
+python -m venv .venv
+source .venv/bin/activate
+# When done: run 'deactivate'
+```
+
Install packages:
```shell
-$ python3 -m pip install --upgrade pip
-$ python3 -m pip install numpy
+$ python -m pip install --upgrade pip
+$ python -m pip install numpy absl-py
# If using the TensorFlow integration
-$ python3 -m pip install tf-nightly
+$ python -m pip install tf-nightly
```
## Running Python Tests
@@ -40,29 +62,60 @@
$ ctest -L bindings/python
```
-To run tests for core Python bindings built with Bazel:
-
-```shell
-$ bazel test bindings/python/...
-```
-
To run tests for the TensorFlow integration, which include end-to-end backend
comparison tests:
```shell
-# Exclude tests that are skipped in the Kokoro CI
-$ bazel test \
- --build_tag_filters="-nokokoro" \
- --test_tag_filters="-nokokoro" \
- --define=iree_tensorflow=true \
- integrations/tensorflow/...
+cd build
+# TODO: Revisit once more patches land.
+ctest -L integrations/tensorflow/e2e
+
+# Or run individually as:
+export PYTHONPATH=bindings/python # In build dir
+python integrations/tensorflow/e2e/simple_arithmetic_test.py \
+ --target_backends=iree_vmla --artifacts_dir=/tmp/artifacts
```
+
## Using Colab
-See
-[start_colab_kernel.py](https://github.com/google/iree/blob/main/colab/start_colab_kernel.py)
-and [Using Colab](../using_iree/using_colab.md) for setup instructions, then
-take a look through the
-[Colab directory](https://github.com/google/iree/tree/main/colab) for some
-sample notebooks.
+There are some sample colabs in the `colab` folder. If you have built the
+project with CMake/ninja and set your `PYTHONPATH` to the `bindings/python`
+directory in the build dir (or installed per below), you should be able to
+start a kernel by following the stock instructions at
+https://colab.research.google.com/ .
+
+
+## Installing and Packaging
+
+There is a `setup.py` in the `bindings/python` directory under the build dir.
+To install into your (hopefully isolated) virtual env:
+
+```shell
+python bindings/python/setup.py install
+```
+
+To create wheels (platform dependent and locked to your Python version
+without further config):
+
+```shell
+python bindings/python/setup.py bdist_wheel
+```
+
+Note that it is often helpful to differentiate between the environment used to
+build and the one used to install. While this is just "normal" python
+knowledge, here is an incantation to do so:
+
+```shell
+# From parent/build environment.
+python -m pip freeze > /tmp/requirements.txt
+deactivate # If already in an environment
+
+# Enter new scratch environment.
+python -m venv ./.venv-scratch
+source ./.venv-scratch/bin/activate
+python -m pip install -r /tmp/requirements.txt
+
+# Install IREE into the new environment.
+python bindings/python/setup.py install
+```
diff --git a/experimental/ModelBuilder/test/CMakeLists.txt b/experimental/ModelBuilder/test/CMakeLists.txt
index 09ab2f2..7af3c3b 100644
--- a/experimental/ModelBuilder/test/CMakeLists.txt
+++ b/experimental/ModelBuilder/test/CMakeLists.txt
@@ -43,7 +43,6 @@
"TestDotProdJIT.cpp"
DEPS
LLVMSupport
- MLIRAllDialects
MLIREDSC
MLIRIR
MLIRSCFTransforms
@@ -59,9 +58,9 @@
SRCS
"TestVectorTransfersJIT.cpp"
DEPS
- runtime-support.so
LLVMSupport
- MLIRAllDialects
+ # TODO(thomasraoux): Fix dependecy to shared library.
+ # runtime-support.so
MLIREDSC
MLIRIR
MLIRSCFTransforms
@@ -77,11 +76,9 @@
SRCS
"TestMNISTJIT.cpp"
DEPS
- MLIRAllDialects
MLIREDSC
MLIRIR
MLIRSCFTransforms
- MLIRmlir_runner_utils
experimental::ModelBuilder
experimental::ModelBuilder::ModelRunner
)
@@ -95,7 +92,6 @@
"TestSimpleJIT.cpp"
DEPS
LLVMSupport
- MLIRAllDialects
MLIREDSC
MLIRIR
MLIRSCFTransforms
@@ -112,17 +108,11 @@
"TestSimpleJITVulkan.cpp"
DEPS
LLVMSupport
- MLIRAllDialects
MLIRIR
MLIRParser
MLIRSPIRV
- MLIRmlir_runner_utils
experimental::ModelBuilder
experimental::ModelBuilder::ModelRunner
- iree::base::initializer
- iree::hal::llvmjit::llvmjit_driver_module
- iree::hal::vmla::vmla_driver_module
- iree::hal::vulkan::vulkan_driver_module
vulkan-runtime-wrappers
)
@@ -135,7 +125,6 @@
"TestMatMulVulkan.cpp"
DEPS
LLVMSupport
- MLIRAllDialects
MLIRExecutionEngine
MLIRGPU
MLIRGPUToSPIRVTransforms
@@ -153,15 +142,11 @@
MLIRTargetLLVMIR
MLIRTransformUtils
MLIRVectorToLLVM
- MLIRmlir_runner_utils
experimental::ModelBuilder
experimental::ModelBuilder::ModelRunner
experimental::ModelBuilder::VulkanLaunchWrapper
- iree::base::initializer
iree::compiler::Conversion::LinalgToSPIRV
- iree::hal::llvmjit::llvmjit_driver_module
- iree::hal::vmla::vmla_driver_module
- iree::hal::vulkan::vulkan_driver_module
+ iree::tools::init_mlir_passes_and_dialects
vulkan-runtime-wrappers
)
@@ -174,7 +159,6 @@
"TestVectorToGPU.cpp"
DEPS
LLVMSupport
- MLIRAllDialects
MLIRExecutionEngine
MLIRGPU
MLIRGPUToVulkanTransforms
@@ -189,16 +173,12 @@
MLIRStandardToSPIRVTransforms
MLIRTransformUtils
MLIRVector
- MLIRmlir_runner_utils
experimental::ModelBuilder
experimental::ModelBuilder::ModelRunner
experimental::ModelBuilder::VulkanLaunchWrapper
- iree::base::initializer
iree::compiler::Conversion::CodegenUtils
iree::compiler::Conversion::LinalgToSPIRV
- iree::hal::llvmjit::llvmjit_driver_module
- iree::hal::vmla::vmla_driver_module
- iree::hal::vulkan::vulkan_driver_module
+ iree::tools::init_mlir_passes_and_dialects
vulkan-runtime-wrappers
)
@@ -211,7 +191,6 @@
"BenchMatMulVectorGPU.cpp"
DEPS
LLVMSupport
- MLIRAllDialects
MLIRExecutionEngine
MLIRGPU
MLIRGPUToVulkanTransforms
@@ -226,16 +205,12 @@
MLIRStandardToSPIRVTransforms
MLIRTransformUtils
MLIRVector
- MLIRmlir_runner_utils
experimental::ModelBuilder
experimental::ModelBuilder::ModelRunner
experimental::ModelBuilder::VulkanLaunchWrapper
- iree::base::initializer
iree::compiler::Conversion::CodegenUtils
iree::compiler::Conversion::LinalgToSPIRV
- iree::hal::llvmjit::llvmjit_driver_module
- iree::hal::vmla::vmla_driver_module
- iree::hal::vulkan::vulkan_driver_module
+ iree::tools::init_mlir_passes_and_dialects
vulkan-runtime-wrappers
)
@@ -247,7 +222,6 @@
SRCS
"TestSimpleMLIR.cpp"
DEPS
- MLIRAllDialects
experimental::ModelBuilder
experimental::ModelBuilder::ModelRunner
)
@@ -260,7 +234,6 @@
SRCS
"BenchMatVecVectorJIT.cpp"
DEPS
- MLIRAllDialects
MLIREDSC
MLIRIR
benchmark
@@ -276,7 +249,6 @@
SRCS
"BenchMatMulVectorJIT.cpp"
DEPS
- MLIRAllDialects
MLIREDSC
MLIRIR
benchmark
@@ -292,7 +264,6 @@
SRCS
"BenchMatMulVectorColumnMajorLLVMIntrinsicsJIT.cpp"
DEPS
- MLIRAllDialects
MLIREDSC
MLIRIR
benchmark
diff --git a/integrations/tensorflow/CMakeLists.txt b/integrations/tensorflow/CMakeLists.txt
new file mode 100644
index 0000000..6194ee3
--- /dev/null
+++ b/integrations/tensorflow/CMakeLists.txt
@@ -0,0 +1,36 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# TensorFlow builds through bazel, and IREE maintains all of its TensorFlow
+# dependent code under this directory tree. The CMake support is limited to
+# compiler binaries and python bindings.
+#
+# Bazel is a beast that likes to be the center of the universe. There is some
+# fragility in delegating to it in this fashion.
+#
+# If this directory is included, then building TensorFlow is assumed (the
+# config option happens at the higher level).
+
+iree_add_bazel_invocation(
+ INVOCATION_TARGET
+ integrations_iree_tensorflow_importers
+ BAZEL_TARGETS
+ //integrations/tensorflow/compiler:iree-tf-import
+ EXECUTABLE_PATHS
+ integrations/tensorflow/compiler/iree-tf-import
+)
+
+if(${IREE_BUILD_PYTHON_BINDINGS})
+ add_subdirectory(bindings/python)
+endif()
diff --git a/bindings/python/pyiree/CMakeLists.txt b/integrations/tensorflow/bindings/python/CMakeLists.txt
similarity index 62%
copy from bindings/python/pyiree/CMakeLists.txt
copy to integrations/tensorflow/bindings/python/CMakeLists.txt
index 55ac43d..31b61a2 100644
--- a/bindings/python/pyiree/CMakeLists.txt
+++ b/integrations/tensorflow/bindings/python/CMakeLists.txt
@@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-add_subdirectory(common)
-add_subdirectory(rt)
+# Overlays a subdirectory into the main python bindings directory.
+function(_add_overlay_subdirectory dir)
+ # Overlay binary directories onto the main bindings directory.
+ set(_MAIN_PYTHON_DIR "${CMAKE_BINARY_DIR}/bindings/python")
+ add_subdirectory(${dir} "${_MAIN_PYTHON_DIR}/${dir}")
+endfunction()
-if(${IREE_BUILD_COMPILER})
- add_subdirectory(compiler)
-endif()
+_add_overlay_subdirectory(pyiree/tools/tf)
diff --git a/integrations/tensorflow/bindings/python/pyiree/tools/tf/.skip_bazel_to_cmake b/integrations/tensorflow/bindings/python/pyiree/tools/tf/.skip_bazel_to_cmake
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/integrations/tensorflow/bindings/python/pyiree/tools/tf/.skip_bazel_to_cmake
diff --git a/bindings/python/pyiree/CMakeLists.txt b/integrations/tensorflow/bindings/python/pyiree/tools/tf/CMakeLists.txt
similarity index 70%
copy from bindings/python/pyiree/CMakeLists.txt
copy to integrations/tensorflow/bindings/python/pyiree/tools/tf/CMakeLists.txt
index 55ac43d..a10869b 100644
--- a/bindings/python/pyiree/CMakeLists.txt
+++ b/integrations/tensorflow/bindings/python/pyiree/tools/tf/CMakeLists.txt
@@ -12,9 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-add_subdirectory(common)
-add_subdirectory(rt)
+iree_py_library(
+ NAME
+ tf
+ SRCS
+ "__init__.py"
+ DEPS
+ integrations_iree_tensorflow_importers
+)
-if(${IREE_BUILD_COMPILER})
- add_subdirectory(compiler)
-endif()
+iree_symlink_tool(
+ TARGET tf
+ FROM_TOOL_TARGET integrations_tensorflow_compiler_iree-tf-import
+ TO_EXE_NAME iree-tf-import
+)
diff --git a/bindings/python/pyiree/CMakeLists.txt b/integrations/tensorflow/bindings/python/pyiree/tools/tf/__init__.py
similarity index 63%
copy from bindings/python/pyiree/CMakeLists.txt
copy to integrations/tensorflow/bindings/python/pyiree/tools/tf/__init__.py
index 55ac43d..ceb37ea 100644
--- a/bindings/python/pyiree/CMakeLists.txt
+++ b/integrations/tensorflow/bindings/python/pyiree/tools/tf/__init__.py
@@ -1,3 +1,6 @@
+# Lint-as: python3
+"""TensorFlow tools."""
+
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,9 +15,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-add_subdirectory(common)
-add_subdirectory(rt)
+from typing import Optional
-if(${IREE_BUILD_COMPILER})
- add_subdirectory(compiler)
-endif()
+import os
+import platform
+
+
+def get_tool(exe_name: str) -> Optional[str]:
+ if platform.system() == "Windows":
+ exe_name = exe_name + ".exe"
+ this_path = os.path.dirname(__file__)
+ tool_path = os.path.join(this_path, exe_name)
+ return tool_path
diff --git a/iree/compiler/Conversion/Common/VectorTransferOptimization.cpp b/iree/compiler/Conversion/Common/VectorTransferOptimization.cpp
index 91ef5a0..85aa505 100644
--- a/iree/compiler/Conversion/Common/VectorTransferOptimization.cpp
+++ b/iree/compiler/Conversion/Common/VectorTransferOptimization.cpp
@@ -12,17 +12,57 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace iree_compiler {
+// Return true if all the uses of op are either Store/transfer_write.
+// There can be SubviewOp users as long as all its users are also
+// StoreOp/transfer_write. If return true it also fills out the uses, if it
+// returns false uses is unchanged.
+static bool allUsesAreStores(Operation* op, std::vector<Operation*>& uses) {
+ std::vector<Operation*> opUses;
+ for (OpOperand& use : op->getUses()) {
+ Operation* useOp = use.getOwner();
+ if (isa<vector::TransferWriteOp, StoreOp>(useOp) ||
+ (isa<SubViewOp>(useOp) && allUsesAreStores(useOp, opUses))) {
+ opUses.push_back(useOp);
+ continue;
+ }
+ return false;
+ }
+ uses.insert(uses.end(), opUses.begin(), opUses.end());
+ return true;
+}
+
+// Track temporary allocations that are never read from. If this is the case
+// it means both the allocations and associated stores can be removed.
+static void eraseDeadAllocAndStores(FuncOp funcOp) {
+ std::vector<Operation*> opToErase;
+ funcOp.walk([&](AllocOp op) {
+ if (allUsesAreStores(op, opToErase)) {
+ opToErase.push_back(op.getOperation());
+ }
+ });
+ for (Operation* op : opToErase) {
+ op->erase();
+ }
+}
+
namespace {
struct VectorTransferOptimizationPass
: public PassWrapper<VectorTransferOptimizationPass, FunctionPass> {
- void runOnFunction() override { vector::transferOpflowOpt(getFunction()); }
+ void runOnFunction() override {
+ vector::transferOpflowOpt(getFunction());
+ // Delete potential dead alloc and associated ops after store to load
+ // forwarding.
+ eraseDeadAllocAndStores(getFunction());
+ }
};
} // namespace
diff --git a/iree/compiler/Conversion/HLOToLinalg/BUILD b/iree/compiler/Conversion/HLOToLinalg/BUILD
index a1f9de8..56a034d 100644
--- a/iree/compiler/Conversion/HLOToLinalg/BUILD
+++ b/iree/compiler/Conversion/HLOToLinalg/BUILD
@@ -19,11 +19,35 @@
)
cc_library(
- name = "HLOToLinalg",
+ name = "HLOToLinalgOnTensors",
srcs = [
"FusionOfTensorOps.cpp",
- "HLOToLinalgOnBuffers.cpp",
"HLOToLinalgOnTensors.cpp",
+ ],
+ hdrs = [
+ "HLOToLinalgOnTensorPasses.h",
+ ],
+ deps = [
+ "//iree/compiler/Dialect/HAL/IR",
+ "//iree/compiler/Dialect/HAL/IR:HALDialect",
+ "@llvm-project//mlir:Affine",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:LinalgOps",
+ "@llvm-project//mlir:LinalgTransforms",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:StandardOps",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:Transforms",
+ "@org_tensorflow//tensorflow/compiler/mlir/hlo",
+ "@org_tensorflow//tensorflow/compiler/mlir/hlo:legalize_to_linalg",
+ "@org_tensorflow//tensorflow/compiler/mlir/hlo:map_lmhlo_to_scalar_op",
+ ],
+)
+
+cc_library(
+ name = "HLOToLinalg",
+ srcs = [
+ "HLOToLinalgOnBuffers.cpp",
"Passes.cpp",
"ResolveShapeOps.cpp",
],
@@ -31,6 +55,7 @@
"Passes.h",
],
deps = [
+ ":HLOToLinalgOnTensors",
"//iree/compiler/Conversion/CodegenUtils",
"//iree/compiler/Conversion/HLOToHLO",
"//iree/compiler/Dialect/HAL/IR",
@@ -38,6 +63,7 @@
"//iree/compiler/Dialect/IREE/IR",
"//iree/compiler/Dialect/Shape/IR",
"@llvm-project//llvm:Support",
+ "@llvm-project//mlir:Affine",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps",
@@ -48,7 +74,6 @@
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
"@org_tensorflow//tensorflow/compiler/mlir/hlo",
- "@org_tensorflow//tensorflow/compiler/mlir/hlo:legalize_to_linalg",
"@org_tensorflow//tensorflow/compiler/mlir/hlo:map_lmhlo_to_scalar_op",
],
)
diff --git a/iree/compiler/Conversion/HLOToLinalg/CMakeLists.txt b/iree/compiler/Conversion/HLOToLinalg/CMakeLists.txt
index 5f24d65..8984a07 100644
--- a/iree/compiler/Conversion/HLOToLinalg/CMakeLists.txt
+++ b/iree/compiler/Conversion/HLOToLinalg/CMakeLists.txt
@@ -16,17 +16,40 @@
iree_cc_library(
NAME
+ HLOToLinalgOnTensors
+ HDRS
+ "HLOToLinalgOnTensorPasses.h"
+ SRCS
+ "FusionOfTensorOps.cpp"
+ "HLOToLinalgOnTensors.cpp"
+ DEPS
+ MLIRAffine
+ MLIRIR
+ MLIRLinalg
+ MLIRLinalgTransforms
+ MLIRPass
+ MLIRStandard
+ MLIRSupport
+ MLIRTransforms
+ iree::compiler::Dialect::HAL::IR
+ iree::compiler::Dialect::HAL::IR::HALDialect
+ tensorflow::mlir_hlo
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
HLOToLinalg
HDRS
"Passes.h"
SRCS
- "FusionOfTensorOps.cpp"
"HLOToLinalgOnBuffers.cpp"
- "HLOToLinalgOnTensors.cpp"
"Passes.cpp"
"ResolveShapeOps.cpp"
DEPS
+ ::HLOToLinalgOnTensors
LLVMSupport
+ MLIRAffine
MLIRIR
MLIRLinalg
MLIRLinalgTransforms
diff --git a/iree/compiler/Conversion/HLOToLinalg/FusionOfTensorOps.cpp b/iree/compiler/Conversion/HLOToLinalg/FusionOfTensorOps.cpp
index f2f7f3f..05a6c87 100644
--- a/iree/compiler/Conversion/HLOToLinalg/FusionOfTensorOps.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/FusionOfTensorOps.cpp
@@ -20,9 +20,10 @@
//
//===----------------------------------------------------------------------===//
-#include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
+#include "iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensorPasses.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -67,7 +68,8 @@
struct FusionOfTensorOpsPass
: public PassWrapper<FusionOfTensorOpsPass, OperationPass<>> {
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<linalg::LinalgDialect, IREE::HAL::HALDialect>();
+ registry
+ .insert<AffineDialect, IREE::HAL::HALDialect, linalg::LinalgDialect>();
}
void runOnOperation() override {
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensorPasses.h b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensorPasses.h
new file mode 100644
index 0000000..5057bdc
--- /dev/null
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensorPasses.h
@@ -0,0 +1,48 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- HLOToLinalgOnTensorsPasses.h - Passes to convert from HLO To Linalg ===//
+//
+// IREE specific passes used in the HLO to Linalg on tensors conversion as well
+// as fusion.
+//
+//===----------------------------------------------------------------------===//
+#ifndef IREE_COMPILER_CONVERSION_HLOTOLINALGONTENSORS_PASSES_H_
+#define IREE_COMPILER_CONVERSION_HLOTOLINALGONTENSORS_PASSES_H_
+
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+/// Creates a pass to fuse operations on tensors.
+std::unique_ptr<Pass> createFusionOfTensorOpsPass();
+
+/// Creates XLA-HLO to Linalg on tensors transformation pass.
+std::unique_ptr<OperationPass<FuncOp>> createHLOToLinalgOnTensorsPass();
+
+/// Populates the patterns that convert from MHLO to Linalg on tensors. Imports
+/// patterns from XLA, as well as some IREE specific modifications.
+void populateHLOToLinalgOnTensorsConversionPatterns(
+ MLIRContext *context, OwningRewritePatternList &patterns);
+
+/// Populated passes to convert from MHLO to Linalg on tensors as well as fusion
+/// of the converted operations.
+void addHLOToLinalgOnTensorsPasses(OpPassManager &pm);
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_CONVERSION_HLOTOLINALGONTENSORS_PASSES_H_
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
index a58a877..d86b947 100644
--- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
@@ -21,7 +21,7 @@
//===----------------------------------------------------------------------===//
#include <memory>
-#include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
+#include "iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensorPasses.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
diff --git a/iree/compiler/Conversion/HLOToLinalg/Passes.cpp b/iree/compiler/Conversion/HLOToLinalg/Passes.cpp
index 2bb1e30..ca0842b 100644
--- a/iree/compiler/Conversion/HLOToLinalg/Passes.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/Passes.cpp
@@ -15,6 +15,7 @@
#include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
#include "iree/compiler/Conversion/HLOToHLO/Passes.h"
+#include "iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensorPasses.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
@@ -23,12 +24,16 @@
namespace iree_compiler {
void addHLOToLinalgOnBuffersPasses(OpPassManager &pm) {
+ addHLOToLinalgOnTensorsPasses(pm);
+ pm.addNestedPass<FuncOp>(createHLOToLinalgOnBuffersPass());
+}
+
+void addHLOToLinalgOnTensorsPasses(OpPassManager &pm) {
pm.addNestedPass<FuncOp>(createHLOToLinalgOnTensorsPass());
pm.addNestedPass<FuncOp>(createLinalgFoldUnitExtentDimsPass());
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
pm.addNestedPass<FuncOp>(createFusionOfTensorOpsPass());
pm.addNestedPass<FuncOp>(createCSEPass());
- pm.addNestedPass<FuncOp>(createHLOToLinalgOnBuffersPass());
}
static PassPipelineRegistration<> hloToLinalgOnBuffersPipeline(
diff --git a/iree/compiler/Conversion/HLOToLinalg/Passes.h b/iree/compiler/Conversion/HLOToLinalg/Passes.h
index 4c048b0..2119e42 100644
--- a/iree/compiler/Conversion/HLOToLinalg/Passes.h
+++ b/iree/compiler/Conversion/HLOToLinalg/Passes.h
@@ -27,24 +27,13 @@
namespace mlir {
namespace iree_compiler {
-/// Creates a pass to fuse operations on tensors.
-std::unique_ptr<Pass> createFusionOfTensorOpsPass();
-
/// Creates XLA-HLO to Linalg on buffers transformation pass.
std::unique_ptr<OperationPass<FuncOp>> createHLOToLinalgOnBuffersPass();
-/// Creates XLA-HLO to Linalg on tensors transformation pass.
-std::unique_ptr<OperationPass<FuncOp>> createHLOToLinalgOnTensorsPass();
-
/// Resolves shape related ops (std.dim, shapex.tie_shape, etc.) by tracing
/// them back to the original HAL interface bindings.
std::unique_ptr<OperationPass<FuncOp>> createResolveShapeOpsPass();
-/// Populates the patterns that convert from XLA to Linalg on tensors. Imports
-/// patterns from XLA, as well as some IREE specific modifications.
-void populateHLOToLinalgOnTensorsConversionPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns);
-
/// Populates the patterns that convert from XLA to Linalg on buffers. Currently
/// only implements conversions when the XLA op is the only op XLA op in the
/// dispatch region.
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
index 25d53b6..7caca66 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
@@ -597,6 +597,16 @@
DISPATCH(vector::TransferWriteOp)
#undef DISPATCH
+
+ if (op->hasTrait<OpTrait::ElementwiseMappable>() &&
+ op->getNumResults() == 1) {
+ if (auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>()) {
+ // Map elementwise ops to vec4.
+ SmallVector<int64_t, 4> nativeSize(vecType.getRank() - 1, 1);
+ nativeSize.push_back(4);
+ return nativeSize;
+ }
+ }
return llvm::None;
}
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
index 864f2d4..7837c17 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
@@ -304,7 +304,8 @@
patterns.insert<linalg::LinalgTilingPattern<linalg::MatmulOp>,
linalg::LinalgTilingPattern<linalg::FillOp>,
- linalg::LinalgTilingPattern<linalg::BatchMatmulOp>>(
+ linalg::LinalgTilingPattern<linalg::BatchMatmulOp>,
+ linalg::LinalgTilingPattern<linalg::GenericOp>>(
context, tilingOptions,
getLinalgMatchAndReplaceMarker(
{getWorkgroupMemoryMarker(), getWorkgroupMarker()},
@@ -326,7 +327,8 @@
OwningRewritePatternList &patterns) {
patterns.insert<linalg::LinalgVectorizationPattern<linalg::MatmulOp>,
linalg::LinalgVectorizationPattern<linalg::BatchMatmulOp>,
- linalg::LinalgVectorizationPattern<linalg::FillOp>>(
+ linalg::LinalgVectorizationPattern<linalg::FillOp>,
+ linalg::LinalgVectorizationPattern<linalg::GenericOp>>(
context,
linalg::LinalgMarker(Identifier::get(getVectorizeMarker(), context)));
}
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
index 03cdd3e..bce7f04 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
@@ -293,6 +293,44 @@
}
};
+// Lower elementwise operation from N-D vector to 1-D vectors that can be
+// natively supported.
+class ElementwiseLowering : public RewritePattern {
+ public:
+ ElementwiseLowering(MLIRContext *context)
+ : RewritePattern(0, MatchAnyOpTypeTag()) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ if (!op->hasTrait<OpTrait::ElementwiseMappable>() ||
+ op->getNumResults() != 1)
+ return failure();
+ auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>();
+ if (!vecType || vecType.getRank() == 1) return failure();
+
+ SmallVector<Value, 4> newOperands;
+ for (Value operand : op->getOperands()) {
+ if (auto opVecType = operand.getType().dyn_cast<VectorType>()) {
+ auto newType = VectorType::get({opVecType.getNumElements()},
+ opVecType.getElementType());
+ newOperands.push_back(rewriter.create<vector::ShapeCastOp>(
+ op->getLoc(), newType, operand));
+ } else {
+ newOperands.push_back(operand);
+ }
+ }
+ OperationState state(op->getLoc(), op->getName());
+ state.addAttributes(op->getAttrs());
+ state.addOperands(newOperands);
+ state.addTypes({VectorType::get({vecType.getNumElements()},
+ vecType.getElementType())});
+ Operation *newOp = rewriter.createOperation(state);
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, vecType,
+ newOp->getResult(0));
+ return success();
+ }
+};
+
// Lower ExtractStridedSliceOp to an ExtractOp instruction that can be natively
// converted to SPIR-V. Add a BroadcastOp to keep the type consistent, we expect
// the Broadcast to be removed by canonicalization.
@@ -325,7 +363,8 @@
MLIRContext *context) {
OwningRewritePatternList patterns;
patterns.insert<VectorContractLowering, VectorTransferReadToLoad,
- VectorTransferWriteToStore, ExtractStridedLowering>(context);
+ VectorTransferWriteToStore, ExtractStridedLowering,
+ ElementwiseLowering>(context);
applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/dead_alloc.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/dead_alloc.mlir
new file mode 100644
index 0000000..433b836
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/dead_alloc.mlir
@@ -0,0 +1,20 @@
+// RUN: iree-opt -iree-codegen-optimize-vector-transfer %s | IreeFileCheck %s
+
+module {
+ func @dead_alloc() {
+ %0 = alloc() : memref<8x64xf32, 3>
+ %1 = subview %0[0, 0] [8, 4] [1, 1] : memref<8x64xf32, 3> to
+ memref<8x4xf32, affine_map<(d0, d1) -> (d0 * 64 + d1)>, 3>
+ %c0 = constant 0 : index
+ %cst_0 = constant dense<0.000000e+00> : vector<1x4xf32>
+ vector.transfer_write %cst_0, %1[%c0, %c0] {masked = [false, false]} :
+ vector<1x4xf32>, memref<8x4xf32, affine_map<(d0, d1) -> (d0 * 64 + d1)>, 3>
+ return
+ }
+}
+
+// CHECK-LABEL: func @dead_alloc
+// CHECK-NOT: alloc
+// CHECK-NOT: subview
+// CHECK-NOT: vector.transfer_write
+// CHECK: return
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir
index 6c9a4e3..9b2873e 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir
@@ -96,3 +96,67 @@
// CHECK-COUNT-12: spv.Load "StorageBuffer" %{{.*}} : vector<4xf32>
// CHECK-COUNT-32: spv.FMul %{{.*}}, %{{.*}} : vector<4xf32>
// CHECK-COUNT-8: spv.Store "StorageBuffer" %{{.*}}, %{{.*}} : vector<4xf32>
+
+// -----
+
+module attributes {
+ spv.target_env =
+ #spv.target_env<#spv.vce<v1.3,
+ [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess,
+ StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess,
+ UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform,
+ GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot,
+ GroupNonUniformShuffle, GroupNonUniformShuffleRelative, VariablePointers,
+ VariablePointersStorageBuffer],
+ [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage,
+ SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
+ ARM:IntegratedGPU,
+ {max_compute_shared_memory_size = 32768 : i32,
+ max_compute_workgroup_invocations = 512 : i32,
+ max_compute_workgroup_size = dense<512> : vector<3xi32>,
+ subgroup_size = 16 : i32}>} {
+ func @matmul_add_fused() attributes {hal.num_workgroups_fn = @matmul_add_fused__num_workgroups__} {
+ %cst = constant 0.000000e+00 : f32
+ %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0, operand_result_index = 3 : i32} : memref<1024x256xf32>
+ %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<1024x512xf32>
+ %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1, operand_result_index = 1 : i32} : memref<512x256xf32>
+ %3 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg2, operand_result_index = 2 : i32} : memref<1024x256xf32>
+ %4 = alloc() : memref<1024x256xf32>
+ linalg.fill(%4, %cst) : memref<1024x256xf32>, f32
+ linalg.matmul ins(%1, %2 : memref<1024x512xf32>, memref<512x256xf32>)
+ outs(%4 : memref<1024x256xf32>)
+ linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%4, %3 : memref<1024x256xf32>, memref<1024x256xf32>)
+ outs(%0 : memref<1024x256xf32>) {
+ ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): // no predecessors
+ %5 = addf %arg0, %arg1 : f32
+ linalg.yield %5 : f32
+ }
+ return
+ }
+ func @matmul_add_fused__num_workgroups__
+ (!shapex.ranked_shape<[4096, 4096]>, !shapex.ranked_shape<[4096, 4096]>,
+ !shapex.ranked_shape<[4096, 4096]>) -> (index, index, index)
+ attributes {sym_visibility = "private"}
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg2, set=0, binding=2, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=3, type="StorageBuffer", access="Write|Discard"
+ }
+}
+
+// CHECK-LABEL: spv.func @matmul_add_fused
+// CHECK-NOT: spv.Store "StorageBuffer"
+// CHECK-NOT: spv.Load "StorageBuffer"
+// CHECK: spv.loop
+// CHECK-COUNT-12: spv.Load "StorageBuffer" %{{.*}} : vector<4xf32>
+// CHECK-COUNT-32: spv.FMul %{{.*}}, %{{.*}} : vector<4xf32>
+// CHECK: spv.mlir.merge
+// CHECK-COUNT-8: spv.Load "StorageBuffer" %{{.*}} : vector<4xf32>
+// CHECK-NOT: spv.Load "StorageBuffer"
+// CHECK-NOT: spv.Store "StorageBuffer"
+// CHECK-COUNT-8: spv.FAdd %{{.*}}, %{{.*}} : vector<4xf32>
+// CHECK-COUNT-8: spv.Store "StorageBuffer" %{{.*}}, %{{.*}} : vector<4xf32>
diff --git a/iree/compiler/Conversion/init_conversions.h b/iree/compiler/Conversion/init_conversions.h
index 05bcdc4..6157e00 100644
--- a/iree/compiler/Conversion/init_conversions.h
+++ b/iree/compiler/Conversion/init_conversions.h
@@ -17,6 +17,7 @@
#include "iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.h"
#include "iree/compiler/Conversion/HLOToHLO/Passes.h"
+#include "iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensorPasses.h"
#include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
#include "iree/compiler/Conversion/LinalgToLLVM/Passes.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
index cc4e1eb..53b3703 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
+++ b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
@@ -17,6 +17,7 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOpUtils.h"
#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
#include "llvm/ADT/StringExtras.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
@@ -28,6 +29,7 @@
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/RegionUtils.h"
namespace mlir {
namespace iree_compiler {
@@ -277,27 +279,45 @@
// flow.dispatch.region
//===----------------------------------------------------------------------===//
+/// Inlines operation |op| into the |dispatchRegionOp| by making all operands,
+/// as well as values caputred implicitly by the regions of the operation, that
+/// are outside the dispatch region operands of the dispatch region as well.
+static Operation *inlineOpIntoDispatchRegion(OpBuilder &builder,
+ DispatchRegionOp dispatchRegionOp,
+ Operation *op,
+ BlockAndValueMapping &map) {
+ llvm::SetVector<Value> capturedInputs(op->getOperands().begin(),
+ op->getOperands().end());
+ getUsedValuesDefinedAbove(op->getRegions(), capturedInputs);
+ Block *block = builder.getInsertionBlock();
+ for (Value capturedInput : capturedInputs) {
+ if (map.contains(capturedInput)) continue;
+ dispatchRegionOp.getOperation()->insertOperands(
+ dispatchRegionOp.getOperation()->getNumOperands(), {capturedInput});
+ Value newBlockArgument = block->addArgument(capturedInput.getType());
+ map.map(capturedInput, newBlockArgument);
+ }
+
+ return builder.clone(*op, map);
+}
+
llvm::Optional<std::pair<DispatchRegionOp, Operation *>>
DispatchRegionOp::formFromAnchorOp(Value workload, Operation *anchorOp,
OpBuilder &builder) {
builder.setInsertionPoint(anchorOp);
auto loc = anchorOp->getLoc();
// Map anchor into new dispatch region.
- llvm::SmallVector<Value, 4> capturedInputs(anchorOp->getOperands());
auto drOp = builder.create<DispatchRegionOp>(
loc, llvm::to_vector<1>(anchorOp->getResultTypes()), workload,
- capturedInputs);
+ ArrayRef<Value>());
auto *drBlock = new Block();
drOp.body().push_back(drBlock);
BlockAndValueMapping mapping;
- for (Value capturedInput : capturedInputs) {
- auto blockArg = drBlock->addArgument(capturedInput.getType());
- mapping.map(capturedInput, blockArg);
- }
-
- // Create new body.
builder.setInsertionPointToEnd(drBlock);
- auto *newAnchorOp = builder.clone(*anchorOp, mapping);
+ Operation *newAnchorOp =
+ inlineOpIntoDispatchRegion(builder, drOp, anchorOp, mapping);
+
+ // Insert terminator
builder.create<IREE::Flow::ReturnOp>(loc, newAnchorOp->getResults());
// Replace anchor uses with region result.
@@ -366,16 +386,8 @@
origOpResultValues.push_back(mapping.lookupOrNull(result));
}
- // Add arguments for any op arguments that need to be captured.
- for (Value newArgument : origOp->getOperands()) {
- if (mapping.contains(newArgument)) continue;
- getOperation()->insertOperands(getNumOperands(), {newArgument});
- Value newBlockArgument = block.addArgument(newArgument.getType());
- mapping.map(newArgument, newBlockArgument);
- }
-
- // Clone the op.
- Operation *inlinedOp = builder.clone(*origOp, mapping);
+ Operation *inlinedOp =
+ inlineOpIntoDispatchRegion(builder, *this, origOp, mapping);
// Replace any results from the orig with results from the clone.
for (unsigned i = 0, e = origOp->getNumResults(); i < e; ++i) {
diff --git a/iree/compiler/Dialect/Flow/Transforms/BUILD b/iree/compiler/Dialect/Flow/Transforms/BUILD
index 0a7d576..b1121e9 100644
--- a/iree/compiler/Dialect/Flow/Transforms/BUILD
+++ b/iree/compiler/Dialect/Flow/Transforms/BUILD
@@ -48,6 +48,7 @@
],
deps = [
"//iree/base:signature_mangle",
+ "//iree/compiler/Conversion/HLOToLinalg:HLOToLinalgOnTensors",
"//iree/compiler/Dialect/Flow/Analysis",
"//iree/compiler/Dialect/Flow/Conversion",
"//iree/compiler/Dialect/Flow/Conversion/HLOToFlow",
@@ -61,6 +62,7 @@
"//iree/compiler/Utils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
+ "@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:ShapeTransforms",
diff --git a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index 9fd3368..88e17ed 100644
--- a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -44,6 +44,7 @@
DEPS
LLVMSupport
MLIRIR
+ MLIRLinalg
MLIRPass
MLIRShape
MLIRShapeOpsTransforms
@@ -52,6 +53,7 @@
MLIRTransformUtils
MLIRTransforms
iree::base::signature_mangle
+ iree::compiler::Conversion::HLOToLinalg::HLOToLinalgOnTensors
iree::compiler::Dialect::Flow::Analysis
iree::compiler::Dialect::Flow::Conversion
iree::compiler::Dialect::Flow::Conversion::HLOToFlow
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
index 1c59969..3581d9b 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
@@ -20,6 +20,7 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#define DEBUG_TYPE "iree-detail"
@@ -129,7 +130,7 @@
}
bool OpDispatchPolicy::isViewModificationOp(Operation *op) {
- return isa<mhlo::ReshapeOp>(op);
+ return isa<mhlo::ReshapeOp, linalg::TensorReshapeOp>(op);
}
int OpDispatchPolicy::getAnchorBenefit(Operation *op) {
@@ -221,8 +222,9 @@
// TODO(b/144530470): replace with tablegen attributes/interfaces.
bool OpDispatchPolicy::isUnsupportedFusionOp(Operation *op) {
- return isa<mhlo::ConcatenateOp, mhlo::ConvOp, mhlo::PadOp, mhlo::ReduceOp,
- mhlo::ReduceWindowOp, mhlo::TorchIndexSelectOp>(op) ||
+ return isa<linalg::IndexedGenericOp, linalg::GenericOp, mhlo::ConcatenateOp,
+ mhlo::ConvOp, mhlo::PadOp, mhlo::ReduceOp, mhlo::ReduceWindowOp,
+ mhlo::TorchIndexSelectOp>(op) ||
(!clEnableConsumerOnlyFusion &&
isa<mhlo::DotOp, mhlo::DotGeneralOp>(op)) ||
isRootOnlyOp(op);
diff --git a/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions2.cpp b/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions2.cpp
index f604b39..73e6b6e 100644
--- a/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions2.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions2.cpp
@@ -20,10 +20,12 @@
#include "iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h"
#include "iree/compiler/Dialect/Flow/Utils/WorkloadUtils.h"
#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/SetVector.h"
#include "llvm/Support/Debug.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/RegionUtils.h"
#define DEBUG_TYPE "iree-dispatch"
@@ -271,6 +273,18 @@
}
}
}
+ // Check for values that are used in the region of the op but captured from
+ // outside the region.
+ llvm::SetVector<Value> capturedValues;
+ getUsedValuesDefinedAbove(inlinedOp->getRegions(), capturedValues);
+ for (Value capturedValue : capturedValues) {
+ if (Operation *definingOp = capturedValue.getDefiningOp()) {
+ if (!lastOperandDef ||
+ lastOperandDef.getValue()->isBeforeInBlock(definingOp)) {
+ lastOperandDef = definingOp;
+ }
+ }
+ }
// If the last operand def is already before the dispatch region, there is
// nothing to do.
if (!lastOperandDef ||
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index d012ba5..6732b18 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -16,13 +16,21 @@
#include <memory>
+#include "iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensorPasses.h"
#include "iree/compiler/Dialect/Shape/Conversion/Passes.h"
#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
+#include "mlir/Pass/PassOptions.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/Passes.h"
+static llvm::cl::opt<bool> clEnableLinalgOnTensors(
+ "iree-enable-linalg-on-tensors",
+ llvm::cl::desc(
+ "Enable use of Linalg on tensors for dispatch region creation"),
+ llvm::cl::init(false));
+
namespace mlir {
namespace iree_compiler {
namespace IREE {
@@ -150,6 +158,10 @@
passManager.addNestedPass<FuncOp>(
IREE::Flow::createPrePartitioningConversionPass());
+ if (clEnableLinalgOnTensors) {
+ addHLOToLinalgOnTensorsPasses(passManager);
+ }
+
// First perform module-level analysis that following passes will use to query
// per-function dispatchability information. We run this first so that it only
// needs to run once and will be cached for all of the following passes.
diff --git a/iree/compiler/Dialect/Flow/Transforms/RematerializeDispatchConstants.cpp b/iree/compiler/Dialect/Flow/Transforms/RematerializeDispatchConstants.cpp
index 431f0e3..79cd715 100644
--- a/iree/compiler/Dialect/Flow/Transforms/RematerializeDispatchConstants.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/RematerializeDispatchConstants.cpp
@@ -46,6 +46,8 @@
return true;
} else if (auto value = constantOp.getValue().dyn_cast<DenseElementsAttr>()) {
return value.isSplat();
+ } else if (constantOp.getType().isIntOrFloat()) {
+ return true;
}
// Assume anything unshaped is small. This may not always be true in custom
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_linalg.mlir b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_linalg.mlir
new file mode 100644
index 0000000..2c82ed7
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_linalg.mlir
@@ -0,0 +1,50 @@
+// RUN: iree-opt -split-input-file -iree-flow-dispatchability-analysis -iree-flow-identify-dispatch-regions2 %s | IreeFileCheck %s
+
+func @constant_capture(%arg0 : tensor<10x20xf32>) -> tensor<10x20xf32> {
+ %cst1 = constant 1.0 : f32
+ %cst2 = constant dense<2.0> : tensor<10x20xf32>
+ %cst3 = constant dense<
+ [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]> : tensor<10xf32>
+ %0 = linalg.generic
+ {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %cst2, %cst3
+ : tensor<10x20xf32>, tensor<10x20xf32>, tensor<10xf32>) {
+ ^bb0(%arg1 : f32, %arg2 : f32, %arg3 : f32):
+ %1 = addf %arg1, %cst1 : f32
+ %2 = mulf %1, %arg2 : f32
+ %3 = addf %2, %arg3 : f32
+ linalg.yield %3 : f32
+ } -> tensor<10x20xf32>
+ return %0 : tensor<10x20xf32>
+}
+// CHECK: func @constant_capture
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<10x20xf32>
+// CHECK-DAG: %[[CST1:.+]] = constant 1.000000e+00 : f32
+// CHECK-DAG: %[[CST2:.+]] = constant dense<2.000000e+00> : tensor<10x20xf32>
+// CHECK-DAG: %[[CST3:.+]] = constant dense<[1.000000e+00, 2.000000e+00,
+// CHECK-SAME: 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00,
+// CHECK-SAME: 7.000000e+00, 8.000000e+00, 9.000000e+00, 1.000000e+01]>
+// CHECK: %[[RESULT:.+]] = flow.dispatch.region[%{{.+}} : index](
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] = %[[ARG0]]
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]] = %[[CST2]]
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]] = %[[CST3]]
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]] = %[[CST1]]
+// CHECK-SAME: ) -> tensor<10x20xf32> {
+// CHECK: %[[RETURN:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG1]], %[[ARG2]], %[[ARG3]]
+// CHECK-SAME: ) {
+// CHECK-NEXT: ^{{[a-zA-Z0-9]+}}(
+// CHECK-SAME: %[[ARG5:.[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG6:.[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG7:.[a-zA-Z0-9_]+]]: f32)
+// CHECK: %[[T0:.+]] = addf %[[ARG5]], %[[ARG4]]
+// CHECK: %[[T1:.+]] = mulf %[[T0]], %[[ARG6]]
+// CHECK: %[[T2:.+]] = addf %[[T1]], %[[ARG7]]
+// CHECK: linalg.yield %[[T2]]
+// CHECK: }
+// CHECK: flow.return %[[RETURN]]
+// CHECK: }
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/rematerialize_dispatch_constants.mlir b/iree/compiler/Dialect/Flow/Transforms/test/rematerialize_dispatch_constants.mlir
index 45a0195..0562442 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/rematerialize_dispatch_constants.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/rematerialize_dispatch_constants.mlir
@@ -65,3 +65,59 @@
}
return %0 : tensor<4x4xf32>
}
+
+// -----
+
+func @constant_capture(%arg0: tensor<10x20xf32>) -> tensor<10x20xf32> {
+ %c200 = constant 200 : index
+ %cst = constant 1.000000e+00 : f32
+ %cst_0 = constant dense<2.000000e+00> : tensor<10x20xf32>
+ %cst_1 = constant dense<
+ [1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00,
+ 6.000000e+00, 7.000000e+00, 8.000000e+00, 9.000000e+00, 1.000000e+01]>
+ : tensor<10xf32>
+ %0 = flow.dispatch.region[%c200 : index]
+ (%arg1 = %arg0 : tensor<10x20xf32>, %arg2 = %cst_0 : tensor<10x20xf32>,
+ %arg3 = %cst_1 : tensor<10xf32>, %arg4 = %cst : f32) -> tensor<10x20xf32> {
+ %1 = linalg.generic
+ {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg1, %arg2, %arg3
+ : tensor<10x20xf32>, tensor<10x20xf32>, tensor<10xf32>) {
+ ^bb0(%arg5: f32, %arg6: f32, %arg7: f32): // no predecessors
+ %2 = addf %arg5, %arg4 : f32
+ %3 = mulf %2, %arg6 : f32
+ %4 = addf %3, %arg7 : f32
+ linalg.yield %4 : f32
+ } -> tensor<10x20xf32>
+ flow.return %1 : tensor<10x20xf32>
+ }
+ return %0 : tensor<10x20xf32>
+}
+
+// CHECK-LABEL: func @constant_capture
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<10x20xf32>
+// CHECK: %[[CST:.+]] = constant dense<[1.000000e+00, 2.000000e+00,
+// CHECK-SAME: 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00,
+// CHECK-SAME: 7.000000e+00, 8.000000e+00, 9.000000e+00, 1.000000e+01]>
+// CHECK: flow.dispatch.region
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] = %[[ARG0]]
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]] = %[[CST]]
+// CHECK-DAG: %[[CST_0:.+]] = constant 1.000000e+00 : f32
+// CHECK-DAG: %[[CST_1:.+]] = constant dense<2.000000e+00> : tensor<10x20xf32>
+// CHECK: %[[RESULT:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG1]], %[[CST_1]], %[[ARG2]]
+// CHECK-SAME: ) {
+// CHECK: ^{{[a-zA-Z0-9_]+}}(
+// CHECK-SAME: %[[ARG3:.[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG4:.[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG5:.[a-zA-Z0-9_]+]]: f32)
+// CHECK: %[[T0:.+]] = addf %[[ARG3]], %[[CST_0]]
+// CHECK: %[[T1:.+]] = mulf %[[T0]], %[[ARG4]]
+// CHECK: %[[T2:.+]] = addf %[[T1]], %[[ARG5]]
+// CHECK: linalg.yield %[[T2]]
+// CHECK: }
+// CHECK: flow.return %[[RESULT]]
diff --git a/iree/compiler/Dialect/Flow/Utils/DispatchUtils.cpp b/iree/compiler/Dialect/Flow/Utils/DispatchUtils.cpp
index 162c682..7fc8240 100644
--- a/iree/compiler/Dialect/Flow/Utils/DispatchUtils.cpp
+++ b/iree/compiler/Dialect/Flow/Utils/DispatchUtils.cpp
@@ -33,11 +33,11 @@
// TODO(benvanik): replace with op dispatchability interface to allow dialects
// to opt into dispatch.
auto dialectNamespace = op->getDialect()->getNamespace();
- return dialectNamespace == mhlo::MhloDialect::getDialectNamespace() ||
+ return dialectNamespace == FlowDialect::getDialectNamespace() ||
+ dialectNamespace == linalg::LinalgDialect::getDialectNamespace() ||
+ dialectNamespace == mhlo::MhloDialect::getDialectNamespace() ||
dialectNamespace == mlir::StandardOpsDialect::getDialectNamespace() ||
- dialectNamespace == FlowDialect::getDialectNamespace() ||
- dialectNamespace == ShapeDialect::getDialectNamespace() ||
- dialectNamespace == linalg::LinalgDialect::getDialectNamespace();
+ dialectNamespace == ShapeDialect::getDialectNamespace();
}
namespace {
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index 463b49e..9bf155b 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -1356,8 +1356,9 @@
}
OptionalParseResult parseResult = parser.parseOptionalRegion(*body);
- if (parseResult.hasValue() && failed(*parseResult))
+ if (parseResult.hasValue() && failed(*parseResult)) {
return failure();
+ }
// Ensure that this module has a valid terminator.
ExecutableTargetOp::ensureTerminator(*body, parser.getBuilder(),
@@ -1410,8 +1411,9 @@
return failure();
}
OptionalParseResult parseResult = parser.parseOptionalRegion(*body);
- if (parseResult.hasValue() && failed(*parseResult))
+ if (parseResult.hasValue() && failed(*parseResult)) {
return failure();
+ }
// Ensure that this module has a valid terminator.
ExecutableBinaryOp::ensureTerminator(*body, parser.getBuilder(),
diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.cpp b/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.cpp
index 64e1079..faf41fe 100644
--- a/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.cpp
+++ b/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.cpp
@@ -106,6 +106,19 @@
return flatbuffers_uint8_vec_end(fbb);
}
+static flatbuffers_uint8_vec_ref_t serializeConstantF16Array(
+ DenseFPElementsAttr attr, FlatbufferBuilder &fbb) {
+ flatbuffers_uint8_vec_start(fbb);
+ uint8_t *bytePtr = flatbuffers_uint8_vec_extend(
+ fbb, attr.getNumElements() * sizeof(uint16_t));
+ uint16_t *nativePtr = reinterpret_cast<uint16_t *>(bytePtr);
+ for (const APFloat &value : attr.getFloatValues()) {
+ *(nativePtr++) =
+ value.bitcastToAPInt().extractBitsAsZExtValue(16, 0) & UINT16_MAX;
+ }
+ return flatbuffers_uint8_vec_end(fbb);
+}
+
flatbuffers_uint8_vec_ref_t serializeConstant(Location loc,
ElementsAttr elementsAttr,
FlatbufferBuilder &fbb) {
@@ -126,6 +139,8 @@
}
} else if (auto attr = elementsAttr.dyn_cast<DenseFPElementsAttr>()) {
switch (attr.getType().getElementTypeBitWidth()) {
+ case 16:
+ return serializeConstantF16Array(attr, fbb);
case 32:
return serializeConstantF32Array(attr, fbb);
case 64:
diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/test/constant_encoding.mlir b/iree/compiler/Dialect/VM/Target/Bytecode/test/constant_encoding.mlir
index fc4d1a2..78298bb 100644
--- a/iree/compiler/Dialect/VM/Target/Bytecode/test/constant_encoding.mlir
+++ b/iree/compiler/Dialect/VM/Target/Bytecode/test/constant_encoding.mlir
@@ -47,4 +47,15 @@
// CHECK-NEXT: 63
// CHECK-NEXT: ]
vm.rodata @splat_float32s dense<1.000000e+00> : tensor<3xf32>
+
+ // CHECK: "data": [
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 60,
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 64,
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 66
+ // CHECK-NEXT: ]
+ vm.rodata @dense_float16s dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf16>
+
}
diff --git a/iree/test/e2e/models/edge_detection.mlir b/iree/test/e2e/models/edge_detection.mlir
index bb8b47a..c290e84 100644
--- a/iree/test/e2e/models/edge_detection.mlir
+++ b/iree/test/e2e/models/edge_detection.mlir
@@ -1,6 +1,8 @@
// RUN: iree-run-mlir -export-all -iree-hal-target-backends=vmla %s -function-input="1x128x128x1xf32" | IreeFileCheck %s
// RUN: [[ $IREE_LLVMJIT_DISABLE == 1 ]] || (iree-run-mlir -export-all -iree-hal-target-backends=llvm-ir %s -function-input="1x128x128x1xf32" | IreeFileCheck %s)
// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir -export-all -iree-hal-target-backends=vulkan-spirv %s -function-input="1x128x128x1xf32" | IreeFileCheck %s)
+// RUN: [[ $IREE_LLVMJIT_DISABLE == 1 ]] || (iree-run-mlir -export-all -iree-hal-target-backends=llvm-ir -iree-enable-linalg-on-tensors %s -function-input="1x128x128x1xf32" | IreeFileCheck %s)
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir -export-all -iree-hal-target-backends=vulkan-spirv -iree-enable-linalg-on-tensors %s -function-input="1x128x128x1xf32" | IreeFileCheck %s)
// Image edge detection module generated by iree/colab/edge_detection.ipynb.
//
diff --git a/iree/test/e2e/models/fragment_000.mlir b/iree/test/e2e/models/fragment_000.mlir
index dd9f902..8328f86 100644
--- a/iree/test/e2e/models/fragment_000.mlir
+++ b/iree/test/e2e/models/fragment_000.mlir
@@ -1,6 +1,8 @@
// RUN: iree-run-mlir -export-all -iree-hal-target-backends=vmla %s -function-input="f32=0" -function-input="5x1xf32=[1][-2][-3][4][-5]" -function-input="f32=1" -function-input="5x5xf32=[3.46499 -7.64389 -5.72249 5.98053 17.6892][2.9707 -6.20734 -4.25962 4.76055 13.8784][2.47641 -4.77079 -2.79675 3.54056 10.0675][1.98212 -3.33424 -1.33388 2.32058 6.25666][1.48783 -1.8977 0.12899 1.1006 2.4458]" -function-input="5xf32=0 0 0 0 0" | IreeFileCheck %s --implicit-check-not="[" --implicit-check-not="]"
// RUN: [[ $IREE_LLVMJIT_DISABLE == 1 ]] || (iree-run-mlir -export-all -iree-hal-target-backends=llvm-ir %s -function-input="f32=0" -function-input="5x1xf32=[1][-2][-3][4][-5]" -function-input="f32=1" -function-input="5x5xf32=[3.46499 -7.64389 -5.72249 5.98053 17.6892][2.9707 -6.20734 -4.25962 4.76055 13.8784][2.47641 -4.77079 -2.79675 3.54056 10.0675][1.98212 -3.33424 -1.33388 2.32058 6.25666][1.48783 -1.8977 0.12899 1.1006 2.4458]" -function-input="5xf32=0 0 0 0 0" | IreeFileCheck %s --implicit-check-not="[" --implicit-check-not="]")
// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir -export-all -iree-hal-target-backends=vulkan-spirv %s -function-input="f32=0" -function-input="5x1xf32=[1][-2][-3][4][-5]" -function-input="f32=1" -function-input="5x5xf32=[3.46499 -7.64389 -5.72249 5.98053 17.6892][2.9707 -6.20734 -4.25962 4.76055 13.8784][2.47641 -4.77079 -2.79675 3.54056 10.0675][1.98212 -3.33424 -1.33388 2.32058 6.25666][1.48783 -1.8977 0.12899 1.1006 2.4458]" -function-input="5xf32=0 0 0 0 0" | IreeFileCheck %s --implicit-check-not="[" --implicit-check-not="]")
+// RUN: [[ $IREE_LLVMJIT_DISABLE == 1 ]] || (iree-run-mlir -export-all -iree-hal-target-backends=llvm-ir -iree-enable-linalg-on-tensors %s -function-input="f32=0" -function-input="5x1xf32=[1][-2][-3][4][-5]" -function-input="f32=1" -function-input="5x5xf32=[3.46499 -7.64389 -5.72249 5.98053 17.6892][2.9707 -6.20734 -4.25962 4.76055 13.8784][2.47641 -4.77079 -2.79675 3.54056 10.0675][1.98212 -3.33424 -1.33388 2.32058 6.25666][1.48783 -1.8977 0.12899 1.1006 2.4458]" -function-input="5xf32=0 0 0 0 0" | IreeFileCheck %s --implicit-check-not="[" --implicit-check-not="]")
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir -export-all -iree-hal-target-backends=vulkan-spirv -iree-enable-linalg-on-tensors %s -function-input="f32=0" -function-input="5x1xf32=[1][-2][-3][4][-5]" -function-input="f32=1" -function-input="5x5xf32=[3.46499 -7.64389 -5.72249 5.98053 17.6892][2.9707 -6.20734 -4.25962 4.76055 13.8784][2.47641 -4.77079 -2.79675 3.54056 10.0675][1.98212 -3.33424 -1.33388 2.32058 6.25666][1.48783 -1.8977 0.12899 1.1006 2.4458]" -function-input="5xf32=0 0 0 0 0" | IreeFileCheck %s --implicit-check-not="[" --implicit-check-not="]")
// CHECK-LABEL: EXEC @main_entry_dispatch_3
func @main_entry_dispatch_3(
diff --git a/iree/test/e2e/models/fullyconnected.mlir b/iree/test/e2e/models/fullyconnected.mlir
index b2ba5f4..ab9c466 100644
--- a/iree/test/e2e/models/fullyconnected.mlir
+++ b/iree/test/e2e/models/fullyconnected.mlir
@@ -1,6 +1,8 @@
// RUN: iree-run-mlir -export-all %s -iree-hal-target-backends=vmla -function-input="1x5xf32=1,-2,-3,4,-5" -function-input="1x5x3x1xf32=15,14,13,12,11,10,9,8,7,6,5,4,3,2,1" | IreeFileCheck %s
// RUN: [[ $IREE_LLVMJIT_DISABLE == 1 ]] || (iree-run-mlir -export-all %s -iree-hal-target-backends=llvm-ir -function-input="1x5xf32=1,-2,-3,4,-5" -function-input="1x5x3x1xf32=15,14,13,12,11,10,9,8,7,6,5,4,3,2,1" -iree-enable-consumer-only-fusion | IreeFileCheck %s)
// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir -export-all %s -iree-hal-target-backends=vulkan-spirv -function-input="1x5xf32=1,-2,-3,4,-5" -function-input="1x5x3x1xf32=15,14,13,12,11,10,9,8,7,6,5,4,3,2,1" | IreeFileCheck %s)
+// RUN: [[ $IREE_LLVMJIT_DISABLE == 1 ]] || (iree-run-mlir -export-all %s -iree-hal-target-backends=llvm-ir -function-input="1x5xf32=1,-2,-3,4,-5" -function-input="1x5x3x1xf32=15,14,13,12,11,10,9,8,7,6,5,4,3,2,1" -iree-enable-linalg-on-tensors | IreeFileCheck %s)
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir -export-all %s -iree-hal-target-backends=vulkan-spirv -function-input="1x5xf32=1,-2,-3,4,-5" -function-input="1x5x3x1xf32=15,14,13,12,11,10,9,8,7,6,5,4,3,2,1" -iree-enable-linalg-on-tensors | IreeFileCheck %s)
// CHECK-LABEL: EXEC @main
func @main(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5x3x1xf32>) -> tuple<tensor<5x1x5xf32>>
diff --git a/iree/test/e2e/models/mnist_fake_weights.mlir b/iree/test/e2e/models/mnist_fake_weights.mlir
index 8404532..380899d 100644
--- a/iree/test/e2e/models/mnist_fake_weights.mlir
+++ b/iree/test/e2e/models/mnist_fake_weights.mlir
@@ -3,6 +3,8 @@
// RUN: iree-run-mlir -export-all -iree-hal-target-backends=vmla %s -function-input="1x28x28x1xf32" | IreeFileCheck %s
// RUN: [[ $IREE_LLVMJIT_DISABLE == 1 ]] || (iree-run-mlir -export-all -iree-hal-target-backends=llvm-ir %s -function-input="1x28x28x1xf32" | IreeFileCheck %s)
// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir -export-all -iree-hal-target-backends=vulkan-spirv %s -function-input="1x28x28x1xf32" | IreeFileCheck %s)
+// RUN: [[ $IREE_LLVMJIT_DISABLE == 1 ]] || (iree-run-mlir -export-all -iree-hal-target-backends=llvm-ir -iree-enable-linalg-on-tensors %s -function-input="1x28x28x1xf32" | IreeFileCheck %s)
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir -export-all -iree-hal-target-backends=vulkan-spirv -iree-enable-linalg-on-tensors %s -function-input="1x28x28x1xf32" | IreeFileCheck %s)
module {
flow.variable @"__iree_flow___sm_node17__model.layer-1.kernel" dense<1.000000e+00> : tensor<784x128xf32> attributes {noinline, sym_visibility = "private"}
diff --git a/iree/test/e2e/models/unidirectional_lstm.mlir b/iree/test/e2e/models/unidirectional_lstm.mlir
index 510eb4a..302f79e 100644
--- a/iree/test/e2e/models/unidirectional_lstm.mlir
+++ b/iree/test/e2e/models/unidirectional_lstm.mlir
@@ -3,6 +3,8 @@
// RUN: iree-run-mlir %s -iree-hal-target-backends=vmla -function-input="1x5xf32=[0 1 0 3 4]" -function-input="1x5x2x2xf32=[1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20]" | IreeFileCheck %s --implicit-check-not="[" --implicit-check-not="]"
// RUN: [[ $IREE_LLVMJIT_DISABLE == 1 ]] || (iree-run-mlir %s -iree-hal-target-backends=llvm-ir -function-input="1x5xf32=[0 1 0 3 4]" -function-input="1x5x2x2xf32=[1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20]" | IreeFileCheck %s --implicit-check-not="[" --implicit-check-not="]")
// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir %s -iree-hal-target-backends=vulkan-spirv -function-input="1x5xf32=[0 1 0 3 4]" -function-input="1x5x2x2xf32=[1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20]" | IreeFileCheck %s --implicit-check-not="[" --implicit-check-not="]")
+// RUN: [[ $IREE_LLVMJIT_DISABLE == 1 ]] || (iree-run-mlir %s -iree-hal-target-backends=llvm-ir -function-input="1x5xf32=[0 1 0 3 4]" -function-input="1x5x2x2xf32=[1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20]" -iree-enable-linalg-on-tensors | IreeFileCheck %s --implicit-check-not="[" --implicit-check-not="]")
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir %s -iree-hal-target-backends=vulkan-spirv -function-input="1x5xf32=[0 1 0 3 4]" -function-input="1x5x2x2xf32=[1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20]" -iree-enable-linalg-on-tensors | IreeFileCheck %s --implicit-check-not="[" --implicit-check-not="]")
// Exported via the XLA HLO Importer
// The resulting MLIR was modified by hand by changing all large constants to be
diff --git a/iree/tools/run_lit.sh b/iree/tools/run_lit.sh
index 704c7cc..78cbf66 100755
--- a/iree/tools/run_lit.sh
+++ b/iree/tools/run_lit.sh
@@ -17,9 +17,16 @@
set -o pipefail
if [ -z "${RUNFILES_DIR}" ]; then
- # Some versions of bazel do not set RUNFILES_DIR. Instead they just cd
- # into the directory.
- RUNFILES_DIR="$PWD"
+ if [ -f "CMakeCache.txt" ]; then
+ # If running under CMake/CTest in the build directory, just scope to the
+ # iree directory to avoid blowing up the search through things like
+ # bazel out directories and the like.
+ RUNFILES_DIR="$PWD/iree"
+ else
+ # Some versions of bazel do not set RUNFILES_DIR. Instead they just cd
+ # into the directory.
+ RUNFILES_DIR="$PWD"
+ fi
fi
# Detect whether cygwin/msys2 paths need to be translated.
diff --git a/packaging/python/.gitignore b/packaging/python/.gitignore
deleted file mode 100644
index b357fe2..0000000
--- a/packaging/python/.gitignore
+++ /dev/null
@@ -1,4 +0,0 @@
-build/
-dist/
-*.egg-info
-
diff --git a/packaging/python/BUILD.bazel b/packaging/python/BUILD.bazel
deleted file mode 100644
index 4e48edd..0000000
--- a/packaging/python/BUILD.bazel
+++ /dev/null
@@ -1,39 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-package(
- features = ["layering_check"],
- licenses = ["notice"],
-)
-
-# This is a dummy binary that has the side-effect of building all of the TF
-# python bindings. It is used to build wheel files.
-py_binary(
- name = "all_pyiree_packages",
- srcs = ["dummy_exclude_from_package.py"],
- legacy_create_init = False,
- main = "dummy_exclude_from_package.py",
- python_version = "PY3",
- tags = [
- # Do not build with ... expansion
- "manual",
- ],
- deps = [
- "//bindings/python:pathsetup", # build_cleaner: keep
- "//bindings/python/pyiree/compiler", # build_cleaner: keep
- "//bindings/python/pyiree/rt", # build_cleaner: keep
- "//integrations/tensorflow/bindings/python/pyiree/tf/compiler", # build_cleaner: keep
- "//integrations/tensorflow/bindings/python/pyiree/tf/support", # build_cleaner: keep
- ],
-)
diff --git a/packaging/python/README.md b/packaging/python/README.md
deleted file mode 100644
index 3d387a2..0000000
--- a/packaging/python/README.md
+++ /dev/null
@@ -1,87 +0,0 @@
-# Python packaging scripts.
-
-Note that packages will be placed in `packaging/python/dist` with the canonical
-instructions. However, the setup scripts can be run from anywhere and will
-create `build` and `dist` directories where run. Wheels can be installed with
-`pip3 install --user dist/*.whl`.
-
-## Building core wheels with CMake
-
-Most of IREE is built/packaged with CMake. For the parts that build with CMake,
-this is preferred.
-
-Canonical instructions follow:
-
-### Linux
-
-```shell
-export LDFLAGS=-fuse-ld=/usr/bin/ld.lld
-export PYIREE_CMAKE_BUILD_ROOT="${HOME?}/build-iree-release"
-export IREE_SRC="${HOME?}/src/iree"
-rm -Rf "${PYIREE_CMAKE_BUILD_ROOT?}"; mkdir -p "${PYIREE_CMAKE_BUILD_ROOT?}"
-cmake -GNinja -B"${PYIREE_CMAKE_BUILD_ROOT?}" -H"${IREE_SRC}" \
- -DCMAKE_BUILD_TYPE=Release \
- -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ \
- -DIREE_BUILD_PYTHON_BINDINGS=ON -DIREE_BUILD_SAMPLES=OFF
-(cd "${PYIREE_CMAKE_BUILD_ROOT?}" && ninja)
-(cd "${IREE_SRC?}/packaging/python" && (
- rm -Rf build;
- python3 setup_compiler.py bdist_wheel;
- rm -Rf build;
- python3 setup_rt.py bdist_wheel))
-```
-
-## Building IREE/TensorFlow wheels
-
-If building TensorFlow integration wheels, then this must be done via Bazel. In
-this case, it can be easiest to just package everything from a Bazel build to
-avoid multiple steps.
-
-Canonical instructions follow:
-
-### Env Setup
-
-```shell
-IREE_SRC=$HOME/src/iree
-export PYIREE_BAZEL_BUILD_ROOT="$IREE_SRC/bazel-bin"
-if which cygpath; then
- export PYIREE_BAZEL_BUILD_ROOT="$(cygpath -w "$PYIREE_BAZEL_BUILD_ROOT")"
-fi
-```
-
-### Building:
-
-Optionally add: `--define=PYIREE_TF_DISABLE_KERNELS=1` to build a 'thin' (less
-functional) version without TensorFlow kernels. This should not be done for
-released binaries but can help while developing.
-
-Note that bazel does not always build properly named artifacts. See the tool
-`hack_python_package_from_runfiles.py` to extract and fixup artifacts from a
-bazel-bin directory. If using this mechanism, then the environment variable
-`PYIREE_PYTHON_ROOT` should be set to a suitable temp directory.
-
-```shell
-cd $IREE_SRC
-bazel build -c opt \
- //packaging/python:all_pyiree_packages
-```
-
-# Packaging
-
-```shell
-(cd $IREE_SRC/packaging/python && (
- rm -Rf build;
- python3 setup_tf.py bdist_wheel))
-```
-
-```shell
-(cd $IREE_SRC/packaging/python && (
- rm -Rf build;
- python3 setup_compiler.py bdist_wheel))
-```
-
-```shell
-(cd $IREE_SRC/packaging/python && (
- rm -Rf build;
- python3 setup_rt.py bdist_wheel))
-```
diff --git a/packaging/python/__init__.py b/packaging/python/__init__.py
deleted file mode 100644
index 8b13789..0000000
--- a/packaging/python/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/packaging/python/common_setup.py b/packaging/python/common_setup.py
deleted file mode 100644
index 727ce24..0000000
--- a/packaging/python/common_setup.py
+++ /dev/null
@@ -1,132 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import platform
-import setuptools
-import sys
-import sysconfig
-from datetime import date
-
-
-def get_exe_suffix():
- if platform.system() == "Windows":
- return ".exe"
- else:
- return ""
-
-
-def get_package_dir(prefix=("bindings", "python")):
- explicit_root = os.environ.get("PYIREE_PYTHON_ROOT")
- if explicit_root:
- return explicit_root
-
- # Use env variables based on build system type.
- cmake_build_root = os.environ.get("PYIREE_CMAKE_BUILD_ROOT")
- bazel_build_root = os.environ.get("PYIREE_BAZEL_BUILD_ROOT")
-
- if cmake_build_root and bazel_build_root:
- print("ERROR: Both PYIREE_CMAKE_BUILD_ROOT and PYIREE_BAZEL_BUILD_ROOT"
- " cannot be set at the same time")
- sys.exit(1)
-
- if cmake_build_root:
- print("Using CMake build root:", cmake_build_root)
- pkg_dir = os.path.join(cmake_build_root, *prefix)
- elif bazel_build_root:
- print("Using Bazel build root:", bazel_build_root)
- if not os.path.isdir(bazel_build_root):
- print("ERROR: Could not find bazel-bin:", bazel_build_root)
- sys.exit(1)
- # Find the path to the runfiles of the built target:
- # //bindings/python/packaging:all_pyiree_packages
- runfiles_dir = os.path.join(
- bazel_build_root, "packaging", "python",
- "all_pyiree_packages%s.runfiles" % (get_exe_suffix(),))
- if not os.path.isdir(runfiles_dir):
- print("ERROR: Could not find build target 'all_pyiree_packages':",
- runfiles_dir)
- print("Make sure to build target",
- "//packaging/python:all_pyiree_packages")
- sys.exit(1)
- # And finally seek into the corresponding path in the runfiles dir.
- # Aren't bazel paths fun???
- # Note that the "iree_core" path segment corresponds to the workspace name.
- pkg_dir = os.path.join(runfiles_dir, "iree_core", *prefix)
- else:
- print("ERROR: No build directory specified. Set one of these variables:")
- print(" PYIREE_CMAKE_BUILD_ROOT=/path/to/cmake/build")
- sys.exit(1)
-
- if not os.path.exists(pkg_dir):
- print("ERROR: Package path does not exist:", pkg_dir)
- sys.exit(1)
- return pkg_dir
-
-
-def get_default_date_version():
- today = date.today()
- return today.strftime("%Y%m%d")
-
-
-def get_setup_defaults(sub_project, description, package_dir=None):
- if not package_dir:
- package_dir = get_package_dir()
- return {
- "name": "google-iree-%s" % (sub_project,),
- "version": get_default_date_version(),
- "author": "The IREE Team at Google",
- "author_email": "iree-discuss@googlegroups.com",
- "description": description,
- "long_description": description,
- "long_description_content_type": "text/plain",
- "url": "https://github.com/google/iree",
- "package_dir": {
- "": package_dir,
- },
- "classifiers": [
- "Programming Language :: Python :: 3",
- "License :: OSI Approved :: Apache License",
- "Operating System :: OS Independent",
- "Development Status :: 3 - Alpha",
- ],
- "python_requires": ">=3.6",
- }
-
-
-def setup(**kwargs):
- # See: https://stackoverflow.com/q/45150304
- try:
- from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
-
- class bdist_wheel(_bdist_wheel):
-
- def finalize_options(self):
- _bdist_wheel.finalize_options(self)
- self.root_is_pure = False
- except ImportError:
- bdist_wheel = None
-
- # Need to include platform specific extensions binaries:
- # Windows: .pyd
- # macOS: .dylib
- # Other: .so
- # Unfortunately, bazel is imprecise and scatters .so files around, so
- # need to be specific.
- package_data = {
- "": ["*%s" % (sysconfig.get_config_var("EXT_SUFFIX"),)],
- }
- setuptools.setup(package_data=package_data,
- cmdclass={"bdist_wheel": bdist_wheel},
- **kwargs)
diff --git a/packaging/python/hack_python_package_from_runfiles.py b/packaging/python/hack_python_package_from_runfiles.py
deleted file mode 100644
index 61c2d6e..0000000
--- a/packaging/python/hack_python_package_from_runfiles.py
+++ /dev/null
@@ -1,102 +0,0 @@
-#!/usr/bin/python
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Given a runfiles directory from a bazel build, does surgery to extract
-# a usable python package directory. In addition to the bazel directory
-# structure being unnecessarily obtuse, it is also really hard to actually
-# name files correctly. This affects python extension modules which must be
-# named with a specific extension suffix. Bazel is extremely unflexible and
-# we patch around it with this script. For the record, there are various ways
-# to write custom rules to do this more natively, but it is all complicated
-# and needless complexity. We opt for a script that is at least readable by
-# mere mortals and in one place.
-# Usage:
-# ./this_script <dest_dir> <path to bazel-bin>
-
-import os
-import platform
-import shutil
-import sys
-import sysconfig
-
-FILE_NAME_MAP = {
- "binding.so": "binding{}".format(sysconfig.get_config_var("EXT_SUFFIX")),
- "binding.pyd": False,
- "binding.dylib": False,
-}
-
-
-def get_exe_suffix():
- if platform.system() == "Windows":
- return ".exe"
- else:
- return ""
-
-
-def copy_prefix(dest_dir, runfiles_dir, prefix):
- # And finally seek into the corresponding path in the runfiles dir.
- # Aren't bazel paths fun???
- # Note that the "iree_core" path segment corresponds to the workspace name.
- pkg_dir = os.path.join(runfiles_dir, "iree_core", *prefix)
- if not os.path.exists(pkg_dir):
- return
- dest_dir = os.path.join(dest_dir)
- for root, dirs, files in os.walk(pkg_dir):
- assert root.startswith(pkg_dir)
- dest_prefix = root[len(pkg_dir):]
- if dest_prefix.startswith(os.path.sep):
- dest_prefix = dest_prefix[1:]
- local_dest_dir = os.path.join(dest_dir, dest_prefix)
- os.makedirs(local_dest_dir, exist_ok=True)
- for file in files:
- copy_file(os.path.join(root, file), local_dest_dir)
-
-
-def copy_file(src_file, dst_dir):
- basename = os.path.basename(src_file)
- dst_file = os.path.join(dst_dir, basename)
- mapped_name = FILE_NAME_MAP.get(basename)
- if mapped_name is False:
- # Skip.
- return
- elif mapped_name is not None:
- dst_file = os.path.join(dst_dir, mapped_name)
- shutil.copyfile(src_file, dst_file, follow_symlinks=True)
-
-
-def main():
- # Parse args.
- dest_dir = sys.argv[1]
- bazel_bin = sys.argv[2] if len(sys.argv) > 2 else os.path.join(
- os.path.dirname(__file__), "..", "..", "bazel-bin")
-
- # Find the path to the runfiles of the built target:
- # //bindings/python/packaging:all_pyiree_packages
- runfiles_dir = os.path.join(
- bazel_bin, "packaging", "python",
- "all_pyiree_packages%s.runfiles" % (get_exe_suffix(),))
- if not os.path.isdir(runfiles_dir):
- print("ERROR: Could not find build target 'all_pyiree_packages':",
- runfiles_dir)
- print("Make sure to build target", "//packaging/python:all_pyiree_packages")
- sys.exit(1)
-
- copy_prefix(dest_dir, runfiles_dir, ("bindings", "python"))
- copy_prefix(dest_dir, runfiles_dir,
- ("integrations", "tensorflow", "bindings", "python"))
-
-
-if __name__ == "__main__":
- main()
diff --git a/packaging/python/setup_compiler.py b/packaging/python/setup_compiler.py
deleted file mode 100644
index 57f4479..0000000
--- a/packaging/python/setup_compiler.py
+++ /dev/null
@@ -1,51 +0,0 @@
-#!/usr/bin/python3
-
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Build platform specific wheel files for the pyiree.rt package.
-# Built artifacts are per-platform and build out of the build tree.
-# Usage:
-# ------
-# Windows with CMake:
-# export CMAKE_BUILD_ROOT='D:\src\build-iree' # Must be native path
-# python ./setup_compiler.py bdist_wheel
-
-import os
-import setuptools
-import sys
-
-# Ensure that path starts here for execution as a script.
-sys.path.insert(0, os.path.dirname(__file__))
-import common_setup
-
-
-def run():
- packages = setuptools.find_namespace_packages(
- common_setup.get_package_dir(),
- include=["pyiree.compiler", "pyiree.compiler.*"],
- exclude=["*.CMakeFiles"])
- print("Found packages:", packages)
- setup_kwargs = common_setup.get_setup_defaults(
- sub_project="compiler", description="IREE Generic Compiler")
- common_setup.setup(packages=packages,
- ext_modules=[
- setuptools.Extension(name="pyiree.compiler.binding",
- sources=[]),
- ],
- **setup_kwargs)
-
-
-if __name__ == "__main__":
- run()
diff --git a/packaging/python/setup_rt.py b/packaging/python/setup_rt.py
deleted file mode 100644
index c9fc948..0000000
--- a/packaging/python/setup_rt.py
+++ /dev/null
@@ -1,47 +0,0 @@
-#!/usr/bin/python3
-
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Build platform specific wheel files for the pyiree.rt package.
-# Built artifacts are per-platform and build out of the build tree.
-
-import os
-import setuptools
-import sys
-
-# Ensure that path starts here for execution as a script.
-sys.path.insert(0, os.path.dirname(__file__))
-import common_setup
-
-
-def run():
- packages = setuptools.find_namespace_packages(
- common_setup.get_package_dir(),
- include=["pyiree.rt", "pyiree.rt.*"],
- exclude=["*.CMakeFiles"])
- print("Found packages:", packages)
- setup_kwargs = common_setup.get_setup_defaults(
- sub_project="rt",
- description="IREE Runtime Components (for executing compiled programs)")
- common_setup.setup(packages=packages,
- ext_modules=[
- setuptools.Extension(name="pyiree.rt.binding",
- sources=[]),
- ],
- **setup_kwargs)
-
-
-if __name__ == "__main__":
- run()
diff --git a/packaging/python/setup_tf.py b/packaging/python/setup_tf.py
deleted file mode 100644
index 252c756..0000000
--- a/packaging/python/setup_tf.py
+++ /dev/null
@@ -1,59 +0,0 @@
-#!/usr/bin/python3
-
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Build platform specific wheel files for the pyiree.tf packages.
-# Built artifacts are per-platform and build out of the build tree.
-
-import os
-import platform
-import setuptools
-import sys
-
-# Ensure that path starts here for execution as a script.
-sys.path.insert(0, os.path.dirname(__file__))
-import common_setup
-
-
-def run():
- package_dir = common_setup.get_package_dir(prefix=("integrations",
- "tensorflow", "bindings",
- "python"))
- packages = setuptools.find_namespace_packages(package_dir,
- include=[
- "pyiree.tf.compiler",
- "pyiree.tf.compiler.*",
- "pyiree.tf.support",
- "pyiree.tf.support.*"
- ],
- exclude=["*.CMakeFiles"])
- print("Found packages:", packages)
- if not packages:
- print("ERROR: Did not find packages under", package_dir)
- sys.exit(1)
- setup_kwargs = common_setup.get_setup_defaults(
- sub_project="tf",
- description="IREE TensorFlow Compiler",
- package_dir=package_dir)
- common_setup.setup(packages=packages,
- ext_modules=[
- setuptools.Extension(name="pyiree.tf.compiler.binding",
- sources=[]),
- ],
- **setup_kwargs)
-
-
-if __name__ == "__main__":
- run()