| #!/usr/bin/python3 |
| |
| # Copyright 2020 Google LLC |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # https://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| # Build platform specific wheel files for the pyiree.rt package. |
| # Built artifacts are per-platform and build out of the build tree. |
| |
| from distutils.command.install import install |
| import os |
| import platform |
| from setuptools import setup, find_namespace_packages |
| |
| README = r''' |
| # IREE Compiler Python Bindings |
| |
| Transitional note: These bindings are not complete yet and will ultimately |
| replace the `pyiree.compiler` and `pyiree.tf.compiler` packages. |
| |
| ## Core compiler |
| |
| ```py |
| from pyiree.compiler2 import * |
| |
| SIMPLE_MUL_ASM = """ |
| func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> |
| attributes { iree.module.export } { |
| %0 = "mhlo.multiply"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> |
| return %0 : tensor<4xf32> |
| } |
| """ |
| |
| # Also see compile_file() |
| # There are many keyword options available. |
| # See pyiree.compiler2.CompilerOptions |
| binary = compile_str(SIMPLE_MUL_ASM, target_backends=["vulkan-spirv"]) |
| ``` |
| |
| |
| ## TensorFlow compiler |
| |
| ```py |
| import tensorflow as tf |
| from pyiree.compiler2.tf import * |
| |
| class SimpleArithmeticModule(tf.Module): |
| |
| @tf.function(input_signature=[ |
| tf.TensorSpec([4], tf.float32), |
| tf.TensorSpec([4], tf.float32) |
| ]) |
| def simple_mul(self, a, b): |
| return a * b |
| |
| # Also see compile_saved_model to directly compile an on-disk saved model. |
| # There are many keyword options available. |
| # See: pyiree.compiler2.tf.ImportOptions |
| binary = compile_module( |
| SimpleArithmeticModule(), target_backends=["vulkan-spirv"]) |
| ``` |
| |
| ''' |
| |
| exe_suffix = ".exe" if platform.system() == "Windows" else "" |
| |
| # Force platform specific wheel. |
| # https://stackoverflow.com/questions/45150304 |
| try: |
| from wheel.bdist_wheel import bdist_wheel as _bdist_wheel |
| |
| class bdist_wheel(_bdist_wheel): |
| |
| def finalize_options(self): |
| _bdist_wheel.finalize_options(self) |
| self.root_is_pure = False |
| |
| def get_tag(self): |
| python, abi, plat = _bdist_wheel.get_tag(self) |
| # We don't contain any python extensions so are version agnostic |
| # but still want to be platform specific. |
| python, abi = 'py3', 'none' |
| return python, abi, plat |
| |
| except ImportError: |
| bdist_wheel = None |
| |
| |
| # Force installation into platlib. |
| # Since this is a pure-python library with platform binaries, it is |
| # mis-detected as "pure", which fails audit. Usually, the presence of an |
| # extension triggers non-pure install. We force it here. |
| class platlib_install(install): |
| |
| def finalize_options(self): |
| install.finalize_options(self) |
| self.install_lib = self.install_platlib |
| |
| |
| setup( |
| name="iree-compiler@IREE_RELEASE_PACKAGE_SUFFIX@", |
| version="@IREE_RELEASE_VERSION@", |
| author="The IREE Team", |
| author_email="iree-discuss@googlegroups.com", |
| license="Apache", |
| description="IREE Python Compiler API", |
| long_description=README, |
| long_description_content_type="text/markdown", |
| url="https://github.com/google/iree", |
| classifiers=[ |
| "Programming Language :: Python :: 3", |
| "License :: OSI Approved :: Apache License", |
| "Operating System :: OS Independent", |
| "Development Status :: 3 - Alpha", |
| ], |
| python_requires=">=3.6", |
| packages=find_namespace_packages( |
| include=["pyiree.compiler2", "pyiree.compiler2.*", "pyiree.tools.core" |
| ]), |
| package_data={ |
| "pyiree.tools.core": [f"iree-translate{exe_suffix}",], |
| }, |
| cmdclass={ |
| 'bdist_wheel': bdist_wheel, |
| 'install': platlib_install, |
| }, |
| zip_safe=False, # This package is fine but not zipping is more versatile. |
| ) |