Remove the legacy Jax API from the iree repo. (#8031)
* Remove the legacy Jax API from the iree repo.
This is being landed in the new iree-jax repo, alongside a more modern API: https://github.com/google/iree-jax/pull/2
diff --git a/.github/workflows/build_package.yml b/.github/workflows/build_package.yml
index 724c4c3..8c096ba 100644
--- a/.github/workflows/build_package.yml
+++ b/.github/workflows/build_package.yml
@@ -211,18 +211,6 @@
export CIBW_BEFORE_BUILD="python -m pip install --upgrade pip==21.3dev0 -f https://github.com/stellaraccident/pip/releases/tag/21.3dev20210925"
python -m cibuildwheel --output-dir bindist ./main_checkout/llvm-external-projects/iree-compiler-api
- # Pure python packages only need to do a minimal CMake configure and no
- # actual building, aside from setting up sources. If there are multiples,
- # it is fairly cheap to build them serially. Using cibuildwheel is a bit
- # overkill for this, but we keep it the same as the others for maintenance
- # value.
- - name: Build pure python wheels
- if: "matrix.build_package == 'py-pure-pkgs'"
- shell: bash
- run: |
- python ./main_checkout/build_tools/github_actions/build_dist.py py-pure-pkgs
- python -m pip wheel -w bindist --no-deps ./iree-install/python_packages/iree_jax
-
# Compiler tools wheels are not python version specific, so just build
# for one examplar python version.
- name: Build XLA Compiler Tools wheels
diff --git a/.github/workflows/validate_and_publish_release.yml b/.github/workflows/validate_and_publish_release.yml
index 1236c48..a51386b 100644
--- a/.github/workflows/validate_and_publish_release.yml
+++ b/.github/workflows/validate_and_publish_release.yml
@@ -41,7 +41,7 @@
- name: Install python packages
id: install_python_packages
run: |
- python -m pip install -f file://$PWD/artifact/ iree-compiler-snapshot iree-jax-snapshot iree-runtime-snapshot iree-tools-tflite-snapshot iree-tools-tf-snapshot iree-tools-xla-snapshot
+ python -m pip install -f file://$PWD/artifact/ iree-compiler-snapshot iree-runtime-snapshot iree-tools-tflite-snapshot iree-tools-tf-snapshot iree-tools-xla-snapshot
- name: Run iree-benchmark-module
id: run_iree_benchmark_module
run: ./bin/iree-benchmark-module --help
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 3e67822..ac03a37 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -42,7 +42,6 @@
option(IREE_BUILD_SAMPLES "Builds IREE sample projects." ON)
option(IREE_BUILD_TRACY "Builds tracy server tools." OFF)
-option(IREE_BUILD_LEGACY_JAX "Builds the legacy JAX Python API" ON)
option(IREE_BUILD_TENSORFLOW_ALL "Builds all TensorFlow compiler frontends." OFF)
option(IREE_BUILD_TENSORFLOW_COMPILER "Builds TensorFlow compiler frontend." "${IREE_BUILD_TENSORFLOW_ALL}")
option(IREE_BUILD_TFLITE_COMPILER "Builds the TFLite compiler frontend." "${IREE_BUILD_TENSORFLOW_ALL}")
diff --git a/bindings/python/CMakeLists.txt b/bindings/python/CMakeLists.txt
index 77bf1ce..545a800 100644
--- a/bindings/python/CMakeLists.txt
+++ b/bindings/python/CMakeLists.txt
@@ -11,8 +11,3 @@
# Namespace packages.
add_subdirectory(iree/runtime)
-
-if(IREE_BUILD_LEGACY_JAX)
- message(STATUS "Building legacy JAX API")
- add_subdirectory(iree/jax)
-endif()
diff --git a/bindings/python/iree/jax/CMakeLists.txt b/bindings/python/iree/jax/CMakeLists.txt
deleted file mode 100644
index 08bd99a..0000000
--- a/bindings/python/iree/jax/CMakeLists.txt
+++ /dev/null
@@ -1,31 +0,0 @@
-# Copyright 2021 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-iree_py_library(
- NAME
- jax
- SRCS
- "__init__.py"
- "frontend.py"
-)
-
-# Only enable the tests if the XLA compiler is built.
-if(${IREE_BUILD_XLA_COMPILER})
-iree_py_test(
- NAME
- frontend_test
- SRCS
- "frontend_test.py"
-)
-endif()
-
-iree_py_install_package(
- COMPONENT IreePythonPackage-jax
- PACKAGE_NAME iree_jax
- MODULE_PATH iree/jax
- ADDL_PACKAGE_FILES
- ${CMAKE_CURRENT_SOURCE_DIR}/README.md
-)
diff --git a/bindings/python/iree/jax/README.md b/bindings/python/iree/jax/README.md
deleted file mode 100644
index cba2485..0000000
--- a/bindings/python/iree/jax/README.md
+++ /dev/null
@@ -1,106 +0,0 @@
-# IREE–JAX Frontend
-
-## Requirements
-
-A local JAX installation is necessary in addition to IREE's Python requirements:
-
-```shell
-python -m pip install jax jaxlib
-```
-
-## Just In Time Compilation with Runtime Bindings
-
-A just-in-time compilation decorator similar to `jax.jit` is provided by
-`iree.jax.jit`:
-
-```python
-import iree.jax
-
-import jax
-import jax.numpy as jnp
-
-# 'backend' is one of 'vmvx', 'llvmaot' and 'vulkan' and defaults to 'llvmaot'.
-@iree.jax.jit(backend="llvmaot")
-def linear_relu_layer(params, x):
- w, b = params
- return jnp.max(jnp.matmul(x, w) + b, 0)
-
-w = jnp.zeros((784, 128))
-b = jnp.zeros(128)
-x = jnp.zeros((1, 784))
-
-linear_relu_layer([w, b], x)
-```
-
-## Ahead of Time Compilation
-
-An ahead-of-time compilation function provides a lower-level API for compiling a
-function with a specific input signature without creating the runtime bindings
-for execution within Python. This is primarily useful for targeting other
-runtime environments like Android.
-
-### Example: Compile a MLP and run it on Android
-
-Install the Android NDK according to the
-[Android Getting Started](https://google.github.io/iree/get-started/getting-started-android-cmake)
-doc, and then ensure the following environment variable is set:
-
-```shell
-export ANDROID_NDK=# NDK install location
-```
-
-The code below assumes that you have `flax` installed.
-
-```python
-import iree.jax
-
-import jax
-import jax.numpy as jnp
-import flax
-from flax import linen as nn
-
-
-class MLP(nn.Module):
-
- @nn.compact
- def __call__(self, x):
- x = x.reshape((x.shape[0], -1)) # Flatten.
- x = nn.Dense(128)(x)
- x = nn.relu(x)
- x = nn.Dense(10)(x)
- x = nn.log_softmax(x)
- return x
-
-
-image = jnp.zeros((1, 28, 28, 1))
-params = MLP().init(jax.random.PRNGKey(0), image)["params"]
-
-apply_args = [{"params": params}, image]
-options = dict(target_backends=["dylib-llvm-aot"],
- extra_args=["--iree-llvm-target-triple=aarch64-linux-android"])
-
-compiled_binary = iree.jax.aot(MLP().apply, *apply_args, **options)
-
-with open("/tmp/mlp_apply.vmfb", "wb") as f:
- f.write(compiled_binary)
-```
-
-IREE doesn't provide installable tools for Android at this time, so they'll need
-to be built according to the
-[Android Getting Started](https://google.github.io/iree/get-started/getting-started-android-cmake).
-Afterward, the compiled `.vmfb` can be pushed to an Android device and executed
-using `iree-run-module`:
-
-```shell
-adb push /tmp/mlp_apply.vmfb /data/local/tmp/
-adb push ../iree-build-android/iree/tools/iree-run-module /data/local/tmp/
-adb shell /data/local/tmp/iree-run-module \
- --driver=dylib \
- --module_file=/data/local/tmp/mlp_apply.vmfb \
- --entry_function=main \
- --function_input=128xf32 \
- --function_input=784x128xf32 \
- --function_input=10xf32 \
- --function_input=128x10xf32 \
- --function_input=1x28x28x1xf32
-```
diff --git a/bindings/python/iree/jax/__init__.py b/bindings/python/iree/jax/__init__.py
deleted file mode 100644
index c724179..0000000
--- a/bindings/python/iree/jax/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-# Copyright 2021 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from .frontend import *
diff --git a/bindings/python/iree/jax/frontend.py b/bindings/python/iree/jax/frontend.py
deleted file mode 100644
index 8b225ee..0000000
--- a/bindings/python/iree/jax/frontend.py
+++ /dev/null
@@ -1,145 +0,0 @@
-# Copyright 2021 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-import functools
-
-import iree.compiler.xla
-import iree.runtime
-
-try:
- import jax
-except ModuleNotFoundError as e:
- raise ModuleNotFoundError("iree.jax requires 'jax' and 'jaxlib' to be "
- "installed in your python environment.") from e
-
-# pytype thinks iree.jax is jax.
-# pytype: disable=module-attr
-
-__all__ = [
- "aot",
- "is_available",
- "jit",
-]
-
-_BACKEND_TO_TARGETS = {
- "vmvx": "vmvx",
- "llvmaot": "dylib-llvm-aot",
- "vulkan": "vulkan-spirv",
-}
-_BACKENDS = tuple(_BACKEND_TO_TARGETS.keys())
-
-
-def is_available():
- """Determine if the IREE–XLA compiler are available for JAX."""
- return iree.compiler.xla.is_available()
-
-
-def aot(function, *args, **options):
- """Traces and compiles a function, flattening the input args.
-
- This is intended to be a lower-level interface for compiling a JAX function to
- IREE without setting up the runtime bindings to use it within Python. A common
- usecase for this is compiling to Android (and similar targets).
-
- Args:
- function: The function to compile.
- args: The inputs to trace and compile the function for.
- **kwargs: Keyword args corresponding to xla.ImportOptions or CompilerOptions
- """
- xla_comp = jax.xla_computation(function)(*args)
- hlo_proto = xla_comp.as_serialized_hlo_module_proto()
- return iree.compiler.xla.compile_str(hlo_proto, **options)
-
-
-# A more JAX-native approach to jitting would be desireable here, however
-# implementing that reasonably would require using JAX internals, particularly
-# jax.linear_util.WrappedFun and helpers. The following is sufficient for many
-# usecases for the time being.
-
-
-class _JittedFunction:
-
- def __init__(self, function, driver: str, **options):
- self._function = function
- self._driver_config = iree.runtime.Config(driver)
- self._options = options
- self._memoized_signatures = {}
-
- def _get_signature(self, args_flat, in_tree):
- args_flat = [iree.runtime.normalize_value(arg) for arg in args_flat]
- return tuple((arg.shape, arg.dtype) for arg in args_flat) + (in_tree,)
-
- def _wrap_and_compile(self, signature, args_flat, in_tree):
- """Compiles the function for the given signature."""
-
- def wrapped_function(*args_flat):
- args, kwargs = jax.tree_unflatten(in_tree, args_flat)
- return self._function(*args, **kwargs)
-
- # Compile the wrapped_function to IREE.
- vm_flatbuffer = aot(wrapped_function, *args_flat, **self._options)
- vm_module = iree.runtime.VmModule.from_flatbuffer(vm_flatbuffer)
- module = iree.runtime.load_vm_module(vm_module, config=self._driver_config)
-
- # Get the output tree so it can be reconstructed from the outputs of the
- # compiled module. Duplicating execution here isn't ideal, and could
- # probably be avoided using internal APIs.
- args, kwargs = jax.tree_unflatten(in_tree, args_flat)
- _, out_tree = jax.tree_flatten(self._function(*args, **kwargs))
-
- self._memoized_signatures[signature] = (module, out_tree)
-
- def _get_compiled_artifacts(self, args, kwargs):
- """Returns the binary, loaded runtime module and out_tree."""
- args_flat, in_tree = jax.tree_flatten((args, kwargs))
- signature = self._get_signature(args_flat, in_tree)
-
- if signature not in self._memoized_signatures:
- self._wrap_and_compile(signature, args_flat, in_tree)
- return self._memoized_signatures[signature]
-
- def __call__(self, *args, **kwargs):
- """Executes the function on the provided inputs, compiling if necessary."""
- args_flat, _ = jax.tree_flatten((args, kwargs))
- # Use the uncompiled function if the inputs are being traced.
- if any(issubclass(type(arg), jax.core.Tracer) for arg in args_flat):
- return self._function(*args, **kwargs)
-
- module, out_tree = self._get_compiled_artifacts(args, kwargs)
- results = module.main(*args_flat)
- if results is not None:
- if not isinstance(results, tuple):
- results = (results,)
- return jax.tree_unflatten(out_tree, results)
- else:
- # Address IREE returning None instead of empty sequences.
- if out_tree == jax.tree_flatten([])[-1]:
- return []
- elif out_tree == jax.tree_flatten(())[-1]:
- return ()
- else:
- return results
-
-
-def jit(function=None, *, backend: str = "llvmaot", **options):
- """Compiles a function to the specified IREE backend."""
- if function is None:
- # 'function' will be None if @jit() is called with parens (e.g. to specify a
- # backend or **options). We return a partial function capturing these
- # options, which python will then apply as a decorator, and execution will
- # continue below.
- return functools.partial(jit, backend=backend, **options)
-
- # Parse the backend to more concrete compiler and runtime settings.
- if backend not in _BACKENDS:
- raise ValueError(
- f"Expected backend to be one of {_BACKENDS}, but got '{backend}'")
- target_backend = _BACKEND_TO_TARGETS[backend]
- driver = iree.runtime.TARGET_BACKEND_TO_DRIVER[target_backend]
- if "target_backends" not in options:
- options["target_backends"] = (target_backend,)
-
- return functools.wraps(function)(_JittedFunction(function, driver, **options))
diff --git a/bindings/python/iree/jax/frontend_test.py b/bindings/python/iree/jax/frontend_test.py
deleted file mode 100644
index 3bff19d..0000000
--- a/bindings/python/iree/jax/frontend_test.py
+++ /dev/null
@@ -1,171 +0,0 @@
-# Copyright 2021 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from absl.testing import absltest
-import iree.jax
-import iree.runtime
-import jax
-import jax.numpy as jnp
-import numpy as np
-
-# pytype thinks iree.jax is jax.
-# pytype: disable=module-attr
-
-TOLERANCE = {"rtol": 1e-6, "atol": 1e-6}
-
-
-def normal(shape):
- return np.random.normal(0, 1, shape).astype(np.float32)
-
-
-class SqrtNode:
-
- def __init__(self, x, y):
- self.x = x
- self.y = y
-
- def apply(self, z):
- return self.x * jnp.sqrt(self.y * z)
-
- def tree_flatten(self):
- return ((self.x, self.y), None)
-
- @classmethod
- def tree_unflatten(cls, aux_data, children):
- return cls(*children)
-
-
-class SquareNode:
-
- def __init__(self, x, y):
- self.x = x
- self.y = y
-
- def apply(self, z):
- return self.x * (self.y * z)**2
-
- def tree_flatten(self):
- return ((self.x, self.y), None)
-
- @classmethod
- def tree_unflatten(cls, aux_data, children):
- return cls(*children)
-
-
-class JAXFrontendTest(absltest.TestCase):
-
- def test_aot_pytree(self):
-
- def pytree_func(params, x):
- return jnp.max(jnp.matmul(x, params["w"]) + params["b"], 0)
-
- trace_args = [
- {
- "w": jnp.zeros((32, 8)),
- "b": jnp.zeros((8,))
- },
- jnp.zeros((1, 32)),
- ]
- binary = iree.jax.aot(pytree_func, *trace_args, target_backends=["vmvx"])
-
- def test_jit_pytree_return(self):
-
- @iree.jax.jit
- def apply_sqrt(pytree):
- return jax.tree_map(jnp.sqrt, pytree)
-
- np.random.seed(0)
- input_tree = {
- "a": [
- normal((2, 3)),
- {
- "b": normal(3)
- },
- ],
- "c": (
- {
- "d": [normal(2), normal(3)]
- },
- (normal(1), normal(4)),
- )
- }
-
- expected = jax.tree_map(jnp.sqrt, input_tree)
- expected_arrays, expected_tree = jax.tree_flatten(expected)
- result = apply_sqrt(input_tree)
- result_arrays, result_tree = jax.tree_flatten(result)
-
- self.assertEqual(expected_tree, result_tree)
- for expected_array, result_array in zip(expected_arrays, result_arrays):
- np.testing.assert_allclose(expected_array, result_array, **TOLERANCE)
-
- def test_iree_jit_of_iree_jit(self):
-
- @iree.jax.jit
- def add(a, b):
- return a + b
-
- @iree.jax.jit
- def mul_two(a):
- return add(a, a)
-
- self.assertEqual(mul_two(3), 6)
-
- def test_jax_jit_of_iree_jit(self):
-
- @iree.jax.jit
- def add(a, b):
- return a + b
-
- @jax.jit
- def mul_two(a):
- return add(a, a)
-
- self.assertEqual(mul_two(3), 6)
-
- def test_iree_jit_of_jax_jit(self):
-
- @jax.jit
- def add(a, b):
- return a + b
-
- @iree.jax.jit
- def mul_two(a):
- return add(a, a)
-
- self.assertEqual(mul_two(3), 6)
-
- def test_iree_jit_of_empty_iree_jit(self):
-
- @iree.jax.jit
- def sqrt_four():
- return jnp.sqrt(4)
-
- @iree.jax.jit
- def add_sqrt_four(a):
- return a + sqrt_four()
-
- self.assertEqual(add_sqrt_four(2), 4)
-
- def test_jit_pytree_method(self):
-
- @iree.jax.jit
- def apply_node(node, z):
- return node.apply(z)
-
- expected_sqrt = apply_node._function(SqrtNode(2, 3), 4)
- compiled_sqrt = apply_node(SqrtNode(2, 3), 4)
- np.testing.assert_allclose(compiled_sqrt, expected_sqrt, **TOLERANCE)
-
- expected_square = apply_node._function(SquareNode(2, 3), 4)
- compiled_square = apply_node(SquareNode(2, 3), 4)
- np.testing.assert_allclose(compiled_square, expected_square, **TOLERANCE)
-
-
-if __name__ == "__main__":
- jax.tree_util.register_pytree_node_class(SqrtNode)
- jax.tree_util.register_pytree_node_class(SquareNode)
- absltest.main()
diff --git a/bindings/python/iree/jax/setup.py.in b/bindings/python/iree/jax/setup.py.in
deleted file mode 100644
index db31229..0000000
--- a/bindings/python/iree/jax/setup.py.in
+++ /dev/null
@@ -1,45 +0,0 @@
-#!/usr/bin/python3
-
-# Copyright 2021 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from distutils.command.install import install
-import os
-import platform
-from setuptools import setup, find_namespace_packages
-
-with open(os.path.join(os.path.dirname(__file__), "README.md"), "r") as f:
- README = f.read()
-
-exe_suffix = ".exe" if platform.system() == "Windows" else ""
-
-setup(
- name="iree-jax@IREE_RELEASE_PACKAGE_SUFFIX@",
- version="@IREE_RELEASE_VERSION@",
- author="The IREE Team",
- author_email="iree-discuss@googlegroups.com",
- license="Apache",
- description="IREE JAX API",
- long_description=README,
- long_description_content_type="text/markdown",
- url="https://github.com/google/iree",
- classifiers=[
- "Programming Language :: Python :: 3",
- "License :: OSI Approved :: Apache License",
- "Operating System :: OS Independent",
- "Development Status :: 3 - Alpha",
- ],
- python_requires=">=3.7",
- packages=find_namespace_packages(include=["iree.jax"]),
- zip_safe=True,
- install_requires = [
- "jax",
- "jaxlib",
- "iree-compiler@IREE_RELEASE_PACKAGE_SUFFIX@==@IREE_RELEASE_VERSION@",
- "iree-runtime@IREE_RELEASE_PACKAGE_SUFFIX@==@IREE_RELEASE_VERSION@",
- "iree-tools-xla@IREE_RELEASE_PACKAGE_SUFFIX@==@IREE_RELEASE_VERSION@",
- ],
-)
diff --git a/bindings/python/iree/jax/version.py.in b/bindings/python/iree/jax/version.py.in
deleted file mode 100644
index ed72ac3..0000000
--- a/bindings/python/iree/jax/version.py.in
+++ /dev/null
@@ -1,9 +0,0 @@
-# Copyright 2021 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-PACKAGE_SUFFIX = "@IREE_RELEASE_PACKAGE_SUFFIX@"
-VERSION = "@IREE_RELEASE_VERSION@"
-REVISION = "@IREE_RELEASE_REVISION@"
diff --git a/build_tools/github_actions/build_dist.py b/build_tools/github_actions/build_dist.py
index f9e9545..ca72085 100644
--- a/build_tools/github_actions/build_dist.py
+++ b/build_tools/github_actions/build_dist.py
@@ -39,7 +39,6 @@
python ./main_checkout/build_tools/github_actions/build_dist.py main-dist
python ./main_checkout/build_tools/github_actions/build_dist.py py-runtime-pkg
- python ./main_checkout/build_tools/github_actions/build_dist.py py-pure-pkgs
python ./main_checkout/build_tools/github_actions/build_dist.py py-xla-compiler-tools-pkg
python ./main_checkout/build_tools/github_actions/build_dist.py py-tflite-compiler-tools-pkg
python ./main_checkout/build_tools/github_actions/build_dist.py py-tf-compiler-tools-pkg
@@ -168,48 +167,6 @@
tf.add(os.path.join(INSTALL_DIR, entry), arcname=entry, recursive=True)
-def build_py_pure_pkgs():
- """Performs a minimal build sufficient to produce pure python packages.
-
- This installs the following packages:
- - iree-install/python_packages/iree_jax
-
- Since these are pure python packages, it is expected that they will be built
- on a single examplar (i.e. Linux) distribution.
- """
- install_python_requirements()
-
- # Clean up install and build trees.
- shutil.rmtree(INSTALL_DIR, ignore_errors=True)
- remove_cmake_cache()
-
- # CMake configure.
- print("*** Configuring ***")
- subprocess.run([
- sys.executable,
- CMAKE_CI_SCRIPT,
- f"-B{BUILD_DIR}",
- f"-DCMAKE_INSTALL_PREFIX={INSTALL_DIR}",
- f"-DCMAKE_BUILD_TYPE=Release",
- f"-DIREE_BUILD_COMPILER=OFF",
- f"-DIREE_BUILD_PYTHON_BINDINGS=ON",
- f"-DIREE_BUILD_SAMPLES=OFF",
- f"-DIREE_BUILD_TESTS=OFF",
- ],
- check=True)
-
- print("*** Building ***")
- subprocess.run([
- sys.executable,
- CMAKE_CI_SCRIPT,
- "--build",
- BUILD_DIR,
- "--target",
- "install-IreePythonPackage-jax",
- ],
- check=True)
-
-
def build_py_runtime_pkg(instrumented: bool = False):
"""Builds the iree-install/python_packages/iree_runtime package.
@@ -420,8 +377,6 @@
build_py_runtime_pkg()
elif command == "instrumented-py-runtime-pkg":
build_py_runtime_pkg(instrumented=True)
-elif command == "py-pure-pkgs":
- build_py_pure_pkgs()
elif command == "py-xla-compiler-tools-pkg":
build_py_xla_compiler_tools_pkg()
elif command == "py-tflite-compiler-tools-pkg":