| // Copyright 2023 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "compiler/plugins/input/Torch/InputConversion/Passes.h" |
| #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h" |
| #include "iree/compiler/PluginAPI/Client.h" |
| #include "mlir/Dialect/MLProgram/IR/MLProgram.h" |
| #include "mlir/IR/BuiltinOps.h" |
| #include "mlir/Pass/PassManager.h" |
| #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" |
| #include "torch-mlir/Conversion/Passes.h" |
| #include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h" |
| #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" |
| #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" |
| #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" |
| |
| namespace mlir::iree_compiler { |
| |
| namespace { |
| |
| struct TorchOptions { |
| bool strictSymbolicShapes = true; |
| bool decompose = true; |
| void bindOptions(OptionsBinder &binder) { |
| static llvm::cl::OptionCategory category("Torch Input"); |
| binder.opt<bool>( |
| "iree-torch-use-strict-symbolic-shapes", strictSymbolicShapes, |
| llvm::cl::cat(category), |
| llvm::cl::desc("Forces dynamic shapes to be treated as strict")); |
| binder.opt<bool>("iree-torch-decompose-complex-ops", decompose, |
| llvm::cl::cat(category), |
| llvm::cl::desc("Decompose complex torch operations.")); |
| } |
| }; |
| |
| // The torch plugin provides dialects, passes and opt-in options. |
| // Therefore, it is appropriate for default activation. |
| struct TorchSession |
| : public PluginSession<TorchSession, TorchOptions, |
| PluginActivationPolicy::DefaultActivated> { |
| static void registerPasses() { |
| mlir::torch::registerTorchPasses(); |
| mlir::torch::registerTorchConversionPasses(); |
| mlir::torch::registerConversionPasses(); |
| mlir::torch::onnx_c::registerTorchOnnxToTorchPasses(); |
| TorchInput::registerTMTensorConversionPasses(); |
| } |
| |
| void onRegisterDialects(DialectRegistry ®istry) override { |
| registry.insert<torch::Torch::TorchDialect>(); |
| registry.insert<torch::TorchConversion::TorchConversionDialect>(); |
| registry.insert<mlir::torch::TMTensor::TMTensorDialect>(); |
| registry.insert<mlir::ml_program::MLProgramDialect>(); |
| registry.insert<IREE::LinalgExt::IREELinalgExtDialect>(); |
| } |
| |
| bool extendCustomInputConversionPassPipeline( |
| OpPassManager &passManager, std::string_view typeMnemonic) override { |
| if (typeMnemonic == "onnx") { |
| // ONNX input is a pre-processing step to torch. |
| mlir::torch::Torch::TorchLoweringPipelineOptions torchOnnxPipelineOptions; |
| torchOnnxPipelineOptions.decompose = options.decompose; |
| torchOnnxPipelineOptions.backendLegalOps = |
| TorchInput::BackendLegalOps::get(); |
| mlir::torch::Torch::createTorchOnnxToTorchBackendPipeline( |
| passManager, torchOnnxPipelineOptions); |
| } |
| |
| if (typeMnemonic == "torch" || typeMnemonic == "onnx") { |
| TorchInput::TorchToIREELoweringPipelineOptions torchOptions; |
| torchOptions.strictSymbolicShapes = options.strictSymbolicShapes; |
| torchOptions.decompose = options.decompose; |
| TorchInput::createTorchToIREEPipeline(passManager, torchOptions); |
| return true; |
| } |
| |
| // TODO: Retire the tm_tensor input pipeline once we are fully switched |
| // to the 'torch' pipeline, which handles everything from the 'torch' |
| // dialect down (vs just 'tm_tensor' which was converting a couple of |
| // ops to linalg). |
| if (typeMnemonic == "tm_tensor") { |
| passManager.addNestedPass<func::FuncOp>( |
| TorchInput::createConvertTMTensorToLinalgExtPass()); |
| return true; |
| } |
| return false; |
| } |
| |
| void populateCustomInputConversionTypes(StringSet<> &typeMnemonics) override { |
| typeMnemonics.insert("tm_tensor"); |
| typeMnemonics.insert("torch"); |
| typeMnemonics.insert("onnx"); |
| } |
| |
| void populateDetectedCustomInputConversionTypes( |
| ModuleOp &module, StringSet<> &typeMnemonics) override { |
| auto *ctx = module.getContext(); |
| const Dialect *torchDialect = ctx->getLoadedDialect("torch"); |
| const Dialect *torchConversionDialect = ctx->getLoadedDialect("torch_c"); |
| const Dialect *tmTensorDialect = ctx->getLoadedDialect("tm_tensor"); |
| |
| bool hasTorch = false; |
| bool hasOnnx = false; |
| // TODO: Retire the tm_tensor input pipeline |
| bool hasTmTensor = false; |
| |
| module.walk([&](Operation *op) { |
| Dialect *d = op->getDialect(); |
| if (d == torchDialect || d == torchConversionDialect) { |
| hasTorch = true; |
| } else if (d == tmTensorDialect) { |
| hasTmTensor = true; |
| } |
| return WalkResult::advance(); |
| }); |
| |
| for (auto funcOp : module.getOps<func::FuncOp>()) { |
| if (funcOp->getAttrOfType<mlir::IntegerAttr>( |
| "torch.onnx_meta.opset_version")) { |
| hasOnnx = true; |
| break; |
| } |
| } |
| |
| // ONNX is considered a superset of Torch. It runs all of the Torch |
| // pipelines with an extra ONNX-specific preprocessing step. |
| if (hasOnnx) { |
| typeMnemonics.insert("onnx"); |
| } else if (hasTorch) { |
| typeMnemonics.insert("torch"); |
| } |
| |
| if (hasTmTensor) { |
| typeMnemonics.insert("tm_tensor"); |
| } |
| } |
| }; |
| |
| } // namespace |
| |
| } // namespace mlir::iree_compiler |
| |
| IREE_DEFINE_COMPILER_OPTION_FLAGS(::mlir::iree_compiler::TorchOptions); |
| |
| extern "C" bool iree_register_compiler_plugin_input_torch( |
| mlir::iree_compiler::PluginRegistrar *registrar) { |
| registrar->registerPlugin<::mlir::iree_compiler::TorchSession>("input_torch"); |
| return true; |
| } |