Move GPU ukernel selection to KernelConfig (#19440)

This moves the logic deciding whether an op should be a ukernel out of
the GPULowerToUKernels pass, into KernelConfig.

So KernelConfig decides whether the op should be a ukernel, and encodes
that into the resulting `lowering_config`, in a new parameter, that is a
new attribute, UKernelSpecAttr. That attribute is directly modeled after
the equivalent C++ data structure that we have had in LowerToUKernels
passes, `FnNameAndDefAttrs`, which it replaces. If the attribute is
present, it means that the op was selected for ukernel lowering, with
the fields telling the ukernel name and some function definition
attributes (to import any dependencies, such as the `rocm` module for
runtime support symbols).

All the details about supplying the ukernel bitcode in a
`hal.executable.object` are also moved there, becoming a side effect of
`KernelConfig`.

The GPULowerToUKernels becomes much simpler, since all the
decision-making was already done for it. It just looks at the
`LoweringConfigAttr` and if it's there, it performs the requested
lowering.

The motivation for this split is that we need to know in KernelConfig
whether it's going to be a ukernel, because ops that will get lowered to
a ukernel require a different configuration. The important example for
us is `multi_mma`, which in the ukernel case needs to avoid
reduction-dimension tiling to 1 so that the ukernel gets to see the
reduction loop.

A few simplifications arise already in the current argmax ukernel logic,
confirming that this was the right design choice: the old ukernel's
matching logic was checking that the distribution tile sizes matched
what the ukernel could handle; now that is turned upside down: the
ukernel matching happens as a helper within KernelConfig where we know
we are setting the appropriate tile sizes on purpose.

Another nice improvement is that this puts just enough distance between
ukernel selection (which creates the `hal.executable.object`) and
ukernel lowering, that we are able to insert
`HoistExecutableObjectsPass` in between, simplifying the ukernel
lowering as it doesn't need to worry anymore about preserving the
`hal.executable.object`.

---------

Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
diff --git a/compiler/plugins/target/ROCM/test/BUILD.bazel b/compiler/plugins/target/ROCM/test/BUILD.bazel
index f0521a0..2a71f59 100644
--- a/compiler/plugins/target/ROCM/test/BUILD.bazel
+++ b/compiler/plugins/target/ROCM/test/BUILD.bazel
@@ -15,8 +15,9 @@
 iree_lit_test_suite(
     name = "lit",
     srcs = [
+        "config_ukernel_argmax_gfx908.mlir",
+        "config_ukernel_argmax_gfx942.mlir",
         "default_tuning_specs_amdgpu.mlir",
-        "gpu_lower_to_ukernels.mlir",
         "lowering_strategy_from_tuning_spec.mlir",
         "ukernel_pipeline_transform.mlir",
     ],
diff --git a/compiler/plugins/target/ROCM/test/CMakeLists.txt b/compiler/plugins/target/ROCM/test/CMakeLists.txt
index 36d9ba6..bab8858 100644
--- a/compiler/plugins/target/ROCM/test/CMakeLists.txt
+++ b/compiler/plugins/target/ROCM/test/CMakeLists.txt
@@ -14,8 +14,9 @@
   NAME
     lit
   SRCS
+    "config_ukernel_argmax_gfx908.mlir"
+    "config_ukernel_argmax_gfx942.mlir"
     "default_tuning_specs_amdgpu.mlir"
-    "gpu_lower_to_ukernels.mlir"
     "lowering_strategy_from_tuning_spec.mlir"
     "ukernel_pipeline_transform.mlir"
   TOOLS
diff --git a/compiler/plugins/target/ROCM/test/config_ukernel_argmax_gfx908.mlir b/compiler/plugins/target/ROCM/test/config_ukernel_argmax_gfx908.mlir
new file mode 100644
index 0000000..ba12bf5
--- /dev/null
+++ b/compiler/plugins/target/ROCM/test/config_ukernel_argmax_gfx908.mlir
@@ -0,0 +1,30 @@
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx908 --pass-pipeline='builtin.module(iree-llvmgpu-select-lowering-strategy)' %s | FileCheck %s
+
+// gfx908 a.k.a. CDNA1 is used here as an example of a GPU target that we don't have ukernels for.
+// No need to add many ukernels here, just a quick check that we correctly do not select a ukernel.
+
+func.func @argmax_2d_f32i64(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes {
+  hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}>
+} {
+  %c0_i64 = arith.constant 0 : i64
+  %cst = arith.constant 0xFF800000 : f32
+  %0 = tensor.empty() : tensor<1xi64>
+  %1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<1xi64>) -> tensor<1xi64>
+  %2 = tensor.empty() : tensor<1xf32>
+  %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<1xf32>) -> tensor<1xf32>
+  %4:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<1x?xf32>) outs(%3, %1 : tensor<1xf32>, tensor<1xi64>) {
+  ^bb0(%in: f32, %out: f32, %out_0: i64):
+    %5 = linalg.index 1 : index
+    %6 = arith.index_cast %5 : index to i64
+    %7 = arith.maximumf %in, %out : f32
+    %8 = arith.cmpf ogt, %in, %out : f32
+    %9 = arith.select %8, %6, %out_0 : i64
+    linalg.yield %7, %9 : f32, i64
+  } -> (tensor<1xf32>, tensor<1xi64>)
+  return %4#1 : tensor<1xi64>
+}
+
+//   CHECK-NOT: lowering_config<{{.*}}ukernel
+// CHECK-LABEL: func @argmax_2d_f32i64(
+//       CHECK: linalg.generic
+//   CHECK-NOT: hal.executable.objects
diff --git a/compiler/plugins/target/ROCM/test/config_ukernel_argmax_gfx942.mlir b/compiler/plugins/target/ROCM/test/config_ukernel_argmax_gfx942.mlir
new file mode 100644
index 0000000..4a7da4b
--- /dev/null
+++ b/compiler/plugins/target/ROCM/test/config_ukernel_argmax_gfx942.mlir
@@ -0,0 +1,242 @@
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx942 --pass-pipeline='builtin.module(iree-llvmgpu-select-lowering-strategy)' %s | FileCheck %s
+
+func.func @argmax_2d_f32i64(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes {
+  hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}>
+} {
+  %c0_i64 = arith.constant 0 : i64
+  %cst = arith.constant 0xFF800000 : f32
+  %0 = tensor.empty() : tensor<1xi64>
+  %1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<1xi64>) -> tensor<1xi64>
+  %2 = tensor.empty() : tensor<1xf32>
+  %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<1xf32>) -> tensor<1xf32>
+  %4:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<1x?xf32>) outs(%3, %1 : tensor<1xf32>, tensor<1xi64>) {
+  ^bb0(%in: f32, %out: f32, %out_0: i64):
+    %5 = linalg.index 1 : index
+    %6 = arith.index_cast %5 : index to i64
+    %7 = arith.maximumf %in, %out : f32
+    %8 = arith.cmpf ogt, %in, %out : f32
+    %9 = arith.select %8, %6, %out_0 : i64
+    linalg.yield %7, %9 : f32, i64
+  } -> (tensor<1xf32>, tensor<1xi64>)
+  return %4#1 : tensor<1xi64>
+}
+
+// CHECK-LABEL: func @argmax_2d_f32i64(
+//       CHECK: linalg.generic
+//  CHECK-SAME: hal.executable.objects = [
+//  CEHCK-SAME:   #hal.executable.object<{path = "iree_uk_amdgpu_argmax_f32i64.gfx942.bc", data = dense_resource<iree_uk_amdgpu_argmax_f32i64.gfx942.bc> : vector<{{[0-9]+}}xi8>}>]
+//  CHECK-SAME:   #iree_gpu.lowering_config<{{.*}}ukernel = #iree_gpu.ukernel_spec<name = "iree_uk_amdgpu_argmax_f32i64", def_attrs = {vm.import.module = "rocm"}>
+
+// -----
+
+func.func @argmax_4d_unit_parallel_f32i64(%arg0 : tensor<1x1x1x?xf32>) -> tensor<1x1x1xi64> attributes {
+  hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}>
+} {
+  %c0_i64 = arith.constant 0 : i64
+  %cst = arith.constant 0xFF800000 : f32
+  %0 = tensor.empty() : tensor<1x1x1xi64>
+  %1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<1x1x1xi64>) -> tensor<1x1x1xi64>
+  %2 = tensor.empty() : tensor<1x1x1xf32>
+  %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
+  %4:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0 : tensor<1x1x1x?xf32>) outs(%3, %1 : tensor<1x1x1xf32>, tensor<1x1x1xi64>) {
+  ^bb0(%in: f32, %out: f32, %out_0: i64):
+    %5 = linalg.index 1 : index
+    %6 = arith.index_cast %5 : index to i64
+    %7 = arith.maximumf %in, %out : f32
+    %8 = arith.cmpf ogt, %in, %out : f32
+    %9 = arith.select %8, %6, %out_0 : i64
+    linalg.yield %7, %9 : f32, i64
+  } -> (tensor<1x1x1xf32>, tensor<1x1x1xi64>)
+  return %4#1 : tensor<1x1x1xi64>
+}
+
+// CHECK-LABEL: func @argmax_4d_unit_parallel_f32i64(
+//       CHECK: linalg.generic
+//  CHECK-SAME: hal.executable.objects = [
+//  CEHCK-SAME:   #hal.executable.object<{path = "iree_uk_amdgpu_argmax_f32i64.gfx942.bc", data = dense_resource<iree_uk_amdgpu_argmax_f32i64.gfx942.bc> : vector<{{[0-9]+}}xi8>}>]
+//  CHECK-SAME:   #iree_gpu.lowering_config<{{.*}}ukernel = #iree_gpu.ukernel_spec<name = "iree_uk_amdgpu_argmax_f32i64", def_attrs = {vm.import.module = "rocm"}>
+
+// -----
+
+func.func @argmax_none_ukernel_enabled(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes {
+  hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "none"}>
+} {
+  %c0_i64 = arith.constant 0 : i64
+  %cst = arith.constant 0xFF800000 : f32
+  %0 = tensor.empty() : tensor<1xi64>
+  %1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<1xi64>) -> tensor<1xi64>
+  %2 = tensor.empty() : tensor<1xf32>
+  %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<1xf32>) -> tensor<1xf32>
+  %4:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<1x?xf32>) outs(%3, %1 : tensor<1xf32>, tensor<1xi64>) {
+  ^bb0(%in: f32, %out: f32, %out_0: i64):
+    %5 = linalg.index 1 : index
+    %6 = arith.index_cast %5 : index to i64
+    %7 = arith.maximumf %in, %out : f32
+    %8 = arith.cmpf ogt, %in, %out : f32
+    %9 = arith.select %8, %6, %out_0 : i64
+    linalg.yield %7, %9 : f32, i64
+  } -> (tensor<1xf32>, tensor<1xi64>)
+  return %4#1 : tensor<1xi64>
+}
+
+// CHECK-LABEL: func @argmax_none_ukernel_enabled(
+//       CHECK: linalg.generic
+//   CHECK-NOT: hal.executable.objects
+//   CHECK-NOT: iree_gpu.ukernel_spec
+
+// -----
+
+func.func @argmax_only_argmax_ukernel_enabled(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes {
+  hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "argmax"}>
+} {
+  %c0_i64 = arith.constant 0 : i64
+  %cst = arith.constant 0xFF800000 : f32
+  %0 = tensor.empty() : tensor<1xi64>
+  %1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<1xi64>) -> tensor<1xi64>
+  %2 = tensor.empty() : tensor<1xf32>
+  %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<1xf32>) -> tensor<1xf32>
+  %4:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<1x?xf32>) outs(%3, %1 : tensor<1xf32>, tensor<1xi64>) {
+  ^bb0(%in: f32, %out: f32, %out_0: i64):
+    %5 = linalg.index 1 : index
+    %6 = arith.index_cast %5 : index to i64
+    %7 = arith.maximumf %in, %out : f32
+    %8 = arith.cmpf ogt, %in, %out : f32
+    %9 = arith.select %8, %6, %out_0 : i64
+    linalg.yield %7, %9 : f32, i64
+  } -> (tensor<1xf32>, tensor<1xi64>)
+  return %4#1 : tensor<1xi64>
+}
+
+// CHECK-LABEL: func @argmax_only_argmax_ukernel_enabled(
+//       CHECK: linalg.generic
+//  CHECK-SAME: hal.executable.objects = [
+//  CHECK-SAME:   #hal.executable.object<{path = "iree_uk_amdgpu_argmax_f32i64.gfx942.bc", data = dense_resource<iree_uk_amdgpu_argmax_f32i64.gfx942.bc> : vector<{{[0-9]+}}xi8>}>]
+//  CHECK-SAME:   #iree_gpu.lowering_config<{{.*}}ukernel = #iree_gpu.ukernel_spec<name = "iree_uk_amdgpu_argmax_f32i64", def_attrs = {vm.import.module = "rocm"}>
+
+// -----
+
+func.func @argmax_only_foo_argmax_bar_ukernel_enabled(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes {
+  hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "foo,argmax,bar"}>
+} {
+  %c0_i64 = arith.constant 0 : i64
+  %cst = arith.constant 0xFF800000 : f32
+  %0 = tensor.empty() : tensor<1xi64>
+  %1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<1xi64>) -> tensor<1xi64>
+  %2 = tensor.empty() : tensor<1xf32>
+  %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<1xf32>) -> tensor<1xf32>
+  %4:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<1x?xf32>) outs(%3, %1 : tensor<1xf32>, tensor<1xi64>) {
+  ^bb0(%in: f32, %out: f32, %out_0: i64):
+    %5 = linalg.index 1 : index
+    %6 = arith.index_cast %5 : index to i64
+    %7 = arith.maximumf %in, %out : f32
+    %8 = arith.cmpf ogt, %in, %out : f32
+    %9 = arith.select %8, %6, %out_0 : i64
+    linalg.yield %7, %9 : f32, i64
+  } -> (tensor<1xf32>, tensor<1xi64>)
+  return %4#1 : tensor<1xi64>
+}
+
+// CHECK-LABEL: func @argmax_only_foo_argmax_bar_ukernel_enabled(
+//       CHECK: linalg.generic
+//  CHECK-SAME: hal.executable.objects = [
+//  CHECK-SAME:   #hal.executable.object<{path = "iree_uk_amdgpu_argmax_f32i64.gfx942.bc", data = dense_resource<iree_uk_amdgpu_argmax_f32i64.gfx942.bc> : vector<{{[0-9]+}}xi8>}>]
+//  CHECK-SAME:   #iree_gpu.lowering_config<{{.*}}ukernel = #iree_gpu.ukernel_spec<name = "iree_uk_amdgpu_argmax_f32i64", def_attrs = {vm.import.module = "rocm"}>
+
+// -----
+
+func.func @argmax_only_foo_ukernel_enabled(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes {
+  hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "foo"}>
+} {
+  %c0_i64 = arith.constant 0 : i64
+  %cst = arith.constant 0xFF800000 : f32
+  %0 = tensor.empty() : tensor<1xi64>
+  %1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<1xi64>) -> tensor<1xi64>
+  %2 = tensor.empty() : tensor<1xf32>
+  %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<1xf32>) -> tensor<1xf32>
+  %4:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<1x?xf32>) outs(%3, %1 : tensor<1xf32>, tensor<1xi64>) {
+  ^bb0(%in: f32, %out: f32, %out_0: i64):
+    %5 = linalg.index 1 : index
+    %6 = arith.index_cast %5 : index to i64
+    %7 = arith.maximumf %in, %out : f32
+    %8 = arith.cmpf ogt, %in, %out : f32
+    %9 = arith.select %8, %6, %out_0 : i64
+    linalg.yield %7, %9 : f32, i64
+  } -> (tensor<1xf32>, tensor<1xi64>)
+  return %4#1 : tensor<1xi64>
+}
+
+// CHECK-LABEL: func @argmax_only_foo_ukernel_enabled(
+//       CHECK: linalg.generic
+//   CHECK-NOT: hal.executable.objects
+//   CHECK-NOT: iree_gpu.ukernel_spec
+
+// -----
+
+// Currently we do only handle -Inf case as initial values.
+func.func @argmax_2d_f32i64_not_neg_inf_init(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes {
+  hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}>
+} {
+  %c0_i64 = arith.constant 0 : i64
+  %cst = arith.constant 0.0 : f32
+  %0 = tensor.empty() : tensor<1xi64>
+  %1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<1xi64>) -> tensor<1xi64>
+  %2 = tensor.empty() : tensor<1xf32>
+  %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<1xf32>) -> tensor<1xf32>
+  %4:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<1x?xf32>) outs(%3, %1 : tensor<1xf32>, tensor<1xi64>) {
+  ^bb0(%in: f32, %out: f32, %out_0: i64):
+    %5 = linalg.index 1 : index
+    %6 = arith.index_cast %5 : index to i64
+    %7 = arith.maximumf %in, %out : f32
+    %8 = arith.cmpf ogt, %in, %out : f32
+    %9 = arith.select %8, %6, %out_0 : i64
+    linalg.yield %7, %9 : f32, i64
+  } -> (tensor<1xf32>, tensor<1xi64>)
+  return %4#1 : tensor<1xi64>
+}
+
+//   CHECK-NOT: lowering_config<{{.*}}ukernel
+// CHECK-LABEL: func @argmax_2d_f32i64_not_neg_inf_init(
+//       CHECK: linalg.generic
+//   CHECK-NOT: hal.executable.objects
+
+// -----
+
+// Test user-provided bitcode in the source IR.
+
+func.func @argmax_2d_f32i64_custom_bitcode(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes {
+  hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}>,
+  // Dummy bitcode with an unusual length of 12. The first 4 bytes are the .bc file format signature.
+  hal.executable.objects = [
+    #hal.executable.object<{
+      path = "iree_uk_amdgpu_argmax_f32i64.gfx942.bc",
+      data = dense<[66, 67, -64, -34, 1, 35, 69, 103, -119, -85, -51, -17]> : tensor<12xi8>
+    }>
+  ]
+} {
+  %c0_i64 = arith.constant 0 : i64
+  %cst = arith.constant 0xFF800000 : f32
+  %0 = tensor.empty() : tensor<1xi64>
+  %1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<1xi64>) -> tensor<1xi64>
+  %2 = tensor.empty() : tensor<1xf32>
+  %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<1xf32>) -> tensor<1xf32>
+  %4:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<1x?xf32>) outs(%3, %1 : tensor<1xf32>, tensor<1xi64>) {
+  ^bb0(%in: f32, %out: f32, %out_0: i64):
+    %5 = linalg.index 1 : index
+    %6 = arith.index_cast %5 : index to i64
+    %7 = arith.maximumf %in, %out : f32
+    %8 = arith.cmpf ogt, %in, %out : f32
+    %9 = arith.select %8, %6, %out_0 : i64
+    linalg.yield %7, %9 : f32, i64
+  } -> (tensor<1xf32>, tensor<1xi64>)
+  return %4#1 : tensor<1xi64>
+}
+
+// CHECK-LABEL: func @argmax_2d_f32i64_custom_bitcode(
+//       CHECK:   linalg.generic
+//  CHECK-SAME:     hal.executable.objects = [
+//  CHECK-SAME:       #hal.executable.object<{
+//  CHECK-SAME:         path = "iree_uk_amdgpu_argmax_f32i64.gfx942.bc",
+//  CHECK-SAME:         data = dense<[66, 67, -64, -34, 1, 35, 69, 103, -119, -85, -51, -17]> : tensor<12xi8>
+//  CHECK-SAME:       }>
+//  CHECK-SAME:     ]
+//  CHECK-SAME: #iree_gpu.lowering_config<{{.*}}ukernel = #iree_gpu.ukernel_spec<name = "iree_uk_amdgpu_argmax_f32i64", def_attrs = {vm.import.module = "rocm"}>
diff --git a/compiler/plugins/target/ROCM/test/gpu_lower_to_ukernels.mlir b/compiler/plugins/target/ROCM/test/gpu_lower_to_ukernels.mlir
deleted file mode 100644
index 177bd0b..0000000
--- a/compiler/plugins/target/ROCM/test/gpu_lower_to_ukernels.mlir
+++ /dev/null
@@ -1,333 +0,0 @@
-// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx942 --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-lower-to-ukernels,cse,canonicalize))" %s | FileCheck %s
-// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx908 --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-lower-to-ukernels,cse,canonicalize))" %s | FileCheck %s --check-prefix=CDNA1
-
-func.func @argmax_2d_f32i64(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes {
-  hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}>
-} {
-  %c0_i64 = arith.constant 0 : i64
-  %cst = arith.constant 0xFF800000 : f32
-  %0 = tensor.empty() : tensor<1xi64>
-  %1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<1xi64>) -> tensor<1xi64>
-  %2 = tensor.empty() : tensor<1xf32>
-  %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<1xf32>) -> tensor<1xf32>
-  %4:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<1x?xf32>) outs(%3, %1 : tensor<1xf32>, tensor<1xi64>) {
-  ^bb0(%in: f32, %out: f32, %out_0: i64):
-    %5 = linalg.index 1 : index
-    %6 = arith.index_cast %5 : index to i64
-    %7 = arith.maximumf %in, %out : f32
-    %8 = arith.cmpf ogt, %in, %out : f32
-    %9 = arith.select %8, %6, %out_0 : i64
-    linalg.yield %7, %9 : f32, i64
-  } -> (tensor<1xf32>, tensor<1xi64>)
-  return %4#1 : tensor<1xi64>
-}
-
-//CHECK-LABEL: func @argmax_2d_f32i64(
-// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x?xf32>
-//  CHECK-DAG:   %[[C1_index:.+]] = arith.constant 1 : index
-//  CHECK-DAG:   %[[C0_i64:.+]] = arith.constant 0
-//  CHECK-DAG:   %[[FILL:.+]] = linalg.fill ins(%[[C0_i64]]
-//      CHECK:   %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic {hal.executable.objects = [{{.*}}]} "iree_uk_amdgpu_argmax_f32i64"
-// CHECK-SAME:       ins(%[[ARG0]] :
-// CHECK-SAME:       outs(%[[FILL]] :
-//      CHECK:   return %[[MICRO_KERNEL]]
-
-// -----
-
-func.func @argmax_4d_unit_parallel_f32i64(%arg0 : tensor<1x1x1x?xf32>) -> tensor<1x1x1xi64> attributes {
-  hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}>
-} {
-  %c0_i64 = arith.constant 0 : i64
-  %cst = arith.constant 0xFF800000 : f32
-  %0 = tensor.empty() : tensor<1x1x1xi64>
-  %1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<1x1x1xi64>) -> tensor<1x1x1xi64>
-  %2 = tensor.empty() : tensor<1x1x1xf32>
-  %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
-  %4:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0 : tensor<1x1x1x?xf32>) outs(%3, %1 : tensor<1x1x1xf32>, tensor<1x1x1xi64>) {
-  ^bb0(%in: f32, %out: f32, %out_0: i64):
-    %5 = linalg.index 1 : index
-    %6 = arith.index_cast %5 : index to i64
-    %7 = arith.maximumf %in, %out : f32
-    %8 = arith.cmpf ogt, %in, %out : f32
-    %9 = arith.select %8, %6, %out_0 : i64
-    linalg.yield %7, %9 : f32, i64
-  } -> (tensor<1x1x1xf32>, tensor<1x1x1xi64>)
-  return %4#1 : tensor<1x1x1xi64>
-}
-
-//      CHECK-LABEL: func @argmax_4d_unit_parallel_f32i64(
-//      CHECK: iree_codegen.ukernel.generic
-//      CHECK-NOT: linalg.generic
-
-// -----
-
-func.func @argmax_2d_non_unit_parallel_f32i64(%arg0 : tensor<4x?xf32>) -> tensor<4xi64> attributes {
-  hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}>
-} {
-  %c0_i64 = arith.constant 0 : i64
-  %cst = arith.constant 0xFF800000 : f32
-  %0 = tensor.empty() : tensor<4xi64>
-  %1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<4xi64>) -> tensor<4xi64>
-  %2 = tensor.empty() : tensor<4xf32>
-  %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<4xf32>) -> tensor<4xf32>
-  %4:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<4x?xf32>) outs(%3, %1 : tensor<4xf32>, tensor<4xi64>) {
-  ^bb0(%in: f32, %out: f32, %out_0: i64):
-    %5 = linalg.index 1 : index
-    %6 = arith.index_cast %5 : index to i64
-    %7 = arith.maximumf %in, %out : f32
-    %8 = arith.cmpf ogt, %in, %out : f32
-    %9 = arith.select %8, %6, %out_0 : i64
-    linalg.yield %7, %9 : f32, i64
-  } -> (tensor<4xf32>, tensor<4xi64>)
-  return %4#1 : tensor<4xi64>
-}
-
-//      CHECK-LABEL: func @argmax_2d_non_unit_parallel_f32i64(
-//      CHECK-NOT: iree_codegen.ukernel.generic
-//      CHECK: linalg.generic
-
-// -----
-
-func.func @argmax_2d_dyn_parallel_f32i64(%arg0 : tensor<?x?xf32>) -> tensor<?xi64> attributes {
-  hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}>
-} {
-  %c0 = arith.constant 0 : index
-  %c0_i64 = arith.constant 0 : i64
-  %cst = arith.constant 0xFF800000 : f32
-  %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
-  %0 = tensor.empty(%dim) : tensor<?xi64>
-  %1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<?xi64>) -> tensor<?xi64>
-  %2 = tensor.empty(%dim) : tensor<?xf32>
-  %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<?xf32>) -> tensor<?xf32>
-  %4:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<?x?xf32>) outs(%3, %1 : tensor<?xf32>, tensor<?xi64>) {
-  ^bb0(%in: f32, %out: f32, %out_0: i64):
-    %5 = linalg.index 1 : index
-    %6 = arith.index_cast %5 : index to i64
-    %7 = arith.maximumf %in, %out : f32
-    %8 = arith.cmpf ogt, %in, %out : f32
-    %9 = arith.select %8, %6, %out_0 : i64
-    linalg.yield %7, %9 : f32, i64
-  } -> (tensor<?xf32>, tensor<?xi64>)
-  return %4#1 : tensor<?xi64>
-}
-
-//      CHECK-LABEL: func @argmax_2d_dyn_parallel_f32i64(
-//      CHECK-NOT: iree_codegen.ukernel.generic
-//      CHECK: linalg.generic
-
-// -----
-
-func.func @argmax_none_ukernel_enabled(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes {
-  hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "none"}>
-} {
-  %c0_i64 = arith.constant 0 : i64
-  %cst = arith.constant 0xFF800000 : f32
-  %0 = tensor.empty() : tensor<1xi64>
-  %1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<1xi64>) -> tensor<1xi64>
-  %2 = tensor.empty() : tensor<1xf32>
-  %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<1xf32>) -> tensor<1xf32>
-  %4:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<1x?xf32>) outs(%3, %1 : tensor<1xf32>, tensor<1xi64>) {
-  ^bb0(%in: f32, %out: f32, %out_0: i64):
-    %5 = linalg.index 1 : index
-    %6 = arith.index_cast %5 : index to i64
-    %7 = arith.maximumf %in, %out : f32
-    %8 = arith.cmpf ogt, %in, %out : f32
-    %9 = arith.select %8, %6, %out_0 : i64
-    linalg.yield %7, %9 : f32, i64
-  } -> (tensor<1xf32>, tensor<1xi64>)
-  return %4#1 : tensor<1xi64>
-}
-
-//      CHECK-LABEL: func @argmax_none_ukernel_enabled(
-//      CHECK-NOT: iree_codegen.ukernel.generic
-//      CHECK: linalg.generic
-
-// -----
-
-func.func @argmax_only_argmax_ukernel_enabled(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes {
-  hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "argmax"}>
-} {
-  %c0_i64 = arith.constant 0 : i64
-  %cst = arith.constant 0xFF800000 : f32
-  %0 = tensor.empty() : tensor<1xi64>
-  %1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<1xi64>) -> tensor<1xi64>
-  %2 = tensor.empty() : tensor<1xf32>
-  %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<1xf32>) -> tensor<1xf32>
-  %4:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<1x?xf32>) outs(%3, %1 : tensor<1xf32>, tensor<1xi64>) {
-  ^bb0(%in: f32, %out: f32, %out_0: i64):
-    %5 = linalg.index 1 : index
-    %6 = arith.index_cast %5 : index to i64
-    %7 = arith.maximumf %in, %out : f32
-    %8 = arith.cmpf ogt, %in, %out : f32
-    %9 = arith.select %8, %6, %out_0 : i64
-    linalg.yield %7, %9 : f32, i64
-  } -> (tensor<1xf32>, tensor<1xi64>)
-  return %4#1 : tensor<1xi64>
-}
-
-//      CDNA2-LABEL: func @argmax_only_argmax_ukernel_enabled(
-//      CDNA2: iree_codegen.ukernel.generic
-//      CDNA2-NOT: linalg.generic
-
-// -----
-
-func.func @argmax_only_foo_argmax_bar_ukernel_enabled(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes {
-  hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "foo,argmax,bar"}>
-} {
-  %c0_i64 = arith.constant 0 : i64
-  %cst = arith.constant 0xFF800000 : f32
-  %0 = tensor.empty() : tensor<1xi64>
-  %1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<1xi64>) -> tensor<1xi64>
-  %2 = tensor.empty() : tensor<1xf32>
-  %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<1xf32>) -> tensor<1xf32>
-  %4:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<1x?xf32>) outs(%3, %1 : tensor<1xf32>, tensor<1xi64>) {
-  ^bb0(%in: f32, %out: f32, %out_0: i64):
-    %5 = linalg.index 1 : index
-    %6 = arith.index_cast %5 : index to i64
-    %7 = arith.maximumf %in, %out : f32
-    %8 = arith.cmpf ogt, %in, %out : f32
-    %9 = arith.select %8, %6, %out_0 : i64
-    linalg.yield %7, %9 : f32, i64
-  } -> (tensor<1xf32>, tensor<1xi64>)
-  return %4#1 : tensor<1xi64>
-}
-
-//      CHECK-LABEL: func @argmax_only_foo_argmax_bar_ukernel_enabled(
-//      CHECK: iree_codegen.ukernel.generic
-//      CHECK-NOT: linalg.generic
-
-//      CDNA2-LABEL: func @argmax_only_foo_argmax_bar_ukernel_enabled(
-
-// -----
-
-func.func @argmax_only_foo_ukernel_enabled(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes {
-  hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "foo"}>
-} {
-  %c0_i64 = arith.constant 0 : i64
-  %cst = arith.constant 0xFF800000 : f32
-  %0 = tensor.empty() : tensor<1xi64>
-  %1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<1xi64>) -> tensor<1xi64>
-  %2 = tensor.empty() : tensor<1xf32>
-  %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<1xf32>) -> tensor<1xf32>
-  %4:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<1x?xf32>) outs(%3, %1 : tensor<1xf32>, tensor<1xi64>) {
-  ^bb0(%in: f32, %out: f32, %out_0: i64):
-    %5 = linalg.index 1 : index
-    %6 = arith.index_cast %5 : index to i64
-    %7 = arith.maximumf %in, %out : f32
-    %8 = arith.cmpf ogt, %in, %out : f32
-    %9 = arith.select %8, %6, %out_0 : i64
-    linalg.yield %7, %9 : f32, i64
-  } -> (tensor<1xf32>, tensor<1xi64>)
-  return %4#1 : tensor<1xi64>
-}
-
-//      CHECK-LABEL: func @argmax_only_foo_ukernel_enabled(
-//      CHECK-NOT: iree_codegen.ukernel.generic
-//      CHECK: linalg.generic
-
-// -----
-
-// Currently we do only handle -Inf case as initial values.
-func.func @argmax_2d_f32i64_not_neg_inf_init(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes {
-  hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}>
-} {
-  %c0_i64 = arith.constant 0 : i64
-  %cst = arith.constant 0.0 : f32
-  %0 = tensor.empty() : tensor<1xi64>
-  %1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<1xi64>) -> tensor<1xi64>
-  %2 = tensor.empty() : tensor<1xf32>
-  %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<1xf32>) -> tensor<1xf32>
-  %4:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<1x?xf32>) outs(%3, %1 : tensor<1xf32>, tensor<1xi64>) {
-  ^bb0(%in: f32, %out: f32, %out_0: i64):
-    %5 = linalg.index 1 : index
-    %6 = arith.index_cast %5 : index to i64
-    %7 = arith.maximumf %in, %out : f32
-    %8 = arith.cmpf ogt, %in, %out : f32
-    %9 = arith.select %8, %6, %out_0 : i64
-    linalg.yield %7, %9 : f32, i64
-  } -> (tensor<1xf32>, tensor<1xi64>)
-  return %4#1 : tensor<1xi64>
-}
-
-//      CHECK-LABEL: func @argmax_2d_f32i64_not_neg_inf_init(
-//      CHECK-NOT: iree_codegen.ukernel.generic
-//      CHECK: linalg.generic
-
-// -----
-
-// TODO: No technical reason this architecture is not supported.
-//       Currently just picking out popular chips to support,
-//       to minimize compile time and space.
-
-func.func @argmax_ukernel_unsupported_arch(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes {
-  hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}>
-} {
-  %c0_i64 = arith.constant 0 : i64
-  %cst = arith.constant 0xFF800000 : f32
-  %0 = tensor.empty() : tensor<1xi64>
-  %1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<1xi64>) -> tensor<1xi64>
-  %2 = tensor.empty() : tensor<1xf32>
-  %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<1xf32>) -> tensor<1xf32>
-  %4:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<1x?xf32>) outs(%3, %1 : tensor<1xf32>, tensor<1xi64>) {
-  ^bb0(%in: f32, %out: f32, %out_0: i64):
-    %5 = linalg.index 1 : index
-    %6 = arith.index_cast %5 : index to i64
-    %7 = arith.maximumf %in, %out : f32
-    %8 = arith.cmpf ogt, %in, %out : f32
-    %9 = arith.select %8, %6, %out_0 : i64
-    linalg.yield %7, %9 : f32, i64
-  } -> (tensor<1xf32>, tensor<1xi64>)
-  return %4#1 : tensor<1xi64>
-}
-
-//      CDNA1-LABEL: func @argmax_ukernel_unsupported_arch(
-//      CDNA1-NOT: iree_codegen.ukernel.generic
-//      CDNA1: linalg.generic
-
-// -----
-
-// Test user-provided bitcode in the source IR.
-
-func.func @argmax_2d_f32i64(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes {
-  hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}>,
-  // Dummy bitcode with an unusual length of 12. The first 4 bytes are the .bc file format signature.
-  hal.executable.objects = [
-    #hal.executable.object<{
-      path = "iree_uk_amdgpu_argmax_f32i64.gfx942.bc",
-      data = dense<[66, 67, -64, -34, 1, 35, 69, 103, -119, -85, -51, -17]> : tensor<12xi8>
-    }>
-  ]
-} {
-  %c0_i64 = arith.constant 0 : i64
-  %cst = arith.constant 0xFF800000 : f32
-  %0 = tensor.empty() : tensor<1xi64>
-  %1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<1xi64>) -> tensor<1xi64>
-  %2 = tensor.empty() : tensor<1xf32>
-  %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<1xf32>) -> tensor<1xf32>
-  %4:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<1x?xf32>) outs(%3, %1 : tensor<1xf32>, tensor<1xi64>) {
-  ^bb0(%in: f32, %out: f32, %out_0: i64):
-    %5 = linalg.index 1 : index
-    %6 = arith.index_cast %5 : index to i64
-    %7 = arith.maximumf %in, %out : f32
-    %8 = arith.cmpf ogt, %in, %out : f32
-    %9 = arith.select %8, %6, %out_0 : i64
-    linalg.yield %7, %9 : f32, i64
-  } -> (tensor<1xf32>, tensor<1xi64>)
-  return %4#1 : tensor<1xi64>
-}
-
-//CHECK-LABEL: func @argmax_2d_f32i64(
-// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x?xf32>
-//  CHECK-DAG:   %[[C1_index:.+]] = arith.constant 1 : index
-//  CHECK-DAG:   %[[C0_i64:.+]] = arith.constant 0
-//  CHECK-DAG:   %[[FILL:.+]] = linalg.fill ins(%[[C0_i64]]
-//      CHECK:   %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic {
-// CHECK-SAME:     hal.executable.objects = [
-// CHECK-SAME:       #hal.executable.object<{
-// CHECK-SAME:         path = "iree_uk_amdgpu_argmax_f32i64.gfx942.bc",
-// CHECK-SAME:         data = dense<[66, 67, -64, -34, 1, 35, 69, 103, -119, -85, -51, -17]> : tensor<12xi8>
-// CHECK-SAME:       }>
-// CHECK-SAME:     ]} "iree_uk_amdgpu_argmax_f32i64"
-// CHECK-SAME:       ins(%[[ARG0]] :
-// CHECK-SAME:       outs(%[[FILL]] :
-//      CHECK:   return %[[MICRO_KERNEL]]
diff --git a/compiler/plugins/target/ROCM/test/ukernel_pipeline_transform.mlir b/compiler/plugins/target/ROCM/test/ukernel_pipeline_transform.mlir
index 26ce4c8..15e5169 100644
--- a/compiler/plugins/target/ROCM/test/ukernel_pipeline_transform.mlir
+++ b/compiler/plugins/target/ROCM/test/ukernel_pipeline_transform.mlir
@@ -44,7 +44,7 @@
 //       CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUDefault workgroup_size = [32, 1, 1]>
 //       CHECK: func.func @argmax_1d_f16i64()
 //  CHECK-SAME:     translation_info = #[[$TRANSLATION]]
