[Codegen] Add matmul and batched matmul to list of ops to generalize (#21720)

Changes in this PR are 

1) carry 'compilation info' across from linalg.matmul (or other named
op) to linalg.generic when running the generalize pass. See for example
the test `lowering_matmul_promotion.mlir`

2) Refactoring the SPIRV configuration selection. This now fails if a
matmul/batch_matmul appears in specialized form.

3) One nvidia test now uses a different pipeline. It matches earlier in
the selection process for something more specialized. Looking at the
history of the config selected, it looks like it used to choose a more
specialised config anyway:
https://github.com/iree-org/iree/blame/80af7ac6546a9d6975f2c5dbbb1d68e3c640565a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_nvidia_matmul.mlir#L61

---------

Signed-off-by: James Newling <james.newling@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUGeneralizeNamedOps.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUGeneralizeNamedOps.cpp
index d53f5a2..5055446 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUGeneralizeNamedOps.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUGeneralizeNamedOps.cpp
@@ -16,22 +16,24 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
-#include "mlir/Pass/Pass.h"
 
 namespace mlir::iree_compiler {
 
 #define GEN_PASS_DEF_GPUGENERALIZENAMEDOPSPASS
 #include "iree/compiler/Codegen/Common/GPU/Passes.h.inc"
 
-LogicalResult
+static LogicalResult
 generalizeCandidates(MLIRContext *context,
                      ArrayRef<linalg::LinalgOp> namedOpCandidates) {
   IRRewriter rewriter(context);
   for (auto linalgOp : namedOpCandidates) {
-    // Pass down lowering configuration. It can exist due to user set
-    // configuration from the input.
+    // Pass down lowering configuration and compilation info. These
+    // can exist due to user set configuration from the input.
     IREE::Codegen::LoweringConfigAttrInterface config =
         getLoweringConfig(linalgOp);
+    IREE::Codegen::CompilationInfoAttr compilationInfo =
+        getCompilationInfo(linalgOp);
+
     rewriter.setInsertionPoint(linalgOp);
     FailureOr<linalg::GenericOp> generalizedOp =
         linalg::generalizeNamedOp(rewriter, linalgOp);
@@ -42,6 +44,9 @@
     if (config) {
       setLoweringConfig(*generalizedOp, config);
     }
+    if (compilationInfo) {
+      setCompilationInfo(*generalizedOp, compilationInfo);
+    }
   }
   return success();
 }
@@ -53,9 +58,8 @@
     FunctionOpInterface funcOp = getOperation();
     SmallVector<linalg::LinalgOp> namedOpCandidates;
     funcOp.walk([&](linalg::LinalgOp linalgOp) {
-      if (isa<linalg::BatchMatmulTransposeBOp, linalg::MatmulTransposeBOp,
-              linalg::VecmatOp, linalg::MatvecOp, linalg::TransposeOp>(
-              linalgOp.getOperation()))
+      if (isa<linalg::BatchMatmulOp, linalg::MatmulOp, linalg::MatvecOp,
+              linalg::TransposeOp, linalg::VecmatOp>(linalgOp.getOperation()))
         namedOpCandidates.push_back(linalgOp);
     });
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
index b186a70..904e59c 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
@@ -140,6 +140,11 @@
 def GPUGeneralizeNamedOpsPass :
     InterfacePass<"iree-codegen-gpu-generalize-named-ops", "mlir::FunctionOpInterface"> {
   let summary = "Convert named Linalg ops to linalg.generic ops";
+
+  let description = [{
+    Convert a whitelisted set of named Linalg ops to linalg.generics. The whitelist
+    does not contain all named ops.
+  }];
 }
 
 def GPUGreedilyDistributeToThreadsPass :
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir
index 5c0b875..6f25f7a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir
@@ -58,7 +58,7 @@
 //      CHECK: func.func @dot_dispatch_1
 // CHECK-SAME:     translation_info = #[[TRANSLATION]]
 //      CHECK:   linalg.fill
-//      CHECK:   linalg.matmul
+//      CHECK:   linalg.generic
 // CHECK-SAME:       lowering_config = #iree_gpu.lowering_config<{reduction = [0, 0, 4], thread = [2, 1, 0], workgroup = [4, 2, 1]}>
 
 // -----
@@ -85,7 +85,7 @@
 //      CHECK: func.func @unaligned_k
 // CHECK-SAME:     translation_info = #[[TRANSLATION]]
 //      CHECK:   linalg.fill
-//      CHECK:   linalg.matmul
+//      CHECK:   linalg.generic
 // CHECK-SAME:       lowering_config = #iree_gpu.lowering_config<{reduction = [0, 0, 2], thread = [1, 16, 0], workgroup = [32, 128, 1]}>
 
 // -----
@@ -273,7 +273,7 @@
 //      CHECK: func.func @_lowering_config_test_dispatch_1()
 // CHECK-SAME:     translation_info = #[[TRANSLATION]]
 //      CHECK: linalg.fill
-//      CHECK: linalg.matmul
+//      CHECK: linalg.generic
 // CHECK-SAME:     lowering_config = #[[CONFIG]]
 
 // -----
@@ -456,13 +456,15 @@
   iree_tensor_ext.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [2560, 2048], strides = [1, 1] : tensor<2560x2048xf16> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2560x2048xf16>>
   return
 }
