Extend MapNestedForallToGpuThreadsOp to support distrution to warpId (#12272)

Distribute scf.forall marked with warpId mapping.

---------

Co-authored-by: Nicolas Vasilache <nicolas.vasilache@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
index 7141f1e..3cb94ac 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
@@ -27,6 +27,7 @@
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/IR/Region.h"
+#include "mlir/IR/Visitors.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
@@ -120,6 +121,57 @@
     // TODO: should really be: exportOp.setWorkgroupSizeAttr(newAttr);
     exportOp->setAttr(exportOp.getWorkgroupSizeAttrName(), newAttr);
   }
+
+  // Map warpIds, only if threadIdx.x is a multiple of the warp size.
+  // TODO: more advanced mechanism to linearize/delinearize the threadIds to
+  // warps.
+  SmallVector<DeviceMappingAttrInterface> warpMappingAttributes = {
+      gpu::GPUWarpMappingAttr::get(ctx, gpu::Warps::DimX),
+      gpu::GPUWarpMappingAttr::get(ctx, gpu::Warps::DimY),
+      gpu::GPUWarpMappingAttr::get(ctx, gpu::Warps::DimZ)};
+  if (diag.succeeded() && (workgroupSize[0] % kWarpSize == 0)) {
+    auto warpIdGenerator = [](RewriterBase &rewriter, scf::ForallOp forallOp,
+                              SmallVectorImpl<Value> &warpIds) {
+      Location loc = forallOp.getLoc();
+      IndexType indexType = rewriter.getIndexType();
+      Value threadIdX =
+          rewriter.create<gpu::ThreadIdOp>(loc, indexType, gpu::Dimension::x);
+      Value cstWarpSize =
+          rewriter.create<arith::ConstantIndexOp>(loc, kWarpSize);
+      Value warpIdX =
+          rewriter.create<arith::DivUIOp>(loc, threadIdX, cstWarpSize);
+      warpIds.assign(
+          {warpIdX,
+           rewriter.create<gpu::ThreadIdOp>(loc, indexType, gpu::Dimension::y),
+           rewriter.create<gpu::ThreadIdOp>(loc, indexType,
+                                            gpu::Dimension::z)});
+    };
+    SmallVector<int64_t> numWarps = {workgroupSize[0] / kWarpSize,
+                                     workgroupSize[1], workgroupSize[2]};
+    diag = mlir::transform::gpu::mapNestedForeachToThreadsImpl(
+        rewriter, target, workgroupSize, warpIdGenerator, true, transformOp,
+        warpMappingAttributes);
+  }
+
+  auto walkResult = target->walk([&warpMappingAttributes](
+                                     scf::ForallOp forallOp) -> WalkResult {
+    auto maybeMapping = forallOp.getMapping();
+    if (!maybeMapping) return WalkResult::advance();
+    for (Attribute attr : *maybeMapping) {
+      for (const auto &warpAttr : warpMappingAttributes) {
+        if (attr == warpAttr) {
+          forallOp->emitOpError(
+              "Mapping failed: is threadIdx.x a multiple of the warp size?");
+          return WalkResult::interrupt();
+        }
+      }
+    }
+    return WalkResult::advance();
+  });
+  if (walkResult.wasInterrupted()) {
+    return emitDefaultDefiniteFailure(target);
+  }
+
   results.push_back(target);
   return diag;
 }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td
