Add software pipeline depth to lowering config (#9114)

diff --git a/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.cpp b/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.cpp
index df9ead9..d9a613d 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.cpp
@@ -75,12 +75,13 @@
 
 TranslationInfoAttr TranslationInfoAttr::get(
     MLIRContext *context, DispatchLoweringPassPipeline passPipeline,
-    ArrayRef<int64_t> workloadPerWorkgroup) {
+    ArrayRef<int64_t> workloadPerWorkgroup, unsigned softwarePipelineDepth) {
   auto pipelineAttr =
       DispatchLoweringPassPipelineAttr::get(context, passPipeline);
   ArrayAttr workloadPerWorkgroupAttr =
       getI64IntegerArrayAttr(context, workloadPerWorkgroup);
-  return get(context, pipelineAttr, workloadPerWorkgroupAttr);
+  return get(context, pipelineAttr, workloadPerWorkgroupAttr,
+             softwarePipelineDepth);
 }
 
 DispatchLoweringPassPipeline
@@ -95,7 +96,7 @@
 LogicalResult TranslationInfoAttr::verify(
     function_ref<InFlightDiagnostic()> emitError,
     IREE::Codegen::DispatchLoweringPassPipelineAttr passPipeline,
-    ArrayAttr workloadPerWorkgroup) {
+    ArrayAttr workloadPerWorkgroup, unsigned softwarePipelineDepth) {
   if (!passPipeline) {
     return emitError() << "missing pass pipeline specification";
   }
@@ -245,7 +246,8 @@
   }
   if (failed(TranslationInfoAttr::verify(
           emitError, translationInfo.getPassPipeline(),
-          translationInfo.getWorkloadPerWorkgroup()))) {
+          translationInfo.getWorkloadPerWorkgroup(),
+          translationInfo.getSoftwarePipelineDepth()))) {
     return failure();
   }
   if (workgroupSize) {
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.h b/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.h
index 3a6508e..af89154 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.h
+++ b/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.h
@@ -84,7 +84,8 @@
 inline void setTranslationInfo(
     func::FuncOp entryPointFn,
     IREE::Codegen::DispatchLoweringPassPipeline passPipeline,
-    ArrayRef<int64_t> workloadPerWorkgroup, ArrayRef<int64_t> workgroupSize) {
+    ArrayRef<int64_t> workloadPerWorkgroup, ArrayRef<int64_t> workgroupSize,
+    unsigned softwarePipelineDepth = 0) {
   auto entryPointOp = getEntryPoint(entryPointFn);
   MLIRContext *context = entryPointFn.getContext();
   auto translationInfo = IREE::Codegen::TranslationInfoAttr::get(
@@ -117,12 +118,12 @@
 inline LogicalResult setOpConfigAndEntryPointFnTranslation(
     func::FuncOp entryPointFn, Operation *op, TileSizesListTypeRef tileSizes,
     IREE::Codegen::DispatchLoweringPassPipeline passPipeline,
-    ArrayRef<int64_t> workgroupSize = {}) {
+    ArrayRef<int64_t> workgroupSize = {}, unsigned softwarePipelineDepth = 0) {
   MLIRContext *context = entryPointFn.getContext();
   auto config = IREE::Codegen::LoweringConfigAttr::get(context, tileSizes);
   setLoweringConfig(op, config);
   auto translationInfo = IREE::Codegen::TranslationInfoAttr::get(
-      entryPointFn->getContext(), passPipeline);
+      entryPointFn->getContext(), passPipeline, {}, softwarePipelineDepth);
   setTranslationInfo(entryPointFn, translationInfo, workgroupSize);
   return success();
 }
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.td b/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.td
index 9ae145c..e5b117b 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.td
@@ -104,18 +104,22 @@
   }];
 
   let assemblyFormat = [{
-    `<` `` $passPipeline (`workload_per_wg` `=` $workloadPerWorkgroup^)? `>`
+    `<` `` $passPipeline (`workload_per_wg` `=` $workloadPerWorkgroup^)?
+    (`pipeline_depth` `=` $softwarePipelineDepth^)? `>`
   }];
 
   let parameters = (ins
     AttrParameter<"IREE::Codegen::DispatchLoweringPassPipelineAttr",
         "Name of the pipeline to be invoked on the translation unit.">:$passPipeline,
     DefaultValuedParameter<"ArrayAttr", "ArrayAttr::get($_ctx, {})",
-        "The workload mapped to a single workgroup">:$workloadPerWorkgroup
+        "The workload mapped to a single workgroup">:$workloadPerWorkgroup,
+    DefaultValuedParameter<"unsigned", "1",
+        "The software pipeline depth to be used">:$softwarePipelineDepth
   );
   let builders = [
     AttrBuilder<(ins "DispatchLoweringPassPipeline":$passPipeline,
-        CArg<"ArrayRef<int64_t>", "{}">:$workloadPerWorkgroup)>
+        CArg<"ArrayRef<int64_t>", "{}">:$workloadPerWorkgroup,
+        CArg<"unsigned", "0">:$softwarePipelineDepth)>
   ];
   let extraClassDeclaration = [{
     // Returns the lowering pass pipeline set.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index c407baf..c6964d7 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -29,6 +29,11 @@
   std::array<int64_t, 3> tileSize;
   std::array<int64_t, 3> workgroupSize;
 };
+
+// Software pipeline depths
+constexpr unsigned softwarePipelineDepthTensorCore = 4;
+// Simt codegen does not do software pipelining.
+constexpr unsigned softwarePipelineDepthSimt = 0;
 }  // namespace
 
 /// Return the best combination of tile size and wg size. It will then used to