-//  SM80-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[128, 256, 32]{{\]}}
-//  SM80-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUMatmulTensorCoreMmaSync workgroup_size = [128, 2, 1] subgroup_size = 32, {pipeline_depth = 3 : i64, store_stage = 1 : i64}>
-//      SM80: func.func @large_matmul_f16()
-// SM80-SAME:     translation_info = #[[TRANSLATION]]
-//      SM80: linalg.fill
-//      SM80: linalg.matmul
-// SM80-SAME:     lowering_config = #[[CONFIG]]
+
+//      SM80:   #config = #iree_codegen.lowering_config<tile_sizes = {{\[}}[128, 256, 32]]>
+//      SM80:   #translation = #iree_codegen.translation_info<pipeline = LLVMGPUMatmulTensorCoreMmaSync workgroup_size = [128, 2, 1]
+// SM80-SAME:   subgroup_size = 32, {pipeline_depth = 3 : i64, store_stage = 1 : i64}>
+//      SM80:   func.func @large_matmul_f16()
+// SM80-SAME:       translation_info = #[[TRANSLATION]]
+//      SM80:   linalg.fill
+//      SM80:   linalg.generic
+// SM80-SAME:       lowering_config = #config
 
 // -----
 
@@ -493,7 +495,7 @@
 //      SM80: func.func @large_matmul_f32()
 // SM80-SAME:     translation_info = #[[TRANSLATION]]
 //      SM80: linalg.fill
-//      SM80: linalg.matmul
+//      SM80: linalg.generic
 // SM80-SAME:     lowering_config = #[[CONFIG]]
 
 // -----
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/linalg_transform.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/linalg_transform.mlir
index 2703218..f05b19a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/linalg_transform.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/linalg_transform.mlir
@@ -26,8 +26,10 @@
   %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<250x1020xf32>) -> tensor<250x1020xf32>
 
   //      CHECK: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : memref<250x1020xf32, #hal.descriptor_type<storage_buffer>>)
-  // CHECK-NEXT: linalg.matmul{{.*}}ins(%{{.*}} : memref<250x500xf32, #hal.descriptor_type<storage_buffer>>, memref<500x1020xf32, #hal.descriptor_type<storage_buffer>>) outs(%{{.*}} : memref<250x1020xf32, #hal.descriptor_type<storage_buffer>>)
-  // CHECK-NEXT: return
+  //      CHECK: linalg.generic
+  // CHECK-SAME: ins(%{{.*}} : memref<250x500xf32, #hal.descriptor_type<storage_buffer>>, memref<500x1020xf32, #hal.descriptor_type<storage_buffer>>)
+  // CHECK-SAME: outs(%{{.*}} : memref<250x1020xf32, #hal.descriptor_type<storage_buffer>>)
+  //  CHECK: return
 
   // workgroup_size is explicitly set to [10, 11].
   // FOREACH-TO-GPU: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = None workgroup_size = [10, 11, 1] subgroup_size = 32>
