[CPU][ArmSME] Add convert-arith-to-arm-sme to the SME pipeline (#16409)
This fixes some breakages from #16350. Also, re-enable scalable
vectorization in the SVE and SME tests (that would have caught this
pre-merge).
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel
index 29f8e85..d88d21d 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel
@@ -115,6 +115,7 @@
"@llvm-project//mlir:AffineUtils",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:ArithDialect",
+ "@llvm-project//mlir:ArithToArmSME",
"@llvm-project//mlir:ArithToLLVM",
"@llvm-project//mlir:ArithTransforms",
"@llvm-project//mlir:ArmNeon2dToIntr",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
index 25bab50..048f11e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
@@ -92,6 +92,7 @@
MLIRAffineUtils
MLIRAnalysis
MLIRArithDialect
+ MLIRArithToArmSME
MLIRArithToLLVM
MLIRArithTransforms
MLIRArmNeon2dToIntr
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index 238319e..2f05877 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -12,6 +12,7 @@
#include "iree/compiler/Codegen/LLVMCPU/Passes.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
+#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
#include "mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h"
#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
@@ -636,7 +637,9 @@
}
if (enableAArch64SME) {
- // Lower vector operations to Arm SME operations.
+ // (Arith, Vector) -> ArmSME
+ passManager.addNestedPass<func::FuncOp>(
+ mlir::createArithToArmSMEConversionPass());
passManager.addNestedPass<func::FuncOp>(
mlir::createConvertVectorToArmSMEPass());
passManager.addNestedPass<func::FuncOp>(
diff --git a/tests/e2e/matmul/BUILD.bazel b/tests/e2e/matmul/BUILD.bazel
index 9f0117d..60dc03d 100644
--- a/tests/e2e/matmul/BUILD.bazel
+++ b/tests/e2e/matmul/BUILD.bazel
@@ -30,6 +30,7 @@
compiler_flags = [
"--iree-opt-data-tiling=false",
"--iree-llvmcpu-enable-ukernels=none",
+ "--iree-llvmcpu-enable-scalable-vectorization",
],
generator = ":generate_e2e_matmul_tests",
generator_args = [
@@ -67,6 +68,7 @@
name = "e2e_matmul_arm_sme_nondt_%s_%s" % (dtype, size),
compiler_flags = [
"--iree-opt-data-tiling=false",
+ "--iree-llvmcpu-enable-scalable-vectorization",
],
generator = ":generate_e2e_matmul_tests",
generator_args = [
diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt
index 6c03efa..7b515c5 100644
--- a/tests/e2e/matmul/CMakeLists.txt
+++ b/tests/e2e/matmul/CMakeLists.txt
@@ -27,6 +27,7 @@
"local-task"
COMPILER_FLAGS
"--iree-opt-data-tiling=false"
+ "--iree-llvmcpu-enable-scalable-vectorization"
TARGET_CPU_FEATURES_VARIANTS
"arm_64:sme:+sve,+sme"
)
@@ -48,6 +49,7 @@
"local-task"
COMPILER_FLAGS
"--iree-opt-data-tiling=false"
+ "--iree-llvmcpu-enable-scalable-vectorization"
TARGET_CPU_FEATURES_VARIANTS
"arm_64:sme:+sve,+sme"
)