@@ -118,6 +123,7 @@
   auto setMatmulConfig =
       [&entryPoint, &op](int64_t tileX, int64_t tileY, int64_t tileK,
                          llvm::ArrayRef<int64_t> workgroupSize,
+                         unsigned softwarePipelineDepth,
                          IREE::Codegen::DispatchLoweringPassPipeline pipeline) {
         TileSizesListType tileSizes;
         unsigned numParallelLoops = op.getNumParallelLoops();
@@ -140,7 +146,8 @@
         tileSizes.emplace_back(
             std::move(workgroupTileSizes));  // Workgroup level.
         return setOpConfigAndEntryPointFnTranslation(entryPoint, op, tileSizes,
-                                                     pipeline, workgroupSize);
+                                                     pipeline, workgroupSize,
+                                                     softwarePipelineDepth);
       };
   // Infer the MxN size of the matmul based on operands and indexing maps.
   auto lhsShape =
@@ -196,17 +203,19 @@
         if (sizeK % config.tileSize[2] == 0 &&
             sizeN % config.tileSize[1] == 0 &&
             sizeM % config.tileSize[0] == 0) {
-          return setMatmulConfig(config.tileSize[0], config.tileSize[1],
-                                 config.tileSize[2], config.workgroupSize,
-                                 IREE::Codegen::DispatchLoweringPassPipeline::
-                                     LLVMGPUMatmulTensorCore);
+          return setMatmulConfig(
+              config.tileSize[0], config.tileSize[1], config.tileSize[2],
+              config.workgroupSize,
+              sizeK == config.tileSize[2] ? 1 : softwarePipelineDepthTensorCore,
+              IREE::Codegen::DispatchLoweringPassPipeline::
+                  LLVMGPUMatmulTensorCore);
         }
       }
     }
     // Special case for very small matrices.
     if (sizeM * sizeN <= cudaWarpSize) {
       return setMatmulConfig(
-          sizeN, sizeM, 4, {sizeM, sizeN, 1},
+          sizeN, sizeM, 4, {sizeM, sizeN, 1}, softwarePipelineDepthSimt,
           IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUMatmulSimt);
     }
     // simt matmul case
