[Codegen][Tuner] Make default and user-provided specs work

This is a fixup to the tuning spec materialization that makes default
and user-provided specs work e2e. As an example, a working spec for
`linalg.matmul_transpose_b` is provided for gfx942.

* Allow for tuning spec entry points that consume their argument op.
  This is so that tuning specs can use `transform.foreach_match`.
* Require all tuning spec entry points to return `any_op`, so that we
  can chain includes. This works for both consumed and readonly args.
* Add a test to show that user-provided tuning specs take precedence
  over default ones.
* Work around a transform interpreter bug when multiple named sequenced
  across different modules share the same symbol name.

Issue: https://github.com/iree-org/iree/issues/19214

Signed-off-by: Jakub Kuderski <jakub@nod-labs.com>
diff --git a/compiler/plugins/target/ROCM/builtins/tuning/iree_default_tuning_spec_gfx942.mlir b/compiler/plugins/target/ROCM/builtins/tuning/iree_default_tuning_spec_gfx942.mlir
index f0a8ca5..53f5688 100644
--- a/compiler/plugins/target/ROCM/builtins/tuning/iree_default_tuning_spec_gfx942.mlir
+++ b/compiler/plugins/target/ROCM/builtins/tuning/iree_default_tuning_spec_gfx942.mlir
@@ -7,9 +7,68 @@
 
 module @iree_default_tuning_spec_gfx942 attributes { transform.with_named_sequence } {
 
-transform.named_sequence @__kernel_config(%variant_op: !transform.any_op {transform.readonly}) -> ()
-  attributes { iree_codegen.tuning_spec_entrypoint } {
+transform.named_sequence @apply_op_config(%op: !transform.any_op {transform.readonly},
+                                        %config: !transform.any_param {transform.readonly}) {
+  // transform.print %op {name="Apply on"} : !transform.any_op
+  transform.annotate %op "compilation_info" = %config : !transform.any_op, !transform.any_param
+  // Add a dummy unit attribute to be sure that the tuning spec applied.
+  // Otherwise it would be difficult to tell if the lowering config attribute
+  // comes from our tuning spec or if the compiler heuristic happened to produce
+  // the same config as this script.
+  transform.annotate %op "__tuning_spec_applied__" : !transform.any_op
   transform.yield
 }
 
+transform.named_sequence @match_mmt_f16_f16_f32(%root: !transform.any_op {transform.readonly}) -> !transform.any_op {
+  transform.match.operation_name %root ["linalg.generic"] : !transform.any_op
+  // transform.print %root {name = "Generic"} : !transform.any_op
+  %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %root {
+    ^bb0(%lhs: tensor<?x?xf16>, %rhs: tensor<?x?xf16>, %out: tensor<?x?xf32>):
+    %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+                                          affine_map<(d0, d1, d2) -> (d1, d2)>,
+                                          affine_map<(d0, d1, d2) -> (d0, d1)>],
+                          iterator_types = ["parallel", "parallel", "reduction"]}
+        ins(%lhs, %rhs : tensor<?x?xf16>, tensor<?x?xf16>) outs(%out : tensor<?x?xf32>) {
+      ^bb0(%in: f16, %in_0: f16, %acc: f32):
+        %8 = arith.extf %in : f16 to f32
+        %9 = arith.extf %in_0 : f16 to f32
+        %10 = arith.mulf %8, %9 : f32
+        %11 = arith.addf %acc, %10 : f32
+        linalg.yield %11 : f32
+      } -> tensor<?x?xf32>
+  } : (!transform.any_op) -> (!transform.any_value, !transform.any_value)
+  transform.yield %root : !transform.any_op
+}
+
+transform.named_sequence
+@match_mmt_2048x1280x5120_f16_f16_f32(%matmul: !transform.any_op {transform.readonly})
+  -> (!transform.any_op, !transform.any_param) {
+  %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul)
+    : (!transform.any_op) -> !transform.any_op
+  %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value
+  %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value
+  transform.iree.match.cast_compatible_type %lhs = tensor<2048x5120xf16> : !transform.any_value
+  transform.iree.match.cast_compatible_type %rhs = tensor<1280x5120xf16> : !transform.any_value
+  %config = transform.param.constant #iree_codegen.compilation_info<
+    lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1],
+                                                 mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
+                                                 subgroup_m_count = 2, subgroup_n_count = 2,
+                                                 reduction = [0, 0, 64],
+                                                 workgroup = [64, 128, 0]}>,
+    translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
+      workgroup_size = [256, 1, 1] subgroup_size = 64,
+      {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true>}>
+  > -> !transform.any_param
+  transform.yield %matmul, %config : !transform.any_op, !transform.any_param
+}
+
+transform.named_sequence
+@__kernel_config(%variant_op: !transform.any_op {transform.consumed}) -> !transform.any_op
+  attributes { iree_codegen.tuning_spec_entrypoint } {
+  %res = transform.foreach_match in %variant_op
+    @match_mmt_2048x1280x5120_f16_f16_f32 -> @apply_op_config
+    : (!transform.any_op) -> !transform.any_op
+  transform.yield %res : !transform.any_op
+}
+
 }
