Remove experimental/dispatch_profiler. (#17287)

This code hasn't been touched in ~8 months and the workflow has been
failing for ~4 months:
https://github.com/iree-org/iree/actions/workflows/run_iree_dispatch_profiler.yml.
The code will still be useful to reference in git history, but deleting
it will keep the repository tidier and save us ~10 minutes of CI time
per day across large Linux build machines and a2-highgpu-1g (A100, which
is expensive).

Separately, someone who manages the
https://storage.googleapis.com/dispatch-profiler-artifacts cloud bucket
may want to check if there are ongoing storage costs there (may want to
delete existing files or add a TTL).

See also the discussion [here on
Discord](https://discord.com/channels/689900678990135345/689906000043573354/1237060498831315024).
diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
index 094306c..a2ae5da 100644
--- a/.github/CODEOWNERS
+++ b/.github/CODEOWNERS
@@ -32,7 +32,6 @@
 # Experimental
 # It's experimental, but we still don't want any old directory added here.
 /experimental/ @benvanik @stellaraccident
-/experimental/dispatch_profiler/ @manishucsd
 /experimental/rocm/ @benvanik
 /experimental/web/ @ScottTodd
 /experimental/webgpu/ @benvanik @ScottTodd
diff --git a/.github/workflows/run_iree_dispatch_profiler.yml b/.github/workflows/run_iree_dispatch_profiler.yml
deleted file mode 100644
index 7335df0..0000000
--- a/.github/workflows/run_iree_dispatch_profiler.yml
+++ /dev/null
@@ -1,79 +0,0 @@
-# Copyright 2023 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#
-# Workflow for running IREE Dispatch Profiler
-# (https://github.com/iree-org/iree/tree/main/experimental/dispatch_profiler).
-#
-# IREE Dispatch Profiler validates functional correctness and performs profiling
-# on matmul, conv, and other ops on IREE CUDA and other back-ends.
-#
-# It uploads results to this a publicly accessible GCS bucket:
-# https://storage.googleapis.com/dispatch-profiler-artifacts
-
-name: dispatch_profiler
-
-on:
-  schedule:
-    # Run daily at 2:00 PM (14:00) every weekday (Monday to Friday)
-    - cron: "0 14 * * 1-5"
-  workflow_dispatch:
-
-concurrency:
-  # A PR number if a pull request and otherwise the commit hash. This cancels
-  # queued and in-progress runs for the same PR (presubmit) or commit
-  # (postsubmit). The workflow name is prepended to avoid conflicts between
-  # different workflows.
-  group: ${{ github.workflow }}-${{ github.event.number || github.sha }}
-  cancel-in-progress: true
-
-jobs:
-  setup:
-    uses: ./.github/workflows/setup.yml
-
-  build_all:
-    needs: setup
-    if: contains(fromJson(needs.setup.outputs.enabled-jobs), 'build_all')
-    uses: ./.github/workflows/build_all.yml
-    with:
-      runner-group: ${{ needs.setup.outputs.runner-group }}
-      runner-env: ${{ needs.setup.outputs.runner-env }}
-      write-caches: ${{ needs.setup.outputs.write-caches }}
-
-  profile_cuda:
-    needs: [setup, build_all]
-    if: contains(fromJson(needs.setup.outputs.enabled-jobs), 'profile_cuda')
-    runs-on:
-      - self-hosted # must come first
-      - runner-group=${{ needs.setup.outputs.runner-group }}
-      - environment=${{ needs.setup.outputs.runner-env }}
-      - machine-type=a2-highgpu-1g
-    env:
-      RESULTS_DIR: dispatch-profiler-results-cuda
-      GCS_UPLOAD_PARENT_DIR: "gs://dispatch-profiler-artifacts/cuda"
-      GCS_UPLOAD_DIR_NAME: ${{ needs.setup.outputs.artifact-upload-dir }}
-      INSTALL_DIR: ${{ needs.build_all.outputs.install-dir }}
-      INSTALL_DIR_ARCHIVE: ${{ needs.build_all.outputs.install-dir-archive }}
-      INSTALL_DIR_GCS_ARTIFACT: ${{ needs.build_all.outputs.install-dir-gcs-artifact }}
-      IREE_SHA: ${{ github.sha }}
-    steps:
-      - name: "Checking out repository"
-        uses: actions/checkout@8f4b7f84864484a7bf31766abe9204da3cbe65b3 # v3.5.0
-      - name: "Downloading install dir archive"
-        run: gcloud storage cp "${INSTALL_DIR_GCS_ARTIFACT}" "${INSTALL_DIR_ARCHIVE}"
-      - name: "Extracting install directory"
-        run: tar -vxf "${INSTALL_DIR_ARCHIVE}"
-      - name: "Running IREE Dispatch Profiler on CUDA"
-        run: |
-          mkdir "${RESULTS_DIR}"
-          ./build_tools/github_actions/docker_run.sh \
-            --gpus all \
-            gcr.io/iree-oss/nvidia-bleeding-edge@sha256:81b3b5485f962c978bb7e5b2a6ded44ae4ef432048cafffe2b74fcf6dbe1bbca \
-            ./experimental/dispatch_profiler/profile_all.sh "${INSTALL_DIR}/bin" \
-            "${RESULTS_DIR}"
-      - name: "Uploading results"
-        run: |
-          GCS_ARTIFACT_DIR="$(date +'%Y-%m-%d').sha_${IREE_SHA}.timestamp_$(date +'%s')"
-          gcloud storage cp "${RESULTS_DIR}/**" "${GCS_UPLOAD_PARENT_DIR}/${GCS_ARTIFACT_DIR}/"
diff --git a/experimental/dispatch_profiler/README.md b/experimental/dispatch_profiler/README.md
deleted file mode 100644
index c09152d..0000000
--- a/experimental/dispatch_profiler/README.md
+++ /dev/null
@@ -1,173 +0,0 @@
-# IREE Dispatch Profiler
-
-The IREE Dispatch Profiler is a Python-based tool designed to achieve two primary objectives: functional verification and performance profiling for individual dispatches, such as matrix multiplication, batch matrix multiplication, and convolutions. This tool ensures that performance optimizations maintain functionality and provides a convenient way to quantitatively measure performance. Additionally, the tool offers dispatch generation and compilation capabilities. In summary, the IREE dispatch profiler accomplishes the following:
-
-- Auto-generation of MLIR dispatches (e.g., matmul, batch_matmul, convolutions, fused dispatches).
-- Compilation of generated MLIR dispatches into binaries (vmfb).
-- Functional verification against Python-based reference implementations.
-- Performance profiling and reporting.
-
-## Definitions
-
-- Operation: An operation structure captures and refers to the functional description of an operation. For example, a Matmul operation includes the datatype, layout, and matrix multiplication problem shape.
-- Tuning Configuration: Tuning configurations are attributes applied to the IREE compilation flow that can alter the performance of the compiled dispatch without affecting its functionality.
-- Dispatch: A dispatch is a combination of an operation and its corresponding tuning configuration.
-
-## Auto-generation of MLIR Dispatches
-
-IREE dispatch profiler provides [`generator.py`](generator.py) that can be used to generate dispatches. Please find a sample run below:
-
-```bash
-$ python3 dispatch_profiler/generator.py --generated-dir </path/to/create/`generated`/dir>
-[Generating]: ./generated/linalg/matmul/matmul_128x128x256_f16t_f16t_f16t/matmul_128x128x256_f16t_f16t_f16t.mlir
-    Emitting tuning configuration : tile_config_128x128_64x4_tensorcore_mmasync
-    Emitting tuning configuration : tile_config_128x128_32x5_tensorcore_mmasync
-    Emitting tuning configuration : tile_config_128x64_32x5_tensorcore_mmasync
-    Emitting tuning configuration : tile_config_64x64_64x5_tensorcore_mmasync
-    Emitting tuning configuration : tile_config_64x64_32x10_tensorcore_mmasync
-    ...
-```
-
-This creates a `generated` folder containing dispatches organized in folders as `mlir_dialect/operation_name/`. The folder includes an .mlir file with all the dispatches for an operation.
-
-The `generator.py` script serves as a generator for implemented operation data types, using a predefined list of problem shapes. You can also provide specific matrix multiplication shapes of interest. Examples are provided below.
-
-#### Generating user-specified matmul shape `768x512x1024`
-
-```bash
-python3 ../iree/experimental/dispatch_profiler/generator.py --generated-dir </path/to/create/`generated`/dir> --problem-m=768 --problem-n=512 --problem-k=1024
-...
-[Generating]: ./generated/linalg/matmul/matmul_768x512x1024_f16t_f16t_f16t/matmul_768x512x1024_f16t_f16t_f16t.mlir
-[Generating]: ./generated/linalg/matmul/matmul_768x512x1024_f32t_f32t_f32t/matmul_768x512x1024_f32t_f32t_f32t.mlir
-...
-```
-
-#### Generate a user-specified sweep of matmul shapes
-
-Generate matmuls where M ranges from 64 to 1024 in increments of 128, N varies from 64 to 1024 in steps of 128, and K is fixed at 4096.
-
-```bash
-$ python3 ../iree/experimental/dispatch_profiler/generator.py --generated-dir </path/to/create/`generated`/dir> --problem-m=64:1024:128 --problem-n=64:1024:128 --problem-k=4096
-...
-```
-
-## Compilation of generated MLIR dispatches into binaries (vmfb)
-
-IREE dispatch profiler provies `compile.py` that trigges `iree-compile` with appropiate compilation flags. The output of `iree-compile` vmfb files are placed in `mlir_dialect/operation_path/operation_name.mlir`. The `compiler.py` uses all the possible cpus on your machine to compile all different generated mlir source files.
-
-```bash
-python3 ../iree/experimental/dispatch_profiler/compile.py --build-dir </path/to/iree/build/dir> --generated-dir </path/to/create/`generated`/dir>
-```
-
-Compiles all the generated source mlir dispatches. One can check the generated dispatched folder to find the vmfb files.
-
-```bash
-$ ls ./generated/linalg/matmul/matmul_64x64x4096_f16t_f16t_f16t/
-iree_compile_cmd_stdout.mlir  matmul_64x64x4096_f16t_f16t_f16t.mlir  matmul_64x64x4096_f16t_f16t_f16t_profile.vmfb  matmul_64x64x4096_f16t_f16t_f16t_verify.vmfb
-```
-
-## Functional verification and performance profiling
-
-The tool provides [`profiler.py`](profiler.py) script which can be used to trigger both verification and profiler for all the compiled dispatches. Please find some example profiling commandlines below:
-
-### Functional verification and performance profiling of a _single_ dispatch
-
-```
-$ python3 profiler.py --build-dir </path/to/iree/build/dir> --generated-dir </path/to/create/`generated`/dir> --dispatches=matmul_3456x1024x2048_f16t_f16t_f16t_tile_config_128x128_32x5_tensorcore_mmasync --verification-enabled=true --profiling-enabled=true
----------------------------------------------------------------- 
-Dispatch      : matmul_3456x1024x2048_f16t_f16t_f16t_tile_config_128x128_32x5_tensorcore_mmasync
-Provider      : IREE Codegen
-OpKind        : OperationKind.Matmul
-Operation     : matmul_3456x1024x2048_f16t_f16t_f16t
-Configuration : tile_config_128x128_32x5_tensorcore_mmasync
-Arguments     : --batch_count=1 --m=3456 --n=1024 --k=2048 --lhs=f16t --rhs=f16t --result=f16t
-                --split_k_mode=N/A --split_k_slices=N/A
-Verification  : SUCCESS
-Runtime(ms)   : 0.062
-GFLOPs        : 233798.62
-```
-
-### Performance profiling _single_ dispatch
-
-Verification, particularly for large matrix multiplications, can be time-consuming when using a CPU-based numpy reference. To prioritize profiling speed and when functional correctness is assured, disable verification using `--verification-enabled=false`.
-
-```bash
-python3 profiler.py --build-dir </path/to/iree/build/dir> --generated-dir </path/to/create/`generated`/dir> --dispatches=matmul_3456x1024x2048_f16t_f16t_f16t_tile_config_128x128_32x5_tensorcore_mmasync --verification-enabled=false --profiling-enabled=true
-```
-
-### Performance profile _single_ operation and _sweep_ tunning configurations
-
-The `--dispatch` option accepts a comma-separated list of regex patterns to profile all tuning configurations generated for a operation. The command-line argument is formatted as `--dispatch=<regex>,<regex>`. Additionally, you can export the profiled output to a CSV file for further analysis using `--output=<filepath>`.
-
-```bash
-$ python3 profiler.py --build-dir </path/to/iree/build/dir> --generated-dir </path/to/create/`generated`/dir> --dispatches=matmul_3456x1024x2048_f16t_f16t_f16t_*_tensorcore_mmasync --verification-enabled=false --profiling-enabled=true --output=data.csv
----------------------------------------------------------------- 
-Dispatch      : matmul_3456x1024x2048_f16t_f16t_f16t_tile_config_128x256_32x3_tensorcore_mmasync
-Provider      : IREE Codegen
-OpKind        : OperationKind.Matmul
-Operation     : matmul_3456x1024x2048_f16t_f16t_f16t
-Configuration : tile_config_128x256_32x3_tensorcore_mmasync
-Arguments     : --batch_count=1 --m=3456 --n=1024 --k=2048 --lhs=f16t --rhs=f16t --result=f16t
-                --split_k_mode=N/A --split_k_slices=N/A
-Verification  : Not verified
-Runtime(ms)   : 0.062
-GFLOPs        : 233798.62
----------------------------------------------------------------- 
-Dispatch      : matmul_3456x1024x2048_f16t_f16t_f16t_tile_config_128x128_64x4_tensorcore_mmasync
-Provider      : IREE Codegen
-OpKind        : OperationKind.Matmul
-Operation     : matmul_3456x1024x2048_f16t_f16t_f16t
-Configuration : tile_config_128x128_64x4_tensorcore_mmasync
-Arguments     : --batch_count=1 --m=3456 --n=1024 --k=2048 --lhs=f16t --rhs=f16t --result=f16t
-                --split_k_mode=N/A --split_k_slices=N/A
-Verification  : Not verified
-Runtime(ms)   : 0.064
-GFLOPs        : 226492.42
----------------------------------------------------------------- 
-...
-----------------------------------------------------------------
-Dispatch      : matmul_3456x1024x2048_f16t_f16t_f16t_tile_config_64x64_32x10_tensorcore_mmasync
-Provider      : IREE Codegen
-OpKind        : OperationKind.Matmul
-Operation     : matmul_3456x1024x2048_f16t_f16t_f16t
-Configuration : tile_config_64x64_32x10_tensorcore_mmasync
-Arguments     : --batch_count=1 --m=3456 --n=1024 --k=2048 --lhs=f16t --rhs=f16t --result=f16t
-                --split_k_mode=N/A --split_k_slices=N/A
-Verification  : Not verified
-Runtime(ms)   : 0.103
-GFLOPs        : 140733.15
-
-Writing performance report to data.csv
-
-```
-
-### Performance profiling a large matmul targetting _F16_ and _F32_ datatype
-
-Another example showcasing the use of `--dispatch` to profile a matmul_3456x1024x2048 targetting F16 and F32 NVIDIA A100 Tensor Cores.
-
-```bash
-$ python3 profiler.py --build-dir </path/to/iree/build/dir> --generated-dir </path/to/create/`generated`/dir>  --dispatches=matmul_3456x1024x2048_f16t_f16t_f16t_tile_config_128x128_32x5_tensorcore_mmasync,matmul_3456x1024x2048_f32t_f32t_f32t_tile_config_128x128_16x5_tensorcore_mmasync 
----------------------------------------------------------------- 
-Dispatch      : matmul_3456x1024x2048_f16t_f16t_f16t_tile_config_128x128_32x5_tensorcore_mmasync
-Provider      : IREE Codegen
-OpKind        : OperationKind.Matmul
-Operation     : matmul_3456x1024x2048_f16t_f16t_f16t
-Configuration : tile_config_128x128_32x5_tensorcore_mmasync
-Arguments     : --batch_count=1 --m=3456 --n=1024 --k=2048 --lhs=f16t --rhs=f16t --result=f16t
-                --split_k_mode=N/A --split_k_slices=N/A
-Verification  : SUCCESS
-Runtime(ms)   : 0.062
-GFLOPs        : 233798.62
----------------------------------------------------------------- 
-Dispatch      : matmul_3456x1024x2048_f32t_f32t_f32t_tile_config_128x128_16x5_tensorcore_mmasync
-Provider      : IREE Codegen
-OpKind        : OperationKind.Matmul
-Operation     : matmul_3456x1024x2048_f32t_f32t_f32t
-Configuration : tile_config_128x128_16x5_tensorcore_mmasync
-Arguments     : --batch_count=1 --m=3456 --n=1024 --k=2048 --lhs=f32t --rhs=f32t --result=f32t
-                --split_k_mode=N/A --split_k_slices=N/A
-Verification  : SUCCESS
-Runtime(ms)   : 0.122
-GFLOPs        : 118815.69
-----------------------------------------------------------------
-```
diff --git a/experimental/dispatch_profiler/batch_matmul.py b/experimental/dispatch_profiler/batch_matmul.py
deleted file mode 100644
index 9b8c401..0000000
--- a/experimental/dispatch_profiler/batch_matmul.py
+++ /dev/null
@@ -1,264 +0,0 @@
-# Copyright 2022 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from library import *
-from dispatch import *
-from matmul import MatmulOperation, MatmulCompilationInfo, CudaMatmulGenerator
-
-
-class BatchMatmulOperation(MatmulOperation):
-    """Data structure to describe a batch matrix multiplication operation."""
-
-    def __init__(self, bmm_shape, lhs, rhs, result):
-        assert len(bmm_shape) == 4, "Batch matmul shape must be 4D"
-        super().__init__(
-            bmm_shape[1:], lhs, rhs, result, bmm_shape[0], 1, OperationKind.BatchMatmul
-        )
-
-    def name(self):
-        return (
-            f"{OperationKindNames[self.operation_kind]}_"
-            f"{self.batch_count}x{self.M}x{self.N}x{self.K}_"
-            f"{DataTypeName[self.lhs.datatype]}{ShortLayoutTypeName[self.lhs.layout]}_"
-            f"{DataTypeName[self.rhs.datatype]}{ShortLayoutTypeName[self.rhs.layout]}_"
-            f"{DataTypeName[self.result.datatype]}{ShortLayoutTypeName[self.result.layout]}"
-        )
-
-    def lhs_npy_shape(self):
-        return f"{self.batch_count}x{super().lhs_npy_shape()}"
-
-    def rhs_npy_shape(self):
-        return f"{self.batch_count}x{super().rhs_npy_shape()}"
-
-    def result_npy_shape(self):
-        return f"{self.batch_count}x{super().result_npy_shape()}"
-
-
-class EmitLinalgBatchMatmulDispatch:
-    """Emitters for the `linalg.batch_matmul` dispatch."""
-
-    def __init__(self):
-        self.mlir_dialect = MlirDialect.Linalg
-
-        self.linalg_row_row_matmul_template = """
-// Dispatch linalg.batch_matmul row-row layout 
-func.func @${operation_name}_${compilation_info_name}(
-  %lhs: tensor<${batch_count}x${problem_m}x${problem_k}x${datatype_lhs}>,
-  %rhs: tensor<${batch_count}x${problem_k}x${problem_n}x${datatype_rhs}>) -> tensor<${batch_count}x${problem_m}x${problem_n}x${datatype_result}>
-{
-  %c0 = arith.constant 0.0 : ${datatype_result}
-  %init = tensor.empty() : tensor<${batch_count}x${problem_m}x${problem_n}x${datatype_result}>
-  %inital_result = linalg.fill ins(%c0 : ${datatype_result}) outs(%init : tensor<${batch_count}x${problem_m}x${problem_n}x${datatype_result}>) -> tensor<${batch_count}x${problem_m}x${problem_n}x${datatype_result}>
-  %result = linalg.batch_matmul ${compilation_info_attribute} 
-                     ins(%lhs, %rhs: tensor<${batch_count}x${problem_m}x${problem_k}x${datatype_lhs}>, tensor<${batch_count}x${problem_k}x${problem_n}x${datatype_rhs}>)
-                     outs(%inital_result: tensor<${batch_count}x${problem_m}x${problem_n}x${datatype_result}>) -> tensor<${batch_count}x${problem_m}x${problem_n}x${datatype_result}>
-  return %result : tensor<${batch_count}x${problem_m}x${problem_n}x${datatype_result}>
-}
-"""
-
-    def emit(self, dispatch):
-        """Emit the matmul operation in the MLIR dialect for a single compilation info"""
-        compilation_info_attribute_template = (
-            """{compilation_info = #${compilation_info_name}}"""
-        )
-        compilation_info_attribute_str = SubstituteTemplate(
-            compilation_info_attribute_template,
-            {"compilation_info_name": dispatch.configuration.name()},
-        )
-        compilation_info_attribute = (
-            compilation_info_attribute_str
-            if dispatch.configuration.config_type != CompilationConfigType.Default
-            else ""
-        )
-
-        values = {
-            "operation_name": dispatch.operation.name(),
-            "compilation_info_attribute": compilation_info_attribute,
-            "batch_count": str(dispatch.operation.batch_count),
-            "problem_m": str(dispatch.operation.M),
-            "problem_n": str(dispatch.operation.N),
-            "problem_k": str(dispatch.operation.K),
-            "datatype_lhs": DataTypeName[dispatch.operation.lhs.datatype],
-            "datatype_rhs": DataTypeName[dispatch.operation.rhs.datatype],
-            "datatype_result": DataTypeName[dispatch.operation.result.datatype],
-            "compilation_info_name": dispatch.configuration.name(),
-        }
-
-        return SubstituteTemplate(self.linalg_row_row_matmul_template, values)
-
-
-class ReferenceBatchMatmulOp(ReferenceOpInterface):
-    """Reference implementation for the batch matmul operation in numpy."""
-
-    def __init__(self, bmm_operation, op_reference_cache_path, dist_lhs, dist_rhs):
-        self.bmm_operation = bmm_operation
-        self.op_reference_cache_path = op_reference_cache_path
-
-        if not self.op_reference_cache_path.exists():
-            self.op_reference_cache_path.mkdir()
-
-        # Problem shape.
-        self.batch_count = bmm_operation.batch_count
-        self.M = bmm_operation.M
-        self.N = bmm_operation.N
-        self.K = bmm_operation.K
-
-        # Data type for the input and result matrices.
-        self.dtype_lhs = DataTypeNumPyTag[bmm_operation.lhs.datatype]
-        self.dtype_rhs = DataTypeNumPyTag[bmm_operation.rhs.datatype]
-        self.dtype_result = DataTypeNumPyTag[bmm_operation.result.datatype]
-
-        # Distribution of the input tensors.
-        self.dist_lhs = dist_lhs
-        self.dist_rhs = dist_rhs
-
-        # Filename for the left hand side input tensor.
-        self.filename_lhs = (
-            "batch_count{batch_count}xm{problem_m}xk{problem_k}_"
-            "{tensor_description}_{dist}_lhs.npy".format(
-                batch_count=self.batch_count,
-                problem_m=self.M,
-                problem_k=self.K,
-                tensor_description=self.bmm_operation.lhs.name(),
-                dist=DistributionName[self.dist_lhs],
-            )
-        )
-
-        # Filename for the right hand side input tensor.
-        self.filename_rhs = (
-            "batch_count{batch_count}xk{problem_k}xn{problem_n}_"
-            "{tensor_description}_{dist}_rhs.npy".format(
-                batch_count=self.batch_count,
-                problem_k=self.K,
-                problem_n=self.N,
-                tensor_description=self.bmm_operation.rhs.name(),
-                dist=DistributionName[self.dist_rhs],
-            )
-        )
-
-        # Filename for the reference result tensor.
-        self.filename_reference_result = (
-            "batch_count{batch_count}xm{problem_m}xn{problem_n}_"
-            "{tensor_description}_reference_result.npy".format(
-                batch_count=self.batch_count,
-                problem_m=self.M,
-                problem_n=self.N,
-                tensor_description=self.bmm_operation.result.name(),
-            )
-        )
-
-        # Filepath for input and output files.
-        self.filepath_lhs = self.op_reference_cache_path.joinpath(self.filename_lhs)
-        self.filepath_rhs = self.op_reference_cache_path.joinpath(self.filename_rhs)
-        self.filepath_reference_result = self.op_reference_cache_path.joinpath(
-            self.filename_reference_result
-        )
-
-    def get_input_filepaths(self):
-        """Returns the list of input file paths."""
-        return [self.filepath_lhs, self.filepath_rhs]
-
-    def get_output_filepaths(self):
-        """Returns the list of expected output file paths."""
-        return [self.filepath_reference_result]
-
-    def __call__(self):
-        """Generates input data, runs reference numpy.matmul, and save npy files to the output directory."""
-        # Generate the input data as np.array for the matmul operation.
-        lhs_np_array = get_np_array(
-            self.bmm_operation.lhs, (self.batch_count, self.M, self.K), self.dist_lhs
-        )
-        rhs_np_array = get_np_array(
-            self.bmm_operation.rhs, (self.batch_count, self.K, self.N), self.dist_rhs
-        )
-
-        # Run the reference np.matmul and generate result np.array.
-        result = np.matmul(lhs_np_array, rhs_np_array)
-
-        # Save the input data as np.array for the matmul operation.
-        np.save(self.filepath_lhs, np.array(lhs_np_array, dtype=self.dtype_lhs))
-        np.save(self.filepath_rhs, np.array(rhs_np_array, dtype=self.dtype_rhs))
-
-        # Save the expected result as an np.array.
-        np.save(
-            self.filepath_reference_result, np.array(result, dtype=self.dtype_result)
-        )
-
-
-##############################################################################
-class CudaBatchMatmulGenerator(CudaMatmulGenerator):
-    """Batch matmul dispatch generator class."""
-
-    def __init__(self, args):
-        """Initializes the batch matmul dispatch generator."""
-        super().__init__(args)
-
-        # Predefined batch matmul problem shapes.
-        self.batch_matmul_shapes = [[16, 512, 64, 512]]
-
-        # Batch matmul dispatches collection.
-        self.dispatches_collection_list = []
-
-    def _append_matmul_dispatch_collection(
-        self, bmm_shapes, data_type, configuration_list
-    ):
-        """Update the batch matmul dispatch collection with the given configuration list."""
-
-        # Create dispatches collection for each problem shape with the configuration list.
-        for bmm_shape in bmm_shapes:
-            operation = BatchMatmulOperation(
-                bmm_shape,
-                TensorDescription(data_type[0], LayoutType.RowMajor),
-                TensorDescription(data_type[1], LayoutType.RowMajor),
-                TensorDescription(data_type[2], LayoutType.RowMajor),
-            )
-
-            # Filter out configurations that are not supported by LLVM GPU CUDA backend.
-            supported_configuration_list = self._cuda_supported_configuration_list(
-                operation, configuration_list
-            )
-
-            # Add default configuration if enabled.
-            if self.args.default_config:
-                supported_configuration_list.append(
-                    MatmulCompilationInfo(
-                        [], [], OperationKind.BatchMatmul, CompilationConfigType.Default
-                    )
-                )
-
-            # Append the dispatches collection.
-            self.dispatches_collection_list.append(
-                DispatchCollection(operation, supported_configuration_list)
-            )
-
-    def _cuda_matmul_tensor_cores_f16(self):
-        """Appends a list of matmul dispatches for GPU TensorCore F16 data type."""
-        configuration_list = self._get_matmul_custom_compilation_info_list(
-            self.tile_descriptions_tensor_cores_f16,
-            self.translation_infos,
-            OperationKind.BatchMatmul,
-        )
-        data_type = [DataType.f16, DataType.f16, DataType.f16]
-        self._append_matmul_dispatch_collection(
-            self.batch_matmul_shapes, data_type, configuration_list
-        )
-
-    def _cuda_matmul_tensor_cores_f32(self):
-        """Appends a list of matmul dispatches for GPU TensorCore F32 data type."""
-        configuration_list = self._get_matmul_custom_compilation_info_list(
-            self.tile_descriptions_tensor_cores_f32,
-            self.translation_infos,
-            OperationKind.BatchMatmul,
-        )
-        data_type = [DataType.f32, DataType.f32, DataType.f32]
-        self._append_matmul_dispatch_collection(
-            self.batch_matmul_shapes, data_type, configuration_list
-        )
-
-    def generate(self):
-        """Generates a list of matmul operations."""
-        self._cuda_matmul_tensor_cores_f16()
-        self._cuda_matmul_tensor_cores_f32()
-        return self.dispatches_collection_list
diff --git a/experimental/dispatch_profiler/compile.py b/experimental/dispatch_profiler/compile.py
deleted file mode 100644
index 12a67bb..0000000
--- a/experimental/dispatch_profiler/compile.py
+++ /dev/null
@@ -1,65 +0,0 @@
-# Copyright 2023 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-import argparse, os
-
-from library import *
-from manifest import *
-from launchers import *
-from concurrent.futures import ThreadPoolExecutor
-from options import parse_compile_arguments
-
-###############################################################################
-# Compile main : The main entry point for the compile tool.
-# This tool compiles IREE-compiled MLIR operations for a given backend device,
-###############################################################################
-
-if __name__ == "__main__":
-    ###############################################################################
-    # Parse command line arguments
-    ###############################################################################
-    parser = argparse.ArgumentParser(
-        description="IREE Python compile tool for launching iree-compile for verification and "
-        "profiling. Issues iree-compile for a given backend device and iree-compile "
-        "flags. Uses ThreadPoolExecutor to launch multiple iree-compile processes "
-        "in parallel."
-    )
-
-    args = parse_compile_arguments(parser)
-    ###############################################################################
-
-    # Manifests metadata for a group of accompanying operations and configurations.
-    manifest = Manifest(args)
-    manifest.load()
-
-    # Try and use all CPUs to launch iree-compile in parallel.
-    cpu_count = os.cpu_count()
-    if args.num_cpu > 0:
-        cpu_count = min(cpu_count, args.num_cpu)
-
-    # For all the operations in the manifest, issue iree-compile for verification
-    # and profiling in parallel using ThreadPoolExecutor and cpu_count threads.
-    cmds = []
-    with ThreadPoolExecutor(max_workers=cpu_count) as executor:
-        # For all the operations in the manifest compile, verify, and profile.
-        for _, dispatch_collection_list in manifest.dispatch_collection_map.items():
-            for dispatch_collection in dispatch_collection_list:
-                # Create an instance of operation_launcher.
-                operation = dispatch_collection.operation
-                operation_launcher = IreeToolsLauncher(args, operation)
-                for configuration in dispatch_collection.configuration_list:
-                    for compile_mode in [
-                        CompilationMode.Profile,
-                        CompilationMode.Verify,
-                    ]:
-                        cmds.append(
-                            executor.submit(
-                                operation_launcher.iree_compile, compile_mode
-                            )
-                        )
-
-    # Wait for all the commands to complete.
-    results = [cmd.result() for cmd in cmds]
diff --git a/experimental/dispatch_profiler/dispatch.py b/experimental/dispatch_profiler/dispatch.py
deleted file mode 100644
index c003271..0000000
--- a/experimental/dispatch_profiler/dispatch.py
+++ /dev/null
@@ -1,63 +0,0 @@
-# Copyright 2023 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from library import *
-
-
-################################################################################
-class Dispatch:
-    """
-    Dispatch: A combination of an operation and a configuration is launched by
-      the dispatch profiler for verification and performance profiling. Note that
-      a dispatch is not a MLIR operation it is binary executable that is launched
-      by the profiler. Additionaly, the goal of the tool is to also profile the
-      performance of the fusions and a dispatch for fusion is a combination of
-      multiple operations glued together and compiled into a single dispatch.
-    """
-
-    def __init__(self, operation, configuration):
-        self.operation = operation
-        self.configuration = configuration
-        self.is_fused_dispatch = False
-
-    def name(self):
-        return f"{self.operation.name()}_{self.configuration.name()}"
-
-
-################################################################################
-class DispatchCollection:
-    """
-    DispatchCollection: A collection of dispatches that only vary in their
-      configurations but not in their operations. For example, a collection
-      of matmul dispatches with different tile sizes.
-
-      We can emit a single MLIR file for all the dispatches in a collection
-      and compile with single run of iree-compile and them into a single executable
-    """
-
-    def __init__(self, operation, configuration_list):
-        self.operation = operation
-        self.configuration_list = configuration_list
-
-    def get_dispatches(self):
-        """Returns a list of dispatches in the collection."""
-        dispatches = []
-        for configuration in self.configuration_list:
-            dispatches.append(Dispatch(self.operation, configuration))
-        return dispatches
-
-    def append(self, dispatch):
-        """Appends a dispatch to the collection."""
-        if dispatch.operation != self.operation:
-            raise ValueError(
-                f"operation {self.operation.name()} does not match the dispatch "
-                f"collection operation name {dispatch.operation.name()}."
-            )
-        self.configuration_list.append(dispatch.configuration)
-
-    def num_of_dispatches(self):
-        """Returns number of dispatches in the collection."""
-        return len(self.configuration_list)
diff --git a/experimental/dispatch_profiler/generator.py b/experimental/dispatch_profiler/generator.py
deleted file mode 100644
index 59e6b95..0000000
--- a/experimental/dispatch_profiler/generator.py
+++ /dev/null
@@ -1,30 +0,0 @@
-# Copyright 2023 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-import argparse
-from library import *
-from matmul import *
-from manifest import *
-from options import parse_generator_arguments
-
-###############################################################################
-
-if __name__ == "__main__":
-    parser = argparse.ArgumentParser(
-        description="Generates MLIR operations for "
-        "verification and profiling of IREE compiled dispatches."
-    )
-
-    args = parse_generator_arguments(parser)
-
-    # Manifest dispatches for a group of accompanying operations and configurations.
-    manifest = Manifest(args)
-
-    # Load all the pre-defined dispatches in a manifest.
-    manifest.initialize()
-
-    # Emit the dispatches in MLIR source files.
-    manifest.emit()
diff --git a/experimental/dispatch_profiler/launchers.py b/experimental/dispatch_profiler/launchers.py
deleted file mode 100644
index e97ecf2..0000000
--- a/experimental/dispatch_profiler/launchers.py
+++ /dev/null
@@ -1,238 +0,0 @@
-# Copyright 2023 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from library import *
-from matmul import ReferenceMatmulOp
-from batch_matmul import ReferenceBatchMatmulOp
-from pathlib import Path
-import subprocess
-
-
-class IreeToolsLauncher:
-    """Launcher for IREE tools."""
-
-    def __init__(self, args, operation):
-        self.operation = operation
-
-        self.generated_path = Path(args.generated_dir, "generated", args.mlir_dialect)
-
-        self.args = args
-        self.benchmark_dispatch_repeat_count = args.batch_size
-        self.batch_size = args.batch_size
-
-        # paths to source dispatch mlir, compiled vmfb, and logs.
-        self.operation_path = self.generated_path.joinpath(
-            OperationKindNames[operation.operation_kind], operation.name()
-        )
-
-        self.source_mlir_file = self.operation_path.joinpath(
-            operation.name()
-        ).with_suffix(".mlir")
-
-        # path to cached numpy refernece input and expected output files.
-        self.op_reference_cache_path = Path(
-            args.generated_dir, "generated", "reference_cache", operation.name()
-        )
-
-        if not self.op_reference_cache_path.exists():
-            self.op_reference_cache_path.mkdir(parents=True, exist_ok=True)
-
-        # path to iree-compile tool. (for compiling the input mlir file to vmfb)
-        self.iree_compile_path = Path(args.iree_bin_dir, "iree-compile")
-
-        # path to iree-benchmark-module tool. (for performance benchmarking and profiling)
-        self.iree_benchmark_module_path = Path(
-            args.iree_bin_dir, "iree-benchmark-module"
-        )
-
-        # path to iree-run-module tool. (for verification)
-        self.iree_run_module_path = Path(args.iree_bin_dir, "iree-run-module")
-
-        # output vmfb files for verification and profiling.
-        vmfb_filename = f"{operation.name()}"
-
-        if operation.operation_kind == OperationKind.SplitkMatmul:
-            split_k_suffix = "_".join(["split_k_slice", str(operation.split_k_slices)])
-            vmfb_filename = f"{vmfb_filename}_{split_k_suffix}"
-
-        self.vmfb_verify_filepath = self.operation_path.joinpath(
-            self.operation.name()
-        ).with_name(f"{vmfb_filename}_verify.vmfb")
-        self.vmfb_profile_filepath = self.operation_path.joinpath(
-            self.operation.name()
-        ).with_name(f"{vmfb_filename}_profile.vmfb")
-
-        # reference implementation for the operation_kind.
-        self.reference_impl_map = {
-            OperationKind.Matmul: ReferenceMatmulOp,
-            OperationKind.SplitkMatmul: ReferenceMatmulOp,
-            OperationKind.BatchMatmul: ReferenceBatchMatmulOp,
-        }
-
-    def iree_compile(self, compilation_mode):
-        """Compiles the input mlir file to vmfb file."""
-
-        benchmark_dispatch_repeat_count = (
-            self.benchmark_dispatch_repeat_count
-            if compilation_mode == CompilationMode.Profile
-            else 1
-        )
-        vmfb_filepath = (
-            self.vmfb_profile_filepath
-            if compilation_mode == CompilationMode.Profile
-            else self.vmfb_verify_filepath
-        )
-
-        # Base iree-compile commandline
-        cmd = [
-            f"{self.iree_compile_path}",
-            f"{self.source_mlir_file}",
-            "-o",
-            f"{vmfb_filepath}",
-        ]
-
-        # General compilation options
-        cmd += [f"--iree-hal-target-backends={self.args.device}"]
-
-        if self.args.device == "cuda":
-            cmd += [f"--iree-hal-cuda-llvm-target-arch={self.args.cuda_arch}"]
-        if self.operation.operation_kind == OperationKind.SplitkMatmul:
-            cmd += [
-                f"--iree-flow-split-matmul-reduction={self.operation.split_k_slices}"
-            ]
-        if self.args.use_mma_sync:
-            cmd += [f"--iree-codegen-llvmgpu-use-mma-sync"]
-        if self.args.use_wmma:
-            cmd += [f"--iree-codegen-llvmgpu-use-wmma"]
-
-        # Compilation options for profiling
-        cmd += [
-            f"--iree-hal-benchmark-dispatch-repeat-count={benchmark_dispatch_repeat_count}"
-        ]
-
-        # Appends print ir options at the end of the command line.
-        if self.args.mlir_print_ir_after_all:
-            cmd += [f"--mlir-print-ir-after-all"]
-
-        if not vmfb_filepath.exists() or self.args.force_compile:
-            complie_mode_str = CompilationModeNames[compilation_mode]
-
-            print(f"[Compiling ({complie_mode_str})] {' '.join(cmd)}")
-
-            iree_compile_stdout_filepath = self.operation_path.joinpath(
-                "iree_compile_cmd_stdout.mlir"
-            )
-
-            with open(iree_compile_stdout_filepath, "w") as fp:
-                subprocess.run(cmd, stderr=fp)
-
-        elif self.args.verbose:
-            print(
-                f"Skipping compilation of operation: {vmfb_filepath} since it already exists."
-            )
-
-    def verify(self, configuration):
-        """Verifies the operation with a given configuration."""
-        # First compile the operation to a vmfb file.
-        self.iree_compile(CompilationMode.Verify)
-
-        # Verify using random data distribution.
-        reference_run = self.reference_impl_map[self.operation.operation_kind](
-            self.operation,
-            self.op_reference_cache_path,
-            Distribution.Random,
-            Distribution.Random,
-        )
-
-        if not reference_run.is_cached():
-            reference_run()
-
-        # Commandline `iree-run-module` for verification.
-        cmd = [
-            f"{self.iree_run_module_path}",
-            f"--module={self.vmfb_verify_filepath}",
-            f"--device={self.args.device}",
-        ]
-
-        # Operation-specific verification command-line.
-        cmd.append(f"--function={self.operation.name()}_{configuration.name()}")
-        for input_file_path in reference_run.get_input_filepaths():
-            cmd.append(f"--input=@{input_file_path}")
-        for output_file_path in reference_run.get_output_filepaths():
-            cmd.append(f"--expected_output=@{output_file_path}")
-
-        # Print the command if verbose.
-        if self.args.verbose:
-            print(f"[Verification] {' '.join(cmd)}")
-
-        # Launch verification.
-        cmd_output = subprocess.check_output(cmd, text=True)
-
-        # Save the verification command and the output, only if requested
-        # (file writing could slow down the verification).
-        if self.args.save_cmds:
-            filepath = self.operation_path.joinpath("iree_run_module.stdout")
-            with open(filepath, "w") as fp:
-                fp.write(f"[Command] $ {' '.join(cmd)}\n")
-                fp.write(cmd_output)
-
-        # Parse the verification output.
-        m = re.search(r"\[(?P<verification_result>[a-zA-Z]+)\]", cmd_output)
-        if m is None:
-            raise ValueError(
-                f"Failed to parse verification output by iree-run-module: {cmd_output}"
-            )
-        verification_result = m.group("verification_result")
-
-        if self.args.verbose or verification_result != "SUCCESS":
-            print(cmd_output)
-
-        return verification_result
-
-    def profile(self, configuration):
-        """Profiles the operation with a given configuration."""
-        # First compile the operation to a vmfb file.
-        self.iree_compile(CompilationMode.Profile)
-
-        # Commandline `iree-benchmark-module` for profiling.
-        cmd = [
-            f"{self.iree_benchmark_module_path}",
-            f"--module={self.vmfb_profile_filepath}",
-            f"--device={self.args.device}",
-        ]
-
-        # Profiling specific flags.
-        cmd += [f"--benchmark_repetitions={self.args.benchmark_repetitions}"]
-        cmd += [f"--batch_size={self.batch_size}"]
-
-        # Operation-specific profiling command-line.
-        cmd += [f"--function={self.operation.name()}_{configuration.name()}"]
-        cmd += [f"--input={self.operation.lhs_npy_shape()}"]
-        cmd += [f"--input={self.operation.rhs_npy_shape()}"]
-
-        # Print the command if verbose.
-        if self.args.verbose:
-            print(f"[Profiling] {' '.join(cmd)}")
-
-        # Launch profiling.
-        cmd_output = subprocess.check_output(cmd, text=True, stderr=subprocess.STDOUT)
-
-        # Save the profiling command and the output, only if requested
-        # (file writing could slow down the profiling).
-        if self.args.save_cmds:
-            filepath = self.operation_path.joinpath("iree_benchmark_module.stdout")
-            with open(filepath, "w") as fp:
-                fp.write(f"[Command] $ {' '.join(cmd)}\n")
-                fp.write(cmd_output)
-
-        # Parse the profiling output.
-        m = re.search(r"real_time_median\s+(?P<runtime>\d+.\d+)\s+ms", cmd_output)
-        if m is None:
-            raise ValueError(
-                f"Failed to parse runtime from benchmark result: {cmd_output}"
-            )
-        runtime_in_ms = float(m.group("runtime"))
-        return runtime_in_ms
diff --git a/experimental/dispatch_profiler/library.py b/experimental/dispatch_profiler/library.py
deleted file mode 100644
index 2a7f345..0000000
--- a/experimental/dispatch_profiler/library.py
+++ /dev/null
@@ -1,354 +0,0 @@
-# Copyright 2023 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-import enum, re
-from enum import auto
-import numpy as np
-from abc import ABC, abstractmethod
-from collections import namedtuple
-
-###################################################################################################
-# This file contains library of enumerations and classes used to build operation descritpions.
-# The operation descriptions are used to generate MLIR source files, performance tuning configuration,
-# reference implementations, and numpy input/output files.
-
-# The file is organized as follows:
-# 1. Enumerated `Type`s grouped together for categories, For e.g. [Arch]Type, [Data]Type etc.
-# 2. Dictonaries `Names` mapping the enumeration values to their string names.
-#    For e.g. [Arch]TypeNames, [Data]TypeNames etc.
-# 3. `Tags` for each enumeration value to be used in the generated MLIR source files.
-#    For e.g. [TranslationInfo]Tags
-###################################################################################################
-
-
-# Architecure types
-###################################################################################################
-class ArchType(enum.Enum):
-    Cpu = auto()
-    Gpu = auto()
-
-
-ArchTypeNames = {
-    ArchType.Cpu: "cpu",
-    ArchType.Gpu: "gpu",
-}
-
-
-class GpuArchType(enum.Enum):
-    nvptx = auto()
-    rocm = auto()
-    spirv = auto()
-
-
-GpuArchTypeNames = {
-    GpuArchType.nvptx: "nvptx",
-    GpuArchType.rocm: "rocm",
-    GpuArchType.spirv: "spirv",
-}
-
-
-# Operation kinds
-###################################################################################################
-class OperationKind(enum.Enum):
-    Matmul = auto()
-    BatchMatmul = auto()
-    SplitkMatmul = auto()
-    Conv2d = auto()
-
-
-OperationKindNames = {
-    OperationKind.Matmul: "matmul",
-    OperationKind.SplitkMatmul: "matmul_splitk",
-    OperationKind.BatchMatmul: "batch_matmul",
-    OperationKind.Conv2d: "conv2d",
-}
-
-
-# MLIR dialects
-###################################################################################################
-class MlirDialect(enum.Enum):
-    Linalg = auto()
-    Mhlo = auto()
-
-
-MlirDialectNames = {
-    MlirDialect.Linalg: "linalg",
-    MlirDialect.Mhlo: "mhlo",
-}
-
-
-# Compilation modes (verification or benchmarking/profiling)
-###################################################################################################
-class CompilationMode(enum.Enum):
-    Verify = auto()
-    Profile = auto()
-
-
-CompilationModeNames = {
-    CompilationMode.Verify: "verify",
-    CompilationMode.Profile: "profile",
-}
-
-
-class CompilationConfigType(enum.Enum):
-    Default = auto()
-    Custom = auto()
-
-
-CompilationConfigTypeName = {
-    CompilationConfigType.Default: "default",
-    CompilationConfigType.Custom: "custom",
-}
-
-
-# Enumerations for data types and layouts
-###################################################################################################
-class DataType(enum.Enum):
-    b1 = auto()
-    u4 = auto()
-    u8 = auto()
-    u16 = auto()
-    u32 = auto()
-    u64 = auto()
-    s4 = auto()
-    s8 = auto()
-    s16 = auto()
-    s32 = auto()
-    s64 = auto()
-    e4m3 = auto()
-    e5m2 = auto()
-    f16 = auto()
-    bf16 = auto()
-    f32 = auto()
-    tf32 = auto()
-    f64 = auto()
-    invalid = auto()
-
-
-DataTypeName = {
-    DataType.b1: "b1",
-    DataType.u4: "u4",
-    DataType.u8: "u8",
-    DataType.u16: "u16",
-    DataType.u32: "u32",
-    DataType.u64: "u64",
-    DataType.s4: "s4",
-    DataType.s8: "s8",
-    DataType.s16: "s16",
-    DataType.s32: "s32",
-    DataType.s64: "s64",
-    DataType.e4m3: "e4m3",
-    DataType.e5m2: "e5m2",
-    DataType.f16: "f16",
-    DataType.bf16: "bf16",
-    DataType.f32: "f32",
-    DataType.tf32: "tf32",
-    DataType.f64: "f64",
-}
-
-DataTypeNumPyTag = {
-    DataType.f16: np.float16,
-    DataType.f32: np.float32,
-}
-
-DataTypeSizeInBits = {
-    DataType.b1: 1,
-    DataType.u4: 4,
-    DataType.u8: 8,
-    DataType.u16: 16,
-    DataType.u32: 32,
-    DataType.u64: 64,
-    DataType.s4: 4,
-    DataType.s8: 8,
-    DataType.s16: 16,
-    DataType.s32: 32,
-    DataType.s64: 64,
-    DataType.e4m3: 8,
-    DataType.e5m2: 8,
-    DataType.f16: 16,
-    DataType.bf16: 16,
-    DataType.f32: 32,
-    DataType.tf32: 32,
-    DataType.f64: 64,
-}
-
-
-class LayoutType(enum.Enum):
-    ColumnMajor = auto()
-    RowMajor = auto()
-    NHWC = auto()
-    NCWH = auto()
-
-
-# cuBLAS/cuDNN layout type names convention is followed for the layout names.
-# https://docs.nvidia.com/cuda/cublas/index.html#cublasoperation-t
-ShortLayoutTypeName = {
-    LayoutType.ColumnMajor: "n",
-    LayoutType.RowMajor: "t",
-    LayoutType.NHWC: "nhwc",
-    LayoutType.NCWH: "ncwh",
-}
-
-
-# Compilation pipelines/translation info.
-###################################################################################################
-class TranslationInfo(enum.Enum):
-    LLVMGPUMatmulSIMT = auto()
-    LLVMGPUMatmulTensorCore = auto()
-    LLVMGPUMatmulTensorCoreMmaSync = auto()
-
-
-TranslationInfoTag = {
-    TranslationInfo.LLVMGPUMatmulSIMT: "LLVMGPUMatmulSIMT",
-    TranslationInfo.LLVMGPUMatmulTensorCore: "LLVMGPUMatmulTensorCore",
-    TranslationInfo.LLVMGPUMatmulTensorCoreMmaSync: "LLVMGPUMatmulTensorCoreMmaSync",
-}
-
-TranslationInfoName = {
-    TranslationInfo.LLVMGPUMatmulSIMT: "simt_ffma",
-    TranslationInfo.LLVMGPUMatmulTensorCore: "tensorcore_wmma",
-    TranslationInfo.LLVMGPUMatmulTensorCoreMmaSync: "tensorcore_mmasync",
-}
-
-
-# Distribution of values in a tensor.
-###################################################################################################
-class Distribution(enum.Enum):
-    Empty = auto()
-    Zeros = auto()
-    Ones = auto()
-    Sequential = auto()
-    Identity = auto()
-    Random = auto()
-
-
-DistributionName = {
-    Distribution.Empty: "empty",
-    Distribution.Zeros: "zeros",
-    Distribution.Ones: "ones",
-    Distribution.Sequential: "seq",
-    Distribution.Identity: "identity",
-    Distribution.Random: "random",
-}
-
-###################################################################################################
-# The next part of this file contains the data structures for describing a tensor, tiles etc that
-# are built using the above enumerations. These data structures are used to create compose bigger
-# data structures that describe an operation or a sequence of operations, along with compilation
-# pipeling to form a collection of dispatches to profiled.
-###################################################################################################
-
-
-class TensorDescription:
-    """A class for tensor description."""
-
-    def __init__(self, datatype, layout):
-        self.datatype = datatype
-        self.layout = layout
-
-    def name(self):
-        return "%s%s" % (DataTypeName[self.datatype], ShortLayoutTypeName[self.layout])
-
-
-class TileDescription:
-    """A class for tile description."""
-
-    def __init__(self, threadblock_shape, stages, block_dim):
-        self.threadblock_shape = threadblock_shape  # in number of elements in M, N, K
-        self.stages = stages  # number of shared memory stages in tile K
-        self.block_dim = block_dim  # block dimension in number of threads in x, y, z
-
-    def name(self):
-        return "%dx%d_%dx%d" % (
-            self.threadblock_shape[0],
-            self.threadblock_shape[1],
-            self.threadblock_shape[2],
-            self.stages,
-        )
-
-
-###################################################################################################
-# The following part contains utility functions for which are used by the profiler tool.
-# These function may be moved as the need for create a proper structure for the
-# functionality they provide becomes apparent and necessary as we move forward.
-###################################################################################################
-def get_np_array(tensor_description, shape, dist):
-    """Returns a numpy array based on the distribution and shape."""
-    # Fix the seed for reproducibility.
-    np.random.seed(42)
-
-    # Generate the numpy array based on the distribution.
-    if dist == Distribution.Empty:
-        return np.empty(shape)
-    elif dist == Distribution.Zeros:
-        return np.zeros(shape)
-    elif dist == Distribution.Ones:
-        return np.ones(shape)
-    elif dist == Distribution.Sequential:
-        return np.arange(np.prod(shape)).reshape(shape)
-    elif dist == Distribution.Identity:
-        return np.eye(shape[0], shape[1])
-    elif dist == Distribution.Random:
-        if tensor_description.datatype == DataType.s8:
-            return np.random.randint(-2, 3, shape)
-        elif tensor_description.datatype == DataType.u8:
-            return np.random.randint(0, 4, shape)
-        elif (
-            tensor_description.datatype == DataType.f16
-            or tensor_description.datatype == DataType.bf16
-        ):
-            return np.random.randint(-3, 4, shape)
-        elif tensor_description.datatype == DataType.f32:
-            return np.random.randint(-7, 8, shape)
-
-
-###################################################################################################
-def SubstituteTemplate(template, values):
-    """Substitutes values into a template string."""
-    text = template
-    for key, value in values.items():
-        regex = "\\$\\{%s\\}" % key
-        newtext = re.sub(regex, value, text)
-        text = newtext
-    return text
-
-
-###################################################################################################
-class ReferenceOpInterface(ABC):
-    """Interface for reference implementations."""
-
-    @abstractmethod
-    def get_input_filepaths(self):
-        """Returns the list of inputs."""
-        pass
-
-    @abstractmethod
-    def get_output_filepaths(self):
-        """Returns the list of outputs/."""
-        pass
-
-    @abstractmethod
-    def __call__(self):
-        """Runs the reference implementation."""
-        pass
-
-    def is_cached(self):
-        """Returns whether the reference run is cached."""
-
-        # Returns False if any of the reference input are missing.
-        for input_filepath in self.get_input_filepaths():
-            if not input_filepath.exists():
-                return False
-
-        # Returns False if any of the reference output are missing.
-        for output_filepath in self.get_output_filepaths():
-            if not output_filepath.exists():
-                return False
-
-        # Returns True if all the reference inputs and outputs are cached.
-        return True
-
-    ###################################################################################################
diff --git a/experimental/dispatch_profiler/manifest.py b/experimental/dispatch_profiler/manifest.py
deleted file mode 100644
index 9254c3c..0000000
--- a/experimental/dispatch_profiler/manifest.py
+++ /dev/null
@@ -1,264 +0,0 @@
-# Copyright 2022 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-import enum, shutil, pickle
-from library import *
-from matmul import *
-from batch_matmul import *
-from split_k_matmul import *
-from pathlib import Path
-
-
-###############################################################################
-class EmitSourceMLIR:
-    """Emitters for the operation MLIR source files."""
-
-    def __init__(self, operation_path, dispatch_collection):
-        self.operation_path = operation_path
-        self.dispatch_collection = dispatch_collection
-        self.operation = dispatch_collection.operation
-        self.operation_kind = self.operation.operation_kind
-        self.configuration_list = dispatch_collection.configuration_list
-        self.operation_filepath = self.operation_path.joinpath(
-            self.operation.name()
-        ).with_suffix(".mlir")
-
-        mlir_configuration_emitter = {
-            OperationKind.Matmul: EmitMatmulCompilationInfo,
-            OperationKind.SplitkMatmul: EmitMatmulCompilationInfo,
-            OperationKind.BatchMatmul: EmitMatmulCompilationInfo,
-        }
-        self.configuration_emitter = mlir_configuration_emitter[self.operation_kind]()
-
-        mlir_dispatch_emitter = {
-            OperationKind.Matmul: EmitLinalgMatmulDispatch,
-            OperationKind.SplitkMatmul: EmitLinalgMatmulDispatch,
-            OperationKind.BatchMatmul: EmitLinalgBatchMatmulDispatch,
-        }
-        self.dispatch_emitter = mlir_dispatch_emitter[self.operation_kind]()
-
-    def __enter__(self):
-        self.operation_file = open(self.operation_filepath, "w")
-        self.operation_file.write(f"// Finename: {self.operation_filepath}")
-
-        # Emit all the configuration attribute tags.
-        for configuration in self.configuration_list:
-            self.operation_file.write(self.configuration_emitter.emit(configuration))
-        return self
-
-    def emit(self):
-        """Emit the op func.func for each dispatch (operation + configuration)"""
-        for dispatch in self.dispatch_collection.get_dispatches():
-            print(
-                f"    Emitting tuning configuration : {dispatch.configuration.name()}"
-            )
-            self.operation_file.write(self.dispatch_emitter.emit(dispatch))
-
-    def __exit__(self, exc_type, exc_value, traceback):
-        self.operation_file.close()
-
-
-###############################################################################
-class Manifest:
-    """Manifest collects, filters, and stores dispatches in a data structure.
-    Manifest organizes the dispatches in a dictionary.
-    Usage:
-     1. Create a manifest object with the command line arguments.
-     2(a). Generate dispatches, append them in the manifest, and
-           serialize them into a file.
-     2(b). Load dispatches from a serialized file.
-
-     ```python
-     # generator.py usage:
-     manifest = Manifest(args)
-     manifest.initialize()
-
-     # compile.py or profile.py usage:
-     manifest = Manifest(args)
-     manifest.load()
-     ```
-    """
-
-    def __init__(self, args):
-        self.args = args
-
-        # Dictionary of operation kind to a list of dispatch collections. We
-        # initialize the dictionary during the generation of dispatches and
-        # serialize it to a file. The serialized file is used to load the
-        # dispatches for compilation and profiling.
-        # Datatype: OperationKind -> [DispatchCollection]
-        self.dispatch_collection_map = {}
-
-        # For operation kind-based filtering of dispatches.
-        self.operation_kind_enabled = []
-
-        # For name-based filtering of dispatches.
-        self.dispatch_names = []
-        self.ignore_dispatch_names = []
-
-        if args.operation_kind == "all":
-            self.operation_kind_enabled = []
-        else:
-            operations_kind_list = [
-                OperationKind.Matmul,
-                OperationKind.SplitkMatmul,
-                OperationKind.BatchMatmul,
-            ]
-            self.operation_kind_enabled = [
-                x
-                for x in operations_kind_list
-                if OperationKindNames[x] in args.operation_kind.split(",")
-            ]
-
-        if args.dispatches == "all":
-            self.dispatch_names = []
-        else:
-            self.dispatch_names = [x for x in args.dispatches.split(",") if x != ""]
-
-        # Paths to the generated directory (e.g. `./generated/linalg`).
-        self.generated_path = Path(
-            self.args.generated_dir, "generated", self.args.mlir_dialect
-        )
-
-        # Create the directories in self.generated_path, if it does not exist.
-        if not self.generated_path.exists():
-            self.generated_path.mkdir(parents=True, exist_ok=True)
-
-        # Path to the serialized file.
-        self.serialized_file_path = self.generated_path.joinpath("serialized_file.pkl")
-
-    def _filter_string_matches(self, filter_string, haystack):
-        """Returns true if all substrings appear in the haystack in order"""
-        substrings = filter_string.split("*")
-        for sub in substrings:
-            idx = haystack.find(sub)
-            if idx < 0:
-                return False
-            haystack = haystack[idx + len(sub) :]
-        return True
-
-    def is_enabled(self, dispatch):
-        """Rerturns true if pass through filters based various criteria."""
-
-        # Get the operation and configuration from the dispatch.
-        operation = dispatch.operation
-        configuration = dispatch.configuration
-
-        # If the operation is not in the enabled list, return False.
-        enabled = True
-
-        # If operation_kind filter is enabled and the \
-        # operation_kind in not in the enabled list, return False.
-        if (
-            len(self.operation_kind_enabled)
-            and operation.operation_kind not in self.operation_kind_enabled
-        ):
-            enabled = False
-
-        # If dispatch name-based filter regex is enabled match the \
-        # dispatch name (operation+configuration) against all regexs \
-        # in self.dispatch_names.
-        if len(self.dispatch_names):
-            name = dispatch.name()
-            enabled = False
-
-            # compare against each regex included in self.dispatch_names.
-            for substr_to_match in self.dispatch_names:
-                if self._filter_string_matches(substr_to_match, name):
-                    enabled = True
-                    break
-
-        # Return the result of the filter.
-        return enabled
-
-    def append_dispatch_collection(self, dispatch_collection):
-        """Appends one instance of DispatchCollection to the manifest."""
-        operation_kind = dispatch_collection.operation.operation_kind
-        if operation_kind not in self.dispatch_collection_map.keys():
-            self.dispatch_collection_map[operation_kind] = []
-
-        # Get all the dispatches from the dispatch_collection.
-        dispatches = dispatch_collection.get_dispatches()
-
-        # Filter dispatches based on the filter criteria.
-        filtered_dispatch_collection = DispatchCollection(
-            dispatch_collection.operation, []
-        )
-        for dispatch in dispatches:
-            if self.is_enabled(dispatch):
-                filtered_dispatch_collection.append(dispatch)
-
-        # Only append the filtered_dispatch_collection if it has an unfiltered configuration.
-        if len(filtered_dispatch_collection.configuration_list):
-            self.dispatch_collection_map[operation_kind].append(
-                filtered_dispatch_collection
-            )
-
-    def append(self, dispatch_collection_list):
-        """Appends one instance of DispatchCollection to the manifest."""
-        for dispatch_collection in dispatch_collection_list:
-            self.append_dispatch_collection(dispatch_collection)
-
-    def initialize(self):
-        """Initialize the mainfest object by generating dispatches for supported operations."""
-        self.append(CudaMatmulGenerator(self.args).generate())
-        self.append(CudaSplitKMatmulGenerator(self.args).generate())
-        self.append(CudaBatchMatmulGenerator(self.args).generate())
-
-        # Serialize the initialized mainfest state.
-        self.dump()
-
-    def dump(self):
-        """Serialize (dump) the self.dispatch_collection_map to a pickle file."""
-        with open(self.serialized_file_path, "wb") as f:
-            pickle.dump(self.dispatch_collection_map, f)
-
-    def load(self):
-        """Deserialize (load) the self.dispatch_collection_map from a pickle file."""
-        if not self.serialized_file_path.exists():
-            raise ValueError(f"Could not find : {self.serialized_file_path}")
-
-        with open(self.serialized_file_path, "rb") as load_file:
-            self.dispatch_collection_map = pickle.load(load_file)
-
-    def emit(self):
-        """Emits the operations in the Manifest to the build directory as MLIR source files.
-        The operations are emitted in the dialect specified by the `mlir_dialect` flag.
-        """
-
-        # For each operation_kind create a directory and emit the operations with
-        # all the configurations in the configuration_list into their seperate directories.
-        for (
-            operation_kind,
-            dispatch_collection_list,
-        ) in self.dispatch_collection_map.items():
-            operation_kind_path = self.generated_path.joinpath(
-                OperationKindNames[operation_kind]
-            )
-
-            # If the operation_kind_path does not exists, create it.
-            if not operation_kind_path.exists():
-                operation_kind_path.mkdir(parents=True, exist_ok=True)
-
-            for dispatch_collection in dispatch_collection_list:
-                operation_path = operation_kind_path.joinpath(
-                    dispatch_collection.operation.name()
-                )
-
-                # If the operation_path does not exists, create it.
-                if not operation_path.exists():
-                    operation_path.mkdir()
-
-                with EmitSourceMLIR(
-                    operation_path, dispatch_collection
-                ) as emit_mlir_source:
-                    mlir_file_path = operation_path.joinpath(
-                        dispatch_collection.operation.name()
-                    ).with_suffix(".mlir")
-                    print(f"[Generating]: {mlir_file_path}")
-
-                    # Emit mlir source file for the dispatch_collection.operation with all the configurations
-                    emit_mlir_source.emit()
diff --git a/experimental/dispatch_profiler/matmul.py b/experimental/dispatch_profiler/matmul.py
deleted file mode 100644
index 902708d..0000000
--- a/experimental/dispatch_profiler/matmul.py
+++ /dev/null
@@ -1,658 +0,0 @@
-# Copyright 2023 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-import enum, shutil, functools, operator, collections, subprocess
-from library import *
-from dispatch import *
-from options import get_cmd_line_argument_list
-
-
-################################################################################
-class MatmulOperation:
-    """Data structure to describe a matrix multiplication operation.
-    This includes the shape, datatype, and layout of the operands. This data
-    structure is *independent* of the compilation* and tiling configuration.
-    It "mostly" contains the parameter that changes the functionality of matmul
-    operation. The only exception is the split_k_slices parameter, which is
-    changes the performance of the matmul operation and not the functionality.
-    """
-
-    def __init__(
-        self,
-        matmul_shape,
-        lhs,
-        rhs,
-        result,
-        batch_count=1,
-        split_k_slices=1,
-        operation_kind=OperationKind.Matmul,
-    ):
-        """Initializes a matrix multiplication operation.
-        Matrix-multiple operation: `result[M, N] = lhs[M, K] * rhs[K, N]`
-        matmul_shape: A tuple representing the matrix multiplication problem shape
-          in the format (M, N, K), where M is the number of rows in the lhs matrix,
-          N is the number of columns in the rhs matrix, and K is the number of columns
-          in the lhs matrix and rows in the rhs matrix.
-        lhs: A TensorDescription object representing the left-hand-side matrix operand.
-        rhs: A TensorDescription object representing the right-hand-side matrix operand.
-        result: A TensorDescription object representing the result matrix operand.
-        """
-
-        # Parameters that change the matmul operation *functionally*.
-        self.operation_kind = operation_kind
-        self.matmul_shape = matmul_shape
-        self.M = matmul_shape[0]
-        self.N = matmul_shape[1]
-        self.K = matmul_shape[2]
-        self.batch_count = batch_count
-        self.lhs = lhs  # TensorDescription
-        self.rhs = rhs  # TensorDescription
-        self.result = result  # TensorDescription
-
-        # Parameters that change the matmul operation *performance*.
-        self.split_k_slices = split_k_slices
-
-    def __eq__(self, other):
-        """Returns true if the matmul operation is *functionally* the same."""
-        return (
-            self.matmul_shape == other.matmul_shape
-            and self.lhs == other.lhs
-            and self.rhs == other.rhs
-            and self.result == other.result
-            and self.batch_count == other.batch_count
-        )
-
-    def name(self):
-        """Procedurally generated name for the matmul operation.
-        The name uniquely identifies a matmul operation with matmul shape,
-        lhs dataype and layout, rhs datatype and layout, and result
-        datatype and layout.
-        """
-        return (
-            f"{OperationKindNames[self.operation_kind]}_"
-            f"{self.M}x{self.N}x{self.K}_"
-            f"{DataTypeName[self.lhs.datatype]}{ShortLayoutTypeName[self.lhs.layout]}_"
-            f"{DataTypeName[self.rhs.datatype]}{ShortLayoutTypeName[self.rhs.layout]}_"
-            f"{DataTypeName[self.result.datatype]}{ShortLayoutTypeName[self.result.layout]}"
-        )
-
-    def get_argument_dict(self):
-        """Returns the dictionary of matmul arguments (shape, datatypes, split_k_slices)."""
-        split_k_mode = (
-            "parallel" if self.operation_kind == OperationKind.SplitkMatmul else "N/A"
-        )
-        split_k_slices = (
-            self.split_k_slices
-            if self.operation_kind == OperationKind.SplitkMatmul
-            else "N/A"
-        )
-        return {
-            "batch_count": self.batch_count,
-            "m": self.M,
-            "n": self.N,
-            "k": self.K,
-            "lhs": self.lhs.name(),
-            "rhs": self.rhs.name(),
-            "result": self.result.name(),
-            "split_k_mode": split_k_mode,
-            "split_k_slices": split_k_slices,
-        }
-
-    def get_dict_entry(self):
-        """Returns the dictionary of matmul operation summary."""
-        dict_entry = {
-            "op_kind": OperationKindNames[self.operation_kind],
-            "Operation": self.name(),
-            "bytes": self.bytes(),
-            "flops": self.flops(),
-        }
-        dict_entry.update(self.get_argument_dict())
-        return dict_entry
-
-    def lhs_npy_shape(self):
-        """Returns the shape of the lhs numpy array as a string in the format "MxKxDataType"."""
-        return f"{self.M}x{self.K}x{DataTypeName[self.lhs.datatype]}"
-
-    def rhs_npy_shape(self):
-        """Returns the shape of the rhs numpy array as a string in the format "KxNxDataType"."""
-        return f"{self.K}x{self.N}x{DataTypeName[self.rhs.datatype]}"
-
-    def result_npy_shape(self):
-        """Returns the shape of the result numpy array as a string in the format "MxNxDataType"."""
-        return f"{self.M}x{self.N}x{DataTypeName[self.result.datatype]}"
-
-    def bytes(self):
-        """Returns the number of bytes read/written by the matmul operation."""
-        bytes = (
-            (DataTypeSizeInBits[self.lhs.datatype] * self.M // 8) * self.K
-            + (DataTypeSizeInBits[self.rhs.datatype] * self.K // 8) * self.N
-            + (DataTypeSizeInBits[self.result.datatype] * self.M // 8) * self.N
-        )
-        return bytes * self.batch_count
-
-    def flops(self):
-        """Returns the number of floating point operations performed by the matmul operation."""
-        return 2 * self.M * self.N * self.K * self.batch_count
-
-
-##############################################################################
-class MatmulCompilationInfo:
-    """Data structure strictly describes the compilation passes and the tiling configurations.
-    For a matrix multiplication operation, compilation passes and tiling configuration
-    influences the performance of the compiled matmul operation, but the functionality.
-    This data structure should be independent of the matmul operation functionality.
-
-    Any change in this data structure should not affect the functionality of the matmul operation, i.e.,
-    we should be able to use the same reference results for a matrix operation compiled with different
-    compilation info.
-    """
-
-    def __init__(
-        self,
-        tile_description,
-        translation_info,
-        operation_kind=OperationKind.Matmul,
-        config_type=CompilationConfigType.Custom,
-    ):
-        self.tile_description = tile_description  # TileDescription
-        self.translation_info = translation_info  # TranslationInfo
-        self.operation_kind = operation_kind  # OperationKind
-        self.config_type = config_type  # CompilationConfigType
-
-    def get_dict_entry(self):
-        """Returns the dictionary entry for the matmul compilation info."""
-        if self.config_type == CompilationConfigType.Default:
-            return {
-                "Tile config": "Default",
-                "Core class": "Default",
-                "Instruction class": "Default",
-            }
-
-        translation_info_name = TranslationInfoName[self.translation_info]
-        return {
-            "Tile config": self.tile_description.name(),
-            "Core class": translation_info_name.split("_")[0],
-            "Instruction class": translation_info_name.split("_")[1],
-        }
-
-    def name(self):
-        """Procedurally generated name for the matmul compilation info."""
-        if self.config_type == CompilationConfigType.Default:
-            return "tile_config_default"
-
-        return "tile_config_{tbm}x{tbn}_{tbk}x{stages}_{translation_info}".format(
-            tbm=self.tile_description.threadblock_shape[0],
-            tbn=self.tile_description.threadblock_shape[1],
-            tbk=self.tile_description.threadblock_shape[2],
-            stages=self.tile_description.stages,
-            translation_info=TranslationInfoName[self.translation_info],
-        )
-
-
-################################################################################
-class EmitMatmulCompilationInfo:
-    """Emitters for the matmul compilation info."""
-
-    def __init__(self):
-        # matmul compilation info template
-        self.matmul_compilation_info_template = """
-// matmul compilation info (tile configuration, translation info, workgroup size)
-#${compilation_info_name} = #iree_codegen.compilation_info<
-  lowering_config = <tile_sizes = [[${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}]]>,
-  translation_info = <${translation_info} pipeline_depth = ${stages}>,
-  workgroup_size = [${block_dim_x}, ${block_dim_y}, ${block_dim_z}]
->
-"""
-        # batch matmul and split-k matmul compilation info template
-        self.batch_matmul_compilation_info_template = """
-// batch matmul compilation info (tile configuration, translation info, workgroup size)
-#${compilation_info_name} = #iree_codegen.compilation_info<
-  lowering_config = <tile_sizes = [[1, ${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}]]>,
-  translation_info = <${translation_info} pipeline_depth = ${stages}>,
-  workgroup_size = [${block_dim_x}, ${block_dim_y}, ${block_dim_z}]
->
-"""
-
-    def emit(self, compilation_info):
-        """Emits the matmul compilation info as a string."""
-        if compilation_info.config_type == CompilationConfigType.Default:
-            return ""
-
-        values = {
-            "compilation_info_name": compilation_info.name(),
-            "translation_info": TranslationInfoTag[compilation_info.translation_info],
-            "threadblock_shape_m": str(
-                compilation_info.tile_description.threadblock_shape[0]
-            ),
-            "threadblock_shape_n": str(
-                compilation_info.tile_description.threadblock_shape[1]
-            ),
-            "threadblock_shape_k": str(
-                compilation_info.tile_description.threadblock_shape[2]
-            ),
-            "stages": str(compilation_info.tile_description.stages),
-            "block_dim_x": str(compilation_info.tile_description.block_dim[0]),
-            "block_dim_y": str(compilation_info.tile_description.block_dim[1]),
-            "block_dim_z": str(compilation_info.tile_description.block_dim[2]),
-        }
-
-        # linalg.matmul (without split-k) compilation info template.
-        compilation_info_template = self.matmul_compilation_info_template
-
-        # linalg.batch_matmul and linalg.matmul (with split-k) have different
-        # compilation info template from the linalg.matmul (without split-k).
-        if (
-            compilation_info.operation_kind == OperationKind.BatchMatmul
-            or compilation_info.operation_kind == OperationKind.SplitkMatmul
-        ):
-            compilation_info_template = self.batch_matmul_compilation_info_template
-
-        return SubstituteTemplate(compilation_info_template, values)
-
-
-###############################################################################
-class EmitLinalgMatmulDispatch:
-    """Emitters for the `linalg.matmul` dispatch."""
-
-    def __init__(self):
-        self.mlir_dialect = MlirDialect.Linalg
-
-        # linalg.matmul mlir template
-        self.linalg_row_row_matmul_template = """
-// Dispatch linalg.matmul row-row layout
-func.func @${operation_name}_${compilation_info_name}(
-  %lhs: tensor<${problem_m}x${problem_k}x${datatype_lhs}>,
-  %rhs: tensor<${problem_k}x${problem_n}x${datatype_rhs}>) -> tensor<${problem_m}x${problem_n}x${datatype_result}>
-{
-  %c0 = arith.constant 0.0 : ${datatype_result}
-  %init = tensor.empty() : tensor<${problem_m}x${problem_n}x${datatype_result}>
-  %inital_result = linalg.fill ins(%c0 : ${datatype_result}) outs(%init : tensor<${problem_m}x${problem_n}x${datatype_result}>) -> tensor<${problem_m}x${problem_n}x${datatype_result}>
-  %result = linalg.matmul ${compilation_info_attribute}
-                     ins(%lhs, %rhs: tensor<${problem_m}x${problem_k}x${datatype_lhs}>, tensor<${problem_k}x${problem_n}x${datatype_rhs}>)
-                     outs(%inital_result: tensor<${problem_m}x${problem_n}x${datatype_result}>) -> tensor<${problem_m}x${problem_n}x${datatype_result}>
-  return %result : tensor<${problem_m}x${problem_n}x${datatype_result}>
-}
-"""
-
-    def emit(self, matmul_dispatch):
-        """Emit the matmul operation in the MLIR dialect for a single compilation info"""
-        compilation_info_attribute_template = (
-            """{compilation_info = #${compilation_info_name}}"""
-        )
-        compilation_info_attribute_str = SubstituteTemplate(
-            compilation_info_attribute_template,
-            {"compilation_info_name": matmul_dispatch.configuration.name()},
-        )
-        compilation_info_attribute = (
-            compilation_info_attribute_str
-            if matmul_dispatch.configuration.config_type
-            != CompilationConfigType.Default
-            else ""
-        )
-
-        values = {
-            "operation_name": matmul_dispatch.operation.name(),
-            "compilation_info_attribute": compilation_info_attribute,
-            "problem_m": str(matmul_dispatch.operation.M),
-            "problem_n": str(matmul_dispatch.operation.N),
-            "problem_k": str(matmul_dispatch.operation.K),
-            "datatype_lhs": DataTypeName[matmul_dispatch.operation.lhs.datatype],
-            "datatype_rhs": DataTypeName[matmul_dispatch.operation.rhs.datatype],
-            "datatype_result": DataTypeName[matmul_dispatch.operation.result.datatype],
-            "compilation_info_name": matmul_dispatch.configuration.name(),
-        }
-
-        return SubstituteTemplate(self.linalg_row_row_matmul_template, values)
-
-
-###############################################################################
-class ReferenceMatmulOp(ReferenceOpInterface):
-    """Reference implementation for the matmul operation in numpy."""
-
-    def __init__(self, matmul_operation, op_reference_cache_path, dist_lhs, dist_rhs):
-        self.matmul_operation = matmul_operation
-        self.op_reference_cache_path = op_reference_cache_path
-
-        # Problem shape.
-        self.M = matmul_operation.M
-        self.N = matmul_operation.N
-        self.K = matmul_operation.K
-
-        # Data type for the input and result matrices.
-        self.dtype_lhs = DataTypeNumPyTag[matmul_operation.lhs.datatype]
-        self.dtype_rhs = DataTypeNumPyTag[matmul_operation.rhs.datatype]
-        self.dtype_result = DataTypeNumPyTag[matmul_operation.result.datatype]
-
-        # Distribution of the input tensors.
-        self.dist_lhs = dist_lhs
-        self.dist_rhs = dist_rhs
-
-        # Filename for the left hand side input tensor.
-        self.filename_lhs = (
-            "m{problem_m}xk{problem_k}_"
-            "{tensor_description}_{dist}_lhs.npy".format(
-                problem_m=self.M,
-                problem_k=self.K,
-                tensor_description=self.matmul_operation.lhs.name(),
-                dist=DistributionName[self.dist_lhs],
-            )
-        )
-
-        # Filename for the right hand side input tensor.
-        self.filename_rhs = (
-            "k{problem_k}xn{problem_n}_"
-            "{tensor_description}_{dist}_rhs.npy".format(
-                problem_k=self.K,
-                problem_n=self.N,
-                tensor_description=self.matmul_operation.rhs.name(),
-                dist=DistributionName[self.dist_rhs],
-            )
-        )
-
-        # Filename for the reference result tensor.
-        self.filename_reference_result = (
-            "m{problem_m}xn{problem_n}_"
-            "{tensor_description}_reference_result.npy".format(
-                problem_m=self.M,
-                problem_n=self.N,
-                tensor_description=self.matmul_operation.result.name(),
-            )
-        )
-
-        # Filepath for input and output files.
-        self.filepath_lhs = self.op_reference_cache_path.joinpath(self.filename_lhs)
-        self.filepath_rhs = self.op_reference_cache_path.joinpath(self.filename_rhs)
-        self.filepath_reference_result = self.op_reference_cache_path.joinpath(
-            self.filename_reference_result
-        )
-
-    def get_input_filepaths(self):
-        """Returns the list of input file paths."""
-        return [self.filepath_lhs, self.filepath_rhs]
-
-    def get_output_filepaths(self):
-        """Returns the list of expected output file paths."""
-        return [self.filepath_reference_result]
-
-    def __call__(self):
-        """Generates input data, runs reference numpy.matmul, and save npy files to the output directory."""
-        # Generate the input data as np.array for the matmul operation.
-        lhs_np_array = get_np_array(
-            self.matmul_operation.lhs, (self.M, self.K), self.dist_lhs
-        )
-        rhs_np_array = get_np_array(
-            self.matmul_operation.rhs, (self.K, self.N), self.dist_rhs
-        )
-
-        # Run the reference np.matmul and generate result np.array.
-        result = np.matmul(lhs_np_array, rhs_np_array)
-
-        # Save the input data as np.array for the matmul operation.
-        np.save(self.filepath_lhs, np.array(lhs_np_array, dtype=self.dtype_lhs))
-        np.save(self.filepath_rhs, np.array(rhs_np_array, dtype=self.dtype_rhs))
-
-        # Save the expected result as an np.array.
-        np.save(
-            self.filepath_reference_result, np.array(result, dtype=self.dtype_result)
-        )
-
-
-class CudaMatmulDispatchChecker:
-    """Given a matmul dispatch, checks if the dispatch is supported by the target GPU."""
-
-    def __init__(self, args):
-        self.args = args
-
-        # CUDA shared memory capacity per SM in KB.
-        self.sharedMemPerSm = {
-            "sm_80": 163,  # 1KB is reserved for the driver.
-            "sm_86": 99,  # 1KB is reserved for the driver
-        }
-
-        self.cuda_arch = self.args.cuda_arch
-        self.cuda_smem_capacity_in_bytes = self.sharedMemPerSm[self.cuda_arch] << 10
-
-    def _is_tile_aligned_shape(self, dispatch):
-        """Checks if the given dispatch is valid for CUDA."""
-        matmul_shape = dispatch.operation.matmul_shape
-        threadblock_shape = dispatch.configuration.tile_description.threadblock_shape
-        if len(matmul_shape) != len(threadblock_shape):
-            raise ValueError(
-                "Problem shape and threadblock shape must have the same rank."
-            )
-        is_aligned = all(a % b == 0 for a, b in zip(matmul_shape, threadblock_shape))
-        return is_aligned
-
-    def _cuda_smem_required_in_bytes(self, dispatch):
-        """Returns size bytes of shared memory required for a given cuda dispatch."""
-        threadblock_shape = dispatch.configuration.tile_description.threadblock_shape
-        num_stages = dispatch.configuration.tile_description.stages
-        tile_shape_lhs = threadblock_shape[0] * threadblock_shape[2]
-        tile_shape_rhs = threadblock_shape[2] * threadblock_shape[1]
-        return (
-            (
-                tile_shape_lhs * DataTypeSizeInBits[dispatch.operation.lhs.datatype]
-                + tile_shape_rhs * DataTypeSizeInBits[dispatch.operation.rhs.datatype]
-            )
-            * num_stages
-        ) // 8
-
-    def _is_problem_k_divisible_by_split_k(self, dispatch):
-        """Checks if the given dispatch is valid for CUDA."""
-        return dispatch.operation.K % dispatch.operation.split_k_slices == 0
-
-    def _is_cuda_smem_avialable(self, dispatch):
-        """Checks if the given dispatch is valid for CUDA."""
-        return (
-            self._cuda_smem_required_in_bytes(dispatch)
-            <= self.cuda_smem_capacity_in_bytes
-        )
-
-    def is_valid(self, dispatch):
-        """Checks if the given dispatch is valid for CUDA."""
-        if not self._is_tile_aligned_shape(dispatch):
-            if self.args.verbose:
-                print(f"[Warning]: {dispatch.name()} is not aligned is being skipped.")
-            return False
-        if not self._is_cuda_smem_avialable(dispatch):
-            if self.args.verbose:
-                print(
-                    f"[Warning]: {dispatch.name()} requires {self._cuda_smem_required_in_bytes(dispatch)} "
-                    f"bytes of shared memory, which is larger than the {self.cuda_arch} capacity "
-                    f"{self.cuda_smem_capacity_in_bytes} bytes."
-                )
-            return False
-        if (dispatch.operation.split_k_slices > 1) and (
-            not self._is_problem_k_divisible_by_split_k(dispatch)
-        ):
-            if self.args.verbose:
-                print(
-                    f"[Warning]: {dispatch.name()} problem k is not divisible by {dispatch.operation.split_k_slices} "
-                    f"split-k slices, which is not supported on LLVM GPU CUDA backend."
-                )
-            return False
-        return True
-
-
-class CudaMatmulGenerator:
-    """Matmul dispatch generator class.
-    Generates a list of pre-defined matmul operations with resonable tuning cofigurations.
-    The generator function are seperated based on the target backend and the data type.
-    Please see example `MatmulGenerator._cuda_matmul_tensor_cores_f16` for cuda target
-    backend and f16 data type."""
-
-    def __init__(self, args):
-        """Initializes the matmul generator."""
-        self.args = args
-        self.translation_infos = [
-            # TranslationInfo.LLVMGPUMatmulSimt,  # CUDA Core (SMIT)
-            # TranslationInfo.LLVMGPUMatmulTensorCore, # Tensor Core (WMMA)
-            TranslationInfo.LLVMGPUMatmulTensorCoreMmaSync,  # Tensor Core (MMA.SYNC)
-        ]
-
-        # List of pre-defined threadblock tile shapes for Tensor Core.
-        self.tile_descriptions_tensor_cores_f16 = [
-            TileDescription([256, 128, 32], 3, [64, 4, 1]),
-            TileDescription([128, 256, 32], 3, [128, 2, 1]),
-            TileDescription([128, 128, 64], 4, [64, 2, 1]),
-            TileDescription([128, 128, 32], 5, [64, 2, 1]),
-            TileDescription([128, 64, 32], 5, [64, 2, 1]),
-            TileDescription([64, 64, 64], 5, [64, 2, 1]),
-            TileDescription([64, 64, 32], 10, [64, 2, 1]),
-        ]
-
-        self.tile_descriptions_tensor_cores_f32 = [
-            TileDescription([128, 256, 16], 3, [128, 2, 1]),
-            TileDescription([256, 128, 16], 3, [64, 4, 1]),
-            TileDescription([128, 128, 16], 5, [64, 2, 1]),
-            TileDescription([128, 128, 32], 3, [64, 2, 1]),
-            TileDescription([128, 128, 32], 4, [64, 2, 1]),
-            TileDescription([128, 64, 32], 3, [64, 2, 1]),
-            TileDescription([128, 64, 16], 5, [64, 2, 1]),
-            TileDescription([64, 64, 32], 3, [64, 2, 1]),
-            TileDescription([64, 64, 16], 10, [64, 2, 1]),
-        ]
-
-        # Create a list of matmul problem and initialize with some *default* shapes.
-        self.matmul_shapes = [[256, 512, 128], [2560, 2560, 2560], [3456, 1024, 2048]]
-
-        # Append matmul problem with *user* provided shapes.
-        self.add_cmd_line_shapes()
-
-        # Matmul dispatches collection.
-        self.dispatches_collection_list = []
-
-    def add_cmd_line_shapes(self):
-        """Adds matmul shapes from command line arguments."""
-
-        m_list = get_cmd_line_argument_list(self.args.problem_m)
-        n_list = get_cmd_line_argument_list(self.args.problem_n)
-        k_list = get_cmd_line_argument_list(self.args.problem_k)
-
-        # If no command line matmul problem shapes are provided, only
-        # use the default shapes.
-        if len(m_list) == 0 and len(n_list) == 0 and len(k_list) == 0:
-            return
-
-        # If any of the command line matmul problem shapes are provided,
-        # set the default shapes to empty problem dimension.
-        if len(m_list) == 0:
-            m_list = [256]
-        if len(n_list) == 0:
-            n_list = [256]
-        if len(k_list) == 0:
-            k_list = [256]
-
-        # Append the command line matmul problem shapes with user-proivded
-        # matmul problem shapes.
-        for m in m_list:
-            for n in n_list:
-                for k in k_list:
-                    self.matmul_shapes.append([m, n, k])
-
-    def _cuda_supported_configuration_list(self, operation, configuration_list):
-        """Returns a list of supported configurations for CUDA."""
-        supported_configuration_list = []
-        dispatch_checker = CudaMatmulDispatchChecker(self.args)
-        for configuration in configuration_list:
-            if not dispatch_checker.is_valid(Dispatch(operation, configuration)):
-                continue
-            supported_configuration_list.append(configuration)
-
-        # Return the supported configuration list.
-        return supported_configuration_list
-
-    def _get_matmul_custom_compilation_info_list(
-        self, tile_descriptions, translation_infos, operation_kind
-    ):
-        """Creates a *custom* list of matmul compilation info."""
-        configuration_list = []
-        for tile_description in tile_descriptions:
-            for translation_info in translation_infos:
-                configuration_list.append(
-                    MatmulCompilationInfo(
-                        tile_description,
-                        translation_info,
-                        operation_kind,
-                        CompilationConfigType.Custom,
-                    )
-                )
-        return configuration_list
-
-    def _append_matmul_dispatch_collection(
-        self, matmul_shapes, data_type, configuration_list
-    ):
-        """Appends the matmul dispatches collection with the given configuration list."""
-
-        # Create dispatches collection for each matmul_shape x configuration list..
-        for matmul_shape in matmul_shapes:
-            operation = MatmulOperation(
-                matmul_shape,
-                TensorDescription(data_type[0], LayoutType.RowMajor),
-                TensorDescription(data_type[1], LayoutType.RowMajor),
-                TensorDescription(data_type[2], LayoutType.RowMajor),
-            )
-
-            # Filter out configurations that are not supported by LLVM GPU CUDA backend.
-            supported_configuration_list = self._cuda_supported_configuration_list(
-                operation, configuration_list
-            )
-
-            # Add default configuration if enabled.
-            if self.args.default_config:
-                supported_configuration_list.append(
-                    MatmulCompilationInfo(
-                        [], [], OperationKind.Matmul, CompilationConfigType.Default
-                    )
-                )
-
-            # Append the dispatch collection.
-            self.dispatches_collection_list.append(
-                DispatchCollection(operation, supported_configuration_list)
-            )
-
-    def _cuda_matmul_tensor_cores_f16(self):
-        """Appends dispatches for TensorCore with F16 input, F16 accum, F16 output."""
-        configuration_list = self._get_matmul_custom_compilation_info_list(
-            self.tile_descriptions_tensor_cores_f16,
-            self.translation_infos,
-            OperationKind.Matmul,
-        )
-        data_type = [DataType.f16, DataType.f16, DataType.f16]
-        self._append_matmul_dispatch_collection(
-            self.matmul_shapes, data_type, configuration_list
-        )
-
-    def _cuda_matmul_tensor_cores_f32(self):
-        """Appends dispatches for TensorCore with F32 input, F32 accum, F32 output."""
-        configuration_list = self._get_matmul_custom_compilation_info_list(
-            self.tile_descriptions_tensor_cores_f32,
-            self.translation_infos,
-            OperationKind.Matmul,
-        )
-        data_type = [DataType.f32, DataType.f32, DataType.f32]
-        self._append_matmul_dispatch_collection(
-            self.matmul_shapes, data_type, configuration_list
-        )
-
-    def _cuda_matmul_tensor_cores_mixed_precision(self):
-        """Appends dispatches for TensorCore with F16 input, F32 accum, F32 output."""
-        configuration_list = self._get_matmul_custom_compilation_info_list(
-            self.tile_descriptions_tensor_cores_f16,
-            self.translation_infos,
-            OperationKind.Matmul,
-        )
-        data_type = [DataType.f16, DataType.f16, DataType.f32]
-        self._append_matmul_dispatch_collection(
-            self.matmul_shapes, data_type, configuration_list
-        )
-
-    def generate(self):
-        """Generates a list of matmul operations."""
-        self._cuda_matmul_tensor_cores_f16()
-        self._cuda_matmul_tensor_cores_f32()
-        self._cuda_matmul_tensor_cores_mixed_precision()
-        return self.dispatches_collection_list
diff --git a/experimental/dispatch_profiler/options.py b/experimental/dispatch_profiler/options.py
deleted file mode 100644
index b53d183..0000000
--- a/experimental/dispatch_profiler/options.py
+++ /dev/null
@@ -1,333 +0,0 @@
-# Copyright 2023 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-import argparse
-
-###############################################################################
-#                    Options ohh! too main options
-###############################################################################
-# This file organizes the plenty of options that once can pass to the profiler
-# tool scripts for generating, compiling, verifying, and profiling IREE-compiled
-# MLIR operations.
-#
-# The options are organized into groups: typical, compilation, iree-compile,
-# verification, profiling, performance-reporting. Note that there is a function
-# of each group.
-###############################################################################
-
-
-def add_typical_arguments(parser):
-    """Adds typical command line arguments to the parser."""
-    parser.add_argument(
-        "--iree-bin-dir",
-        default="./tools",
-        help="Directory containing IREE binaries, "
-        "e.g. iree-compile, iree-benchmark-module, "
-        "iree-run-module",
-    )
-    parser.add_argument(
-        "--generated-dir",
-        default=".",
-        help="The dispatch profiler scripts generate "
-        "mlir dispatches, compiled vmfbs, and reference_chache "
-        "containing golden npy files in the generated-dir",
-    )
-    parser.add_argument(
-        "--operation-kind",
-        "--op-kind",
-        dest="operation_kind",
-        default="all",
-        help="Specifies the operation kinds to generate.",
-        choices=["matmul", "conv2d", "all"],
-    )
-    parser.add_argument(
-        "--dispatches",
-        default="",
-        help="Comma delimited list to filter dispatches by name. "
-        "A dispatch is a combination of operation and tuning "
-        "configuration.",
-    )
-    parser.add_argument(
-        "--mlir-dialect",
-        default="linalg",
-        help="MLIR dialect entry point at which operation is emitter.",
-        choices=["linalg"],
-    )
-    parser.add_argument(
-        "--verbose",
-        action="store_true",
-        help="Prints verbose output and commands executed.",
-    )
-    parser.add_argument(
-        "--dry-run",
-        action="store_true",
-        help="Prints commands that will be executed without actually "
-        "executing them.",
-    )
-    parser.add_argument(
-        "--default-config",
-        action="store_true",
-        help="Adds a dispatch without a pre-defined "
-        "tuning configuration. This dispatch will use "
-        "default configuration from KernelsConfig.cpp.",
-    )
-
-
-def add_compilation_arguments(parser):
-    """Adds compilation (not part of iree-compile) command line arguments to the parser."""
-    compilation_parser = parser.add_argument_group(
-        "Compilation", "Compilation related options."
-    )
-    compilation_parser.add_argument(
-        "--num-cpu",
-        "-j",
-        dest="num_cpu",
-        type=int,
-        default=-1,
-        help="Number of cpu threads to use for compilation.",
-    )
-    compilation_parser.add_argument(
-        "--force-compile",
-        action="store_true",
-        help="Force re-compilation of the operation even " "if .vmfb file is present.",
-    )
-
-
-def add_iree_compile_arguments(parser):
-    """Adds iree-compile command line arguments to the parser."""
-    iree_compile_parser = parser.add_argument_group(
-        "iree-compile", "iree-compile related options."
-    )
-
-    iree_compile_parser.add_argument(
-        "--iree-hal-target-backends",
-        "--device",
-        dest="device",
-        default="cuda",
-        help="Target backends for executable compilation. ",
-        choices=["cuda", "vulkan", "cpu"],
-    )
-    iree_compile_parser.add_argument(
-        "--iree-hal-cuda-llvm-target-arch",
-        "--cuda-arch",
-        dest="cuda_arch",
-        default="sm_80",
-        help="Target architecture for the CUDA backend. ",
-        choices=["sm_50", "sm_60", "sm_75", "sm_80", "sm_86"],
-    )
-    iree_compile_parser.add_argument(
-        "--iree-hal-benchmark-dispatch-repeat-count",
-        "--batch-size",
-        dest="batch_size",
-        default=100,
-        help="Number of times dispatch is launched in a loop to "
-        "amortize the launch overhead. This argument is used for "
-        "iree-compile and iree-benchamrk-module. The value used by "
-        "iree-compile and iree-benchamrk-module should be the same.",
-    )
-    iree_compile_parser.add_argument(
-        "--iree-flow-split-matmul-reduction",
-        "--split-k-slices",
-        dest="split_k_slices",
-        default="",
-        help="Number of slices to split the reduction K-dimension.",
-    )
-    iree_compile_parser.add_argument(
-        "--iree-codegen-llvmgpu-use-mma-sync",
-        "--use-mma-sync",
-        dest="use_mma_sync",
-        action="store_true",
-        help="Use mma.sync instructions.",
-    )
-    iree_compile_parser.add_argument(
-        "--iree-codegen-llvmgpu-use-wmma",
-        "--use-wmma",
-        dest="use_wmma",
-        action="store_true",
-        help="Use wmma instructions.",
-    )
-    iree_compile_parser.add_argument(
-        "--mlir-print-ir-after-all",
-        "--print-ir-after-all",
-        dest="mlir_print_ir_after_all",
-        action="store_true",
-        help="Prints IR after all transformations and dumps a "
-        "file print_ir_after_*.mlir file.",
-    )
-
-
-def add_verification_arguments(parser):
-    """Adds verification related arguments to the parser."""
-    verification_parser = parser.add_argument_group(
-        "Verification", "Verification related options."
-    )
-
-    verification_parser.add_argument(
-        "--verification-enabled", default="True", type=str, help="Verify the operation."
-    )
-    verification_parser.add_argument(
-        "--verification-providers",
-        default="numpy",
-        choices=["numpy"],
-        help="Comma delimited list of verification providers.",
-    )
-
-
-def add_profiling_arguments(parser):
-    """Adds profiling related arguments to the parser."""
-    profiling_parser = parser.add_argument_group(
-        "Profiling", "Profiling (iree-benchmark-module) related options."
-    )
-
-    profiling_parser.add_argument(
-        "--profiling-enabled",
-        "--benchmark",
-        default="True",
-        type=str,
-        help="Benchmark the operation.",
-    )
-    profiling_parser.add_argument(
-        "--benchmark-repetitions",
-        default=5,
-        type=int,
-        help="Number of times benchmark is repeated "
-        "and min, max, median, and average runtimes/gflops are "
-        "reported.",
-    )
-
-
-def add_performance_report_arguments(parser):
-    """Adds performance report related arguments to the parser."""
-
-    performance_report_parser = parser.add_argument_group(
-        "Performance Report", "Performance report related options."
-    )
-
-    performance_report_parser.add_argument(
-        "--output", default="", help="Path to output file for csv readable results."
-    )
-    performance_report_parser.add_argument(
-        "--append",
-        action="store_true",
-        help="Appends the results to existing file. "
-        "o.w., the existing file is overwritten.",
-    )
-    performance_report_parser.add_argument(
-        "--tags",
-        default="",
-        help="Inserts leading columns in output table "
-        "and uniform values for each column. Useful for "
-        "generating pivot tables.",
-    )
-
-
-def add_matmul_arguments(parser):
-    """Adds matmul related arguments to the parser."""
-
-    matmul_parser = parser.add_argument_group(
-        "Matmul", "Matrix-multiplication related options."
-    )
-    matmul_parser.add_argument(
-        "--problem-m",
-        default="",
-        help="M dimension of the matrix. "
-        "--problem-m==<value>,<value_start:value_end:increment>*",
-    )
-    matmul_parser.add_argument(
-        "--problem-n",
-        default="",
-        help="N dimension of the matrix."
-        "--problem-n==<value>,<value_start:value_end:increment>*",
-    )
-    matmul_parser.add_argument(
-        "--problem-k",
-        default="",
-        help="K dimension of the matrix."
-        "--problem-k==<value>,<value_start:value_end:increment>*",
-    )
-
-
-###############################################################################
-# Parser all the arguments for a script function:
-# parse_generator_arguments() for generator.py
-# parse_profiler_arguments() for profiler.py
-###############################################################################
-
-
-def parse_generator_arguments(parser):
-    """Adds and parse all the arguments for the *generator.py* script."""
-    add_typical_arguments(parser)
-    add_matmul_arguments(parser)
-    add_iree_compile_arguments(parser)
-    args = parser.parse_args()
-    return args
-
-
-def parse_compile_arguments(parser):
-    """Adds and parse all the arguments for the *compile.py* script."""
-    add_typical_arguments(parser)
-    add_compilation_arguments(parser)
-    add_iree_compile_arguments(parser)
-    args = parser.parse_args()
-    return args
-
-
-def parse_profiler_arguments(parser):
-    """Adds and parse all the arguments for the *profiler.py* script."""
-    add_typical_arguments(parser)
-    add_compilation_arguments(parser)
-    add_iree_compile_arguments(parser)
-    add_verification_arguments(parser)
-    add_profiling_arguments(parser)
-    add_performance_report_arguments(parser)
-
-    # Additional arguments for the profiler.
-    parser.add_argument(
-        "--save-cmds",
-        action="store_true",
-        help="Saves commands and their output that are executed "
-        "by the profiler in a file.",
-    )
-
-    args = parser.parse_args()
-
-    # Boolenize the string arguments from command line. For these args, it makes easier
-    # to read and convey the meaning. The boolean arguments below are specified as:
-    # `--argument=<true|false>`
-    args.verification_enabled = (
-        False if args.verification_enabled in ["False", "false", "0"] else True
-    )
-
-    args.profiling_enabled = (
-        False if args.profiling_enabled in ["False", "false", "0"] else True
-    )
-
-    return args
-
-
-###############################################################################
-# Helper functions for parsing command line arguments.
-###############################################################################
-def get_cmd_line_argument_ranges(arg):
-    """Returns a list of values generated by range of the form start:end:increment."""
-    if not arg:
-        return []
-    if ":" not in arg:
-        return [int(arg)]
-    range_elements = arg.split(":")
-    start = int(range_elements[0])
-    end = int(range_elements[1])
-    increment = int(range_elements[2]) if len(range_elements) == 3 else 1
-    return range(start, end, increment)
-
-
-def get_cmd_line_argument_list(arg):
-    """Returns a list of values generated by comma delimited string."""
-    values = arg.split(",")
-    range_list = []
-    for val in values:
-        range_list += get_cmd_line_argument_ranges(val)
-    return range_list
diff --git a/experimental/dispatch_profiler/performance_report.py b/experimental/dispatch_profiler/performance_report.py
deleted file mode 100644
index e924baf..0000000
--- a/experimental/dispatch_profiler/performance_report.py
+++ /dev/null
@@ -1,158 +0,0 @@
-# Copyright 2023 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-import csv, textwrap
-import numpy as np
-from collections import namedtuple
-from pathlib import Path
-
-
-class PerformanceResult:
-    """Performance result of a single run."""
-
-    def __init__(self, operation, configuration, verification_result, runtime):
-        self.operation = operation
-        self.configuration = configuration
-        self.verification_result = verification_result
-        self.runtime = runtime  # in milliseconds
-        self.gflops = float(self.operation.flops()) / self.runtime / 1.0e6
-
-    def print(self):
-        """Prints the performance result to the console."""
-        runtime = str(self.runtime) if self.runtime != -1.0 else "Not profiled"
-        gflops = (
-            str(float(round(self.gflops, 2)))
-            if self.runtime != -1.0
-            else "Not profiled"
-        )
-
-        print("---------------------------------------------------------------- ")
-        print(
-            f'Dispatch      : {"_".join([self.operation.name(), self.configuration.name()])}'
-        )
-        print(f"Provider      : IREE Codegen")
-        print(f"OpKind        : {self.operation.operation_kind}")
-        print(f"Operation     : {self.operation.name()}")
-        print(f"Configuration : {self.configuration.name()}")
-        # Operation specific arguments.
-        arg_str = " ".join(
-            [
-                f"--{key}={value}"
-                for key, value in self.operation.get_argument_dict().items()
-            ]
-        )
-        wrapped_arg_str = textwrap.fill(
-            arg_str, width=80, subsequent_indent="                "
-        )
-        print(f"Arguments     : {wrapped_arg_str}")
-        print(f"Verification  : {self.verification_result}")
-        print(f"Runtime(ms)   : {runtime}")
-        print(f"GFLOPs        : {gflops}")
-
-    def get_dict_entry(self):
-        """Returns a dictionary with the performance result."""
-        runtime = self.runtime if self.runtime != -1.0 else ""
-        gflops = float(round(self.gflops, 2)) if self.runtime != -1.0 else "Not run"
-        dict_entry = {
-            "Provider": "IREE Codegen",
-            "Verification": self.verification_result,
-            "Runtime(ms)": runtime,
-            "GFLOPs": gflops,
-        }
-
-        # Add the operation specific arguments.
-        dict_entry.update(self.operation.get_dict_entry())
-
-        # Add the configuration specific arguments.
-        dict_entry.update(self.configuration.get_dict_entry())
-
-        return dict_entry
-
-
-class PerformanceReport:
-    """Performance report class is used to store the performance results of multiple runs.
-    The report can be written to a csv file."""
-
-    def __init__(self, args):
-        self.args = args
-
-        # Data members extracted from the args.
-        self.output_file_path = None
-        if args.output != "":
-            self.output_file_path = Path(args.output)
-
-        # List of PerformanceResult.
-        self.perf_result_vector = []
-
-        # Additional tags to add to the csv report file. \
-        # Useful for generating pivot tables.
-        self.tags = []
-        if args.tags != "":
-            self.tags = args.tags.split(",")
-
-        # Boolen to check if the header is written to the csv file.
-        self.is_header_written = False
-
-        # If the args.output set, open the file and write the header.
-        self.open_mode = "a" if self.args.append else "w"
-        if self.output_file_path:
-            self.csv_file = open(self.output_file_path, self.open_mode)
-
-    def __del__(self):
-        """If the args.output set, close the file."""
-        if self.output_file_path:
-            print("Writing performance report to %s" % self.output_file_path)
-            self.csv_file.close()
-
-    def write_csv_header(self, operation, configuration):
-        """Write the header to the csv file."""
-
-        # Create and write the header.
-        operation_specific_header = list(operation.get_dict_entry().keys())
-        configuration_specific_header = list(configuration.get_dict_entry().keys())
-        performance_header = ["Verification", "Runtime(ms)", "GFLOPs"]
-        csv_header = (
-            operation_specific_header
-            + configuration_specific_header
-            + performance_header
-        )
-        csv_header = ["Provider"] + csv_header
-
-        # If tags are present, add the tags.keys() to the begining of the csv header.
-        if len(self.tags):
-            tag_header = [tag.split(":")[0] for tag in self.tags]
-            csv_header = tag_header + csv_header
-
-        # Create the csv dictionary writer.
-        self.csv_writer = csv.DictWriter(self.csv_file, fieldnames=csv_header)
-
-        # Write the header if the file is being created.
-        if self.open_mode == "w":
-            self.csv_writer.writeheader()
-
-    def append_perf_result(self, performance_result):
-        """Appends a performance result to the report.
-        Additionaly, if args.output set, write the csv_row entry."""
-        self.perf_result_vector.append(performance_result)
-
-        if self.output_file_path:
-            # Write the header if not written.
-            if not self.is_header_written:
-                self.write_csv_header(
-                    performance_result.operation, performance_result.configuration
-                )
-                self.is_header_written = True
-
-            # Create the row entries for performance result.
-            csv_dict_row = performance_result.get_dict_entry()
-
-            # Create the row entries for tags.
-            for tag in self.tags:
-                tag_key, tag_value = tag.split(":")
-                csv_dict_row[tag_key] = tag_value
-
-            # Write the row.
-            self.csv_writer.writerow(csv_dict_row)
diff --git a/experimental/dispatch_profiler/profile_all.sh b/experimental/dispatch_profiler/profile_all.sh
deleted file mode 100755
index 2daa14e..0000000
--- a/experimental/dispatch_profiler/profile_all.sh
+++ /dev/null
@@ -1,65 +0,0 @@
-#!/bin/bash
-
-# Copyright 2023 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-# Sets up a venv suitable for running IREE Dispatch Profiler and executes
-# a suite of runs. This is invoked by a Github workflow and can be invoked
-# locally.
-#
-# Recommend getting default 'python' to be python 3. For example on Debian:
-#   sudo update-alternatives --install /usr/bin/python python /usr/bin/python3 1
-# Or launch with python=/some/path
-#
-# Arg 1: The directory where iree-compile, iree-benchmark-module, etc. are
-#        located. If unset, uses IREE Dispatch Profiler defaults.
-# Arg 2: The directory where output is written. If unset, uses
-#        `dispatch_profiler_output` in current working directory.
-
-set -euo pipefail
-
-TD="$(cd $(dirname $0) && pwd)"
-
-PYTHON="${PYTHON:-python3}"
-
-DISPATCH_PROFILER_IREE_BIN_DIR=${1:-""}
-if [[ -z "${DISPATCH_PROFILER_IREE_BIN_DIR}" ]]; then
-  DISPATCH_PROFILER_IREE_BIN_DIR_FLAG=""
-else
-  DISPATCH_PROFILER_IREE_BIN_DIR_FLAG="--iree-bin-dir=${DISPATCH_PROFILER_IREE_BIN_DIR}"
-fi
-
-DISPATCH_PROFILER_OUTPUT_DIR="${2:-"dispatch_profiler_output"}"
-DISPATCH_PROFILER_GENERATED_DIR="."
-VENV_DIR="dispatch-profiler.venv"
-
-echo "Setting up venv dir: ${VENV_DIR}"
-echo "Python: ${PYTHON}"
-echo "Python version: $("${PYTHON}" --version)"
-echo "Dispatch Profiler IREE bin dir flag: ${DISPATCH_PROFILER_IREE_BIN_DIR_FLAG}"
-echo "Dispatch Profiler output dir: ${DISPATCH_PROFILER_OUTPUT_DIR}"
-echo "Dispatch profiler generated dir: ${DISPATCH_PROFILER_GENERATED_DIR}"
-
-${PYTHON} -m venv "${VENV_DIR}"
-source "${VENV_DIR}/bin/activate"
-
-# Upgrade pip and install requirements. 'python' is used here in order to
-# reference to the python executable from the venv.
-python -m pip install --upgrade pip
-python -m pip install --upgrade -r "${TD}/requirements.txt"
-
-mkdir -p "${DISPATCH_PROFILER_OUTPUT_DIR}"
-
-python "${TD}/generator.py" \
-  --generated-dir "${DISPATCH_PROFILER_GENERATED_DIR}"
-python "${TD}/compile.py" \
-  ${DISPATCH_PROFILER_IREE_BIN_DIR_FLAG} \
-  --generated-dir "${DISPATCH_PROFILER_GENERATED_DIR}"
-python "${TD}/profiler.py" \
-  ${DISPATCH_PROFILER_IREE_BIN_DIR_FLAG} \
-  --generated-dir "${DISPATCH_PROFILER_GENERATED_DIR}" \
-  --dispatches="matmul_3456x1024x2048_f16t_f16t_f16t_tile_config_128x128_32x5_tensorcore_mmasync,matmul_3456x1024x2048_f32t_f32t_f32t_tile_config_128x128_16x5_tensorcore_mmasync" \
-  --output "${DISPATCH_PROFILER_OUTPUT_DIR}/matmul_perf_tensor_core_a100.csv"
diff --git a/experimental/dispatch_profiler/profiler.py b/experimental/dispatch_profiler/profiler.py
deleted file mode 100644
index f9e460f..0000000
--- a/experimental/dispatch_profiler/profiler.py
+++ /dev/null
@@ -1,98 +0,0 @@
-# Copyright 2023 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-import argparse
-
-from library import *
-from matmul import *
-from batch_matmul import *
-from manifest import *
-from performance_report import *
-from launchers import *
-from options import parse_profiler_arguments
-
-###############################################################################
-# Profiler main : The main entry point for the profiler tool.
-###############################################################################
-# This tool compiles, verifies, and profiles IREE-compiled MLIR operations for
-# a given backend device, compiler flags, and tuning configuration.
-#
-# The dispatch profiler tool is organized based on below defintions:
-# Operation: A MLIR operation that is generated or consumed by the
-#       dispatch_profiler. For example, linalg.matmul, linalg.conv2d, etc.
-# Configuration: A set of compile parameters that are used by iree-compile the
-#       to choose a compilation pipeline (e.g. LLVMGPUTensorCore,
-#       LLVMGPUTensorCoreMmaSync, LLVGPUCPU, etc.), performance tuning parameters
-#       (e.g. workgroup size, tile size etc.).
-# Dispatch: A combination of an operation and a configuration is launched by the
-#       dispatch profiler for verification and performance profiling. Note that
-#       a dispatch is not a MLIR operation it is binary executable that is launched
-#       by the profiler. Additionaly, the goal of the tool is to also profile the
-#       performance of the fusions and a dispatch for fusion is a combination of
-#       multiple operations glued together and compiled into a single dispatch.
-###############################################################################
-
-if __name__ == "__main__":
-    ###############################################################################
-    # Parse command line arguments
-    ###############################################################################
-    parser = argparse.ArgumentParser(
-        description="IREE Python profiler tool for "
-        "verifcation and performance profiling tool for IREE-compiled "
-        "MLIR operations."
-    )
-
-    args = parse_profiler_arguments(parser)
-    ###############################################################################
-
-    # Create manifest object and load dispatches.
-    manifest = Manifest(args)
-    manifest.load()
-
-    # Performance report
-    perf_report = PerformanceReport(args)
-
-    # For all the operations in the manifest compile (if needed), verify, and profile.
-    for _, dispatch_collection_list in manifest.dispatch_collection_map.items():
-        for dispatch_collection in dispatch_collection_list:
-            operation = dispatch_collection.operation
-            # Select and create an instance of operation_launcher for the operation.
-            operation_launcher = IreeToolsLauncher(args, operation)
-            for configuration in dispatch_collection.configuration_list:
-                # Create a dispatch object.
-                dispatch = Dispatch(operation, configuration)
-
-                # Skip the dispatch if filter returns false.
-                if not manifest.is_enabled(dispatch):
-                    continue
-
-                # If dry run is enabled, skip the dispatch.
-                if args.dry_run:
-                    print(f"[Dry run] : {dispatch.name()}")
-                    continue
-
-                # Initialize verification and profiling results.
-                verification_result = (
-                    "Not verified" if not args.verification_enabled else "Failed"
-                )
-                runtime = -1.0
-
-                # Launch the operation dispatches for verification and profiling.
-                if args.verification_enabled:
-                    verification_result = operation_launcher.verify(configuration)
-                if args.profiling_enabled:
-                    runtime = operation_launcher.profile(configuration)
-
-                # Create performance result.
-                result = PerformanceResult(
-                    operation, configuration, verification_result, runtime
-                )
-
-                # Print the performance result.
-                result.print()
-
-                # Append the performance result to the performance report.
-                perf_report.append_perf_result(result)
diff --git a/experimental/dispatch_profiler/requirements.txt b/experimental/dispatch_profiler/requirements.txt
deleted file mode 100644
index 296d654..0000000
--- a/experimental/dispatch_profiler/requirements.txt
+++ /dev/null
@@ -1 +0,0 @@
-numpy
\ No newline at end of file
diff --git a/experimental/dispatch_profiler/split_k_matmul.py b/experimental/dispatch_profiler/split_k_matmul.py
deleted file mode 100644
index 02f462e..0000000
--- a/experimental/dispatch_profiler/split_k_matmul.py
+++ /dev/null
@@ -1,92 +0,0 @@
-# Copyright 2023 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from library import *
-from dispatch import *
-from matmul import MatmulOperation, MatmulCompilationInfo, CudaMatmulGenerator
-
-
-class CudaSplitKMatmulGenerator(CudaMatmulGenerator):
-    """SplitK Matmul dispatch generator class."""
-
-    def __init__(self, args):
-        """Initializes the splitK matmul generator."""
-        super().__init__(args)
-
-        # Predefined matmul shapes for splitK matmul.
-        self.matmul_shapes = [[128, 128, 12288]]
-
-        # Predefined split_k_slices list for splitK matmul.
-        self.split_k_slices = [2, 4, 16, 18]
-
-        # SplitK matmul dispatches collection list.
-        self.dispatches_collection_list = []
-
-    def _append_matmul_dispatch_collection(
-        self, matmul_shapes, split_k_slices, data_type, configuration_list
-    ):
-        """Appends the split-k matmul dispatches collection with the given configuration list."""
-
-        # Create dispatches collection for each matmul_shape x split_k_slice x configuration list.
-        for matmul_shape in matmul_shapes:
-            for split_k_slice in split_k_slices:
-                operation = MatmulOperation(
-                    matmul_shape,
-                    TensorDescription(data_type[0], LayoutType.RowMajor),
-                    TensorDescription(data_type[1], LayoutType.RowMajor),
-                    TensorDescription(data_type[2], LayoutType.RowMajor),
-                    1,  # batch_count
-                    split_k_slice,
-                    OperationKind.SplitkMatmul,
-                )
-
-                # Filter out configurations that are not supported by LLVM GPU CUDA backend.
-                supported_configuration_list = self._cuda_supported_configuration_list(
-                    operation, configuration_list
-                )
-
-                # Add default configuration if enabled.
-                if self.args.default_config:
-                    supported_configuration_list.append(
-                        MatmulCompilationInfo(
-                            [], [], OperationKind.Matmul, CompilationConfigType.Default
-                        )
-                    )
-
-                # Append the dispatch collection.
-                self.dispatches_collection_list.append(
-                    DispatchCollection(operation, supported_configuration_list)
-                )
-
-    def _cuda_matmul_tensor_cores_f16(self):
-        """Appends a list of matmul split-k dispatches for GPU TensorCore F16 data type."""
-        configuration_list = self._get_matmul_custom_compilation_info_list(
-            self.tile_descriptions_tensor_cores_f16,
-            self.translation_infos,
-            OperationKind.SplitkMatmul,
-        )
-        data_type = [DataType.f16, DataType.f16, DataType.f16]
-        self._append_matmul_dispatch_collection(
-            self.matmul_shapes, self.split_k_slices, data_type, configuration_list
-        )
-
-    def _cuda_matmul_tensor_cores_f32(self):
-        """Appends a list of matmul split-k dispatches for GPU TensorCore F32 data type."""
-        configuration_list = self._get_matmul_custom_compilation_info_list(
-            self.tile_descriptions_tensor_cores_f32,
-            self.translation_infos,
-            OperationKind.SplitkMatmul,
-        )
-        data_type = [DataType.f32, DataType.f32, DataType.f32]
-        self._append_matmul_dispatch_collection(
-            self.matmul_shapes, self.split_k_slices, data_type, configuration_list
-        )
-
-    def generate(self):
-        """Generates a list of split-k matmul operations."""
-        self._cuda_matmul_tensor_cores_f16()
-        self._cuda_matmul_tensor_cores_f32()
-        return self.dispatches_collection_list