[onnx] Add ONNX importer and iree-import-onnx tool to compiler package. (#15920)
* When building the torch frontend, we now also have access to the
upstream ONNX importer and include it here as part of our official API.
* Also added a custom `iree-import-onnx` tool and corresponding test.
* Added extras_require for `onnx` to setup.py (allowing it to be
installed as an optional dependency).
* Added a _package_test.py for the compiler package like the runtime has
and configured the CI to use it.
* Added a check to the release validation job.
* Includes a bump of torch-mlir to latest.
* May need to tweak some things in the input pipeline to get
iree-compile to work on this by default. Will do in a followup.
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 2ab9ffe..b999a38 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -607,13 +607,21 @@
shell: bash
env:
packages: "iree-compiler"
- output_dir: "${{ github.workspace }}/bindist"
# Note when upgrading: Build just one Python version synced to our
# minimum.
override_python_versions: cp39-cp39
run: |
- ./build_tools/python_deploy/build_linux_packages.sh
-
+ output_dir="$PWD" ./build_tools/python_deploy/build_linux_packages.sh
+ - name: Validate compiler wheel (Linux)
+ shell: bash
+ run: |
+ pip install --upgrade pip
+ # Pre-fetch optional deps that iree-compiler needs (but we constrain that
+ # to not consult a package index).
+ pip install onnx>=1.15.0
+ pip install --no-index -f $PWD -v iree-compiler[onnx]
+ echo "Testing default compiler:"
+ python -m iree.compiler._package_test
asan:
needs: setup
if: contains(fromJson(needs.setup.outputs.enabled-jobs), 'asan')
diff --git a/.github/workflows/validate_and_publish_release.yml b/.github/workflows/validate_and_publish_release.yml
index 41d19b4..0d8d85a 100644
--- a/.github/workflows/validate_and_publish_release.yml
+++ b/.github/workflows/validate_and_publish_release.yml
@@ -41,7 +41,12 @@
- name: Install python packages
id: install_python_packages
run: |
- python -m pip install -f file://$PWD/artifact/ iree-compiler iree-runtime iree-tools-tflite iree-tools-tf
+ python -m pip install -f file://$PWD/artifact/ iree-compiler[onnx] iree-runtime iree-tools-tflite iree-tools-tf
+ - name: Validate IREE Compiler Package
+ id: validate_compiler_package
+ run: |
+ echo "Testing compiler package:"
+ python -m iree.compiler._package_test
- name: Validate IREE Runtime Package
id: validate_runtime_package
run: |
diff --git a/build_tools/scripts/check_tabs.sh b/build_tools/scripts/check_tabs.sh
index 72cdf49..f27cf16 100755
--- a/build_tools/scripts/check_tabs.sh
+++ b/build_tools/scripts/check_tabs.sh
@@ -21,6 +21,7 @@
# Symlinks make grep upset
"^integrations/tensorflow/iree-dialects$"
# Generated / Binary files
+ ".onnx"
".svg"
)
diff --git a/compiler/bindings/python/CMakeLists.txt b/compiler/bindings/python/CMakeLists.txt
index e47ff50..013edc4 100644
--- a/compiler/bindings/python/CMakeLists.txt
+++ b/compiler/bindings/python/CMakeLists.txt
@@ -100,19 +100,52 @@
DIALECT_NAME vm
)
+declare_mlir_python_sources(IREECompilerAPIPythonCore
+ ADD_TO_PARENT IREEPythonSources
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/iree/compiler"
+ SOURCES
+ _package_test.py
+ api/__init__.py
+ api/ctypes_dl.py
+)
+
+# Note that some tools rely on optional features but we unconditionally
+# include them because they are referenced from console scripts and
+# other package metadata. They will detect mis-configuration and error
+# accordingly at runtime.
declare_mlir_python_sources(IREECompilerAPIPythonTools
+ ADD_TO_PARENT IREEPythonSources
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/iree/compiler"
SOURCES
__init__.py
tf.py
tflite.py
- xla.py
- SOURCES_GLOB
- api/*.py
- tools/*.py
- tools/ir_tool/*.py
+ tools/__init__.py
+ tools/binaries.py
+ tools/core.py
+ tools/debugging.py
+ tools/tf.py
+ tools/tflite.py
+ tools/import_onnx/__main__.py
+ tools/ir_tool/__main__.py
+ tools/scripts/ireec/__main__.py
)
+# The Python bindings are monolithic and we don't have a good way for the
+# torch plugin to contribute Python sources, so we just gate it here
+# versus having more complicated indirection. May want to rethink this
+# if others need it.
+if(IREE_INPUT_TORCH)
+
+ declare_mlir_python_sources(IREEPythonSources.Torch.Importers
+ ADD_TO_PARENT IREEPythonSources
+ ROOT_DIR "${IREE_SOURCE_DIR}/third_party/torch-mlir/python/torch_mlir"
+ SOURCES
+ extras/onnx_importer.py
+ )
+
+endif()
+
################################################################################
# Extensions
################################################################################
diff --git a/compiler/bindings/python/iree/compiler/_package_test.py b/compiler/bindings/python/iree/compiler/_package_test.py
new file mode 100644
index 0000000..19088a9
--- /dev/null
+++ b/compiler/bindings/python/iree/compiler/_package_test.py
@@ -0,0 +1,51 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+"""Runs on a pip installed runtime package and verifies it is setup properly."""
+
+from typing import Optional
+
+import subprocess
+from typing import List
+
+
+# Check tools.
+def check_tool(tool_name: str, args: List[str], find_line: Optional[str] = None):
+ print(f"**** Checking tool {tool_name} with args {args}")
+ output = subprocess.check_output([tool_name] + args).decode()
+ if find_line is not None:
+ output_lines = output.splitlines()
+ for line in output_lines:
+ if find_line in line:
+ print(f"Found output: {line.strip()}")
+ return
+ raise ValueError(
+ f"Did not find banner '{find_line}' for {tool_name}:\n{output}"
+ )
+
+
+# Verify version.
+import iree.compiler.version as v
+
+assert hasattr(v, "PACKAGE_SUFFIX")
+assert v.REVISIONS["IREE"]
+assert v.VERSION
+print("IREE version:", v.VERSION)
+
+check_tool("iree-compile", ["--help"], "IREE compilation driver")
+check_tool("iree-ir-tool", ["--help"], "IREE IR Tool")
+
+# ONNX dependent.
+onnx_available = False
+try:
+ import onnx
+
+ onnx_available = True
+except ModuleNotFoundError:
+ print("Not checking iree-import-onnx: onnx pip package not found")
+if onnx_available:
+ check_tool("iree-import-onnx", ["--help"], "IREE ONNX import tool")
+
+print("***** All done *****")
diff --git a/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py b/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py
new file mode 100644
index 0000000..9d1fb54
--- /dev/null
+++ b/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py
@@ -0,0 +1,87 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+"""Console tool for converting an ONNX proto to torch IR.
+
+Typically, when installed from a wheel, this can be invoked as:
+
+ iree-import-onnx some.pb
+
+Or from Python:
+
+ python -m iree.compiler.tools.import_onnx ...
+"""
+import argparse
+from pathlib import Path
+import sys
+
+try:
+ import onnx
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError(
+ f"iree-import-onnx requires that the `onnx` Python package is installed "
+ f"(typically `{sys.executable} -m pip install onnx`)"
+ ) from e
+
+try:
+ from ...extras import onnx_importer
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError(
+ "iree-import-onnx is only available if IREE was built with Torch support"
+ ) from e
+
+from ...ir import (
+ Context,
+)
+
+
+def main(args):
+ model_proto = load_onnx_model(args.input_file)
+ context = Context()
+ model_info = onnx_importer.ModelInfo(model_proto)
+ m = model_info.create_module(context=context)
+ imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m)
+ imp.import_all()
+ if not args.no_verify:
+ m.verify()
+
+ # TODO: This isn't very efficient output. If these files ever
+ # get large, enable bytecode and direct binary emission to save
+ # some copies.
+ if args.output_file and args.output_file != "-":
+ with open(args.output_file, "wt") as f:
+ print(m.get_asm(assume_verified=not args.no_verify), file=f)
+ else:
+ print(m.get_asm(assume_verified=not args.no_verify))
+
+
+def load_onnx_model(file_path: Path) -> onnx.ModelProto:
+ raw_model = onnx.load(file_path)
+ inferred_model = onnx.shape_inference.infer_shapes(raw_model)
+ return inferred_model
+
+
+def parse_arguments(argv=None):
+ parser = argparse.ArgumentParser(description="IREE ONNX import tool")
+ parser.add_argument("input_file", help="ONNX protobuf input", type=Path)
+ parser.add_argument(
+ "-o", dest="output_file", help="Output path (or '-' for stdout)"
+ )
+ parser.add_argument(
+ "--no-verify",
+ action="store_true",
+ help="Disable verification prior to printing",
+ )
+ args = parser.parse_args(argv)
+ return args
+
+
+def _cli_main():
+ sys.exit(main(parse_arguments()))
+
+
+if __name__ == "__main__":
+ _cli_main()
diff --git a/compiler/bindings/python/iree/compiler/xla.py b/compiler/bindings/python/iree/compiler/xla.py
deleted file mode 100644
index e66d5fe..0000000
--- a/compiler/bindings/python/iree/compiler/xla.py
+++ /dev/null
@@ -1,10 +0,0 @@
-# Copyright 2021 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-import sys
-from .tools import xla
-
-sys.modules[__name__] = xla
diff --git a/compiler/bindings/python/test/tools/CMakeLists.txt b/compiler/bindings/python/test/tools/CMakeLists.txt
index 01456f9..4dd27a1 100644
--- a/compiler/bindings/python/test/tools/CMakeLists.txt
+++ b/compiler/bindings/python/test/tools/CMakeLists.txt
@@ -15,6 +15,15 @@
)
endif() # IREE_BUILD_BUNDLED_LLVM
+if(IREE_INPUT_TORCH)
+ iree_py_test(
+ NAME
+ import_onnx_test
+ SRCS
+ "import_onnx_test.py"
+ )
+endif()
+
iree_py_test(
NAME
ir_tool_test
diff --git a/compiler/bindings/python/test/tools/import_onnx_test.py b/compiler/bindings/python/test/tools/import_onnx_test.py
new file mode 100644
index 0000000..73eb56f
--- /dev/null
+++ b/compiler/bindings/python/test/tools/import_onnx_test.py
@@ -0,0 +1,58 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import os
+import sys
+import tempfile
+import unittest
+
+
+def run_tool(*argv: str):
+ try:
+ from iree.compiler.tools.import_onnx import __main__
+
+ args = __main__.parse_arguments(list(argv))
+ __main__.main(args)
+ except SystemExit as e:
+ if e.code != 0:
+ raise RuntimeError(f"Tool exited with code {e.code}")
+
+
+ONNX_FILE_PATH = os.path.join(os.path.dirname(__file__), "testdata", "LeakyReLU.onnx")
+
+
+class ImportOnnxTest(unittest.TestCase):
+ def setUp(self):
+ with tempfile.NamedTemporaryFile(delete=False) as f:
+ self.outputPath = f.name
+
+ def tearDown(self) -> None:
+ if os.path.exists(self.outputPath):
+ os.unlink(self.outputPath)
+
+ def testConsoleOutput(self):
+ # Just test that it doesn't crash: rely on the file test for verification.
+ run_tool(ONNX_FILE_PATH)
+
+ def testDisableVerify(self):
+ # Just test that the flag is accepted.
+ run_tool(ONNX_FILE_PATH, "--no-verify")
+
+ def testFileOutput(self):
+ run_tool(ONNX_FILE_PATH, "-o", self.outputPath)
+ with open(self.outputPath, "rt") as f:
+ contents = f.read()
+ self.assertIn("torch.operator", contents)
+
+
+if __name__ == "__main__":
+ try:
+ import onnx
+ except ModuleNotFoundError:
+ print(f"Skipping test {__file__} because Python dependency `onnx` is not found")
+ sys.exit(0)
+
+ unittest.main()
diff --git a/compiler/bindings/python/test/tools/testdata/LeakyReLU.onnx b/compiler/bindings/python/test/tools/testdata/LeakyReLU.onnx
new file mode 100644
index 0000000..f76bccb
--- /dev/null
+++ b/compiler/bindings/python/test/tools/testdata/LeakyReLU.onnx
@@ -0,0 +1,15 @@
+pytorch0.3:h
+"
+01" LeakyRelu*
+alpha
+×#< torch-jit-exportZ
+0
+
+
+
+b
+1
+
+
+
+B
\ No newline at end of file
diff --git a/compiler/bindings/python/test/tools/testdata/README.md b/compiler/bindings/python/test/tools/testdata/README.md
new file mode 100644
index 0000000..dcd5c4c
--- /dev/null
+++ b/compiler/bindings/python/test/tools/testdata/README.md
@@ -0,0 +1,10 @@
+# Importer test data.
+
+Most files have a generation script except for when it is expected that they
+will never change. Things in that category and break glass instructions to
+update:
+
+* LeakyReLU.onnx: Just a random single-op ONNX test to verify that the upstream
+ importer is wired properly. It should never need to be updated but if it
+ does, pretty much any single-op test case from the ONNX test suite will
+ suffice.
diff --git a/compiler/setup.py b/compiler/setup.py
index 01a347d..85b99f3 100644
--- a/compiler/setup.py
+++ b/compiler/setup.py
@@ -456,6 +456,7 @@
# TODO: We have renamed to iree-compile on 2022-03-18. Remove
# this alias once no longer needed.
"ireec = iree.compiler.tools.scripts.ireec.__main__:main",
+ "iree-import-onnx = iree.compiler.tools.import_onnx.__main__:_cli_main",
"iree-ir-tool = iree.compiler.tools.ir_tool.__main__:_cli_main",
],
},
@@ -463,4 +464,9 @@
"numpy",
"PyYAML",
],
+ extras_require={
+ "onnx": [
+ "onnx>=1.15.0",
+ ],
+ },
)