Add several tile sizes and add separate tile size for small matrices. (#5050)

This improves performance of MobileBert on Mali
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
index 6ce6ae2..0ef6877 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
@@ -70,6 +70,12 @@
   std::array<int64_t, 3> numSubgroups = {1, 1, 1};
   bool vectorize = false;
 };
+
+struct TileWorkgroupSizePair {
+  // How many scalar elements each workgroup should handle along each dimension.
+  std::array<int64_t, 3> tileSize;
+  std::array<int64_t, 3> workgroupSize;
+};
 }  // namespace
 
 /// For a given operation `op`, compute the following configurations according
@@ -86,12 +92,24 @@
   return op.emitError("undefined launch config for tiled operation");
 }
 
-static void getMaliBestMatMulTileSizes(Type elementType,
-                                       SmallVectorImpl<int64_t> &tileSizes) {
+static void getMaliBestMatMulTileSizes(
+    Type elementType, SmallVectorImpl<TileWorkgroupSizePair> &tileSizes,
+    int64_t dstSize) {
+  const int64_t smallMatrixSizeThreshold = 512 * 512;
   if (elementType.isF16()) {
-    tileSizes.append({16, 64, 8});
+    // When the destination is smaller than the threshold, we prefer smaller
+    // tiles to increase parallelism.
+    // TODO: The threshold needs to be fine tuned by doing exploration based on
+    // matrix shapes.
+    if (dstSize <= smallMatrixSizeThreshold) {
+      tileSizes.push_back(TileWorkgroupSizePair({{16, 32, 8}, {8, 2, 1}}));
+    } else {
+      tileSizes.push_back(TileWorkgroupSizePair({{16, 64, 4}, {8, 2, 1}}));
+      tileSizes.push_back(TileWorkgroupSizePair({{8, 128, 4}, {8, 2, 1}}));
+      tileSizes.push_back(TileWorkgroupSizePair({{16, 32, 4}, {8, 2, 1}}));
+    }
   } else {
-    tileSizes.append({8, 64, 4});
+    tileSizes.push_back(TileWorkgroupSizePair({{8, 64, 4}, {16, 1, 1}}));
   }
 }
 
@@ -106,29 +124,34 @@
   auto lhsType = op.inputs()[0].getType().cast<MemRefType>();
   auto rhsType = op.inputs()[1].getType().cast<MemRefType>();
   assert(lhsType.getElementType() == rhsType.getElementType());
-  // Pick ideal tile size based on the type.
-  SmallVector<int64_t, 4> workgroupLevelTs(1, 1);
-  getMaliBestMatMulTileSizes(lhsType.getElementType(), workgroupLevelTs);
-  // Fall back to the none vectorize path for cases we don't handle.
-  if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape() ||
-      lhsType.getDimSize(1) % workgroupLevelTs[1] != 0 ||
-      rhsType.getDimSize(2) % workgroupLevelTs[2] != 0 ||
-      lhsType.getDimSize(2) % workgroupLevelTs[3] != 0) {
-    return failure();
-  }
+  if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape()) return failure();
+  // Get a vector of best tile size ordered from best to worst.
+  SmallVector<TileWorkgroupSizePair, 4> workgroupLevelTs;
+  int64_t dstSize =
+      lhsType.getDimSize(0) * lhsType.getDimSize(1) * rhsType.getDimSize(2);
+  getMaliBestMatMulTileSizes(lhsType.getElementType(), workgroupLevelTs,
+                             dstSize);
+  for (TileWorkgroupSizePair pair : workgroupLevelTs) {
+    if (lhsType.getDimSize(1) % pair.tileSize[0] != 0 ||
+        rhsType.getDimSize(2) % pair.tileSize[1] != 0 ||
+        lhsType.getDimSize(2) % pair.tileSize[2] != 0) {
+      continue;
+    }
 
-  workgroupSize[0] = targetEnv.getResourceLimits().subgroup_size().getInt();
-  workgroupSize[1] = 1;
-  workgroupSize[2] = 1;
-  tileSizes.emplace_back(workgroupLevelTs);
-  // No tiling at the subgroup level since this target doesn't use subgroup op
-  // or shared memory.
-  tileSizes.emplace_back();
-  SmallVector<int64_t, 4> invocationLevelTs = {
-      workgroupLevelTs[0], workgroupLevelTs[1],
-      workgroupLevelTs[2] / workgroupSize[0], workgroupLevelTs[3]};
-  tileSizes.emplace_back(invocationLevelTs);
-  return success();
+    workgroupSize = pair.workgroupSize;
+    SmallVector<int64_t, 4> batchTs;
+    batchTs.append({1, pair.tileSize[0], pair.tileSize[1], pair.tileSize[2]});
+    tileSizes.emplace_back(batchTs);
+    // No tiling at the subgroup level since this target doesn't use subgroup op
+    // or shared memory.
+    tileSizes.emplace_back();
+    SmallVector<int64_t, 4> invocationLevelTs = {
+        batchTs[0], batchTs[1] / workgroupSize[1],
+        batchTs[2] / workgroupSize[0], batchTs[3]};
+    tileSizes.emplace_back(invocationLevelTs);
+    return success();
+  }
+  return failure();
 }
 
 /// Launch config for `linalg.batchmatmul`.
