Added ComplexToStandardPass to the LLVM compilation pipelines (#12273)
ComplexToStandard handles the decomposed versions of many complex
floating point operations. It is preferable to include the
decompositions in the ConvertToLLVM however some of the decompositions
require the PolynomialApproximations pass, and must be performed earlier
in the pipeline.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD
index 744c688..5ebf191 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD
@@ -74,6 +74,7 @@
"@llvm-project//mlir:ArmNeonDialect",
"@llvm-project//mlir:BufferizationDialect",
"@llvm-project//mlir:ComplexToLLVM",
+ "@llvm-project//mlir:ComplexToStandard",
"@llvm-project//mlir:ControlFlowToLLVM",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:FuncDialect",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
index 948f734..693c113 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
@@ -54,6 +54,7 @@
MLIRArmNeonDialect
MLIRBufferizationDialect
MLIRComplexToLLVM
+ MLIRComplexToStandard
MLIRControlFlowToLLVM
MLIRFuncDialect
MLIRFuncToLLVM
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index 91bb0d7..18dd53f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -16,6 +16,7 @@
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
+#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
@@ -717,6 +718,9 @@
passManager.addPass(arith::createConstantBufferizePass());
passManager.addPass(createFoldTensorExtractOpPass());
+ // Handle complex operation conversion.
+ passManager.addPass(createConvertComplexToStandardPass());
+
// math dialect elementry functions -> polynomial form.
passManager.addNestedPass<func::FuncOp>(createPolynomialApproximationPass());
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD
index 2f41fe4..4710e4e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD
@@ -62,6 +62,7 @@
"@llvm-project//mlir:ArithToLLVM",
"@llvm-project//mlir:ArithTransforms",
"@llvm-project//mlir:BufferizationDialect",
+ "@llvm-project//mlir:ComplexToStandard",
"@llvm-project//mlir:ControlFlowToLLVM",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FuncToLLVM",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
index c4da51c..f40341f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
@@ -48,6 +48,7 @@
MLIRArithToLLVM
MLIRArithTransforms
MLIRBufferizationDialect
+ MLIRComplexToStandard
MLIRControlFlowToLLVM
MLIRFuncDialect
MLIRFuncToLLVM
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index b13a5b9..37d0ab3 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -11,6 +11,7 @@
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
+#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
@@ -403,6 +404,9 @@
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createCSEPass());
+ // Handle complex operation conversion.
+ pm.addPass(createConvertComplexToStandardPass());
+
// math dialect elementry functions -> polynomial form.
pm.addNestedPass<func::FuncOp>(createPolynomialApproximationPass());
diff --git a/tests/e2e/xla_ops/BUILD b/tests/e2e/xla_ops/BUILD
index 41d711e..3c90766 100644
--- a/tests/e2e/xla_ops/BUILD
+++ b/tests/e2e/xla_ops/BUILD
@@ -32,6 +32,7 @@
"broadcast_in_dim.mlir",
"clamp.mlir",
"compare.mlir",
+ "complex.mlir",
"concatenate.mlir",
"constant.mlir",
"convert.mlir",
@@ -113,6 +114,7 @@
"broadcast_in_dim.mlir",
"clamp.mlir",
"compare.mlir",
+ "complex.mlir",
"concatenate.mlir",
"constant.mlir",
"convert.mlir",
@@ -193,6 +195,7 @@
"broadcast_in_dim.mlir",
"clamp.mlir",
"compare.mlir",
+ "complex.mlir",
"concatenate.mlir",
"constant.mlir",
"convert.mlir",
@@ -262,6 +265,7 @@
"broadcast_in_dim.mlir",
"clamp.mlir",
"compare.mlir",
+ "complex.mlir",
"concatenate.mlir",
"constant.mlir",
"convert.mlir",
@@ -333,6 +337,7 @@
"broadcast_in_dim.mlir",
"clamp.mlir",
"compare.mlir",
+ "complex.mlir",
"concatenate.mlir",
"constant.mlir",
"convert.mlir",
@@ -405,6 +410,7 @@
"broadcast_in_dim.mlir",
"clamp.mlir",
"compare.mlir",
+ "complex.mlir",
"concatenate.mlir",
"constant.mlir",
"convert.mlir",
diff --git a/tests/e2e/xla_ops/CMakeLists.txt b/tests/e2e/xla_ops/CMakeLists.txt
index 20958d0..7bb80c2 100644
--- a/tests/e2e/xla_ops/CMakeLists.txt
+++ b/tests/e2e/xla_ops/CMakeLists.txt
@@ -23,6 +23,7 @@
"broadcast_in_dim.mlir"
"clamp.mlir"
"compare.mlir"
+ "complex.mlir"
"concatenate.mlir"
"constant.mlir"
"convert.mlir"
@@ -98,6 +99,7 @@
"broadcast_in_dim.mlir"
"clamp.mlir"
"compare.mlir"
+ "complex.mlir"
"concatenate.mlir"
"constant.mlir"
"convert.mlir"
@@ -173,6 +175,7 @@
"broadcast_in_dim.mlir"
"clamp.mlir"
"compare.mlir"
+ "complex.mlir"
"concatenate.mlir"
"constant.mlir"
"convert.mlir"
@@ -241,6 +244,7 @@
"broadcast_in_dim.mlir"
"clamp.mlir"
"compare.mlir"
+ "complex.mlir"
"concatenate.mlir"
"constant.mlir"
"convert.mlir"
@@ -308,6 +312,7 @@
"broadcast_in_dim.mlir"
"clamp.mlir"
"compare.mlir"
+ "complex.mlir"
"concatenate.mlir"
"constant.mlir"
"convert.mlir"
@@ -374,6 +379,7 @@
"broadcast_in_dim.mlir"
"clamp.mlir"
"compare.mlir"
+ "complex.mlir"
"concatenate.mlir"
"constant.mlir"
"convert.mlir"
@@ -448,6 +454,7 @@
"broadcast_in_dim.mlir"
# "clamp.mlir" # TODO(#10906): fix (i8/i16?)
# "compare.mlir" # TODO(#10906): fix (i8/i16?)
+ "complex.mlir"
"concatenate.mlir"
"constant.mlir"
# "convert.mlir" # TODO(#10906): fix (i8/i16?)
diff --git a/tests/e2e/xla_ops/complex.mlir b/tests/e2e/xla_ops/complex.mlir
new file mode 100644
index 0000000..63898c5
--- /dev/null
+++ b/tests/e2e/xla_ops/complex.mlir
@@ -0,0 +1,23 @@
+func.func @math_sin() {
+ %real = util.unfoldable_constant dense<[0., 1., 1., -1.]> : tensor<4xf32>
+ %imag = util.unfoldable_constant dense<[0., 1., -1., 1.]> : tensor<4xf32>
+ %complex = "mhlo.complex"(%real, %imag) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
+ %result = "mhlo.sine"(%complex) : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
+ %result_real = "mhlo.real"(%result) : (tensor<4xcomplex<f32>>) -> tensor<4xf32>
+ %result_imag = "mhlo.imag"(%result) : (tensor<4xcomplex<f32>>) -> tensor<4xf32>
+ check.expect_almost_eq_const(%result_real, dense<[0., 1.29846, 1.29846, -1.29846]> : tensor<4xf32>) : tensor<4xf32>
+ check.expect_almost_eq_const(%result_imag, dense<[0., 0.634964, -0.634964, 0.634964]> : tensor<4xf32>) : tensor<4xf32>
+ return
+}
+
+func.func @math_exp() {
+ %real = util.unfoldable_constant dense<[0., 1., 1., -1.]> : tensor<4xf32>
+ %imag = util.unfoldable_constant dense<[0., 1., -1., 1.]> : tensor<4xf32>
+ %complex = "mhlo.complex"(%real, %imag) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
+ %result = "mhlo.exponential"(%complex) : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
+ %result_real = "mhlo.real"(%result) : (tensor<4xcomplex<f32>>) -> tensor<4xf32>
+ %result_imag = "mhlo.imag"(%result) : (tensor<4xcomplex<f32>>) -> tensor<4xf32>
+ check.expect_almost_eq_const(%result_real, dense<[1., 1.46869, 1.46869, 0.19876]> : tensor<4xf32>) : tensor<4xf32>
+ check.expect_almost_eq_const(%result_imag, dense<[0., 2.28735, -2.28735, 0.30956]> : tensor<4xf32>) : tensor<4xf32>
+ return
+}