[spirv] Account for dynamic dimensions when computing parallelism (#15211)
We can pretty much ignore dynamic dimensions and only focus on static
ones when deciding whether we have enough parallelism. This makes llama2
dispatches like `generic_Dx11008x32x128_f16` perform better.
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
index 5013f8a..c02f09c 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
@@ -1275,26 +1275,23 @@
// memory cache behavior.
// Both means we cannot use a too large workgroup size.
- std::optional<int64_t> parallelSize = 1;
+ int64_t parallelSize = 1;
for (int64_t dim : parallelDims) {
- if (ShapedType::isDynamic(bounds[dim])) {
- parallelSize = std::nullopt;
- break;
- }
- *parallelSize *= bounds[dim];
+ if (!ShapedType::isDynamic(bounds[dim]))
+ parallelSize *= bounds[dim];
}
// Total parallel size that can fill the GPU with enough workgorups.
// TODO: query from the target device; roughly 2x hardware compute unit.
int parallelThreshold = 256;
// How many 128-bit vectors each thread should at least read.
const int targetVectorCount = 8;
- while (parallelSize && *parallelSize > parallelThreshold &&
+ while (parallelSize > parallelThreshold &&
(groupSize / 2) % subgroupSize == 0 &&
reductionSize / (groupSize * vectorSize) < targetVectorCount) {
// Use less subgroups per workgroup..
groupSize /= 2;
// in order to host more workgroups per hardware compute unit.
- *parallelSize /= 2;
+ parallelSize /= 2;
}
// Current warp reduction pattern is a two step butterfly warp reduce.
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matvec.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matvec.mlir
index 0a86ca1..0774abb 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matvec.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matvec.mlir
@@ -446,3 +446,106 @@
// CHECK: func.func @i4_dequant_matvec()
// CHECK: linalg.generic
// CHECK-SAME: lowering_config = #[[$CONFIG]]
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>,
+ #hal.descriptor_set.binding<2, storage_buffer>,
+ #hal.descriptor_set.binding<3, storage_buffer>,
+ #hal.descriptor_set.binding<4, storage_buffer>
+ ]>
+]>
+
+hal.executable @i4_dequant_matvec {
+ hal.executable.variant @vulkan_spirv_fb, target = <"vulkan-spirv", "vulkan-spirv-fb", {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniform, GroupNonUniformShuffle], []>, Unknown:IntegratedGPU, #spirv.resource_limits<
+ max_compute_shared_memory_size = 32768,
+ max_compute_workgroup_invocations = 1024,
+ max_compute_workgroup_size = [1024, 1024, 1024],
+ subgroup_size = 64>>
+ }> {
+ hal.executable.export @i4_dequant_matvec layout(#pipeline_layout)
+ builtin.module {
+ func.func @i4_dequant_matvec() {
+ %c32_i64 = arith.constant 32 : i64
+ %cst = arith.constant 0.000000e+00 : f16
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.constant.load[0] : i32
+ %1 = hal.interface.constant.load[1] : i32
+ %2 = hal.interface.constant.load[2] : i32
+ %3 = hal.interface.constant.load[3] : i32
+ %4 = hal.interface.constant.load[4] : i32
+ %5 = hal.interface.constant.load[5] : i32
+ %6 = hal.interface.constant.load[6] : i32
+ %7 = arith.index_castui %0 : i32 to index
+ %8 = arith.index_castui %1 : i32 to index
+ %9 = arith.index_castui %2 : i32 to index
+ %10 = arith.extui %3 : i32 to i64
+ %11 = arith.extui %4 : i32 to i64
+ %12 = arith.shli %11, %c32_i64 : i64
+ %13 = arith.ori %10, %12 : i64
+ %14 = arith.index_castui %13 : i64 to index
+ %15 = arith.extui %5 : i32 to i64
+ %16 = arith.extui %6 : i32 to i64
+ %17 = arith.shli %16, %c32_i64 : i64
+ %18 = arith.ori %15, %17 : i64
+ %19 = arith.index_castui %18 : i64 to index
+ %20 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<11008x32x128xi4>>
+ %21 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<11008x32xf16>>
+ %22 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<11008x32xf16>>
+ %23 = flow.dispatch.workload.ordinal %19, 0 : index
+ %24 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<?x32x128xf16>>{%23}
+ %25 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%14) : !flow.dispatch.tensor<writeonly:tensor<?x11008xf16>>{%23}
+ %26 = flow.dispatch.tensor.load %20, offsets = [0, 0, 0], sizes = [11008, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<11008x32x128xi4>> -> tensor<11008x32x128xi4>
+ %27 = flow.dispatch.tensor.load %21, offsets = [0, 0], sizes = [11008, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<11008x32xf16>> -> tensor<11008x32xf16>
+ %28 = flow.dispatch.tensor.load %22, offsets = [0, 0], sizes = [11008, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<11008x32xf16>> -> tensor<11008x32xf16>
+ %29 = flow.dispatch.tensor.load %24, offsets = [0, 0, 0], sizes = [%23, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<?x32x128xf16>>{%23} -> tensor<?x32x128xf16>
+ %30 = tensor.empty() : tensor<11008x32x128xf16>
+ %31 = tensor.empty(%23) : tensor<?x11008xf16>
+ %32 = linalg.fill ins(%cst : f16) outs(%31 : tensor<?x11008xf16>) -> tensor<?x11008xf16>
+ %33 = linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%26, %27, %28 : tensor<11008x32x128xi4>, tensor<11008x32xf16>, tensor<11008x32xf16>)
+ outs(%30 : tensor<11008x32x128xf16>) {
+ ^bb0(%in: i4, %in_0: f16, %in_1: f16, %out: f16):
+ %35 = arith.extui %in : i4 to i32
+ %36 = arith.uitofp %35 : i32 to f16
+ %37 = arith.subf %36, %in_1 : f16
+ %38 = arith.mulf %37, %in_0 : f16
+ linalg.yield %38 : f16
+ } -> tensor<11008x32x128xf16>
+ %34 = linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ ins(%29, %33 : tensor<?x32x128xf16>, tensor<11008x32x128xf16>) outs(%32 : tensor<?x11008xf16>) {
+ ^bb0(%in: f16, %in_0: f16, %out: f16):
+ %35 = arith.mulf %in, %in_0 : f16
+ %36 = arith.addf %35, %out : f16
+ linalg.yield %36 : f16
+ } -> tensor<?x11008xf16>
+ flow.dispatch.tensor.store %34, %25, offsets = [0, 0], sizes = [%23, 11008], strides = [1, 1] : tensor<?x11008xf16> -> !flow.dispatch.tensor<writeonly:tensor<?x11008xf16>>{%23}
+ return
+ }
+ }
+ }
+}
+
+// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 1], [0, 0, 4, 128]{{\]}}>
+// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVSubgroupReduce>
+// CHECK-LABEL: hal.executable.export public @i4_dequant_matvec
+// CHECK-SAME: translation_info = #[[$TRANSLATION]]
+// CHECK-SAME: workgroup_size = [64 : index, 1 : index, 1 : index]
+// CHECK: func.func @i4_dequant_matvec()
+// CHECK: linalg.generic
+// CHECK-SAME: lowering_config = #[[$CONFIG]]