Plumb sort op support through LLVM CPU pipeline. (#6378)
Fixes https://github.com/google/iree/issues/6154
diff --git a/iree/compiler/Codegen/LLVMCPU/BUILD b/iree/compiler/Codegen/LLVMCPU/BUILD
index 746369e..6863ab2 100644
--- a/iree/compiler/Codegen/LLVMCPU/BUILD
+++ b/iree/compiler/Codegen/LLVMCPU/BUILD
@@ -36,6 +36,7 @@
"//iree/compiler/Dialect/HAL/IR",
"//iree/compiler/Dialect/HAL/IR:HALDialect",
"//iree/compiler/Dialect/IREE/IR",
+ "//iree/compiler/Dialect/LinalgExt/Transforms",
"//iree/compiler/Dialect/Shape/IR",
"//iree/compiler/Dialect/Shape/Transforms",
"@llvm-project//llvm:Support",
diff --git a/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt b/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
index 78ced25..3f61527 100644
--- a/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
+++ b/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
@@ -60,6 +60,7 @@
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::HAL::IR::HALDialect
iree::compiler::Dialect::IREE::IR
+ iree::compiler::Dialect::LinalgExt::Transforms
iree::compiler::Dialect::Shape::IR
iree::compiler::Dialect::Shape::Transforms
PUBLIC
diff --git a/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index c4bf585..2ec5db7 100644
--- a/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -7,6 +7,7 @@
#include "iree/compiler/Codegen/Passes.h"
#include "iree/compiler/Codegen/PassDetail.h"
+#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Dialect/Linalg/Passes.h"
@@ -73,6 +74,9 @@
static void addLowerToLLVMPasses(
OpPassManager &passManager,
const LLVMCPUCodegenPassPipelineOptions &options) {
+ // LinalgExt -> SCF
+ passManager.addNestedPass<FuncOp>(linalg_ext::createLinalgExtToLoopsPass());
+
// Linalg -> SCF
passManager.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());
passManager.addNestedPass<FuncOp>(createCanonicalizerPass());
diff --git a/iree/compiler/InputConversion/MHLO/Passes.cpp b/iree/compiler/InputConversion/MHLO/Passes.cpp
index 2866a4e..ebf3147 100644
--- a/iree/compiler/InputConversion/MHLO/Passes.cpp
+++ b/iree/compiler/InputConversion/MHLO/Passes.cpp
@@ -52,10 +52,6 @@
passManager.addNestedPass<FuncOp>(createMHLOToMHLOPreprocessingPass());
- // Perform initial cleanup.
- passManager.addNestedPass<FuncOp>(mlir::createCanonicalizerPass());
- passManager.addNestedPass<FuncOp>(mlir::createCSEPass());
-
// Legalize input types. We do this after flattening tuples so that we don't
// have to deal with them.
// TODO(nicolasvasilache): createLegalizeInputTypesPass is old and does not
@@ -63,8 +59,15 @@
// when using ops with regions such as scf.for and linalg.generic.
passManager.addPass(mlir::iree_compiler::createLegalizeInputTypesPass());
+ // Perform initial cleanup. createLegalizeInputTypes could rewrite types. In
+ // this context, some operations could be folded away.
+ passManager.addNestedPass<FuncOp>(mlir::createCanonicalizerPass());
+ passManager.addNestedPass<FuncOp>(mlir::createCSEPass());
+
// Convert to Linalg. After this point, MHLO will be eliminated.
passManager.addNestedPass<FuncOp>(
+ mlir::iree_compiler::createConvertAndDistributeMHLOToLinalgExtPass());
+ passManager.addNestedPass<FuncOp>(
mlir::iree_compiler::createMHLOToLinalgOnTensorsPass());
// Note that some MHLO ops are left by the above and must resolve via
diff --git a/iree/test/e2e/xla_ops/BUILD b/iree/test/e2e/xla_ops/BUILD
index 744266f..2d089a6 100644
--- a/iree/test/e2e/xla_ops/BUILD
+++ b/iree/test/e2e/xla_ops/BUILD
@@ -139,6 +139,7 @@
"select.mlir",
"sine.mlir",
"slice.mlir",
+ "sort.mlir",
"sqrt.mlir",
"subtract.mlir",
"tanh.mlir",
@@ -149,7 +150,6 @@
include = ["*.mlir"],
exclude = [
"round.mlir",
- "sort.mlir",
],
),
compiler_flags = ["-iree-input-type=mhlo"],
diff --git a/iree/test/e2e/xla_ops/CMakeLists.txt b/iree/test/e2e/xla_ops/CMakeLists.txt
index 4813895..1d5bb4a 100644
--- a/iree/test/e2e/xla_ops/CMakeLists.txt
+++ b/iree/test/e2e/xla_ops/CMakeLists.txt
@@ -121,6 +121,7 @@
"select.mlir"
"sine.mlir"
"slice.mlir"
+ "sort.mlir"
"sqrt.mlir"
"subtract.mlir"
"tanh.mlir"