[spirv] Enable vectorized codegen for i8 matmul (#12262)
The main goal is to excercise this codegen path before introducing
further optimizations and codegen changes.
Locally, I see 2x speedup on MobileBERT and 1.8x speedup EfficientNet
with this patch on Pixel 6.
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
index 2a59748..16e75c6 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
@@ -541,8 +541,9 @@
auto lhsType = lhs->get().getType().cast<ShapedType>();
auto rhsType = rhs->get().getType().cast<ShapedType>();
- auto elementBits = lhsType.getElementType().getIntOrFloatBitWidth();
- if (elementBits != 16 && elementBits != 32) return success();
+ auto elementBits =
+ static_cast<int>(lhsType.getElementType().getIntOrFloatBitWidth());
+ if (!llvm::is_contained({8, 16, 32}, elementBits)) return success();
ArrayRef<int64_t> lhsShape = lhsType.getShape();
ArrayRef<int64_t> rhsShape = rhsType.getShape();
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matmul.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matmul.mlir
index 8fbf54e..b71b594 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matmul.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matmul.mlir
@@ -55,7 +55,7 @@
// -----
-// Non-16 / non-32 bit types cannot be vectorized right now.
+// 8-bit integers can be vectorized.
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
@@ -64,7 +64,7 @@
#hal.descriptor_set.binding<2, storage_buffer>
]>
]>
-hal.executable private @matmul_64x16 {
+hal.executable private @matmul_64x16xi8 {
hal.executable.variant public @vulkan_spirv_fb, target = <"vulkan", "vulkan-spirv-fb", {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, Unknown:IntegratedGPU, #spirv.resource_limits<
max_compute_shared_memory_size = 16384,
@@ -72,9 +72,9 @@
max_compute_workgroup_size = [128, 128, 64],
subgroup_size = 64>>
}> {
- hal.executable.export public @matmul_64x16 layout(#pipeline_layout)
+ hal.executable.export public @matmul_64x16xi8 layout(#pipeline_layout)
builtin.module {
- func.func @matmul_64x16() {
+ func.func @matmul_64x16xi8() {
%c0 = arith.constant 0 : index
%c16 = arith.constant 16 : index
%c64 = arith.constant 64 : index
@@ -98,12 +98,66 @@
}
}
+// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 16], [2, 8], [0, 0, 8]{{\]}}>
+// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVBaseVectorize>
+// CHECK: hal.executable.export public @matmul_64x16xi8
+// CHECK-SAME: translation_info = #[[TRANSLATION]]
+// CHECK-SAME: workgroup_size = [2 : index, 32 : index, 1 : index]
+// CHECK: func.func @matmul_64x16xi8()
+// CHECK: linalg.matmul
+// CHECK-SAME: lowering_config = #[[CONFIG]]
+
+// -----
+
+// Non-16 / non-32 bit types cannot be vectorized right now.
+
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>,
+ #hal.descriptor_set.binding<2, storage_buffer>
+ ]>
+]>
+hal.executable private @matmul_64x16xi64 {
+ hal.executable.variant public @vulkan_spirv_fb, target = <"vulkan", "vulkan-spirv-fb", {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, Int64], []>, Unknown:IntegratedGPU, #spirv.resource_limits<
+ max_compute_shared_memory_size = 16384,
+ max_compute_workgroup_invocations = 128,
+ max_compute_workgroup_size = [128, 128, 64],
+ subgroup_size = 64>>
+ }> {
+ hal.executable.export public @matmul_64x16xi64 layout(#pipeline_layout)
+ builtin.module {
+ func.func @matmul_64x16xi64() {
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c64 = arith.constant 64 : index
+ %c0_i32 = arith.constant 0 : i32
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:tensor<64x32xi64>>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:tensor<32x16xi64>>
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:tensor<64x16xi64>>
+ %8 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [64, 32], strides = [1, 1]
+ : !flow.dispatch.tensor<readonly:tensor<64x32xi64>> -> tensor<64x32xi64>
+ %10 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [32, 16], strides = [1, 1]
+ : !flow.dispatch.tensor<readonly:tensor<32x16xi64>> -> tensor<32x16xi64>
+ %15 = tensor.empty() : tensor<64x16xi64>
+ %16 = linalg.fill ins(%c0_i32 : i32) outs(%15 : tensor<64x16xi64>) -> tensor<64x16xi64>
+ %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup"}
+ ins(%8, %10 : tensor<64x32xi64>, tensor<32x16xi64>) outs(%16 : tensor<64x16xi64>) -> tensor<64x16xi64>
+ flow.dispatch.tensor.store %17, %2, offsets = [0, 0], sizes = [64, 16], strides = [1, 1]
+ : tensor<64x16xi64> -> !flow.dispatch.tensor<writeonly:tensor<64x16xi64>>
+ return
+ }
+ }
+ }
+}
+
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[4, 16], [1, 1]{{\]}}>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVBaseDistribute>
-// CHECK: hal.executable.export public @matmul_64x16
+// CHECK: hal.executable.export public @matmul_64x16xi64
// CHECK-SAME: translation_info = #[[TRANSLATION]]
// CHECK-SAME: workgroup_size = [16 : index, 4 : index, 1 : index]
-// CHECK: func.func @matmul_64x16()
+// CHECK: func.func @matmul_64x16xi64()
// CHECK: linalg.matmul
// CHECK-SAME: lowering_config = #[[CONFIG]]