[CUDA] Refactor op configuration (#7199)

Recommit revert changes.

Don't set thread level tile size in configuration but calculate it based
on workloadPerWorkgroup size and workgroup size. This makes it simpler
to pick a configuration.

Explicitly disable second level of tiling for FFT op as it is currently
not supported.
diff --git a/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index d0e796a..e839d2a 100644
--- a/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -107,12 +107,6 @@
   // Tile all the reduction dimensions.
   ts.append(op.getNumReductionLoops(), tileK);
   tileSizes.push_back(ts);  // Workgroup level.
-  tileSizes.push_back({});  // Subgroup level.
-  // At the thread level only tile parallel loops.
-  SmallVector<int64_t, 4> invocationLevelTs(op.getNumParallelLoops() - 2, 1);
-  invocationLevelTs.append(
-      {tileX / workgroupSize[1], tileY / workgroupSize[0]});
-  tileSizes.push_back(invocationLevelTs);  // Thread level.
   return setOpConfigAndEntryPointFnTranslation(
       entryPoint, op, tileSizes, /*nativeVectorSizes=*/ArrayRef<int64_t>{},
       IREE::HAL::DispatchLoweringPassPipeline::LLVMGPUMatmulSimt,
@@ -163,15 +157,13 @@
 
   std::array<int64_t, 3> workgroupSize = {cudaWarpSize, 1, 1};
   unsigned vectorSize = 4;
-  SmallVector<int64_t, 4> workgroupTileSizes(numLoops, 1),
-      threadTileSizes(numLoops, 1);
+  SmallVector<int64_t, 4> workgroupTileSizes(numLoops, 1);
   // Set all non-parallel loops to zero tile size.
   llvm::DenseSet<unsigned> partitionedLoopsSet(partitionedLoops.begin(),
                                                partitionedLoops.end());
   for (auto depth : llvm::seq<int64_t>(0, numLoops)) {
     if (!partitionedLoopsSet.count(depth)) {
       workgroupTileSizes[depth] = 0;
-      threadTileSizes[depth] = 0;
     }
   }
 
@@ -206,7 +198,6 @@
   for (int64_t depth = numLoops; depth > 0; depth--) {
     if (partitionedLoopsSet.count(depth - 1)) {
       workgroupTileSizes[depth - 1] = cudaWarpSize * vectorSize;
-      threadTileSizes[depth - 1] = vectorSize;
       break;
     }
   }
@@ -217,8 +208,6 @@
     workgroupTileSizes.append(linalgOp.getNumReductionLoops(), 1);
   }
   tileSizes.emplace_back(std::move(workgroupTileSizes));  // Workgroup level
-  tileSizes.push_back({});                                // Subgroup level.
-  tileSizes.emplace_back(std::move(threadTileSizes));     // Thread level
   return setOpConfigAndEntryPointFnTranslation(
       entryPoint, op, tileSizes, /*nativeVectorSizes=*/ArrayRef<int64_t>{},
       IREE::HAL::DispatchLoweringPassPipeline::LLVMGPUVectorize, workgroupSize);
diff --git a/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp b/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp
index 3945bc3..165d711 100644
--- a/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp
@@ -69,11 +69,16 @@
 /// Patterns for thread level tiling.
 static void populateTilingToInvocationPatterns(
     MLIRContext *context, OwningRewritePatternList &patterns,
-    ArrayRef<int64_t> workgroupSize) {
+    ArrayRef<int64_t> workgroupSize, ArrayRef<int64_t> workloadPerWorkgroup) {
   linalg::TileSizeComputationFunction getInnerTileSizeFn =
-      [](OpBuilder &builder, Operation *operation) {
+      [&](OpBuilder &builder, Operation *operation) {
         SmallVector<Value, 4> tileSizesVal;
-        SmallVector<int64_t, 4> tileSizes = getTileSizes(operation, 2);
+        SmallVector<int64_t, 4> tileSizes;
+        for (auto workload : llvm::enumerate(workloadPerWorkgroup)) {
+          tileSizes.push_back(workload.value() /
+                              workgroupSize[workload.index()]);
+        }
+        std::reverse(tileSizes.begin(), tileSizes.end());
         if (tileSizes.empty()) return SmallVector<Value, 4>();
         SmallVector<unsigned> partitionedLoops = getPartitionedLoops(operation);
         llvm::DenseSet<unsigned> partitionedLoopsSet(partitionedLoops.begin(),
@@ -120,7 +125,11 @@
           {Identifier::get(getWorkgroupMarker(), context),
            Identifier::get(getWorkgroupKTiledMarker(), context),
            Identifier::get(getWorkgroupMemoryMarker(), context)},
-          Identifier::get(getVectorizeMarker(), context)));
+          Identifier::get(getVectorizeMarker(), context))
+          .addFilter([](Operation *op) {
+            // FFT doesn't support second level of tiling yet.
+            return success(!isa<linalg_ext::FftOp>(op));
+          }));
 }
 
 static LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst) {
@@ -229,6 +238,12 @@
     auto workgroupSize = llvm::to_vector<4>(llvm::map_range(
         getEntryPoint(funcOp).workgroup_size().getValue(),
         [&](Attribute attr) { return attr.cast<IntegerAttr>().getInt(); }));
+    auto workloadPerWorkgroup = llvm::to_vector<4>(llvm::map_range(
+        getTranslationInfo(getEntryPoint(funcOp))
+            .workloadPerWorkgroup()
+            .getValue(),
+        [&](Attribute attr) { return attr.cast<IntegerAttr>().getInt(); }));
+
     int64_t flatWorkgroupSize =
         workgroupSize[0] * workgroupSize[1] * workgroupSize[2];
     // Only promote to workgroup size if there are multiple warps.
@@ -271,7 +286,7 @@
       // Apply last level of tiling and distribute to threads.
       OwningRewritePatternList threadLevelTilingPatterns(context);
       populateTilingToInvocationPatterns(context, threadLevelTilingPatterns,
-                                         workgroupSize);
+                                         workgroupSize, workloadPerWorkgroup);
       (void)applyPatternsAndFoldGreedily(funcOp,
                                          std::move(threadLevelTilingPatterns));
     }
