blob: 16cc62f9a7347477b9f48c3183c1bbafda737a3e [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Codegen/Common/GPU/GPUPatterns.h"
#include "iree/compiler/Codegen/Common/Transforms.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
#include "iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.h"
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir::iree_compiler {
#define GEN_PASS_DEF_CONVERTTONVVMPASS
#include "iree/compiler/Codegen/LLVMGPU/Passes.h.inc"
namespace {
/// A pass that replaces all occurrences of GPU device operations with their
/// corresponding NVVM equivalent.
///
/// This pass only handles device code and is not meant to be run on GPU host
/// code.
struct ConvertToNVVMPass final
: impl::ConvertToNVVMPassBase<ConvertToNVVMPass> {
using impl::ConvertToNVVMPassBase<ConvertToNVVMPass>::ConvertToNVVMPassBase;
void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<gpu::GPUDialect, IREE::GPU::IREEGPUDialect, LLVM::LLVMDialect,
NVVM::NVVMDialect, affine::AffineDialect>();
}
void runOnOperation() override {
ModuleOp m = getOperation();
/// Customize the bitwidth used for the device side index computations.
LowerToLLVMOptions options(m.getContext(), DataLayout(m));
options.overrideIndexBitwidth(64);
LLVMTypeConverter converter(m.getContext(), options);
populateGpuMemorySpaceAttributeConversions(
converter, [](gpu::AddressSpace space) -> unsigned {
switch (space) {
case gpu::AddressSpace::Global:
return static_cast<unsigned>(
NVVM::NVVMMemorySpace::kGlobalMemorySpace);
case gpu::AddressSpace::Workgroup:
return static_cast<unsigned>(
NVVM::NVVMMemorySpace::kSharedMemorySpace);
case gpu::AddressSpace::Private:
return 0;
}
llvm_unreachable("unknown address space enum value");
return 0;
});
// Lowering for MMAMatrixType.
converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
return convertMMAToLLVMType(type);
});
// Convert dummy tokens.
converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
return converter.convertType(IntegerType::get(type.getContext(), 32));
});
// Apply in-dialect lowering first. In-dialect lowering will replace ops
// which need to be lowered further, which is not supported by a single
// conversion pass.
// Run Vector -> Vector transformations ahead of conversion to LLVM.
{
RewritePatternSet patterns(&getContext());
populateVectorToSCFConversionPatterns(
patterns, VectorTransferToSCFOptions().enableFullUnroll());
populateDropSharedMemoryDeallocOpPatterns(patterns);
populateScalarizeMathOps(patterns);
populateConvertSharedMemoryAllocOps(patterns);
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
vector::populateVectorBroadcastLoweringPatterns(patterns);
vector::populateVectorContractLoweringPatterns(
patterns,
vector::VectorTransformsOptions().setVectorTransformsOptions(
vector::VectorContractLowering::OuterProduct));
vector::populateVectorMaskOpLoweringPatterns(patterns);
// We currently always use 64 bit indices, thus ensure the bit width of
// the mask compare is consistent.
vector::populateVectorMaskMaterializationPatterns(
patterns, /*force32BitVectorIndices=*/false);
vector::populateVectorShapeCastLoweringPatterns(patterns);
// TODO: doubtful that the "default" does what one want here, it is likely
// better to use something else.
vector::populateVectorTransposeLoweringPatterns(
patterns, vector::VectorTransformsOptions());
vector::populateVectorTransferLoweringPatterns(patterns);
arith::populateExpandBFloat16Patterns(patterns);
if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) {
return signalPassFailure();
}
}
{
RewritePatternSet patterns(&getContext());
populateGpuRewritePatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) {
return signalPassFailure();
}
}
{
// Convert arith::maximumf/minimumf ops on older gpus since the lowering
// is faulty for them.
// TODO: Remove this once the lowering in LLVM is fixed
// (https://github.com/llvm/llvm-project/issues/64606).
std::optional<int> cc = getGPUTargetAttr(m).getCUDAComputeCapability();
if (!cc || cc.value() < 80) {
RewritePatternSet patterns(&getContext());
populateReplaceSlowMinMaxOpsPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) {
return signalPassFailure();
}
}
}
{
RewritePatternSet llvmPatterns(&getContext());
populateLowerHALInterfaceOp(llvmPatterns);
populateLLVMConversionPatterns(&getContext(), llvmPatterns, converter);
populateComplexToLLVMConversionPatterns(converter, llvmPatterns);
populateMathToLLVMConversionPatterns(converter, llvmPatterns);
memref::populateExpandStridedMetadataPatterns(llvmPatterns);
populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);
populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
populateGpuToNVVMConversionPatterns(converter, llvmPatterns);
populateNVGPUToNVVMConversionPatterns(converter, llvmPatterns);
populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns);
/// Target specification.
LLVMConversionTarget target(getContext());
target.addIllegalOp<func::FuncOp>();
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
target.addIllegalDialect<gpu::GPUDialect>();
target.addIllegalOp<
LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp,
LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp,
LLVM::RoundEvenOp, LLVM::RoundOp, LLVM::SinOp, LLVM::SqrtOp>();
// TODO: Remove once we support replacing non-root ops.
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns)))) {
signalPassFailure();
}
}
// Convert NVVM ops to Inline Assembly.
{
RewritePatternSet patterns(&getContext());
populateNVVMToLLVMConversionPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) {
return signalPassFailure();
}
}
ConvertToDynamicSharedMemory(m);
}
};
} // namespace
} // namespace mlir::iree_compiler