[LLVMCPU] Re-enable vector masking for sub-byte element types (#15335)
Fixes #15031
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
index 1d627bf..1d5ebde 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
@@ -191,11 +191,10 @@
auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(linalgOp);
bool isLinalgGeneric = isa<linalg::GenericOp>(linalgOp.getOperation());
- bool isByteAligned = hasByteAlignedElementTypes(linalgOp);
// Default X86 specific strategy.
if (isX86(targetAttr)) {
- if (isLinalgGeneric && isByteAligned) {
+ if (isLinalgGeneric) {
return VectorPreProcStrategy::Masking;
}
@@ -212,7 +211,7 @@
// Default RISC-V specific strategies.
if (isRISCV(targetAttr)) {
- if (isLinalgGeneric && isByteAligned) {
+ if (isLinalgGeneric) {
return VectorPreProcStrategy::Masking;
}
@@ -223,7 +222,7 @@
// Default AArch64 specific strategies.
if (isAArch64(targetAttr)) {
- if (hasAnySVEFeature(targetAttr) && isByteAligned) {
+ if (hasAnySVEFeature(targetAttr)) {
return VectorPreProcStrategy::Masking;
}
if ((linalg::isElementwise(linalgOp) || isFullyDynamicOp(linalgOp)) &&
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp
index f731f8f..889074b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp
@@ -189,15 +189,9 @@
if (translationInfo.has_value()) {
auto target = variantOp.getTarget();
bool lowerToAVX2 = hasAVX2Feature(target);
- auto walkRes = moduleOp.walk([](linalg::LinalgOp linalgOp) {
- if (!hasByteAlignedElementTypes(linalgOp))
- return WalkResult::interrupt();
- return WalkResult::advance();
- });
- bool isByteAligned = !walkRes.wasInterrupted();
bool enableVectorMasking =
- isByteAligned && (isX86(target) || isRISCV(target) ||
- (isAArch64(target) && hasAnySVEFeature(target)));
+ isX86(target) || isRISCV(target) ||
+ (isAArch64(target) && hasAnySVEFeature(target));
bool enableMicrokernels = hasMicrokernels(target);
bool enableAArch64SSVE = isAArch64(target) && hasAnySVEFeature(target) &&
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Utils.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Utils.cpp
index 767986e..a6e9a5f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Utils.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Utils.cpp
@@ -100,14 +100,6 @@
return rootOperation;
}
-bool hasByteAlignedElementTypes(linalg::LinalgOp linalgOp) {
- return llvm::all_of(linalgOp->getOperands(), [](Value operand) {
- auto bitwidth =
- IREE::Util::getTypeBitWidth(getElementTypeOrSelf(operand.getType()));
- return bitwidth % 8 == 0;
- });
-}
-
void setSCFTileSizes(scf::SCFTilingOptions &options, TilingInterface consumerOp,
SmallVector<int64_t> tileSizes,
SmallVector<bool> tileScalableFlags) {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Utils.h b/compiler/src/iree/compiler/Codegen/LLVMCPU/Utils.h
index ee04d79..fb06643 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Utils.h
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Utils.h
@@ -53,10 +53,6 @@
/// to the end of the function is the root op.
FailureOr<Operation *> getRootOperation(ArrayRef<Operation *> computeOps);
-/// Returns true if all of the element types involved in the linalg op are byte
-/// aligned.
-bool hasByteAlignedElementTypes(linalg::LinalgOp linalgOp);
-
/// Sets the tile sizes of the SCFTilingOptions. If `tileScalableFlags` are
/// provided the corresponding tile size will be multiplied by a vector.vscale
/// op.