Account all the element types to determine vector sizes. (#8552)

The assumption was that all the element types have the same bitwidth.
However, there are cases that element types do not match, e.g.,
matmul i8xi8 -> i32. It caused issues that large tiling sizes were
selected, which kicked in heavy optimization in LLVM. This commit
chooses the smallest vector size over all the element types.

This also updates the logic of first level tiling, which follows what we've
done for generic ops.

The commit reduce compilation time from hours to 5 mins for
mobilebert-baseline-tf2-quant.mlir when targeting ARM.

Fixes https://github.com/google/iree/issues/8540
diff --git a/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
index 726cdd8..db883f2 100644
--- a/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
@@ -99,6 +99,33 @@
   return getVectorSize(entryPointFn, byteWidth);
 }
 
+/// Returns minimum tiling sizes for each dimension. One dimension is possible
+/// to access at different element types. It determines the tiling sizes by
+/// looking into all the operands.
+static SmallVector<int64_t> getMinTilingSizesForEachDim(FuncOp entryPointFn,
+                                                        linalg::LinalgOp op) {
+  unsigned numLoops = op.getNumLoops();
+  SmallVector<int64_t> minTileSizes(numLoops, 1);
+  auto inputOutputOpOperands = op.getInputAndOutputOperands();
+  for (auto map : llvm::enumerate(op.getIndexingMaps())) {
+    // Check the fastest varying dimension of the operand. Set the vector size
+    // of the corresponding loop to the vector size.
+    if (map.value().getNumResults() == 0) continue;
+    auto fastestVaryingDimExpr =
+        map.value().getResults().back().dyn_cast<AffineDimExpr>();
+    if (!fastestVaryingDimExpr) continue;
+    unsigned fastestVaryingDim = fastestVaryingDimExpr.getPosition();
+
+    // If the indexing map has result it has to be a shaped type.
+    auto operandType =
+        inputOutputOpOperands[map.index()]->get().getType().cast<ShapedType>();
+    minTileSizes[fastestVaryingDim] =
+        std::max<int64_t>(minTileSizes[fastestVaryingDim],
+                          getVectorSize(entryPointFn, operandType));
+  }
+  return minTileSizes;
+}
+
 /// Returns the type length in bytes. Looks through all the interface binding
 /// ops to see the ABI types and guess-timates the type size to use. This is
 /// used to convert the vector size in bytes to vector size in number of
@@ -409,11 +436,20 @@
     FuncOp entryPointFn, linalg::ContractionOpInterface contractionOp,
     ArrayRef<LoopTilingAndDistributionInfo> tiledLoops) {
   auto linalgOp = cast<linalg::LinalgOp>(contractionOp.getOperation());
+  // Consider all element types and use the smallest vector size. The tiling
+  // sizes are chosen based on the vector size.
   auto lhsShapedType = contractionOp.lhs().getType().cast<ShapedType>();
+  auto rhsShapedType = contractionOp.rhs().getType().cast<ShapedType>();
+  auto resShapedType =
+      linalgOp.getOutputOperand(0)->get().getType().cast<ShapedType>();
+  int64_t vectorSize = getVectorSize(entryPointFn, lhsShapedType);
+  vectorSize = std::min(vectorSize, getVectorSize(entryPointFn, rhsShapedType));
+  vectorSize = std::min(vectorSize, getVectorSize(entryPointFn, resShapedType));
+
   // Use the default distribution for the matmul loops.
   unsigned numLoops = linalgOp.getNumLoops();
-  int64_t vectorSize = getVectorSize(entryPointFn, lhsShapedType);
-  SmallVector<int64_t> minTileSizes(numLoops, vectorSize);
+  SmallVector<int64_t> minTileSizes =
+      getMinTilingSizesForEachDim(entryPointFn, linalgOp);
   SmallVector<int64_t> maxTileSizes(numLoops, defaultWorkgroupTileSize);
   if (numLoops > 3) {
     minTileSizes[0] = 1;
@@ -539,25 +575,9 @@
   unsigned numLoops = genericOp.getNumLoops();
   if (numLoops == 0) return success();
 
-  SmallVector<int64_t> minTileSizes(numLoops, 1),
-      maxTileSizes(numLoops, defaultWorkgroupTileSize);
-  auto inputOutputOpOperands = genericOp.getInputAndOutputOperands();
-  for (auto map : llvm::enumerate(genericOp.getIndexingMaps())) {
-    // Check the fastest varying dimension of the operand. Set the vector size
-    // of the corresponding loop to the vector size.
-    if (map.value().getNumResults() == 0) continue;
-    auto fastestVaryingDimExpr =
-        map.value().getResults().back().dyn_cast<AffineDimExpr>();
-    if (!fastestVaryingDimExpr) continue;
-    unsigned fastestVaryingDim = fastestVaryingDimExpr.getPosition();
-
-    // If the indexing map has result it has to be a shaped type.
-    auto operandType =
-        inputOutputOpOperands[map.index()]->get().getType().cast<ShapedType>();
-    minTileSizes[fastestVaryingDim] =
-        std::max<int64_t>(minTileSizes[fastestVaryingDim],
-                          getVectorSize(entryPointFn, operandType));
-  }
+  SmallVector<int64_t> minTileSizes =
+      getMinTilingSizesForEachDim(entryPointFn, genericOp);
+  SmallVector<int64_t> maxTileSizes(numLoops, defaultWorkgroupTileSize);
   if (llvm::all_of(minTileSizes, [](int64_t vs) { return vs == 1; })) {
     // Nothing to vectorize just lower to loops.
     return success();
diff --git a/iree/compiler/Codegen/LLVMCPU/test/materialize_launch_configuration.mlir b/iree/compiler/Codegen/LLVMCPU/test/materialize_launch_configuration.mlir
index 3445c60..4f219c4 100644
--- a/iree/compiler/Codegen/LLVMCPU/test/materialize_launch_configuration.mlir
+++ b/iree/compiler/Codegen/LLVMCPU/test/materialize_launch_configuration.mlir
@@ -718,7 +718,7 @@
   }
 }
 
-//   CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[28, 8, 0], [4, 4, 60], [4, 4, 4]{{\]}}>
+//   CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[49, 8, 0], [7, 4, 60], [4, 4, 4]{{\]}}>
 //   CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUTileFuseAndVectorize>
 //       CHECK: hal.executable.entry_point public @matmul_static
 //  CHECK-SAME:     translation_info = #[[TRANSLATION]]
@@ -936,7 +936,7 @@
     #hal.descriptor_set.binding<2, storage_buffer>
   ]>
 ]>
