Merge pull request #5671 from ThomasRaoux/handle_vectorization

Prepare handling for relaxed vectorization for SPIR-V and CUDA backends
diff --git a/iree/compiler/Conversion/LinalgToNVVM/VectorizationPass.cpp b/iree/compiler/Conversion/LinalgToNVVM/VectorizationPass.cpp
index 69971dd..7030e23 100644
--- a/iree/compiler/Conversion/LinalgToNVVM/VectorizationPass.cpp
+++ b/iree/compiler/Conversion/LinalgToNVVM/VectorizationPass.cpp
@@ -98,6 +98,18 @@
     });
 
     {
+      // Lower transfer op to canonical form.
+      OwningRewritePatternList lowerTransferOpPatterns(funcOp.getContext());
+      vector::populateVectorToVectorCanonicalizationPatterns(
+          lowerTransferOpPatterns);
+      vector::populateVectorToVectorTransformationPatterns(
+          lowerTransferOpPatterns);
+      vector::populateVectorTransferLoweringPatterns(lowerTransferOpPatterns);
+      (void)applyPatternsAndFoldGreedily(funcOp,
+                                         std::move(lowerTransferOpPatterns));
+    }
+
+    {
       // Step 2. Unroll the vetors to native size and canonicalize.
       OwningRewritePatternList vectorUnrollPatterns(context);
       populateVectorUnrollPatterns(vectorUnrollPatterns);
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/TileAndVectorizeInOneWorkgroupPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/TileAndVectorizeInOneWorkgroupPass.cpp
index 8fc4463..3be9b52 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/TileAndVectorizeInOneWorkgroupPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/TileAndVectorizeInOneWorkgroupPass.cpp
@@ -323,29 +323,102 @@
       vector::UnrollVectorOptions().setNativeShapeFn(getNativeVectorSize));
 }
 
+namespace {
+
+/// Workaround SPIR-V backend limitations. SPIR-V vetorization pass relies on
+/// unrolling to reduce instructions to a vector size we can convert to SPIR-V.
+/// When vectorization creates transpose those block unrolling and result in
+/// large vector we currently cannot lower. For now we always merge the
+/// transpose into the contract op so that it can be unrolled.
+// TODO(thomasraoux): Make transpose work with the current unrolling mechanism
+// or replace unrolling.
+class CombineContractTranspose final
+    : public OpRewritePattern<vector::ContractionOp> {
+ public:
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override {
+    // Perform lhs + rhs transpositions to conform to matmul row-major
+    // semantics. Bail out if the contraction cannot be put in this form.
+    MLIRContext *ctx = op.getContext();
+    Location loc = op.getLoc();
+    bool foundTranspose = false;
+    std::array<Value, 3> sources = {op.lhs(), op.rhs(), op.acc()};
+    SmallVector<AffineMap> newMaps;
+    SmallVector<Value> newSources;
+    for (auto source : llvm::enumerate(sources)) {
+      auto map =
+          op.indexing_maps()[source.index()].cast<AffineMapAttr>().getValue();
+      auto tranposeOp = source.value().getDefiningOp<vector::TransposeOp>();
+      if (!tranposeOp) {
+        newSources.push_back(source.value());
+        newMaps.push_back(map);
+        continue;
+      }
+      SmallVector<int64_t, 3> perm;
+      tranposeOp.getTransp(perm);
+      SmallVector<AffineExpr> exprs(perm.size());
+      for (auto remap : llvm::enumerate(perm)) {
+        exprs[remap.value()] = map.getResult(remap.index());
+      }
+      newMaps.push_back(
+          AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs, ctx));
+      newSources.push_back(tranposeOp.vector());
+      foundTranspose = true;
+    }
+    if (!foundTranspose) return failure();
+
+    Value res = rewriter.create<vector::ContractionOp>(
+        loc, newSources[0], newSources[1], newSources[2],
+        rewriter.getAffineMapArrayAttr(newMaps), op.iterator_types());
+    rewriter.replaceOp(op, res);
+    return success();
+  }
+};
+
+}  // namespace
+
 //====---------------------------------------------------------------------===//
 // Vector patterns
 //====---------------------------------------------------------------------===//
 
 static void applyVectorTransformation(FuncOp funcOp) {
   {
-    OwningRewritePatternList vectorUnrollPatterns(funcOp.getContext());
-    populateVectorUnrollPatterns(funcOp.getContext(), vectorUnrollPatterns);
-    (void)applyPatternsAndFoldGreedily(funcOp, std::move(vectorUnrollPatterns));
+    {
+      OwningRewritePatternList lowerTransferOpPatterns(funcOp.getContext());
+      vector::populateVectorToVectorCanonicalizationPatterns(
+          lowerTransferOpPatterns);
+      vector::populateVectorToVectorTransformationPatterns(
+          lowerTransferOpPatterns);
+      vector::populateVectorTransferLoweringPatterns(lowerTransferOpPatterns);
+      lowerTransferOpPatterns.add<CombineContractTranspose>(
+          funcOp.getContext());
+      (void)applyPatternsAndFoldGreedily(funcOp,
+                                         std::move(lowerTransferOpPatterns));
+    }
+    {
+      OwningRewritePatternList vectorUnrollPatterns(funcOp.getContext());
+      populateVectorUnrollPatterns(funcOp.getContext(), vectorUnrollPatterns);
+      (void)applyPatternsAndFoldGreedily(funcOp,
+                                         std::move(vectorUnrollPatterns));
+    }
+    {
+      OwningRewritePatternList canonicalizationPatterns1(funcOp.getContext());
 
-    OwningRewritePatternList canonicalizationPatterns1(funcOp.getContext());
-    vector::populateVectorToVectorCanonicalizationPatterns(
-        canonicalizationPatterns1);
-    vector::populateVectorToVectorTransformationPatterns(
-        canonicalizationPatterns1);
-    vector::populateSplitVectorTransferPatterns(canonicalizationPatterns1);
-    (void)applyPatternsAndFoldGreedily(funcOp,
-                                       std::move(canonicalizationPatterns1));
+      vector::populateVectorToVectorTransformationPatterns(
+          canonicalizationPatterns1);
+      vector::populateVectorToVectorCanonicalizationPatterns(
+          canonicalizationPatterns1);
+      vector::populateSplitVectorTransferPatterns(canonicalizationPatterns1);
+      (void)applyPatternsAndFoldGreedily(funcOp,
+                                         std::move(canonicalizationPatterns1));
 
-    OwningRewritePatternList canonicalizationPatterns2(funcOp.getContext());
-    vector::populateVectorSlicesLoweringPatterns(canonicalizationPatterns2);
-    (void)applyPatternsAndFoldGreedily(funcOp,
-                                       std::move(canonicalizationPatterns2));
+      OwningRewritePatternList canonicalizationPatterns2(funcOp.getContext());
+      vector::populateVectorSlicesLoweringPatterns(canonicalizationPatterns2);
+      (void)applyPatternsAndFoldGreedily(funcOp,
+                                         std::move(canonicalizationPatterns2));
+    }
     LLVM_DEBUG({
       llvm::dbgs() << "--- After Vector Unroll ---\n";
       funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());