[CUDA] Add support for config lowering coming from IR (#7082)

To enable search we support having the config coming from the higher
level IR.
diff --git a/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 5d2ddb0..cb554f2 100644
--- a/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -221,7 +221,41 @@
       IREE::HAL::DispatchLoweringPassPipeline::LLVMGPUVectorize, workgroupSize);
 }
 
+/// Propagate the configuration annotated in the incoming IR.
+static LogicalResult setUserConfig(FuncOp entryPointFn, Operation *computeOp,
+                                   IREE::HAL::LoweringConfig config) {
+  IREE::HAL::DispatchLoweringPassPipeline passPipeline =
+      IREE::HAL::DispatchLoweringPassPipeline::LLVMGPUVectorize;
+  if (auto setPassPipeline = getLoweringPassPipeline(config)) {
+    passPipeline = setPassPipeline.getValue();
+  }
+  SmallVector<int64_t, 4> workgroupSize;
+  if (auto workgroupSizeAttr = config.workgroupSize()) {
+    workgroupSize = llvm::to_vector<4>(
+        llvm::map_range(workgroupSizeAttr, [](Attribute intAttr) {
+          return intAttr.cast<IntegerAttr>().getInt();
+        }));
+  }
+  if (failed(setOpConfigAndEntryPointFnTranslation(
+          entryPointFn, computeOp, config, passPipeline, workgroupSize))) {
+    return failure();
+  }
+  // Reset the op configuration to drop the pass-pipeline and workgroup size
+  // info. The op does not carry that information anymore.
+  auto resetConfig = IREE::HAL::LoweringConfig::get(
+      config.tileSizes(), config.nativeVectorSize(),
+      /*passPipeline =*/nullptr,
+      /*workgroupSize =*/nullptr, computeOp->getContext());
+  setLoweringConfig(computeOp, resetConfig);
+  return success();
+}
+
 static LogicalResult setRootConfig(FuncOp entryPointFn, Operation *computeOp) {
+  if (IREE::HAL::LoweringConfig config = getLoweringConfig(computeOp)) {
+    // If the op already has a lowering config coming from the IR use this and
+    // bypass the heuristic.
+    return setUserConfig(entryPointFn, computeOp, config);
+  }
   if (auto linalgOp = dyn_cast<linalg::LinalgOp>(computeOp)) {
     if (linalg::isaContractionOpInterface(linalgOp) &&
         linalgOp.getNumParallelLoops() >= 2) {
diff --git a/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir b/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir
index 3665903..07cd247 100644
--- a/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir
+++ b/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir
@@ -364,3 +364,65 @@
 //       CHECK: func @static_3d_fft_stage3()
 //       CHECK:   linalg_ext.fft
 //  CHECK-SAME:     lowering.config = #[[CONFIG]]
+
+// -----
+
+hal.executable @user_config {
+hal.executable.variant public @cuda_nvptx_fb, target = #hal.executable.target<"cuda", "cuda-nvptx-fb"> {
+  hal.executable.entry_point public @_lowering_config_test_dispatch_1 attributes {interface = @io, ordinal = 0 : index}
+  builtin.module  {
+    func @_lowering_config_test_dispatch_1() {
+      %cst = constant 0.000000e+00 : f32
+      %c128 = constant 128 : index
+      %c1024 = constant 1024 : index
+      %c0 = constant 0 : index
+      %0 = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : !flow.dispatch.tensor<readonly:128x256xf32>
+      %1 = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : !flow.dispatch.tensor<readonly:256x1024xf32>
+      %2 = hal.interface.binding.subspan @io::@s0b2_xw_external[%c0] : !flow.dispatch.tensor<writeonly:128x1024xf32>
+      %workgroup_size_x = hal.interface.workgroup.size[0] : index
+      %workgroup_size_y = hal.interface.workgroup.size[1] : index
+      %workgroup_id_x = hal.interface.workgroup.id[0] : index
+      %workgroup_count_x = hal.interface.workgroup.count[0] : index
+      %workgroup_id_y = hal.interface.workgroup.id[1] : index
+      %workgroup_count_y = hal.interface.workgroup.count[1] : index
+      %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
+      %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
+      scf.for %arg0 = %3 to %c128 step %4 {
+        %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
+        %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
+        scf.for %arg1 = %5 to %c1024 step %6 {
+          %7 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 128)>(%arg0)[%workgroup_size_y]
+          %8 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [%7, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x256xf32> -> tensor<?x256xf32>
+          %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1024)>(%arg1)[%workgroup_size_x]
+          %10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [256, %9], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x1024xf32> -> tensor<256x?xf32>
+          %11 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 128)>(%arg0)[%workgroup_size_y]
+          %12 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1024)>(%arg1)[%workgroup_size_x]
+          %13 = affine.min affine_map<(d0)[s0] -> (-d0 + 128, s0)>(%arg0)[%workgroup_size_y]
+          %14 = affine.min affine_map<(d0)[s0] -> (-d0 + 1024, s0)>(%arg1)[%workgroup_size_x]
+          %15 = linalg.init_tensor [%13, %14] : tensor<?x?xf32>
+          %16 = linalg.fill(%cst, %15) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
+          %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup", lowering.config = {passPipeline = "LLVMGPUMatmulSimt", tileSizes = [[32, 256, 64], [], [4, 16]], workgroupSize = [16, 8, 1]}} ins(%8, %10 : tensor<?x256xf32>, tensor<256x?xf32>) outs(%16 : tensor<?x?xf32>) -> tensor<?x?xf32>
+          flow.dispatch.tensor.store %17, %2, offsets = [%arg0, %arg1], sizes = [%11, %12], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:128x1024xf32>
+        }
+      }
+      return
+    }
+    hal.interface private @io {
+      hal.interface.binding public @s0b0_ro_external, set=0, binding=0, type="StorageBuffer", access="Read"
+      hal.interface.binding public @s0b1_ro_external, set=0, binding=1, type="StorageBuffer", access="Read"
+      hal.interface.binding public @s0b2_xw_external, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+    }
+  }
+}
+}
+
+//  CHECK-DAG: #[[CONFIG:.+]] = {{{.*}}tileSizes = {{\[}}[32, 256, 64], [], [4, 16]{{\]}}}
+//      CHECK: hal.executable.entry_point public @_lowering_config_test_dispatch_1
+// CHECK-SAME:     passPipeline = "LLVMGPUMatmulSimt"
+// CHECK-SAME:     workloadPerWorkgroup = [256, 32]
+// CHECK-SAME:     workgroup_size = [16 : index, 8 : index, 1 : index]
+//      CHECK: func @_lowering_config_test_dispatch_1
+//      CHECK:   linalg.fill
+// CHECK-SAME:       lowering.config = #[[CONFIG]]
+//      CHECK:   linalg.matmul
+// CHECK-SAME:       lowering.config = #[[CONFIG]]