diff --git a/compiler/plugins/target/ROCM/builtins/tuning/test/spec_gfx942.mlir b/compiler/plugins/target/ROCM/builtins/tuning/test/spec_gfx942.mlir
index 51c0e01..b63ddc9 100644
--- a/compiler/plugins/target/ROCM/builtins/tuning/test/spec_gfx942.mlir
+++ b/compiler/plugins/target/ROCM/builtins/tuning/test/spec_gfx942.mlir
@@ -4,7 +4,12 @@
 // RUN:   --iree-codegen-notify-transform-strategy-application \
 // RUN:   --verify-diagnostics %s | FileCheck %s
 
-// CHECK-LABEL:      func.func @placeholder
+// Check that the default configuration for mmt_2048x1280x5120_f16_f16_f32
+// applies to the `linalg.matmul_transpose_b` below.
+
+// CHECK-LABEL:  func.func @mmt_2048x1280x5120_f16_f16_f32
+// CHECK:          linalg.generic
+// CHECK-SAME:       __tuning_spec_applied__
 
 #pipeline_layout = #hal.pipeline.layout<bindings = [
   #hal.pipeline.binding<storage_buffer>,
@@ -13,14 +18,27 @@
 ]>
 hal.executable public @main {
   hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) {
-    hal.executable.export public @placeholder ordinal(0) layout(#pipeline_layout) {
+    hal.executable.export public @matmul_transpose_b ordinal(0) layout(#pipeline_layout) {
     ^bb0(%arg0: !hal.device):
       %x, %y, %z = flow.dispatch.workgroup_count_from_slice
       hal.return %x, %y, %z : index, index, index
     }
     builtin.module {
       // expected-remark@+1 {{Applied transform configuration strategy @iree_default_tuning_spec_gfx942::@__kernel_config}}
-      func.func @placeholder() {
+      func.func @mmt_2048x1280x5120_f16_f16_f32() {
+        %cst = arith.constant 0.000000e+00 : f16
+        %c0 = arith.constant 0 : index
+        %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2048x5120xf16>>
+        %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1280x5120xf16>>
+        %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2048x1280xf32>>
+        %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [2048, 5120], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2048x5120xf16>> -> tensor<2048x5120xf16>
+        %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [1280, 5120], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<1280x5120xf16>> -> tensor<1280x5120xf16>
+        %5 = tensor.empty() : tensor<2048x1280xf32>
+        %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>
+        %7 = linalg.matmul_transpose_b
+          ins(%3, %4 : tensor<2048x5120xf16>, tensor<1280x5120xf16>)
+          outs(%6 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>
+        flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [2048, 1280], strides = [1, 1] : tensor<2048x1280xf32> -> !flow.dispatch.tensor<writeonly:tensor<2048x1280xf32>>
         return
       }
     }
diff --git a/compiler/plugins/target/ROCM/test/lowering_strategy_from_tuning_spec.mlir b/compiler/plugins/target/ROCM/test/lowering_strategy_from_tuning_spec.mlir
index 6f7cf09..e769d4d 100644
--- a/compiler/plugins/target/ROCM/test/lowering_strategy_from_tuning_spec.mlir
+++ b/compiler/plugins/target/ROCM/test/lowering_strategy_from_tuning_spec.mlir
@@ -4,13 +4,20 @@
 // RUN:   --iree-codegen-notify-transform-strategy-application \
 // RUN:   --verify-diagnostics %s | FileCheck %s
 
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx942 \
+// RUN:   --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-hal-configure-target-executable-variants{target=rocm})))" \
+// RUN:   --iree-codegen-tuning-spec-path=%p/tuning_spec_mmt_tile_and_fuse.mlir \
+// RUN:   --iree-codegen-enable-default-tuning-specs \
+// RUN:   --iree-codegen-notify-transform-strategy-application \
+// RUN:   --verify-diagnostics %s | FileCheck %s
+
 // Make sure we can apply the lowering strategy from the specified tuning spec.
 
 // CHECK:      #translation = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [128, 1, 1] subgroup_size = 64>
 // CHECK:      func.func @matmul_transpose_b
 // CHECK-SAME:   translation_info = #translation
 // CHECK:        linalg.generic
-// CHECK-SAME:     __tuning_spec_applied__
+// CHECK-SAME:     __custom_tuning_spec_applied__
 // CHECK-SAME:     lowering_config = #iree_gpu.lowering_config<
 
 #pipeline_layout = #hal.pipeline.layout<bindings = [
diff --git a/compiler/plugins/target/ROCM/test/tuning_spec_mmt_tile_and_fuse.mlir b/compiler/plugins/target/ROCM/test/tuning_spec_mmt_tile_and_fuse.mlir
index 24f0c3a..2c85c55 100644
--- a/compiler/plugins/target/ROCM/test/tuning_spec_mmt_tile_and_fuse.mlir
+++ b/compiler/plugins/target/ROCM/test/tuning_spec_mmt_tile_and_fuse.mlir
@@ -1,24 +1,33 @@
 // RUN: iree-opt %s
 
 module @mmt_tile_and_fuse_spec attributes { transform.with_named_sequence } {
-  transform.named_sequence @main(%arg0: !transform.any_op {transform.readonly}) -> ()
-    attributes { iree_codegen.tuning_spec_entrypoint } {
-    %mmt = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-    // transform.print %mmt {name="MMT"} : !transform.any_op
-    %config = transform.param.constant #iree_codegen.compilation_info<
-      lowering_config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0],
-                                                   reduction = [0, 0, 4],
-                                                   thread = [8, 4],
-                                                   promote_operands = [0, 1]}>,
-      translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse
-        workgroup_size = [128, 1, 1] subgroup_size = 64>
-    > -> !transform.any_param
-    transform.annotate %mmt "compilation_info" = %config : !transform.any_op, !transform.any_param
-    // Add a dummy unit attribute to be sure that the tuning spec applied.
-    // Otherwise it would be difficult to tell if the lowering config attribute
-    // comes from our tuning spec or if the compiler heuristic happened to produce
-    // the same config as this script.
-    transform.annotate %mmt "__tuning_spec_applied__" : !transform.any_op
-    transform.yield
-  }
+transform.named_sequence @apply_op_config(%op: !transform.any_op {transform.readonly},
+                                          %config: !transform.any_param {transform.readonly}) {
+  transform.annotate %op "compilation_info" = %config : !transform.any_op, !transform.any_param
+  transform.annotate %op "__custom_tuning_spec_applied__" : !transform.any_op
+  transform.yield
+}
+
+transform.named_sequence @match_mmt(%matmul: !transform.any_op {transform.readonly})
+  -> (!transform.any_op, !transform.any_param) {
+  transform.match.operation_name %matmul ["linalg.generic"] : !transform.any_op
+  %config = transform.param.constant #iree_codegen.compilation_info<
+    lowering_config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0],
+                                                  reduction = [0, 0, 4],
+                                                  thread = [8, 4],
+                                                  promote_operands = [0, 1]}>,
+    translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse
+      workgroup_size = [128, 1, 1] subgroup_size = 64>
+  > -> !transform.any_param
+ transform.yield %matmul, %config : !transform.any_op, !transform.any_param
+}
+
+transform.named_sequence @main(%variant_op: !transform.any_op {transform.consumed}) -> (!transform.any_op)
+  attributes { iree_codegen.tuning_spec_entrypoint } {
+  transform.print %variant_op {name="Custom spec"} : !transform.any_op
+  %res = transform.foreach_match in %variant_op
+    @match_mmt -> @apply_op_config
+    : (!transform.any_op) -> !transform.any_op
+  transform.yield %res : !transform.any_op
+}
 }