diff --git a/iree/compiler/Codegen/LLVMGPU/test/distribute_to_thread.mlir b/iree/compiler/Codegen/LLVMGPU/test/distribute_to_thread.mlir
index b24f3ff..03cc0c5 100644
--- a/iree/compiler/Codegen/LLVMGPU/test/distribute_to_thread.mlir
+++ b/iree/compiler/Codegen/LLVMGPU/test/distribute_to_thread.mlir
@@ -1,6 +1,6 @@
 // RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(builtin.func(iree-llvmgpu-tile-and-distribute))))' %s | IreeFileCheck %s
 
-#config = {tileSizes = [[2, 256, 4], [], [2, 4]]}
+#config = {tileSizes = [[2, 256, 4]]}
 #executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb">
 #map0 = affine_map<()[s0] -> (s0 * 2)>
 #map1 = affine_map<()[s0] -> (s0 * 256)>
@@ -12,6 +12,7 @@
   hal.executable.entry_point @dot_dispatch_0 attributes {
     interface = @legacy_io,
     ordinal = 0 : index,
+    translation.info = {passPipeline = "LLVMGPUMatmulSimt" : i32, workloadPerWorkgroup = [256, 2]},
     workgroup_size = [64 : index, 1 : index, 1 : index]}
   builtin.module  {
     builtin.func @dot_dispatch_0() {
@@ -92,7 +93,7 @@
     hal.executable.entry_point @predict_dispatch_153 attributes {
       interface = @io,
       ordinal = 0 : index,
-      translation.info = {passPipeline = 2 : i32},
+      translation.info = {passPipeline = "LLVMGPUVectorize" : i32},
       workgroup_size = [1: index, 1: index, 1: index]}
     builtin.module  {
       builtin.func @predict_dispatch_153() {
diff --git a/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir b/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir
index 07cd247..d72539c 100644
--- a/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir
+++ b/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir
@@ -33,7 +33,7 @@
   }
 }
 
-//  CHECK-DAG: #[[CONFIG:.+]] = {tileSizes = {{\[}}[128], [], [4]{{\]}}}
+//  CHECK-DAG: #[[CONFIG:.+]] = {tileSizes = {{\[}}[128]{{\]}}}
 //  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 128)>
 //      CHECK: hal.executable.entry_point public @add_dispatch_0
 // CHECK-SAME:     passPipeline = "LLVMGPUVectorize"
@@ -92,7 +92,7 @@
     }
   }
 }
-//  CHECK-DAG: #[[CONFIG:.+]] = {tileSizes = {{\[}}[4, 2, 4], [], [1, 1]{{\]}}}
+//  CHECK-DAG: #[[CONFIG:.+]] = {tileSizes = {{\[}}[4, 2, 4]{{\]}}}
 //  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
 //      CHECK: hal.executable.entry_point public @dot_dispatch_1
@@ -246,7 +246,7 @@
     }
   }
 }
-//  CHECK-DAG: #[[CONFIG:.+]] = {tileSizes = {{\[}}[1, 128], [], [1, 4]{{\]}}}
+//  CHECK-DAG: #[[CONFIG:.+]] = {tileSizes = {{\[}}[1, 128]{{\]}}}
 //  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 128)>
 //      CHECK: hal.executable.entry_point public @tensor_insert_slice
 // CHECK-SAME:   translation.info = {passPipeline = "LLVMGPUVectorize", workloadPerWorkgroup = [128, 1]}
@@ -401,7 +401,7 @@
           %14 = affine.min affine_map<(d0)[s0] -> (-d0 + 1024, s0)>(%arg1)[%workgroup_size_x]
           %15 = linalg.init_tensor [%13, %14] : tensor<?x?xf32>
           %16 = linalg.fill(%cst, %15) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
-          %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup", lowering.config = {passPipeline = "LLVMGPUMatmulSimt", tileSizes = [[32, 256, 64], [], [4, 16]], workgroupSize = [16, 8, 1]}} ins(%8, %10 : tensor<?x256xf32>, tensor<256x?xf32>) outs(%16 : tensor<?x?xf32>) -> tensor<?x?xf32>
+          %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup", lowering.config = {passPipeline = "LLVMGPUMatmulSimt", tileSizes = [[32, 256, 64]], workgroupSize = [16, 8, 1]}} ins(%8, %10 : tensor<?x256xf32>, tensor<256x?xf32>) outs(%16 : tensor<?x?xf32>) -> tensor<?x?xf32>
           flow.dispatch.tensor.store %17, %2, offsets = [%arg0, %arg1], sizes = [%11, %12], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:128x1024xf32>
         }
       }
@@ -416,7 +416,7 @@
 }
 }
 
-//  CHECK-DAG: #[[CONFIG:.+]] = {{{.*}}tileSizes = {{\[}}[32, 256, 64], [], [4, 16]{{\]}}}
+//  CHECK-DAG: #[[CONFIG:.+]] = {{{.*}}tileSizes = {{\[}}[32, 256, 64]{{\]}}}
 //      CHECK: hal.executable.entry_point public @_lowering_config_test_dispatch_1
 // CHECK-SAME:     passPipeline = "LLVMGPUMatmulSimt"
 // CHECK-SAME:     workloadPerWorkgroup = [256, 32]