[spirv] Account for dynamic dimensions when computing parallelism (#15211)

We can pretty much ignore dynamic dimensions and only focus on static
ones when deciding whether we have enough parallelism. This makes llama2
dispatches like `generic_Dx11008x32x128_f16` perform better.
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
index 5013f8a..c02f09c 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
@@ -1275,26 +1275,23 @@
   //    memory cache behavior.
   // Both means we cannot use a too large workgroup size.
 
-  std::optional<int64_t> parallelSize = 1;
+  int64_t parallelSize = 1;
   for (int64_t dim : parallelDims) {
-    if (ShapedType::isDynamic(bounds[dim])) {
-      parallelSize = std::nullopt;
-      break;
-    }
-    *parallelSize *= bounds[dim];
+    if (!ShapedType::isDynamic(bounds[dim]))
+      parallelSize *= bounds[dim];
   }
   // Total parallel size that can fill the GPU with enough workgorups.
   // TODO: query from the target device; roughly 2x hardware compute unit.
   int parallelThreshold = 256;
   // How many 128-bit vectors each thread should at least read.
   const int targetVectorCount = 8;
-  while (parallelSize && *parallelSize > parallelThreshold &&
+  while (parallelSize > parallelThreshold &&
          (groupSize / 2) % subgroupSize == 0 &&
          reductionSize / (groupSize * vectorSize) < targetVectorCount) {
     // Use less subgroups per workgroup..
     groupSize /= 2;
     // in order to host more workgroups per hardware compute unit.
-    *parallelSize /= 2;
+    parallelSize /= 2;
   }
 
   // Current warp reduction pattern is a two step butterfly warp reduce.
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matvec.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matvec.mlir
index 0a86ca1..0774abb 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matvec.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matvec.mlir
@@ -446,3 +446,106 @@
 //       CHECK: func.func @i4_dequant_matvec()
 //       CHECK:   linalg.generic
 //  CHECK-SAME:     lowering_config = #[[$CONFIG]]
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>,
+    #hal.descriptor_set.binding<2, storage_buffer>,
+    #hal.descriptor_set.binding<3, storage_buffer>,
+    #hal.descriptor_set.binding<4, storage_buffer>
+  ]>
+]>
+
+hal.executable @i4_dequant_matvec {
+  hal.executable.variant @vulkan_spirv_fb, target = <"vulkan-spirv", "vulkan-spirv-fb", {
+      spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniform, GroupNonUniformShuffle], []>, Unknown:IntegratedGPU, #spirv.resource_limits<
+        max_compute_shared_memory_size = 32768,
+        max_compute_workgroup_invocations = 1024,
+        max_compute_workgroup_size = [1024, 1024, 1024],
+        subgroup_size = 64>>
+    }> {
+    hal.executable.export @i4_dequant_matvec layout(#pipeline_layout)
+    builtin.module {
+      func.func @i4_dequant_matvec() {
+        %c32_i64 = arith.constant 32 : i64
+        %cst = arith.constant 0.000000e+00 : f16
+        %c0 = arith.constant 0 : index
+        %0 = hal.interface.constant.load[0] : i32
+        %1 = hal.interface.constant.load[1] : i32
+        %2 = hal.interface.constant.load[2] : i32
+        %3 = hal.interface.constant.load[3] : i32
+        %4 = hal.interface.constant.load[4] : i32
+        %5 = hal.interface.constant.load[5] : i32
+        %6 = hal.interface.constant.load[6] : i32
+        %7 = arith.index_castui %0 : i32 to index
+        %8 = arith.index_castui %1 : i32 to index
+        %9 = arith.index_castui %2 : i32 to index
+        %10 = arith.extui %3 : i32 to i64
+        %11 = arith.extui %4 : i32 to i64
+        %12 = arith.shli %11, %c32_i64 : i64
+        %13 = arith.ori %10, %12 : i64
+        %14 = arith.index_castui %13 : i64 to index
+        %15 = arith.extui %5 : i32 to i64
+        %16 = arith.extui %6 : i32 to i64
+        %17 = arith.shli %16, %c32_i64 : i64
+        %18 = arith.ori %15, %17 : i64
+        %19 = arith.index_castui %18 : i64 to index
+        %20 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<11008x32x128xi4>>
+        %21 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<11008x32xf16>>
+        %22 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<11008x32xf16>>
+        %23 = flow.dispatch.workload.ordinal %19, 0 : index
+        %24 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<?x32x128xf16>>{%23}
+        %25 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%14) : !flow.dispatch.tensor<writeonly:tensor<?x11008xf16>>{%23}
+        %26 = flow.dispatch.tensor.load %20, offsets = [0, 0, 0], sizes = [11008, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<11008x32x128xi4>> -> tensor<11008x32x128xi4>
+        %27 = flow.dispatch.tensor.load %21, offsets = [0, 0], sizes = [11008, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<11008x32xf16>> -> tensor<11008x32xf16>
+        %28 = flow.dispatch.tensor.load %22, offsets = [0, 0], sizes = [11008, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<11008x32xf16>> -> tensor<11008x32xf16>
+        %29 = flow.dispatch.tensor.load %24, offsets = [0, 0, 0], sizes = [%23, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<?x32x128xf16>>{%23} -> tensor<?x32x128xf16>
+        %30 = tensor.empty() : tensor<11008x32x128xf16>
+        %31 = tensor.empty(%23) : tensor<?x11008xf16>
+        %32 = linalg.fill ins(%cst : f16) outs(%31 : tensor<?x11008xf16>) -> tensor<?x11008xf16>
+        %33 = linalg.generic {
+          indexing_maps = [
+            affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+            affine_map<(d0, d1, d2) -> (d0, d1)>,
+            affine_map<(d0, d1, d2) -> (d0, d1)>,
+            affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+          iterator_types = ["parallel", "parallel", "parallel"]}
+        ins(%26, %27, %28 : tensor<11008x32x128xi4>, tensor<11008x32xf16>, tensor<11008x32xf16>)
+        outs(%30 : tensor<11008x32x128xf16>) {
+        ^bb0(%in: i4, %in_0: f16, %in_1: f16, %out: f16):
+          %35 = arith.extui %in : i4 to i32
+          %36 = arith.uitofp %35 : i32 to f16
+          %37 = arith.subf %36, %in_1 : f16
+          %38 = arith.mulf %37, %in_0 : f16
+          linalg.yield %38 : f16
+        } -> tensor<11008x32x128xf16>
+        %34 = linalg.generic {
+          indexing_maps = [
+            affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+            affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>,
+            affine_map<(d0, d1, d2, d3) -> (d0, d1)>],
+          iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+        ins(%29, %33 : tensor<?x32x128xf16>, tensor<11008x32x128xf16>) outs(%32 : tensor<?x11008xf16>) {
+        ^bb0(%in: f16, %in_0: f16, %out: f16):
+          %35 = arith.mulf %in, %in_0 : f16
+          %36 = arith.addf %35, %out : f16
+          linalg.yield %36 : f16
+        } -> tensor<?x11008xf16>
+        flow.dispatch.tensor.store %34, %25, offsets = [0, 0], sizes = [%23, 11008], strides = [1, 1] : tensor<?x11008xf16> -> !flow.dispatch.tensor<writeonly:tensor<?x11008xf16>>{%23}
+        return
+      }
+    }
+  }
+}
+
+//   CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 1], [0, 0, 4, 128]{{\]}}>
+//   CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVSubgroupReduce>
+// CHECK-LABEL: hal.executable.export public @i4_dequant_matvec
+//  CHECK-SAME:   translation_info = #[[$TRANSLATION]]
+//  CHECK-SAME:   workgroup_size = [64 : index, 1 : index, 1 : index]
+//       CHECK: func.func @i4_dequant_matvec()
+//       CHECK:   linalg.generic
+//  CHECK-SAME:     lowering_config = #[[$CONFIG]]