@@ -219,7 +228,7 @@
       if (sizeN % config.tileSize[1] == 0 && sizeM % config.tileSize[0] == 0) {
         return setMatmulConfig(
             config.tileSize[0], config.tileSize[1], config.tileSize[2],
-            config.workgroupSize,
+            config.workgroupSize, softwarePipelineDepthSimt,
             IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUMatmulSimt);
       }
     }
@@ -230,7 +239,7 @@
   int64_t tileK = 4;
   SmallVector<int64_t, 3> workgroupSize = {2 * cudaWarpSize, 1, 1};
   return setMatmulConfig(
-      tileX, tileY, tileK, workgroupSize,
+      tileX, tileY, tileK, workgroupSize, softwarePipelineDepthSimt,
       IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUMatmulSimt);
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
index eb89819..4f24c26 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
@@ -155,7 +155,9 @@
         addGPUMatmulSimtPassPipeline(nestedModulePM);
         break;
       case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUMatmulTensorCore:
-        addGPUMatmulTensorCorePassPipeline(nestedModulePM);
+        addGPUMatmulTensorCorePassPipeline(
+            nestedModulePM,
+            translationInfo.getValue().getSoftwarePipelineDepth());
         break;
       default:
         variantOp.emitOpError("Unsupported pipeline on GPU target.");
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index db61284..752d5cf 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -24,11 +24,6 @@
 namespace mlir {
 namespace iree_compiler {
 
-// TODO(thomasraoux): Add a new optional attribute to translate info.
-static llvm::cl::opt<unsigned> pipelineDepth("iree-codegen-cuda-pipeline-depth",
-                                             llvm::cl::desc("Pipeline depth"),
-                                             llvm::cl::init(4));
-
 static llvm::cl::opt<unsigned> logSwizzleTile(
     "iree-codegen-log-swizzle-tile", llvm::cl::desc("log swizzle tile value"),
     llvm::cl::init(0));
@@ -109,7 +104,8 @@
   pm.addNestedPass<func::FuncOp>(createGPUPipeliningPass());
 }
 
-void addGPUMatmulTensorCorePassPipeline(OpPassManager &pm) {
+void addGPUMatmulTensorCorePassPipeline(OpPassManager &pm,
+                                        unsigned pipelineDepth) {
   tileAndBufferize(pm);
 
   // Distribute linalg onto warps within the workgroup.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Verifiers.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Verifiers.cpp
index e832517..da06e2e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Verifiers.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Verifiers.cpp
@@ -96,6 +96,7 @@
     ArrayRef<int64_t> workgroupSize) {
   auto pipeline =
       IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUMatmulTensorCore;
+  unsigned softwarePipelinedepth = translationInfo.getSoftwarePipelineDepth();
   StringRef pipelineName = stringifyEnum(pipeline);
   if (workgroupSize.empty()) {
     return op->emitOpError("expected workgroup size for GPU pipelines");
@@ -153,7 +154,15 @@
     return op->emitOpError("workgroup size is not 32 aligned for ")
            << pipelineName << ", got " << workgroupSize[0];
   }
-
+  if (softwarePipelinedepth > 1 && firstLevelTileSizes[2] == lhsShape[1]) {
+    return op->emitError(
+               "Software pipelining is not supported when first level K tile "
+               "size is same as matrix reduction size.\n This dispatch has\nk "
+               "tile: ")
+           << firstLevelTileSizes[2]
+           << "\nMatrix reduction size: " << lhsShape[1]
+           << "\nPipelinedepth: " << softwarePipelinedepth;
+  }
   // Verify the workgroup.z component should always be 1
   if (workgroupSize[2] != 1) {
     return op->emitOpError("expected workgroup z component to be 1 for ")
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir
index 7c62e49..5698b63 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir
@@ -1,5 +1,4 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='hal.executable(hal.executable.variant(iree-codegen-linalg-to-nvvm-pipeline))' --iree-codegen-cuda-pipeline-depth=1 %s | FileCheck %s
-// RUN: iree-opt --split-input-file --pass-pipeline='hal.executable(hal.executable.variant(iree-codegen-linalg-to-nvvm-pipeline))' --iree-codegen-cuda-pipeline-depth=4 %s | FileCheck %s --check-prefix=CHECKP
+// RUN: iree-opt --split-input-file --pass-pipeline='hal.executable(hal.executable.variant(iree-codegen-linalg-to-nvvm-pipeline))' %s | FileCheck %s
 
 // Verify that a simple element wise op gets lowered succefully all the way to
 // nvvm/llvm dialect.
@@ -421,58 +420,40 @@
 }
 }
 
+// case with larger pipeline depth
 //     CHECK-LABEL: hal.executable public @mma_fused
 //           CHECK:   hal.executable.variant public @cuda
 //       CHECK-NOT:   llvm.store
 //   CHECK-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
 //           CHECK:   nvvm.cp.async.commit.group
+//   CHECK-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
+//           CHECK:   nvvm.cp.async.commit.group
+//   CHECK-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
+//           CHECK:   nvvm.cp.async.commit.group
+//   CHECK-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
+//           CHECK:   nvvm.cp.async.commit.group
 //           CHECK:   llvm.br
-//           CHECK:   nvvm.cp.async.wait.group 0
+//           CHECK:   nvvm.cp.async.wait.group 3
 //   CHECK-COUNT-4:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
 //   CHECK-COUNT-2:   nvvm.wmma.mma
 //   CHECK-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
 //           CHECK:   nvvm.cp.async.commit.group
 //           CHECK:   llvm.br
+//           CHECK:   nvvm.cp.async.wait.group 3
+//   CHECK-COUNT-4:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
+//   CHECK-COUNT-2:   nvvm.wmma.mma
+//           CHECK:   nvvm.cp.async.wait.group 2
+//   CHECK-COUNT-4:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
+//   CHECK-COUNT-2:   nvvm.wmma.mma
+//           CHECK:   nvvm.cp.async.wait.group 1
+//   CHECK-COUNT-4:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
+//   CHECK-COUNT-2:   nvvm.wmma.mma
 //           CHECK:   nvvm.cp.async.wait.group 0
 //   CHECK-COUNT-4:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
 //   CHECK-COUNT-2:   nvvm.wmma.mma
 //   CHECK-COUNT-8:   llvm.fadd
 //   CHECK-COUNT-1:   nvvm.wmma.store {{.*}} : !llvm.ptr<f32>, f32, f32, f32, f32, f32, f32, f32, f32
 
-// case with larger pipeline depth
-//     CHECKP-LABEL: hal.executable public @mma_fused
-//           CHECKP:   hal.executable.variant public @cuda
-//       CHECKP-NOT:   llvm.store
-//   CHECKP-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
-//           CHECKP:   nvvm.cp.async.commit.group
-//   CHECKP-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
-//           CHECKP:   nvvm.cp.async.commit.group
-//   CHECKP-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
-//           CHECKP:   nvvm.cp.async.commit.group
-//   CHECKP-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
-//           CHECKP:   nvvm.cp.async.commit.group
-//           CHECKP:   llvm.br
-//           CHECKP:   nvvm.cp.async.wait.group 3
-//   CHECKP-COUNT-4:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
-//   CHECKP-COUNT-2:   nvvm.wmma.mma
-//   CHECKP-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
-//           CHECKP:   nvvm.cp.async.commit.group
-//           CHECKP:   llvm.br
-//           CHECKP:   nvvm.cp.async.wait.group 3
-//   CHECKP-COUNT-4:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
-//   CHECKP-COUNT-2:   nvvm.wmma.mma
-//           CHECKP:   nvvm.cp.async.wait.group 2
-//   CHECKP-COUNT-4:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
-//   CHECKP-COUNT-2:   nvvm.wmma.mma
-//           CHECKP:   nvvm.cp.async.wait.group 1
-//   CHECKP-COUNT-4:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
-//   CHECKP-COUNT-2:   nvvm.wmma.mma
-//           CHECKP:   nvvm.cp.async.wait.group 0
-//   CHECKP-COUNT-4:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
-//   CHECKP-COUNT-2:   nvvm.wmma.mma
-//   CHECKP-COUNT-8:   llvm.fadd
-//   CHECKP-COUNT-1:   nvvm.wmma.store {{.*}} : !llvm.ptr<f32>, f32, f32, f32, f32, f32, f32, f32, f32
-
 // -----
 
 #executable_layout = #hal.executable.layout<push_constants = 0, sets = [
@@ -521,57 +502,40 @@
 }
 }
 
