Select matmul tile sizes based on static shape information (#5653)

For a given matmul shape this selects the largest integer multiple of the vector size that fits in the workgroup or L1 sizes.
diff --git a/iree/compiler/Conversion/LinalgToLLVM/BUILD b/iree/compiler/Conversion/LinalgToLLVM/BUILD
index ffe19c3..8480630 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/BUILD
+++ b/iree/compiler/Conversion/LinalgToLLVM/BUILD
@@ -63,6 +63,7 @@
         "//iree/compiler/Conversion/HLOToHLO",
         "//iree/compiler/Conversion/LinalgToLLVM:FoldTensorExtractOpIncGen",
         "//iree/compiler/Conversion/VectorToLLVM",
+        "//iree/compiler/Dialect/Flow/IR",
         "//iree/compiler/Dialect/HAL/IR",
         "//iree/compiler/Dialect/HAL/IR:HALDialect",
         "//iree/compiler/Dialect/IREE/IR",
diff --git a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
index ee47ddb..3118cab 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
@@ -69,6 +69,7 @@
     iree::compiler::Conversion::HLOToHLO
     iree::compiler::Conversion::LinalgToLLVM::FoldTensorExtractOpIncGen
     iree::compiler::Conversion::VectorToLLVM
+    iree::compiler::Dialect::Flow::IR
     iree::compiler::Dialect::HAL::IR
     iree::compiler::Dialect::HAL::IR::HALDialect
     iree::compiler::Dialect::IREE::IR
diff --git a/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.cpp b/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.cpp
index 37d9421..fef25c2 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.cpp
@@ -15,9 +15,11 @@
 
 #include "iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h"
 
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
 #include "llvm/Support/CommandLine.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Operation.h"
 
@@ -42,7 +44,7 @@
     llvm::cl::desc(
         "linalg.matmul tile size for L1 spliting of M, N, K dimension"),
     llvm::cl::init(32));
-static llvm::cl::opt<int> matmulL2TileSize(
+static llvm::cl::opt<int> matmulVectorSize(
     "iree-codegen-llvm-matmul-vector-size",
     llvm::cl::desc("linalg.matmul vector tile size"), llvm::cl::init(4));
 
@@ -71,15 +73,69 @@
 llvm::SmallVector<int64_t, 4> getTileSizes(Operation *op) {
   if (auto contractionOp = dyn_cast<linalg::ContractionOpInterface>(op)) {
     if (contractionOp.isRowMajorMatmul()) {
+      int mWorkgroupSize = matmulWorkgroupTileSize;
+      int nWorkgroupSize = matmulWorkgroupTileSize;
+      int mL1TileSize = matmulL1TileSize;
+      int nL1TileSize = matmulL1TileSize;
+      int kL1TileSize = matmulL1TileSize;
+      if (auto matmulOp = dyn_cast<linalg::MatmulOp>(op)) {
+        // Returns the original problem size before tiling.
+        auto getOriginalOperandShape = [](Value operand) {
+          if (auto dispatchLoadOp =
+                  operand.getDefiningOp<IREE::Flow::DispatchTensorLoadOp>()) {
+            return dispatchLoadOp.source()
+                .getType()
+                .cast<IREE::Flow::DispatchTensorType>()
+                .getShape();
+          }
+          if (auto operandParent = operand.getDefiningOp<memref::SubViewOp>()) {
+            return operandParent.source()
+                .getType()
+                .cast<ShapedType>()
+                .getShape();
+          }
+          if (auto operandParent = operand.getDefiningOp<SubTensorOp>()) {
+            return operandParent.source()
+                .getType()
+                .cast<ShapedType>()
+                .getShape();
+          }
+          if (auto operandParent = operand.getDefiningOp<memref::AllocaOp>()) {
+            return operandParent.getType().cast<ShapedType>().getShape();
+          }
+          return ArrayRef<int64_t>{};
+        };
+
+        auto lhsShape = getOriginalOperandShape(matmulOp.inputs()[0]);
+        auto rhsShape = getOriginalOperandShape(matmulOp.inputs()[1]);
+
+        if (!lhsShape.empty() && !rhsShape.empty()) {
+          // Find largest tile size that is a multiple of the vector size.
+          auto getTileSize = [](int dim, int maxSize) {
+            for (int i = std::min(maxSize, dim); i > 0; --i) {
+              if (dim % i == 0 && i % matmulVectorSize == 0) {
+                return i;
+              }
+            }
+            return maxSize;
+          };
+          mWorkgroupSize = getTileSize(lhsShape[0], mWorkgroupSize);
+          nWorkgroupSize = getTileSize(rhsShape[1], nWorkgroupSize);
+          mL1TileSize = getTileSize(mWorkgroupSize, mL1TileSize);
+          nL1TileSize = getTileSize(nWorkgroupSize, nL1TileSize);
+          kL1TileSize = getTileSize(rhsShape[0], kL1TileSize);
+        }
+      }
+
       switch (tilingLevel) {
         case TilingLevel::WorkGroupTiles: {
-          return {matmulWorkgroupTileSize, matmulWorkgroupTileSize};
+          return {mWorkgroupSize, nWorkgroupSize};
         }
         case TilingLevel::Level1Tiles: {
-          return {matmulL1TileSize, matmulL1TileSize, matmulL1TileSize};
+          return {mL1TileSize, nL1TileSize, kL1TileSize};
         }
         case TilingLevel::Level2Tiles: {
-          return {matmulL2TileSize, matmulL2TileSize, matmulL2TileSize};
+          return {matmulVectorSize, matmulVectorSize, matmulVectorSize};
         }
       }
     }