Open source documentation for e2e benchmarking (#3275)

Adds documentation for benchmarking TensorFlow models 
with IREE and TFLite on desktop and Android.
diff --git a/docs/developing_iree/e2e_benchmarking.md b/docs/developing_iree/e2e_benchmarking.md
new file mode 100644
index 0000000..42bc69c
--- /dev/null
+++ b/docs/developing_iree/e2e_benchmarking.md
@@ -0,0 +1,355 @@
+# Benchmark IREE and TFLite
+
+We use our end-to-end TensorFlow integration tests to test compilation and
+numerical accuracy, and to generate compilation and benchmarking artifacts.
+This allows us to validate that our benchmarks are behaving as we expect them
+to, and to run them using valid inputs for each model.
+
+This guide assumes that you can run the tensorflow integration tests. See
+[this doc](https://google.github.io/iree/developing-iree/tensorflow-integrations)
+for more information. That doc also covers writing new tests, which you'll need
+to do if you'd like to benchmark a new TensorFlow model.
+
+## 1. Run IREE's E2E TensorFlow tests to generate the benchmarking artifacts
+
+This command will compile and test all of our passing, non-manual targets.
+
+```shell
+bazel test //integrations/tensorflow/e2e/...
+```
+
+Running the above command populates a directory `/tmp/iree/modules/` with the
+compilation artifacts needed to benchmark each TensorFlow model in our tests.
+Each test/module has a folder with the following artifacts (filtered to only
+include those relevant for benchmarking):
+
+```shell
+# Example for a generic module `ModuleName`:
+/tmp/iree/modules/ModuleName
+  ├── iree_vmla  # Or any other IREE backend.
+  │   ├── compiled.vmfb
+  │   │   # A flatbuffer containing IREE's compiled code.
+  │   └── traces
+  │       # Directory with a trace for each unittest in vision_model_test.py.
+  │       ├── traced_function_1
+  │       │   # Directory storing logs and serialization for a specific trace.
+  │       │   └── flagfile
+  │       │       # An Abseil flagfile containing arguments
+  │       │       # iree-benchmark-module needs to benchmark this trace.
+  │       ├── traced_function_2
+  │       └── ...
+  └── tflite
+      ├── module_method_1.tflite
+      │   # A method on ModuleName compiled to bytes with TFLite, which can
+      │   # be used by the TFLite's benchmark_model binary.
+      ├── module_method_2.tflite
+      ├── ...
+      └── traces
+          ├── traced_function_1
+          │   └── graph_path
+          │       # In general, a trace's name does not have to match the name
+          │       # of the method(s) on the tf.Module that it calls. This file
+          │       # points to the correct module_method_*.tflite graph file
+          │       # for TFLite's benchmark_model to use.
+          ├── traced_function_2
+          └── ...
+
+# Example for MatrixOpsStaticModule:
+/tmp/iree/modules/MatrixOpsStaticModule
+  ├── iree_llvmjit
+  │   ├── compiled.vmfb
+  │   └── traces
+  │       ├── basic_matmul
+  │       │   └── flagfile
+  │       ├── matmul_broadcast_singleton_dimension
+  │       │   └── flagfile
+  │       ├── matmul_lhs_batch
+  │       │   └── flagfile
+  │       └── matmul_rhs_batch
+  │           └── flagfile
+  ├── iree_vmla
+  │   ├── compiled.vmfb
+  │   └── traces  # ...same as iree_llvmjit/traces above.
+  ├── iree_vulkan
+  │   ├── compiled.vmfb
+  │   └── traces  # ...same as iree_llvmjit/traces above.
+  └── tflite
+      ├── basic_matmul.tflite
+      ├── matmul_broadcast_singleton_dimension.tflite
+      ├── matmul_lhs_batch.tflite
+      ├── matmul_rhs_batch.tflite
+      └── traces
+          ├── basic_matmul
+          │   └── graph_path
+          ├── matmul_broadcast_singleton_dimension
+          │   └── graph_path
+          ├── matmul_lhs_batch
+          │   └── graph_path
+          └── matmul_rhs_batch
+              └── graph_path
+```
+
+### Optional: Compile the Keras Applications Vision tests
+
+The vision tests take a while to run, so we exclude them from the CI and
+wildcard expansion. They can be run by invoking the following test suite:
+
+```shell
+bazel test //integrations/tensorflow/e2e/keras:vision_external_tests
+```
+
+The previous command compiles `MobileNet`, `MobileNetV2` and `ResNet50` to run
+on `cifar10` and `imagenet` weights on all backends. The artifacts generated by
+this test suite are slightly different than those above in that they are
+organized by `/tmp/iree/modules/ModelName/Dataset/backends` instead of just by
+`/tmp/iree/modules/ModelName/backends`.
+
+## 2. Benchmarking IREE on desktop
+
+### 2.1 Optional: Build the `iree-benchmark-module`
+
+This step is optional, but allows running the benchmarks without running `bazel`
+at the same time.
+
+```shell
+bazel build -c opt //iree/tools:iree-benchmark-module
+```
+
+This creates `bazel-bin/iree/tools/iree-benchmark-module`. The rest of the guide
+will use this binary, but you could also use
+`bazel run iree/tools:iree-benchmark-module` in its place if your prefer.
+
+### 2.2 Benchmark the model on IREE
+
+The E2E tests generate a flagfile with all of the information that
+`iree-benchmark-module` needs to benchmark each trace. Namely it handles
+providing the following flags:
+
+| Flag              | Description                                      |
+|-------------------|--------------------------------------------------|
+| --input_file      | Absolute path to the IREE compiled VM flatbuffer |
+| --inputs          | A comma delimited string of input tensors        |
+| --driver          | The backend driver to use for the benchmark      |
+| --entry_function  | The method on the TensorFlow module to benchmark |
+
+You can find the flagfile to benchmark a specific TensorFlow module on a
+specific IREE backend and trace at the following path:
+
+```shell
+/tmp/iree/modules/ModuleName/backend/traces/trace_name/flagfile
+```
+
+For example, if we wanted to benchmark a static left-hand-side batched matmul
+using `MatrixOpsStaticModule` on VMLA we would run the following command:
+
+```shell
+./bazel-bin/iree/tools/iree-benchmark-module \
+  --flagfile="/tmp/iree/modules/MatrixOpsStaticModule/iree_vmla/traces/matmul_lhs_batch/flagfile"
+```
+
+If you ran the Keras Applications vision test suite, then you'll be able to
+benchmark `ResNet50`, `MobileNet` or `MobileNetV2` with `cifar10` or `imagenet`
+weights. For example:
+
+```shell
+./bazel-bin/iree/tools/iree-benchmark-module \
+  --flagfile="/tmp/iree/modules/ResNet50/cifar10/iree_vmla/traces/predict/flagfile"
+```
+
+## 3. Benchmarking TFLite on desktop
+
+### 3.1 Build TFLite's `benchmark_model` binary
+
+```shell
+# Enter the TensorFlow Bazel workspace.
+cd third_party/tensorflow/
+
+# Build the benchmark_model binary without RUY...
+bazel build --copt=-mavx2 -c opt \
+  //tensorflow/lite/tools/benchmark:benchmark_model
+
+# ...or build the benchmark_model binary with RUY. This will overwrite the
+# previous binary unless you move it.
+bazel build --copt=-mavx2 -c opt \
+  --define=tflite_with_ruy=true \
+  //tensorflow/lite/tools/benchmark:benchmark_model
+
+# The binary can now be found in the following directory:
+ls bazel-bin/tensorflow/lite/tools/benchmark/
+```
+
+### 3.2 Benchmark the model on TFLite
+
+TFLite doesn't support flagfiles, so we need to manually pass the path to the
+graph file via `cat`. TFLite will generate fake inputs for the model.
+
+Using `MatrixOpsStaticModule`'s left-hand-side batched matmul again as an
+example we can run the benchmark as follows:
+
+```shell
+# Run within `third_party/tensorflow/`.
+./bazel-bin/tensorflow/lite/tools/benchmark/benchmark_model \
+  --graph=$(cat "/tmp/iree/modules/MatrixOpsStaticModule/tflite/traces/matmul_lhs_batch/graph_path") \
+  --warmup_runs=1 \
+  --num_threads=1 \
+  --num_runs=100 \
+  --enable_op_profiling=true
+```
+
+## 4. Benchmarking IREE on Android
+
+### 4.1 Prepare the benchmarking tools
+
+IREE only supports compiling to Android with CMake. Documentation on setting up
+your environment to cross-compile to Android can be found
+[here](https://google.github.io/iree/get-started/getting-started-android-cmake).
+
+```shell
+# After following the instructions above up to 'Build all targets', the
+# iree-benchmark-module binary should be in the following directory:
+ls build-android/iree/tools/
+
+# Copy the benchmarking binary to phone.
+adb push build-android/iree/tools/iree-benchmark-module /data/local/tmp
+
+# Allow executing benchmarking file as a program.
+adb shell chmod +x /data/local/tmp/iree-benchmark-module
+```
+
+### 4.2 Push the IREE's compilation / benchmarking artifacts to the device
+
+In this example we'll only copy over the files we need to benchmark a single
+module on a single backend, but you can easily copy all of the modules over
+as well.
+
+Using `MatrixOpsStaticModule`'s left-hand-side batched matmul again as an
+example:
+
+```shell
+# Make a directory for the module/backend pair we want to benchmark.
+mkdir -p /data/local/tmp/MatrixOpsStaticModule/iree_vmla/
+
+# Transfer the files.
+adb push /tmp/iree/modules/MatrixOpsStaticModule/iree_vmla/* \
+  /data/local/tmp/MatrixOpsStaticModule/iree_vmla/
+```
+
+### 4.3 Benchmark the module
+
+```shell
+adb shell /data/local/tmp/iree-benchmark-module \
+  --flagfile="/data/local/tmp/MatrixOpsStaticModule/iree_vmla/traces/matmul_lhs_batch/flagfile"
+  --input_file="/data/local/tmp/MatrixOpsStaticModule/iree_vmla/compiled.vmfb"
+```
+
+Note: Because the flagfile uses absolute paths, the `--input_file` flag must be
+specified manually if the location of the compiled flatbuffer (`compiled.vmfb`)
+changes. The flagfile can still take care of specifying the input data, driver
+and entry function however.
+
+## 5. Benchmark the model on Android with TFLite
+
+### 5.1 Prepare the benchmarking tools
+
+There are three options for getting TFLite's `benchmark_model` binary for
+Android.
+
+The first two are to build it directly, either in a
+[`docker` container](https://www.tensorflow.org/lite/guide/build_android#set_up_build_environment_using_docker)
+or
+[in your own environment](https://www.tensorflow.org/lite/guide/build_android#set_up_build_environment_without_docker). Assuming you can build
+TensorFlow with Android, you can configure the TFLite `benchmark_model` binary
+in the following ways:
+
+```shell
+# Build the benchmark_model binary without any add-ons.
+bazel build -c opt \
+  --config=android_arm64 \
+  --cxxopt='--std=c++17' \
+  //tensorflow/lite/tools/benchmark:benchmark_model
+
+# Copy the benchmarking binary to phone and allow execution.
+adb push bazel-bin/tensorflow/lite/tools/benchmark/benchmark_model \
+  /data/local/tmp
+adb shell chmod +x /data/local/tmp/benchmark_model
+```
+
+```shell
+# Build the benchmark_model binary with ruy.
+bazel build --copt=-mavx2 -c opt \
+  --config=android_arm64 \
+  --cxxopt='--std=c++17' \
+  --define=tflite_with_ruy=true \
+  --copt=-DTFLITE_WITH_RUY_GEMV \
+  //tensorflow/lite/tools/benchmark:benchmark_model
+
+# Rename the binary for comparison with the standard benchmark_model.
+mv bazel-bin/tensorflow/lite/tools/benchmark/benchmark_model \
+  bazel-bin/tensorflow/lite/tools/benchmark/benchmark_model_plus_ruy
+adb push bazel-bin/tensorflow/lite/tools/benchmark/benchmark_model_plus_ruy \
+  /data/local/tmp/
+adb shell chmod +x /data/local/tmp/benchmark_model_plus_ruy
+```
+
+```shell
+# Build the benchmark_model binary with flex.
+bazel build -c opt \
+  --config=android_arm64 \
+  --cxxopt='--std=c++17' \
+  //tensorflow/lite/tools/benchmark:benchmark_model_plus_flex
+
+# Copy the benchmarking binary to phone and allow execution.
+adb push bazel-bin/tensorflow/lite/tools/benchmark/benchmark_model_plus_flex \
+  /data/local/tmp
+adb shell chmod +x /data/local/tmp/benchmark_model_plus_flex
+```
+
+Alternatively, you can download and install the
+[Android Benchmark App](https://www.tensorflow.org/lite/performance/measurement#android_benchmark_app). If you choose to install the app then
+you'll have to modify the benchmarking commands below slightly, as shown in
+[this example](https://www.tensorflow.org/lite/performance/measurement#run_benchmark).
+
+### 5.2 Run the benchmark
+
+```shell
+# Copy the data over to the phone.
+mkdir -p /data/local/tmp/MatrixOpsStaticModule/tflite
+adb push /tmp/iree/modules/MatrixOpsStaticModule/tflite/* \
+  /data/local/tmp/MatrixOpsStaticModule/tflite/
+```
+
+```shell
+# Benchmark with TFLite.
+adb shell taskset f0 /data/local/tmp/benchmark_model \
+  --graph=/data/local/tmp/MatrixOpsStaticModule/tflite/matmul_lhs_batch.tflite \
+  --warmup_runs=1 \
+  --num_threads=1 \
+  --num_runs=10 \
+  --enable_op_profiling=true
+```
+
+```shell
+# Benchmark with TFLite + RUY.
+adb shell taskset f0 /data/local/tmp/benchmark_model_plus_ruy \
+  --graph=/data/local/tmp/MatrixOpsStaticModule/tflite/matmul_lhs_batch.tflite \
+  --warmup_runs=1 \
+  --num_threads=1 \
+  --num_runs=10 \
+  --enable_op_profiling=true
+```
+
+```shell
+# Benchmark with TFLite + Flex.
+adb shell taskset f0 /data/local/tmp/benchmark_model_plus_flex \
+  --graph=/data/local/tmp/MatrixOpsStaticModule/tflite/matmul_lhs_batch.tflite \
+  --warmup_runs=1 \
+  --num_threads=1 \
+  --num_runs=10 \
+  --enable_op_profiling=true
+```
+
+Note: You will have to manually specify the TFLite graph that you want to
+benchmark, as the `graph_path` file assumes that the graph has not moved. The
+name of the `.tflite` graph that you need to benchmark _may_ be different from
+the name of the trace that you want to benchmark, but you can use `cat` on
+the `graph_path` file to verify the correct `.tflite` filename if you're unsure.
diff --git a/integrations/tensorflow/e2e/README.md b/integrations/tensorflow/e2e/README.md
index 9a97e02..6f77741 100644
--- a/integrations/tensorflow/e2e/README.md
+++ b/integrations/tensorflow/e2e/README.md
@@ -16,8 +16,8 @@
 ## Vulkan Setup
 
 If you do not have your environment setup to use IREE with Vulkan (see
-[this doc](https://google.github.io/iree/get-started/generic-vulkan-env-setup)), 
-then you can run the manual test targets with 
+[this doc](https://google.github.io/iree/get-started/generic-vulkan-env-setup)),
+then you can run the manual test targets with
 `--target_backends=tf,iree_vmla,iree_llvmjit` (that is, by omitting
 `iree_vulkan` from the list of backends to run the tests on).
 
@@ -153,25 +153,37 @@
 can be changed via the `--artifacts_dir` flag. The generated directory structure
 for each module is as follows:
 
-```
+```shell
 /tmp/iree/modules/ModuleName
-├── tf_input.mlir        # MLIR for ModuleName in TF's input dialect
-├── iree_input.mlir      # tf_input.mlir translated to IREE MLIR
-├── iree_backend_name    # e.g. iree_vmla, iree_llvmjit or iree_vulkan
-│   ├── compiled.vmfb    # flatbuffer of ModuleName compiled to this backend
-│   └── traces
-│       ├── trace_1      # Directory storing logs and serialization for each trace
-│       │   └── log.txt  # A more detailed version of the test logs
-│       └── trace_2
-│           └── log.txt
-├── tflite               # If TFLite supports compiling ModuleName
-│   ├── method_1.tflite  # Methods on ModuleName compiled to bytes with TFLite
-│   ├── method_2.tflite
-│   └── traces
-│       └── ...
-└── tf_ref               # Directory storing the tensorflow reference traces
-    └── traces
-        └── ...
+  ├── tf_input.mlir
+  │   # MLIR for ModuleName in TF's input dialect.
+  ├── iree_input.mlir
+  │   # tf_input.mlir translated to IREE MLIR.
+  ├── iree_vmla
+  │   # Or any other IREE backend.
+  │   ├── compiled.vmfb
+  │   │   # A flatbuffer containing IREE's compiled code.
+  │   └── traces
+  │       # Directory with a trace for each unittest in vision_model_test.py.
+  │       ├── trace_function_1
+  │       │   # Directory storing logs and serialization for a specific trace.
+  │       │   │── flagfile
+  │       │   │   # An Abseil flagfile containing arguments
+  │       │   │   # iree-benchmark-module needs to benchmark this trace.
+  │       │   └── log.txt
+  │       │       # A more detailed version of the test logs.
+  │       │── trace_function_2
+  │       └── ...
+  ├── tflite  # If TFLite supports compiling ModuleName.
+  │   ├── method_1.tflite  # Methods on ModuleName compiled to bytes with TFLite
+  │   │   # A method on ModuleName compiled to bytes with TFLite, which can
+  │   │   # be ingested by TFLite's benchmark_model binary.
+  │   ├── method_2.tflite
+  │   └── traces
+  │       └── ...
+  └── tf_ref  # Directory storing the tensorflow reference traces.
+      └── traces
+          └── ...
 ```
 
 Traces for a particular test can be loaded via the `Trace.load(trace_dir)`
@@ -190,36 +202,11 @@
 
 ## Benchmarking E2E Modules
 
-Abseil flagfiles containing all of the data that `iree-benchmark-module` needs
-to run are generated for each `Trace` in our E2E tests. This allows for any
-module we test to be easily benchmarked on valid inputs. The process for
-benchmarking a vision model can thus be reduced to the following:
-
-```shell
-# Generate benchmarking artifacts for all vision models:
-bazel test integrations/tensorflow/e2e/keras:vision_external_tests
-
-# Benchmark ResNet50 with cifar10 weights on vmla:
-bazel run iree/tools:iree-benchmark-module -- \
-  --flagfile=/tmp/iree/modules/ResNet50/cifar10/iree_vmla/traces/predict/flagfile
-
-# Benchmark ResNet50 with cifar10 weights on llvmjit:
-bazel run iree/tools:iree-benchmark-module -- \
-  --flagfile=/tmp/iree/modules/ResNet50/cifar10/iree_llvmjit/traces/predict/flagfile
-```
-
-Duplicate flags provided after the flagfile will take precedence. For example:
-
-```shell
-bazel run iree/tools:iree-benchmark-module -- \
-  --flagfile=/tmp/iree/modules/ResNet50/cifar10/iree_llvmjit/traces/predict/flagfile  \
-  --input_file=/path/to/custom/compiled.vmfb
-```
-
-Currently, this only supports benchmarking the first module call in a trace. We
-plan to extend this to support benchmarking all of the calls in the trace, and
-also plan to support verifying outputs during the warm-up phase of the
-benchmark.
+We use our end-to-end TensorFlow integrations tests to generate tested
+compilation and benchmarking artifacts. This allows us to validate that our
+benchmarks are behaving as we expect them to, and to run them using valid inputs
+for each model. An overview of how to run benchmarks on IREE and TFLite can be
+found in [this doc](TODO(meadowlark)).
 
 ## Debugging Tests
 
diff --git a/integrations/tensorflow/e2e/matrix_ops_static_test.py b/integrations/tensorflow/e2e/matrix_ops_static_test.py
index 14bd361..82fa482 100644
--- a/integrations/tensorflow/e2e/matrix_ops_static_test.py
+++ b/integrations/tensorflow/e2e/matrix_ops_static_test.py
@@ -18,33 +18,38 @@
 from pyiree.tf.support import tf_utils
 import tensorflow.compat.v2 as tf
 
+LEFT_DIM = 64
+INNER_DIM = 32
+RIGHT_DIM = 16
+BATCH_DIM = 256
+
 
 class MatrixOpsStaticModule(tf.Module):
 
   @tf.function(input_signature=[
-      tf.TensorSpec([4, 2], tf.float32),
-      tf.TensorSpec([2, 4], tf.float32),
+      tf.TensorSpec([LEFT_DIM, INNER_DIM], tf.float32),
+      tf.TensorSpec([INNER_DIM, RIGHT_DIM], tf.float32),
   ])
   def basic_matmul(self, lhs, rhs):
     return tf.matmul(lhs, rhs)
 
   @tf.function(input_signature=[
-      tf.TensorSpec([3, 4, 2], tf.float32),
-      tf.TensorSpec([2, 4], tf.float32),
+      tf.TensorSpec([BATCH_DIM, LEFT_DIM, INNER_DIM], tf.float32),
+      tf.TensorSpec([INNER_DIM, RIGHT_DIM], tf.float32),
   ])
   def matmul_lhs_batch(self, lhs, rhs):
     return tf.matmul(lhs, rhs)
 
   @tf.function(input_signature=[
-      tf.TensorSpec([4, 2], tf.float32),
-      tf.TensorSpec([3, 2, 4], tf.float32),
+      tf.TensorSpec([LEFT_DIM, INNER_DIM], tf.float32),
+      tf.TensorSpec([BATCH_DIM, INNER_DIM, RIGHT_DIM], tf.float32),
   ])
   def matmul_rhs_batch(self, lhs, rhs):
     return tf.matmul(lhs, rhs)
 
   @tf.function(input_signature=[
-      tf.TensorSpec([1, 4, 2], tf.float32),
-      tf.TensorSpec([3, 2, 4], tf.float32),
+      tf.TensorSpec([1, LEFT_DIM, INNER_DIM], tf.float32),
+      tf.TensorSpec([BATCH_DIM, INNER_DIM, RIGHT_DIM], tf.float32),
   ])
   def matmul_broadcast_singleton_dimension(self, lhs, rhs):
     return tf.matmul(lhs, rhs)
@@ -56,7 +61,8 @@
   def test_basic_matmul(self):
 
     def basic_matmul(module):
-      module.basic_matmul(tf_utils.uniform([4, 2]), tf_utils.uniform([2, 4]))
+      module.basic_matmul(tf_utils.uniform([LEFT_DIM, INNER_DIM]),
+                          tf_utils.uniform([INNER_DIM, RIGHT_DIM]))
 
     self.compare_backends(basic_matmul)
 
@@ -64,7 +70,8 @@
 
     def matmul_lhs_batch(module):
       module.matmul_lhs_batch(
-          tf_utils.uniform([3, 4, 2]), tf_utils.uniform([2, 4]))
+          tf_utils.uniform([BATCH_DIM, LEFT_DIM, INNER_DIM]),
+          tf_utils.uniform([INNER_DIM, RIGHT_DIM]))
 
     self.compare_backends(matmul_lhs_batch)
 
@@ -72,7 +79,8 @@
 
     def matmul_rhs_batch(module):
       module.matmul_rhs_batch(
-          tf_utils.uniform([4, 2]), tf_utils.uniform([3, 2, 4]))
+          tf_utils.uniform([LEFT_DIM, INNER_DIM]),
+          tf_utils.uniform([BATCH_DIM, INNER_DIM, RIGHT_DIM]))
 
     self.compare_backends(matmul_rhs_batch)
 
@@ -80,7 +88,8 @@
 
     def matmul_broadcast_singleton_dimension(module):
       module.matmul_broadcast_singleton_dimension(
-          tf_utils.uniform([1, 4, 2]), tf_utils.uniform([3, 2, 4]))
+          tf_utils.uniform([1, LEFT_DIM, INNER_DIM]),
+          tf_utils.uniform([BATCH_DIM, INNER_DIM, RIGHT_DIM]))
 
     self.compare_backends(matmul_broadcast_singleton_dimension)
 
diff --git a/scripts/prepare_doc_publication.py b/scripts/prepare_doc_publication.py
index d4c7d37..94c094d 100755
--- a/scripts/prepare_doc_publication.py
+++ b/scripts/prepare_doc_publication.py
@@ -61,7 +61,8 @@
     'getting_started_python.md': 'Python',
     'milestones.md': 'Short-term Focus Areas',
     'design_roadmap.md': 'Long-term Design Roadmap',
-    'tensorflow_integrations.md': 'TensorFlow Integrations and Benchmarking',
+    'tensorflow_integrations.md': 'TensorFlow Integrations',
+    'e2e_benchmarking.md': 'Benchmarking TensorFlow with IREE and TFLite',
 }
 
 # A dictionary containing source file to permanent link mappings.
@@ -116,7 +117,8 @@
     'testing_guide.md': 3,
     'benchmarking.md': 4,
     'tensorflow_integrations.md': 5,
-    'repository_management.md': 6,
+    'e2e_benchmarking.md': 6,
+    'repository_management.md': 7,
 
     # Within 'Using IREE' use explicit ordering.
     'using_colab.md': 1,