Skip to content

Commit

Permalink
Chore: Cosmetic changes to how we resolve session properties for a mo…
Browse files Browse the repository at this point in the history
  • Loading branch information
izeigerman authored May 13, 2024
1 parent 3f2643a commit 9e7d4ca
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 13 deletions.
1 change: 1 addition & 0 deletions docs/reference/model_configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ The SQLMesh project-level `model_defaults` key supports the following options, d
- owner
- start
- storage_format
- session_properties (on per key basis)

## Model kind properties

Expand Down
23 changes: 11 additions & 12 deletions sqlmesh/core/model/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1813,8 +1813,8 @@ def _create_model(
) -> Model:
_validate_model_fields(klass, {"name", *kwargs} - {"grain", "table_properties"}, path)

kwargs["session_properties"] = _resolve_custom_session_properties(
defaults, kwargs.get("session_properties")
kwargs["session_properties"] = _resolve_session_properties(
(defaults or {}).get("session_properties"), kwargs.get("session_properties")
)

dialect = dialect or ""
Expand Down Expand Up @@ -1880,21 +1880,20 @@ def _split_sql_model_statements(
return query, expressions[:pos], expressions[pos + 1 :]


def _resolve_custom_session_properties(
defaults: t.Optional[t.Dict[str, t.Any]],
provided: t.Optional[exp.Expression] | t.Optional[t.Dict[str, t.Any]],
def _resolve_session_properties(
default: t.Optional[t.Dict[str, t.Any]],
provided: t.Optional[exp.Expression | t.Dict[str, t.Any]],
) -> t.Optional[exp.Expression]:
if isinstance(provided, dict):
session_properties = {k: exp.Literal.string(k).eq(v) for k, v in provided.items()}
elif provided:
session_properties = {expr.this.name: expr for expr in provided}
else:
session_properties = (
{expr.this.name: expr for expr in provided} if provided is not None else {}
)
session_properties = {}

if defaults and defaults.get("session_properties"):
for k, v in defaults["session_properties"].items():
if k not in session_properties:
session_properties[k] = exp.Literal.string(k).eq(v)
for k, v in (default or {}).items():
if k not in session_properties:
session_properties[k] = exp.Literal.string(k).eq(v)

if session_properties:
return exp.Tuple(expressions=list(session_properties.values()))
Expand Down
9 changes: 8 additions & 1 deletion tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1554,7 +1554,7 @@ def my_model(context, **kwargs):
assert m.depends_on == {'"foo"."bar"'}


def test_python_model_with_session_props():
def test_python_model_with_session_properties():
@model(
name="python_model_prop",
kind="full",
Expand All @@ -1568,11 +1568,18 @@ def python_model_prop(context, **kwargs):
module_path=Path("."),
path=Path("."),
dialect="duckdb",
defaults={
"session_properties": {
"some_string": "default_string",
"default_value": "default_value",
}
},
)
assert m.session_properties == {
"some_string": "string_prop",
"some_bool": True,
"some_float": 1.0,
"default_value": "default_value",
}


Expand Down

0 comments on commit 9e7d4ca

Please sign in to comment.