Merge pull request #5671 from ThomasRaoux/handle_vectorization
Prepare handling for relaxed vectorization for SPIR-V and CUDA backends
diff --git a/iree/compiler/Conversion/LinalgToNVVM/VectorizationPass.cpp b/iree/compiler/Conversion/LinalgToNVVM/VectorizationPass.cpp
index 69971dd..7030e23 100644
--- a/iree/compiler/Conversion/LinalgToNVVM/VectorizationPass.cpp
+++ b/iree/compiler/Conversion/LinalgToNVVM/VectorizationPass.cpp
@@ -98,6 +98,18 @@
});
{
+ // Lower transfer op to canonical form.
+ OwningRewritePatternList lowerTransferOpPatterns(funcOp.getContext());
+ vector::populateVectorToVectorCanonicalizationPatterns(
+ lowerTransferOpPatterns);
+ vector::populateVectorToVectorTransformationPatterns(
+ lowerTransferOpPatterns);
+ vector::populateVectorTransferLoweringPatterns(lowerTransferOpPatterns);
+ (void)applyPatternsAndFoldGreedily(funcOp,
+ std::move(lowerTransferOpPatterns));
+ }
+
+ {
// Step 2. Unroll the vetors to native size and canonicalize.
OwningRewritePatternList vectorUnrollPatterns(context);
populateVectorUnrollPatterns(vectorUnrollPatterns);
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/TileAndVectorizeInOneWorkgroupPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/TileAndVectorizeInOneWorkgroupPass.cpp
index 8fc4463..3be9b52 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/TileAndVectorizeInOneWorkgroupPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/TileAndVectorizeInOneWorkgroupPass.cpp
@@ -323,29 +323,102 @@
vector::UnrollVectorOptions().setNativeShapeFn(getNativeVectorSize));
}
+namespace {
+
+/// Workaround SPIR-V backend limitations. SPIR-V vetorization pass relies on
+/// unrolling to reduce instructions to a vector size we can convert to SPIR-V.
+/// When vectorization creates transpose those block unrolling and result in
+/// large vector we currently cannot lower. For now we always merge the
+/// transpose into the contract op so that it can be unrolled.
+// TODO(thomasraoux): Make transpose work with the current unrolling mechanism
+// or replace unrolling.
+class CombineContractTranspose final
+ : public OpRewritePattern<vector::ContractionOp> {
+ public:
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override {
+ // Perform lhs + rhs transpositions to conform to matmul row-major
+ // semantics. Bail out if the contraction cannot be put in this form.
+ MLIRContext *ctx = op.getContext();
+ Location loc = op.getLoc();
+ bool foundTranspose = false;
+ std::array<Value, 3> sources = {op.lhs(), op.rhs(), op.acc()};
+ SmallVector<AffineMap> newMaps;
+ SmallVector<Value> newSources;
+ for (auto source : llvm::enumerate(sources)) {
+ auto map =
+ op.indexing_maps()[source.index()].cast<AffineMapAttr>().getValue();
+ auto tranposeOp = source.value().getDefiningOp<vector::TransposeOp>();
+ if (!tranposeOp) {
+ newSources.push_back(source.value());
+ newMaps.push_back(map);
+ continue;
+ }
+ SmallVector<int64_t, 3> perm;
+ tranposeOp.getTransp(perm);
+ SmallVector<AffineExpr> exprs(perm.size());
+ for (auto remap : llvm::enumerate(perm)) {
+ exprs[remap.value()] = map.getResult(remap.index());
+ }
+ newMaps.push_back(
+ AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs, ctx));
+ newSources.push_back(tranposeOp.vector());
+ foundTranspose = true;
+ }
+ if (!foundTranspose) return failure();
+
+ Value res = rewriter.create<vector::ContractionOp>(
+ loc, newSources[0], newSources[1], newSources[2],
+ rewriter.getAffineMapArrayAttr(newMaps), op.iterator_types());
+ rewriter.replaceOp(op, res);
+ return success();
+ }
+};
+
+} // namespace
+
//====---------------------------------------------------------------------===//
// Vector patterns
//====---------------------------------------------------------------------===//
static void applyVectorTransformation(FuncOp funcOp) {
{
- OwningRewritePatternList vectorUnrollPatterns(funcOp.getContext());
- populateVectorUnrollPatterns(funcOp.getContext(), vectorUnrollPatterns);
- (void)applyPatternsAndFoldGreedily(funcOp, std::move(vectorUnrollPatterns));
+ {
+ OwningRewritePatternList lowerTransferOpPatterns(funcOp.getContext());
+ vector::populateVectorToVectorCanonicalizationPatterns(
+ lowerTransferOpPatterns);
+ vector::populateVectorToVectorTransformationPatterns(
+ lowerTransferOpPatterns);
+ vector::populateVectorTransferLoweringPatterns(lowerTransferOpPatterns);
+ lowerTransferOpPatterns.add<CombineContractTranspose>(
+ funcOp.getContext());
+ (void)applyPatternsAndFoldGreedily(funcOp,
+ std::move(lowerTransferOpPatterns));
+ }
+ {
+ OwningRewritePatternList vectorUnrollPatterns(funcOp.getContext());
+ populateVectorUnrollPatterns(funcOp.getContext(), vectorUnrollPatterns);
+ (void)applyPatternsAndFoldGreedily(funcOp,
+ std::move(vectorUnrollPatterns));
+ }
+ {
+ OwningRewritePatternList canonicalizationPatterns1(funcOp.getContext());
- OwningRewritePatternList canonicalizationPatterns1(funcOp.getContext());
- vector::populateVectorToVectorCanonicalizationPatterns(
- canonicalizationPatterns1);
- vector::populateVectorToVectorTransformationPatterns(
- canonicalizationPatterns1);
- vector::populateSplitVectorTransferPatterns(canonicalizationPatterns1);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(canonicalizationPatterns1));
+ vector::populateVectorToVectorTransformationPatterns(
+ canonicalizationPatterns1);
+ vector::populateVectorToVectorCanonicalizationPatterns(
+ canonicalizationPatterns1);
+ vector::populateSplitVectorTransferPatterns(canonicalizationPatterns1);
+ (void)applyPatternsAndFoldGreedily(funcOp,
+ std::move(canonicalizationPatterns1));
- OwningRewritePatternList canonicalizationPatterns2(funcOp.getContext());
- vector::populateVectorSlicesLoweringPatterns(canonicalizationPatterns2);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(canonicalizationPatterns2));
+ OwningRewritePatternList canonicalizationPatterns2(funcOp.getContext());
+ vector::populateVectorSlicesLoweringPatterns(canonicalizationPatterns2);
+ (void)applyPatternsAndFoldGreedily(funcOp,
+ std::move(canonicalizationPatterns2));
+ }
LLVM_DEBUG({
llvm::dbgs() << "--- After Vector Unroll ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());