add decompose complex ops pass to torch_to_iree (#14992)
diff --git a/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.cpp b/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.cpp index c3fde5a..9d44ec4 100644 --- a/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.cpp +++ b/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.cpp
@@ -33,6 +33,10 @@ // backends. We do this first as it tends to involve pattern-matching against // constants, (e.g. dimensions which must be constant in a ranked programming // model) and those constants get somewhat obscured by TorchToArith. + llvm::ArrayRef<std::string> emptyArrayRef; + + pm.addNestedPass<func::FuncOp>( + torch::Torch::createDecomposeComplexOpsPass(emptyArrayRef)); pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToTMTensorPass()); pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToLinalgPass()); pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToSCFPass());
diff --git a/compiler/plugins/input/Torch/torch-iree/InputConversion/test/torch_to_iree.mlir b/compiler/plugins/input/Torch/torch-iree/InputConversion/test/torch_to_iree.mlir index b9655dc..e89142b 100644 --- a/compiler/plugins/input/Torch/torch-iree/InputConversion/test/torch_to_iree.mlir +++ b/compiler/plugins/input/Torch/torch-iree/InputConversion/test/torch_to_iree.mlir
@@ -27,3 +27,28 @@ util.global private @_params.classifier.weight {noinline} : tensor<30x20xf32> util.global private @_params.classifier.bias {noinline} : tensor<30xf32> } + +// ----- + +// Verify we can decompose complex ops +// CHECK: func @main(%arg0: tensor<2x3x4xf32>) -> (tensor<2x3x4xf32>, tensor<2x3x4xf32>) +// CHECK: tensor.empty +module { + func.func @main(%arg0: !torch.vtensor<[2,3,4],f32>) -> (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) { + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %int4 = torch.constant.int 4 + %0 = torch.prim.ListConstruct %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> + %int12 = torch.constant.int 12 + %int4_0 = torch.constant.int 4 + %int1 = torch.constant.int 1 + %1 = torch.prim.ListConstruct %int12, %int4_0, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> + %none = torch.constant.none + %none_1 = torch.constant.none %cpu = torch.constant.device "cpu" + %false = torch.constant.bool false + %2 = torch.aten.empty_strided %0, %1, %none, %none_1, %cpu, %false : !torch.list<int>, !torch.list<int>, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[2,3,4],f32> + %false_2 = torch.constant.bool false + %3 = torch.aten.copy %arg0, %2, %false_2 : !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.bool -> !torch.vtensor<[2,3,4],f32> + return %3, %3 : !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32> + } +}