Drop unused `_none` values in ukernel enums (#13877)
This is a minor optimization and arguably cleaner separation of
validation vs implementation code, but the primary motivation was to
avoid a miscompile in LLVM/riscv32, explained in that comment in
`mmt4d_internal.h`:
```
// This unreachable statement is not just an optimization, it also works
// around a LLVM/riscv32 miscompile.
// When we used to have a iree_uk_mmt4d_type_none value equal to 0 and
// were returning it here, that caused this whole switch statement to be
// miscompiled by LLVM/riscv32 as if it were UB. That value was passed to
// `iree_uk_type_bit_count(x)`, which evaluates to `1<<(x - 3)`, which is
// UB if x<3. So it was fair to treat that default: clause as UB, but
// LLVM/riscv32 was incorrectly treating the whole switch as UB.
```
diff --git a/runtime/src/iree/builtins/ukernel/mmt4d.c b/runtime/src/iree/builtins/ukernel/mmt4d.c
index 5ea0034..d7ccc44 100644
--- a/runtime/src/iree/builtins/ukernel/mmt4d.c
+++ b/runtime/src/iree/builtins/ukernel/mmt4d.c
@@ -14,8 +14,9 @@
IREE_UK_FLAG_MMT4D_ACCUMULATE |
IREE_UK_FLAG_MMT4D_PREFER_INTRINSICS;
IREE_UK_ASSERT(!(params->flags & ~allflags));
- iree_uk_mmt4d_type_t mmt4d_type = iree_uk_mmt4d_type(params->flags);
- IREE_UK_ASSERT(mmt4d_type != iree_uk_mmt4d_type_none);
+ iree_uk_uint32_t flags_type = params->flags & IREE_UK_FLAG_MMT4D_TYPE_MASK;
+ IREE_UK_ASSERT(flags_type == IREE_UK_FLAG_MMT4D_TYPE_F32F32F32 ||
+ flags_type == IREE_UK_FLAG_MMT4D_TYPE_I8I8I32);
// Some implementations may wish to avoid supporting absurdly wide types. For
// instance, K is the innermost (i.e. hottest) loop bound, so some 32bit
// targets may benefit from K being int32, not int64. We still let K be of
@@ -29,6 +30,7 @@
IREE_UK_ASSERT(IREE_UK_VALUE_IN_UNSIGNED_INT_RANGE(params->N0, 15));
IREE_UK_ASSERT(IREE_UK_VALUE_IN_UNSIGNED_INT_RANGE(params->K0, 15));
// Ensure iree_uk_mmt4d_tile_generic_max_bytes large enough for this tile.
+ iree_uk_mmt4d_type_t mmt4d_type = iree_uk_mmt4d_type(params->flags);
IREE_UK_ASSERT(params->M0 * params->N0 *
iree_uk_type_size(iree_uk_mmt4d_out_type(mmt4d_type)) <=
iree_uk_mmt4d_tile_generic_max_bytes);
diff --git a/runtime/src/iree/builtins/ukernel/mmt4d_internal.h b/runtime/src/iree/builtins/ukernel/mmt4d_internal.h
index 37cd87f..ba5acf5 100644
--- a/runtime/src/iree/builtins/ukernel/mmt4d_internal.h
+++ b/runtime/src/iree/builtins/ukernel/mmt4d_internal.h
@@ -10,7 +10,6 @@
#include "iree/builtins/ukernel/mmt4d.h"
typedef enum iree_uk_mmt4d_type_t {
- iree_uk_mmt4d_type_none = 0,
iree_uk_mmt4d_type_f32f32f32 =
IREE_UK_TIE_3_TYPES_LITERAL(FLOAT_32, FLOAT_32, FLOAT_32),
iree_uk_mmt4d_type_i8i8i32 =
@@ -24,7 +23,16 @@
case IREE_UK_FLAG_MMT4D_TYPE_I8I8I32:
return iree_uk_mmt4d_type_i8i8i32;
default:
- return iree_uk_mmt4d_type_none;
+ // This unreachable statement is not just an optimization, it also works
+ // around a LLVM/riscv32 miscompile.
+
+ // When we used to have a iree_uk_mmt4d_type_none value equal to 0 and
+ // were returning it here, that caused this whole switch statement to be
+ // miscompiled by LLVM/riscv32 as if it were UB. That value was passed to
+ // `iree_uk_type_bit_count(x)`, which evaluates to `1<<(x - 3)`, which is
+ // UB if x<3. So it was fair to treat that default: clause as UB, but
+ // LLVM/riscv32 was incorrectly treating the whole switch as UB.
+ IREE_UK_ASSUME_UNREACHABLE;
}
}
diff --git a/runtime/src/iree/builtins/ukernel/pack.c b/runtime/src/iree/builtins/ukernel/pack.c
index 3aaeca6..e7e7a25 100644
--- a/runtime/src/iree/builtins/ukernel/pack.c
+++ b/runtime/src/iree/builtins/ukernel/pack.c
@@ -62,8 +62,10 @@
IREE_UK_FLAG_PACK_TRANSPOSE_OUTER |
IREE_UK_FLAG_PACK_TYPE_MASK;
IREE_UK_ASSERT(!(params->flags & ~allflags));
- iree_uk_pack_type_t pack_type = iree_uk_pack_type(params->flags);
- IREE_UK_ASSERT(pack_type != iree_uk_pack_type_none);
+ iree_uk_uint32_t flags_type = params->flags & IREE_UK_FLAG_PACK_TYPE_MASK;
+ IREE_UK_ASSERT(flags_type == IREE_UK_FLAG_PACK_TYPE_F32F32 ||
+ flags_type == IREE_UK_FLAG_PACK_TYPE_I8I8 ||
+ flags_type == IREE_UK_FLAG_PACK_TYPE_I32I32);
IREE_UK_ASSERT(params->in_stride0 >= 0);
IREE_UK_ASSERT(params->out_stride0 >= 0);
IREE_UK_ASSERT(params->in_size0 >= 0);
@@ -96,6 +98,7 @@
// in the validation function so that the subsequent ukernel code can be
// treated as infallible.
iree_uk_pack_tmpbuf_helper_t padding_helper;
+ iree_uk_pack_type_t pack_type = iree_uk_pack_type(params->flags);
iree_uk_type_t elem_type = iree_uk_pack_in_type(pack_type);
iree_uk_ssize_t elem_size = iree_uk_type_size(elem_type);
iree_uk_pack_tmpbuf_helper_t_init(tile_size0, tile_size1, elem_size,
diff --git a/runtime/src/iree/builtins/ukernel/pack_internal.h b/runtime/src/iree/builtins/ukernel/pack_internal.h
index e583536..ce92874 100644
--- a/runtime/src/iree/builtins/ukernel/pack_internal.h
+++ b/runtime/src/iree/builtins/ukernel/pack_internal.h
@@ -10,7 +10,6 @@
#include "iree/builtins/ukernel/pack.h"
typedef enum iree_uk_pack_type_t {
- iree_uk_pack_type_none = 0,
iree_uk_pack_type_f32f32 = IREE_UK_TIE_2_TYPES_LITERAL(FLOAT_32, FLOAT_32),
iree_uk_pack_type_i8i8 = IREE_UK_TIE_2_TYPES_LITERAL(INT_8, INT_8),
iree_uk_pack_type_i32i32 = IREE_UK_TIE_2_TYPES_LITERAL(INT_32, INT_32),
@@ -25,7 +24,7 @@
case IREE_UK_FLAG_PACK_TYPE_I32I32:
return iree_uk_pack_type_i32i32;
default:
- return iree_uk_pack_type_none;
+ IREE_UK_ASSUME_UNREACHABLE;
}
}
diff --git a/runtime/src/iree/builtins/ukernel/unpack.c b/runtime/src/iree/builtins/ukernel/unpack.c
index cbeeeb0..4a3696e 100644
--- a/runtime/src/iree/builtins/ukernel/unpack.c
+++ b/runtime/src/iree/builtins/ukernel/unpack.c
@@ -42,8 +42,9 @@
IREE_UK_FLAG_UNPACK_TRANSPOSE_OUTER |
IREE_UK_FLAG_UNPACK_TYPE_MASK;
IREE_UK_ASSERT(!(params->flags & ~allflags));
- iree_uk_unpack_type_t unpack_type = iree_uk_unpack_type(params->flags);
- IREE_UK_ASSERT(unpack_type != iree_uk_unpack_type_none);
+ iree_uk_uint32_t flags_type = params->flags & IREE_UK_FLAG_UNPACK_TYPE_MASK;
+ IREE_UK_ASSERT(flags_type == IREE_UK_FLAG_UNPACK_TYPE_F32F32 ||
+ flags_type == IREE_UK_FLAG_UNPACK_TYPE_I32I32);
IREE_UK_ASSERT(params->in_stride0 >= 0);
IREE_UK_ASSERT(params->out_stride0 >= 0);
IREE_UK_ASSERT(params->out_size0 >= 0);
@@ -76,6 +77,7 @@
// in the validation function so that the subsequent ukernel code can be
// treated as infallible.
iree_uk_unpack_tmpbuf_helper_t helper;
+ iree_uk_unpack_type_t unpack_type = iree_uk_unpack_type(params->flags);
iree_uk_type_t elem_type = iree_uk_unpack_in_type(unpack_type);
iree_uk_ssize_t elem_size = iree_uk_type_size(elem_type);
iree_uk_unpack_tmpbuf_helper_init(tile_size0, tile_size1, elem_size, &helper);
diff --git a/runtime/src/iree/builtins/ukernel/unpack_internal.h b/runtime/src/iree/builtins/ukernel/unpack_internal.h
index 3ffac8a..3d48d82 100644
--- a/runtime/src/iree/builtins/ukernel/unpack_internal.h
+++ b/runtime/src/iree/builtins/ukernel/unpack_internal.h
@@ -10,7 +10,6 @@
#include "iree/builtins/ukernel/unpack.h"
typedef enum iree_uk_unpack_type_t {
- iree_uk_unpack_type_none = 0,
iree_uk_unpack_type_f32f32 = IREE_UK_TIE_2_TYPES_LITERAL(FLOAT_32, FLOAT_32),
iree_uk_unpack_type_i32i32 = IREE_UK_TIE_2_TYPES_LITERAL(INT_32, INT_32),
} iree_uk_unpack_type_t;
@@ -23,7 +22,7 @@
case IREE_UK_FLAG_UNPACK_TYPE_I32I32:
return iree_uk_unpack_type_i32i32;
default:
- return iree_uk_unpack_type_none;
+ IREE_UK_ASSUME_UNREACHABLE;
}
}