Skip to content

Commit

Permalink
Feat: add support for parameterization of model names in unit tests (T…
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas authored Jul 8, 2024
1 parent 3b88959 commit 7acaef3
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 9 deletions.
18 changes: 18 additions & 0 deletions docs/concepts/tests.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,24 @@ test_colors:
execution_time: "2023-01-01 12:05:03+02:00"
```

## Parameterized model names

Testing models with parameterized names, such as `@{gold}.some_schema.some_table`, is possible using Jinja:

```yaml linenums="1"
test_parameterized_model:
model: {{ var('gold') }}.some_schema.some_table
...
```

For example, assuming `gold` is a [config variable](../reference/configuration/#variables) with value `gold_db`, the above test would be rendered as:

```yaml linenums="1"
test_parameterized_model:
model: gold_db.some_schema.some_table
...
```

## Automatic test generation

Creating tests manually can be repetitive and error-prone, which is why SQLMesh also provides a way to automate this process using the [`create_test` command](../reference/cli.md#create_test).
Expand Down
1 change: 1 addition & 0 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1509,6 +1509,7 @@ def test(
path / c.TESTS,
patterns=match_patterns,
ignore_patterns=config.ignore_patterns,
variables=config.variables,
)
)

Expand Down
4 changes: 2 additions & 2 deletions sqlmesh/core/test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ def run_model_tests(
path = pathlib.Path(filename)

if test_name:
loaded_tests.append(load_model_test_file(path)[test_name])
loaded_tests.append(load_model_test_file(path, variables=config.variables)[test_name])
else:
loaded_tests.extend(load_model_test_file(path).values())
loaded_tests.extend(load_model_test_file(path, variables=config.variables).values())

if patterns:
loaded_tests = filter_tests_by_patterns(loaded_tests, patterns)
Expand Down
19 changes: 14 additions & 5 deletions sqlmesh/core/test/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ def __hash__(self) -> int:
return self.fully_qualified_test_name.__hash__()


def load_model_test_file(path: pathlib.Path) -> dict[str, ModelTestMetadata]:
def load_model_test_file(
path: pathlib.Path, variables: dict[str, t.Any] | None = None
) -> dict[str, ModelTestMetadata]:
"""Load a single model test file.
Args:
Expand All @@ -36,7 +38,7 @@ def load_model_test_file(path: pathlib.Path) -> dict[str, ModelTestMetadata]:
A list of ModelTestMetadata named tuples.
"""
model_test_metadata = {}
contents = yaml_load(path)
contents = yaml_load(path, variables=variables)

for test_name, value in contents.items():
model_test_metadata[test_name] = ModelTestMetadata(
Expand All @@ -46,7 +48,9 @@ def load_model_test_file(path: pathlib.Path) -> dict[str, ModelTestMetadata]:


def discover_model_tests(
path: pathlib.Path, ignore_patterns: list[str] | None = None
path: pathlib.Path,
ignore_patterns: list[str] | None = None,
variables: dict[str, t.Any] | None = None,
) -> Iterator[ModelTestMetadata]:
"""Discover model tests.
Expand All @@ -69,7 +73,9 @@ def discover_model_tests(
if yaml_file.match(ignore_pattern):
break
else:
for model_test_metadata in load_model_test_file(yaml_file).values():
for model_test_metadata in load_model_test_file(
yaml_file, variables=variables
).values():
yield model_test_metadata


Expand Down Expand Up @@ -97,9 +103,12 @@ def get_all_model_tests(
*paths: pathlib.Path,
patterns: list[str] | None = None,
ignore_patterns: list[str] | None = None,
variables: dict[str, t.Any] | None = None,
) -> list[ModelTestMetadata]:
model_test_metadatas = [
meta for path in paths for meta in discover_model_tests(pathlib.Path(path), ignore_patterns)
meta
for path in paths
for meta in discover_model_tests(pathlib.Path(path), ignore_patterns, variables=variables)
]
if patterns:
model_test_metadatas = filter_tests_by_patterns(model_test_metadatas, patterns)
Expand Down
11 changes: 9 additions & 2 deletions sqlmesh/utils/yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@

from ruamel import yaml

from sqlmesh.core.constants import VAR
from sqlmesh.utils.errors import SQLMeshError
from sqlmesh.utils.jinja import ENVIRONMENT
from sqlmesh.utils.jinja import ENVIRONMENT, create_var

JINJA_METHODS = {
"env_var": lambda key, default=None: getenv(key, default),
Expand All @@ -22,6 +23,7 @@ def load(
raise_if_empty: bool = True,
render_jinja: bool = True,
allow_duplicate_keys: bool = False,
variables: t.Optional[t.Dict[str, t.Any]] = None,
) -> t.Dict:
"""Loads a YAML object from either a raw string or a file."""
path: t.Optional[Path] = None
Expand All @@ -32,7 +34,12 @@ def load(
source = file.read()

if render_jinja:
source = ENVIRONMENT.from_string(source).render(JINJA_METHODS)
source = ENVIRONMENT.from_string(source).render(
{
**JINJA_METHODS,
VAR: create_var(variables or {}),
}
)

yaml = YAML()
yaml.allow_duplicate_keys = allow_duplicate_keys
Expand Down
39 changes: 39 additions & 0 deletions tests/core/test_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1449,6 +1449,45 @@ def execute(context, start, end, execution_time, **kwargs):
)


def test_variable_usage(tmp_path: Path) -> None:
init_example_project(tmp_path, dialect="duckdb")

config = Config(
default_connection=DuckDBConnectionConfig(),
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
variables={"gold": "gold_db", "silver": "silver_db"},
)
context = Context(paths=tmp_path, config=config)

parent = _create_model("SELECT 1 as id", meta="MODEL (name silver_db.sch.b)")
parent = t.cast(SqlModel, context.upsert_model(parent))

child = _create_model("SELECT id FROM silver_db.sch.b", meta="MODEL (name gold_db.sch.a)")
child = t.cast(SqlModel, context.upsert_model(child))

test_file = tmp_path / "tests" / "test_parameterized_model_names.yaml"
test_file.write_text(
"""
test_parameterized_model_names:
model: {{ var('gold') }}.sch.a
inputs:
{{ var('silver') }}.sch.b:
- id: 1
outputs:
query:
- id: 1
"""
)

results = context.test()

assert not results.failures
assert not results.errors

# The example project has one test and we added another one above
assert len(results.successes) == 2


def test_test_generation(tmp_path: Path) -> None:
init_example_project(tmp_path, dialect="duckdb")

Expand Down

0 comments on commit 7acaef3

Please sign in to comment.