@@ -272,30 +295,34 @@
   auto lhsType = op.inputs()[0].getType().cast<MemRefType>();
   auto rhsType = op.inputs()[1].getType().cast<MemRefType>();
   assert(lhsType.getElementType() == rhsType.getElementType());
+  // If the shape size is unknonw fall back to none vectorized path.
+  if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape()) return failure();
   // Pick ideal tile size based on the type.
-  SmallVector<int64_t, 4> workgroupLevelTs;
-  getMaliBestMatMulTileSizes(lhsType.getElementType(), workgroupLevelTs);
+  SmallVector<TileWorkgroupSizePair, 4> workgroupLevelTs;
+  int64_t dstSize = lhsType.getDimSize(0) * rhsType.getDimSize(1);
+  getMaliBestMatMulTileSizes(lhsType.getElementType(), workgroupLevelTs,
+                             dstSize);
+  for (TileWorkgroupSizePair pair : workgroupLevelTs) {
+    if (lhsType.getDimSize(0) % pair.tileSize[0] != 0 ||
+        rhsType.getDimSize(1) % pair.tileSize[1] != 0 ||
+        lhsType.getDimSize(1) % pair.tileSize[2] != 0) {
+      continue;
+    }
 
-  // Fall back to the none vectorize path for cases we don't handle.
-  if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape() ||
-      lhsType.getDimSize(0) % workgroupLevelTs[0] != 0 ||
-      rhsType.getDimSize(1) % workgroupLevelTs[1] != 0 ||
-      lhsType.getDimSize(1) % workgroupLevelTs[2] != 0) {
-    return failure();
+    workgroupSize = pair.workgroupSize;
+    SmallVector<int64_t, 4> matmulTS(pair.tileSize.begin(),
+                                     pair.tileSize.end());
+    tileSizes.emplace_back(matmulTS);
+    // No tiling at the subgroup level since this target doesn't use subgroup op
+    // or shared memory.
+    tileSizes.emplace_back();
+    SmallVector<int64_t, 4> invocationLevelTs = {matmulTS[0] / workgroupSize[1],
+                                                 matmulTS[1] / workgroupSize[0],
+                                                 matmulTS[2]};
+    tileSizes.emplace_back(invocationLevelTs);
+    return success();
   }
-
-  workgroupSize[0] = targetEnv.getResourceLimits().subgroup_size().getInt();
-  workgroupSize[1] = 1;
-  workgroupSize[2] = 1;
-  tileSizes.emplace_back(workgroupLevelTs);
-  // No tiling at the subgroup level since this target doesn't use subgroup op
-  // or shared memory.
-  tileSizes.emplace_back();
-  SmallVector<int64_t, 4> invocationLevelTs = {
-      workgroupLevelTs[0], workgroupLevelTs[1] / workgroupSize[0],
-      workgroupLevelTs[2]};
-  tileSizes.emplace_back(invocationLevelTs);
-  return success();
+  return failure();
 }
 
 template <>
@@ -344,12 +371,6 @@
   return success();
 }
 
-struct TileWorkgroupSizePair {
-  // How many scalar elements each workgroup should handle along each dimension.
-  std::array<int64_t, 3> tileSize;
-  std::array<int64_t, 3> workgroupSize;
-};
-
 template <typename ConvOpTy>
 static LogicalResult getMaliSpecificConfig(ConvOpTy op,
                                            TileSizesListType &tileSizes,
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir
index 870461f..88251f4 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir
@@ -120,7 +120,7 @@
 //  CHECK-COUNT-16:   vector.transfer_write
 //  CHECK-COUNT-16:   vector.transfer_read
 //          CHECK:   %[[FOR_RES:.+]]:16 = scf.for
-// CHECK-COUNT-40:     vector.transfer_read
+// CHECK-COUNT-16:     vector.transfer_read
 // CHECK-COUNT-64:     vector.contract
 //      CHECK:         scf.yield
 //  CHECK-COUNT-16:    vector.transfer_write %[[FOR_RES]]