[CUDA] Add support for config lowering coming from IR (#7082)
To enable search we support having the config coming from the higher
level IR.
diff --git a/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 5d2ddb0..cb554f2 100644
--- a/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -221,7 +221,41 @@
IREE::HAL::DispatchLoweringPassPipeline::LLVMGPUVectorize, workgroupSize);
}
+/// Propagate the configuration annotated in the incoming IR.
+static LogicalResult setUserConfig(FuncOp entryPointFn, Operation *computeOp,
+ IREE::HAL::LoweringConfig config) {
+ IREE::HAL::DispatchLoweringPassPipeline passPipeline =
+ IREE::HAL::DispatchLoweringPassPipeline::LLVMGPUVectorize;
+ if (auto setPassPipeline = getLoweringPassPipeline(config)) {
+ passPipeline = setPassPipeline.getValue();
+ }
+ SmallVector<int64_t, 4> workgroupSize;
+ if (auto workgroupSizeAttr = config.workgroupSize()) {
+ workgroupSize = llvm::to_vector<4>(
+ llvm::map_range(workgroupSizeAttr, [](Attribute intAttr) {
+ return intAttr.cast<IntegerAttr>().getInt();
+ }));
+ }
+ if (failed(setOpConfigAndEntryPointFnTranslation(
+ entryPointFn, computeOp, config, passPipeline, workgroupSize))) {
+ return failure();
+ }
+ // Reset the op configuration to drop the pass-pipeline and workgroup size
+ // info. The op does not carry that information anymore.
+ auto resetConfig = IREE::HAL::LoweringConfig::get(
+ config.tileSizes(), config.nativeVectorSize(),
+ /*passPipeline =*/nullptr,
+ /*workgroupSize =*/nullptr, computeOp->getContext());
+ setLoweringConfig(computeOp, resetConfig);
+ return success();
+}
+
static LogicalResult setRootConfig(FuncOp entryPointFn, Operation *computeOp) {
+ if (IREE::HAL::LoweringConfig config = getLoweringConfig(computeOp)) {
+ // If the op already has a lowering config coming from the IR use this and
+ // bypass the heuristic.
+ return setUserConfig(entryPointFn, computeOp, config);
+ }
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(computeOp)) {
if (linalg::isaContractionOpInterface(linalgOp) &&
linalgOp.getNumParallelLoops() >= 2) {
diff --git a/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir b/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir
index 3665903..07cd247 100644
--- a/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir
+++ b/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir
@@ -364,3 +364,65 @@
// CHECK: func @static_3d_fft_stage3()
// CHECK: linalg_ext.fft
// CHECK-SAME: lowering.config = #[[CONFIG]]
+
+// -----
+
+hal.executable @user_config {
+hal.executable.variant public @cuda_nvptx_fb, target = #hal.executable.target<"cuda", "cuda-nvptx-fb"> {
+ hal.executable.entry_point public @_lowering_config_test_dispatch_1 attributes {interface = @io, ordinal = 0 : index}
+ builtin.module {
+ func @_lowering_config_test_dispatch_1() {
+ %cst = constant 0.000000e+00 : f32
+ %c128 = constant 128 : index
+ %c1024 = constant 1024 : index
+ %c0 = constant 0 : index
+ %0 = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : !flow.dispatch.tensor<readonly:128x256xf32>
+ %1 = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : !flow.dispatch.tensor<readonly:256x1024xf32>
+ %2 = hal.interface.binding.subspan @io::@s0b2_xw_external[%c0] : !flow.dispatch.tensor<writeonly:128x1024xf32>
+ %workgroup_size_x = hal.interface.workgroup.size[0] : index
+ %workgroup_size_y = hal.interface.workgroup.size[1] : index
+ %workgroup_id_x = hal.interface.workgroup.id[0] : index
+ %workgroup_count_x = hal.interface.workgroup.count[0] : index
+ %workgroup_id_y = hal.interface.workgroup.id[1] : index
+ %workgroup_count_y = hal.interface.workgroup.count[1] : index
+ %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
+ %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
+ scf.for %arg0 = %3 to %c128 step %4 {
+ %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
+ %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
+ scf.for %arg1 = %5 to %c1024 step %6 {
+ %7 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 128)>(%arg0)[%workgroup_size_y]
+ %8 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [%7, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x256xf32> -> tensor<?x256xf32>
+ %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1024)>(%arg1)[%workgroup_size_x]
+ %10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [256, %9], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x1024xf32> -> tensor<256x?xf32>
+ %11 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 128)>(%arg0)[%workgroup_size_y]
+ %12 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1024)>(%arg1)[%workgroup_size_x]
+ %13 = affine.min affine_map<(d0)[s0] -> (-d0 + 128, s0)>(%arg0)[%workgroup_size_y]
+ %14 = affine.min affine_map<(d0)[s0] -> (-d0 + 1024, s0)>(%arg1)[%workgroup_size_x]
+ %15 = linalg.init_tensor [%13, %14] : tensor<?x?xf32>
+ %16 = linalg.fill(%cst, %15) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
+ %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup", lowering.config = {passPipeline = "LLVMGPUMatmulSimt", tileSizes = [[32, 256, 64], [], [4, 16]], workgroupSize = [16, 8, 1]}} ins(%8, %10 : tensor<?x256xf32>, tensor<256x?xf32>) outs(%16 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ flow.dispatch.tensor.store %17, %2, offsets = [%arg0, %arg1], sizes = [%11, %12], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:128x1024xf32>
+ }
+ }
+ return
+ }
+ hal.interface private @io {
+ hal.interface.binding public @s0b0_ro_external, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding public @s0b1_ro_external, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding public @s0b2_xw_external, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+ }
+}
+}
+
+// CHECK-DAG: #[[CONFIG:.+]] = {{{.*}}tileSizes = {{\[}}[32, 256, 64], [], [4, 16]{{\]}}}
+// CHECK: hal.executable.entry_point public @_lowering_config_test_dispatch_1
+// CHECK-SAME: passPipeline = "LLVMGPUMatmulSimt"
+// CHECK-SAME: workloadPerWorkgroup = [256, 32]
+// CHECK-SAME: workgroup_size = [16 : index, 8 : index, 1 : index]
+// CHECK: func @_lowering_config_test_dispatch_1
+// CHECK: linalg.fill
+// CHECK-SAME: lowering.config = #[[CONFIG]]
+// CHECK: linalg.matmul
+// CHECK-SAME: lowering.config = #[[CONFIG]]