Select matmul tile sizes based on static shape information (#5653)
For a given matmul shape this selects the largest integer multiple of the vector size that fits in the workgroup or L1 sizes.
diff --git a/iree/compiler/Conversion/LinalgToLLVM/BUILD b/iree/compiler/Conversion/LinalgToLLVM/BUILD
index ffe19c3..8480630 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/BUILD
+++ b/iree/compiler/Conversion/LinalgToLLVM/BUILD
@@ -63,6 +63,7 @@
"//iree/compiler/Conversion/HLOToHLO",
"//iree/compiler/Conversion/LinalgToLLVM:FoldTensorExtractOpIncGen",
"//iree/compiler/Conversion/VectorToLLVM",
+ "//iree/compiler/Dialect/Flow/IR",
"//iree/compiler/Dialect/HAL/IR",
"//iree/compiler/Dialect/HAL/IR:HALDialect",
"//iree/compiler/Dialect/IREE/IR",
diff --git a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
index ee47ddb..3118cab 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
@@ -69,6 +69,7 @@
iree::compiler::Conversion::HLOToHLO
iree::compiler::Conversion::LinalgToLLVM::FoldTensorExtractOpIncGen
iree::compiler::Conversion::VectorToLLVM
+ iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::HAL::IR::HALDialect
iree::compiler::Dialect::IREE::IR
diff --git a/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.cpp b/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.cpp
index 37d9421..fef25c2 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.cpp
@@ -15,9 +15,11 @@
#include "iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Operation.h"
@@ -42,7 +44,7 @@
llvm::cl::desc(
"linalg.matmul tile size for L1 spliting of M, N, K dimension"),
llvm::cl::init(32));
-static llvm::cl::opt<int> matmulL2TileSize(
+static llvm::cl::opt<int> matmulVectorSize(
"iree-codegen-llvm-matmul-vector-size",
llvm::cl::desc("linalg.matmul vector tile size"), llvm::cl::init(4));
@@ -71,15 +73,69 @@
llvm::SmallVector<int64_t, 4> getTileSizes(Operation *op) {
if (auto contractionOp = dyn_cast<linalg::ContractionOpInterface>(op)) {
if (contractionOp.isRowMajorMatmul()) {
+ int mWorkgroupSize = matmulWorkgroupTileSize;
+ int nWorkgroupSize = matmulWorkgroupTileSize;
+ int mL1TileSize = matmulL1TileSize;
+ int nL1TileSize = matmulL1TileSize;
+ int kL1TileSize = matmulL1TileSize;
+ if (auto matmulOp = dyn_cast<linalg::MatmulOp>(op)) {
+ // Returns the original problem size before tiling.
+ auto getOriginalOperandShape = [](Value operand) {
+ if (auto dispatchLoadOp =
+ operand.getDefiningOp<IREE::Flow::DispatchTensorLoadOp>()) {
+ return dispatchLoadOp.source()
+ .getType()
+ .cast<IREE::Flow::DispatchTensorType>()
+ .getShape();
+ }
+ if (auto operandParent = operand.getDefiningOp<memref::SubViewOp>()) {
+ return operandParent.source()
+ .getType()
+ .cast<ShapedType>()
+ .getShape();
+ }
+ if (auto operandParent = operand.getDefiningOp<SubTensorOp>()) {
+ return operandParent.source()
+ .getType()
+ .cast<ShapedType>()
+ .getShape();
+ }
+ if (auto operandParent = operand.getDefiningOp<memref::AllocaOp>()) {
+ return operandParent.getType().cast<ShapedType>().getShape();
+ }
+ return ArrayRef<int64_t>{};
+ };
+
+ auto lhsShape = getOriginalOperandShape(matmulOp.inputs()[0]);
+ auto rhsShape = getOriginalOperandShape(matmulOp.inputs()[1]);
+
+ if (!lhsShape.empty() && !rhsShape.empty()) {
+ // Find largest tile size that is a multiple of the vector size.
+ auto getTileSize = [](int dim, int maxSize) {
+ for (int i = std::min(maxSize, dim); i > 0; --i) {
+ if (dim % i == 0 && i % matmulVectorSize == 0) {
+ return i;
+ }
+ }
+ return maxSize;
+ };
+ mWorkgroupSize = getTileSize(lhsShape[0], mWorkgroupSize);
+ nWorkgroupSize = getTileSize(rhsShape[1], nWorkgroupSize);
+ mL1TileSize = getTileSize(mWorkgroupSize, mL1TileSize);
+ nL1TileSize = getTileSize(nWorkgroupSize, nL1TileSize);
+ kL1TileSize = getTileSize(rhsShape[0], kL1TileSize);
+ }
+ }
+
switch (tilingLevel) {
case TilingLevel::WorkGroupTiles: {
- return {matmulWorkgroupTileSize, matmulWorkgroupTileSize};
+ return {mWorkgroupSize, nWorkgroupSize};
}
case TilingLevel::Level1Tiles: {
- return {matmulL1TileSize, matmulL1TileSize, matmulL1TileSize};
+ return {mL1TileSize, nL1TileSize, kL1TileSize};
}
case TilingLevel::Level2Tiles: {
- return {matmulL2TileSize, matmulL2TileSize, matmulL2TileSize};
+ return {matmulVectorSize, matmulVectorSize, matmulVectorSize};
}
}
}