-//       CHECK:   iree_codegen.ukernel.generic {hal.executable.objects = [{{.*}}]} "iree_uk_amdgpu_argmax_f16i64"
+//       CHECK:   iree_codegen.ukernel.generic "iree_uk_amdgpu_argmax_f16i64"
 
 // -----
 
@@ -94,7 +94,7 @@
 // CHECK-SAME:     translation_info = #[[$TRANSLATION]]
 //      CHECK:   %[[SUBVIEW:.*]] = memref.subview{{.*}} memref<16x?xf32
 // CHECK-SAME:        to memref<1x?xf32
-//      CHECK:   iree_codegen.ukernel.generic {hal.executable.objects = [{{.*}}]} "iree_uk_amdgpu_argmax_f32i64" ins(%[[SUBVIEW]]
+//      CHECK:   iree_codegen.ukernel.generic "iree_uk_amdgpu_argmax_f32i64" ins(%[[SUBVIEW]]
 
 // -----
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToUKernels.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToUKernels.cpp
index c9ff4b8..796138d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToUKernels.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToUKernels.cpp
@@ -5,12 +5,11 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 #include "iree/compiler/Codegen/Common/GPU/Passes.h"
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
 #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
 #include "iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.h"
-#include "iree/compiler/Codegen/Utils/GPUUtils.h"
+#include "iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h"
 #include "iree/compiler/Codegen/Utils/Utils.h"
