[CPU] Minor clean-up and fixes for mmt4d code generation (#15380)
Just minor clean-up changes and fixes in tiling config exposed when
looking at mmt4d code generation.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp b/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp
index b164dfc..a50130c 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp
@@ -198,7 +198,7 @@
if (vectorizePadding && enableVectorMasking && isa<tensor::PadOp>(op))
candidates.push_back(op);
});
- for (auto op : candidates) {
+ for (Operation *op : candidates) {
SmallVector<int64_t> vectorSizes;
SmallVector<bool> scalableVecDims;
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileSizeSelection.cpp b/compiler/src/iree/compiler/Codegen/Common/TileSizeSelection.cpp
index 10ed570..e29a09b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileSizeSelection.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TileSizeSelection.cpp
@@ -61,22 +61,32 @@
auto [parallelCommonSizes, parallelCommonScalableFlags] =
getVectorCommonParallelSizes();
auto [reductionSizes, reductionScalableFlags] = getVectorReductionSizes();
- auto [parallelInnerSizes, parallelInnerScalableFlags] =
- getVectorInnerParallelSizes();
+ SizesAndScalableFlags parallelInnerTiles;
+ if (hasVectorInnerParallelLevel()) {
+ parallelInnerTiles = getVectorInnerParallelSizes();
+ }
+
for (int i = 0; i < numDims; ++i) {
- unsigned nonZeroCnt = llvm::count(
- ArrayRef<bool>{
- !!parallelCommonSizes[i] || parallelCommonScalableFlags[i],
- !!reductionSizes[i] || reductionScalableFlags[i],
- !!parallelInnerSizes[i] || parallelInnerScalableFlags[i]},
- true);
+ SmallVector<bool> dimSizes;
+ dimSizes.push_back(!!parallelCommonSizes[i] ||
+ parallelCommonScalableFlags[i]);
+ dimSizes.push_back(!!reductionSizes[i] || reductionScalableFlags[i]);
+ if (hasVectorInnerParallelLevel())
+ dimSizes.push_back(!!parallelInnerTiles.first[i] ||
+ parallelInnerTiles.second[i]);
+
+ unsigned nonZeroCnt = llvm::count(dimSizes, true);
assert(nonZeroCnt <= 1 && "expected one tile size at most to be non-zero");
(void)nonZeroCnt;
- vectorSizes[i] =
- parallelCommonSizes[i] ^ reductionSizes[i] ^ parallelInnerSizes[i];
- scalableFlags[i] = parallelCommonScalableFlags[i] ||
- reductionScalableFlags[i] ||
- parallelInnerScalableFlags[i];
+
+ vectorSizes[i] = parallelCommonSizes[i] ^ reductionSizes[i];
+ if (hasVectorInnerParallelLevel())
+ vectorSizes[i] ^= parallelInnerTiles.first[i];
+
+ scalableFlags[i] =
+ parallelCommonScalableFlags[i] || reductionScalableFlags[i];
+ if (hasVectorInnerParallelLevel())
+ scalableFlags[i] |= parallelInnerTiles.second[i];
}
return std::make_pair(vectorSizes, scalableFlags);
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileSizeSelection.h b/compiler/src/iree/compiler/Codegen/Common/TileSizeSelection.h
index a086ff8..2f2a9ab 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileSizeSelection.h
+++ b/compiler/src/iree/compiler/Codegen/Common/TileSizeSelection.h
@@ -72,6 +72,10 @@
return getActualLevel(VectorCommonParallelTiles);
}
+ /// Returns true if the tiling configuration has vector inner parallel
+ /// dimensions
+ bool hasVectorInnerParallelLevel() { return getNumTilingLevels() > 3; }
+
/// Returns the tiling level for vector inner parallel dimensions.
unsigned getVectorInnerParallelLevel() {
return getActualLevel(VectorInnerParallelTiles);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
index a56deaa..3c6f93e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
@@ -54,7 +54,7 @@
llvm::cl::desc("disable thread distribution in codegen"),
llvm::cl::init(false));
-static llvm::cl::list<int> mmt4dDistributionTileSizes(
+static llvm::cl::list<int> clMmt4dDistributionTileSizes(
"iree-codegen-llvm-mmt4d-distribution-tile-sizes",
llvm::cl::desc("linalg.mmt4d distribution tile size"),
llvm::cl::ZeroOrMore);
@@ -1176,9 +1176,9 @@
linalg::Mmt4DOp mmt4dOp) {
assert(!getLoweringConfig(mmt4dOp) && "expected lowering_config is not set");
auto getDistTileSizes = [&]() -> SmallVector<int64_t> {
- if (!mmt4dDistributionTileSizes.empty()) {
- return SmallVector<int64_t>(mmt4dDistributionTileSizes.begin(),
- mmt4dDistributionTileSizes.end());
+ if (!clMmt4dDistributionTileSizes.empty()) {
+ return SmallVector<int64_t>(clMmt4dDistributionTileSizes.begin(),
+ clMmt4dDistributionTileSizes.end());
}
unsigned numLoops = mmt4dOp.getNumLoops();
SmallVector<int64_t> minTileSizes(numLoops, 0);
@@ -1193,6 +1193,10 @@
};
auto getL1TileSizes = [&]() -> SmallVector<int64_t> {
+ if (!mmt4dL1TileSizes.empty()) {
+ return SmallVector<int64_t>(mmt4dL1TileSizes.begin(),
+ mmt4dL1TileSizes.end());
+ }
auto lhsShape =
llvm::cast<ShapedType>(mmt4dOp.getInputs()[0].getType()).getShape();
auto rhsShape =
@@ -1200,10 +1204,7 @@
int M0 = lhsShape[2];
int N0 = rhsShape[2];
int K0 = lhsShape[3];
- if (!mmt4dL1TileSizes.empty()) {
- return SmallVector<int64_t>(mmt4dL1TileSizes.begin(),
- mmt4dL1TileSizes.end());
- }
+
return {1, 1, 1, M0, N0, K0};
};
@@ -1215,6 +1216,10 @@
TileSizesListType tileSizes = {getDistTileSizes(), parallelTileSizes,
reductionTileSizes};
+ LLVM_DEBUG(KD_DBGS() << "Parallel tile sizes: " << parallelTileSizes << "\n");
+ LLVM_DEBUG(KD_DBGS() << "Reduction tile sizes: " << reductionTileSizes
+ << "\n");
+
return setOpConfigAndEntryPointFnTranslation(
entryPointFn, mmt4dOp, tileSizes,
DispatchLoweringPassPipeline::Mmt4dTilingExpert);
@@ -1227,13 +1232,13 @@
assert(!getLoweringConfig(batchMmt4dOp) &&
"expected lowering_config is not set");
auto getDistTileSizes = [&]() -> SmallVector<int64_t> {
- if (!mmt4dDistributionTileSizes.empty()) {
+ if (!clMmt4dDistributionTileSizes.empty()) {
SmallVector<int64_t> tileSizes;
- // If mmt4dDistributionTileSizes is set, tile batch dim to 1 + the
+ // If clMmt4dDistributionTileSizes is set, tile batch dim to 1 + the
// specified mmt4d tile sizes.
tileSizes.push_back(1);
- tileSizes.append(mmt4dDistributionTileSizes.begin(),
- mmt4dDistributionTileSizes.end());
+ tileSizes.append(clMmt4dDistributionTileSizes.begin(),
+ clMmt4dDistributionTileSizes.end());
return tileSizes;
}
unsigned numLoops = batchMmt4dOp.getNumLoops();
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMmt4dVectorLowering.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMmt4dVectorLowering.cpp
index a058ec7..b1b8a9c 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMmt4dVectorLowering.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMmt4dVectorLowering.cpp
@@ -6,9 +6,6 @@
#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h"
#include "iree/compiler/Codegen/LLVMCPU/Passes.h"
-#include "iree/compiler/Codegen/Transforms/Transforms.h"
-#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
-#include "iree/compiler/Codegen/Utils/Utils.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"