[CUDA] Re-organize the set of vector transformation to improve codegen (#6628)

Improve the order of transformation and enable transfer ops
optimizations.
diff --git a/iree/compiler/Codegen/LLVMGPU/BUILD b/iree/compiler/Codegen/LLVMGPU/BUILD
index c2bddc1..fd00b70 100644
--- a/iree/compiler/Codegen/LLVMGPU/BUILD
+++ b/iree/compiler/Codegen/LLVMGPU/BUILD
@@ -20,6 +20,7 @@
         "LLVMGPULowerExecutableTarget.cpp",
         "LLVMGPURemoveTrivialLoops.cpp",
         "LLVMGPUTileAndDistribute.cpp",
+        "LLVMGPUVectorLowering.cpp",
         "LLVMGPUVectorization.cpp",
         "Passes.cpp",
     ],
diff --git a/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
index 074b92c..20ae473 100644
--- a/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
+++ b/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
@@ -24,6 +24,7 @@
     "LLVMGPULowerExecutableTarget.cpp"
     "LLVMGPURemoveTrivialLoops.cpp"
     "LLVMGPUTileAndDistribute.cpp"
+    "LLVMGPUVectorLowering.cpp"
     "LLVMGPUVectorization.cpp"
     "Passes.cpp"
   DEPS
diff --git a/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp b/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp
new file mode 100644
index 0000000..aac5f1c
--- /dev/null
+++ b/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp
@@ -0,0 +1,48 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/PassDetail.h"
+#include "iree/compiler/Codegen/Passes.h"
+#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/VectorTransforms.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/Passes.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+//====---------------------------------------------------------------------===//
+// Patterns for late vector op lowering.
+//====---------------------------------------------------------------------===//
+
+namespace {
+struct LLVMGPUVectorLoweringPass
+    : public LLVMGPUVectorLoweringBase<LLVMGPUVectorLoweringPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<vector::VectorDialect>();
+  }
+  void runOnOperation() override {
+    FuncOp funcOp = getOperation();
+    RewritePatternSet vectorToLoopsPatterns(&getContext());
+    VectorTransferToSCFOptions vectorToSCFOptions;
+    vectorToSCFOptions.setUnroll(true);
+    populateVectorToSCFConversionPatterns(vectorToLoopsPatterns,
+                                          vectorToSCFOptions);
+    memref::populateFoldSubViewOpPatterns(vectorToLoopsPatterns);
+    vector::populateVectorTransferLoweringPatterns(vectorToLoopsPatterns);
+    (void)applyPatternsAndFoldGreedily(funcOp,
+                                       std::move(vectorToLoopsPatterns));
+  }
+};
+}  // namespace
+
+std::unique_ptr<OperationPass<FuncOp>> createLLVMGPUVectorLoweringPass() {
+  return std::make_unique<LLVMGPUVectorLoweringPass>();
+}
+
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorization.cpp b/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorization.cpp
index 7bef525..cdf55d6 100644
--- a/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorization.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorization.cpp
@@ -99,17 +99,22 @@
       populateVectorizationPatterns(vectorizationPatterns);
       (void)applyPatternsAndFoldGreedily(funcOp,
                                          std::move(vectorizationPatterns));
-    }
-    // TODO: This should be a folding of Add into Contract in core but while
-    // they live in different dialects, it is not possible without unnatural
-    // dependencies.
-    funcOp.walk([&](Operation *op) {
-      if (auto contract = canonicalizeContractionAdd(op))
-        op->replaceAllUsesWith(contract);
-    });
+      // TODO: This should be a folding of Add into Contract in core but while
+      // they live in different dialects, it is not possible without unnatural
+      // dependencies.
+      funcOp.walk([&](Operation *op) {
+        if (auto contract = canonicalizeContractionAdd(op))
+          op->replaceAllUsesWith(contract);
+      });
 
+      RewritePatternSet vectorUnrollPatterns(context);
+      populateVectorUnrollPatterns(vectorUnrollPatterns);
+      (void)applyPatternsAndFoldGreedily(funcOp,
+                                         std::move(vectorUnrollPatterns));
+      linalg::hoistRedundantVectorTransfers(funcOp);
+    }
     {
-      // Lower transfer op to canonical form.
+      // Step 2. Lower transfer op to canonical form.
       RewritePatternSet lowerTransferOpPatterns(funcOp.getContext());
       vector::populateVectorToVectorCanonicalizationPatterns(
           lowerTransferOpPatterns);
@@ -126,24 +131,21 @@
       (void)applyPatternsAndFoldGreedily(funcOp,
                                          std::move(vectorUnrollPatterns));
 
-      RewritePatternSet canonicalizationPatterns1(funcOp.getContext());
+      RewritePatternSet canonicalizationPatterns(funcOp.getContext());
       vector::populateVectorToVectorCanonicalizationPatterns(
-          canonicalizationPatterns1);
+          canonicalizationPatterns);
       (void)applyPatternsAndFoldGreedily(funcOp,
-                                         std::move(canonicalizationPatterns1));
-
-      linalg::hoistRedundantVectorTransfers(funcOp);
+                                         std::move(canonicalizationPatterns));
     }
     {
-      // Step 3. Canonicalize the transfer ops generated.
-      RewritePatternSet vectorToLoopsPatterns(context);
-      VectorTransferToSCFOptions vectorToSCFOptions;
-      vectorToSCFOptions.setUnroll(true);
-      populateVectorToSCFConversionPatterns(vectorToLoopsPatterns,
-                                            vectorToSCFOptions);
-      memref::populateFoldSubViewOpPatterns(vectorToLoopsPatterns);
+      // Step 3. Lower contract op to outer product.
+      RewritePatternSet contractLoweringPatterns(funcOp.getContext());
+      vector::populateVectorContractLoweringPatterns(
+          contractLoweringPatterns,
+          vector::VectorTransformsOptions().setVectorTransformsOptions(
+              vector::VectorContractLowering::OuterProduct));
       (void)applyPatternsAndFoldGreedily(funcOp,
-                                         std::move(vectorToLoopsPatterns));
+                                         std::move(contractLoweringPatterns));
     }
   }
 };
