Add software pipeline depth to lowering config (#9114)
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.cpp b/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.cpp
index df9ead9..d9a613d 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.cpp
@@ -75,12 +75,13 @@
TranslationInfoAttr TranslationInfoAttr::get(
MLIRContext *context, DispatchLoweringPassPipeline passPipeline,
- ArrayRef<int64_t> workloadPerWorkgroup) {
+ ArrayRef<int64_t> workloadPerWorkgroup, unsigned softwarePipelineDepth) {
auto pipelineAttr =
DispatchLoweringPassPipelineAttr::get(context, passPipeline);
ArrayAttr workloadPerWorkgroupAttr =
getI64IntegerArrayAttr(context, workloadPerWorkgroup);
- return get(context, pipelineAttr, workloadPerWorkgroupAttr);
+ return get(context, pipelineAttr, workloadPerWorkgroupAttr,
+ softwarePipelineDepth);
}
DispatchLoweringPassPipeline
@@ -95,7 +96,7 @@
LogicalResult TranslationInfoAttr::verify(
function_ref<InFlightDiagnostic()> emitError,
IREE::Codegen::DispatchLoweringPassPipelineAttr passPipeline,
- ArrayAttr workloadPerWorkgroup) {
+ ArrayAttr workloadPerWorkgroup, unsigned softwarePipelineDepth) {
if (!passPipeline) {
return emitError() << "missing pass pipeline specification";
}
@@ -245,7 +246,8 @@
}
if (failed(TranslationInfoAttr::verify(
emitError, translationInfo.getPassPipeline(),
- translationInfo.getWorkloadPerWorkgroup()))) {
+ translationInfo.getWorkloadPerWorkgroup(),
+ translationInfo.getSoftwarePipelineDepth()))) {
return failure();
}
if (workgroupSize) {
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.h b/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.h
index 3a6508e..af89154 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.h
+++ b/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.h
@@ -84,7 +84,8 @@
inline void setTranslationInfo(
func::FuncOp entryPointFn,
IREE::Codegen::DispatchLoweringPassPipeline passPipeline,
- ArrayRef<int64_t> workloadPerWorkgroup, ArrayRef<int64_t> workgroupSize) {
+ ArrayRef<int64_t> workloadPerWorkgroup, ArrayRef<int64_t> workgroupSize,
+ unsigned softwarePipelineDepth = 0) {
auto entryPointOp = getEntryPoint(entryPointFn);
MLIRContext *context = entryPointFn.getContext();
auto translationInfo = IREE::Codegen::TranslationInfoAttr::get(
@@ -117,12 +118,12 @@
inline LogicalResult setOpConfigAndEntryPointFnTranslation(
func::FuncOp entryPointFn, Operation *op, TileSizesListTypeRef tileSizes,
IREE::Codegen::DispatchLoweringPassPipeline passPipeline,
- ArrayRef<int64_t> workgroupSize = {}) {
+ ArrayRef<int64_t> workgroupSize = {}, unsigned softwarePipelineDepth = 0) {
MLIRContext *context = entryPointFn.getContext();
auto config = IREE::Codegen::LoweringConfigAttr::get(context, tileSizes);
setLoweringConfig(op, config);
auto translationInfo = IREE::Codegen::TranslationInfoAttr::get(
- entryPointFn->getContext(), passPipeline);
+ entryPointFn->getContext(), passPipeline, {}, softwarePipelineDepth);
setTranslationInfo(entryPointFn, translationInfo, workgroupSize);
return success();
}
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.td b/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.td
index 9ae145c..e5b117b 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.td
@@ -104,18 +104,22 @@
}];
let assemblyFormat = [{
- `<` `` $passPipeline (`workload_per_wg` `=` $workloadPerWorkgroup^)? `>`
+ `<` `` $passPipeline (`workload_per_wg` `=` $workloadPerWorkgroup^)?
+ (`pipeline_depth` `=` $softwarePipelineDepth^)? `>`
}];
let parameters = (ins
AttrParameter<"IREE::Codegen::DispatchLoweringPassPipelineAttr",
"Name of the pipeline to be invoked on the translation unit.">:$passPipeline,
DefaultValuedParameter<"ArrayAttr", "ArrayAttr::get($_ctx, {})",
- "The workload mapped to a single workgroup">:$workloadPerWorkgroup
+ "The workload mapped to a single workgroup">:$workloadPerWorkgroup,
+ DefaultValuedParameter<"unsigned", "1",
+ "The software pipeline depth to be used">:$softwarePipelineDepth
);
let builders = [
AttrBuilder<(ins "DispatchLoweringPassPipeline":$passPipeline,
- CArg<"ArrayRef<int64_t>", "{}">:$workloadPerWorkgroup)>
+ CArg<"ArrayRef<int64_t>", "{}">:$workloadPerWorkgroup,
+ CArg<"unsigned", "0">:$softwarePipelineDepth)>
];
let extraClassDeclaration = [{
// Returns the lowering pass pipeline set.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index c407baf..c6964d7 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -29,6 +29,11 @@
std::array<int64_t, 3> tileSize;
std::array<int64_t, 3> workgroupSize;
};
+
+// Software pipeline depths
+constexpr unsigned softwarePipelineDepthTensorCore = 4;
+// Simt codegen does not do software pipelining.
+constexpr unsigned softwarePipelineDepthSimt = 0;
} // namespace
/// Return the best combination of tile size and wg size. It will then used to
@@ -118,6 +123,7 @@
auto setMatmulConfig =
[&entryPoint, &op](int64_t tileX, int64_t tileY, int64_t tileK,
llvm::ArrayRef<int64_t> workgroupSize,
+ unsigned softwarePipelineDepth,
IREE::Codegen::DispatchLoweringPassPipeline pipeline) {
TileSizesListType tileSizes;
unsigned numParallelLoops = op.getNumParallelLoops();
@@ -140,7 +146,8 @@
tileSizes.emplace_back(
std::move(workgroupTileSizes)); // Workgroup level.
return setOpConfigAndEntryPointFnTranslation(entryPoint, op, tileSizes,
- pipeline, workgroupSize);
+ pipeline, workgroupSize,
+ softwarePipelineDepth);
};
// Infer the MxN size of the matmul based on operands and indexing maps.
auto lhsShape =
@@ -196,17 +203,19 @@
if (sizeK % config.tileSize[2] == 0 &&
sizeN % config.tileSize[1] == 0 &&
sizeM % config.tileSize[0] == 0) {
- return setMatmulConfig(config.tileSize[0], config.tileSize[1],
- config.tileSize[2], config.workgroupSize,
- IREE::Codegen::DispatchLoweringPassPipeline::
- LLVMGPUMatmulTensorCore);
+ return setMatmulConfig(
+ config.tileSize[0], config.tileSize[1], config.tileSize[2],
+ config.workgroupSize,
+ sizeK == config.tileSize[2] ? 1 : softwarePipelineDepthTensorCore,
+ IREE::Codegen::DispatchLoweringPassPipeline::
+ LLVMGPUMatmulTensorCore);
}
}
}
// Special case for very small matrices.
if (sizeM * sizeN <= cudaWarpSize) {
return setMatmulConfig(
- sizeN, sizeM, 4, {sizeM, sizeN, 1},
+ sizeN, sizeM, 4, {sizeM, sizeN, 1}, softwarePipelineDepthSimt,
IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUMatmulSimt);
}
// simt matmul case
@@ -219,7 +228,7 @@
if (sizeN % config.tileSize[1] == 0 && sizeM % config.tileSize[0] == 0) {
return setMatmulConfig(
config.tileSize[0], config.tileSize[1], config.tileSize[2],
- config.workgroupSize,
+ config.workgroupSize, softwarePipelineDepthSimt,
IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUMatmulSimt);
}
}
@@ -230,7 +239,7 @@
int64_t tileK = 4;
SmallVector<int64_t, 3> workgroupSize = {2 * cudaWarpSize, 1, 1};
return setMatmulConfig(
- tileX, tileY, tileK, workgroupSize,
+ tileX, tileY, tileK, workgroupSize, softwarePipelineDepthSimt,
IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUMatmulSimt);
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
index eb89819..4f24c26 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
@@ -155,7 +155,9 @@
addGPUMatmulSimtPassPipeline(nestedModulePM);
break;
case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUMatmulTensorCore:
- addGPUMatmulTensorCorePassPipeline(nestedModulePM);
+ addGPUMatmulTensorCorePassPipeline(
+ nestedModulePM,
+ translationInfo.getValue().getSoftwarePipelineDepth());
break;
default:
variantOp.emitOpError("Unsupported pipeline on GPU target.");
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index db61284..752d5cf 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -24,11 +24,6 @@
namespace mlir {
namespace iree_compiler {
-// TODO(thomasraoux): Add a new optional attribute to translate info.
-static llvm::cl::opt<unsigned> pipelineDepth("iree-codegen-cuda-pipeline-depth",
- llvm::cl::desc("Pipeline depth"),
- llvm::cl::init(4));
-
static llvm::cl::opt<unsigned> logSwizzleTile(
"iree-codegen-log-swizzle-tile", llvm::cl::desc("log swizzle tile value"),
llvm::cl::init(0));
@@ -109,7 +104,8 @@
pm.addNestedPass<func::FuncOp>(createGPUPipeliningPass());
}
-void addGPUMatmulTensorCorePassPipeline(OpPassManager &pm) {
+void addGPUMatmulTensorCorePassPipeline(OpPassManager &pm,
+ unsigned pipelineDepth) {
tileAndBufferize(pm);
// Distribute linalg onto warps within the workgroup.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Verifiers.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Verifiers.cpp
index e832517..da06e2e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Verifiers.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Verifiers.cpp
@@ -96,6 +96,7 @@
ArrayRef<int64_t> workgroupSize) {
auto pipeline =
IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUMatmulTensorCore;
+ unsigned softwarePipelinedepth = translationInfo.getSoftwarePipelineDepth();
StringRef pipelineName = stringifyEnum(pipeline);
if (workgroupSize.empty()) {
return op->emitOpError("expected workgroup size for GPU pipelines");
@@ -153,7 +154,15 @@
return op->emitOpError("workgroup size is not 32 aligned for ")
<< pipelineName << ", got " << workgroupSize[0];
}
-
+ if (softwarePipelinedepth > 1 && firstLevelTileSizes[2] == lhsShape[1]) {
+ return op->emitError(
+ "Software pipelining is not supported when first level K tile "
+ "size is same as matrix reduction size.\n This dispatch has\nk "
+ "tile: ")
+ << firstLevelTileSizes[2]
+ << "\nMatrix reduction size: " << lhsShape[1]
+ << "\nPipelinedepth: " << softwarePipelinedepth;
+ }
// Verify the workgroup.z component should always be 1
if (workgroupSize[2] != 1) {
return op->emitOpError("expected workgroup z component to be 1 for ")
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir
index 7c62e49..5698b63 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir
@@ -1,5 +1,4 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='hal.executable(hal.executable.variant(iree-codegen-linalg-to-nvvm-pipeline))' --iree-codegen-cuda-pipeline-depth=1 %s | FileCheck %s
-// RUN: iree-opt --split-input-file --pass-pipeline='hal.executable(hal.executable.variant(iree-codegen-linalg-to-nvvm-pipeline))' --iree-codegen-cuda-pipeline-depth=4 %s | FileCheck %s --check-prefix=CHECKP
+// RUN: iree-opt --split-input-file --pass-pipeline='hal.executable(hal.executable.variant(iree-codegen-linalg-to-nvvm-pipeline))' %s | FileCheck %s
// Verify that a simple element wise op gets lowered succefully all the way to
// nvvm/llvm dialect.
@@ -421,58 +420,40 @@
}
}
+// case with larger pipeline depth
// CHECK-LABEL: hal.executable public @mma_fused
// CHECK: hal.executable.variant public @cuda
// CHECK-NOT: llvm.store
// CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
// CHECK: nvvm.cp.async.commit.group
+// CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
+// CHECK: nvvm.cp.async.commit.group
+// CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
+// CHECK: nvvm.cp.async.commit.group
+// CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
+// CHECK: nvvm.cp.async.commit.group
// CHECK: llvm.br
-// CHECK: nvvm.cp.async.wait.group 0
+// CHECK: nvvm.cp.async.wait.group 3
// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
// CHECK-COUNT-2: nvvm.wmma.mma
// CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
// CHECK: nvvm.cp.async.commit.group
// CHECK: llvm.br
+// CHECK: nvvm.cp.async.wait.group 3
+// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
+// CHECK-COUNT-2: nvvm.wmma.mma
+// CHECK: nvvm.cp.async.wait.group 2
+// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
+// CHECK-COUNT-2: nvvm.wmma.mma
+// CHECK: nvvm.cp.async.wait.group 1
+// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
+// CHECK-COUNT-2: nvvm.wmma.mma
// CHECK: nvvm.cp.async.wait.group 0
// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
// CHECK-COUNT-2: nvvm.wmma.mma
// CHECK-COUNT-8: llvm.fadd
// CHECK-COUNT-1: nvvm.wmma.store {{.*}} : !llvm.ptr<f32>, f32, f32, f32, f32, f32, f32, f32, f32
-// case with larger pipeline depth
-// CHECKP-LABEL: hal.executable public @mma_fused
-// CHECKP: hal.executable.variant public @cuda
-// CHECKP-NOT: llvm.store
-// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
-// CHECKP: nvvm.cp.async.commit.group
-// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
-// CHECKP: nvvm.cp.async.commit.group
-// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
-// CHECKP: nvvm.cp.async.commit.group
-// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
-// CHECKP: nvvm.cp.async.commit.group
-// CHECKP: llvm.br
-// CHECKP: nvvm.cp.async.wait.group 3
-// CHECKP-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
-// CHECKP-COUNT-2: nvvm.wmma.mma
-// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
-// CHECKP: nvvm.cp.async.commit.group
-// CHECKP: llvm.br
-// CHECKP: nvvm.cp.async.wait.group 3
-// CHECKP-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
-// CHECKP-COUNT-2: nvvm.wmma.mma
-// CHECKP: nvvm.cp.async.wait.group 2
-// CHECKP-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
-// CHECKP-COUNT-2: nvvm.wmma.mma
-// CHECKP: nvvm.cp.async.wait.group 1
-// CHECKP-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
-// CHECKP-COUNT-2: nvvm.wmma.mma
-// CHECKP: nvvm.cp.async.wait.group 0
-// CHECKP-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
-// CHECKP-COUNT-2: nvvm.wmma.mma
-// CHECKP-COUNT-8: llvm.fadd
-// CHECKP-COUNT-1: nvvm.wmma.store {{.*}} : !llvm.ptr<f32>, f32, f32, f32, f32, f32, f32, f32, f32
-
// -----
#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
@@ -521,57 +502,40 @@
}
}
+// case with larger pipeline depth
// CHECK-LABEL: hal.executable public @mma_fused_fp16
// CHECK: hal.executable.variant public @cuda
// CHECK-NOT: llvm.store
// CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
// CHECK: nvvm.cp.async.commit.group
+// CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
+// CHECK: nvvm.cp.async.commit.group
+// CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
+// CHECK: nvvm.cp.async.commit.group
+// CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
+// CHECK: nvvm.cp.async.commit.group
// CHECK: llvm.br
-// CHECK: nvvm.cp.async.wait.group 0
+// CHECK: nvvm.cp.async.wait.group 3
// CHECK-COUNT-2: nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)
// CHECK-COUNT-1: nvvm.wmma.mma
// CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
// CHECK: nvvm.cp.async.commit.group
// CHECK: llvm.br
+// CHECK: nvvm.cp.async.wait.group 3
+// CHECK-COUNT-2: nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)
+// CHECK-COUNT-1: nvvm.wmma.mma
+// CHECK: nvvm.cp.async.wait.group 2
+// CHECK-COUNT-2: nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)
+// CHECK-COUNT-1: nvvm.wmma.mma
+// CHECK: nvvm.cp.async.wait.group 1
+// CHECK-COUNT-2: nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)
+// CHECK-COUNT-1: nvvm.wmma.mma
+// CHECK: nvvm.cp.async.wait.group 0
// CHECK-COUNT-2: nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)
// CHECK-COUNT-1: nvvm.wmma.mma
// CHECK-COUNT-4: llvm.fadd
// CHECK-COUNT-1: nvvm.wmma.store {{.*}} : !llvm.ptr<f16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
-// case with larger pipeline depth
-// CHECKP-LABEL: hal.executable public @mma_fused_fp16
-// CHECKP: hal.executable.variant public @cuda
-// CHECKP-NOT: llvm.store
-// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
-// CHECKP: nvvm.cp.async.commit.group
-// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
-// CHECKP: nvvm.cp.async.commit.group
-// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
-// CHECKP: nvvm.cp.async.commit.group
-// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
-// CHECKP: nvvm.cp.async.commit.group
-// CHECKP: llvm.br
-// CHECKP: nvvm.cp.async.wait.group 3
-// CHECKP-COUNT-2: nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)
-// CHECKP-COUNT-1: nvvm.wmma.mma
-// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
-// CHECKP: nvvm.cp.async.commit.group
-// CHECKP: llvm.br
-// CHECKP: nvvm.cp.async.wait.group 3
-// CHECKP-COUNT-2: nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)
-// CHECKP-COUNT-1: nvvm.wmma.mma
-// CHECKP: nvvm.cp.async.wait.group 2
-// CHECKP-COUNT-2: nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)
-// CHECKP-COUNT-1: nvvm.wmma.mma
-// CHECKP: nvvm.cp.async.wait.group 1
-// CHECKP-COUNT-2: nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)
-// CHECKP-COUNT-1: nvvm.wmma.mma
-// CHECKP: nvvm.cp.async.wait.group 0
-// CHECKP-COUNT-2: nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)
-// CHECKP-COUNT-1: nvvm.wmma.mma
-// CHECKP-COUNT-4: llvm.fadd
-// CHECKP-COUNT-1: nvvm.wmma.store {{.*}} : !llvm.ptr<f16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
-
// -----
#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
@@ -620,56 +584,40 @@
}
}
}
+
+// case with larger pipeline depth
// CHECK-LABEL: hal.executable public @large_dot_general_dispatch_0
// CHECK: hal.executable.variant public @cuda
// CHECK-NOT: llvm.store
// CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
// CHECK: nvvm.cp.async.commit.group
+// CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
+// CHECK: nvvm.cp.async.commit.group
+// CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
+// CHECK: nvvm.cp.async.commit.group
+// CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
+// CHECK: nvvm.cp.async.commit.group
// CHECK: llvm.br
-// CHECK: nvvm.cp.async.wait.group 0
+// CHECK: nvvm.cp.async.wait.group 3
// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
// CHECK-COUNT-2: nvvm.wmma.mma
// CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
// CHECK: nvvm.cp.async.commit.group
// CHECK: llvm.br
+// CHECK: nvvm.cp.async.wait.group 3
+// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
+// CHECK-COUNT-2: nvvm.wmma.mma
+// CHECK: nvvm.cp.async.wait.group 2
+// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
+// CHECK-COUNT-2: nvvm.wmma.mma
+// CHECK: nvvm.cp.async.wait.group 1
+// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
+// CHECK-COUNT-2: nvvm.wmma.mma
// CHECK: nvvm.cp.async.wait.group 0
// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
// CHECK-COUNT-2: nvvm.wmma.mma
// CHECK-COUNT-1: nvvm.wmma.store {{.*}} : !llvm.ptr<f32>, f32, f32, f32, f32, f32, f32, f32, f32
-// case with larger pipeline depth
-// CHECKP-LABEL: hal.executable public @large_dot_general_dispatch_0
-// CHECKP: hal.executable.variant public @cuda
-// CHECKP-NOT: llvm.store
-// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
-// CHECKP: nvvm.cp.async.commit.group
-// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
-// CHECKP: nvvm.cp.async.commit.group
-// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
-// CHECKP: nvvm.cp.async.commit.group
-// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
-// CHECKP: nvvm.cp.async.commit.group
-// CHECKP: llvm.br
-// CHECKP: nvvm.cp.async.wait.group 3
-// CHECKP-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
-// CHECKP-COUNT-2: nvvm.wmma.mma
-// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
-// CHECKP: nvvm.cp.async.commit.group
-// CHECKP: llvm.br
-// CHECKP: nvvm.cp.async.wait.group 3
-// CHECKP-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
-// CHECKP-COUNT-2: nvvm.wmma.mma
-// CHECKP: nvvm.cp.async.wait.group 2
-// CHECKP-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
-// CHECKP-COUNT-2: nvvm.wmma.mma
-// CHECKP: nvvm.cp.async.wait.group 1
-// CHECKP-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
-// CHECKP-COUNT-2: nvvm.wmma.mma
-// CHECKP: nvvm.cp.async.wait.group 0
-// CHECKP-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
-// CHECKP-COUNT-2: nvvm.wmma.mma
-// CHECKP-COUNT-1: nvvm.wmma.store {{.*}} : !llvm.ptr<f32>, f32, f32, f32, f32, f32, f32, f32, f32
-
// -----
#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
@@ -722,12 +670,15 @@
// CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
// CHECK: nvvm.cp.async.commit.group
// CHECK: llvm.br
-// CHECK: nvvm.cp.async.wait.group 0
+// CHECK: nvvm.cp.async.wait.group 3
// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
// CHECK-COUNT-2: nvvm.wmma.mma
-// CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
-// CHECK: nvvm.cp.async.commit.group
-// CHECK: llvm.br
+// CHECK: nvvm.cp.async.wait.group 2
+// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
+// CHECK-COUNT-2: nvvm.wmma.mma
+// CHECK: nvvm.cp.async.wait.group 1
+// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
+// CHECK-COUNT-2: nvvm.wmma.mma
// CHECK: nvvm.cp.async.wait.group 0
// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
// CHECK-COUNT-2: nvvm.wmma.mma
diff --git a/compiler/src/iree/compiler/Codegen/Passes.h b/compiler/src/iree/compiler/Codegen/Passes.h
index 9799258..928e0c6 100644
--- a/compiler/src/iree/compiler/Codegen/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/Passes.h
@@ -327,7 +327,8 @@
Operation *op, IREE::Codegen::LoweringConfigAttr loweringConfig,
IREE::Codegen::TranslationInfoAttr translationInfo,
ArrayRef<int64_t> workgroupSize);
-void addGPUMatmulTensorCorePassPipeline(OpPassManager &pm);
+void addGPUMatmulTensorCorePassPipeline(OpPassManager &pm,
+ unsigned pipelineDepth);
/// Simple lowering only distributute linalg ops on blocks and threads. This
/// will result in scalar operations. Expects pass manager to be a module-level
diff --git a/iree/test/e2e/matmul/generate_e2e_matmul_tests.py b/iree/test/e2e/matmul/generate_e2e_matmul_tests.py
index 62c7721..cb8b8b6 100644
--- a/iree/test/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/iree/test/e2e/matmul/generate_e2e_matmul_tests.py
@@ -77,6 +77,7 @@
# Translation Info
dispatch_lowering_pass_pipeline: str
workload_per_wg: typing.List[int]
+ software_pipeline_depth: int
# Compilation info
workgroup_size: typing.List[int]
@@ -189,7 +190,8 @@
workload_per_wg=[
a for a in reversed(tile_workgroup_size_pair.tile_size[0:2])
],
- workgroup_size=tile_workgroup_size_pair.workgroup_size))
+ workgroup_size=tile_workgroup_size_pair.workgroup_size,
+ software_pipeline_depth=4))
return compilation_infos
@@ -355,7 +357,8 @@
compilation_info_string = (
f"#compilation{generate_function.compilation_index} = #iree_codegen.compilation_info<\n"
f" lowering_config = <tile_sizes = [{compilation_info.tile_sizes}]>,\n"
- f" translation_info = <{compilation_info.dispatch_lowering_pass_pipeline}>,\n"
+ f" translation_info = <{compilation_info.dispatch_lowering_pass_pipeline}\n"
+ f" pipeline_depth = {compilation_info.software_pipeline_depth}>,\n"
f" workgroup_size = {compilation_info.workgroup_size_str()}>\n")
compilation_info_attr = f"{{compilation_info = #compilation{generate_function.compilation_index}}} "
func_definition = func_definition + compilation_info_string