[CPU] Minor clean-up and fixes for mmt4d code generation (#15380)

Just minor clean-up changes and fixes in tiling config exposed when
looking at mmt4d code generation.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp b/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp
index b164dfc..a50130c 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp
@@ -198,7 +198,7 @@
     if (vectorizePadding && enableVectorMasking && isa<tensor::PadOp>(op))
       candidates.push_back(op);
   });
-  for (auto op : candidates) {
+  for (Operation *op : candidates) {
     SmallVector<int64_t> vectorSizes;
     SmallVector<bool> scalableVecDims;
     if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileSizeSelection.cpp b/compiler/src/iree/compiler/Codegen/Common/TileSizeSelection.cpp
index 10ed570..e29a09b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileSizeSelection.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TileSizeSelection.cpp
@@ -61,22 +61,32 @@
   auto [parallelCommonSizes, parallelCommonScalableFlags] =
       getVectorCommonParallelSizes();
   auto [reductionSizes, reductionScalableFlags] = getVectorReductionSizes();
-  auto [parallelInnerSizes, parallelInnerScalableFlags] =
-      getVectorInnerParallelSizes();
+  SizesAndScalableFlags parallelInnerTiles;
+  if (hasVectorInnerParallelLevel()) {
+    parallelInnerTiles = getVectorInnerParallelSizes();
+  }
+
   for (int i = 0; i < numDims; ++i) {
-    unsigned nonZeroCnt = llvm::count(
-        ArrayRef<bool>{
-            !!parallelCommonSizes[i] || parallelCommonScalableFlags[i],
-            !!reductionSizes[i] || reductionScalableFlags[i],
-            !!parallelInnerSizes[i] || parallelInnerScalableFlags[i]},
-        true);
+    SmallVector<bool> dimSizes;
+    dimSizes.push_back(!!parallelCommonSizes[i] ||
+                       parallelCommonScalableFlags[i]);
+    dimSizes.push_back(!!reductionSizes[i] || reductionScalableFlags[i]);
+    if (hasVectorInnerParallelLevel())
+      dimSizes.push_back(!!parallelInnerTiles.first[i] ||
+                         parallelInnerTiles.second[i]);
+
+    unsigned nonZeroCnt = llvm::count(dimSizes, true);
     assert(nonZeroCnt <= 1 && "expected one tile size at most to be non-zero");
     (void)nonZeroCnt;
-    vectorSizes[i] =
-        parallelCommonSizes[i] ^ reductionSizes[i] ^ parallelInnerSizes[i];
-    scalableFlags[i] = parallelCommonScalableFlags[i] ||
-                       reductionScalableFlags[i] ||
-                       parallelInnerScalableFlags[i];
+
+    vectorSizes[i] = parallelCommonSizes[i] ^ reductionSizes[i];
+    if (hasVectorInnerParallelLevel())
+      vectorSizes[i] ^= parallelInnerTiles.first[i];
+
+    scalableFlags[i] =
+        parallelCommonScalableFlags[i] || reductionScalableFlags[i];
+    if (hasVectorInnerParallelLevel())
+      scalableFlags[i] |= parallelInnerTiles.second[i];
   }
 
   return std::make_pair(vectorSizes, scalableFlags);
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileSizeSelection.h b/compiler/src/iree/compiler/Codegen/Common/TileSizeSelection.h
index a086ff8..2f2a9ab 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileSizeSelection.h
+++ b/compiler/src/iree/compiler/Codegen/Common/TileSizeSelection.h
@@ -72,6 +72,10 @@
     return getActualLevel(VectorCommonParallelTiles);
   }
 
+  /// Returns true if the tiling configuration has vector inner parallel
+  /// dimensions
+  bool hasVectorInnerParallelLevel() { return getNumTilingLevels() > 3; }
+
   /// Returns the tiling level for vector inner parallel dimensions.
   unsigned getVectorInnerParallelLevel() {
     return getActualLevel(VectorInnerParallelTiles);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
index a56deaa..3c6f93e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
@@ -54,7 +54,7 @@
     llvm::cl::desc("disable thread distribution in codegen"),
     llvm::cl::init(false));
 
-static llvm::cl::list<int> mmt4dDistributionTileSizes(
+static llvm::cl::list<int> clMmt4dDistributionTileSizes(
     "iree-codegen-llvm-mmt4d-distribution-tile-sizes",
     llvm::cl::desc("linalg.mmt4d distribution tile size"),
     llvm::cl::ZeroOrMore);
@@ -1176,9 +1176,9 @@
                                    linalg::Mmt4DOp mmt4dOp) {
   assert(!getLoweringConfig(mmt4dOp) && "expected lowering_config is not set");
   auto getDistTileSizes = [&]() -> SmallVector<int64_t> {
-    if (!mmt4dDistributionTileSizes.empty()) {
-      return SmallVector<int64_t>(mmt4dDistributionTileSizes.begin(),
-                                  mmt4dDistributionTileSizes.end());
+    if (!clMmt4dDistributionTileSizes.empty()) {
+      return SmallVector<int64_t>(clMmt4dDistributionTileSizes.begin(),
+                                  clMmt4dDistributionTileSizes.end());
     }
     unsigned numLoops = mmt4dOp.getNumLoops();
     SmallVector<int64_t> minTileSizes(numLoops, 0);
@@ -1193,6 +1193,10 @@
   };
 
   auto getL1TileSizes = [&]() -> SmallVector<int64_t> {
+    if (!mmt4dL1TileSizes.empty()) {
+      return SmallVector<int64_t>(mmt4dL1TileSizes.begin(),
+                                  mmt4dL1TileSizes.end());
+    }
     auto lhsShape =
         llvm::cast<ShapedType>(mmt4dOp.getInputs()[0].getType()).getShape();
     auto rhsShape =
@@ -1200,10 +1204,7 @@
     int M0 = lhsShape[2];
     int N0 = rhsShape[2];
     int K0 = lhsShape[3];
-    if (!mmt4dL1TileSizes.empty()) {
-      return SmallVector<int64_t>(mmt4dL1TileSizes.begin(),
-                                  mmt4dL1TileSizes.end());
-    }
+
     return {1, 1, 1, M0, N0, K0};
   };
 
@@ -1215,6 +1216,10 @@
   TileSizesListType tileSizes = {getDistTileSizes(), parallelTileSizes,
                                  reductionTileSizes};
 
+  LLVM_DEBUG(KD_DBGS() << "Parallel tile sizes: " << parallelTileSizes << "\n");
+  LLVM_DEBUG(KD_DBGS() << "Reduction tile sizes: " << reductionTileSizes
+                       << "\n");
+
   return setOpConfigAndEntryPointFnTranslation(
       entryPointFn, mmt4dOp, tileSizes,
       DispatchLoweringPassPipeline::Mmt4dTilingExpert);
@@ -1227,13 +1232,13 @@
   assert(!getLoweringConfig(batchMmt4dOp) &&
          "expected lowering_config is not set");
   auto getDistTileSizes = [&]() -> SmallVector<int64_t> {
-    if (!mmt4dDistributionTileSizes.empty()) {
+    if (!clMmt4dDistributionTileSizes.empty()) {
       SmallVector<int64_t> tileSizes;
-      // If mmt4dDistributionTileSizes is set, tile batch dim to 1 + the
+      // If clMmt4dDistributionTileSizes is set, tile batch dim to 1 + the
       // specified mmt4d tile sizes.
       tileSizes.push_back(1);
-      tileSizes.append(mmt4dDistributionTileSizes.begin(),
-                       mmt4dDistributionTileSizes.end());
+      tileSizes.append(clMmt4dDistributionTileSizes.begin(),
+                       clMmt4dDistributionTileSizes.end());
       return tileSizes;
     }
     unsigned numLoops = batchMmt4dOp.getNumLoops();
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMmt4dVectorLowering.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMmt4dVectorLowering.cpp
index a058ec7..b1b8a9c 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMmt4dVectorLowering.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMmt4dVectorLowering.cpp
@@ -6,9 +6,6 @@
 
 #include "iree/compiler/Codegen/LLVMCPU/PassDetail.h"
 #include "iree/compiler/Codegen/LLVMCPU/Passes.h"
-#include "iree/compiler/Codegen/Transforms/Transforms.h"
-#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
-#include "iree/compiler/Codegen/Utils/Utils.h"
 #include "llvm/Support/Debug.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"