forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[runtime env] plugin refactor[3/n]: support strong type by @DataClass (…
- Loading branch information
1 parent
b3878e2
commit 781c2a7
Showing
18 changed files
with
591 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2018 Konrad Hałas | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .config import Config | ||
from .core import from_dict | ||
from .exceptions import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from dataclasses import dataclass, field | ||
from typing import Dict, Any, Callable, Optional, Type, List | ||
|
||
|
||
@dataclass | ||
class Config: | ||
type_hooks: Dict[Type, Callable[[Any], Any]] = field(default_factory=dict) | ||
cast: List[Type] = field(default_factory=list) | ||
forward_references: Optional[Dict[str, Any]] = None | ||
check_types: bool = True | ||
strict: bool = False | ||
strict_unions_match: bool = False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import copy | ||
from dataclasses import is_dataclass | ||
from itertools import zip_longest | ||
from typing import TypeVar, Type, Optional, get_type_hints, Mapping, Any | ||
|
||
from .config import Config | ||
from .data import Data | ||
from .dataclasses import get_default_value_for_field, create_instance, DefaultValueNotFoundError, get_fields | ||
from .exceptions import ( | ||
ForwardReferenceError, | ||
WrongTypeError, | ||
DaciteError, | ||
UnionMatchError, | ||
MissingValueError, | ||
DaciteFieldError, | ||
UnexpectedDataError, | ||
StrictUnionMatchError, | ||
) | ||
from .types import ( | ||
is_instance, | ||
is_generic_collection, | ||
is_union, | ||
extract_generic, | ||
is_optional, | ||
transform_value, | ||
extract_origin_collection, | ||
is_init_var, | ||
extract_init_var, | ||
) | ||
|
||
T = TypeVar("T") | ||
|
||
|
||
def from_dict(data_class: Type[T], data: Data, config: Optional[Config] = None) -> T: | ||
"""Create a data class instance from a dictionary. | ||
:param data_class: a data class type | ||
:param data: a dictionary of a input data | ||
:param config: a configuration of the creation process | ||
:return: an instance of a data class | ||
""" | ||
init_values: Data = {} | ||
post_init_values: Data = {} | ||
config = config or Config() | ||
try: | ||
data_class_hints = get_type_hints(data_class, globalns=config.forward_references) | ||
except NameError as error: | ||
raise ForwardReferenceError(str(error)) | ||
data_class_fields = get_fields(data_class) | ||
if config.strict: | ||
extra_fields = set(data.keys()) - {f.name for f in data_class_fields} | ||
if extra_fields: | ||
raise UnexpectedDataError(keys=extra_fields) | ||
for field in data_class_fields: | ||
field = copy.copy(field) | ||
field.type = data_class_hints[field.name] | ||
try: | ||
try: | ||
field_data = data[field.name] | ||
transformed_value = transform_value( | ||
type_hooks=config.type_hooks, cast=config.cast, target_type=field.type, value=field_data | ||
) | ||
value = _build_value(type_=field.type, data=transformed_value, config=config) | ||
except DaciteFieldError as error: | ||
error.update_path(field.name) | ||
raise | ||
if config.check_types and not is_instance(value, field.type): | ||
raise WrongTypeError(field_path=field.name, field_type=field.type, value=value) | ||
except KeyError: | ||
try: | ||
value = get_default_value_for_field(field) | ||
except DefaultValueNotFoundError: | ||
if not field.init: | ||
continue | ||
raise MissingValueError(field.name) | ||
if field.init: | ||
init_values[field.name] = value | ||
else: | ||
post_init_values[field.name] = value | ||
|
||
return create_instance(data_class=data_class, init_values=init_values, post_init_values=post_init_values) | ||
|
||
|
||
def _build_value(type_: Type, data: Any, config: Config) -> Any: | ||
if is_init_var(type_): | ||
type_ = extract_init_var(type_) | ||
if is_union(type_): | ||
return _build_value_for_union(union=type_, data=data, config=config) | ||
elif is_generic_collection(type_) and is_instance(data, extract_origin_collection(type_)): | ||
return _build_value_for_collection(collection=type_, data=data, config=config) | ||
elif is_dataclass(type_) and is_instance(data, Data): | ||
return from_dict(data_class=type_, data=data, config=config) | ||
return data | ||
|
||
|
||
def _build_value_for_union(union: Type, data: Any, config: Config) -> Any: | ||
types = extract_generic(union) | ||
if is_optional(union) and len(types) == 2: | ||
return _build_value(type_=types[0], data=data, config=config) | ||
union_matches = {} | ||
for inner_type in types: | ||
try: | ||
# noinspection PyBroadException | ||
try: | ||
data = transform_value( | ||
type_hooks=config.type_hooks, cast=config.cast, target_type=inner_type, value=data | ||
) | ||
except Exception: # pylint: disable=broad-except | ||
continue | ||
value = _build_value(type_=inner_type, data=data, config=config) | ||
if is_instance(value, inner_type): | ||
if config.strict_unions_match: | ||
union_matches[inner_type] = value | ||
else: | ||
return value | ||
except DaciteError: | ||
pass | ||
if config.strict_unions_match: | ||
if len(union_matches) > 1: | ||
raise StrictUnionMatchError(union_matches) | ||
return union_matches.popitem()[1] | ||
if not config.check_types: | ||
return data | ||
raise UnionMatchError(field_type=union, value=data) | ||
|
||
|
||
def _build_value_for_collection(collection: Type, data: Any, config: Config) -> Any: | ||
data_type = data.__class__ | ||
if is_instance(data, Mapping): | ||
item_type = extract_generic(collection, defaults=(Any, Any))[1] | ||
return data_type((key, _build_value(type_=item_type, data=value, config=config)) for key, value in data.items()) | ||
elif is_instance(data, tuple): | ||
types = extract_generic(collection) | ||
if len(types) == 2 and types[1] == Ellipsis: | ||
return data_type(_build_value(type_=types[0], data=item, config=config) for item in data) | ||
return data_type( | ||
_build_value(type_=type_, data=item, config=config) for item, type_ in zip_longest(data, types) | ||
) | ||
item_type = extract_generic(collection, defaults=(Any,))[0] | ||
return data_type(_build_value(type_=item_type, data=item, config=config) for item in data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from typing import Dict, Any | ||
|
||
Data = Dict[str, Any] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from dataclasses import Field, MISSING, _FIELDS, _FIELD, _FIELD_INITVAR # type: ignore | ||
from typing import Type, Any, TypeVar, List | ||
|
||
from .data import Data | ||
from .types import is_optional | ||
|
||
T = TypeVar("T", bound=Any) | ||
|
||
|
||
class DefaultValueNotFoundError(Exception): | ||
pass | ||
|
||
|
||
def get_default_value_for_field(field: Field) -> Any: | ||
if field.default != MISSING: | ||
return field.default | ||
elif field.default_factory != MISSING: # type: ignore | ||
return field.default_factory() # type: ignore | ||
elif is_optional(field.type): | ||
return None | ||
raise DefaultValueNotFoundError() | ||
|
||
|
||
def create_instance(data_class: Type[T], init_values: Data, post_init_values: Data) -> T: | ||
instance = data_class(**init_values) | ||
for key, value in post_init_values.items(): | ||
setattr(instance, key, value) | ||
return instance | ||
|
||
|
||
def get_fields(data_class: Type[T]) -> List[Field]: | ||
fields = getattr(data_class, _FIELDS) | ||
return [f for f in fields.values() if f._field_type is _FIELD or f._field_type is _FIELD_INITVAR] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
from typing import Any, Type, Optional, Set, Dict | ||
|
||
|
||
def _name(type_: Type) -> str: | ||
return type_.__name__ if hasattr(type_, "__name__") else str(type_) | ||
|
||
|
||
class DaciteError(Exception): | ||
pass | ||
|
||
|
||
class DaciteFieldError(DaciteError): | ||
def __init__(self, field_path: Optional[str] = None): | ||
super().__init__() | ||
self.field_path = field_path | ||
|
||
def update_path(self, parent_field_path: str) -> None: | ||
if self.field_path: | ||
self.field_path = f"{parent_field_path}.{self.field_path}" | ||
else: | ||
self.field_path = parent_field_path | ||
|
||
|
||
class WrongTypeError(DaciteFieldError): | ||
def __init__(self, field_type: Type, value: Any, field_path: Optional[str] = None) -> None: | ||
super().__init__(field_path=field_path) | ||
self.field_type = field_type | ||
self.value = value | ||
|
||
def __str__(self) -> str: | ||
return ( | ||
f'wrong value type for field "{self.field_path}" - should be "{_name(self.field_type)}" ' | ||
f'instead of value "{self.value}" of type "{_name(type(self.value))}"' | ||
) | ||
|
||
|
||
class MissingValueError(DaciteFieldError): | ||
def __init__(self, field_path: Optional[str] = None): | ||
super().__init__(field_path=field_path) | ||
|
||
def __str__(self) -> str: | ||
return f'missing value for field "{self.field_path}"' | ||
|
||
|
||
class UnionMatchError(WrongTypeError): | ||
def __str__(self) -> str: | ||
return ( | ||
f'can not match type "{_name(type(self.value))}" to any type ' | ||
f'of "{self.field_path}" union: {_name(self.field_type)}' | ||
) | ||
|
||
|
||
class StrictUnionMatchError(DaciteFieldError): | ||
def __init__(self, union_matches: Dict[Type, Any], field_path: Optional[str] = None) -> None: | ||
super().__init__(field_path=field_path) | ||
self.union_matches = union_matches | ||
|
||
def __str__(self) -> str: | ||
conflicting_types = ", ".join(_name(type_) for type_ in self.union_matches) | ||
return f'can not choose between possible Union matches for field "{self.field_path}": {conflicting_types}' | ||
|
||
|
||
class ForwardReferenceError(DaciteError): | ||
def __init__(self, message: str) -> None: | ||
super().__init__() | ||
self.message = message | ||
|
||
def __str__(self) -> str: | ||
return f"can not resolve forward reference: {self.message}" | ||
|
||
|
||
class UnexpectedDataError(DaciteError): | ||
def __init__(self, keys: Set[str]) -> None: | ||
super().__init__() | ||
self.keys = keys | ||
|
||
def __str__(self) -> str: | ||
formatted_keys = ", ".join(f'"{key}"' for key in self.keys) | ||
return f"can not match {formatted_keys} to any data class field" |
Empty file.
Oops, something went wrong.