Add proper package configuration for iree.jax Python package. (#6214)
* Will set it up to build in releases in a followup.
diff --git a/bindings/python/CMakeLists.txt b/bindings/python/CMakeLists.txt
index f7b489c..b80db7b 100644
--- a/bindings/python/CMakeLists.txt
+++ b/bindings/python/CMakeLists.txt
@@ -11,15 +11,12 @@
# Namespace packages.
add_subdirectory(iree/runtime)
+add_subdirectory(iree/jax)
if(${IREE_BUILD_COMPILER})
add_subdirectory(iree/compiler)
add_subdirectory(iree/tools/core)
endif()
-if(${IREE_BUILD_XLA_COMPILER})
-add_subdirectory(iree/jax)
-endif()
-
# Tests.
add_subdirectory(tests)
diff --git a/bindings/python/iree/jax/CMakeLists.txt b/bindings/python/iree/jax/CMakeLists.txt
index e140f4a..08bd99a 100644
--- a/bindings/python/iree/jax/CMakeLists.txt
+++ b/bindings/python/iree/jax/CMakeLists.txt
@@ -12,9 +12,20 @@
"frontend.py"
)
+# Only enable the tests if the XLA compiler is built.
+if(${IREE_BUILD_XLA_COMPILER})
iree_py_test(
NAME
frontend_test
SRCS
"frontend_test.py"
)
+endif()
+
+iree_py_install_package(
+ COMPONENT IreePythonPackage-jax
+ PACKAGE_NAME iree_jax
+ MODULE_PATH iree/jax
+ ADDL_PACKAGE_FILES
+ ${CMAKE_CURRENT_SOURCE_DIR}/README.md
+)
diff --git a/bindings/python/iree/jax/setup.py.in b/bindings/python/iree/jax/setup.py.in
new file mode 100644
index 0000000..59ad912
--- /dev/null
+++ b/bindings/python/iree/jax/setup.py.in
@@ -0,0 +1,45 @@
+#!/usr/bin/python3
+
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from distutils.command.install import install
+import os
+import platform
+from setuptools import setup, find_namespace_packages
+
+with open(os.path.join(os.path.dirname(__file__), "README.md"), "r") as f:
+ README = f.read()
+
+exe_suffix = ".exe" if platform.system() == "Windows" else ""
+
+setup(
+ name="iree-jax@IREE_RELEASE_PACKAGE_SUFFIX@",
+ version="@IREE_RELEASE_VERSION@",
+ author="The IREE Team",
+ author_email="iree-discuss@googlegroups.com",
+ license="Apache",
+ description="IREE JAX API",
+ long_description=README,
+ long_description_content_type="text/markdown",
+ url="https://github.com/google/iree",
+ classifiers=[
+ "Programming Language :: Python :: 3",
+ "License :: OSI Approved :: Apache License",
+ "Operating System :: OS Independent",
+ "Development Status :: 3 - Alpha",
+ ],
+ python_requires=">=3.6",
+ packages=find_namespace_packages(include=["iree.jax"]),
+ zip_safe=True,
+ install_requires = [
+ "jax",
+ "jaxlib",
+ "iree-compiler@IREE_RELEASE_PACKAGE_SUFFIX@==@IREE_RELEASE_VERSION@",
+ "iree-runtime@IREE_RELEASE_PACKAGE_SUFFIX@==@IREE_RELEASE_VERSION@",
+ "iree-tools-xla@IREE_RELEASE_PACKAGE_SUFFIX@==@IREE_RELEASE_VERSION@",
+ ],
+)
diff --git a/bindings/python/iree/jax/version.py.in b/bindings/python/iree/jax/version.py.in
new file mode 100644
index 0000000..ed72ac3
--- /dev/null
+++ b/bindings/python/iree/jax/version.py.in
@@ -0,0 +1,9 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+PACKAGE_SUFFIX = "@IREE_RELEASE_PACKAGE_SUFFIX@"
+VERSION = "@IREE_RELEASE_VERSION@"
+REVISION = "@IREE_RELEASE_REVISION@"
diff --git a/build_tools/cmake/iree_python.cmake b/build_tools/cmake/iree_python.cmake
index b5a825c..355f5a3 100644
--- a/build_tools/cmake/iree_python.cmake
+++ b/build_tools/cmake/iree_python.cmake
@@ -38,7 +38,7 @@
cmake_parse_arguments(ARG
"AUGMENT_EXISTING_PACKAGE"
"COMPONENT;PACKAGE_NAME;MODULE_PATH"
- "DEPS"
+ "DEPS;ADDL_PACKAGE_FILES"
${ARGN})
set(_install_component ${ARG_COMPONENT})
set(_install_packages_dir "${CMAKE_INSTALL_PREFIX}/python_packages/${ARG_PACKAGE_NAME}")
@@ -50,6 +50,7 @@
install(
FILES
${CMAKE_CURRENT_BINARY_DIR}/setup.py
+ ${ARG_ADDL_PACKAGE_FILES}
COMPONENT ${_install_component}
DESTINATION "${_install_packages_dir}"
)