Skip to content

Commit

Permalink
Add missing attribute "name" and "group" for Asset and "group" for As…
Browse files Browse the repository at this point in the history
…setAlias in serialization, api and methods (apache#43774)

* test(tests/www/views/test_views_grid): extend Asset test cases to include both uri and name

* test(utils/test_json): extend Asset test cases to include both uri and name

* test(timetables/test_assets_timetable): extend Asset test cases to include both uri and name

* test(listeners/test_asset_listener): extend Asset test cases to include both uri and name

* test(jobs/test_scheduler_job): extend Asset test cases to include both uri and name

* test(providers/openlineage): extend Asset test cases to include both uri and name

* test(decorators/test_python): extend Asset test cases to include both uri and name

* test(models/test_dag): extend asset test cases to cover name, uri, group

* test(api_connexsion/schemas/dag_run): extend asset test cases to cover name, uri, group

* test(serialization/serialized_objects): extend asset test cases to cover name, uri, group and asset alias test cases to cover name and group

* test(serialization/dag_serialization): extend asset test cases to cover name, uri, group

* test(models/dag): extend asset test cases to cover name, uri, group

* test(serialization/serde): extend asset test cases to cover name, uri, group

* test(api_connexion/schemas/asset): extend asset test cases to cover name, uri, group

* test(api_connexion/schemas/asset): extend asset alias test cases to cover name, group

* test(api_connexsion/schemas/dag): extend asset test cases to cover name, uri, group

* test(api_connexsion/schemas/dag_run): extend asset test cases to cover name, uri, group

* test(dags/test_assets): extend asset test cases to cover name, uri, group

* test(dags/test_only_empty_tasks): extend asset test cases to cover name, uri, group

* test(api_fastapi): extend asset test cases to cover name, uri, group

* test(assets/manager): extend asset test cases to cover name, uri, group

* test(task_sdk/assets): extend asset test cases to cover name, uri, group

* test(api_connexion/endpoints/asset): extend asset test cases to cover name, uri, group

* test: add missing session

* test(www/views/asset): extend asset test cases to cover name, uri, group

* test(models/seraialized_dag): extend asset test cases to cover name, uri, group

* test(lineage/hook): extend asset test cases to cover name, uri, group

* test(io/path): extend asset test cases to cover name, uri, group

* test(jobs): enhance test_activate_referenced_assets_with_no_existing_warning to cover extra edge case

* fix(serialization): serialize both name, uri and group for Asset

* fix(assets): extend Asset as_expression methods to include name, group fields (also AssetAlias group field)

* fix(serialization/serialized_objects): fix missing AssetAlias.group serialization

* fix(serialization): change dependency_id to use name instead of uri

* feat(api_connexion/schemas/asset): add name, group to asset schema and group to asset alias schema

* feat(assets/manager): filter asset by name, uri, group instead of uri only

* style(assets/manager): rename argument asset in _add_asset_alias_association as asset_model

* fix(asset): use name to evalute instead of uri

* fix(api_connexion/endpoints/asset): fix how asset event is fetch in create asset event

* fix(api_fastapi/asset): fix how asset event is fetch in create asset event

* fix(lineage/hook): extend asset realted methods to include name and group

* fix(task_sdk/asset): change iter_assets to return ((name, uri), obj) instead of (uri, obj)

* fix(fastapi/asset): add missing group column to asset alias schema

* build: build autogen ts files

* feat(lineage/hook): make create_asset keyword only

* docs(newsfragments): add 43774.significant.rst

* refactor(task_sdk/asset): add from_asset_alias to AssetAliasCondition to remove duplicate code

* refactor(task_sdk/asset): add AssetUniqueKey.from_asset to reduce duplicate code

* Revert "fix(asset): use name to evalute instead of uri"

This reverts commit e812b8a.
  • Loading branch information
Lee-W authored Dec 2, 2024
1 parent 1e0a499 commit 06c3823
Show file tree
Hide file tree
Showing 39 changed files with 774 additions and 298 deletions.
8 changes: 4 additions & 4 deletions airflow/api_connexion/endpoints/asset_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
)
from airflow.assets.manager import asset_manager
from airflow.models.asset import AssetDagRunQueue, AssetEvent, AssetModel
from airflow.sdk.definitions.asset import Asset
from airflow.utils import timezone
from airflow.utils.api_migration import mark_fastapi_migration_done
from airflow.utils.db import get_query_count
Expand Down Expand Up @@ -341,15 +340,16 @@ def create_asset_event(session: Session = NEW_SESSION) -> APIResponse:
except ValidationError as err:
raise BadRequest(detail=str(err))