-#include "iree/compiler/Utils/EmbeddedDataDirectory.h"
-#include "llvm/Support/FormatVariadic.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -27,114 +26,12 @@
 
 namespace {
 
-// Returns a ExecutableObjectAttr carrying the bitcode for the given ukernel.
-//
-// First tries finding the bitcode in the input `sourceExecutableObjects`, which
-// must be an array of ExecutableObjectAttr's and is typically coming from a
-// hal.executable.objects array attribute in the source IR, which is the
-// mechanism by which source programs may provide their own ukernel bitcode.
-//
-// If no matching bitcode was found in `sourceExecutableObjects`, this function
-// will then search in bitcode files that we have embedded as static data.
-static IREE::HAL::ExecutableObjectAttr
-getUKernelBitcode(OpBuilder &builder,
-                  IREE::HAL::ExecutableTargetAttr execTarget,
-                  ArrayAttr sourceExecutableObjects, StringRef ukernelName) {
-  IREE::GPU::TargetAttr gpuTarget = getGPUTargetAttr(execTarget);
-  if (!gpuTarget) {
-    return {};
-  }
-  StringRef gpuArch = gpuTarget.getArch();
-  std::string bitcodeFilename = llvm::formatv("{}.{}.bc", ukernelName, gpuArch);
-
-  // Early-return if the source executable.objects already contain an object
-  // with the expected file name. This happens with user-provided bitcode in the
-  // source IR.
-  if (sourceExecutableObjects) {
-    for (Attribute a : sourceExecutableObjects) {
-      if (auto object = dyn_cast<IREE::HAL::ExecutableObjectAttr>(a)) {
-        if (object.getPath() == bitcodeFilename) {
-          return object;
-        }
-      }
-    }
-  }
-
-  // No user-provided bitcode, so we search our embedded bitcode files in the
-  // EmbeddedDataDirectory singleton.
-  std::optional<StringRef> bitcode;
-  EmbeddedDataDirectory::withGlobal([&](EmbeddedDataDirectory &dir) {
-    bitcode = dir.getFile(bitcodeFilename);
-  });
-  if (!bitcode) {
-    return {};
-  }
-  MLIRContext *context = builder.getContext();
-  auto blob = HeapAsmResourceBlob::allocateAndCopyInferAlign(
-      ArrayRef<char>(bitcode->data(), bitcode->size()));
-  auto bitcodeDenseAttr = DenseI8ResourceElementsAttr::get(
-      VectorType::get({static_cast<int64_t>(bitcode->size())},
-                      builder.getI8Type()),
-      bitcodeFilename, std::move(blob));
-  return IREE::HAL::ExecutableObjectAttr::get(
-      context, StringAttr::get(context, bitcodeFilename),
-      cast<IREE::Util::SerializableAttrInterface>(bitcodeDenseAttr));
-}
-
-// Walks parents ops from `op` to return the nearest hal.executable.objects
-// array attribute. If the parent hal.executable.variant is reached, its objects
-// attribute is returned.
-// Adapted from ExecutableTargetAttr::lookup.
-static ArrayAttr lookUpExecutableObjects(Operation *op) {
-  MLIRContext *context = op->getContext();
-  auto attrId = StringAttr::get(context, "hal.executable.objects");
-  while (op) {
-    // Take directly from the enclosing variant.
-    if (auto variantOp = dyn_cast<IREE::HAL::ExecutableVariantOp>(op)) {
-      if (std::optional<ArrayAttr> objects = variantOp.getObjects()) {
-        return *objects;
-      }
-    }
-    // Take from op attributes.
-    if (auto attr = op->getAttrOfType<ArrayAttr>(attrId)) {
-      return attr;
-    }
-    // Continue walk.
-    op = op->getParentOp();
-  }
-  return {};
-}
-
-/// Holds a function name and attributes.
-struct FnNameAndDefAttrs {
-  std::string name;
-  SmallVector<NamedAttribute> defAttrs;
-  explicit operator bool() const { return !name.empty(); }
-};
-
-/// Returns the function name and attributes to use for a ukernel with given
-/// `name` and `suffix` on the target described by `targetAttr`.
-static FnNameAndDefAttrs
-getFnNameAndDefAttrs(const char *name, std::string &suffix,
-                     RewriterBase &rewriter,
-                     IREE::HAL::ExecutableTargetAttr targetAttr) {
-  FnNameAndDefAttrs result;
-  if (isROCMBackend(targetAttr)) {
-    result.name = llvm::formatv("iree_uk_amdgpu_{}_{}", name, suffix);
-    result.defAttrs.emplace_back(rewriter.getStringAttr("vm.import.module"),
-                                 rewriter.getStringAttr("rocm"));
-  }
-  return result;
-}
-
 /// Matches generic that represent argmax and check if
 /// we have the ukernel that matches it shape constraint, and types.
 /// If we do, then we convert into iree_codegen.ukernel.argmax operation,
 /// that is later lowered into a call to the microkernel.
 static FailureOr<IREE::Codegen::UKernelOpInterface>
 matchArgmaxDAGForUKernel(RewriterBase &rewriter, linalg::GenericOp op) {
-  auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(op);
-  const char ukernelName[] = "argmax";
   Value input = op.getDpsInputOperand(0)->get();
   auto inputType = cast<ShapedType>(input.getType());
   Value index = op.getDpsInitOperand(1)->get();
@@ -142,41 +39,16 @@
   std::string suffix;
   llvm::raw_string_ostream(suffix)
       << inputType.getElementType() << indexType.getElementType();
-  FnNameAndDefAttrs fn =
-      getFnNameAndDefAttrs(ukernelName, suffix, rewriter, targetAttr);
-  if (!fn) {
-    return rewriter.notifyMatchFailure(op, "no ukernels on this backend");
+  auto loweringConfig = getLoweringConfig<IREE::GPU::LoweringConfigAttr>(op);
+  if (!loweringConfig) {
+    return rewriter.notifyMatchFailure(op, "no lowering_config on this op");
+  }
+  IREE::GPU::UKernelSpecAttr ukernelAttr =
+      IREE::GPU::getUkernelSpec(loweringConfig);
+  if (!ukernelAttr) {
+    return rewriter.notifyMatchFailure(op, "no ukernel selected for this op");
   }
 
-  if (!hasUkernel(targetAttr, ukernelName)) {
-    return rewriter.notifyMatchFailure(op, "ukernel not enabled");
-  }
-
-  // Currently only support argmax where parallel dims are 1.
-  // Tiling pipeline is also set to tile all parallel dims to 1, and
-  // reduction dim to be size of whole reduction problem. Which allow
-  // this constraint to be true for a lot of argmax variances.
-  // TODO: Support multi-row or grid-strided argmax ukernel.
-  SmallVector<int64_t, 4> bounds = op.getStaticLoopRanges();
-  SmallVector<unsigned> parallelDims;
-  op.getParallelDims(parallelDims);
-  int64_t parallelSize = 1;
-  for (int64_t dim : parallelDims) {
-    if (ShapedType::isDynamic(bounds[dim])) {
-      return failure();
-    }
-    parallelSize *= bounds[dim];
-  }
-  if (parallelSize != 1) {
-    return failure();
-  }
-  auto execTarget = IREE::HAL::ExecutableTargetAttr::lookup(op);
-  ArrayAttr sourceExecutableObjects = lookUpExecutableObjects(op);
-  IREE::HAL::ExecutableObjectAttr bitcodeObject =
-      getUKernelBitcode(rewriter, execTarget, sourceExecutableObjects, fn.name);
-  if (!bitcodeObject) {
-    return rewriter.notifyMatchFailure(op, "no ukernel bitcode for this op");
-  }
   Location loc = op.getLoc();
   // Currently only support 1D reduction, where reduc is on fastest dim.
   // Tiling argmax ukernel is also set to enforce this structure.
@@ -184,13 +56,9 @@
   Value reductionDimSize =
       rewriter.create<tensor::DimOp>(loc, input, kReductionDim);
   auto genericMicroKernelOp = rewriter.create<IREE::Codegen::UKernelGenericOp>(
-      loc, indexType, fn.name, ValueRange{input}, index,
-      ValueRange{reductionDimSize},
-      /*fn_def_attrs=*/rewriter.getDictionaryAttr(fn.defAttrs),
+      loc, indexType, ukernelAttr.getName(), ValueRange{input}, index,
+      ValueRange{reductionDimSize}, ukernelAttr.getDefAttrs(),
       /*strided_outer_dims=*/rewriter.getIndexAttr(0));
-  genericMicroKernelOp->setAttr(
-      "hal.executable.objects",
-      ArrayAttr::get(rewriter.getContext(), bitcodeObject));
   return cast<IREE::Codegen::UKernelOpInterface>(
       genericMicroKernelOp.getOperation());
 }
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
index b3fdd50..2c25e02 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
@@ -107,7 +107,7 @@
 
 def GPULowerToUKernelsPass :
     Pass<"iree-codegen-gpu-lower-to-ukernels", ""> {
-  let summary = "Lower suitable ops to microkernels.";
+  let summary = "Lower suitable ops to previously-selected microkernels";
   let dependentDialects = [
     "::mlir::iree_compiler::IREE::Codegen::IREECodegenDialect",
     "::mlir::iree_compiler::IREE::GPU::IREEGPUDialect",
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
index dc8e6a1..030e6f4 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
@@ -31,6 +31,7 @@
             "gpu_greedily_distribute_to_threads.mlir",
             "gpu_infer_memory_space.mlir",
             "gpu_combine_value_barriers.mlir",
+            "gpu_lower_to_ukernels.mlir",
             "gpu_materialize_encoding_gfx908.mlir",
             "gpu_materialize_encoding_gfx90a.mlir",
             "gpu_materialize_encoding_gfx942.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
index 4dc0f28..6d1f540 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
@@ -26,6 +26,7 @@
     "gpu_generalize_named_ops.mlir"
     "gpu_greedily_distribute_to_threads.mlir"
     "gpu_infer_memory_space.mlir"
+    "gpu_lower_to_ukernels.mlir"
     "gpu_materialize_encoding_gfx1100.mlir"
     "gpu_materialize_encoding_gfx908.mlir"
     "gpu_materialize_encoding_gfx90a.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_lower_to_ukernels.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_lower_to_ukernels.mlir
new file mode 100644
index 0000000..6a13468
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_lower_to_ukernels.mlir
@@ -0,0 +1,72 @@
+// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-lower-to-ukernels,cse,canonicalize))" %s | FileCheck %s
+
+#config = #iree_gpu.lowering_config<{ukernel = #iree_gpu.ukernel_spec<name = "some_ukernel", def_attrs = {vm.import.module = "rocm"}>}>
+func.func @argmax_f32i64_with_selected_ukernel(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes {
+  hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}>
+} {
+  %c0_i64 = arith.constant 0 : i64
+  %cst = arith.constant 0xFF800000 : f32
+  %0 = tensor.empty() : tensor<1xi64>
+  %1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<1xi64>) -> tensor<1xi64>
+  %2 = tensor.empty() : tensor<1xf32>
+  %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<1xf32>) -> tensor<1xf32>
+  %4:2 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>],
+        iterator_types = ["parallel", "reduction"]
+      }
+      ins(%arg0 : tensor<1x?xf32>) outs(%3, %1 : tensor<1xf32>, tensor<1xi64>)
+      attrs = {
+        // The lowering_config.ukernel is what is essential to the lowering.
+        lowering_config = #config} {
+  ^bb0(%in: f32, %out: f32, %out_0: i64):
+    %5 = linalg.index 1 : index
+    %6 = arith.index_cast %5 : index to i64
+    %7 = arith.maximumf %in, %out : f32
+    %8 = arith.cmpf ogt, %in, %out : f32
+    %9 = arith.select %8, %6, %out_0 : i64
+    linalg.yield %7, %9 : f32, i64
+  } -> (tensor<1xf32>, tensor<1xi64>)
+  return %4#1 : tensor<1xi64>
+}
+
+//CHECK-LABEL: func @argmax_f32i64_with_selected_ukernel(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x?xf32>
+//  CHECK-DAG:   %[[C1_index:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C0_i64:.+]] = arith.constant 0
+//  CHECK-DAG:   %[[FILL:.+]] = linalg.fill ins(%[[C0_i64]]
+//      CHECK:   %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic
+//  CHECK-SAME:      "some_ukernel"
+// CHECK-SAME:       ins(%[[ARG0]] :
+// CHECK-SAME:       outs(%[[FILL]] :
+//      CHECK:   return %[[MICRO_KERNEL]]
+
+// -----
+
+func.func @argmax_f32i64_without_selected_ukernel(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes {
+  hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}>
+} {
+  %c0_i64 = arith.constant 0 : i64
+  %cst = arith.constant 0xFF800000 : f32
+  %0 = tensor.empty() : tensor<1xi64>
+  %1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<1xi64>) -> tensor<1xi64>
+  %2 = tensor.empty() : tensor<1xf32>
+  %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<1xf32>) -> tensor<1xf32>
+  %4:2 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>],
+        iterator_types = ["parallel", "reduction"]
+      }
+      ins(%arg0 : tensor<1x?xf32>) outs(%3, %1 : tensor<1xf32>, tensor<1xi64>) {
+  ^bb0(%in: f32, %out: f32, %out_0: i64):
+    %5 = linalg.index 1 : index
+    %6 = arith.index_cast %5 : index to i64
+    %7 = arith.maximumf %in, %out : f32
+    %8 = arith.cmpf ogt, %in, %out : f32
+    %9 = arith.select %8, %6, %out_0 : i64
+    linalg.yield %7, %9 : f32, i64
+  } -> (tensor<1xf32>, tensor<1xi64>)
+  return %4#1 : tensor<1xi64>
+}
+
+//CHECK-LABEL: func @argmax_f32i64_without_selected_ukernel(
+//      CHECK-NOT: iree_codegen.ukernel.generic
+//      CHECK: linalg.generic
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.cpp
index 6957caf..8ebfba9 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.cpp
@@ -145,4 +145,9 @@
   return getIntegerVector(array);
 }
 