diff --git a/compiler/src/iree/compiler/Codegen/Common/LinkTuningSpecsPass.cpp b/compiler/src/iree/compiler/Codegen/Common/LinkTuningSpecsPass.cpp
index 8f57104..9c7f6ad 100644
--- a/compiler/src/iree/compiler/Codegen/Common/LinkTuningSpecsPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/LinkTuningSpecsPass.cpp
@@ -8,8 +8,8 @@
 #include "iree/compiler/Codegen/Common/Passes.h"
 #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
 #include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/SmallVectorExtras.h"
+#include "llvm/Support/FormatVariadic.h"
 #include "mlir/Dialect/Transform/IR/TransformAttrs.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformOps.h"
@@ -31,6 +31,10 @@
 namespace {
 
 using mlir::transform::NamedSequenceOp;
+constexpr StringLiteral kArgConsumedAttrName =
+    mlir::transform::TransformDialect::kArgConsumedAttrName;
+constexpr StringLiteral kArgReadOnlyAttrName =
+    mlir::transform::TransformDialect::kArgReadOnlyAttrName;
 
 static SmallVector<ModuleOp>
 findNestedModulesWithNamedSequences(ModuleOp module) {
@@ -49,43 +53,53 @@
       });
 }
 
+// Returns true iff the entrypoint has the following signature:
+// ```
+// transform.named_sequence @name(%arg0: !transform.any_op) ->
+// (!transform.any_op)
+// ```
 static LogicalResult validateTuningSpec(NamedSequenceOp op) {
-  if (!op.getResultTypes().empty()) {
-    op->emitWarning() << "Tuning spec expected to have no results";
-    return failure();
+  ArrayRef<Type> resTypes = op.getFunctionType().getResults();
+  if (resTypes.size() != 1 || !isa<transform::AnyOpType>(resTypes[0])) {
+    return op.emitWarning()
+           << "Tuning spec entry point expected to return any_op";
   }
 
   ArrayRef<Type> argTypes = op.getArgumentTypes();
   if (argTypes.size() != 1 || !isa<transform::AnyOpType>(argTypes[0])) {
-    op->emitWarning() << "Tuning spec expected to have one argument of type "
-                         "'!transform.any_op'";
-    return failure();
-  }
-
-  if (!op.getArgAttr(0, transform::TransformDialect::kArgReadOnlyAttrName)) {
-    op->emitWarning() << "Tuning spec expected to have one readonly argument";
-    return failure();
+    return op.emitWarning() << "Tuning spec entry point expected to have a "
+                               "single any_op argument";
   }
 
   return success();
 }
 
+static bool consumesInputOp(NamedSequenceOp op) {
+  if (op.getArgAttr(0, kArgConsumedAttrName)) {
+    return true;
+  }
+  return false;
+}
+
 static NamedSequenceOp
 emitLinkedTuningSpec(ModuleOp module, ArrayRef<NamedSequenceOp> specsToLink) {
   OpBuilder builder(module->getContext());
   builder.setInsertionPointToEnd(module.getBody());
 
+  const bool hasConsumedSequences = llvm::any_of(specsToLink, consumesInputOp);
   Location loc = builder.getFusedLoc(llvm::map_to_vector(
       specsToLink, [](NamedSequenceOp op) { return op->getLoc(); }));
-  FunctionType specType = builder.getFunctionType(
-      TypeRange{builder.getType<transform::AnyOpType>()}, TypeRange{});
+  Type anyOpType = builder.getType<transform::AnyOpType>();
+  FunctionType specType =
+      builder.getFunctionType(TypeRange{anyOpType}, TypeRange{anyOpType});
   auto newSpec = builder.create<NamedSequenceOp>(
       loc, kKernelConfigSpecName, TypeAttr::get(specType),
       /*sym_visibility=*/StringAttr{},
       /*arg_attrs=*/ArrayAttr{},
       /*res_attrs*/ ArrayAttr{});
-  newSpec.setArgAttr(0, transform::TransformDialect::kArgReadOnlyAttrName,
-                     builder.getUnitAttr());
+  newSpec.setArgAttr(
+      0, hasConsumedSequences ? kArgConsumedAttrName : kArgReadOnlyAttrName,
+      builder.getUnitAttr());
   newSpec->setAttr(kTuningSpecEntrypointAttrName, builder.getUnitAttr());
 
   Region &region = newSpec.getRegion();
@@ -93,6 +107,13 @@
                                     newSpec.getArgumentTypes(), loc);
   builder.setInsertionPointToStart(body);
 
+  // Make sure spec names are unique to work around a transform dialect
+  // interpreter bug (`transform.include` does not handle name collisions
+  // correctly).
+  llvm::StringMap<unsigned> specNameCounts;
+  // Reserve the name for the outermost entrypoint.
+  specNameCounts[kKernelConfigSpecName] = 1;
+
   // Emit one `transform.include` op per child tuning spec. In the future,
   // we may want to switch to a custom transform op for this to perform
   // 'short-circuring' and apply at most one tuning spec.
@@ -102,17 +123,27 @@
     assert(parentModule);
     StringAttr parentSymbol = parentModule.getSymNameAttr();
     assert(parentSymbol);
+    StringRef specName = spec.getSymName();
+    unsigned specNameSeenCount = specNameCounts[specName]++;
+    if (specNameSeenCount > 0) {
+      spec.setSymName(
+          llvm::formatv("{}_{}", specName, specNameSeenCount).str());
+    }
+
     auto symbol = SymbolRefAttr::get(
         parentSymbol, FlatSymbolRefAttr::get(spec.getSymNameAttr()));
 
     // Surpress silenceable errors so that failures to match in child tuning
     // specs can be ignored.
-    builder.create<transform::IncludeOp>(
-        loc, TypeRange{}, symbol, transform::FailurePropagationMode::Suppress,
-        operand);
+    operand = builder
+                  .create<transform::IncludeOp>(
+                      loc, anyOpType, symbol,
+                      transform::FailurePropagationMode::Suppress, operand)
+                  .getResults()
+                  .front();
   }
 
