add decompose complex ops pass to torch_to_iree (#14992)
diff --git a/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.cpp b/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.cpp
index c3fde5a..9d44ec4 100644
--- a/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.cpp
+++ b/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.cpp
@@ -33,6 +33,10 @@
// backends. We do this first as it tends to involve pattern-matching against
// constants, (e.g. dimensions which must be constant in a ranked programming
// model) and those constants get somewhat obscured by TorchToArith.
+ llvm::ArrayRef<std::string> emptyArrayRef;
+
+ pm.addNestedPass<func::FuncOp>(
+ torch::Torch::createDecomposeComplexOpsPass(emptyArrayRef));
pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToTMTensorPass());
pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToLinalgPass());
pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToSCFPass());
diff --git a/compiler/plugins/input/Torch/torch-iree/InputConversion/test/torch_to_iree.mlir b/compiler/plugins/input/Torch/torch-iree/InputConversion/test/torch_to_iree.mlir
index b9655dc..e89142b 100644
--- a/compiler/plugins/input/Torch/torch-iree/InputConversion/test/torch_to_iree.mlir
+++ b/compiler/plugins/input/Torch/torch-iree/InputConversion/test/torch_to_iree.mlir
@@ -27,3 +27,28 @@
util.global private @_params.classifier.weight {noinline} : tensor<30x20xf32>
util.global private @_params.classifier.bias {noinline} : tensor<30xf32>
}
+
+// -----
+
+// Verify we can decompose complex ops
+// CHECK: func @main(%arg0: tensor<2x3x4xf32>) -> (tensor<2x3x4xf32>, tensor<2x3x4xf32>)
+// CHECK: tensor.empty
+module {
+ func.func @main(%arg0: !torch.vtensor<[2,3,4],f32>) -> (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) {
+ %int2 = torch.constant.int 2
+ %int3 = torch.constant.int 3
+ %int4 = torch.constant.int 4
+ %0 = torch.prim.ListConstruct %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
+ %int12 = torch.constant.int 12
+ %int4_0 = torch.constant.int 4
+ %int1 = torch.constant.int 1
+ %1 = torch.prim.ListConstruct %int12, %int4_0, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
+ %none = torch.constant.none
+ %none_1 = torch.constant.none %cpu = torch.constant.device "cpu"
+ %false = torch.constant.bool false
+ %2 = torch.aten.empty_strided %0, %1, %none, %none_1, %cpu, %false : !torch.list<int>, !torch.list<int>, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[2,3,4],f32>
+ %false_2 = torch.constant.bool false
+ %3 = torch.aten.copy %arg0, %2, %false_2 : !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.bool -> !torch.vtensor<[2,3,4],f32>
+ return %3, %3 : !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>
+ }
+}