+IREE::GPU::UKernelSpecAttr
+getUkernelSpec(IREE::GPU::LoweringConfigAttr config) {
+  return config.getAttributes().getAs<IREE::GPU::UKernelSpecAttr>("ukernel");
+}
+
 } // namespace mlir::iree_compiler::IREE::GPU
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h
index c1188b7..5bebb64 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h
@@ -59,6 +59,8 @@
 /// Helper to retrieve  list of operand to pad.
 std::optional<SmallVector<int64_t>> getPaddingList(LoweringConfigAttr config);
 
+IREE::GPU::UKernelSpecAttr getUkernelSpec(IREE::GPU::LoweringConfigAttr config);
+
 } // namespace mlir::iree_compiler::IREE::GPU
 
 #endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_IR_GPULOWERINGCONFIGUTILS_H_
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
index a239af3..0b1e32f 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
@@ -520,6 +520,25 @@
   }];
 }
 
+//===---------------------------------------------------------------------===//
+// iree_gpu.ukernel_spec
+//===---------------------------------------------------------------------===//
+
+def IREEGPU_UKernelSpecAttr  :
+    AttrDef<IREEGPU_Dialect, "UKernelSpec", []> {
+  let mnemonic = "ukernel_spec";
+  let summary = "An attribute specifying a ukernel that an op can lower to.";
+  let description = [{
+    An attribute that can be applied to any operation to specify that it has
+    been match with a ukernel that is a legal lowering for it.
+  }];
+  let assemblyFormat = "`<` struct(params) `>`";
+  let parameters = (ins
+       "StringAttr":$name,
+       "DictionaryAttr":$def_attrs
+  );
+}
+
 //===----------------------------------------------------------------------===//
 // GPU Pipeline Options
 //===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
