blob: 13588a28be458271bb916e99af6a66bdb098fa07 [file]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.h"
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/GPU/Passes.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
namespace mlir {
namespace iree_compiler {
namespace {
/// A pass that replaces all occurrences of GPU device operations with their
/// corresponding NVVM equivalent.
///
/// This pass only handles device code and is not meant to be run on GPU host
/// code.
struct ConvertToNVVMPass : public ConvertToNVVMBase<ConvertToNVVMPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<LLVM::LLVMDialect, NVVM::NVVMDialect>();
}
void runOnOperation() override {
ModuleOp m = getOperation();
/// Customize the bitwidth used for the device side index computations.
LowerToLLVMOptions options(m.getContext(), DataLayout(m));
options.overrideIndexBitwidth(64);
LLVMTypeConverter converter(m.getContext(), options);
// Apply in-dialect lowering first. In-dialect lowering will replace ops
// which need to be lowered further, which is not supported by a single
// conversion pass.
// Run Vector -> Vector transformations ahead of conversion to LLVM.
{
OwningRewritePatternList patterns(&getContext());
populateScalarizeMathOps(patterns);
populateConvertSharedMemoryAllocOps(patterns);
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
vector::populateVectorContractLoweringPatterns(
patterns,
vector::VectorTransformsOptions().setVectorTransformsOptions(
vector::VectorContractLowering::OuterProduct));
mlir::vector::populateVectorTransposeLoweringPatterns(patterns);
mlir::vector::populateVectorTransferLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(m, std::move(patterns));
}
{
OwningRewritePatternList patterns(&getContext());
populateGpuRewritePatterns(patterns);
(void)applyPatternsAndFoldGreedily(m, std::move(patterns));
}
{
OwningRewritePatternList llvmPatterns(&getContext());
populateLLVMConversionPatterns(&getContext(), llvmPatterns, converter,
false);
populateMathToLLVMConversionPatterns(converter, llvmPatterns);
populateMemRefToLLVMConversionPatterns(converter, llvmPatterns);
populateStdToLLVMConversionPatterns(converter, llvmPatterns);
populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
populateGpuToNVVMConversionPatterns(converter, llvmPatterns);
populateShapeToLLVMConversionPatterns(&getContext(), &converter,
llvmPatterns);
LLVMConversionTarget target(getContext());
populateStdToLLVMFuncOpConversionPattern(converter, llvmPatterns);
configureGpuToNVVMConversionLegality(target);
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp funcOp) {
if (isEntryPoint(funcOp)) return false;
return true;
});
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
signalPassFailure();
}
}
};
} // anonymous namespace
std::unique_ptr<OperationPass<ModuleOp>> createConvertToNVVMPass() {
return std::make_unique<ConvertToNVVMPass>();
}
} // namespace iree_compiler
} // namespace mlir