Adjust default mmt4d tiling (#7071)

* Adjust default mmt4d tiling

- Tile to M0=4, N0=4, K0=1.
- Add microbenchmark for that case.

* Set M0,N0,K0 from operands
diff --git a/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
index 89d61d0..e9d69ee 100644
--- a/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
@@ -183,23 +183,33 @@
       return SmallVector<int64_t>(mmt4dWorkgroupTileSizes.begin(),
                                   mmt4dWorkgroupTileSizes.end());
     }
-    return {64, 32};
+    return {48, 32};
   };
 
   auto getL1TileSizes = [&]() -> SmallVector<int64_t> {
+    auto lhsShape = getUntiledShape(mmt4dOp.inputs()[0]);
+    auto rhsShape = getUntiledShape(mmt4dOp.inputs()[1]);
+    int M0 = lhsShape[2];
+    int N0 = rhsShape[2];
+    int K0 = lhsShape[3];
     if (!mmt4dL1TileSizes.empty()) {
       return SmallVector<int64_t>(mmt4dL1TileSizes.begin(),
                                   mmt4dL1TileSizes.end());
     }
-    return {1, 1, 4, 4, 1, 4};
+    return {1, 1, 1, M0, N0, K0};
   };
 
   auto getVectorSizes = [&]() -> SmallVector<int64_t> {
+    auto lhsShape = getUntiledShape(mmt4dOp.inputs()[0]);
+    auto rhsShape = getUntiledShape(mmt4dOp.inputs()[1]);
+    int M0 = lhsShape[2];
+    int N0 = rhsShape[2];
+    int K0 = lhsShape[3];
     if (!mmt4dVectorSizes.empty()) {
       return SmallVector<int64_t>(mmt4dVectorSizes.begin(),
                                   mmt4dVectorSizes.end());
     }
-    return {1, 1, 4, 4, 1, 4};
+    return {1, 1, 1, M0, N0, K0};
   };
 
   SmallVector<int64_t, 4> nativeVectorSize = getVectorSizes();
diff --git a/iree/test/microbenchmarks/linalg_mmt4d.mlir b/iree/test/microbenchmarks/linalg_mmt4d.mlir
index c8a9d10..543d6bd 100644
--- a/iree/test/microbenchmarks/linalg_mmt4d.mlir
+++ b/iree/test/microbenchmarks/linalg_mmt4d.mlir
@@ -10,10 +10,18 @@
     return %0 : tensor<384x512xf32>
 }
 
-func @mmt4d_384x384x512() -> tensor<96x128x4x4xf32> {
-    %lhs = util.unfoldable_constant dense<1.0> : tensor<96x96x4x4xf32>
-    %rhs = util.unfoldable_constant dense<1.0> : tensor<128x96x4x4xf32>
+func @mmt4d_384x384x512_4x1x4() -> tensor<96x128x4x4xf32> {
+    %lhs = util.unfoldable_constant dense<1.0> : tensor<96x384x4x1xf32>
+    %rhs = util.unfoldable_constant dense<1.0> : tensor<128x384x4x1xf32>
     %dst = util.unfoldable_constant dense<1.0> : tensor<96x128x4x4xf32>
-    %0 = linalg.mmt4d ins(%lhs, %rhs : tensor<96x96x4x4xf32>, tensor<128x96x4x4xf32>) outs(%dst : tensor<96x128x4x4xf32>) -> tensor<96x128x4x4xf32>
+    %0 = linalg.mmt4d ins(%lhs, %rhs : tensor<96x384x4x1xf32>, tensor<128x384x4x1xf32>) outs(%dst : tensor<96x128x4x4xf32>) -> tensor<96x128x4x4xf32>
     return %0 : tensor<96x128x4x4xf32>
 }
+
+func @mmt4d_384x384x512_8x1x8() -> tensor<48x64x8x8xf32> {
+    %lhs = util.unfoldable_constant dense<1.0> : tensor<48x384x8x1xf32>
+    %rhs = util.unfoldable_constant dense<1.0> : tensor<64x384x8x1xf32>
+    %dst = util.unfoldable_constant dense<1.0> : tensor<48x64x8x8xf32>
+    %0 = linalg.mmt4d ins(%lhs, %rhs : tensor<48x384x8x1xf32>, tensor<64x384x8x1xf32>) outs(%dst : tensor<48x64x8x8xf32>) -> tensor<48x64x8x8xf32>
+    return %0 : tensor<48x64x8x8xf32>
+}