blob: 27977376aa0aa7be386801eb3f75392242e8b8fd [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/Modules/VMVX/Transforms/Passes.h"
#include <memory>
#include "iree-dialects/Dialect/LinalgExt/Transforms/Passes.h"
#include "iree/compiler/Codegen/Passes.h"
#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
#include "mlir/Dialect/Func/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/Passes.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace VMVX {
static void buildVectorVMVXTransformPassPipeline(OpPassManager &passManager) {
passManager.nest<ModuleOp>().nest<FuncOp>().addPass(
createTypePropagationPass());
passManager.addPass(createLLVMCPULowerExecutableTargetPass());
OpPassManager &nestedModulePM = passManager.nest<ModuleOp>();
// ---------------------------------------------------------------------------
// Linalg -> Vectors
// ---------------------------------------------------------------------------
// Tiling and distribution.
nestedModulePM.addNestedPass<FuncOp>(createCanonicalizerPass());
// TODO(#5925): This can also be modified to just use the dynamic pass
// pipeline like the CPU side.
// nestedModulePM.addNestedPass<FuncOp>(
// createLinalgTileAndVectorizeWorkgroupsPass());
// Linalg -> SCF.
nestedModulePM.addNestedPass<FuncOp>(
IREE::LinalgExt::createLinalgExtToLoopsPass());
nestedModulePM.addNestedPass<FuncOp>(createMemrefCopyToLinalgPass());
nestedModulePM.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());
nestedModulePM.addNestedPass<FuncOp>(createCanonicalizerPass());
nestedModulePM.addNestedPass<FuncOp>(createCSEPass());
nestedModulePM.addNestedPass<FuncOp>(createConvertVectorToSCFPass());
nestedModulePM.addNestedPass<FuncOp>(createCanonicalizerPass());
nestedModulePM.addNestedPass<FuncOp>(arith::createArithmeticExpandOpsPass());
nestedModulePM.addNestedPass<FuncOp>(memref::createExpandOpsPass());
// Handle tensor-type constants.
nestedModulePM.addPass(arith::createConstantBufferizePass());
nestedModulePM.addPass(createFoldTensorExtractOpPass());
// Flatten and cleanup memrefs.
nestedModulePM.addNestedPass<FuncOp>(memref::createFoldSubViewOpsPass());
nestedModulePM.addPass(createCanonicalizerPass());
nestedModulePM.addPass(createCSEPass());
nestedModulePM.addPass(createFlattenMemRefSubspanPass());
nestedModulePM.addPass(memref::createNormalizeMemRefsPass());
nestedModulePM.addNestedPass<FuncOp>(createAffineScalarReplacementPass());
nestedModulePM.addPass(createCanonicalizerPass());
}
static void buildLoopOptimizationVMVXTransformPassPipeline(
OpPassManager &passManager) {
OpPassManager &nestedModulePM = passManager.nest<ModuleOp>();
nestedModulePM.addNestedPass<FuncOp>(createLowerAffinePass());
nestedModulePM.addNestedPass<FuncOp>(createForOpCanonicalizationPass());
nestedModulePM.addNestedPass<FuncOp>(createLoopInvariantCodeMotionPass());
}
void buildVMVXTransformPassPipeline(OpPassManager &passManager) {
// ---------------------------------------------------------------------------
// Linalg -> Scalars/Vectors
// ---------------------------------------------------------------------------
buildVectorVMVXTransformPassPipeline(passManager);
passManager.addPass(createCanonicalizerPass());
passManager.addPass(createCSEPass());
// ---------------------------------------------------------------------------
// Standard/Vector/HAL/etc -> VMVX conversion
// ---------------------------------------------------------------------------
passManager.addNestedPass<mlir::ModuleOp>(createConversionPass());
passManager.nest<mlir::ModuleOp>().addNestedPass<FuncOp>(
memref::createFoldSubViewOpsPass());
passManager.addPass(createCanonicalizerPass());
passManager.addPass(createCSEPass());
// ---------------------------------------------------------------------------
// Cleanup and canonicalization
// ---------------------------------------------------------------------------
buildLoopOptimizationVMVXTransformPassPipeline(passManager);
passManager.addPass(createCanonicalizerPass());
passManager.addPass(createCSEPass());
}
void createVMVXTransformPassPipeline() {
PassPipelineRegistration<> transformPassPipeline(
"iree-vmvx-transformation-pipeline",
"Runs the full IREE VMVX dialect transformation pipeline",
[](OpPassManager &passManager) {
buildVMVXTransformPassPipeline(passManager);
});
}
} // namespace VMVX
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir