blob: fb3883583051965942ba09f97739881d82840246 [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.h"
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/GPU/Passes.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
namespace mlir {
namespace iree_compiler {
namespace {
/// A pass that replaces all occurrences of GPU device operations with their
/// corresponding NVVM equivalent.
///
/// This pass only handles device code and is not meant to be run on GPU host
/// code.
struct ConvertToNVVMPass : public ConvertToNVVMBase<ConvertToNVVMPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<LLVM::LLVMDialect, NVVM::NVVMDialect>();
}
void runOnOperation() override {
ModuleOp m = getOperation();
/// Customize the bitwidth used for the device side index computations.
LowerToLLVMOptions options(m.getContext(), DataLayout(m));
options.overrideIndexBitwidth(64);
LLVMTypeConverter converter(m.getContext(), options);
// Lowering for MMAMatrixType.
converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
return convertMMAToLLVMType(type);
});
// Apply in-dialect lowering first. In-dialect lowering will replace ops
// which need to be lowered further, which is not supported by a single
// conversion pass.
// Run Vector -> Vector transformations ahead of conversion to LLVM.
{
RewritePatternSet patterns(&getContext());
populateScalarizeMathOps(patterns);
populateConvertSharedMemoryAllocOps(patterns);
populateLowerHALInterfaceOp(patterns);
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
vector::populateVectorBroadcastLoweringPatterns(patterns);
vector::populateVectorContractLoweringPatterns(
patterns,
vector::VectorTransformsOptions().setVectorTransformsOptions(
vector::VectorContractLowering::OuterProduct));
vector::populateVectorMaskOpLoweringPatterns(patterns);
vector::populateVectorShapeCastLoweringPatterns(patterns);
vector::populateVectorTransposeLoweringPatterns(patterns);
vector::populateVectorTransferLoweringPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) {
return signalPassFailure();
}
}
{
RewritePatternSet patterns(&getContext());
populateGpuRewritePatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) {
return signalPassFailure();
}
}
{
RewritePatternSet llvmPatterns(&getContext());
populateLLVMConversionPatterns(&getContext(), llvmPatterns, converter);
populateMathToLLVMConversionPatterns(converter, llvmPatterns);
populateMemRefToLLVMConversionPatterns(converter, llvmPatterns);
populateStdToLLVMConversionPatterns(converter, llvmPatterns);
arith::populateArithmeticToLLVMConversionPatterns(converter,
llvmPatterns);
populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
populateGpuToNVVMConversionPatterns(converter, llvmPatterns);
populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns);
LLVMConversionTarget target(getContext());
populateStdToLLVMFuncOpConversionPattern(converter, llvmPatterns);
configureGpuToNVVMConversionLegality(target);
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp funcOp) {
if (isEntryPoint(funcOp)) return false;
return true;
});
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns)))) {
signalPassFailure();
}
}
}
};
} // anonymous namespace
std::unique_ptr<OperationPass<ModuleOp>> createConvertToNVVMPass() {
return std::make_unique<ConvertToNVVMPass>();
}
} // namespace iree_compiler
} // namespace mlir