[Codegen][LLVMGPU]  Give ops same config irrespective of generalized/specialized (#21769)

This is part of deprecating the warp reduce pipeline, Before this PR,
generalized and specialized ops got different configs (confusingly!).
This PR changes some logic that checks if a linalg op is a
linalg.generic, to include all named ops (except, for now, convolution
ops).

---------

Signed-off-by: James Newling <james.newling@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUGeneralizeNamedOps.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUGeneralizeNamedOps.cpp
index 5055446..5d37563 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUGeneralizeNamedOps.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUGeneralizeNamedOps.cpp
@@ -58,9 +58,11 @@
     FunctionOpInterface funcOp = getOperation();
     SmallVector<linalg::LinalgOp> namedOpCandidates;
     funcOp.walk([&](linalg::LinalgOp linalgOp) {
-      if (isa<linalg::BatchMatmulOp, linalg::MatmulOp, linalg::MatvecOp,
-              linalg::TransposeOp, linalg::VecmatOp>(linalgOp.getOperation()))
+      if (isa<linalg::BatchMatmulOp, linalg::DotOp, linalg::MatmulOp,
+              linalg::MatvecOp, linalg::TransposeOp, linalg::VecmatOp>(
+              linalgOp.getOperation())) {
         namedOpCandidates.push_back(linalgOp);
+      }
     });
 
     if (failed(generalizeCandidates(&getContext(), namedOpCandidates))) {
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
index 904e59c..08cabc1 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
@@ -142,8 +142,10 @@
   let summary = "Convert named Linalg ops to linalg.generic ops";
 
   let description = [{
-    Convert a whitelisted set of named Linalg ops to linalg.generics. The whitelist
-    does not contain all named ops.
+    Convert a subset of named Linalg ops to linalg.generics. The subset does not
+    contain all named ops. The rule-of-thumb is that named ops whose operand
+    maps are projections are in the subset. For example convolutions and pooling ops
+    are not generalized by this pass, but matmuls are.
   }];
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index c33039b..886b176 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -19,12 +19,10 @@
 #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
 #include "iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h"
 #include "iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.h"
-#include "iree/compiler/Codegen/Interfaces/UKernelOpInterface.h"
 #include "iree/compiler/Codegen/LLVMGPU/Passes.h"
 #include "iree/compiler/Codegen/Utils/GPUUtils.h"
 #include "iree/compiler/Codegen/Utils/LinalgOpInfo.h"
 #include "iree/compiler/Codegen/Utils/Utils.h"
-#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
 #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
@@ -292,15 +290,13 @@
 //====---------------------------------------------------------------------===//
 //
 
-static bool isMatmulLike(linalg::LinalgOp &linalgOp) {
+static bool isMatmulLike(linalg::LinalgOp linalgOp) {
   return linalg::isaContractionOpInterface(linalgOp) &&
          linalgOp.getNumParallelLoops() >= 1;
 };
 
-/// Check if `op` is a linalg.reduce or a linalg.generic that has at least one
-/// reduction iterator.
-static bool hasReductionIterator(linalg::LinalgOp &op) {
-  return isa<linalg::ReduceOp, linalg::GenericOp>(op) &&
+static bool hasReductionIterator(linalg::LinalgOp op) {
+  return !linalg::isaConvolutionOpInterface(op) &&
          llvm::any_of(op.getIteratorTypesArray(), linalg::isReductionIterator);
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matmul.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matmul.mlir
index 3c0e602..fff4944 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matmul.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matmul.mlir
@@ -100,20 +100,16 @@
 
 // -----
 
-// TODO(newling) specialized form should be the same as generalized form.
-
-//      GENERALIZED:         #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
-// GENERALIZED-SAME:         workgroup_size = [1024, 1, 1] subgroup_size = 64,
-// GENERALIZED-SAME:         {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = false, no_reduce_shared_memory_bank_conflicts = false, use_igemm_convolution = false>}>
-
-//      SPECIALIZED:         #iree_codegen.translation_info<pipeline = LLVMGPUWarpReduction workgroup_size = [64, 1, 1] subgroup_size = 64>
+//      CHECK:         #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
+// CHECK-SAME:         workgroup_size = [1024, 1, 1] subgroup_size = 64,
+// CHECK-SAME:         {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = false, no_reduce_shared_memory_bank_conflicts = false, use_igemm_convolution = false>}>
 !TA = tensor<?x4096xf32>
 !TB = tensor<4096x4096xf32>
 !TC = tensor<?x4096xf32>
 !DTC = !iree_tensor_ext.dispatch.tensor<readwrite:tensor<?x4096xf32>>
 func.func @matmul_DYN_4096_4096(%arg0: !TA, %arg1: !TB, %arg2: !TC, %arg3: !DTC, %arg4 : index) {
-//      GENERALIZED:         #iree_gpu.lowering_config<{lane_basis =  {{\[}}[1, 1, 64], [0, 1, 2]],
-// GENERALIZED-SAME:         partial_reduction = [0, 0, 4096], subgroup_basis =  {{\[}}[1, 1, 16], [0, 1, 2]], thread = [0, 0, 4], workgroup = [1, 1, 0]}
+//      CHECK:         #iree_gpu.lowering_config<{lane_basis =  {{\[}}[1, 1, 64], [0, 1, 2]],
+// CHECK-SAME:         partial_reduction = [0, 0, 4096], subgroup_basis =  {{\[}}[1, 1, 16], [0, 1, 2]], thread = [0, 0, 4], workgroup = [1, 1, 0]}
   %0 = linalg.matmul ins(%arg0, %arg1 : !TA, !TB) outs(%arg2 : !TC) -> !TC
   iree_tensor_ext.dispatch.tensor.store %0, %arg3, offsets = [0, 0], sizes = [%arg4, 4096], strides = [1, 1] : !TC -> !DTC{%arg4}
   return
@@ -190,12 +186,8 @@
 // Dynamic all
 // ============================================================================
 
-
-// TODO(newling) specialized form should follow generalized form.
-
-//     SPECIALIZED:      LLVMGPUWarpReduction
-//     GENERALIZED:      LLVMGPUVectorDistribute
-// GENERALIZED-SAME:     workgroup_size = [1024, 1, 1] subgroup_size = 64,
+//     CHECK:      LLVMGPUVectorDistribute
+// CHECK-SAME:     workgroup_size = [1024, 1, 1] subgroup_size = 64,
 
 
 !TA = tensor<?x?xf32>
@@ -204,9 +196,9 @@
 !DTC = !iree_tensor_ext.dispatch.tensor<readwrite:tensor<?x?xf32>>
 func.func @matmul_DYN_1_4096(%arg0: !TA, %arg1: !TB, %arg2: !TC, %arg3: !DTC, %arg4 : index, %arg5 : index, %arg6 : index) {
 
-  //      GENERALIZED:     {lane_basis = {{\[}}[1, 1, 64], [0, 1, 2]],
-  // GENERALIZED-SAME:     partial_reduction = [0, 0, 4096], subgroup_basis = {{\[}}[1, 1, 16], [0, 1, 2]],
-  // GENERALIZED-SAME:     thread = [0, 0, 4], workgroup = [1, 1, 0]}
+  //      CHECK:     {lane_basis = {{\[}}[1, 1, 64], [0, 1, 2]],
+  // CHECK-SAME:     partial_reduction = [0, 0, 4096], subgroup_basis = {{\[}}[1, 1, 16], [0, 1, 2]],
+  // CHECK-SAME:     thread = [0, 0, 4], workgroup = [1, 1, 0]}
   %0 = linalg.matmul ins(%arg0, %arg1 : !TA, !TB) outs(%arg2 : !TC) -> !TC
   iree_tensor_ext.dispatch.tensor.store %0, %arg3, offsets = [0, 0], sizes = [%arg4, %arg5], strides = [1, 1] : !TC -> !DTC{%arg4, %arg5}
   return
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir
index 0919ace..abb5c9d 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir
@@ -24,11 +24,9 @@
 }
 
 
-// CHECK:     LLVMGPUWarpReduction
+// CHECK:     LLVMGPUVectorDistribute
 // CDNA3:     LLVMGPUTileAndFuse
 
-// We want to deprecate LLVMGPUWarpReduction. Currently LLVMGPUVectorDistribution is not chosen in setReductionVectorDistributionConfig because it fails in 'hasReductionIterator' (which doesn't check specialized ops). This might be an easy whitelisting fix, but I will return to this later (TODO(newling)).
-
 // -----
 
 #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
@@ -325,7 +323,7 @@
   #hal.pipeline.binding<storage_buffer>,
   #hal.pipeline.binding<storage_buffer>
 ]>
-func.func @skinny_mmt() {
+func.func @skinny_mmt_lhs_is_vector() {
   %c0 = arith.constant 0 : index
   %cst = arith.constant 0.000000e+00 : f16
   %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x4096xf16>>
@@ -347,16 +345,19 @@
   return
 }
 
-// CHECK-DAG: #[[$MA:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
-// CHECK-DAG: #[[$MB:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
-// CHECK-DAG: #[[$MC:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
-
-//   CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 1], [0, 0, 512]{{\]}}>
-//       CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUWarpReduction workgroup_size = [64, 1, 1] subgroup_size = 64>
-//       CHECK: func.func @skinny_mmt()
-//  CHECK-SAME:     translation_info = #[[$TRANSLATION]]
-//       CHECK:   linalg.matmul indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]]
-//  CHECK-SAME:       lowering_config = #[[$CONFIG]]
+//  CHECK-DAG: #[[$MA:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+//  CHECK-DAG: #[[$MB:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+//  CHECK-DAG: #[[$MC:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+//      CHECK: pipeline = LLVMGPUVectorDistribute workgroup_size = [64, 1, 1] subgroup_size = 64
+//      CHECK: linalg.fill
+//      CHECK: linalg.matmul
+// CHECK-SAME: indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]]
+// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{
+// CHECK-SAME:       lane_basis =        {{\[}}[1, 1, 64], [0, 1, 2]],
+// CHECK-SAME:       partial_reduction =       [0, 0, 512],
+// CHECK-SAME:       subgroup_basis =    {{\[}}[1, 1, 1], [0, 1, 2]],
+// CHECK-SAME:       thread =                  [0, 0, 8],
+// CHECK-SAME:       workgroup =               [1, 1, 0]}>}
 
 // -----
 
@@ -367,7 +368,7 @@
   #hal.pipeline.binding<storage_buffer>,
   #hal.pipeline.binding<storage_buffer>
 ]>
-func.func @skinny_mmt() {
+func.func @skinny_mmt_lhs_is_matrix() {
   %c0 = arith.constant 0 : index
   %cst = arith.constant 0.000000e+00 : f16
   %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x4096xf16>>
@@ -389,12 +390,16 @@
   return
 }
 
-//   CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 1], [0, 0, 512]{{\]}}>
-//       CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUWarpReduction workgroup_size = [64, 1, 1] subgroup_size = 64>
-//       CHECK: func.func @skinny_mmt()
-//  CHECK-SAME:     translation_info = #[[$TRANSLATION]]
-//       CHECK:   linalg.matmul indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]]
-//  CHECK-SAME:       lowering_config = #[[$CONFIG]]
+//      CHECK: pipeline = LLVMGPUVectorDistribute workgroup_size = [64, 1, 1] subgroup_size = 64
+//      CHECK: linalg.fill
+//      CHECK: linalg.matmul
+// CHECK-SAME: indexing_maps
+// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{
+// CHECK-SAME:       lane_basis =        {{\[}}[1, 1, 64], [0, 1, 2]],
+// CHECK-SAME:       partial_reduction =       [0, 0, 512],
+// CHECK-SAME:       subgroup_basis =    {{\[}}[1, 1, 1], [0, 1, 2]],
+// CHECK-SAME:       thread =                  [0, 0, 8],
+// CHECK-SAME:       workgroup =               [8, 1, 0]}>}
 
 // -----