Add LinalgExt -> loops transform to cuda backend. (#6389)
This also enables mhlo.sort testing on cuda backend.
diff --git a/iree/compiler/Codegen/LLVMGPU/BUILD b/iree/compiler/Codegen/LLVMGPU/BUILD
index b348c5f..0b72c14 100644
--- a/iree/compiler/Codegen/LLVMGPU/BUILD
+++ b/iree/compiler/Codegen/LLVMGPU/BUILD
@@ -34,6 +34,7 @@
"//iree/compiler/Codegen/Utils",
"//iree/compiler/Dialect/HAL/IR",
"//iree/compiler/Dialect/IREE/IR",
+ "//iree/compiler/Dialect/LinalgExt/Transforms",
"//iree/compiler/Dialect/Shape/Transforms",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:AffineToStandard",
diff --git a/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
index d02522f..23205b2 100644
--- a/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
+++ b/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
@@ -54,6 +54,7 @@
iree::compiler::Codegen::Utils
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::IREE::IR
+ iree::compiler::Dialect::LinalgExt::Transforms
iree::compiler::Dialect::Shape::Transforms
tensorflow::mlir_hlo
PUBLIC
diff --git a/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index acaa233..1884407 100644
--- a/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -7,6 +7,7 @@
#include "iree/compiler/Codegen/Passes.h"
#include "iree/compiler/Codegen/PassDetail.h"
+#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
@@ -75,6 +76,10 @@
pm.addNestedPass<ModuleOp>(createCanonicalizerPass());
pm.addNestedPass<ModuleOp>(createCSEPass());
+ // LinalgExt -> SCF
+ pm.nest<ModuleOp>().addNestedPass<FuncOp>(
+ linalg_ext::createLinalgExtToLoopsPass());
+
// Linalg -> SCF
pm.nest<ModuleOp>().addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());
pm.nest<ModuleOp>().addNestedPass<FuncOp>(createCanonicalizerPass());
diff --git a/iree/test/e2e/xla_ops/BUILD b/iree/test/e2e/xla_ops/BUILD
index 2d089a6..0eb498c 100644
--- a/iree/test/e2e/xla_ops/BUILD
+++ b/iree/test/e2e/xla_ops/BUILD
@@ -65,6 +65,7 @@
"select.mlir",
"sine.mlir",
"slice.mlir",
+ "sort.mlir",
"sqrt.mlir",
"subtract.mlir",
"tanh.mlir",
@@ -76,7 +77,6 @@
exclude = [
"round.mlir",
"scatter.mlir",
- "sort.mlir",
],
),
compiler_flags = ["-iree-input-type=mhlo"],
diff --git a/iree/test/e2e/xla_ops/CMakeLists.txt b/iree/test/e2e/xla_ops/CMakeLists.txt
index 1d5bb4a..f5219f0 100644
--- a/iree/test/e2e/xla_ops/CMakeLists.txt
+++ b/iree/test/e2e/xla_ops/CMakeLists.txt
@@ -55,6 +55,7 @@
"select.mlir"
"sine.mlir"
"slice.mlir"
+ "sort.mlir"
"sqrt.mlir"
"subtract.mlir"
"tanh.mlir"