[SPIRV] Enable small float support in SPIR-V pipeline. (#23391)

Add ConvertUnsupportedFloatArithPass and
ConvertUnsupportedFloatToIntBuffersPass to the SPIR-V lowering pipeline
to emulate bf16 and fp8 arithmetic and buffer types. Also add expansion
patterns for f4E2M1FN, and f8E8M0FNU in ConvertToSPIRVPass.

Enable small_float_arith and fp_to_subbyte e2e tests for vulkan-spirv.
The fp4_f32_conversion test is excluded because scaling_extf/truncf
lowering is not properly supported yet.

Fixes https://github.com/iree-org/iree/issues/15772

Signed-off-by: hanhanW <hanhan0912@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
index ca91dab..e675160 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
@@ -633,6 +633,8 @@
   {
     RewritePatternSet patterns(context);
     arith::populateExpandBFloat16Patterns(patterns);
+    arith::populateExpandF4E2M1Patterns(patterns);
+    arith::populateExpandF8E8M0Patterns(patterns);
     arith::BitcastOp::getCanonicalizationPatterns(patterns, context);
     if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) {
       moduleOp.emitOpError() << "failed running bf16 extf/trunc patterns";
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
index 4a38927..d94043d 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -200,6 +200,7 @@
       // possible. In SPIR-V we don't use memref descriptor so it's not possible
       // to handle subview ops.
       .addPass(memref::createFoldMemRefAliasOpsPass)
+      .addPass(createConvertUnsupportedFloatArithPass)
       .addPass(createEmulateNarrowTypePass)
       .addPass(createCanonicalizerPass)
       .addPass(createCSEPass)
@@ -254,16 +255,17 @@
       .addPass(createSPIRVEmulateI64Pass)
       .addPass(createConvertBf16ArithToF32Pass)
       .addPass([]() {
-        // Convert bf16 buffers to i16. Other float types are not yet
-        // supported in the SPIR-V pipeline.
+        // Convert unsupported float buffer types to integer types.
+        // SPIR-V doesn't natively support bf16 or fp8 types, so we convert
+        // them to integer types of the same bit width for storage.
         return createConvertUnsupportedFloatToIntBuffersPass(
             ConvertUnsupportedFloatToIntBuffersPassOptions{
                 /*includeBf16=*/true,
-                /*includeF8E5M2=*/false,
-                /*includeF8E4M3FN=*/false,
-                /*includeF8E5M2FNUZ=*/false,
-                /*includeF8E4M3FNUZ=*/false,
-                /*includeF8E8M0FNU=*/false,
+                /*includeF8E5M2=*/true,
+                /*includeF8E4M3FN=*/true,
+                /*includeF8E5M2FNUZ=*/true,
+                /*includeF8E4M3FNUZ=*/true,
+                /*includeF8E8M0FNU=*/true,
             });
       })
       .addPass(createCanonicalizerPass)
diff --git a/tests/e2e/linalg/BUILD.bazel b/tests/e2e/linalg/BUILD.bazel
index 4b5f343..8929ea1 100644
--- a/tests/e2e/linalg/BUILD.bazel
+++ b/tests/e2e/linalg/BUILD.bazel
@@ -128,21 +128,23 @@
     [
         "argmax.mlir",
         "conv2d.mlir",
+        "fp_to_subbyte.mlir",
         "gather_like_ops.mlir",
         "index.mlir",
         "narrow_n_matmuls.mlir",
+        "small_float_arith.mlir",
         "softmax.mlir",
         "subbyte_to_fp.mlir",
     ],
     include = ["*.mlir"],
     exclude = [
-        "fp_to_subbyte.mlir",
+        # TODO(#23387): Enable the test once scaling_extf/truncf lowering is
+        # supported.
         "fp4_f32_conversion.mlir",
         "large_linalg_matmul.mlir",
         "pack.mlir",
         "pack_dynamic_inner_tiles.mlir",
         "pack_i8.mlir",
-        "small_float_arith.mlir",
         "unpack.mlir",
     ],
 )
diff --git a/tests/e2e/linalg/CMakeLists.txt b/tests/e2e/linalg/CMakeLists.txt
index 40b05b2..3bc34d7 100644
--- a/tests/e2e/linalg/CMakeLists.txt
+++ b/tests/e2e/linalg/CMakeLists.txt
@@ -98,9 +98,11 @@
   SRCS
     "argmax.mlir"
     "conv2d.mlir"
+    "fp_to_subbyte.mlir"
     "gather_like_ops.mlir"
     "index.mlir"
     "narrow_n_matmuls.mlir"
+    "small_float_arith.mlir"
     "softmax.mlir"
     "subbyte_to_fp.mlir"
   TARGET_BACKEND