Skip to content

Commit

Permalink
Add support for all new features of 3.13
Browse files Browse the repository at this point in the history
  • Loading branch information
zhPavel committed Dec 14, 2024
1 parent 138719c commit 981652f
Show file tree
Hide file tree
Showing 9 changed files with 336 additions and 88 deletions.
2 changes: 1 addition & 1 deletion benchmarks/benchmarks/pybench/matplotlib_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def sign(x):

if lbl is None:
if isinstance(fmt, str):
lbl = cbook._auto_format_str(fmt, value) # type: ignore[attr-defined]
lbl = cbook._auto_format_str(fmt, value)
elif callable(fmt):
lbl = fmt(value)
else:
Expand Down
1 change: 1 addition & 0 deletions src/adaptix/_internal/feature_requirement.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def fail_reason(self) -> str:
HAS_TV_SYNTAX = HAS_PY_312

HAS_PY_313 = PythonVersionRequirement((3, 13))
HAS_TV_DEFAULT = HAS_PY_313

HAS_SUPPORTED_ATTRS_PKG = DistributionVersionRequirement("attrs", "21.3.0")
HAS_ATTRS_PKG = DistributionRequirement("attrs")
Expand Down
1 change: 1 addition & 0 deletions src/adaptix/_internal/type_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .normalize_type import (
AnyNormTypeVarLike,
BaseNormType,
NormParamSpec,
NormParamSpecMarker,
NormTV,
NormTVTuple,
Expand Down
23 changes: 18 additions & 5 deletions src/adaptix/_internal/type_tools/implicit_params.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import sys
import typing
from itertools import chain
from typing import Any, ForwardRef, TypeVar

from ..common import TypeHint, VarTuple
from ..feature_requirement import HAS_PARAM_SPEC, HAS_TV_TUPLE
from ..feature_requirement import HAS_PARAM_SPEC, HAS_TV_DEFAULT, HAS_TV_TUPLE
from .basic_utils import create_union, eval_forward_ref, is_user_defined_generic, strip_alias
from .constants import BUILTIN_ORIGIN_TO_TYPEVARS

Expand All @@ -14,7 +15,7 @@ def _process_limit_element(self, type_var: TypeVar, tp: TypeHint) -> TypeHint:
return eval_forward_ref(vars(sys.modules[type_var.__module__]), tp)
return tp

def _process_type_var(self, type_var) -> TypeHint:
def _derive_default(self, type_var) -> TypeHint:
if HAS_PARAM_SPEC and isinstance(type_var, typing.ParamSpec):
return ...
if HAS_TV_TUPLE and isinstance(type_var, typing.TypeVarTuple):
Expand All @@ -30,20 +31,32 @@ def _process_type_var(self, type_var) -> TypeHint:
return Any
return self._process_limit_element(type_var, type_var.__bound__)

def _get_default_tuple(self, type_var) -> VarTuple[TypeHint]:
if HAS_TV_DEFAULT and type_var.has_default():
if isinstance(type_var, TypeVar):
return (type_var.__default__, ) # type: ignore[attr-defined]
return type_var.__default__
return (self._derive_default(type_var), )

def get_implicit_params(self, origin) -> VarTuple[TypeHint]:
if is_user_defined_generic(origin):
type_vars = origin.__parameters__
else:
type_vars = BUILTIN_ORIGIN_TO_TYPEVARS.get(origin, ())

return tuple(
self._process_type_var(type_var)
for type_var in type_vars
chain.from_iterable(
self._get_default_tuple(type_var)
for type_var in type_vars
),
)


_getter = ImplicitParamsGetter()


def fill_implicit_params(tp: TypeHint) -> TypeHint:
params = ImplicitParamsGetter().get_implicit_params(strip_alias(tp))
params = _getter.get_implicit_params(strip_alias(tp))
if params:
return tp[params]
raise ValueError(f"Can not derive implicit parameters for {tp}")
4 changes: 3 additions & 1 deletion src/adaptix/_internal/type_tools/norm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
from dataclasses import InitVar
from typing import Annotated, ClassVar, Final, TypeVar

from ..feature_requirement import HAS_TYPED_DICT_REQUIRED
from ..feature_requirement import HAS_PY_313, HAS_TYPED_DICT_REQUIRED
from .normalize_type import BaseNormType

_TYPE_TAGS = [Final, ClassVar, InitVar, Annotated]

if HAS_TYPED_DICT_REQUIRED:
_TYPE_TAGS.extend([typing.Required, typing.NotRequired])
if HAS_PY_313:
_TYPE_TAGS.extend([typing.ReadOnly, typing.TypeIs]) # type: ignore[attr-defined]


def strip_tags(norm: BaseNormType) -> BaseNormType:
Expand Down
123 changes: 68 additions & 55 deletions src/adaptix/_internal/type_tools/normalize_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
HAS_PARAM_SPEC,
HAS_PY_310,
HAS_PY_311,
HAS_PY_313,
HAS_SELF_TYPE,
HAS_TV_DEFAULT,
HAS_TV_SYNTAX,
HAS_TV_TUPLE,
HAS_TYPE_ALIAS,
Expand Down Expand Up @@ -215,21 +217,12 @@ class Constraints:
TypeVarLimit = Union[Bound, Constraints]


class NormTV(BaseNormType):
__slots__ = ("_var", "_limit", "_variance", "_source")
class _BaseNormTypeVarLike(BaseNormType):
__slots__ = ("_var", "_source")

def __init__(self, var: Any, limit: TypeVarLimit, *, source: TypeHint):
def __init__(self, var: Any, *, source: TypeHint):
self._var = var
self._source = source
self._limit = limit

if var.__covariant__:
self._variance = Variance.COVARIANT
if var.__contravariant__:
self._variance = Variance.CONTRAVARIANT
if getattr(var, "__infer_variance__", False):
self._variance = Variance.INFERRED
self._variance = Variance.INVARIANT

@property
def origin(self) -> Any:
Expand All @@ -247,63 +240,76 @@ def source(self) -> TypeHint:
def name(self) -> str:
return self._var.__name__

@property
def variance(self) -> Variance:
return self._variance

@property
def limit(self) -> TypeVarLimit:
return self._limit

def __repr__(self):
return f"<{type(self).__name__}({self._var})>"

def __hash__(self):
return hash(self._var)

def __eq__(self, other):
if isinstance(other, NormTV):
if isinstance(other, type(self)):
return self._var == other._var
if isinstance(other, BaseNormType):
return False
return NotImplemented


class NormTVTuple(BaseNormType):
__slots__ = ("_var", "_source")
class NormTV(_BaseNormTypeVarLike):
__slots__ = (*_BaseNormTypeVarLike.__slots__, "_limit", "_variance", "_default")

def __init__(self, var: Any, *, source: TypeHint):
self._var = var
self._source = source
def __init__(self, var: Any, limit: TypeVarLimit, *, source: TypeHint, default: Optional[BaseNormType]):
super().__init__(var, source=source)
self._limit = limit

if var.__covariant__:
self._variance = Variance.COVARIANT
if var.__contravariant__:
self._variance = Variance.CONTRAVARIANT
if getattr(var, "__infer_variance__", False):
self._variance = Variance.INFERRED
self._variance = Variance.INVARIANT
self._default = default

@property
def origin(self) -> Any:
return self._var
def variance(self) -> Variance:
return self._variance

@property
def args(self) -> tuple[()]:
return ()
def limit(self) -> TypeVarLimit:
return self._limit

@property
def source(self) -> TypeHint:
return self._source
def default(self) -> Optional[BaseNormType]:
return self._default


class NormTVTuple(_BaseNormTypeVarLike):
__slots__ = (*_BaseNormTypeVarLike.__slots__, "_default")

def __init__(self, var: Any, *, source: TypeHint, default: Optional[tuple[BaseNormType, ...]]):
super().__init__(var, source=source)
self._default = default

@property
def name(self) -> str:
return self._var.__name__
def default(self) -> Optional[tuple[BaseNormType, ...]]:
return self._default

def __repr__(self):
return f"<{type(self).__name__}({self._var})>"

def __hash__(self):
return hash(self._var)
class NormParamSpec(_BaseNormTypeVarLike):
__slots__ = (*_BaseNormTypeVarLike.__slots__, "_limit", "_default")

def __eq__(self, other):
if isinstance(other, NormTVTuple):
return self._var == other._var
if isinstance(other, BaseNormType):
return False
return NotImplemented
def __init__(self, var: Any, limit: TypeVarLimit, *, source: TypeHint, default: Optional[tuple[BaseNormType, ...]]):
super().__init__(var, source=source)
self._default = default
self._limit = limit

@property
def limit(self) -> TypeVarLimit:
return self._limit

@property
def default(self) -> Optional[tuple[BaseNormType, ...]]:
return self._default


class NormParamSpecMarker(BaseNormType, ABC):
Expand All @@ -314,7 +320,7 @@ def __init__(self, param_spec: Any, *, source: TypeHint):
self._source = source

@property
def param_spec(self) -> NormTV:
def param_spec(self) -> NormParamSpec:
return self._param_spec

@property
Expand Down Expand Up @@ -348,7 +354,7 @@ def origin(self) -> Any:
return typing.ParamSpecKwargs


AnyNormTypeVarLike = Union[NormTV, NormTVTuple]
AnyNormTypeVarLike = Union[NormTV, NormTVTuple, NormParamSpec]


class NormTypeAlias(BaseNormType):
Expand Down Expand Up @@ -394,7 +400,11 @@ def __hash__(self):
return hash(self._type_alias)


_PARAM_SPEC_MARKER_TYPES = (typing.ParamSpecArgs, typing.ParamSpecKwargs) if HAS_PARAM_SPEC else ()
_SPECIAL_CONSTRUCTOR_TYPE = (
TypeVar,
*((typing.ParamSpecArgs, typing.ParamSpecKwargs, typing.ParamSpec) if HAS_PARAM_SPEC else ()),
*((typing.TypeVarTuple,) if HAS_TV_TUPLE else ()),
)


def make_norm_type(
Expand All @@ -413,11 +423,7 @@ def make_norm_type(
return _LiteralNormType(args, source=source)
if origin == Annotated:
return _AnnotatedNormType(args, source=source)
if isinstance(origin, TypeVar):
raise TypeError
if HAS_PARAM_SPEC and (
isinstance(origin, _PARAM_SPEC_MARKER_TYPES) or isinstance(source, _PARAM_SPEC_MARKER_TYPES)
):
if isinstance(origin, _SPECIAL_CONSTRUCTOR_TYPE) or isinstance(source, _SPECIAL_CONSTRUCTOR_TYPE):
raise TypeError
return _NormType(origin, args, source=source)

Expand Down Expand Up @@ -567,6 +573,8 @@ def _norm_iter(self, tps: Iterable[Any]) -> VarTuple[BaseNormType]:
MUST_SUBSCRIBED_ORIGINS.append(typing.TypeGuard)
if HAS_TYPED_DICT_REQUIRED:
MUST_SUBSCRIBED_ORIGINS.extend([typing.Required, typing.NotRequired])
if HAS_PY_313:
MUST_SUBSCRIBED_ORIGINS.extend([typing.ReadOnly, typing.TypeIs]) # type: ignore[attr-defined]

@_aspect_storage.add
def _check_bad_input(self, tp, origin, args):
Expand Down Expand Up @@ -611,12 +619,15 @@ def _norm_type_var(self, tp, origin, args):
if origin.__constraints__ else
namespaced._get_bound(origin)
)
return NormTV(var=origin, limit=limit, source=tp)
default = namespaced.normalize(origin.__default__) if HAS_TV_DEFAULT and origin.has_default() else None
return NormTV(var=origin, limit=limit, source=tp, default=default)

@_aspect_storage.add(condition=HAS_TV_TUPLE)
def _norm_type_var_tuple(self, tp, origin, args):
if isinstance(origin, typing.TypeVarTuple):
return NormTVTuple(var=origin, source=tp)
namespaced = self._with_module_namespace(origin.__module__)
default = namespaced._norm_iter(origin.__default__) if HAS_TV_DEFAULT and origin.has_default() else None
return NormTVTuple(var=origin, source=tp, default=default)

@_aspect_storage.add(condition=HAS_PARAM_SPEC)
def _norm_param_spec(self, tp, origin, args):
Expand All @@ -628,10 +639,12 @@ def _norm_param_spec(self, tp, origin, args):

if isinstance(origin, typing.ParamSpec):
namespaced = self._with_module_namespace(origin.__module__)
return NormTV(
default = namespaced._norm_iter(origin.__default__) if HAS_TV_DEFAULT and origin.has_default() else None
return NormParamSpec(
var=origin,
limit=namespaced._get_bound(origin),
source=tp,
default=default,
)

@_aspect_storage.add(condition=HAS_TV_SYNTAX)
Expand Down
51 changes: 51 additions & 0 deletions tests/unit/provider/shape_provider/test_generic_resolving.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import sys
import typing
from contextlib import nullcontext
from typing import Any, Dict, Generic, List, Tuple, TypeVar

import pytest
Expand All @@ -11,6 +13,7 @@
HAS_PY_312,
HAS_SELF_TYPE,
HAS_SUPPORTED_PYDANTIC_PKG,
HAS_TV_DEFAULT,
HAS_TV_TUPLE,
IS_PYPY,
DistributionVersionRequirement,
Expand Down Expand Up @@ -554,3 +557,51 @@ def b(self) -> T:

# a limitation of pydantic implementation
assert_distinct_fields_types(MyModel[T], input={"a": Any}, output={"a": Any, "b": Any, "_c": Any})


NOTHING_TYPEVAR_MAKER = lambda default: TypeVar("tv_int", default=default) # noqa: E731


@requires(HAS_TV_DEFAULT)
@pytest.mark.parametrize(
"tv_maker",
[
pytest.param(NOTHING_TYPEVAR_MAKER, id="nothing"),
pytest.param(lambda default: TypeVar("tv_int", default=default, bound=object), id="bound"),
pytest.param(lambda default: TypeVar("tv_int", str, int, default=default), id="constraints"),
],
)
def test_tv_default(model_spec, tv_maker):
with (
pytest.raises(NotImplementedError)
if model_spec.kind == ModelSpec.PYDANTIC and tv_maker != NOTHING_TYPEVAR_MAKER else
nullcontext()
):
tv_int = tv_maker(default=int)

@model_spec.decorator
class MyModel(*model_spec.bases, Generic[tv_int]):
a: tv_int
b: int

assert_fields_types(MyModel, {"a": int, "b": int})
assert_fields_types(MyModel[int], {"a": int, "b": int})
assert_fields_types(MyModel[str], {"a": str, "b": int})
assert_fields_types(MyModel[T], {"a": T, "b": int})


@requires(HAS_TV_DEFAULT)
@exclude_model_spec(ModelSpec.PYDANTIC)
def test_tv_tuple_default(model_spec):
t1 = typing.TypeVarTuple("t1", default=(int, str))

@model_spec.decorator
class MyModel(*model_spec.bases, Generic[typing.Unpack[t1]]):
a: tuple[typing.Unpack[t1]]
b: int

assert_fields_types(MyModel, {"a": tuple[int, str], "b": int})
assert_fields_types(MyModel[str], {"a": tuple[str], "b": int})
assert_fields_types(MyModel[int, str], {"a": tuple[int, str], "b": int})
assert_fields_types(MyModel[int, str, bool], {"a": tuple[int, str, bool], "b": int})
assert_fields_types(MyModel[T], {"a": tuple[T], "b": int})
Loading

0 comments on commit 981652f

Please sign in to comment.