[Codegen][GPU] Rework scf.forall fusion to support different thread counts (#18280)
The current fusion pattern is restricted to cases where the thread count
of each loop being fused is statically the same. This changes the
pattern to instead generate an scf.for loop within the consumer loop and
map the producer loop to the iteration space of the consumer loop. This
will allow supporting dynamic and unaligned code generation.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp
index 1c3fac3..9b016fb 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp
@@ -124,7 +124,7 @@
}
}
}
- tilingOptions.setMapping(mapping);
+ tilingOptions.setMapping(llvm::to_vector(llvm::reverse(mapping)));
}
scf::SCFTileAndFuseOptions tileAndFuseOptions;
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir
index 170161b..1b9cc62 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir
@@ -41,13 +41,13 @@
// THREAD: scf.forall ({{.*}}) = (0, 0) to (64, 256) step (2, 16)
// THREAD: linalg.generic {{.*}} ins(%{{.*}}: tensor<2x16xf32>, tensor<2x16xf32>)
// THREAD: scf.forall.in_parallel
-// THREAD: mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]
+// THREAD: mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]
// SUBGROUP-LABEL: func.func @add_tensor
// SUBGROUP: scf.forall ({{.*}}) = (0, 0) to (64, 256) step (2, 16)
// SUBGROUP: linalg.generic {{.*}} ins(%{{.*}}: tensor<2x16xf32>, tensor<2x16xf32>)
// SUBGROUP: scf.forall.in_parallel
-// SUBGROUP: mapping = [#gpu.warp<linear_dim_0>, #gpu.warp<linear_dim_1>]
+// SUBGROUP: mapping = [#gpu.warp<linear_dim_1>, #gpu.warp<linear_dim_0>]
// -----
@@ -138,13 +138,13 @@
// THREAD-LABEL: func.func @matmul_transpose_b
// THREAD: scf.forall ({{.*}}) in (64, 4)
// THREAD: linalg.copy
-// THREAD: mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]
+// THREAD: mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]
// THREAD: scf.forall ({{.*}}) in (64, 4)
// THREAD: linalg.copy
-// THREAD: mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]
+// THREAD: mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]
// THREAD: scf.forall ({{.*}}) = (0, 0) to (64, 64) step (4, 4)
// THREAD: linalg.matmul
-// THREAD: mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]
+// THREAD: mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]
// -----
@@ -310,7 +310,7 @@
// THREAD: scf.forall ({{.*}}) = (0, 0) to (64, 256) step (8, 4)
// THREAD: linalg.generic {{.*}} ins(%{{.*}}: tensor<8x4xf32>, tensor<8x4xf32>)
// THREAD: scf.forall.in_parallel
-// THREAD: mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]
+// THREAD: mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]
// -----
@@ -344,7 +344,7 @@
// THREAD: scf.forall ({{.*}}) = (0, 0, 0) to (2, 128, 8) step (1, 1, 4)
// THREAD: iree_linalg_ext.im2col {{.*}} ins(%{{.*}}: tensor<1x34x34x128xf16>) outs({{.*}}: tensor<1x1x4xf16>)
// THREAD: scf.forall.in_parallel
-// THREAD: mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_2>]
+// THREAD: mapping = [#gpu.thread<linear_dim_2>, #gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]
// -----
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td
index e1d9c4c..ceba3f0 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td
@@ -198,6 +198,12 @@
`extract_slice` of the consumer. If specified, uses |address_space| for
the intermediate allocation.
+ The mapping attributes of both the producer and consumer `scf.forall` ops
+ must be in a relative descending order, for example:
+ [#gpu.thread<z>, #gpu.thread<y>, #gpu.thread<x>]
+ or
+ [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]
+
NOTE: This pattern implicitly REQUIRES that the resulting scf.forall
is capable of synchronizing all threads at the point of fusion (i.e.
inserting a barrier). This invalidates certain kinds of lowerings of
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_fuse_forall.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_fuse_forall.mlir
index 4bf4fe1..b70768d 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_fuse_forall.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_fuse_forall.mlir
@@ -37,9 +37,9 @@
}
}
-// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 16)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
-// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> (d0 * 2)>
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 16)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 8 + d2)>
+// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0) -> (d0 * 2)>
// CHECK-LABEL: func @fuse_forall
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<128x128xf32>
@@ -49,14 +49,18 @@
// CHECK: scf.forall (%[[IDX:.+]], %[[IDY:.+]]) in (8, 8) shared_outs(%[[INIT:.+]] = %[[EMPTY]]) -> (tensor<128x128xf32>) {
// CHECK-DAG: %[[OUTID0:.+]] = affine.apply #[[$MAP]](%[[IDX]])
// CHECK-DAG: %[[OUTID1:.+]] = affine.apply #[[$MAP]](%[[IDY]])
-// CHECK: %[[LINEARID:.+]] = affine.apply #[[$MAP1]](%[[IDX]], %[[IDY]])
-// CHECK: %[[IDS:.+]]:2 = affine.delinearize_index %[[LINEARID]] into (%c64, %c1) : index, index
-// CHECK: %[[INID0:.+]] = affine.apply #[[$MAP2]](%[[IDS]]#0)
-// CHECK: %[[INSLICE0:.+]] = tensor.extract_slice %[[ARG0]][%[[INID0]], %[[IDS]]#1] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32>
-// CHECK: %[[INSLICE1:.+]] = tensor.extract_slice %[[EMPTY]][%[[INID0]], %[[IDS]]#1] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32>
-// CHECK: %[[COPY:.+]] = linalg.copy ins(%[[INSLICE0]] : tensor<2x128xf32>) outs(%[[INSLICE1]] : tensor<2x128xf32>) -> tensor<2x128xf32>
-// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[ALLOC]][%[[INID0]], %[[IDS]]#1] [2, 128] [1, 1]
-// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[INSERT]]
+
+// CHECK: %[[LOOP:.+]] = scf.for %[[I:.+]] = %c0 to %c64{{.*}} step %c64{{.*}} iter_args(%[[ITER:.+]] = %[[ALLOC]]) -> (tensor<128x128xf32>)
+// CHECK: %[[LINEARID:.+]] = affine.apply #[[$MAP2]](%[[I]], %[[IDX]], %[[IDY]])
+// CHECK: %[[IDS:.+]]:2 = affine.delinearize_index %[[LINEARID]] into (%c1, %c64) : index, index
+// CHECK: %[[INID0:.+]] = affine.apply #[[$MAP3]](%[[IDS]]#1)
+// CHECK: %[[INSLICE0:.+]] = tensor.extract_slice %[[ARG0]][%[[INID0]], %[[IDS]]#0] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32>
+// CHECK: %[[INSLICE1:.+]] = tensor.extract_slice %[[ITER]][%[[INID0]], %[[IDS]]#0] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32>
+// CHECK: %[[COPY:.+]] = linalg.copy ins(%[[INSLICE0]] : tensor<2x128xf32>) outs(%[[INSLICE1]] : tensor<2x128xf32>) -> tensor<2x128xf32>
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[ITER]][%[[INID0]], %[[IDS]]#0] [2, 128] [1, 1]
+// CHECK: scf.yield %[[INSERT]]
+
+// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[LOOP]]
// CHECK: ^bb0(%[[INTERMEDIATE:.+]]: tensor<128x128xf32>):
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[INTERMEDIATE]][%[[OUTID0]], %[[OUTID1]]] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
// CHECK: iree_gpu.yield %[[SLICE]]
@@ -108,9 +112,9 @@
}
}
-// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 16)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
-// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> (d0 * 2)>
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 16)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0) -> (d0 * 2)>
// CHECK-LABEL: func @fuse_forall
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<128x128xf32>
@@ -118,8 +122,9 @@
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<128x128xf32>
// CHECK-DAG: %[[ALLOC:.+]] = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<128x128xf32>
// CHECK: scf.forall (%[[IDX:.+]], %[[IDY:.+]]) in (8, 8) shared_outs(%[[INIT:.+]] = %[[EMPTY]]) -> (tensor<128x128xf32>) {
-// CHECK: %[[INSERT:.+]] = tensor.insert_slice %{{.*}} into %[[ALLOC]]
-// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[INSERT]]
+// CHECK: %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[INIT:.+]] = %[[ALLOC]])
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %{{.*}} into %[[INIT]]
+// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[LOOP]]
// CHECK: } : tensor<128x128xf32> -> tensor<16x16xf32>
// CHECK: } {mapping = [#gpu.warp<y>, #gpu.warp<x>]}
@@ -163,9 +168,9 @@
}
}
-// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 16)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
-// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> (d0 * 2)>
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 16)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0) -> (d0 * 2)>
// CHECK-LABEL: func @fuse_forall_with_reshape
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<128x128xf32>
@@ -173,8 +178,9 @@
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<128x128xf32>
// CHECK-DAG: %[[ALLOC:.+]] = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<128x128xf32>
// CHECK: scf.forall (%[[IDX:.+]], %[[IDY:.+]]) in (8, 8) shared_outs(%[[INIT:.+]] = %[[EMPTY]]) -> (tensor<128x128xf32>) {
-// CHECK: %[[INSERT:.+]] = tensor.insert_slice %{{.*}} into %[[ALLOC]]
-// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[INSERT]]
+// CHECK: %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[INIT:.+]] = %[[ALLOC]])
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %{{.*}} into %[[INIT]]
+// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[LOOP]]
// CHECK: ^bb0(%[[INTERMEDIATE:.+]]: tensor<128x128xf32>):
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[INTERMEDIATE]] {{\[}}[0, 1], [2]{{\]}} output_shape [2, 64, 128]
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[EXPAND]][0, %{{.*}}, %{{.*}}] [1, 16, 16] [1, 1, 1] : tensor<2x64x128xf32> to tensor<16x16xf32>
@@ -227,9 +233,9 @@
}
}
-// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 8 + d2 * 4)>
-// CHECK: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0 + d1 * 4 + d2 * 32 + d3 * 16)>
-// CHECK: #[[$MAP4:.+]] = affine_map<(d0) -> (d0 * 2)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 8 + d2 * 4)>
+// CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0 + d1 + d2 * 4 + d3 * 32 + d4 * 16)>
+// CHECK-DAG: #[[$MAP5:.+]] = affine_map<(d0) -> (d0 * 2)>
// CHECK-LABEL: func @fuse_thread_forall_with_warp_and_lane
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<128x128xf32>
@@ -238,12 +244,124 @@
// CHECK-DAG: %[[ALLOC:.+]] = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<128x128xf32>
// CHECK: scf.forall (%[[W_IDX:.+]], %[[W_IDY:.+]]) in (2, 2) shared_outs(%[[INIT:.+]] = %[[EMPTY]]) -> (tensor<128x128xf32>) {
// CHECK: scf.forall (%[[L_IDX:.+]], %[[L_IDY:.+]]) in (4, 4) {{.*}} -> (tensor<64x64xf32>)
-// CHECK-DAG: %[[FLAT_ID:.+]] = affine.apply #[[$MAP3]](%[[L_IDY]], %[[L_IDX]], %[[W_IDX]], %[[W_IDY]])
-// CHECK-DAG: %[[IDS:.+]]:2 = affine.delinearize_index %[[FLAT_ID]] into (%c64, %c1) : index, index
-// CHECK-DAG: %[[IDX:.+]] = affine.apply #[[$MAP4]](%[[IDS]]#0)
-// CHECK: %[[COPY:.+]] = linalg.copy
-// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[ALLOC]][%[[IDX]], %[[IDS]]#1] [2, 128]
-// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[INSERT]]
+
+// CHECK: %[[LOOP:.+]] = scf.for %[[I:.+]] = %c0 to %c64{{.*}} step %c64{{.*}} iter_args(%[[ITER:.+]] = %[[ALLOC]]) -> (tensor<128x128xf32>)
+// CHECK: %[[FLAT_ID:.+]] = affine.apply #[[$MAP4]](%[[I]], %[[L_IDY]], %[[L_IDX]], %[[W_IDX]], %[[W_IDY]])
+// CHECK: %[[IDS:.+]]:2 = affine.delinearize_index %[[FLAT_ID]] into (%c1, %c64) : index, index
+// CHECK: %[[IDX:.+]] = affine.apply #[[$MAP5]](%[[IDS]]#1)
+// CHECK: %[[COPY:.+]] = linalg.copy
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[ITER]][%[[IDX]], %[[IDS]]#0] [2, 128]
+// CHECK: scf.yield %[[INSERT]]
+
+// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[LOOP]]
// CHECK: } : tensor<128x128xf32> -> tensor<16x16xf32>
// CHECK: } {mapping = [#iree_gpu.lane_id<1>, #iree_gpu.lane_id<0>]}
// CHECK: } {mapping = [#gpu.warp<y>, #gpu.warp<x>]}
+
+// -----
+
+#map = affine_map<(d0) -> (d0 * 4)>
+#map1 = affine_map<(d0) -> (d0 * 16)>
+module {
+ func.func @fuse_forall_different_thread_count(%arg0: tensor<128x128xf32>) -> tensor<128x128xf32> {
+ %0 = tensor.empty() : tensor<128x128xf32>
+ %2 = scf.forall (%arg5) in (32) shared_outs(%arg7 = %0) -> (tensor<128x128xf32>) {
+ %4 = affine.apply #map(%arg5)
+ %extracted_slice = tensor.extract_slice %arg0[%4, 0] [4, 128] [1, 1] : tensor<128x128xf32> to tensor<4x128xf32>
+ %extracted_slice_0 = tensor.extract_slice %arg7[%4, 0] [4, 128] [1, 1] : tensor<128x128xf32> to tensor<4x128xf32>
+ %5 = linalg.copy ins(%extracted_slice : tensor<4x128xf32>) outs(%extracted_slice_0 : tensor<4x128xf32>) -> tensor<4x128xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %5 into %arg7[%4, 0] [4, 128] [1, 1] : tensor<4x128xf32> into tensor<128x128xf32>
+ }
+ } {mapping = [#gpu.thread<x>]}
+ %3 = scf.forall (%arg5, %arg6) in (8, 8) shared_outs(%arg7 = %0) -> (tensor<128x128xf32>) {
+ %6 = affine.apply #map1(%arg5)
+ %7 = affine.apply #map1(%arg6)
+ %extracted_slice_0 = tensor.extract_slice %2[%6, %7] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
+ %extracted_slice_1 = tensor.extract_slice %arg7[%6, %7] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
+ %8 = linalg.matmul ins(%extracted_slice_0, %extracted_slice_0 : tensor<16x16xf32>, tensor<16x16xf32>) outs(%extracted_slice_1 : tensor<16x16xf32>) -> tensor<16x16xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %8 into %arg7[%6, %7] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<128x128xf32>
+ }
+ } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
+ return %3 : tensor<128x128xf32>
+ }
+}
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
+ %loops = transform.structured.match ops{["scf.forall"]} in %root : (!transform.any_op) -> !transform.any_op
+ %producer, %consumer = transform.split_handle %loops : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.iree.fuse_forall %producer into %consumer : (!transform.any_op, !transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
+
+// CHECK-LABEL: func @fuse_forall_different_thread_count
+// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<128x128xf32>
+
+// CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<128x128xf32>
+// CHECK: scf.forall (%[[IDX:.+]], %[[IDY:.+]]) in (8, 8) {{.*}} -> (tensor<128x128xf32>) {
+// CHECK: %[[LINEARID:.+]] = affine.apply #[[$MAP1]](%[[IDX]], %[[IDY]])
+// CHECK: %[[LOOP:.+]] = scf.for %[[I:.+]] = %[[LINEARID]] to %c32{{.*}} step %c64{{.*}} iter_args(%[[ITER:.+]] = %[[ALLOC]])
+// CHECK: %[[IDS:.+]] = affine.delinearize_index %[[I]] into (%c32) : index
+// CHECK: scf.yield
+// CHECK: iree_gpu.barrier_region %[[LOOP]]
+// CHECK: } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
+
+// -----
+
+#map = affine_map<(d0) -> (d0 * 4)>
+#map1 = affine_map<(d0) -> (d0 * 16)>
+module {
+ func.func @fuse_forall_dynamic_thread_count(%arg0: tensor<128x128xf32>, %x: index, %y: index, %z: index) -> tensor<128x128xf32> {
+ %0 = tensor.empty() : tensor<128x128xf32>
+ %2 = scf.forall (%arg5, %arg6, %arg7) in (%x, %y, %z) shared_outs(%arg8 = %0) -> (tensor<128x128xf32>) {
+ %slice = tensor.extract_slice %arg0[%arg5, %arg6] [4, 128] [1, 1] : tensor<128x128xf32> to tensor<4x128xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %slice into %arg8[%arg7, 0] [4, 128] [1, 1] : tensor<4x128xf32> into tensor<128x128xf32>
+ }
+ } {mapping = [#gpu.thread<z>, #gpu.thread<y>, #gpu.thread<x>]}
+ %3 = scf.forall (%arg5, %arg6) in (8, 8) shared_outs(%arg7 = %0) -> (tensor<128x128xf32>) {
+ %6 = affine.apply #map1(%arg5)
+ %7 = affine.apply #map1(%arg6)
+ %extracted_slice_0 = tensor.extract_slice %2[%6, %7] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
+ %extracted_slice_1 = tensor.extract_slice %arg7[%6, %7] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
+ %8 = linalg.matmul ins(%extracted_slice_0, %extracted_slice_0 : tensor<16x16xf32>, tensor<16x16xf32>) outs(%extracted_slice_1 : tensor<16x16xf32>) -> tensor<16x16xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %8 into %arg7[%6, %7] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<128x128xf32>
+ }
+ } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
+ return %3 : tensor<128x128xf32>
+ }
+}
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
+ %loops = transform.structured.match ops{["scf.forall"]} in %root : (!transform.any_op) -> !transform.any_op
+ %producer, %consumer = transform.split_handle %loops : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.iree.fuse_forall %producer into %consumer : (!transform.any_op, !transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
+// CHECK-DAG: #[[$MAP3:.+]] = affine_map<()[s0, s1, s2] -> (s2 * (s0 * s1))>
+
+// CHECK-LABEL: func @fuse_forall_dynamic_thread_count
+// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<128x128xf32>
+// CHECK-SAME: %[[X:[A-Za-z0-9]+]]: index
+// CHECK-SAME: %[[Y:[A-Za-z0-9]+]]: index
+// CHECK-SAME: %[[Z:[A-Za-z0-9]+]]: index
+
+// CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<128x128xf32>
+// CHECK: scf.forall (%[[IDX:.+]], %[[IDY:.+]]) in (8, 8) {{.*}} -> (tensor<128x128xf32>) {
+// CHECK-DAG: %[[LINEARID:.+]] = affine.apply #[[$MAP1]](%[[IDX]], %[[IDY]])
+// CHECK-DAG: %[[PRODCOUNT:.+]] = affine.apply #[[$MAP3]]()[%[[X]], %[[Y]], %[[Z]]]
+// CHECK: %[[LOOP:.+]] = scf.for %[[I:.+]] = %[[LINEARID]] to %[[PRODCOUNT]] step %c64{{.*}} iter_args(%[[ITER:.+]] = %[[ALLOC]])
+// CHECK: %[[IDS:.+]] = affine.delinearize_index %[[I]] into (%[[Z]], %[[Y]], %[[X]]) : index
+// CHECK: scf.yield
+// CHECK: iree_gpu.barrier_region %[[LOOP]]
+// CHECK: } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp
index 7086471..9e6ee98 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp
@@ -12,6 +12,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MathExtras.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -22,6 +23,7 @@
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -40,29 +42,14 @@
// Forall Fusion
//===---------------------------------------------------------------------===//
-static FailureOr<int64_t> getTripCount(scf::ForallOp loop) {
- ArrayRef<int64_t> lbs = loop.getStaticLowerBound();
- ArrayRef<int64_t> ubs = loop.getStaticUpperBound();
- ArrayRef<int64_t> steps = loop.getStaticStep();
-
- if (ShapedType::isDynamicShape(lbs) || ShapedType::isDynamicShape(ubs) ||
- ShapedType::isDynamicShape(steps)) {
- return failure();
- }
-
- int64_t tripCount = 1;
- for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
- tripCount *= llvm::divideCeil((ub - lb), step);
- }
- return tripCount;
-}
-
static FailureOr<SmallVector<scf::ForallOp>>
getEquivalentMappingConsumerLoopNest(scf::ForallOp producer,
scf::ForallOp consumer) {
- auto checkMappingTypes = [&](ArrayRef<Attribute> array) {
- return llvm::all_of(array, llvm::IsaPred<gpu::GPUThreadMappingAttr>) ||
- llvm::all_of(array, llvm::IsaPred<gpu::GPUWarpMappingAttr>);
+ auto compareMappingTypes = [&](ArrayRef<Attribute> l, ArrayRef<Attribute> r) {
+ return (llvm::all_of(l, llvm::IsaPred<gpu::GPUThreadMappingAttr>) &&
+ llvm::all_of(r, llvm::IsaPred<gpu::GPUThreadMappingAttr>)) ||
+ (llvm::all_of(l, llvm::IsaPred<gpu::GPUWarpMappingAttr>) &&
+ llvm::all_of(r, llvm::IsaPred<gpu::GPUWarpMappingAttr>));
};
ArrayRef<Attribute> producerMapping = producer.getMappingAttr().getValue();
@@ -72,12 +59,34 @@
return failure();
}
- if (producerMapping.front() == consumerMapping.front() &&
- checkMappingTypes(producerMapping) &&
- checkMappingTypes(consumerMapping)) {
+ auto isDescendingRelativeIndices = [&](ArrayRef<Attribute> array) {
+ int64_t prev =
+ llvm::cast<DeviceMappingAttrInterface>(array[0]).getRelativeIndex();
+ for (Attribute attr : array.drop_front()) {
+ int64_t relativeIndex =
+ llvm::cast<DeviceMappingAttrInterface>(attr).getRelativeIndex();
+ if (relativeIndex != prev - 1) {
+ return false;
+ }
+ prev = relativeIndex;
+ }
+ return true;
+ };
+
+ // Require descending relative indices so that the linearization and
+ // delinearization done in subsequent steps are valid.
+ if (!isDescendingRelativeIndices(producerMapping) ||
+ !isDescendingRelativeIndices(consumerMapping)) {
+ return failure();
+ }
+
+ // If both loops share the same kind of mapping, return the sole consumer.
+ if (compareMappingTypes(producerMapping, consumerMapping)) {
return SmallVector<scf::ForallOp>({consumer});
}
+ // The only other supported case is fusing a thread mapped loop into a nest
+ // of a warp and lane forall.
if (!llvm::all_of(producerMapping,
llvm::IsaPred<gpu::GPUThreadMappingAttr>) ||
!llvm::all_of(consumerMapping, llvm::IsaPred<IREE::GPU::LaneIdAttr>)) {
@@ -91,64 +100,39 @@
return SmallVector<scf::ForallOp>({outerWarpLoop, consumer});
}
-static LogicalResult compareWorkerCounts(scf::ForallOp producer,
- ArrayRef<scf::ForallOp> consumers) {
- FailureOr<int64_t> producerTripCount = getTripCount(producer);
- if (failed(producerTripCount)) {
+static FailureOr<Value> createSharedAllocDestination(RewriterBase &rewriter,
+ scf::ForallOp forallOp) {
+ if (forallOp->getNumResults() != 1) {
return failure();
}
- int64_t consumerTotal = 1;
- for (auto consumer : consumers) {
- FailureOr<int64_t> consumerTripCount = getTripCount(consumer);
- if (failed(consumerTripCount)) {
- return failure();
- }
- consumerTotal *= *consumerTripCount;
- }
- if (*producerTripCount != consumerTotal) {
- return failure();
- }
- return success();
-}
-static LogicalResult
-replaceConsumerChain(RewriterBase &rewriter, Location loc, Value source,
- tensor::ParallelInsertSliceOp parallelInsert,
- SmallVector<Operation *> consumerChain) {
- auto extractSlice = cast<tensor::ExtractSliceOp>(consumerChain.back());
- OpBuilder::InsertionGuard g(rewriter);
- Value shuffleDest = parallelInsert.getDest();
- auto empty = shuffleDest.getDefiningOp<tensor::EmptyOp>();
+ auto empty = forallOp.getDpsInits()[0].getDefiningOp<tensor::EmptyOp>();
// Fail if the destination is not a `tensor.empty` op and cannot be trivially
// converted to a `bufferization.alloc_tensor`.
if (!empty) {
return failure();
}
- // Replace the destination with a `bufferization.alloc_tensor` op with
- // memory space `#gpu.address_space<workgroup>`.
- {
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(empty);
- Attribute sharedMemoryAddrSpace = gpu::AddressSpaceAttr::get(
- rewriter.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
- auto allocTensor = rewriter.create<bufferization::AllocTensorOp>(
- empty->getLoc(), empty->getResultTypes()[0], empty.getDynamicSizes());
- allocTensor.setMemorySpaceAttr(sharedMemoryAddrSpace);
- shuffleDest = allocTensor.getResult();
- }
+ // Create a `bufferization.alloc_tensor` op with memory space
+ // `#gpu.address_space<workgroup>`.
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(empty);
+ Attribute sharedMemoryAddrSpace = gpu::AddressSpaceAttr::get(
+ rewriter.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
+ auto allocTensor = rewriter.create<bufferization::AllocTensorOp>(
+ empty->getLoc(), empty->getResultTypes()[0], empty.getDynamicSizes());
+ allocTensor.setMemorySpaceAttr(sharedMemoryAddrSpace);
+ return allocTensor.getResult();
+}
- // Create an insert_slice for the result of the first forall op into the
- // shared memory alloc_tensor.
- SmallVector<OpFoldResult, 4> sourceOffsets = parallelInsert.getMixedOffsets();
- SmallVector<OpFoldResult, 4> sourceSizes = parallelInsert.getMixedSizes();
- SmallVector<OpFoldResult, 4> sourceStrides = parallelInsert.getMixedStrides();
- Value insertedSlice = rewriter.create<tensor::InsertSliceOp>(
- loc, parallelInsert.getSource(), shuffleDest, sourceOffsets, sourceSizes,
- sourceStrides);
+static void replaceConsumerChain(RewriterBase &rewriter, Location loc,
+ Value source, Value replacement,
+ SmallVector<Operation *> consumerChain) {
+ auto extractSlice = cast<tensor::ExtractSliceOp>(consumerChain.back());
+ OpBuilder::InsertionGuard g(rewriter);
auto barrierRegionOp = rewriter.create<IREE::GPU::BarrierRegionOp>(
- loc, extractSlice.getType(), insertedSlice);
+ loc, extractSlice.getType(), replacement);
rewriter.setInsertionPointToStart(barrierRegionOp.getBody());
auto terminator =
rewriter.create<IREE::GPU::YieldOp>(loc, extractSlice.getResult());
@@ -159,7 +143,6 @@
->replaceUsesOfWith(source, barrierRegionOp.getBody()->getArgument(0));
rewriter.replaceAllUsesExcept(extractSlice.getResult(), barrierRegionOp,
terminator);
- return success();
}
LogicalResult fuseForallIntoSlice(RewriterBase &rewriter,
@@ -191,6 +174,7 @@
});
};
+ // Verify that both loops are normalized.
if (!isAll(producer.getMixedStep(), 1) ||
!isAll(producer.getMixedLowerBound(), 0)) {
return failure();
@@ -205,54 +189,132 @@
rewriter.setInsertionPoint(slice);
- // Step 1. Compute the producer IDs in terms of the consumer IDs.
+ // Step 1. Get the destination of the producer loop as a shared memory
+ // allocation.
+ FailureOr<Value> sharedDest =
+ createSharedAllocDestination(rewriter, producer);
+ if (failed(sharedDest)) {
+ return failure();
+ }
+
+ // Step 2. Compute the producer IDs in terms of the consumer IDs.
+ // The producer IDs are computed as follows:
+ //
+ // producer = [p0, ..., pn] ∈ [0, ..., 0] to [P0, ..., Pn]
+ // consumer = [c0, ..., cn] ∈ [0, ..., 0] to [C0, ..., Cn]
+ //
+ // Not a real op
+ // |
+ // %ub = P0 * ... * Pn |
+ // %step = C0 * ... * Cn v
+ // %flatc = affine.linearize_index %c0, ..., %cn
+ // scf.for %id = %flatc to %ub step %step {
+ // %p:n = affine.delinearize_index %id into [%P0, ..., %Pn]
+ // ...
+ // }
+ //
+ // Note: We use 0 as the loop lower bound instead of the linearized consumer
+ // loop ID if possible to make later loop promotion patterns easier.
MLIRContext *context = rewriter.getContext();
Location loc = producer.getLoc();
+ // Compute the linearize consumer loop ID and total consumer loop worker
+ // count (C0 * ... * Cn).
AffineExpr d0, d1, d2;
bindDims(context, d0, d1, d2);
AffineExpr mulAdd = d0 * d1 + d2;
OpFoldResult linearId = rewriter.getIndexAttr(0);
+ OpFoldResult consumerWorkerCount = rewriter.getIndexAttr(1);
for (auto loop : *consumerLoopNest) {
for (auto [inductionVar, workerCount] :
llvm::zip_equal(getAsOpFoldResult(loop.getInductionVars()),
loop.getMixedUpperBound())) {
linearId = affine::makeComposedFoldedAffineApply(
rewriter, loc, mulAdd, {linearId, workerCount, inductionVar});
+ consumerWorkerCount = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, d0 * d1, {consumerWorkerCount, workerCount});
}
}
- Value linearThreadIdVal =
+ // Compute the total producer loop worker count (P0 * ... * Pn).
+ Value linearConsumerIdVal =
getValueOrCreateConstantIndexOp(rewriter, loc, linearId);
- SmallVector<Value> ranges;
- for (auto workerCount : producer.getStaticUpperBound()) {
- ranges.push_back(rewriter.create<arith::ConstantIndexOp>(loc, workerCount));
+ SmallVector<Value> producerRanges;
+ OpFoldResult producerWorkerCount = rewriter.getIndexAttr(1);
+ for (auto workerCount : producer.getMixedUpperBound()) {
+ producerRanges.push_back(
+ getValueOrCreateConstantIndexOp(rewriter, loc, workerCount));
+ producerWorkerCount = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, d0 * d1, {producerWorkerCount, workerCount});
}
- ValueRange newIds = rewriter
- .create<affine::AffineDelinearizeIndexOp>(
- loc, linearThreadIdVal, ranges)
- .getResults();
- // Step 2. Inline the region of the producer.
- SmallVector<Value> bbArgReplacements(newIds);
- bbArgReplacements.append(producer.getOutputs().begin(),
- producer.getOutputs().end());
+ std::optional<int64_t> staticProducerCount =
+ getConstantIntValue(producerWorkerCount);
+ std::optional<int64_t> staticConsumerCount =
+ getConstantIntValue(consumerWorkerCount);
+ bool perfectlyDivides =
+ staticConsumerCount && staticProducerCount &&
+ staticProducerCount.value() % staticConsumerCount.value() == 0;
+ // Step 3. Create the `scf.for` loop for the producer.
+ // If the consumer worker count perfectly divides the producer worker count,
+ // then we can use a lower bound of 0 and keep the loop bounds static.
+ Value lb = perfectlyDivides ? rewriter.create<arith::ConstantIndexOp>(loc, 0)
+ : linearConsumerIdVal;
+ Value ub =
+ getValueOrCreateConstantIndexOp(rewriter, loc, producerWorkerCount);
+ Value step =
+ getValueOrCreateConstantIndexOp(rewriter, loc, consumerWorkerCount);
+ auto newProducer =
+ rewriter.create<scf::ForOp>(loc, lb, ub, step, *sharedDest);
+ Block *loopBody = newProducer.getBody();
+
+ // Get the replacement IDs for the producer loop.
+ rewriter.setInsertionPointToStart(loopBody);
+ Value newFlatProducerId =
+ perfectlyDivides
+ ? affine::makeComposedAffineApply(
+ rewriter, loc, d0 + d1,
+ {newProducer.getInductionVar(), linearConsumerIdVal})
+ : newProducer.getInductionVar();
+
+ // We require a descending relative mapping, so delinearize in reverse order.
+ auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
+ loc, newFlatProducerId, llvm::to_vector(llvm::reverse(producerRanges)));
+
+ SmallVector<Value> newBlockArgs =
+ llvm::map_to_vector(llvm::reverse(delinearize.getResults()),
+ [](OpResult r) -> Value { return r; });
+ newBlockArgs.append(newProducer.getRegionIterArgs().begin(),
+ newProducer.getRegionIterArgs().end());
+
+ // Step 4. Inline the region of the producer and replace the terminator.
scf::InParallelOp terminator = producer.getTerminator();
- rewriter.inlineBlockBefore(producer.getBody(), slice, bbArgReplacements);
+ rewriter.mergeBlocks(producer.getBody(), loopBody, newBlockArgs);
rewriter.setInsertionPointAfter(terminator);
auto parallelInsert =
cast<tensor::ParallelInsertSliceOp>(*terminator.getYieldingOps().begin());
- if (failed(replaceConsumerChain(rewriter, loc, producer.getResult(0),
- parallelInsert, consumerChain))) {
- return failure();
- }
-
+ // Create an insert_slice to yield from the loop body.
+ SmallVector<OpFoldResult, 4> sourceOffsets = parallelInsert.getMixedOffsets();
+ SmallVector<OpFoldResult, 4> sourceSizes = parallelInsert.getMixedSizes();
+ SmallVector<OpFoldResult, 4> sourceStrides = parallelInsert.getMixedStrides();
+ Value insertedSlice = rewriter.create<tensor::InsertSliceOp>(
+ loc, parallelInsert.getSource(), parallelInsert.getDest(),
+ parallelInsert.getMixedOffsets(), parallelInsert.getMixedSizes(),
+ parallelInsert.getMixedStrides());
+ rewriter.create<scf::YieldOp>(loc, insertedSlice);
rewriter.eraseOp(parallelInsert);
rewriter.eraseOp(terminator);
+
+ // Step 5. Replace the extract slice with a `barrier_region` op to indicate
+ // synchronization of the shared tensor.
+ rewriter.setInsertionPointAfter(newProducer);
+ replaceConsumerChain(rewriter, loc, producer.getResult(0),
+ newProducer.getResult(0), consumerChain);
+
rewriter.eraseOp(producer);
return success();
}
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h
index a246124..8bdcff7 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h
@@ -38,6 +38,12 @@
/// the single consumer loop at the given |slice| within the consumer of the
/// producer. This is managed by inserting an `iree_gpu.barrier_region` at the
/// boundary to synchronize the workers at the fusion point.
+///
+/// The mapping attributes of both the producer and consumer `scf.forall` ops
+/// must be in a relative descending order, for example:
+/// [#gpu.thread<z>, #gpu.thread<y>, #gpu.thread<x>]
+/// or
+/// [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]
LogicalResult fuseForallIntoSlice(RewriterBase &rewriter,
scf::ForallOp producer,
scf::ForallOp consumer,
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir
index a50c0e3..7bb8f8f 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir
@@ -23,7 +23,7 @@
scf.forall.in_parallel {
tensor.parallel_insert_slice %16 into %arg4[%12, %13] [2, 4] [1, 1] : tensor<2x4xf16> into tensor<128x4xf16>
}
- } {mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]}
+ } {mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
%10 = scf.forall (%arg2, %arg3) in (2, 32) shared_outs(%arg4 = %7) -> (tensor<4x128xf16>) {
%12 = affine.apply #map(%arg2)
%13 = affine.apply #map1(%arg3)
@@ -35,7 +35,7 @@
scf.forall.in_parallel {
tensor.parallel_insert_slice %16 into %arg4[%12, %13] [2, 4] [1, 1] : tensor<2x4xf16> into tensor<4x128xf16>
}
- } {mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]}
+ } {mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
%11 = scf.forall (%arg2, %arg3) in (8, 8) shared_outs(%arg4 = %arg1) -> (tensor<128x128xf32>) {
%12 = affine.apply #map4(%arg2)
%13 = affine.apply #map4(%arg3)
@@ -46,7 +46,7 @@
scf.forall.in_parallel {
tensor.parallel_insert_slice %14 into %arg4[%12, %13] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<128x128xf32>
}
- } {mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]}
+ } {mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
scf.yield %11 : tensor<128x128xf32>
}
return %8 : tensor<128x128xf32>
@@ -85,7 +85,7 @@
scf.forall.in_parallel {
tensor.parallel_insert_slice %16 into %arg5[%12, %13] [2, 4] [1, 1] : tensor<2x4xf16> into tensor<128x4xf16>
}
- } {mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_2>]}
+ } {mapping = [#gpu.thread<linear_dim_2>, #gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
%11 = scf.forall (%arg2, %arg3) in (8, 8) shared_outs(%arg4 = %arg1) -> (tensor<128x128xf32>) {
%12 = affine.apply #map3(%arg2)
%13 = affine.apply #map3(%arg3)
@@ -96,7 +96,7 @@
scf.forall.in_parallel {
tensor.parallel_insert_slice %14 into %arg4[%12, %13] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<128x128xf32>
}
- } {mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]}
+ } {mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
scf.yield %11 : tensor<128x128xf32>
}
return %8 : tensor<128x128xf32>
@@ -139,7 +139,7 @@
scf.forall.in_parallel {
tensor.parallel_insert_slice %16 into %arg4[%12, %13] [2, 4] [1, 1] : tensor<2x4xf16> into tensor<128x4xf16>
}
- } {mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]}
+ } {mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
%10 = scf.forall (%arg2, %arg3) in (2, 32) shared_outs(%arg4 = %7) -> (tensor<4x128xf16>) {
%12 = affine.apply #map(%arg2)
%13 = affine.apply #map1(%arg3)
@@ -151,7 +151,7 @@
scf.forall.in_parallel {
tensor.parallel_insert_slice %16 into %arg4[%12, %13] [2, 4] [1, 1] : tensor<2x4xf16> into tensor<4x128xf16>
}
- } {mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]}
+ } {mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
%11 = scf.forall (%arg2, %arg3) in (8, 8) shared_outs(%arg4 = %arg1) -> (tensor<128x128xf32>) {
%12 = affine.apply #map4(%arg2)
%13 = affine.apply #map4(%arg3)
@@ -162,7 +162,7 @@
scf.forall.in_parallel {
tensor.parallel_insert_slice %14 into %arg4[%12, %13] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<128x128xf32>
}
- } {mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]}
+ } {mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
scf.yield %11 : tensor<128x128xf32>
}
return %8 : tensor<128x128xf32>
@@ -194,11 +194,11 @@
scf.forall.in_parallel {
tensor.parallel_insert_slice %16 into %arg7[%arg5, %arg6] [2, 4] [1, 1] : tensor<2x4xf16> into tensor<64x64xf16>
}
- } {mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]}
+ } {mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
scf.forall.in_parallel {
tensor.parallel_insert_slice %10 into %arg4[%arg2, %arg3] [64, 64] [1, 1] : tensor<64x64xf16> into tensor<128x128xf16>
}
- } {mapping = [#gpu.warp<linear_dim_0>, #gpu.warp<linear_dim_1>]}
+ } {mapping = [#gpu.warp<linear_dim_1>, #gpu.warp<linear_dim_0>]}
scf.yield %9 : tensor<128x128xf16>
}
%transpose = linalg.transpose ins(%8: tensor<128x128xf16>) outs(%empty: tensor<128x128xf16>) permutation = [1, 0]
@@ -234,11 +234,11 @@
scf.forall.in_parallel {
tensor.parallel_insert_slice %16 into %arg7[%arg5, %arg6] [2, 4] [1, 1] : tensor<2x4xf16> into tensor<64x64xf16>
}
- } {mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]}
+ } {mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
scf.forall.in_parallel {
tensor.parallel_insert_slice %10 into %arg4[%arg2, %arg3] [64, 64] [1, 1] : tensor<64x64xf16> into tensor<128x128xf16>
}
- } {mapping = [#gpu.warp<linear_dim_0>, #gpu.warp<linear_dim_1>]}
+ } {mapping = [#gpu.warp<linear_dim_1>, #gpu.warp<linear_dim_0>]}
scf.yield %9 : tensor<128x128xf16>
}
%transpose_input = linalg.transpose ins(%3: tensor<128x128xf16>) outs(%empty: tensor<128x128xf16>) permutation = [1, 0]