index 73e0397..a5c1bce 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
@@ -147,6 +147,7 @@
         "//compiler/src/iree/compiler/Dialect/Flow/IR",
         "//compiler/src/iree/compiler/Dialect/Flow/Transforms",
         "//compiler/src/iree/compiler/Dialect/HAL/IR",
+        "//compiler/src/iree/compiler/Dialect/HAL/Transforms",
         "//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
         "//compiler/src/iree/compiler/Dialect/LinalgExt/Transforms",
         "//compiler/src/iree/compiler/Dialect/LinalgExt/Utils",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
index b33641b..5c20621 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
@@ -190,6 +190,7 @@
     iree::compiler::Dialect::Flow::IR
     iree::compiler::Dialect::Flow::Transforms
     iree::compiler::Dialect::HAL::IR
+    iree::compiler::Dialect::HAL::Transforms
     iree::compiler::Dialect::LinalgExt::IR
     iree::compiler::Dialect::LinalgExt::Transforms
     iree::compiler::Dialect::LinalgExt::Utils
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index cb22b59..ee4614d 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -10,6 +10,7 @@
 #include <numeric>
 #include <optional>
 
+#include "compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUSelectUKernels.h"
 #include "iree/compiler/Codegen/Common/GPU/GPUHeuristics.h"
 #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
 #include "iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h"
