[LLVMCPU] Drop unit dims on memory transfers (#13340)
diff --git a/compiler/src/iree/compiler/Codegen/Common/CommonPasses.h b/compiler/src/iree/compiler/Codegen/Common/CommonPasses.h
index ac8dc27..f2189b2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CommonPasses.h
+++ b/compiler/src/iree/compiler/Codegen/Common/CommonPasses.h
@@ -153,7 +153,7 @@
/// Pass to optimize vector transfer_read and transfer_write.
std::unique_ptr<OperationPass<func::FuncOp>> createOptimizeVectorTransferPass(
- bool flatten = false);
+ bool flatten = false, bool dropUnitDims = true);
/// Pad dynamic alloc op to convert them into static one.
std::unique_ptr<OperationPass<func::FuncOp>> createPadDynamicAlloc();
diff --git a/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp b/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp
index 3b02ad4..165620b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp
@@ -88,7 +88,8 @@
struct OptimizeVectorTransferPass
: public OptimizeVectorTransferBase<OptimizeVectorTransferPass> {
- OptimizeVectorTransferPass(bool flatten) : flatten(flatten) {}
+ OptimizeVectorTransferPass(bool flatten, bool dropUnitDims)
+ : flatten(flatten), dropUnitDims(dropUnitDims) {}
void runOnOperation() override {
func::FuncOp funcOp = getOperation();
// Generate vector.shape_cast for dropping leading one dimensions in vector
@@ -125,10 +126,20 @@
}
}
+ // TODO(#14191): SPIR-V can't handle the vector.shape_cast created for
+ // dropping unit dims so this option is disabled in SPIR-V pipeline.
+ // This option should go away after all backend issues have been resolved.
+ if (dropUnitDims) {
+ RewritePatternSet patterns(&getContext());
+ mlir::vector::populateVectorTransferDropUnitDimsPatterns(patterns);
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+
// Second stage of patterns to flatten transfer ops.
if (flatten) {
RewritePatternSet patterns(&getContext());
- mlir::vector::populateVectorTransferDropUnitDimsPatterns(patterns);
mlir::vector::populateFlattenVectorTransferPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
@@ -151,13 +162,14 @@
private:
bool flatten;
+ bool dropUnitDims;
};
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>> createOptimizeVectorTransferPass(
- bool flatten) {
- return std::make_unique<OptimizeVectorTransferPass>(flatten);
+ bool flatten, bool dropUnitDims) {
+ return std::make_unique<OptimizeVectorTransferPass>(flatten, dropUnitDims);
}
} // namespace iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index 481fc33..22e5601 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -552,12 +552,16 @@
nestedModulePM.addNestedPass<func::FuncOp>(createCSEPass());
}
- nestedModulePM.addNestedPass<func::FuncOp>(createCSEPass());
- nestedModulePM.addNestedPass<func::FuncOp>(createCanonicalizerPass());
+ // Eliminate redundant transfer_read/write to avoid stack allocations.
nestedModulePM.addNestedPass<func::FuncOp>(
createOptimizeVectorTransferPass(/*flatten=*/true));
+
addBufferizePasses(nestedModulePM);
+ // Perform memref-based transfer_read/write optimizations.
+ nestedModulePM.addNestedPass<func::FuncOp>(
+ createOptimizeVectorTransferPass(/*flatten=*/false));
+
// Run IREE specific passes before vector lowering expert.
nestedModulePM.addNestedPass<func::FuncOp>(
createRemoveSingleIterationLoopPass());
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pad_pipeline_tests.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pad_pipeline_tests.mlir
index 182021c..1c89a6a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pad_pipeline_tests.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pad_pipeline_tests.mlir
@@ -217,4 +217,5 @@
// CHECK: scf.yield
// CHECK: scf.yield
// CHECK: scf.yield
-// CHECK-COUNT-7: vector.store %{{.+}}, %[[OUTPUT_SUBVIEW_0]]
+// CHECK: %[[OUTPUT_SUBVIEW_1:.+]] = memref.subview %[[OUTPUT_SUBVIEW_0]]
+// CHECK-COUNT-7: vector.store %{{.+}}, %[[OUTPUT_SUBVIEW_1]]
diff --git a/compiler/src/iree/compiler/Codegen/Passes.td b/compiler/src/iree/compiler/Codegen/Passes.td
index a4bfa03..08f6c79 100644
--- a/compiler/src/iree/compiler/Codegen/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Passes.td
@@ -236,7 +236,9 @@
let constructor = "mlir::iree_compiler::createOptimizeVectorTransferPass()";
let options = [
Option<"optionFlatten", "flatten", "bool", "false",
- "Flatten the vector type of vector transfers where possible (contiguous row-major data).">
+ "Flatten the vector type of vector transfers where possible (contiguous row-major data).">,
+ Option<"optionDropUnitDims", "drop-unit-dims", "bool", /*default=*/"true",
+ "Drop unit dims in vector transfers where possible (might generate vector.shape_cast).">,
];
let dependentDialects = [
"memref::MemRefDialect"
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
index b220e3e..16e2321 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -208,14 +208,16 @@
pm.addPass(createSPIRVVectorizeLoadStore());
// Perform various vector-level cross-op optimizations like load-store
// forwarding, shape casting and casting op cancelling.
- pm.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass());
+ pm.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass(
+ /*flatten=*/false, /*dropUnitDims=*/false));
pm.addNestedPass<func::FuncOp>(createSPIRVBreakDownLargeVectorPass());
// Perform optimizations that need to across the scf.for region boundary.
pm.addNestedPass<func::FuncOp>(createForOpCanonicalizationPass());
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
- pm.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass());
+ pm.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass(
+ /*flatten=*/false, /*dropUnitDims=*/false));
// Turn multi-dimension memref into one-dimension. This is needed for SPIR-V
// because we don't use upstream memref descriptors.
@@ -311,8 +313,8 @@
// Perform various vector-level cross-op optimizations like load-store
// forwarding, shape casting and casting op cancelling.
- nestedModulePM.addNestedPass<func::FuncOp>(
- createOptimizeVectorTransferPass());
+ nestedModulePM.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass(
+ /*flatten=*/false, /*dropUnitDims=*/false));
}
void addSPIRVCooperativeMatrixVectorizePassPipeline(OpPassManager &pm,
@@ -370,8 +372,8 @@
// Perform various vector-level cross-op optimizations like load-store
// forwarding, shape casting and casting op cancelling.
- nestedModulePM.addNestedPass<func::FuncOp>(
- createOptimizeVectorTransferPass());
+ nestedModulePM.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass(
+ /*flatten=*/false, /*dropUnitDims=*/false));
// Fold subview ops is reqiured for converting vector transfer ops into SPIR-V
// cooperative ops in the next step.
@@ -445,10 +447,12 @@
// to hoisting. Because this is before folding all memref subview ops away, we
// still have subview ops using the same indices, which allows for transfer
// read/write forwarding.
- nestedPM.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass());
+ nestedPM.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass(
+ /*flatten=*/false, /*dropUnitDims=*/false));
nestedPM.addNestedPass<func::FuncOp>(memref::createFoldMemRefAliasOpsPass());
- nestedPM.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass());
+ nestedPM.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass(
+ /*flatten=*/false, /*dropUnitDims=*/false));
// Hoist loop invariant code to avoid pipelining it.
nestedPM.addNestedPass<func::FuncOp>(createLoopInvariantCodeMotionPass());
@@ -506,8 +510,8 @@
// Perform various vector-level cross-op optimizations like load-store
// forwarding, shape casting and casting op cancelling.
- nestedModulePM.addNestedPass<func::FuncOp>(
- createOptimizeVectorTransferPass());
+ nestedModulePM.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass(
+ /*flatten=*/false, /*dropUnitDims=*/false));
// Simplify the IR for vector distribution.
nestedModulePM.addNestedPass<func::FuncOp>(
@@ -567,8 +571,8 @@
// Perform various vector-level cross-op optimizations like load-store
// forwarding, shape casting and casting op cancelling.
- nestedModulePM.addNestedPass<func::FuncOp>(
- createOptimizeVectorTransferPass());
+ nestedModulePM.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass(
+ /*flatten=*/false, /*dropUnitDims=*/false));
}
void addSPIRVTransformDialectPassPipeline(OpPassManager &pm) {