Add matmul tile and vectorize lowering stratgy to LinalgToLLVM passes (#2608)
diff --git a/iree/compiler/Conversion/LinalgToLLVM/BUILD b/iree/compiler/Conversion/LinalgToLLVM/BUILD
index fcd9476..c6cafc1 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/BUILD
+++ b/iree/compiler/Conversion/LinalgToLLVM/BUILD
@@ -22,12 +22,14 @@
name = "LinalgToLLVM",
srcs = [
"ConvertToLLVM.cpp",
+ "MatMulVectorization.cpp",
"Passes.cpp",
],
hdrs = [
"Passes.h",
],
deps = [
+ "//iree/compiler/Conversion/CodegenUtils",
"//iree/compiler/Conversion/HLOToLinalg",
"//iree/compiler/Dialect/HAL/IR",
"//iree/compiler/Dialect/HAL/IR:HALDialect",
@@ -45,6 +47,7 @@
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:StandardOpsTransforms",
"@llvm-project//mlir:Transforms",
+ "@llvm-project//mlir:VectorOps",
"@llvm-project//mlir:VectorToLLVM",
"@llvm-project//mlir:VectorToSCF",
],
diff --git a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
index bc31e4e..fddc144 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
@@ -21,6 +21,7 @@
"Passes.h"
SRCS
"ConvertToLLVM.cpp"
+ "MatMulVectorization.cpp"
"Passes.cpp"
DEPS
MLIRAffineToStandard
@@ -34,8 +35,10 @@
MLIRStandardOpsTransforms
MLIRStandardToLLVM
MLIRTransforms
+ MLIRVector
MLIRVectorToLLVM
MLIRVectorToSCF
+ iree::compiler::Conversion::CodegenUtils
iree::compiler::Conversion::HLOToLinalg
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::HAL::IR::HALDialect
diff --git a/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp b/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp
index 76fb2c5..777fb02 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp
@@ -28,6 +28,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
@@ -308,6 +309,17 @@
} // namespace
void ConvertToLLVMPass::runOnOperation() {
+ // Vector -> Vector transformation is needed before we do any conversion to
+ // LLVM.
+ {
+ OwningRewritePatternList patterns;
+ vector::populateVectorToVectorCanonicalizationPatterns(patterns,
+ &getContext());
+ vector::populateVectorSlicesLoweringPatterns(patterns, &getContext());
+ vector::populateVectorContractLoweringPatterns(patterns, &getContext());
+ applyPatternsAndFoldGreedily(getOperation(), patterns);
+ }
+ //
auto module = getOperation();
LLVMTypeConverter converter(&getContext());
diff --git a/iree/compiler/Conversion/LinalgToLLVM/MatMulVectorization.cpp b/iree/compiler/Conversion/LinalgToLLVM/MatMulVectorization.cpp
new file mode 100644
index 0000000..a36358d
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToLLVM/MatMulVectorization.cpp
@@ -0,0 +1,98 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+static llvm::cl::opt<int> l1TileSize(
+ "iree-codegen-linalg-to-llvm-matmul-l1-tile-size",
+ llvm::cl::desc("Specify the size of L1 tile for matmul vector lowering"),
+ llvm::cl::init(4));
+
+static llvm::cl::opt<int> l2TileSize(
+ "iree-codegen-linalg-to-llvm-matmul-l2-tile-size",
+ llvm::cl::desc("Specify the size of L2 tile for matmul vector lowering"),
+ llvm::cl::init(32));
+
+static llvm::cl::opt<int> l3TileSize(
+ "iree-codegen-linalg-to-llvm-matmul-l3-tile-size",
+ llvm::cl::desc("Specify the size of L3 tile for matmul vector lowering"),
+ llvm::cl::init(64));
+
+static llvm::cl::opt<bool> unrollVectorTransfer(
+ "iree-codegen-linalg-to-llvm-matmul-unroll-vector-transfer",
+ llvm::cl::desc("If true vector transfers operation loop get unrolled."),
+ llvm::cl::init(true));
+
+static llvm::cl::opt<std::string> vectorOpLowering(
+ "iree-codegen-linalg-to-llvm-matmul-vector-op-lowerig",
+ llvm::cl::desc(
+ "Select the vector operation for lowering linalg.matmul, options : "
+ "{'outer_product', 'vector_contract', 'matrix_internsics'}"),
+ llvm::cl::init("outer_product"));
+
+namespace {
+struct MatMulTileAndVectorizePass
+ : PassWrapper<MatMulTileAndVectorizePass, FunctionPass> {
+ void runOnFunction() override;
+};
+} // namespace
+
+void MatMulTileAndVectorizePass::runOnFunction() {
+ FuncOp fn = getFunction();
+ MatmulCodegenStrategy strategy;
+ strategy
+ .tile<linalg::MatmulOp>(linalg::LinalgTilingOptions().setTileSizes(
+ {l3TileSize, l3TileSize, l3TileSize}))
+ .tile<linalg::MatmulOp>(linalg::LinalgTilingOptions().setTileSizes(
+ {l2TileSize, l2TileSize, l2TileSize}))
+ .tile<linalg::MatmulOp>(linalg::LinalgTilingOptions().setTileSizes(
+ {l1TileSize, l1TileSize, l1TileSize}))
+ .vectorize<linalg::MatmulOp>()
+ .setVectorTransferToSCFOptions(
+ VectorTransferToSCFOptions().setUnroll(unrollVectorTransfer));
+ if (vectorOpLowering == "outer_product") {
+ strategy.setVectorTransformsOptions(
+ vector::VectorTransformsOptions().setVectorTransformsOptions(
+ vector::VectorContractLowering::OuterProduct));
+ } else if (vectorOpLowering == "vector_contract") {
+ strategy.setVectorTransformsOptions(
+ vector::VectorTransformsOptions().setVectorTransformsOptions(
+ vector::VectorContractLowering::OuterProduct));
+ } else if (vectorOpLowering == "matrix_internsics") {
+ strategy.setVectorTransformsOptions(
+ vector::VectorTransformsOptions().setVectorTransformsOptions(
+ vector::VectorContractLowering::OuterProduct));
+ } else {
+ signalPassFailure();
+ }
+ strategy.setDefaultCPULowering();
+ strategy.transform(fn);
+}
+
+std::unique_ptr<FunctionPass> createMatMulTileAndVectorizePass() {
+ return std::make_unique<MatMulTileAndVectorizePass>();
+}
+
+static PassRegistration<MatMulTileAndVectorizePass> pass(
+ "iree-codegen-linalg-to-llvm-matmul-vectorization-pass",
+ "Tile and vectorize linalg.matmul operation",
+ [] { return std::make_unique<MatMulTileAndVectorizePass>(); });
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
index 8c8eb21..8631cdf 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
@@ -25,6 +25,8 @@
namespace iree_compiler {
void addLinalgToLLVMPasses(OpPassManager &passManager) {
+ // Linalg -> Vectors Ops.
+ passManager.addPass(createMatMulTileAndVectorizePass());
// Linalg -> SCF
passManager.addPass(createConvertLinalgToLoopsPass());
passManager.addPass(createCanonicalizerPass());
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.h b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
index fdad0e6..2a4db8c 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/Passes.h
+++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
@@ -20,6 +20,9 @@
namespace mlir {
namespace iree_compiler {
+/// Converts linalg::MatmulOp into LLVM dialect
+std::unique_ptr<FunctionPass> createMatMulTileAndVectorizePass();
+
/// Pass to perform final conversion to LLVM dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertToLLVMPass();
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/convert_to_llvm.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/convert_to_llvm.mlir
index d362688..0a8034f 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/test/convert_to_llvm.mlir
+++ b/iree/compiler/Conversion/LinalgToLLVM/test/convert_to_llvm.mlir
@@ -1,12 +1,14 @@
// RUN: iree-opt -iree-codegen-convert-to-llvm -split-input-file %s | IreeFileCheck %s
-func @convert_dynamic_shape() {
+func @convert_dynamic_shape() -> f32 {
+ %c0 = constant 0 : index
%0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x?xf32>
%1 = hal.interface.load.constant offset = 0 : index
%2 = hal.interface.load.constant offset = 1 : index
%3 = shapex.make_ranked_shape %1, %2 : (index, index) -> !shapex.ranked_shape<[?,?]>
%6 = shapex.tie_shape %0, %3 : memref<?x?xf32>, !shapex.ranked_shape<[?,?]>
- return
+ %7 = load %6[%c0, %c0] : memref<?x?xf32>
+ return %7 : f32
}
hal.interface @legacy_io attributes {push_constants = 2 : i32, sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/matmul_vectorization.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/matmul_vectorization.mlir
new file mode 100644
index 0000000..7d47f47
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToLLVM/test/matmul_vectorization.mlir
@@ -0,0 +1,31 @@
+// RUN: iree-opt --iree-codegen-linalg-to-llvm-matmul-vectorization-pass -split-input-file %s | IreeFileCheck %s
+
+// CHECK-LABEL: func @matmul_128x128x128
+// CHECK-SAME: (%[[ARG0:.+]]: memref<128x128xf32>, %[[ARG1:.+]]: memref<128x128xf32>, %[[ARG2:.+]]: memref<128x128xf32>)
+func @matmul_128x128x128(%arg0 : memref<128x128xf32>, %arg1: memref<128x128xf32>, %arg2: memref<128x128xf32>) {
+ linalg.matmul %arg0, %arg1, %arg2 : (memref<128x128xf32>, memref<128x128xf32>, memref<128x128xf32>)
+ return
+}
+// CHECK: %[[L3END:.+]] = constant 128 : index
+// CHECK: %[[L3STEP:.+]] = constant 64 : index
+// CHECK: %[[L1STEP:.+]] = constant 4 : index
+// CHECK: %[[L2STEP:.+]] = constant 32 : index
+// CHECK: %[[START:.+]] = constant 0 : index
+// CHECK: scf.for %[[IL3:.+]] = %[[START]] to %[[L3END]] step %[[L3STEP]]
+// CHECK: scf.for %[[JL3:.+]] = %[[START]] to %[[L3END]] step %[[L3STEP]]
+// CHECK: scf.for %[[KL3:.+]] = %[[START]] to %[[L3END]] step %[[L3STEP]]
+// CHECK: %[[ARG0_TILE_L3:.+]] = subview %[[ARG0]][%[[IL3]], %[[KL3]]] [64, 64] [1, 1] : memref<128x128xf32> to memref<64x64xf32
+// CHECK: %[[ARG1_TILE_L3:.+]] = subview %[[ARG1]][%[[KL3]], %[[JL3]]] [64, 64] [1, 1] : memref<128x128xf32> to memref<64x64xf32
+// CHECK: %[[ARG2_TILE_L3:.+]] = subview %[[ARG2]][%[[IL3]], %[[JL3]]] [64, 64] [1, 1] : memref<128x128xf32> to memref<64x64xf32
+// CHECK: scf.for %[[IL2:.+]] = %[[START]] to %[[L3STEP]] step %[[L2STEP]]
+// CHECK: scf.for %[[JL2:.+]] = %[[START]] to %[[L3STEP]] step %[[L2STEP]]
+// CHECK: scf.for %[[KL2:.+]] = %[[START]] to %[[L3STEP]] step %[[L2STEP]]
+// CHECK: %[[ARG0_TILE_L2:.+]] = subview %[[ARG0_TILE_L3]][%[[IL2]], %[[KL2]]] [32, 32] [1, 1] : memref<64x64xf32
+// CHECK: %[[ARG1_TILE_L2:.+]] = subview %[[ARG1_TILE_L3]][%[[KL2]], %[[JL2]]] [32, 32] [1, 1] : memref<64x64xf32
+// CHECK: %[[ARG2_TILE_L2:.+]] = subview %[[ARG2_TILE_L3]][%[[IL2]], %[[JL2]]] [32, 32] [1, 1] : memref<64x64xf32
+// CHECK: scf.for %[[IL1:.+]] = %[[START]] to %[[L2STEP]] step %[[L1STEP]]
+// CHECK: scf.for %[[JL1:.+]] = %[[START]] to %[[L2STEP]] step %[[L1STEP]]
+// CHECK: scf.for %[[KL1:.+]] = %[[START]] to %[[L2STEP]] step %[[L1STEP]]
+// CHECK: %[[ARG0_TILE_L1:.+]] = subview %[[ARG0_TILE_L2]][%[[IL1]], %[[KL1]]] [4, 4] [1, 1] : memref<32x32xf32
+// CHECK: %[[ARG1_TILE_L1:.+]] = subview %[[ARG1_TILE_L2]][%[[KL1]], %[[JL1]]] [4, 4] [1, 1] : memref<32x32xf32
+// CHECK: %[[ARG2_TILE_L1:.+]] = subview %[[ARG2_TILE_L2]][%[[IL1]], %[[JL1]]] [4, 4] [1, 1] : memref<32x32xf32