| // Copyright 2021 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.h" |
| #include "iree/compiler/Codegen/PassDetail.h" |
| #include "iree/compiler/Codegen/Passes.h" |
| #include "iree/compiler/Codegen/Utils/Utils.h" |
| #include "iree/compiler/Dialect/Util/IR/UtilOps.h" |
| #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" |
| #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" |
| #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
| #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" |
| #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
| #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" |
| #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" |
| #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" |
| #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" |
| #include "mlir/Dialect/GPU/Passes.h" |
| #include "mlir/Dialect/LLVMIR/NVVMDialect.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |
| #include "mlir/Dialect/Vector/VectorOps.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| #include "mlir/Transforms/Passes.h" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| |
| namespace { |
| |
| /// A pass that replaces all occurrences of GPU device operations with their |
| /// corresponding NVVM equivalent. |
| /// |
| /// This pass only handles device code and is not meant to be run on GPU host |
| /// code. |
| struct ConvertToNVVMPass : public ConvertToNVVMBase<ConvertToNVVMPass> { |
| void getDependentDialects(DialectRegistry ®istry) const override { |
| registry.insert<LLVM::LLVMDialect, NVVM::NVVMDialect>(); |
| } |
| void runOnOperation() override { |
| ModuleOp m = getOperation(); |
| |
| /// Customize the bitwidth used for the device side index computations. |
| LowerToLLVMOptions options(m.getContext(), DataLayout(m)); |
| options.overrideIndexBitwidth(64); |
| LLVMTypeConverter converter(m.getContext(), options); |
| // Lowering for MMAMatrixType. |
| converter.addConversion([&](gpu::MMAMatrixType type) -> Type { |
| return convertMMAToLLVMType(type); |
| }); |
| // Apply in-dialect lowering first. In-dialect lowering will replace ops |
| // which need to be lowered further, which is not supported by a single |
| // conversion pass. |
| // Run Vector -> Vector transformations ahead of conversion to LLVM. |
| { |
| RewritePatternSet patterns(&getContext()); |
| populateScalarizeMathOps(patterns); |
| populateConvertSharedMemoryAllocOps(patterns); |
| populateLowerHALInterfaceOp(patterns); |
| vector::populateVectorToVectorCanonicalizationPatterns(patterns); |
| vector::populateVectorBroadcastLoweringPatterns(patterns); |
| vector::populateVectorContractLoweringPatterns( |
| patterns, |
| vector::VectorTransformsOptions().setVectorTransformsOptions( |
| vector::VectorContractLowering::OuterProduct)); |
| vector::populateVectorMaskOpLoweringPatterns(patterns); |
| vector::populateVectorShapeCastLoweringPatterns(patterns); |
| vector::populateVectorTransposeLoweringPatterns(patterns); |
| vector::populateVectorTransferLoweringPatterns(patterns); |
| if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) { |
| return signalPassFailure(); |
| } |
| } |
| { |
| RewritePatternSet patterns(&getContext()); |
| populateGpuRewritePatterns(patterns); |
| if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) { |
| return signalPassFailure(); |
| } |
| } |
| { |
| RewritePatternSet llvmPatterns(&getContext()); |
| populateLLVMConversionPatterns(&getContext(), llvmPatterns, converter); |
| populateMathToLLVMConversionPatterns(converter, llvmPatterns); |
| populateMemRefToLLVMConversionPatterns(converter, llvmPatterns); |
| populateStdToLLVMConversionPatterns(converter, llvmPatterns); |
| arith::populateArithmeticToLLVMConversionPatterns(converter, |
| llvmPatterns); |
| populateVectorToLLVMConversionPatterns(converter, llvmPatterns); |
| populateGpuToNVVMConversionPatterns(converter, llvmPatterns); |
| populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns); |
| LLVMConversionTarget target(getContext()); |
| populateStdToLLVMFuncOpConversionPattern(converter, llvmPatterns); |
| configureGpuToNVVMConversionLegality(target); |
| target.addDynamicallyLegalOp<FuncOp>([&](FuncOp funcOp) { |
| if (isEntryPoint(funcOp)) return false; |
| return true; |
| }); |
| if (failed(applyPartialConversion(m, target, std::move(llvmPatterns)))) { |
| signalPassFailure(); |
| } |
| } |
| } |
| }; |
| |
| } // anonymous namespace |
| |
| std::unique_ptr<OperationPass<ModuleOp>> createConvertToNVVMPass() { |
| return std::make_unique<ConvertToNVVMPass>(); |
| } |
| |
| } // namespace iree_compiler |
| } // namespace mlir |