| // Copyright 2021 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" |
| #include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h" |
| #include "iree/compiler/Codegen/Dialect/LoweringConfig.h" |
| #include "iree/compiler/Codegen/LLVMGPU/KernelConfig.h" |
| #include "iree/compiler/Codegen/LLVMGPU/LLVMGPUUtils.h" |
| #include "iree/compiler/Codegen/PassDetail.h" |
| #include "iree/compiler/Codegen/Passes.h" |
| #include "iree/compiler/Codegen/Transforms/Transforms.h" |
| #include "iree/compiler/Codegen/Utils/MarkerUtils.h" |
| #include "iree/compiler/Dialect/Util/IR/UtilOps.h" |
| #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" |
| #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" |
| #include "mlir/Dialect/GPU/Passes.h" |
| #include "mlir/Dialect/LLVMIR/NVVMDialect.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/Support/MathExtras.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| #include "mlir/Transforms/Passes.h" |
| |
| #define DEBUG_TYPE "iree-llvmgpu-tile-and-distribute" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| |
| /// Patterns for workgroup level tiling. Workgroup tiling is done at the flow |
| /// level but we may have extra tiling for the reduction dimension. Therefore we |
| /// tile again without distributing. |
| static void populateTilingReductionPatterns( |
| OwningRewritePatternList &patterns) { |
| auto tileSizesFn = [&](OpBuilder &builder, |
| Operation *op) -> SmallVector<Value, 4> { |
| SmallVector<unsigned> partitionedLoops = getPartitionedLoops(op); |
| SmallVector<Value, 4> tileSizes = getTileSizes(builder, op, 0); |
| auto zero = builder.create<arith::ConstantIndexOp>(op->getLoc(), 0); |
| for (unsigned depth : partitionedLoops) { |
| if (depth < tileSizes.size()) { |
| tileSizes[depth] = zero; |
| } |
| } |
| return tileSizes; |
| }; |
| |
| auto tilingOptions = linalg::LinalgTilingOptions() |
| .setLoopType(linalg::LinalgTilingLoopType::Loops) |
| .setTileSizeComputationFunction(tileSizesFn); |
| MLIRContext *context = patterns.getContext(); |
| |
| linalg::LinalgTransformationFilter filter( |
| ArrayRef<StringAttr>{}, |
| StringAttr::get(context, getWorkgroupKTiledMarker())); |
| linalg::TilingPatterns<linalg::MatmulOp, linalg::BatchMatmulOp, |
| linalg::GenericOp>::insert(patterns, tilingOptions, |
| filter); |
| } |
| |
| /// Patterns for warp level tiling. |
| static void populateTilingToWarpPatterns( |
| OwningRewritePatternList &patterns, SmallVectorImpl<int64_t> &workgroupSize, |
| SmallVectorImpl<int64_t> &workloadPerWorkgroup) { |
| std::array<int64_t, 3> warpPerWorkgroup = { |
| workgroupSize[0] / kWarpSize, workgroupSize[1], workgroupSize[2]}; |
| |
| linalg::TileSizeComputationFunction getInnerTileSizeFn = |
| [&workloadPerWorkgroup, warpPerWorkgroup](OpBuilder &builder, |
| Operation *operation) { |
| SmallVector<Value, 4> tileSizesVal; |
| SmallVector<int64_t, 4> tileSizes; |
| for (auto workload : llvm::enumerate(workloadPerWorkgroup)) { |
| tileSizes.push_back(workload.value() / |
| warpPerWorkgroup[workload.index()]); |
| } |
| std::reverse(tileSizes.begin(), tileSizes.end()); |
| if (tileSizes.empty()) return SmallVector<Value, 4>(); |
| SmallVector<unsigned> partitionedLoops = getPartitionedLoops(operation); |
| unsigned maxDepth = partitionedLoops.back() + 1; |
| auto zero = |
| builder.create<arith::ConstantIndexOp>(operation->getLoc(), 0); |
| tileSizesVal.resize(maxDepth, zero); |
| size_t tileSizeIdx = 0; |
| for (unsigned depth : partitionedLoops) { |
| tileSizesVal[depth] = builder.create<arith::ConstantIndexOp>( |
| operation->getLoc(), tileSizes[tileSizeIdx++]); |
| if (tileSizeIdx == tileSizes.size()) break; |
| } |
| return tileSizesVal; |
| }; |
| auto getWarpProcInfoFn = [warpPerWorkgroup]( |
| OpBuilder &builder, Location loc, |
| ArrayRef<Range> parallelLoopRanges) { |
| return getSubgroupIdsAndCounts(builder, loc, parallelLoopRanges.size(), |
| warpPerWorkgroup); |
| }; |
| linalg::LinalgLoopDistributionOptions warpDistributionOptions; |
| warpDistributionOptions.procInfo = getWarpProcInfoFn; |
| warpDistributionOptions.distributionMethod = { |
| {linalg::DistributionMethod::Cyclic, linalg::DistributionMethod::Cyclic, |
| linalg::DistributionMethod::Cyclic}}; |
| |
| auto tilingOptions = linalg::LinalgTilingOptions() |
| .setLoopType(linalg::LinalgTilingLoopType::Loops) |
| .setTileSizeComputationFunction(getInnerTileSizeFn) |
| .setDistributionOptions(warpDistributionOptions); |
| MLIRContext *context = patterns.getContext(); |
| linalg::LinalgTransformationFilter filter( |
| {StringAttr::get(context, getWorkgroupKTiledMarker()), |
| StringAttr::get(context, getWorkgroupMemoryMarker())}, |
| StringAttr::get(context, getVectorizeMarker())); |
| filter.setMatchByDefault(); |
| linalg::TilingPatterns<linalg::MatmulOp, linalg::FillOp, linalg::CopyOp, |
| linalg::BatchMatmulOp, |
| linalg::GenericOp>::insert(patterns, tilingOptions, |
| filter); |
| } |
| |
| /// Patterns for thread level tiling. |
| static void populateTilingToInvocationPatterns( |
| OwningRewritePatternList &patterns, SmallVectorImpl<int64_t> &workgroupSize, |
| SmallVectorImpl<int64_t> &workloadPerWorkgroup) { |
| linalg::TileSizeComputationFunction getInnerTileSizeFn = |
| [&](OpBuilder &builder, Operation *operation) { |
| SmallVector<Value, 4> tileSizesVal; |
| SmallVector<int64_t, 4> tileSizes; |
| for (auto workload : llvm::enumerate(workloadPerWorkgroup)) { |
| tileSizes.push_back(workload.value() / |
| workgroupSize[workload.index()]); |
| } |
| std::reverse(tileSizes.begin(), tileSizes.end()); |
| if (tileSizes.empty()) return SmallVector<Value, 4>(); |
| SmallVector<unsigned> partitionedLoops = getPartitionedLoops(operation); |
| unsigned maxDepth = partitionedLoops.back() + 1; |
| auto zero = |
| builder.create<arith::ConstantIndexOp>(operation->getLoc(), 0); |
| tileSizesVal.resize(maxDepth, zero); |
| size_t tileSizeIdx = 0; |
| for (unsigned depth : partitionedLoops) { |
| tileSizesVal[depth] = builder.create<arith::ConstantIndexOp>( |
| operation->getLoc(), tileSizes[tileSizeIdx++]); |
| if (tileSizeIdx == tileSizes.size()) break; |
| } |
| return tileSizesVal; |
| }; |
| |
| auto getThreadProcInfoFn = [&workgroupSize]( |
| OpBuilder &builder, Location loc, |
| ArrayRef<Range> parallelLoopRanges) { |
| return getGPUThreadIdsAndCounts(builder, loc, parallelLoopRanges.size(), |
| workgroupSize); |
| }; |
| linalg::LinalgLoopDistributionOptions invocationDistributionOptions; |
| invocationDistributionOptions.procInfo = getThreadProcInfoFn; |
| invocationDistributionOptions.distributionMethod = { |
| {linalg::DistributionMethod::Cyclic, linalg::DistributionMethod::Cyclic, |
| linalg::DistributionMethod::Cyclic}}; |
| |
| auto tilingOptions = |
| linalg::LinalgTilingOptions() |
| .setLoopType(linalg::LinalgTilingLoopType::Loops) |
| .setTileSizeComputationFunction(getInnerTileSizeFn) |
| .setDistributionOptions(invocationDistributionOptions); |
| |
| MLIRContext *context = patterns.getContext(); |
| linalg::LinalgTransformationFilter f( |
| {StringAttr::get(context, getWorkgroupKTiledMarker()), |
| StringAttr::get(context, getWorkgroupMemoryMarker())}, |
| StringAttr::get(context, getVectorizeMarker())); |
| f.addFilter([](Operation *op) { |
| // FFT doesn't support second level of tiling yet. |
| return success(!isa<IREE::LinalgExt::FftOp>(op)); |
| }).setMatchByDefault(); |
| linalg::TilingPatterns< |
| linalg::MatmulOp, linalg::FillOp, linalg::CopyOp, linalg::BatchMatmulOp, |
| linalg::GenericOp, linalg::Conv2DNhwcHwcfOp, |
| linalg::DepthwiseConv2DNhwcHwcOp, linalg::DepthwiseConv2DNhwcHwcmOp, |
| linalg::PoolingNhwcMaxOp, linalg::PoolingNhwcMinOp, |
| linalg::PoolingNhwcSumOp>::insert(patterns, tilingOptions, f); |
| patterns.insert<IREE::LinalgExt::TiledOpInterfaceTilingPattern>( |
| context, tilingOptions, f); |
| } |
| |
| static LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst) { |
| auto copyOp = b.create<linalg::CopyOp>(src.getLoc(), src, dst); |
| setMarker(copyOp, getCopyToWorkgroupMemoryMarker()); |
| return success(); |
| } |
| |
| static Optional<Value> allocateWorkgroupMemory( |
| OpBuilder &b, memref::SubViewOp subview, |
| ArrayRef<Value> boundingSubViewSize, DataLayout &layout) { |
| // In CUDA workgroup memory is represented by a global variable. Create a |
| // global variable and a memref.GetGlobalOp at the beginning of the function |
| // to get the memref. |
| OpBuilder::InsertionGuard guard(b); |
| FuncOp funcOp = subview->getParentOfType<FuncOp>(); |
| if (!funcOp) { |
| subview.emitError("expected op to be within std.func"); |
| return llvm::None; |
| } |
| ModuleOp moduleOp = funcOp->getParentOfType<ModuleOp>(); |
| SymbolTable symbolTable(moduleOp); |
| |
| // The bounding subview size is expected to be constant. This specified the |
| // shape of the allocation. |
| SmallVector<int64_t, 2> shape = llvm::to_vector<2>( |
| llvm::map_range(boundingSubViewSize, [](Value v) -> int64_t { |
| APInt value; |
| if (matchPattern(v, m_ConstantInt(&value))) return value.getSExtValue(); |
| return -1; |
| })); |
| if (llvm::any_of(shape, [](int64_t v) { return v == -1; })) return {}; |
| MemRefType allocType = |
| MemRefType::get(shape, subview.getType().getElementType(), {}, |
| gpu::GPUDialect::getWorkgroupAddressSpace()); |
| b.setInsertionPoint(&moduleOp.front()); |
| auto global = b.create<memref::GlobalOp>( |
| funcOp.getLoc(), "__shared_memory__", |
| /*sym_visibility=*/b.getStringAttr("private"), |
| /*type=*/allocType, |
| /*initial_value=*/ElementsAttr(), |
| /*constant=*/false, /*alignment=*/IntegerAttr()); |
| symbolTable.insert(global); |
| |
| b.setInsertionPointToStart(&(*funcOp.getBody().begin())); |
| Value buffer = b.create<memref::GetGlobalOp>(funcOp.getLoc(), global.type(), |
| global.getName()); |
| return buffer; |
| } |
| |
| static LogicalResult deallocateWorkgroupMemory(OpBuilder &b, Value buffer) { |
| // Nothing to do. |
| return success(); |
| } |
| |
| static void populatePromotionPatterns(MLIRContext *context, |
| OwningRewritePatternList &patterns) { |
| patterns.insert<linalg::LinalgPromotionPattern<linalg::MatmulOp>, |
| linalg::LinalgPromotionPattern<linalg::BatchMatmulOp>>( |
| context, |
| linalg::LinalgPromotionOptions() |
| .setAllocationDeallocationFns(allocateWorkgroupMemory, |
| deallocateWorkgroupMemory) |
| .setCopyInOutFns(copyToWorkgroupMemory, copyToWorkgroupMemory) |
| .setOperandsToPromote({0, 1}) |
| .setUseFullTileBuffers({false, false}), |
| linalg::LinalgTransformationFilter( |
| {StringAttr::get(context, getWorkgroupKTiledMarker())}, |
| StringAttr::get(context, getWorkgroupMemoryMarker()))); |
| } |
| |
| namespace { |
| struct LLVMGPUTileAndDistributePass |
| : public LLVMGPUTileAndDistributeBase<LLVMGPUTileAndDistributePass> { |
| private: |
| // Distribute the workloads to warp if true otherwise distribute to threads. |
| bool distributeToWarp = false; |
| |
| public: |
| LLVMGPUTileAndDistributePass(bool distributeToWarp) |
| : distributeToWarp(distributeToWarp) {} |
| void getDependentDialects(DialectRegistry ®istry) const override { |
| registry.insert<AffineDialect, gpu::GPUDialect>(); |
| } |
| void runOnOperation() override { |
| MLIRContext *context = &getContext(); |
| auto funcOp = getOperation(); |
| if (!isEntryPoint(funcOp)) return; |
| { |
| // Tile again at the workgroup level since redution dimension were |
| // ignored. Dimensions already tiled will be ignore since we tile to the |
| // same size. |
| OwningRewritePatternList wgTilingPatterns(context); |
| populateTilingReductionPatterns(wgTilingPatterns); |
| if (failed(applyPatternsAndFoldGreedily(funcOp, |
| std::move(wgTilingPatterns)))) { |
| return signalPassFailure(); |
| } |
| } |
| |
| { |
| RewritePatternSet wgTilingCanonicalizationPatterns = |
| linalg::getLinalgTilingCanonicalizationPatterns(context); |
| populateAffineMinSCFCanonicalizationPattern( |
| wgTilingCanonicalizationPatterns); |
| if (failed(applyPatternsAndFoldGreedily( |
| funcOp, std::move(wgTilingCanonicalizationPatterns)))) { |
| return signalPassFailure(); |
| } |
| } |
| |
| LLVM_DEBUG({ |
| llvm::dbgs() << "After tile reductions:"; |
| funcOp.dump(); |
| }); |
| |
| auto workgroupSize = llvm::to_vector<4>(llvm::map_range( |
| getEntryPoint(funcOp).workgroup_size().getValue(), |
| [&](Attribute attr) { return attr.cast<IntegerAttr>().getInt(); })); |
| auto workloadPerWorkgroup = |
| getTranslationInfo(getEntryPoint(funcOp)).getWorkloadPerWorkgroupVals(); |
| |
| int64_t flatWorkgroupSize = |
| workgroupSize[0] * workgroupSize[1] * workgroupSize[2]; |
| // Only promote to workgroup size if there are multiple warps. |
| if (flatWorkgroupSize > kWarpSize) { |
| OwningRewritePatternList promotionPatterns(&getContext()); |
| populatePromotionPatterns(context, promotionPatterns); |
| if (failed(applyPatternsAndFoldGreedily(funcOp, |
| std::move(promotionPatterns)))) { |
| return signalPassFailure(); |
| } |
| // Insert barriers before and after copies to workgroup memory and skip |
| // insert barriers between back to back copy to workgroup memory. |
| OpBuilder builder(&getContext()); |
| funcOp.walk([&builder](linalg::CopyOp copyOp) { |
| if (hasMarker(copyOp, getCopyToWorkgroupMemoryMarker())) { |
| Operation *prevOp = copyOp->getPrevNode(); |
| if (!prevOp || !hasMarker(prevOp, getCopyToWorkgroupMemoryMarker())) { |
| builder.setInsertionPoint(copyOp); |
| builder.create<gpu::BarrierOp>(copyOp.getLoc()); |
| } |
| Operation *nextOp = copyOp->getNextNode(); |
| if (!nextOp || !hasMarker(nextOp, getCopyToWorkgroupMemoryMarker())) { |
| builder.setInsertionPointAfter(copyOp); |
| builder.create<gpu::BarrierOp>(copyOp.getLoc()); |
| } |
| } |
| }); |
| } |
| |
| { |
| RewritePatternSet promotionCanonicalization = |
| linalg::getLinalgTilingCanonicalizationPatterns(context); |
| if (failed(applyPatternsAndFoldGreedily( |
| funcOp, std::move(promotionCanonicalization)))) { |
| return signalPassFailure(); |
| } |
| } |
| |
| LLVM_DEBUG({ |
| llvm::dbgs() << "After promotion:"; |
| funcOp.dump(); |
| }); |
| |
| if (distributeToWarp) { |
| // Apply last level of tiling and distribute to warps. |
| OwningRewritePatternList warpLevelTilingPatterns(context); |
| populateTilingToWarpPatterns(warpLevelTilingPatterns, workgroupSize, |
| workloadPerWorkgroup); |
| if (failed(applyPatternsAndFoldGreedily( |
| funcOp, std::move(warpLevelTilingPatterns)))) { |
| return signalPassFailure(); |
| } |
| |
| } else { |
| // Apply last level of tiling and distribute to threads. |
| OwningRewritePatternList threadLevelTilingPatterns(context); |
| populateTilingToInvocationPatterns(threadLevelTilingPatterns, |
| workgroupSize, workloadPerWorkgroup); |
| if (failed(applyPatternsAndFoldGreedily( |
| funcOp, std::move(threadLevelTilingPatterns)))) { |
| return signalPassFailure(); |
| } |
| } |
| { |
| // Apply canonicalization patterns. |
| RewritePatternSet threadTilingCanonicalizationPatterns = |
| linalg::getLinalgTilingCanonicalizationPatterns(context); |
| populateAffineMinSCFCanonicalizationPattern( |
| threadTilingCanonicalizationPatterns); |
| if (failed(applyPatternsAndFoldGreedily( |
| funcOp, std::move(threadTilingCanonicalizationPatterns)))) { |
| return signalPassFailure(); |
| } |
| } |
| |
| LLVM_DEBUG({ |
| llvm::dbgs() << "After tile and distribute to threads:"; |
| funcOp.dump(); |
| }); |
| } |
| }; |
| } // namespace |
| |
| std::unique_ptr<OperationPass<FuncOp>> createLLVMGPUTileAndDistribute( |
| bool distributeToWarp) { |
| return std::make_unique<LLVMGPUTileAndDistributePass>(distributeToWarp); |
| } |
| |
| } // namespace iree_compiler |
| } // namespace mlir |