Merge pull request #3533

d9d1dd6 Synchronize submodules
298c725 Integrate TF at tensorflow/tensorflow@2d6bdab
503b47a Merge pull request #3526 from rsuderman:main-to-google
5dde072 Synchronize submodules
cea398a Integrate TF at tensorflow/tensorflow@de10689
0a32dad Synchronize submodules
21bdab3 Integrate LLVM at llvm/llvm-project@f402e68
06daf26 Add folder for mhlo::pad
05c8e7c Synchronize submodules
8f24911 Integrate LLVM at llvm/llvm-project@b740899
diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
index ad9635a..d5d5bb0 100644
--- a/.github/workflows/lint.yml
+++ b/.github/workflows/lint.yml
@@ -59,9 +59,10 @@
         run: |
           git diff -U0 "${BASE_REF?}" | python3 third_party/format_diff/format_diff.py yapf -i
           git diff --exit-code
-      - name: Instructions for fixing these linting errors
+      - name: Instructions for fixing the above linting errors
+        if: ${{ failure() }}
         run: |
-          printf "If the lint above failed it can be fixed by running\n"
+          printf "You can fix the lint errors above by running\n"
           printf "  git diff -U0 "${BASE_REF?}" | python3 third_party/format_diff/format_diff.py yapf -i\n"
 
   clang-format:
diff --git a/build_tools/cmake/iree_copts.cmake b/build_tools/cmake/iree_copts.cmake
index 9d99d9b..74a29c4 100644
--- a/build_tools/cmake/iree_copts.cmake
+++ b/build_tools/cmake/iree_copts.cmake
@@ -39,7 +39,7 @@
 )
 
 if(${IREE_ENABLE_RUNTIME_TRACING})
