| // Copyright 2021 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.h" |
| #include "iree/compiler/Codegen/PassDetail.h" |
| #include "iree/compiler/Codegen/Passes.h" |
| #include "iree/compiler/Codegen/Utils/Utils.h" |
| #include "iree/compiler/Dialect/Util/IR/UtilOps.h" |
| #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" |
| #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" |
| #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
| #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" |
| #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
| #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" |
| #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" |
| #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" |
| #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" |
| #include "mlir/Dialect/GPU/Passes.h" |
| #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |
| #include "mlir/Dialect/Vector/VectorOps.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| #include "mlir/Transforms/Passes.h" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| |
| namespace { |
| |
| /// A pass that replaces all occurrences of GPU device operations with their |
| /// corresponding ROCDL equivalent. |
| /// |
| /// This pass only handles device code and is not meant to be run on GPU host |
| /// code. |
| struct ConvertToROCDLPass : public ConvertToROCDLBase<ConvertToROCDLPass> { |
| void getDependentDialects(DialectRegistry ®istry) const override { |
| registry.insert<LLVM::LLVMDialect, ROCDL::ROCDLDialect>(); |
| } |
| void runOnOperation() override { |
| ModuleOp m = getOperation(); |
| |
| /// Customize the bitwidth used for the device side index computations. |
| LowerToLLVMOptions options(m.getContext(), DataLayout(m)); |
| options.overrideIndexBitwidth(64); |
| LLVMTypeConverter converter(m.getContext(), options); |
| // Apply in-dialect lowering first. In-dialect lowering will replace ops |
| // which need to be lowered further, which is not supported by a single |
| // conversion pass. |
| // Run Vector -> Vector transformations ahead of conversion to LLVM. |
| { |
| RewritePatternSet patterns(&getContext()); |
| populateScalarizeMathOps(patterns); |
| populateConvertSharedMemoryAllocOps(patterns); |
| populateLowerHALInterfaceOp(patterns); |
| vector::populateVectorToVectorCanonicalizationPatterns(patterns); |
| vector::populateVectorBroadcastLoweringPatterns(patterns); |
| vector::populateVectorContractLoweringPatterns( |
| patterns, |
| vector::VectorTransformsOptions().setVectorTransformsOptions( |
| vector::VectorContractLowering::OuterProduct)); |
| vector::populateVectorMaskOpLoweringPatterns(patterns); |
| vector::populateVectorShapeCastLoweringPatterns(patterns); |
| mlir::vector::populateVectorTransposeLoweringPatterns(patterns); |
| vector::populateVectorTransferLoweringPatterns(patterns); |
| if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) { |
| return signalPassFailure(); |
| } |
| } |
| { |
| RewritePatternSet patterns(&getContext()); |
| populateGpuRewritePatterns(patterns); |
| if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) { |
| return signalPassFailure(); |
| } |
| } |
| { |
| RewritePatternSet llvmPatterns(&getContext()); |
| populateLLVMConversionPatterns(&getContext(), llvmPatterns, converter); |
| populateMathToLLVMConversionPatterns(converter, llvmPatterns); |
| populateMemRefToLLVMConversionPatterns(converter, llvmPatterns); |
| populateStdToLLVMConversionPatterns(converter, llvmPatterns); |
| arith::populateArithmeticToLLVMConversionPatterns(converter, |
| llvmPatterns); |
| populateVectorToLLVMConversionPatterns(converter, llvmPatterns); |
| populateGpuToROCDLConversionPatterns(converter, llvmPatterns, |
| gpu::amd::Runtime::Unknown); |
| LLVMConversionTarget target(getContext()); |
| populateStdToLLVMFuncOpConversionPattern(converter, llvmPatterns); |
| configureGpuToROCDLConversionLegality(target); |
| target.addDynamicallyLegalOp<FuncOp>([&](FuncOp funcOp) { |
| if (isEntryPoint(funcOp)) return false; |
| return true; |
| }); |
| if (failed(applyPartialConversion(m, target, std::move(llvmPatterns)))) |
| signalPassFailure(); |
| } |
| } |
| }; |
| |
| } // anonymous namespace |
| |
| std::unique_ptr<OperationPass<ModuleOp>> createConvertToROCDLPass() { |
| return std::make_unique<ConvertToROCDLPass>(); |
| } |
| |
| } // namespace iree_compiler |
| } // namespace mlir |