Add VectorToSPIRV patterns to ConvertToSPIRVPass (#3593)
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/BUILD b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
index cd0d823..f5fe4d9 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/BUILD
+++ b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
@@ -76,6 +76,7 @@
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:VectorOps",
+ "@llvm-project//mlir:VectorToSPIRV",
"@org_tensorflow//tensorflow/compiler/mlir/hlo",
"@org_tensorflow//tensorflow/compiler/mlir/hlo:legalize_to_linalg",
],
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
index a4b11b3..79578dd 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
@@ -62,6 +62,7 @@
MLIRSupport
MLIRTransforms
MLIRVector
+ MLIRVectorToSPIRV
iree::compiler::Conversion::CodegenUtils
iree::compiler::Conversion::HLOToHLO
iree::compiler::Conversion::HLOToLinalg
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
index 93b8809..508d0dc 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
@@ -30,6 +30,7 @@
#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h"
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
+#include "mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRV.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
@@ -402,6 +403,8 @@
patterns);
// Pull in standard patterns to convert arithmetic ops and others.
populateStandardToSPIRVPatterns(context, typeConverter, patterns);
+ // Pull in vector patterns to convert vector ops.
+ mlir::populateVectorToSPIRVPatterns(context, typeConverter, patterns);
// Pull in builtin func to spv.func conversion.
populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns);
auto &cooperativeMatrixAnalysis = getAnalysis<CooperativeMatrixAnalysis>();