From fcddf0f5c82b1b4130bfadfdf932a1c34d69248c Mon Sep 17 00:00:00 2001 From: Barry Hart Date: Mon, 16 Jan 2023 06:59:37 -0500 Subject: [PATCH] Due to performance and other issues, revert the osmosis implementation of the templater for now (#4273) * Revert osmosis dbt templater for now * Revert the osmosis implementation of the templater for now * Fix broken test * Fix coverage gaps * More coverage fixes * PR review * PR review Co-authored-by: Barry Hart Co-authored-by: Alan Cruickshank --- .../sqlfluff_templater_dbt/__init__.py | 14 +- .../sqlfluff_templater_dbt/osmosis/LICENSE | 201 --- .../sqlfluff_templater_dbt/osmosis/README.md | 4 - .../osmosis/__init__.py | 1422 ----------------- .../osmosis/exceptions.py | 14 - .../osmosis/log_controller.py | 75 - .../sqlfluff_templater_dbt/osmosis/patch.py | 35 - .../sqlfluff_templater_dbt/templater.py | 581 +++++-- .../test/fixtures/dbt/templater.py | 4 +- .../test/templater_test.py | 78 +- setup.cfg | 3 - 11 files changed, 512 insertions(+), 1919 deletions(-) delete mode 100644 plugins/sqlfluff-templater-dbt/sqlfluff_templater_dbt/osmosis/LICENSE delete mode 100644 plugins/sqlfluff-templater-dbt/sqlfluff_templater_dbt/osmosis/README.md delete mode 100644 plugins/sqlfluff-templater-dbt/sqlfluff_templater_dbt/osmosis/__init__.py delete mode 100644 plugins/sqlfluff-templater-dbt/sqlfluff_templater_dbt/osmosis/exceptions.py delete mode 100644 plugins/sqlfluff-templater-dbt/sqlfluff_templater_dbt/osmosis/log_controller.py delete mode 100644 plugins/sqlfluff-templater-dbt/sqlfluff_templater_dbt/osmosis/patch.py diff --git a/plugins/sqlfluff-templater-dbt/sqlfluff_templater_dbt/__init__.py b/plugins/sqlfluff-templater-dbt/sqlfluff_templater_dbt/__init__.py index c33b119c53c..f9f94e76f60 100644 --- a/plugins/sqlfluff-templater-dbt/sqlfluff_templater_dbt/__init__.py +++ b/plugins/sqlfluff-templater-dbt/sqlfluff_templater_dbt/__init__.py @@ -1,20 +1,10 @@ """Defines the hook endpoints for the dbt templater plugin.""" -from sqlfluff.core.plugin import hookimpl - -from sqlfluff_templater_dbt.osmosis import DbtProjectContainer from sqlfluff_templater_dbt.templater import DbtTemplater - - -dbt_project_container = DbtProjectContainer() +from sqlfluff.core.plugin import hookimpl @hookimpl def get_templaters(): """Get templaters.""" - - def create_templater(**kwargs): - return DbtTemplater(dbt_project_container=dbt_project_container, **kwargs) - - create_templater.name = DbtTemplater.name - return [create_templater] + return [DbtTemplater] diff --git a/plugins/sqlfluff-templater-dbt/sqlfluff_templater_dbt/osmosis/LICENSE b/plugins/sqlfluff-templater-dbt/sqlfluff_templater_dbt/osmosis/LICENSE deleted file mode 100644 index 261eeb9e9f8..00000000000 --- a/plugins/sqlfluff-templater-dbt/sqlfluff_templater_dbt/osmosis/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/plugins/sqlfluff-templater-dbt/sqlfluff_templater_dbt/osmosis/README.md b/plugins/sqlfluff-templater-dbt/sqlfluff_templater_dbt/osmosis/README.md deleted file mode 100644 index f68d01acd37..00000000000 --- a/plugins/sqlfluff-templater-dbt/sqlfluff_templater_dbt/osmosis/README.md +++ /dev/null @@ -1,4 +0,0 @@ -The code in this directory was vendored (i.e. copied) from -[dbt-osmosis](https://github.com/z3z1ma/dbt-osmosis). Eventually this code -is expected to be released as a separate package, at which time SQLFluff -should use that package instead of vendoring the code. diff --git a/plugins/sqlfluff-templater-dbt/sqlfluff_templater_dbt/osmosis/__init__.py b/plugins/sqlfluff-templater-dbt/sqlfluff_templater_dbt/osmosis/__init__.py deleted file mode 100644 index b93a9968cae..00000000000 --- a/plugins/sqlfluff-templater-dbt/sqlfluff_templater_dbt/osmosis/__init__.py +++ /dev/null @@ -1,1422 +0,0 @@ -# type: ignore -import dbt.adapters.factory - -# This is critical because `get_adapter` is all over dbt-core -# as they expect a singleton adapter instance per plugin, -# so dbt-niceDatabase will have one adapter instance named niceDatabase. -# This makes sense in dbt-land where we have a single Project/Profile -# combination executed in process from start to finish or a single tenant RPC -# This doesn't fit our paradigm of one adapter per DbtProject in a multitenant server, -# so we create an adapter instance **independent** of the FACTORY cache -# and attach it directly to our RuntimeConfig which is passed through -# anywhere dbt-core needs config including in all `get_adapter` calls -dbt.adapters.factory.get_adapter = lambda config: config.adapter - -import os -import re -import threading -import time -import uuid -from collections import OrderedDict, UserDict -from copy import copy -from enum import Enum -from functools import lru_cache, partial -from hashlib import md5 -from itertools import chain -from pathlib import Path -from typing import ( - Any, - Callable, - Dict, - Iterable, - Iterator, - List, - Mapping, - MutableMapping, - Optional, - Set, - Tuple, - TypeVar, - Union, -) - -import agate -from dbt.adapters.base import BaseRelation -from dbt.adapters.factory import Adapter, get_adapter_class_by_name -from dbt.clients import jinja # monkey-patched for perf -from dbt.config.runtime import RuntimeConfig -from dbt.context.providers import generate_runtime_model_context -from dbt.contracts.connection import AdapterResponse -from dbt.contracts.graph.manifest import ( - ManifestNode, - MaybeNonSource, - MaybeParsedSource, -) -from dbt.events.functions import fire_event # monkey-patched for perf -from dbt.exceptions import CompilationException, InternalException, RuntimeException -from dbt.flags import DEFAULT_PROFILES_DIR, set_from_args -from dbt.node_types import NodeType -from dbt.parser.manifest import ManifestLoader, process_node -from dbt.parser.sql import SqlBlockParser, SqlMacroParser -from dbt.task.sql import SqlCompileRunner, SqlExecuteRunner -from dbt.tracking import disable_tracking -from dbt.version import __version__ as dbt_version -from pydantic import BaseModel -from rich.progress import track -from ruamel.yaml import YAML - -from .exceptions import ( - InvalidOsmosisConfig, - MissingOsmosisConfig, - SanitizationRequired, -) -from .log_controller import logger -from .patch import write_manifest_for_partial_parse - -__all__ = [ - "DbtProject", - "DbtProjectContainer", - "DbtYamlManager", - "ConfigInterface", - "has_jinja", - "DbtOsmosis", # for compat -] - -CACHE = {} -CACHE_VERSION = 1 -SQL_CACHE_SIZE = 1024 - -MANIFEST_ARTIFACT = "manifest.json" -DBT_MAJOR_VER, DBT_MINOR_VER, DBT_PATCH_VER = ( - int(v) for v in re.search(r"^([0-9\.])+", dbt_version).group().split(".") -) -RAW_CODE = "raw_code" if DBT_MAJOR_VER >= 1 and DBT_MINOR_VER >= 3 else "raw_sql" -COMPILED_CODE = ( - "compiled_code" if DBT_MAJOR_VER >= 1 and DBT_MINOR_VER >= 3 else "compiled_sql" -) - -JINJA_CONTROL_SEQS = ["{{", "}}", "{%", "%}", "{#", "#}"] - -T = TypeVar("T") - - -def has_jinja(query: str) -> bool: - """Utility to check for jinja prior to certain compilation procedures""" - return any(seq in query for seq in JINJA_CONTROL_SEQS) - - -def memoize_get_rendered(function): - """Custom memoization function for dbt-core jinja interface""" - - def wrapper( - string: str, - ctx: Dict[str, Any], - node: ManifestNode = None, - capture_macros: bool = False, - native: bool = False, - ): - v = md5(string.strip().encode("utf-8")).hexdigest() - v += "__" + str(CACHE_VERSION) - if capture_macros == True and node is not None: - if node.is_ephemeral: - return function(string, ctx, node, capture_macros, native) - v += "__" + node.unique_id - rv = CACHE.get(v) - if rv is not None: - return rv - else: - rv = function(string, ctx, node, capture_macros, native) - CACHE[v] = rv - return rv - - return wrapper - - -# Performance hacks -# jinja.get_rendered = memoize_get_rendered(jinja.get_rendered) -disable_tracking() -fire_event = lambda e: None - - -class ConfigInterface: - """This mimic dbt-core args based interface for dbt-core - class instantiation""" - - def __init__( - self, - threads: Optional[int] = 1, - target: Optional[str] = None, - profiles_dir: Optional[str] = None, - project_dir: Optional[str] = None, - vars: Optional[str] = "{}", - ): - self.threads = threads - if target: - self.target = target # We don't want target in args context if it is None - self.profiles_dir = profiles_dir or DEFAULT_PROFILES_DIR - self.project_dir = project_dir - self.vars = vars # json.dumps str - self.dependencies = [] - self.single_threaded = threads == 1 - self.quiet = True - - @classmethod - def from_str(cls, arguments: str) -> "ConfigInterface": - import argparse - import shlex - - parser = argparse.ArgumentParser() - args = parser.parse_args(shlex.split(arguments)) - return cls( - threads=args.threads, - target=args.target, - profiles_dir=args.profiles_dir, - project_dir=args.project_dir, - ) - - -class YamlHandler(YAML): - """A `ruamel.yaml` wrapper to handle dbt YAML files with sane defaults""" - - def __init__(self, **kwargs) -> None: - super().__init__(**kwargs) - self.indent(mapping=2, sequence=4, offset=2) - self.width = 800 - self.preserve_quotes = True - self.default_flow_style = False - - -class ManifestProxy(UserDict): - """Proxy for manifest dictionary (`flat_graph`), if we need mutation then we should - create a copy of the dict or interface with the dbt-core manifest object instead""" - - def _readonly(self, *args, **kwargs): - raise RuntimeError("Cannot modify ManifestProxy") - - __setitem__ = _readonly - __delitem__ = _readonly - pop = _readonly - popitem = _readonly - clear = _readonly - update = _readonly - setdefault = _readonly - - -class DbtAdapterExecutionResult: - """Interface for execution results, this keeps us 1 layer removed from dbt interfaces which may change""" - - def __init__( - self, - adapter_response: AdapterResponse, - table: agate.Table, - raw_sql: str, - compiled_sql: str, - ) -> None: - self.adapter_response = adapter_response - self.table = table - self.raw_sql = raw_sql - self.compiled_sql = compiled_sql - - -class DbtAdapterCompilationResult: - """Interface for compilation results, this keeps us 1 layer removed from dbt interfaces which may change""" - - def __init__( - self, - raw_sql: str, - compiled_sql: str, - node: ManifestNode, - injected_sql: Optional[str] = None, - ) -> None: - self.raw_sql = raw_sql - self.compiled_sql = compiled_sql - self.node = node - self.injected_sql = injected_sql - - -class DbtProject: - """Container for a dbt project. The dbt attribute is the primary interface for - dbt-core. The adapter attribute is the primary interface for the dbt adapter""" - - ADAPTER_TTL = 3600 - - def __init__( - self, - target: Optional[str] = None, - profiles_dir: Optional[str] = None, - project_dir: Optional[str] = None, - threads: Optional[int] = 1, - vars: Optional[str] = "{}", - ): - self.args = ConfigInterface( - threads=threads, - target=target, - profiles_dir=profiles_dir, - project_dir=project_dir, - vars=vars, - ) - - self.parse_project(init=True) - - # Utilities - self._yaml_handler: Optional[YamlHandler] = None - self._sql_parser: Optional[SqlBlockParser] = None - self._macro_parser: Optional[SqlMacroParser] = None - self._sql_runner: Optional[SqlExecuteRunner] = None - self._sql_compiler: Optional[SqlCompileRunner] = None - - # Tracks internal state version - self._version: int = 1 - self.mutex = threading.Lock() if not self.args.single_threaded else None - # atexit.register(lambda dbt_project: dbt_project.adapter.connections.cleanup_all, self) - - def get_adapter(self): - """This inits a new Adapter which is fundamentally different than - the singleton approach in the core lib""" - adapter_name = self.config.credentials.type - return get_adapter_class_by_name(adapter_name)(self.config) - - def init_adapter(self): - """Initialize a dbt adapter.""" - if hasattr(self, "_adapter"): - self._adapter.connections.cleanup_all() - # The setter verifies connection, resets TTL, and updates adapter ref on config - self.adapter = self.get_adapter() - - @property - def adapter(self): - """dbt-core adapter with TTL and automatic reinstantiation""" - if time.time() - self._adapter_ttl > self.ADAPTER_TTL: - logger().info("TTL expired, reinitializing adapter!") - self.init_adapter() - return self._adapter - - @adapter.setter - def adapter(self, adapter: Adapter): - """Verify connection and reset TTL on adapter set, update adapter prop ref on config""" - self._adapter = self._verify_connection(adapter) - self._adapter_ttl = time.time() - self.config.adapter = self.adapter - - def parse_project(self, init: bool = False) -> None: - """Parses project on disk from `ConfigInterface` in args attribute, verifies connection - to adapters database, mutates config, adapter, and dbt attributes""" - if init: - set_from_args(self.args, self.args) - self.config = RuntimeConfig.from_args(self.args) - self.init_adapter() - - project_parser = ManifestLoader( - self.config, - self.config.load_dependencies(), - self.adapter.connections.set_query_header, - ) - # temporarily patched so we write partial parse to correct directory until its fixed in dbt core - project_parser.write_manifest_for_partial_parse = partial( - write_manifest_for_partial_parse, project_parser - ) - # endpatched (https://github.com/dbt-labs/dbt-core/blob/main/core/dbt/parser/manifest.py#L545) - self.dbt = project_parser.load() - self.dbt.build_flat_graph() - project_parser.save_macros_to_adapter(self.adapter) - self._sql_parser = None - self._macro_parser = None - self._sql_compiler = None - self._sql_runner = None - - @classmethod - def from_args(cls, args: ConfigInterface) -> "DbtProject": - """Instatiate the DbtProject directly from a ConfigInterface instance""" - return cls( - target=args.target, - profiles_dir=args.profiles_dir, - project_dir=args.project_dir, - threads=args.threads, - ) - - @property - def yaml_handler(self) -> YamlHandler: - """A YAML handler for loading and dumping yaml files on disk""" - if self._yaml_handler is None: - self._yaml_handler = YamlHandler() - return self._yaml_handler - - @property - def sql_parser(self) -> SqlBlockParser: - """A dbt-core SQL parser capable of parsing and adding nodes to the manifest via `parse_remote` which will - also return the added node to the caller. Note that post-parsing this still typically requires calls to - `_process_nodes_for_ref` and `_process_sources_for_ref` from `dbt.parser.manifest` - """ - if self._sql_parser is None: - self._sql_parser = SqlBlockParser(self.config, self.dbt, self.config) - return self._sql_parser - - @property - def macro_parser(self) -> SqlMacroParser: - """A dbt-core macro parser""" - if self._macro_parser is None: - self._macro_parser = SqlMacroParser(self.config, self.dbt) - return self._macro_parser - - @property - def sql_runner(self) -> SqlExecuteRunner: - """A runner which is used internally by the `execute_sql` function of `dbt.lib`. - The runners `node` attribute can be updated before calling `compile` or `compile_and_execute`. - """ - if self._sql_runner is None: - self._sql_runner = SqlExecuteRunner( - self.config, self.adapter, node=None, node_index=1, num_nodes=1 - ) - return self._sql_runner - - @property - def sql_compiler(self) -> SqlCompileRunner: - """A runner which is used internally by the `compile_sql` function of `dbt.lib`. - The runners `node` attribute can be updated before calling `compile` or `compile_and_execute`. - """ - if self._sql_compiler is None: - self._sql_compiler = SqlCompileRunner( - self.config, self.adapter, node=None, node_index=1, num_nodes=1 - ) - return self._sql_compiler - - def _verify_connection(self, adapter: Adapter) -> Adapter: - """Verification for adapter + profile. Used as a passthrough, - ie: `self.adapter = _verify_connection(get_adapter(...))` - This also seeds the master connection""" - try: - adapter.connections.set_connection_name() - adapter.debug_query() - except Exception as query_exc: - raise RuntimeException("Could not connect to Database") from query_exc - else: - return adapter - - def adapter_probe(self) -> bool: - """Check adapter connection, useful for long running processes such as the server or workbench""" - if not hasattr(self, "adapter") or self.adapter is None: - return False - try: - with self.adapter.connection_named("osmosis-heartbeat"): - self.adapter.debug_query() - except Exception: - # TODO: we can decide to reinit the Adapter here - return False - logger().info("Heartbeat received for %s", self.project_name) - return True - - def fn_threaded_conn( - self, fn: Callable[..., T], *args, **kwargs - ) -> Callable[..., T]: - """Used for jobs which are intended to be submitted to a thread pool, - the 'master' thread should always have an available connection for the duration of - typical program runtime by virtue of the `_verify_connection` method. - Threads however require singleton seeding""" - - def _with_conn() -> T: - self.adapter.connections.set_connection_name() - return fn(*args, **kwargs) - - return _with_conn - - def generate_runtime_model_context(self, node: ManifestNode): - """Wraps dbt context provider""" - return generate_runtime_model_context(node, self.config, self.dbt) - - @property - def project_name(self) -> str: - """dbt project name""" - return self.config.project_name - - @property - def project_root(self) -> str: - """dbt project root""" - return self.config.project_root - - @property - def manifest(self) -> ManifestProxy: - """dbt manifest dict""" - return ManifestProxy(self.dbt.flat_graph) - - def safe_parse_project(self, reinit: bool = False) -> None: - """This is used to reseed the DbtProject safely post-init. This is - intended for use by the osmosis server""" - if reinit: - self.clear_caches() - _config_pointer = copy(self.config) - try: - self.parse_project(init=reinit) - except Exception as parse_error: - self.config = _config_pointer - raise parse_error - self.write_manifest_artifact() - - def write_manifest_artifact(self) -> None: - """Write a manifest.json to disk""" - artifact_path = os.path.join( - self.config.project_root, self.config.target_path, MANIFEST_ARTIFACT - ) - self.dbt.write(artifact_path) - - def clear_caches(self) -> None: - """Clear least recently used caches and reinstantiable container objects""" - self.get_ref_node.cache_clear() - self.get_source_node.cache_clear() - self.get_macro_function.cache_clear() - self.compile_sql.cache_clear() - - @lru_cache(maxsize=10) - def get_ref_node(self, target_model_name: str) -> MaybeNonSource: - """Get a `ManifestNode` from a dbt project model name""" - return self.dbt.resolve_ref( - target_model_name=target_model_name, - target_model_package=None, - current_project=self.config.project_name, - node_package=self.config.project_name, - ) - - @lru_cache(maxsize=10) - def get_source_node( - self, target_source_name: str, target_table_name: str - ) -> MaybeParsedSource: - """Get a `ManifestNode` from a dbt project source name and table name""" - return self.dbt.resolve_source( - target_source_name=target_source_name, - target_table_name=target_table_name, - current_project=self.config.project_name, - node_package=self.config.project_name, - ) - - def get_server_node(self, sql: str, node_name="name"): - """Get a node for SQL execution against adapter""" - self._clear_node(node_name) - sql_node = self.sql_parser.parse_remote(sql, node_name) - process_node(self.config, self.dbt, sql_node) - return sql_node - - @lru_cache(maxsize=10) - def get_node_by_path(self, path: str): - """Find an existing node given relative file path.""" - for node in self.dbt.nodes.values(): - if node.original_file_path == path: - return node - return None - - @lru_cache(maxsize=100) - def get_macro_function(self, macro_name: str) -> Callable[[Dict[str, Any]], Any]: - """Get macro as a function which takes a dict via argument named `kwargs`, - ie: `kwargs={"relation": ...}` - - make_schema_fn = get_macro_function('make_schema')\n - make_schema_fn({'name': '__test_schema_1'})\n - make_schema_fn({'name': '__test_schema_2'})""" - return partial( - self.adapter.execute_macro, macro_name=macro_name, manifest=self.dbt - ) - - def adapter_execute( - self, sql: str, auto_begin: bool = False, fetch: bool = False - ) -> Tuple[AdapterResponse, agate.Table]: - """Wraps adapter.execute. Execute SQL against database""" - return self.adapter.execute(sql, auto_begin, fetch) - - def execute_macro( - self, - macro: str, - kwargs: Optional[Dict[str, Any]] = None, - ) -> Any: - """Wraps adapter execute_macro. Execute a macro like a function.""" - return self.get_macro_function(macro)(kwargs=kwargs) - - def execute_sql(self, raw_sql: str) -> DbtAdapterExecutionResult: - """Execute dbt SQL statement against database""" - # if no jinja chars then these are synonymous - compiled_sql = raw_sql - if has_jinja(raw_sql): - # jinja found, compile it - compiled_sql = self.compile_sql(raw_sql).compiled_sql - return DbtAdapterExecutionResult( - *self.adapter_execute(compiled_sql, fetch=True), - raw_sql, - compiled_sql, - ) - - def execute_node(self, node: ManifestNode) -> DbtAdapterExecutionResult: - """Execute dbt SQL statement against database from a ManifestNode""" - raw_sql: str = getattr(node, RAW_CODE) - compiled_sql: Optional[str] = getattr(node, COMPILED_CODE, None) - if compiled_sql: - # node is compiled, execute the SQL - return self.execute_sql(compiled_sql) - # node not compiled - if has_jinja(raw_sql): - # node has jinja in its SQL, compile it - compiled_sql = self.compile_node(node).compiled_sql - # execute the SQL - return self.execute_sql(compiled_sql or raw_sql) - - @lru_cache(maxsize=SQL_CACHE_SIZE) - def compile_sql(self, raw_sql: str, retry: int = 3) -> DbtAdapterCompilationResult: - """Creates a node with `get_server_node` method. Compile generated node. - Has a retry built in because even uuidv4 cannot gaurantee uniqueness at the speed - in which we can call this function concurrently. A retry significantly increases the stability - """ - temp_node_id = str(uuid.uuid4()) - try: - node = self.compile_node(self.get_server_node(raw_sql, temp_node_id)) - except Exception as exc: - if retry > 0: - return self.compile_sql(raw_sql, retry - 1) - raise exc - else: - return node - finally: - self._clear_node(temp_node_id) - - def compile_node(self, node: ManifestNode) -> DbtAdapterCompilationResult: - """Compiles existing node.""" - self.sql_compiler.node = node - # this is essentially a convenient wrapper to adapter.get_compiler - compiled_node = self.sql_compiler.compile(self.dbt) - return DbtAdapterCompilationResult( - getattr(compiled_node, RAW_CODE), - getattr(compiled_node, COMPILED_CODE), - compiled_node, - ) - - def _clear_node(self, name="name"): - """Removes the statically named node created by `execute_sql` and `compile_sql` in `dbt.lib`""" - self.dbt.nodes.pop(f"{NodeType.SqlOperation}.{self.project_name}.{name}", None) - - def get_relation( - self, database: str, schema: str, name: str - ) -> Optional[BaseRelation]: - """Wrapper for `adapter.get_relation`""" - return self.adapter.get_relation(database, schema, name) - - def create_relation(self, database: str, schema: str, name: str) -> BaseRelation: - """Wrapper for `adapter.Relation.create`""" - return self.adapter.Relation.create(database, schema, name) - - def create_relation_from_node(self, node: ManifestNode) -> BaseRelation: - """Wrapper for `adapter.Relation.create_from`""" - return self.adapter.Relation.create_from(self.config, node) - - def get_columns_in_relation(self, node: ManifestNode) -> List[str]: - """Wrapper for `adapter.get_columns_in_relation`""" - return self.adapter.get_columns_in_relation( - self.create_relation_from_node(node) - ) - - def get_or_create_relation( - self, database: str, schema: str, name: str - ) -> Tuple[BaseRelation, bool]: - """Get relation or create if not exists. Returns tuple of relation and - boolean result of whether it existed ie: (relation, did_exist)""" - ref = self.get_relation(database, schema, name) - return ( - (ref, True) - if ref - else (self.create_relation(database, schema, name), False) - ) - - def create_schema(self, node: ManifestNode): - """Create a schema in the database""" - return self.execute_macro( - "create_schema", - kwargs={"relation": self.create_relation_from_node(node)}, - ) - - def materialize( - self, node: ManifestNode, temporary: bool = True - ) -> Tuple[AdapterResponse, None]: - """Materialize a table in the database""" - return self.adapter_execute( - # Returns CTAS string so send to adapter.execute - self.execute_macro( - "create_table_as", - kwargs={ - "sql": getattr(node, COMPILED_CODE), - "relation": self.create_relation_from_node(node), - "temporary": temporary, - }, - ), - auto_begin=True, - ) - - -class DbtProjectContainer: - """This class manages multiple DbtProjects which each correspond - to a single dbt project on disk. This is mostly for osmosis server use""" - - def __init__(self): - self._projects: Dict[str, DbtProject] = OrderedDict() - self._default_project: Optional[str] = None - - def get_project(self, project_name: str) -> Optional[DbtProject]: - """Primary interface to get a project and execute code""" - return self._projects.get(project_name) - - @lru_cache(maxsize=10) - def get_project_by_root_dir(self, root_dir: str) -> Optional[DbtProject]: - """Get a project by its root directory.""" - root_dir = os.path.abspath(os.path.normpath(root_dir)) - for project in self._projects.values(): - if os.path.abspath(project.project_root) == root_dir: - return project - return None - - def get_default_project(self) -> Optional[DbtProject]: - """Gets the default project which at any given time is the - earliest project inserted into the container""" - return self._projects.get(self._default_project) - - def add_project( - self, - target: Optional[str] = None, - profiles_dir: Optional[str] = None, - project_dir: Optional[str] = None, - threads: Optional[int] = 1, - name_override: Optional[str] = "", - vars: Optional[str] = "{}", - ) -> DbtProject: - """Add a DbtProject with arguments""" - project = DbtProject(target, profiles_dir, project_dir, threads, vars) - project_name = name_override or project.config.project_name - if self._default_project is None: - self._default_project = project_name - self._projects[project_name] = project - return project - - def add_parsed_project(self, project: DbtProject) -> DbtProject: - """Add an already instantiated DbtProject""" - self._projects.setdefault(project.config.project_name, project) - return project - - def add_project_from_args(self, args: ConfigInterface) -> DbtProject: - """Add a DbtProject from a ConfigInterface""" - project = DbtProject.from_args(args) - self._projects.setdefault(project.config.project_name, project) - return project - - def drop_project(self, project_name: str) -> None: - """Drop a DbtProject""" - project = self.get_project(project_name) - if project is None: - return - project.clear_caches() - project.adapter.connections.cleanup_all() - self._projects.pop(project_name) - if self._default_project == project_name: - if len(self) > 0: - self._default_project = self._projects.keys()[0] - else: - self._default_project = None - - def drop_all_projects(self) -> None: - """Drop all DbtProjectContainers""" - self._default_project = None - for project in self._projects: - self.drop_project(project) - - def reparse_all_projects(self) -> None: - """Reparse all projects""" - for project in self: - project.safe_parse_project() - - def registered_projects(self) -> List[str]: - """Convenience to grab all registered project names""" - return list(self._projects.keys()) - - def __len__(self): - """Allows len(DbtProjectContainer)""" - return len(self._projects) - - def __getitem__(self, project: str): - """Allows DbtProjectContainer['jaffle_shop']""" - maybe_project = self.get_project(project) - if maybe_project is None: - raise KeyError(project) - return maybe_project - - def __delitem__(self, project: str): - """Allows del DbtProjectContainer['jaffle_shop']""" - self.drop_project(project) - - def __iter__(self): - """Allows project for project in DbtProjectContainer""" - for project in self._projects: - yield self.get_project(project) - - def __contains__(self, project): - """Allows 'jaffle_shop' in DbtProjectContainer""" - return project in self._projects - - def __repr__(self): - """Canonical string representation of DbtProjectContainer instance""" - return "\n".join( - f"Project: {project.project_name}, Dir: {project.project_root}" - for project in self - ) - - def __call__(self) -> "DbtProjectContainer": - """This allows the object to be used as a callable, primarily for FastAPI dependency injection - ```python - dbt_project_container = DbtProjectContainer() - def register(x_dbt_project: str = Header(default=None)): - dbt_project_container.add_project(...) - def compile(x_dbt_project: str = Header(default=None), dbt = Depends(dbt_project_container), request: fastapi.Request): - query = request.body() - dbt.get_project(x_dbt_project).compile(query) - ``` - """ - return self - - -DbtOsmosis = DbtProject - - -class SchemaFileOrganizationPattern(str, Enum): - SchemaYaml = "schema.yml" - FolderYaml = "folder.yml" - ModelYaml = "model.yml" - UnderscoreModelYaml = "_model.yml" - SchemaModelYaml = "schema/model.yml" - - -class SchemaFileLocation(BaseModel): - target: Path - current: Optional[Path] = None - - @property - def is_valid(self) -> bool: - return self.current == self.target - - -class SchemaFileMigration(BaseModel): - output: Dict[str, Any] = {} - supersede: Dict[Path, List[str]] = {} - - -class DbtYamlManager(DbtProject): - """The DbtYamlManager class handles developer automation tasks surrounding - schema yaml files organziation, documentation, and coverage.""" - - audit_report = """ - :white_check_mark: [bold]Audit Report[/bold] - ------------------------------- - - Database: [bold green]{database}[/bold green] - Schema: [bold green]{schema}[/bold green] - Table: [bold green]{table}[/bold green] - - Total Columns in Database: {total_columns} - Total Documentation Coverage: {coverage}% - - Action Log: - Columns Added to dbt: {n_cols_added} - Column Knowledge Inherited: {n_cols_doc_inherited} - Extra Columns Removed: {n_cols_removed} - """ - - # TODO: Let user supply a custom arg / config file / csv of strings which we - # consider placeholders which are not valid documentation, these are just my own - # We may well drop the placeholder concept too - placeholders = [ - "Pending further documentation", - "Pending further documentation.", - "No description for this column", - "No description for this column.", - "Not documented", - "Not documented.", - "Undefined", - "Undefined.", - "", - ] - - def __init__( - self, - target: Optional[str] = None, - profiles_dir: Optional[str] = None, - project_dir: Optional[str] = None, - threads: Optional[int] = 1, - fqn: Optional[str] = None, - dry_run: bool = False, - ): - super().__init__(target, profiles_dir, project_dir, threads) - self.fqn = fqn - self.dry_run = dry_run - - def _filter_model(self, node: ManifestNode) -> bool: - """Validates a node as being actionable. Validates both models and sources.""" - fqn = self.fqn or ".".join(node.fqn[1:]) - fqn_parts = fqn.split(".") - logger().debug("%s: %s -> %s", node.resource_type, fqn, node.fqn[1:]) - return ( - # Verify Resource Type - node.resource_type in (NodeType.Model, NodeType.Source) - # Verify Package == Current Project - and node.package_name == self.project_name - # Verify Materialized is Not Ephemeral if NodeType is Model [via short-circuit] - and ( - node.resource_type != NodeType.Model - or node.config.materialized != "ephemeral" - ) - # Verify FQN Length [Always true if no fqn was supplied] - and len(node.fqn[1:]) >= len(fqn_parts) - # Verify FQN Matches Parts [Always true if no fqn was supplied] - and all(left == right for left, right in zip(fqn_parts, node.fqn[1:])) - ) - - @staticmethod - def get_patch_path(node: ManifestNode) -> Optional[Path]: - if node is not None and node.patch_path: - return Path(node.patch_path.split("://")[-1]) - - def filtered_models( - self, subset: Optional[MutableMapping[str, ManifestNode]] = None - ) -> Iterator[Tuple[str, ManifestNode]]: - """Generates an iterator of valid models""" - for unique_id, dbt_node in ( - subset.items() - if subset - else chain(self.dbt.nodes.items(), self.dbt.sources.items()) - ): - if self._filter_model(dbt_node): - yield unique_id, dbt_node - - def get_osmosis_config( - self, node: ManifestNode - ) -> Optional[SchemaFileOrganizationPattern]: - """Validates a config string. If input is a source, we return the resource type str instead""" - if node.resource_type == NodeType.Source: - return None - osmosis_config = node.config.get("dbt-osmosis") - if not osmosis_config: - raise MissingOsmosisConfig( - f"Config not set for model {node.name}, we recommend setting the config at a directory level through the `dbt_project.yml`" - ) - try: - return SchemaFileOrganizationPattern(osmosis_config) - except ValueError as exc: - raise InvalidOsmosisConfig( - f"Invalid config for model {node.name}: {osmosis_config}" - ) from exc - - def get_schema_path(self, node: ManifestNode) -> Optional[Path]: - """Resolve absolute schema file path for a manifest node""" - schema_path = None - if node.resource_type == NodeType.Model and node.patch_path: - schema_path: str = node.patch_path.partition("://")[-1] - elif node.resource_type == NodeType.Source: - if hasattr(node, "source_name"): - schema_path: str = node.path - if schema_path: - return Path(self.project_root).joinpath(schema_path) - - def get_target_schema_path(self, node: ManifestNode) -> Path: - """Resolve the correct schema yml target based on the dbt-osmosis config for the model / directory""" - osmosis_config = self.get_osmosis_config(node) - if not osmosis_config: - return Path(node.root_path, node.original_file_path) - # Here we resolve file migration targets based on the config - if osmosis_config == SchemaFileOrganizationPattern.SchemaYaml: - schema = "schema" - elif osmosis_config == SchemaFileOrganizationPattern.FolderYaml: - schema = node.fqn[-2] - elif osmosis_config == SchemaFileOrganizationPattern.ModelYaml: - schema = node.name - elif osmosis_config == SchemaFileOrganizationPattern.SchemaModelYaml: - schema = "schema/" + node.name - elif osmosis_config == SchemaFileOrganizationPattern.UnderscoreModelYaml: - schema = "_" + node.name - else: - raise InvalidOsmosisConfig( - f"Invalid dbt-osmosis config for model: {node.fqn}" - ) - return Path(node.root_path, node.original_file_path).parent / Path( - f"{schema}.yml" - ) - - @staticmethod - def get_database_parts(node: ManifestNode) -> Tuple[str, str, str]: - return node.database, node.schema, getattr(node, "alias", node.name) - - def bootstrap_existing_model( - self, model_documentation: Dict[str, Any], node: ManifestNode - ) -> Dict[str, Any]: - """Injects columns from database into existing model if not found""" - model_columns: List[str] = [ - c["name"].lower() for c in model_documentation.get("columns", []) - ] - database_columns = self.get_columns(node) - for column in database_columns: - if column.lower() not in model_columns: - logger().info(":syringe: Injecting column %s into dbt schema", column) - model_documentation.setdefault("columns", []).append({"name": column}) - return model_documentation - - def get_columns(self, node: ManifestNode) -> List[str]: - """Get all columns in a list for a model""" - parts = self.get_database_parts(node) - table = self.adapter.get_relation(*parts) - columns = [] - if not table: - logger().info( - ":cross_mark: Relation %s.%s.%s does not exist in target database, cannot resolve columns", - *parts, - ) - return columns - try: - columns = [c.name for c in self.adapter.get_columns_in_relation(table)] - except CompilationException as error: - logger().info( - ":cross_mark: Could not resolve relation %s.%s.%s against database active tables during introspective query: %s", - *parts, - str(error), - ) - return columns - - @staticmethod - def assert_schema_has_no_sources(schema: Mapping) -> Mapping: - """Inline assertion ensuring that a schema does not have a source key""" - if schema.get("sources"): - raise SanitizationRequired( - "Found `sources:` block in a models schema file. We require you separate sources in order to organize your project." - ) - return schema - - def build_schema_folder_mapping( - self, - target_node_type: Optional[Union[NodeType.Model, NodeType.Source]] = None, - ) -> Dict[str, SchemaFileLocation]: - """Builds a mapping of models or sources to their existing and target schema file paths""" - if target_node_type == NodeType.Source: - # Source folder mapping is reserved for source importing - target_nodes = self.dbt.sources - elif target_node_type == NodeType.Model: - target_nodes = self.dbt.nodes - else: - target_nodes = {**self.dbt.nodes, **self.dbt.sources} - # Container for output - schema_map = {} - logger().info("...building project structure mapping in memory") - # Iterate over models and resolve current path vs declarative target path - for unique_id, dbt_node in self.filtered_models(target_nodes): - schema_path = self.get_schema_path(dbt_node) - osmosis_schema_path = self.get_target_schema_path(dbt_node) - schema_map[unique_id] = SchemaFileLocation( - target=osmosis_schema_path, current=schema_path - ) - return schema_map - - def draft_project_structure_update_plan(self) -> Dict[Path, SchemaFileMigration]: - """Build project structure update plan based on `dbt-osmosis:` configs set across dbt_project.yml and model files. - The update plan includes injection of undocumented models. Unless this plan is constructed and executed by the `commit_project_restructure` function, - dbt-osmosis will only operate on models it is aware of through the existing documentation. - - Returns: - MutableMapping: Update plan where dict keys consist of targets and contents consist of outputs which match the contents of the `models` to be output in the - target file and supersede lists of what files are superseded by a migration - """ - - # Container for output - blueprint: Dict[Path, SchemaFileMigration] = {} - logger().info( - ":chart_increasing: Searching project stucture for required updates and building action plan" - ) - with self.adapter.connection_named("dbt-osmosis"): - for unique_id, schema_file in self.build_schema_folder_mapping( - target_node_type=NodeType.Model - ).items(): - if not schema_file.is_valid: - blueprint.setdefault( - schema_file.target, - SchemaFileMigration( - output={"version": 2, "models": []}, supersede={} - ), - ) - node = self.dbt.nodes[unique_id] - if schema_file.current is None: - # Bootstrapping Undocumented Model - blueprint[schema_file.target].output["models"].append( - self.get_base_model(node) - ) - else: - # Model Is Documented but Must be Migrated - if not schema_file.current.exists(): - continue - # TODO: We avoid sources for complexity reasons but if we are opinionated, we don't have to - schema = self.assert_schema_has_no_sources( - self.yaml_handler.load(schema_file.current) - ) - models_in_file: Iterable[Dict[str, Any]] = schema.get( - "models", [] - ) - for documented_model in models_in_file: - if documented_model["name"] == node.name: - # Bootstrapping Documented Model - blueprint[schema_file.target].output["models"].append( - self.bootstrap_existing_model( - documented_model, node - ) - ) - # Target to supersede current - blueprint[schema_file.target].supersede.setdefault( - schema_file.current, [] - ).append(documented_model["name"]) - break - else: - ... # Model not found at patch path -- We should pass on this for now - else: - ... # Valid schema file found for model -- We will update the columns in the `Document` task - - return blueprint - - def commit_project_restructure_to_disk( - self, blueprint: Optional[Dict[Path, SchemaFileMigration]] = None - ) -> bool: - """Given a project restrucure plan of pathlib Paths to a mapping of output and supersedes which is in itself a mapping of Paths to model names, - commit changes to filesystem to conform project to defined structure as code fully or partially superseding existing models as needed. - - Args: - blueprint (Dict[Path, SchemaFileMigration]): Project restructure plan as typically created by `build_project_structure_update_plan` - - Returns: - bool: True if the project was restructured, False if no action was required - """ - - # Build blueprint if not user supplied - if not blueprint: - blueprint = self.draft_project_structure_update_plan() - - # Verify we have actions in the plan - if not blueprint: - logger().info(":1st_place_medal: Project structure approved") - return False - - # Print plan for user auditability - self.pretty_print_restructure_plan(blueprint) - - logger().info( - ":construction_worker: Executing action plan and conforming projecting schemas to defined structure" - ) - for target, structure in blueprint.items(): - if not target.exists(): - # Build File - logger().info(":construction: Building schema file %s", target.name) - if not self.dry_run: - target.parent.mkdir(exist_ok=True, parents=True) - target.touch() - self.yaml_handler.dump(structure.output, target) - - else: - # Update File - logger().info(":toolbox: Updating schema file %s", target.name) - target_schema: Optional[Dict[str, Any]] = self.yaml_handler.load(target) - if not target_schema: - target_schema = {"version": 2} - elif "version" not in target_schema: - target_schema["version"] = 2 - target_schema.setdefault("models", []).extend( - structure.output["models"] - ) - if not self.dry_run: - self.yaml_handler.dump(target_schema, target) - - # Clean superseded schema files - for dir, models in structure.supersede.items(): - preserved_models = [] - raw_schema: Dict[str, Any] = self.yaml_handler.load(dir) - models_marked_for_superseding = set(models) - models_in_schema = set( - map(lambda mdl: mdl["name"], raw_schema.get("models", [])) - ) - non_superseded_models = models_in_schema - models_marked_for_superseding - if len(non_superseded_models) == 0: - logger().info(":rocket: Superseded schema file %s", dir.name) - if not self.dry_run: - dir.unlink(missing_ok=True) - else: - for model in raw_schema["models"]: - if model["name"] in non_superseded_models: - preserved_models.append(model) - raw_schema["models"] = preserved_models - if not self.dry_run: - self.yaml_handler.dump(raw_schema, dir) - logger().info( - ":satellite: Model documentation migrated from %s to %s", - dir.name, - target.name, - ) - - return True - - @staticmethod - def pretty_print_restructure_plan( - blueprint: Dict[Path, SchemaFileMigration] - ) -> None: - logger().info( - list( - map( - lambda plan: (blueprint[plan].supersede or "CREATE", "->", plan), - blueprint.keys(), - ) - ) - ) - - def build_node_ancestor_tree( - self, - node: ManifestNode, - family_tree: Optional[Dict[str, List[str]]] = None, - members_found: Optional[List[str]] = None, - depth: int = 0, - ) -> Dict[str, List[str]]: - """Recursively build dictionary of parents in generational order""" - if family_tree is None: - family_tree = {} - if members_found is None: - members_found = [] - for parent in node.depends_on.nodes: - member = self.dbt.nodes.get(parent, self.dbt.sources.get(parent)) - if member and parent not in members_found: - family_tree.setdefault(f"generation_{depth}", []).append(parent) - members_found.append(parent) - # Recursion - family_tree = self.build_node_ancestor_tree( - member, family_tree, members_found, depth + 1 - ) - return family_tree - - def inherit_column_level_knowledge( - self, - family_tree: Dict[str, Any], - ) -> Dict[str, Dict[str, Any]]: - """Inherit knowledge from ancestors in reverse insertion order to ensure that the most recent ancestor is always the one to inherit from""" - knowledge: Dict[str, Dict[str, Any]] = {} - for generation in reversed(family_tree): - for ancestor in family_tree[generation]: - member: ManifestNode = self.dbt.nodes.get( - ancestor, self.dbt.sources.get(ancestor) - ) - if not member: - continue - for name, info in member.columns.items(): - knowledge.setdefault(name, {"progenitor": ancestor}) - deserialized_info = info.to_dict() - # Handle Info: - # 1. tags are additive - # 2. descriptions are overriden - # 3. meta is merged - # 4. tests are ignored until I am convinced those shouldn't be hand curated with love - if deserialized_info["description"] in self.placeholders: - deserialized_info.pop("description", None) - deserialized_info["tags"] = list( - set( - deserialized_info.pop("tags", []) - + knowledge[name].get("tags", []) - ) - ) - if not deserialized_info["tags"]: - deserialized_info.pop("tags") # poppin' tags like Macklemore - deserialized_info["meta"] = { - **knowledge[name].get("meta", {}), - **deserialized_info["meta"], - } - if not deserialized_info["meta"]: - deserialized_info.pop("meta") - knowledge[name].update(deserialized_info) - return knowledge - - def get_node_columns_with_inherited_knowledge( - self, - node: ManifestNode, - ) -> Dict[str, Dict[str, Any]]: - """Build a knowledgebase for the model based on iterating through ancestors""" - family_tree = self.build_node_ancestor_tree(node) - knowledge = self.inherit_column_level_knowledge(family_tree) - return knowledge - - @staticmethod - def get_column_sets( - database_columns: Iterable[str], - yaml_columns: Iterable[str], - documented_columns: Iterable[str], - ) -> Tuple[List[str], List[str], List[str]]: - """Returns: - missing_columns: Columns in database not in dbt -- will be injected into schema file - undocumented_columns: Columns missing documentation -- descriptions will be inherited and injected into schema file where prior knowledge exists - extra_columns: Columns in schema file not in database -- will be removed from schema file - """ - missing_columns = [ - x - for x in database_columns - if x.lower() not in (y.lower() for y in yaml_columns) - ] - undocumented_columns = [ - x - for x in database_columns - if x.lower() not in (y.lower() for y in documented_columns) - ] - extra_columns = [ - x - for x in yaml_columns - if x.lower() not in (y.lower() for y in database_columns) - ] - return missing_columns, undocumented_columns, extra_columns - - def propagate_documentation_downstream( - self, force_inheritance: bool = False - ) -> None: - schema_map = self.build_schema_folder_mapping() - with self.adapter.connection_named("dbt-osmosis"): - for unique_id, node in track(list(self.filtered_models())): - logger().info( - "\n:point_right: Processing model: [bold]%s[/bold] \n", unique_id - ) - # Get schema file path, must exist to propagate documentation - schema_path: Optional[SchemaFileLocation] = schema_map.get(unique_id) - if schema_path is None or schema_path.current is None: - logger().info( - ":bow: No valid schema file found for model %s", unique_id - ) # We can't take action - continue - - # Build Sets - database_columns: Set[str] = set(self.get_columns(node)) - yaml_columns: Set[str] = set(column for column in node.columns) - - if not database_columns: - logger().info( - ":safety_vest: Unable to resolve columns in database, falling back to using yaml columns as base column set\n" - ) - database_columns = yaml_columns - - # Get documentated columns - documented_columns: Set[str] = set( - column - for column, info in node.columns.items() - if info.description and info.description not in self.placeholders - ) - - # Queue - ( - missing_columns, - undocumented_columns, - extra_columns, - ) = self.get_column_sets( - database_columns, yaml_columns, documented_columns - ) - - if force_inheritance: - # Consider all columns "undocumented" so that inheritance is not selective - undocumented_columns = database_columns - - # Engage - n_cols_added = 0 - n_cols_doc_inherited = 0 - n_cols_removed = 0 - if ( - len(missing_columns) > 0 - or len(undocumented_columns) - or len(extra_columns) > 0 - ): - schema_file = self.yaml_handler.load(schema_path.current) - ( - n_cols_added, - n_cols_doc_inherited, - n_cols_removed, - ) = self.update_schema_file_and_node( - missing_columns, - undocumented_columns, - extra_columns, - node, - schema_file, - ) - if n_cols_added + n_cols_doc_inherited + n_cols_removed > 0: - # Dump the mutated schema file back to the disk - if not self.dry_run: - self.yaml_handler.dump(schema_file, schema_path.current) - logger().info(":sparkles: Schema file updated") - - # Print Audit Report - n_cols = float(len(database_columns)) - n_cols_documented = ( - float(len(documented_columns)) + n_cols_doc_inherited - ) - perc_coverage = ( - min(100.0 * round(n_cols_documented / n_cols, 3), 100.0) - if n_cols > 0 - else "Unable to Determine" - ) - logger().info( - self.audit_report.format( - database=node.database, - schema=node.schema, - table=node.name, - total_columns=n_cols, - n_cols_added=n_cols_added, - n_cols_doc_inherited=n_cols_doc_inherited, - n_cols_removed=n_cols_removed, - coverage=perc_coverage, - ) - ) - - @staticmethod - def remove_columns_not_in_database( - extra_columns: Iterable[str], - node: ManifestNode, - yaml_file_model_section: Dict[str, Any], - ) -> int: - """Removes columns found in dbt model that do not exist in database from both node and model simultaneously - THIS MUTATES THE NODE AND MODEL OBJECTS so that state is always accurate""" - changes_committed = 0 - for column in extra_columns: - node.columns.pop(column, None) - yaml_file_model_section["columns"] = [ - c for c in yaml_file_model_section["columns"] if c["name"] != column - ] - changes_committed += 1 - logger().info(":wrench: Removing column %s from dbt schema", column) - return changes_committed - - def update_schema_file_and_node( - self, - missing_columns: Iterable[str], - undocumented_columns: Iterable[str], - extra_columns: Iterable[str], - node: ManifestNode, - yaml_file: Dict[str, Any], - ) -> Tuple[int, int, int]: - """Take action on a schema file mirroring changes in the node.""" - # We can extrapolate this to a general func - noop = 0, 0, 0 - if node.resource_type == NodeType.Source: - KEY = "tables" - yaml_file_models = None - for src in yaml_file.get("sources", []): - if src["name"] == node.source_name: - # Scope our pointer to a specific portion of the object - yaml_file_models = src - else: - KEY = "models" - yaml_file_models = yaml_file - if yaml_file_models is None: - return noop - for yaml_file_model_section in yaml_file_models[KEY]: - if yaml_file_model_section["name"] == node.name: - logger().info(":microscope: Looking for actions") - n_cols_added = self.add_missing_cols_to_node_and_model( - missing_columns, node, yaml_file_model_section - ) - n_cols_doc_inherited = ( - self.update_undocumented_columns_with_prior_knowledge( - undocumented_columns, node, yaml_file_model_section - ) - ) - n_cols_removed = self.remove_columns_not_in_database( - extra_columns, node, yaml_file_model_section - ) - return n_cols_added, n_cols_doc_inherited, n_cols_removed - logger().info(":thumbs_up: No actions needed") - return noop diff --git a/plugins/sqlfluff-templater-dbt/sqlfluff_templater_dbt/osmosis/exceptions.py b/plugins/sqlfluff-templater-dbt/sqlfluff_templater_dbt/osmosis/exceptions.py deleted file mode 100644 index dadfbe03bcd..00000000000 --- a/plugins/sqlfluff-templater-dbt/sqlfluff_templater_dbt/osmosis/exceptions.py +++ /dev/null @@ -1,14 +0,0 @@ -class InvalidOsmosisConfig(Exception): - pass - - -class MissingOsmosisConfig(Exception): - pass - - -class MissingArgument(Exception): - pass - - -class SanitizationRequired(Exception): - pass diff --git a/plugins/sqlfluff-templater-dbt/sqlfluff_templater_dbt/osmosis/log_controller.py b/plugins/sqlfluff-templater-dbt/sqlfluff_templater_dbt/osmosis/log_controller.py deleted file mode 100644 index 22fa70428e7..00000000000 --- a/plugins/sqlfluff-templater-dbt/sqlfluff_templater_dbt/osmosis/log_controller.py +++ /dev/null @@ -1,75 +0,0 @@ -import logging -from functools import lru_cache -from logging.handlers import RotatingFileHandler -from pathlib import Path -from typing import Optional, Union - -from rich.logging import RichHandler - -# Log File Format -LOG_FILE_FORMAT = "%(asctime)s — %(name)s — %(levelname)s — %(message)s" - -# Log File Path -LOG_PATH = Path.home().absolute() / ".dbt-osmosis" / "logs" - -# Console Output Level -LOGGING_LEVEL = logging.INFO - - -def rotating_log_handler( - name: str, - path: Path, - formatter: str, -) -> RotatingFileHandler: - """This handler writes warning and higher level outputs to logs in a home .dbt-osmosis directory rotating them as needed""" - path.mkdir(parents=True, exist_ok=True) - handler = RotatingFileHandler( - str(path / "{log_name}.log".format(log_name=name)), - maxBytes=int(1e6), - backupCount=3, - ) - handler.setFormatter(logging.Formatter(formatter)) - handler.setLevel(logging.WARNING) - return handler - - -@lru_cache(maxsize=10) -def logger( - name: str = "dbt-osmosis", - level: Optional[Union[int, str]] = None, - path: Optional[Path] = None, - formatter: Optional[str] = None, -) -> logging.Logger: - """Builds and caches loggers. Can be configured with module level attributes or on a call by call basis. - Simplifies logger management without having to instantiate separate pointers in each module. - - Args: - name (str, optional): Logger name, also used for output log file name in `~/.dbt-osmosis/logs` directory. - level (Union[int, str], optional): Logging level, this is explicitly passed to console handler which effects what level of log messages make it to the console. Defaults to logging.INFO. - path (Path, optional): Path for output warning level+ log files. Defaults to `~/.dbt-osmosis/logs` - formatter (str, optional): Format for output log files. Defaults to a "time — name — level — message" format - - Returns: - logging.Logger: Prepared logger with rotating logs and console streaming. Can be executed directly from function. - """ - if isinstance(level, str): - level = getattr(logging, level, logging.INFO) - if level is None: - level = LOGGING_LEVEL - if path is None: - path = LOG_PATH - if formatter is None: - formatter = LOG_FILE_FORMAT - _logger = logging.getLogger(name) - _logger.setLevel(level) - _logger.addHandler(rotating_log_handler(name, path, formatter)) - _logger.addHandler( - RichHandler( - level=level, - rich_tracebacks=True, - markup=True, - show_time=False, - ) - ) - _logger.propagate = False - return _logger diff --git a/plugins/sqlfluff-templater-dbt/sqlfluff_templater_dbt/osmosis/patch.py b/plugins/sqlfluff-templater-dbt/sqlfluff_templater_dbt/osmosis/patch.py deleted file mode 100644 index 4ffce4adf8f..00000000000 --- a/plugins/sqlfluff-templater-dbt/sqlfluff_templater_dbt/osmosis/patch.py +++ /dev/null @@ -1,35 +0,0 @@ -"""A temporary patch for a method in dbt manifest loader we apply in osmosis.py - -""" - -import os - -from dbt.clients.system import make_directory -from dbt.events.functions import fire_event -from dbt.events.types import ManifestWrongMetadataVersion -from dbt.parser.manifest import PARTIAL_PARSE_FILE_NAME -from dbt.version import __version__ - - -def write_manifest_for_partial_parse(self): - # Patched this 👇 - # path = os.path.join(self.root_project.target_path, PARTIAL_PARSE_FILE_NAME) - path = os.path.join( - self.root_project.project_root, - self.root_project.target_path, - PARTIAL_PARSE_FILE_NAME, - ) - try: - # This shouldn't be necessary, but we have gotten bug reports (#3757) of the - # saved manifest not matching the code version. - if self.manifest.metadata.dbt_version != __version__: - fire_event( - ManifestWrongMetadataVersion(version=self.manifest.metadata.dbt_version) - ) - self.manifest.metadata.dbt_version = __version__ - manifest_msgpack = self.manifest.to_msgpack() - make_directory(os.path.dirname(path)) - with open(path, "wb") as fp: - fp.write(manifest_msgpack) - except Exception: - raise diff --git a/plugins/sqlfluff-templater-dbt/sqlfluff_templater_dbt/templater.py b/plugins/sqlfluff-templater-dbt/sqlfluff_templater_dbt/templater.py index 2f45c4e88ae..a192e826df8 100644 --- a/plugins/sqlfluff-templater-dbt/sqlfluff_templater_dbt/templater.py +++ b/plugins/sqlfluff-templater-dbt/sqlfluff_templater_dbt/templater.py @@ -1,41 +1,62 @@ -"""Defines the dbt templater (aka 'sqlfluff-templater-dbt' package). +"""Defines the templaters.""" -Parts of this file are based on dbt-osmosis' dbt templater. -(https://github.com/z3z1ma/dbt-osmosis/blob/main/src/dbt_osmosis/dbt_templater/templater.py) -That project uses the Apache 2.0 license: https://www.apache.org/licenses/LICENSE-2.0 -""" -import logging +from collections import deque +from contextlib import contextmanager +import os import os.path -from pathlib import Path -from typing import Iterator, List, Optional +import logging +from typing import List, Optional, Iterator, Tuple, Any, Dict, Deque + +from dataclasses import dataclass -from dbt.clients import jinja +from dbt.version import get_installed_version +from dbt.config import read_user_config +from dbt.config.runtime import RuntimeConfig as DbtRuntimeConfig +from dbt.adapters.factory import register_adapter, get_adapter +from dbt.compilation import Compiler as DbtCompiler from dbt.exceptions import ( - RuntimeException as DbtRuntimeException, + CompilationException as DbtCompilationException, + FailedToConnectException as DbtFailedToConnectException, ) -from dbt.flags import PROFILES_DIR -from dbt.version import get_installed_version +from dbt import flags +from jinja2 import Environment from jinja2_simple_tags import StandaloneTag from sqlfluff.cli.formatters import OutputStreamFormatter from sqlfluff.core import FluffConfig +from sqlfluff.core.cached_property import cached_property from sqlfluff.core.errors import SQLTemplaterError, SQLFluffSkipFile + from sqlfluff.core.templaters.base import TemplatedFile, large_file_check -from sqlfluff.core.templaters.jinja import JinjaTemplater -from sqlfluff_templater_dbt.osmosis import DbtProjectContainer +from sqlfluff.core.templaters.jinja import JinjaTemplater # Instantiate the templater logger templater_logger = logging.getLogger("sqlfluff.templater") + DBT_VERSION = get_installed_version() DBT_VERSION_STRING = DBT_VERSION.to_version_string() DBT_VERSION_TUPLE = (int(DBT_VERSION.major), int(DBT_VERSION.minor)) -COMPILED_SQL_ATTRIBUTE = ( - "compiled_code" if DBT_VERSION_TUPLE >= (1, 3) else "compiled_sql" -) -RAW_SQL_ATTRIBUTE = "raw_code" if DBT_VERSION_TUPLE >= (1, 3) else "raw_sql" +if DBT_VERSION_TUPLE >= (1, 3): + COMPILED_SQL_ATTRIBUTE = "compiled_code" + RAW_SQL_ATTRIBUTE = "raw_code" +else: # pragma: no cover + COMPILED_SQL_ATTRIBUTE = "compiled_sql" + RAW_SQL_ATTRIBUTE = "raw_sql" + + +@dataclass +class DbtConfigArgs: + """Arguments to load dbt runtime config.""" + + project_dir: Optional[str] = None + profiles_dir: Optional[str] = None + profile: Optional[str] = None + target: Optional[str] = None + single_threaded: bool = False + vars: str = "" class DbtTemplater(JinjaTemplater): @@ -43,18 +64,114 @@ class DbtTemplater(JinjaTemplater): name = "dbt" sequential_fail_limit = 3 + adapters = {} def __init__(self, **kwargs): self.sqlfluff_config = None self.formatter = None self.project_dir = None self.profiles_dir = None + self.working_dir = os.getcwd() self._sequential_fails = 0 - self.dbt_project_container: DbtProjectContainer = kwargs.pop( - "dbt_project_container" - ) super().__init__(**kwargs) + def config_pairs(self): # pragma: no cover TODO? + """Returns info about the given templater for output by the cli.""" + return [("templater", self.name), ("dbt", self.dbt_version)] + + @property + def dbt_version(self): # pragma: no cover + """Gets the dbt version.""" + return DBT_VERSION_STRING + + @cached_property + def dbt_config(self): + """Loads the dbt config.""" + # Here, we read flags.PROFILE_DIR directly, prior to calling + # set_from_args(). Apparently, set_from_args() sets PROFILES_DIR + # to a lowercase version of the value, and the profile wouldn't be + # found if the directory name contained uppercase letters. This fix + # was suggested and described here: + # https://github.com/sqlfluff/sqlfluff/issues/2253#issuecomment-1018722979 + user_config = read_user_config(flags.PROFILES_DIR) + flags.set_from_args( + DbtConfigArgs( + project_dir=self.project_dir, + profiles_dir=self.profiles_dir, + profile=self._get_profile(), + vars=self._get_cli_vars(), + ), + user_config, + ) + self.dbt_config = DbtRuntimeConfig.from_args( + DbtConfigArgs( + project_dir=self.project_dir, + profiles_dir=self.profiles_dir, + profile=self._get_profile(), + target=self._get_target(), + vars=self._get_cli_vars(), + ) + ) + register_adapter(self.dbt_config) + return self.dbt_config + + @cached_property + def dbt_compiler(self): + """Loads the dbt compiler.""" + self.dbt_compiler = DbtCompiler(self.dbt_config) + return self.dbt_compiler + + @cached_property + def dbt_manifest(self): + """Loads the dbt manifest.""" + # Set dbt not to run tracking. We don't load + # a full project and so some tracking routines + # may fail. + from dbt.tracking import do_not_track + + do_not_track() + + # dbt 0.20.* and onward + from dbt.parser.manifest import ManifestLoader + + old_cwd = os.getcwd() + try: + # Changing cwd temporarily as dbt is not using project_dir to + # read/write `target/partial_parse.msgpack`. This can be undone when + # https://github.com/dbt-labs/dbt-core/issues/6055 is solved. + os.chdir(self.project_dir) + self.dbt_manifest = ManifestLoader.get_full_manifest(self.dbt_config) + finally: + os.chdir(old_cwd) + return self.dbt_manifest + + @cached_property + def dbt_selector_method(self): + """Loads the dbt selector method.""" + if self.formatter: # pragma: no cover TODO? + self.formatter.dispatch_compilation_header( + "dbt templater", "Compiling dbt project..." + ) + + from dbt.graph.selector_methods import ( + MethodManager as DbtSelectorMethodManager, + MethodName as DbtMethodName, + ) + + selector_methods_manager = DbtSelectorMethodManager( + self.dbt_manifest, previous_state=None + ) + self.dbt_selector_method = selector_methods_manager.get_method( + DbtMethodName.Path, method_arguments=[] + ) + + if self.formatter: # pragma: no cover TODO? + self.formatter.dispatch_compilation_header( + "dbt templater", "Project Compiled." + ) + + return self.dbt_selector_method + def _get_profiles_dir(self): """Get the dbt profiles directory from the configuration. @@ -69,7 +186,7 @@ def _get_profiles_dir(self): self.sqlfluff_config.get_section( (self.templater_selector, self.name, "profiles_dir") ) - or PROFILES_DIR + or flags.PROFILES_DIR ) ) @@ -137,38 +254,73 @@ def sequence_files( self.project_dir = self._get_project_dir() if not self.profiles_dir: self.profiles_dir = self._get_profiles_dir() - yield from super().sequence_files(fnames, config, formatter) - def config_pairs(self): # pragma: no cover - """Returns info about the given templater for output by the cli.""" - return [ - ("templater", self.name), - ("dbt", get_installed_version().to_version_string()), - ] - - def _find_node(self, project, fname): - expected_node_path = os.path.relpath( - fname, start=os.path.abspath(project.args.project_dir) - ) - node = project.get_node_by_path(expected_node_path) - if node: - return node - skip_reason = self._find_skip_reason(project, expected_node_path) - if skip_reason: - raise SQLFluffSkipFile(f"Skipped file {fname} because it is {skip_reason}") - raise SQLFluffSkipFile( - f"File {fname} was not found in dbt project" - ) # pragma: no cover + # Populate full paths for selected files + full_paths: Dict[str, str] = {} + selected_files = set() + for fname in fnames: + fpath = os.path.join(self.working_dir, fname) + full_paths[fpath] = fname + selected_files.add(fpath) + + ephemeral_nodes: Dict[str, Tuple[str, Any]] = {} + + # Extract the ephemeral models + for key, node in self.dbt_manifest.nodes.items(): + if node.config.materialized == "ephemeral": + # The key is the full filepath. + # The value tuple, with the filepath and a list of dependent keys + ephemeral_nodes[key] = ( + os.path.join(self.project_dir, node.original_file_path), + node.depends_on.nodes, + ) + + # Yield ephemeral nodes first. We use a deque for efficient re-queuing. + # We iterate through the deque, yielding any nodes without dependents, + # or where those dependents have already yielded, first. The original + # mapping is still used to hold the metadata on each key. + already_yielded = set() + ephemeral_buffer: Deque[str] = deque(ephemeral_nodes.keys()) + while ephemeral_buffer: + key = ephemeral_buffer.popleft() + fpath, dependents = ephemeral_nodes[key] + + # If it's not in our selection, skip it + if fpath not in selected_files: + templater_logger.debug("- Purging unselected ephemeral: %r", fpath) + # If there are dependent nodes in the set, don't process it yet. + elif any( + dependent in ephemeral_buffer for dependent in dependents + ): # pragma: no cover + templater_logger.debug( + "- Requeuing ephemeral with dependents: %r", fpath + ) + # Requeue it for later + ephemeral_buffer.append(key) + # Otherwise yield it. + else: + templater_logger.debug("- Yielding Ephemeral: %r", fpath) + yield full_paths[fpath] + already_yielded.add(full_paths[fpath]) + + for fname in fnames: + if fname not in already_yielded: + yield fname + # Dedupe here so we don't yield twice + already_yielded.add(fname) + else: + templater_logger.debug( + "- Skipping yield of previously sequenced file: %r", fname + ) @large_file_check def process( self, *, - in_str: Optional[str] = None, fname: str, + in_str: Optional[str] = None, config: Optional[FluffConfig] = None, formatter: Optional[OutputStreamFormatter] = None, - **kwargs, ): """Compile a dbt model and return the compiled SQL. @@ -182,128 +334,249 @@ def process( # Stash the formatter if provided to use in cached methods. self.formatter = formatter self.sqlfluff_config = config + self.project_dir = self._get_project_dir() + self.profiles_dir = self._get_profiles_dir() + fname_absolute_path = os.path.abspath(fname) + try: - processsed_result = self._unsafe_process( - os.path.abspath(fname) if fname else None, in_str, config - ) + os.chdir(self.project_dir) + processed_result = self._unsafe_process(fname_absolute_path, in_str, config) # Reset the fail counter self._sequential_fails = 0 - return processsed_result - except DbtRuntimeException as e: + return processed_result + except DbtCompilationException as e: # Increment the counter self._sequential_fails += 1 - message = ( - f"dbt error on file '{e.node.original_file_path}', " f"{e.msg}" - if e.node - else f"dbt error: {e.msg}" - ) + if e.node: + return None, [ + SQLTemplaterError( + f"dbt compilation error on file '{e.node.original_file_path}', " + f"{e.msg}", + # It's fatal if we're over the limit + fatal=self._sequential_fails > self.sequential_fail_limit, + ) + ] + else: + raise # pragma: no cover + except DbtFailedToConnectException as e: return None, [ SQLTemplaterError( - message, - # It's fatal if we're over the limit - fatal=self._sequential_fails > self.sequential_fail_limit, + "dbt tried to connect to the database and failed: you could use " + "'execute' to skip the database calls. See" + "https://docs.getdbt.com/reference/dbt-jinja-functions/execute/ " + f"Error: {e.msg}", + fatal=True, ) ] # If a SQLFluff error is raised, just pass it through except SQLTemplaterError as e: # pragma: no cover return None, [e] + finally: + os.chdir(self.working_dir) + + def _find_node(self, fname, config=None): + if not config: # pragma: no cover + raise ValueError( + "For the dbt templater, the `process()` method " + "requires a config object." + ) + if not fname: # pragma: no cover + raise ValueError( + "For the dbt templater, the `process()` method requires a file name" + ) + elif fname == "stdin": # pragma: no cover + raise ValueError( + "The dbt templater does not support stdin input, provide a path instead" + ) + selected = self.dbt_selector_method.search( + included_nodes=self.dbt_manifest.nodes, + # Selector needs to be a relative path + selector=os.path.relpath(fname, start=os.getcwd()), + ) + results = [self.dbt_manifest.expect(uid) for uid in selected] + + if not results: + skip_reason = self._find_skip_reason(fname) + if skip_reason: + raise SQLFluffSkipFile( + f"Skipped file {fname} because it is {skip_reason}" + ) + raise SQLFluffSkipFile( + "File %s was not found in dbt project" % fname + ) # pragma: no cover + return results[0] - def _find_skip_reason(self, project, expected_node_path) -> Optional[str]: + def _find_skip_reason(self, fname) -> Optional[str]: """Return string reason if model okay to skip, otherwise None.""" # Scan macros. - for macro in project.dbt.macros.values(): - if macro.original_file_path == expected_node_path: + abspath = os.path.abspath(fname) + for macro in self.dbt_manifest.macros.values(): + if os.path.abspath(macro.original_file_path) == abspath: return "a macro" # Scan disabled nodes. - for nodes in project.dbt.disabled.values(): + for nodes in self.dbt_manifest.disabled.values(): for node in nodes: - if node.original_file_path == expected_node_path: + if os.path.abspath(node.original_file_path) == abspath: return "disabled" - return None # pragma: no cover - - def _unsafe_process( - self, fname: Optional[str], in_str: str, config: FluffConfig = None - ): - # Get project_dir from '.sqlfluff' config file - self.project_dir = ( - config.get_section((self.templater_selector, self.name, "project_dir")) - or os.getcwd() + return None + + def _unsafe_process(self, fname, in_str=None, config=None): + original_file_path = os.path.relpath(fname, start=os.getcwd()) + + # Below, we monkeypatch Environment.from_string() to intercept when dbt + # compiles (i.e. runs Jinja) to expand the "node" corresponding to fname. + # We do this to capture the Jinja context at the time of compilation, i.e.: + # - Jinja Environment object + # - Jinja "globals" dictionary + # + # This info is captured by the "make_template()" function, which in + # turn is used by our parent class' (JinjaTemplater) slice_file() + # function. + old_from_string = Environment.from_string + make_template = None + + def from_string(*args, **kwargs): + """Replaces (via monkeypatch) the jinja2.Environment function.""" + nonlocal make_template + # Is it processing the node corresponding to fname? + globals = kwargs.get("globals") + if globals: + model = globals.get("model") + if model: + if model.get("original_file_path") == original_file_path: + # Yes. Capture the important arguments and create + # a make_template() function. + env = args[0] + globals = args[2] if len(args) >= 3 else kwargs["globals"] + + def make_template(in_str): + env.add_extension(SnapshotExtension) + return env.from_string(in_str, globals=globals) + + return old_from_string(*args, **kwargs) + + node = self._find_node(fname, config) + templater_logger.debug( + "_find_node for path %r returned object of type %s.", fname, type(node) ) - # Get project - osmosis_dbt_project = self.dbt_project_container.get_project_by_root_dir( - self.project_dir + + save_ephemeral_nodes = dict( + (k, v) + for k, v in self.dbt_manifest.nodes.items() + if v.config.materialized == "ephemeral" + and not getattr(v, "compiled", False) ) - if not osmosis_dbt_project: - if not self.profiles_dir: - self.profiles_dir = self._get_profiles_dir() - assert self.project_dir - assert self.profiles_dir - osmosis_dbt_project = self.dbt_project_container.add_project( - project_dir=self.project_dir, - profiles_dir=self.profiles_dir, - vars=self._get_cli_vars(), + with self.connection(): + # Apply the monkeypatch. + Environment.from_string = from_string + try: + node = self.dbt_compiler.compile_node( + node=node, + manifest=self.dbt_manifest, + ) + except Exception as err: # pragma: no cover + templater_logger.exception( + "Fatal dbt compilation error on %s. This occurs most often " + "during incorrect sorting of ephemeral models before linting. " + "Please report this error on github at " + "https://github.com/sqlfluff/sqlfluff/issues, including " + "both the raw and compiled sql for the model affected.", + fname, + ) + # Additional error logging in case we get a fatal dbt error. + raise SQLFluffSkipFile( # pragma: no cover + f"Skipped file {fname} because dbt raised a fatal " + f"exception during compilation: {err!s}" + ) from err + finally: + # Undo the monkeypatch. + Environment.from_string = old_from_string + + if hasattr(node, "injected_sql"): + # If injected SQL is present, it contains a better picture + # of what will actually hit the database (e.g. with tests). + # However it's not always present. + compiled_sql = node.injected_sql # pragma: no cover + else: + compiled_sql = getattr(node, COMPILED_SQL_ATTRIBUTE) + + raw_sql = getattr(node, RAW_SQL_ATTRIBUTE) + + if not compiled_sql: # pragma: no cover + raise SQLTemplaterError( + "dbt templater compilation failed silently, check your " + "configuration by running `dbt compile` directly." + ) + source_dbt_sql = in_str + if not source_dbt_sql.rstrip().endswith("-%}"): + n_trailing_newlines = len(source_dbt_sql) - len( + source_dbt_sql.rstrip("\n") + ) + else: + # Source file ends with right whitespace stripping, so there's + # no need to preserve/restore trailing newlines, as they would + # have been removed regardless of dbt's + # keep_trailing_newlines=False behavior. + n_trailing_newlines = 0 + + templater_logger.debug( + " Trailing newline count in source dbt model: %r", + n_trailing_newlines, ) - - # If in_str not provided, use path if file is present. - fpath = Path(fname) - if fpath.exists() and not in_str: - in_str = fpath.read_text() - - self.dbt_config = osmosis_dbt_project.config - node = self._find_node(osmosis_dbt_project, fname) - node = osmosis_dbt_project.compile_node(node).node - # Generate context - ctx = osmosis_dbt_project.generate_runtime_model_context(node) - env = jinja.get_environment(node) - env.add_extension(SnapshotExtension) - if hasattr(node, "injected_sql"): - # If injected SQL is present, it contains a better picture - # of what will actually hit the database (e.g. with tests). - # However it's not always present. - compiled_sql = node.injected_sql # pragma: no cover - else: - compiled_sql = getattr(node, COMPILED_SQL_ATTRIBUTE) - - def make_template(_in_str): - return env.from_string(_in_str, globals=ctx) - - # Need compiled - if not compiled_sql: # pragma: no cover - raise SQLTemplaterError( - "dbt templater compilation failed silently, check your " - "configuration by running `dbt compile` directly." + templater_logger.debug(" Raw SQL before compile: %r", source_dbt_sql) + templater_logger.debug(" Node raw SQL: %r", raw_sql) + templater_logger.debug(" Node compiled SQL: %r", compiled_sql) + + # When using dbt-templater, trailing newlines are ALWAYS REMOVED during + # compiling. Unless fixed (like below), this will cause: + # 1. Assertion errors in TemplatedFile, when it sanity checks the + # contents of the sliced_file array. + # 2. L009 linting errors when running "sqlfluff lint foo_bar.sql" + # since the linter will use the compiled code with the newlines + # removed. + # 3. "No newline at end of file" warnings in Git/GitHub since + # sqlfluff uses the compiled SQL to write fixes back to the + # source SQL in the dbt model. + # + # The solution is (note that both the raw and compiled files have + # had trailing newline(s) removed by the dbt-templater. + # 1. Check for trailing newlines before compiling by looking at the + # raw SQL in the source dbt file. Remember the count of trailing + # newlines. + # 2. Set node.raw_sql/node.raw_code to the original source file contents. + # 3. Append the count from #1 above to compiled_sql. (In + # production, slice_file() does not usually use this string, + # but some test scenarios do. + setattr(node, RAW_SQL_ATTRIBUTE, source_dbt_sql) + compiled_sql = compiled_sql + "\n" * n_trailing_newlines + + # TRICKY: dbt configures Jinja2 with keep_trailing_newline=False. + # As documented (https://jinja.palletsprojects.com/en/3.0.x/api/), + # this flag's behavior is: "Preserve the trailing newline when + # rendering templates. The default is False, which causes a single + # newline, if present, to be stripped from the end of the template." + # + # Below, we use "append_to_templated" to effectively "undo" this. + raw_sliced, sliced_file, templated_sql = self.slice_file( + source_dbt_sql, + compiled_sql, + config=config, + make_template=make_template, + append_to_templated="\n" if n_trailing_newlines else "", ) - - # Whitespace - if not in_str.rstrip().endswith("-%}"): - n_trailing_newlines = len(in_str) - len(in_str.rstrip("\n")) - else: - # Source file ends with right whitespace stripping, so there's - # no need to preserve/restore trailing newlines. - n_trailing_newlines = 0 - - # LOG - templater_logger.debug( - " Trailing newline count in source dbt model: %r", - n_trailing_newlines, - ) - templater_logger.debug(" Raw SQL before compile: %r", in_str) - templater_logger.debug(" Node raw SQL: %r", in_str) - templater_logger.debug(" Node compiled SQL: %r", compiled_sql) - - # SLICE - raw_sliced, sliced_file, templated_sql = self.slice_file( - raw_str=in_str, - templated_str=compiled_sql + "\n" * n_trailing_newlines, - config=config, - make_template=make_template, - append_to_templated="\n" if n_trailing_newlines else "", - ) - + # :HACK: If calling compile_node() compiled any ephemeral nodes, + # restore them to their earlier state. This prevents a runtime error + # in the dbt "_inject_ctes_into_sql()" function that occurs with + # 2nd-level ephemeral model dependencies (e.g. A -> B -> C, where + # both B and C are ephemeral). Perhaps there is a better way to do + # this, but this seems good enough for now. + for k, v in save_ephemeral_nodes.items(): + if getattr(self.dbt_manifest.nodes[k], "compiled", False): + self.dbt_manifest.nodes[k] = v return ( TemplatedFile( - source_str=in_str, + source_str=source_dbt_sql, templated_str=templated_sql, fname=fname, sliced_file=sliced_file, @@ -313,11 +586,35 @@ def make_template(_in_str): [], ) + @contextmanager + def connection(self): + """Context manager that manages a dbt connection, if needed.""" + # We have to register the connection in dbt >= 1.0.0 ourselves + # In previous versions, we relied on the functionality removed in + # https://github.com/dbt-labs/dbt-core/pull/4062. + adapter = self.adapters.get(self.project_dir) + if adapter is None: + adapter = get_adapter(self.dbt_config) + self.adapters[self.project_dir] = adapter + adapter.acquire_connection("master") + adapter.set_relations_cache(self.dbt_manifest) + + yield + # :TRICKY: Once connected, we never disconnect. Making multiple + # connections during linting has proven to cause major performance + # issues. + class SnapshotExtension(StandaloneTag): """Dummy "snapshot" tags so raw dbt templates will parse. - For more context, see sqlfluff-templater-dbt. + Context: dbt snapshots + (https://docs.getdbt.com/docs/building-a-dbt-project/snapshots/#example) + use custom Jinja "snapshot" and "endsnapshot" tags. However, dbt does not + actually register those tags with Jinja. Instead, it finds and removes these + tags during a preprocessing step. However, DbtTemplater needs those tags to + actually parse, because JinjaTracer creates and uses Jinja to process + another template similar to the original one. """ tags = {"snapshot", "endsnapshot"} diff --git a/plugins/sqlfluff-templater-dbt/test/fixtures/dbt/templater.py b/plugins/sqlfluff-templater-dbt/test/fixtures/dbt/templater.py index d1e9f393230..40befeb897d 100644 --- a/plugins/sqlfluff-templater-dbt/test/fixtures/dbt/templater.py +++ b/plugins/sqlfluff-templater-dbt/test/fixtures/dbt/templater.py @@ -31,6 +31,4 @@ def project_dir(): @pytest.fixture() def dbt_templater(): """Returns an instance of the DbtTemplater.""" - templater = FluffConfig(overrides={"dialect": "ansi"}).get_templater("dbt") - templater.dbt_project_container.drop_all_projects() - return templater + return FluffConfig(overrides={"dialect": "ansi"}).get_templater("dbt") diff --git a/plugins/sqlfluff-templater-dbt/test/templater_test.py b/plugins/sqlfluff-templater-dbt/test/templater_test.py index 3b883bd0c19..84a32c05d5e 100644 --- a/plugins/sqlfluff-templater-dbt/test/templater_test.py +++ b/plugins/sqlfluff-templater-dbt/test/templater_test.py @@ -17,9 +17,7 @@ dbt_templater, project_dir, ) -from dbt.exceptions import ( - RuntimeException as DbtRuntimeException, -) +from sqlfluff_templater_dbt.templater import DbtFailedToConnectException, DbtTemplater def test__templater_dbt_missing(dbt_templater, project_dir): # noqa: F811 @@ -132,6 +130,56 @@ def _get_fixture_path(template_output_folder_path, fname): return fixture_path +@pytest.mark.parametrize( + "fnames_input, fnames_expected_sequence", + [ + [ + ( + Path("models") / "depends_on_ephemeral" / "a.sql", + Path("models") / "depends_on_ephemeral" / "b.sql", + Path("models") / "depends_on_ephemeral" / "d.sql", + ), + # c.sql is not present in the original list and should not appear here, + # even though b.sql depends on it. This test ensures that "out of scope" + # files, e.g. those ignored using ".sqlfluffignore" or in directories + # outside what was specified, are not inadvertently processed. + ( + Path("models") / "depends_on_ephemeral" / "a.sql", + Path("models") / "depends_on_ephemeral" / "b.sql", + Path("models") / "depends_on_ephemeral" / "d.sql", + ), + ], + [ + ( + Path("models") / "depends_on_ephemeral" / "a.sql", + Path("models") / "depends_on_ephemeral" / "b.sql", + Path("models") / "depends_on_ephemeral" / "c.sql", + Path("models") / "depends_on_ephemeral" / "d.sql", + ), + # c.sql should come before b.sql because b.sql depends on c.sql. + # It also comes first overall because ephemeral models come first. + ( + Path("models") / "depends_on_ephemeral" / "c.sql", + Path("models") / "depends_on_ephemeral" / "a.sql", + Path("models") / "depends_on_ephemeral" / "b.sql", + Path("models") / "depends_on_ephemeral" / "d.sql", + ), + ], + ], +) +def test__templater_dbt_sequence_files_ephemeral_dependency( + project_dir, dbt_templater, fnames_input, fnames_expected_sequence # noqa: F811 +): + """Test that dbt templater sequences files based on dependencies.""" + result = dbt_templater.sequence_files( + [str(Path(project_dir) / fn) for fn in fnames_input], + config=FluffConfig(configs=DBT_FLUFF_CONFIG), + ) + pd = Path(project_dir) + expected = [str(pd / fn) for fn in fnames_expected_sequence] + assert list(result) == expected + + @pytest.mark.parametrize( "raw_file,templated_file,result", [ @@ -307,7 +355,7 @@ def test__templater_dbt_templating_absolute_path( [ ( "compiler_error.sql", - "dbt error on file 'models/my_new_project/compiler_error.sql', " + "dbt compilation error on file 'models/my_new_project/compiler_error.sql', " "Unexpected end of template. " "Jinja was looking for the following tags: 'endfor'", ), @@ -317,6 +365,8 @@ def test__templater_dbt_handle_exceptions( project_dir, dbt_templater, fname, exception_msg # noqa: F811 ): """Test that exceptions during compilation are returned as violation.""" + from dbt.adapters.factory import get_adapter + src_fpath = "plugins/sqlfluff-templater-dbt/test/fixtures/dbt/error_models/" + fname target_fpath = os.path.abspath( os.path.join(project_dir, "models/my_new_project/", fname) @@ -332,17 +382,23 @@ def test__templater_dbt_handle_exceptions( ) finally: os.rename(target_fpath, src_fpath) + get_adapter(dbt_templater.dbt_config).connections.release() assert violations # NB: Replace slashes to deal with different platform paths being returned. assert violations[0].desc().replace("\\", "/").startswith(exception_msg) -@mock.patch("sqlfluff_templater_dbt.osmosis.DbtProjectContainer.add_project") +@mock.patch("dbt.adapters.postgres.impl.PostgresAdapter.set_relations_cache") def test__templater_dbt_handle_database_connection_failure( - add_project, project_dir, dbt_templater # noqa: F811 + set_relations_cache, project_dir, dbt_templater # noqa: F811 ): """Test the result of a failed database connection.""" - add_project.side_effect = DbtRuntimeException("dummy error") + from dbt.adapters.factory import get_adapter + + # Clear the adapter cache to force this test to create a new connection. + DbtTemplater.adapters.clear() + + set_relations_cache.side_effect = DbtFailedToConnectException("dummy error") src_fpath = ( "plugins/sqlfluff-templater-dbt/test/fixtures/dbt/error_models" @@ -368,9 +424,15 @@ def test__templater_dbt_handle_database_connection_failure( ) finally: os.rename(target_fpath, src_fpath) + get_adapter(dbt_templater.dbt_config).connections.release() assert violations # NB: Replace slashes to deal with different platform paths being returned. - assert violations[0].desc().replace("\\", "/").startswith("dbt error") + assert ( + violations[0] + .desc() + .replace("\\", "/") + .startswith("dbt tried to connect to the database") + ) def test__project_dir_does_not_exist_error(dbt_templater, caplog): # noqa: F811 diff --git a/setup.cfg b/setup.cfg index 9762f6e72e8..8468df93552 100644 --- a/setup.cfg +++ b/setup.cfg @@ -83,7 +83,6 @@ install_requires = Jinja2 # Used for .sqlfluffignore pathspec - pydantic # We provide a testing library for plugins in sqlfluff.utils.testing pytest # We require pyyaml >= 5.1 so that we can preserve the ordering of keys. @@ -94,8 +93,6 @@ install_requires = # This was introduced in https://github.com/sqlfluff/sqlfluff/pull/2027 # and further details can be found in that PR. regex - rich - ruamel.yaml # For returning exceptions from multiprocessing.Pool.map() tblib # For parsing pyproject.toml