# TODO: handle name
uri = json_body["asset_uri"]
asset = session.scalar(select(AssetModel).where(AssetModel.uri == uri).limit(1))
if not asset:
asset_model = session.scalar(select(AssetModel).where(AssetModel.uri == uri).limit(1))
if not asset_model:
raise NotFound(title="Asset not found", detail=f"Asset with uri: '{uri}' not found")
timestamp = timezone.utcnow()
extra = json_body.get("extra", {})
extra["from_rest_api"] = True
asset_event = asset_manager.register_asset_change(
asset=Asset(uri=uri),
asset=asset_model.to_public(),
timestamp=timestamp,
extra=extra,
session=session,
Expand Down
3 changes: 3 additions & 0 deletions airflow/api_connexion/schemas/asset_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class Meta:

id = auto_field()
name = auto_field()
group = auto_field()


class AssetSchema(SQLAlchemySchema):
Expand All @@ -82,6 +83,8 @@ class Meta:

id = auto_field()
uri = auto_field()
name = auto_field()
group = auto_field()
extra = JsonObjectField()
created_at = auto_field()
updated_at = auto_field()
Expand Down
1 change: 1 addition & 0 deletions airflow/api_fastapi/core_api/datamodels/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class AssetAliasSchema(BaseModel):

id: int
name: str
group: str


class AssetResponse(BaseModel):
Expand Down
4 changes: 4 additions & 0 deletions airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5786,10 +5786,14 @@ components:
name:
type: string
title: Name
group:
type: string
title: Group
type: object
required:
- id
- name
- group
title: AssetAliasSchema
description: Asset alias serializer for assets.
AssetCollectionResponse:
Expand Down
7 changes: 3 additions & 4 deletions airflow/api_fastapi/core_api/routes/public/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.assets.manager import asset_manager
from airflow.models.asset import AssetDagRunQueue, AssetEvent, AssetModel
from airflow.sdk.definitions.asset import Asset
from airflow.utils import timezone

assets_router = AirflowRouter(tags=["Asset"])
Expand Down Expand Up @@ -171,13 +170,13 @@ def create_asset_event(
session: SessionDep,
) -> AssetEventResponse:
"""Create asset events."""
asset = session.scalar(select(AssetModel).where(AssetModel.uri == body.uri).limit(1))
if not asset:
asset_model = session.scalar(select(AssetModel).where(AssetModel.uri == body.uri).limit(1))
if not asset_model:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Asset with uri: `{body.uri}` was not found")
timestamp = timezone.utcnow()

assets_event = asset_manager.register_asset_change(
asset=Asset(uri=body.uri),
asset=asset_model.to_public(),
timestamp=timestamp,
extra=body.extra,
session=session,
Expand Down
12 changes: 7 additions & 5 deletions airflow/assets/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,16 @@ def _add_one(asset_alias: AssetAlias) -> AssetAliasModel:
def _add_asset_alias_association(
cls,
alias_names: Collection[str],
asset: AssetModel,
asset_model: AssetModel,
*,
session: Session,
) -> None:
already_related = {m.name for m in asset.aliases}
already_related = {m.name for m in asset_model.aliases}
existing_aliases = {
m.name: m
for m in session.scalars(select(AssetAliasModel).where(AssetAliasModel.name.in_(alias_names)))
}
asset.aliases.extend(
asset_model.aliases.extend(
existing_aliases.get(name, AssetAliasModel(name=name))
for name in alias_names
if name not in already_related
Expand All @@ -121,7 +121,7 @@ def register_asset_change(
"""
asset_model = session.scalar(
select(AssetModel)
.where(AssetModel.uri == asset.uri)
.where(AssetModel.name == asset.name, AssetModel.uri == asset.uri)
.options(
joinedload(AssetModel.aliases),
joinedload(AssetModel.consuming_dags).joinedload(DagScheduleAssetReference.dag),
Expand All @@ -131,7 +131,9 @@ def register_asset_change(
cls.logger().warning("AssetModel %s not found", asset)
return None

cls._add_asset_alias_association({alias.name for alias in aliases}, asset_model, session=session)
cls._add_asset_alias_association(
alias_names={alias.name for alias in aliases}, asset_model=asset_model, session=session
)

event_kwargs = {
"asset_id": asset_model.id,
Expand Down
40 changes: 32 additions & 8 deletions airflow/lineage/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,24 +95,40 @@ def _generate_key(self, asset: Asset, context: LineageContext) -> str:
return f"{asset.uri}_{extra_hash}_{id(context)}"

def create_asset(
self, scheme: str | None, uri: str | None, asset_kwargs: dict | None, asset_extra: dict | None
self,
*,
scheme: str | None = None,
uri: str | None = None,
name: str | None = None,
group: str | None = None,
asset_kwargs: dict | None = None,
asset_extra: dict | None = None,
) -> Asset | None:
"""
Create an asset instance using the provided parameters.
This method attempts to create an asset instance using the given parameters.
It first checks if a URI is provided and falls back to using the default asset factory
with the given URI if no other information is available.
It first checks if a URI or a name is provided and falls back to using the default asset factory
with the given URI or name if no other information is available.
If a scheme is provided but no URI, it attempts to find an asset factory that matches
If a scheme is provided but no URI or name, it attempts to find an asset factory that matches
the given scheme. If no such factory is found, it logs an error message and returns None.
If asset_kwargs is provided, it is used to pass additional parameters to the asset
factory. The asset_extra parameter is also passed to the factory as an ``extra`` parameter.
"""
if uri:
if uri or name:
# Fallback to default factory using the provided URI
return Asset(uri=uri, extra=asset_extra)
kwargs: dict[str, str | dict] = {}
if uri:
kwargs["uri"] = uri
if name:
kwargs["name"] = name
if group:
kwargs["group"] = group
if asset_extra:
kwargs["extra"] = asset_extra
return Asset(**kwargs) # type: ignore[call-overload]

if not scheme:
self.log.debug(
Expand All @@ -137,11 +153,15 @@ def add_input_asset(
context: LineageContext,
scheme: str | None = None,
uri: str | None = None,
name: str | None = None,
group: str | None = None,
asset_kwargs: dict | None = None,
asset_extra: dict | None = None,
):
"""Add the input asset and its corresponding hook execution context to the collector."""
asset = self.create_asset(scheme=scheme, uri=uri, asset_kwargs=asset_kwargs, asset_extra=asset_extra)
asset = self.create_asset(
scheme=scheme, uri=uri, name=name, group=group, asset_kwargs=asset_kwargs, asset_extra=asset_extra
)
if asset:
key = self._generate_key(asset, context)
if key not in self._inputs:
Expand All @@ -153,11 +173,15 @@ def add_output_asset(
context: LineageContext,
scheme: str | None = None,
uri: str | None = None,
name: str | None = None,
group: str | None = None,
asset_kwargs: dict | None = None,
asset_extra: dict | None = None,
):
"""Add the output asset and its corresponding hook execution context to the collector."""
asset = self.create_asset(scheme=scheme, uri=uri, asset_kwargs=asset_kwargs, asset_extra=asset_extra)
asset = self.create_asset(
scheme=scheme, uri=uri, name=name, group=group, asset_kwargs=asset_kwargs, asset_extra=asset_extra
)
if asset:
key = self._generate_key(asset, context)
if key not in self._outputs:
Expand Down
66 changes: 51 additions & 15 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@
from airflow.models.connection import Connection
from airflow.models.dag import DAG, DagModel
from airflow.models.dagrun import DagRun
from airflow.models.expandinput import EXPAND_INPUT_EMPTY, create_expand_input, get_map_type_key
from airflow.models.expandinput import (
EXPAND_INPUT_EMPTY,
create_expand_input,
get_map_type_key,
)
from airflow.models.mappedoperator import MappedOperator
from airflow.models.param import Param, ParamsDict
from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
Expand Down Expand Up @@ -213,7 +217,9 @@ def _get_registered_timetable(importable_string: str) -> type[Timetable] | None:
return None


def _get_registered_priority_weight_strategy(importable_string: str) -> type[PriorityWeightStrategy] | None:
def _get_registered_priority_weight_strategy(
importable_string: str,
) -> type[PriorityWeightStrategy] | None:
from airflow import plugins_manager

if importable_string in airflow_priority_weight_strategies:
Expand Down Expand Up @@ -256,13 +262,25 @@ def encode_asset_condition(var: BaseAsset) -> dict[str, Any]:
:meta private:
"""
if isinstance(var, Asset):
return {"__type": DAT.ASSET, "name": var.name, "uri": var.uri, "extra": var.extra}
return {
"__type": DAT.ASSET,
"name": var.name,
"uri": var.uri,
"group": var.group,
"extra": var.extra,
}
if isinstance(var, AssetAlias):
return {"__type": DAT.ASSET_ALIAS, "name": var.name}
return {"__type": DAT.ASSET_ALIAS, "name": var.name, "group": var.group}
if isinstance(var, AssetAll):
return {"__type": DAT.ASSET_ALL, "objects": [encode_asset_condition(x) for x in var.objects]}
return {
"__type": DAT.ASSET_ALL,
"objects": [encode_asset_condition(x) for x in var.objects],
}
if isinstance(var, AssetAny):
return {"__type": DAT.ASSET_ANY, "objects": [encode_asset_condition(x) for x in var.objects]}
return {
"__type": DAT.ASSET_ANY,
"objects": [encode_asset_condition(x) for x in var.objects],
}
raise ValueError(f"serialization not implemented for {type(var).__name__!r}")


Expand All @@ -274,13 +292,13 @@ def decode_asset_condition(var: dict[str, Any]) -> BaseAsset:
"""
dat = var["__type"]
if dat == DAT.ASSET:
return Asset(uri=var["uri"], name=var["name"], extra=var["extra"])
return Asset(name=var["name"], uri=var["uri"], group=var["group"], extra=var["extra"])
if dat == DAT.ASSET_ALL:
return AssetAll(*(decode_asset_condition(x) for x in var["objects"]))
if dat == DAT.ASSET_ANY:
return AssetAny(*(decode_asset_condition(x) for x in var["objects"]))
if dat == DAT.ASSET_ALIAS:
return AssetAlias(name=var["name"])
return AssetAlias(name=var["name"], group=var["group"])
raise ValueError(f"deserialization not implemented for DAT {dat!r}")


Expand Down Expand Up @@ -586,7 +604,9 @@ def _is_excluded(cls, var: Any, attrname: str, instance: Any) -> bool:

@classmethod
def serialize_to_json(
cls, object_to_serialize: BaseOperator | MappedOperator | DAG, decorated_fields: set
cls,
object_to_serialize: BaseOperator | MappedOperator | DAG,
decorated_fields: set,
) -> dict[str, Any]:
"""Serialize an object to JSON."""
serialized_object: dict[str, Any] = {}
Expand Down Expand Up @@ -653,7 +673,11 @@ def serialize(
return cls._encode(json_pod, type_=DAT.POD)
elif isinstance(var, OutletEventAccessors):
return cls._encode(
cls.serialize(var._dict, strict=strict, use_pydantic_models=use_pydantic_models), # type: ignore[attr-defined]
cls.serialize(
var._dict, # type: ignore[attr-defined]
strict=strict,
use_pydantic_models=use_pydantic_models,
),
type_=DAT.ASSET_EVENT_ACCESSORS,
)
elif isinstance(var, OutletEventAccessor):
Expand Down Expand Up @@ -696,15 +720,23 @@ def serialize(
elif isinstance(var, (KeyError, AttributeError)):
return cls._encode(
cls.serialize(
{"exc_cls_name": var.__class__.__name__, "args": [var.args], "kwargs": {}},
{
"exc_cls_name": var.__class__.__name__,
"args": [var.args],
"kwargs": {},
},
use_pydantic_models=use_pydantic_models,
strict=strict,
),
type_=DAT.BASE_EXC_SER,
)
elif isinstance(var, BaseTrigger):
return cls._encode(
cls.serialize(var.serialize(), use_pydantic_models=use_pydantic_models, strict=strict),
cls.serialize(
var.serialize(),
use_pydantic_models=use_pydantic_models,
strict=strict,
),
type_=DAT.BASE_TRIGGER,
)
elif callable(var):
Expand Down Expand Up @@ -1065,11 +1097,11 @@ def detect_task_dependencies(task: Operator) -> list[DagDependency]:
source=task.dag_id,
target="asset",
dependency_type="asset",
dependency_id=obj.uri,
dependency_id=obj.name,
)
)
elif isinstance(obj, AssetAlias):
cond = AssetAliasCondition(obj.name)
cond = AssetAliasCondition(name=obj.name, group=obj.group)

deps.extend(cond.iter_dag_dependencies(source=task.dag_id, target=""))
return deps
Expand Down Expand Up @@ -1298,7 +1330,11 @@ def populate_operator(cls, op: Operator, encoded_op: dict[str, Any]) -> None:
# The case for "If OperatorLinks are defined in the operator that is being Serialized"
# is handled in the deserialization loop where it matches k == "_operator_extra_links"
if op_extra_links_from_plugin and "_operator_extra_links" not in encoded_op:
setattr(op, "operator_extra_links", list(op_extra_links_from_plugin.values()))
setattr(
op,
"operator_extra_links",
list(op_extra_links_from_plugin.values()),
)

for k, v in encoded_op.items():
# python_callable_name only serves to detect function name changes
Expand Down
4 changes: 2 additions & 2 deletions airflow/timetables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from collections.abc import Iterator, Sequence
from typing import TYPE_CHECKING, Any, NamedTuple

from airflow.sdk.definitions.asset import BaseAsset
from airflow.sdk.definitions.asset import AssetUniqueKey, BaseAsset
from airflow.typing_compat import Protocol, runtime_checkable

if TYPE_CHECKING:
Expand Down Expand Up @@ -55,7 +55,7 @@ def as_expression(self) -> Any:
def evaluate(self, statuses: dict[str, bool]) -> bool:
return False

def iter_assets(self) -> Iterator[tuple[str, Asset]]:
def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
return iter(())

def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]:
Expand Down
2 changes: 1 addition & 1 deletion airflow/timetables/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def __init__(self, assets: BaseAsset) -> None:
super().__init__()
self.asset_condition = assets
if isinstance(self.asset_condition, AssetAlias):
self.asset_condition = AssetAliasCondition(self.asset_condition.name)
self.asset_condition = AssetAliasCondition.from_asset_alias(self.asset_condition)

if not next(self.asset_condition.iter_assets(), False):
self._summary = AssetTriggeredTimetable.UNRESOLVED_ALIAS_SUMMARY
Expand Down
Loading

0 comments on commit 06c3823

Please sign in to comment.