diff --git a/bump_pydantic/codemods/add_default_none.py b/bump_pydantic/codemods/add_default_none.py index c06be7a..be0bf1f 100644 --- a/bump_pydantic/codemods/add_default_none.py +++ b/bump_pydantic/codemods/add_default_none.py @@ -56,7 +56,7 @@ def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef self.inside_base_model = False return updated_node - def visit_AnnAssign(self, node: cst.AnnAssign) -> bool | None: + def visit_AnnAssign(self, node: cst.AnnAssign) -> None: if m.matches( node.annotation.annotation, m.Subscript(m.Name("Optional") | m.Attribute(m.Name("typing"), m.Name("Optional"))) @@ -75,11 +75,29 @@ def visit_AnnAssign(self, node: cst.AnnAssign) -> bool | None: | m.BinaryOperation(operator=m.BitOr(), right=m.Name("None")), ): self.should_add_none = True - return super().visit_AnnAssign(node) + return None def leave_AnnAssign(self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign) -> cst.AnnAssign: - if self.inside_base_model and self.should_add_none and updated_node.value is None: - updated_node = updated_node.with_changes(value=cst.Name("None")) + if self.inside_base_model and self.should_add_none: + if updated_node.value is None: + updated_node = updated_node.with_changes(value=cst.Name("None")) + # TODO: Should accept `pydantic.Field` as well. + elif m.matches(updated_node.value, m.Call(func=m.Name("Field"))): + assert isinstance(updated_node.value, cst.Call) + if updated_node.value.args: + arg = updated_node.value.args[0] + if (arg.keyword is None or arg.keyword.value == "default") and m.matches(arg.value, m.Ellipsis()): + updated_node = updated_node.with_changes( + value=updated_node.value.with_changes( + args=[arg.with_changes(value=cst.Name("None")), *updated_node.value.args[1:]] + ) + ) + # This is the case where `Field` is called without any arguments e.g. `Field()`. + else: + updated_node = updated_node.with_changes( + value=updated_node.value.with_changes(args=[cst.Arg(value=cst.Name("None"))]) # type: ignore + ) + self.inside_an_assign = False self.should_add_none = False return updated_node diff --git a/tests/integration/cases/__init__.py b/tests/integration/cases/__init__.py index 81856d9..baaea54 100644 --- a/tests/integration/cases/__init__.py +++ b/tests/integration/cases/__init__.py @@ -4,8 +4,8 @@ from .add_none import cases as add_none_cases from .base_settings import cases as base_settings_cases from .config_to_model import cases as config_to_model_cases +from .field import cases as generic_model_cases from .folder_inside_folder import cases as folder_inside_folder_cases -from .generic_model import cases as generic_model_cases from .is_base_model import cases as is_base_model_cases from .replace_validator import cases as replace_validator_cases from .root_model import cases as root_model_cases diff --git a/tests/integration/cases/add_none.py b/tests/integration/cases/add_none.py index 90781df..ec52077 100644 --- a/tests/integration/cases/add_none.py +++ b/tests/integration/cases/add_none.py @@ -9,7 +9,7 @@ content=[ "from typing import Any, Dict, Optional, Union", "", - "from pydantic import BaseModel", + "from pydantic import BaseModel, Field", "", "", "class A(BaseModel):", @@ -18,6 +18,10 @@ " c: Union[int, None]", " d: Any", " e: Dict[str, str]", + " f: Optional[int] = Field(..., lt=10)", + " g: Optional[int] = Field()", + " h: Optional[int] = Field(...)", + " i: Optional[int] = Field(default_factory=lambda: None)", ], ), expected=File( @@ -25,7 +29,7 @@ content=[ "from typing import Any, Dict, Optional, Union", "", - "from pydantic import BaseModel", + "from pydantic import BaseModel, Field", "", "", "class A(BaseModel):", @@ -34,6 +38,10 @@ " c: Union[int, None] = None", " d: Any = None", " e: Dict[str, str]", + " f: Optional[int] = Field(None, lt=10)", + " g: Optional[int] = Field(None)", + " h: Optional[int] = Field(None)", + " i: Optional[int] = Field(default_factory=lambda: None)", ], ), ) diff --git a/tests/integration/cases/generic_model.py b/tests/integration/cases/field.py similarity index 100% rename from tests/integration/cases/generic_model.py rename to tests/integration/cases/field.py