Skip to content

Commit

Permalink
python model beta feature(dbt-labs#5421)
Browse files Browse the repository at this point in the history
* Python model beta version with update to manifest that renames `raw_sql` and `compiled_sql` to `raw_code` and `compiled_code`
Co-authored-by: Jeremy Cohen <[email protected]>
Co-authored-by: Ian Knox <[email protected]>
Co-authored-by: Stu Kilgore <[email protected]>
  • Loading branch information
ChenyuLInx authored Jul 28, 2022
1 parent 2547e4f commit a7ff003
Show file tree
Hide file tree
Showing 65 changed files with 1,249 additions and 338 deletions.
7 changes: 7 additions & 0 deletions .changes/unreleased/Features-20220510-165130.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Features
body: Python model inital version
time: 2022-05-10T16:51:30.245589-07:00
custom:
Author: ChenyuLInx
Issue: "0"
PR: "13"
7 changes: 7 additions & 0 deletions .changes/unreleased/Features-20220627-131042.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Features
body: Python models can support incremental logic
time: 2022-06-27T13:10:42.123303-05:00
custom:
Author: iknox-fa
Issue: "0"
PR: "35"
7 changes: 7 additions & 0 deletions .changes/unreleased/Under the Hood-20220713-124925.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Under the Hood
body: Added language to tracked fields in run_model event
time: 2022-07-13T12:49:25.362678-05:00
custom:
Author: stu-k
Issue: "0000"
PR: "5469"
7 changes: 7 additions & 0 deletions .changes/unreleased/Under the Hood-20220728-094536.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Under the Hood
body: Add python incremental materialization test
time: 2022-07-28T09:45:36.13608-05:00
custom:
Author: stu-k
Issue: "0000"
PR: "5571"
57 changes: 51 additions & 6 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from concurrent.futures import as_completed, Future
from contextlib import contextmanager
from datetime import datetime
import time
from itertools import chain
from typing import (
Optional,
Expand Down Expand Up @@ -42,7 +43,12 @@
from dbt.contracts.graph.parsed import ParsedSeedNode
from dbt.exceptions import warn_or_error
from dbt.events.functions import fire_event
from dbt.events.types import CacheMiss, ListRelations
from dbt.events.types import (
CacheMiss,
ListRelations,
CodeExecution,
CodeExecutionStatus,
)
from dbt.utils import filter_null_values, executor

from dbt.adapters.base.connections import Connection, AdapterResponse
Expand Down Expand Up @@ -284,7 +290,9 @@ def load_macro_manifest(self, base_macros_only=False) -> MacroManifest:
from dbt.parser.manifest import ManifestLoader

manifest = ManifestLoader.load_macros(
self.config, self.connections.set_query_header, base_macros_only=base_macros_only
self.config,
self.connections.set_query_header,
base_macros_only=base_macros_only,
)
# TODO CT-211
self._macro_manifest_lazy = manifest # type: ignore[assignment]
Expand All @@ -303,7 +311,11 @@ def _schema_is_cached(self, database: Optional[str], schema: str) -> bool:

if (database, schema) not in self.cache:
fire_event(
CacheMiss(conn_name=self.nice_connection_name(), database=database, schema=schema)
CacheMiss(
conn_name=self.nice_connection_name(),
database=database,
schema=schema,
)
)
return False
else:
Expand Down Expand Up @@ -381,7 +393,10 @@ def _relations_cache_for_schemas(
self.cache.update_schemas(cache_update)

def set_relations_cache(
self, manifest: Manifest, clear: bool = False, required_schemas: Set[BaseRelation] = None
self,
manifest: Manifest,
clear: bool = False,
required_schemas: Set[BaseRelation] = None,
) -> None:
"""Run a query that gets a populated cache of the relations in the
database and set the cache on this adapter.
Expand Down Expand Up @@ -670,15 +685,20 @@ def list_relations(self, database: Optional[str], schema: str) -> List[BaseRelat
return self.cache.get_relations(database, schema)

schema_relation = self.Relation.create(
database=database, schema=schema, identifier="", quote_policy=self.config.quoting
database=database,
schema=schema,
identifier="",
quote_policy=self.config.quoting,
).without_identifier()

# we can't build the relations cache because we don't have a
# manifest so we can't run any operations.
relations = self.list_relations_without_caching(schema_relation)
fire_event(
ListRelations(
database=database, schema=schema, relations=[_make_key(x) for x in relations]
database=database,
schema=schema,
relations=[_make_key(x) for x in relations],
)
)

Expand Down Expand Up @@ -1162,6 +1182,10 @@ def get_rows_different_sql(

return sql

@available.parse_none
def submit_python_job(self, parsed_model: dict, compiled_code: str):
raise NotImplementedException("`submit_python_job` is not implemented for this adapter!")

def valid_incremental_strategies(self):
"""The set of standard builtin strategies which this adapter supports out-of-the-box.
Not used to validate custom strategies defined by end users.
Expand Down Expand Up @@ -1250,3 +1274,24 @@ def catch_as_completed(
# exc is not None, derives from Exception, and isn't ctrl+c
exceptions.append(exc)
return merge_tables(tables), exceptions


def log_code_execution(code_execution_function):
# decorator to log code and execution time
if code_execution_function.__name__ != "submit_python_job":
raise ValueError("this should be only used to log submit_python_job now")

def execution_with_log(*args):
self = args[0]
connection_name = self.connections.get_thread_connection().name
fire_event(CodeExecution(conn_name=connection_name, code_content=args[2]))
start_time = time.time()
response = code_execution_function(*args)
fire_event(
CodeExecutionStatus(
status=response._message, elapsed=round((time.time() - start_time), 2)
)
)
return response

return execution_with_log
57 changes: 38 additions & 19 deletions core/dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from dbt.graph import Graph
from dbt.events.functions import fire_event
from dbt.events.types import FoundStats, CompilingNode, WritingInjectedSQLForNode
from dbt.node_types import NodeType
from dbt.node_types import NodeType, ModelLanguage
from dbt.events.format import pluralize
import dbt.tracking

Expand Down Expand Up @@ -271,7 +271,7 @@ def _recursively_prepend_ctes(
are rolled up into the models that refer to them by
inserting CTEs into the SQL.
"""
if model.compiled_sql is None:
if model.compiled_code is None:
raise RuntimeException("Cannot inject ctes into an unparsed node", model)
if model.extra_ctes_injected:
return (model, model.extra_ctes)
Expand Down Expand Up @@ -324,29 +324,28 @@ def _recursively_prepend_ctes(
_extend_prepended_ctes(prepended_ctes, new_prepended_ctes)

new_cte_name = self.add_ephemeral_prefix(cte_model.name)
rendered_sql = cte_model._pre_injected_sql or cte_model.compiled_sql
rendered_sql = cte_model._pre_injected_sql or cte_model.compiled_code
sql = f" {new_cte_name} as (\n{rendered_sql}\n)"

_add_prepended_cte(prepended_ctes, InjectedCTE(id=cte.id, sql=sql))

injected_sql = self._inject_ctes_into_sql(
model.compiled_sql,
model.compiled_code,
prepended_ctes,
)
model._pre_injected_sql = model.compiled_sql
model.compiled_sql = injected_sql
model._pre_injected_sql = model.compiled_code
model.compiled_code = injected_sql
model.extra_ctes_injected = True
model.extra_ctes = prepended_ctes
model.validate(model.to_dict(omit_none=True))

manifest.update_node(model)

return model, prepended_ctes

# creates a compiled_node from the ManifestNode passed in,
# creates a "context" dictionary for jinja rendering,
# and then renders the "compiled_sql" using the node, the
# raw_sql and the context.
# and then renders the "compiled_code" using the node, the
# raw_code and the context.
def _compile_node(
self,
node: ManifestNode,
Expand All @@ -362,20 +361,40 @@ def _compile_node(
data.update(
{
"compiled": False,
"compiled_sql": None,
"compiled_code": None,
"extra_ctes_injected": False,
"extra_ctes": [],
}
)
compiled_node = _compiled_type_for(node).from_dict(data)

context = self._create_node_context(compiled_node, manifest, extra_context)
if compiled_node.language == ModelLanguage.python:
# TODO could we also 'minify' this code at all? just aesthetic, not functional

compiled_node.compiled_sql = jinja.get_rendered(
node.raw_sql,
context,
node,
)
# quoating seems like something very specific to sql so far
# for all python implementations we are seeing there's no quating.
# TODO try to find better way to do this, given that
original_quoting = self.config.quoting
self.config.quoting = {key: False for key in original_quoting.keys()}
context = self._create_node_context(compiled_node, manifest, extra_context)

postfix = jinja.get_rendered(
"{{ py_script_postfix(model) }}",
context,
node,
)
# we should NOT jinja render the python model's 'raw code'
compiled_node.compiled_code = f"{node.raw_code}\n\n{postfix}"
# restore quoting settings in the end since context is lazy evaluated
self.config.quoting = original_quoting

else:
context = self._create_node_context(compiled_node, manifest, extra_context)
compiled_node.compiled_code = jinja.get_rendered(
node.raw_code,
context,
node,
)

compiled_node.relation_name = self._get_relation_name(node)

Expand Down Expand Up @@ -487,15 +506,15 @@ def compile(self, manifest: Manifest, write=True, add_test_edges=False) -> Graph

return Graph(linker.graph)

# writes the "compiled_sql" into the target/compiled directory
# writes the "compiled_code" into the target/compiled directory
def _write_node(self, node: NonSourceCompiledNode) -> ManifestNode:
if not node.extra_ctes_injected or node.resource_type == NodeType.Snapshot:
return node
fire_event(WritingInjectedSQLForNode(unique_id=node.unique_id))

if node.compiled_sql:
if node.compiled_code:
node.compiled_path = node.write_node(
self.config.target_path, "compiled", node.compiled_sql
self.config.target_path, "compiled", node.compiled_code
)
return node

Expand Down
22 changes: 19 additions & 3 deletions core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
disallow_secret_env_var,
)
from dbt.config import IsFQNResource
from dbt.node_types import NodeType
from dbt.node_types import NodeType, ModelLanguage

from dbt.utils import merge, AttrDict, MultiDict

Expand Down Expand Up @@ -1157,7 +1157,12 @@ def graph(self) -> Dict[str, Any]:

@contextproperty("model")
def ctx_model(self) -> Dict[str, Any]:
return self.model.to_dict(omit_none=True)
ret = self.model.to_dict(omit_none=True)
# Maintain direct use of compiled_sql
# TODO add depreciation logic[CT-934]
if "compiled_code" in ret:
ret["compiled_sql"] = ret["compiled_code"]
return ret

@contextproperty
def pre_hooks(self) -> Optional[List[Dict[str, Any]]]:
Expand Down Expand Up @@ -1278,9 +1283,20 @@ def post_hooks(self) -> List[Dict[str, Any]]:

@contextproperty
def sql(self) -> Optional[str]:
# only doing this in sql model for backward compatible
if (
getattr(self.model, "extra_ctes_injected", None)
and self.model.language == ModelLanguage.sql # type: ignore[union-attr]
):
# TODO CT-211
return self.model.compiled_code # type: ignore[union-attr]
return None

@contextproperty
def compiled_code(self) -> Optional[str]:
if getattr(self.model, "extra_ctes_injected", None):
# TODO CT-211
return self.model.compiled_sql # type: ignore[union-attr]
return self.model.compiled_code # type: ignore[union-attr]
return None

@contextproperty
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/contracts/graph/compiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class CompiledNodeMixin(dbtClassMixin):

@dataclass
class CompiledNode(ParsedNode, CompiledNodeMixin):
compiled_sql: Optional[str] = None
compiled_code: Optional[str] = None
extra_ctes_injected: bool = False
extra_ctes: List[InjectedCTE] = field(default_factory=list)
relation_name: Optional[str] = None
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,7 +1157,7 @@ def __init__(self, macros):


@dataclass
@schema_version("manifest", 6)
@schema_version("manifest", 7)
class WritableManifest(ArtifactMixin):
nodes: Mapping[UniqueID, ManifestNode] = field(
metadata=dict(description=("The nodes defined in the dbt project and its dependencies"))
Expand Down Expand Up @@ -1203,7 +1203,7 @@ class WritableManifest(ArtifactMixin):

@classmethod
def compatible_previous_versions(self):
return [("manifest", 4), ("manifest", 5)]
return [("manifest", 4), ("manifest", 5), ("manifest", 6)]

def __post_serialize__(self, dct):
for unique_id, node in dct["nodes"].items():
Expand Down
4 changes: 4 additions & 0 deletions core/dbt/contracts/graph/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,10 @@ class NodeConfig(NodeAndTestConfig):
grants: Dict[str, Any] = field(
default_factory=dict, metadata=MergeBehavior.DictKeyAppend.meta()
)
packages: List[str] = field(
default_factory=list,
metadata=MergeBehavior.Append.meta(),
)

@classmethod
def __pre_deserialize__(cls, data):
Expand Down
3 changes: 2 additions & 1 deletion core/dbt/contracts/graph/parsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ def _serialize(self):
return self.to_dict()

def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
if "_event_status" in dct:
del dct["_event_status"]
return dct
Expand Down Expand Up @@ -281,7 +282,7 @@ def _persist_relation_docs(self) -> bool:
return False

def same_body(self: T, other: T) -> bool:
return self.raw_sql == other.raw_sql
return self.raw_code == other.raw_code

def same_persisted_description(self: T, other: T) -> bool:
# the check on configs will handle the case where we have different
Expand Down
Loading

0 comments on commit a7ff003

Please sign in to comment.