[Codegen][GPU] Loosen dim mapping restrictions on forall fusion (#17612)
The `FuseForalls` pattern is restricted to forall loops that have
equivalent dim mappings, but it does not have to be. This PR loosens the
restriction to just require equivalent mapping types for the 2 forall
loops, with an equivalent first dim mapping.
---------
Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp
index 6955193..3afa4ea 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp
@@ -56,27 +56,27 @@
static FailureOr<SmallVector<scf::ForallOp>>
getEquivalentMappingConsumerLoopNest(scf::ForallOp producer,
scf::ForallOp consumer) {
-
- auto checkMappingTypes = [&](ArrayAttr array) {
- return llvm::all_of(array.getValue(),
- llvm::IsaPred<gpu::GPUThreadMappingAttr>) ||
- llvm::all_of(array.getValue(),
- llvm::IsaPred<gpu::GPUWarpMappingAttr>);
+ auto checkMappingTypes = [&](ArrayRef<Attribute> array) {
+ return llvm::all_of(array, llvm::IsaPred<gpu::GPUThreadMappingAttr>) ||
+ llvm::all_of(array, llvm::IsaPred<gpu::GPUWarpMappingAttr>);
};
- ArrayAttr producerMapping = producer.getMappingAttr();
- ArrayAttr consumerMapping = consumer.getMappingAttr();
+ ArrayRef<Attribute> producerMapping = producer.getMappingAttr().getValue();
+ ArrayRef<Attribute> consumerMapping = consumer.getMappingAttr().getValue();
- if (producerMapping == consumerMapping &&
+ if (producerMapping.empty() || consumerMapping.empty()) {
+ return failure();
+ }
+
+ if (producerMapping.front() == consumerMapping.front() &&
checkMappingTypes(producerMapping) &&
checkMappingTypes(consumerMapping)) {
return SmallVector<scf::ForallOp>({consumer});
}
- if (!llvm::all_of(producerMapping.getValue(),
+ if (!llvm::all_of(producerMapping,
llvm::IsaPred<gpu::GPUThreadMappingAttr>) ||
- !llvm::all_of(consumerMapping.getValue(),
- llvm::IsaPred<IREE::GPU::LaneIdAttr>)) {
+ !llvm::all_of(consumerMapping, llvm::IsaPred<IREE::GPU::LaneIdAttr>)) {
return failure();
}
auto outerWarpLoop = consumer->getParentOfType<scf::ForallOp>();
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir
index 43863fe..c22dd40 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir
@@ -74,6 +74,64 @@
#map = affine_map<(d0) -> (d0 * 2)>
#map1 = affine_map<(d0) -> (d0 * 4)>
#map2 = affine_map<(d0)[s0] -> (d0 * 4 + s0)>
+#map3 = affine_map<(d0) -> (d0 * 16)>
+module {
+ func.func @forall_fuse_then_hoist_mixed_mappings() {
+ %c4 = arith.constant 4 : index
+ %c128 = arith.constant 128 : index
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant dense<0.0> : tensor<4x128xf16>
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128x128xf16>>
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readwrite:tensor<128x128xf32>>
+ %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128x128xf16>> -> tensor<128x128xf16>
+ %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor<readwrite:tensor<128x128xf32>> -> tensor<128x128xf32>
+ %6 = tensor.empty() : tensor<128x4xf16>
+ %7 = tensor.empty() : tensor<4x128xf16>
+ %8 = scf.for %arg0 = %c0 to %c128 step %c4 iter_args(%arg1 = %5) -> (tensor<128x128xf32>) {
+ %9 = scf.forall (%arg2, %arg3, %arg4) in (1, 64, 1) shared_outs(%arg5 = %6) -> (tensor<128x4xf16>) {
+ %12 = affine.apply #map(%arg3)
+ %13 = affine.apply #map1(%arg4)
+ %14 = affine.apply #map(%arg3)
+ %15 = affine.apply #map2(%arg4)[%arg0]
+ %extracted_slice = tensor.extract_slice %3[%14, %15] [2, 4] [1, 1] : tensor<128x128xf16> to tensor<2x4xf16>
+ %extracted_slice_0 = tensor.extract_slice %arg5[%12, %13] [2, 4] [1, 1] : tensor<128x4xf16> to tensor<2x4xf16>
+ %16 = linalg.copy ins(%extracted_slice : tensor<2x4xf16>) outs(%extracted_slice_0 : tensor<2x4xf16>) -> tensor<2x4xf16>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %16 into %arg5[%12, %13] [2, 4] [1, 1] : tensor<2x4xf16> into tensor<128x4xf16>
+ }
+ } {mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_2>]}
+ %11 = scf.forall (%arg2, %arg3) in (8, 8) shared_outs(%arg4 = %arg1) -> (tensor<128x128xf32>) {
+ %12 = affine.apply #map3(%arg2)
+ %13 = affine.apply #map3(%arg3)
+ %extracted_slice = tensor.extract_slice %9[%12, 0] [16, 4] [1, 1] : tensor<128x4xf16> to tensor<16x4xf16>
+ %extracted_slice_0 = tensor.extract_slice %cst[0, %13] [4, 16] [1, 1] : tensor<4x128xf16> to tensor<4x16xf16>
+ %extracted_slice_1 = tensor.extract_slice %arg4[%12, %13] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
+ %14 = linalg.matmul ins(%extracted_slice, %extracted_slice_0 : tensor<16x4xf16>, tensor<4x16xf16>) outs(%extracted_slice_1 : tensor<16x16xf32>) -> tensor<16x16xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %14 into %arg4[%12, %13] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<128x128xf32>
+ }
+ } {mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]}
+ scf.yield %11 : tensor<128x128xf32>
+ }
+ flow.dispatch.tensor.store %8, %2, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : tensor<128x128xf32> -> !flow.dispatch.tensor<readwrite:tensor<128x128xf32>>
+ return
+ }
+}
+
+// CHECK-LABEL: func @forall_fuse_then_hoist_mixed_mappings
+// CHECK: %[[OUTER_PARALLEL:.+]] = scf.forall
+// CHECK: %[[LOOP:.+]] = scf.for
+// CHECK: scf.yield {{.*}} : tensor<16x16xf32>
+// CHECK: scf.forall.in_parallel
+// CHECK-NEXT: tensor.parallel_insert_slice %[[LOOP]]
+// CHECK-NOT: scf.forall
+// CHECK: flow.dispatch.tensor.store %[[OUTER_PARALLEL]]
+
+// -----
+
+#map = affine_map<(d0) -> (d0 * 2)>
+#map1 = affine_map<(d0) -> (d0 * 4)>
+#map2 = affine_map<(d0)[s0] -> (d0 * 4 + s0)>
#map3 = affine_map<(d0)[s0] -> (d0 * 2 + s0)>
#map4 = affine_map<(d0) -> (d0 * 16)>
module {