-  set (CMAKE_EXE_LINKER_FLAGS -ldl)
+  string (APPEND CMAKE_EXE_LINKER_FLAGS -ldl)
 endif()
 
 iree_select_compiler_opts(IREE_DEFAULT_COPTS
diff --git a/docs/developing_iree/e2e_benchmarking.md b/docs/developing_iree/e2e_benchmarking.md
index 3fe8604..96523c4 100644
--- a/docs/developing_iree/e2e_benchmarking.md
+++ b/docs/developing_iree/e2e_benchmarking.md
@@ -178,12 +178,19 @@
 # Enter the TensorFlow Bazel workspace.
 cd third_party/tensorflow/
 
-# Build the benchmark_model binary without RUY...
+# Build the benchmark_model binary.
 bazel build --copt=-mavx2 -c opt \
   //tensorflow/lite/tools/benchmark:benchmark_model
 
-# ...or build the benchmark_model binary with RUY. This will overwrite the
+# By default, TFLite/x86 uses various matrix multiplication libraries.
+# It is possible to force it to only use Ruy for all matrix multiplications.
+# That is the default on ARM but not on x86. This will overwrite the
 # previous binary unless you move it.
+#
+# Note that Ruy takes care of -mavx2 and other AVX extensions internally,
+# so this passing this flag here isn't going to make a difference to
+# matrix multiplications. However, the rest of TFLite's kernels outside
+# of ruy will still benefit from -mavx2.
 bazel build --copt=-mavx2 -c opt \
   --define=tflite_with_ruy=true \
   //tensorflow/lite/tools/benchmark:benchmark_model
@@ -274,6 +281,9 @@
 
 ```shell
 # Build the benchmark_model binary without any add-ons.
+# Note that unlike TFLite/x86, TFLite/ARM uses Ruy by default for all
+# matrix multiplications (No need to pass tflite_with_ruy), except for some
+# matrix*vector products. Below we show how to force using ruy also for that.
 bazel build -c opt \
   --config=android_arm64 \
   --cxxopt='--std=c++17' \
@@ -286,20 +296,21 @@
 ```
 
 ```shell
-# Build the benchmark_model binary with ruy.
-bazel build --copt=-mavx2 -c opt \
+# Build the benchmark_model binary using ruy even for matrix*vector
+# products. This is only worth trying in models that are heavy on matrix*vector
+# shapes, typically LSTMs and other RNNs.
+bazel build -c opt \
   --config=android_arm64 \
   --cxxopt='--std=c++17' \
-  --define=tflite_with_ruy=true \
   --copt=-DTFLITE_WITH_RUY_GEMV \
   //tensorflow/lite/tools/benchmark:benchmark_model
 
 # Rename the binary for comparison with the standard benchmark_model.
 mv bazel-bin/tensorflow/lite/tools/benchmark/benchmark_model \
-  bazel-bin/tensorflow/lite/tools/benchmark/benchmark_model_plus_ruy
-adb push bazel-bin/tensorflow/lite/tools/benchmark/benchmark_model_plus_ruy \
+  bazel-bin/tensorflow/lite/tools/benchmark/benchmark_model_plus_ruy_gemv
+adb push bazel-bin/tensorflow/lite/tools/benchmark/benchmark_model_plus_ruy_gemv \
   /data/local/tmp/
-adb shell chmod +x /data/local/tmp/benchmark_model_plus_ruy
+adb shell chmod +x /data/local/tmp/benchmark_model_plus_ruy_gemv
 ```
 
 ```shell
@@ -336,17 +347,15 @@
   --warmup_runs=1 \
   --num_threads=1 \
   --num_runs=10 \
-  --enable_op_profiling=true
 ```
 
 ```shell
-# Benchmark with TFLite + RUY.
-adb shell taskset f0 /data/local/tmp/benchmark_model_plus_ruy \
+# Benchmark with TFLite + RUY GEMV
+adb shell taskset f0 /data/local/tmp/benchmark_model_plus_ruy_gemv \
   --graph=/data/local/tmp/MatrixOpsStaticModule/tflite/matmul_lhs_batch.tflite \
   --warmup_runs=1 \
   --num_threads=1 \
   --num_runs=10 \
-  --enable_op_profiling=true
 ```
 
 ```shell
@@ -356,7 +365,6 @@
   --warmup_runs=1 \
   --num_threads=1 \
   --num_runs=10 \
-  --enable_op_profiling=true
 ```
 
 ```shell
@@ -366,7 +374,6 @@
   --warmup_runs=1 \
   --num_threads=1 \
   --num_runs=10 \
-  --enable_op_profiling=true \
   --use_gpu=true
 ```
 
@@ -384,3 +391,30 @@
 name of the `.tflite` graph that you need to benchmark _may_ be different from
 the name of the trace that you want to benchmark, but you can use `cat` on
 the `graph_path` file to verify the correct `.tflite` filename if you're unsure.
+
+### Profile
+
+There are 2 profilers built into TFLite's `benchmark_model` program. Both of them impact latencies, so they should only be used to get a breakdown of the relative time spent in each operator type, they should not be enabled for the purpose of measuring a latency.
+
+The first is `enable_op_profiling`. It's based on timestamps before and after each op. It's a runtime commandline flag taken by `benchmark_model`. Example:
+
+```
+adb shell taskset f0 /data/local/tmp/benchmark_model \
+  --graph=/data/local/tmp/MatrixOpsStaticModule/tflite/matmul_lhs_batch.tflite \
+  --warmup_runs=1 \
+  --num_threads=1 \
+  --num_runs=10 \
+  --enable_op_profiling=true
+```
+
+The second is `ruy_profiler`. Despite its name, it's available regardless of whether `ruy` is used for the matrix multiplications. It's a sampling profiler, which allows it to provide some more detailed informations, particularly on matrix multiplications. It's a build-time switch:
+
+```
+blaze build \
+  --define=ruy_profiler=true \
+  -c opt \
+  --config=android_arm64 \
+  //tensorflow/lite/tools/benchmark:benchmark_model
+```
+
+The binary thus built can be run like above, no commandline flag needed.
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
index 899a485..93b8809 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
@@ -204,7 +204,7 @@
     if (vecType.getRank() != 2) return failure();
     // TODO(thomasraoux): use coloumn major operand when TransfertRead +
     // TransposeOp.
-    if (!op.permutation_map().isIdentity()) return failure();
+    if (!op.permutation_map().isMinorIdentity()) return failure();
     if (op.masked() &&
         llvm::any_of(op.masked()->template cast<ArrayAttr>(),
                      [](mlir::Attribute maskedDim) {
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir
index 52aeff7..ac83d70 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir
@@ -104,3 +104,41 @@
     return
   }
 }
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+module attributes {gpu.container_module, spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader, CooperativeMatrixNV, Int8, Float16, StorageUniform16, StorageBuffer8BitAccess, Float16Buffer], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix, SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+  func @kernel_matmul_vector_memref(%arg0: memref<4096x256xvector<4xi32>>, %arg1: memref<4096x256xvector<4xi32>>, %arg2: memref<4096x1024xvector<4xi32>>) attributes {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} {
+    %c32 = constant 32 : index
+    %c4096 = constant 4096 : index
+    %c0 = constant 0 : index
+    %cst = constant dense<0> : vector<4xi32>
+    // CHECK: %[[C:.+]] = spv.CooperativeMatrixLoadNV %{{.*}}, %{{.*}}, %{{.*}}
+    %4 = vector.transfer_read %arg2[%c0, %c0], %cst : memref<4096x1024xvector<4xi32>>, vector<16x16xi32>
+    // CHECK: %[[ACC:.+]] = spv.Variable : !spv.ptr<!spv.coopmatrix<16x16xi32, Subgroup>, Function>
+    // CHECK: spv.loop {
+      // CHECK: spv.Branch ^[[BB:.+]](%{{.*}}, %[[C]] : i32, !spv.coopmatrix<16x16xi32, Subgroup>)
+      // CHECK: ^[[BB]](%{{.*}}: i32, %[[C1:.+]]: !spv.coopmatrix<16x16xi32, Subgroup>)
+    %5 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %4) -> (vector<16x16xi32>) {
+      // CHECK: %[[A:.+]] = spv.CooperativeMatrixLoadNV %{{.*}}, %{{.*}}, %{{.*}}
+      %6 = vector.transfer_read %arg0[%c0, %arg3], %cst : memref<4096x256xvector<4xi32>>, vector<16x32xi8>
+      // CHECK: %[[B:.+]] = spv.CooperativeMatrixLoadNV %{{.*}}, %{{.*}}, %{{.*}}
+      %7 = vector.transfer_read %arg1[%arg3, %c0], %cst : memref<4096x256xvector<4xi32>>, vector<32x16xi8>
+      // CHECK: %[[R:.+]] = spv.CooperativeMatrixMulAddNV %[[A]], %[[B]], %[[C1]]
+      %8 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %6, %7, %arg4 : vector<16x32xi8>, vector<32x16xi8> into vector<16x16xi32>
+      // CHECK: spv.Store "Function" %[[ACC]], %[[R]] : !spv.coopmatrix<16x16xi32, Subgroup>
+      // CHECK: spv.Branch ^[[BB]](%{{.*}}, %[[R]] : i32, !spv.coopmatrix<16x16xi32, Subgroup>)
+      scf.yield %8 : vector<16x16xi32>
+    }
+    // CHECK: %[[ACCv:.+]] = spv.Load "Function" %[[ACC]] : !spv.coopmatrix<16x16xi32, Subgroup>
+    // CHECK: spv.CooperativeMatrixStoreNV %{{.*}}, %[[ACCv]], %{{.*}}, %{{.*}}
+    vector.transfer_write %5, %arg2[%c0, %c0] : vector<16x16xi32>, memref<4096x1024xvector<4xi32>>
+    return
+  }
+}
+