Support GEMM Pipelining *without* Epilogue Peeling (#10388)

diff --git a/compiler/src/iree/compiler/Codegen/Common/GPUPipelining.cpp b/compiler/src/iree/compiler/Codegen/Common/GPUPipelining.cpp
index 109c1d0..f2c9ae5 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPUPipelining.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPUPipelining.cpp
@@ -7,11 +7,13 @@
 #include "iree/compiler/Codegen/PassDetail.h"
 #include "iree/compiler/Codegen/Passes.h"
 #include "iree/compiler/Codegen/Utils/Utils.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/SideEffectUtils.h"
 
 //====---------------------------------------------------------------------===//
 // Pass to pipeline copy to shared memory for matmul op.
@@ -23,6 +25,63 @@
 static const StringLiteral kPipeliningLoopMarker = "__pipelining_K_loop__";
 static const StringLiteral kPipeliningGlobalLoad = "__pipelining_global_load__";
 
+// Returns a new predicated operation to support unpeeled epilogue. Unpeeled
+// epilogue needs to handle the last iterations within the mainloop which
+// requires predicating operations, for e.g., OOB global memory access. This
+// helper function predicates operations (where predication is avialable),
+// checks if unpredicated operations are side-effect free and acceptable to
+// execute speculatively.
+static Operation* replaceOpWithPredicatedOp(Operation* op, Value pred,
+                                            PatternRewriter& rewriter) {
+  // Predication is only supported for AsyncCopyOp. Thus, for operations which
+  // are *not* AsyncCopyOp additional checks are requrired in order to be issued
+  // speculatively.
+  if (!isa<nvgpu::DeviceAsyncCopyOp>(op)) {
+    // Return/execute the op if it is a side effect free.
+    if (mlir::isSideEffectFree(op)) return op;
+    // Return/execute the op if it is barrier, commit group, or ldmatrix op.
+    if (isa<gpu::BarrierOp, nvgpu::DeviceAsyncCreateGroupOp, nvgpu::LdMatrixOp>(
+            op))
+      return op;
+    // Return/execute the op if it is a shared memory load.
+    if (auto loadOp = dyn_cast<vector::LoadOp>(op)) {
+      unsigned loadAddrSpace =
+          loadOp.getBase().getType().cast<MemRefType>().getMemorySpaceAsInt();
+      if (loadAddrSpace == gpu::GPUDialect::getWorkgroupAddressSpace())
+        return op;
+    }
+    // If we are here that means the operation does not have predication support
+    // and cannot be speculatively executed. Thus, unpeeled epilogue is not
+    // supported.
+    assert(false &&
+           "Unpeeled epilogue not supported with a side-effect instruction "
+           "with no predication.");
+  }
+
+  // Replace mainloop AsyncCopy with AsyncCopy(zfill) inline asm.
+  auto asyncCopyOp = dyn_cast<nvgpu::DeviceAsyncCopyOp>(op);
+  auto loc = asyncCopyOp->getLoc();
+
+  // Create srcElement Value based on the pred.
+  // The next few lins generate the below code:
+  // srcElement = (pred) ?  dstElements : 0;
+  Value dstElements =
+      rewriter.create<arith::ConstantOp>(loc, asyncCopyOp.getDstElementsAttr());
+  Value c0Index = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto srcElements =
+      rewriter.create<arith::SelectOp>(loc, pred, dstElements, c0Index);
+  auto asyncCopyZfillOp = rewriter.create<nvgpu::DeviceAsyncCopyOp>(
+      loc, nvgpu::DeviceAsyncTokenType::get(asyncCopyOp.getContext()),
+      asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(),
+      asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements,
+      UnitAttr());
+
+  rewriter.eraseOp(asyncCopyOp);
+
+  // Return the newly create predicated AsyncCopyZfillOp.
+  return asyncCopyZfillOp;
+}
+
 /// Helper to recursively add operation dependencies within `block` to `dep`
 /// set.
 static void addDepOps(llvm::SmallDenseSet<Operation*>& dep, Operation* op,
@@ -84,7 +143,9 @@
 
 namespace {
 struct GPUPipeliningPass : public GPUPipeliningBase<GPUPipeliningPass> {
-  GPUPipeliningPass(unsigned depth) : depth(depth) {}
+  GPUPipeliningPass(bool epiloguePeeling, unsigned depth) : depth(depth) {
+    this->epiloguePeeling = epiloguePeeling;
+  }
   void runOnOperation() override {
     auto funcOp = getOperation();
     MLIRContext* context = &getContext();
@@ -142,6 +203,17 @@
     };
     options.getScheduleFn = getSchedule;
     options.annotateFn = setAnnotation;
+
+    // Use un-peeled epilogue (i.e. epiloguePeeling=flase) only when predication
+    // is avialable a.k.a. AsyncCopyOp.
+    if (!epiloguePeeling) {
+      options.peelEpilogue = false;
+      options.predicateFn = [](Operation* op, Value pred,
+                               PatternRewriter& rewriter) {
+        return replaceOpWithPredicatedOp(op, pred, rewriter);
+      };
+    }
+
     RewritePatternSet pipeliningPatterns(context);
     scf::populateSCFLoopPipeliningPatterns(pipeliningPatterns, options);
     if (failed(applyPatternsAndFoldGreedily(funcOp,
@@ -155,9 +227,14 @@
 };
 }  // namespace
 
+/// Pass options
+/// epiloguePeeling - try enable/disable epilogue peeling.
+/// true  : Peel epilogue (no additional checks required)
+/// false : Try and use unpeeled epilogue (check if predication is supported is
+/// avialable)
 std::unique_ptr<OperationPass<func::FuncOp>> createGPUPipeliningPass(
-    unsigned depth) {
-  return std::make_unique<GPUPipeliningPass>(depth);
+    bool epiloguePeeling, unsigned depth) {
+  return std::make_unique<GPUPipeliningPass>(epiloguePeeling, depth);
 }
 
 }  // namespace iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD b/compiler/src/iree/compiler/Codegen/Common/test/BUILD
index a3c1bfa..fab3c01 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD
+++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD
@@ -29,6 +29,7 @@
             "fold_affine_min_in_distributed_loops.mlir",
             "fold_tensor_extract_op.mlir",
             "forop_canonicalization.mlir",
+            "gpu_pipeline.mlir",
             "gpu_vectorization.mlir",
             "iree_comprehensive_bufferize.mlir",
             "pad_dynamic_alloc.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
index 684fc18..fdd1347 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
@@ -25,6 +25,7 @@
     "fold_affine_min_in_distributed_loops.mlir"
     "fold_tensor_extract_op.mlir"
     "forop_canonicalization.mlir"
+    "gpu_pipeline.mlir"
     "gpu_vectorization.mlir"
     "iree_comprehensive_bufferize.mlir"
     "pad_dynamic_alloc.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/gpu_pipeline.mlir b/compiler/src/iree/compiler/Codegen/Common/test/gpu_pipeline.mlir
new file mode 100644
index 0000000..55bff1e
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/test/gpu_pipeline.mlir
@@ -0,0 +1,54 @@
+// Test un-peeled epilogue generating AsyncCopyOp using zfill 
+// RUN: iree-opt --iree-gpu-pipelining=epilogue-peeling=false %s | FileCheck %s
+
+func.func @_matmul_f16_f16_dispatch_0_fill_3456x1024() {
+  %c2048 = arith.constant 2048 : index
+  %c32 = arith.constant 32 : index
+  %c16 = arith.constant 16 : index
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f16
+  %0 = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf16, "COp">
+  %1 = gpu.thread_id  x
+  %2 = gpu.thread_id  y
+  %3 = gpu.thread_id  z
+  %4 = memref.alloc() : memref<4x32x40xf16, 3>
+  %5 = memref.alloc() : memref<4x32x40xf16, 3>
+  %6 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : memref<3456x2048xf16>
+  memref.assume_alignment %6, 64 : memref<3456x2048xf16>
+  %7 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : memref<2048x1024xf16>
+  memref.assume_alignment %7, 64 : memref<2048x1024xf16>
+  %8 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : memref<3456x1024xf16>
+  memref.assume_alignment %8, 64 : memref<3456x1024xf16>
+  %workgroup_id_x = hal.interface.workgroup.id[0] : index
+  %workgroup_id_y = hal.interface.workgroup.id[1] : index
+  %9 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%1, %2, %3]
+  %10 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%1]
+  %11 = scf.for %arg0 = %c0 to %c2048 step %c32 iter_args(%arg1 = %0) -> (!gpu.mma_matrix<16x16xf16, "COp">) {
+    gpu.barrier
+    %14 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 8 - (s1 floordiv 4) * 32)>()[%arg0, %1]
+    %15 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 16 + s2 * 32 + s3 * 32 + s0 floordiv 4)>()[%1, %2, %3, %workgroup_id_y]
+    %16 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) mod 4)>(%arg0)
+    %17 = nvgpu.device_async_copy %6[%15, %14], %4[%16, %9, %10], 8 : memref<3456x2048xf16> to memref<4x32x40xf16, 3>
+    %18 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 + s2 * 16 + s3 * 32 + s1 floordiv 4)>()[%arg0, %1, %2, %3]
+    %19 = affine.apply affine_map<()[s0, s1] -> (s0 * 8 + s1 * 32 - (s0 floordiv 4) * 32)>()[%1, %workgroup_id_x]
+    %20 = nvgpu.device_async_copy %7[%18, %19], %5[%16, %9, %10], 8 : memref<2048x1024xf16> to memref<4x32x40xf16, 3>
+    %21 = nvgpu.device_async_create_group %17, %20
+    nvgpu.device_async_wait %21
+    gpu.barrier
+    %22 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%2]
+    %23 = gpu.subgroup_mma_load_matrix %4[%16, %22, %c0] {leadDimension = 40 : index} : memref<4x32x40xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
+    %24 = gpu.subgroup_mma_load_matrix %4[%16, %22, %c16] {leadDimension = 40 : index} : memref<4x32x40xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
+    %25 = affine.apply affine_map<()[s0] -> ((s0 floordiv 32) * 16)>()[%1]
+    %26 = gpu.subgroup_mma_load_matrix %5[%16, %c0, %25] {leadDimension = 40 : index} : memref<4x32x40xf16, 3> -> !gpu.mma_matrix<16x16xf16, "BOp">
+    %27 = gpu.subgroup_mma_load_matrix %5[%16, %c16, %25] {leadDimension = 40 : index} : memref<4x32x40xf16, 3> -> !gpu.mma_matrix<16x16xf16, "BOp">
+    %28 = gpu.subgroup_mma_compute %23, %26, %arg1 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+    %29 = gpu.subgroup_mma_compute %24, %27, %28 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+    scf.yield %29 : !gpu.mma_matrix<16x16xf16, "COp">
+  }
+  %12 = affine.apply affine_map<()[s0, s1] -> (s0 * 16 + s1 * 32)>()[%2, %workgroup_id_y]
+  %13 = affine.apply affine_map<()[s0, s1] -> (s1 * 32 + (s0 floordiv 32) * 16)>()[%1, %workgroup_id_x]
+  gpu.subgroup_mma_store_matrix %11, %8[%12, %13] {leadDimension = 1024 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<3456x1024xf16>
+  return
+}
+// CHECK-LABEL: func.func @_matmul_f16_f16_dispatch_0_fill_3456x1024
+// CHECK:  %[[CP_ID:.*]] = nvgpu.device_async_copy %[[GMEMPTR:.*]][%[[IDX:.*]]%[[IDY:.*]]], %[[SMEMPTR:.*]][%[[IDK_S:.*]]%[[IDX_S:.*]]%[[IDY_S:.*]]], 8, %[[PRED:.*]] : memref<3456x2048xf16> to memref<4x32x40xf16, 3>
\ No newline at end of file
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index e050776..6d31088 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -222,7 +222,7 @@
 
   // Pipeline memory operations.
   nestedModulePM.addNestedPass<func::FuncOp>(
-      createGPUPipeliningPass(pipelineDepth));
+      createGPUPipeliningPass(/*epiloguePeeling=*/true, pipelineDepth));
 }
 
 void addGPUTransposePassPipeline(OpPassManager &pm) {
diff --git a/compiler/src/iree/compiler/Codegen/Passes.h b/compiler/src/iree/compiler/Codegen/Passes.h
index b257681..689cc34 100644
--- a/compiler/src/iree/compiler/Codegen/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/Passes.h
@@ -133,7 +133,7 @@
 
 /// Apply software pipelining.
 std::unique_ptr<OperationPass<func::FuncOp>> createGPUPipeliningPass(
-    unsigned depth = 1);
+    bool epiloguePeeling = true, unsigned depth = 1);
 
 /// Converts vector ops to gpu dialect.
 std::unique_ptr<OperationPass<func::FuncOp>> createWorkGroupSwizzle(
diff --git a/compiler/src/iree/compiler/Codegen/Passes.td b/compiler/src/iree/compiler/Codegen/Passes.td
index 80ac1da..b56de9e 100644
--- a/compiler/src/iree/compiler/Codegen/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Passes.td
@@ -163,6 +163,11 @@
 def GPUPipelining : Pass<"iree-gpu-pipelining", "func::FuncOp"> {
   let summary = "Pass to do software pipelining.";
   let constructor = "mlir::iree_compiler::createGPUPipeliningPass()";
+  let options = [
+    Option<"epiloguePeeling", "epilogue-peeling", "bool",
+            /*default=*/"true",
+           "Try to use un-peeling epilogue when false, peeled epilouge o.w.">,
+  ];
 }
 
 def WorkGroupSwizzle :
diff --git a/compiler/src/iree/compiler/Tools/init_mlir_dialects.h b/compiler/src/iree/compiler/Tools/init_mlir_dialects.h
index c044362..2562372 100644
--- a/compiler/src/iree/compiler/Tools/init_mlir_dialects.h
+++ b/compiler/src/iree/compiler/Tools/init_mlir_dialects.h
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/PDL/IR/PDL.h"
 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
 #include "mlir/Dialect/Quant/QuantOps.h"
@@ -48,6 +49,7 @@
                   cf::ControlFlowDialect,
                   bufferization::BufferizationDialect,
                   gpu::GPUDialect,
+                  nvgpu::NVGPUDialect,
                   LLVM::LLVMDialect,
                   linalg::LinalgDialect,
                   math::MathDialect,
diff --git a/tests/e2e/matmul/large_linalg_matmul.mlir b/tests/e2e/matmul/large_linalg_matmul.mlir
index 69ea6c7..ccd8d45 100644
--- a/tests/e2e/matmul/large_linalg_matmul.mlir
+++ b/tests/e2e/matmul/large_linalg_matmul.mlir
@@ -1,6 +1,10 @@
 // Test large aligned linalg matmul to make sure we go through the optimized
 // path for GPUs.
-func.func @large_aligned() {
+
+// Problem size      : 2048x512x1024
+// Input type        : F32
+// Accumulation type : F32
+func.func @matmul_2048x512x1024_f32_f32() {
   %lhs = util.unfoldable_constant dense<1.0> : tensor<2048x1024xf32>
   %rhs = util.unfoldable_constant dense<0.4> : tensor<1024x512xf32>
   %c0 = arith.constant 0.0 : f32
@@ -10,4 +14,19 @@
                     outs(%CC: tensor<2048x512xf32>) -> tensor<2048x512xf32>
   check.expect_almost_eq_const(%D, dense<409.596> : tensor<2048x512xf32>) : tensor<2048x512xf32>
   return
+}
+
+// Problem size      : 3456x1024x2048
+// Input type        : F16
+// Accumulation type : F16 
+func.func @matmul_3456x1024x2048_f16_f16() {
+  %lhs = util.unfoldable_constant dense<1.00> : tensor<3456x2048xf16>
+  %rhs = util.unfoldable_constant dense<0.01> : tensor<2048x1024xf16>
+  %c0 = arith.constant 0.0 : f16
+  %init = linalg.init_tensor[3456, 1024] : tensor<3456x1024xf16>
+  %CC = linalg.fill ins(%c0 : f16) outs(%init : tensor<3456x1024xf16>) -> tensor<3456x1024xf16>
+  %D = linalg.matmul ins(%lhs, %rhs: tensor<3456x2048xf16>, tensor<2048x1024xf16>)
+                    outs(%CC: tensor<3456x1024xf16>) -> tensor<3456x1024xf16>
+  check.expect_almost_eq_const(%D, dense<20.2812> : tensor<3456x1024xf16>) : tensor<3456x1024xf16>
+  return
 }
\ No newline at end of file