@@ -64,7 +66,8 @@
   // FOREACH-TO-GPU-DAG:   affine.apply #{{.*}}()[%[[TIDY]]]
   // FOREACH-TO-GPU-DAG:   %[[svB:.*]] = memref.subview {{.*}} : memref<500x1020xf32{{.*}}> to memref<500x?xf32
   // FOREACH-TO-GPU-DAG:   %[[svC:.*]] = memref.subview {{.*}} : memref<250x1020xf32{{.*}}> to memref<?x?xf32
-  // FOREACH-TO-GPU:   linalg.matmul ins(%[[svA]], %[[svB]] : memref<?x500xf32{{.*}}>, memref<500x?xf32{{.*}}>) outs(%[[svC]] : memref<?x?xf32{{.*}}>)
+  // FOREACH-TO-GPU:   linalg.generic
+  // FOREACH-TO-GPU-SAME: ins(%[[svA]], %[[svB]] : memref<?x500xf32{{.*}}>, memref<500x?xf32{{.*}}>) outs(%[[svC]] : memref<?x?xf32{{.*}}>)
   // FOREACH-TO-GPU: }
   // FOREACH-TO-GPU: gpu.barrier
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir
index 9bf2e8f..6fc9222 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir
@@ -6,7 +6,7 @@
     ( mapping = [#gpu.thread<y>, #gpu.thread<x>] )
     : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
 
-    %1 = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.match ops{["linalg.generic"]} in %variant_op : (!transform.any_op) -> !transform.any_op
     %forall_2, %tiled_matmul = transform.structured.tile_using_forall %1 num_threads [7, 9]
     ( mapping = [#gpu.thread<x>, #gpu.thread<y>] )
     : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
index b602099..403f33a 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
@@ -762,6 +762,24 @@
       CodeGenPipeline::SPIRVBaseVectorize, workgroupSize);
 }
 
+static LogicalResult setTilingAndMatmulOpConfig(linalg::LinalgOp op,
+                                                IREE::GPU::TargetAttr target) {
+  if (!isMatmulOrBatchMatmul(op)) {
+    return failure();
+  }
+  // Try to tile and vectorize first. It's common to see 32 threads
+  // per subgroup for GPUs.
+  std::array<int64_t, 2> workgroupXY = {32, 2};
+  std::array<int64_t, 3> threadMNK;
+  auto inputType = cast<ShapedType>(op->getOperand(0).getType());
+  if (IREE::Util::getTypeBitWidth(inputType.getElementType()) == 16) {
+    threadMNK = {8, 8, 8};
+  } else {
+    threadMNK = {8, 8, 4};
+  }
+  return detail::setMatmulOpConfig(target, op, workgroupXY, threadMNK);
+}
+
 } // namespace detail
 
 //===----------------------------------------------------------------------===//
@@ -1509,29 +1527,13 @@
   // Otherwise fallback to use a default configuration that tiles and
   // distributes/vectorizes.
   return TypeSwitch<Operation *, LogicalResult>(rootOp)
-      .Case<linalg::BatchMatmulOp, linalg::MatmulOp>([&](auto op) {
-        // Try to tile and vectorize first. It's common to see 32 threads
-        // per subgroup for GPUs.
-        std::array<int64_t, 2> workgroupXY = {32, 2};
-        std::array<int64_t, 3> threadMNK;
-        auto inputType = llvm::cast<ShapedType>(op.getInputs()[0].getType());
-        if (IREE::Util::getTypeBitWidth(inputType.getElementType()) == 16) {
-          threadMNK = {8, 8, 8};
-        } else {
-          threadMNK = {8, 8, 4};
-        }
-        auto result =
-            detail::setMatmulOpConfig(target, op, workgroupXY, threadMNK);
-        if (succeeded(result))
-          return success();
-
-        LLVM_DEBUG(llvm::dbgs()
-                   << "failed to set matmul op config, trying reduction\n");
-        if (succeeded(setReductionConfig(target, op)))
-          return success();
-
-        // If unsuccessful, try to tile and distribute.
-        return setDefaultOpConfig(target, op);
+      .Case<linalg::MatmulOp, linalg::BatchMatmulOp>([](auto op) {
+        // Assertion is better than returning failure here to
+        // avoid unexpected configurations.
+        assert(false && "named matmul not supported here, pass expects it to "
+                        "generalized first");
+        return op->emitOpError(
+            "named matmul not supported, expected to be generalized first");
       })
       .Case<linalg::ConvolutionOpInterface>([target](auto op) {
         // Use the result type in case of larger bitwidth for accumulators.
@@ -1546,21 +1548,31 @@
           if (succeeded(result))
             return success();
         }
-
         // If unsuccessful, try to tile and distribute/vectorize.
         return setDefaultOpConfig(target, op);
       })
       .Case<linalg::GenericOp>([&](linalg::GenericOp op) {
-        LLVM_DEBUG(llvm::dbgs() << "figuring configuration for generic op\n");
-        if (succeeded(setReductionConfig(target, op)))
+        LLVM_DEBUG(llvm::dbgs() << "configuring for generic op\n");
+        if (succeeded(detail::setTilingAndMatmulOpConfig(op, target))) {
           return success();
+        }
+        LLVM_DEBUG(llvm::dbgs()
+                   << "failed to set matmul op config, trying reduction\n");
+
+        if (succeeded(setReductionConfig(target, op))) {
+          return success();
+        }
+        LLVM_DEBUG(llvm::dbgs() << "failed to set reduction op config");
 
         // If a generic op has reduction iterator types, it can be treated as a
         // root op for configuration as well. Use the default configuration,
         // which will mark it as a root.
         if (op.getNumLoops() != op.getNumParallelLoops()) {
+          LLVM_DEBUG(llvm::dbgs() << "trying default config for generic");
           return setDefaultOpConfig(target, op);
         }
+
+        LLVM_DEBUG(llvm::dbgs() << "failed to set config of generic");
         return failure();
       })
       .Case<IREE::LinalgExt::FftOp>([target](IREE::LinalgExt::FftOp op) {
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_matvec.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_matvec.mlir
index 41c6741..8f07d23 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_matvec.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_matvec.mlir
@@ -1,4 +1,6 @@
-// RUN: iree-opt --split-input-file --iree-gpu-test-target=cdna2@vulkan --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=cdna2@vulkan --pass-pipeline='builtin.module(func.func(iree-codegen-gpu-generalize-named-ops),iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
+
+// Note: above we generalize named ops before selecting the lowering strategy, as selection assumes that some named ops like linalg.matmul have been generalized.
 
 #pipeline_layout = #hal.pipeline.layout<bindings = [
   #hal.pipeline.binding<storage_buffer>,
@@ -432,5 +434,5 @@
 //   CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = SPIRVSubgroupReduce workgroup_size = [64, 1, 1]>
 //       CHECK: func.func @dynamic_batch_matvec()
 //  CHECK-SAME:     translation_info = #[[$TRANSLATION]]
-//       CHECK:   linalg.batch_matmul
+//       CHECK:   linalg.generic
 //  CHECK-SAME:       lowering_config = #[[$CONFIG]]
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matmul.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matmul.mlir
index be0add8..17f89e5 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matmul.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matmul.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
+// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(func.func(iree-codegen-gpu-generalize-named-ops),iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
 
 // Odd K that forbids vectorization.
 
@@ -36,7 +36,7 @@
 //   CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = SPIRVBaseDistribute workgroup_size = [32, 1, 1]>
 //       CHECK: func.func @batch_matmul_1x3x32()
 //  CHECK-SAME:     translation_info = #[[$TRANSLATION]]
-//       CHECK:   linalg.batch_matmul
+//       CHECK:   linalg.generic
 //  CHECK-SAME:       lowering_config = #[[$CONFIG]]
 
 // -----
@@ -76,7 +76,7 @@
 //   CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = SPIRVBaseVectorize workgroup_size = [2, 32, 1]>
 //       CHECK: func.func @matmul_64x16xi8()
 //  CHECK-SAME:     translation_info = #[[$TRANSLATION]]
-//       CHECK:   linalg.matmul
+//       CHECK:   linalg.generic
 //  CHECK-SAME:       lowering_config = #[[$CONFIG]]
 
 // -----
@@ -116,7 +116,7 @@
 //   CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = SPIRVBaseVectorize workgroup_size = [4, 16, 1]>
 //       CHECK: func.func @matmul_64x16xi64()
 //  CHECK-SAME:     translation_info = #[[$TRANSLATION]]
-//       CHECK:   linalg.matmul
+//       CHECK:   linalg.generic
 //  CHECK-SAME:       lowering_config = #[[$CONFIG]]
 
 // -----
@@ -167,8 +167,9 @@
 //   CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = SPIRVBaseDistribute workgroup_size = [32, 2, 1]>
 //       CHECK: func.func @matmul_400x273()
 //  CHECK-SAME:     translation_info = #[[$TRANSLATION]]
-//       CHECK:   linalg.matmul
+//       CHECK:   linalg.generic
 //  CHECK-SAME:       lowering_config = #[[$CONFIG]]
+//       CHECK:   linalg.generic
 
 // -----
 
@@ -218,8 +219,9 @@
 //   CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = SPIRVBaseDistribute workgroup_size = [2, 32, 1]>
 //       CHECK: func.func @matmul_25x546()
 //  CHECK-SAME:     translation_info = #[[$TRANSLATION]]
-//       CHECK:   linalg.matmul
+//       CHECK:   linalg.generic
 //  CHECK-SAME:       lowering_config = #[[$CONFIG]]
+//       CHECK:   linalg.generic
 
 // -----
 
@@ -272,5 +274,6 @@
 //   CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = SPIRVBaseVectorize workgroup_size = [32, 2, 1]>
 //       CHECK: func.func @matmul_pointwise_256x1024()
 //  CHECK-SAME:     translation_info = #[[$TRANSLATION]]
-//       CHECK:   linalg.matmul
+//       CHECK:   linalg.generic
 //  CHECK-SAME:       lowering_config = #[[$CONFIG]]
+//       CHECK:   linalg.generic
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_nvidia_matmul.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_nvidia_matmul.mlir
index c06031f..3620767 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_nvidia_matmul.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_nvidia_matmul.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --iree-gpu-test-target=pascal@vulkan --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=pascal@vulkan --pass-pipeline='builtin.module(func.func(iree-codegen-gpu-generalize-named-ops),iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
 
 #pipeline_layout = #hal.pipeline.layout<bindings = [
   #hal.pipeline.binding<storage_buffer>,
@@ -27,7 +27,7 @@
 //  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = SPIRVMatmulPromoteVectorize workgroup_size = [32, 4, 1], {pipeline_depth = 1 : i64, store_stage = 1 : i64}>
 //      CHECK: func.func @matmul_4x4096x9216()
 // CHECK-SAME:     translation_info = #[[TRANSLATION]]
-//      CHECK:   linalg.matmul
+//      CHECK:   linalg.generic
 // CHECK-SAME:       lowering_config = #[[CONFIG]]
 
 // -----
@@ -57,11 +57,11 @@
   return
 }
 
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 2048], [1, 8], [0, 0, 8]{{\]}}>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = SPIRVBaseVectorize workgroup_size = [256, 1, 1]>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 1], [0, 0, 1024]]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = SPIRVSubgroupReduce workgroup_size = [256, 1, 1]>
 //      CHECK: func.func @matmul_1x4096x9216()
 // CHECK-SAME:     translation_info = #[[TRANSLATION]]
