Support 8-bit floats in the compiler. (#18886)
This is a step in a series of PRs adding support for 8-bit flows. It's
sandwiched between https://github.com/iree-org/iree/pull/18885 and
subsequent PRs that will actually do something useful with this.
---------
Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
index 976a4ca..3c1ebd7 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
@@ -336,6 +336,11 @@
def HAL_CollectiveElementType_Float32 : I32EnumAttrCase<"Float32", 9, "f32">;
def HAL_CollectiveElementType_Float64 : I32EnumAttrCase<"Float64", 10, "f64">;
def HAL_CollectiveElementType_BFloat16 : I32EnumAttrCase<"BFloat16", 11, "bf16">;
+def HAL_CollectiveElementType_Float8E5M2 : I32EnumAttrCase<"Float8E5M2", 12, "f8E5M2">;
+def HAL_CollectiveElementType_Float8E4M3 : I32EnumAttrCase<"Float8E4M3", 13, "f8E4M3">;
+def HAL_CollectiveElementType_Float8E5M2FNUZ : I32EnumAttrCase<"Float8E5M2FNUZ", 14, "f8E5M2FNUZ">;
+def HAL_CollectiveElementType_Float8E4M3FNUZ : I32EnumAttrCase<"Float8E4M3FNUZ", 15, "f8E4M3FNUZ">;
+
def HAL_CollectiveElementTypeAttr :
I32EnumAttr<"CollectiveElementType", "valid CollectiveElementType", [
HAL_CollectiveElementType_Sint8,
@@ -350,6 +355,10 @@
HAL_CollectiveElementType_Float32,
HAL_CollectiveElementType_Float64,
HAL_CollectiveElementType_BFloat16,
+ HAL_CollectiveElementType_Float8E5M2,
+ HAL_CollectiveElementType_Float8E4M3,
+ HAL_CollectiveElementType_Float8E5M2FNUZ,
+ HAL_CollectiveElementType_Float8E4M3FNUZ
]> {
let cppNamespace = "::mlir::iree_compiler::IREE::HAL";
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index d06c2dc..f1a820f 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -878,6 +878,10 @@
kFloatIEEE = kFloat | 0x01,
kFloatBrain = kFloat | 0x02,
kFloatComplex = kFloat | 0x03,
+ kFloat8E5M2 = kFloat | 0x04,
+ kFloat8E4M3 = kFloat | 0x05,
+ kFloat8E5M2FNUZ = kFloat | 0x06,
+ kFloat8E4M3FNUZ = kFloat | 0x07,
};
constexpr inline int32_t makeElementTypeValue(NumericalType numericalType,
@@ -905,7 +909,14 @@
return makeElementTypeValue(numericalType, intType.getWidth());
} else if (auto floatType = llvm::dyn_cast_if_present<FloatType>(type)) {
switch (APFloat::SemanticsToEnum(floatType.getFloatSemantics())) {
+ case APFloat::S_Float8E5M2:
+ return makeElementTypeValue(NumericalType::kFloat8E5M2, 8);
+ case APFloat::S_Float8E4M3:
+ return makeElementTypeValue(NumericalType::kFloat8E4M3, 8);
+ case APFloat::S_Float8E5M2FNUZ:
+ return makeElementTypeValue(NumericalType::kFloat8E5M2FNUZ, 8);
case APFloat::S_Float8E4M3FNUZ:
+ return makeElementTypeValue(NumericalType::kFloat8E4M3FNUZ, 8);
case APFloat::S_IEEEhalf:
case APFloat::S_IEEEsingle:
case APFloat::S_IEEEdouble: