Support 8-bit floats in the compiler. (#18886)

This is a step in a series of PRs adding support for 8-bit flows. It's
sandwiched between https://github.com/iree-org/iree/pull/18885 and
subsequent PRs that will actually do something useful with this.

---------

Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
index 976a4ca..3c1ebd7 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
@@ -336,6 +336,11 @@
 def HAL_CollectiveElementType_Float32 : I32EnumAttrCase<"Float32", 9, "f32">;
 def HAL_CollectiveElementType_Float64 : I32EnumAttrCase<"Float64", 10, "f64">;
 def HAL_CollectiveElementType_BFloat16 : I32EnumAttrCase<"BFloat16", 11, "bf16">;
+def HAL_CollectiveElementType_Float8E5M2 : I32EnumAttrCase<"Float8E5M2", 12, "f8E5M2">;
+def HAL_CollectiveElementType_Float8E4M3 : I32EnumAttrCase<"Float8E4M3", 13, "f8E4M3">;
+def HAL_CollectiveElementType_Float8E5M2FNUZ : I32EnumAttrCase<"Float8E5M2FNUZ", 14, "f8E5M2FNUZ">;
+def HAL_CollectiveElementType_Float8E4M3FNUZ : I32EnumAttrCase<"Float8E4M3FNUZ", 15, "f8E4M3FNUZ">;
+
 def HAL_CollectiveElementTypeAttr :
     I32EnumAttr<"CollectiveElementType", "valid CollectiveElementType", [
       HAL_CollectiveElementType_Sint8,
@@ -350,6 +355,10 @@
       HAL_CollectiveElementType_Float32,
       HAL_CollectiveElementType_Float64,
       HAL_CollectiveElementType_BFloat16,
+      HAL_CollectiveElementType_Float8E5M2,
+      HAL_CollectiveElementType_Float8E4M3,
+      HAL_CollectiveElementType_Float8E5M2FNUZ,
+      HAL_CollectiveElementType_Float8E4M3FNUZ
     ]> {
   let cppNamespace = "::mlir::iree_compiler::IREE::HAL";
 }
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index d06c2dc..f1a820f 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -878,6 +878,10 @@
   kFloatIEEE = kFloat | 0x01,
   kFloatBrain = kFloat | 0x02,
   kFloatComplex = kFloat | 0x03,
+  kFloat8E5M2 = kFloat | 0x04,
+  kFloat8E4M3 = kFloat | 0x05,
+  kFloat8E5M2FNUZ = kFloat | 0x06,
+  kFloat8E4M3FNUZ = kFloat | 0x07,
 };
 
 constexpr inline int32_t makeElementTypeValue(NumericalType numericalType,
@@ -905,7 +909,14 @@
     return makeElementTypeValue(numericalType, intType.getWidth());
   } else if (auto floatType = llvm::dyn_cast_if_present<FloatType>(type)) {
     switch (APFloat::SemanticsToEnum(floatType.getFloatSemantics())) {
+    case APFloat::S_Float8E5M2:
+      return makeElementTypeValue(NumericalType::kFloat8E5M2, 8);
+    case APFloat::S_Float8E4M3:
+      return makeElementTypeValue(NumericalType::kFloat8E4M3, 8);
+    case APFloat::S_Float8E5M2FNUZ:
+      return makeElementTypeValue(NumericalType::kFloat8E5M2FNUZ, 8);
     case APFloat::S_Float8E4M3FNUZ:
+      return makeElementTypeValue(NumericalType::kFloat8E4M3FNUZ, 8);
     case APFloat::S_IEEEhalf:
     case APFloat::S_IEEEsingle:
     case APFloat::S_IEEEdouble: