Extend MapNestedForallToGpuThreadsOp to support distrution to warpId (#12272)
Distribute scf.forall marked with warpId mapping.
---------
Co-authored-by: Nicolas Vasilache <nicolas.vasilache@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
index 7141f1e..3cb94ac 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
@@ -27,6 +27,7 @@
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/Region.h"
+#include "mlir/IR/Visitors.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
@@ -120,6 +121,57 @@
// TODO: should really be: exportOp.setWorkgroupSizeAttr(newAttr);
exportOp->setAttr(exportOp.getWorkgroupSizeAttrName(), newAttr);
}
+
+ // Map warpIds, only if threadIdx.x is a multiple of the warp size.
+ // TODO: more advanced mechanism to linearize/delinearize the threadIds to
+ // warps.
+ SmallVector<DeviceMappingAttrInterface> warpMappingAttributes = {
+ gpu::GPUWarpMappingAttr::get(ctx, gpu::Warps::DimX),
+ gpu::GPUWarpMappingAttr::get(ctx, gpu::Warps::DimY),
+ gpu::GPUWarpMappingAttr::get(ctx, gpu::Warps::DimZ)};
+ if (diag.succeeded() && (workgroupSize[0] % kWarpSize == 0)) {
+ auto warpIdGenerator = [](RewriterBase &rewriter, scf::ForallOp forallOp,
+ SmallVectorImpl<Value> &warpIds) {
+ Location loc = forallOp.getLoc();
+ IndexType indexType = rewriter.getIndexType();
+ Value threadIdX =
+ rewriter.create<gpu::ThreadIdOp>(loc, indexType, gpu::Dimension::x);
+ Value cstWarpSize =
+ rewriter.create<arith::ConstantIndexOp>(loc, kWarpSize);
+ Value warpIdX =
+ rewriter.create<arith::DivUIOp>(loc, threadIdX, cstWarpSize);
+ warpIds.assign(
+ {warpIdX,
+ rewriter.create<gpu::ThreadIdOp>(loc, indexType, gpu::Dimension::y),
+ rewriter.create<gpu::ThreadIdOp>(loc, indexType,
+ gpu::Dimension::z)});
+ };
+ SmallVector<int64_t> numWarps = {workgroupSize[0] / kWarpSize,
+ workgroupSize[1], workgroupSize[2]};
+ diag = mlir::transform::gpu::mapNestedForeachToThreadsImpl(
+ rewriter, target, workgroupSize, warpIdGenerator, true, transformOp,
+ warpMappingAttributes);
+ }
+
+ auto walkResult = target->walk([&warpMappingAttributes](
+ scf::ForallOp forallOp) -> WalkResult {
+ auto maybeMapping = forallOp.getMapping();
+ if (!maybeMapping) return WalkResult::advance();
+ for (Attribute attr : *maybeMapping) {
+ for (const auto &warpAttr : warpMappingAttributes) {
+ if (attr == warpAttr) {
+ forallOp->emitOpError(
+ "Mapping failed: is threadIdx.x a multiple of the warp size?");
+ return WalkResult::interrupt();
+ }
+ }
+ }
+ return WalkResult::advance();
+ });
+ if (walkResult.wasInterrupted()) {
+ return emitDefaultDefiniteFailure(target);
+ }
+
results.push_back(target);
return diag;
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td
index bb1b4fb..33584a9 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td
@@ -23,6 +23,9 @@
let description = [{
Target the whole hal.executable_variant op and rewrite all scf.forall
to distributed gpu.thread_id and translation_info attribute.
+
+ This op will handle all the scf.forall using gpu.thread or gpu.warp
+ mapping.
The mapping of threads to gpu.thread_id is currently one-to-one and in order.
Only **bufferized** scf.forall are currently supported.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD
index 831ad64..616ad87 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD
@@ -42,6 +42,7 @@
"transform_dialect_vector_distribution.mlir",
"transform_dialect_bufferize.mlir",
"transform_dialect_promote_operands.mlir",
+ "transform_distribute_forall.mlir",
"transform_gpu_pipelining.mlir",
"transform_vector_to_mma.mlir",
"transpose_pipeline_test.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index da82228..b63550c 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -38,6 +38,7 @@
"transform_dialect_bufferize.mlir"
"transform_dialect_promote_operands.mlir"
"transform_dialect_vector_distribution.mlir"
+ "transform_distribute_forall.mlir"
"transform_gpu_pipelining.mlir"
"transform_vector_to_mma.mlir"
"transpose_pipeline_test.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_distribute_forall.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_distribute_forall.mlir
new file mode 100644
index 0000000..f77cda3
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_distribute_forall.mlir
@@ -0,0 +1,52 @@
+// RUN: iree-opt %s --pass-pipeline="builtin.module(hal.executable(iree-transform-dialect-interpreter,transform-dialect-drop-schedule))" | FileCheck %s
+
+#executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_35"}>
+#map = affine_map<()[s0] -> (s0 * 8)>
+#map1 = affine_map<(d0) -> (d0)>
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>
+#translation = #iree_codegen.translation_info<TransformDialectCodegen>
+hal.executable private @distribute {
+ hal.executable.variant public @cuda_nvptx_fb, target = #executable_target_cuda_nvptx_fb {
+ hal.executable.export public @distribute ordinal(0) layout(#pipeline_layout) attributes {translation_info = #translation} {
+ ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
+ %c1 = arith.constant 1 : index
+ hal.return %arg1, %c1, %c1 : index, index, index
+ }
+ builtin.module {
+// CHECK-LABEL: func.func @distribute
+ func.func @distribute() {
+ %cst_0 = arith.constant dense<0.000000e+00> : vector<1xf16>
+ %c250 = arith.constant 250 : index
+ %c256 = arith.constant 256 : index
+ %c0 = arith.constant 0 : index
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<2xf16>
+ memref.assume_alignment %1, 64 : memref<2xf16>
+ %workgroup_id_x = hal.interface.workgroup.id[0] : index
+ %subview = memref.subview %1[%workgroup_id_x] [1] [1] : memref<2xf16> to memref<1xf16, strided<[1], offset: ?>>
+// CHECK: %[[C32:.+]] = arith.constant 32 : index
+// CHECK: %[[TX:.+]] = gpu.thread_id x
+// CHECK: %[[COND:.*]] = arith.cmpi ult
+// CHECK: scf.if %[[COND]] {
+// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[TX]]] {in_bounds = [true]} : vector<1xf16>, memref<1xf16, strided<[1], offset: ?>>
+ scf.forall (%arg0) in (%c250) {
+ vector.transfer_write %cst_0, %subview[%arg0]
+ {in_bounds = [true]} : vector<1xf16>, memref<1xf16, strided<[1], offset: ?>>
+ } {mapping = [#gpu.thread<x>]}
+// CHECK: %[[WX:.+]] = arith.divui %[[TX]], %[[C32]] : index
+// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[WX]]] {in_bounds = [true]} : vector<1xf16>, memref<1xf16, strided<[1], offset: ?>>
+ scf.forall (%arg0) in (%c256) {
+ vector.transfer_write %cst_0, %subview[%arg0]
+ {in_bounds = [true]} : vector<1xf16>, memref<1xf16, strided<[1], offset: ?>>
+ } {mapping = [#gpu.warp<x>]}
+ return
+ }
+ module {
+ transform.structured.canonicalized_sequence failures(propagate) {
+ ^bb0(%arg0: !pdl.operation):
+ %17 = transform.structured.match ops{["func.func"]} in %arg0 : (!pdl.operation) -> !pdl.operation
+ %18 = transform.iree.map_nested_forall_to_gpu_threads %17 {workgroup_size = [256, 1, 1]}
+ }
+ }
+ }
+}
+}