+// case with larger pipeline depth
 //     CHECK-LABEL: hal.executable public @mma_fused_fp16
 //           CHECK:   hal.executable.variant public @cuda
 //       CHECK-NOT:   llvm.store
 //   CHECK-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
 //           CHECK:   nvvm.cp.async.commit.group
+//   CHECK-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
+//           CHECK:   nvvm.cp.async.commit.group
+//   CHECK-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
+//           CHECK:   nvvm.cp.async.commit.group
+//   CHECK-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
+//           CHECK:   nvvm.cp.async.commit.group
 //           CHECK:   llvm.br
-//           CHECK:   nvvm.cp.async.wait.group 0
+//           CHECK:   nvvm.cp.async.wait.group 3
 //   CHECK-COUNT-2:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)
 //   CHECK-COUNT-1:   nvvm.wmma.mma
 //   CHECK-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
 //           CHECK:   nvvm.cp.async.commit.group
 //           CHECK:   llvm.br
+//           CHECK:   nvvm.cp.async.wait.group 3
+//   CHECK-COUNT-2:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)
+//   CHECK-COUNT-1:   nvvm.wmma.mma
+//           CHECK:   nvvm.cp.async.wait.group 2
+//   CHECK-COUNT-2:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)
+//   CHECK-COUNT-1:   nvvm.wmma.mma
+//           CHECK:   nvvm.cp.async.wait.group 1
+//   CHECK-COUNT-2:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)
+//   CHECK-COUNT-1:   nvvm.wmma.mma
+//           CHECK:   nvvm.cp.async.wait.group 0
 //   CHECK-COUNT-2:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)
 //   CHECK-COUNT-1:   nvvm.wmma.mma
 //   CHECK-COUNT-4:   llvm.fadd
 //   CHECK-COUNT-1:   nvvm.wmma.store {{.*}} : !llvm.ptr<f16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
 