@@ -2042,28 +2043,15 @@
 /// Set the configuration for argmax when ukernels are enabled.
 /// Distribute all parallel dim across different workgroups, and only use single
 /// subgroup per workgroup.
-///
-/// TODO(bjacob): This is fragile, as we can't know yet if this argmax will be
-/// lowered to a ukernel. We need instead a config that works regardless of
-/// ukernels. For now, we use the looser condition that the argmax ukernel is
-/// enabled, a necessary but not sufficient condition for this particular op to
-/// lower to the ukernel. This is good enough for now for a couple of reasons:
-/// 1. Even if a argmax does not actually lower to a ukernel, this config should
-///    still work.
-/// 2. Ukernels are not enabled by default.
 static LogicalResult
 setArgmaxUkernelConfig(IREE::GPU::TargetAttr target,
                        mlir::FunctionOpInterface entryPoint,
                        linalg::GenericOp op) {
   // Checks if UKernels are enabled.
-  if (auto target = IREE::HAL::ExecutableTargetAttr::lookup(entryPoint)) {
-    if (!hasUkernel(target, "argmax")) {
-      return failure();
-    }
-  }
-
-  if (!target.supportsSubgroupShuffle())
+  IREE::GPU::UKernelSpecAttr ukernelSpec = selectUKernelForArgmax(op);
+  if (!ukernelSpec) {
     return failure();
+  }
 
   if (failed(isArgmaxOp(op)))
     return failure();
@@ -2094,26 +2082,35 @@
     return failure();
   }
 
