[LLVMGPU] Improve how we distribute small inner shapes (#12368)

Improve the default kernel configuration for cases where the inner
dimensions are smaller than the workgroup size we want to target. Also
fix a bug in tile and distribute for loops that are not distributed.
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
index b9e0b8e..47df80d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
@@ -203,6 +203,7 @@
     // slowest varying.
     SmallVector<Value> numWorkgroups;
     for (auto partitionedLoop : llvm::reverse(partitionedLoops)) {
+      if (isConstantIntValue(tileSizes[partitionedLoop], 0)) continue;
       Value numTileAlongDim = getValueOrCreateConstantIndexOp(
           rewriter, loc, numTiles[partitionedLoop]);
       if (numWorkgroups.size() == kNumMaxParallelDims) {
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir
index ab7d75a..5185b61 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir
@@ -335,6 +335,80 @@
 //      CHECK:         %[[GENERIC:.+]] = linalg.generic
 //      CHECK:         flow.dispatch.tensor.store %[[GENERIC]], %{{.+}}, offsets = [%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
 
+
+// -----
+
+#config = #iree_codegen.lowering_config<tile_sizes = [[2, 64, 0, 64], [1, 1, 1, 4], [0, 0, 0, 0]]>
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>,
+    #hal.descriptor_set.binding<2, storage_buffer>,
+    #hal.descriptor_set.binding<3, storage_buffer>
+  ]>
+]>
+#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {
+  data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+  native_vector_size = 16 : index,
+  target_triple = "x86_64-unknown-linux-gnu"}>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+#translation = #iree_codegen.translation_info<CPUDoubleTilingExpert>
+hal.executable private @add_distribute4D_zero_tile_size {
+  hal.executable.variant public @llvm, target = #executable_target_embedded_elf_x86_64_ {
+    hal.executable.export public @add_distribute4D_zero_tile_size layout(#pipeline_layout) attributes {translation_info = #translation} {
+    ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index, %arg3 : index, %arg4 :index):
+      %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3, %arg4
+      hal.return %x, %y, %z : index, index, index
+    }
+    builtin.module {
+      func.func @add_distribute4D_zero_tile_size() {
+        %0 = hal.interface.constant.load[0] : index
+        %1 = hal.interface.constant.load[1] : index
+        %2 = hal.interface.constant.load[2] : index
+        %3 = hal.interface.constant.load[3] : index
+        %4 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(32)
+            : !flow.dispatch.tensor<readonly:tensor<?x?x?x?xf32>>{%0, %1, %2, %3}
+        %5 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(32)
+            : !flow.dispatch.tensor<readonly:tensor<?x?x?x?xf32>>{%0, %1, %2, %3}
+        %6 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(32)
+            : !flow.dispatch.tensor<writeonly:tensor<?x?x?x?xf32>>{%0, %1, %2, %3}
+        %7 = flow.dispatch.tensor.load %4, offsets = [0, 0, 0, 0], sizes = [%0, %1, %2, %3], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:tensor<?x?x?x?xf32>>{%0, %1, %2, %3} -> tensor<?x?x?x?xf32>
+        %8 = flow.dispatch.tensor.load %5, offsets = [0, 0, 0, 0], sizes = [%0, %1, %2, %3], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:tensor<?x?x?x?xf32>>{%0, %1, %2, %3} -> tensor<?x?x?x?xf32>
+        %9 = tensor.empty(%0, %1, %2, %3) : tensor<?x?x?x?xf32>
+        %10 = linalg.generic {
+            indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+            ins(%7, %8 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) outs(%9 : tensor<?x?x?x?xf32>) attrs =  {lowering_config = #config} {
+        ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
+          %11 = arith.addf %arg0, %arg1 : f32
+          linalg.yield %11 : f32
+        } -> tensor<?x?x?x?xf32>
+        flow.dispatch.tensor.store %10, %6, offsets = [0, 0, 0, 0], sizes = [%0, %1, %2, %3], strides = [1, 1, 1, 1]
+            : tensor<?x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:tensor<?x?x?x?xf32>>{%0, %1, %2, %3}
+        return
+      }
+    }
+  }
+}
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
+
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDoubleTilingExpert>
+//      CHECK: hal.executable.export public @add_distribute4D_zero_tile_size
+// CHECK-SAME:   translation_info = #[[TRANSLATION]]
+// CHECK-NEXT:   (%[[DEVICE:.+]]: !hal.device,
+// CHECK-SAME:    %[[WORKLOAD_0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:    %[[WORKLOAD_1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:    %[[WORKLOAD_2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:    %[[WORKLOAD_3:[a-zA-Z0-9_]+]]: index)
+//  CHECK-DAG:    %[[D0:.+]] = affine.apply #[[MAP1]]()[%[[WORKLOAD_0]]]
+//  CHECK-DAG:    %[[D1:.+]] = affine.apply #[[MAP]]()[%[[WORKLOAD_1]]]
+//  CHECK-DAG:    %[[D2:.+]] = affine.apply #[[MAP]]()[%[[WORKLOAD_3]]]
+//      CHECK:    hal.return %[[D2]], %[[D1]], %[[D0]] : index, index, index
+//      CHECK: func.func @add_distribute4D_zero_tile_size()
+
+
 // -----
 
 #config = #iree_codegen.lowering_config<tile_sizes = [[1, 64, 64, 0], [1, 16, 4, 0], [0, 0, 0, 64]]>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 731c2fc..d093133 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -428,7 +428,7 @@
       workgroupTileSizes[depth] = 0;
     }
   }
