diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3a80bf681..3057e365e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,34 +1,20 @@ repos: - repo: local hooks: - - id: autoflake - name: autoflake - entry: autoflake - language: system - "types": [python] + - id: ruff + name: ruff + entry: ruff check --force-exclude --fix --ignore E721 --ignore E741 + language: python + types_or: [python, pyi] require_serial: true files: &files ^(sqlmesh/|tests/|web/|examples/|setup.py) - - id: isort - name: isort - entry: isort - language: system - "types": [python] - files: *files - require_serial: true - - id: black - name: black - language: system - args: - - --target-version - - py37 - - --line-length - - "100" - entry: black + - id: ruff-format + name: ruff-format + entry: ruff format --force-exclude --line-length 100 + language: python + types_or: [python, pyi] require_serial: true files: *files - types_or: - - python - - pyi - id: mypy name: mypy language: system diff --git a/Makefile b/Makefile index dc2afdd29..81fb69313 100644 --- a/Makefile +++ b/Makefile @@ -22,7 +22,7 @@ py-style: SKIP=prettier,eslint pre-commit run --all-files ui-style: - SKIP=autoflake,isort,black,mypy pre-commit run --all-files + SKIP=ruff,ruff-format,mypy pre-commit run --all-files doc-test: PYTEST_PLUGINS=tests.common_fixtures pytest --doctest-modules sqlmesh/core sqlmesh/utils diff --git a/examples/sushi/config.py b/examples/sushi/config.py index 386c2ae99..5d1831ae2 100644 --- a/examples/sushi/config.py +++ b/examples/sushi/config.py @@ -127,7 +127,7 @@ CATALOGS = { "in_memory": ":memory:", - "other_catalog": f":memory:", + "other_catalog": ":memory:", } local_catalogs = Config( diff --git a/examples/sushi/models/raw_marketing.py b/examples/sushi/models/raw_marketing.py index 3f47866d3..14cba26ca 100644 --- a/examples/sushi/models/raw_marketing.py +++ b/examples/sushi/models/raw_marketing.py @@ -4,7 +4,6 @@ import numpy as np import pandas as pd -from helper import iter_dates # type: ignore from sqlglot import exp from sqlmesh import ExecutionContext, model diff --git a/setup.cfg b/setup.cfg index 20f5780ef..f7e19479b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -98,16 +98,3 @@ ignore_missing_imports = True [mypy-pydantic_core.*] ignore_missing_imports = True - -[autoflake] -in-place = True -expand-star-imports = True -remove-all-unused-imports = True -ignore-init-module-imports = True -remove-duplicate-keys = True -remove-unused-variables = True -quiet = True - -[isort] -profile=black -known_first_party=sqlmesh diff --git a/setup.py b/setup.py index 525bdcc4e..46f4420e9 100644 --- a/setup.py +++ b/setup.py @@ -59,10 +59,9 @@ ], "dev": [ f"apache-airflow=={os.environ.get('AIRFLOW_VERSION', '2.9.1')}", - "autoflake==1.7.7", "agate==1.7.1", "beautifulsoup4", - "black==24.4.2", + "ruff~=0.4.0", "cryptography~=42.0.4", "dbt-core", "dbt-duckdb>=1.7.1", @@ -70,8 +69,7 @@ "google-auth", "google-cloud-bigquery", "google-cloud-bigquery-storage", - "isort==5.10.1", - "mypy~=1.8.0", + "mypy~=1.10.0", "pre-commit", "pandas-stubs", "psycopg2-binary", diff --git a/sqlmesh/__init__.py b/sqlmesh/__init__.py index b65c39d21..ed33bbb30 100644 --- a/sqlmesh/__init__.py +++ b/sqlmesh/__init__.py @@ -1,3 +1,4 @@ +# ruff: noqa: E402 """ .. include:: ../README.md """ @@ -18,16 +19,19 @@ extend_sqlglot() from sqlmesh.core import constants as c -from sqlmesh.core.config import Config -from sqlmesh.core.context import Context, ExecutionContext -from sqlmesh.core.engine_adapter import EngineAdapter -from sqlmesh.core.macros import SQL, macro -from sqlmesh.core.model import Model, model -from sqlmesh.core.snapshot import Snapshot -from sqlmesh.utils import debug_mode_enabled, enable_debug_mode +from sqlmesh.core.config import Config as Config +from sqlmesh.core.context import Context as Context, ExecutionContext as ExecutionContext +from sqlmesh.core.engine_adapter import EngineAdapter as EngineAdapter +from sqlmesh.core.macros import SQL as SQL, macro as macro +from sqlmesh.core.model import Model as Model, model as model +from sqlmesh.core.snapshot import Snapshot as Snapshot +from sqlmesh.utils import ( + debug_mode_enabled as debug_mode_enabled, + enable_debug_mode as enable_debug_mode, +) try: - from sqlmesh._version import __version__, __version_tuple__ # type: ignore + from sqlmesh._version import __version__ as __version__, __version_tuple__ as __version_tuple__ except ImportError: pass diff --git a/sqlmesh/cli/__init__.py b/sqlmesh/cli/__init__.py index fd9c4e9e1..b0d59356f 100644 --- a/sqlmesh/cli/__init__.py +++ b/sqlmesh/cli/__init__.py @@ -17,7 +17,7 @@ def error_handler( - func: t.Callable[..., DECORATOR_RETURN_TYPE] + func: t.Callable[..., DECORATOR_RETURN_TYPE], ) -> t.Callable[..., DECORATOR_RETURN_TYPE]: @wraps(func) def wrapper(*args: t.List[t.Any], **kwargs: t.Any) -> DECORATOR_RETURN_TYPE: diff --git a/sqlmesh/core/audit/__init__.py b/sqlmesh/core/audit/__init__.py index 67b94e99f..4fd74c2e4 100644 --- a/sqlmesh/core/audit/__init__.py +++ b/sqlmesh/core/audit/__init__.py @@ -4,12 +4,12 @@ from sqlmesh.core.audit import builtin from sqlmesh.core.audit.definition import ( - Audit, - AuditResult, - ModelAudit, - StandaloneAudit, - load_audit, - load_multiple_audits, + Audit as Audit, + AuditResult as AuditResult, + ModelAudit as ModelAudit, + StandaloneAudit as StandaloneAudit, + load_audit as load_audit, + load_multiple_audits as load_multiple_audits, ) diff --git a/sqlmesh/core/config/__init__.py b/sqlmesh/core/config/__init__.py index 46bfee035..8b61d7da7 100644 --- a/sqlmesh/core/config/__init__.py +++ b/sqlmesh/core/config/__init__.py @@ -1,34 +1,37 @@ -from sqlmesh.core.config.categorizer import AutoCategorizationMode, CategorizerConfig -from sqlmesh.core.config.common import EnvironmentSuffixTarget +from sqlmesh.core.config.categorizer import ( + AutoCategorizationMode as AutoCategorizationMode, + CategorizerConfig as CategorizerConfig, +) +from sqlmesh.core.config.common import EnvironmentSuffixTarget as EnvironmentSuffixTarget from sqlmesh.core.config.connection import ( - BigQueryConnectionConfig, - ConnectionConfig, - DatabricksConnectionConfig, - DuckDBConnectionConfig, - GCPPostgresConnectionConfig, - MotherDuckConnectionConfig, - MSSQLConnectionConfig, - MySQLConnectionConfig, - PostgresConnectionConfig, - RedshiftConnectionConfig, - SnowflakeConnectionConfig, - SparkConnectionConfig, - parse_connection_config, + BigQueryConnectionConfig as BigQueryConnectionConfig, + ConnectionConfig as ConnectionConfig, + DatabricksConnectionConfig as DatabricksConnectionConfig, + DuckDBConnectionConfig as DuckDBConnectionConfig, + GCPPostgresConnectionConfig as GCPPostgresConnectionConfig, + MotherDuckConnectionConfig as MotherDuckConnectionConfig, + MSSQLConnectionConfig as MSSQLConnectionConfig, + MySQLConnectionConfig as MySQLConnectionConfig, + PostgresConnectionConfig as PostgresConnectionConfig, + RedshiftConnectionConfig as RedshiftConnectionConfig, + SnowflakeConnectionConfig as SnowflakeConnectionConfig, + SparkConnectionConfig as SparkConnectionConfig, + parse_connection_config as parse_connection_config, ) -from sqlmesh.core.config.gateway import GatewayConfig +from sqlmesh.core.config.gateway import GatewayConfig as GatewayConfig from sqlmesh.core.config.loader import ( - load_config_from_paths, - load_config_from_yaml, - load_configs, + load_config_from_paths as load_config_from_paths, + load_config_from_yaml as load_config_from_yaml, + load_configs as load_configs, ) -from sqlmesh.core.config.migration import MigrationConfig -from sqlmesh.core.config.model import ModelDefaultsConfig -from sqlmesh.core.config.plan import PlanConfig -from sqlmesh.core.config.root import Config -from sqlmesh.core.config.run import RunConfig +from sqlmesh.core.config.migration import MigrationConfig as MigrationConfig +from sqlmesh.core.config.model import ModelDefaultsConfig as ModelDefaultsConfig +from sqlmesh.core.config.plan import PlanConfig as PlanConfig +from sqlmesh.core.config.root import Config as Config +from sqlmesh.core.config.run import RunConfig as RunConfig from sqlmesh.core.config.scheduler import ( - AirflowSchedulerConfig, - BuiltInSchedulerConfig, - CloudComposerSchedulerConfig, - MWAASchedulerConfig, + AirflowSchedulerConfig as AirflowSchedulerConfig, + BuiltInSchedulerConfig as BuiltInSchedulerConfig, + CloudComposerSchedulerConfig as CloudComposerSchedulerConfig, + MWAASchedulerConfig as MWAASchedulerConfig, ) diff --git a/sqlmesh/core/config/loader.py b/sqlmesh/core/config/loader.py index 435dee7e8..78d0a9bb0 100644 --- a/sqlmesh/core/config/loader.py +++ b/sqlmesh/core/config/loader.py @@ -118,7 +118,7 @@ def load_config_from_paths( supported_model_defaults = ModelDefaultsConfig.all_fields() for default in non_python_config_dict.get("model_defaults", {}): - if not default in supported_model_defaults: + if default not in supported_model_defaults: raise ConfigError( f"'{default}' is not a valid model default configuration key. Please remove it from the `model_defaults` specification in your config file." ) diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index a42a604b2..bd3db5573 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -464,7 +464,7 @@ def start_migration_progress(self, total_tasks: int) -> None: """Indicates that a new migration progress has begun.""" if self.migration_progress is None: self.migration_progress = Progress( - TextColumn(f"[bold blue]Migrating snapshots", justify="right"), + TextColumn("[bold blue]Migrating snapshots", justify="right"), BarColumn(bar_width=40), "[progress.percentage]{task.percentage:>3.1f}%", "•", @@ -476,7 +476,7 @@ def start_migration_progress(self, total_tasks: int) -> None: self.migration_progress.start() self.migration_task = self.migration_progress.add_task( - f"Migrating snapshots...", + "Migrating snapshots...", total=total_tasks, ) @@ -590,7 +590,7 @@ def _get_ignored_tree( environment_naming_info: EnvironmentNamingInfo, default_catalog: t.Optional[str], ) -> Tree: - ignored = Tree(f"[bold][ignored]Ignored Models (Expected Plan Start):") + ignored = Tree("[bold][ignored]Ignored Models (Expected Plan Start):") for s_id in ignored_snapshot_ids: snapshot = snapshots[s_id] ignored.add( @@ -642,7 +642,7 @@ def _show_summary_tree_for( tree = Tree(f"[bold]{header}:") if added_snapshot_ids: - added_tree = Tree(f"[bold][added]Added:") + added_tree = Tree("[bold][added]Added:") for s_id in added_snapshot_ids: snapshot = context_diff.snapshots[s_id] added_tree.add( @@ -650,7 +650,7 @@ def _show_summary_tree_for( ) tree.add(added_tree) if removed_snapshot_ids: - removed_tree = Tree(f"[bold][removed]Removed:") + removed_tree = Tree("[bold][removed]Removed:") for s_id in removed_snapshot_ids: snapshot_table_info = context_diff.removed_snapshots[s_id] removed_tree.add( @@ -658,9 +658,9 @@ def _show_summary_tree_for( ) tree.add(removed_tree) if modified_snapshot_ids: - direct = Tree(f"[bold][direct]Directly Modified:") - indirect = Tree(f"[bold][indirect]Indirectly Modified:") - metadata = Tree(f"[bold][metadata]Metadata Updated:") + direct = Tree("[bold][direct]Directly Modified:") + indirect = Tree("[bold][indirect]Indirectly Modified:") + metadata = Tree("[bold][metadata]Metadata Updated:") for s_id in modified_snapshot_ids: name = s_id.name display_name = context_diff.snapshots[s_id].display_name( @@ -745,7 +745,7 @@ def _prompt_categorize( for child_sid in sorted(plan.indirectly_modified.get(snapshot.snapshot_id, set())): child_snapshot = plan.context_diff.snapshots[child_sid] if not indirect_tree: - indirect_tree = Tree(f"[indirect]Indirectly Modified Children:") + indirect_tree = Tree("[indirect]Indirectly Modified Children:") tree.add(indirect_tree) indirect_tree.add( f"[indirect]{child_snapshot.display_name(plan.environment_naming_info, default_catalog)}" @@ -771,7 +771,7 @@ def _show_categorized_snapshots(self, plan: Plan, default_catalog: t.Optional[st for child_sid in sorted(plan.indirectly_modified.get(snapshot.snapshot_id, set())): child_snapshot = context_diff.snapshots[child_sid] if not indirect_tree: - indirect_tree = Tree(f"[indirect]Indirectly Modified Children:") + indirect_tree = Tree("[indirect]Indirectly Modified Children:") tree.add(indirect_tree) child_category_str = SNAPSHOT_CHANGE_CATEGORY_STR[child_snapshot.change_category] indirect_tree.add( @@ -868,7 +868,7 @@ def _prompt_backfill( def _prompt_promote(self, plan_builder: PlanBuilder) -> None: if self._confirm( - f"Apply - Virtual Update", + "Apply - Virtual Update", ): plan_builder.apply() @@ -1072,10 +1072,12 @@ def __init__( **kwargs: t.Any, ) -> None: import ipywidgets as widgets + from IPython import get_ipython from IPython.display import display as ipython_display super().__init__(console, **kwargs) - self.display = display or get_ipython().user_ns.get("display", ipython_display) # type: ignore + + self.display = display or get_ipython().user_ns.get("display", ipython_display) self.missing_dates_output = widgets.Output() self.dynamic_options_after_categorization_output = widgets.VBox() @@ -1424,7 +1426,7 @@ def show_model_difference_summary( } added_snapshot_models = {s for s in added_snapshots if s.is_model} if added_snapshot_models: - self._print(f"\n**Added Models:**") + self._print("\n**Added Models:**") for snapshot in sorted(added_snapshot_models): self._print( f"- `{snapshot.display_name(environment_naming_info, default_catalog)}`" @@ -1432,7 +1434,7 @@ def show_model_difference_summary( added_snapshot_audits = {s for s in added_snapshots if s.is_audit} if added_snapshot_audits: - self._print(f"\n**Added Standalone Audits:**") + self._print("\n**Added Standalone Audits:**") for snapshot in sorted(added_snapshot_audits): self._print( f"- `{snapshot.display_name(environment_naming_info, default_catalog)}`" @@ -1445,7 +1447,7 @@ def show_model_difference_summary( } removed_model_snapshot_table_infos = {s for s in removed_snapshot_table_infos if s.is_model} if removed_model_snapshot_table_infos: - self._print(f"\n**Removed Models:**") + self._print("\n**Removed Models:**") for snapshot_table_info in sorted(removed_model_snapshot_table_infos): self._print( f"- `{snapshot_table_info.display_name(environment_naming_info, default_catalog)}`" @@ -1453,7 +1455,7 @@ def show_model_difference_summary( removed_audit_snapshot_table_infos = {s for s in removed_snapshot_table_infos if s.is_audit} if removed_audit_snapshot_table_infos: - self._print(f"\n**Removed Standalone Audits:**") + self._print("\n**Removed Standalone Audits:**") for snapshot_table_info in sorted(removed_audit_snapshot_table_infos): self._print( f"- `{snapshot_table_info.display_name(environment_naming_info, default_catalog)}`" @@ -1476,7 +1478,7 @@ def show_model_difference_summary( elif context_diff.metadata_updated(snapshot.name): metadata_modified.append(snapshot) if directly_modified: - self._print(f"\n**Directly Modified:**") + self._print("\n**Directly Modified:**") for snapshot in sorted(directly_modified): self._print( f"- `{snapshot.display_name(environment_naming_info, default_catalog)}`" @@ -1484,19 +1486,19 @@ def show_model_difference_summary( if not no_diff: self._print(f"```diff\n{context_diff.text_diff(snapshot.name)}\n```") if indirectly_modified: - self._print(f"\n**Indirectly Modified:**") + self._print("\n**Indirectly Modified:**") for snapshot in sorted(indirectly_modified): self._print( f"- `{snapshot.display_name(environment_naming_info, default_catalog)}`" ) if metadata_modified: - self._print(f"\n**Metadata Updated:**") + self._print("\n**Metadata Updated:**") for snapshot in sorted(metadata_modified): self._print( f"- `{snapshot.display_name(environment_naming_info, default_catalog)}`" ) if ignored_snapshot_ids: - self._print(f"\n**Ignored Models (Expected Plan Start):**") + self._print("\n**Ignored Models (Expected Plan Start):**") for s_id in sorted(ignored_snapshot_ids): snapshot = context_diff.snapshots[s_id] self._print( @@ -1536,7 +1538,7 @@ def _show_categorized_snapshots(self, plan: Plan, default_catalog: t.Optional[st for child_sid in sorted(plan.indirectly_modified.get(snapshot.snapshot_id, set())): child_snapshot = context_diff.snapshots[child_sid] if not indirect_tree: - indirect_tree = Tree(f"[indirect]Indirectly Modified Children:") + indirect_tree = Tree("[indirect]Indirectly Modified Children:") tree.add(indirect_tree) child_category_str = SNAPSHOT_CHANGE_CATEGORY_STR[child_snapshot.change_category] indirect_tree.add( @@ -1651,7 +1653,7 @@ def start_creation_progress( ) -> None: """Indicates that a new creation progress has begun.""" self.model_creation_status = (0, total_tasks) - print(f"Starting Creating New Model Versions") + print("Starting Creating New Model Versions") def update_creation_progress(self, snapshot: SnapshotInfoLike) -> None: """Update the snapshot creation progress.""" @@ -1692,7 +1694,7 @@ def stop_promotion_progress(self, success: bool = True) -> None: def start_migration_progress(self, total_tasks: int) -> None: """Indicates that a new migration progress has begun.""" self.migration_status = (0, total_tasks) - print(f"Starting Migration") + print("Starting Migration") def update_migration_progress(self, num_tasks: int) -> None: """Update the migration progress.""" diff --git a/sqlmesh/core/dialect.py b/sqlmesh/core/dialect.py index ac31df419..0382b4ac8 100644 --- a/sqlmesh/core/dialect.py +++ b/sqlmesh/core/dialect.py @@ -384,7 +384,9 @@ def _parse_types( allow_identifiers: bool = True, ) -> t.Optional[exp.Expression]: start = self._curr - parsed_type = self.__parse_types(check_func=check_func, schema=schema, allow_identifiers=allow_identifiers) # type: ignore + parsed_type = self.__parse_types( # type: ignore + check_func=check_func, schema=schema, allow_identifiers=allow_identifiers + ) if schema and parsed_type: parsed_type.meta["sql"] = self._find_sql(start, self._prev) @@ -798,7 +800,8 @@ def extend_sqlglot() -> None: DColonCast: lambda self, e: f"{self.sql(e, 'this')}::{self.sql(e, 'to')}", Jinja: lambda self, e: e.name, JinjaQuery: lambda self, e: f"{JINJA_QUERY_BEGIN};\n{e.name}\n{JINJA_END};", - JinjaStatement: lambda self, e: f"{JINJA_STATEMENT_BEGIN};\n{e.name}\n{JINJA_END};", + JinjaStatement: lambda self, + e: f"{JINJA_STATEMENT_BEGIN};\n{e.name}\n{JINJA_END};", MacroDef: lambda self, e: f"@DEF({self.sql(e.this)}, {self.sql(e.expression)})", MacroFunc: _macro_func_sql, MacroStrReplace: lambda self, e: f"@{self.sql(e.this)}", diff --git a/sqlmesh/core/engine_adapter/base_postgres.py b/sqlmesh/core/engine_adapter/base_postgres.py index ecf72b218..7b49982da 100644 --- a/sqlmesh/core/engine_adapter/base_postgres.py +++ b/sqlmesh/core/engine_adapter/base_postgres.py @@ -171,7 +171,10 @@ def _get_data_objects( df = self.fetchdf(query) return [ DataObject( - catalog=catalog, schema=row.schema_name, name=row.name, type=DataObjectType.from_str(row.type) # type: ignore + catalog=catalog, + schema=row.schema_name, + name=row.name, + type=DataObjectType.from_str(row.type), # type: ignore ) for row in df.itertuples() ] diff --git a/sqlmesh/core/engine_adapter/bigquery.py b/sqlmesh/core/engine_adapter/bigquery.py index 0f01fc3ac..00ff69738 100644 --- a/sqlmesh/core/engine_adapter/bigquery.py +++ b/sqlmesh/core/engine_adapter/bigquery.py @@ -669,7 +669,6 @@ def _get_data_objects( # resort to using SQL instead. schema = to_schema(schema_name) catalog = schema.catalog or self.get_current_catalog() - schema_sql = schema.sql(dialect=self.dialect) query = exp.select( exp.column("table_catalog").as_("catalog"), exp.column("table_name").as_("name"), diff --git a/sqlmesh/core/engine_adapter/databricks.py b/sqlmesh/core/engine_adapter/databricks.py index 5bf2a7a41..65cf9bd5f 100644 --- a/sqlmesh/core/engine_adapter/databricks.py +++ b/sqlmesh/core/engine_adapter/databricks.py @@ -54,7 +54,7 @@ def can_access_spark_session(cls) -> bool: if RuntimeEnv.get().is_databricks: return True try: - from databricks.connect import DatabricksSession # type: ignore + from databricks.connect import DatabricksSession # noqa return True except ImportError: diff --git a/sqlmesh/core/engine_adapter/mssql.py b/sqlmesh/core/engine_adapter/mssql.py index 54bafcb5b..aec63283b 100644 --- a/sqlmesh/core/engine_adapter/mssql.py +++ b/sqlmesh/core/engine_adapter/mssql.py @@ -131,11 +131,13 @@ def drop_schema( # _get_data_objects is catalog-specific, so these can't accidentally drop view/tables in another catalog if obj.type == DataObjectType.VIEW: self.drop_view( - ".".join([obj.schema_name, obj.name]), ignore_if_not_exists=ignore_if_not_exists # type: ignore + ".".join([obj.schema_name, obj.name]), + ignore_if_not_exists=ignore_if_not_exists, ) else: self.drop_table( - ".".join([obj.schema_name, obj.name]), exists=ignore_if_not_exists # type: ignore + ".".join([obj.schema_name, obj.name]), + exists=ignore_if_not_exists, ) super().drop_schema(schema_name, ignore_if_not_exists=ignore_if_not_exists, cascade=False) @@ -175,7 +177,9 @@ def query_factory() -> Query: columns_to_types_create = columns_to_types.copy() self._convert_df_datetime(df, columns_to_types_create) self.create_table(temp_table, columns_to_types_create) - rows: t.List[t.Tuple[t.Any, ...]] = list(df.replace({np.nan: None}).itertuples(index=False, name=None)) # type: ignore + rows: t.List[t.Tuple[t.Any, ...]] = list( + df.replace({np.nan: None}).itertuples(index=False, name=None) + ) # type: ignore conn = self._connection_pool.get() conn.bulk_copy(temp_table.sql(dialect=self.dialect), rows) return exp.select(*self._casted_columns(columns_to_types)).from_(temp_table) # type: ignore diff --git a/sqlmesh/core/engine_adapter/mysql.py b/sqlmesh/core/engine_adapter/mysql.py index 51963ce0c..3642cfc09 100644 --- a/sqlmesh/core/engine_adapter/mysql.py +++ b/sqlmesh/core/engine_adapter/mysql.py @@ -93,7 +93,9 @@ def _get_data_objects( df = self.fetchdf(query) return [ DataObject( - schema=row.schema_name, name=row.name, type=DataObjectType.from_str(row.type) # type: ignore + schema=row.schema_name, + name=row.name, + type=DataObjectType.from_str(row.type), # type: ignore ) for row in df.itertuples() ] diff --git a/sqlmesh/core/engine_adapter/redshift.py b/sqlmesh/core/engine_adapter/redshift.py index e3b982bcc..9b269fc1e 100644 --- a/sqlmesh/core/engine_adapter/redshift.py +++ b/sqlmesh/core/engine_adapter/redshift.py @@ -282,7 +282,10 @@ def _get_data_objects( df = self.fetchdf(query) return [ DataObject( - catalog=catalog, schema=row.schema_name, name=row.name, type=DataObjectType.from_str(row.type) # type: ignore + catalog=catalog, + schema=row.schema_name, + name=row.name, + type=DataObjectType.from_str(row.type), # type: ignore ) for row in df.itertuples() ] diff --git a/sqlmesh/core/engine_adapter/snowflake.py b/sqlmesh/core/engine_adapter/snowflake.py index ae5972aa2..ea18aaca5 100644 --- a/sqlmesh/core/engine_adapter/snowflake.py +++ b/sqlmesh/core/engine_adapter/snowflake.py @@ -101,10 +101,14 @@ def query_factory() -> Query: if kind.is_type("date"): # type: ignore df[column] = pd.to_datetime(df[column]).dt.date # type: ignore elif getattr(df.dtypes[column], "tz", None) is not None: # type: ignore - df[column] = pd.to_datetime(df[column]).dt.strftime("%Y-%m-%d %H:%M:%S.%f%z") # type: ignore + df[column] = pd.to_datetime(df[column]).dt.strftime( + "%Y-%m-%d %H:%M:%S.%f%z" + ) # type: ignore # https://github.com/snowflakedb/snowflake-connector-python/issues/1677 else: # type: ignore - df[column] = pd.to_datetime(df[column]).dt.strftime("%Y-%m-%d %H:%M:%S.%f") # type: ignore + df[column] = pd.to_datetime(df[column]).dt.strftime( + "%Y-%m-%d %H:%M:%S.%f" + ) # type: ignore self.create_table(temp_table, columns_to_types, exists=False) write_pandas( self._connection_pool.get(), @@ -149,7 +153,6 @@ def _get_data_objects( schema = to_schema(schema_name) catalog_name = schema.catalog or self.get_current_catalog() - schema_sql = schema.sql(dialect=self.dialect) query = ( exp.select( exp.column("TABLE_CATALOG").as_("catalog"), diff --git a/sqlmesh/core/engine_adapter/spark.py b/sqlmesh/core/engine_adapter/spark.py index 3c781e39b..f83db1d0f 100644 --- a/sqlmesh/core/engine_adapter/spark.py +++ b/sqlmesh/core/engine_adapter/spark.py @@ -121,12 +121,12 @@ def spark_to_sqlglot_types(cls, input: spark_types.StructType) -> t.Dict[str, ex def spark_complex_to_sqlglot_complex( complex_type: t.Union[ spark_types.StructType, spark_types.ArrayType, spark_types.MapType - ] + ], ) -> exp.DataType: def get_fields( complex_type: t.Union[ spark_types.StructType, spark_types.ArrayType, spark_types.MapType - ] + ], ) -> t.Sequence[spark_types.DataType]: if isinstance(complex_type, spark_types.StructType): return complex_type.fields diff --git a/sqlmesh/core/engine_adapter/trino.py b/sqlmesh/core/engine_adapter/trino.py index 2a8e6fd47..fbeea496b 100644 --- a/sqlmesh/core/engine_adapter/trino.py +++ b/sqlmesh/core/engine_adapter/trino.py @@ -176,12 +176,15 @@ def _df_to_source_queries( batch_size: int, target_table: TableName, ) -> t.List[SourceQuery]: + assert isinstance(df, pd.DataFrame) + # Trino does not accept timestamps in ISOFORMAT that include the "T". `execution_time` is stored in # Pandas with that format, so we convert the column to a string with the proper format and CAST to # timestamp in Trino. for column, kind in (columns_to_types or {}).items(): - if is_datetime64_any_dtype(df.dtypes[column]) and getattr(df.dtypes[column], "tz", None) is not None: # type: ignore - df[column] = pd.to_datetime(df[column]).map(lambda x: x.isoformat(" ")) # type: ignore + dtype = df.dtypes[column] + if is_datetime64_any_dtype(dtype) and getattr(dtype, "tz", None) is not None: + df[column] = pd.to_datetime(df[column]).map(lambda x: x.isoformat(" ")) return super()._df_to_source_queries(df, columns_to_types, batch_size, target_table) diff --git a/sqlmesh/core/metric/__init__.py b/sqlmesh/core/metric/__init__.py index 6379a4b9d..e3dea9d8c 100644 --- a/sqlmesh/core/metric/__init__.py +++ b/sqlmesh/core/metric/__init__.py @@ -1,7 +1,7 @@ from sqlmesh.core.metric.definition import ( - Metric, - MetricMeta, - expand_metrics, - load_metric_ddl, + Metric as Metric, + MetricMeta as MetricMeta, + expand_metrics as expand_metrics, + load_metric_ddl as load_metric_ddl, ) -from sqlmesh.core.metric.rewriter import rewrite +from sqlmesh.core.metric.rewriter import rewrite as rewrite diff --git a/sqlmesh/core/model/__init__.py b/sqlmesh/core/model/__init__.py index cdd608292..3c1fda093 100644 --- a/sqlmesh/core/model/__init__.py +++ b/sqlmesh/core/model/__init__.py @@ -1,32 +1,35 @@ -from sqlmesh.core.model.cache import ModelCache, OptimizedQueryCache -from sqlmesh.core.model.decorator import model +from sqlmesh.core.model.cache import ( + ModelCache as ModelCache, + OptimizedQueryCache as OptimizedQueryCache, +) +from sqlmesh.core.model.decorator import model as model from sqlmesh.core.model.definition import ( - Model, - PythonModel, - SeedModel, - SqlModel, - create_external_model, - create_python_model, - create_seed_model, - create_sql_model, - load_sql_based_model, + Model as Model, + PythonModel as PythonModel, + SeedModel as SeedModel, + SqlModel as SqlModel, + create_external_model as create_external_model, + create_python_model as create_python_model, + create_seed_model as create_seed_model, + create_sql_model as create_sql_model, + load_sql_based_model as load_sql_based_model, ) from sqlmesh.core.model.kind import ( - EmbeddedKind, - ExternalKind, - FullKind, - IncrementalByTimeRangeKind, - IncrementalByUniqueKeyKind, - IncrementalUnmanagedKind, - ModelKind, - ModelKindMixin, - ModelKindName, - SCDType2ByColumnKind, - SCDType2ByTimeKind, - SeedKind, - TimeColumn, - ViewKind, - model_kind_validator, + EmbeddedKind as EmbeddedKind, + ExternalKind as ExternalKind, + FullKind as FullKind, + IncrementalByTimeRangeKind as IncrementalByTimeRangeKind, + IncrementalByUniqueKeyKind as IncrementalByUniqueKeyKind, + IncrementalUnmanagedKind as IncrementalUnmanagedKind, + ModelKind as ModelKind, + ModelKindMixin as ModelKindMixin, + ModelKindName as ModelKindName, + SCDType2ByColumnKind as SCDType2ByColumnKind, + SCDType2ByTimeKind as SCDType2ByTimeKind, + SeedKind as SeedKind, + TimeColumn as TimeColumn, + ViewKind as ViewKind, + model_kind_validator as model_kind_validator, ) -from sqlmesh.core.model.meta import ModelMeta -from sqlmesh.core.model.seed import Seed +from sqlmesh.core.model.meta import ModelMeta as ModelMeta +from sqlmesh.core.model.seed import Seed as Seed diff --git a/sqlmesh/core/model/decorator.py b/sqlmesh/core/model/decorator.py index 55b5f6c77..2ff24eb66 100644 --- a/sqlmesh/core/model/decorator.py +++ b/sqlmesh/core/model/decorator.py @@ -38,7 +38,7 @@ def __init__(self, name: str, is_sql: bool = False, **kwargs: t.Any) -> None: f"""Python model "{name}"'s `kind` argument was passed a SQLMesh `{type(kind).__name__}` object. This may result in unexpected behavior - provide a dictionary instead.""" ) elif isinstance(kind, dict): - if not "name" in kind or not isinstance(kind.get("name"), ModelKindName): + if "name" not in kind or not isinstance(kind.get("name"), ModelKindName): raise ConfigError( f"""Python model "{name}"'s `kind` dictionary must contain a `name` key with a valid ModelKindName enum value.""" ) diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index 297cccef5..a86d13bda 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -124,7 +124,7 @@ def render( end: t.Optional[TimeLike] = None, execution_time: t.Optional[TimeLike] = None, **kwargs: t.Any, - ) -> t.Generator[QueryOrDF, None, None]: + ) -> t.Iterator[QueryOrDF]: """Renders the content of this model in a form of either a SELECT query, executing which the data for this model can be fetched, or a dataframe object which contains the data itself. @@ -1142,7 +1142,7 @@ def render( end: t.Optional[TimeLike] = None, execution_time: t.Optional[TimeLike] = None, **kwargs: t.Any, - ) -> t.Generator[QueryOrDF, None, None]: + ) -> t.Iterator[QueryOrDF]: self._ensure_hydrated() date_columns = [] @@ -1163,12 +1163,15 @@ def render( # convert all date/time types to native pandas timestamp for column in [*date_columns, *datetime_columns]: df[column] = pd.to_datetime(df[column]) + # extract datetime.date from pandas timestamp for DATE columns for column in date_columns: df[column] = df[column].dt.date + df[bool_columns] = df[bool_columns].apply(lambda i: str_to_bool(str(i))) df.loc[:, string_columns] = df[string_columns].mask( - cond=lambda x: x.notna(), other=df[string_columns].astype(str) # type: ignore + cond=lambda x: x.notna(), # type: ignore + other=df[string_columns].astype(str), # type: ignore ) yield df @@ -1318,7 +1321,7 @@ def render( end: t.Optional[TimeLike] = None, execution_time: t.Optional[TimeLike] = None, **kwargs: t.Any, - ) -> t.Generator[QueryOrDF, None, None]: + ) -> t.Iterator[QueryOrDF]: env = prepare_env(self.python_env) start, end = make_inclusive(start or c.EPOCH, end or c.EPOCH) execution_time = to_datetime(execution_time or c.EPOCH) @@ -1940,11 +1943,11 @@ def _python_env( if macro_ref.package is None and macro_ref.name in macros: used_macros[macro_ref.name] = macros[macro_ref.name] - for name, macro in used_macros.items(): - if isinstance(macro, Executable): - serialized_env[name] = macro - elif not hasattr(macro, c.SQLMESH_BUILTIN): - build_env(macro.func, env=python_env, name=name, path=module_path) + for name, used_macro in used_macros.items(): + if isinstance(used_macro, Executable): + serialized_env[name] = used_macro + elif not hasattr(used_macro, c.SQLMESH_BUILTIN): + build_env(used_macro.func, env=python_env, name=name, path=module_path) serialized_env.update(serialize_env(python_env, path=module_path)) diff --git a/sqlmesh/core/model/meta.py b/sqlmesh/core/model/meta.py index b038b64c4..60f45f671 100644 --- a/sqlmesh/core/model/meta.py +++ b/sqlmesh/core/model/meta.py @@ -134,7 +134,9 @@ def _validate_value_or_tuple( cls, v: t.Dict[str, t.Any], values: t.Dict[str, t.Any], normalize: bool = False ) -> t.Any: dialect = values.get("dialect") - _normalize = lambda v: normalize_identifiers(v, dialect=dialect) if normalize else v + + def _normalize(value: t.Any) -> t.Any: + return normalize_identifiers(value, dialect=dialect) if normalize else value if isinstance(v, (exp.Tuple, exp.Array)): return [_normalize(e).name for e in v.expressions] diff --git a/sqlmesh/core/notification_target.py b/sqlmesh/core/notification_target.py index acf37c09d..55269c10b 100644 --- a/sqlmesh/core/notification_target.py +++ b/sqlmesh/core/notification_target.py @@ -273,7 +273,6 @@ def send( exc: t.Optional[str] = None, **kwargs: t.Any, ) -> None: - status_emoji = { NotificationStatus.PROGRESS: slack.SlackAlertIcon.START, NotificationStatus.SUCCESS: slack.SlackAlertIcon.SUCCESS, diff --git a/sqlmesh/core/plan/__init__.py b/sqlmesh/core/plan/__init__.py index e6cf01c55..139979e24 100644 --- a/sqlmesh/core/plan/__init__.py +++ b/sqlmesh/core/plan/__init__.py @@ -1,9 +1,13 @@ -from sqlmesh.core.plan.builder import PlanBuilder -from sqlmesh.core.plan.definition import Plan, PlanStatus, SnapshotIntervals +from sqlmesh.core.plan.builder import PlanBuilder as PlanBuilder +from sqlmesh.core.plan.definition import ( + Plan as Plan, + PlanStatus as PlanStatus, + SnapshotIntervals as SnapshotIntervals, +) from sqlmesh.core.plan.evaluator import ( - AirflowPlanEvaluator, - BuiltInPlanEvaluator, - MWAAPlanEvaluator, - PlanEvaluator, - update_intervals_for_new_snapshots, + AirflowPlanEvaluator as AirflowPlanEvaluator, + BuiltInPlanEvaluator as BuiltInPlanEvaluator, + MWAAPlanEvaluator as MWAAPlanEvaluator, + PlanEvaluator as PlanEvaluator, + update_intervals_for_new_snapshots as update_intervals_for_new_snapshots, ) diff --git a/sqlmesh/core/plan/builder.py b/sqlmesh/core/plan/builder.py index e14448882..e223754ca 100644 --- a/sqlmesh/core/plan/builder.py +++ b/sqlmesh/core/plan/builder.py @@ -171,7 +171,7 @@ def set_choice(self, snapshot: Snapshot, choice: SnapshotChangeCategory) -> Plan def apply(self) -> None: """Builds and applies the plan.""" if not self._apply: - raise SQLMeshError(f"Plan was not initialized with an applier.") + raise SQLMeshError("Plan was not initialized with an applier.") self._apply(self.build()) def build(self) -> Plan: @@ -543,9 +543,8 @@ def _set_choice( child_snapshot.categorize_as(SnapshotChangeCategory.INDIRECT_NON_BREAKING) for upstream_id in directly_modified: - if ( - upstream_id == snapshot.snapshot_id - or child_s_id not in indirectly_modified.get(upstream_id, set()) + if upstream_id == snapshot.snapshot_id or child_s_id not in indirectly_modified.get( + upstream_id, set() ): continue diff --git a/sqlmesh/core/plan/evaluator.py b/sqlmesh/core/plan/evaluator.py index c3e3b80ff..a4cdf66b8 100644 --- a/sqlmesh/core/plan/evaluator.py +++ b/sqlmesh/core/plan/evaluator.py @@ -343,7 +343,6 @@ def _apply_plan(self, plan: Plan, plan_request_id: str) -> None: class StateBasedAirflowPlanEvaluator(BaseAirflowPlanEvaluator): - backfill_concurrent_tasks: int ddl_concurrent_tasks: int notification_targets: t.Optional[t.List[NotificationTarget]] diff --git a/sqlmesh/core/snapshot/__init__.py b/sqlmesh/core/snapshot/__init__.py index de7385234..b3c65739d 100644 --- a/sqlmesh/core/snapshot/__init__.py +++ b/sqlmesh/core/snapshot/__init__.py @@ -1,29 +1,29 @@ -from sqlmesh.core.snapshot.categorizer import categorize_change +from sqlmesh.core.snapshot.categorizer import categorize_change as categorize_change from sqlmesh.core.snapshot.definition import ( - DeployabilityIndex, - Intervals, - Node, - QualifiedViewName, - Snapshot, - SnapshotChangeCategory, - SnapshotDataVersion, - SnapshotFingerprint, - SnapshotId, - SnapshotIdLike, - SnapshotInfoLike, - SnapshotIntervals, - SnapshotNameVersion, - SnapshotNameVersionLike, - SnapshotTableCleanupTask, - SnapshotTableInfo, - earliest_start_date, - fingerprint_from_node, - has_paused_forward_only, - merge_intervals, - missing_intervals, - snapshots_to_dag, - start_date, - table_name, - to_table_mapping, + DeployabilityIndex as DeployabilityIndex, + Intervals as Intervals, + Node as Node, + QualifiedViewName as QualifiedViewName, + Snapshot as Snapshot, + SnapshotChangeCategory as SnapshotChangeCategory, + SnapshotDataVersion as SnapshotDataVersion, + SnapshotFingerprint as SnapshotFingerprint, + SnapshotId as SnapshotId, + SnapshotIdLike as SnapshotIdLike, + SnapshotInfoLike as SnapshotInfoLike, + SnapshotIntervals as SnapshotIntervals, + SnapshotNameVersion as SnapshotNameVersion, + SnapshotNameVersionLike as SnapshotNameVersionLike, + SnapshotTableCleanupTask as SnapshotTableCleanupTask, + SnapshotTableInfo as SnapshotTableInfo, + earliest_start_date as earliest_start_date, + fingerprint_from_node as fingerprint_from_node, + has_paused_forward_only as has_paused_forward_only, + merge_intervals as merge_intervals, + missing_intervals as missing_intervals, + snapshots_to_dag as snapshots_to_dag, + start_date as start_date, + table_name as table_name, + to_table_mapping as to_table_mapping, ) -from sqlmesh.core.snapshot.evaluator import SnapshotEvaluator +from sqlmesh.core.snapshot.evaluator import SnapshotEvaluator as SnapshotEvaluator diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index 6e5f82c09..20912651c 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -740,8 +740,7 @@ def merge_intervals(self, other: t.Union[Snapshot, SnapshotIntervals]) -> None: if self.identifier == other.identifier or ( # Indirect Non-Breaking snapshots share the dev table with its previous version. # The same applies to migrated snapshots. - (self.is_indirect_non_breaking or self.migrated) - and other.snapshot_id in previous_ids + (self.is_indirect_non_breaking or self.migrated) and other.snapshot_id in previous_ids ): for start, end in other.dev_intervals: self.add_interval(start, end, is_dev=True) diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 5bf0e393c..1ade05bba 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -30,6 +30,7 @@ from functools import reduce import pandas as pd +from pyspark.sql.dataframe import DataFrame as PySparkDataFrame from sqlglot import exp, select from sqlglot.executor import execute @@ -543,14 +544,20 @@ def apply(query_or_df: QueryOrDF, index: int = 0) -> None: if limit is not None: query_or_df = next(queries_or_dfs) - if isinstance(query_or_df, exp.Select): - existing_limit = query_or_df.args.get("limit") - if existing_limit: - limit = min( - limit, - execute(exp.select(existing_limit.expression)).rows[0][0], - ) - return query_or_df.head(limit) if hasattr(query_or_df, "head") else self.adapter._fetch_native_df(query_or_df.limit(limit)) # type: ignore + if isinstance(query_or_df, PySparkDataFrame): + return query_or_df.limit(limit) + if isinstance(query_or_df, pd.DataFrame): + return query_or_df.head(limit) + + assert isinstance(query_or_df, exp.Query) + + existing_limit = query_or_df.args.get("limit") + if existing_limit: + limit = min(limit, execute(exp.select(existing_limit.expression)).rows[0][0]) + assert limit is not None + + return self.adapter._fetch_native_df(query_or_df.limit(limit)) + # DataFrames, unlike SQL expressions, can provide partial results by yielding dataframes. As a result, # if the engine supports INSERT OVERWRITE or REPLACE WHERE and the snapshot is incremental by time range, we risk # having a partial result since each dataframe write can re-truncate partitions. To avoid this, we diff --git a/sqlmesh/core/state_sync/__init__.py b/sqlmesh/core/state_sync/__init__.py index 8662a1ddd..8b78e5749 100644 --- a/sqlmesh/core/state_sync/__init__.py +++ b/sqlmesh/core/state_sync/__init__.py @@ -14,7 +14,14 @@ adapter to read and write state to the underlying data store. """ -from sqlmesh.core.state_sync.base import StateReader, StateSync, Versions -from sqlmesh.core.state_sync.cache import CachingStateSync -from sqlmesh.core.state_sync.common import CommonStateSyncMixin, cleanup_expired_views -from sqlmesh.core.state_sync.engine_adapter import EngineAdapterStateSync +from sqlmesh.core.state_sync.base import ( + StateReader as StateReader, + StateSync as StateSync, + Versions as Versions, +) +from sqlmesh.core.state_sync.cache import CachingStateSync as CachingStateSync +from sqlmesh.core.state_sync.common import ( + CommonStateSyncMixin as CommonStateSyncMixin, + cleanup_expired_views as cleanup_expired_views, +) +from sqlmesh.core.state_sync.engine_adapter import EngineAdapterStateSync as EngineAdapterStateSync diff --git a/sqlmesh/core/state_sync/engine_adapter.py b/sqlmesh/core/state_sync/engine_adapter.py index b1bea4acb..797aeaaac 100644 --- a/sqlmesh/core/state_sync/engine_adapter.py +++ b/sqlmesh/core/state_sync/engine_adapter.py @@ -397,7 +397,7 @@ def nodes_exist(self, names: t.Iterable[str], exclude_external: bool = False) -> ) if exclude_external: query = query.where(exp.column("kind_name").neq(ModelKindName.EXTERNAL.value)) - return {name for name, in self._fetchall(query)} + return {name for (name,) in self._fetchall(query)} def reset(self, default_catalog: t.Optional[str]) -> None: """Resets the state store to the state when it was first initialized.""" @@ -1279,7 +1279,7 @@ def _snapshot_name_version_filter( batches = self._snapshot_batches(name_versions) if not name_versions: - return exp.false() + yield exp.false() elif self.engine_adapter.SUPPORTS_TUPLE_IN: for versions in batches: yield t.cast( diff --git a/sqlmesh/core/table_diff.py b/sqlmesh/core/table_diff.py index 7c06080b0..05d914684 100644 --- a/sqlmesh/core/table_diff.py +++ b/sqlmesh/core/table_diff.py @@ -341,7 +341,7 @@ def name(e: exp.Expression) -> str: s_sample = sample[(sample["s_exists"] == 1) & (sample["rows_joined"] == 0)][ [ *[f"s__{c}" for c in index_cols], - *[f"s__{c}" for c in self.source_schema if not c in index_cols], + *[f"s__{c}" for c in self.source_schema if c not in index_cols], ] ] s_sample.rename( @@ -351,7 +351,7 @@ def name(e: exp.Expression) -> str: t_sample = sample[(sample["t_exists"] == 1) & (sample["rows_joined"] == 0)][ [ *[f"t__{c}" for c in index_cols], - *[f"t__{c}" for c in self.target_schema if not c in index_cols], + *[f"t__{c}" for c in self.target_schema if c not in index_cols], ] ] t_sample.rename( diff --git a/sqlmesh/core/test/__init__.py b/sqlmesh/core/test/__init__.py index db5464cde..aa2bcf45d 100644 --- a/sqlmesh/core/test/__init__.py +++ b/sqlmesh/core/test/__init__.py @@ -6,14 +6,14 @@ from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.model import Model -from sqlmesh.core.test.definition import ModelTest, generate_test +from sqlmesh.core.test.definition import ModelTest as ModelTest, generate_test as generate_test from sqlmesh.core.test.discovery import ( - ModelTestMetadata, - filter_tests_by_patterns, - get_all_model_tests, - load_model_test_file, + ModelTestMetadata as ModelTestMetadata, + filter_tests_by_patterns as filter_tests_by_patterns, + get_all_model_tests as get_all_model_tests, + load_model_test_file as load_model_test_file, ) -from sqlmesh.core.test.result import ModelTextTestResult +from sqlmesh.core.test.result import ModelTextTestResult as ModelTextTestResult from sqlmesh.utils import UniqueKeyDict if t.TYPE_CHECKING: diff --git a/sqlmesh/core/test/definition.py b/sqlmesh/core/test/definition.py index 930371d9f..5469a4ff7 100644 --- a/sqlmesh/core/test/definition.py +++ b/sqlmesh/core/test/definition.py @@ -645,8 +645,7 @@ def generate_test( # ruamel.yaml does not support pandas Timestamps, so we must convert them to python # datetime or datetime.date objects based on column type inputs = { - models[dep] - .name: pandas_timestamp_to_pydatetime( + models[dep].name: pandas_timestamp_to_pydatetime( engine_adapter.fetchdf(query).apply(lambda col: col.map(_normalize_df_value)), models[dep].columns_to_types, ) diff --git a/sqlmesh/dbt/__init__.py b/sqlmesh/dbt/__init__.py index c92667c7f..690b1f528 100644 --- a/sqlmesh/dbt/__init__.py +++ b/sqlmesh/dbt/__init__.py @@ -1 +1,4 @@ -from sqlmesh.dbt.builtin import create_builtin_filters, create_builtin_globals +from sqlmesh.dbt.builtin import ( + create_builtin_filters as create_builtin_filters, + create_builtin_globals as create_builtin_globals, +) diff --git a/sqlmesh/dbt/adapter.py b/sqlmesh/dbt/adapter.py index 0c5a328d2..463ecb067 100644 --- a/sqlmesh/dbt/adapter.py +++ b/sqlmesh/dbt/adapter.py @@ -327,7 +327,6 @@ def _map_table_name(self, table: exp.Table) -> exp.Table: return exp.to_table(physical_table_name, dialect=self.dialect) def _relation_to_table(self, relation: BaseRelation) -> exp.Table: - table = exp.to_table(relation.render(), dialect=self.dialect) return exp.to_table(relation.render(), dialect=self.dialect) def _table_to_relation(self, table: exp.Table) -> BaseRelation: diff --git a/sqlmesh/dbt/column.py b/sqlmesh/dbt/column.py index d2765a818..327f7cd53 100644 --- a/sqlmesh/dbt/column.py +++ b/sqlmesh/dbt/column.py @@ -11,7 +11,7 @@ def yaml_to_columns( - yaml: t.Dict[str, ColumnConfig] | t.List[t.Dict[str, ColumnConfig]] + yaml: t.Dict[str, ColumnConfig] | t.List[t.Dict[str, ColumnConfig]], ) -> t.Dict[str, ColumnConfig]: columns = {} mappings: t.List[t.Dict[str, ColumnConfig]] = ensure_list(yaml) diff --git a/sqlmesh/dbt/loader.py b/sqlmesh/dbt/loader.py index 8807472fe..71b3b28ee 100644 --- a/sqlmesh/dbt/loader.py +++ b/sqlmesh/dbt/loader.py @@ -51,7 +51,12 @@ def sqlmesh_config( return Config( default_gateway=profile.target_name, - gateways={profile.target_name: GatewayConfig(connection=profile.target.to_sqlmesh(**target_to_sqlmesh_args), state_connection=state_connection)}, # type: ignore + gateways={ + profile.target_name: GatewayConfig( + connection=profile.target.to_sqlmesh(**target_to_sqlmesh_args), + state_connection=state_connection, + ) + }, # type: ignore loader=DbtLoader, model_defaults=model_defaults, variables=variables or {}, diff --git a/sqlmesh/dbt/manifest.py b/sqlmesh/dbt/manifest.py index 19fc45d9a..e910d8090 100644 --- a/sqlmesh/dbt/manifest.py +++ b/sqlmesh/dbt/manifest.py @@ -125,9 +125,9 @@ def _load_sources(self) -> None: **_config(source), **source.to_dict(), ) - self._sources_per_package[source.package_name][ - source_config.config_name - ] = source_config + self._sources_per_package[source.package_name][source_config.config_name] = ( + source_config + ) def _load_macros(self) -> None: for macro in self._manifest.macros.values(): diff --git a/sqlmesh/integrations/github/cicd/controller.py b/sqlmesh/integrations/github/cicd/controller.py index 08804d959..669f10fcc 100644 --- a/sqlmesh/integrations/github/cicd/controller.py +++ b/sqlmesh/integrations/github/cicd/controller.py @@ -705,7 +705,7 @@ def update_required_approval_check( def conclusion_handler( conclusion: GithubCheckConclusion, ) -> t.Tuple[GithubCheckConclusion, str, t.Optional[str]]: - test_summary = f"**List of possible required approvers:**\n" + test_summary = "**List of possible required approvers:**\n" for user in self._required_approvers: test_summary += f"- `{user.github_username or user.username}`\n" diff --git a/sqlmesh/schedulers/airflow/plan.py b/sqlmesh/schedulers/airflow/plan.py index 1fbf03f7c..4cf23162f 100644 --- a/sqlmesh/schedulers/airflow/plan.py +++ b/sqlmesh/schedulers/airflow/plan.py @@ -154,7 +154,9 @@ def create_plan_dag_spec( if deployability_index_for_creation.is_representative(s) } if request.no_gaps and not request.is_dev - else None if request.no_gaps else set() + else None + if request.no_gaps + else set() ) return common.PlanDagSpec( diff --git a/sqlmesh/schedulers/airflow/plugin.py b/sqlmesh/schedulers/airflow/plugin.py index 3a12b546e..ac64d1163 100644 --- a/sqlmesh/schedulers/airflow/plugin.py +++ b/sqlmesh/schedulers/airflow/plugin.py @@ -39,7 +39,7 @@ def on_load(cls, *args: t.Any, **kwargs: t.Any) -> None: versions = state_sync.get_versions(validate=False) if versions.schema_version != 0: raise SQLMeshError( - f"Must define `default_catalog` when creating `SQLMeshAirflow` object. See docs for more info: https://sqlmesh.readthedocs.io/en/stable/integrations/airflow/#airflow-cluster-configuration" + "Must define `default_catalog` when creating `SQLMeshAirflow` object. See docs for more info: https://sqlmesh.readthedocs.io/en/stable/integrations/airflow/#airflow-cluster-configuration" ) logger.info("Migrating SQLMesh state ...") state_sync.migrate(default_catalog=default_catalog) diff --git a/sqlmesh/utils/__init__.py b/sqlmesh/utils/__init__.py index d57a1a4c6..f25dfb510 100644 --- a/sqlmesh/utils/__init__.py +++ b/sqlmesh/utils/__init__.py @@ -169,9 +169,7 @@ def sys_path(*paths: Path) -> t.Iterator[None]: def format_exception(exception: BaseException) -> t.List[str]: if sys.version_info < (3, 10): - return traceback.format_exception( - type(exception), exception, exception.__traceback__ - ) # type: ignore + return traceback.format_exception(type(exception), exception, exception.__traceback__) # type: ignore else: return traceback.format_exception(exception) # type: ignore @@ -302,7 +300,7 @@ def groupby( def columns_to_types_to_struct( - columns_to_types: t.Union[t.Dict[str, exp.DataType], t.Dict[str, str]] + columns_to_types: t.Union[t.Dict[str, exp.DataType], t.Dict[str, str]], ) -> exp.DataType: """ Converts a dict of column names to types to a struct. diff --git a/sqlmesh/utils/date.py b/sqlmesh/utils/date.py index a6d7396af..1f8bf421b 100644 --- a/sqlmesh/utils/date.py +++ b/sqlmesh/utils/date.py @@ -7,10 +7,6 @@ from pandas.api.types import is_datetime64_any_dtype # type: ignore -warnings.filterwarnings( - "ignore", - message="The localize method is no longer necessary, as this time zone supports the fold attribute", -) from datetime import date, datetime, timedelta, timezone import dateparser @@ -28,10 +24,16 @@ if t.TYPE_CHECKING: from sqlmesh.core.scheduler import Interval +warnings.filterwarnings( + "ignore", + message="The localize method is no longer necessary, as this time zone supports the fold attribute", +) + # The Freshness Date Data Parser doesn't support plural units so we add the `s?` to the expression freshness_date_parser_module.PATTERN = re.compile( - r"(\d+[.,]?\d*)\s*(%s)s?\b" % freshness_date_parser_module._UNITS, re.I | re.S | re.U # type: ignore + r"(\d+[.,]?\d*)\s*(%s)s?\b" % freshness_date_parser_module._UNITS, # type: ignore + re.I | re.S | re.U, # type: ignore ) DAY_SHORTCUT_EXPRESSIONS = {"today", "yesterday", "tomorrow"} TIME_UNITS = {"hours", "minutes", "seconds"} diff --git a/sqlmesh/utils/jinja.py b/sqlmesh/utils/jinja.py index a3b7d13ba..69aee6b8d 100644 --- a/sqlmesh/utils/jinja.py +++ b/sqlmesh/utils/jinja.py @@ -216,7 +216,7 @@ def _serialize_attribute_dict( ) -> t.Dict[str, t.Any]: # NOTE: This is called only when used with Pydantic V2. def _convert( - val: t.Union[t.Dict[str, JinjaGlobalAttribute], t.Dict[str, t.Any]] + val: t.Union[t.Dict[str, JinjaGlobalAttribute], t.Dict[str, t.Any]], ) -> t.Dict[str, t.Any]: return {k: _convert(v) if isinstance(v, AttributeDict) else v for k, v in val.items()} diff --git a/sqlmesh/utils/pydantic.py b/sqlmesh/utils/pydantic.py index fae6bd355..9032e359d 100644 --- a/sqlmesh/utils/pydantic.py +++ b/sqlmesh/utils/pydantic.py @@ -177,7 +177,9 @@ def parse_obj(cls: t.Type["Model"], obj: t.Any) -> "Model": @classmethod def parse_raw(cls: t.Type["Model"], b: t.Union[str, bytes], **kwargs: t.Any) -> "Model": return ( - super().model_validate_json(b, **kwargs) if PYDANTIC_MAJOR_VERSION >= 2 else super().parse_raw(b, **kwargs) # type: ignore + super().model_validate_json(b, **kwargs) # type: ignore + if PYDANTIC_MAJOR_VERSION >= 2 + else super().parse_raw(b, **kwargs) ) @classmethod @@ -200,7 +202,9 @@ def all_field_infos(cls: t.Type["PydanticModel"]) -> t.Dict[str, FieldInfo]: @classmethod def required_fields(cls: t.Type["PydanticModel"]) -> t.Set[str]: - return cls._fields(lambda field: field.is_required() if PYDANTIC_MAJOR_VERSION >= 2 else field.required) # type: ignore + return cls._fields( + lambda field: field.is_required() if PYDANTIC_MAJOR_VERSION >= 2 else field.required + ) # type: ignore @classmethod def _fields( diff --git a/sqlmesh/utils/yaml.py b/sqlmesh/utils/yaml.py index 21dc6c0a6..10cd4d97a 100644 --- a/sqlmesh/utils/yaml.py +++ b/sqlmesh/utils/yaml.py @@ -14,7 +14,7 @@ "env_var": lambda key, default=None: getenv(key, default), } -YAML = lambda: yaml.YAML(typ="safe") +YAML = lambda: yaml.YAML(typ="safe") # noqa: E731 def load( diff --git a/tests/conftest.py b/tests/conftest.py index 4eb5aa9c2..d400a88bb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -217,7 +217,7 @@ def sushi_context_pre_scheduling(init_and_plan_context: t.Callable) -> Context: def sushi_context_fixed_date(init_and_plan_context: t.Callable) -> Context: context, plan = init_and_plan_context("examples/sushi") - for model in context.models.values(): + for model in context.models.values(): # noqa: F402 if model.start: context.upsert_model(model.name, start="2022-01-01") @@ -365,7 +365,7 @@ def ignore(src, names): return [] def _make_function( - paths: t.Union[t.Union[str, Path], t.Collection[t.Union[str, Path]]] + paths: t.Union[t.Union[str, Path], t.Collection[t.Union[str, Path]]], ) -> t.List[Path]: paths = ensure_list(paths) temp_dirs = [] diff --git a/tests/core/engine_adapter/test_bigquery.py b/tests/core/engine_adapter/test_bigquery.py index 38776934e..a2d6a2467 100644 --- a/tests/core/engine_adapter/test_bigquery.py +++ b/tests/core/engine_adapter/test_bigquery.py @@ -201,7 +201,7 @@ def temp_table_exists(table: exp.Table) -> bool: ] assert load_temp_table.kwargs["df"].equals(df) assert load_temp_table.kwargs["table"] == get_temp_bq_table.return_value - assert load_temp_table.kwargs["job_config"].write_disposition == None + assert load_temp_table.kwargs["job_config"].write_disposition is None assert ( merge_sql.sql(dialect="bigquery") == "MERGE INTO test_table AS __MERGE_TARGET__ USING (SELECT `a`, `ds` FROM (SELECT `a`, `ds` FROM project.dataset.temp_table) AS _subquery WHERE ds BETWEEN '2022-01-01' AND '2022-01-05') AS __MERGE_SOURCE__ ON FALSE WHEN NOT MATCHED BY SOURCE AND ds BETWEEN '2022-01-01' AND '2022-01-05' THEN DELETE WHEN NOT MATCHED THEN INSERT (a, ds) VALUES (a, ds)" @@ -415,7 +415,7 @@ def test_ctas_time_partition( sql_calls = _to_sql_calls(execute_mock) assert sql_calls == [ - f"CREATE TABLE IF NOT EXISTS `test_table` PARTITION BY `ds` AS SELECT * FROM `a`", + "CREATE TABLE IF NOT EXISTS `test_table` PARTITION BY `ds` AS SELECT * FROM `a`", ] diff --git a/tests/core/engine_adapter/test_integration.py b/tests/core/engine_adapter/test_integration.py index 54eb5020b..2b12cc854 100644 --- a/tests/core/engine_adapter/test_integration.py +++ b/tests/core/engine_adapter/test_integration.py @@ -1873,8 +1873,7 @@ def validate_no_comments( # confirm view layer comments are registered in PROD if ctx.engine_adapter.COMMENT_CREATION_VIEW.is_supported: - prod_plan = context.plan(skip_tests=True, no_prompts=True, auto_apply=True) - + context.plan(skip_tests=True, no_prompts=True, auto_apply=True) validate_comments("sushi", is_physical_layer=False) diff --git a/tests/core/engine_adapter/test_mssql.py b/tests/core/engine_adapter/test_mssql.py index 717c9c33f..6e16cc988 100644 --- a/tests/core/engine_adapter/test_mssql.py +++ b/tests/core/engine_adapter/test_mssql.py @@ -382,7 +382,7 @@ def temp_table_exists(table: exp.Table) -> bool: assert to_sql_calls(adapter) == [ f"""IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = '{temp_table_name}') EXEC('CREATE TABLE [{temp_table_name}] ([a] INTEGER, [b] INTEGER)');""", - f"MERGE INTO [test_table] AS [__MERGE_TARGET__] USING (SELECT CAST([a] AS INTEGER) AS [a], CAST([b] AS INTEGER) AS [b] FROM [__temp_test_table_abcdefgh]) AS [__MERGE_SOURCE__] ON (1 = 0) WHEN NOT MATCHED BY SOURCE THEN DELETE WHEN NOT MATCHED THEN INSERT ([a], [b]) VALUES ([a], [b]);", + "MERGE INTO [test_table] AS [__MERGE_TARGET__] USING (SELECT CAST([a] AS INTEGER) AS [a], CAST([b] AS INTEGER) AS [b] FROM [__temp_test_table_abcdefgh]) AS [__MERGE_SOURCE__] ON (1 = 0) WHEN NOT MATCHED BY SOURCE THEN DELETE WHEN NOT MATCHED THEN INSERT ([a], [b]) VALUES ([a], [b]);", f"DROP TABLE IF EXISTS [{temp_table_name}];", ] diff --git a/tests/core/engine_adapter/test_mysql.py b/tests/core/engine_adapter/test_mysql.py index e0f07de11..46d23bbcf 100644 --- a/tests/core/engine_adapter/test_mysql.py +++ b/tests/core/engine_adapter/test_mysql.py @@ -58,7 +58,7 @@ def test_comments(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture) f"CREATE TABLE IF NOT EXISTS `test_table` (`a` INT COMMENT '{truncated_column_comment}', `b` INT) COMMENT='{truncated_table_comment}'", f"CREATE TABLE IF NOT EXISTS `test_table` COMMENT='{truncated_table_comment}' AS SELECT `a`, `b` FROM `source_table`", f"ALTER TABLE `test_table` MODIFY `a` INT COMMENT '{truncated_column_comment}'", - f"CREATE OR REPLACE VIEW `test_view` AS SELECT `a`, `b` FROM `source_table`", + "CREATE OR REPLACE VIEW `test_view` AS SELECT `a`, `b` FROM `source_table`", f"ALTER TABLE `test_table` COMMENT = '{truncated_table_comment}'", f"ALTER TABLE `test_table` MODIFY `a` INT COMMENT '{truncated_column_comment}'", ] diff --git a/tests/core/engine_adapter/test_spark.py b/tests/core/engine_adapter/test_spark.py index 815565ad0..9e94710bc 100644 --- a/tests/core/engine_adapter/test_spark.py +++ b/tests/core/engine_adapter/test_spark.py @@ -118,7 +118,9 @@ def test_replace_query_table_properties_exists( def test_create_view_properties(make_mocked_engine_adapter: t.Callable): adapter = make_mocked_engine_adapter(SparkEngineAdapter) - adapter.create_view("test_view", parse_one("SELECT a FROM tbl"), view_properties={"a": exp.convert(1)}) # type: ignore + adapter.create_view( + "test_view", parse_one("SELECT a FROM tbl"), view_properties={"a": exp.convert(1)} + ) # type: ignore adapter.cursor.execute.assert_called_once_with( "CREATE OR REPLACE VIEW test_view TBLPROPERTIES ('a'=1) AS SELECT a FROM tbl" ) @@ -323,7 +325,7 @@ def test_replace_query_self_ref_exists( # the schema for the table already exists "CREATE SCHEMA IF NOT EXISTS `db`", f"CREATE TABLE IF NOT EXISTS `db`.`temp_table_{temp_table_id}` AS SELECT `col` FROM `db`.`table`", - f"INSERT OVERWRITE TABLE `db`.`table` (`col`) SELECT `col` + 1 AS `col` FROM `db`.`temp_table_abcdefgh`", + "INSERT OVERWRITE TABLE `db`.`table` (`col`) SELECT `col` + 1 AS `col` FROM `db`.`temp_table_abcdefgh`", f"DROP TABLE IF EXISTS `db`.`temp_table_{temp_table_id}`", ] diff --git a/tests/core/engine_adapter/test_trino.py b/tests/core/engine_adapter/test_trino.py index 4d56dd890..b453b8899 100644 --- a/tests/core/engine_adapter/test_trino.py +++ b/tests/core/engine_adapter/test_trino.py @@ -164,7 +164,7 @@ def test_partitioned_by_iceberg_transforms( ) expressions = d.parse( - f""" + """ MODEL ( name test_table, partitioned_by (day(cola), truncate(colb, 8), colc), @@ -303,7 +303,7 @@ def test_comments_hive(mocker: MockerFixture, make_mocked_engine_adapter: t.Call f"""CREATE TABLE IF NOT EXISTS "test_table" ("a" INTEGER COMMENT '{truncated_column_comment}', "b" INTEGER) COMMENT '{truncated_table_comment}'""", f"""CREATE TABLE IF NOT EXISTS "test_table" COMMENT '{truncated_table_comment}' AS SELECT "a", "b" FROM "source_table\"""", f"""COMMENT ON COLUMN "test_table"."a" IS '{truncated_column_comment}'""", - f"""CREATE OR REPLACE VIEW test_view AS SELECT a, b FROM source_table""", + """CREATE OR REPLACE VIEW test_view AS SELECT a, b FROM source_table""", f"""COMMENT ON VIEW "test_view" IS '{truncated_table_comment}'""", f"""COMMENT ON TABLE "test_table" IS '{truncated_table_comment}'""", f"""COMMENT ON COLUMN "test_table"."a" IS '{truncated_column_comment}'""", @@ -371,7 +371,7 @@ def test_comments_iceberg_delta( f"""CREATE TABLE IF NOT EXISTS "test_table" ("a" INTEGER COMMENT '{long_column_comment}', "b" INTEGER) COMMENT '{long_table_comment}'""", f"""CREATE TABLE IF NOT EXISTS "test_table" COMMENT '{long_table_comment}' AS SELECT "a", "b" FROM "source_table\"""", f"""COMMENT ON COLUMN "test_table"."a" IS '{long_column_comment}'""", - f"""CREATE OR REPLACE VIEW test_view AS SELECT a, b FROM source_table""", + """CREATE OR REPLACE VIEW test_view AS SELECT a, b FROM source_table""", f"""COMMENT ON VIEW "test_view" IS '{long_table_comment}'""", f"""COMMENT ON TABLE "test_table" IS '{long_table_comment}'""", f"""COMMENT ON COLUMN "test_table"."a" IS '{long_column_comment}'""", diff --git a/tests/core/test_connection_config.py b/tests/core/test_connection_config.py index 4b9c53a8f..10171fa98 100644 --- a/tests/core/test_connection_config.py +++ b/tests/core/test_connection_config.py @@ -546,7 +546,6 @@ def test_duckdb_attach_catalog(make_config): def test_duckdb_attach_options(): - options = DuckDBAttachOptions( type="postgres", path="dbname=postgres user=postgres host=127.0.0.1", read_only=True ) diff --git a/tests/core/test_context.py b/tests/core/test_context.py index 7261520e4..a8c1add31 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -107,7 +107,7 @@ def test_render_sql_model(sushi_context, assert_exp_eq, copy_to_temp_path: t.Cal end=date(2021, 1, 1), expand=True, ), - f""" + """ SELECT CAST("o"."waiter_id" AS INT) AS "waiter_id", /* Waiter id */ CAST(SUM("oi"."quantity" * "i"."price") AS DOUBLE) AS "revenue", /* Revenue from orders taken by this waiter */ @@ -157,7 +157,7 @@ def test_render_sql_model(sushi_context, assert_exp_eq, copy_to_temp_path: t.Cal unpushed = Context(paths=copy_to_temp_path("examples/sushi")) assert_exp_eq( unpushed.render("sushi.waiter_revenue_by_day"), - f""" + """ SELECT CAST("o"."waiter_id" AS INT) AS "waiter_id", /* Waiter id */ CAST(SUM("oi"."quantity" * "i"."price") AS DOUBLE) AS "revenue", /* Revenue from orders taken by this waiter */ @@ -282,12 +282,12 @@ def test_ignore_files(mocker: MockerFixture, tmp_path: pathlib.Path): models_dir = pathlib.Path("models") macros_dir = pathlib.Path("macros") - ignore_model_file = create_temp_file( + create_temp_file( tmp_path, pathlib.Path(models_dir, "ignore", "ignore_model.sql"), "MODEL(name ignore.ignore_model); SELECT 1 AS cola", ) - ignore_macro_file = create_temp_file( + create_temp_file( tmp_path, pathlib.Path(macros_dir, "macro_ignore.py"), """ @@ -298,7 +298,7 @@ def test(): return "test" """, ) - constant_ignore_model_file = create_temp_file( + create_temp_file( tmp_path, pathlib.Path(models_dir, ".ipynb_checkpoints", "ignore_model2.sql"), "MODEL(name ignore_model2); SELECT cola::bigint AS cola FROM db.other_table", diff --git a/tests/core/test_dialect.py b/tests/core/test_dialect.py index 90feaca18..955bbfc7e 100644 --- a/tests/core/test_dialect.py +++ b/tests/core/test_dialect.py @@ -237,9 +237,7 @@ def test_text_diff(): - * -FROM x + 1 -+FROM y""" in text_diff( - parse("SELECT * FROM x"), parse("SELECT 1 FROM y") - ) ++FROM y""" in text_diff(parse("SELECT * FROM x"), parse("SELECT 1 FROM y")) def test_parse(): diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index 970a11adf..61c9f6215 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -358,7 +358,7 @@ def test_hourly_model_with_lookback_no_backfill_in_dev(init_and_plan_context: t. top_waiters_model = add_projection_to_model(t.cast(SqlModel, top_waiters_model), literal=True) context.upsert_model(top_waiters_model) - snapshot = context.get_snapshot(model, raise_if_missing=True) + context.get_snapshot(model, raise_if_missing=True) top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) with freeze_time(now() + timedelta(hours=2)): diff --git a/tests/core/test_macros.py b/tests/core/test_macros.py index 751483c14..40319e79d 100644 --- a/tests/core/test_macros.py +++ b/tests/core/test_macros.py @@ -78,7 +78,7 @@ def stamped(evaluator, query: exp.Select) -> exp.Subquery: return MacroEvaluator( "hive", - {"test": Executable(name="test", payload=f"def test(_):\n return 'test'")}, + {"test": Executable(name="test", payload="def test(_):\n return 'test'")}, ) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index d2736709b..21021add9 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -1,3 +1,4 @@ +# ruff: noqa: F811 import json import logging import typing as t @@ -3547,7 +3548,7 @@ def test_default_catalog_external_model(): def test_user_cannot_set_default_catalog(): expressions = d.parse( - f""" + """ MODEL ( name db.table, default_catalog some_catalog @@ -3584,7 +3585,7 @@ def my_model(context, **kwargs): def test_end_date(): expressions = d.parse( - f""" + """ MODEL ( name db.table, kind INCREMENTAL_BY_TIME_RANGE ( @@ -3606,7 +3607,7 @@ def test_end_date(): with pytest.raises(ConfigError, match=".*Start date.+can't be greater than end date.*"): load_sql_based_model( d.parse( - f""" + """ MODEL ( name db.table, kind INCREMENTAL_BY_TIME_RANGE ( @@ -3624,7 +3625,7 @@ def test_end_date(): def test_end_no_start(): expressions = d.parse( - f""" + """ MODEL ( name db.table, kind INCREMENTAL_BY_TIME_RANGE ( @@ -3868,7 +3869,9 @@ def test_named_variables_python_model(mocker: MockerFixture) -> None: def model_with_named_variables( context, start: TimeLike, test_var_a: str, test_var_b: t.Optional[str] = None, **kwargs ): - return pd.DataFrame([{"a": test_var_a, "b": test_var_b, "start": start.strftime("%Y-%m-%d")}]) # type: ignore + return pd.DataFrame( + [{"a": test_var_a, "b": test_var_b, "start": start.strftime("%Y-%m-%d")}] # type: ignore + ) python_model = model.get_registry()["test_named_variables_python_model"].model( module_path=Path("."), diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index d61987539..40edded98 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -91,14 +91,11 @@ def test_run(sushi_context_fixed_date: Context, scheduler: Scheduler): "2022-01-30", ) - assert ( - adapter.fetchone( - f""" + assert adapter.fetchone( + f""" SELECT id, name, price FROM sqlmesh__sushi.sushi__items__{snapshot.version} ORDER BY event_date LIMIT 1 """ - ) - == (0, "Hotate", 5.99) - ) + ) == (0, "Hotate", 5.99) def test_incremental_by_unique_key_kind_dag(mocker: MockerFixture, make_snapshot): diff --git a/tests/core/test_schema_loader.py b/tests/core/test_schema_loader.py index 56c182027..1e758e60c 100644 --- a/tests/core/test_schema_loader.py +++ b/tests/core/test_schema_loader.py @@ -148,7 +148,11 @@ def test_missing_table(tmp_path: Path): logger = logging.getLogger("sqlmesh.core.schema_loader") with patch.object(logger, "warning") as mock_logger: create_schema_file( - schema_file, {"a": model}, context.engine_adapter, context.state_reader, "" # type: ignore + schema_file, + {"a": model}, # type: ignore + context.engine_adapter, + context.state_reader, + "", ) assert """Unable to get schema for '"tbl_source"'""" in mock_logger.call_args[0][0] diff --git a/tests/core/test_selector.py b/tests/core/test_selector.py index 295ada252..34fd97b88 100644 --- a/tests/core/test_selector.py +++ b/tests/core/test_selector.py @@ -48,7 +48,7 @@ def test_select_models(mocker: MockerFixture, make_snapshot, default_catalog: t. default_catalog=default_catalog, ) standalone_audit = StandaloneAudit( - name="test_audit", query=d.parse_one(f"SELECT * FROM added_model WHERE a IS NULL") + name="test_audit", query=d.parse_one("SELECT * FROM added_model WHERE a IS NULL") ) modified_model_v1_snapshot = make_snapshot(modified_model_v1) diff --git a/tests/core/test_snapshot.py b/tests/core/test_snapshot.py index 4aac9be09..ddcc3d8d9 100644 --- a/tests/core/test_snapshot.py +++ b/tests/core/test_snapshot.py @@ -622,7 +622,7 @@ def test_get_removal_intervals_full_history_restatement_model(make_snapshot): assert interval == (to_timestamp("2023-01-01"), execution_time) -each_macro = lambda: "test" +each_macro = lambda: "test" # noqa: E731 def test_fingerprint(model: Model, parent_model: Model): diff --git a/tests/core/test_state_sync.py b/tests/core/test_state_sync.py index b9752aa74..f9b7a033c 100644 --- a/tests/core/test_state_sync.py +++ b/tests/core/test_state_sync.py @@ -215,7 +215,6 @@ def test_snapshots_exists(state_sync: EngineAdapterStateSync, snapshots: t.List[ @pytest.fixture def get_snapshot_intervals(state_sync) -> t.Callable[[Snapshot], t.Optional[SnapshotIntervals]]: - def _get_snapshot_intervals(snapshot: Snapshot) -> t.Optional[SnapshotIntervals]: intervals = state_sync._get_snapshot_intervals([snapshot])[-1] return intervals[0] if intervals else None diff --git a/tests/core/test_test.py b/tests/core/test_test.py index 7d110cac4..4c1855ef2 100644 --- a/tests/core/test_test.py +++ b/tests/core/test_test.py @@ -601,7 +601,7 @@ def test_partially_inferred_schemas(sushi_context: Context, mocker: MockerFixtur _check_successful_or_raise(test.run()) spy_execute.assert_any_call( - f'CREATE OR REPLACE VIEW "memory"."sqlmesh_test_jzngz56a"."memory__sushi__parent" ("s", "a", "b") AS ' + 'CREATE OR REPLACE VIEW "memory"."sqlmesh_test_jzngz56a"."memory__sushi__parent" ("s", "a", "b") AS ' "SELECT " 'CAST("s" AS STRUCT("d" DATE)) AS "s", ' 'CAST("a" AS INT) AS "a", ' @@ -1037,7 +1037,7 @@ def test_create_external_model_fixture(sushi_context: Context, mocker: MockerFix assert len(test._fixture_table_cache) == len(sushi_context.models) + 1 for table in test._fixture_table_cache.values(): assert table.catalog == "memory" - assert table.db == f"sqlmesh_test_jzngz56a" + assert table.db == "sqlmesh_test_jzngz56a" def test_runtime_stage() -> None: @@ -1136,7 +1136,7 @@ def test_generate_input_data_using_sql(mocker: MockerFixture) -> None: _check_successful_or_raise(test.run()) spy_execute.assert_any_call( - f'CREATE OR REPLACE VIEW "memory"."sqlmesh_test_jzngz56a"."foo" AS ' + 'CREATE OR REPLACE VIEW "memory"."sqlmesh_test_jzngz56a"."foo" AS ' '''SELECT {'x': 1, 'n': {'y': 2}} AS "struct_value"''' ) diff --git a/tests/dbt/test_adapter.py b/tests/dbt/test_adapter.py index 9105c43e7..644b2080a 100644 --- a/tests/dbt/test_adapter.py +++ b/tests/dbt/test_adapter.py @@ -50,16 +50,13 @@ def test_adapter_relation(sushi_test_project: Project, runtime_renderer: t.Calla assert renderer("{{ adapter.list_relations(database=None, schema='foo')|length }}") == "2" - assert ( - renderer( - """ + assert renderer( + """ {%- set from = adapter.get_relation(database=None, schema='foo', identifier='bar') -%} {%- set to = adapter.get_relation(database=None, schema='foo', identifier='another') -%} {{ adapter.get_missing_columns(from, to) -}} """ - ) - == str([Column.from_description(name="baz", raw_data_type="INT")]) - ) + ) == str([Column.from_description(name="baz", raw_data_type="INT")]) assert ( renderer( @@ -139,9 +136,6 @@ def test_normalization( adapter_mock.drop_table.reset_mock() renderer = runtime_renderer(context, engine_adapter=adapter_mock) - bla_id = exp.to_identifier("bla", quoted=True) - bob_id = exp.to_identifier("bob", quoted=True) - # Ensures we'll pass lowercase names to the engine renderer( "{%- set relation = api.Relation.create(schema='bla', identifier='bob') -%}" diff --git a/tests/dbt/test_config.py b/tests/dbt/test_config.py index 8ca46f5fc..16e7bf167 100644 --- a/tests/dbt/test_config.py +++ b/tests/dbt/test_config.py @@ -140,8 +140,8 @@ def test_test_to_sqlmesh_fields(): assert audit.name == "foo_test" assert audit.dialect == "duckdb" - assert audit.skip == False - assert audit.blocking == True + assert not audit.skip + assert audit.blocking assert sql in audit.query.sql() sql = "SELECT * FROM FOO WHERE NOT id IS NULL" @@ -158,8 +158,8 @@ def test_test_to_sqlmesh_fields(): assert audit.name == "foo_null_test" assert audit.dialect == "duckdb" - assert audit.skip == True - assert audit.blocking == False + assert audit.skip + assert not audit.blocking assert sql in audit.query.sql() @@ -178,7 +178,7 @@ def test_singular_test_to_standalone_audit(): dependencies=Dependencies(refs=["bar"]), ) - assert test_config.is_standalone == True + assert test_config.is_standalone model = ModelConfig(schema="foo", name="bar", alias="bar") diff --git a/tests/dbt/test_transformation.py b/tests/dbt/test_transformation.py index 3e333943f..cd1843c3b 100644 --- a/tests/dbt/test_transformation.py +++ b/tests/dbt/test_transformation.py @@ -246,19 +246,19 @@ def test_model_kind(): insert_overwrite=True, disable_restatement=True ) - with pytest.raises(ConfigError) as exception: + with pytest.raises(ConfigError): ModelConfig( materialized=Materialization.INCREMENTAL, unique_key=["bar"], incremental_strategy="delete+insert", ).model_kind(context) - with pytest.raises(ConfigError) as exception: + with pytest.raises(ConfigError): ModelConfig( materialized=Materialization.INCREMENTAL, unique_key=["bar"], incremental_strategy="insert_overwrite", ).model_kind(context) - with pytest.raises(ConfigError) as exception: + with pytest.raises(ConfigError): ModelConfig( materialized=Materialization.INCREMENTAL, unique_key=["bar"], diff --git a/tests/schedulers/airflow/test_client.py b/tests/schedulers/airflow/test_client.py index fed35c9a5..7c11357f8 100644 --- a/tests/schedulers/airflow/test_client.py +++ b/tests/schedulers/airflow/test_client.py @@ -160,7 +160,6 @@ def test_apply_plan(mocker: MockerFixture, snapshot: Snapshot): "skip_backfill": False, "notification_targets": [], "request_id": request_id, - "restatements": {}, "backfill_concurrent_tasks": 1, "ddl_concurrent_tasks": 1, "users": [], diff --git a/tests/utils/test_metaprogramming.py b/tests/utils/test_metaprogramming.py index 9b83a24ea..87aeda8e7 100644 --- a/tests/utils/test_metaprogramming.py +++ b/tests/utils/test_metaprogramming.py @@ -60,7 +60,7 @@ def test_fun(): Y = 2 Z = 3 -my_lambda = lambda: print("z") +my_lambda = lambda: print("z") # noqa: E731 KLASS_X = 1 KLASS_Y = 2 @@ -163,7 +163,7 @@ def closure(z: int): def test_serialize_env_error() -> None: - with pytest.raises(SQLMeshError) as e: + with pytest.raises(SQLMeshError): # pretend to be the module pandas serialize_env({"test_date": test_date}, path=Path("tests/utils")) @@ -231,7 +231,7 @@ def baz(self): "my_lambda": Executable( name="my_lambda", path="test_metaprogramming.py", - payload=f"my_lambda = lambda : print('z')", + payload="my_lambda = lambda : print('z')", ), "other_func": Executable( name="other_func", diff --git a/web/client/src/workers/sqlglot/sqlglot.py b/web/client/src/workers/sqlglot/sqlglot.py index 4090261e6..998d46792 100644 --- a/web/client/src/workers/sqlglot/sqlglot.py +++ b/web/client/src/workers/sqlglot/sqlglot.py @@ -45,7 +45,7 @@ def validate(sql: str = "", read: DialectType = None) -> str: sqlglot.transpile( sql, read=read, pretty=False, unsupported_level=sqlglot.errors.ErrorLevel.IMMEDIATE ) - except sqlglot.errors.ParseError as e: + except sqlglot.errors.ParseError: return json.dumps(False) return json.dumps(True) diff --git a/web/server/watcher.py b/web/server/watcher.py index ecfa7b984..1f67887a8 100644 --- a/web/server/watcher.py +++ b/web/server/watcher.py @@ -78,7 +78,7 @@ async def watch_project() -> None: except Exception: error = ApiException( message="Error updating file", - origin=f"API -> watcher -> watch_project", + origin="API -> watcher -> watch_project", trigger=path_str, ).to_dict() api_console.log_event(event=models.EventName.WARNINGS, data=error)