add decompose complex ops pass to torch_to_iree (#14992)

diff --git a/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.cpp b/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.cpp
index c3fde5a..9d44ec4 100644
--- a/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.cpp
+++ b/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.cpp
@@ -33,6 +33,10 @@
   // backends. We do this first as it tends to involve pattern-matching against
   // constants, (e.g. dimensions which must be constant in a ranked programming
   // model) and those constants get somewhat obscured by TorchToArith.
+  llvm::ArrayRef<std::string> emptyArrayRef;
+
+  pm.addNestedPass<func::FuncOp>(
+      torch::Torch::createDecomposeComplexOpsPass(emptyArrayRef));
   pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToTMTensorPass());
   pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToLinalgPass());
   pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToSCFPass());
diff --git a/compiler/plugins/input/Torch/torch-iree/InputConversion/test/torch_to_iree.mlir b/compiler/plugins/input/Torch/torch-iree/InputConversion/test/torch_to_iree.mlir
index b9655dc..e89142b 100644
--- a/compiler/plugins/input/Torch/torch-iree/InputConversion/test/torch_to_iree.mlir
+++ b/compiler/plugins/input/Torch/torch-iree/InputConversion/test/torch_to_iree.mlir
@@ -27,3 +27,28 @@
   util.global private @_params.classifier.weight {noinline} : tensor<30x20xf32>
   util.global private @_params.classifier.bias {noinline} : tensor<30xf32>
 }
+
+// -----
+
+// Verify we can decompose complex ops
+// CHECK: func @main(%arg0: tensor<2x3x4xf32>) -> (tensor<2x3x4xf32>, tensor<2x3x4xf32>) 
+// CHECK: tensor.empty
+module {
+  func.func @main(%arg0: !torch.vtensor<[2,3,4],f32>) -> (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) {
+    %int2 = torch.constant.int 2
+    %int3 = torch.constant.int 3
+    %int4 = torch.constant.int 4
+    %0 = torch.prim.ListConstruct %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
+    %int12 = torch.constant.int 12
+    %int4_0 = torch.constant.int 4
+    %int1 = torch.constant.int 1
+    %1 = torch.prim.ListConstruct %int12, %int4_0, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
+    %none = torch.constant.none
+    %none_1 = torch.constant.none    %cpu = torch.constant.device "cpu"
+    %false = torch.constant.bool false
+    %2 = torch.aten.empty_strided %0, %1, %none, %none_1, %cpu, %false : !torch.list<int>, !torch.list<int>, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[2,3,4],f32>
+    %false_2 = torch.constant.bool false
+    %3 = torch.aten.copy %arg0, %2, %false_2 : !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.bool -> !torch.vtensor<[2,3,4],f32>
+    return %3, %3 : !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>
+  }
+}