blob: 9a87970ed692661caf25d79361b452a228c440e0 [file] [log] [blame]
#!/usr/bin/python3
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Build platform specific wheel files for the pyiree.rt package.
# Built artifacts are per-platform and build out of the build tree.
from distutils.command.install import install
import os
import platform
from setuptools import setup, find_namespace_packages
README = r'''
# IREE Compiler Python Bindings
Transitional note: These bindings are not complete yet and will ultimately
replace the `pyiree.compiler` and `pyiree.tf.compiler` packages.
## Core compiler
```py
from pyiree.compiler2 import *
SIMPLE_MUL_ASM = """
func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
attributes { iree.module.export } {
%0 = "mhlo.multiply"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
"""
# Also see compile_file()
# There are many keyword options available.
# See pyiree.compiler2.CompilerOptions
binary = compile_str(SIMPLE_MUL_ASM, target_backends=["vulkan-spirv"])
```
## TensorFlow compiler
```py
import tensorflow as tf
from pyiree.compiler2.tf import *
class SimpleArithmeticModule(tf.Module):
@tf.function(input_signature=[
tf.TensorSpec([4], tf.float32),
tf.TensorSpec([4], tf.float32)
])
def simple_mul(self, a, b):
return a * b
# Also see compile_saved_model to directly compile an on-disk saved model.
# There are many keyword options available.
# See: pyiree.compiler2.tf.ImportOptions
binary = compile_module(
SimpleArithmeticModule(), target_backends=["vulkan-spirv"])
```
'''
exe_suffix = ".exe" if platform.system() == "Windows" else ""
# Force platform specific wheel.
# https://stackoverflow.com/questions/45150304
try:
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
class bdist_wheel(_bdist_wheel):
def finalize_options(self):
_bdist_wheel.finalize_options(self)
self.root_is_pure = False
def get_tag(self):
python, abi, plat = _bdist_wheel.get_tag(self)
# We don't contain any python extensions so are version agnostic
# but still want to be platform specific.
python, abi = 'py3', 'none'
return python, abi, plat
except ImportError:
bdist_wheel = None
# Force installation into platlib.
# Since this is a pure-python library with platform binaries, it is
# mis-detected as "pure", which fails audit. Usually, the presence of an
# extension triggers non-pure install. We force it here.
class platlib_install(install):
def finalize_options(self):
install.finalize_options(self)
self.install_lib = self.install_platlib
setup(
name="iree-compiler@IREE_RELEASE_PACKAGE_SUFFIX@",
version="@IREE_RELEASE_VERSION@",
author="The IREE Team",
author_email="iree-discuss@googlegroups.com",
license="Apache",
description="IREE Python Compiler API",
long_description=README,
long_description_content_type="text/markdown",
url="https://github.com/google/iree",
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache License",
"Operating System :: OS Independent",
"Development Status :: 3 - Alpha",
],
python_requires=">=3.6",
packages=find_namespace_packages(
include=["pyiree.compiler2", "pyiree.compiler2.*", "pyiree.tools.core"
]),
package_data={
"pyiree.tools.core": [f"iree-translate{exe_suffix}",],
},
cmdclass={
'bdist_wheel': bdist_wheel,
'install': platlib_install,
},
zip_safe=False, # This package is fine but not zipping is more versatile.
)