diff --git a/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index ec13ad3..c64e6fa 100644
--- a/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -49,6 +49,7 @@
   pm.addNestedPass<FuncOp>(createLLVMGPUVectorizationPass());
   pm.addNestedPass<FuncOp>(createCanonicalizerPass());
   pm.addNestedPass<FuncOp>(createCSEPass());
+  pm.addNestedPass<FuncOp>(createOptimizeVectorTransferPass());
 }
 
 void addGPUSimpleDistributePassPipeline(OpPassManager &pm) {
@@ -91,6 +92,7 @@
   pm.addNestedPass<FuncOp>(createCanonicalizerPass());
   pm.addNestedPass<FuncOp>(createCSEPass());
 
+  pm.addNestedPass<FuncOp>(createLLVMGPUVectorLoweringPass());
   pm.addPass(createLowerAffinePass());
 
   // Strip out the debug info for the kernel as CUDA driver doesn't diggest PTX
diff --git a/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir b/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir
index 4e2de4f..73bd058 100644
--- a/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir
+++ b/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir
@@ -101,6 +101,7 @@
 
 //   CHECK-LABEL: hal.executable @dot_dispatch_0
 //         CHECK:   hal.executable.variant @cuda
+//     CHECK-NOT:   llvm.store
 // CHECK-COUNT-2:   llvm.load {{.*}} : !llvm.ptr<vector<4xf32>>
 //         CHECK:   llvm.br
 // CHECK-COUNT-6:   llvm.load {{.*}} : !llvm.ptr<vector<4xf32>, 3>
diff --git a/iree/compiler/Codegen/LLVMGPU/test/vectorization.mlir b/iree/compiler/Codegen/LLVMGPU/test/vectorization.mlir
index f0574b6..480d0c4 100644
--- a/iree/compiler/Codegen/LLVMGPU/test/vectorization.mlir
+++ b/iree/compiler/Codegen/LLVMGPU/test/vectorization.mlir
@@ -50,7 +50,13 @@
   return
 }
 // CHECK-LABEL: func @add_dispatch_0()
-//       CHECK:   vector.transfer_read {{.*}} : memref<1024x1024x1024xf32>, vector<4xf32>
-//       CHECK:   vector.transfer_read {{.*}} : memref<1024x1024x1024xf32>, vector<4xf32>
-//       CHECK:   addf %{{.*}}, %{{.*}}  : vector<1x1x4xf32>
-//       CHECK:   vector.transfer_write {{.*}} : vector<4xf32>, memref<1024x1024x1024xf32>
+//       CHECK:   scf.for
+//       CHECK:     scf.for
+//       CHECK:       scf.for
+//       CHECK:         scf.for
+//       CHECK:           scf.for
+//       CHECK:             scf.for
+//       CHECK:               vector.transfer_read {{.*}} : memref<1x1x4xf32, {{.*}}>, vector<1x1x4xf32>
+//       CHECK:               vector.transfer_read {{.*}} : memref<1x1x4xf32, {{.*}}>, vector<1x1x4xf32>
+//       CHECK:               addf %{{.*}}, %{{.*}}  : vector<1x1x4xf32>
+//       CHECK:               vector.transfer_write {{.*}} : vector<1x1x4xf32>, memref<1x1x4xf32, {{.*}}>
diff --git a/iree/compiler/Codegen/Passes.h b/iree/compiler/Codegen/Passes.h
index fc69b16..6d9ae4b 100644
--- a/iree/compiler/Codegen/Passes.h
+++ b/iree/compiler/Codegen/Passes.h
@@ -233,6 +233,9 @@
 /// Convert Linalg ops to Vector.
 std::unique_ptr<OperationPass<FuncOp>> createLLVMGPUVectorizationPass();
 
+/// Lower vector ops before convertion to LLVM.
+std::unique_ptr<OperationPass<FuncOp>> createLLVMGPUVectorLoweringPass();
+
 //------------------------------------------------------------------------------
 // SPIRV Passes
 //------------------------------------------------------------------------------
diff --git a/iree/compiler/Codegen/Passes.td b/iree/compiler/Codegen/Passes.td
index 3716f0e..328515c 100644
--- a/iree/compiler/Codegen/Passes.td
+++ b/iree/compiler/Codegen/Passes.td
@@ -187,6 +187,12 @@
   let constructor = "mlir::iree_compiler::createLLVMGPUVectorizationPass()";
 }
 
+def LLVMGPUVectorLowering :
+    Pass<"iree-llvmgpu-vector-lowering", "FuncOp"> {
+  let summary = "Pass to lower Vector ops before conversion to LLVM.";
+  let constructor = "mlir::iree_compiler::createLLVMGPUVectorLoweringPass()";
+}
+
 //------------------------------------------------------------------------------
 // SPIRV
 //------------------------------------------------------------------------------