blob: c5b95f796b80ee5c8c9f5c1d2f381f60948c765d [file]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "mlir/Transforms/Passes.h"
#include "iree/compiler/Codegen/Common/CPU/Passes.h"
#include "iree/compiler/Codegen/Common/PassUtils.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/VMVX/Passes.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/PassManager.h"
namespace mlir::iree_compiler {
/// Command line options used purely for development purposes. Not to be relied
/// on in any way.
static llvm::cl::opt<bool> clSkipIntermediateRoundings(
"iree-vmvx-skip-intermediate-roundings",
llvm::cl::desc(
"Allow skipping intermediate roundings. For example, in f16 matmul "
"kernels on targets with only f32 arithmetic, we have to perform each "
"multiply-accumulate in f32, and if this flag is false, then we have "
"to round those f32 accumulators to the nearest f16 every time, which "
"is slow."),
llvm::cl::init(true));
static llvm::cl::opt<bool> clEnableUKernelsDecomposeLinalgGeneric(
"iree-vmvx-enable-ukernels-decompose-linalg-generic",
llvm::cl::desc("Enables decomposition of linalg.generic ops when "
"ukernels are enabled (experimental)"),
llvm::cl::init(true));
static void addTileAndDistributePasses(OpPassManager &funcPassManager) {
funcPassManager.addPass(
createTileAndDistributeToWorkgroupsUsingForallOpPass());
funcPassManager.addPass(createBufferizeDispatchTensorLoadStorePass());
funcPassManager.addPass(createCombineResultLayoutTransformationPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
funcPassManager.addPass(createFuseTensorPadWithConsumerPass());
funcPassManager.addPass(createConcretizePadResultShapePass());
funcPassManager.addPass(IREE::LinalgExt::createDecomposeAttentionPass());
funcPassManager.addPass(
IREE::LinalgExt::createDecomposeWinogradTransformPass());
}
void addVMVXDefaultPassPipeline(OpPassManager &funcPassManager,
bool enableUKernels) {
addTileAndDistributePasses(funcPassManager);
if (enableUKernels) {
funcPassManager.addPass(createCPUPrepareUkernelsPass());
funcPassManager.addPass(
createCPULowerToUKernelsPass(clSkipIntermediateRoundings));
}
// Tensor-level micro-kernel optimizations.
// Note that this must be done post-tiling because it changes the structure
// of the dispatch region such that tiling is not always possible.
if (enableUKernels && clEnableUKernelsDecomposeLinalgGeneric) {
funcPassManager.addPass(createDecomposeLinalgGenericPass());
}
// Lower to buffers.
addCPUBufferizePasses(funcPassManager);
// Cleanup the IR that may now have unused loops.
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createRemoveSingleIterationLoopPass());
// Convert buffer-level microkernels.
if (enableUKernels) {
funcPassManager.addPass(createVMVXLowerLinalgMicrokernelsPass());
}
}
static LogicalResult buildVMVXPipeline(Attribute pipelineAttr,
OpPassManager &pm,
const CodegenPipelineOptions *options) {
assert(isa<IREE::Codegen::VMVXPipelineAttr>(pipelineAttr) &&
"unexpected VMVX pipeline attr");
if (!options) {
return emitError(UnknownLoc::get(pipelineAttr.getContext()))
<< "VMVX pipeline expects VMVXCodegenPipelineOptions";
}
auto vmvxOptions = dyn_cast<VMVXCodegenPipelineOptions>(options);
if (!vmvxOptions) {
return emitError(UnknownLoc::get(pipelineAttr.getContext()))
<< "VMVX pipeline expects VMVXCodegenPipelineOptions";
}
addVMVXDefaultPassPipeline(pm, vmvxOptions->enableUKernels);
return success();
}
void buildVMVXConfigurationPassPipeline(OpPassManager &modulePassManager) {
{
FunctionLikeNest funcPassManager(modulePassManager);
addCommonTargetExecutablePreprocessingPasses(funcPassManager);
}
modulePassManager.addPass(createMaterializeUserConfigsPass());
FunctionLikeNest(modulePassManager)
.addPass(createMaterializeDeviceEncodingPass)
// TODO: Remove the following pass the plumb support for
// #hal.descriptor_type memory space through the stack.
.addPass(createEraseHALDescriptorTypeFromMemRefPass);
modulePassManager.addPass(createVMVXSelectLoweringStrategyPass());
}
void buildVMVXLoweringPassPipeline(OpPassManager &modulePassManager) {
FunctionLikeNest(modulePassManager)
.addPass(createVMVXLowerExecutableTargetPass);
modulePassManager.addPass(createReconcileTranslationInfoPass());
modulePassManager.addPass(createResolveWorkgroupCountHintsPass());
}
// NOTE: this runs on the top-level program module containing all
// hal.executable ops.
void buildVMVXLinkingPassPipeline(OpPassManager &modulePassManager) {
// Link together executables. This may produce some IR duplication.
modulePassManager.addPass(createVMVXLinkExecutablesPass());
// Cleanup IR duplication.
modulePassManager.addNestedPass<IREE::HAL::ExecutableOp>(
mlir::createCanonicalizerPass());
// Assign final executable constant ordinals.
modulePassManager.nest<IREE::HAL::ExecutableOp>()
.addNestedPass<IREE::HAL::ExecutableVariantOp>(
createVMVXAssignConstantOrdinalsPass());
}
//===---------------------------------------------------------------------===//
// Register VMVX Passes
//===---------------------------------------------------------------------===//
namespace {
#define GEN_PASS_REGISTRATION
#include "iree/compiler/Codegen/VMVX/Passes.h.inc"
} // namespace
void registerCodegenVMVXPasses() {
// Generated.
registerPasses();
IREE::Codegen::registerVMVXPipelineBuilder(buildVMVXPipeline);
static PassPipelineRegistration<> VMVXConfigPipeline(
"iree-codegen-vmvx-configuration-pipeline",
"Runs the VMVX codegen configuration pipeline",
[](OpPassManager &modulePassManager) {
buildVMVXConfigurationPassPipeline(modulePassManager);
});
static PassPipelineRegistration<> VMVXLoweringPipeline(
"iree-codegen-vmvx-lowering-pipeline",
"Runs the VMVX codegen lowering pipeline",
[](OpPassManager &modulePassManager) {
buildVMVXLoweringPassPipeline(modulePassManager);
});
static PassPipelineRegistration<> VMVXLinkingPipeline(
"iree-codegen-vmvx-linking-pipeline",
"Runs the VMVX HAL executable linking pipeline",
[](OpPassManager &modulePassManager) {
buildVMVXLinkingPassPipeline(modulePassManager);
});
}
} // namespace mlir::iree_compiler