Skip to content

Commit

Permalink
Fix: introspect function defaults to extract global refs for serializ…
Browse files Browse the repository at this point in the history
…ation (TobikoData#3028)
  • Loading branch information
georgesittas authored Aug 21, 2024
1 parent d8bbe88 commit 67a5798
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 12 deletions.
22 changes: 16 additions & 6 deletions sqlmesh/utils/metaprogramming.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,21 @@ def func_globals(func: t.Callable) -> t.Dict[str, t.Any]:
variables = {}

if hasattr(func, "__code__"):
code = func.__code__
root_node = parse_source(func)

func_args = next(node for node in ast.walk(root_node) if isinstance(node, ast.arguments))
arg_defaults = (d for d in func_args.defaults + func_args.kw_defaults if d is not None)

for var in list(_code_globals(code)) + decorators(func):
if var in func.__globals__:
ref = func.__globals__[var]
# ast.Name corresponds to variable references, such as foo or x.foo. The former is
# represented as Name(id=foo), and the latter as Attribute(value=Name(id=x) attr=foo)
arg_globals = [
n.id for default in arg_defaults for n in ast.walk(default) if isinstance(n, ast.Name)
]

code = func.__code__
for var in arg_globals + list(_code_globals(code)) + decorators(func, root_node=root_node):
ref = func.__globals__.get(var)
if ref:
variables[var] = ref

if func.__closure__:
Expand Down Expand Up @@ -177,9 +187,9 @@ def _decorator_name(decorator: ast.expr) -> str:
return ""


def decorators(func: t.Callable) -> t.List[str]:
def decorators(func: t.Callable, root_node: t.Optional[ast.Module] = None) -> t.List[str]:
"""Finds a list of all the decorators of a callable."""
root_node = parse_source(func)
root_node = root_node or parse_source(func)
decorators = []

for node in ast.walk(root_node):
Expand Down
5 changes: 3 additions & 2 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,8 +627,8 @@ def foo(**kwargs) -> None:
assert model.query == d.parse("SELECT 1 AS x")[0]

@macro()
def multiple_statements(evaluator):
return ["CREATE TABLE t1 AS SELECT 1 AS c", "CREATE TABLE t2 AS SELECT 2 AS c"]
def multiple_statements(evaluator, t1_value=exp.Literal.number(1)):
return [f"CREATE TABLE t1 AS SELECT {t1_value} AS c", "CREATE TABLE t2 AS SELECT 2 AS c"]

expressions = d.parse(
"""
Expand All @@ -645,6 +645,7 @@ def multiple_statements(evaluator):
'CREATE TABLE "t1" AS SELECT 1 AS "c"; CREATE TABLE "t2" AS SELECT 2 AS "c"'
)
assert model.render_post_statements() == expected_post
assert "exp" in model.python_env


def test_seed_hydration():
Expand Down
18 changes: 14 additions & 4 deletions tests/utils/test_metaprogramming.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import pytest
import sqlglot
from pytest_mock.plugin import MockerFixture
from sqlglot import exp
from sqlglot import exp as expressions
from sqlglot.expressions import to_table

import tests.utils.test_date as test_date
Expand Down Expand Up @@ -43,7 +45,7 @@ def test_print_exception(mocker: MockerFixture):

expected_message = f"""Traceback (most recent call last):
File "{__file__}", line 40, in test_print_exception
File "{__file__}", line 42, in test_print_exception
eval("test_fun()", env)
File "<string>", line 1, in <module>
Expand Down Expand Up @@ -104,7 +106,7 @@ def noop_metadata() -> None:
setattr(noop_metadata, c.SQLMESH_METADATA, True)


def main_func(y: int) -> int:
def main_func(y: int, foo=exp.true(), *, bar=expressions.Literal.number(1) + 2) -> int:
"""DOC STRING"""
sqlglot.parse_one("1")
MyClass()
Expand All @@ -128,6 +130,8 @@ def test_func_globals() -> None:
"normalize_model_name": normalize_model_name,
"other_func": other_func,
"sqlglot": sqlglot,
"exp": exp,
"expressions": exp,
}
assert func_globals(other_func) == {
"X": 1,
Expand All @@ -153,7 +157,8 @@ def closure() -> int:
def test_normalize_source() -> None:
assert (
normalize_source(main_func)
== """def main_func(y: int):
== """def main_func(y: int, foo=exp.true(), *, bar=expressions.Literal.number(1) + 2
):
sqlglot.parse_one('1')
MyClass()
DataClass(x=y)
Expand Down Expand Up @@ -194,7 +199,8 @@ def test_serialize_env() -> None:
name="main_func",
alias="MAIN",
path="test_metaprogramming.py",
payload="""def main_func(y: int):
payload="""def main_func(y: int, foo=exp.true(), *, bar=expressions.Literal.number(1) + 2
):
sqlglot.parse_one('1')
MyClass()
DataClass(x=y)
Expand Down Expand Up @@ -245,6 +251,10 @@ def baz(self):
),
"pd": Executable(payload="import pandas as pd", kind=ExecutableKind.IMPORT),
"sqlglot": Executable(kind=ExecutableKind.IMPORT, payload="import sqlglot"),
"exp": Executable(kind=ExecutableKind.IMPORT, payload="import sqlglot.expressions as exp"),
"expressions": Executable(
kind=ExecutableKind.IMPORT, payload="import sqlglot.expressions as expressions"
),
"my_lambda": Executable(
name="my_lambda",
path="test_metaprogramming.py",
Expand Down

0 comments on commit 67a5798

Please sign in to comment.