-  // Tile all the parallel dimension to 1.
+  // Tile all the parallel dimension to 1. This is a requirement of the ukernel.
   SmallVector<unsigned> partitionedLoops =
       cast<PartitionableLoopsInterface>(op.getOperation())
           .getPartitionableLoops(kNumMaxParallelDims);
   size_t numLoops = partitionedLoops.empty() ? 0 : partitionedLoops.back() + 1;
   SmallVector<int64_t> workgroupTileSizes(numLoops, 1);
 
-  // Currently Argmax Ukernel let's every thread reduce reductionDim/WarpSize
+  // Currently Argmax Ukernel lets every thread reduce reductionDim/WarpSize
   // number of elements, and then it does a single step butterfly warp reduce.
   // Hence it expects workgroupSize to be warpSize(subgroupSize), and
   // reductionTileSize to be size of the reduction dim.
   SmallVector<int64_t> reductionTileSizes(op.getNumLoops(), 0);
   int64_t preferredSubgroupSize = target.getPreferredSubgroupSize();
   reductionTileSizes[reductionDims[0]] = preferredSubgroupSize;
-  TileSizesListType tileSizes;
-  tileSizes.emplace_back(std::move(workgroupTileSizes)); // Workgroup level
-  tileSizes.emplace_back(std::move(reductionTileSizes)); // Reduction level
   std::array<int64_t, 3> workgroupSize = {preferredSubgroupSize, 1, 1};
+
+  MLIRContext *context = op->getContext();
+  Builder b(context);
+  SmallVector<NamedAttribute, 2> attrs;
+  attrs.emplace_back(StringAttr::get(context, "workgroup"),
+                     b.getI64ArrayAttr(workgroupTileSizes));
+  attrs.emplace_back(StringAttr::get(context, "reduction"),
+                     b.getI64ArrayAttr(reductionTileSizes));
+  attrs.emplace_back(StringAttr::get(context, "ukernel"), ukernelSpec);
+  IREE::GPU::setPromotedOperandList(context, attrs, {0, 1});
+  auto configDict = DictionaryAttr::get(context, attrs);
+  auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict);
   if (failed(setOpConfigAndEntryPointFnTranslation(
-          entryPoint, op, tileSizes, CodeGenPipeline::LLVMGPUDefault,
+          entryPoint, op, loweringConfig, CodeGenPipeline::LLVMGPUDefault,
           workgroupSize))) {
     return failure();
   }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index f8ebe1c..b6414e1 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -21,6 +21,7 @@
 #include "iree/compiler/Codegen/Utils/Utils.h"
 #include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
 #include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
+#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
 #include "iree/compiler/Dialect/Util/Transforms/Passes.h"
 #include "iree/compiler/Utils/PassUtils.h"
 #include "llvm/ADT/STLForwardCompat.h"
