Add matmul tile and vectorize lowering stratgy to LinalgToLLVM passes (#2608)

diff --git a/iree/compiler/Conversion/LinalgToLLVM/BUILD b/iree/compiler/Conversion/LinalgToLLVM/BUILD
index fcd9476..c6cafc1 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/BUILD
+++ b/iree/compiler/Conversion/LinalgToLLVM/BUILD
@@ -22,12 +22,14 @@
     name = "LinalgToLLVM",
     srcs = [
         "ConvertToLLVM.cpp",
+        "MatMulVectorization.cpp",
         "Passes.cpp",
     ],
     hdrs = [
         "Passes.h",
     ],
     deps = [
+        "//iree/compiler/Conversion/CodegenUtils",
         "//iree/compiler/Conversion/HLOToLinalg",
         "//iree/compiler/Dialect/HAL/IR",
         "//iree/compiler/Dialect/HAL/IR:HALDialect",
@@ -45,6 +47,7 @@
         "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:StandardOpsTransforms",
         "@llvm-project//mlir:Transforms",
+        "@llvm-project//mlir:VectorOps",
         "@llvm-project//mlir:VectorToLLVM",
         "@llvm-project//mlir:VectorToSCF",
     ],
diff --git a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
index bc31e4e..fddc144 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
@@ -21,6 +21,7 @@
     "Passes.h"
   SRCS
     "ConvertToLLVM.cpp"
+    "MatMulVectorization.cpp"
     "Passes.cpp"
   DEPS
     MLIRAffineToStandard
@@ -34,8 +35,10 @@
     MLIRStandardOpsTransforms
     MLIRStandardToLLVM
     MLIRTransforms
+    MLIRVector
     MLIRVectorToLLVM
     MLIRVectorToSCF
+    iree::compiler::Conversion::CodegenUtils
     iree::compiler::Conversion::HLOToLinalg
     iree::compiler::Dialect::HAL::IR
     iree::compiler::Dialect::HAL::IR::HALDialect
diff --git a/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp b/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp
index 76fb2c5..777fb02 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp
@@ -28,6 +28,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
@@ -308,6 +309,17 @@
 }  // namespace
 
 void ConvertToLLVMPass::runOnOperation() {
+  // Vector -> Vector transformation is needed before we do any conversion to
+  // LLVM.
+  {
+    OwningRewritePatternList patterns;
+    vector::populateVectorToVectorCanonicalizationPatterns(patterns,
+                                                           &getContext());
+    vector::populateVectorSlicesLoweringPatterns(patterns, &getContext());
+    vector::populateVectorContractLoweringPatterns(patterns, &getContext());
+    applyPatternsAndFoldGreedily(getOperation(), patterns);
+  }
+  //
   auto module = getOperation();
 
   LLVMTypeConverter converter(&getContext());
diff --git a/iree/compiler/Conversion/LinalgToLLVM/MatMulVectorization.cpp b/iree/compiler/Conversion/LinalgToLLVM/MatMulVectorization.cpp
new file mode 100644
index 0000000..a36358d
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToLLVM/MatMulVectorization.cpp
@@ -0,0 +1,98 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+static llvm::cl::opt<int> l1TileSize(
+    "iree-codegen-linalg-to-llvm-matmul-l1-tile-size",
+    llvm::cl::desc("Specify the size of L1 tile for matmul vector lowering"),
+    llvm::cl::init(4));
+
+static llvm::cl::opt<int> l2TileSize(
+    "iree-codegen-linalg-to-llvm-matmul-l2-tile-size",
+    llvm::cl::desc("Specify the size of L2 tile for matmul vector lowering"),
+    llvm::cl::init(32));
+
+static llvm::cl::opt<int> l3TileSize(
+    "iree-codegen-linalg-to-llvm-matmul-l3-tile-size",
+    llvm::cl::desc("Specify the size of L3 tile for matmul vector lowering"),
+    llvm::cl::init(64));
+
+static llvm::cl::opt<bool> unrollVectorTransfer(
+    "iree-codegen-linalg-to-llvm-matmul-unroll-vector-transfer",
+    llvm::cl::desc("If true vector transfers operation loop get unrolled."),
+    llvm::cl::init(true));
+
+static llvm::cl::opt<std::string> vectorOpLowering(
+    "iree-codegen-linalg-to-llvm-matmul-vector-op-lowerig",
+    llvm::cl::desc(
+        "Select the vector operation for lowering linalg.matmul, options : "
+        "{'outer_product', 'vector_contract', 'matrix_internsics'}"),
+    llvm::cl::init("outer_product"));
+
+namespace {
+struct MatMulTileAndVectorizePass
+    : PassWrapper<MatMulTileAndVectorizePass, FunctionPass> {
+  void runOnFunction() override;
+};
+}  // namespace
+
+void MatMulTileAndVectorizePass::runOnFunction() {
+  FuncOp fn = getFunction();
+  MatmulCodegenStrategy strategy;
+  strategy
+      .tile<linalg::MatmulOp>(linalg::LinalgTilingOptions().setTileSizes(
+          {l3TileSize, l3TileSize, l3TileSize}))
+      .tile<linalg::MatmulOp>(linalg::LinalgTilingOptions().setTileSizes(
+          {l2TileSize, l2TileSize, l2TileSize}))
+      .tile<linalg::MatmulOp>(linalg::LinalgTilingOptions().setTileSizes(
+          {l1TileSize, l1TileSize, l1TileSize}))
+      .vectorize<linalg::MatmulOp>()
+      .setVectorTransferToSCFOptions(
+          VectorTransferToSCFOptions().setUnroll(unrollVectorTransfer));
+  if (vectorOpLowering == "outer_product") {
+    strategy.setVectorTransformsOptions(
+        vector::VectorTransformsOptions().setVectorTransformsOptions(
+            vector::VectorContractLowering::OuterProduct));
+  } else if (vectorOpLowering == "vector_contract") {
+    strategy.setVectorTransformsOptions(
+        vector::VectorTransformsOptions().setVectorTransformsOptions(
+            vector::VectorContractLowering::OuterProduct));
+  } else if (vectorOpLowering == "matrix_internsics") {
+    strategy.setVectorTransformsOptions(
+        vector::VectorTransformsOptions().setVectorTransformsOptions(
+            vector::VectorContractLowering::OuterProduct));
+  } else {
+    signalPassFailure();
+  }
+  strategy.setDefaultCPULowering();
+  strategy.transform(fn);
+}
+
+std::unique_ptr<FunctionPass> createMatMulTileAndVectorizePass() {
+  return std::make_unique<MatMulTileAndVectorizePass>();
+}
+
+static PassRegistration<MatMulTileAndVectorizePass> pass(
+    "iree-codegen-linalg-to-llvm-matmul-vectorization-pass",
+    "Tile and vectorize linalg.matmul operation",
+    [] { return std::make_unique<MatMulTileAndVectorizePass>(); });
+
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
index 8c8eb21..8631cdf 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
@@ -25,6 +25,8 @@
 namespace iree_compiler {
 
 void addLinalgToLLVMPasses(OpPassManager &passManager) {
+  // Linalg -> Vectors Ops.
+  passManager.addPass(createMatMulTileAndVectorizePass());
   // Linalg -> SCF
   passManager.addPass(createConvertLinalgToLoopsPass());
   passManager.addPass(createCanonicalizerPass());
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.h b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
index fdad0e6..2a4db8c 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/Passes.h
+++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
@@ -20,6 +20,9 @@
 namespace mlir {
 namespace iree_compiler {
 
+/// Converts linalg::MatmulOp into LLVM dialect
+std::unique_ptr<FunctionPass> createMatMulTileAndVectorizePass();
+
 /// Pass to perform final conversion to LLVM dialect.
 std::unique_ptr<OperationPass<ModuleOp>> createConvertToLLVMPass();
 
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/convert_to_llvm.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/convert_to_llvm.mlir
index d362688..0a8034f 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/test/convert_to_llvm.mlir
+++ b/iree/compiler/Conversion/LinalgToLLVM/test/convert_to_llvm.mlir
@@ -1,12 +1,14 @@
 // RUN: iree-opt -iree-codegen-convert-to-llvm -split-input-file %s | IreeFileCheck %s
 
-func @convert_dynamic_shape() {
+func @convert_dynamic_shape() -> f32 {
+  %c0 = constant 0 : index
   %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x?xf32>
   %1 = hal.interface.load.constant offset = 0 : index
   %2 = hal.interface.load.constant offset = 1 : index
   %3 = shapex.make_ranked_shape %1, %2 : (index, index) -> !shapex.ranked_shape<[?,?]>
   %6 = shapex.tie_shape %0, %3 : memref<?x?xf32>, !shapex.ranked_shape<[?,?]>
-  return
+  %7 = load %6[%c0, %c0] : memref<?x?xf32>
+  return %7 : f32
 }
 hal.interface @legacy_io attributes {push_constants = 2 : i32, sym_visibility = "private"} {
     hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/matmul_vectorization.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/matmul_vectorization.mlir
new file mode 100644
index 0000000..7d47f47
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToLLVM/test/matmul_vectorization.mlir
@@ -0,0 +1,31 @@
+// RUN: iree-opt --iree-codegen-linalg-to-llvm-matmul-vectorization-pass -split-input-file %s | IreeFileCheck %s
+
+// CHECK-LABEL: func @matmul_128x128x128
+// CHECK-SAME: (%[[ARG0:.+]]: memref<128x128xf32>, %[[ARG1:.+]]: memref<128x128xf32>, %[[ARG2:.+]]: memref<128x128xf32>)
+func @matmul_128x128x128(%arg0 : memref<128x128xf32>, %arg1: memref<128x128xf32>, %arg2: memref<128x128xf32>) {
+    linalg.matmul %arg0, %arg1, %arg2 : (memref<128x128xf32>, memref<128x128xf32>, memref<128x128xf32>)
+    return
+}
+// CHECK: %[[L3END:.+]] = constant 128 : index
+// CHECK: %[[L3STEP:.+]] = constant 64 : index
+// CHECK: %[[L1STEP:.+]] = constant 4 : index
+// CHECK: %[[L2STEP:.+]] = constant 32 : index
+// CHECK: %[[START:.+]] = constant 0 : index
+// CHECK: scf.for %[[IL3:.+]] = %[[START]] to %[[L3END]] step %[[L3STEP]]
+// CHECK: scf.for %[[JL3:.+]] = %[[START]] to %[[L3END]] step %[[L3STEP]]
+// CHECK: scf.for %[[KL3:.+]] = %[[START]] to %[[L3END]] step %[[L3STEP]]
+// CHECK: %[[ARG0_TILE_L3:.+]] = subview %[[ARG0]][%[[IL3]], %[[KL3]]] [64, 64] [1, 1] : memref<128x128xf32> to memref<64x64xf32
+// CHECK: %[[ARG1_TILE_L3:.+]] = subview %[[ARG1]][%[[KL3]], %[[JL3]]] [64, 64] [1, 1] : memref<128x128xf32> to memref<64x64xf32
+// CHECK: %[[ARG2_TILE_L3:.+]] = subview %[[ARG2]][%[[IL3]], %[[JL3]]] [64, 64] [1, 1] : memref<128x128xf32> to memref<64x64xf32
+// CHECK: scf.for %[[IL2:.+]] = %[[START]] to %[[L3STEP]] step %[[L2STEP]]
+// CHECK: scf.for %[[JL2:.+]] = %[[START]] to %[[L3STEP]] step %[[L2STEP]]
+// CHECK: scf.for %[[KL2:.+]] = %[[START]] to %[[L3STEP]] step %[[L2STEP]]
+// CHECK: %[[ARG0_TILE_L2:.+]] = subview %[[ARG0_TILE_L3]][%[[IL2]], %[[KL2]]] [32, 32] [1, 1] : memref<64x64xf32
+// CHECK: %[[ARG1_TILE_L2:.+]] = subview %[[ARG1_TILE_L3]][%[[KL2]], %[[JL2]]] [32, 32] [1, 1] : memref<64x64xf32
+// CHECK: %[[ARG2_TILE_L2:.+]] = subview %[[ARG2_TILE_L3]][%[[IL2]], %[[JL2]]] [32, 32] [1, 1] : memref<64x64xf32
+// CHECK: scf.for %[[IL1:.+]] = %[[START]] to %[[L2STEP]] step %[[L1STEP]]
+// CHECK: scf.for %[[JL1:.+]] = %[[START]] to %[[L2STEP]] step %[[L1STEP]]
+// CHECK: scf.for %[[KL1:.+]] = %[[START]] to %[[L2STEP]] step %[[L1STEP]]
+// CHECK: %[[ARG0_TILE_L1:.+]] = subview %[[ARG0_TILE_L2]][%[[IL1]], %[[KL1]]] [4, 4] [1, 1] : memref<32x32xf32
+// CHECK: %[[ARG1_TILE_L1:.+]] = subview %[[ARG1_TILE_L2]][%[[KL1]], %[[JL1]]] [4, 4] [1, 1] : memref<32x32xf32
+// CHECK: %[[ARG2_TILE_L1:.+]] = subview %[[ARG2_TILE_L2]][%[[IL1]], %[[JL1]]] [4, 4] [1, 1] : memref<32x32xf32