-// case with larger pipeline depth
-//     CHECKP-LABEL: hal.executable public @mma_fused_fp16
-//           CHECKP:   hal.executable.variant public @cuda
-//       CHECKP-NOT:   llvm.store
-//   CHECKP-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
-//           CHECKP:   nvvm.cp.async.commit.group
-//   CHECKP-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
-//           CHECKP:   nvvm.cp.async.commit.group
-//   CHECKP-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
-//           CHECKP:   nvvm.cp.async.commit.group
-//   CHECKP-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
-//           CHECKP:   nvvm.cp.async.commit.group
-//           CHECKP:   llvm.br
-//           CHECKP:   nvvm.cp.async.wait.group 3
-//   CHECKP-COUNT-2:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)
-//   CHECKP-COUNT-1:   nvvm.wmma.mma
-//   CHECKP-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
-//           CHECKP:   nvvm.cp.async.commit.group
-//           CHECKP:   llvm.br
-//           CHECKP:   nvvm.cp.async.wait.group 3
-//   CHECKP-COUNT-2:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)
-//   CHECKP-COUNT-1:   nvvm.wmma.mma
-//           CHECKP:   nvvm.cp.async.wait.group 2
-//   CHECKP-COUNT-2:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)
-//   CHECKP-COUNT-1:   nvvm.wmma.mma
-//           CHECKP:   nvvm.cp.async.wait.group 1
-//   CHECKP-COUNT-2:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)
-//   CHECKP-COUNT-1:   nvvm.wmma.mma
-//           CHECKP:   nvvm.cp.async.wait.group 0
-//   CHECKP-COUNT-2:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)
-//   CHECKP-COUNT-1:   nvvm.wmma.mma
-//   CHECKP-COUNT-4:   llvm.fadd
-//   CHECKP-COUNT-1:   nvvm.wmma.store {{.*}} : !llvm.ptr<f16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
-
 // -----
 
 #executable_layout = #hal.executable.layout<push_constants = 0, sets = [
