Add several tile sizes and add separate tile size for small matrices. (#5050)
This improves performance of MobileBert on Mali
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
index 6ce6ae2..0ef6877 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
@@ -70,6 +70,12 @@
std::array<int64_t, 3> numSubgroups = {1, 1, 1};
bool vectorize = false;
};
+
+struct TileWorkgroupSizePair {
+ // How many scalar elements each workgroup should handle along each dimension.
+ std::array<int64_t, 3> tileSize;
+ std::array<int64_t, 3> workgroupSize;
+};
} // namespace
/// For a given operation `op`, compute the following configurations according
@@ -86,12 +92,24 @@
return op.emitError("undefined launch config for tiled operation");
}
-static void getMaliBestMatMulTileSizes(Type elementType,
- SmallVectorImpl<int64_t> &tileSizes) {
+static void getMaliBestMatMulTileSizes(
+ Type elementType, SmallVectorImpl<TileWorkgroupSizePair> &tileSizes,
+ int64_t dstSize) {
+ const int64_t smallMatrixSizeThreshold = 512 * 512;
if (elementType.isF16()) {
- tileSizes.append({16, 64, 8});
+ // When the destination is smaller than the threshold, we prefer smaller
+ // tiles to increase parallelism.
+ // TODO: The threshold needs to be fine tuned by doing exploration based on
+ // matrix shapes.
+ if (dstSize <= smallMatrixSizeThreshold) {
+ tileSizes.push_back(TileWorkgroupSizePair({{16, 32, 8}, {8, 2, 1}}));
+ } else {
+ tileSizes.push_back(TileWorkgroupSizePair({{16, 64, 4}, {8, 2, 1}}));
+ tileSizes.push_back(TileWorkgroupSizePair({{8, 128, 4}, {8, 2, 1}}));
+ tileSizes.push_back(TileWorkgroupSizePair({{16, 32, 4}, {8, 2, 1}}));
+ }
} else {
- tileSizes.append({8, 64, 4});
+ tileSizes.push_back(TileWorkgroupSizePair({{8, 64, 4}, {16, 1, 1}}));
}
}
@@ -106,29 +124,34 @@
auto lhsType = op.inputs()[0].getType().cast<MemRefType>();
auto rhsType = op.inputs()[1].getType().cast<MemRefType>();
assert(lhsType.getElementType() == rhsType.getElementType());
- // Pick ideal tile size based on the type.
- SmallVector<int64_t, 4> workgroupLevelTs(1, 1);
- getMaliBestMatMulTileSizes(lhsType.getElementType(), workgroupLevelTs);
- // Fall back to the none vectorize path for cases we don't handle.
- if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape() ||
- lhsType.getDimSize(1) % workgroupLevelTs[1] != 0 ||
- rhsType.getDimSize(2) % workgroupLevelTs[2] != 0 ||
- lhsType.getDimSize(2) % workgroupLevelTs[3] != 0) {
- return failure();
- }
+ if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape()) return failure();
+ // Get a vector of best tile size ordered from best to worst.
+ SmallVector<TileWorkgroupSizePair, 4> workgroupLevelTs;
+ int64_t dstSize =
+ lhsType.getDimSize(0) * lhsType.getDimSize(1) * rhsType.getDimSize(2);
+ getMaliBestMatMulTileSizes(lhsType.getElementType(), workgroupLevelTs,
+ dstSize);
+ for (TileWorkgroupSizePair pair : workgroupLevelTs) {
+ if (lhsType.getDimSize(1) % pair.tileSize[0] != 0 ||
+ rhsType.getDimSize(2) % pair.tileSize[1] != 0 ||
+ lhsType.getDimSize(2) % pair.tileSize[2] != 0) {
+ continue;
+ }
- workgroupSize[0] = targetEnv.getResourceLimits().subgroup_size().getInt();
- workgroupSize[1] = 1;
- workgroupSize[2] = 1;
- tileSizes.emplace_back(workgroupLevelTs);
- // No tiling at the subgroup level since this target doesn't use subgroup op
- // or shared memory.
- tileSizes.emplace_back();
- SmallVector<int64_t, 4> invocationLevelTs = {
- workgroupLevelTs[0], workgroupLevelTs[1],
- workgroupLevelTs[2] / workgroupSize[0], workgroupLevelTs[3]};
- tileSizes.emplace_back(invocationLevelTs);
- return success();
+ workgroupSize = pair.workgroupSize;
+ SmallVector<int64_t, 4> batchTs;
+ batchTs.append({1, pair.tileSize[0], pair.tileSize[1], pair.tileSize[2]});
+ tileSizes.emplace_back(batchTs);
+ // No tiling at the subgroup level since this target doesn't use subgroup op
+ // or shared memory.
+ tileSizes.emplace_back();
+ SmallVector<int64_t, 4> invocationLevelTs = {
+ batchTs[0], batchTs[1] / workgroupSize[1],
+ batchTs[2] / workgroupSize[0], batchTs[3]};
+ tileSizes.emplace_back(invocationLevelTs);
+ return success();
+ }
+ return failure();
}
/// Launch config for `linalg.batchmatmul`.
@@ -272,30 +295,34 @@
auto lhsType = op.inputs()[0].getType().cast<MemRefType>();
auto rhsType = op.inputs()[1].getType().cast<MemRefType>();
assert(lhsType.getElementType() == rhsType.getElementType());
+ // If the shape size is unknonw fall back to none vectorized path.
+ if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape()) return failure();
// Pick ideal tile size based on the type.
- SmallVector<int64_t, 4> workgroupLevelTs;
- getMaliBestMatMulTileSizes(lhsType.getElementType(), workgroupLevelTs);
+ SmallVector<TileWorkgroupSizePair, 4> workgroupLevelTs;
+ int64_t dstSize = lhsType.getDimSize(0) * rhsType.getDimSize(1);
+ getMaliBestMatMulTileSizes(lhsType.getElementType(), workgroupLevelTs,
+ dstSize);
+ for (TileWorkgroupSizePair pair : workgroupLevelTs) {
+ if (lhsType.getDimSize(0) % pair.tileSize[0] != 0 ||
+ rhsType.getDimSize(1) % pair.tileSize[1] != 0 ||
+ lhsType.getDimSize(1) % pair.tileSize[2] != 0) {
+ continue;
+ }
- // Fall back to the none vectorize path for cases we don't handle.
- if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape() ||
- lhsType.getDimSize(0) % workgroupLevelTs[0] != 0 ||
- rhsType.getDimSize(1) % workgroupLevelTs[1] != 0 ||
- lhsType.getDimSize(1) % workgroupLevelTs[2] != 0) {
- return failure();
+ workgroupSize = pair.workgroupSize;
+ SmallVector<int64_t, 4> matmulTS(pair.tileSize.begin(),
+ pair.tileSize.end());
+ tileSizes.emplace_back(matmulTS);
+ // No tiling at the subgroup level since this target doesn't use subgroup op
+ // or shared memory.
+ tileSizes.emplace_back();
+ SmallVector<int64_t, 4> invocationLevelTs = {matmulTS[0] / workgroupSize[1],
+ matmulTS[1] / workgroupSize[0],
+ matmulTS[2]};
+ tileSizes.emplace_back(invocationLevelTs);
+ return success();
}
-
- workgroupSize[0] = targetEnv.getResourceLimits().subgroup_size().getInt();
- workgroupSize[1] = 1;
- workgroupSize[2] = 1;
- tileSizes.emplace_back(workgroupLevelTs);
- // No tiling at the subgroup level since this target doesn't use subgroup op
- // or shared memory.
- tileSizes.emplace_back();
- SmallVector<int64_t, 4> invocationLevelTs = {
- workgroupLevelTs[0], workgroupLevelTs[1] / workgroupSize[0],
- workgroupLevelTs[2]};
- tileSizes.emplace_back(invocationLevelTs);
- return success();
+ return failure();
}
template <>
@@ -344,12 +371,6 @@
return success();
}
-struct TileWorkgroupSizePair {
- // How many scalar elements each workgroup should handle along each dimension.
- std::array<int64_t, 3> tileSize;
- std::array<int64_t, 3> workgroupSize;
-};
-
template <typename ConvOpTy>
static LogicalResult getMaliSpecificConfig(ConvOpTy op,
TileSizesListType &tileSizes,
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir
index 870461f..88251f4 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir
@@ -120,7 +120,7 @@
// CHECK-COUNT-16: vector.transfer_write
// CHECK-COUNT-16: vector.transfer_read
// CHECK: %[[FOR_RES:.+]]:16 = scf.for
-// CHECK-COUNT-40: vector.transfer_read
+// CHECK-COUNT-16: vector.transfer_read
// CHECK-COUNT-64: vector.contract
// CHECK: scf.yield
// CHECK-COUNT-16: vector.transfer_write %[[FOR_RES]]