-hal.executable private @matmul_i8_i8_i32  {
+hal.executable private @matmul_x86_i8_i8_i32  {
   hal.executable.variant public @embedded_elf_x86_64, target = #hal.executable.target<
     "llvm",
     "embedded-elf-x86_64", {
@@ -944,9 +944,9 @@
       native_vector_size = 4 : index,
       target_triple = "x86_64-unknown-unknown-eabi-elf"
     }> {
-    hal.executable.entry_point public @matmul_i8_i8_i32 layout(#executable_layout)
+    hal.executable.entry_point public @matmul_x86_i8_i8_i32 layout(#executable_layout)
     builtin.module {
-      func @matmul_i8_i8_i32() {
+      func @matmul_x86_i8_i8_i32() {
         %c0 = arith.constant 0 : index
         %M = hal.interface.constant.load[0] : index
         %N = hal.interface.constant.load[1] : index
@@ -974,7 +974,57 @@
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64, 0], [8, 32, 0], [0, 0, 16]{{\]}}>
 //  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDoubleTilingExpert>
-//      CHECK: hal.executable.entry_point public @matmul_i8_i8_i32
+//      CHECK: hal.executable.entry_point public @matmul_x86_i8_i8_i32
+// CHECK-SAME:     translation_info = #[[TRANSLATION]]
+//      CHECK:   linalg.matmul
+// CHECK-SAME:       lowering_config = #[[CONFIG]]
+
+// -----
+
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>,
+    #hal.descriptor_set.binding<2, storage_buffer>
+  ]>
+]>
+hal.executable private @matmul_aarch_i8_i8_i32  {
+  hal.executable.variant public @system_elf_arm_64, target = <"llvm", "system-elf-arm_64", {
+    data_layout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128",
+    native_vector_size = 16 : index,
+    target_triple = "aarch64-none-linux-android30"
+  }> {
+  hal.executable.entry_point public @matmul_aarch_i8_i8_i32 layout(#executable_layout)
+    builtin.module {
+      func @matmul_aarch_i8_i8_i32() {
+        %c0 = arith.constant 0 : index
+        %M = hal.interface.constant.load[0] : index
+        %N = hal.interface.constant.load[1] : index
+        %K = hal.interface.constant.load[2] : index
+        %lhs_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32)
+            : !flow.dispatch.tensor<readonly:?x?xi8>{%M, %K}
+        %rhs_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32)
+            : !flow.dispatch.tensor<readonly:?x?xi8>{%K, %N}
+        %result_binding = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32)
+            : !flow.dispatch.tensor<readwrite:?x?xi32>{%M, %N}
+        %lhs = flow.dispatch.tensor.load %lhs_binding, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:?x?xi8>{%M, %K} -> tensor<?x?xi8>
+        %rhs = flow.dispatch.tensor.load %rhs_binding, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:?x?xi8>{%K, %N} -> tensor<?x?xi8>
+        %init = flow.dispatch.tensor.load %result_binding, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
+            : !flow.dispatch.tensor<readwrite:?x?xi32>{%M, %N} -> tensor<?x?xi32>
+        %gemm = linalg.matmul ins(%lhs, %rhs : tensor<?x?xi8>, tensor<?x?xi8>) outs(%init : tensor<?x?xi32>) -> tensor<?x?xi32>
+        flow.dispatch.tensor.store %gemm, %result_binding, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
+            : tensor<?x?xi32> -> !flow.dispatch.tensor<readwrite:?x?xi32>{%M, %N}
+        return
+      }
+    }
+  }
+}
+
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64, 0], [16, 4, 64], [4, 4, 4]]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUTileFuseAndVectorize>
+//      CHECK: hal.executable.entry_point public @matmul_aarch_i8_i8_i32
 // CHECK-SAME:     translation_info = #[[TRANSLATION]]
 //      CHECK:   linalg.matmul
 // CHECK-SAME:       lowering_config = #[[CONFIG]]
@@ -1118,7 +1168,7 @@
     }
   }
 }
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[3, 7, 0], [3, 7, 0], [0, 0, 16]]>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[11, 7, 0], [1, 7, 0], [0, 0, 16]]>
 //  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDoubleTilingExpert>
 //      CHECK: hal.executable.entry_point public @matmul_odd
 // CHECK-SAME:       translation_info = #[[TRANSLATION]]