[SPIRV] Enable small float support in SPIR-V pipeline. (#23391)
Add ConvertUnsupportedFloatArithPass and
ConvertUnsupportedFloatToIntBuffersPass to the SPIR-V lowering pipeline
to emulate bf16 and fp8 arithmetic and buffer types. Also add expansion
patterns for f4E2M1FN, and f8E8M0FNU in ConvertToSPIRVPass.
Enable small_float_arith and fp_to_subbyte e2e tests for vulkan-spirv.
The fp4_f32_conversion test is excluded because scaling_extf/truncf
lowering is not properly supported yet.
Fixes https://github.com/iree-org/iree/issues/15772
Signed-off-by: hanhanW <hanhan0912@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
index ca91dab..e675160 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
@@ -633,6 +633,8 @@
{
RewritePatternSet patterns(context);
arith::populateExpandBFloat16Patterns(patterns);
+ arith::populateExpandF4E2M1Patterns(patterns);
+ arith::populateExpandF8E8M0Patterns(patterns);
arith::BitcastOp::getCanonicalizationPatterns(patterns, context);
if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) {
moduleOp.emitOpError() << "failed running bf16 extf/trunc patterns";
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
index 4a38927..d94043d 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -200,6 +200,7 @@
// possible. In SPIR-V we don't use memref descriptor so it's not possible
// to handle subview ops.
.addPass(memref::createFoldMemRefAliasOpsPass)
+ .addPass(createConvertUnsupportedFloatArithPass)
.addPass(createEmulateNarrowTypePass)
.addPass(createCanonicalizerPass)
.addPass(createCSEPass)
@@ -254,16 +255,17 @@
.addPass(createSPIRVEmulateI64Pass)
.addPass(createConvertBf16ArithToF32Pass)
.addPass([]() {
- // Convert bf16 buffers to i16. Other float types are not yet
- // supported in the SPIR-V pipeline.
+ // Convert unsupported float buffer types to integer types.
+ // SPIR-V doesn't natively support bf16 or fp8 types, so we convert
+ // them to integer types of the same bit width for storage.
return createConvertUnsupportedFloatToIntBuffersPass(
ConvertUnsupportedFloatToIntBuffersPassOptions{
/*includeBf16=*/true,
- /*includeF8E5M2=*/false,
- /*includeF8E4M3FN=*/false,
- /*includeF8E5M2FNUZ=*/false,
- /*includeF8E4M3FNUZ=*/false,
- /*includeF8E8M0FNU=*/false,
+ /*includeF8E5M2=*/true,
+ /*includeF8E4M3FN=*/true,
+ /*includeF8E5M2FNUZ=*/true,
+ /*includeF8E4M3FNUZ=*/true,
+ /*includeF8E8M0FNU=*/true,
});
})
.addPass(createCanonicalizerPass)
diff --git a/tests/e2e/linalg/BUILD.bazel b/tests/e2e/linalg/BUILD.bazel
index 4b5f343..8929ea1 100644
--- a/tests/e2e/linalg/BUILD.bazel
+++ b/tests/e2e/linalg/BUILD.bazel
@@ -128,21 +128,23 @@
[
"argmax.mlir",
"conv2d.mlir",
+ "fp_to_subbyte.mlir",
"gather_like_ops.mlir",
"index.mlir",
"narrow_n_matmuls.mlir",
+ "small_float_arith.mlir",
"softmax.mlir",
"subbyte_to_fp.mlir",
],
include = ["*.mlir"],
exclude = [
- "fp_to_subbyte.mlir",
+ # TODO(#23387): Enable the test once scaling_extf/truncf lowering is
+ # supported.
"fp4_f32_conversion.mlir",
"large_linalg_matmul.mlir",
"pack.mlir",
"pack_dynamic_inner_tiles.mlir",
"pack_i8.mlir",
- "small_float_arith.mlir",
"unpack.mlir",
],
)
diff --git a/tests/e2e/linalg/CMakeLists.txt b/tests/e2e/linalg/CMakeLists.txt
index 40b05b2..3bc34d7 100644
--- a/tests/e2e/linalg/CMakeLists.txt
+++ b/tests/e2e/linalg/CMakeLists.txt
@@ -98,9 +98,11 @@
SRCS
"argmax.mlir"
"conv2d.mlir"
+ "fp_to_subbyte.mlir"
"gather_like_ops.mlir"
"index.mlir"
"narrow_n_matmuls.mlir"
+ "small_float_arith.mlir"
"softmax.mlir"
"subbyte_to_fp.mlir"
TARGET_BACKEND