diff --git a/bump_pydantic/codemods/replace_config.py b/bump_pydantic/codemods/replace_config.py index ffee6b4..f1c6ff7 100644 --- a/bump_pydantic/codemods/replace_config.py +++ b/bump_pydantic/codemods/replace_config.py @@ -117,10 +117,15 @@ def __init__(self, context: CodemodContext) -> None: super().__init__(context) self.inside_config_class = False + self.is_base_settings = False self.invalid_config_class = False self.inherited_config_class = False self.config_args: List[cst.Arg] = [] + @m.visit(m.ClassDef(bases=[m.ZeroOrMore(), m.Arg(value=m.Name("BaseSettings")), m.ZeroOrMore()])) + def visit_settings_with_config(self, node: cst.ClassDef) -> None: + self.is_base_settings = True + @m.visit(m.ClassDef(name=m.Name(value="Config"))) def visit_config_class(self, node: cst.ClassDef) -> None: scope = self.get_metadata(ScopeProvider, node) @@ -203,7 +208,11 @@ def leave_config_class_childless(self, original_node: cst.ClassDef, updated_node if self.invalid_config_class: self.invalid_config_class = False return updated_node - AddImportsVisitor.add_needed_import(context=self.context, module="pydantic", obj="ConfigDict") + if self.is_base_settings: + needed_import = {"module": "pydantic_settings", "obj": "SettingsConfigDict"} + else: + needed_import = {"module": "pydantic", "obj": "ConfigDict"} + AddImportsVisitor.add_needed_import(context=self.context, **needed_import) # type: ignore[arg-type] block = cst.ensure_type(updated_node.body, cst.IndentedBlock) body = [ cst.SimpleStatementLine( @@ -211,7 +220,7 @@ def leave_config_class_childless(self, original_node: cst.ClassDef, updated_node cst.Assign( targets=[cst.AssignTarget(target=cst.Name("model_config"))], value=cst.Call( - func=cst.Name("ConfigDict"), + func=cst.Name("SettingsConfigDict" if self.is_base_settings else "ConfigDict"), args=self.config_args, ), ) @@ -222,6 +231,7 @@ def leave_config_class_childless(self, original_node: cst.ClassDef, updated_node else statement for statement in block.body ] + self.is_base_settings = False self.config_args = [] return updated_node.with_changes(body=updated_node.body.with_changes(body=body)) @@ -249,7 +259,7 @@ def _leading_lines_from_removed_keys(args: List[cst.Arg]) -> List[cst.EmptyLine] """ from pydantic import BaseModel - class A(BaseModel): + class A(BaseSettings): a: str # My comment diff --git a/tests/integration/cases/config_to_model.py b/tests/integration/cases/config_to_model.py index c0c706a..e345818 100644 --- a/tests/integration/cases/config_to_model.py +++ b/tests/integration/cases/config_to_model.py @@ -60,6 +60,9 @@ "class Settings(BaseSettings):", " sentry_dsn: str", "", + " class Config:", + " orm_mode = True", + "", "", "class A(BaseModel):", " class Config:", @@ -70,11 +73,12 @@ "config_dict_and_settings.py", content=[ "from pydantic import ConfigDict, BaseModel", - "from pydantic_settings import BaseSettings", + "from pydantic_settings import BaseSettings, SettingsConfigDict", "", "", "class Settings(BaseSettings):", " sentry_dsn: str", + " model_config = SettingsConfigDict(from_attributes=True)", "", "", "class A(BaseModel):",