index bb1b4fb..33584a9 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td
@@ -23,6 +23,9 @@
   let description = [{
     Target the whole hal.executable_variant op and rewrite all scf.forall
     to distributed gpu.thread_id and translation_info attribute.
+    
+    This op will handle all the scf.forall using gpu.thread or gpu.warp
+    mapping.
 
     The mapping of threads to gpu.thread_id is currently one-to-one and in order.
     Only **bufferized** scf.forall are currently supported.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD
index 831ad64..616ad87 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD
@@ -42,6 +42,7 @@
             "transform_dialect_vector_distribution.mlir",
             "transform_dialect_bufferize.mlir",
             "transform_dialect_promote_operands.mlir",
+            "transform_distribute_forall.mlir",
             "transform_gpu_pipelining.mlir",
             "transform_vector_to_mma.mlir",
             "transpose_pipeline_test.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index da82228..b63550c 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -38,6 +38,7 @@
     "transform_dialect_bufferize.mlir"
     "transform_dialect_promote_operands.mlir"
     "transform_dialect_vector_distribution.mlir"
+    "transform_distribute_forall.mlir"
     "transform_gpu_pipelining.mlir"
     "transform_vector_to_mma.mlir"
     "transpose_pipeline_test.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_distribute_forall.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_distribute_forall.mlir
new file mode 100644
index 0000000..f77cda3
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_distribute_forall.mlir
@@ -0,0 +1,52 @@
+// RUN: iree-opt %s --pass-pipeline="builtin.module(hal.executable(iree-transform-dialect-interpreter,transform-dialect-drop-schedule))" | FileCheck %s
+
+#executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_35"}>
+#map = affine_map<()[s0] -> (s0 * 8)>
+#map1 = affine_map<(d0) -> (d0)>
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>
+#translation = #iree_codegen.translation_info<TransformDialectCodegen>
+hal.executable private @distribute {
+  hal.executable.variant public @cuda_nvptx_fb, target = #executable_target_cuda_nvptx_fb {
+    hal.executable.export public @distribute ordinal(0) layout(#pipeline_layout) attributes {translation_info = #translation} {
+    ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
+      %c1 = arith.constant 1 : index
+      hal.return %arg1, %c1, %c1 : index, index, index
+    }
+    builtin.module {
+// CHECK-LABEL: func.func @distribute
+      func.func @distribute() {
+        %cst_0 = arith.constant dense<0.000000e+00> : vector<1xf16>
+        %c250 = arith.constant 250 : index
+        %c256 = arith.constant 256 : index
+        %c0 = arith.constant 0 : index
+        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<2xf16>
+        memref.assume_alignment %1, 64 : memref<2xf16>
+        %workgroup_id_x = hal.interface.workgroup.id[0] : index
+        %subview = memref.subview %1[%workgroup_id_x] [1] [1] : memref<2xf16> to memref<1xf16, strided<[1], offset: ?>>
+// CHECK: %[[C32:.+]] = arith.constant 32 : index
+// CHECK: %[[TX:.+]] = gpu.thread_id  x
+// CHECK: %[[COND:.*]] = arith.cmpi ult
+// CHECK: scf.if %[[COND]] {
+// CHECK:   vector.transfer_write %{{.*}}, %{{.*}}[%[[TX]]] {in_bounds = [true]} : vector<1xf16>, memref<1xf16, strided<[1], offset: ?>>
+        scf.forall (%arg0) in (%c250) {
+          vector.transfer_write %cst_0, %subview[%arg0]
+          {in_bounds = [true]} : vector<1xf16>, memref<1xf16, strided<[1], offset: ?>>
+        } {mapping = [#gpu.thread<x>]}
+// CHECK: %[[WX:.+]] = arith.divui %[[TX]], %[[C32]] : index
+// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[WX]]] {in_bounds = [true]} : vector<1xf16>, memref<1xf16, strided<[1], offset: ?>>
+        scf.forall (%arg0) in (%c256) {
+          vector.transfer_write %cst_0, %subview[%arg0]
+          {in_bounds = [true]} : vector<1xf16>, memref<1xf16, strided<[1], offset: ?>>
+        } {mapping = [#gpu.warp<x>]}        
+        return
+      }
+      module {
+        transform.structured.canonicalized_sequence failures(propagate) {
+        ^bb0(%arg0: !pdl.operation):
+        %17 = transform.structured.match ops{["func.func"]} in %arg0 : (!pdl.operation) -> !pdl.operation
+        %18 = transform.iree.map_nested_forall_to_gpu_threads %17 {workgroup_size = [256, 1, 1]}
+      }
+    }
+  }
+}
+}