blob: a0682b5c6e4190bd9179d5e17bd5b7c21e81bfc7 [file] [log] [blame]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "compiler/plugins/input/Torch/InputConversion/Passes.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
#include "torch-mlir/Conversion/TorchToTensor/TorchToTensor.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
namespace mlir::iree_compiler::TorchInput {
namespace {
#define GEN_PASS_REGISTRATION
#include "compiler/plugins/input/Torch/InputConversion/Passes.h.inc" // IWYU pragma: export
} // namespace
void createTorchToIREEPipeline(
OpPassManager &pm, const TorchToIREELoweringPipelineOptions &options) {
// This pipeline adapted from
// createTorchBackendToLinalgOnTensorsBackendPipeline. Keep in sync with
// additions there. Lower to linalg + guards which is the input to codegen
// backends. We do this first as it tends to involve pattern-matching against
// constants, (e.g. dimensions which must be constant in a ranked programming
// model) and those constants get somewhat obscured by TorchToArith.
llvm::ArrayRef<std::string> emptyArrayRef;
// Dynamic shape bindings add a lot of structure to the IR which we prefer to
// leverage and eliminate prior to any other activity, so do this first.
pm.addNestedPass<func::FuncOp>(createBindSymbolicShapesPass());
if (options.strictSymbolicShapes) {
pm.addNestedPass<func::FuncOp>(createSetStrictSymbolicShapesPass());
// Run canonicalization in case any previously non-strict dynamic code can
// now be simplified.
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
}
pm.addNestedPass<func::FuncOp>(createBitCastQuantTensorPass());
pm.addNestedPass<func::FuncOp>(
torch::Torch::createReduceOpVariantsPass(llvm::StringRef()));
pm.addNestedPass<func::FuncOp>(
mlir::torch::TorchConversion::createConvertCustomQuantOpPass());
pm.addNestedPass<func::FuncOp>(
torch::Torch::createDecomposeComplexOpsPass(emptyArrayRef));
pm.addNestedPass<func::FuncOp>(torch::Torch::createFuseQuantizedOpsPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(torch::Torch::createScalarizeShapesPass());
pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToTMTensorPass());
pm.addNestedPass<func::FuncOp>(
TorchInput::createConvertTMTensorToLinalgExtPass());
pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToTensorPass());
pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToLinalgPass());
pm.addNestedPass<func::FuncOp>(createCSEPass());
pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToSCFPass());
pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToArithPass());
pm.addPass(torch::createConvertTorchConversionToMLProgramPass());
pm.addNestedPass<func::FuncOp>(memref::createExpandOpsPass());
// Clean up any non-canonical code introduced above..
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// Resolve `dim` ops on tensors (which currently live in the `memref`
// dialect for some reason -- we don't have memrefs at this level).
pm.addNestedPass<func::FuncOp>(
memref::createResolveShapedTypeResultDimsPass());
// The resolution of `dim` ops tends to create identical ops. CSE them.
pm.addNestedPass<func::FuncOp>(createCSEPass());
// Regular function calls in torch have to be inlined presently. In the
// future, we would like to support async invocation, which will operate
// differently and would not be subject to inlining.
pm.addPass(mlir::createInlinerPass());
pm.addPass(createFuncConversionPass());
pm.addNestedPass<IREE::Util::FuncOp>(createCanonicalizerPass());
pm.addPass(createSymbolDCEPass());
// Finish the type conversion from `torch` types to the types of the
// linalg-on-tensors backend contract.
pm.addNestedPass<IREE::Util::FuncOp>(
torch::TorchConversion::createFinalizingBackendTypeConversionPass());
}
void registerTMTensorConversionPasses() {
// Generated.
registerPasses();
mlir::PassPipelineRegistration<TorchToIREELoweringPipelineOptions>(
"torch-to-iree",
"Pipeline to lower from the Torch backend contract to legal IREE input.",
createTorchToIREEPipeline);
}
} // namespace mlir::iree_compiler::TorchInput