[CPU] Enable mmt4d distribution for large reduction size cases. (#16037)

diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
index 4b4004d..70439e2 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
@@ -335,7 +335,7 @@
   while (numWorkgroups > numWorkgroupsLimit && currDim > 0) {
     unsigned index = currDim - 1;
     int64_t currSize = distributedTileSizes[index];
-    if (ShapedType::isDynamic(workload[index]) ||
+    if (ShapedType::isDynamic(workload[index]) || currSize == 0 ||
         (maxTileSizes && currSize >= maxTileSizes.value()[index]) ||
         currSize >= workload[index]) {
       currDim--;
@@ -1160,6 +1160,7 @@
 
 static TileSizesListType getMmt4dTileSizes(linalg::LinalgOp op) {
   DistributionHeuristicConfig distConfig;
+  distConfig.allowIncompleteTile = true;
   distConfig.minTileSizes.resize(op.getNumLoops(), 0);
   distConfig.maxTileSizes.resize(op.getNumLoops(), 0);
 
@@ -1191,10 +1192,12 @@
   // guess a reasonable default for the reduction dimension size.
   int64_t reductionSize = ShapedType::isDynamic(K1) ? 1024 : K0 * K1;
   auto getMatmulTileSize = [](int64_t targetTileBytes, int bitWidth,
-                              int64_t reductionSize, int64_t Tile0Size) {
+                              int64_t reductionSize, int64_t tile0Size) {
     int64_t targetRhsTileElems = targetTileBytes * 8 / bitWidth;
     int64_t targetRhsTileNSize = targetRhsTileElems / reductionSize;
-    return llvm::divideCeil(targetRhsTileNSize, Tile0Size);
+    int64_t tileSize = llvm::divideCeil(targetRhsTileNSize, tile0Size);
+    tileSize = std::max<int64_t>(tileSize, 1);
+    return tileSize;
   };
   int64_t tileBytes =
       (M1 == 1 || N1 == 1) ? clNarrowMatmulTileBytes : clGeneralMatmulTileBytes;
@@ -1207,6 +1210,8 @@
               : getMatmulTileSize(tileBytes, rhsType.getElementTypeBitWidth(),
                                   reductionSize, N0);
 
+  SmallVector<int64_t> distTileSizes =
+      getDefaultDistributedLevelTileSizes(op, distConfig);
   SmallVector<int64_t> parallelTileSizes(op.getNumLoops(), 1);
   assert(parallelTileSizes.size() == mmt4dDimBase + 6);
   parallelTileSizes[mmt4dDimBase + 3] = M0;
@@ -1214,8 +1219,7 @@
   parallelTileSizes[mmt4dDimBase + 5] = K0;
   SmallVector<int64_t> reductionTileSizes;
   splitParallelAndReductionTiles(op, parallelTileSizes, reductionTileSizes);
-  return {getDefaultDistributedLevelTileSizes(op, distConfig),
-          parallelTileSizes, reductionTileSizes};
+  return {distTileSizes, parallelTileSizes, reductionTileSizes};
 }
 
 /// Sets the lowering configuration for dispatch region for linalg.mmt4d
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir
index c5fcc45..b0160f7 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir
@@ -2142,6 +2142,42 @@
 
 // -----
 
+#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu = "cascadelake", cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 64 : index, target_triple = "x86_64-unknown-unknown-eabi-elf"}>
+#pipeline_layout = #hal.pipeline.layout<push_constants = 28, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer, ReadOnly>,
+    #hal.descriptor_set.binding<1, storage_buffer>
+  ]>
+]>
+hal.executable private @mmt4d_with_large_reduction {
+  hal.executable.variant public @embedded_elf_x86_64 target(#executable_target_embedded_elf_x86_64_) {
+    hal.executable.export public @mmt4d_with_large_reduction ordinal(0) layout(#pipeline_layout)
+    builtin.module {
+      func.func @mmt4d_with_large_reduction() {
+        %c0 = arith.constant 0 : index
+        %cst = arith.constant 0.000000e+00 : f32
+        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<7x18176x16x1xf32>>
+        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<284x18176x16x1xf32>>
+        %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<7x284x16x16xf32>>
+        %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [7, 18176, 16, 1], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<7x18176x16x1xf32>> -> tensor<7x18176x16x1xf32>
+        %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [284, 18176, 16, 1], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<284x18176x16x1xf32>> -> tensor<284x18176x16x1xf32>
+        %5 = tensor.empty() : tensor<7x284x16x16xf32>
+        %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<7x284x16x16xf32>) -> tensor<7x284x16x16xf32>
+        %7 = linalg.mmt4d ins(%3, %4 : tensor<7x18176x16x1xf32>, tensor<284x18176x16x1xf32>) outs(%6 : tensor<7x284x16x16xf32>) -> tensor<7x284x16x16xf32>
+        flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0, 0], sizes = [7, 284, 16, 16], strides = [1, 1, 1, 1] : tensor<7x284x16x16xf32> -> !flow.dispatch.tensor<writeonly:tensor<7x284x16x16xf32>>
+        return
+      }
+    }
+  }
+}
+
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 1, 0, 0, 0, 0], [1, 1, 0, 16, 16, 0], [0, 0, 1, 0, 0, 1]]>
+//      CHECK: func.func @mmt4d_with_large_reduction()
+//      CHECK:   linalg.mmt4d
+// CHECK-SAME:     lowering_config = #[[CONFIG]]
+
+// -----
+
 hal.executable private @pad_only {
   hal.executable.variant public @embedded_elf_x86_64 target(<"llvm-cpu", "embedded-elf-x86_64", {
       cpu = "generic", cpu_features = "",