Remove the legacy Jax API from the iree repo. (#8031)

* Remove the legacy Jax API from the iree repo.

This is being landed in the new iree-jax repo, alongside a more modern API: https://github.com/google/iree-jax/pull/2
diff --git a/.github/workflows/build_package.yml b/.github/workflows/build_package.yml
index 724c4c3..8c096ba 100644
--- a/.github/workflows/build_package.yml
+++ b/.github/workflows/build_package.yml
@@ -211,18 +211,6 @@
           export CIBW_BEFORE_BUILD="python -m pip install --upgrade pip==21.3dev0 -f https://github.com/stellaraccident/pip/releases/tag/21.3dev20210925"
           python -m cibuildwheel --output-dir bindist ./main_checkout/llvm-external-projects/iree-compiler-api
 
-      # Pure python packages only need to do a minimal CMake configure and no
-      # actual building, aside from setting up sources. If there are multiples,
-      # it is fairly cheap to build them serially. Using cibuildwheel is a bit
-      # overkill for this, but we keep it the same as the others for maintenance
-      # value.
-      - name: Build pure python wheels
-        if: "matrix.build_package == 'py-pure-pkgs'"
-        shell: bash
-        run: |
-          python ./main_checkout/build_tools/github_actions/build_dist.py py-pure-pkgs
-          python -m pip wheel -w bindist --no-deps ./iree-install/python_packages/iree_jax
-
       # Compiler tools wheels are not python version specific, so just build
       # for one examplar python version.
       - name: Build XLA Compiler Tools wheels
diff --git a/.github/workflows/validate_and_publish_release.yml b/.github/workflows/validate_and_publish_release.yml
index 1236c48..a51386b 100644
--- a/.github/workflows/validate_and_publish_release.yml
+++ b/.github/workflows/validate_and_publish_release.yml
@@ -41,7 +41,7 @@
       - name: Install python packages
         id: install_python_packages
         run: |
-          python -m pip install -f file://$PWD/artifact/ iree-compiler-snapshot iree-jax-snapshot iree-runtime-snapshot iree-tools-tflite-snapshot iree-tools-tf-snapshot iree-tools-xla-snapshot
+          python -m pip install -f file://$PWD/artifact/ iree-compiler-snapshot iree-runtime-snapshot iree-tools-tflite-snapshot iree-tools-tf-snapshot iree-tools-xla-snapshot
       - name: Run iree-benchmark-module
         id: run_iree_benchmark_module
         run: ./bin/iree-benchmark-module --help
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 3e67822..ac03a37 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -42,7 +42,6 @@
 option(IREE_BUILD_SAMPLES "Builds IREE sample projects." ON)
 option(IREE_BUILD_TRACY "Builds tracy server tools." OFF)
 
-option(IREE_BUILD_LEGACY_JAX "Builds the legacy JAX Python API" ON)
 option(IREE_BUILD_TENSORFLOW_ALL "Builds all TensorFlow compiler frontends." OFF)
 option(IREE_BUILD_TENSORFLOW_COMPILER "Builds TensorFlow compiler frontend." "${IREE_BUILD_TENSORFLOW_ALL}")
 option(IREE_BUILD_TFLITE_COMPILER "Builds the TFLite compiler frontend." "${IREE_BUILD_TENSORFLOW_ALL}")
diff --git a/bindings/python/CMakeLists.txt b/bindings/python/CMakeLists.txt
index 77bf1ce..545a800 100644
--- a/bindings/python/CMakeLists.txt
+++ b/bindings/python/CMakeLists.txt
@@ -11,8 +11,3 @@
 
 # Namespace packages.
 add_subdirectory(iree/runtime)
-
-if(IREE_BUILD_LEGACY_JAX)
-  message(STATUS "Building legacy JAX API")
-  add_subdirectory(iree/jax)
-endif()
diff --git a/bindings/python/iree/jax/CMakeLists.txt b/bindings/python/iree/jax/CMakeLists.txt
deleted file mode 100644
index 08bd99a..0000000
--- a/bindings/python/iree/jax/CMakeLists.txt
+++ /dev/null
@@ -1,31 +0,0 @@
-# Copyright 2021 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-iree_py_library(
-  NAME
-    jax
-  SRCS
-    "__init__.py"
-    "frontend.py"
-)
-
-# Only enable the tests if the XLA compiler is built.
-if(${IREE_BUILD_XLA_COMPILER})
-iree_py_test(
-  NAME
-    frontend_test
-  SRCS
-    "frontend_test.py"
-)
-endif()
-
-iree_py_install_package(
-  COMPONENT IreePythonPackage-jax
-  PACKAGE_NAME iree_jax
-  MODULE_PATH iree/jax
-  ADDL_PACKAGE_FILES
-    ${CMAKE_CURRENT_SOURCE_DIR}/README.md
-)
diff --git a/bindings/python/iree/jax/README.md b/bindings/python/iree/jax/README.md
deleted file mode 100644
index cba2485..0000000
--- a/bindings/python/iree/jax/README.md
+++ /dev/null
@@ -1,106 +0,0 @@
-# IREE–JAX Frontend
-
-## Requirements
-
-A local JAX installation is necessary in addition to IREE's Python requirements:
-
-```shell
-python -m pip install jax jaxlib
-```
-
-## Just In Time Compilation with Runtime Bindings
-
-A just-in-time compilation decorator similar to `jax.jit` is provided by
-`iree.jax.jit`:
-
-```python
-import iree.jax
-
-import jax
-import jax.numpy as jnp
-
-# 'backend' is one of 'vmvx', 'llvmaot' and 'vulkan' and defaults to 'llvmaot'.
-@iree.jax.jit(backend="llvmaot")
-def linear_relu_layer(params, x):
-  w, b = params
-  return jnp.max(jnp.matmul(x, w) + b, 0)
-
-w = jnp.zeros((784, 128))
-b = jnp.zeros(128)
-x = jnp.zeros((1, 784))
-
-linear_relu_layer([w, b], x)
-```
-
-## Ahead of Time Compilation
-
-An ahead-of-time compilation function provides a lower-level API for compiling a
-function with a specific input signature without creating the runtime bindings
-for execution within Python. This is primarily useful for targeting other
-runtime environments like Android.
-
-### Example: Compile a MLP and run it on Android
-
-Install the Android NDK according to the
-[Android Getting Started](https://google.github.io/iree/get-started/getting-started-android-cmake)
-doc, and then ensure the following environment variable is set:
-
-```shell
-export ANDROID_NDK=# NDK install location
-```
-
-The code below assumes that you have `flax` installed.
-
-```python
-import iree.jax
-
-import jax
-import jax.numpy as jnp
-import flax
-from flax import linen as nn
-
-
-class MLP(nn.Module):
-
-  @nn.compact
-  def __call__(self, x):
-    x = x.reshape((x.shape[0], -1))  # Flatten.
-    x = nn.Dense(128)(x)
-    x = nn.relu(x)
-    x = nn.Dense(10)(x)
-    x = nn.log_softmax(x)
-    return x
-
-
-image = jnp.zeros((1, 28, 28, 1))
-params = MLP().init(jax.random.PRNGKey(0), image)["params"]
-
-apply_args = [{"params": params}, image]
-options = dict(target_backends=["dylib-llvm-aot"],
-               extra_args=["--iree-llvm-target-triple=aarch64-linux-android"])
-
-compiled_binary = iree.jax.aot(MLP().apply, *apply_args, **options)
-
-with open("/tmp/mlp_apply.vmfb", "wb") as f:
-  f.write(compiled_binary)
-```
-
-IREE doesn't provide installable tools for Android at this time, so they'll need
-to be built according to the
-[Android Getting Started](https://google.github.io/iree/get-started/getting-started-android-cmake).
-Afterward, the compiled `.vmfb` can be pushed to an Android device and executed
-using `iree-run-module`:
-
-```shell
-adb push /tmp/mlp_apply.vmfb /data/local/tmp/
-adb push ../iree-build-android/iree/tools/iree-run-module /data/local/tmp/
-adb shell /data/local/tmp/iree-run-module \
-  --driver=dylib \
-  --module_file=/data/local/tmp/mlp_apply.vmfb \
-  --entry_function=main \
-  --function_input=128xf32 \
-  --function_input=784x128xf32 \
-  --function_input=10xf32 \
-  --function_input=128x10xf32 \
-  --function_input=1x28x28x1xf32
-```
diff --git a/bindings/python/iree/jax/__init__.py b/bindings/python/iree/jax/__init__.py
deleted file mode 100644
index c724179..0000000
--- a/bindings/python/iree/jax/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-# Copyright 2021 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from .frontend import *
diff --git a/bindings/python/iree/jax/frontend.py b/bindings/python/iree/jax/frontend.py
deleted file mode 100644
index 8b225ee..0000000
--- a/bindings/python/iree/jax/frontend.py
+++ /dev/null
@@ -1,145 +0,0 @@
-# Copyright 2021 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-import functools
-
-import iree.compiler.xla
-import iree.runtime
-
-try:
-  import jax
-except ModuleNotFoundError as e:
-  raise ModuleNotFoundError("iree.jax requires 'jax' and 'jaxlib' to be "
-                            "installed in your python environment.") from e
-
-# pytype thinks iree.jax is jax.
-# pytype: disable=module-attr
-
-__all__ = [
-    "aot",
-    "is_available",
-    "jit",
-]
-
-_BACKEND_TO_TARGETS = {
-    "vmvx": "vmvx",
-    "llvmaot": "dylib-llvm-aot",
-    "vulkan": "vulkan-spirv",
-}
-_BACKENDS = tuple(_BACKEND_TO_TARGETS.keys())
-
-
-def is_available():
-  """Determine if the IREE–XLA compiler are available for JAX."""
-  return iree.compiler.xla.is_available()
-
-
-def aot(function, *args, **options):
-  """Traces and compiles a function, flattening the input args.
-
-  This is intended to be a lower-level interface for compiling a JAX function to
-  IREE without setting up the runtime bindings to use it within Python. A common
-  usecase for this is compiling to Android (and similar targets).
-
-  Args:
-    function: The function to compile.
-    args: The inputs to trace and compile the function for.
-    **kwargs: Keyword args corresponding to xla.ImportOptions or CompilerOptions
-  """
-  xla_comp = jax.xla_computation(function)(*args)
-  hlo_proto = xla_comp.as_serialized_hlo_module_proto()
-  return iree.compiler.xla.compile_str(hlo_proto, **options)
-
-
-# A more JAX-native approach to jitting would be desireable here, however
-# implementing that reasonably would require using JAX internals, particularly
-# jax.linear_util.WrappedFun and helpers. The following is sufficient for many
-# usecases for the time being.
-
-
-class _JittedFunction:
-
-  def __init__(self, function, driver: str, **options):
-    self._function = function
-    self._driver_config = iree.runtime.Config(driver)
-    self._options = options
-    self._memoized_signatures = {}
-
-  def _get_signature(self, args_flat, in_tree):
-    args_flat = [iree.runtime.normalize_value(arg) for arg in args_flat]
-    return tuple((arg.shape, arg.dtype) for arg in args_flat) + (in_tree,)
-
-  def _wrap_and_compile(self, signature, args_flat, in_tree):
-    """Compiles the function for the given signature."""
-
-    def wrapped_function(*args_flat):
-      args, kwargs = jax.tree_unflatten(in_tree, args_flat)
-      return self._function(*args, **kwargs)
-
-    # Compile the wrapped_function to IREE.
-    vm_flatbuffer = aot(wrapped_function, *args_flat, **self._options)
-    vm_module = iree.runtime.VmModule.from_flatbuffer(vm_flatbuffer)
-    module = iree.runtime.load_vm_module(vm_module, config=self._driver_config)
-
-    # Get the output tree so it can be reconstructed from the outputs of the
-    # compiled module. Duplicating execution here isn't ideal, and could
-    # probably be avoided using internal APIs.
-    args, kwargs = jax.tree_unflatten(in_tree, args_flat)
-    _, out_tree = jax.tree_flatten(self._function(*args, **kwargs))
-
-    self._memoized_signatures[signature] = (module, out_tree)
-
-  def _get_compiled_artifacts(self, args, kwargs):
-    """Returns the binary, loaded runtime module and out_tree."""
-    args_flat, in_tree = jax.tree_flatten((args, kwargs))
-    signature = self._get_signature(args_flat, in_tree)
-
-    if signature not in self._memoized_signatures:
-      self._wrap_and_compile(signature, args_flat, in_tree)
-    return self._memoized_signatures[signature]
-
-  def __call__(self, *args, **kwargs):
-    """Executes the function on the provided inputs, compiling if necessary."""
-    args_flat, _ = jax.tree_flatten((args, kwargs))
-    # Use the uncompiled function if the inputs are being traced.
-    if any(issubclass(type(arg), jax.core.Tracer) for arg in args_flat):
-      return self._function(*args, **kwargs)
-
-    module, out_tree = self._get_compiled_artifacts(args, kwargs)
-    results = module.main(*args_flat)
-    if results is not None:
-      if not isinstance(results, tuple):
-        results = (results,)
-      return jax.tree_unflatten(out_tree, results)
-    else:
-      # Address IREE returning None instead of empty sequences.
-      if out_tree == jax.tree_flatten([])[-1]:
-        return []
-      elif out_tree == jax.tree_flatten(())[-1]:
-        return ()
-      else:
-        return results
-
-
-def jit(function=None, *, backend: str = "llvmaot", **options):
-  """Compiles a function to the specified IREE backend."""
-  if function is None:
-    # 'function' will be None if @jit() is called with parens (e.g. to specify a
-    # backend or **options). We return a partial function capturing these
-    # options, which python will then apply as a decorator, and execution will
-    # continue below.
-    return functools.partial(jit, backend=backend, **options)
-
-  # Parse the backend to more concrete compiler and runtime settings.
-  if backend not in _BACKENDS:
-    raise ValueError(
-        f"Expected backend to be one of {_BACKENDS}, but got '{backend}'")
-  target_backend = _BACKEND_TO_TARGETS[backend]
-  driver = iree.runtime.TARGET_BACKEND_TO_DRIVER[target_backend]
-  if "target_backends" not in options:
-    options["target_backends"] = (target_backend,)
-
-  return functools.wraps(function)(_JittedFunction(function, driver, **options))
diff --git a/bindings/python/iree/jax/frontend_test.py b/bindings/python/iree/jax/frontend_test.py
deleted file mode 100644
index 3bff19d..0000000
--- a/bindings/python/iree/jax/frontend_test.py
+++ /dev/null
@@ -1,171 +0,0 @@
-# Copyright 2021 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from absl.testing import absltest
-import iree.jax
-import iree.runtime
-import jax
-import jax.numpy as jnp
-import numpy as np
-
-# pytype thinks iree.jax is jax.
-# pytype: disable=module-attr
-
-TOLERANCE = {"rtol": 1e-6, "atol": 1e-6}
-
-
-def normal(shape):
-  return np.random.normal(0, 1, shape).astype(np.float32)
-
-
-class SqrtNode:
-
-  def __init__(self, x, y):
-    self.x = x
-    self.y = y
-
-  def apply(self, z):
-    return self.x * jnp.sqrt(self.y * z)
-
-  def tree_flatten(self):
-    return ((self.x, self.y), None)
-
-  @classmethod
-  def tree_unflatten(cls, aux_data, children):
-    return cls(*children)
-
-
-class SquareNode:
-
-  def __init__(self, x, y):
-    self.x = x
-    self.y = y
-
-  def apply(self, z):
-    return self.x * (self.y * z)**2
-
-  def tree_flatten(self):
-    return ((self.x, self.y), None)
-
-  @classmethod
-  def tree_unflatten(cls, aux_data, children):
-    return cls(*children)
-
-
-class JAXFrontendTest(absltest.TestCase):
-
-  def test_aot_pytree(self):
-
-    def pytree_func(params, x):
-      return jnp.max(jnp.matmul(x, params["w"]) + params["b"], 0)
-
-    trace_args = [
-        {
-            "w": jnp.zeros((32, 8)),
-            "b": jnp.zeros((8,))
-        },
-        jnp.zeros((1, 32)),
-    ]
-    binary = iree.jax.aot(pytree_func, *trace_args, target_backends=["vmvx"])
-
-  def test_jit_pytree_return(self):
-
-    @iree.jax.jit
-    def apply_sqrt(pytree):
-      return jax.tree_map(jnp.sqrt, pytree)
-
-    np.random.seed(0)
-    input_tree = {
-        "a": [
-            normal((2, 3)),
-            {
-                "b": normal(3)
-            },
-        ],
-        "c": (
-            {
-                "d": [normal(2), normal(3)]
-            },
-            (normal(1), normal(4)),
-        )
-    }
-
-    expected = jax.tree_map(jnp.sqrt, input_tree)
-    expected_arrays, expected_tree = jax.tree_flatten(expected)
-    result = apply_sqrt(input_tree)
-    result_arrays, result_tree = jax.tree_flatten(result)
-
-    self.assertEqual(expected_tree, result_tree)
-    for expected_array, result_array in zip(expected_arrays, result_arrays):
-      np.testing.assert_allclose(expected_array, result_array, **TOLERANCE)
-
-  def test_iree_jit_of_iree_jit(self):
-
-    @iree.jax.jit
-    def add(a, b):
-      return a + b
-
-    @iree.jax.jit
-    def mul_two(a):
-      return add(a, a)
-
-    self.assertEqual(mul_two(3), 6)
-
-  def test_jax_jit_of_iree_jit(self):
-
-    @iree.jax.jit
-    def add(a, b):
-      return a + b
-
-    @jax.jit
-    def mul_two(a):
-      return add(a, a)
-
-    self.assertEqual(mul_two(3), 6)
-
-  def test_iree_jit_of_jax_jit(self):
-
-    @jax.jit
-    def add(a, b):
-      return a + b
-
-    @iree.jax.jit
-    def mul_two(a):
-      return add(a, a)
-
-    self.assertEqual(mul_two(3), 6)
-
-  def test_iree_jit_of_empty_iree_jit(self):
-
-    @iree.jax.jit
-    def sqrt_four():
-      return jnp.sqrt(4)
-
-    @iree.jax.jit
-    def add_sqrt_four(a):
-      return a + sqrt_four()
-
-    self.assertEqual(add_sqrt_four(2), 4)
-
-  def test_jit_pytree_method(self):
-
-    @iree.jax.jit
-    def apply_node(node, z):
-      return node.apply(z)
-
-    expected_sqrt = apply_node._function(SqrtNode(2, 3), 4)
-    compiled_sqrt = apply_node(SqrtNode(2, 3), 4)
-    np.testing.assert_allclose(compiled_sqrt, expected_sqrt, **TOLERANCE)
-
-    expected_square = apply_node._function(SquareNode(2, 3), 4)
-    compiled_square = apply_node(SquareNode(2, 3), 4)
-    np.testing.assert_allclose(compiled_square, expected_square, **TOLERANCE)
-
-
-if __name__ == "__main__":
-  jax.tree_util.register_pytree_node_class(SqrtNode)
-  jax.tree_util.register_pytree_node_class(SquareNode)
-  absltest.main()
diff --git a/bindings/python/iree/jax/setup.py.in b/bindings/python/iree/jax/setup.py.in
deleted file mode 100644
index db31229..0000000
--- a/bindings/python/iree/jax/setup.py.in
+++ /dev/null
@@ -1,45 +0,0 @@
-#!/usr/bin/python3
-
-# Copyright 2021 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from distutils.command.install import install
-import os
-import platform
-from setuptools import setup, find_namespace_packages
-
-with open(os.path.join(os.path.dirname(__file__), "README.md"), "r") as f:
-  README = f.read()
-
-exe_suffix = ".exe" if platform.system() == "Windows" else ""
-
-setup(
-    name="iree-jax@IREE_RELEASE_PACKAGE_SUFFIX@",
-    version="@IREE_RELEASE_VERSION@",
-    author="The IREE Team",
-    author_email="iree-discuss@googlegroups.com",
-    license="Apache",
-    description="IREE JAX API",
-    long_description=README,
-    long_description_content_type="text/markdown",
-    url="https://github.com/google/iree",
-    classifiers=[
-        "Programming Language :: Python :: 3",
-        "License :: OSI Approved :: Apache License",
-        "Operating System :: OS Independent",
-        "Development Status :: 3 - Alpha",
-    ],
-    python_requires=">=3.7",
-    packages=find_namespace_packages(include=["iree.jax"]),
-    zip_safe=True,
-    install_requires = [
-      "jax",
-      "jaxlib",
-      "iree-compiler@IREE_RELEASE_PACKAGE_SUFFIX@==@IREE_RELEASE_VERSION@",
-      "iree-runtime@IREE_RELEASE_PACKAGE_SUFFIX@==@IREE_RELEASE_VERSION@",
-      "iree-tools-xla@IREE_RELEASE_PACKAGE_SUFFIX@==@IREE_RELEASE_VERSION@",
-    ],
-)
diff --git a/bindings/python/iree/jax/version.py.in b/bindings/python/iree/jax/version.py.in
deleted file mode 100644
index ed72ac3..0000000
--- a/bindings/python/iree/jax/version.py.in
+++ /dev/null
@@ -1,9 +0,0 @@
-# Copyright 2021 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-PACKAGE_SUFFIX = "@IREE_RELEASE_PACKAGE_SUFFIX@"
-VERSION = "@IREE_RELEASE_VERSION@"
-REVISION = "@IREE_RELEASE_REVISION@"
diff --git a/build_tools/github_actions/build_dist.py b/build_tools/github_actions/build_dist.py
index f9e9545..ca72085 100644
--- a/build_tools/github_actions/build_dist.py
+++ b/build_tools/github_actions/build_dist.py
@@ -39,7 +39,6 @@
 
   python ./main_checkout/build_tools/github_actions/build_dist.py main-dist
   python ./main_checkout/build_tools/github_actions/build_dist.py py-runtime-pkg
-  python ./main_checkout/build_tools/github_actions/build_dist.py py-pure-pkgs
   python ./main_checkout/build_tools/github_actions/build_dist.py py-xla-compiler-tools-pkg
   python ./main_checkout/build_tools/github_actions/build_dist.py py-tflite-compiler-tools-pkg
   python ./main_checkout/build_tools/github_actions/build_dist.py py-tf-compiler-tools-pkg
@@ -168,48 +167,6 @@
       tf.add(os.path.join(INSTALL_DIR, entry), arcname=entry, recursive=True)
 
 
-def build_py_pure_pkgs():
-  """Performs a minimal build sufficient to produce pure python packages.
-
-  This installs the following packages:
-    - iree-install/python_packages/iree_jax
-
-  Since these are pure python packages, it is expected that they will be built
-  on a single examplar (i.e. Linux) distribution.
-  """
-  install_python_requirements()
-
-  # Clean up install and build trees.
-  shutil.rmtree(INSTALL_DIR, ignore_errors=True)
-  remove_cmake_cache()
-
-  # CMake configure.
-  print("*** Configuring ***")
-  subprocess.run([
-      sys.executable,
-      CMAKE_CI_SCRIPT,
-      f"-B{BUILD_DIR}",
-      f"-DCMAKE_INSTALL_PREFIX={INSTALL_DIR}",
-      f"-DCMAKE_BUILD_TYPE=Release",
-      f"-DIREE_BUILD_COMPILER=OFF",
-      f"-DIREE_BUILD_PYTHON_BINDINGS=ON",
-      f"-DIREE_BUILD_SAMPLES=OFF",
-      f"-DIREE_BUILD_TESTS=OFF",
-  ],
-                 check=True)
-
-  print("*** Building ***")
-  subprocess.run([
-      sys.executable,
-      CMAKE_CI_SCRIPT,
-      "--build",
-      BUILD_DIR,
-      "--target",
-      "install-IreePythonPackage-jax",
-  ],
-                 check=True)
-
-
 def build_py_runtime_pkg(instrumented: bool = False):
   """Builds the iree-install/python_packages/iree_runtime package.
 
@@ -420,8 +377,6 @@
   build_py_runtime_pkg()
 elif command == "instrumented-py-runtime-pkg":
   build_py_runtime_pkg(instrumented=True)
-elif command == "py-pure-pkgs":
-  build_py_pure_pkgs()
 elif command == "py-xla-compiler-tools-pkg":
   build_py_xla_compiler_tools_pkg()
 elif command == "py-tflite-compiler-tools-pkg":