@@ -1197,6 +1198,10 @@
 
 void buildLLVMGPUCodegenPassPipeline(OpPassManager &variantPassManager,
                                      bool useROCM) {
+  // LLVMGPUSelectLoweringStrategyPass may have created ExecutableObjectAttr.
+  // Hoisting them now deduplicates them and ensures that rewrite patterns don't
+  // need to think about explicitly copying them over to new ops.
+  variantPassManager.addPass(IREE::HAL::createHoistExecutableObjectsPass());
   {
     OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
     modulePassManager.addPass(createLowerExecutableUsingTransformDialectPass());
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/BUILD.bazel
index 113c6d5..66bd982 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/BUILD.bazel
@@ -17,10 +17,12 @@
 iree_compiler_cc_library(
     name = "Utils",
     srcs = [
+        "LLVMGPUSelectUKernels.cpp",
         "LLVMGPUUtils.cpp",
         "PrefetchSharedMemoryCopy.cpp",
     ],
     hdrs = [
+        "LLVMGPUSelectUKernels.h",
         "LLVMGPUUtils.h",
     ],
     deps = [
@@ -34,6 +36,7 @@
         "//compiler/src/iree/compiler/Codegen/Utils:VectorOpUtils",
         "//compiler/src/iree/compiler/Dialect/HAL/IR",
         "//compiler/src/iree/compiler/Dialect/LinalgExt/Utils",
+        "//compiler/src/iree/compiler/Utils",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:AMDGPUDialect",
         "@llvm-project//mlir:AffineDialect",
@@ -42,6 +45,7 @@
         "@llvm-project//mlir:FunctionInterfaces",
         "@llvm-project//mlir:GPUDialect",
         "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:LinalgDialect",
         "@llvm-project//mlir:MathDialect",
         "@llvm-project//mlir:MemRefDialect",
         "@llvm-project//mlir:NVGPUDialect",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/CMakeLists.txt
index 6b66e96..98ee940 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/CMakeLists.txt
@@ -14,8 +14,10 @@
   NAME
     Utils
   HDRS
+    "LLVMGPUSelectUKernels.h"
     "LLVMGPUUtils.h"
   SRCS
+    "LLVMGPUSelectUKernels.cpp"
     "LLVMGPUUtils.cpp"
     "PrefetchSharedMemoryCopy.cpp"
   DEPS
@@ -27,6 +29,7 @@
     MLIRFunctionInterfaces
     MLIRGPUDialect
     MLIRIR
+    MLIRLinalgDialect
     MLIRMathDialect
     MLIRMemRefDialect
     MLIRNVGPUDialect
@@ -45,6 +48,7 @@
     iree::compiler::Codegen::Utils::VectorOpUtils
     iree::compiler::Dialect::HAL::IR
     iree::compiler::Dialect::LinalgExt::Utils
+    iree::compiler::Utils
   PUBLIC
 )
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUSelectUKernels.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUSelectUKernels.cpp
new file mode 100644
index 0000000..1940e8f
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUSelectUKernels.cpp
@@ -0,0 +1,152 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUSelectUKernels.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
+#include "iree/compiler/Codegen/Utils/Utils.h"
+#include "iree/compiler/Utils/EmbeddedDataDirectory.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributes.h"
+
+namespace mlir::iree_compiler {
+
+namespace {
+
+constexpr StringLiteral executableObjectsAttrName = "hal.executable.objects";
+
+// Returns a ExecutableObjectAttr carrying the bitcode for the given ukernel.
+//
+// First tries finding the bitcode in the input `sourceExecutableObjects`, which
+// must be an array of ExecutableObjectAttr's and is typically coming from a
+// hal.executable.objects array attribute in the source IR, which is the
+// mechanism by which source programs may provide their own ukernel bitcode.
+//
+// If no matching bitcode was found in `sourceExecutableObjects`, this function
+// will then search in bitcode files that we have embedded as static data.
+static IREE::HAL::ExecutableObjectAttr
+getUKernelBitcode(MLIRContext *context,
+                  IREE::HAL::ExecutableTargetAttr execTarget,
+                  ArrayAttr sourceExecutableObjects, StringRef ukernelName) {
+  IREE::GPU::TargetAttr gpuTarget = getGPUTargetAttr(execTarget);
+  if (!gpuTarget) {
+    return {};
+  }
+  StringRef gpuArch = gpuTarget.getArch();
+  std::string bitcodeFilename = llvm::formatv("{}.{}.bc", ukernelName, gpuArch);
+
+  // Early-return if the source executable.objects already contain an object
+  // with the expected file name. This happens with user-provided bitcode in the
+  // source IR.
+  if (sourceExecutableObjects) {
+    for (Attribute a : sourceExecutableObjects) {
+      if (auto object = dyn_cast<IREE::HAL::ExecutableObjectAttr>(a)) {
+        if (object.getPath() == bitcodeFilename) {
+          return object;
+        }
+      }
+    }
+  }
+
+  // No user-provided bitcode, so we search our embedded bitcode files in the
+  // EmbeddedDataDirectory singleton.
+  std::optional<StringRef> bitcode;
+  EmbeddedDataDirectory::withGlobal([&](EmbeddedDataDirectory &dir) {
+    bitcode = dir.getFile(bitcodeFilename);
+  });
+  if (!bitcode) {
+    return {};
+  }
+  auto blob = HeapAsmResourceBlob::allocateAndCopyInferAlign(
+      ArrayRef<char>(bitcode->data(), bitcode->size()));
+  auto bitcodeDenseAttr = DenseI8ResourceElementsAttr::get(
+      VectorType::get({static_cast<int64_t>(bitcode->size())},
+                      IntegerType::get(context, 8)),
+      bitcodeFilename, std::move(blob));
+  return IREE::HAL::ExecutableObjectAttr::get(
+      context, StringAttr::get(context, bitcodeFilename),
+      cast<IREE::Util::SerializableAttrInterface>(bitcodeDenseAttr));
+}
+
+// Walks parents ops from `op` to return the nearest hal.executable.objects
+// array attribute. If the parent hal.executable.variant is reached, its objects
+// attribute is returned.
+// Adapted from ExecutableTargetAttr::lookup.
+static ArrayAttr lookUpExecutableObjects(Operation *op) {
+  MLIRContext *context = op->getContext();
+  auto attrId = StringAttr::get(context, executableObjectsAttrName);
+  while (op) {
+    // Take directly from the enclosing variant.
+    if (auto variantOp = dyn_cast<IREE::HAL::ExecutableVariantOp>(op)) {
+      if (std::optional<ArrayAttr> objects = variantOp.getObjects()) {
+        return *objects;
+      }
+    }
+    // Take from op attributes.
+    if (auto attr = op->getAttrOfType<ArrayAttr>(attrId)) {
+      return attr;
+    }
+    // Continue walk.
+    op = op->getParentOp();
+  }
+  return {};
+}
+
+/// Returns the function name and attributes to use for a ukernel with given
+/// `name` and `suffix` on the target described by `targetAttr`.
+static IREE::GPU::UKernelSpecAttr
+getUKernelSpec(StringRef name, StringRef suffix, MLIRContext *context,
+               IREE::HAL::ExecutableTargetAttr targetAttr) {
+  if (isROCMBackend(targetAttr)) {
+    auto nameAttr = StringAttr::get(
+        context, llvm::formatv("iree_uk_amdgpu_{}_{}", name, suffix));
+    auto defsAttr = DictionaryAttr::get(
+        context, {{StringAttr::get(context, "vm.import.module"),
+                   StringAttr::get(context, "rocm")}});
+    return IREE::GPU::UKernelSpecAttr::get(context, nameAttr, defsAttr);
+  }
+  return {};
+}
+
+} // namespace
+
+IREE::GPU::UKernelSpecAttr selectUKernelForArgmax(linalg::GenericOp op) {
+  if (failed(isArgmaxOp(op))) {
+    return {};
+  }
+  auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(op);
+  const char ukernelName[] = "argmax";
+  if (!hasUkernel(targetAttr, ukernelName)) {
+    return {};
+  }
+  Value input = op.getDpsInputOperand(0)->get();
+  auto inputType = cast<ShapedType>(input.getType());
+  Value index = op.getDpsInitOperand(1)->get();
+  auto indexType = cast<ShapedType>(index.getType());
+  std::string suffix;
+  llvm::raw_string_ostream(suffix)
+      << inputType.getElementType() << indexType.getElementType();
+  MLIRContext *context = op->getContext();
+  IREE::GPU::UKernelSpecAttr ukernelSpec =
+      getUKernelSpec(ukernelName, suffix, context, targetAttr);
+  if (!ukernelSpec) {
+    return {};
+  }
+  auto execTarget = IREE::HAL::ExecutableTargetAttr::lookup(op);
+  ArrayAttr sourceExecutableObjects = lookUpExecutableObjects(op);
+  IREE::HAL::ExecutableObjectAttr bitcodeObject = getUKernelBitcode(
+      context, execTarget, sourceExecutableObjects, ukernelSpec.getName());
+  if (!bitcodeObject) {
+    return {};
+  }
+  op->setAttr(executableObjectsAttrName,
+              ArrayAttr::get(context, bitcodeObject));
+  return ukernelSpec;
+}
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUSelectUKernels.h b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUSelectUKernels.h
new file mode 100644
index 0000000..4ed251b
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUSelectUKernels.h
@@ -0,0 +1,15 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
+#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+
+namespace mlir::iree_compiler {
+
+IREE::GPU::UKernelSpecAttr selectUKernelForArgmax(linalg::GenericOp op);
+
+} // namespace mlir::iree_compiler