@@ -620,56 +584,40 @@
       }
     }
   }
+
+// case with larger pipeline depth
 //     CHECK-LABEL: hal.executable public @large_dot_general_dispatch_0
 //           CHECK:   hal.executable.variant public @cuda
 //       CHECK-NOT:   llvm.store
 //   CHECK-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
 //           CHECK:   nvvm.cp.async.commit.group
+//   CHECK-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
+//           CHECK:   nvvm.cp.async.commit.group
+//   CHECK-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
+//           CHECK:   nvvm.cp.async.commit.group
+//   CHECK-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
+//           CHECK:   nvvm.cp.async.commit.group
 //           CHECK:   llvm.br
-//           CHECK:   nvvm.cp.async.wait.group 0
+//           CHECK:   nvvm.cp.async.wait.group 3
 //   CHECK-COUNT-4:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
 //   CHECK-COUNT-2:   nvvm.wmma.mma
 //   CHECK-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
 //           CHECK:   nvvm.cp.async.commit.group
 //           CHECK:   llvm.br
+//           CHECK:   nvvm.cp.async.wait.group 3
+//   CHECK-COUNT-4:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
+//   CHECK-COUNT-2:   nvvm.wmma.mma
+//           CHECK:   nvvm.cp.async.wait.group 2
+//   CHECK-COUNT-4:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
+//   CHECK-COUNT-2:   nvvm.wmma.mma
+//           CHECK:   nvvm.cp.async.wait.group 1
+//   CHECK-COUNT-4:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
+//   CHECK-COUNT-2:   nvvm.wmma.mma
 //           CHECK:   nvvm.cp.async.wait.group 0
 //   CHECK-COUNT-4:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
 //   CHECK-COUNT-2:   nvvm.wmma.mma
 //   CHECK-COUNT-1:   nvvm.wmma.store {{.*}} : !llvm.ptr<f32>, f32, f32, f32, f32, f32, f32, f32, f32
 
-// case with larger pipeline depth
-//     CHECKP-LABEL: hal.executable public @large_dot_general_dispatch_0
-//           CHECKP:   hal.executable.variant public @cuda
-//       CHECKP-NOT:   llvm.store
-//   CHECKP-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
-//           CHECKP:   nvvm.cp.async.commit.group
-//   CHECKP-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
-//           CHECKP:   nvvm.cp.async.commit.group
-//   CHECKP-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
-//           CHECKP:   nvvm.cp.async.commit.group
-//   CHECKP-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
-//           CHECKP:   nvvm.cp.async.commit.group
-//           CHECKP:   llvm.br
-//           CHECKP:   nvvm.cp.async.wait.group 3
-//   CHECKP-COUNT-4:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
-//   CHECKP-COUNT-2:   nvvm.wmma.mma
-//   CHECKP-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
-//           CHECKP:   nvvm.cp.async.commit.group
-//           CHECKP:   llvm.br
-//           CHECKP:   nvvm.cp.async.wait.group 3
-//   CHECKP-COUNT-4:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
-//   CHECKP-COUNT-2:   nvvm.wmma.mma
-//           CHECKP:   nvvm.cp.async.wait.group 2
-//   CHECKP-COUNT-4:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
-//   CHECKP-COUNT-2:   nvvm.wmma.mma
-//           CHECKP:   nvvm.cp.async.wait.group 1
-//   CHECKP-COUNT-4:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
-//   CHECKP-COUNT-2:   nvvm.wmma.mma
-//           CHECKP:   nvvm.cp.async.wait.group 0
-//   CHECKP-COUNT-4:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
-//   CHECKP-COUNT-2:   nvvm.wmma.mma
-//   CHECKP-COUNT-1:   nvvm.wmma.store {{.*}} : !llvm.ptr<f32>, f32, f32, f32, f32, f32, f32, f32, f32
-
 // -----
 
 #executable_layout = #hal.executable.layout<push_constants = 0, sets = [
