blob: bc430c72e9eba418e2e4cfc2b2e3f343e0b0bbec [file] [log] [blame]
## Copyright 2022 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Helpers to serialize/deserialize objects."""
from enum import Enum
from typing import Any, Dict, Optional, Sequence, Tuple, Type, TypeVar, Union
import dataclasses
import typing
# types.NoneType is only added after Python 3.10.
NONE_TYPE = type(None)
SERIALIZE_FUNC_NAME = "__serialize__"
DESERIALIZE_FUNC_NAME = "__deserialize__"
SUPPORTED_DICT_KEY_TYPES = {str, int, float, bool}
SUPPORTED_PRIMITIVE_TYPES = {str, int, float, bool, NONE_TYPE}
def serialize_and_pack(obj,
root_obj_field_name="root_obj",
keyed_obj_map_field_name="keyed_obj_map"):
"""Converts and packs the object into a serializable object.
Args:
obj: object to be serialized.
root_obj_field_name: field name of the top-level object in the return dict.
keyed_obj_map_field_name: field name of the keyed object map in the return
dict.
Returns
A serializable dict.
"""
if root_obj_field_name == keyed_obj_map_field_name:
raise ValueError(
f"root_obj and keyed_obj_map can't have the same field name.")
keyed_obj_map = {}
root_obj = _serialize(obj=obj, keyed_obj_map=keyed_obj_map)
return {
root_obj_field_name: root_obj,
keyed_obj_map_field_name: keyed_obj_map
}
T = TypeVar('T')
def unpack_and_deserialize(data,
root_type: Type[T],
root_obj_field_name="root_obj",
keyed_obj_map_field_name="keyed_obj_map") -> T:
"""Unpacks and deserializes the data back to the typed object.
Args:
data: serialized data dict.
root_type: top-level object type of the data.
root_obj_field_name: field name of the top-level object in the dict.
keyed_obj_map_field_name: field name of the keyed object map in the dict.
Returns:
A deserialized object.
"""
obj = _deserialize(data=data[root_obj_field_name],
obj_type=root_type,
keyed_obj_map=data[keyed_obj_map_field_name])
return typing.cast(root_type, obj)
def _serialize(obj, keyed_obj_map: Dict[str, Any]):
"""Converts the object into a serializable object.
Args:
obj: object to be serialized.
keyed_obj_map: mutable container to store the keyed serializable object.
Returns
A serializable object.
"""
serialize_func = getattr(obj, SERIALIZE_FUNC_NAME, None)
if serialize_func is not None:
return serialize_func(keyed_obj_map)
elif isinstance(obj, list):
return [_serialize(value, keyed_obj_map) for value in obj]
elif isinstance(obj, Enum):
return obj.name
elif isinstance(obj, dict):
result_dict = {}
for key, value in obj.items():
if type(key) not in SUPPORTED_DICT_KEY_TYPES:
raise ValueError(f"Unsupported key {key} in the dict {obj}.")
result_dict[key] = _serialize(value, keyed_obj_map)
return result_dict
elif type(obj) in SUPPORTED_PRIMITIVE_TYPES:
return obj
raise ValueError(f"Unsupported object: {obj}.")
def _deserialize(data,
obj_type: Type,
keyed_obj_map: Dict[str, Any],
obj_cache: Dict[str, Any] = {}):
"""Deserializes the data back to the typed object.
Args:
data: serialized data.
obj_type: type of the data.
keyed_obj_map: container of the keyed serializable object.
Returns:
A deserialized object.
"""
deserialize_func = getattr(obj_type, DESERIALIZE_FUNC_NAME, None)
if deserialize_func is not None:
return deserialize_func(data, keyed_obj_map, obj_cache)
elif typing.get_origin(obj_type) == list:
subtype, = typing.get_args(obj_type)
return [
_deserialize(item, subtype, keyed_obj_map, obj_cache) for item in data
]
elif typing.get_origin(obj_type) == dict:
_, value_type = typing.get_args(obj_type)
return dict((key, _deserialize(value, value_type, keyed_obj_map, obj_cache))
for key, value in data.items())
elif typing.get_origin(obj_type) == Union:
subtypes = typing.get_args(obj_type)
if len(subtypes) != 2 or NONE_TYPE not in subtypes:
raise ValueError(f"Unsupported union type: {obj_type}.")
subtype = subtypes[0] if subtypes[1] == NONE_TYPE else subtypes[1]
return _deserialize(data, subtype, keyed_obj_map, obj_cache)
elif issubclass(obj_type, Enum):
for member in obj_type:
if data == member.name:
return member
raise ValueError(f"Member {data} not found in the enum {obj_type}.")
return data
def serializable(cls=None,
type_key: Optional[str] = None,
id_field: str = "id"):
"""Decorator to make a dataclass serializable.
Args:
type_key: string defines the object type and indeicates that the class is a
keyed object, which is unique per id and will only have one copy in the
serialization per id.
id_field: field name of the id field of a keyed object.
Example:
@serializable
@dataclass
class A(object):
...
@serialzable(type_key="obj_b")
@dataclass
class B(object):
id: str
"""
if type_key is not None and ":" in type_key:
raise ValueError("':' is the reserved character in type_key.")
def wrap(cls):
if not dataclasses.is_dataclass(cls):
raise ValueError(f"{cls} is not a dataclass.")
fields = dataclasses.fields(cls)
if type_key is not None and all(field.name != id_field for field in fields):
raise ValueError(f'Id field "{id_field}" not found in the class {cls}.')
def serialize(self, keyed_obj_map: Dict[str, Any]):
if type_key is None:
return _fields_to_dict(self, fields, keyed_obj_map)
obj_id = getattr(self, id_field)
obj_key = f"{type_key}:{obj_id}"
if obj_key in keyed_obj_map:
# If the value in the map is None, it means we have visited this object
# before but not yet finished serializing it. This will only happen if
# there is a circular reference.
if keyed_obj_map[obj_key] is None:
raise ValueError(f"Circular reference is not supported: {obj_key}.")
return obj_id
# Populate the keyed_obj_map with None first to detect circular reference.
keyed_obj_map[obj_key] = None
obj_dict = _fields_to_dict(self, fields, keyed_obj_map)
keyed_obj_map[obj_key] = obj_dict
return obj_id
def deserialize(data, keyed_obj_map: Dict[str, Any], obj_cache: Dict[str,
Any]):
if type_key is None:
field_value_map = _dict_to_fields(data, fields, keyed_obj_map,
obj_cache)
return cls(**field_value_map)
obj_id = data
obj_key = f"{type_key}:{obj_id}"
if obj_key in obj_cache:
return obj_cache[obj_key]
field_value_map = _dict_to_fields(keyed_obj_map[obj_key], fields,
keyed_obj_map, obj_cache)
derialized_obj = cls(**field_value_map)
obj_cache[obj_key] = derialized_obj
return derialized_obj
setattr(cls, SERIALIZE_FUNC_NAME, serialize)
setattr(cls, DESERIALIZE_FUNC_NAME, deserialize)
return cls
# Trick to allow the decoration with `@serializable(...)`. In that case,
# `serializable` is called without cls and should return a decorator.
if cls is None:
return wrap
return wrap(cls)
def _fields_to_dict(obj, fields: Sequence[dataclasses.Field],
keyed_obj_map: Dict[str, Any]) -> Dict[str, Any]:
return dict((field.name, _serialize(getattr(obj, field.name), keyed_obj_map))
for field in fields)
def _dict_to_fields(obj_dict, fields: Sequence[dataclasses.Field],
keyed_obj_map: Dict[str, Any],
obj_cache: Dict[str, Any]) -> Dict[str, Any]:
return dict(
(field.name,
_deserialize(obj_dict[field.name], field.type, keyed_obj_map, obj_cache))
for field in fields)