blob: 42b2ef3cdec3576b9bc952172e3b5938df6073d7 [file] [log] [blame]
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/compiler/Conversion/LinalgToLLVMGPU/Passes.h"
#include "iree/compiler/Conversion/Common/Passes.h"
#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassOptions.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/Passes.h"
namespace mlir {
namespace iree_compiler {
static void addLinalgToLLVMGPUPasses(OpPassManager &pm, bool useROCM) {
//===--------------------------------------------------------------------===//
// Initial clean up.
//===--------------------------------------------------------------------===//
pm.addNestedPass<ModuleOp>(createCanonicalizerPass());
pm.addNestedPass<ModuleOp>(createCSEPass());
// Distribute linalg onto threads within the workgroup.
pm.addPass(createTileAndDistributeToThreads());
pm.addNestedPass<ModuleOp>(createCanonicalizerPass());
pm.addNestedPass<ModuleOp>(createCSEPass());
pm.nest<ModuleOp>().addNestedPass<FuncOp>(
createRemoveSingleIterationLoopPass());
// Linalg -> vector
pm.nest<ModuleOp>().addNestedPass<FuncOp>(createVectorizationPass());
pm.nest<ModuleOp>().addNestedPass<FuncOp>(createCanonicalizerPass());
pm.nest<ModuleOp>().addNestedPass<FuncOp>(createCSEPass());
pm.addNestedPass<ModuleOp>(createLowerAffinePass());
pm.addNestedPass<ModuleOp>(createCanonicalizerPass());
pm.addNestedPass<ModuleOp>(createCSEPass());
// TODO: This currently maps to a single thread. We should share Tile and
// distribute with other GPU backends.
// Linalg -> SCF
pm.nest<ModuleOp>().addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());
pm.nest<ModuleOp>().addNestedPass<FuncOp>(createCanonicalizerPass());
pm.nest<ModuleOp>().addNestedPass<FuncOp>(createCSEPass());
// SCF -> STD
pm.nest<ModuleOp>().addNestedPass<FuncOp>(createLowerToCFGPass());
pm.nest<ModuleOp>().addNestedPass<FuncOp>(createCanonicalizerPass());
pm.nest<ModuleOp>().addNestedPass<FuncOp>(createCSEPass());
pm.addNestedPass<ModuleOp>(createLowerAffinePass());
// Strip out the debug info for the kernel as CUDA driver doesn't diggest PTX
// debug info well.
pm.addNestedPass<ModuleOp>(createStripDebugInfoPass());
if (useROCM) {
// convert to ROCDL.
pm.addNestedPass<ModuleOp>(createConvertToROCDLPass());
} else {
// convert to NVVM.
pm.addNestedPass<ModuleOp>(createConvertToNVVMPass());
}
}
void buildLLVMGPUTransformPassPipeline(OpPassManager &pm, bool useROCM) {
OpPassManager &nestedModulePM = pm.nest<ModuleOp>();
nestedModulePM.addPass(createInlinerPass());
WorkgroupMemoryAllocationFn allocationFn =
[](OpBuilder &builder, Location loc, ArrayRef<int64_t> staticShape,
Type elementType, ArrayRef<Value> dynamicSizes) {
MemRefType allocType = MemRefType::get(staticShape, elementType, {}, 3);
return builder.create<memref::AllocOp>(loc, allocType, dynamicSizes);
};
addLinalgBufferizePasses(nestedModulePM, allocationFn);
//===--------------------------------------------------------------------===//
// Convert Linalg ops to LLVM+NVVM/ROCDL ops.
//
// Post-conditions:
// - All Linalg/Loops/GPU/Affine/Standard ops are converted away.
// - The module contains the final llvm.module ready to be serialized.
//===--------------------------------------------------------------------===//
addLinalgToLLVMGPUPasses(pm, useROCM);
}
static PassPipelineRegistration<> linalgToNVVMPipeline(
"iree-codegen-linalg-to-nvvm-pipeline",
"Runs the progressive lowering pipeline from Linalg to NVVM",
[](OpPassManager &passManager) {
addLinalgToLLVMGPUPasses(passManager, false);
});
static PassPipelineRegistration<> linalgToROCDLPipeline(
"iree-codegen-linalg-to-rocdl-pipeline",
"Runs the progressive lowering pipeline from Linalg to ROCDL",
[](OpPassManager &passManager) {
addLinalgToLLVMGPUPasses(passManager, true);
});
static PassPipelineRegistration<> hloToLinalgNVVMPipeline(
"iree-codegen-hlo-to-nvvm-pipeline",
"Runs the progressive lowering pipeline from XLA HLO to Linalg to "
"NVVM",
[](OpPassManager &passManager) {
buildLLVMGPUTransformPassPipeline(passManager, false);
});
static PassPipelineRegistration<> hloToLinalgROCDLPipeline(
"iree-codegen-hlo-to-rocdl-pipeline",
"Runs the progressive lowering pipeline from XLA HLO to Linalg to "
"ROCDL",
[](OpPassManager &passManager) {
buildLLVMGPUTransformPassPipeline(passManager, true);
});
} // namespace iree_compiler
} // namespace mlir