@@ -722,12 +670,15 @@
 //   CHECK-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
 //           CHECK:   nvvm.cp.async.commit.group
 //           CHECK:   llvm.br
-//           CHECK:   nvvm.cp.async.wait.group 0
+//           CHECK:   nvvm.cp.async.wait.group 3
 //   CHECK-COUNT-4:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
 //   CHECK-COUNT-2:   nvvm.wmma.mma
-//   CHECK-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
-//           CHECK:   nvvm.cp.async.commit.group
-//           CHECK:   llvm.br
+//           CHECK:   nvvm.cp.async.wait.group 2
+//   CHECK-COUNT-4:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
+//   CHECK-COUNT-2:   nvvm.wmma.mma
+//           CHECK:   nvvm.cp.async.wait.group 1
+//   CHECK-COUNT-4:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
+//   CHECK-COUNT-2:   nvvm.wmma.mma
 //           CHECK:   nvvm.cp.async.wait.group 0
 //   CHECK-COUNT-4:   nvvm.wmma.load{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)
 //   CHECK-COUNT-2:   nvvm.wmma.mma
diff --git a/compiler/src/iree/compiler/Codegen/Passes.h b/compiler/src/iree/compiler/Codegen/Passes.h
index 9799258..928e0c6 100644
--- a/compiler/src/iree/compiler/Codegen/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/Passes.h
@@ -327,7 +327,8 @@
     Operation *op, IREE::Codegen::LoweringConfigAttr loweringConfig,
     IREE::Codegen::TranslationInfoAttr translationInfo,
     ArrayRef<int64_t> workgroupSize);
-void addGPUMatmulTensorCorePassPipeline(OpPassManager &pm);
+void addGPUMatmulTensorCorePassPipeline(OpPassManager &pm,
+                                        unsigned pipelineDepth);
 
 /// Simple lowering only distributute linalg ops on blocks and threads. This
 /// will result in scalar operations. Expects pass manager to be a module-level
diff --git a/iree/test/e2e/matmul/generate_e2e_matmul_tests.py b/iree/test/e2e/matmul/generate_e2e_matmul_tests.py
index 62c7721..cb8b8b6 100644
--- a/iree/test/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/iree/test/e2e/matmul/generate_e2e_matmul_tests.py
@@ -77,6 +77,7 @@
   # Translation Info
   dispatch_lowering_pass_pipeline: str
   workload_per_wg: typing.List[int]
+  software_pipeline_depth: int
   # Compilation info
   workgroup_size: typing.List[int]
 
@@ -189,7 +190,8 @@
             workload_per_wg=[
                 a for a in reversed(tile_workgroup_size_pair.tile_size[0:2])
             ],
-            workgroup_size=tile_workgroup_size_pair.workgroup_size))
+            workgroup_size=tile_workgroup_size_pair.workgroup_size,
+            software_pipeline_depth=4))
   return compilation_infos
 
 
@@ -355,7 +357,8 @@
     compilation_info_string = (
         f"#compilation{generate_function.compilation_index} = #iree_codegen.compilation_info<\n"
         f"  lowering_config = <tile_sizes = [{compilation_info.tile_sizes}]>,\n"
-        f"  translation_info = <{compilation_info.dispatch_lowering_pass_pipeline}>,\n"
+        f"  translation_info = <{compilation_info.dispatch_lowering_pass_pipeline}\n"
+        f"  pipeline_depth = {compilation_info.software_pipeline_depth}>,\n"
         f"  workgroup_size = {compilation_info.workgroup_size_str()}>\n")
     compilation_info_attr = f"{{compilation_info = #compilation{generate_function.compilation_index}}} "
     func_definition = func_definition + compilation_info_string