[spirv] NFC: Restructure the vector lowering pass (#15134)
This commit breaks the `SPIRVVectorLowering` pass into two parts, a
`SPIRVInitialVectorLowering` pass and a `SPIRVFinalVectorLowering` pass.
In the meanwhile, use the common hoisting pass for the hoisting.
Progress towards https://github.com/openxla/iree/issues/15083
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel b/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel
index edb009b..23ac1dc 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel
@@ -61,14 +61,15 @@
"SPIRVDistribute.cpp",
"SPIRVEmulateI64.cpp",
"SPIRVEraseStorageBufferStaticShape.cpp",
+ "SPIRVFinalVectorLowering.cpp",
"SPIRVGeneralizeNamedOps.cpp",
+ "SPIRVInitialVectorLowering.cpp",
"SPIRVLowerExecutableTargetPass.cpp",
"SPIRVMapMemRefStorageClass.cpp",
"SPIRVTile.cpp",
"SPIRVTileAndDistribute.cpp",
"SPIRVTileAndPromote.cpp",
"SPIRVTileAndVectorizeToCooperativeOps.cpp",
- "SPIRVVectorLowering.cpp",
"SPIRVVectorToGPUSubgroupMMAOps.cpp",
"SPIRVVectorizeLoadStore.cpp",
"Utils.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt
index e212853..e722a23 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt
@@ -60,14 +60,15 @@
"SPIRVDistribute.cpp"
"SPIRVEmulateI64.cpp"
"SPIRVEraseStorageBufferStaticShape.cpp"
+ "SPIRVFinalVectorLowering.cpp"
"SPIRVGeneralizeNamedOps.cpp"
+ "SPIRVInitialVectorLowering.cpp"
"SPIRVLowerExecutableTargetPass.cpp"
"SPIRVMapMemRefStorageClass.cpp"
"SPIRVTile.cpp"
"SPIRVTileAndDistribute.cpp"
"SPIRVTileAndPromote.cpp"
"SPIRVTileAndVectorizeToCooperativeOps.cpp"
- "SPIRVVectorLowering.cpp"
"SPIRVVectorToGPUSubgroupMMAOps.cpp"
"SPIRVVectorizeLoadStore.cpp"
"Utils.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
index 129e429..f3efc4d 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -127,6 +127,14 @@
nestedModulePM.addPass(createCSEPass());
}
+/// Adds passes to lower vector ops to meet SPIR-V requirements.
+static void addSPIRVVectorLoweringPasses(OpPassManager &modulePM) {
+ modulePM.addNestedPass<func::FuncOp>(createSPIRVInitialVectorLoweringPass());
+ modulePM.addNestedPass<func::FuncOp>(
+ createHoistRedundantVectorTransfersPass());
+ modulePM.addNestedPass<func::FuncOp>(createSPIRVFinalVectorLoweringPass());
+}
+
static void
addSPIRVBufferizePasses(OpPassManager &passManager,
BufferizationOptions::AllocationFn allocationFn) {
@@ -298,7 +306,7 @@
nestedModulePM.addNestedPass<func::FuncOp>(
createGenericVectorizationPass(options));
}
- nestedModulePM.addNestedPass<func::FuncOp>(createSPIRVVectorLoweringPass());
+ addSPIRVVectorLoweringPasses(nestedModulePM);
nestedModulePM.addNestedPass<func::FuncOp>(createForOpCanonicalizationPass());
nestedModulePM.addPass(createCanonicalizerPass());
nestedModulePM.addPass(createCSEPass());
@@ -384,7 +392,7 @@
createSPIRVVectorToGPUSubgroupMMAOpsPass());
nestedModulePM.addPass(createCanonicalizerPass());
nestedModulePM.addPass(createCSEPass());
- nestedModulePM.addNestedPass<func::FuncOp>(createSPIRVVectorLoweringPass());
+ addSPIRVVectorLoweringPasses(nestedModulePM);
if (pipelineDepth > 0) {
PipeliningSchedulingStrategy schedule =
@@ -446,7 +454,7 @@
nestedPM.addNestedPass<func::FuncOp>(createGPUReduceSharedMemoryBankConflicts(
detail::bankConflictReductionPaddingBits));
- nestedPM.addNestedPass<func::FuncOp>(createSPIRVVectorLoweringPass());
+ addSPIRVVectorLoweringPasses(nestedPM);
nestedPM.addNestedPass<func::FuncOp>(createForOpCanonicalizationPass());
nestedPM.addPass(createCanonicalizerPass());
nestedPM.addPass(createCSEPass());
@@ -574,7 +582,7 @@
createConvertVectorReductionToGPUPass(getWarpSize));
// Perform normal vector unrolling and lowering transformations. This breaks
// vectors down to native machine size.
- nestedModulePM.addNestedPass<func::FuncOp>(createSPIRVVectorLoweringPass());
+ addSPIRVVectorLoweringPasses(nestedModulePM);
nestedModulePM.addPass(createCanonicalizerPass());
nestedModulePM.addPass(createCSEPass());
}
@@ -605,7 +613,7 @@
nestedModulePM.addNestedPass<func::FuncOp>(
createGenericVectorizationPass(options));
}
- nestedModulePM.addNestedPass<func::FuncOp>(createSPIRVVectorLoweringPass());
+ addSPIRVVectorLoweringPasses(nestedModulePM);
nestedModulePM.addNestedPass<func::FuncOp>(createForOpCanonicalizationPass());
nestedModulePM.addPass(createCanonicalizerPass());
nestedModulePM.addPass(createCSEPass());
@@ -629,7 +637,7 @@
// for SPIR-V.
auto &nestedModulePM = pm.nest<ModuleOp>();
nestedModulePM.addNestedPass<func::FuncOp>(createGenericVectorizationPass());
- nestedModulePM.addNestedPass<func::FuncOp>(createSPIRVVectorLoweringPass());
+ addSPIRVVectorLoweringPasses(nestedModulePM);
}
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h
index faf2d9f..e762e11 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h
@@ -136,8 +136,11 @@
/// having pointer bitcast.
std::unique_ptr<OperationPass<ModuleOp>> createSPIRVVectorizeLoadStore();
-/// Pass to vectorize Linalg ops with buffer semantics.
-std::unique_ptr<OperationPass<func::FuncOp>> createSPIRVVectorLoweringPass();
+/// Pass to lower vector ops to meet SPIR-V requirements.
+std::unique_ptr<OperationPass<func::FuncOp>>
+createSPIRVInitialVectorLoweringPass();
+std::unique_ptr<OperationPass<func::FuncOp>>
+createSPIRVFinalVectorLoweringPass();
/// Pass to do vectorization suitable for lowering to SPIR-V cooperative ops.
std::unique_ptr<OperationPass<func::FuncOp>>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.td b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.td
index 05ed844..e409e8b 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.td
@@ -112,9 +112,16 @@
"mlir::iree_compiler::createSPIRVTileToCooperativeOpsPass()";
}
-def SPIRVVectorLowering : Pass<"iree-spirv-vector-lowering", "func::FuncOp"> {
- let summary = "Vectorize Linalg ops with buffer semantics";
- let constructor = "mlir::iree_compiler::createSPIRVVectorLoweringPass()";
+def SPIRVInitialVectorLowering : Pass<
+ "iree-spirv-initial-vector-lowering", "func::FuncOp"> {
+ let summary = "Perform initial lowering of vectors ops to fit SPIR-V";
+ let constructor = "mlir::iree_compiler::createSPIRVInitialVectorLoweringPass()";
+}
+
+def SPIRVFinalVectorLowering : Pass<
+ "iree-spirv-final-vector-lowering", "func::FuncOp"> {
+ let summary = "Perform final lowering of vectors ops to fit SPIR-V";
+ let constructor = "mlir::iree_compiler::createSPIRVFinalVectorLoweringPass()";
}
def SPIRVVectorizeLoadStore :
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVFinalVectorLowering.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVFinalVectorLowering.cpp
new file mode 100644
index 0000000..a6e0e81
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVFinalVectorLowering.cpp
@@ -0,0 +1,115 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+//===- SPIRVFinalVectorLowering.cpp ---------------------------------------===//
+//
+// This pass hosts final steps towards lowering vectors ops to meet SPIR-V
+// requirements--it applies vector lowering patterns to convert vector ops
+// to more basic forms.
+//
+//===----------------------------------------------------------------------===//
+
+#include "iree/compiler/Codegen/Common/Passes.h"
+#include "iree/compiler/Codegen/SPIRV/PassDetail.h"
+#include "iree/compiler/Codegen/SPIRV/Passes.h"
+#include "iree/compiler/Codegen/Transforms/Transforms.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/VectorInterfaces.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-spirv-final-vector-lowering"
+
+namespace mlir {
+namespace iree_compiler {
+namespace {
+
+void debugPrint(func::FuncOp funcOp, const char *message) {
+ LLVM_DEBUG({
+ llvm::dbgs() << "//--- " << message << " ---//\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+}
+
+class SPIRVFinalVectorLoweringPass
+ : public SPIRVFinalVectorLoweringBase<SPIRVFinalVectorLoweringPass> {
+public:
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ // vector.gather lowering patterns target scf ops.
+ registry.insert<scf::SCFDialect, vector::VectorDialect>();
+ }
+
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+ func::FuncOp funcOp = getOperation();
+
+ // Lower vector transfer permutation map.
+ {
+ RewritePatternSet patterns(context);
+ vector::ExtractStridedSliceOp::getCanonicalizationPatterns(patterns,
+ context);
+ vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+
+ debugPrint(funcOp, "after lowering transfer ops");
+
+ // Lower vector broadcast/transpose and contraction.
+ {
+ RewritePatternSet patterns(context);
+ auto options = vector::VectorTransformsOptions()
+ .setVectorTransformsOptions(
+ vector::VectorContractLowering::OuterProduct)
+ .setVectorTransposeLowering(
+ vector::VectorTransposeLowering::EltWise);
+ vector::populateVectorBroadcastLoweringPatterns(patterns);
+ vector::populateVectorContractLoweringPatterns(patterns, options);
+ vector::populateVectorMultiReductionLoweringPatterns(
+ patterns, vector::VectorMultiReductionLowering::InnerParallel);
+ vector::populateVectorTransposeLoweringPatterns(patterns, options);
+ vector::populateVectorGatherLoweringPatterns(patterns);
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+
+ debugPrint(funcOp, "after lowering various vector ops");
+
+ // Run all sorts of canonicalization patterns to clean up again.
+ {
+ RewritePatternSet patterns(context);
+ vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
+ vector::InsertOp::getCanonicalizationPatterns(patterns, context);
+ vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
+ vector::TransferReadOp::getCanonicalizationPatterns(patterns, context);
+ vector::TransferWriteOp::getCanonicalizationPatterns(patterns, context);
+ populateVectorTransferTensorSliceTransforms(patterns);
+ vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
+ scf::IfOp::getCanonicalizationPatterns(patterns, context);
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<func::FuncOp>>
+createSPIRVFinalVectorLoweringPass() {
+ return std::make_unique<SPIRVFinalVectorLoweringPass>();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorLowering.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVInitialVectorLowering.cpp
similarity index 75%
rename from compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorLowering.cpp
rename to compiler/src/iree/compiler/Codegen/SPIRV/SPIRVInitialVectorLowering.cpp
index 3f6dff6..d25ec01 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorLowering.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVInitialVectorLowering.cpp
@@ -4,29 +4,23 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//===- SPIRVVectorLoweringPass.cpp
-//-------------------------------------------------===//
+//===- SPIRVInitialVectorLowering.cpp -------------------------------------===//
//
-// This pass vectorizes Linalg ops with buffer semantics.
+// This pass hosts initial steps towards lowering vectors ops to meet SPIR-V
+// requirements--it applies vector lowering patterns to unroll large n-D vectors
+// to 1-D ones that are directly in SPIR-V.
//
//===----------------------------------------------------------------------===//
-#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
-#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/SPIRV/PassDetail.h"
#include "iree/compiler/Codegen/SPIRV/Passes.h"
#include "iree/compiler/Codegen/SPIRV/Utils.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
-#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
-#include "iree/compiler/Codegen/Utils/Utils.h"
-#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
@@ -35,19 +29,24 @@
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
-#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/VectorInterfaces.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#define DEBUG_TYPE "iree-spirv-vector-lowering"
+#define DEBUG_TYPE "iree-spirv-initial-vector-lowering"
namespace mlir {
namespace iree_compiler {
namespace {
+void debugPrint(func::FuncOp funcOp, const char *message) {
+ LLVM_DEBUG({
+ llvm::dbgs() << "//--- " << message << " ---//\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+}
+
int getComputeVectorSize(int64_t size) {
for (int i : {4, 3, 2}) {
if (size % i == 0)
@@ -286,13 +285,9 @@
return true;
}
-/// Vectorizes Linalg ops on buffer semantics.
-class SPIRVVectorLoweringPass
- : public SPIRVVectorLoweringBase<SPIRVVectorLoweringPass> {
+class SPIRVInitialLoweringPass
+ : public SPIRVInitialVectorLoweringBase<SPIRVInitialLoweringPass> {
public:
- SPIRVVectorLoweringPass() = default;
- SPIRVVectorLoweringPass(const SPIRVVectorLoweringPass &pass) = default;
-
void getDependentDialects(DialectRegistry ®istry) const override {
// vector.gather lowering patterns target scf ops.
registry.insert<linalg::LinalgDialect, vector::VectorDialect,
@@ -316,11 +311,7 @@
}
}
- LLVM_DEBUG({
- llvm::dbgs() << "--- After IREE tensor.pad vectorization ---\n";
- funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
- llvm::dbgs() << "\n\n";
- });
+ debugPrint(funcOp, "after vectorizing tensor.pad");
// Special peephole optimizations to clean up IR before further processing.
{
@@ -342,11 +333,7 @@
}
}
- LLVM_DEBUG({
- llvm::dbgs() << "--- After peephole optimization ---\n";
- funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
- llvm::dbgs() << "\n\n";
- });
+ debugPrint(funcOp, "after peephole optimization");
// High dimension contraction can appear after vectorizing ops like 1-D
// convolution. Those 1-D convolution ops typically have a leading unit
@@ -362,11 +349,7 @@
(void)vector::castAwayContractionLeadingOneDim(op, rewriter);
}
- LLVM_DEBUG({
- llvm::dbgs() << "--- After trimming contract leading unit dims ---\n";
- funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
- llvm::dbgs() << "\n\n";
- });
+ debugPrint(funcOp, "after trimming contract leading unit dims");
// Fold tensor.extract_slice/insert_slice ops into transfer ops. This helps
// to remove those tensor slice ops so that we can enable further vector op
@@ -382,11 +365,7 @@
}
}
- LLVM_DEBUG({
- llvm::dbgs() << "--- After folding tensor extract/insert slice ops ---\n";
- funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
- llvm::dbgs() << "\n\n";
- });
+ debugPrint(funcOp, "after folding tensor extract/insert slice ops");
// Lower vector.multi_dimension early if any operand is a transpose op.
// The lowering itself generates transpose ops. This helps to cancel
@@ -411,11 +390,7 @@
}
}
- LLVM_DEBUG({
- llvm::dbgs() << "--- After lowering multi_reduction ops ---\n";
- funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
- llvm::dbgs() << "\n\n";
- });
+ debugPrint(funcOp, "after lowering multi reduction ops");
// Prepare for SPIR-V integer dot product lowering.
if (emitIntegerDotProdOps) {
@@ -426,12 +401,7 @@
return signalPassFailure();
}
- LLVM_DEBUG({
- llvm::dbgs()
- << "--- After prepare for SPIR-V dot product lowering ---\n";
- funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
- llvm::dbgs() << "\n\n";
- });
+ debugPrint(funcOp, "after preparing for SPIR-V dot product lowering");
}
// Then unroll vectors to native vector size. We try to use 128-bit
@@ -444,11 +414,7 @@
}
}
- LLVM_DEBUG({
- llvm::dbgs() << "--- After unrolling vector ---\n";
- funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
- llvm::dbgs() << "\n\n";
- });
+ debugPrint(funcOp, "after unrolling vector ops");
// Lower reduction-unrolled vector contract ops. Such contract ops have
// their reduction dimensions all be one, so we can convert them into
@@ -470,11 +436,7 @@
}
}
- LLVM_DEBUG({
- llvm::dbgs() << "--- After lowering size-1 reduction contract ops ---\n";
- funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
- llvm::dbgs() << "\n\n";
- });
+ debugPrint(funcOp, "after lowering size-1 reduction contract ops");
// Now lower vector transpose given we have handled vector patterns that may
// generate transpose ops in previous steps. This converts transpose ops
@@ -491,11 +453,7 @@
}
}
- LLVM_DEBUG({
- llvm::dbgs() << "--- After lowering transpose ops ---\n";
- funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
- llvm::dbgs() << "\n\n";
- });
+ debugPrint(funcOp, "after lowering transpose ops");
// Next run canonicalization to cast away leading size-1 dimensions. They
// can be generated from vector unrolling and generally cause issues to
@@ -531,11 +489,7 @@
}
}
- LLVM_DEBUG({
- llvm::dbgs() << "--- After trimming leading unit dims ---\n";
- funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
- llvm::dbgs() << "\n\n";
- });
+ debugPrint(funcOp, "after trimming leading unit dims");
// Lower vector reduction to SPIR-V integer dot product.
if (emitIntegerDotProdOps) {
@@ -545,89 +499,16 @@
return signalPassFailure();
}
- LLVM_DEBUG({
- llvm::dbgs() << "--- After lowering to SPIR-V dot product ---\n";
- funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
- llvm::dbgs() << "\n\n";
- });
- }
-
- // Next perform hoisting. This would analyze transfer read/write ops into
- // tensors and hoist them out of loop nests. So after it we have
- // loop-carried vectors, not loop-carried tensors anymore.
- linalg::hoistRedundantVectorTransfersOnTensor(funcOp);
- linalg::hoistRedundantVectorTransfers(funcOp);
-
- LLVM_DEBUG({
- llvm::dbgs() << "--- After hoisting transfers ---\n";
- funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
- llvm::dbgs() << "\n\n";
- });
-
- // Lower vector transfer permutation map.
- {
- RewritePatternSet patterns(context);
- vector::ExtractStridedSliceOp::getCanonicalizationPatterns(patterns,
- context);
- vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
- if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
- return signalPassFailure();
- }
- }
-
- LLVM_DEBUG({
- llvm::dbgs() << "--- After lowering transfer ops ---\n";
- funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
- llvm::dbgs() << "\n\n";
- });
-
- // Lower vector broadcast/transpose and contraction.
- {
- RewritePatternSet patterns(context);
- auto options = vector::VectorTransformsOptions()
- .setVectorTransformsOptions(
- vector::VectorContractLowering::OuterProduct)
- .setVectorTransposeLowering(
- vector::VectorTransposeLowering::EltWise);
- vector::populateVectorBroadcastLoweringPatterns(patterns);
- vector::populateVectorContractLoweringPatterns(patterns, options);
- vector::populateVectorMultiReductionLoweringPatterns(
- patterns, vector::VectorMultiReductionLowering::InnerParallel);
- vector::populateVectorTransposeLoweringPatterns(patterns, options);
- vector::populateVectorGatherLoweringPatterns(patterns);
- if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
- return signalPassFailure();
- }
- }
-
- LLVM_DEBUG({
- llvm::dbgs() << "--- After lowering various vector ops ---\n";
- funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
- llvm::dbgs() << "\n\n";
- });
-
- // Run all sorts of canonicalization patterns to clean up again.
- {
- RewritePatternSet patterns(context);
- vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
- vector::InsertOp::getCanonicalizationPatterns(patterns, context);
- vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
- vector::TransferReadOp::getCanonicalizationPatterns(patterns, context);
- vector::TransferWriteOp::getCanonicalizationPatterns(patterns, context);
- populateVectorTransferTensorSliceTransforms(patterns);
- vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
- scf::IfOp::getCanonicalizationPatterns(patterns, context);
- if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
- return signalPassFailure();
- }
+ debugPrint(funcOp, "after lowering to SPIR-V dot product");
}
}
};
} // namespace
-std::unique_ptr<OperationPass<func::FuncOp>> createSPIRVVectorLoweringPass() {
- return std::make_unique<SPIRVVectorLoweringPass>();
+std::unique_ptr<OperationPass<func::FuncOp>>
+createSPIRVInitialVectorLoweringPass() {
+ return std::make_unique<SPIRVInitialLoweringPass>();
}
} // namespace iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_batch_matmul.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_batch_matmul.mlir
index 1fca89c..76efc45 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_batch_matmul.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_batch_matmul.mlir
@@ -1,5 +1,5 @@
// RUN: iree-opt --split-input-file \
-// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile,canonicalize,cse,iree-codegen-generic-vectorization,iree-spirv-vector-lowering,canonicalize,cse)))))' \
+// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile,canonicalize,cse,iree-codegen-generic-vectorization,iree-spirv-initial-vector-lowering,iree-codegen-hoist-redundant-vector-transfers,iree-spirv-final-vector-lowering,canonicalize,cse)))))' \
// RUN: %s | FileCheck %s
#config = #iree_codegen.lowering_config<tile_sizes = [[1, 8, 64], [1, 8, 4], [0, 0, 0, 4]]>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir
index f460f90..dcf863b 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir
@@ -1,5 +1,5 @@
// RUN: iree-opt --split-input-file \
-// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-create-fast-slow-path,iree-spirv-tile,canonicalize,cse,iree-codegen-generic-vectorization,iree-spirv-vector-lowering,canonicalize,cse)))))' \
+// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-create-fast-slow-path,iree-spirv-tile,canonicalize,cse,iree-codegen-generic-vectorization,iree-spirv-initial-vector-lowering,iree-codegen-hoist-redundant-vector-transfers,iree-spirv-final-vector-lowering,canonicalize,cse)))))' \
// RUN: %s | FileCheck %s
#config = #iree_codegen.lowering_config<tile_sizes = [[0, 4, 4, 16], [0, 2, 2, 4], [0, 0, 0, 0, 1, 1, 4], [0, 1, 0, 0]]>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_matmul.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_matmul.mlir
index 81e14fb..e82a5c3 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_matmul.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_matmul.mlir
@@ -1,5 +1,5 @@
// RUN: iree-opt --split-input-file \
-// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile,canonicalize,cse,iree-codegen-generic-vectorization,iree-spirv-vector-lowering,canonicalize,cse)))))' %s | FileCheck %s
+// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile,canonicalize,cse,iree-codegen-generic-vectorization,iree-spirv-initial-vector-lowering,iree-codegen-hoist-redundant-vector-transfers,iree-spirv-final-vector-lowering,canonicalize,cse)))))' %s | FileCheck %s
#config = #iree_codegen.lowering_config<tile_sizes = [[8, 64], [8, 4], [0, 0, 4]]>
#translation = #iree_codegen.translation_info<SPIRVBaseVectorize>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_pooling.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_pooling.mlir
index dcd40fd..6587c3a 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_pooling.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_pooling.mlir
@@ -1,5 +1,5 @@
// RUN: iree-opt --split-input-file \
-// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile,canonicalize,cse,iree-codegen-generic-vectorization,iree-spirv-vector-lowering,canonicalize,cse)))))' \
+// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile,canonicalize,cse,iree-codegen-generic-vectorization,iree-spirv-initial-vector-lowering,iree-codegen-hoist-redundant-vector-transfers,iree-spirv-final-vector-lowering,canonicalize,cse)))))' \
// RUN: %s | FileCheck %s
#config = #iree_codegen.lowering_config<tile_sizes = [[0, 2, 2, 8], [0, 1, 1, 4], [0, 0, 0, 0, 1, 1], [0, 1, 0, 0]]>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_conv.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_conv.mlir
index 2cf3d6b..60e1157 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_conv.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_conv.mlir
@@ -1,5 +1,5 @@
// RUN: iree-opt --split-input-file \
-// RUN: --pass-pipeline='builtin.module(func.func(iree-codegen-generic-vectorization,iree-spirv-vector-lowering))' \
+// RUN: --pass-pipeline='builtin.module(func.func(iree-codegen-generic-vectorization,iree-spirv-initial-vector-lowering,iree-codegen-hoist-redundant-vector-transfers,iree-spirv-final-vector-lowering))' \
// RUN: %s | FileCheck %s
func.func @ncw_conv_1d(%input: tensor<2x4x4xf32>, %filter: tensor<4x4x1xf32>, %init: tensor<2x4x4xf32>) -> tensor<2x4x4xf32> {
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_elementwise_ops.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_elementwise_ops.mlir
index e4e3dd2..81ad85d 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_elementwise_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_elementwise_ops.mlir
@@ -1,5 +1,5 @@
// RUN: iree-opt --split-input-file \
-// RUN: --pass-pipeline='builtin.module(func.func(iree-codegen-generic-vectorization,iree-spirv-vector-lowering))' \
+// RUN: --pass-pipeline='builtin.module(func.func(iree-codegen-generic-vectorization,iree-spirv-initial-vector-lowering,iree-codegen-hoist-redundant-vector-transfers,iree-spirv-final-vector-lowering))' \
// RUN: %s | FileCheck %s
func.func @add(%lhs: tensor<2x8xf32>, %rhs: tensor<2x8xf32>) -> tensor<2x8xf32> {
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_gather.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_gather.mlir
index a3e3b92..a2336b3 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_gather.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_gather.mlir
@@ -1,5 +1,5 @@
// RUN: iree-opt --split-input-file \
-// RUN: --pass-pipeline='builtin.module(func.func(iree-codegen-generic-vectorization{vectorize-gather-accesses},iree-spirv-vector-lowering))' \
+// RUN: --pass-pipeline='builtin.module(func.func(iree-codegen-generic-vectorization{vectorize-gather-accesses},iree-spirv-initial-vector-lowering,iree-codegen-hoist-redundant-vector-transfers,iree-spirv-final-vector-lowering))' \
// RUN: %s | FileCheck %s
func.func @tensor_extract(%arg0: tensor<6x4xf32>, %arg1: tensor<6xi32>, %data: tensor<1x2x512xf32>, %init: tensor<6x4xf32>, %i : index, %j: index) -> tensor<6x4xf32> {
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
index ab2f8b2..0cd0efd 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
@@ -1,5 +1,5 @@
// RUN: iree-opt --split-input-file \
-// RUN: --pass-pipeline='builtin.module(func.func(iree-codegen-generic-vectorization,iree-spirv-vector-lowering))' \
+// RUN: --pass-pipeline='builtin.module(func.func(iree-codegen-generic-vectorization,iree-spirv-initial-vector-lowering,iree-codegen-hoist-redundant-vector-transfers,iree-spirv-final-vector-lowering))' \
// RUN: %s | FileCheck %s
func.func @matmul_1x4x4(%lhs: tensor<1x4xf32>, %rhs: tensor<4x4xf32>, %init: tensor<1x4xf32>) -> tensor<1x4xf32> {
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_reduction.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_reduction.mlir
index 917560f..5e6a71c 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_reduction.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_reduction.mlir
@@ -1,5 +1,5 @@
// RUN: iree-opt --split-input-file \
-// RUN: --pass-pipeline='builtin.module(func.func(iree-codegen-generic-vectorization,iree-spirv-vector-lowering))' \
+// RUN: --pass-pipeline='builtin.module(func.func(iree-codegen-generic-vectorization,iree-spirv-initial-vector-lowering,iree-codegen-hoist-redundant-vector-transfers,iree-spirv-final-vector-lowering))' \
// RUN: %s | FileCheck %s
func.func @reduce_outmost_dim(%input: tensor<4x1x4xf32>, %init: tensor<1x4xf32>) -> tensor<1x4xf32> {