-//      CHECK:   linalg.matmul
+//      CHECK:   linalg.generic
 // CHECK-SAME:       lowering_config = #[[CONFIG]]
 
 // -----
diff --git a/samples/transform_dialect/example_module.mlir b/samples/transform_dialect/example_module.mlir
index 3f0902a..901e4f7 100644
--- a/samples/transform_dialect/example_module.mlir
+++ b/samples/transform_dialect/example_module.mlir
@@ -52,78 +52,92 @@
     ]> : !hal.device
   ]
 } {
-  hal.executable private @example_module_dispatch_0 {
-    hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {iree_codegen.target_info = #target}>) {
-      hal.executable.export public @example_module_dispatch_0_generic_80_f32 ordinal(0) layout(#pipeline_layout_0) count(%arg0: !hal.device) -> (index, index, index) {
-        %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice()
-        hal.return %x, %y, %z : index, index, index
-      }
-      builtin.module {
-        func.func @example_module_dispatch_0_generic_80_f32() {
-          %c0 = arith.constant 0 : index
-          %0 = hal.interface.binding.subspan layout(#pipeline_layout_0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<80xf32>>
-          %1 = hal.interface.binding.subspan layout(#pipeline_layout_0) binding(1) alignment(64) offset(%c0) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<80xf32>>
-          %2 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0], sizes = [80], strides = [1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<80xf32>> -> tensor<80xf32>
-          %3 = tensor.empty() : tensor<80xf32>
-          %4 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%2 : tensor<80xf32>) outs(%3 : tensor<80xf32>) {
-          ^bb0(%in: f32, %out: f32):
-            %5 = arith.addf %in, %in : f32
-            linalg.yield %5 : f32
-          } -> tensor<80xf32>
-          iree_tensor_ext.dispatch.tensor.store %4, %1, offsets = [0], sizes = [80], strides = [1] : tensor<80xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<80xf32>>
-          return
-        }
+
+
+// The linalg.add (expressed as a linalg.generic).
+hal.executable private @example_module_dispatch_0 {
+  hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {iree_codegen.target_info = #target}>) {
+    hal.executable.export public @example_module_dispatch_0_generic_80_f32 ordinal(0) layout(#pipeline_layout_0) count(%arg0: !hal.device) -> (index, index, index) {
+      %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice()
+      hal.return %x, %y, %z : index, index, index
+    }
+    builtin.module {
+      func.func @example_module_dispatch_0_generic_80_f32() {
+        %c0 = arith.constant 0 : index
+        %0 = hal.interface.binding.subspan layout(#pipeline_layout_0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<80xf32>>
+        %1 = hal.interface.binding.subspan layout(#pipeline_layout_0) binding(1) alignment(64) offset(%c0) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<80xf32>>
+        %2 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0], sizes = [80], strides = [1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<80xf32>> -> tensor<80xf32>
+        %3 = tensor.empty() : tensor<80xf32>
+        %4 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%2 : tensor<80xf32>) outs(%3 : tensor<80xf32>) {
+        ^bb0(%in: f32, %out: f32):
+          %5 = arith.addf %in, %in : f32
+          linalg.yield %5 : f32
+        } -> tensor<80xf32>
+        iree_tensor_ext.dispatch.tensor.store %4, %1, offsets = [0], sizes = [80], strides = [1] : tensor<80xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<80xf32>>
+        return
       }
     }
   }
-  hal.executable private @example_module_dispatch_1 {
-    hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {iree_codegen.target_info = #target}>) {
-      hal.executable.export public @example_module_dispatch_1_matmul_16x16x5_f32 ordinal(0) layout(#pipeline_layout_1) count(%arg0: !hal.device) -> (index, index, index) {
-        %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice()
-        hal.return %x, %y, %z : index, index, index
-      }
-      builtin.module {
-        func.func @example_module_dispatch_1_matmul_16x16x5_f32() {
-          %c0 = arith.constant 0 : index
-          %0 = hal.interface.binding.subspan layout(#pipeline_layout_1) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<16x5xf32>>
-          %1 = hal.interface.binding.subspan layout(#pipeline_layout_1) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<5x16xf32>>
-          %2 = hal.interface.binding.subspan layout(#pipeline_layout_1) binding(2) alignment(64) offset(%c0) : !iree_tensor_ext.dispatch.tensor<readwrite:tensor<16x16xf32>>
-          %3 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0, 0], sizes = [16, 5], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<16x5xf32>> -> tensor<16x5xf32>
-          %4 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0], sizes = [5, 16], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<5x16xf32>> -> tensor<5x16xf32>
-          %5 = iree_tensor_ext.dispatch.tensor.load %2, offsets = [0, 0], sizes = [16, 16], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readwrite:tensor<16x16xf32>> -> tensor<16x16xf32>
-          %6 = linalg.matmul ins(%3, %4 : tensor<16x5xf32>, tensor<5x16xf32>) outs(%5 : tensor<16x16xf32>) -> tensor<16x16xf32>
-          iree_tensor_ext.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [16, 16], strides = [1, 1] : tensor<16x16xf32> -> !iree_tensor_ext.dispatch.tensor<readwrite:tensor<16x16xf32>>
-          return
-        }
+}
+
+// The linalg.matmul (expressed as a linalg.generic).
+hal.executable private @example_module_dispatch_1 {
+  hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {iree_codegen.target_info = #target}>) {
+    hal.executable.export public @example_module_dispatch_1_matmul_16x16x5_f32 ordinal(0) layout(#pipeline_layout_1) count(%arg0: !hal.device) -> (index, index, index) {
+      %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice()
+      hal.return %x, %y, %z : index, index, index
+    }
+    builtin.module {
+      func.func @example_module_dispatch_1_matmul_16x16x5_f32() {
+        %c0 = arith.constant 0 : index
+        %0 = hal.interface.binding.subspan layout(#pipeline_layout_1) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<16x5xf32>>
+        %1 = hal.interface.binding.subspan layout(#pipeline_layout_1) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<5x16xf32>>
+        %2 = hal.interface.binding.subspan layout(#pipeline_layout_1) binding(2) alignment(64) offset(%c0) : !iree_tensor_ext.dispatch.tensor<readwrite:tensor<16x16xf32>>
+        %3 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0, 0], sizes = [16, 5], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<16x5xf32>> -> tensor<16x5xf32>
+        %4 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0], sizes = [5, 16], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<5x16xf32>> -> tensor<5x16xf32>
+        %5 = iree_tensor_ext.dispatch.tensor.load %2, offsets = [0, 0], sizes = [16, 16], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readwrite:tensor<16x16xf32>> -> tensor<16x16xf32>
+        %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]}
+        ins(%3, %4 : tensor<16x5xf32>, tensor<5x16xf32>) outs(%5 : tensor<16x16xf32>) {
+        ^bb0(%in1: f32, %in2: f32, %out: f32):
+          %7 = arith.mulf %in1, %in2 : f32
+          %8 = arith.addf %out, %7 : f32
+          linalg.yield %8 : f32
+        }-> tensor<16x16xf32>
+        iree_tensor_ext.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [16, 16], strides = [1, 1] : tensor<16x16xf32> -> !iree_tensor_ext.dispatch.tensor<readwrite:tensor<16x16xf32>>
+        return
       }
     }
   }
-  hal.executable private @example_module_dispatch_2 {
-    hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {iree_codegen.target_info = #target}>) {
-      hal.executable.export public @example_module_dispatch_2_generic_16x16_f32 ordinal(0) layout(#pipeline_layout_2) count(%arg0: !hal.device) -> (index, index, index) {
-        %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice()
-        hal.return %x, %y, %z : index, index, index
-      }
-      builtin.module {
-        func.func @example_module_dispatch_2_generic_16x16_f32() {
-          %c0 = arith.constant 0 : index
-          %0 = hal.interface.binding.subspan layout(#pipeline_layout_2) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<16x16xf32>>
-          %1 = hal.interface.binding.subspan layout(#pipeline_layout_2) binding(1) alignment(64) offset(%c0) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<16xf32>>
-          %2 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0, 0], sizes = [16, 16], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<16x16xf32>> -> tensor<16x16xf32>
-          %3 = tensor.empty() : tensor<16xf32>
-          %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%2 : tensor<16x16xf32>) outs(%3 : tensor<16xf32>) {
-          ^bb0(%in: f32, %out: f32):
-            %5 = arith.addf %out, %in : f32
-            linalg.yield %5 : f32
-          } -> tensor<16xf32>
-          iree_tensor_ext.dispatch.tensor.store %4, %1, offsets = [0], sizes = [16], strides = [1] : tensor<16xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<16xf32>>
-          return
-        }
+}
+
+// The linalg.reduce (expressed as a linalg.generic).
+hal.executable private @example_module_dispatch_2 {
+  hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {iree_codegen.target_info = #target}>) {
+    hal.executable.export public @example_module_dispatch_2_generic_16x16_f32 ordinal(0) layout(#pipeline_layout_2) count(%arg0: !hal.device) -> (index, index, index) {
+      %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice()
+      hal.return %x, %y, %z : index, index, index
+    }
+    builtin.module {
+      func.func @example_module_dispatch_2_generic_16x16_f32() {
+        %c0 = arith.constant 0 : index
+        %0 = hal.interface.binding.subspan layout(#pipeline_layout_2) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<16x16xf32>>
+        %1 = hal.interface.binding.subspan layout(#pipeline_layout_2) binding(1) alignment(64) offset(%c0) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<16xf32>>
+        %2 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0, 0], sizes = [16, 16], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<16x16xf32>> -> tensor<16x16xf32>
+        %3 = tensor.empty() : tensor<16xf32>
+        %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%2 : tensor<16x16xf32>) outs(%3 : tensor<16xf32>) {
+        ^bb0(%in: f32, %out: f32):
+          %5 = arith.addf %out, %in : f32
+          linalg.yield %5 : f32
+        } -> tensor<16xf32>
+        iree_tensor_ext.dispatch.tensor.store %4, %1, offsets = [0], sizes = [16], strides = [1] : tensor<16xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<16xf32>>
+        return
       }
     }
   }
 }
 
+}
+
 /// We test first with threading off so that the printers are legible.
 // RUN: iree-compile %s \
 // RUN:   --iree-hal-target-device=vulkan \
diff --git a/samples/transform_dialect/transform_library.mlir b/samples/transform_dialect/transform_library.mlir
index 678708e..1a1e399 100644
--- a/samples/transform_dialect/transform_library.mlir
+++ b/samples/transform_dialect/transform_library.mlir
@@ -4,9 +4,9 @@
   // default IREE codegen.
   transform.named_sequence @custom_transform_strategy(
       %variant_op: !transform.any_op) {
-    // Step 1. Re-match the matmul
+    // Step 1. Re-match the matmul. It is now a linalg.generic, because of op generalization in the SPIRV pipeline
     // ===========================================================================
-    %matmul = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    %matmul = transform.structured.match ops{["linalg.generic"]} in %variant_op : (!transform.any_op) -> !transform.any_op
 
     // Step 2. Tile to grid
     // ===========================================================================
@@ -48,7 +48,6 @@
     transform.iree.forall_to_workgroup %func_7 : (!transform.any_op) -> ()
     transform.iree.map_nested_forall_to_gpu_threads %func_7
         workgroup_dims = [4, 8, 1] : (!transform.any_op) -> ()
-
     transform.print {name = "Ran custom_transform_strategy"}
     transform.yield
   }
@@ -79,9 +78,24 @@
   //===------------------------------------------------------===
   // Matchers
   //===------------------------------------------------------===
-  transform.named_sequence @match_matmul(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op) {
-    transform.match.operation_name %matmul ["linalg.matmul"] : !transform.any_op
-    transform.yield %matmul : !transform.any_op
+
+  // Match for our matmul (a linalg.generic with certain properties).
+  transform.named_sequence @match_matmul(%root: !transform.any_op {transform.readonly}) -> !transform.any_op {
+    transform.match.operation_name %root ["linalg.generic"] : !transform.any_op
+    %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %root {
+      ^bb0(%lhs: tensor<16x5xf32>, %rhs: tensor<5x16xf32>, %out: tensor<16x16xf32>):
+      %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+                                            affine_map<(d0, d1, d2) -> (d2, d1)>,
+                                            affine_map<(d0, d1, d2) -> (d0, d1)>],
+                            iterator_types = ["parallel", "parallel", "reduction"]}
+          ins(%lhs, %rhs : tensor<16x5xf32>, tensor<5x16xf32>) outs(%out : tensor<16x16xf32>) {
+        ^bb0(%in: f32, %in_0: f32, %acc: f32):
+          %10 = arith.mulf %in, %in_0 : f32
+          %11 = arith.addf %acc, %10 : f32
+          linalg.yield %11 : f32
+        } -> tensor<16x16xf32>
+    } : (!transform.any_op) -> (!transform.any_value, !transform.any_value)
+    transform.yield %root : !transform.any_op
   }
 
   transform.named_sequence @match_reduce(%reduce: !transform.any_op {transform.readonly}) -> (!transform.any_op) {
@@ -91,7 +105,6 @@
       %c2 = transform.param.constant 2 : i64 -> !transform.param<i64>
       %rank = transform.match.structured.rank %arg1 : (!transform.any_op) -> !transform.param<i64>
       transform.match.param.cmpi eq %rank, %c2 : !transform.param<i64>
-
       transform.match.structured.dim %arg1[-1] {reduction} : !transform.any_op
       transform.match.structured.yield %arg1 : !transform.any_op
     }