[LLVMGPU] Enable scf.forall distr. on vectorDistribute Pipeline (#19420)
Enables `scf.forall` distribution on the vector distribute pipeline.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index e666e8c..4db5141 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -186,10 +186,15 @@
static void tileAndDistributeToWorkgroup(
OpPassManager &funcPassManager, bool useForall,
std::optional<ConvertToDestinationPassingStylePassOptions>
- convertToDpsOptions = ConvertToDestinationPassingStylePassOptions{}) {
+ convertToDpsOptions = ConvertToDestinationPassingStylePassOptions{},
+ ReorderWorkgroupsStrategy strategy = ReorderWorkgroupsStrategy::None) {
if (useForall) {
- funcPassManager.addPass(
- createTileAndDistributeToWorkgroupsUsingForallOpPass());
+ assert((strategy == ReorderWorkgroupsStrategy::None ||
+ strategy == ReorderWorkgroupsStrategy::Transpose) &&
+ "Only None and Transpose reorder strategies are supported with "
+ "forall distribution.");
+ funcPassManager.addPass(createTileAndDistributeToWorkgroupsWithReordering(
+ strategy == ReorderWorkgroupsStrategy::Transpose));
} else {
funcPassManager.addPass(createTileAndDistributeToWorkgroupsPass(
kNumMaxParallelDims,
@@ -786,12 +791,13 @@
void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager,
const GPUPipelineOptions &options,
bool usePadToModelSharedMemcpy) {
- tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);
ReorderWorkgroupsStrategy reorderStrategy =
getReorderWorkgroupsStrategy(options.reorderStrategy);
- funcPassManager.addPass(
- createReorderWorkgroups(reorderStrategy, canReorderWorkgroups));
+
+ tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/true,
+ /*convertToDpsOptions=*/std::nullopt,
+ /*reorderStrategy=*/reorderStrategy);
if (usePadToModelSharedMemcpy) {
funcPassManager.addPass(createLLVMGPUPromoteMatmulToFitMMAPass());
@@ -1234,6 +1240,7 @@
.addPass(createVerifyWorkgroupDistributionPass);
}
variantPassManager.addPass(createReconcileTranslationInfoPass());
+ variantPassManager.addPass(createLowerAffinePass());
variantPassManager.addPass(IREE::Util::createDropCompilerHintsPass());
addLowerToLLVMGPUPasses(variantPassManager.nest<ModuleOp>(),
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_user_vector_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_user_vector_distribute.mlir
index ce04036..d6b75cc 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_user_vector_distribute.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_user_vector_distribute.mlir
@@ -33,18 +33,16 @@
// OPT-OUT-LABEL: func.func @main_0_dispatch_0_matmul_transpose_b
// OPT-OUT: memref.alloc() : memref<128x32xf16, #gpu.address_space<workgroup>>
// OPT-OUT: memref.alloc() : memref<128x32xf16, #gpu.address_space<workgroup>>
- // OPT-OUT-DAG: %[[WG_Y:.+]] = hal.interface.workgroup.id[1] : index
- // OPT-OUT-DAG: %[[WG_X:.+]] = hal.interface.workgroup.id[0] : index
- // OPT-OUT-DAG: arith.muli %[[WG_Y]], %{{.+}} : index
- // OPT-OUT-DAG: arith.addi %{{.+}}, %[[WG_X]] : index
- // OPT-OUT: scf.for
+ // OPT-OUT: scf.forall
+ // OPT-OUT: scf.for
+ // OPT-OUT: } {mapping = [#iree_codegen.workgroup_mapping<x>, #iree_codegen.workgroup_mapping<y>]}
// OPT-IN-LABEL: func.func @main_0_dispatch_0_matmul_transpose_b
// OPT-IN: memref.alloc() : memref<128x32xf16, #gpu.address_space<workgroup>>
// OPT-IN: memref.alloc() : memref<128x32xf16, #gpu.address_space<workgroup>>
- // OPT-IN-DAG: %[[WG_Y:.+]] = hal.interface.workgroup.id[1] : index
- // OPT-IN-DAG: %[[WG_X:.+]] = hal.interface.workgroup.id[0] : index
- // OPT-IN: scf.for
+ // OPT-IN: scf.forall
+ // OPT-IN: scf.for
+ // OPT-IN: } {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
func.func @main_0_dispatch_0_matmul_transpose_b_2048x10240x1280_f16xf16xf32()
attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64, {
@@ -108,20 +106,16 @@
// OPT-OUT-LABEL: func.func @main_0_dispatch_0_matmul_transpose_b
// OPT-OUT: memref.alloc() : memref<128x36xf16, #gpu.address_space<workgroup>>
// OPT-OUT: memref.alloc() : memref<128x36xf16, #gpu.address_space<workgroup>>
- // OPT-OUT-DAG: %[[WG_Y:.+]] = hal.interface.workgroup.id[1] : index
- // OPT-OUT-DAG: %[[WG_X:.+]] = hal.interface.workgroup.id[0] : index
- // OPT-OUT-DAG: arith.muli %[[WG_Y]], %{{.+}} : index
- // OPT-OUT-DAG: arith.addi %{{.+}}, %[[WG_X]] : index
- // OPT-OUT: scf.for
+ // OPT-OUT: scf.forall
+ // OPT-OUT: scf.for
+ // OPT-OUT: } {mapping = [#iree_codegen.workgroup_mapping<x>, #iree_codegen.workgroup_mapping<y>]}
// OPT-IN-LABEL: func.func @main_0_dispatch_0_matmul_transpose_b
// OPT-IN: memref.alloc() : memref<128x36xf16, #gpu.address_space<workgroup>>
// OPT-IN: memref.alloc() : memref<128x36xf16, #gpu.address_space<workgroup>>
- // OPT-IN-DAG: %[[WG_Y:.+]] = hal.interface.workgroup.id[1] : index
- // OPT-IN-DAG: %[[WG_X:.+]] = hal.interface.workgroup.id[0] : index
- // OPT-IN-DAG: arith.muli %[[WG_Y]], %{{.+}} : index
- // OPT-IN-DAG: arith.addi %{{.+}}, %[[WG_X]] : index
- // OPT-IN: scf.for
+ // OPT-IN: scf.forall
+ // OPT-IN: scf.for
+ // OPT-IN: } {mapping = [#iree_codegen.workgroup_mapping<x>, #iree_codegen.workgroup_mapping<y>]}
func.func @main_0_dispatch_0_matmul_transpose_b_2048x10240x1280_f16xf16xf32()
attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64, {
gpu_pipeline_options = #iree_gpu.pipeline_options<reorder_workgroups_strategy = <Transpose>> // enable the 'reorderWorkgroups' pass.
@@ -180,9 +174,9 @@
// OPT-OUT-LABEL: func.func @main_0_dispatch_0_matmul_transpose_b
// OPT-OUT: memref.alloc() : memref<128x36xf16, #gpu.address_space<workgroup>>
// OPT-OUT: memref.alloc() : memref<128x36xf16, #gpu.address_space<workgroup>>
- // OPT-OUT-DAG: hal.interface.workgroup.id[1] : index
- // OPT-OUT-DAG: hal.interface.workgroup.id[0] : index
- // OPT-OUT-NEXT: scf.for
+ // OPT-OUT: scf.forall
+ // OPT-OUT: scf.for
+ // OPT-OUT: } {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
func.func @main_0_dispatch_0_matmul_transpose_b_2048x10240x1280_f16xf16xf32()
attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64, {
gpu_pipeline_options = #iree_gpu.pipeline_options<reorder_workgroups_strategy = <None>> // Disable the 'reorderWorkgroups' pass.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx942.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx942.mlir
index 0a27cb9..d93467a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx942.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx942.mlir
@@ -1084,7 +1084,7 @@
// CHECK-LABEL: func.func @attention_multiple_m_transpose()
// CHECK: scf.for %{{.*}} = %c0 to %c4608 step %c64
-// CHECK-SAME: -> (vector<2x1x1xf32>, vector<2x1x1xf32>, vector<2x8x1x1x1x4xf32>)
+// CHECK-SAME: -> (vector<2x1x1xf32>, vector<2x1x1xf32>, vector<2x4x1x1x1x4xf32>)
// CHECK-COUNT-96: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
// CHECK: scf.yield
@@ -1152,8 +1152,8 @@
// CHECK-LABEL: func.func @attention_mfma_32x32x8()
// CHECK: scf.for %{{.*}} = %c0 to %c4608 step %c32
-// CHECK-SAME: -> (vector<1x1x1xf32>, vector<1x1x1xf32>, vector<1x4x1x4x1x4xf32>)
-// CHECK-COUNT-32: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<16xf32>
+// CHECK-SAME: -> (vector<1x1x1xf32>, vector<1x1x1xf32>, vector<1x2x1x4x1x4xf32>)
+// CHECK-COUNT-24: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<16xf32>
// CHECK: scf.yield
// Check that we only use alloc for Q, K, and V. No shared memory for S is
@@ -1311,3 +1311,4 @@
// MEMORY-LABEL: func.func @attention_gather_k
// MEMORY-COUNT-3: memref.alloc
+// MEMORY-NOT: memref.alloc