Fix Python dtype conversion for int64 on Windows. (#12880)
Fixes https://github.com/openxla/iree/issues/11080. The int64 and uint64
test cases here were failing on Windows as the element type mapping was
routing via the code `l`, which is a "C long int" - not an explicitly 64
bit type. This changes the mapping to always use the explicit "type
strings" (any string in `numpy.sctypeDict.keys()`, [shown in this
gist](https://gist.github.com/ScottTodd/ec1f7906e9c644eb47f74280d6c26229)).
Relates to https://github.com/openxla/iree/pull/12872
diff --git a/build_tools/cmake/ctest_all.sh b/build_tools/cmake/ctest_all.sh
index b255e77..cba6b09 100755
--- a/build_tools/cmake/ctest_all.sh
+++ b/build_tools/cmake/ctest_all.sh
@@ -90,8 +90,6 @@
"iree/tests/e2e/tensor_ops/check_vmvx_ukernel_local-task_unpack.mlir"
# TODO(#11070): Fix argument/result signature mismatch
"iree/tests/e2e/tosa_ops/check_vmvx_local-sync_microkernels_fully_connected.mlir"
- # TODO(#11080): Fix arrays not matching in test_variant_list_buffers
- "iree/runtime/bindings/python/vm_types_test"
)
elif [[ "$OSTYPE" =~ ^darwin ]]; then
excluded_tests+=(
diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc
index 110f121..f970a70 100644
--- a/runtime/bindings/python/hal.cc
+++ b/runtime/bindings/python/hal.cc
@@ -411,54 +411,59 @@
namespace {
py::object MapElementTypeToDType(iree_hal_element_type_t element_type) {
- // See: https://docs.python.org/3/c-api/arg.html#numbers
- // TODO: Handle dtypes that do not map to a code (i.e. fp16).
- const char* dtype_code;
+ // See:
+ // * https://numpy.org/doc/stable/reference/arrays.dtypes.html
+ // * https://docs.python.org/3/c-api/arg.html#numbers
+ //
+ // Single letter codes can be ambiguous across platforms, so prefer explicit
+ // bit depth values, ("Type strings: Any string in numpy.sctypeDict.keys()").
+ // See https://github.com/pybind/pybind11/issues/1908
+ const char* dtype_string;
switch (element_type) {
case IREE_HAL_ELEMENT_TYPE_BOOL_8:
- dtype_code = "?";
+ dtype_string = "?";
break;
case IREE_HAL_ELEMENT_TYPE_INT_8:
case IREE_HAL_ELEMENT_TYPE_SINT_8:
- dtype_code = "b";
+ dtype_string = "int8";
break;
case IREE_HAL_ELEMENT_TYPE_UINT_8:
- dtype_code = "B";
+ dtype_string = "uint8";
break;
case IREE_HAL_ELEMENT_TYPE_INT_16:
case IREE_HAL_ELEMENT_TYPE_SINT_16:
- dtype_code = "h";
+ dtype_string = "int16";
break;
case IREE_HAL_ELEMENT_TYPE_UINT_16:
- dtype_code = "H";
+ dtype_string = "uint16";
break;
case IREE_HAL_ELEMENT_TYPE_INT_32:
case IREE_HAL_ELEMENT_TYPE_SINT_32:
- dtype_code = "i";
+ dtype_string = "int32";
break;
case IREE_HAL_ELEMENT_TYPE_UINT_32:
- dtype_code = "I";
+ dtype_string = "uint32";
break;
case IREE_HAL_ELEMENT_TYPE_INT_64:
case IREE_HAL_ELEMENT_TYPE_SINT_64:
- dtype_code = "l";
+ dtype_string = "int64";
break;
case IREE_HAL_ELEMENT_TYPE_UINT_64:
- dtype_code = "L";
+ dtype_string = "uint64";
break;
case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
- dtype_code = "e";
+ dtype_string = "float16";
break;
case IREE_HAL_ELEMENT_TYPE_FLOAT_32:
- dtype_code = "f";
+ dtype_string = "float32";
break;
case IREE_HAL_ELEMENT_TYPE_FLOAT_64:
- dtype_code = "d";
+ dtype_string = "float64";
break;
default:
throw RaiseValueError("Unsupported VM Buffer -> numpy dtype mapping");
}
- return py::dtype(dtype_code);
+ return py::dtype(dtype_string);
}
} // namespace
diff --git a/runtime/bindings/python/tests/vm_types_test.py b/runtime/bindings/python/tests/vm_types_test.py
index 782142e..6710026 100644
--- a/runtime/bindings/python/tests/vm_types_test.py
+++ b/runtime/bindings/python/tests/vm_types_test.py
@@ -49,12 +49,18 @@
def test_variant_list_buffers(self):
device = rt.get_device("local-sync")
ET = rt.HalElementType
- for dt, et in ((np.int8, ET.SINT_8), (np.int16, ET.SINT_16),
- (np.int32, ET.SINT_32), (np.int64, ET.SINT_64),
- (np.uint8, ET.UINT_8), (np.uint16, ET.UINT_16),
- (np.uint32, ET.UINT_32), (np.uint64, ET.UINT_64),
- (np.float32, ET.FLOAT_32), (np.float64, ET.FLOAT_64)):
- # TODO: Unimplemented: (np.float16, ET.FLOAT_16)
+ for dt, et in (
+ (np.int8, ET.SINT_8), #
+ (np.int16, ET.SINT_16), #
+ (np.int32, ET.SINT_32), #
+ (np.int64, ET.SINT_64), #
+ (np.uint8, ET.UINT_8), #
+ (np.uint16, ET.UINT_16), #
+ (np.uint32, ET.UINT_32), #
+ (np.uint64, ET.UINT_64), #
+ (np.float16, ET.FLOAT_16), #
+ (np.float32, ET.FLOAT_32), #
+ (np.float64, ET.FLOAT_64)):
lst = rt.VmVariantList(5)
ary1 = np.asarray([1, 2, 3, 4], dtype=dt)
bv1 = device.allocator.allocate_buffer_copy(
@@ -65,7 +71,6 @@
lst.push_ref(bv1)
ary2 = rt.DeviceArray(device,
lst.get_as_object(0, rt.HalBufferView),
- override_dtype=dt,
implicit_host_transfer=True)
np.testing.assert_array_equal(ary1, ary2)
with self.assertRaises(IndexError):