Add software pipeline depth to lowering config (#9114)
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.cpp b/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.cpp index df9ead9..d9a613d 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.cpp
@@ -75,12 +75,13 @@ TranslationInfoAttr TranslationInfoAttr::get( MLIRContext *context, DispatchLoweringPassPipeline passPipeline, - ArrayRef<int64_t> workloadPerWorkgroup) { + ArrayRef<int64_t> workloadPerWorkgroup, unsigned softwarePipelineDepth) { auto pipelineAttr = DispatchLoweringPassPipelineAttr::get(context, passPipeline); ArrayAttr workloadPerWorkgroupAttr = getI64IntegerArrayAttr(context, workloadPerWorkgroup); - return get(context, pipelineAttr, workloadPerWorkgroupAttr); + return get(context, pipelineAttr, workloadPerWorkgroupAttr, + softwarePipelineDepth); } DispatchLoweringPassPipeline @@ -95,7 +96,7 @@ LogicalResult TranslationInfoAttr::verify( function_ref<InFlightDiagnostic()> emitError, IREE::Codegen::DispatchLoweringPassPipelineAttr passPipeline, - ArrayAttr workloadPerWorkgroup) { + ArrayAttr workloadPerWorkgroup, unsigned softwarePipelineDepth) { if (!passPipeline) { return emitError() << "missing pass pipeline specification"; } @@ -245,7 +246,8 @@ } if (failed(TranslationInfoAttr::verify( emitError, translationInfo.getPassPipeline(), - translationInfo.getWorkloadPerWorkgroup()))) { + translationInfo.getWorkloadPerWorkgroup(), + translationInfo.getSoftwarePipelineDepth()))) { return failure(); } if (workgroupSize) {
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.h b/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.h index 3a6508e..af89154 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.h
@@ -84,7 +84,8 @@ inline void setTranslationInfo( func::FuncOp entryPointFn, IREE::Codegen::DispatchLoweringPassPipeline passPipeline, - ArrayRef<int64_t> workloadPerWorkgroup, ArrayRef<int64_t> workgroupSize) { + ArrayRef<int64_t> workloadPerWorkgroup, ArrayRef<int64_t> workgroupSize, + unsigned softwarePipelineDepth = 0) { auto entryPointOp = getEntryPoint(entryPointFn); MLIRContext *context = entryPointFn.getContext(); auto translationInfo = IREE::Codegen::TranslationInfoAttr::get( @@ -117,12 +118,12 @@ inline LogicalResult setOpConfigAndEntryPointFnTranslation( func::FuncOp entryPointFn, Operation *op, TileSizesListTypeRef tileSizes, IREE::Codegen::DispatchLoweringPassPipeline passPipeline, - ArrayRef<int64_t> workgroupSize = {}) { + ArrayRef<int64_t> workgroupSize = {}, unsigned softwarePipelineDepth = 0) { MLIRContext *context = entryPointFn.getContext(); auto config = IREE::Codegen::LoweringConfigAttr::get(context, tileSizes); setLoweringConfig(op, config); auto translationInfo = IREE::Codegen::TranslationInfoAttr::get( - entryPointFn->getContext(), passPipeline); + entryPointFn->getContext(), passPipeline, {}, softwarePipelineDepth); setTranslationInfo(entryPointFn, translationInfo, workgroupSize); return success(); }
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.td b/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.td index 9ae145c..e5b117b 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.td
@@ -104,18 +104,22 @@ }]; let assemblyFormat = [{ - `<` `` $passPipeline (`workload_per_wg` `=` $workloadPerWorkgroup^)? `>` + `<` `` $passPipeline (`workload_per_wg` `=` $workloadPerWorkgroup^)? + (`pipeline_depth` `=` $softwarePipelineDepth^)? `>` }]; let parameters = (ins AttrParameter<"IREE::Codegen::DispatchLoweringPassPipelineAttr", "Name of the pipeline to be invoked on the translation unit.">:$passPipeline, DefaultValuedParameter<"ArrayAttr", "ArrayAttr::get($_ctx, {})", - "The workload mapped to a single workgroup">:$workloadPerWorkgroup + "The workload mapped to a single workgroup">:$workloadPerWorkgroup, + DefaultValuedParameter<"unsigned", "1", + "The software pipeline depth to be used">:$softwarePipelineDepth ); let builders = [ AttrBuilder<(ins "DispatchLoweringPassPipeline":$passPipeline, - CArg<"ArrayRef<int64_t>", "{}">:$workloadPerWorkgroup)> + CArg<"ArrayRef<int64_t>", "{}">:$workloadPerWorkgroup, + CArg<"unsigned", "0">:$softwarePipelineDepth)> ]; let extraClassDeclaration = [{ // Returns the lowering pass pipeline set.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index c407baf..c6964d7 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -29,6 +29,11 @@ std::array<int64_t, 3> tileSize; std::array<int64_t, 3> workgroupSize; }; + +// Software pipeline depths +constexpr unsigned softwarePipelineDepthTensorCore = 4; +// Simt codegen does not do software pipelining. +constexpr unsigned softwarePipelineDepthSimt = 0; } // namespace /// Return the best combination of tile size and wg size. It will then used to @@ -118,6 +123,7 @@ auto setMatmulConfig = [&entryPoint, &op](int64_t tileX, int64_t tileY, int64_t tileK, llvm::ArrayRef<int64_t> workgroupSize, + unsigned softwarePipelineDepth, IREE::Codegen::DispatchLoweringPassPipeline pipeline) { TileSizesListType tileSizes; unsigned numParallelLoops = op.getNumParallelLoops(); @@ -140,7 +146,8 @@ tileSizes.emplace_back( std::move(workgroupTileSizes)); // Workgroup level. return setOpConfigAndEntryPointFnTranslation(entryPoint, op, tileSizes, - pipeline, workgroupSize); + pipeline, workgroupSize, + softwarePipelineDepth); }; // Infer the MxN size of the matmul based on operands and indexing maps. auto lhsShape = @@ -196,17 +203,19 @@ if (sizeK % config.tileSize[2] == 0 && sizeN % config.tileSize[1] == 0 && sizeM % config.tileSize[0] == 0) { - return setMatmulConfig(config.tileSize[0], config.tileSize[1], - config.tileSize[2], config.workgroupSize, - IREE::Codegen::DispatchLoweringPassPipeline:: - LLVMGPUMatmulTensorCore); + return setMatmulConfig( + config.tileSize[0], config.tileSize[1], config.tileSize[2], + config.workgroupSize, + sizeK == config.tileSize[2] ? 1 : softwarePipelineDepthTensorCore, + IREE::Codegen::DispatchLoweringPassPipeline:: + LLVMGPUMatmulTensorCore); } } } // Special case for very small matrices. if (sizeM * sizeN <= cudaWarpSize) { return setMatmulConfig( - sizeN, sizeM, 4, {sizeM, sizeN, 1}, + sizeN, sizeM, 4, {sizeM, sizeN, 1}, softwarePipelineDepthSimt, IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUMatmulSimt); } // simt matmul case @@ -219,7 +228,7 @@ if (sizeN % config.tileSize[1] == 0 && sizeM % config.tileSize[0] == 0) { return setMatmulConfig( config.tileSize[0], config.tileSize[1], config.tileSize[2], - config.workgroupSize, + config.workgroupSize, softwarePipelineDepthSimt, IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUMatmulSimt); } } @@ -230,7 +239,7 @@ int64_t tileK = 4; SmallVector<int64_t, 3> workgroupSize = {2 * cudaWarpSize, 1, 1}; return setMatmulConfig( - tileX, tileY, tileK, workgroupSize, + tileX, tileY, tileK, workgroupSize, softwarePipelineDepthSimt, IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUMatmulSimt); }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp index eb89819..4f24c26 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
@@ -155,7 +155,9 @@ addGPUMatmulSimtPassPipeline(nestedModulePM); break; case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUMatmulTensorCore: - addGPUMatmulTensorCorePassPipeline(nestedModulePM); + addGPUMatmulTensorCorePassPipeline( + nestedModulePM, + translationInfo.getValue().getSoftwarePipelineDepth()); break; default: variantOp.emitOpError("Unsupported pipeline on GPU target.");
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index db61284..752d5cf 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -24,11 +24,6 @@ namespace mlir { namespace iree_compiler { -// TODO(thomasraoux): Add a new optional attribute to translate info. -static llvm::cl::opt<unsigned> pipelineDepth("iree-codegen-cuda-pipeline-depth", - llvm::cl::desc("Pipeline depth"), - llvm::cl::init(4)); - static llvm::cl::opt<unsigned> logSwizzleTile( "iree-codegen-log-swizzle-tile", llvm::cl::desc("log swizzle tile value"), llvm::cl::init(0)); @@ -109,7 +104,8 @@ pm.addNestedPass<func::FuncOp>(createGPUPipeliningPass()); } -void addGPUMatmulTensorCorePassPipeline(OpPassManager &pm) { +void addGPUMatmulTensorCorePassPipeline(OpPassManager &pm, + unsigned pipelineDepth) { tileAndBufferize(pm); // Distribute linalg onto warps within the workgroup.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Verifiers.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Verifiers.cpp index e832517..da06e2e 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Verifiers.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Verifiers.cpp
@@ -96,6 +96,7 @@ ArrayRef<int64_t> workgroupSize) { auto pipeline = IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUMatmulTensorCore; + unsigned softwarePipelinedepth = translationInfo.getSoftwarePipelineDepth(); StringRef pipelineName = stringifyEnum(pipeline); if (workgroupSize.empty()) { return op->emitOpError("expected workgroup size for GPU pipelines"); @@ -153,7 +154,15 @@ return op->emitOpError("workgroup size is not 32 aligned for ") << pipelineName << ", got " << workgroupSize[0]; } - + if (softwarePipelinedepth > 1 && firstLevelTileSizes[2] == lhsShape[1]) { + return op->emitError( + "Software pipelining is not supported when first level K tile " + "size is same as matrix reduction size.\n This dispatch has\nk " + "tile: ") + << firstLevelTileSizes[2] + << "\nMatrix reduction size: " << lhsShape[1] + << "\nPipelinedepth: " << softwarePipelinedepth; + } // Verify the workgroup.z component should always be 1 if (workgroupSize[2] != 1) { return op->emitOpError("expected workgroup z component to be 1 for ")
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir index 7c62e49..5698b63 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir
@@ -1,5 +1,4 @@ -// RUN: iree-opt --split-input-file --pass-pipeline='hal.executable(hal.executable.variant(iree-codegen-linalg-to-nvvm-pipeline))' --iree-codegen-cuda-pipeline-depth=1 %s | FileCheck %s -// RUN: iree-opt --split-input-file --pass-pipeline='hal.executable(hal.executable.variant(iree-codegen-linalg-to-nvvm-pipeline))' --iree-codegen-cuda-pipeline-depth=4 %s | FileCheck %s --check-prefix=CHECKP +// RUN: iree-opt --split-input-file --pass-pipeline='hal.executable(hal.executable.variant(iree-codegen-linalg-to-nvvm-pipeline))' %s | FileCheck %s // Verify that a simple element wise op gets lowered succefully all the way to // nvvm/llvm dialect. @@ -421,58 +420,40 @@ } } +// case with larger pipeline depth // CHECK-LABEL: hal.executable public @mma_fused // CHECK: hal.executable.variant public @cuda // CHECK-NOT: llvm.store // CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 // CHECK: nvvm.cp.async.commit.group +// CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 +// CHECK: nvvm.cp.async.commit.group +// CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 +// CHECK: nvvm.cp.async.commit.group +// CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 +// CHECK: nvvm.cp.async.commit.group // CHECK: llvm.br -// CHECK: nvvm.cp.async.wait.group 0 +// CHECK: nvvm.cp.async.wait.group 3 // CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32) // CHECK-COUNT-2: nvvm.wmma.mma // CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 // CHECK: nvvm.cp.async.commit.group // CHECK: llvm.br +// CHECK: nvvm.cp.async.wait.group 3 +// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32) +// CHECK-COUNT-2: nvvm.wmma.mma +// CHECK: nvvm.cp.async.wait.group 2 +// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32) +// CHECK-COUNT-2: nvvm.wmma.mma +// CHECK: nvvm.cp.async.wait.group 1 +// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32) +// CHECK-COUNT-2: nvvm.wmma.mma // CHECK: nvvm.cp.async.wait.group 0 // CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32) // CHECK-COUNT-2: nvvm.wmma.mma // CHECK-COUNT-8: llvm.fadd // CHECK-COUNT-1: nvvm.wmma.store {{.*}} : !llvm.ptr<f32>, f32, f32, f32, f32, f32, f32, f32, f32 -// case with larger pipeline depth -// CHECKP-LABEL: hal.executable public @mma_fused -// CHECKP: hal.executable.variant public @cuda -// CHECKP-NOT: llvm.store -// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 -// CHECKP: nvvm.cp.async.commit.group -// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 -// CHECKP: nvvm.cp.async.commit.group -// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 -// CHECKP: nvvm.cp.async.commit.group -// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 -// CHECKP: nvvm.cp.async.commit.group -// CHECKP: llvm.br -// CHECKP: nvvm.cp.async.wait.group 3 -// CHECKP-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32) -// CHECKP-COUNT-2: nvvm.wmma.mma -// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 -// CHECKP: nvvm.cp.async.commit.group -// CHECKP: llvm.br -// CHECKP: nvvm.cp.async.wait.group 3 -// CHECKP-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32) -// CHECKP-COUNT-2: nvvm.wmma.mma -// CHECKP: nvvm.cp.async.wait.group 2 -// CHECKP-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32) -// CHECKP-COUNT-2: nvvm.wmma.mma -// CHECKP: nvvm.cp.async.wait.group 1 -// CHECKP-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32) -// CHECKP-COUNT-2: nvvm.wmma.mma -// CHECKP: nvvm.cp.async.wait.group 0 -// CHECKP-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32) -// CHECKP-COUNT-2: nvvm.wmma.mma -// CHECKP-COUNT-8: llvm.fadd -// CHECKP-COUNT-1: nvvm.wmma.store {{.*}} : !llvm.ptr<f32>, f32, f32, f32, f32, f32, f32, f32, f32 - // ----- #executable_layout = #hal.executable.layout<push_constants = 0, sets = [ @@ -521,57 +502,40 @@ } } +// case with larger pipeline depth // CHECK-LABEL: hal.executable public @mma_fused_fp16 // CHECK: hal.executable.variant public @cuda // CHECK-NOT: llvm.store // CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 // CHECK: nvvm.cp.async.commit.group +// CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 +// CHECK: nvvm.cp.async.commit.group +// CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 +// CHECK: nvvm.cp.async.commit.group +// CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 +// CHECK: nvvm.cp.async.commit.group // CHECK: llvm.br -// CHECK: nvvm.cp.async.wait.group 0 +// CHECK: nvvm.cp.async.wait.group 3 // CHECK-COUNT-2: nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>) // CHECK-COUNT-1: nvvm.wmma.mma // CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 // CHECK: nvvm.cp.async.commit.group // CHECK: llvm.br +// CHECK: nvvm.cp.async.wait.group 3 +// CHECK-COUNT-2: nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>) +// CHECK-COUNT-1: nvvm.wmma.mma +// CHECK: nvvm.cp.async.wait.group 2 +// CHECK-COUNT-2: nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>) +// CHECK-COUNT-1: nvvm.wmma.mma +// CHECK: nvvm.cp.async.wait.group 1 +// CHECK-COUNT-2: nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>) +// CHECK-COUNT-1: nvvm.wmma.mma +// CHECK: nvvm.cp.async.wait.group 0 // CHECK-COUNT-2: nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>) // CHECK-COUNT-1: nvvm.wmma.mma // CHECK-COUNT-4: llvm.fadd // CHECK-COUNT-1: nvvm.wmma.store {{.*}} : !llvm.ptr<f16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16> -// case with larger pipeline depth -// CHECKP-LABEL: hal.executable public @mma_fused_fp16 -// CHECKP: hal.executable.variant public @cuda -// CHECKP-NOT: llvm.store -// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 -// CHECKP: nvvm.cp.async.commit.group -// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 -// CHECKP: nvvm.cp.async.commit.group -// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 -// CHECKP: nvvm.cp.async.commit.group -// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 -// CHECKP: nvvm.cp.async.commit.group -// CHECKP: llvm.br -// CHECKP: nvvm.cp.async.wait.group 3 -// CHECKP-COUNT-2: nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>) -// CHECKP-COUNT-1: nvvm.wmma.mma -// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 -// CHECKP: nvvm.cp.async.commit.group -// CHECKP: llvm.br -// CHECKP: nvvm.cp.async.wait.group 3 -// CHECKP-COUNT-2: nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>) -// CHECKP-COUNT-1: nvvm.wmma.mma -// CHECKP: nvvm.cp.async.wait.group 2 -// CHECKP-COUNT-2: nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>) -// CHECKP-COUNT-1: nvvm.wmma.mma -// CHECKP: nvvm.cp.async.wait.group 1 -// CHECKP-COUNT-2: nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>) -// CHECKP-COUNT-1: nvvm.wmma.mma -// CHECKP: nvvm.cp.async.wait.group 0 -// CHECKP-COUNT-2: nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>) -// CHECKP-COUNT-1: nvvm.wmma.mma -// CHECKP-COUNT-4: llvm.fadd -// CHECKP-COUNT-1: nvvm.wmma.store {{.*}} : !llvm.ptr<f16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16> - // ----- #executable_layout = #hal.executable.layout<push_constants = 0, sets = [ @@ -620,56 +584,40 @@ } } } + +// case with larger pipeline depth // CHECK-LABEL: hal.executable public @large_dot_general_dispatch_0 // CHECK: hal.executable.variant public @cuda // CHECK-NOT: llvm.store // CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 // CHECK: nvvm.cp.async.commit.group +// CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 +// CHECK: nvvm.cp.async.commit.group +// CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 +// CHECK: nvvm.cp.async.commit.group +// CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 +// CHECK: nvvm.cp.async.commit.group // CHECK: llvm.br -// CHECK: nvvm.cp.async.wait.group 0 +// CHECK: nvvm.cp.async.wait.group 3 // CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32) // CHECK-COUNT-2: nvvm.wmma.mma // CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 // CHECK: nvvm.cp.async.commit.group // CHECK: llvm.br +// CHECK: nvvm.cp.async.wait.group 3 +// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32) +// CHECK-COUNT-2: nvvm.wmma.mma +// CHECK: nvvm.cp.async.wait.group 2 +// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32) +// CHECK-COUNT-2: nvvm.wmma.mma +// CHECK: nvvm.cp.async.wait.group 1 +// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32) +// CHECK-COUNT-2: nvvm.wmma.mma // CHECK: nvvm.cp.async.wait.group 0 // CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32) // CHECK-COUNT-2: nvvm.wmma.mma // CHECK-COUNT-1: nvvm.wmma.store {{.*}} : !llvm.ptr<f32>, f32, f32, f32, f32, f32, f32, f32, f32 -// case with larger pipeline depth -// CHECKP-LABEL: hal.executable public @large_dot_general_dispatch_0 -// CHECKP: hal.executable.variant public @cuda -// CHECKP-NOT: llvm.store -// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 -// CHECKP: nvvm.cp.async.commit.group -// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 -// CHECKP: nvvm.cp.async.commit.group -// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 -// CHECKP: nvvm.cp.async.commit.group -// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 -// CHECKP: nvvm.cp.async.commit.group -// CHECKP: llvm.br -// CHECKP: nvvm.cp.async.wait.group 3 -// CHECKP-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32) -// CHECKP-COUNT-2: nvvm.wmma.mma -// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 -// CHECKP: nvvm.cp.async.commit.group -// CHECKP: llvm.br -// CHECKP: nvvm.cp.async.wait.group 3 -// CHECKP-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32) -// CHECKP-COUNT-2: nvvm.wmma.mma -// CHECKP: nvvm.cp.async.wait.group 2 -// CHECKP-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32) -// CHECKP-COUNT-2: nvvm.wmma.mma -// CHECKP: nvvm.cp.async.wait.group 1 -// CHECKP-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32) -// CHECKP-COUNT-2: nvvm.wmma.mma -// CHECKP: nvvm.cp.async.wait.group 0 -// CHECKP-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32) -// CHECKP-COUNT-2: nvvm.wmma.mma -// CHECKP-COUNT-1: nvvm.wmma.store {{.*}} : !llvm.ptr<f32>, f32, f32, f32, f32, f32, f32, f32, f32 - // ----- #executable_layout = #hal.executable.layout<push_constants = 0, sets = [ @@ -722,12 +670,15 @@ // CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 // CHECK: nvvm.cp.async.commit.group // CHECK: llvm.br -// CHECK: nvvm.cp.async.wait.group 0 +// CHECK: nvvm.cp.async.wait.group 3 // CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32) // CHECK-COUNT-2: nvvm.wmma.mma -// CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 -// CHECK: nvvm.cp.async.commit.group -// CHECK: llvm.br +// CHECK: nvvm.cp.async.wait.group 2 +// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32) +// CHECK-COUNT-2: nvvm.wmma.mma +// CHECK: nvvm.cp.async.wait.group 1 +// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32) +// CHECK-COUNT-2: nvvm.wmma.mma // CHECK: nvvm.cp.async.wait.group 0 // CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32) // CHECK-COUNT-2: nvvm.wmma.mma
diff --git a/compiler/src/iree/compiler/Codegen/Passes.h b/compiler/src/iree/compiler/Codegen/Passes.h index 9799258..928e0c6 100644 --- a/compiler/src/iree/compiler/Codegen/Passes.h +++ b/compiler/src/iree/compiler/Codegen/Passes.h
@@ -327,7 +327,8 @@ Operation *op, IREE::Codegen::LoweringConfigAttr loweringConfig, IREE::Codegen::TranslationInfoAttr translationInfo, ArrayRef<int64_t> workgroupSize); -void addGPUMatmulTensorCorePassPipeline(OpPassManager &pm); +void addGPUMatmulTensorCorePassPipeline(OpPassManager &pm, + unsigned pipelineDepth); /// Simple lowering only distributute linalg ops on blocks and threads. This /// will result in scalar operations. Expects pass manager to be a module-level
diff --git a/iree/test/e2e/matmul/generate_e2e_matmul_tests.py b/iree/test/e2e/matmul/generate_e2e_matmul_tests.py index 62c7721..cb8b8b6 100644 --- a/iree/test/e2e/matmul/generate_e2e_matmul_tests.py +++ b/iree/test/e2e/matmul/generate_e2e_matmul_tests.py
@@ -77,6 +77,7 @@ # Translation Info dispatch_lowering_pass_pipeline: str workload_per_wg: typing.List[int] + software_pipeline_depth: int # Compilation info workgroup_size: typing.List[int] @@ -189,7 +190,8 @@ workload_per_wg=[ a for a in reversed(tile_workgroup_size_pair.tile_size[0:2]) ], - workgroup_size=tile_workgroup_size_pair.workgroup_size)) + workgroup_size=tile_workgroup_size_pair.workgroup_size, + software_pipeline_depth=4)) return compilation_infos @@ -355,7 +357,8 @@ compilation_info_string = ( f"#compilation{generate_function.compilation_index} = #iree_codegen.compilation_info<\n" f" lowering_config = <tile_sizes = [{compilation_info.tile_sizes}]>,\n" - f" translation_info = <{compilation_info.dispatch_lowering_pass_pipeline}>,\n" + f" translation_info = <{compilation_info.dispatch_lowering_pass_pipeline}\n" + f" pipeline_depth = {compilation_info.software_pipeline_depth}>,\n" f" workgroup_size = {compilation_info.workgroup_size_str()}>\n") compilation_info_attr = f"{{compilation_info = #compilation{generate_function.compilation_index}}} " func_definition = func_definition + compilation_info_string