Skip to content

Commit

Permalink
✨ Add SettingsConfigDict to BaseSettings (pydantic#74)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex authored Jul 10, 2023
1 parent a7b399f commit 1807872
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
16 changes: 13 additions & 3 deletions bump_pydantic/codemods/replace_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,15 @@ def __init__(self, context: CodemodContext) -> None:
super().__init__(context)

self.inside_config_class = False
self.is_base_settings = False
self.invalid_config_class = False
self.inherited_config_class = False
self.config_args: List[cst.Arg] = []

@m.visit(m.ClassDef(bases=[m.ZeroOrMore(), m.Arg(value=m.Name("BaseSettings")), m.ZeroOrMore()]))
def visit_settings_with_config(self, node: cst.ClassDef) -> None:
self.is_base_settings = True

@m.visit(m.ClassDef(name=m.Name(value="Config")))
def visit_config_class(self, node: cst.ClassDef) -> None:
scope = self.get_metadata(ScopeProvider, node)
Expand Down Expand Up @@ -203,15 +208,19 @@ def leave_config_class_childless(self, original_node: cst.ClassDef, updated_node
if self.invalid_config_class:
self.invalid_config_class = False
return updated_node
AddImportsVisitor.add_needed_import(context=self.context, module="pydantic", obj="ConfigDict")
if self.is_base_settings:
needed_import = {"module": "pydantic_settings", "obj": "SettingsConfigDict"}
else:
needed_import = {"module": "pydantic", "obj": "ConfigDict"}
AddImportsVisitor.add_needed_import(context=self.context, **needed_import) # type: ignore[arg-type]
block = cst.ensure_type(updated_node.body, cst.IndentedBlock)
body = [
cst.SimpleStatementLine(
body=[
cst.Assign(
targets=[cst.AssignTarget(target=cst.Name("model_config"))],
value=cst.Call(
func=cst.Name("ConfigDict"),
func=cst.Name("SettingsConfigDict" if self.is_base_settings else "ConfigDict"),
args=self.config_args,
),
)
Expand All @@ -222,6 +231,7 @@ def leave_config_class_childless(self, original_node: cst.ClassDef, updated_node
else statement
for statement in block.body
]
self.is_base_settings = False
self.config_args = []
return updated_node.with_changes(body=updated_node.body.with_changes(body=body))

Expand Down Expand Up @@ -249,7 +259,7 @@ def _leading_lines_from_removed_keys(args: List[cst.Arg]) -> List[cst.EmptyLine]
"""
from pydantic import BaseModel
class A(BaseModel):
class A(BaseSettings):
a: str
# My comment
Expand Down
6 changes: 5 additions & 1 deletion tests/integration/cases/config_to_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@
"class Settings(BaseSettings):",
" sentry_dsn: str",
"",
" class Config:",
" orm_mode = True",
"",
"",
"class A(BaseModel):",
" class Config:",
Expand All @@ -70,11 +73,12 @@
"config_dict_and_settings.py",
content=[
"from pydantic import ConfigDict, BaseModel",
"from pydantic_settings import BaseSettings",
"from pydantic_settings import BaseSettings, SettingsConfigDict",
"",
"",
"class Settings(BaseSettings):",
" sentry_dsn: str",
" model_config = SettingsConfigDict(from_attributes=True)",
"",
"",
"class A(BaseModel):",
Expand Down

0 comments on commit 1807872

Please sign in to comment.