Add matmul tile and vectorize lowering stratgy to LinalgToLLVM passes (#2608)
diff --git a/iree/compiler/Conversion/LinalgToLLVM/BUILD b/iree/compiler/Conversion/LinalgToLLVM/BUILD index fcd9476..c6cafc1 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/BUILD +++ b/iree/compiler/Conversion/LinalgToLLVM/BUILD
@@ -22,12 +22,14 @@ name = "LinalgToLLVM", srcs = [ "ConvertToLLVM.cpp", + "MatMulVectorization.cpp", "Passes.cpp", ], hdrs = [ "Passes.h", ], deps = [ + "//iree/compiler/Conversion/CodegenUtils", "//iree/compiler/Conversion/HLOToLinalg", "//iree/compiler/Dialect/HAL/IR", "//iree/compiler/Dialect/HAL/IR:HALDialect", @@ -45,6 +47,7 @@ "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOpsTransforms", "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:VectorOps", "@llvm-project//mlir:VectorToLLVM", "@llvm-project//mlir:VectorToSCF", ],
diff --git a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt index bc31e4e..fddc144 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt +++ b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
@@ -21,6 +21,7 @@ "Passes.h" SRCS "ConvertToLLVM.cpp" + "MatMulVectorization.cpp" "Passes.cpp" DEPS MLIRAffineToStandard @@ -34,8 +35,10 @@ MLIRStandardOpsTransforms MLIRStandardToLLVM MLIRTransforms + MLIRVector MLIRVectorToLLVM MLIRVectorToSCF + iree::compiler::Conversion::CodegenUtils iree::compiler::Conversion::HLOToLinalg iree::compiler::Dialect::HAL::IR iree::compiler::Dialect::HAL::IR::HALDialect
diff --git a/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp b/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp index 76fb2c5..777fb02 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp +++ b/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp
@@ -28,6 +28,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/Transforms/Passes.h" +#include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Pass/Pass.h" namespace mlir { @@ -308,6 +309,17 @@ } // namespace void ConvertToLLVMPass::runOnOperation() { + // Vector -> Vector transformation is needed before we do any conversion to + // LLVM. + { + OwningRewritePatternList patterns; + vector::populateVectorToVectorCanonicalizationPatterns(patterns, + &getContext()); + vector::populateVectorSlicesLoweringPatterns(patterns, &getContext()); + vector::populateVectorContractLoweringPatterns(patterns, &getContext()); + applyPatternsAndFoldGreedily(getOperation(), patterns); + } + // auto module = getOperation(); LLVMTypeConverter converter(&getContext());
diff --git a/iree/compiler/Conversion/LinalgToLLVM/MatMulVectorization.cpp b/iree/compiler/Conversion/LinalgToLLVM/MatMulVectorization.cpp new file mode 100644 index 0000000..a36358d --- /dev/null +++ b/iree/compiler/Conversion/LinalgToLLVM/MatMulVectorization.cpp
@@ -0,0 +1,98 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" + +namespace mlir { +namespace iree_compiler { + +static llvm::cl::opt<int> l1TileSize( + "iree-codegen-linalg-to-llvm-matmul-l1-tile-size", + llvm::cl::desc("Specify the size of L1 tile for matmul vector lowering"), + llvm::cl::init(4)); + +static llvm::cl::opt<int> l2TileSize( + "iree-codegen-linalg-to-llvm-matmul-l2-tile-size", + llvm::cl::desc("Specify the size of L2 tile for matmul vector lowering"), + llvm::cl::init(32)); + +static llvm::cl::opt<int> l3TileSize( + "iree-codegen-linalg-to-llvm-matmul-l3-tile-size", + llvm::cl::desc("Specify the size of L3 tile for matmul vector lowering"), + llvm::cl::init(64)); + +static llvm::cl::opt<bool> unrollVectorTransfer( + "iree-codegen-linalg-to-llvm-matmul-unroll-vector-transfer", + llvm::cl::desc("If true vector transfers operation loop get unrolled."), + llvm::cl::init(true)); + +static llvm::cl::opt<std::string> vectorOpLowering( + "iree-codegen-linalg-to-llvm-matmul-vector-op-lowerig", + llvm::cl::desc( + "Select the vector operation for lowering linalg.matmul, options : " + "{'outer_product', 'vector_contract', 'matrix_internsics'}"), + llvm::cl::init("outer_product")); + +namespace { +struct MatMulTileAndVectorizePass + : PassWrapper<MatMulTileAndVectorizePass, FunctionPass> { + void runOnFunction() override; +}; +} // namespace + +void MatMulTileAndVectorizePass::runOnFunction() { + FuncOp fn = getFunction(); + MatmulCodegenStrategy strategy; + strategy + .tile<linalg::MatmulOp>(linalg::LinalgTilingOptions().setTileSizes( + {l3TileSize, l3TileSize, l3TileSize})) + .tile<linalg::MatmulOp>(linalg::LinalgTilingOptions().setTileSizes( + {l2TileSize, l2TileSize, l2TileSize})) + .tile<linalg::MatmulOp>(linalg::LinalgTilingOptions().setTileSizes( + {l1TileSize, l1TileSize, l1TileSize})) + .vectorize<linalg::MatmulOp>() + .setVectorTransferToSCFOptions( + VectorTransferToSCFOptions().setUnroll(unrollVectorTransfer)); + if (vectorOpLowering == "outer_product") { + strategy.setVectorTransformsOptions( + vector::VectorTransformsOptions().setVectorTransformsOptions( + vector::VectorContractLowering::OuterProduct)); + } else if (vectorOpLowering == "vector_contract") { + strategy.setVectorTransformsOptions( + vector::VectorTransformsOptions().setVectorTransformsOptions( + vector::VectorContractLowering::OuterProduct)); + } else if (vectorOpLowering == "matrix_internsics") { + strategy.setVectorTransformsOptions( + vector::VectorTransformsOptions().setVectorTransformsOptions( + vector::VectorContractLowering::OuterProduct)); + } else { + signalPassFailure(); + } + strategy.setDefaultCPULowering(); + strategy.transform(fn); +} + +std::unique_ptr<FunctionPass> createMatMulTileAndVectorizePass() { + return std::make_unique<MatMulTileAndVectorizePass>(); +} + +static PassRegistration<MatMulTileAndVectorizePass> pass( + "iree-codegen-linalg-to-llvm-matmul-vectorization-pass", + "Tile and vectorize linalg.matmul operation", + [] { return std::make_unique<MatMulTileAndVectorizePass>(); }); + +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp index 8c8eb21..8631cdf 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp +++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
@@ -25,6 +25,8 @@ namespace iree_compiler { void addLinalgToLLVMPasses(OpPassManager &passManager) { + // Linalg -> Vectors Ops. + passManager.addPass(createMatMulTileAndVectorizePass()); // Linalg -> SCF passManager.addPass(createConvertLinalgToLoopsPass()); passManager.addPass(createCanonicalizerPass());
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.h b/iree/compiler/Conversion/LinalgToLLVM/Passes.h index fdad0e6..2a4db8c 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/Passes.h +++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
@@ -20,6 +20,9 @@ namespace mlir { namespace iree_compiler { +/// Converts linalg::MatmulOp into LLVM dialect +std::unique_ptr<FunctionPass> createMatMulTileAndVectorizePass(); + /// Pass to perform final conversion to LLVM dialect. std::unique_ptr<OperationPass<ModuleOp>> createConvertToLLVMPass();
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/convert_to_llvm.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/convert_to_llvm.mlir index d362688..0a8034f 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/test/convert_to_llvm.mlir +++ b/iree/compiler/Conversion/LinalgToLLVM/test/convert_to_llvm.mlir
@@ -1,12 +1,14 @@ // RUN: iree-opt -iree-codegen-convert-to-llvm -split-input-file %s | IreeFileCheck %s -func @convert_dynamic_shape() { +func @convert_dynamic_shape() -> f32 { + %c0 = constant 0 : index %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x?xf32> %1 = hal.interface.load.constant offset = 0 : index %2 = hal.interface.load.constant offset = 1 : index %3 = shapex.make_ranked_shape %1, %2 : (index, index) -> !shapex.ranked_shape<[?,?]> %6 = shapex.tie_shape %0, %3 : memref<?x?xf32>, !shapex.ranked_shape<[?,?]> - return + %7 = load %6[%c0, %c0] : memref<?x?xf32> + return %7 : f32 } hal.interface @legacy_io attributes {push_constants = 2 : i32, sym_visibility = "private"} { hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/matmul_vectorization.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/matmul_vectorization.mlir new file mode 100644 index 0000000..7d47f47 --- /dev/null +++ b/iree/compiler/Conversion/LinalgToLLVM/test/matmul_vectorization.mlir
@@ -0,0 +1,31 @@ +// RUN: iree-opt --iree-codegen-linalg-to-llvm-matmul-vectorization-pass -split-input-file %s | IreeFileCheck %s + +// CHECK-LABEL: func @matmul_128x128x128 +// CHECK-SAME: (%[[ARG0:.+]]: memref<128x128xf32>, %[[ARG1:.+]]: memref<128x128xf32>, %[[ARG2:.+]]: memref<128x128xf32>) +func @matmul_128x128x128(%arg0 : memref<128x128xf32>, %arg1: memref<128x128xf32>, %arg2: memref<128x128xf32>) { + linalg.matmul %arg0, %arg1, %arg2 : (memref<128x128xf32>, memref<128x128xf32>, memref<128x128xf32>) + return +} +// CHECK: %[[L3END:.+]] = constant 128 : index +// CHECK: %[[L3STEP:.+]] = constant 64 : index +// CHECK: %[[L1STEP:.+]] = constant 4 : index +// CHECK: %[[L2STEP:.+]] = constant 32 : index +// CHECK: %[[START:.+]] = constant 0 : index +// CHECK: scf.for %[[IL3:.+]] = %[[START]] to %[[L3END]] step %[[L3STEP]] +// CHECK: scf.for %[[JL3:.+]] = %[[START]] to %[[L3END]] step %[[L3STEP]] +// CHECK: scf.for %[[KL3:.+]] = %[[START]] to %[[L3END]] step %[[L3STEP]] +// CHECK: %[[ARG0_TILE_L3:.+]] = subview %[[ARG0]][%[[IL3]], %[[KL3]]] [64, 64] [1, 1] : memref<128x128xf32> to memref<64x64xf32 +// CHECK: %[[ARG1_TILE_L3:.+]] = subview %[[ARG1]][%[[KL3]], %[[JL3]]] [64, 64] [1, 1] : memref<128x128xf32> to memref<64x64xf32 +// CHECK: %[[ARG2_TILE_L3:.+]] = subview %[[ARG2]][%[[IL3]], %[[JL3]]] [64, 64] [1, 1] : memref<128x128xf32> to memref<64x64xf32 +// CHECK: scf.for %[[IL2:.+]] = %[[START]] to %[[L3STEP]] step %[[L2STEP]] +// CHECK: scf.for %[[JL2:.+]] = %[[START]] to %[[L3STEP]] step %[[L2STEP]] +// CHECK: scf.for %[[KL2:.+]] = %[[START]] to %[[L3STEP]] step %[[L2STEP]] +// CHECK: %[[ARG0_TILE_L2:.+]] = subview %[[ARG0_TILE_L3]][%[[IL2]], %[[KL2]]] [32, 32] [1, 1] : memref<64x64xf32 +// CHECK: %[[ARG1_TILE_L2:.+]] = subview %[[ARG1_TILE_L3]][%[[KL2]], %[[JL2]]] [32, 32] [1, 1] : memref<64x64xf32 +// CHECK: %[[ARG2_TILE_L2:.+]] = subview %[[ARG2_TILE_L3]][%[[IL2]], %[[JL2]]] [32, 32] [1, 1] : memref<64x64xf32 +// CHECK: scf.for %[[IL1:.+]] = %[[START]] to %[[L2STEP]] step %[[L1STEP]] +// CHECK: scf.for %[[JL1:.+]] = %[[START]] to %[[L2STEP]] step %[[L1STEP]] +// CHECK: scf.for %[[KL1:.+]] = %[[START]] to %[[L2STEP]] step %[[L1STEP]] +// CHECK: %[[ARG0_TILE_L1:.+]] = subview %[[ARG0_TILE_L2]][%[[IL1]], %[[KL1]]] [4, 4] [1, 1] : memref<32x32xf32 +// CHECK: %[[ARG1_TILE_L1:.+]] = subview %[[ARG1_TILE_L2]][%[[KL1]], %[[JL1]]] [4, 4] [1, 1] : memref<32x32xf32 +// CHECK: %[[ARG2_TILE_L1:.+]] = subview %[[ARG2_TILE_L2]][%[[IL1]], %[[JL1]]] [4, 4] [1, 1] : memref<32x32xf32