Update matmul tensorcore strategy to properly trigger pipelining with… (#13194)

… mma.sync

---------

Co-authored-by: Quentin Colombet <quentin.colombet@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir
index 5e1e1b2..42d1892 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir
@@ -1,4 +1,21 @@
 // RUN: iree-opt %s --split-input-file --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target{test-lowering-configuration})))" --iree-codegen-llvmgpu-enable-transform-dialect-jit --iree-codegen-llvmgpu-enable-transform-dialect-matmul-tensorcore-strategy | FileCheck %s
+// Check that setting the command line options affect the transform
+// strategy as expected.
+// RUN: iree-opt %s --split-input-file --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target{test-lowering-configuration})))" --iree-codegen-llvmgpu-enable-transform-dialect-jit --iree-codegen-llvmgpu-enable-transform-dialect-matmul-tensorcore-strategy \
+// RUN: -td-matmul-strategy-blk-size-x=256 \
+// RUN: -td-matmul-strategy-blk-size-y=64 \
+// RUN: -td-matmul-strategy-blk-size-z=1 \
+// RUN: -td-matmul-strategy-reduc-size=8 \
+// RUN: -td-matmul-strategy-num-threads-x=32 \
+// RUN: -td-matmul-strategy-num-threads-y=4 \
+// RUN: -td-matmul-strategy-num-threads-z=1 \
+// RUN: -td-matmul-strategy-num-warps-x=1 \
+// RUN: -td-matmul-strategy-num-warps-y=4 \
+// RUN: -td-matmul-strategy-num-warps-z=1 \
+// RUN: -td-matmul-strategy-use-async-copies=true \
+// RUN: -td-matmul-strategy-use-mma-sync=true \
+// RUN: -td-matmul-strategy-pipeline-depth=5 \
+// RUN: | FileCheck --check-prefix WITH_OPTIONS %s
 
 hal.executable @matmul {
 hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}> {
@@ -103,6 +120,60 @@
 // CHECK: transform.iree.apply_patterns %{{.*}} {canonicalization, cse, fold_memref_aliases, licm, tiling_canonicalization} : (!pdl.operation) -> ()
 
 
+// WITH_OPTIONS-LABEL: func @matmul
+
+// WITH_OPTIONS: transform.sequence  failures(propagate) {
+// WITH_OPTIONS: transform.iree.match_callback failures(propagate) "matmul"
+// Tile sizes are set by td-matmul-strategy-blk-size-XX.
+// WITH_OPTIONS: transform.iree.tile_to_forall_and_workgroup_count_region %{{.*}} num_threads [] tile_sizes [256, 64](mapping = [#gpu.block<y>, #gpu.block<x>])
+// WITH_OPTIONS: transform.structured.fuse_into_containing_op
+// The tiling is affected by td-matmul-strategy-reduc-size: 8.
+// WITH_OPTIONS: transform.structured.tile %{{.*}}[0, 0, 8]
+// WITH_OPTIONS: transform.structured.pad %{{.*}} {pack_paddings = [1, 1, 1], padding_dimensions = [0, 1, 2], padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]}
+// WITH_OPTIONS: transform.structured.hoist_pad %{{.}} by 1 loops
+// WITH_OPTIONS: transform.structured.insert_slice_to_copy %{{.*}} : (!pdl.operation) -> !pdl.operation
+// WITH_OPTIONS: transform.structured.tile_to_forall_op %{{.*}}   num_threads [64, 2] tile_sizes [](mapping = [#gpu.linear<x>, #gpu.linear<y>])
+// WITH_OPTIONS:   transform.scf.take_assumed_branch %{{.*}} take_else_branch : (!pdl.operation) -> ()
+// WITH_OPTIONS: transform.structured.tile_to_forall_op %{{.*}}   num_threads [8, 16] tile_sizes [](mapping = [#gpu.linear<y>, #gpu.linear<x>])
+// WITH_OPTIONS: transform.scf.take_assumed_branch %{{.*}} take_else_branch : (!pdl.operation) -> ()
+// WITH_OPTIONS: transform.structured.tile_to_forall_op %{{.*}}   num_threads [8, 16] tile_sizes [](mapping = [#gpu.linear<y>, #gpu.linear<x>])
+// WITH_OPTIONS: transform.structured.tile_to_forall_op %{{.*}}   num_threads [1, 4] tile_sizes [](mapping = [#gpu.warp<y>, #gpu.warp<x>])
+// WITH_OPTIONS: transform.structured.tile_to_forall_op %{{.*}}   num_threads [1, 4] tile_sizes [](mapping = [#gpu.warp<y>, #gpu.warp<x>])
+// WITH_OPTIONS: transform.structured.masked_vectorize %{{.*}} vector_sizes [4, 4]
+// WITH_OPTIONS: transform.structured.masked_vectorize %{{.*}} vector_sizes [1, 4]
+// WITH_OPTIONS: transform.structured.masked_vectorize %{{.*}} vector_sizes [32, 4]
+// WITH_OPTIONS: transform.vector.lower_masked_transfers %{{.*}} : (!pdl.operation) -> !pdl.operation
+// WITH_OPTIONS: transform.structured.vectorize %{{.*}}
+// WITH_OPTIONS: transform.iree.eliminate_empty_tensors %{{.*}}
+// WITH_OPTIONS: transform.iree.bufferize {target_gpu} %{{.*}} : (!pdl.operation) -> !pdl.operation
+// WITH_OPTIONS: transform.iree.erase_hal_descriptor_type_from_memref %{{.*}} : (!pdl.operation) -> ()
+// WITH_OPTIONS: transform.iree.forall_to_workgroup %{{.*}} : (!pdl.operation) -> ()
+// The workgroup dimensions are controled by td-matmul-strategy-num-threads-XX.
+// The warp dimensions are controled by td-matmul-strategy-num-warps-XX.
+// WITH_OPTIONS: transform.iree.map_nested_forall_to_gpu_threads %{{.*}} workgroup_dims = [32, 4, 1] warp_dims = [1, 4, 1] : (!pdl.operation) -> ()
+// WITH_OPTIONS: transform.iree.hoist_static_alloc %{{.*}} : (!pdl.operation) -> ()
+// WITH_OPTIONS: transform.iree.apply_patterns %{{.*}} {fold_memref_aliases} : (!pdl.operation) -> ()
+// WITH_OPTIONS: transform.iree.apply_patterns %{{.*}} {extract_address_computations} : (!pdl.operation) -> ()
+// The unroll attribute should match td-matmul-use-mma-sync, for true: mma_sync,
+// for false:_wmma.
+// WITH_OPTIONS: transform.iree.apply_patterns %{{.*}} {unroll_vectors_gpu_mma_sync} : (!pdl.operation) -> ()
+// WITH_OPTIONS: transform.structured.hoist_redundant_vector_transfers %{{.*}} : (!pdl.operation) -> !pdl.operation
+// WITH_OPTIONS: transform.iree.apply_buffer_optimizations %{{.*}} : (!pdl.operation) -> ()
+// The attribute should match td-matmul-use-mma-sync.
+// WITH_OPTIONS: transform.iree.vector.vector_to_mma_conversion %{{.*}} {use_mma_sync} : (!pdl.operation) -> ()
+// WITH_OPTIONS: transform.iree.apply_patterns %{{.*}} {fold_memref_aliases} : (!pdl.operation) -> ()
+// The multibuffer pass is only run when we set use-async-copies.
+// The factor should match td-matmul-strategy-pipeline-depth: 5.
+// WITH_OPTIONS: transform.memref.multibuffer %{{.*}} {factor = 5 : i64, skip_analysis} : (!transform.op<"memref.alloc">) -> !pdl.operation
+// WITH_OPTIONS: transform.vector.transfer_to_scf %{{.*}}   max_transfer_rank = 1 full_unroll = true : (!pdl.operation) -> !pdl.operation
+// The attribute should match td-matmul-use-mma-sync.
+// WITH_OPTIONS: transform.iree.create_async_groups %{{.*}} {use_mma_sync = true} : (!pdl.operation) -> ()
+// The depth should match td-matmul-strategy-pipeline-depth: 5.
+// WITH_OPTIONS: transform.iree.pipeline_shared_memory_copies %{{.*}} {depth = 5 : i64} : (!pdl.operation) -> !pdl.operation
+// WITH_OPTIONS: transform.vector.lower_masks %{{.*}} : (!pdl.operation) -> !pdl.operation
+// WITH_OPTIONS: transform.vector.materialize_masks %{{.*}} : (!pdl.operation) -> !pdl.operation
+// WITH_OPTIONS: transform.iree.apply_patterns %{{.*}} {canonicalization, cse, fold_memref_aliases, licm, tiling_canonicalization} : (!pdl.operation) -> ()
+
 // -----
 
 hal.executable @matmul {
diff --git a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/BUILD.bazel
index ebf6b34..3061663 100644
--- a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/BUILD.bazel
@@ -45,6 +45,7 @@
         "@llvm-project//mlir:LLVMDialect",
         "@llvm-project//mlir:MemRefDialect",
         "@llvm-project//mlir:MemRefTransformOps",
+        "@llvm-project//mlir:NVGPUDialect",
         "@llvm-project//mlir:PDLDialect",
         "@llvm-project//mlir:PDLInterpDialect",
         "@llvm-project//mlir:SCFDialect",
diff --git a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/CMakeLists.txt
index 5a85f60..16a7c0a 100644
--- a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/CMakeLists.txt
@@ -46,6 +46,7 @@
     MLIRLinalgTransformOps
     MLIRMemRefDialect
     MLIRMemRefTransformOps
+    MLIRNVGPUDialect
     MLIRPDLDialect
     MLIRPDLInterpDialect
     MLIRParser
diff --git a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/MatmulTensorCoreStrategy.cpp b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/MatmulTensorCoreStrategy.cpp
index 36eb155..2606da8 100644
--- a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/MatmulTensorCoreStrategy.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/MatmulTensorCoreStrategy.cpp
@@ -11,11 +11,14 @@
 #include "iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.h"
 #include "iree/compiler/Codegen/TransformDialectStrategies/Common/Common.h"
 #include "iree/compiler/Codegen/TransformDialectStrategies/GPU/Common.h"
+#include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
+#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -63,6 +66,127 @@
 using iree_compiler::gpu::MatmulStrategy;
 using iree_compiler::gpu::scaleUpByBitWidth;
 
+/// Options to set the default values of the matmul strategy.
+
+/// Block tile size X, Y, Z.
+static llvm::cl::opt<int64_t> clBlockTileSizeX(
+    "td-matmul-strategy-blk-size-x",
+    llvm::cl::desc("block tile size for dim X (x,y,z) for the transform "
+                   "dialect matmul strategy"),
+    llvm::cl::init(128));
+static llvm::cl::opt<int64_t> clBlockTileSizeY(
+    "td-matmul-strategy-blk-size-y",
+    llvm::cl::desc("block tile size for dim Y (x,y,z) for the transform "
+                   "dialect matmul strategy"),
+    llvm::cl::init(128));
+static llvm::cl::opt<int64_t> clBlockTileSizeZ(
+    "td-matmul-strategy-blk-size-z",
+    llvm::cl::desc("block tile size for dim z (x,y,z) for the transform "
+                   "dialect matmul strategy"),
+    llvm::cl::init(1));
+
+static llvm::cl::opt<int64_t> clReductionTileSize(
+    "td-matmul-strategy-reduc-size",
+    llvm::cl::desc(
+        "reduction tile sized for the transform dialect matmul strategy"),
+    llvm::cl::init(16));
+
+/// Number of threads X, Y, Z.
+static llvm::cl::opt<int64_t> clNumThreadsX(
+    "td-matmul-strategy-num-threads-x",
+    llvm::cl::desc("number of threads for dim X (x,y,z) for the transform "
+                   "dialect matmul strategy"),
+    llvm::cl::init(64));
+static llvm::cl::opt<int64_t> clNumThreadsY(
+    "td-matmul-strategy-num-threads-y",
+    llvm::cl::desc("number of threads for dim Y (x,y,z) for the transform "
+                   "dialect matmul strategy"),
+    llvm::cl::init(2));
+static llvm::cl::opt<int64_t> clNumThreadsZ(
+    "td-matmul-strategy-num-threads-z",
+    llvm::cl::desc("number of threads for dim z (x,y,z) for the transform "
+                   "dialect matmul strategy"),
+    llvm::cl::init(1));
+
+/// Number of warps X, Y, Z.
+static llvm::cl::opt<int64_t> clNumWarpsX(
+    "td-matmul-strategy-num-warps-x",
+    llvm::cl::desc("number of warps for dim X (x,y,z) for the transform "
+                   "dialect matmul strategy"),
+    llvm::cl::init(2));
+static llvm::cl::opt<int64_t> clNumWarpsY(
+    "td-matmul-strategy-num-warps-y",
+    llvm::cl::desc("number of warps for dim Y (x,y,z) for the transform "
+                   "dialect matmul strategy"),
+    llvm::cl::init(2));
+static llvm::cl::opt<int64_t> clNumWarpsZ(
+    "td-matmul-strategy-num-warps-z",
+    llvm::cl::desc("number of warps for dim z (x,y,z) for the transform "
+                   "dialect matmul strategy"),
+    llvm::cl::init(1));
+
+static llvm::cl::opt<bool> clUseAsyncCopies(
+    "td-matmul-strategy-use-async-copies",
+    llvm::cl::desc("use mma sync for the transform dialect matmul strategy"),
+    llvm::cl::init(true));
+
+static llvm::cl::opt<bool> clUseMmaSync(
+    "td-matmul-strategy-use-mma-sync",
+    llvm::cl::desc("use mma sync for the transform dialect matmul strategy"),
+    llvm::cl::init(false));
+
+static llvm::cl::opt<int64_t> clPipelineDepth(
+    "td-matmul-strategy-pipeline-depth",
+    llvm::cl::desc("pipeline depth for the transform dialect matmul strategy"),
+    llvm::cl::init(3));
+
+void MatmulStrategy::initDefaultValues() {
+  blockTileSizes = {clBlockTileSizeX, clBlockTileSizeY, clBlockTileSizeZ};
+  reductionTileSize = clReductionTileSize;
+  numThreads = {clNumThreadsX, clNumThreadsY, clNumThreadsZ};
+  numWarps = {clNumWarpsX, clNumThreadsY, clNumThreadsZ};
+  useAsyncCopies = clUseAsyncCopies;
+  useMmaSync = clUseMmaSync;
+  pipelineDepth = clPipelineDepth;
+}
+
+LLVM_DUMP_METHOD void MatmulStrategy::dump() const { print(llvm::errs()); }
+
+void MatmulStrategy::print(llvm::raw_ostream &os) const {
+  os << "\n--- Matmul strategy ---\n";
+  os << "- block tile sizes: {";
+  bool isFirst = true;
+  for (int64_t blockTileSize : blockTileSizes) {
+    if (!isFirst) os << ", ";
+    os << blockTileSize;
+    isFirst = false;
+  }
+  os << "}\n";
+  os << "- reduction tile size: " << reductionTileSize << '\n';
+
+  os << "- number of threads: {";
+  isFirst = true;
+  for (int64_t numThreadsForDim : numThreads) {
+    if (!isFirst) os << ", ";
+    os << numThreadsForDim;
+    isFirst = false;
+  }
+  os << "}\n";
+
+  os << "- number of warps: {";
+  isFirst = true;
+  for (int64_t numWarpsForDim : numWarps) {
+    if (!isFirst) os << ", ";
+    os << numWarpsForDim;
+    isFirst = false;
+  }
+  os << "}\n";
+
+  os << "- use async copies: " << useAsyncCopies << '\n';
+  os << "- use mma sync: " << useMmaSync << '\n';
+  os << "- pipeline depth: " << pipelineDepth << '\n';
+}
+
 /// Build the transform IR to pad a matmul op `matmulOpH`.
 // TODO: Less hardcoded, more generalization, extract information from strategy.
 static Value buildPadMatmul(ImplicitLocOpBuilder &b, Value matmulOpH,
@@ -333,11 +457,17 @@
 static void buildPipelineSharedMemoryCopies(ImplicitLocOpBuilder &b,
                                             Value funcH,
                                             const MatmulStrategy &strategy) {
-  Value subgroupMmaOpH = b.create<transform::MatchOp>(
-      funcH, mlir::gpu::SubgroupMmaComputeOp::getOperationName());
+  Value computeOpH;
+  if (strategy.useMmaSync) {
+    computeOpH = b.create<transform::MatchOp>(
+        funcH, mlir::nvgpu::MmaSyncOp::getOperationName());
+  } else {
+    computeOpH = b.create<transform::MatchOp>(
+        funcH, mlir::gpu::SubgroupMmaComputeOp::getOperationName());
+  }
   // TODO: Better builder.
   Value forOpH = b.create<transform::GetParentForOp>(
-      pdl::OperationType::get(b.getContext()), subgroupMmaOpH);
+      pdl::OperationType::get(b.getContext()), computeOpH);
   // TODO: Better builder instead of setting post-hoc.
   auto pipelineOp = b.create<
       iree_compiler::IREE::transform_dialect::PipelineSharedMemoryCopiesOp>(
@@ -394,6 +524,10 @@
 
 void iree_compiler::gpu::buildMatmulTensorCoreStrategy(
     ImplicitLocOpBuilder &b, Value variantH, const MatmulStrategy &strategy) {
+  assert(strategy.totalNumThreads() ==
+             strategy.totalNumWarps() * kCudaWarpSize &&
+         "Number of threads specified by warps must match total number of "
+         "threads");
   // Step 1. Apply block-level part of the strategy, keeps everything fused.
   auto [fillH, matmulH, maybeTiledTrailingHBlock, forall] =
       buildMatmulStrategyBlockDistribution(b, variantH, strategy);
diff --git a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/MatmulTensorCoreStrategy.h b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/MatmulTensorCoreStrategy.h
index 51d579c..c6ca2cf 100644
--- a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/MatmulTensorCoreStrategy.h
+++ b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/MatmulTensorCoreStrategy.h
@@ -10,6 +10,10 @@
 #include "iree-dialects/Transforms/TransformMatchers.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 
+namespace llvm {
+class raw_ostream;
+}
+
 namespace mlir {
 namespace iree_compiler {
 namespace gpu {
@@ -68,20 +72,24 @@
 struct MatmulStrategy : StrategyBase {
   MatmulStrategy(MLIRContext *context,
                  const transform_ext::MatchedMatmulCaptures &captures)
-      : StrategyBase(context), captures(captures) {}
+      : StrategyBase(context), captures(captures) {
+    initDefaultValues();
+  }
 
   /// Constructor quantities.
   transform_ext::MatchedMatmulCaptures captures;
 
   /// Tile sizes for the workgroup / determines grid size for all known
-  /// reduction strategies.
-  SmallVector<int64_t> blockTileSizes = {128, 128, 1};
-  int64_t reductionTileSize = 16;
-  SmallVector<int64_t> numThreads = {64, 2, 1};
-  SmallVector<int64_t> numWarps = {2, 2, 1};
-  bool useAsyncCopies = true;
-  bool useMmaSync = false;
-  int64_t pipelineDepth = 3;
+  /// reduction strategies. The initial values are set by initDefaultValues();
+  SmallVector<int64_t> blockTileSizes;
+  int64_t reductionTileSize;
+  SmallVector<int64_t> numThreads;
+  SmallVector<int64_t> numWarps;
+  bool useAsyncCopies;
+  bool useMmaSync;
+  int64_t pipelineDepth;
+
+  void initDefaultValues();
 
   int64_t m() const {
     assert(captures.matmulOpSizes.size() == 3 && "need 3 sizes");
@@ -100,6 +108,11 @@
     for (auto v : numThreads) res *= v;
     return res;
   }
+  int64_t totalNumWarps() const {
+    int64_t res = 1;
+    for (auto v : numWarps) res *= v;
+    return res;
+  }
 
   int64_t lhsCopyVectorSize() const {
     if (k() % 4 == 0) return 4;
@@ -185,6 +198,9 @@
                        /*tileSizes=*/{},
                        /*threadMapping=*/{warpY(), warpX()}};
   }
+
+  void print(llvm::raw_ostream &os) const;
+  LLVM_DUMP_METHOD void dump() const;
 };
 
 void buildMatmulTensorCoreStrategy(ImplicitLocOpBuilder &b, Value variantH,