-
+  int64_t skipInnerTiling = 0;
   if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) {
     for (auto [index, outputOperand] :
          llvm::enumerate(genericOp.getDpsInitOperands())) {
@@ -462,6 +462,26 @@
         vectorSize = 1;
         break;
       }
+      // If the inner dimension is too small to have one element per thread
+      // reduce the workgroup size try to distribute amongst more dimensions.
+      if (shape.back() < vectorSize * workgroupSize[0]) {
+        int64_t flatWG = workgroupSize[0];
+        vectorSize = 1;
+        int64_t id = 0;
+        for (int64_t dim : llvm::reverse(shape)) {
+          if (dim < flatWG) {
+            skipInnerTiling++;
+            workgroupSize[id] = dim;
+          } else {
+            workgroupSize[id] = flatWG;
+            break;
+          }
+          flatWG = flatWG / dim;
+          id++;
+          if (flatWG <= 1 || id >= workgroupSize.size()) break;
+        }
+        break;
+      }
     }
   }
 
@@ -480,13 +500,24 @@
         IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUVectorize;
   }
 
+  int64_t id = 0;
   // Set the inner most parallel loop to `lowerTs`.
   for (int64_t depth = numLoops; depth > 0; depth--) {
     if (partitionedLoopsSet.count(depth - 1)) {
-      workgroupTileSizes[depth - 1] = workgroupSize[0] * vectorSize;
+      if (skipInnerTiling > 0) {
+        // For dimensions that don't need to be distributed across blocks skip
+        // tiling by setting tile size to 0.
+        workgroupTileSizes[depth - 1] = 0;
+        skipInnerTiling--;
+        id++;
+        if (id >= workgroupSize.size()) break;
+        continue;
+      }
+      workgroupTileSizes[depth - 1] = workgroupSize[id] * vectorSize;
       break;
     }
   }
+
   if (linalgOp) {
     // Tile reduction dimension to 4 to allow doing load4 if the reduction size
     // is the most inner dimension.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD
index a60227e..a80cc50 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD
@@ -25,6 +25,7 @@
             "create_async_groups.mlir",
             "distribute_to_thread.mlir",
             "distribute_foreach.mlir",
+            "elementwise_pipeline.mlir",
             "gpu_set_num_workgroups.mlir",
             "nvvm_pipeline_test.mlir",
             "nvvm_mma_sync_pipeline_test.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index 78b834c..49e37e8 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -21,6 +21,7 @@
     "create_async_groups.mlir"
     "distribute_foreach.mlir"
     "distribute_to_thread.mlir"
+    "elementwise_pipeline.mlir"
     "gpu_set_num_workgroups.mlir"
     "illegal_configuration.mlir"
     "legalize.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/elementwise_pipeline.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/elementwise_pipeline.mlir
new file mode 100644
index 0000000..af2fce0
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/elementwise_pipeline.mlir
@@ -0,0 +1,36 @@
+
+// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target)))" %s | FileCheck %s
+
+hal.executable @warp_reduction_dispatch {
+hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", {target_arch = "sm_35"}> {
+  hal.executable.export public @forward_dispatch_0_generic_320x320x3x3 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>) {
+  ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index, %arg4: index):
+    %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3, %arg4
+    hal.return %x, %y, %z : index, index, index
+  }
+  builtin.module {
+    func.func @forward_dispatch_0_generic_320x320x3x3() {
+      %c0 = arith.constant 0 : index
+      %cst = arith.constant 0.000000e+00 : f32
+      %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<3x320x320x3xf32>>
+      %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<320x320x3x3xf32>>
+      %2 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [3, 320, 320, 3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<3x320x320x3xf32>> -> tensor<3x320x320x3xf32>
+      %3 = tensor.empty() : tensor<320x320x3x3xf32>
+      %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d2, d1, d0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%2 : tensor<3x320x320x3xf32>) outs(%3 : tensor<320x320x3x3xf32>) {
+      ^bb0(%in: f32, %out: f32):
+        %5 = arith.addf %in, %cst : f32
+        linalg.yield %5 : f32
+      } -> tensor<320x320x3x3xf32>
+      flow.dispatch.tensor.store %4, %1, offsets = [0, 0, 0, 0], sizes = [320, 320, 3, 3], strides = [1, 1, 1, 1] : tensor<320x320x3x3xf32> -> !flow.dispatch.tensor<writeonly:tensor<320x320x3x3xf32>>
+      return
+    }
+  }
+}
+}
+
+// CHECK-LABEL: hal.executable.export public @forward_dispatch_0_generic_320x320x3x3
+//     CHECK:     workgroup_size = [3 : index, 3 : index, 7 : index]}
+// CHECK-DAG:     %[[C46:.+]] = arith.constant 46 : index
+// CHECK-DAG:     %[[C320:.+]] = arith.constant 320 : index
+// CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
+//     CHECK:     hal.return %[[C46]], %[[C320]], %[[C1]] : index, index, index