From 3fc57de44106afdd7d684ddf692f1f810a6ec7c3 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Tue, 31 Aug 2021 16:15:04 +0800 Subject: [PATCH] Allow custom timetable as a DAG argument (#17414) --- airflow/api/common/experimental/mark_tasks.py | 8 +- airflow/jobs/backfill_job.py | 2 +- airflow/models/dag.py | 92 ++++--- airflow/models/dagrun.py | 4 +- airflow/models/taskinstance.py | 7 +- airflow/serialization/schema.json | 12 +- airflow/serialization/serialized_objects.py | 102 ++++++-- airflow/timetables/base.py | 61 ++++- airflow/timetables/interval.py | 234 ++++++++++++++++-- airflow/timetables/schedules.py | 207 ---------------- airflow/timetables/simple.py | 56 +++-- airflow/utils/module_loading.py | 5 + airflow/www/views.py | 2 +- tests/jobs/test_backfill_job.py | 3 +- tests/models/test_dag.py | 10 +- .../aws/sensors/test_s3_keys_unchanged.py | 13 +- .../amazon/aws/transfers/test_s3_to_sftp.py | 11 +- .../amazon/aws/transfers/test_sftp_to_s3.py | 11 +- .../google/cloud/sensors/test_gcs.py | 12 +- tests/providers/sftp/operators/test_sftp.py | 11 +- tests/providers/ssh/operators/test_ssh.py | 12 +- tests/serialization/test_dag_serialization.py | 108 +++++++- tests/test_utils/timetables.py | 17 ++ 23 files changed, 626 insertions(+), 374 deletions(-) delete mode 100644 airflow/timetables/schedules.py diff --git a/airflow/api/common/experimental/mark_tasks.py b/airflow/api/common/experimental/mark_tasks.py index 30ecd4200c46d..08b7363b8a388 100644 --- a/airflow/api/common/experimental/mark_tasks.py +++ b/airflow/api/common/experimental/mark_tasks.py @@ -249,13 +249,13 @@ def get_execution_dates(dag, execution_date, future, past): else: start_date = execution_date start_date = execution_date if not past else start_date - if dag.schedule_interval == '@once': - dates = [start_date] - elif not dag.schedule_interval: - # If schedule_interval is None, need to look at existing DagRun if the user wants future or + if not dag.timetable.can_run: + # If the DAG never schedules, need to look at existing DagRun if the user wants future or # past runs. dag_runs = dag.get_dagruns_between(start_date=start_date, end_date=end_date) dates = sorted({d.execution_date for d in dag_runs}) + elif not dag.timetable.periodic: + dates = [start_date] else: dates = [ info.logical_date for info in dag.iter_dagrun_infos_between(start_date, end_date, align=False) diff --git a/airflow/jobs/backfill_job.py b/airflow/jobs/backfill_job.py index 8e120a999573a..76d11007481a6 100644 --- a/airflow/jobs/backfill_job.py +++ b/airflow/jobs/backfill_job.py @@ -298,7 +298,7 @@ def _get_dag_run(self, dagrun_info: DagRunInfo, dag: DAG, session: Session = Non run_date = dagrun_info.logical_date # consider max_active_runs but ignore when running subdags - respect_dag_max_active_limit = bool(dag.schedule_interval and not dag.is_subdag) + respect_dag_max_active_limit = bool(dag.timetable.can_run and not dag.is_subdag) current_active_dag_count = dag.get_num_active_runs(external_trigger=False) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index c5994fdfca8b0..21f3d4eaf3a88 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -27,7 +27,7 @@ import traceback import warnings from collections import OrderedDict -from datetime import datetime, timedelta +from datetime import datetime, timedelta, tzinfo from inspect import signature from typing import ( TYPE_CHECKING, @@ -72,7 +72,6 @@ from airflow.stats import Stats from airflow.timetables.base import DagRunInfo, TimeRestriction, Timetable from airflow.timetables.interval import CronDataIntervalTimetable, DeltaDataIntervalTimetable -from airflow.timetables.schedules import Schedule from airflow.timetables.simple import NullTimetable, OnceTimetable from airflow.typing_compat import Literal, RePatternType from airflow.utils import timezone @@ -92,11 +91,34 @@ log = logging.getLogger(__name__) -ScheduleInterval = Union[str, timedelta, relativedelta] DEFAULT_VIEW_PRESETS = ['tree', 'graph', 'duration', 'gantt', 'landing_times'] ORIENTATION_PRESETS = ['LR', 'TB', 'RL', 'BT'] +ScheduleIntervalArgNotSet = type("ScheduleIntervalArgNotSet", (), {}) + DagStateChangeCallback = Callable[[Context], None] +ScheduleInterval = Union[str, timedelta, relativedelta] +ScheduleIntervalArg = Union[ScheduleInterval, None, Type[ScheduleIntervalArgNotSet]] + + +# Backward compatibility: If neither schedule_interval nor timetable is +# *provided by the user*, default to a one-day interval. +DEFAULT_SCHEDULE_INTERVAL = timedelta(days=1) + + +def create_timetable(interval: ScheduleIntervalArg, timezone: tzinfo) -> Timetable: + """Create a Timetable instance from a ``schedule_interval`` argument.""" + if interval is ScheduleIntervalArgNotSet: + return DeltaDataIntervalTimetable(DEFAULT_SCHEDULE_INTERVAL) + if interval is None: + return NullTimetable() + if interval == "@once": + return OnceTimetable() + if isinstance(interval, (timedelta, relativedelta)): + return DeltaDataIntervalTimetable(interval) + if isinstance(interval, str): + return CronDataIntervalTimetable(interval, timezone) + raise ValueError(f"{interval!r} is not a valid schedule_interval.") def get_last_dagrun(dag_id, session, include_externally_triggered=False): @@ -256,7 +278,8 @@ def __init__( self, dag_id: str, description: Optional[str] = None, - schedule_interval: Optional[ScheduleInterval] = timedelta(days=1), + schedule_interval: ScheduleIntervalArg = ScheduleIntervalArgNotSet, + timetable: Optional[Timetable] = None, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, full_filepath: Optional[str] = None, @@ -349,7 +372,18 @@ def __init__( if 'end_date' in self.default_args: self.default_args['end_date'] = timezone.convert_to_utc(self.default_args['end_date']) - self.schedule_interval = schedule_interval + # Calculate the DAG's timetable. + if timetable is None: + self.timetable = create_timetable(schedule_interval, self.timezone) + if schedule_interval is ScheduleIntervalArgNotSet: + schedule_interval = DEFAULT_SCHEDULE_INTERVAL + self.schedule_interval: ScheduleInterval = schedule_interval + elif schedule_interval is ScheduleIntervalArgNotSet: + self.timetable = timetable + self.schedule_interval = self.timetable.summary + else: + raise TypeError("cannot specify both 'schedule_interval' and 'timetable'") + if isinstance(template_searchpath, str): template_searchpath = [template_searchpath] self.template_searchpath = template_searchpath @@ -494,7 +528,7 @@ def is_fixed_time_schedule(self): stacklevel=2, ) try: - return not self.timetable._schedule._should_fix_dst + return not self.timetable._should_fix_dst except AttributeError: return True @@ -505,24 +539,25 @@ def following_schedule(self, dttm): :param dttm: utc datetime :return: utc datetime """ - current = pendulum.instance(dttm) - between = TimeRestriction(earliest=None, latest=None, catchup=True) - next_info = self.timetable.next_dagrun_info(current, between) + next_info = self.timetable.next_dagrun_info( + last_automated_dagrun=pendulum.instance(dttm), + restriction=TimeRestriction(earliest=None, latest=None, catchup=True), + ) if next_info is None: return None return next_info.data_interval.start def previous_schedule(self, dttm): + from airflow.timetables.interval import _DataIntervalTimetable + warnings.warn( "`DAG.previous_schedule()` is deprecated.", category=DeprecationWarning, stacklevel=2, ) - try: - schedule: Schedule = self.timetable._schedule - except AttributeError: + if not isinstance(self.timetable, _DataIntervalTimetable): return None - return schedule.get_prev(pendulum.instance(dttm)) + return self.timetable._get_prev(pendulum.instance(dttm)) def next_dagrun_info( self, @@ -551,8 +586,8 @@ def next_dagrun_info( # and someone is passing datetime.datetime into this function. We should # fix whatever is doing that. return self.timetable.next_dagrun_info( - timezone.coerce_datetime(date_last_automated_dagrun), - self._time_restriction, + last_automated_dagrun=timezone.coerce_datetime(date_last_automated_dagrun), + restriction=self._time_restriction, ) def next_dagrun_after_date(self, date_last_automated_dagrun: Optional[pendulum.DateTime]): @@ -584,20 +619,6 @@ def _time_restriction(self) -> TimeRestriction: latest = None return TimeRestriction(earliest, latest, self.catchup) - @cached_property - def timetable(self) -> Timetable: - interval = self.schedule_interval - if interval is None: - return NullTimetable() - if interval == "@once": - return OnceTimetable() - if isinstance(interval, (timedelta, relativedelta)): - return DeltaDataIntervalTimetable(interval) - if isinstance(interval, str): - return CronDataIntervalTimetable(interval, self.timezone) - type_name = type(interval).__name__ - raise TypeError(f"{type_name} is not a valid DAG.schedule_interval.") - def iter_dagrun_infos_between( self, earliest: Optional[pendulum.DateTime], @@ -637,7 +658,7 @@ def iter_dagrun_infos_between( if self.is_subdag: align = False - info = self.timetable.next_dagrun_info(None, restriction) + info = self.timetable.next_dagrun_info(last_automated_dagrun=None, restriction=restriction) if info is None: # No runs to be scheduled between the user-supplied timeframe. But # if align=False, "invent" a data interval for the timeframe itself. @@ -653,7 +674,10 @@ def iter_dagrun_infos_between( # Generate naturally according to schedule. while info is not None: yield info - info = self.timetable.next_dagrun_info(info.logical_date, restriction) + info = self.timetable.next_dagrun_info( + last_automated_dagrun=info.logical_date, + restriction=restriction, + ) def get_run_dates(self, start_date, end_date=None): """ @@ -844,7 +868,7 @@ def owner(self) -> str: @property def allow_future_exec_dates(self) -> bool: - return settings.ALLOW_FUTURE_EXEC_DATES and self.schedule_interval is None + return settings.ALLOW_FUTURE_EXEC_DATES and not self.timetable.can_run @provide_session def get_concurrency_reached(self, session=None) -> bool: @@ -2112,7 +2136,9 @@ def create_dagrun( ) if run_type == DagRunType.MANUAL and data_interval is None and execution_date is not None: - data_interval = self.timetable.infer_data_interval(timezone.coerce_datetime(execution_date)) + data_interval = self.timetable.infer_data_interval( + run_after=timezone.coerce_datetime(execution_date), + ) run = DagRun( dag_id=self.dag_id, diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index cdc036edf63fc..797943441d42c 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -581,7 +581,7 @@ def _emit_true_scheduling_delay_stats_for_finished_state(self, finished_tis): This method will be used in the update_state method when the state of the DagRun is updated to a completed status (either success or failure). The method will find the first started task within the DAG and calculate the expected DagRun start time (based on - dag.execution_date & dag.schedule_interval), and minus these two values to get the delay. + dag.execution_date & dag.timetable), and minus these two values to get the delay. The emitted data may contains outlier (e.g. when the first task was cleared, so the second task's start_date will be used), but we can get rid of the outliers on the stats side through the dashboards tooling built. @@ -598,7 +598,7 @@ def _emit_true_scheduling_delay_stats_for_finished_state(self, finished_tis): try: dag = self.get_dag() - if not self.dag.schedule_interval or self.dag.schedule_interval == "@once": + if not self.dag.timetable.periodic: # We can't emit this metric if there is no following schedule to calculate from! return diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 39fa7dc04a4ec..56546e9c7a5c4 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -880,10 +880,11 @@ def get_previous_dagrun( dr.dag = dag - # We always ignore schedule in dagrun lookup when `state` is given or `schedule_interval is None`. - # For legacy reasons, when `catchup=True`, we use `get_previous_scheduled_dagrun` unless + # We always ignore schedule in dagrun lookup when `state` is given + # or the DAG is never scheduled. For legacy reasons, when + # `catchup=True`, we use `get_previous_scheduled_dagrun` unless # `ignore_schedule` is `True`. - ignore_schedule = state is not None or dag.schedule_interval is None + ignore_schedule = state is not None or not dag.timetable.can_run if dag.catchup is True and not ignore_schedule: last_dagrun = dr.get_previous_scheduled_dagrun(session=session) else: diff --git a/airflow/serialization/schema.json b/airflow/serialization/schema.json index d5ec8942525e1..b4a64b4e79459 100644 --- a/airflow/serialization/schema.json +++ b/airflow/serialization/schema.json @@ -48,7 +48,10 @@ } }, "timezone": { - "type": "string" + "anyOf": [ + { "type": "string" }, + { "type": "integer" } + ] }, "dict": { "description": "A python dictionary containing values of any type", @@ -87,6 +90,13 @@ { "$ref": "#/definitions/typed_relativedelta" } ] }, + "timetable": { + "type": "object", + "properties": { + "type": { "type": "string" }, + "value": { "$ref": "#/definitions/dict" } + } + }, "catchup": { "type": "boolean" }, "is_subdag": { "type": "boolean" }, "fileloc": { "type" : "string"}, diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index e1fd91341802d..4978736d15d57 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -33,20 +33,21 @@ from functools import lru_cache cache = lru_cache(maxsize=None) -from pendulum.tz.timezone import Timezone +from pendulum.tz.timezone import FixedTimezone, Timezone from airflow.configuration import conf from airflow.exceptions import AirflowException, SerializationError from airflow.models.baseoperator import BaseOperator, BaseOperatorLink from airflow.models.connection import Connection -from airflow.models.dag import DAG +from airflow.models.dag import DAG, create_timetable from airflow.providers_manager import ProvidersManager from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding from airflow.serialization.helpers import serialize_template_field from airflow.serialization.json_schema import Validator, load_dag_schema from airflow.settings import json +from airflow.timetables.base import Timetable from airflow.utils.code_utils import get_python_source -from airflow.utils.module_loading import import_string +from airflow.utils.module_loading import as_importable_string, import_string from airflow.utils.task_group import TaskGroup try: @@ -87,6 +88,64 @@ def get_operator_extra_links(): return _OPERATOR_EXTRA_LINKS +def encode_relativedelta(var: relativedelta.relativedelta) -> Dict[str, Any]: + encoded = {k: v for k, v in var.__dict__.items() if not k.startswith("_") and v} + if var.weekday and var.weekday.n: + # Every n'th Friday for example + encoded['weekday'] = [var.weekday.weekday, var.weekday.n] + elif var.weekday: + encoded['weekday'] = [var.weekday.weekday] + return encoded + + +def decode_relativedelta(var: Dict[str, Any]) -> relativedelta.relativedelta: + if 'weekday' in var: + var['weekday'] = relativedelta.weekday(*var['weekday']) # type: ignore + return relativedelta.relativedelta(**var) + + +def encode_timezone(var: Timezone) -> Union[str, int]: + """Encode a Pendulum Timezone for serialization. + + Airflow only supports timezone objects that implements Pendulum's Timezone + interface. We try to keep as much information as possible to make conversion + round-tripping possible (see ``decode_timezone``). We need to special-case + UTC; Pendulum implements it as a FixedTimezone (i.e. it gets encoded as + 0 without the special case), but passing 0 into ``pendulum.timezone`` does + not give us UTC (but ``+00:00``). + """ + if isinstance(var, FixedTimezone): + if var.offset == 0: + return "UTC" + return var.offset + if isinstance(var, Timezone): + return var.name + raise ValueError(f"DAG timezone should be a pendulum.tz.Timezone, not {var!r}") + + +def decode_timezone(var: Union[str, int]) -> Timezone: + """Decode a previously serialized Pendulum Timezone.""" + return pendulum.timezone(var) + + +def encode_timetable(var: Timetable) -> Dict[str, Any]: + """Encode a timetable instance. + + This delegates most of the serialization work to the type, so the behavior + can be completely controlled by a custom subclass. + """ + return {"type": as_importable_string(type(var)), "value": var.serialize()} + + +def decode_timetable(var: Dict[str, Any]) -> Timetable: + """Decode a previously serialized timetable. + + Most of the deserialization logic is delegated to the actual type, which + we import from string. + """ + return import_string(var["type"]).deserialize(var["value"]) + + class BaseSerialization: """BaseSerialization provides utils for serialization.""" @@ -188,6 +247,8 @@ def serialize_to_json( if key in decorated_fields: serialized_object[key] = cls._serialize(value) + elif key == "timetable": + serialized_object[key] = encode_timetable(value) else: value = cls._serialize(value) if isinstance(value, dict) and "__type" in value: @@ -228,15 +289,9 @@ def _serialize(cls, var: Any) -> Any: # Unfortunately there is no support for r elif isinstance(var, datetime.timedelta): return cls._encode(var.total_seconds(), type_=DAT.TIMEDELTA) elif isinstance(var, Timezone): - return cls._encode(str(var.name), type_=DAT.TIMEZONE) + return cls._encode(encode_timezone(var), type_=DAT.TIMEZONE) elif isinstance(var, relativedelta.relativedelta): - encoded = {k: v for k, v in var.__dict__.items() if not k.startswith("_") and v} - if var.weekday and var.weekday.n: - # Every n'th Friday for example - encoded['weekday'] = [var.weekday.weekday, var.weekday.n] - elif var.weekday: - encoded['weekday'] = [var.weekday.weekday] - return cls._encode(encoded, type_=DAT.RELATIVEDELTA) + return cls._encode(encode_relativedelta(var), type_=DAT.RELATIVEDELTA) elif callable(var): return str(get_python_source(var)) elif isinstance(var, set): @@ -284,11 +339,9 @@ def _deserialize(cls, encoded_var: Any) -> Any: elif type_ == DAT.TIMEDELTA: return datetime.timedelta(seconds=var) elif type_ == DAT.TIMEZONE: - return Timezone(var) + return decode_timezone(var) elif type_ == DAT.RELATIVEDELTA: - if 'weekday' in var: - var['weekday'] = relativedelta.weekday(*var['weekday']) # type: ignore - return relativedelta.relativedelta(**var) + return decode_relativedelta(var) elif type_ == DAT.SET: return {cls._deserialize(v) for v in var} elif type_ == DAT.TUPLE: @@ -678,6 +731,13 @@ def serialize_dag(cls, dag: DAG) -> dict: try: serialize_dag = cls.serialize_to_json(dag, cls._decorated_fields) + # If schedule_interval is backed by timetable, serialize only + # timetable; vice versa for a timetable backed by schedule_interval. + if dag.timetable.summary == dag.schedule_interval: + del serialize_dag["schedule_interval"] + else: + del serialize_dag["timetable"] + serialize_dag["tasks"] = [cls._serialize(task) for _, task in dag.task_dict.items()] serialize_dag["dag_dependencies"] = [ vars(t) @@ -716,19 +776,29 @@ def deserialize_dag(cls, encoded_dag: Dict[str, Any]) -> 'SerializedDAG': k = "task_dict" elif k == "timezone": v = cls._deserialize_timezone(v) - elif k in {"dagrun_timeout"}: + elif k == "dagrun_timeout": v = cls._deserialize_timedelta(v) elif k.endswith("_date"): v = cls._deserialize_datetime(v) elif k == "edge_info": # Value structure matches exactly pass + elif k == "timetable": + v = decode_timetable(v) elif k in cls._decorated_fields: v = cls._deserialize(v) # else use v as it is setattr(dag, k, v) + # A DAG is always serialized with only one of schedule_interval and + # timetable. This back-populates the other to ensure the two attributes + # line up correctly on the DAG instance. + if "timetable" in encoded_dag: + dag.schedule_interval = dag.timetable.summary + else: + dag.timetable = create_timetable(dag.schedule_interval, dag.timezone) + # Set _task_group if "_task_group" in encoded_dag: diff --git a/airflow/timetables/base.py b/airflow/timetables/base.py index 70278511adaf3..2f1d0980f6416 100644 --- a/airflow/timetables/base.py +++ b/airflow/timetables/base.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import NamedTuple, Optional +from typing import Any, Dict, NamedTuple, Optional from pendulum import DateTime @@ -94,28 +94,79 @@ def logical_date(self) -> DateTime: class Timetable(Protocol): """Protocol that all Timetable classes are expected to implement.""" + periodic: bool = True + """Whether this timetable runs periodically. + + This defaults to and should generally be *True*, but some special setups + like ``schedule_interval=None`` and ``"@once"`` set it to *False*. + """ + + can_run: bool = True + """Whether this timetable can actually schedule runs. + + This defaults to and should generally be *True*, but ``NullTimetable`` sets + this to *False*. + """ + + @classmethod + def deserialize(cls, data: Dict[str, Any]) -> "Timetable": + """Deserialize a timetable from data. + + This is called when a serialized DAG is deserialized. ``data`` will be + whatever was returned by ``serialize`` during DAG serialization. The + default implementation constructs the timetable without any arguments. + """ + return cls() + + def serialize(self) -> Dict[str, Any]: + """Serialize the timetable for JSON encoding. + + This is called during DAG serialization to store timetable information + in the database. This should return a JSON-serializable dict that will + be fed into ``deserialize`` when the DAG is deserialized. The default + implementation returns an empty dict. + """ + return {} + def validate(self) -> None: """Validate the timetable is correctly specified. - This should raise AirflowTimetableInvalid on validation failure. + Override this method to provide run-time validation raised when a DAG + is put into a dagbag. The default implementation does nothing. + + :raises: AirflowTimetableInvalid on validation failure. """ - raise NotImplementedError() + pass - def infer_data_interval(self, run_after: DateTime) -> DataInterval: + @property + def summary(self) -> str: + """A short summary for the timetable. + + This is used to display the timetable in the web UI. A cron expression + timetable, for example, can use this to display the expression. The + default implementation returns the timetable's type name. + """ + return type(self).__name__ + + def infer_data_interval(self, *, run_after: DateTime) -> DataInterval: """When a DAG run is manually triggered, infer a data interval for it. This is used for e.g. manually-triggered runs, where ``run_after`` would - be when the user triggers the run. + be when the user triggers the run. The default implementation raises + ``NotImplementedError``. """ raise NotImplementedError() def next_dagrun_info( self, + *, last_automated_dagrun: Optional[DateTime], restriction: TimeRestriction, ) -> Optional[DagRunInfo]: """Provide information to schedule the next DagRun. + The default implementation raises ``NotImplementedError``. + :param last_automated_dagrun: The ``execution_date`` of the associated DAG's last scheduled or backfilled run (manual runs not considered). :param restriction: Restriction to apply when scheduling the DAG run. diff --git a/airflow/timetables/interval.py b/airflow/timetables/interval.py index 168e28f8ed29b..de8a566a6966c 100644 --- a/airflow/timetables/interval.py +++ b/airflow/timetables/interval.py @@ -16,12 +16,20 @@ # under the License. import datetime -from typing import Any, Optional +from typing import Any, Dict, Optional, Union +from croniter import CroniterBadCronError, CroniterBadDateError, croniter +from dateutil.relativedelta import relativedelta from pendulum import DateTime +from pendulum.tz.timezone import Timezone +from airflow.compat.functools import cached_property +from airflow.exceptions import AirflowTimetableInvalid from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable -from airflow.timetables.schedules import CronSchedule, Delta, DeltaSchedule, Schedule +from airflow.utils.dates import cron_presets +from airflow.utils.timezone import convert_to_utc, make_aware, make_naive + +Delta = Union[datetime.timedelta, relativedelta] class _DataIntervalTimetable(Timetable): @@ -32,59 +40,186 @@ class _DataIntervalTimetable(Timetable): instance), and schedule a DagRun at the end of each interval. """ - _schedule: Schedule + def _skip_to_latest(self, earliest: Optional[DateTime]) -> DateTime: + """Bound the earliest time a run can be scheduled. - def __eq__(self, other: Any) -> bool: - """Delegate to the schedule.""" - if not isinstance(other, _DataIntervalTimetable): - return NotImplemented - return self._schedule == other._schedule + This is called when ``catchup=False``. See docstring of subclasses for + exact skipping behaviour of a schedule. + """ + raise NotImplementedError() - def validate(self) -> None: - self._schedule.validate() + def _align(self, current: DateTime) -> DateTime: + """Align given time to the scheduled. + + For fixed schedules (e.g. every midnight); this finds the next time that + aligns to the declared time, if the given time does not align. If the + schedule is not fixed (e.g. every hour), the given time is returned. + """ + raise NotImplementedError() + + def _get_next(self, current: DateTime) -> DateTime: + """Get the first schedule after the current time.""" + raise NotImplementedError() + + def _get_prev(self, current: DateTime) -> DateTime: + """Get the last schedule before the current time.""" + raise NotImplementedError() def next_dagrun_info( self, + *, last_automated_dagrun: Optional[DateTime], restriction: TimeRestriction, ) -> Optional[DagRunInfo]: earliest = restriction.earliest if not restriction.catchup: - earliest = self._schedule.skip_to_latest(earliest) + earliest = self._skip_to_latest(earliest) if last_automated_dagrun is None: # First run; schedule the run at the first available time matching # the schedule, and retrospectively create a data interval for it. if earliest is None: return None - start = self._schedule.align(earliest) + start = self._align(earliest) else: # There's a previous run. Create a data interval starting from when # the end of the previous interval. - start = self._schedule.get_next(last_automated_dagrun) + start = self._get_next(last_automated_dagrun) if restriction.latest is not None and start > restriction.latest: return None - end = self._schedule.get_next(start) + end = self._get_next(start) return DagRunInfo.interval(start=start, end=end) +def _is_schedule_fixed(expression: str) -> bool: + """Figures out if the schedule has a fixed time (e.g. 3 AM every day). + + :return: True if the schedule has a fixed time, False if not. + + Detection is done by "peeking" the next two cron trigger time; if the + two times have the same minute and hour value, the schedule is fixed, + and we *don't* need to perform the DST fix. + + This assumes DST happens on whole minute changes (e.g. 12:59 -> 12:00). + """ + cron = croniter(expression) + next_a = cron.get_next(datetime.datetime) + next_b = cron.get_next(datetime.datetime) + return next_b.minute == next_a.minute and next_b.hour == next_a.hour + + class CronDataIntervalTimetable(_DataIntervalTimetable): """Timetable that schedules data intervals with a cron expression. This corresponds to ``schedule_interval=``, where ```` is either a five/six-segment representation, or one of ``cron_presets``. + The implementation extends on croniter to add timezone awareness. This is + because crontier works only with naive timestamps, and cannot consider DST + when determining the next/previous time. + Don't pass ``@once`` in here; use ``OnceTimetable`` instead. """ - def __init__(self, cron: str, timezone: datetime.tzinfo) -> None: - self._schedule = CronSchedule(cron, timezone) + def __init__(self, cron: str, timezone: Timezone) -> None: + self._expression = cron_presets.get(cron, cron) + self._timezone = timezone - def infer_data_interval(self, run_after: DateTime) -> DataInterval: + @classmethod + def deserialize(cls, data: Dict[str, Any]) -> "Timetable": + from airflow.serialization.serialized_objects import decode_timezone + + return cls(data["expression"], decode_timezone(data["timezone"])) + + def __eq__(self, other: Any) -> bool: + """Both expression and timezone should match. + + This is only for testing purposes and should not be relied on otherwise. + """ + if not isinstance(other, CronDataIntervalTimetable): + return NotImplemented + return self._expression == other._expression and self._timezone == other._timezone + + @property + def summary(self) -> str: + return self._expression + + def serialize(self) -> Dict[str, Any]: + from airflow.serialization.serialized_objects import encode_timezone + + return {"expression": self._expression, "timezone": encode_timezone(self._timezone)} + + def validate(self) -> None: + try: + croniter(self._expression) + except (CroniterBadCronError, CroniterBadDateError) as e: + raise AirflowTimetableInvalid(str(e)) + + @cached_property + def _should_fix_dst(self) -> bool: + # This is lazy so instantiating a schedule does not immediately raise + # an exception. Validity is checked with validate() during DAG-bagging. + return not _is_schedule_fixed(self._expression) + + def _get_next(self, current: DateTime) -> DateTime: + """Get the first schedule after specified time, with DST fixed.""" + naive = make_naive(current, self._timezone) + cron = croniter(self._expression, start_time=naive) + scheduled = cron.get_next(datetime.datetime) + if not self._should_fix_dst: + return convert_to_utc(make_aware(scheduled, self._timezone)) + delta = scheduled - naive + return convert_to_utc(current.in_timezone(self._timezone) + delta) + + def _get_prev(self, current: DateTime) -> DateTime: + """Get the first schedule before specified time, with DST fixed.""" + naive = make_naive(current, self._timezone) + cron = croniter(self._expression, start_time=naive) + scheduled = cron.get_prev(datetime.datetime) + if not self._should_fix_dst: + return convert_to_utc(make_aware(scheduled, self._timezone)) + delta = naive - scheduled + return convert_to_utc(current.in_timezone(self._timezone) - delta) + + def _align(self, current: DateTime) -> DateTime: + """Get the next scheduled time. + + This is ``current + interval``, unless ``current`` is first interval, + then ``current`` is returned. + """ + next_time = self._get_next(current) + if self._get_prev(next_time) != current: + return next_time + return current + + def _skip_to_latest(self, earliest: Optional[DateTime]) -> DateTime: + """Bound the earliest time a run can be scheduled. + + The logic is that we move start_date up until one period before, so the + current time is AFTER the period end, and the job can be created... + + This is slightly different from the delta version at terminal values. + If the next schedule should start *right now*, we want the data interval + that start right now now, not the one that ends now. + """ + current_time = DateTime.utcnow() + next_start = self._get_next(current_time) + last_start = self._get_prev(current_time) + if next_start == current_time: + new_start = last_start + elif next_start > current_time: + new_start = self._get_prev(last_start) + else: + raise AssertionError("next schedule shouldn't be earlier") + if earliest is None: + return new_start + return max(new_start, earliest) + + def infer_data_interval(self, *, run_after: DateTime) -> DataInterval: # Get the last complete period before run_after, e.g. if a DAG run is # scheduled at each midnight, the data interval of a manually triggered # run at 1am 25th is between 0am 24th and 0am 25th. - end = self._schedule.get_prev(self._schedule.align(run_after)) - return DataInterval(start=self._schedule.get_prev(end), end=end) + end = self._get_prev(self._align(run_after)) + return DataInterval(start=self._get_prev(end), end=end) class DeltaDataIntervalTimetable(_DataIntervalTimetable): @@ -96,7 +231,64 @@ class DeltaDataIntervalTimetable(_DataIntervalTimetable): """ def __init__(self, delta: Delta) -> None: - self._schedule = DeltaSchedule(delta) + self._delta = delta + + @classmethod + def deserialize(cls, data: Dict[str, Any]) -> "Timetable": + from airflow.serialization.serialized_objects import decode_relativedelta + + delta = data["delta"] + if isinstance(delta, dict): + return cls(decode_relativedelta(delta)) + return cls(datetime.timedelta(seconds=delta)) + + def __eq__(self, other: Any) -> bool: + """The offset should match. + + This is only for testing purposes and should not be relied on otherwise. + """ + if not isinstance(other, DeltaDataIntervalTimetable): + return NotImplemented + return self._delta == other._delta + + @property + def summary(self) -> str: + return str(self._delta) + + def serialize(self) -> Dict[str, Any]: + from airflow.serialization.serialized_objects import encode_relativedelta + + if isinstance(self._delta, datetime.timedelta): + delta = self._delta.total_seconds() + else: + delta = encode_relativedelta(self._delta) + return {"delta": delta} + + def validate(self) -> None: + if self._delta.total_seconds() <= 0: + raise AirflowTimetableInvalid("schedule interval must be positive") + + def _get_next(self, current: DateTime) -> DateTime: + return convert_to_utc(current + self._delta) + + def _get_prev(self, current: DateTime) -> DateTime: + return convert_to_utc(current - self._delta) + + def _align(self, current: DateTime) -> DateTime: + return current + + def _skip_to_latest(self, earliest: Optional[DateTime]) -> DateTime: + """Bound the earliest time a run can be scheduled. + + The logic is that we move start_date up until one period before, so the + current time is AFTER the period end, and the job can be created... + + This is slightly different from the cron version at terminal values. + """ + new_start = self._get_prev(DateTime.utcnow()) + if earliest is None: + return new_start + return max(new_start, earliest) def infer_data_interval(self, run_after: DateTime) -> DataInterval: - return DataInterval(start=self._schedule.get_prev(run_after), end=run_after) + return DataInterval(start=self._get_prev(run_after), end=run_after) diff --git a/airflow/timetables/schedules.py b/airflow/timetables/schedules.py deleted file mode 100644 index 8129a8845bf68..0000000000000 --- a/airflow/timetables/schedules.py +++ /dev/null @@ -1,207 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import datetime -import typing - -from croniter import CroniterBadCronError, CroniterBadDateError, croniter -from dateutil.relativedelta import relativedelta -from pendulum import DateTime - -from airflow.compat.functools import cached_property -from airflow.exceptions import AirflowTimetableInvalid -from airflow.typing_compat import Protocol -from airflow.utils.dates import cron_presets -from airflow.utils.timezone import convert_to_utc, make_aware, make_naive - -Delta = typing.Union[datetime.timedelta, relativedelta] - - -class Schedule(Protocol): - """Base protocol for schedules.""" - - def skip_to_latest(self, earliest: typing.Optional[DateTime]) -> DateTime: - """Bound the earliest time a run can be scheduled. - - This is called when ``catchup=False``. See docstring of subclasses for - exact skipping behaviour of a schedule. - """ - raise NotImplementedError() - - def validate(self) -> None: - """Validate the timetable is correctly specified. - - This should raise AirflowTimetableInvalid on validation failure. - """ - raise NotImplementedError() - - def get_next(self, current: DateTime) -> DateTime: - """Get the first schedule after the current time.""" - raise NotImplementedError() - - def get_prev(self, current: DateTime) -> DateTime: - """Get the last schedule before the current time.""" - raise NotImplementedError() - - def align(self, current: DateTime) -> DateTime: - """Align given time to the scheduled. - - For fixed schedules (e.g. every midnight); this finds the next time that - aligns to the declared time, if the given time does not align. If the - schedule is not fixed (e.g. every hour), the given time is returned. - """ - raise NotImplementedError() - - -def _is_schedule_fixed(expression: str) -> bool: - """Figures out if the schedule has a fixed time (e.g. 3 AM every day). - - :return: True if the schedule has a fixed time, False if not. - - Detection is done by "peeking" the next two cron trigger time; if the - two times have the same minute and hour value, the schedule is fixed, - and we *don't* need to perform the DST fix. - - This assumes DST happens on whole minute changes (e.g. 12:59 -> 12:00). - """ - cron = croniter(expression) - next_a = cron.get_next(datetime.datetime) - next_b = cron.get_next(datetime.datetime) - return next_b.minute == next_a.minute and next_b.hour == next_a.hour - - -class CronSchedule(Schedule): - """Schedule things from a cron expression. - - The implementation extends on croniter to add timezone awareness. This is - because crontier works only with naive timestamps, and cannot consider DST - when determining the next/previous time. - """ - - def __init__(self, expression: str, timezone: datetime.tzinfo) -> None: - self._expression = expression = cron_presets.get(expression, expression) - self._timezone = timezone - - def __eq__(self, other: typing.Any) -> bool: - """Both expression and timezone should match.""" - if not isinstance(other, CronSchedule): - return NotImplemented - return self._expression == other._expression and self._timezone == other._timezone - - def validate(self) -> None: - try: - croniter(self._expression) - except (CroniterBadCronError, CroniterBadDateError) as e: - raise AirflowTimetableInvalid(str(e)) - - @cached_property - def _should_fix_dst(self) -> bool: - # This is lazy so instantiating a schedule does not immediately raise - # an exception. Validity is checked with validate() during DAG-bagging. - return not _is_schedule_fixed(self._expression) - - def get_next(self, current: DateTime) -> DateTime: - """Get the first schedule after specified time, with DST fixed.""" - naive = make_naive(current, self._timezone) - cron = croniter(self._expression, start_time=naive) - scheduled = cron.get_next(datetime.datetime) - if not self._should_fix_dst: - return convert_to_utc(make_aware(scheduled, self._timezone)) - delta = scheduled - naive - return convert_to_utc(current.in_timezone(self._timezone) + delta) - - def get_prev(self, current: DateTime) -> DateTime: - """Get the first schedule before specified time, with DST fixed.""" - naive = make_naive(current, self._timezone) - cron = croniter(self._expression, start_time=naive) - scheduled = cron.get_prev(datetime.datetime) - if not self._should_fix_dst: - return convert_to_utc(make_aware(scheduled, self._timezone)) - delta = naive - scheduled - return convert_to_utc(current.in_timezone(self._timezone) - delta) - - def align(self, current: DateTime) -> DateTime: - """Get the next scheduled time. - - This is ``current + interval``, unless ``current`` is first interval, - then ``current`` is returned. - """ - next_time = self.get_next(current) - if self.get_prev(next_time) != current: - return next_time - return current - - def skip_to_latest(self, earliest: typing.Optional[DateTime]) -> DateTime: - """Bound the earliest time a run can be scheduled. - - The logic is that we move start_date up until one period before, so the - current time is AFTER the period end, and the job can be created... - - This is slightly different from the delta version at terminal values. - If the next schedule should start *right now*, we want the data interval - that start right now now, not the one that ends now. - """ - current_time = DateTime.utcnow() - next_start = self.get_next(current_time) - last_start = self.get_prev(current_time) - if next_start == current_time: - new_start = last_start - elif next_start > current_time: - new_start = self.get_prev(last_start) - else: - raise AssertionError("next schedule shouldn't be earlier") - if earliest is None: - return new_start - return max(new_start, earliest) - - -class DeltaSchedule(Schedule): - """Schedule things on a fixed time delta.""" - - def __init__(self, delta: Delta) -> None: - self._delta = delta - - def __eq__(self, other: typing.Any) -> bool: - """The offset should match.""" - if not isinstance(other, DeltaSchedule): - return NotImplemented - return self._delta == other._delta - - def validate(self) -> None: - pass # TODO: Check the delta is positive? - - def get_next(self, current: DateTime) -> DateTime: - return convert_to_utc(current + self._delta) - - def get_prev(self, current: DateTime) -> DateTime: - return convert_to_utc(current - self._delta) - - def align(self, current: DateTime) -> DateTime: - return current - - def skip_to_latest(self, earliest: typing.Optional[DateTime]) -> DateTime: - """Bound the earliest time a run can be scheduled. - - The logic is that we move start_date up until one period before, so the - current time is AFTER the period end, and the job can be created... - - This is slightly different from the cron version at terminal values. - """ - new_start = self.get_prev(DateTime.utcnow()) - if earliest is None: - return new_start - return max(new_start, earliest) diff --git a/airflow/timetables/simple.py b/airflow/timetables/simple.py index f9c062e3342fb..a71755d40d6d1 100644 --- a/airflow/timetables/simple.py +++ b/airflow/timetables/simple.py @@ -15,59 +15,71 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Optional +from typing import Any, Dict, Optional from pendulum import DateTime from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable -class NullTimetable(Timetable): - """Timetable that never schedules anything. +class _TrivialTimetable(Timetable): + """Some code reuse for "trivial" timetables that has nothing complex.""" - This corresponds to ``schedule_interval=None``. - """ + periodic = False + can_run = False + + @classmethod + def deserialize(cls, data: Dict[str, Any]) -> "Timetable": + return cls() def __eq__(self, other: Any) -> bool: - """As long as *other* is of the same type.""" - if not isinstance(other, NullTimetable): + """As long as *other* is of the same type. + + This is only for testing purposes and should not be relied on otherwise. + """ + if not isinstance(other, type(self)): return NotImplemented return True - def validate(self) -> None: - pass + def serialize(self) -> Dict[str, Any]: + return {} - def infer_data_interval(self, run_after: DateTime) -> DataInterval: + def infer_data_interval(self, *, run_after: DateTime) -> DataInterval: return DataInterval.exact(run_after) + +class NullTimetable(_TrivialTimetable): + """Timetable that never schedules anything. + + This corresponds to ``schedule_interval=None``. + """ + + @property + def summary(self) -> str: + return "None" + def next_dagrun_info( self, + *, last_automated_dagrun: Optional[DateTime], restriction: TimeRestriction, ) -> Optional[DagRunInfo]: return None -class OnceTimetable(Timetable): +class OnceTimetable(_TrivialTimetable): """Timetable that schedules the execution once as soon as possible. This corresponds to ``schedule_interval="@once"``. """ - def __eq__(self, other: Any) -> bool: - """As long as *other* is of the same type.""" - if not isinstance(other, OnceTimetable): - return NotImplemented - return True - - def validate(self) -> None: - pass - - def infer_data_interval(self, run_after: DateTime) -> DataInterval: - return DataInterval.exact(run_after) + @property + def summary(self) -> str: + return "@once" def next_dagrun_info( self, + *, last_automated_dagrun: Optional[DateTime], restriction: TimeRestriction, ) -> Optional[DagRunInfo]: diff --git a/airflow/utils/module_loading.py b/airflow/utils/module_loading.py index e863f8641894e..d1f3cb1659f14 100644 --- a/airflow/utils/module_loading.py +++ b/airflow/utils/module_loading.py @@ -35,3 +35,8 @@ def import_string(dotted_path): return getattr(module, class_name) except AttributeError: raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute/class') + + +def as_importable_string(thing) -> str: + """Convert an attribute/class to a string importable by ``import_string``.""" + return f"{thing.__module__}.{thing.__name__}" diff --git a/airflow/www/views.py b/airflow/www/views.py index 17c39598a17e0..ae62fc49279ed 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -2641,7 +2641,7 @@ def landing_times(self, session=None): x_points[task_id] = [] for ti in tis: ts = ti.execution_date - if dag.schedule_interval and dag.following_schedule(ts): + if dag.following_schedule(ts): ts = dag.following_schedule(ts) if ti.end_date: dttm = wwwutils.epoch(ti.execution_date) diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py index ba9e0da3fe134..31826a8205d5f 100644 --- a/tests/jobs/test_backfill_job.py +++ b/tests/jobs/test_backfill_job.py @@ -48,6 +48,7 @@ from airflow.utils.types import DagRunType from tests.test_utils.db import clear_db_dags, clear_db_pools, clear_db_runs, set_default_pool_slots from tests.test_utils.mock_executor import MockExecutor +from tests.test_utils.timetables import cron_timetable logger = logging.getLogger(__name__) @@ -1120,7 +1121,7 @@ def test_backfill_execute_subdag(self): subdag_op_task = dag.get_task('section-1') subdag = subdag_op_task.subdag - subdag.schedule_interval = '@daily' + subdag.timetable = cron_timetable('@daily') start_date = timezone.utcnow() executor = MockExecutor() diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 567f57aa313e1..0ee0baaf24f2b 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -1137,8 +1137,7 @@ def test_schedule_dag_once(self): it is called, and not scheduled the second. """ dag_id = "test_schedule_dag_once" - dag = DAG(dag_id=dag_id) - dag.schedule_interval = '@once' + dag = DAG(dag_id=dag_id, schedule_interval="@once") assert isinstance(dag.timetable, OnceTimetable) dag.add_task(BaseOperator(task_id="faketastic", owner='Also fake', start_date=TEST_DATE)) @@ -1161,8 +1160,7 @@ def test_fractional_seconds(self): Tests if fractional seconds are stored in the database """ dag_id = "test_fractional_seconds" - dag = DAG(dag_id=dag_id) - dag.schedule_interval = '@once' + dag = DAG(dag_id=dag_id, schedule_interval="@once") dag.add_task(BaseOperator(task_id="faketastic", owner='Also fake', start_date=TEST_DATE)) start_date = timezone.utcnow() @@ -1261,11 +1259,9 @@ def test_get_paused_dag_ids(self): (datetime.timedelta(days=1), delta_timetable(datetime.timedelta(days=1))), ] ) - def test_timetable(self, schedule_interval, expected_timetable): + def test_timetable_from_schedule_interval(self, schedule_interval, expected_timetable): dag = DAG("test_schedule_interval", schedule_interval=schedule_interval) - assert dag.timetable == expected_timetable - assert dag.schedule_interval == schedule_interval def test_create_dagrun_run_id_is_generated(self): dag = DAG(dag_id="run_id_is_generated") diff --git a/tests/providers/amazon/aws/sensors/test_s3_keys_unchanged.py b/tests/providers/amazon/aws/sensors/test_s3_keys_unchanged.py index 9e1bbcc4d53dc..fc83b38d12564 100644 --- a/tests/providers/amazon/aws/sensors/test_s3_keys_unchanged.py +++ b/tests/providers/amazon/aws/sensors/test_s3_keys_unchanged.py @@ -32,14 +32,11 @@ class TestS3KeysUnchangedSensor(TestCase): def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE, - } - dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) - dag.schedule_interval = '@once' - self.dag = dag - + self.dag = DAG( + TEST_DAG_ID + 'test_schedule_dag_once', + start_date=DEFAULT_DATE, + schedule_interval="@once", + ) self.sensor = S3KeysUnchangedSensor( task_id='sensor_1', bucket_name='test-bucket', diff --git a/tests/providers/amazon/aws/transfers/test_s3_to_sftp.py b/tests/providers/amazon/aws/transfers/test_s3_to_sftp.py index 539230c965293..e7255c328b4d9 100644 --- a/tests/providers/amazon/aws/transfers/test_s3_to_sftp.py +++ b/tests/providers/amazon/aws/transfers/test_s3_to_sftp.py @@ -51,12 +51,11 @@ def setUp(self): hook = SSHHook(ssh_conn_id='ssh_default') s3_hook = S3Hook('aws_default') hook.no_host_key_check = True - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE, - } - dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) - dag.schedule_interval = '@once' + dag = DAG( + TEST_DAG_ID + 'test_schedule_dag_once', + start_date=DEFAULT_DATE, + schedule_interval='@once', + ) self.hook = hook self.s3_hook = s3_hook diff --git a/tests/providers/amazon/aws/transfers/test_sftp_to_s3.py b/tests/providers/amazon/aws/transfers/test_sftp_to_s3.py index 30267d46eecb9..409809e059a85 100644 --- a/tests/providers/amazon/aws/transfers/test_sftp_to_s3.py +++ b/tests/providers/amazon/aws/transfers/test_sftp_to_s3.py @@ -50,12 +50,11 @@ def setUp(self): s3_hook = S3Hook('aws_default') hook.no_host_key_check = True - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE, - } - dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) - dag.schedule_interval = '@once' + dag = DAG( + TEST_DAG_ID + 'test_schedule_dag_once', + schedule_interval="@once", + start_date=DEFAULT_DATE, + ) self.hook = hook self.s3_hook = s3_hook diff --git a/tests/providers/google/cloud/sensors/test_gcs.py b/tests/providers/google/cloud/sensors/test_gcs.py index 66c7e0427339c..47ac70b828851 100644 --- a/tests/providers/google/cloud/sensors/test_gcs.py +++ b/tests/providers/google/cloud/sensors/test_gcs.py @@ -201,13 +201,11 @@ def test_execute_timeout(self, mock_hook): class TestGCSUploadSessionCompleteSensor(TestCase): def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE, - } - dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) - dag.schedule_interval = '@once' - self.dag = dag + self.dag = DAG( + TEST_DAG_ID + 'test_schedule_dag_once', + schedule_interval="@once", + start_date=DEFAULT_DATE, + ) self.sensor = GCSUploadSessionCompleteSensor( task_id='sensor_1', diff --git a/tests/providers/sftp/operators/test_sftp.py b/tests/providers/sftp/operators/test_sftp.py index 478cd3614d92b..d286184544358 100644 --- a/tests/providers/sftp/operators/test_sftp.py +++ b/tests/providers/sftp/operators/test_sftp.py @@ -42,12 +42,11 @@ def setUp(self): hook = SSHHook(ssh_conn_id='ssh_default') hook.no_host_key_check = True - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE, - } - dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) - dag.schedule_interval = '@once' + dag = DAG( + TEST_DAG_ID + 'test_schedule_dag_once', + schedule_interval="@once", + start_date=DEFAULT_DATE, + ) self.hook = hook self.dag = dag self.test_dir = "/tmp" diff --git a/tests/providers/ssh/operators/test_ssh.py b/tests/providers/ssh/operators/test_ssh.py index e39b8fbb8c156..7083354301b70 100644 --- a/tests/providers/ssh/operators/test_ssh.py +++ b/tests/providers/ssh/operators/test_ssh.py @@ -43,12 +43,12 @@ def setUp(self): hook = SSHHook(ssh_conn_id='ssh_default') hook.no_host_key_check = True - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE, - } - dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) - dag.schedule_interval = '@once' + + dag = DAG( + TEST_DAG_ID + 'test_schedule_dag_once', + schedule_interval="@once", + start_date=DEFAULT_DATE, + ) self.hook = hook self.dag = dag diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index c3872165e09b1..450bed93f0b19 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -18,6 +18,7 @@ """Unit tests for stringified DAGs.""" +import copy import importlib import importlib.util import multiprocessing @@ -43,7 +44,7 @@ from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG from airflow.timetables.simple import NullTimetable, OnceTimetable from tests.test_utils.mock_operators import CustomOperator, CustomOpLink, GoogleLink -from tests.test_utils.timetables import cron_timetable, delta_timetable +from tests.test_utils.timetables import CustomSerializationTimetable, cron_timetable, delta_timetable executor_config_pod = k8s.V1Pod( metadata=k8s.V1ObjectMeta(name="my-name"), @@ -137,6 +138,7 @@ 'label': 'custom_task', }, ], + "schedule_interval": {"__type": "timedelta", "__var": 86400.0}, "timezone": "UTC", "_access_control": { "__type": "dict", @@ -247,6 +249,14 @@ def collect_dags(dag_folder=None): return dags +def get_timetable_based_simple_dag(timetable): + """Create a simple_dag variant that uses timetable instead of schedule_interval.""" + dag = collect_dags(["airflow/example_dags"])["simple_dag"] + dag.timetable = timetable + dag.schedule_interval = timetable.summary + return dag + + def serialize_subprocess(queue, dag_folder): """Validate pickle in a subprocess.""" dags = collect_dags(dag_folder) @@ -288,6 +298,36 @@ def test_serialization(self): # Compares with the ground truth of JSON string. self.validate_serialized_dag(serialized_dags['simple_dag'], serialized_simple_dag_ground_truth) + @parameterized.expand( + [ + ( + cron_timetable("0 0 * * *"), + { + "type": "airflow.timetables.interval.CronDataIntervalTimetable", + "value": {"expression": "0 0 * * *", "timezone": "UTC"}, + }, + ), + ( + CustomSerializationTimetable("foo"), + { + "type": "tests.test_utils.timetables.CustomSerializationTimetable", + "value": {"value": "foo"}, + }, + ), + ], + ) + def test_dag_serialization_to_timetable(self, timetable, serialized_timetable): + """Verify a timetable-backed schedule_interval is excluded in serialization.""" + dag = get_timetable_based_simple_dag(timetable) + serialized_dag = SerializedDAG.to_dict(dag) + SerializedDAG.validate_schema(serialized_dag) + + expected = copy.deepcopy(serialized_simple_dag_ground_truth) + del expected["dag"]["schedule_interval"] + expected["dag"]["timetable"] = serialized_timetable + + self.validate_serialized_dag(serialized_dag, expected) + def validate_serialized_dag(self, json_dag, ground_truth_dag): """Verify serialized DAGs match the ground truth.""" assert json_dag['dag']['fileloc'].split('/')[-1] == 'test_dag_serialization.py' @@ -348,15 +388,23 @@ def test_roundtrip_provider_example_dags(self): serialized_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag)) self.validate_deserialized_dag(serialized_dag, dag) + @parameterized.expand([(cron_timetable("0 0 * * *"),), (CustomSerializationTimetable("foo"),)]) + def test_dag_roundtrip_from_timetable(self, timetable): + """Verify a timetable-backed serialization can be deserialized.""" + dag = get_timetable_based_simple_dag(timetable) + roundtripped = SerializedDAG.from_json(SerializedDAG.to_json(dag)) + self.validate_deserialized_dag(roundtripped, dag) + def validate_deserialized_dag(self, serialized_dag, dag): """ Verify that all example DAGs work with DAG Serialization by checking fields between Serialized Dags & non-Serialized Dags """ fields_to_check = dag.get_serialized_fields() - { - # Doesn't implement __eq__ properly. Check manually + # Doesn't implement __eq__ properly. Check manually. + 'timetable', 'timezone', - # Need to check fields in it, to exclude functions + # Need to check fields in it, to exclude functions. 'default_args', "_task_group", } @@ -375,6 +423,8 @@ def validate_deserialized_dag(self, serialized_dag, dag): v == serialized_dag.default_args[k] ), f'{dag.dag_id}.default_args[{k}] does not match' + assert serialized_dag.timetable.summary == dag.timetable.summary + assert serialized_dag.timetable.serialize() == dag.timetable.serialize() assert serialized_dag.timezone.name == dag.timezone.name for task_id in dag.task_ids: @@ -501,12 +551,51 @@ def test_deserialization_end_date(self, dag_end_date, task_end_date, expected_ta @parameterized.expand( [ - (None, None, NullTimetable()), - ("@weekly", "@weekly", cron_timetable("0 0 * * 0")), - ("@once", "@once", OnceTimetable()), + ({"type": "airflow.timetables.simple.NullTimetable", "value": {}}, NullTimetable()), + ( + { + "type": "airflow.timetables.interval.CronDataIntervalTimetable", + "value": {"expression": "@weekly", "timezone": "UTC"}, + }, + cron_timetable("0 0 * * 0"), + ), + ({"type": "airflow.timetables.simple.OnceTimetable", "value": {}}, OnceTimetable()), + ( + { + "type": "airflow.timetables.interval.DeltaDataIntervalTimetable", + "value": {"delta": 86400.0}, + }, + delta_timetable(timedelta(days=1)), + ), + ] + ) + def test_deserialization_timetable( + self, + serialized_timetable, + expected_timetable, + ): + serialized = { + "__version": 1, + "dag": { + "default_args": {"__type": "dict", "__var": {}}, + "_dag_id": "simple_dag", + "fileloc": __file__, + "tasks": [], + "timezone": "UTC", + "timetable": serialized_timetable, + }, + } + SerializedDAG.validate_schema(serialized) + dag = SerializedDAG.from_dict(serialized) + assert dag.timetable == expected_timetable + + @parameterized.expand( + [ + (None, NullTimetable()), + ("@weekly", cron_timetable("0 0 * * 0")), + ("@once", OnceTimetable()), ( {"__type": "timedelta", "__var": 86400.0}, - timedelta(days=1), delta_timetable(timedelta(days=1)), ), ] @@ -514,9 +603,9 @@ def test_deserialization_end_date(self, dag_end_date, task_end_date, expected_ta def test_deserialization_schedule_interval( self, serialized_schedule_interval, - expected_schedule_interval, expected_timetable, ): + """Test DAGs serialized before 2.2 can be correctly deserialized.""" serialized = { "__version": 1, "dag": { @@ -530,10 +619,7 @@ def test_deserialization_schedule_interval( } SerializedDAG.validate_schema(serialized) - dag = SerializedDAG.from_dict(serialized) - - assert dag.schedule_interval == expected_schedule_interval assert dag.timetable == expected_timetable @parameterized.expand( diff --git a/tests/test_utils/timetables.py b/tests/test_utils/timetables.py index c6db4c7394038..8a1253b437a7a 100644 --- a/tests/test_utils/timetables.py +++ b/tests/test_utils/timetables.py @@ -16,6 +16,7 @@ # under the License. from airflow import settings +from airflow.timetables.base import Timetable from airflow.timetables.interval import CronDataIntervalTimetable, DeltaDataIntervalTimetable @@ -25,3 +26,19 @@ def cron_timetable(expr: str) -> CronDataIntervalTimetable: def delta_timetable(delta) -> DeltaDataIntervalTimetable: return DeltaDataIntervalTimetable(delta) + + +class CustomSerializationTimetable(Timetable): + def __init__(self, value: str): + self.value = value + + @classmethod + def deserialize(cls, data): + return cls(data["value"]) + + def serialize(self): + return {"value": self.value} + + @property + def summary(self): + return f"{type(self).__name__}({self.value!r})"