-  builder.create<transform::YieldOp>(loc);
+  builder.create<transform::YieldOp>(loc, operand);
   return newSpec;
 }
 
@@ -145,6 +176,14 @@
     }
   }
 
+  size_t numConsumedSpecs = llvm::count_if(tuningSpecs, consumesInputOp);
+  if (numConsumedSpecs > 0 && numConsumedSpecs != tuningSpecs.size()) {
+    LDBG("Only " << numConsumedSpecs << " tuning specs out of "
+                 << tuningSpecs.size() << " total consume the input op");
+    return module.emitWarning() << "Expected the argument in all tuning specs "
+                                   "to be consistently readonly or consumed";
+  }
+
   if (tuningSpecs.empty()) {
     LDBG("No tuning specs found, exiting without linking");
     return NamedSequenceOp{};
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/link_tuning_specs.mlir b/compiler/src/iree/compiler/Codegen/Common/test/link_tuning_specs.mlir
index 0df1b06..018107b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/link_tuning_specs.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/link_tuning_specs.mlir
@@ -6,42 +6,42 @@
 // CHECK:         transform.named_sequence @outer_spec
 //
 // CHECK:         transform.named_sequence @__kernel_config
-// CHECK-SAME:      (%arg0: !transform.any_op {transform.readonly})
+// CHECK-SAME:      (%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op
 // CHECK-SAME:      attributes {iree_codegen.tuning_spec_entrypoint}
-// CHECK:           transform.include @foo_module::@foo failures(suppress)
-// CHECK-NEXT:      transform.include @bar_module::@bar failures(suppress)
-// CHECK-NEXT:      transform.include @baz_module::@baz failures(suppress)
-// CHECK-NEXT:      transform.yield
+// CHECK:           %[[OP1:.+]] = transform.include @foo_module::@foo failures(suppress) (%arg0)
+// CHECK-NEXT:      %[[OP2:.+]] = transform.include @bar_module::@bar failures(suppress) (%[[OP1]])
+// CHECK-NEXT:      %[[OP3:.+]] = transform.include @baz_module::@baz failures(suppress) (%[[OP2]])
+// CHECK-NEXT:      transform.yield %[[OP3]] : !transform.any_op
 
 module @td_module_0 attributes { transform.with_named_sequence } {
   module @foo_module attributes { transform.with_named_sequence } {
-    transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> ()
+    transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op
       attributes { iree_codegen.tuning_spec_entrypoint } {
       transform.print {name = "Foo", skip_regions}
-      transform.yield
+      transform.yield %arg0 : !transform.any_op
     }
   }
 
   module @bar_module attributes { transform.with_named_sequence } {
-    transform.named_sequence @bar(%arg0: !transform.any_op {transform.readonly}) -> ()
+    transform.named_sequence @bar(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op
       attributes { iree_codegen.tuning_spec_entrypoint } {
       transform.match.operation_name %arg0 ["func.func"] : !transform.any_op
       transform.print {name = "Bar", skip_regions}
-      transform.yield
+      transform.yield %arg0 : !transform.any_op
     }
   }
 
   module @baz_module attributes { transform.with_named_sequence } {
-    transform.named_sequence @baz(%arg0: !transform.any_op {transform.readonly}) -> ()
+    transform.named_sequence @baz(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op
       attributes { iree_codegen.tuning_spec_entrypoint } {
       transform.print {name = "Baz", skip_regions}
-      transform.yield
+      transform.yield %arg0 : !transform.any_op
     }
   }
 
-  transform.named_sequence @outer_spec(%module: !transform.any_op {transform.readonly}) -> ()
+  transform.named_sequence @outer_spec(%module: !transform.any_op {transform.readonly}) -> !transform.any_op
     attributes { iree_codegen.tuning_spec_entrypoint } {
-    transform.yield
+    transform.yield %module : !transform.any_op
   }
 }
 
@@ -52,19 +52,19 @@
 
 // CHECK-LABEL: module @td_module_1
 // CHECK:       @foo_module
-// CHECK:       @__kernel_config
-// CHECK-NOT      transform.include @foo_module::@foo failures(suppress) (%arg0) : (!transform.any_op) -> ()
-// CHECK:         transform.include @foo_module::@bar failures(suppress) (%arg0) : (!transform.any_op) -> ()
+// CHECK:       @__kernel_config(
+// CHECK-NOT      transform.include @foo_module::@foo failures(suppress) (%arg0) : (!transform.any_op) -> !transform.any_op
+// CHECK:         transform.include @foo_module::@bar failures(suppress) (%arg0) : (!transform.any_op) -> !transform.any_op
 // CHECK-NEXT:    transform.yield
 
 module @td_module_1 attributes { transform.with_named_sequence } {
   module @foo_module attributes { transform.with_named_sequence } {
-    transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> () {
-      transform.yield
+    transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
+      transform.yield %arg0 : !transform.any_op
     }
-    transform.named_sequence @bar(%arg0: !transform.any_op {transform.readonly}) -> ()
+    transform.named_sequence @bar(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op
       attributes { iree_codegen.tuning_spec_entrypoint } {
-      transform.yield
+      transform.yield %arg0 : !transform.any_op
     }
     func.func @baz(%arg0: i32) -> () {
       return
@@ -91,9 +91,47 @@
 
 module @td_module_3 attributes { transform.with_named_sequence } {
   module attributes { transform.with_named_sequence } {
-    transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> ()
+    transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op
       attributes { iree_codegen.tuning_spec_entrypoint } {
-      transform.yield
+      transform.yield %arg0 : !transform.any_op
+    }
+  }
+}
+
+// -----
+
+// Make sure that the names of all included specs and the outermost entrypoint
+// are kept unique.
+
+// CHECK-LABEL: module @td_module_4
+// CHECK:       @foo_module attributes
+// CHECK:       @bar_module attributes
+// CHECK:       @__kernel_config(
+// CHECK:         transform.include @foo_module::@foo failures(suppress) (%arg0) : (!transform.any_op) -> !transform.any_op
+// CHECK:         transform.include @foo_module::@__kernel_config_1 failures(suppress)
+// CHECK:         transform.include @bar_module::@foo_1 failures(suppress)
+// CHECK:         transform.include @bar_module::@__kernel_config_2 failures(suppress)
+// CHECK-NEXT:    transform.yield
+
+module @td_module_4 attributes { transform.with_named_sequence } {
+  module @foo_module attributes { transform.with_named_sequence } {
+    transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op
+      attributes { iree_codegen.tuning_spec_entrypoint } {
+      transform.yield %arg0 : !transform.any_op
+    }
+    transform.named_sequence @__kernel_config(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op
+      attributes { iree_codegen.tuning_spec_entrypoint } {
+      transform.yield %arg0 : !transform.any_op
+    }
+  }
+  module @bar_module attributes { transform.with_named_sequence } {
+    transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op
+      attributes { iree_codegen.tuning_spec_entrypoint } {
+      transform.yield %arg0 : !transform.any_op
+    }
+    transform.named_sequence @__kernel_config(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op
+      attributes { iree_codegen.tuning_spec_entrypoint } {
+      transform.yield %arg0 : !transform.any_op
     }
   }
 }
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/tuning_spec.mlir b/compiler/src/iree/compiler/Codegen/Common/test/tuning_spec.mlir
index 24af073..78fddf7 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/tuning_spec.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/tuning_spec.mlir
@@ -1,9 +1,9 @@
 // RUN: iree-opt %s
 
 module @user_spec attributes { transform.with_named_sequence } {
-  transform.named_sequence @hello(%arg0: !transform.any_op {transform.readonly}) -> ()
+  transform.named_sequence @hello(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op
     attributes { iree_codegen.tuning_spec_entrypoint } {
     transform.print {name = "Hello Tuning Spec", skip_regions}
-    transform.yield
+    transform.yield %arg0 : !transform.any_op
   }
 }